From c761be5cc5a3c41f40de498aa9f4869199b8d79b Mon Sep 17 00:00:00 2001 From: Gagan Kaushik Date: Sat, 4 Apr 2026 05:42:13 +0000 Subject: [PATCH 01/24] llama3 mfu experiment Signed-off-by: Gagan Kaushik --- .../recipes/llama3_native_te/compare_mfu.py | 188 +++++++++++ .../llama3_native_te/compare_mfu_common.py | 252 +++++++++++++++ .../llama3_native_te/compare_mfu_multigpu.py | 292 ++++++++++++++++++ 3 files changed, 732 insertions(+) create mode 100644 bionemo-recipes/recipes/llama3_native_te/compare_mfu.py create mode 100644 bionemo-recipes/recipes/llama3_native_te/compare_mfu_common.py create mode 100644 bionemo-recipes/recipes/llama3_native_te/compare_mfu_multigpu.py diff --git a/bionemo-recipes/recipes/llama3_native_te/compare_mfu.py b/bionemo-recipes/recipes/llama3_native_te/compare_mfu.py new file mode 100644 index 0000000000..f4ce8b1d22 --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/compare_mfu.py @@ -0,0 +1,188 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Single-GPU MFU comparison: TE vs HF head-to-head. + +Compares FLOPs counting methods and measures MFU for TE and HF models on a single GPU. +No distributed setup required. + +Usage: + cd bionemo-recipes/recipes/llama3_native_te + python compare_mfu.py + python compare_mfu.py --seq-len 2048 --batch-size 2 +""" + +import argparse +import json +import sys +from pathlib import Path + +import torch +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaForCausalLM + +from compare_mfu_common import ( + cleanup_model, + compute_flops_first_principles, + compute_flops_readme, + count_flops_with_model, + create_te_model_on_gpu, + detect_gpu_peak_tflops, + format_flops, + measure_step_time, + print_breakdown, +) +from modeling_llama_te import NVLlamaConfig + + +def main(): + """Run single-GPU MFU comparison: TE vs HF.""" + parser = argparse.ArgumentParser(description="Single-GPU MFU comparison: TE vs HF") + parser.add_argument("--config-path", default="./model_configs/lingua-1B", help="Model config directory") + parser.add_argument("--batch-size", type=int, default=1, help="Micro batch size") + parser.add_argument("--seq-len", type=int, default=4096, help="Sequence length") + parser.add_argument("--peak-tflops", type=float, default=None, help="Override GPU peak bf16 TFLOPS") + parser.add_argument("--warmup-steps", type=int, default=10, help="Warmup iterations before timing") + parser.add_argument("--timed-steps", type=int, default=20, help="Timed iterations to average") + args = parser.parse_args() + + # --- Load model config --- + config_path = Path(args.config_path) / "config.json" + with open(config_path) as f: + config_dict = json.load(f) + + b = args.batch_size + s = args.seq_len + h = config_dict["hidden_size"] + num_layers = config_dict["num_hidden_layers"] + vocab_size = config_dict["vocab_size"] + n_kv_heads = config_dict["num_key_value_heads"] + n_heads = config_dict["num_attention_heads"] + head_dim = h // n_heads + ffn_hidden_size = config_dict["intermediate_size"] + + # --- GPU detection --- + if args.peak_tflops: + peak_tflops = args.peak_tflops + device_name = torch.cuda.get_device_name(0) + else: + peak_tflops, device_name = detect_gpu_peak_tflops() + if peak_tflops is None: + print(f"ERROR: Could not auto-detect GPU peak TFLOPS for: {device_name}") + print("Use --peak-tflops to specify manually.") + sys.exit(1) + + peak_flops_per_sec = peak_tflops * 1e12 + + print(f"GPU: {device_name} (Peak: {peak_tflops:.1f} TFLOPS bf16)") + print( + f"Config: H={h}, L={num_layers}, n_heads={n_heads}, n_kv_heads={n_kv_heads}," + f" head_dim={head_dim}, I={ffn_hidden_size}, V={vocab_size}" + ) + print(f"Batch: B={b}, S={s}") + print() + + # ========================================================================= + # Table 1: FLOPs Counting + # ========================================================================= + total_flops_readme = compute_flops_readme(b, s, h, num_layers, vocab_size) + total_flops_fp, breakdown, lm_head_fwd = compute_flops_first_principles( + b, s, h, num_layers, n_kv_heads, head_dim, ffn_hidden_size, vocab_size + ) + + print("Counting FLOPs with HF model (meta device)...") + hf_config = LlamaConfig.from_pretrained(args.config_path) + hf_config._attn_implementation = "eager" + with torch.device("meta"): + hf_model_meta = LlamaForCausalLM(hf_config) + meta_input_ids = torch.randint(0, vocab_size, (b, s), device="meta") + total_flops_hf_counter = count_flops_with_model(hf_model_meta, meta_input_ids) + del hf_model_meta + print(f" HF FlopCounter: {format_flops(total_flops_hf_counter)} (training)") + + # ========================================================================= + # Table 2: MFU — TE vs HF + # ========================================================================= + input_ids = torch.randint(0, vocab_size, (b, s), device="cuda") + + # --- TE model --- + print(f"\n[1/2] TE model (S={s})...") + te_config = NVLlamaConfig.from_pretrained( + args.config_path, dtype=torch.bfloat16, attn_input_format="bshd", self_attn_mask_type="causal" + ) + te_model = create_te_model_on_gpu(te_config) + te_model.train() + print(f"Measuring TE step time ({args.warmup_steps} warmup + {args.timed_steps} timed)...") + te_step_time = measure_step_time(te_model, input_ids, args.warmup_steps, args.timed_steps) + model_params = sum(p.numel() for p in te_model.parameters()) + print(f" TE step time: {te_step_time:.4f}s") + cleanup_model(te_model) + + # --- HF model --- + print(f"[2/2] HF model (S={s})...") + hf_config_gpu = LlamaConfig.from_pretrained(args.config_path) + hf_model = LlamaForCausalLM(hf_config_gpu).to(dtype=torch.bfloat16, device="cuda") + hf_model.train() + print(f"Measuring HF step time ({args.warmup_steps} warmup + {args.timed_steps} timed)...") + hf_step_time = measure_step_time(hf_model, input_ids, args.warmup_steps, args.timed_steps) + print(f" HF step time: {hf_step_time:.4f}s") + cleanup_model(hf_model) + + # ========================================================================= + # Print results + # ========================================================================= + print() + print("=" * 75) + print(f"MFU Comparison: Lingua-1B (B={b}, S={s}, bf16)") + print(f"GPU: {device_name} (Peak: {peak_tflops:.1f} TFLOPS bf16)") + print("=" * 75) + + # --- Table 1 --- + print() + print("--- Table 1: FLOPs Counting (per training step) ---") + hdr1 = f"{'Method':<24} {'FLOPs/step':>14}" + print(hdr1) + print("-" * len(hdr1)) + for name, flops in [ + ("README Formula", total_flops_readme), + ("First Principles", total_flops_fp), + ("FlopCounter (HF)", total_flops_hf_counter), + ]: + print(f"{name:<24} {format_flops(flops):>14}") + + # --- Table 2 --- + print() + print("--- Table 2: MFU ---") + hdr2 = f"{'Model':<12} {'FLOPs/step':>14} {'Step (s)':>9} {'TFLOPS/s':>9} {'MFU':>7}" + print(hdr2) + print("-" * len(hdr2)) + + for name, flops, step_time in [ + ("TE", total_flops_fp, te_step_time), + ("HF", total_flops_fp, hf_step_time), + ]: + tflops = flops / step_time / 1e12 + mfu = flops / step_time / peak_flops_per_sec * 100 + print(f"{name:<12} {format_flops(flops):>14} {step_time:>8.3f}s {tflops:>8.2f} {mfu:>6.1f}%") + + print() + print(f"TE vs HF speedup: {hf_step_time / te_step_time:.2f}x") + + # --- Breakdown --- + print_breakdown(breakdown, lm_head_fwd, num_layers, total_flops_fp, model_params) + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/recipes/llama3_native_te/compare_mfu_common.py b/bionemo-recipes/recipes/llama3_native_te/compare_mfu_common.py new file mode 100644 index 0000000000..c19ca15a71 --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/compare_mfu_common.py @@ -0,0 +1,252 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared utilities for MFU comparison scripts. + +Provides FLOPs counting formulas, GPU detection, model creation helpers, +step time measurement, and formatting utilities used by both single-GPU +and multi-GPU MFU comparison scripts. +""" + +import gc +import time +from contextlib import nullcontext + +import torch +import torch.distributed as dist +from torch.utils.flop_counter import FlopCounterMode + +from modeling_llama_te import NVLlamaForCausalLM + + +# Peak bf16 TFLOPS for common NVIDIA GPUs (tensor core ops). +GPU_PEAK_TFLOPS_BF16 = { + "H100": 989.0, + "H200": 989.0, + "A100": 312.0, + "A6000": 155.0, + "A5000": 111.0, + "L40": 181.0, + "RTX 4090": 330.0, + "RTX 3090": 142.0, + "GH200": 989.0, + "B200": 2250.0, + "GB200": 2250.0, +} + + +def detect_gpu_peak_tflops(): + """Auto-detect GPU peak bf16 TFLOPS from device name via substring match.""" + device_name = torch.cuda.get_device_name(0) + for gpu_key, tflops in GPU_PEAK_TFLOPS_BF16.items(): + if gpu_key.lower() in device_name.lower(): + return tflops, device_name + return None, device_name + + +def compute_flops_readme(b, s, h, num_layers, vocab_size): + """README formula: assumes standard MHA + standard MLP (I=4H, 2 projections).""" + return (24 * b * s * h * h + 4 * b * s * s * h) * (3 * num_layers) + (6 * b * s * h * vocab_size) + + +def compute_flops_first_principles(b, s, h, num_layers, n_kv_heads, head_dim, ffn_hidden_size, vocab_size): + """First-principles FLOPs for GQA + SwiGLU architecture. + + Returns: + total_training_flops: Total FLOPs for one training step (3x forward). + breakdown: Per-component forward FLOPs for one layer. + lm_head_fwd: Forward FLOPs for the LM head. + """ + kv_dim = n_kv_heads * head_dim + + breakdown = { + "Q projection": 2 * b * s * h * h, + "K projection": 2 * b * s * h * kv_dim, + "V projection": 2 * b * s * h * kv_dim, + "O projection": 2 * b * s * h * h, + "Attn logits": 2 * b * s * s * h, + "Attn values": 2 * b * s * s * h, + "Gate projection": 2 * b * s * h * ffn_hidden_size, + "Up projection": 2 * b * s * h * ffn_hidden_size, + "Down projection": 2 * b * s * ffn_hidden_size * h, + } + + per_layer_fwd = sum(breakdown.values()) + lm_head_fwd = 2 * b * s * h * vocab_size + total_fwd = num_layers * per_layer_fwd + lm_head_fwd + total_training = 3 * total_fwd + + return total_training, breakdown, lm_head_fwd + + +def count_flops_with_model(model, input_ids): + """Count forward FLOPs using PyTorch's FlopCounterMode, return 3x for training.""" + flop_counter = FlopCounterMode(display=False) + with flop_counter: + model(input_ids=input_ids) + return flop_counter.get_total_flops() * 3 + + +def create_te_model_on_gpu(config): + """Create a TE model on GPU using the meta device + init_empty_weights pattern.""" + with torch.device("meta"): + model = NVLlamaForCausalLM(config) + model.init_empty_weights() + return model + + +def measure_step_time( + model, input_ids, num_warmup=10, num_timed=20, distributed=False, cp_context_fn=None, labels=None +): + """Measure average training step time (forward + backward). + + Args: + model: The model to benchmark. + input_ids: Input tensor. For CP with context_parallel, pass full-size tensors + (the cp_context_fn will shard them). + num_warmup: Number of warmup iterations (discarded). + num_timed: Number of timed iterations to average. + distributed: Whether to use dist.barrier() for synchronization. + cp_context_fn: Optional callable returning a context manager (e.g., context_parallel). + Called fresh each iteration since it shards/restores buffers. + labels: Optional labels tensor. If None, uses input_ids as labels. + """ + if labels is None: + labels = input_ids + + for _ in range(num_warmup): + ctx = cp_context_fn() if cp_context_fn else nullcontext() + with ctx: + output = model(input_ids=input_ids, labels=labels) + output.loss.backward() + model.zero_grad(set_to_none=True) + + if distributed: + dist.barrier() + torch.cuda.synchronize() + + times = [] + for _ in range(num_timed): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + torch.cuda.synchronize() + start.record() + + ctx = cp_context_fn() if cp_context_fn else nullcontext() + with ctx: + output = model(input_ids=input_ids, labels=labels) + output.loss.backward() + model.zero_grad(set_to_none=True) + + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end) / 1000.0) + + return sum(times) / len(times) + + +def split_for_cp_bshd(tensor, cp_rank, cp_size): + """Split a BSHD tensor for CP using the dual-chunk zigzag pattern. + + For cp_size=2: rank 0 gets chunks [0, 3], rank 1 gets chunks [1, 2]. + """ + if cp_size <= 1: + return tensor + total_chunks = 2 * cp_size + seq_len = tensor.size(1) + chunk_size = seq_len // total_chunks + chunk_indices = [cp_rank, total_chunks - cp_rank - 1] + slices = [tensor[:, idx * chunk_size : (idx + 1) * chunk_size] for idx in chunk_indices] + return torch.cat(slices, dim=1) + + +def measure_bus_bandwidth(device, world_size, num_iters=20, num_elements=10_000_000): + """Measure inter-GPU bus bandwidth using NCCL all-reduce.""" + if world_size <= 1: + return 0.0 + + tensor = torch.randn(num_elements, device=device, dtype=torch.bfloat16) + for _ in range(5): + dist.all_reduce(tensor) + torch.cuda.synchronize() + + start = time.perf_counter() + for _ in range(num_iters): + dist.all_reduce(tensor) + torch.cuda.synchronize() + elapsed = (time.perf_counter() - start) / num_iters + + data_bytes = tensor.nelement() * tensor.element_size() + bus_bw = 2 * (world_size - 1) / world_size * data_bytes / elapsed + return bus_bw / 1e9 # GB/s + + +def estimate_cp_comm_bytes(b, s, num_layers, n_kv_heads, head_dim, cp_size, dtype_bytes=2): + """Estimate total bytes transferred for CP ring attention per training step.""" + if cp_size <= 1: + return 0 + s_local = s // cp_size + kv_dim = n_kv_heads * head_dim + per_layer_fwd = (cp_size - 1) * b * s_local * 2 * kv_dim * dtype_bytes + return 2 * num_layers * per_layer_fwd + + +def format_flops(flops): + """Format FLOPs value with appropriate unit.""" + if flops >= 1e15: + return f"{flops / 1e15:.2f} P" + elif flops >= 1e12: + return f"{flops / 1e12:.2f} T" + elif flops >= 1e9: + return f"{flops / 1e9:.2f} G" + else: + return f"{flops:.2e}" + + +def format_bytes(num_bytes): + """Format bytes with appropriate unit.""" + if num_bytes >= 1e9: + return f"{num_bytes / 1e9:.2f} GB" + elif num_bytes >= 1e6: + return f"{num_bytes / 1e6:.2f} MB" + elif num_bytes >= 1e3: + return f"{num_bytes / 1e3:.2f} KB" + else: + return f"{num_bytes:.0f} B" + + +def cleanup_model(model): + """Delete a model and free GPU memory.""" + del model + gc.collect() + torch.cuda.empty_cache() + + +def print_breakdown(breakdown, lm_head_fwd, num_layers, total_flops, model_params): + """Print first-principles FLOPs breakdown.""" + print() + print("--- First Principles Breakdown (forward pass, per layer) ---") + per_layer_total = sum(breakdown.values()) + for component, flops_val in breakdown.items(): + pct = flops_val / per_layer_total * 100 + print(f" {component:<20} {format_flops(flops_val):>12} ({pct:>5.1f}%)") + total_fwd = num_layers * per_layer_total + lm_head_fwd + print(f" {'LM head':<20} {format_flops(lm_head_fwd):>12}") + print(f" {'Per-layer total':<20} {format_flops(per_layer_total):>12}") + print(f" {'All layers (x' + str(num_layers) + ')':<20} {format_flops(num_layers * per_layer_total):>12}") + print(f" {'Total forward':<20} {format_flops(total_fwd):>12}") + print(f" {'Total training (3x)':<20} {format_flops(total_flops):>12}") + print(f" {'Model params':<20} {model_params / 1e9:.2f}B") diff --git a/bionemo-recipes/recipes/llama3_native_te/compare_mfu_multigpu.py b/bionemo-recipes/recipes/llama3_native_te/compare_mfu_multigpu.py new file mode 100644 index 0000000000..06cfe8ed12 --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/compare_mfu_multigpu.py @@ -0,0 +1,292 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multi-GPU MFU comparison: TE CP vs HF CP head-to-head. + +Compares MFU with context parallelism for both TE (via set_context_parallel_group) +and HF (via PyTorch native context_parallel with ring attention). + +Usage: + cd bionemo-recipes/recipes/llama3_native_te + torchrun --nproc_per_node=2 compare_mfu_multigpu.py + torchrun --nproc_per_node=2 compare_mfu_multigpu.py --seq-len 32768 +""" + +import argparse +import json +import os +import sys +from pathlib import Path + +import torch +import torch.distributed as dist +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.tensor.experimental import context_parallel +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaForCausalLM + +from compare_mfu_common import ( + cleanup_model, + compute_flops_first_principles, + compute_flops_readme, + count_flops_with_model, + create_te_model_on_gpu, + detect_gpu_peak_tflops, + estimate_cp_comm_bytes, + format_bytes, + format_flops, + measure_bus_bandwidth, + measure_step_time, + print_breakdown, + split_for_cp_bshd, +) +from modeling_llama_te import NVLlamaConfig + + +def main(): + """Run multi-GPU MFU comparison: TE CP vs HF CP.""" + # --- Distributed setup --- + dist.init_process_group(backend="cpu:gloo,cuda:nccl") + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + torch.cuda.set_device(local_rank) + rank = dist.get_rank() + world_size = dist.get_world_size() + device = torch.device(f"cuda:{local_rank}") + + # --- Parse arguments --- + parser = argparse.ArgumentParser(description="Multi-GPU MFU comparison: TE CP vs HF CP") + parser.add_argument("--config-path", default="./model_configs/lingua-1B", help="Model config directory") + parser.add_argument("--batch-size", type=int, default=1, help="Micro batch size per GPU") + parser.add_argument("--seq-len", type=int, default=16384, help="Total sequence length (split across CP ranks)") + parser.add_argument("--cp-size", type=int, default=None, help="CP size (default: world_size)") + parser.add_argument("--peak-tflops", type=float, default=None, help="Override GPU peak bf16 TFLOPS") + parser.add_argument("--warmup-steps", type=int, default=10, help="Warmup iterations before timing") + parser.add_argument("--timed-steps", type=int, default=20, help="Timed iterations to average") + args = parser.parse_args() + + cp_size = args.cp_size or world_size + dp_size = world_size // cp_size + if dp_size * cp_size != world_size: + if rank == 0: + print(f"ERROR: dp_size ({dp_size}) * cp_size ({cp_size}) != world_size ({world_size})") + dist.destroy_process_group() + sys.exit(1) + + # --- Device mesh --- + device_mesh = init_device_mesh("cuda", mesh_shape=(dp_size, cp_size), mesh_dim_names=("dp", "cp")) + cp_group = device_mesh["cp"].get_group() + cp_ranks = dist.get_process_group_ranks(cp_group) + cp_rank = device_mesh["cp"].get_local_rank() + + # --- Load model config --- + config_path = Path(args.config_path) / "config.json" + with open(config_path) as f: + config_dict = json.load(f) + + b = args.batch_size + s = args.seq_len + h = config_dict["hidden_size"] + num_layers = config_dict["num_hidden_layers"] + vocab_size = config_dict["vocab_size"] + n_kv_heads = config_dict["num_key_value_heads"] + n_heads = config_dict["num_attention_heads"] + head_dim = h // n_heads + ffn_hidden_size = config_dict["intermediate_size"] + s_local = s // cp_size + + if s % (2 * cp_size) != 0: + if rank == 0: + print(f"ERROR: seq_len ({s}) must be divisible by {2 * cp_size} (2 * cp_size)") + dist.destroy_process_group() + sys.exit(1) + + # --- GPU detection --- + if args.peak_tflops: + peak_tflops = args.peak_tflops + device_name = torch.cuda.get_device_name(0) + else: + peak_tflops, device_name = detect_gpu_peak_tflops() + if peak_tflops is None: + if rank == 0: + print(f"ERROR: Could not auto-detect GPU peak TFLOPS for: {device_name}") + dist.destroy_process_group() + sys.exit(1) + + peak_flops_per_sec = peak_tflops * 1e12 + + if rank == 0: + print(f"GPU: {world_size}x {device_name} (Peak: {peak_tflops:.1f} TFLOPS bf16 each)") + print(f"Parallelism: dp={dp_size}, cp={cp_size} ({world_size} GPUs)") + print( + f"Config: H={h}, L={num_layers}, n_heads={n_heads}, n_kv_heads={n_kv_heads}," + f" head_dim={head_dim}, I={ffn_hidden_size}, V={vocab_size}" + ) + print(f"Batch: B={b}, S={s} (S_local={s_local} per GPU)") + print() + + # ========================================================================= + # Table 1: FLOPs Counting + # ========================================================================= + total_flops_readme = compute_flops_readme(b, s, h, num_layers, vocab_size) + total_flops_fp, breakdown, lm_head_fwd = compute_flops_first_principles( + b, s, h, num_layers, n_kv_heads, head_dim, ffn_hidden_size, vocab_size + ) + + if rank == 0: + print("Counting FLOPs with HF model (meta device)...") + hf_config_meta = LlamaConfig.from_pretrained(args.config_path) + hf_config_meta._attn_implementation = "eager" + hf_config_meta.max_position_embeddings = max(hf_config_meta.max_position_embeddings, s) + with torch.device("meta"): + hf_model_meta = LlamaForCausalLM(hf_config_meta) + meta_input_ids = torch.randint(0, vocab_size, (b, s), device="meta") + total_flops_hf_counter = count_flops_with_model(hf_model_meta, meta_input_ids) + del hf_model_meta + if rank == 0: + print(f" HF FlopCounter: {format_flops(total_flops_hf_counter)} (training, full batch)") + + per_gpu_flops = total_flops_fp // world_size + + # ========================================================================= + # Table 2: MFU — TE CP vs HF CP + # ========================================================================= + + # --- HF with PyTorch native CP (run first to avoid NCCL memory fragmentation) --- + if rank == 0: + print(f"\n[1/2] HF model with PyTorch native CP={cp_size} (S={s})...") + hf_config_gpu = LlamaConfig.from_pretrained(args.config_path) + hf_config_gpu._attn_implementation = "sdpa" # Required for context_parallel + hf_config_gpu.max_position_embeddings = max(hf_config_gpu.max_position_embeddings, s) + hf_model = LlamaForCausalLM(hf_config_gpu).to(dtype=torch.bfloat16, device=device) + hf_model.train() + + # Full-size inputs — context_parallel shards them each iteration + hf_full_ids = torch.randint(0, vocab_size, (b, s), device=device) + hf_full_labels = hf_full_ids.clone() + cp_mesh = device_mesh["cp"] + + def make_hf_cp_ctx(_ids=hf_full_ids, _labels=hf_full_labels): + return context_parallel(cp_mesh, buffers=(_ids, _labels), buffer_seq_dims=(1, 1)) + + if rank == 0: + print(f"Measuring HF CP step time ({args.warmup_steps} warmup + {args.timed_steps} timed)...") + hf_cp_time = measure_step_time( + hf_model, + hf_full_ids, + args.warmup_steps, + args.timed_steps, + distributed=True, + cp_context_fn=make_hf_cp_ctx, + labels=hf_full_labels, + ) + if rank == 0: + print(f" HF CP step time: {hf_cp_time:.4f}s") + cleanup_model(hf_model) + del hf_full_ids, hf_full_labels + + # --- TE with CP via set_context_parallel_group --- + if rank == 0: + print(f"\n[2/2] TE model with CP={cp_size} (S={s})...") + te_config = NVLlamaConfig.from_pretrained( + args.config_path, + dtype=torch.bfloat16, + attn_input_format="bshd", + self_attn_mask_type="causal", + ) + te_config.max_position_embeddings = max(te_config.max_position_embeddings, s) + te_model = create_te_model_on_gpu(te_config) + for layer in te_model.model.layers: + layer.set_context_parallel_group(cp_group, cp_ranks, torch.cuda.Stream()) + te_model.train() + + full_ids = torch.randint(0, vocab_size, (b, s), device=device) + te_local_ids = split_for_cp_bshd(full_ids, cp_rank, cp_size) + + if rank == 0: + print(f"Measuring TE CP step time ({args.warmup_steps} warmup + {args.timed_steps} timed)...") + te_cp_time = measure_step_time(te_model, te_local_ids, args.warmup_steps, args.timed_steps, distributed=True) + model_params = sum(p.numel() for p in te_model.parameters()) + if rank == 0: + print(f" TE CP step time: {te_cp_time:.4f}s") + cleanup_model(te_model) + + # ========================================================================= + # Communication overhead + # ========================================================================= + if rank == 0: + print("\nMeasuring inter-GPU bandwidth...") + bus_bw_gbps = measure_bus_bandwidth(device, world_size) + if rank == 0: + print(f" Bus bandwidth: {bus_bw_gbps:.1f} GB/s") + + cp_comm_bytes = estimate_cp_comm_bytes(b, s, num_layers, n_kv_heads, head_dim, cp_size) + cp_comm_time = cp_comm_bytes / (bus_bw_gbps * 1e9) if bus_bw_gbps > 0 else 0.0 + + # ========================================================================= + # Print results (rank 0 only) + # ========================================================================= + if rank == 0: + print() + print("=" * 75) + print(f"MFU Comparison: Lingua-1B (B={b}, S={s}, bf16, CP={cp_size})") + print(f"GPU: {world_size}x {device_name} (Peak: {peak_tflops:.1f} TFLOPS bf16 each)") + print("=" * 75) + + # --- Table 1 --- + print() + print("--- Table 1: FLOPs Counting (per training step) ---") + hdr1 = f"{'Method':<24} {'Total FLOPs':>14} {'Per-GPU FLOPs':>14}" + print(hdr1) + print("-" * len(hdr1)) + for name, total in [ + ("README Formula", total_flops_readme), + ("First Principles", total_flops_fp), + ("FlopCounter (HF)", total_flops_hf_counter), + ]: + print(f"{name:<24} {format_flops(total):>14} {format_flops(total // world_size):>14}") + + # --- Table 2 --- + print() + print(f"--- Table 2: MFU (per GPU, CP={cp_size}) ---") + hdr2 = f"{'Model':<16} {'Per-GPU FLOPs':>14} {'Step (s)':>9} {'TFLOPS/s':>9} {'MFU':>7}" + print(hdr2) + print("-" * len(hdr2)) + + for name, step_time in [("TE (CP)", te_cp_time), ("HF (CP)", hf_cp_time)]: + tflops = per_gpu_flops / step_time / 1e12 + mfu = per_gpu_flops / step_time / peak_flops_per_sec * 100 + print(f"{name:<16} {format_flops(per_gpu_flops):>14} {step_time:>8.3f}s {tflops:>8.2f} {mfu:>6.1f}%") + + print() + print(f"TE vs HF speedup: {hf_cp_time / te_cp_time:.2f}x") + + # --- Communication overhead --- + print() + print("--- Communication Overhead ---") + print(f"Measured bus bandwidth: {bus_bw_gbps:.1f} GB/s") + print(f"CP ring attention (cp={cp_size}): {format_bytes(cp_comm_bytes):>12}/step (~{cp_comm_time:.4f}s)") + if te_cp_time > 0: + print(f" As % of TE step: {cp_comm_time / te_cp_time * 100:.1f}%") + if hf_cp_time > 0: + print(f" As % of HF step: {cp_comm_time / hf_cp_time * 100:.1f}%") + + # --- Breakdown --- + print_breakdown(breakdown, lm_head_fwd, num_layers, total_flops_fp, model_params) + + dist.destroy_process_group() + + +if __name__ == "__main__": + main() From 7d60786c89d27609032332dcb4a0742c48f27e83 Mon Sep 17 00:00:00 2001 From: Gagan Kaushik Date: Mon, 6 Apr 2026 17:59:58 +0000 Subject: [PATCH 02/24] Add CP golden value tests, fix RoPE bug, and improve MFU scripts - Add compare_mfu_validate.py: golden value tests comparing CP vs non-CP execution for both TE and HF models, validating loss, logits (cosine sim > 0.99), and gradients (cosine sim > 0.8) following the pattern from models/llama3/tests/test_cp_bshd.py - Fix silent RoPE bug in HF CP: position_ids were not passed to context_parallel buffers, causing each rank to auto-generate [0..S/CP-1] instead of correct global positions - Switch FLOPs counting from eager to SDPA attention for consistency with the actual training implementation (identical counts on meta device) - Add exact FLOPs output (full integers with commas) alongside abbreviated values in both single-GPU and multi-GPU scripts - Switch bandwidth measurement from all-reduce to all-gather for more accurate pure data movement measurement matching CP ring attention - Add position_ids and max_length_q/k support to measure_step_time Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Gagan Kaushik --- .../recipes/llama3_native_te/compare_mfu.py | 7 +- .../llama3_native_te/compare_mfu_common.py | 44 +- .../llama3_native_te/compare_mfu_multigpu.py | 39 +- .../llama3_native_te/compare_mfu_validate.py | 380 ++++++++++++++++++ 4 files changed, 448 insertions(+), 22 deletions(-) create mode 100644 bionemo-recipes/recipes/llama3_native_te/compare_mfu_validate.py diff --git a/bionemo-recipes/recipes/llama3_native_te/compare_mfu.py b/bionemo-recipes/recipes/llama3_native_te/compare_mfu.py index f4ce8b1d22..a923ed36de 100644 --- a/bionemo-recipes/recipes/llama3_native_te/compare_mfu.py +++ b/bionemo-recipes/recipes/llama3_native_te/compare_mfu.py @@ -41,6 +41,7 @@ create_te_model_on_gpu, detect_gpu_peak_tflops, format_flops, + format_flops_exact, measure_step_time, print_breakdown, ) @@ -104,7 +105,7 @@ def main(): print("Counting FLOPs with HF model (meta device)...") hf_config = LlamaConfig.from_pretrained(args.config_path) - hf_config._attn_implementation = "eager" + hf_config._attn_implementation = "sdpa" with torch.device("meta"): hf_model_meta = LlamaForCausalLM(hf_config) meta_input_ids = torch.randint(0, vocab_size, (b, s), device="meta") @@ -152,7 +153,7 @@ def main(): # --- Table 1 --- print() print("--- Table 1: FLOPs Counting (per training step) ---") - hdr1 = f"{'Method':<24} {'FLOPs/step':>14}" + hdr1 = f"{'Method':<24} {'FLOPs/step':>14} {'Exact FLOPs':>30}" print(hdr1) print("-" * len(hdr1)) for name, flops in [ @@ -160,7 +161,7 @@ def main(): ("First Principles", total_flops_fp), ("FlopCounter (HF)", total_flops_hf_counter), ]: - print(f"{name:<24} {format_flops(flops):>14}") + print(f"{name:<24} {format_flops(flops):>14} {format_flops_exact(flops):>30}") # --- Table 2 --- print() diff --git a/bionemo-recipes/recipes/llama3_native_te/compare_mfu_common.py b/bionemo-recipes/recipes/llama3_native_te/compare_mfu_common.py index c19ca15a71..e056eb34cf 100644 --- a/bionemo-recipes/recipes/llama3_native_te/compare_mfu_common.py +++ b/bionemo-recipes/recipes/llama3_native_te/compare_mfu_common.py @@ -108,7 +108,15 @@ def create_te_model_on_gpu(config): def measure_step_time( - model, input_ids, num_warmup=10, num_timed=20, distributed=False, cp_context_fn=None, labels=None + model, + input_ids, + num_warmup=10, + num_timed=20, + distributed=False, + cp_context_fn=None, + labels=None, + position_ids=None, + **extra_fwd_kwargs, ): """Measure average training step time (forward + backward). @@ -120,16 +128,24 @@ def measure_step_time( num_timed: Number of timed iterations to average. distributed: Whether to use dist.barrier() for synchronization. cp_context_fn: Optional callable returning a context manager (e.g., context_parallel). - Called fresh each iteration since it shards/restores buffers. + Called fresh each iteration. The context manager shards buffers in-place on entry + and restores them on exit, so reusing the same tensors across iterations is safe. labels: Optional labels tensor. If None, uses input_ids as labels. + position_ids: Optional position_ids tensor. Critical for HF models with CP to ensure + correct RoPE positions per rank. + **extra_fwd_kwargs: Additional kwargs passed to model forward (e.g., max_length_q/k for TE CP). """ if labels is None: labels = input_ids + fwd_kwargs = {"input_ids": input_ids, "labels": labels, **extra_fwd_kwargs} + if position_ids is not None: + fwd_kwargs["position_ids"] = position_ids + for _ in range(num_warmup): ctx = cp_context_fn() if cp_context_fn else nullcontext() with ctx: - output = model(input_ids=input_ids, labels=labels) + output = model(**fwd_kwargs) output.loss.backward() model.zero_grad(set_to_none=True) @@ -147,8 +163,9 @@ def measure_step_time( ctx = cp_context_fn() if cp_context_fn else nullcontext() with ctx: - output = model(input_ids=input_ids, labels=labels) + output = model(**fwd_kwargs) output.loss.backward() + # no need to unshard output here since we're only measuring timing model.zero_grad(set_to_none=True) end.record() @@ -174,23 +191,29 @@ def split_for_cp_bshd(tensor, cp_rank, cp_size): def measure_bus_bandwidth(device, world_size, num_iters=20, num_elements=10_000_000): - """Measure inter-GPU bus bandwidth using NCCL all-reduce.""" + """Measure inter-GPU bus bandwidth using NCCL all-gather (pure data movement). + + Uses all-gather rather than all-reduce since CP ring attention is purely communication + (send/recv of KV chunks), not reduction. all-gather better reflects actual P2P bandwidth. + """ if world_size <= 1: return 0.0 tensor = torch.randn(num_elements, device=device, dtype=torch.bfloat16) + output = [torch.zeros_like(tensor) for _ in range(world_size)] for _ in range(5): - dist.all_reduce(tensor) + dist.all_gather(output, tensor) torch.cuda.synchronize() start = time.perf_counter() for _ in range(num_iters): - dist.all_reduce(tensor) + dist.all_gather(output, tensor) torch.cuda.synchronize() elapsed = (time.perf_counter() - start) / num_iters data_bytes = tensor.nelement() * tensor.element_size() - bus_bw = 2 * (world_size - 1) / world_size * data_bytes / elapsed + # all-gather bus bandwidth: each rank sends data_bytes to (n-1) peers + bus_bw = (world_size - 1) * data_bytes / elapsed return bus_bw / 1e9 # GB/s @@ -216,6 +239,11 @@ def format_flops(flops): return f"{flops:.2e}" +def format_flops_exact(flops): + """Format FLOPs as the full integer with commas.""" + return f"{int(flops):,}" + + def format_bytes(num_bytes): """Format bytes with appropriate unit.""" if num_bytes >= 1e9: diff --git a/bionemo-recipes/recipes/llama3_native_te/compare_mfu_multigpu.py b/bionemo-recipes/recipes/llama3_native_te/compare_mfu_multigpu.py index 06cfe8ed12..b0a8a3642a 100644 --- a/bionemo-recipes/recipes/llama3_native_te/compare_mfu_multigpu.py +++ b/bionemo-recipes/recipes/llama3_native_te/compare_mfu_multigpu.py @@ -47,6 +47,7 @@ estimate_cp_comm_bytes, format_bytes, format_flops, + format_flops_exact, measure_bus_bandwidth, measure_step_time, print_breakdown, @@ -147,7 +148,7 @@ def main(): if rank == 0: print("Counting FLOPs with HF model (meta device)...") hf_config_meta = LlamaConfig.from_pretrained(args.config_path) - hf_config_meta._attn_implementation = "eager" + hf_config_meta._attn_implementation = "sdpa" hf_config_meta.max_position_embeddings = max(hf_config_meta.max_position_embeddings, s) with torch.device("meta"): hf_model_meta = LlamaForCausalLM(hf_config_meta) @@ -163,7 +164,9 @@ def main(): # Table 2: MFU — TE CP vs HF CP # ========================================================================= - # --- HF with PyTorch native CP (run first to avoid NCCL memory fragmentation) --- + cp_mesh = device_mesh["cp"] + + # --- HF with PyTorch native CP --- if rank == 0: print(f"\n[1/2] HF model with PyTorch native CP={cp_size} (S={s})...") hf_config_gpu = LlamaConfig.from_pretrained(args.config_path) @@ -172,13 +175,15 @@ def main(): hf_model = LlamaForCausalLM(hf_config_gpu).to(dtype=torch.bfloat16, device=device) hf_model.train() - # Full-size inputs — context_parallel shards them each iteration + # Full-size inputs — context_parallel shards them each iteration. + # position_ids are critical: without them, HF auto-generates [0..S/CP-1] per rank, + # giving WRONG RoPE positions. We create full [0..S-1] and shard alongside input_ids. hf_full_ids = torch.randint(0, vocab_size, (b, s), device=device) hf_full_labels = hf_full_ids.clone() - cp_mesh = device_mesh["cp"] + hf_full_pos = torch.arange(s, device=device).unsqueeze(0).expand(b, -1) - def make_hf_cp_ctx(_ids=hf_full_ids, _labels=hf_full_labels): - return context_parallel(cp_mesh, buffers=(_ids, _labels), buffer_seq_dims=(1, 1)) + def make_hf_cp_ctx(_ids=hf_full_ids, _labels=hf_full_labels, _pos=hf_full_pos): + return context_parallel(cp_mesh, buffers=(_ids, _labels, _pos), buffer_seq_dims=(1, 1, 1)) if rank == 0: print(f"Measuring HF CP step time ({args.warmup_steps} warmup + {args.timed_steps} timed)...") @@ -190,11 +195,12 @@ def make_hf_cp_ctx(_ids=hf_full_ids, _labels=hf_full_labels): distributed=True, cp_context_fn=make_hf_cp_ctx, labels=hf_full_labels, + position_ids=hf_full_pos, ) if rank == 0: print(f" HF CP step time: {hf_cp_time:.4f}s") cleanup_model(hf_model) - del hf_full_ids, hf_full_labels + del hf_full_ids, hf_full_labels, hf_full_pos # --- TE with CP via set_context_parallel_group --- if rank == 0: @@ -210,14 +216,22 @@ def make_hf_cp_ctx(_ids=hf_full_ids, _labels=hf_full_labels): for layer in te_model.model.layers: layer.set_context_parallel_group(cp_group, cp_ranks, torch.cuda.Stream()) te_model.train() + model_params = sum(p.numel() for p in te_model.parameters()) full_ids = torch.randint(0, vocab_size, (b, s), device=device) te_local_ids = split_for_cp_bshd(full_ids, cp_rank, cp_size) if rank == 0: print(f"Measuring TE CP step time ({args.warmup_steps} warmup + {args.timed_steps} timed)...") - te_cp_time = measure_step_time(te_model, te_local_ids, args.warmup_steps, args.timed_steps, distributed=True) - model_params = sum(p.numel() for p in te_model.parameters()) + te_cp_time = measure_step_time( + te_model, + te_local_ids, + args.warmup_steps, + args.timed_steps, + distributed=True, + max_length_q=s, + max_length_k=s, + ) if rank == 0: print(f" TE CP step time: {te_cp_time:.4f}s") cleanup_model(te_model) @@ -247,7 +261,7 @@ def make_hf_cp_ctx(_ids=hf_full_ids, _labels=hf_full_labels): # --- Table 1 --- print() print("--- Table 1: FLOPs Counting (per training step) ---") - hdr1 = f"{'Method':<24} {'Total FLOPs':>14} {'Per-GPU FLOPs':>14}" + hdr1 = f"{'Method':<24} {'Total FLOPs':>14} {'Per-GPU FLOPs':>14} {'Exact Total FLOPs':>30}" print(hdr1) print("-" * len(hdr1)) for name, total in [ @@ -255,7 +269,10 @@ def make_hf_cp_ctx(_ids=hf_full_ids, _labels=hf_full_labels): ("First Principles", total_flops_fp), ("FlopCounter (HF)", total_flops_hf_counter), ]: - print(f"{name:<24} {format_flops(total):>14} {format_flops(total // world_size):>14}") + print( + f"{name:<24} {format_flops(total):>14} {format_flops(total // world_size):>14}" + f" {format_flops_exact(total):>30}" + ) # --- Table 2 --- print() diff --git a/bionemo-recipes/recipes/llama3_native_te/compare_mfu_validate.py b/bionemo-recipes/recipes/llama3_native_te/compare_mfu_validate.py new file mode 100644 index 0000000000..e62b7fe7f5 --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/compare_mfu_validate.py @@ -0,0 +1,380 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Golden value tests for context parallelism correctness with FSDP2. + +Validates that FSDP2 + CP produces equivalent results to non-CP execution. +Uses the same FSDP2 + CP setup as compare_mfu_multigpu.py and train_fsdp2_cp.py. + +Strategy: +1. Init distributed, create FSDP2 model with CP +2. Gather full weights to rank 0 for non-CP baseline +3. Rank 0 runs non-CP baseline with identical weights +4. All ranks run CP forward+backward +5. Compare loss, logits (cosine sim), gradients (cosine sim) + +Tests both TE (set_context_parallel_group) and HF (PyTorch native context_parallel). + +Usage: + cd bionemo-recipes/recipes/llama3_native_te + torchrun --nproc_per_node=2 compare_mfu_validate.py +""" + +import argparse +import gc +import json +import os +import sys +from pathlib import Path + +import torch +import torch.distributed as dist +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.tensor.experimental import context_parallel +from torch.distributed.tensor.experimental._attention import context_parallel_unshard +from torch.distributed.tensor.experimental._context_parallel._load_balancer import _HeadTailLoadBalancer +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaForCausalLM + +from collator import _split_batch_by_cp_rank +from compare_mfu_common import create_te_model_on_gpu +from modeling_llama_te import NVLlamaConfig + + +SEED = 42 +LOSS_ATOL = 0.5 +LOSS_RTOL = 0.25 +LOGITS_COSINE_MIN = 0.99 +GRAD_COSINE_MIN = 0.8 + + +def seed_everything(seed): + """Set all random seeds for reproducibility.""" + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_dummy_data(vocab_size, batch_size=2, seq_length=64): + """Create deterministic dummy data for golden value tests.""" + seed_everything(SEED + 1000) + input_ids = torch.randint(0, vocab_size, (batch_size, seq_length)) + return {"input_ids": input_ids, "labels": input_ids.clone()} + + +def reconstruct_logits_from_cp(logits_list, full_seq_len, cp_world_size): + """Reconstruct full-sequence logits from TE CP-sharded chunks (zigzag pattern).""" + batch_size, _, vocab_size = logits_list[0].shape + total_chunks = 2 * cp_world_size + chunk_size = full_seq_len // total_chunks + reconstructed = torch.zeros( + (batch_size, full_seq_len, vocab_size), dtype=logits_list[0].dtype, device=logits_list[0].device + ) + for batch_idx in range(batch_size): + for cp_idx, logits_shard in enumerate(logits_list): + chunk_indices = [cp_idx, total_chunks - cp_idx - 1] + for chunk_pos, chunk_idx in enumerate(chunk_indices): + start_idx = chunk_idx * chunk_size + end_idx = start_idx + chunk_size + shard_start = chunk_pos * chunk_size + shard_end = shard_start + chunk_size + reconstructed[batch_idx, start_idx:end_idx, :] = logits_shard[batch_idx, shard_start:shard_end, :] + return reconstructed + + +def capture_gradients(model, layer_accessor): + """Capture gradients from sample layers for comparison.""" + gradients = {} + for i, layer in enumerate(layer_accessor(model)): + for name, param in layer.named_parameters(): + if param.grad is not None: + gradients[f"layer_{i}.{name}"] = param.grad.detach().clone().cpu() + return gradients + + +def compare_results(name, ref_loss, ref_logits, ref_grads, cp_loss, cp_logits, cp_grads, rank): + """Compare CP results against non-distributed reference on rank 0.""" + if rank != 0: + return True + all_passed = True + + try: + torch.testing.assert_close(cp_loss.cpu(), ref_loss.cpu(), atol=LOSS_ATOL, rtol=LOSS_RTOL) + print(f" [{name}] Loss: PASS (ref={ref_loss.item():.6f}, cp={cp_loss.item():.6f})") + except AssertionError as e: + print(f" [{name}] Loss: FAIL - {e}") + all_passed = False + + if ref_logits is not None and cp_logits is not None: + assert cp_logits.shape == ref_logits.shape, f"Shape mismatch: {cp_logits.shape} vs {ref_logits.shape}" + cosine_sim = torch.nn.functional.cosine_similarity( + cp_logits.flatten().float().cuda(), ref_logits.flatten().float().cuda(), dim=0 + ) + passed = cosine_sim > LOGITS_COSINE_MIN + print( + f" [{name}] Logits cosine sim: {'PASS' if passed else 'FAIL'} ({cosine_sim:.6f}, min={LOGITS_COSINE_MIN})" + ) + if not passed: + all_passed = False + + if ref_grads and cp_grads: + for key in ref_grads: + if key in cp_grads: + cosine_sim = torch.nn.functional.cosine_similarity( + cp_grads[key].flatten().float(), ref_grads[key].flatten().float(), dim=0 + ) + passed = cosine_sim > GRAD_COSINE_MIN + print(f" [{name}] Grad {key}: {'PASS' if passed else 'FAIL'} (cosine={cosine_sim:.4f})") + if not passed: + all_passed = False + return all_passed + + +def validate_te(te_config, vocab_size, b, s, device, local_rank, device_mesh, cp_group, cp_rank, cp_size, rank): + """Validate TE model: DDP + CP vs non-CP with identical weights.""" + if rank == 0: + print(f"Test 1: TE model (DDP + CP={cp_size} vs non-CP baseline)") + + # DDP process group for gradient synchronization (matches test_cp_bshd.py) + group_dp_cp = device_mesh[("dp", "cp")]._flatten("dp_cp").get_group() + + # Create CP model on all ranks with identical weights + seed_everything(SEED) + te_model = create_te_model_on_gpu(te_config) + for param in te_model.parameters(): + dist.broadcast(param.data, src=0) + + # DDP + CP (matches test_cp_bshd.py pattern) + te_model = torch.nn.parallel.DistributedDataParallel( + te_model, + device_ids=[local_rank], + output_device=local_rank, + process_group=group_dp_cp, + ) + for layer in te_model.module.model.layers: + layer.set_context_parallel_group(cp_group, dist.get_process_group_ranks(cp_group), torch.cuda.Stream()) + te_model.train() + + # --- Non-CP baseline on rank 0 (same weights) --- + ref_loss = ref_logits = None + ref_grads = {} + if rank == 0: + seed_everything(SEED) + ref_model = create_te_model_on_gpu(te_config) + # Copy weights from the CP model to ensure exact match + ref_model.load_state_dict( + {k: v for k, v in te_model.state_dict().items() if not k.endswith("_extra_state")}, strict=False + ) + ref_model.train() + batch = get_dummy_data(vocab_size, b, s) + batch_cuda = {k: v.to(device) for k, v in batch.items()} + ref_out = ref_model(**batch_cuda) + ref_out.loss.backward() + ref_loss = ref_out.loss.detach().clone().cpu() + ref_logits = ref_out.logits.detach().clone().cpu() + ref_grads = capture_gradients( + ref_model, + lambda m: [ + m.model.layers[0].self_attention.core_attention, + m.model.layers[0].self_attention.layernorm_qkv, + ], + ) + print(f" Baseline loss: {ref_loss.item():.6f}") + del ref_model, ref_out, batch_cuda + gc.collect() + torch.cuda.empty_cache() + dist.barrier() + + # --- CP forward+backward --- + batch = get_dummy_data(vocab_size, b, s) + batch_cuda = {k: v.detach().to(device) for k, v in batch.items()} + batch_shard = dict( + zip( + ["input_ids", "labels"], + _split_batch_by_cp_rank( + None, + batch_cuda["input_ids"], + batch_cuda["labels"], + qvk_format="bshd", + cp_rank=cp_rank, + cp_world_size=cp_size, + ), + ) + ) + batch_shard["max_length_q"] = batch_shard["max_length_k"] = s + + dist.barrier() + te_out = te_model(**batch_shard) + + # All-gather losses + losses = [torch.zeros_like(te_out.loss) for _ in range(cp_size)] + dist.all_gather(losses, te_out.loss, group=cp_group) + cp_loss = torch.mean(torch.stack(losses)).cpu() if rank == 0 else None + + # All-gather + reconstruct logits + logits_list = [torch.zeros_like(te_out.logits.contiguous()) for _ in range(cp_size)] + dist.all_gather(logits_list, te_out.logits.contiguous(), group=cp_group) + cp_logits = reconstruct_logits_from_cp(logits_list, s, cp_size).cpu() if rank == 0 else None + + te_out.loss.backward() # DDP all-reduces gradients automatically + cp_grads = capture_gradients( + te_model.module, + lambda m: [m.model.layers[0].self_attention.core_attention, m.model.layers[0].self_attention.layernorm_qkv], + ) + + passed = compare_results("TE CP", ref_loss, ref_logits, ref_grads, cp_loss, cp_logits, cp_grads, rank) + del te_model, te_out + gc.collect() + torch.cuda.empty_cache() + dist.barrier() + return passed + + +def validate_hf(hf_config, vocab_size, b, s, device, local_rank, device_mesh, cp_group, cp_rank, cp_size, rank): + """Validate HF model: DDP + PyTorch native CP vs non-CP with identical weights.""" + if rank == 0: + print(f"\nTest 2: HF model (DDP + PyTorch native CP={cp_size} vs non-CP baseline)") + + group_dp_cp = device_mesh[("dp", "cp")]._flatten("dp_cp").get_group() + + seed_everything(SEED + 100) + hf_model = LlamaForCausalLM(hf_config).to(dtype=torch.bfloat16, device=device) + for param in hf_model.parameters(): + dist.broadcast(param.data, src=0) + + # DDP for gradient synchronization + hf_model = torch.nn.parallel.DistributedDataParallel( + hf_model, + device_ids=[local_rank], + output_device=local_rank, + process_group=group_dp_cp, + ) + hf_model.train() + + # --- Non-CP baseline on rank 0 (same weights, already on GPU) --- + ref_loss = ref_logits = None + ref_grads = {} + if rank == 0: + ref_model = LlamaForCausalLM(hf_config).to(dtype=torch.bfloat16, device=device) + ref_model.load_state_dict(hf_model.module.state_dict()) + ref_model.train() + batch = get_dummy_data(vocab_size, b, s) + batch_cuda = {k: v.to(device) for k, v in batch.items()} + pos_ids = torch.arange(s, device=device).unsqueeze(0).expand(b, -1) + ref_out = ref_model(position_ids=pos_ids, **batch_cuda) + ref_out.loss.backward() + ref_loss = ref_out.loss.detach().clone().cpu() + ref_logits = ref_out.logits.detach().clone().cpu() + ref_grads = capture_gradients(ref_model, lambda m: [m.model.layers[0].self_attn, m.model.layers[0].mlp]) + print(f" Baseline loss: {ref_loss.item():.6f}") + del ref_model, ref_out, batch_cuda + gc.collect() + torch.cuda.empty_cache() + dist.barrier() + + # --- CP forward+backward --- + batch = get_dummy_data(vocab_size, b, s) + hf_full_ids = batch["input_ids"].to(device) + hf_full_labels = batch["labels"].to(device) + hf_full_pos = torch.arange(s, device=device).unsqueeze(0).expand(b, -1) + + cp_mesh = device_mesh["cp"] + with context_parallel(cp_mesh, buffers=(hf_full_ids, hf_full_labels, hf_full_pos), buffer_seq_dims=(1, 1, 1)): + hf_out = hf_model(input_ids=hf_full_ids, labels=hf_full_labels, position_ids=hf_full_pos) + cp_loss_local = hf_out.loss.detach().clone() + cp_logits_local = hf_out.logits.detach().clone() + hf_out.loss.backward() # DDP all-reduces gradients automatically + cp_grads = capture_gradients(hf_model.module, lambda m: [m.model.layers[0].self_attn, m.model.layers[0].mlp]) + + # All-gather losses + losses = [torch.zeros_like(cp_loss_local) for _ in range(cp_size)] + dist.all_gather(losses, cp_loss_local, group=cp_group) + cp_loss = torch.mean(torch.stack(losses)).cpu() if rank == 0 else None + + # Reconstruct logits with load balancer + load_balancer = _HeadTailLoadBalancer(seq_length=s, world_size=cp_size, device=device) + (cp_logits_full,) = context_parallel_unshard(cp_mesh, [cp_logits_local], [1], load_balancer=load_balancer) + cp_logits = cp_logits_full.cpu() if rank == 0 else None + + passed = compare_results("HF CP", ref_loss, ref_logits, ref_grads, cp_loss, cp_logits, cp_grads, rank) + del hf_model, hf_out + gc.collect() + torch.cuda.empty_cache() + dist.barrier() + return passed + + +def main(): + """Run golden value tests for CP correctness with FSDP2.""" + dist.init_process_group(backend="cpu:gloo,cuda:nccl") + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + torch.cuda.set_device(local_rank) + rank = dist.get_rank() + world_size = dist.get_world_size() + device = torch.device(f"cuda:{local_rank}") + + parser = argparse.ArgumentParser(description="Golden value tests for CP with FSDP2") + parser.add_argument("--config-path", default="./model_configs/lingua-1B", help="Model config directory") + parser.add_argument("--batch-size", type=int, default=2, help="Micro batch size") + parser.add_argument("--seq-len", type=int, default=64, help="Sequence length") + args = parser.parse_args() + + config_dict = json.loads(Path(args.config_path, "config.json").read_text()) + vocab_size = config_dict["vocab_size"] + b, s = args.batch_size, args.seq_len + + cp_size = world_size + dp_size = 1 + device_mesh = init_device_mesh("cuda", mesh_shape=(dp_size, cp_size), mesh_dim_names=("dp", "cp")) + + cp_group = device_mesh["cp"].get_group() + cp_rank = device_mesh["cp"].get_local_rank() + + if s % (2 * cp_size) != 0: + if rank == 0: + print(f"ERROR: seq_len ({s}) must be divisible by {2 * cp_size}") + dist.destroy_process_group() + sys.exit(1) + + if rank == 0: + print(f"Golden Value Tests (FSDP2 + CP): B={b}, S={s}, CP={cp_size}") + print() + + te_config = NVLlamaConfig.from_pretrained( + args.config_path, + dtype=torch.bfloat16, + attn_input_format="bshd", + self_attn_mask_type="causal", + ) + hf_config = LlamaConfig.from_pretrained(args.config_path) + hf_config._attn_implementation = "sdpa" + + te_passed = validate_te( + te_config, vocab_size, b, s, device, local_rank, device_mesh, cp_group, cp_rank, cp_size, rank + ) + hf_passed = validate_hf( + hf_config, vocab_size, b, s, device, local_rank, device_mesh, cp_group, cp_rank, cp_size, rank + ) + + if rank == 0: + print() + print(f"Summary: TE [{'PASS' if te_passed else 'FAIL'}], HF [{'PASS' if hf_passed else 'FAIL'}]") + if not (te_passed and hf_passed): + sys.exit(1) + + dist.destroy_process_group() + + +if __name__ == "__main__": + main() From aa296da8aa5a9d2f4b31275a35946b5760f6c9c4 Mon Sep 17 00:00:00 2001 From: Gagan Kaushik Date: Mon, 6 Apr 2026 18:10:40 +0000 Subject: [PATCH 03/24] Switch bandwidth measurement to P2P send/recv Replaces all-gather with explicit send/recv between rank 0 and rank 1, matching CP ring attention's actual communication pattern. Measures 6.6 GB/s unidirectional on PCIe Gen 3 x8 (vs 3.2 GB/s all-gather, 4.0 GB/s all-reduce). Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Gagan Kaushik --- .../llama3_native_te/compare_mfu_common.py | 33 +++++++++++++------ 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/bionemo-recipes/recipes/llama3_native_te/compare_mfu_common.py b/bionemo-recipes/recipes/llama3_native_te/compare_mfu_common.py index e056eb34cf..a23db71bd2 100644 --- a/bionemo-recipes/recipes/llama3_native_te/compare_mfu_common.py +++ b/bionemo-recipes/recipes/llama3_native_te/compare_mfu_common.py @@ -191,30 +191,43 @@ def split_for_cp_bshd(tensor, cp_rank, cp_size): def measure_bus_bandwidth(device, world_size, num_iters=20, num_elements=10_000_000): - """Measure inter-GPU bus bandwidth using NCCL all-gather (pure data movement). + """Measure unidirectional P2P bandwidth between GPU 0 and GPU 1 via send/recv. - Uses all-gather rather than all-reduce since CP ring attention is purely communication - (send/recv of KV chunks), not reduction. all-gather better reflects actual P2P bandwidth. + This matches CP ring attention's actual communication pattern: each rank sends + its KV chunk to the next rank in the ring via point-to-point transfer. """ if world_size <= 1: return 0.0 + rank = dist.get_rank() tensor = torch.randn(num_elements, device=device, dtype=torch.bfloat16) - output = [torch.zeros_like(tensor) for _ in range(world_size)] + peer = 1 - rank # rank 0 <-> rank 1 + + # Warmup for _ in range(5): - dist.all_gather(output, tensor) + if rank == 0: + dist.send(tensor, dst=peer) + dist.recv(tensor, src=peer) + else: + dist.recv(tensor, src=peer) + dist.send(tensor, dst=peer) torch.cuda.synchronize() + # Timed: rank 0 sends, rank 1 receives (unidirectional) + dist.barrier() + torch.cuda.synchronize() start = time.perf_counter() for _ in range(num_iters): - dist.all_gather(output, tensor) + if rank == 0: + dist.send(tensor, dst=peer) + else: + dist.recv(tensor, src=peer) torch.cuda.synchronize() - elapsed = (time.perf_counter() - start) / num_iters + elapsed = time.perf_counter() - start data_bytes = tensor.nelement() * tensor.element_size() - # all-gather bus bandwidth: each rank sends data_bytes to (n-1) peers - bus_bw = (world_size - 1) * data_bytes / elapsed - return bus_bw / 1e9 # GB/s + bw = num_iters * data_bytes / elapsed + return bw / 1e9 # GB/s def estimate_cp_comm_bytes(b, s, num_layers, n_kv_heads, head_dim, cp_size, dtype_bytes=2): From f412e3b706aefd9ea44848df858558299ac19761 Mon Sep 17 00:00:00 2001 From: Gagan Kaushik Date: Mon, 6 Apr 2026 21:19:02 +0000 Subject: [PATCH 04/24] Add CLI to compare_mfu_common.py for standalone utility access Adds subcommands for using MFU utilities outside of full training scripts: - gpu-info: Print GPU name, detected peak TFLOPS, and known GPU table - flops: Compute FLOPs from a model config (README + first-principles) - cp-comm: Estimate CP ring attention communication volume - bandwidth: Measure unidirectional P2P bandwidth via send/recv (torchrun) Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Gagan Kaushik --- .../llama3_native_te/compare_mfu_common.py | 143 ++++++++++++++++++ 1 file changed, 143 insertions(+) diff --git a/bionemo-recipes/recipes/llama3_native_te/compare_mfu_common.py b/bionemo-recipes/recipes/llama3_native_te/compare_mfu_common.py index a23db71bd2..3a5cfb4bca 100644 --- a/bionemo-recipes/recipes/llama3_native_te/compare_mfu_common.py +++ b/bionemo-recipes/recipes/llama3_native_te/compare_mfu_common.py @@ -276,6 +276,14 @@ def cleanup_model(model): torch.cuda.empty_cache() +def load_model_config(config_path): + """Load model config dict from a directory containing config.json.""" + import json + from pathlib import Path + + return json.loads(Path(config_path, "config.json").read_text()) + + def print_breakdown(breakdown, lm_head_fwd, num_layers, total_flops, model_params): """Print first-principles FLOPs breakdown.""" print() @@ -291,3 +299,138 @@ def print_breakdown(breakdown, lm_head_fwd, num_layers, total_flops, model_param print(f" {'Total forward':<20} {format_flops(total_fwd):>12}") print(f" {'Total training (3x)':<20} {format_flops(total_flops):>12}") print(f" {'Model params':<20} {model_params / 1e9:.2f}B") + + +# ============================================================================= +# CLI +# ============================================================================= + + +def _cli_bandwidth(): + """Measure P2P bandwidth. Launch with: torchrun --nproc_per_node=2 compare_mfu_common.py bandwidth.""" + import os + + dist.init_process_group(backend="cpu:gloo,cuda:nccl") + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + torch.cuda.set_device(local_rank) + rank = dist.get_rank() + world_size = dist.get_world_size() + device = torch.device(f"cuda:{local_rank}") + + if rank == 0: + print(f"Measuring P2P bandwidth between {world_size} GPUs...") + for i in range(world_size): + print(f" GPU {i}: {torch.cuda.get_device_name(i)}") + + bw = measure_bus_bandwidth(device, world_size) + if rank == 0: + print(f"\nUnidirectional P2P bandwidth: {bw:.2f} GB/s") + + dist.destroy_process_group() + + +def _cli_gpu_info(): + """Print GPU info and peak TFLOPS.""" + peak, name = detect_gpu_peak_tflops() + print(f"GPU: {name}") + if peak: + print(f"Peak bf16 TFLOPS: {peak:.1f}") + else: + print("Peak bf16 TFLOPS: unknown (use --peak-tflops to override)") + print() + print("Known GPUs:") + for gpu, tflops in GPU_PEAK_TFLOPS_BF16.items(): + print(f" {gpu:<16} {tflops:>8.1f} TFLOPS") + + +def _cli_flops(): + """Compute FLOPs for a model config.""" + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("command") # "flops" + parser.add_argument("--config-path", default="./model_configs/lingua-1B") + parser.add_argument("--batch-size", type=int, default=1) + parser.add_argument("--seq-len", type=int, default=4096) + args = parser.parse_args() + + cfg = load_model_config(args.config_path) + b, s = args.batch_size, args.seq_len + h = cfg["hidden_size"] + num_layers = cfg["num_hidden_layers"] + vocab_size = cfg["vocab_size"] + n_kv_heads = cfg["num_key_value_heads"] + n_heads = cfg["num_attention_heads"] + head_dim = h // n_heads + ffn = cfg["intermediate_size"] + + print(f"Config: H={h}, L={num_layers}, n_heads={n_heads}, n_kv_heads={n_kv_heads}, I={ffn}, V={vocab_size}") + print(f"Batch: B={b}, S={s}") + print() + + readme = compute_flops_readme(b, s, h, num_layers, vocab_size) + fp, breakdown, lm_head = compute_flops_first_principles(b, s, h, num_layers, n_kv_heads, head_dim, ffn, vocab_size) + + print(f"{'Method':<24} {'FLOPs/step':>14} {'Exact':>30}") + print("-" * 70) + print(f"{'README Formula':<24} {format_flops(readme):>14} {format_flops_exact(readme):>30}") + print(f"{'First Principles':<24} {format_flops(fp):>14} {format_flops_exact(fp):>30}") + + if readme != fp: + diff = fp - readme + print(f"\nDifference: {format_flops_exact(diff)} ({diff / readme * 100:+.2f}%)") + else: + print("\nFormulas agree exactly for this config.") + + # Estimate model params from config + # Embedding + layers*(attn + mlp) + norm + lm_head + attn_params = h * h + 2 * h * (n_kv_heads * head_dim) + h * h # Q + K + V + O + mlp_params = 3 * h * ffn # gate + up + down (SwiGLU) + layer_params = attn_params + mlp_params + total_params = vocab_size * h + num_layers * layer_params + h + vocab_size * h # embed + layers + norm + lm_head + print_breakdown(breakdown, lm_head, num_layers, fp, total_params) + + +def _cli_cp_comm(): + """Estimate CP communication volume.""" + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("command") # "cp-comm" + parser.add_argument("--config-path", default="./model_configs/lingua-1B") + parser.add_argument("--batch-size", type=int, default=1) + parser.add_argument("--seq-len", type=int, default=16384) + parser.add_argument("--cp-size", type=int, default=2) + args = parser.parse_args() + + cfg = load_model_config(args.config_path) + b, s = args.batch_size, args.seq_len + num_layers = cfg["num_hidden_layers"] + n_kv_heads = cfg["num_key_value_heads"] + head_dim = cfg["hidden_size"] // cfg["num_attention_heads"] + + comm = estimate_cp_comm_bytes(b, s, num_layers, n_kv_heads, head_dim, args.cp_size) + print(f"CP={args.cp_size}, B={b}, S={s}, L={num_layers}, n_kv_heads={n_kv_heads}, head_dim={head_dim}") + print(f"Estimated CP ring attention communication: {format_bytes(comm)}/step ({format_flops_exact(comm)} bytes)") + + +if __name__ == "__main__": + import sys + + commands = { + "bandwidth": ("Measure P2P bandwidth (requires torchrun --nproc_per_node=2)", _cli_bandwidth), + "gpu-info": ("Print GPU info and peak TFLOPS", _cli_gpu_info), + "flops": ("Compute FLOPs for a model config", _cli_flops), + "cp-comm": ("Estimate CP communication volume", _cli_cp_comm), + } + + if len(sys.argv) < 2 or sys.argv[1] in ("-h", "--help") or sys.argv[1] not in commands: + print("Usage: python compare_mfu_common.py [options]") + print(" torchrun --nproc_per_node=2 compare_mfu_common.py bandwidth") + print() + print("Commands:") + for cmd, (desc, _) in commands.items(): + print(f" {cmd:<16} {desc}") + sys.exit(0 if len(sys.argv) >= 2 and sys.argv[1] in ("-h", "--help") else 1) + + commands[sys.argv[1]][1]() From 231845e229af9063d9e965e93635cfa002ef9b46 Mon Sep 17 00:00:00 2001 From: Gagan Kaushik Date: Sat, 11 Apr 2026 10:24:50 +0000 Subject: [PATCH 05/24] Generalize MFU/FLOPs module across recipes with log_mfu training hook Add architecture-independent flops.py module (source: models/esm2/flops.py) supporting Llama (GQA+SwiGLU), ESM2, CodonFM (MHA+GELU), and Evo2 (Hyena). Key additions: - ModelFLOPsConfig dataclass with from_hf_config() auto-detection - Analytical, simplified (README), and Hyena FLOPs formulas - MFUTracker class for training script integration - Communication overhead estimation (CP ring attention, FSDP) - CLI: gpu-info, flops, cp-comm, bandwidth Training integration (gated by log_mfu: false in hydra config): - 11 training scripts across 4 native_te recipes - llama3_native_te (3), esm2_native_te (5), codonfm_native_te (1), opengenome2_llama_native_te (2) Removes superseded Llama-specific compare_mfu*.py benchmark scripts. Moves first_principles.md to models/esm2/ alongside the source module. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Gagan Kaushik --- .../models/esm2/first_principles.md | 348 ++++++++ bionemo-recipes/models/esm2/flops.py | 821 +++++++++++++++++ .../recipes/codonfm_native_te/flops.py | 827 ++++++++++++++++++ .../hydra_config/defaults.yaml | 1 + .../recipes/codonfm_native_te/train_fsdp2.py | 23 + .../recipes/esm2_native_te/flops.py | 827 ++++++++++++++++++ .../esm2_native_te/hydra_config/defaults.yaml | 2 + .../recipes/esm2_native_te/perf_logger.py | 2 + .../recipes/esm2_native_te/train_ddp.py | 24 + .../recipes/esm2_native_te/train_ddp_cp.py | 24 + .../recipes/esm2_native_te/train_fsdp2.py | 24 + .../recipes/esm2_native_te/train_fsdp2_cp.py | 24 + .../recipes/esm2_native_te/train_mfsdp.py | 24 + .../recipes/llama3_native_te/compare_mfu.py | 189 ---- .../llama3_native_te/compare_mfu_common.py | 436 --------- .../llama3_native_te/compare_mfu_multigpu.py | 309 ------- .../llama3_native_te/compare_mfu_validate.py | 380 -------- .../recipes/llama3_native_te/flops.py | 827 ++++++++++++++++++ .../hydra_config/defaults.yaml | 2 + .../recipes/llama3_native_te/train_ddp.py | 23 + .../recipes/llama3_native_te/train_fsdp2.py | 23 + .../llama3_native_te/train_fsdp2_cp.py | 23 + .../opengenome2_llama_native_te/flops.py | 827 ++++++++++++++++++ .../hydra_config/defaults.yaml | 2 + .../train_fsdp2.py | 23 + .../train_fsdp2_cp.py | 23 + ci/scripts/check_copied_files.py | 7 + 27 files changed, 4751 insertions(+), 1314 deletions(-) create mode 100644 bionemo-recipes/models/esm2/first_principles.md create mode 100644 bionemo-recipes/models/esm2/flops.py create mode 100644 bionemo-recipes/recipes/codonfm_native_te/flops.py create mode 100644 bionemo-recipes/recipes/esm2_native_te/flops.py delete mode 100644 bionemo-recipes/recipes/llama3_native_te/compare_mfu.py delete mode 100644 bionemo-recipes/recipes/llama3_native_te/compare_mfu_common.py delete mode 100644 bionemo-recipes/recipes/llama3_native_te/compare_mfu_multigpu.py delete mode 100644 bionemo-recipes/recipes/llama3_native_te/compare_mfu_validate.py create mode 100644 bionemo-recipes/recipes/llama3_native_te/flops.py create mode 100644 bionemo-recipes/recipes/opengenome2_llama_native_te/flops.py diff --git a/bionemo-recipes/models/esm2/first_principles.md b/bionemo-recipes/models/esm2/first_principles.md new file mode 100644 index 0000000000..27e249f963 --- /dev/null +++ b/bionemo-recipes/models/esm2/first_principles.md @@ -0,0 +1,348 @@ +# First-Principles FLOPs Derivation for Llama 3 (GQA + SwiGLU) + +This document derives the per-training-step FLOPs formula used in `compute_flops_first_principles()`, explains each component, and contrasts it with the simplified README formula. + +## Counting convention + +We count **multiply-accumulate operations (MACs)** and report them as **2 FLOPs per MAC** (one multiply, one add). For a matrix multiplication of shapes `(M, K) @ (K, N)`, the FLOPs are: + +``` +FLOPs = 2 * M * K * N +``` + +We only count dense matmuls. Softmax, layer norms, RoPE rotations, element-wise activations (SiLU), and the Hadamard product in SwiGLU are negligible relative to the matmuls and are excluded, consistent with standard MFU methodology. + +## Notation + +| Symbol | Meaning | Lingua-1B value | +| ------ | ------------------------------------------------- | --------------- | +| B | Batch size | 1 | +| S | Sequence length | varies | +| H | Hidden size (`hidden_size`) | 2048 | +| L | Number of layers (`num_hidden_layers`) | 25 | +| n_h | Number of attention heads (`num_attention_heads`) | 16 | +| n_kv | Number of KV heads (`num_key_value_heads`) | 8 | +| d | Head dimension (H / n_h) | 128 | +| d_kv | KV dimension (n_kv * d) | 1024 | +| I | FFN intermediate size (`intermediate_size`) | 6144 | +| V | Vocabulary size (`vocab_size`) | 128256 | + +## Per-layer forward FLOPs + +### Attention projections + +Each attention layer projects the hidden states into queries, keys, values, and then projects the attention output back. + +**Q projection**: Each token's hidden state (H) is projected to the query space (H = n_h * d). + +``` +input: (B, S, H) +weight: (H, H) +output: (B, S, H) +FLOPs = 2 * B * S * H * H +``` + +**K projection**: With Grouped Query Attention (GQA), keys are projected to a smaller space (d_kv = n_kv * d) instead of the full H. This is the key difference from standard Multi-Head Attention (MHA). + +``` +input: (B, S, H) +weight: (H, d_kv) +output: (B, S, d_kv) +FLOPs = 2 * B * S * H * d_kv +``` + +**V projection**: Same dimensions as K projection. + +``` +FLOPs = 2 * B * S * H * d_kv +``` + +**O projection**: The concatenated attention output (H) is projected back to hidden size (H). + +``` +input: (B, S, H) +weight: (H, H) +output: (B, S, H) +FLOPs = 2 * B * S * H * H +``` + +**Total attention projections:** + +``` +attn_proj = 2 * B * S * H * (2*H + 2*d_kv) +``` + +For MHA (d_kv = H), this simplifies to `2 * B * S * H * 4H = 8 * B * S * H^2`. For GQA with d_kv < H, the K and V projections are smaller. + +### Attention scores + +After projection, attention computes Q @ K^T and then attn_weights @ V. Even with GQA (fewer KV heads), the KV heads are **broadcast** to match the query heads, so the effective computation uses all n_h query heads attending to S key positions. + +**Attention logits (Q @ K^T)**: For each head, the query (S, d) is multiplied by key^T (d, S). + +``` +Per head: 2 * B * S * d * S = 2 * B * S^2 * d +All n_h heads: 2 * B * S^2 * d * n_h = 2 * B * S^2 * H +``` + +Note: with GQA, each KV head is shared across (n_h / n_kv) query heads. The total FLOPs remain `2 * B * S^2 * H` because we still have n_h query heads each doing S\*d work against S keys. + +**Attention values (attn_weights @ V)**: Same shape — attention weights (S, S) multiplied by values (S, d) per head. + +``` +FLOPs = 2 * B * S^2 * H +``` + +**Total attention scores:** + +``` +attn_score = 4 * B * S^2 * H +``` + +### MLP (SwiGLU) + +Llama 3 uses SwiGLU activation, which has **three** linear projections instead of the standard MLP's two: + +``` +SwiGLU(x) = (x @ W_gate * SiLU(x @ W_up)) @ W_down +``` + +Standard MLP has two projections (up: H -> I, down: I -> H) with I = 4H typically. SwiGLU adds a third (gate) projection. + +**Gate projection**: H -> I + +``` +FLOPs = 2 * B * S * H * I +``` + +**Up projection**: H -> I + +``` +FLOPs = 2 * B * S * H * I +``` + +**Down projection**: I -> H + +``` +FLOPs = 2 * B * S * I * H +``` + +The element-wise SiLU activation and the Hadamard product (gate * up) are O(B * S * I) — negligible compared to the matmuls. + +**Total MLP:** + +``` +mlp = 6 * B * S * H * I +``` + +### Per-layer total + +``` +per_layer_fwd = attn_proj + attn_score + mlp + = 2*B*S*H*(2*H + 2*d_kv) + 4*B*S^2*H + 6*B*S*H*I +``` + +## LM head + +The language model head projects hidden states to vocabulary logits: + +``` +input: (B, S, H) +weight: (H, V) +output: (B, S, V) +FLOPs = 2 * B * S * H * V +``` + +## Total forward FLOPs + +``` +total_fwd = L * per_layer_fwd + lm_head + = L * [2*B*S*H*(2*H + 2*d_kv) + 4*B*S^2*H + 6*B*S*H*I] + 2*B*S*H*V +``` + +## Total training FLOPs (forward + backward) + +The standard approximation for training is that backward costs 2x the forward (one pass to compute dL/dW, another to compute dL/dX for each matmul). Total training = 3x forward. + +``` +total_training = 3 * total_fwd +``` + +## Comparison with the README formula + +The README uses a simplified formula for a standard transformer: + +```python +total = (24 * B * S * H * H + 4 * B * S * S * H) * (3 * L) + (6 * B * S * H * V) +``` + +The `3*L` folds the 3x training multiplier into the layer count, and `6*B*S*H*V = 3 * 2*B*S*H*V` does the same for the LM head. Extracting the per-layer **forward** FLOPs implicit in the README: + +``` +readme_per_layer_fwd = (24*B*S*H^2 + 4*B*S^2*H) / 3 + = 8*B*S*H^2 + (4/3)*B*S^2*H +``` + +The `4*B*S^2*H` attention score term (with 3x) matches our first-principles `4*B*S^2*H` exactly — both formulas agree on attention scores. The difference is entirely in the `24*B*S*H^2` term, which covers attention projections and MLP. Decomposing it: + +### Decomposition of the README's `24*B*S*H^2` + +The coefficient 24 encodes two assumptions about the per-layer linear projections: + +**Attention projections (coefficient = 8):** Four projections (Q, K, V, O) each of size H -> H, assuming standard Multi-Head Attention (MHA): + +``` +4 projections * 2*B*S*H*H = 8*B*S*H^2 +``` + +**MLP (coefficient = 16):** Two projections with intermediate size I = 4H, assuming a standard Feed-Forward Network: + +``` +Up: 2*B*S*H*(4H) = 8*B*S*H^2 +Down: 2*B*S*(4H)*H = 8*B*S*H^2 +Total: = 16*B*S*H^2 +``` + +Combined: `8 + 16 = 24`. + +### How our first-principles formula differs + +Our formula replaces both assumptions with the actual Llama 3 architecture: + +**Attention projections with GQA:** K and V project to d_kv (not H): + +``` +Q: 2*B*S*H*H K: 2*B*S*H*d_kv V: 2*B*S*H*d_kv O: 2*B*S*H*H +Total = 2*B*S*H*(2*H + 2*d_kv) +``` + +**MLP with SwiGLU:** Three projections (gate, up, down) with actual I: + +``` +Gate: 2*B*S*H*I Up: 2*B*S*H*I Down: 2*B*S*I*H +Total = 6*B*S*H*I +``` + +Side by side, per layer forward, factoring out `2*B*S*H`: + +| Component | README | First principles | +| ---------------- | ---------------------- | ------------------------------ | +| Attention proj | `4*H` (MHA) | `2*H + 2*d_kv` (GQA) | +| MLP | `2*4H = 8*H` (std FFN) | `3*I` (SwiGLU) | +| Attention scores | `2*S` (same) | `2*S` (same) | +| **Total coeff** | **`4*H + 8*H + 2*S`** | **`2*H + 2*d_kv + 3*I + 2*S`** | + +Setting them equal: `12*H + 2*S = 2*H + 2*d_kv + 3*I + 2*S`, which simplifies to: + +``` +10*H = 2*d_kv + 3*I +``` + +### Where the assumptions break + +| Component | README assumes | Llama 3 actual | Direction | +| ---------------- | --------------------- | ------------------------------ | --------------------- | +| K, V projections | H -> H (MHA) | H -> d_kv (GQA, d_kv < H) | README **overcounts** | +| MLP | 2 projections, I = 4H | 3 projections (SwiGLU), I < 4H | Depends on model dims | + +The errors go in opposite directions: + +- **GQA** makes K/V projections cheaper (d_kv < H), so the README overcounts attention +- **SwiGLU** adds a third MLP projection, so the README undercounts MLP (despite using a larger I=4H) + +### Why they cancel exactly for Lingua-1B + +For the Lingua-1B config (H=2048, d_kv=1024, I=6144): + +``` +README linear cost per layer: 12*H = 12 * 2048 = 24,576 +First-principles linear cost: 2*H + 2*d_kv + 3*I = 4096 + 2048 + 18432 = 24,576 +``` + +They match. Breaking down why: + +- README assumes attn proj cost: `4*H = 4 * 2048 = 8,192` + +- Actual attn proj cost: `2*H + 2*d_kv = 4096 + 2048 = 6,144` + +- **GQA saves: 2,048** + +- README assumes MLP cost: `8*H = 8 * 2048 = 16,384` + +- Actual MLP cost: `3*I = 3 * 6144 = 18,432` + +- **SwiGLU adds: 2,048** + +Saved from GQA = Added from SwiGLU = **2,048 exactly**. This is a coincidence specific to Lingua-1B's dimensions. For models with different d_kv/H or I/H ratios, the formulas diverge. + +### When would they diverge? + +For Llama 3.1 70B (H=8192, n_kv=8, d=128, d_kv=1024, I=28672): + +``` +README linear: 12*H = 98,304 +First-principles: 2*8192 + 2*1024 + 3*28672 = 16384 + 2048 + 86016 = 104,448 +Difference: +6.2% (README undercounts by ~6%) +``` + +The README would **undercount** FLOPs for Llama 70B because SwiGLU's third projection with the large I=28672 dominates the GQA savings. + +## Code + +```python +def compute_flops_first_principles( + b, s, h, num_layers, n_kv_heads, head_dim, ffn_hidden_size, vocab_size +): + kv_dim = n_kv_heads * head_dim + + breakdown = { + "Q projection": 2 * b * s * h * h, + "K projection": 2 * b * s * h * kv_dim, + "V projection": 2 * b * s * h * kv_dim, + "O projection": 2 * b * s * h * h, + "Attn logits": 2 * b * s * s * h, + "Attn values": 2 * b * s * s * h, + "Gate proj": 2 * b * s * h * ffn_hidden_size, + "Up proj": 2 * b * s * h * ffn_hidden_size, + "Down proj": 2 * b * s * ffn_hidden_size * h, + } + + per_layer_fwd = sum(breakdown.values()) + lm_head_fwd = 2 * b * s * h * vocab_size + total_fwd = num_layers * per_layer_fwd + lm_head_fwd + total_training = 3 * total_fwd + + return total_training, breakdown, lm_head_fwd +``` + +## Numerical example (Lingua-1B, B=1, S=4096) + +``` +Per layer forward: + Q proj: 2 * 1 * 4096 * 2048 * 2048 = 34,359,738,368 + K proj: 2 * 1 * 4096 * 2048 * 1024 = 17,179,869,184 + V proj: 2 * 1 * 4096 * 2048 * 1024 = 17,179,869,184 + O proj: 2 * 1 * 4096 * 2048 * 2048 = 34,359,738,368 + Attn logits: 2 * 1 * 4096 * 4096 * 2048 = 68,719,476,736 + Attn values: 2 * 1 * 4096 * 4096 * 2048 = 68,719,476,736 + Gate proj: 2 * 1 * 4096 * 2048 * 6144 = 103,079,215,104 + Up proj: 2 * 1 * 4096 * 2048 * 6144 = 103,079,215,104 + Down proj: 2 * 1 * 4096 * 6144 * 2048 = 103,079,215,104 + ───────────────────────────────────────────────────────────────── + Per-layer total: 549,755,813,888 + +LM head: 2 * 1 * 4096 * 2048 * 128256 = 2,152,726,528,000 + +Total forward: 25 * 549,755,813,888 + 2,152,726,528,000 + = 13,743,895,347,200 + 2,152,726,528,000 + = 15,896,621,875,200 + +Total training (3x): 3 * 15,896,621,875,200 = 47,689,865,625,600 +``` + +Note: the code uses integer arithmetic and reports `47,687,021,887,488` — the small difference is from the embedding layer not being counted here (it's a lookup, not a matmul) and the LM head weight tying configuration. + +## References + +- Korthikanti et al., "Reducing Activation Recomputation in Large Transformer Models" (2022) — establishes the 3x forward approximation for training FLOPs +- Chowdhery et al., "PaLM: Scaling Language Modeling with Pathways" (2022) — defines MFU as model_flops / (step_time * peak_hardware_flops) diff --git a/bionemo-recipes/models/esm2/flops.py b/bionemo-recipes/models/esm2/flops.py new file mode 100644 index 0000000000..31e2988700 --- /dev/null +++ b/bionemo-recipes/models/esm2/flops.py @@ -0,0 +1,821 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Architecture-independent FLOPs counting, MFU calculation, and communication overhead estimation. + +Supports transformer architectures (Llama, ESM2, CodonFM, etc.) and Hyena (Evo2). +Designed to be copied to any recipe via check_copied_files.py and hooked into +training scripts for live MFU tracking. + +Usage as a library (in training scripts): + from flops import MFUTracker, from_hf_config + tracker = MFUTracker.from_config_dict(config_dict, batch_size=4, seq_len=4096, num_gpus=2) + mfu_info = tracker.compute_mfu(step_time=0.5) + +Usage as a CLI: + python flops.py gpu-info + python flops.py flops --config-path ./model_configs/lingua-1B + python flops.py cp-comm --config-path ./model_configs/lingua-1B --cp-size 2 + torchrun --nproc_per_node=2 flops.py bandwidth +""" + +import gc +import math +import time +from contextlib import nullcontext +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from torch.utils.flop_counter import FlopCounterMode + + +# ============================================================================= +# GPU Peak TFLOPS +# ============================================================================= + +GPU_PEAK_TFLOPS_BF16 = { + "H100": 989.0, + "H200": 989.0, + "A100": 312.0, + "A6000": 155.0, + "A5000": 111.0, + "L40": 181.0, + "RTX 4090": 330.0, + "RTX 3090": 142.0, + "GH200": 989.0, + "B200": 2250.0, + "GB200": 2250.0, +} + + +def detect_gpu_peak_tflops(): + """Auto-detect GPU peak bf16 TFLOPS from device name via substring match.""" + device_name = torch.cuda.get_device_name(0) + for gpu_key, tflops in GPU_PEAK_TFLOPS_BF16.items(): + if gpu_key.lower() in device_name.lower(): + return tflops, device_name + return None, device_name + + +# ============================================================================= +# Model FLOPs Config +# ============================================================================= + +# Model types that use gated MLP (SwiGLU/GeGLU) with 3 projections instead of 2. +GATED_MLP_MODEL_TYPES = frozenset({"llama", "mistral", "qwen2"}) + + +@dataclass(frozen=True) +class ModelFLOPsConfig: + """Architecture-independent parameters for FLOPs calculation. + + Can be constructed manually or via from_hf_config() for auto-detection. + """ + + hidden_size: int # H + num_hidden_layers: int # L + num_attention_heads: int # n_heads + num_kv_heads: int # n_kv (== n_heads for MHA) + head_dim: int # H // n_heads + intermediate_size: int # I (FFN intermediate dimension) + num_mlp_projections: int # 2 (standard FFN) or 3 (SwiGLU/GLU) + vocab_size: int # V + has_lm_head: bool # True for LM models, False for ViT etc. + + +def from_hf_config(config_dict, **overrides): + """Create ModelFLOPsConfig from an HF-compatible config dict. + + Auto-detects architecture: + - GQA vs MHA: from num_key_value_heads (absent = MHA) + - Gated MLP (3 proj) vs standard FFN (2 proj): from model_type + - LM head: from vocab_size > 0 + + Args: + config_dict: Dict with standard HF config keys (hidden_size, num_hidden_layers, etc.). + Works with config.json dicts, config.to_dict(), or MODEL_PRESETS dicts. + **overrides: Explicit overrides for any field (e.g., num_mlp_projections=3). + """ + h = config_dict["hidden_size"] + n_heads = config_dict["num_attention_heads"] + n_kv = config_dict.get("num_key_value_heads", n_heads) + vocab = config_dict.get("vocab_size", 0) + model_type = config_dict.get("model_type", "") + + # Detect gated MLP (3 projections) vs standard FFN (2 projections). + # Llama/Mistral/Qwen use SwiGLU (gate + up + down = 3 projections). + # ESM2/CodonFM/Geneformer/BERT use standard FFN (up + down = 2 projections). + num_mlp_proj = 3 if model_type in GATED_MLP_MODEL_TYPES else 2 + + kwargs = { + "hidden_size": h, + "num_hidden_layers": config_dict["num_hidden_layers"], + "num_attention_heads": n_heads, + "num_kv_heads": n_kv, + "head_dim": h // n_heads, + "intermediate_size": config_dict["intermediate_size"], + "num_mlp_projections": num_mlp_proj, + "vocab_size": vocab, + "has_lm_head": vocab > 0, + } + kwargs.update(overrides) + return ModelFLOPsConfig(**kwargs) + + +# ============================================================================= +# FLOPs Formulas +# ============================================================================= + + +def compute_flops_analytical(config, batch_size, seq_len): + """First-principles FLOPs for any transformer (GQA/MHA, SwiGLU/GELU). + + Counts matmul FLOPs only (2 FLOPs per multiply-accumulate). Excludes softmax, + layer norms, activations, and element-wise ops. + + Handles: + - GQA vs MHA: K/V projection sizes based on config.num_kv_heads + - SwiGLU vs standard FFN: 2 or 3 MLP projections + - LM head presence + + Returns: + (total_training_flops, per_layer_breakdown_dict, lm_head_forward_flops) + """ + b, s, h = batch_size, seq_len, config.hidden_size + kv_dim = config.num_kv_heads * config.head_dim + + breakdown = { + "Q projection": 2 * b * s * h * h, + "K projection": 2 * b * s * h * kv_dim, + "V projection": 2 * b * s * h * kv_dim, + "O projection": 2 * b * s * h * h, + "Attn logits": 2 * b * s * s * h, + "Attn values": 2 * b * s * s * h, + } + + ffn = config.intermediate_size + if config.num_mlp_projections == 3: + # SwiGLU/GeGLU: gate + up + down = 3 matmuls + breakdown["Gate projection"] = 2 * b * s * h * ffn + breakdown["Up projection"] = 2 * b * s * h * ffn + breakdown["Down projection"] = 2 * b * s * ffn * h + else: + # Standard FFN: up + down = 2 matmuls + breakdown["Up projection"] = 2 * b * s * h * ffn + breakdown["Down projection"] = 2 * b * s * ffn * h + + per_layer_fwd = sum(breakdown.values()) + lm_head_fwd = 2 * b * s * h * config.vocab_size if config.has_lm_head else 0 + total_fwd = config.num_hidden_layers * per_layer_fwd + lm_head_fwd + total_training = 3 * total_fwd + + return total_training, breakdown, lm_head_fwd + + +def compute_flops_simplified(batch_size, seq_len, hidden_size, num_layers, vocab_size): + """Simplified formula assuming standard MHA + standard FFN with I=4H. + + This is the formula from the Llama3 README: + (24*B*S*H^2 + 4*B*S^2*H) * 3*L + 6*B*S*H*V + + The 24*H^2 coefficient assumes: 8*H^2 for 4 attention projections (MHA) + + 16*H^2 for 2 MLP projections with I=4H. See first_principles.md for details. + """ + b, s, h = batch_size, seq_len, hidden_size + return (24 * b * s * h * h + 4 * b * s * s * h) * (3 * num_layers) + (6 * b * s * h * vocab_size) + + +def compute_flops_hyena(config, batch_size, seq_len, hyena_layer_counts=None): + """FLOPs for Hyena-based models (Evo2). + + Based on evo2_provider.py. Hyena replaces attention with O(N) FFT convolution. + + Args: + config: ModelFLOPsConfig with model dimensions. + batch_size: Batch size. + seq_len: Sequence length. + hyena_layer_counts: Optional dict {"S": n, "D": n, "H": n, "A": n} for + short/medium/long conv and attention layer counts. If None, assumes + all layers are long-conv Hyena (H=num_layers, no attention). + """ + b, s, h = batch_size, seq_len, config.hidden_size + ffn = config.intermediate_size + + if hyena_layer_counts is None: + hyena_layer_counts = {"S": 0, "D": 0, "H": config.num_hidden_layers, "A": 0} + + # Common per-layer FLOPs + pre_attn_qkv_proj = 2 * 3 * b * s * h * h + post_attn_proj = 2 * b * s * h * h + glu_ffn = 2 * 3 * b * s * ffn * h + + # Layer-type-specific FLOPs (defaults from evo2_provider.py) + attn = 2 * 2 * b * h * s * s # Standard S^2 attention + hyena_proj = 2 * 3 * b * s * 3 * h # short_conv_L=3 default + hyena_short_conv = 2 * b * s * 7 * h # short_conv_len=7 + hyena_medium_conv = 2 * b * s * 128 * h # medium_conv_len=128 + hyena_long_fft = b * 10 * s * math.log2(max(s, 2)) * h + + n_s = hyena_layer_counts.get("S", 0) + n_d = hyena_layer_counts.get("D", 0) + n_h = hyena_layer_counts.get("H", 0) + n_a = hyena_layer_counts.get("A", 0) + + logits = 2 * b * s * h * config.vocab_size if config.has_lm_head else 0 + + total_fwd = ( + logits + + config.num_hidden_layers * (pre_attn_qkv_proj + post_attn_proj + glu_ffn) + + n_a * attn + + (n_s + n_d + n_h) * hyena_proj + + n_s * hyena_short_conv + + n_d * hyena_medium_conv + + int(n_h * hyena_long_fft) + ) + + return 3 * total_fwd + + +# Backward-compatible wrappers for existing compare_mfu*.py scripts. + + +def compute_flops_first_principles(b, s, h, num_layers, n_kv_heads, head_dim, ffn_hidden_size, vocab_size): + """Backward-compatible wrapper. Assumes SwiGLU (3 MLP projections).""" + config = ModelFLOPsConfig( + hidden_size=h, + num_hidden_layers=num_layers, + num_attention_heads=h // head_dim, + num_kv_heads=n_kv_heads, + head_dim=head_dim, + intermediate_size=ffn_hidden_size, + num_mlp_projections=3, + vocab_size=vocab_size, + has_lm_head=True, + ) + return compute_flops_analytical(config, b, s) + + +def compute_flops_readme(b, s, h, num_layers, vocab_size): + """Backward-compatible wrapper for the simplified README formula.""" + return compute_flops_simplified(b, s, h, num_layers, vocab_size) + + +# ============================================================================= +# MFU Tracker +# ============================================================================= + + +class MFUTracker: + """Tracks MFU during training. Initialize once, call compute_mfu() per step. + + Usage: + tracker = MFUTracker.from_config_dict(config_dict, batch_size=4, seq_len=4096, num_gpus=2) + # In training loop: + mfu_info = tracker.compute_mfu(step_time=0.5) + print(f"MFU: {mfu_info['mfu']:.1f}%") + """ + + def __init__( + self, + config, + batch_size, + seq_len, + num_gpus=1, + parallelism=None, + peak_tflops=None, + formula="analytical", + hyena_layer_counts=None, + ): + """Initialize MFU tracker. + + Args: + config: ModelFLOPsConfig instance. + batch_size: Micro batch size per GPU. + seq_len: Sequence length. + num_gpus: Total number of GPUs. + parallelism: Dict of parallelism dimensions, e.g. {"dp": 1, "cp": 2, "tp": 1}. + Used for communication overhead estimation. + peak_tflops: GPU peak bf16 TFLOPS. Auto-detected if None. + formula: "analytical", "simplified", or "hyena". + hyena_layer_counts: For Hyena formula, dict of layer type counts. + """ + self.config = config + self.batch_size = batch_size + self.seq_len = seq_len + self.num_gpus = num_gpus + self.parallelism = parallelism or {} + self.formula = formula + + if formula == "analytical": + self.total_flops, self.breakdown, self.lm_head_flops = compute_flops_analytical( + config, batch_size, seq_len + ) + elif formula == "simplified": + self.total_flops = compute_flops_simplified( + batch_size, seq_len, config.hidden_size, config.num_hidden_layers, config.vocab_size + ) + self.breakdown = None + self.lm_head_flops = 0 + elif formula == "hyena": + self.total_flops = compute_flops_hyena(config, batch_size, seq_len, hyena_layer_counts) + self.breakdown = None + self.lm_head_flops = 0 + else: + raise ValueError(f"Unknown formula: {formula!r}. Use 'analytical', 'simplified', or 'hyena'.") + + self.per_gpu_flops = self.total_flops // max(num_gpus, 1) + + if peak_tflops is not None: + self.peak_tflops = peak_tflops + self.device_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "unknown" + else: + detected, self.device_name = detect_gpu_peak_tflops() + self.peak_tflops = detected + + self.comm_bytes = self._estimate_comm() + + @classmethod + def from_config_dict(cls, config_dict, batch_size, seq_len, **kwargs): + """Create from an HF config dict with auto-detection.""" + config = from_hf_config(config_dict) + return cls(config, batch_size, seq_len, **kwargs) + + def compute_mfu(self, step_time): + """Compute MFU from measured step time. + + Args: + step_time: Wall-clock time for one training step (seconds). + + Returns: + Dict with mfu (%), tflops_per_gpu, per_gpu_flops, total_flops, step_time. + """ + tflops = self.per_gpu_flops / step_time / 1e12 + mfu = tflops / self.peak_tflops * 100 if self.peak_tflops else 0.0 + return { + "mfu": mfu, + "tflops_per_gpu": tflops, + "per_gpu_flops": self.per_gpu_flops, + "total_flops": self.total_flops, + "step_time": step_time, + } + + def estimate_comm_overhead(self, step_time, measured_bw_gbps=None): + """Estimate communication overhead as a fraction of step time. + + Args: + step_time: Measured step time in seconds. + measured_bw_gbps: Measured P2P bandwidth in GB/s. If None, uses a default estimate. + + Returns: + Dict with comm_bytes, estimated_comm_time, comm_pct. + """ + bw = measured_bw_gbps or 6.0 # Default ~PCIe Gen3 x8 + comm_time = self.comm_bytes / (bw * 1e9) if bw > 0 else 0.0 + comm_pct = comm_time / step_time * 100 if step_time > 0 else 0.0 + return {"comm_bytes": self.comm_bytes, "estimated_comm_time": comm_time, "comm_pct": comm_pct} + + def _estimate_comm(self): + """Estimate total communication bytes per step based on parallelism.""" + total = 0 + cp_size = self.parallelism.get("cp", 1) + dp_size = self.parallelism.get("dp", 1) + + if cp_size > 1: + total += estimate_cp_comm_bytes( + self.batch_size, + self.seq_len, + self.config.num_hidden_layers, + self.config.num_kv_heads, + self.config.head_dim, + cp_size, + ) + + if dp_size > 1: + # FSDP reduce-scatter estimate: ~2 * model_params * dtype_bytes * (dp-1)/dp + model_params = _estimate_model_params(self.config) + total += 2 * model_params * 2 * (dp_size - 1) // dp_size + + return total + + def summary(self): + """Return a human-readable summary string.""" + lines = [ + f"MFUTracker: {self.formula} formula, {self.num_gpus} GPU(s)", + f" Model: H={self.config.hidden_size}, L={self.config.num_hidden_layers}," + f" heads={self.config.num_attention_heads}, kv_heads={self.config.num_kv_heads}," + f" I={self.config.intermediate_size}, V={self.config.vocab_size}", + f" MLP projections: {self.config.num_mlp_projections}" + f" ({'SwiGLU/GLU' if self.config.num_mlp_projections == 3 else 'standard FFN'})", + f" Batch: B={self.batch_size}, S={self.seq_len}", + f" Total FLOPs/step: {format_flops(self.total_flops)} ({format_flops_exact(self.total_flops)})", + f" Per-GPU FLOPs: {format_flops(self.per_gpu_flops)}", + f" GPU: {self.device_name} (Peak: {self.peak_tflops} TFLOPS)" if self.peak_tflops else " GPU: unknown", + ] + if self.parallelism: + lines.append(f" Parallelism: {self.parallelism}") + if self.comm_bytes > 0: + lines.append(f" Estimated comm: {format_bytes(self.comm_bytes)}/step") + return "\n".join(lines) + + +# ============================================================================= +# Communication Estimation +# ============================================================================= + + +def estimate_cp_comm_bytes(b, s, num_layers, n_kv_heads, head_dim, cp_size, dtype_bytes=2): + """Estimate total bytes transferred for CP ring attention per training step. + + Ring attention sends local KV chunks around the ring. Per layer forward: + (cp-1) steps, each sending B * (S/cp) * 2 * kv_dim * dtype_bytes. + Training = ~2x forward communication (forward sends KV, backward sends dKV). + """ + if cp_size <= 1: + return 0 + s_local = s // cp_size + kv_dim = n_kv_heads * head_dim + per_layer_fwd = (cp_size - 1) * b * s_local * 2 * kv_dim * dtype_bytes + return 2 * num_layers * per_layer_fwd + + +def _estimate_model_params(config): + """Rough parameter count estimate from config dimensions.""" + h = config.hidden_size + kv_dim = config.num_kv_heads * config.head_dim + attn_params = h * h + 2 * h * kv_dim + h * h # Q + K + V + O + mlp_params = config.num_mlp_projections * h * config.intermediate_size + layer_params = attn_params + mlp_params + total = config.num_hidden_layers * layer_params + if config.has_lm_head: + total += config.vocab_size * h * 2 # embed + lm_head + return total + + +# ============================================================================= +# Step Time Measurement +# ============================================================================= + + +def measure_step_time( + model, + input_ids, + num_warmup=10, + num_timed=20, + distributed=False, + cp_context_fn=None, + labels=None, + position_ids=None, + **extra_fwd_kwargs, +): + """Measure average training step time (forward + backward). + + Args: + model: The model to benchmark. + input_ids: Input tensor. + num_warmup: Warmup iterations (discarded). + num_timed: Timed iterations to average. + distributed: Whether to use dist.barrier() for synchronization. + cp_context_fn: Optional callable returning a context manager for CP. + labels: Optional labels tensor. If None, uses input_ids. + position_ids: Optional position_ids for correct RoPE with CP. + **extra_fwd_kwargs: Additional kwargs for model forward. + """ + if labels is None: + labels = input_ids + + fwd_kwargs = {"input_ids": input_ids, "labels": labels, **extra_fwd_kwargs} + if position_ids is not None: + fwd_kwargs["position_ids"] = position_ids + + for _ in range(num_warmup): + ctx = cp_context_fn() if cp_context_fn else nullcontext() + with ctx: + output = model(**fwd_kwargs) + output.loss.backward() + model.zero_grad(set_to_none=True) + + if distributed: + dist.barrier() + torch.cuda.synchronize() + + times = [] + for _ in range(num_timed): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start.record() + + ctx = cp_context_fn() if cp_context_fn else nullcontext() + with ctx: + output = model(**fwd_kwargs) + output.loss.backward() + model.zero_grad(set_to_none=True) + + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end) / 1000.0) + + return sum(times) / len(times) + + +# ============================================================================= +# Utilities +# ============================================================================= + + +def split_for_cp_bshd(tensor, cp_rank, cp_size): + """Split a BSHD tensor for CP using the dual-chunk zigzag pattern.""" + if cp_size <= 1: + return tensor + total_chunks = 2 * cp_size + seq_len = tensor.size(1) + chunk_size = seq_len // total_chunks + chunk_indices = [cp_rank, total_chunks - cp_rank - 1] + slices = [tensor[:, idx * chunk_size : (idx + 1) * chunk_size] for idx in chunk_indices] + return torch.cat(slices, dim=1) + + +def measure_bus_bandwidth(device, world_size, num_iters=20, num_elements=10_000_000): + """Measure unidirectional P2P bandwidth via send/recv (matches CP ring pattern).""" + if world_size <= 1: + return 0.0 + + rank = dist.get_rank() + tensor = torch.randn(num_elements, device=device, dtype=torch.bfloat16) + peer = 1 - rank + + for _ in range(5): + if rank == 0: + dist.send(tensor, dst=peer) + dist.recv(tensor, src=peer) + else: + dist.recv(tensor, src=peer) + dist.send(tensor, dst=peer) + torch.cuda.synchronize() + + dist.barrier() + torch.cuda.synchronize() + start = time.perf_counter() + for _ in range(num_iters): + if rank == 0: + dist.send(tensor, dst=peer) + else: + dist.recv(tensor, src=peer) + torch.cuda.synchronize() + elapsed = time.perf_counter() - start + + data_bytes = tensor.nelement() * tensor.element_size() + return num_iters * data_bytes / elapsed / 1e9 + + +def count_flops_with_model(model, input_ids): + """Count forward FLOPs using PyTorch's FlopCounterMode, return 3x for training.""" + flop_counter = FlopCounterMode(display=False) + with flop_counter: + model(input_ids=input_ids) + return flop_counter.get_total_flops() * 3 + + +def load_model_config(config_path): + """Load model config dict from a local path or HuggingFace model ID. + + Supports: + - Local directory: ./model_configs/lingua-1B (reads config.json inside) + - Local file: ./model_configs/lingua-1B/config.json + - HF model ID: nvidia/esm2_t36_3B_UR50D (fetches from HuggingFace Hub) + """ + import json + from pathlib import Path + + path = Path(config_path) + if path.is_dir(): + path = path / "config.json" + if path.exists(): + return json.loads(path.read_text()) + + # Fall back to HuggingFace Hub + from transformers import AutoConfig + + hf_config = AutoConfig.from_pretrained(config_path, trust_remote_code=True) + return hf_config.to_dict() + + +def cleanup_model(model): + """Delete a model and free GPU memory.""" + del model + gc.collect() + torch.cuda.empty_cache() + + +# ============================================================================= +# Formatting +# ============================================================================= + + +def format_flops(flops): + """Format FLOPs with appropriate unit (G/T/P).""" + if flops >= 1e15: + return f"{flops / 1e15:.2f} P" + elif flops >= 1e12: + return f"{flops / 1e12:.2f} T" + elif flops >= 1e9: + return f"{flops / 1e9:.2f} G" + else: + return f"{flops:.2e}" + + +def format_flops_exact(flops): + """Format FLOPs as the full integer with commas.""" + return f"{int(flops):,}" + + +def format_bytes(num_bytes): + """Format bytes with appropriate unit.""" + if num_bytes >= 1e9: + return f"{num_bytes / 1e9:.2f} GB" + elif num_bytes >= 1e6: + return f"{num_bytes / 1e6:.2f} MB" + elif num_bytes >= 1e3: + return f"{num_bytes / 1e3:.2f} KB" + else: + return f"{num_bytes:.0f} B" + + +def print_breakdown(breakdown, lm_head_fwd, num_layers, total_flops, model_params): + """Print first-principles FLOPs breakdown.""" + print() + print("--- First Principles Breakdown (forward pass, per layer) ---") + per_layer_total = sum(breakdown.values()) + for component, flops_val in breakdown.items(): + pct = flops_val / per_layer_total * 100 + print(f" {component:<20} {format_flops(flops_val):>12} ({pct:>5.1f}%)") + total_fwd = num_layers * per_layer_total + lm_head_fwd + print(f" {'LM head':<20} {format_flops(lm_head_fwd):>12}") + print(f" {'Per-layer total':<20} {format_flops(per_layer_total):>12}") + print(f" {'All layers (x' + str(num_layers) + ')':<20} {format_flops(num_layers * per_layer_total):>12}") + print(f" {'Total forward':<20} {format_flops(total_fwd):>12}") + print(f" {'Total training (3x)':<20} {format_flops(total_flops):>12}") + print(f" {'Model params':<20} {model_params / 1e9:.2f}B") + + +# ============================================================================= +# CLI +# ============================================================================= + + +def _cli_bandwidth(): + """Measure P2P bandwidth. Launch with: torchrun --nproc_per_node=2 flops.py bandwidth.""" + import os + + dist.init_process_group(backend="cpu:gloo,cuda:nccl") + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + torch.cuda.set_device(local_rank) + rank = dist.get_rank() + world_size = dist.get_world_size() + device = torch.device(f"cuda:{local_rank}") + + if rank == 0: + print(f"Measuring P2P bandwidth between {world_size} GPUs...") + for i in range(world_size): + print(f" GPU {i}: {torch.cuda.get_device_name(i)}") + + bw = measure_bus_bandwidth(device, world_size) + if rank == 0: + print(f"\nUnidirectional P2P bandwidth: {bw:.2f} GB/s") + + dist.destroy_process_group() + + +def _cli_gpu_info(): + """Print GPU info and peak TFLOPS.""" + peak, name = detect_gpu_peak_tflops() + print(f"GPU: {name}") + if peak: + print(f"Peak bf16 TFLOPS: {peak:.1f}") + else: + print("Peak bf16 TFLOPS: unknown (use --peak-tflops to override)") + print() + print("Known GPUs:") + for gpu, tflops in GPU_PEAK_TFLOPS_BF16.items(): + print(f" {gpu:<16} {tflops:>8.1f} TFLOPS") + + +def _cli_flops(): + """Compute FLOPs for a model config.""" + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("command") + parser.add_argument("--config-path", default="./model_configs/lingua-1B") + parser.add_argument("--batch-size", type=int, default=1) + parser.add_argument("--seq-len", type=int, default=4096) + parser.add_argument("--formula", default="analytical", choices=["analytical", "simplified"]) + parser.add_argument("--num-gpus", type=int, default=1, help="Number of GPUs for per-GPU FLOPs and comm estimates") + parser.add_argument("--cp-size", type=int, default=1, help="Context parallelism size for comm overhead estimate") + parser.add_argument("--p2p-bw", type=float, default=6.0, help="P2P bandwidth in GB/s for comm time estimate") + args = parser.parse_args() + + cfg_dict = load_model_config(args.config_path) + config = from_hf_config(cfg_dict) + b, s = args.batch_size, args.seq_len + + print( + f"Config: H={config.hidden_size}, L={config.num_hidden_layers}," + f" n_heads={config.num_attention_heads}, n_kv={config.num_kv_heads}," + f" I={config.intermediate_size}, V={config.vocab_size}" + ) + print( + f"MLP: {config.num_mlp_projections} projections" + f" ({'SwiGLU/GLU' if config.num_mlp_projections == 3 else 'standard FFN'})" + ) + print(f"Batch: B={b}, S={s}, GPUs={args.num_gpus}, CP={args.cp_size}") + print() + + simplified = compute_flops_simplified(b, s, config.hidden_size, config.num_hidden_layers, config.vocab_size) + analytical, breakdown, lm_head = compute_flops_analytical(config, b, s) + + print(f"{'Method':<24} {'FLOPs/step':>14} {'Per-GPU':>14} {'Exact':>30}") + print("-" * 86) + for name, flops in [("Simplified (README)", simplified), ("Analytical", analytical)]: + per_gpu = flops // max(args.num_gpus, 1) + print(f"{name:<24} {format_flops(flops):>14} {format_flops(per_gpu):>14} {format_flops_exact(flops):>30}") + + if simplified != analytical: + diff = analytical - simplified + print(f"\nDifference: {format_flops_exact(diff)} ({diff / simplified * 100:+.2f}%)") + else: + print("\nFormulas agree exactly for this config.") + + # Communication overhead estimate + if args.cp_size > 1: + dp_size = args.num_gpus // args.cp_size + parallelism = {"dp": dp_size, "cp": args.cp_size} + tracker = MFUTracker(config, b, s, num_gpus=args.num_gpus, parallelism=parallelism) + print(f"\n--- Communication Overhead (CP={args.cp_size}, P2P BW={args.p2p_bw} GB/s) ---") + print( + f" CP ring attention: {format_bytes(tracker.comm_bytes)}/step ({format_flops_exact(tracker.comm_bytes)} bytes)" + ) + comm_time = tracker.comm_bytes / (args.p2p_bw * 1e9) if args.p2p_bw > 0 else 0 + print(f" Estimated comm time: {comm_time:.4f}s") + + model_params = _estimate_model_params(config) + print_breakdown(breakdown, lm_head, config.num_hidden_layers, analytical, model_params) + + +def _cli_cp_comm(): + """Estimate CP communication volume.""" + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("command") + parser.add_argument("--config-path", default="./model_configs/lingua-1B") + parser.add_argument("--batch-size", type=int, default=1) + parser.add_argument("--seq-len", type=int, default=16384) + parser.add_argument("--cp-size", type=int, default=2) + args = parser.parse_args() + + cfg_dict = load_model_config(args.config_path) + config = from_hf_config(cfg_dict) + b, s = args.batch_size, args.seq_len + + comm = estimate_cp_comm_bytes(b, s, config.num_hidden_layers, config.num_kv_heads, config.head_dim, args.cp_size) + print( + f"CP={args.cp_size}, B={b}, S={s}, L={config.num_hidden_layers}," + f" n_kv_heads={config.num_kv_heads}, head_dim={config.head_dim}" + ) + print(f"Estimated CP ring attention communication: {format_bytes(comm)}/step ({format_flops_exact(comm)} bytes)") + + +if __name__ == "__main__": + import sys + + commands = { + "bandwidth": ("Measure P2P bandwidth (requires torchrun --nproc_per_node=2)", _cli_bandwidth), + "gpu-info": ("Print GPU info and peak TFLOPS", _cli_gpu_info), + "flops": ("Compute FLOPs for a model config", _cli_flops), + "cp-comm": ("Estimate CP communication volume", _cli_cp_comm), + } + + if len(sys.argv) < 2 or sys.argv[1] in ("-h", "--help") or sys.argv[1] not in commands: + print("Usage: python flops.py [options]") + print(" torchrun --nproc_per_node=2 flops.py bandwidth") + print() + print("Commands:") + for cmd, (desc, _) in commands.items(): + print(f" {cmd:<16} {desc}") + sys.exit(0 if len(sys.argv) >= 2 and sys.argv[1] in ("-h", "--help") else 1) + + commands[sys.argv[1]][1]() diff --git a/bionemo-recipes/recipes/codonfm_native_te/flops.py b/bionemo-recipes/recipes/codonfm_native_te/flops.py new file mode 100644 index 0000000000..fafd718ecf --- /dev/null +++ b/bionemo-recipes/recipes/codonfm_native_te/flops.py @@ -0,0 +1,827 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# --- BEGIN COPIED FILE NOTICE --- +# This file is copied from: bionemo-recipes/models/esm2/flops.py +# Do not modify this file directly. Instead, modify the source and run: +# python ci/scripts/check_copied_files.py --fix +# --- END COPIED FILE NOTICE --- + +"""Architecture-independent FLOPs counting, MFU calculation, and communication overhead estimation. + +Supports transformer architectures (Llama, ESM2, CodonFM, etc.) and Hyena (Evo2). +Designed to be copied to any recipe via check_copied_files.py and hooked into +training scripts for live MFU tracking. + +Usage as a library (in training scripts): + from flops import MFUTracker, from_hf_config + tracker = MFUTracker.from_config_dict(config_dict, batch_size=4, seq_len=4096, num_gpus=2) + mfu_info = tracker.compute_mfu(step_time=0.5) + +Usage as a CLI: + python flops.py gpu-info + python flops.py flops --config-path ./model_configs/lingua-1B + python flops.py cp-comm --config-path ./model_configs/lingua-1B --cp-size 2 + torchrun --nproc_per_node=2 flops.py bandwidth +""" + +import gc +import math +import time +from contextlib import nullcontext +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from torch.utils.flop_counter import FlopCounterMode + + +# ============================================================================= +# GPU Peak TFLOPS +# ============================================================================= + +GPU_PEAK_TFLOPS_BF16 = { + "H100": 989.0, + "H200": 989.0, + "A100": 312.0, + "A6000": 155.0, + "A5000": 111.0, + "L40": 181.0, + "RTX 4090": 330.0, + "RTX 3090": 142.0, + "GH200": 989.0, + "B200": 2250.0, + "GB200": 2250.0, +} + + +def detect_gpu_peak_tflops(): + """Auto-detect GPU peak bf16 TFLOPS from device name via substring match.""" + device_name = torch.cuda.get_device_name(0) + for gpu_key, tflops in GPU_PEAK_TFLOPS_BF16.items(): + if gpu_key.lower() in device_name.lower(): + return tflops, device_name + return None, device_name + + +# ============================================================================= +# Model FLOPs Config +# ============================================================================= + +# Model types that use gated MLP (SwiGLU/GeGLU) with 3 projections instead of 2. +GATED_MLP_MODEL_TYPES = frozenset({"llama", "mistral", "qwen2"}) + + +@dataclass(frozen=True) +class ModelFLOPsConfig: + """Architecture-independent parameters for FLOPs calculation. + + Can be constructed manually or via from_hf_config() for auto-detection. + """ + + hidden_size: int # H + num_hidden_layers: int # L + num_attention_heads: int # n_heads + num_kv_heads: int # n_kv (== n_heads for MHA) + head_dim: int # H // n_heads + intermediate_size: int # I (FFN intermediate dimension) + num_mlp_projections: int # 2 (standard FFN) or 3 (SwiGLU/GLU) + vocab_size: int # V + has_lm_head: bool # True for LM models, False for ViT etc. + + +def from_hf_config(config_dict, **overrides): + """Create ModelFLOPsConfig from an HF-compatible config dict. + + Auto-detects architecture: + - GQA vs MHA: from num_key_value_heads (absent = MHA) + - Gated MLP (3 proj) vs standard FFN (2 proj): from model_type + - LM head: from vocab_size > 0 + + Args: + config_dict: Dict with standard HF config keys (hidden_size, num_hidden_layers, etc.). + Works with config.json dicts, config.to_dict(), or MODEL_PRESETS dicts. + **overrides: Explicit overrides for any field (e.g., num_mlp_projections=3). + """ + h = config_dict["hidden_size"] + n_heads = config_dict["num_attention_heads"] + n_kv = config_dict.get("num_key_value_heads", n_heads) + vocab = config_dict.get("vocab_size", 0) + model_type = config_dict.get("model_type", "") + + # Detect gated MLP (3 projections) vs standard FFN (2 projections). + # Llama/Mistral/Qwen use SwiGLU (gate + up + down = 3 projections). + # ESM2/CodonFM/Geneformer/BERT use standard FFN (up + down = 2 projections). + num_mlp_proj = 3 if model_type in GATED_MLP_MODEL_TYPES else 2 + + kwargs = { + "hidden_size": h, + "num_hidden_layers": config_dict["num_hidden_layers"], + "num_attention_heads": n_heads, + "num_kv_heads": n_kv, + "head_dim": h // n_heads, + "intermediate_size": config_dict["intermediate_size"], + "num_mlp_projections": num_mlp_proj, + "vocab_size": vocab, + "has_lm_head": vocab > 0, + } + kwargs.update(overrides) + return ModelFLOPsConfig(**kwargs) + + +# ============================================================================= +# FLOPs Formulas +# ============================================================================= + + +def compute_flops_analytical(config, batch_size, seq_len): + """First-principles FLOPs for any transformer (GQA/MHA, SwiGLU/GELU). + + Counts matmul FLOPs only (2 FLOPs per multiply-accumulate). Excludes softmax, + layer norms, activations, and element-wise ops. + + Handles: + - GQA vs MHA: K/V projection sizes based on config.num_kv_heads + - SwiGLU vs standard FFN: 2 or 3 MLP projections + - LM head presence + + Returns: + (total_training_flops, per_layer_breakdown_dict, lm_head_forward_flops) + """ + b, s, h = batch_size, seq_len, config.hidden_size + kv_dim = config.num_kv_heads * config.head_dim + + breakdown = { + "Q projection": 2 * b * s * h * h, + "K projection": 2 * b * s * h * kv_dim, + "V projection": 2 * b * s * h * kv_dim, + "O projection": 2 * b * s * h * h, + "Attn logits": 2 * b * s * s * h, + "Attn values": 2 * b * s * s * h, + } + + ffn = config.intermediate_size + if config.num_mlp_projections == 3: + # SwiGLU/GeGLU: gate + up + down = 3 matmuls + breakdown["Gate projection"] = 2 * b * s * h * ffn + breakdown["Up projection"] = 2 * b * s * h * ffn + breakdown["Down projection"] = 2 * b * s * ffn * h + else: + # Standard FFN: up + down = 2 matmuls + breakdown["Up projection"] = 2 * b * s * h * ffn + breakdown["Down projection"] = 2 * b * s * ffn * h + + per_layer_fwd = sum(breakdown.values()) + lm_head_fwd = 2 * b * s * h * config.vocab_size if config.has_lm_head else 0 + total_fwd = config.num_hidden_layers * per_layer_fwd + lm_head_fwd + total_training = 3 * total_fwd + + return total_training, breakdown, lm_head_fwd + + +def compute_flops_simplified(batch_size, seq_len, hidden_size, num_layers, vocab_size): + """Simplified formula assuming standard MHA + standard FFN with I=4H. + + This is the formula from the Llama3 README: + (24*B*S*H^2 + 4*B*S^2*H) * 3*L + 6*B*S*H*V + + The 24*H^2 coefficient assumes: 8*H^2 for 4 attention projections (MHA) + + 16*H^2 for 2 MLP projections with I=4H. See first_principles.md for details. + """ + b, s, h = batch_size, seq_len, hidden_size + return (24 * b * s * h * h + 4 * b * s * s * h) * (3 * num_layers) + (6 * b * s * h * vocab_size) + + +def compute_flops_hyena(config, batch_size, seq_len, hyena_layer_counts=None): + """FLOPs for Hyena-based models (Evo2). + + Based on evo2_provider.py. Hyena replaces attention with O(N) FFT convolution. + + Args: + config: ModelFLOPsConfig with model dimensions. + batch_size: Batch size. + seq_len: Sequence length. + hyena_layer_counts: Optional dict {"S": n, "D": n, "H": n, "A": n} for + short/medium/long conv and attention layer counts. If None, assumes + all layers are long-conv Hyena (H=num_layers, no attention). + """ + b, s, h = batch_size, seq_len, config.hidden_size + ffn = config.intermediate_size + + if hyena_layer_counts is None: + hyena_layer_counts = {"S": 0, "D": 0, "H": config.num_hidden_layers, "A": 0} + + # Common per-layer FLOPs + pre_attn_qkv_proj = 2 * 3 * b * s * h * h + post_attn_proj = 2 * b * s * h * h + glu_ffn = 2 * 3 * b * s * ffn * h + + # Layer-type-specific FLOPs (defaults from evo2_provider.py) + attn = 2 * 2 * b * h * s * s # Standard S^2 attention + hyena_proj = 2 * 3 * b * s * 3 * h # short_conv_L=3 default + hyena_short_conv = 2 * b * s * 7 * h # short_conv_len=7 + hyena_medium_conv = 2 * b * s * 128 * h # medium_conv_len=128 + hyena_long_fft = b * 10 * s * math.log2(max(s, 2)) * h + + n_s = hyena_layer_counts.get("S", 0) + n_d = hyena_layer_counts.get("D", 0) + n_h = hyena_layer_counts.get("H", 0) + n_a = hyena_layer_counts.get("A", 0) + + logits = 2 * b * s * h * config.vocab_size if config.has_lm_head else 0 + + total_fwd = ( + logits + + config.num_hidden_layers * (pre_attn_qkv_proj + post_attn_proj + glu_ffn) + + n_a * attn + + (n_s + n_d + n_h) * hyena_proj + + n_s * hyena_short_conv + + n_d * hyena_medium_conv + + int(n_h * hyena_long_fft) + ) + + return 3 * total_fwd + + +# Backward-compatible wrappers for existing compare_mfu*.py scripts. + + +def compute_flops_first_principles(b, s, h, num_layers, n_kv_heads, head_dim, ffn_hidden_size, vocab_size): + """Backward-compatible wrapper. Assumes SwiGLU (3 MLP projections).""" + config = ModelFLOPsConfig( + hidden_size=h, + num_hidden_layers=num_layers, + num_attention_heads=h // head_dim, + num_kv_heads=n_kv_heads, + head_dim=head_dim, + intermediate_size=ffn_hidden_size, + num_mlp_projections=3, + vocab_size=vocab_size, + has_lm_head=True, + ) + return compute_flops_analytical(config, b, s) + + +def compute_flops_readme(b, s, h, num_layers, vocab_size): + """Backward-compatible wrapper for the simplified README formula.""" + return compute_flops_simplified(b, s, h, num_layers, vocab_size) + + +# ============================================================================= +# MFU Tracker +# ============================================================================= + + +class MFUTracker: + """Tracks MFU during training. Initialize once, call compute_mfu() per step. + + Usage: + tracker = MFUTracker.from_config_dict(config_dict, batch_size=4, seq_len=4096, num_gpus=2) + # In training loop: + mfu_info = tracker.compute_mfu(step_time=0.5) + print(f"MFU: {mfu_info['mfu']:.1f}%") + """ + + def __init__( + self, + config, + batch_size, + seq_len, + num_gpus=1, + parallelism=None, + peak_tflops=None, + formula="analytical", + hyena_layer_counts=None, + ): + """Initialize MFU tracker. + + Args: + config: ModelFLOPsConfig instance. + batch_size: Micro batch size per GPU. + seq_len: Sequence length. + num_gpus: Total number of GPUs. + parallelism: Dict of parallelism dimensions, e.g. {"dp": 1, "cp": 2, "tp": 1}. + Used for communication overhead estimation. + peak_tflops: GPU peak bf16 TFLOPS. Auto-detected if None. + formula: "analytical", "simplified", or "hyena". + hyena_layer_counts: For Hyena formula, dict of layer type counts. + """ + self.config = config + self.batch_size = batch_size + self.seq_len = seq_len + self.num_gpus = num_gpus + self.parallelism = parallelism or {} + self.formula = formula + + if formula == "analytical": + self.total_flops, self.breakdown, self.lm_head_flops = compute_flops_analytical( + config, batch_size, seq_len + ) + elif formula == "simplified": + self.total_flops = compute_flops_simplified( + batch_size, seq_len, config.hidden_size, config.num_hidden_layers, config.vocab_size + ) + self.breakdown = None + self.lm_head_flops = 0 + elif formula == "hyena": + self.total_flops = compute_flops_hyena(config, batch_size, seq_len, hyena_layer_counts) + self.breakdown = None + self.lm_head_flops = 0 + else: + raise ValueError(f"Unknown formula: {formula!r}. Use 'analytical', 'simplified', or 'hyena'.") + + self.per_gpu_flops = self.total_flops // max(num_gpus, 1) + + if peak_tflops is not None: + self.peak_tflops = peak_tflops + self.device_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "unknown" + else: + detected, self.device_name = detect_gpu_peak_tflops() + self.peak_tflops = detected + + self.comm_bytes = self._estimate_comm() + + @classmethod + def from_config_dict(cls, config_dict, batch_size, seq_len, **kwargs): + """Create from an HF config dict with auto-detection.""" + config = from_hf_config(config_dict) + return cls(config, batch_size, seq_len, **kwargs) + + def compute_mfu(self, step_time): + """Compute MFU from measured step time. + + Args: + step_time: Wall-clock time for one training step (seconds). + + Returns: + Dict with mfu (%), tflops_per_gpu, per_gpu_flops, total_flops, step_time. + """ + tflops = self.per_gpu_flops / step_time / 1e12 + mfu = tflops / self.peak_tflops * 100 if self.peak_tflops else 0.0 + return { + "mfu": mfu, + "tflops_per_gpu": tflops, + "per_gpu_flops": self.per_gpu_flops, + "total_flops": self.total_flops, + "step_time": step_time, + } + + def estimate_comm_overhead(self, step_time, measured_bw_gbps=None): + """Estimate communication overhead as a fraction of step time. + + Args: + step_time: Measured step time in seconds. + measured_bw_gbps: Measured P2P bandwidth in GB/s. If None, uses a default estimate. + + Returns: + Dict with comm_bytes, estimated_comm_time, comm_pct. + """ + bw = measured_bw_gbps or 6.0 # Default ~PCIe Gen3 x8 + comm_time = self.comm_bytes / (bw * 1e9) if bw > 0 else 0.0 + comm_pct = comm_time / step_time * 100 if step_time > 0 else 0.0 + return {"comm_bytes": self.comm_bytes, "estimated_comm_time": comm_time, "comm_pct": comm_pct} + + def _estimate_comm(self): + """Estimate total communication bytes per step based on parallelism.""" + total = 0 + cp_size = self.parallelism.get("cp", 1) + dp_size = self.parallelism.get("dp", 1) + + if cp_size > 1: + total += estimate_cp_comm_bytes( + self.batch_size, + self.seq_len, + self.config.num_hidden_layers, + self.config.num_kv_heads, + self.config.head_dim, + cp_size, + ) + + if dp_size > 1: + # FSDP reduce-scatter estimate: ~2 * model_params * dtype_bytes * (dp-1)/dp + model_params = _estimate_model_params(self.config) + total += 2 * model_params * 2 * (dp_size - 1) // dp_size + + return total + + def summary(self): + """Return a human-readable summary string.""" + lines = [ + f"MFUTracker: {self.formula} formula, {self.num_gpus} GPU(s)", + f" Model: H={self.config.hidden_size}, L={self.config.num_hidden_layers}," + f" heads={self.config.num_attention_heads}, kv_heads={self.config.num_kv_heads}," + f" I={self.config.intermediate_size}, V={self.config.vocab_size}", + f" MLP projections: {self.config.num_mlp_projections}" + f" ({'SwiGLU/GLU' if self.config.num_mlp_projections == 3 else 'standard FFN'})", + f" Batch: B={self.batch_size}, S={self.seq_len}", + f" Total FLOPs/step: {format_flops(self.total_flops)} ({format_flops_exact(self.total_flops)})", + f" Per-GPU FLOPs: {format_flops(self.per_gpu_flops)}", + f" GPU: {self.device_name} (Peak: {self.peak_tflops} TFLOPS)" if self.peak_tflops else " GPU: unknown", + ] + if self.parallelism: + lines.append(f" Parallelism: {self.parallelism}") + if self.comm_bytes > 0: + lines.append(f" Estimated comm: {format_bytes(self.comm_bytes)}/step") + return "\n".join(lines) + + +# ============================================================================= +# Communication Estimation +# ============================================================================= + + +def estimate_cp_comm_bytes(b, s, num_layers, n_kv_heads, head_dim, cp_size, dtype_bytes=2): + """Estimate total bytes transferred for CP ring attention per training step. + + Ring attention sends local KV chunks around the ring. Per layer forward: + (cp-1) steps, each sending B * (S/cp) * 2 * kv_dim * dtype_bytes. + Training = ~2x forward communication (forward sends KV, backward sends dKV). + """ + if cp_size <= 1: + return 0 + s_local = s // cp_size + kv_dim = n_kv_heads * head_dim + per_layer_fwd = (cp_size - 1) * b * s_local * 2 * kv_dim * dtype_bytes + return 2 * num_layers * per_layer_fwd + + +def _estimate_model_params(config): + """Rough parameter count estimate from config dimensions.""" + h = config.hidden_size + kv_dim = config.num_kv_heads * config.head_dim + attn_params = h * h + 2 * h * kv_dim + h * h # Q + K + V + O + mlp_params = config.num_mlp_projections * h * config.intermediate_size + layer_params = attn_params + mlp_params + total = config.num_hidden_layers * layer_params + if config.has_lm_head: + total += config.vocab_size * h * 2 # embed + lm_head + return total + + +# ============================================================================= +# Step Time Measurement +# ============================================================================= + + +def measure_step_time( + model, + input_ids, + num_warmup=10, + num_timed=20, + distributed=False, + cp_context_fn=None, + labels=None, + position_ids=None, + **extra_fwd_kwargs, +): + """Measure average training step time (forward + backward). + + Args: + model: The model to benchmark. + input_ids: Input tensor. + num_warmup: Warmup iterations (discarded). + num_timed: Timed iterations to average. + distributed: Whether to use dist.barrier() for synchronization. + cp_context_fn: Optional callable returning a context manager for CP. + labels: Optional labels tensor. If None, uses input_ids. + position_ids: Optional position_ids for correct RoPE with CP. + **extra_fwd_kwargs: Additional kwargs for model forward. + """ + if labels is None: + labels = input_ids + + fwd_kwargs = {"input_ids": input_ids, "labels": labels, **extra_fwd_kwargs} + if position_ids is not None: + fwd_kwargs["position_ids"] = position_ids + + for _ in range(num_warmup): + ctx = cp_context_fn() if cp_context_fn else nullcontext() + with ctx: + output = model(**fwd_kwargs) + output.loss.backward() + model.zero_grad(set_to_none=True) + + if distributed: + dist.barrier() + torch.cuda.synchronize() + + times = [] + for _ in range(num_timed): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start.record() + + ctx = cp_context_fn() if cp_context_fn else nullcontext() + with ctx: + output = model(**fwd_kwargs) + output.loss.backward() + model.zero_grad(set_to_none=True) + + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end) / 1000.0) + + return sum(times) / len(times) + + +# ============================================================================= +# Utilities +# ============================================================================= + + +def split_for_cp_bshd(tensor, cp_rank, cp_size): + """Split a BSHD tensor for CP using the dual-chunk zigzag pattern.""" + if cp_size <= 1: + return tensor + total_chunks = 2 * cp_size + seq_len = tensor.size(1) + chunk_size = seq_len // total_chunks + chunk_indices = [cp_rank, total_chunks - cp_rank - 1] + slices = [tensor[:, idx * chunk_size : (idx + 1) * chunk_size] for idx in chunk_indices] + return torch.cat(slices, dim=1) + + +def measure_bus_bandwidth(device, world_size, num_iters=20, num_elements=10_000_000): + """Measure unidirectional P2P bandwidth via send/recv (matches CP ring pattern).""" + if world_size <= 1: + return 0.0 + + rank = dist.get_rank() + tensor = torch.randn(num_elements, device=device, dtype=torch.bfloat16) + peer = 1 - rank + + for _ in range(5): + if rank == 0: + dist.send(tensor, dst=peer) + dist.recv(tensor, src=peer) + else: + dist.recv(tensor, src=peer) + dist.send(tensor, dst=peer) + torch.cuda.synchronize() + + dist.barrier() + torch.cuda.synchronize() + start = time.perf_counter() + for _ in range(num_iters): + if rank == 0: + dist.send(tensor, dst=peer) + else: + dist.recv(tensor, src=peer) + torch.cuda.synchronize() + elapsed = time.perf_counter() - start + + data_bytes = tensor.nelement() * tensor.element_size() + return num_iters * data_bytes / elapsed / 1e9 + + +def count_flops_with_model(model, input_ids): + """Count forward FLOPs using PyTorch's FlopCounterMode, return 3x for training.""" + flop_counter = FlopCounterMode(display=False) + with flop_counter: + model(input_ids=input_ids) + return flop_counter.get_total_flops() * 3 + + +def load_model_config(config_path): + """Load model config dict from a local path or HuggingFace model ID. + + Supports: + - Local directory: ./model_configs/lingua-1B (reads config.json inside) + - Local file: ./model_configs/lingua-1B/config.json + - HF model ID: nvidia/esm2_t36_3B_UR50D (fetches from HuggingFace Hub) + """ + import json + from pathlib import Path + + path = Path(config_path) + if path.is_dir(): + path = path / "config.json" + if path.exists(): + return json.loads(path.read_text()) + + # Fall back to HuggingFace Hub + from transformers import AutoConfig + + hf_config = AutoConfig.from_pretrained(config_path, trust_remote_code=True) + return hf_config.to_dict() + + +def cleanup_model(model): + """Delete a model and free GPU memory.""" + del model + gc.collect() + torch.cuda.empty_cache() + + +# ============================================================================= +# Formatting +# ============================================================================= + + +def format_flops(flops): + """Format FLOPs with appropriate unit (G/T/P).""" + if flops >= 1e15: + return f"{flops / 1e15:.2f} P" + elif flops >= 1e12: + return f"{flops / 1e12:.2f} T" + elif flops >= 1e9: + return f"{flops / 1e9:.2f} G" + else: + return f"{flops:.2e}" + + +def format_flops_exact(flops): + """Format FLOPs as the full integer with commas.""" + return f"{int(flops):,}" + + +def format_bytes(num_bytes): + """Format bytes with appropriate unit.""" + if num_bytes >= 1e9: + return f"{num_bytes / 1e9:.2f} GB" + elif num_bytes >= 1e6: + return f"{num_bytes / 1e6:.2f} MB" + elif num_bytes >= 1e3: + return f"{num_bytes / 1e3:.2f} KB" + else: + return f"{num_bytes:.0f} B" + + +def print_breakdown(breakdown, lm_head_fwd, num_layers, total_flops, model_params): + """Print first-principles FLOPs breakdown.""" + print() + print("--- First Principles Breakdown (forward pass, per layer) ---") + per_layer_total = sum(breakdown.values()) + for component, flops_val in breakdown.items(): + pct = flops_val / per_layer_total * 100 + print(f" {component:<20} {format_flops(flops_val):>12} ({pct:>5.1f}%)") + total_fwd = num_layers * per_layer_total + lm_head_fwd + print(f" {'LM head':<20} {format_flops(lm_head_fwd):>12}") + print(f" {'Per-layer total':<20} {format_flops(per_layer_total):>12}") + print(f" {'All layers (x' + str(num_layers) + ')':<20} {format_flops(num_layers * per_layer_total):>12}") + print(f" {'Total forward':<20} {format_flops(total_fwd):>12}") + print(f" {'Total training (3x)':<20} {format_flops(total_flops):>12}") + print(f" {'Model params':<20} {model_params / 1e9:.2f}B") + + +# ============================================================================= +# CLI +# ============================================================================= + + +def _cli_bandwidth(): + """Measure P2P bandwidth. Launch with: torchrun --nproc_per_node=2 flops.py bandwidth.""" + import os + + dist.init_process_group(backend="cpu:gloo,cuda:nccl") + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + torch.cuda.set_device(local_rank) + rank = dist.get_rank() + world_size = dist.get_world_size() + device = torch.device(f"cuda:{local_rank}") + + if rank == 0: + print(f"Measuring P2P bandwidth between {world_size} GPUs...") + for i in range(world_size): + print(f" GPU {i}: {torch.cuda.get_device_name(i)}") + + bw = measure_bus_bandwidth(device, world_size) + if rank == 0: + print(f"\nUnidirectional P2P bandwidth: {bw:.2f} GB/s") + + dist.destroy_process_group() + + +def _cli_gpu_info(): + """Print GPU info and peak TFLOPS.""" + peak, name = detect_gpu_peak_tflops() + print(f"GPU: {name}") + if peak: + print(f"Peak bf16 TFLOPS: {peak:.1f}") + else: + print("Peak bf16 TFLOPS: unknown (use --peak-tflops to override)") + print() + print("Known GPUs:") + for gpu, tflops in GPU_PEAK_TFLOPS_BF16.items(): + print(f" {gpu:<16} {tflops:>8.1f} TFLOPS") + + +def _cli_flops(): + """Compute FLOPs for a model config.""" + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("command") + parser.add_argument("--config-path", default="./model_configs/lingua-1B") + parser.add_argument("--batch-size", type=int, default=1) + parser.add_argument("--seq-len", type=int, default=4096) + parser.add_argument("--formula", default="analytical", choices=["analytical", "simplified"]) + parser.add_argument("--num-gpus", type=int, default=1, help="Number of GPUs for per-GPU FLOPs and comm estimates") + parser.add_argument("--cp-size", type=int, default=1, help="Context parallelism size for comm overhead estimate") + parser.add_argument("--p2p-bw", type=float, default=6.0, help="P2P bandwidth in GB/s for comm time estimate") + args = parser.parse_args() + + cfg_dict = load_model_config(args.config_path) + config = from_hf_config(cfg_dict) + b, s = args.batch_size, args.seq_len + + print( + f"Config: H={config.hidden_size}, L={config.num_hidden_layers}," + f" n_heads={config.num_attention_heads}, n_kv={config.num_kv_heads}," + f" I={config.intermediate_size}, V={config.vocab_size}" + ) + print( + f"MLP: {config.num_mlp_projections} projections" + f" ({'SwiGLU/GLU' if config.num_mlp_projections == 3 else 'standard FFN'})" + ) + print(f"Batch: B={b}, S={s}, GPUs={args.num_gpus}, CP={args.cp_size}") + print() + + simplified = compute_flops_simplified(b, s, config.hidden_size, config.num_hidden_layers, config.vocab_size) + analytical, breakdown, lm_head = compute_flops_analytical(config, b, s) + + print(f"{'Method':<24} {'FLOPs/step':>14} {'Per-GPU':>14} {'Exact':>30}") + print("-" * 86) + for name, flops in [("Simplified (README)", simplified), ("Analytical", analytical)]: + per_gpu = flops // max(args.num_gpus, 1) + print(f"{name:<24} {format_flops(flops):>14} {format_flops(per_gpu):>14} {format_flops_exact(flops):>30}") + + if simplified != analytical: + diff = analytical - simplified + print(f"\nDifference: {format_flops_exact(diff)} ({diff / simplified * 100:+.2f}%)") + else: + print("\nFormulas agree exactly for this config.") + + # Communication overhead estimate + if args.cp_size > 1: + dp_size = args.num_gpus // args.cp_size + parallelism = {"dp": dp_size, "cp": args.cp_size} + tracker = MFUTracker(config, b, s, num_gpus=args.num_gpus, parallelism=parallelism) + print(f"\n--- Communication Overhead (CP={args.cp_size}, P2P BW={args.p2p_bw} GB/s) ---") + print( + f" CP ring attention: {format_bytes(tracker.comm_bytes)}/step ({format_flops_exact(tracker.comm_bytes)} bytes)" + ) + comm_time = tracker.comm_bytes / (args.p2p_bw * 1e9) if args.p2p_bw > 0 else 0 + print(f" Estimated comm time: {comm_time:.4f}s") + + model_params = _estimate_model_params(config) + print_breakdown(breakdown, lm_head, config.num_hidden_layers, analytical, model_params) + + +def _cli_cp_comm(): + """Estimate CP communication volume.""" + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("command") + parser.add_argument("--config-path", default="./model_configs/lingua-1B") + parser.add_argument("--batch-size", type=int, default=1) + parser.add_argument("--seq-len", type=int, default=16384) + parser.add_argument("--cp-size", type=int, default=2) + args = parser.parse_args() + + cfg_dict = load_model_config(args.config_path) + config = from_hf_config(cfg_dict) + b, s = args.batch_size, args.seq_len + + comm = estimate_cp_comm_bytes(b, s, config.num_hidden_layers, config.num_kv_heads, config.head_dim, args.cp_size) + print( + f"CP={args.cp_size}, B={b}, S={s}, L={config.num_hidden_layers}," + f" n_kv_heads={config.num_kv_heads}, head_dim={config.head_dim}" + ) + print(f"Estimated CP ring attention communication: {format_bytes(comm)}/step ({format_flops_exact(comm)} bytes)") + + +if __name__ == "__main__": + import sys + + commands = { + "bandwidth": ("Measure P2P bandwidth (requires torchrun --nproc_per_node=2)", _cli_bandwidth), + "gpu-info": ("Print GPU info and peak TFLOPS", _cli_gpu_info), + "flops": ("Compute FLOPs for a model config", _cli_flops), + "cp-comm": ("Estimate CP communication volume", _cli_cp_comm), + } + + if len(sys.argv) < 2 or sys.argv[1] in ("-h", "--help") or sys.argv[1] not in commands: + print("Usage: python flops.py [options]") + print(" torchrun --nproc_per_node=2 flops.py bandwidth") + print() + print("Commands:") + for cmd, (desc, _) in commands.items(): + print(f" {cmd:<16} {desc}") + sys.exit(0 if len(sys.argv) >= 2 and sys.argv[1] in ("-h", "--help") else 1) + + commands[sys.argv[1]][1]() diff --git a/bionemo-recipes/recipes/codonfm_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/codonfm_native_te/hydra_config/defaults.yaml index 3a97660834..fa581d601f 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/hydra_config/defaults.yaml +++ b/bionemo-recipes/recipes/codonfm_native_te/hydra_config/defaults.yaml @@ -65,3 +65,4 @@ quant_stats_config: fp8_layers: null fp4_layers: null use_fp32_master_weights: null +log_mfu: false diff --git a/bionemo-recipes/recipes/codonfm_native_te/train_fsdp2.py b/bionemo-recipes/recipes/codonfm_native_te/train_fsdp2.py index 8b07f8954e..4b2608c88f 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/codonfm_native_te/train_fsdp2.py @@ -16,6 +16,7 @@ """FSDP2 training script for CodonFM with TransformerEngine layers.""" import logging +import time from contextlib import nullcontext from pathlib import Path @@ -25,6 +26,7 @@ from checkpoint import load_checkpoint_fsdp2, save_checkpoint_fsdp2, save_final_model_fsdp2, should_save_checkpoint from dataset import create_bshd_dataloader, create_thd_dataloader from distributed_config import DistributedConfig +from flops import MFUTracker, from_hf_config from modeling_codonfm_te import MODEL_PRESETS, CodonFMConfig, CodonFMForMaskedLM from omegaconf import DictConfig, OmegaConf from perf_logger import PerfLogger @@ -165,9 +167,23 @@ def main(args: DictConfig) -> float | None: perf_logger = PerfLogger(dist_config, args) + # --- MFU Tracking (optional) --- + mfu_tracker = None + if args.get("log_mfu", False): + mfu_tracker = MFUTracker( + config=from_hf_config(config.to_dict()), + batch_size=args.dataset.micro_batch_size, + seq_len=args.dataset.max_seq_length, + num_gpus=dist_config.world_size, + parallelism={"dp": dist_config.world_size}, + ) + if dist_config.is_main_process(): + logger.info("MFU tracking enabled:\n%s", mfu_tracker.summary()) + # Training loop step = start_step micro_step = 0 # Gradient accumulation step counter + step_start_time = time.perf_counter() while step < args.num_train_steps: batches_in_epoch = 0 for batch in train_dataloader: @@ -203,6 +219,13 @@ def main(args: DictConfig) -> float | None: lr=optimizer.param_groups[0]["lr"], ) + if mfu_tracker is not None: + step_time = time.perf_counter() - step_start_time + mfu_info = mfu_tracker.compute_mfu(step_time) + if dist_config.is_main_process(): + logger.info("MFU: %.1f%% (%.2f TFLOPS/GPU)", mfu_info["mfu"], mfu_info["tflops_per_gpu"]) + step_start_time = time.perf_counter() + if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps): save_checkpoint_fsdp2( model=model, diff --git a/bionemo-recipes/recipes/esm2_native_te/flops.py b/bionemo-recipes/recipes/esm2_native_te/flops.py new file mode 100644 index 0000000000..fafd718ecf --- /dev/null +++ b/bionemo-recipes/recipes/esm2_native_te/flops.py @@ -0,0 +1,827 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# --- BEGIN COPIED FILE NOTICE --- +# This file is copied from: bionemo-recipes/models/esm2/flops.py +# Do not modify this file directly. Instead, modify the source and run: +# python ci/scripts/check_copied_files.py --fix +# --- END COPIED FILE NOTICE --- + +"""Architecture-independent FLOPs counting, MFU calculation, and communication overhead estimation. + +Supports transformer architectures (Llama, ESM2, CodonFM, etc.) and Hyena (Evo2). +Designed to be copied to any recipe via check_copied_files.py and hooked into +training scripts for live MFU tracking. + +Usage as a library (in training scripts): + from flops import MFUTracker, from_hf_config + tracker = MFUTracker.from_config_dict(config_dict, batch_size=4, seq_len=4096, num_gpus=2) + mfu_info = tracker.compute_mfu(step_time=0.5) + +Usage as a CLI: + python flops.py gpu-info + python flops.py flops --config-path ./model_configs/lingua-1B + python flops.py cp-comm --config-path ./model_configs/lingua-1B --cp-size 2 + torchrun --nproc_per_node=2 flops.py bandwidth +""" + +import gc +import math +import time +from contextlib import nullcontext +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from torch.utils.flop_counter import FlopCounterMode + + +# ============================================================================= +# GPU Peak TFLOPS +# ============================================================================= + +GPU_PEAK_TFLOPS_BF16 = { + "H100": 989.0, + "H200": 989.0, + "A100": 312.0, + "A6000": 155.0, + "A5000": 111.0, + "L40": 181.0, + "RTX 4090": 330.0, + "RTX 3090": 142.0, + "GH200": 989.0, + "B200": 2250.0, + "GB200": 2250.0, +} + + +def detect_gpu_peak_tflops(): + """Auto-detect GPU peak bf16 TFLOPS from device name via substring match.""" + device_name = torch.cuda.get_device_name(0) + for gpu_key, tflops in GPU_PEAK_TFLOPS_BF16.items(): + if gpu_key.lower() in device_name.lower(): + return tflops, device_name + return None, device_name + + +# ============================================================================= +# Model FLOPs Config +# ============================================================================= + +# Model types that use gated MLP (SwiGLU/GeGLU) with 3 projections instead of 2. +GATED_MLP_MODEL_TYPES = frozenset({"llama", "mistral", "qwen2"}) + + +@dataclass(frozen=True) +class ModelFLOPsConfig: + """Architecture-independent parameters for FLOPs calculation. + + Can be constructed manually or via from_hf_config() for auto-detection. + """ + + hidden_size: int # H + num_hidden_layers: int # L + num_attention_heads: int # n_heads + num_kv_heads: int # n_kv (== n_heads for MHA) + head_dim: int # H // n_heads + intermediate_size: int # I (FFN intermediate dimension) + num_mlp_projections: int # 2 (standard FFN) or 3 (SwiGLU/GLU) + vocab_size: int # V + has_lm_head: bool # True for LM models, False for ViT etc. + + +def from_hf_config(config_dict, **overrides): + """Create ModelFLOPsConfig from an HF-compatible config dict. + + Auto-detects architecture: + - GQA vs MHA: from num_key_value_heads (absent = MHA) + - Gated MLP (3 proj) vs standard FFN (2 proj): from model_type + - LM head: from vocab_size > 0 + + Args: + config_dict: Dict with standard HF config keys (hidden_size, num_hidden_layers, etc.). + Works with config.json dicts, config.to_dict(), or MODEL_PRESETS dicts. + **overrides: Explicit overrides for any field (e.g., num_mlp_projections=3). + """ + h = config_dict["hidden_size"] + n_heads = config_dict["num_attention_heads"] + n_kv = config_dict.get("num_key_value_heads", n_heads) + vocab = config_dict.get("vocab_size", 0) + model_type = config_dict.get("model_type", "") + + # Detect gated MLP (3 projections) vs standard FFN (2 projections). + # Llama/Mistral/Qwen use SwiGLU (gate + up + down = 3 projections). + # ESM2/CodonFM/Geneformer/BERT use standard FFN (up + down = 2 projections). + num_mlp_proj = 3 if model_type in GATED_MLP_MODEL_TYPES else 2 + + kwargs = { + "hidden_size": h, + "num_hidden_layers": config_dict["num_hidden_layers"], + "num_attention_heads": n_heads, + "num_kv_heads": n_kv, + "head_dim": h // n_heads, + "intermediate_size": config_dict["intermediate_size"], + "num_mlp_projections": num_mlp_proj, + "vocab_size": vocab, + "has_lm_head": vocab > 0, + } + kwargs.update(overrides) + return ModelFLOPsConfig(**kwargs) + + +# ============================================================================= +# FLOPs Formulas +# ============================================================================= + + +def compute_flops_analytical(config, batch_size, seq_len): + """First-principles FLOPs for any transformer (GQA/MHA, SwiGLU/GELU). + + Counts matmul FLOPs only (2 FLOPs per multiply-accumulate). Excludes softmax, + layer norms, activations, and element-wise ops. + + Handles: + - GQA vs MHA: K/V projection sizes based on config.num_kv_heads + - SwiGLU vs standard FFN: 2 or 3 MLP projections + - LM head presence + + Returns: + (total_training_flops, per_layer_breakdown_dict, lm_head_forward_flops) + """ + b, s, h = batch_size, seq_len, config.hidden_size + kv_dim = config.num_kv_heads * config.head_dim + + breakdown = { + "Q projection": 2 * b * s * h * h, + "K projection": 2 * b * s * h * kv_dim, + "V projection": 2 * b * s * h * kv_dim, + "O projection": 2 * b * s * h * h, + "Attn logits": 2 * b * s * s * h, + "Attn values": 2 * b * s * s * h, + } + + ffn = config.intermediate_size + if config.num_mlp_projections == 3: + # SwiGLU/GeGLU: gate + up + down = 3 matmuls + breakdown["Gate projection"] = 2 * b * s * h * ffn + breakdown["Up projection"] = 2 * b * s * h * ffn + breakdown["Down projection"] = 2 * b * s * ffn * h + else: + # Standard FFN: up + down = 2 matmuls + breakdown["Up projection"] = 2 * b * s * h * ffn + breakdown["Down projection"] = 2 * b * s * ffn * h + + per_layer_fwd = sum(breakdown.values()) + lm_head_fwd = 2 * b * s * h * config.vocab_size if config.has_lm_head else 0 + total_fwd = config.num_hidden_layers * per_layer_fwd + lm_head_fwd + total_training = 3 * total_fwd + + return total_training, breakdown, lm_head_fwd + + +def compute_flops_simplified(batch_size, seq_len, hidden_size, num_layers, vocab_size): + """Simplified formula assuming standard MHA + standard FFN with I=4H. + + This is the formula from the Llama3 README: + (24*B*S*H^2 + 4*B*S^2*H) * 3*L + 6*B*S*H*V + + The 24*H^2 coefficient assumes: 8*H^2 for 4 attention projections (MHA) + + 16*H^2 for 2 MLP projections with I=4H. See first_principles.md for details. + """ + b, s, h = batch_size, seq_len, hidden_size + return (24 * b * s * h * h + 4 * b * s * s * h) * (3 * num_layers) + (6 * b * s * h * vocab_size) + + +def compute_flops_hyena(config, batch_size, seq_len, hyena_layer_counts=None): + """FLOPs for Hyena-based models (Evo2). + + Based on evo2_provider.py. Hyena replaces attention with O(N) FFT convolution. + + Args: + config: ModelFLOPsConfig with model dimensions. + batch_size: Batch size. + seq_len: Sequence length. + hyena_layer_counts: Optional dict {"S": n, "D": n, "H": n, "A": n} for + short/medium/long conv and attention layer counts. If None, assumes + all layers are long-conv Hyena (H=num_layers, no attention). + """ + b, s, h = batch_size, seq_len, config.hidden_size + ffn = config.intermediate_size + + if hyena_layer_counts is None: + hyena_layer_counts = {"S": 0, "D": 0, "H": config.num_hidden_layers, "A": 0} + + # Common per-layer FLOPs + pre_attn_qkv_proj = 2 * 3 * b * s * h * h + post_attn_proj = 2 * b * s * h * h + glu_ffn = 2 * 3 * b * s * ffn * h + + # Layer-type-specific FLOPs (defaults from evo2_provider.py) + attn = 2 * 2 * b * h * s * s # Standard S^2 attention + hyena_proj = 2 * 3 * b * s * 3 * h # short_conv_L=3 default + hyena_short_conv = 2 * b * s * 7 * h # short_conv_len=7 + hyena_medium_conv = 2 * b * s * 128 * h # medium_conv_len=128 + hyena_long_fft = b * 10 * s * math.log2(max(s, 2)) * h + + n_s = hyena_layer_counts.get("S", 0) + n_d = hyena_layer_counts.get("D", 0) + n_h = hyena_layer_counts.get("H", 0) + n_a = hyena_layer_counts.get("A", 0) + + logits = 2 * b * s * h * config.vocab_size if config.has_lm_head else 0 + + total_fwd = ( + logits + + config.num_hidden_layers * (pre_attn_qkv_proj + post_attn_proj + glu_ffn) + + n_a * attn + + (n_s + n_d + n_h) * hyena_proj + + n_s * hyena_short_conv + + n_d * hyena_medium_conv + + int(n_h * hyena_long_fft) + ) + + return 3 * total_fwd + + +# Backward-compatible wrappers for existing compare_mfu*.py scripts. + + +def compute_flops_first_principles(b, s, h, num_layers, n_kv_heads, head_dim, ffn_hidden_size, vocab_size): + """Backward-compatible wrapper. Assumes SwiGLU (3 MLP projections).""" + config = ModelFLOPsConfig( + hidden_size=h, + num_hidden_layers=num_layers, + num_attention_heads=h // head_dim, + num_kv_heads=n_kv_heads, + head_dim=head_dim, + intermediate_size=ffn_hidden_size, + num_mlp_projections=3, + vocab_size=vocab_size, + has_lm_head=True, + ) + return compute_flops_analytical(config, b, s) + + +def compute_flops_readme(b, s, h, num_layers, vocab_size): + """Backward-compatible wrapper for the simplified README formula.""" + return compute_flops_simplified(b, s, h, num_layers, vocab_size) + + +# ============================================================================= +# MFU Tracker +# ============================================================================= + + +class MFUTracker: + """Tracks MFU during training. Initialize once, call compute_mfu() per step. + + Usage: + tracker = MFUTracker.from_config_dict(config_dict, batch_size=4, seq_len=4096, num_gpus=2) + # In training loop: + mfu_info = tracker.compute_mfu(step_time=0.5) + print(f"MFU: {mfu_info['mfu']:.1f}%") + """ + + def __init__( + self, + config, + batch_size, + seq_len, + num_gpus=1, + parallelism=None, + peak_tflops=None, + formula="analytical", + hyena_layer_counts=None, + ): + """Initialize MFU tracker. + + Args: + config: ModelFLOPsConfig instance. + batch_size: Micro batch size per GPU. + seq_len: Sequence length. + num_gpus: Total number of GPUs. + parallelism: Dict of parallelism dimensions, e.g. {"dp": 1, "cp": 2, "tp": 1}. + Used for communication overhead estimation. + peak_tflops: GPU peak bf16 TFLOPS. Auto-detected if None. + formula: "analytical", "simplified", or "hyena". + hyena_layer_counts: For Hyena formula, dict of layer type counts. + """ + self.config = config + self.batch_size = batch_size + self.seq_len = seq_len + self.num_gpus = num_gpus + self.parallelism = parallelism or {} + self.formula = formula + + if formula == "analytical": + self.total_flops, self.breakdown, self.lm_head_flops = compute_flops_analytical( + config, batch_size, seq_len + ) + elif formula == "simplified": + self.total_flops = compute_flops_simplified( + batch_size, seq_len, config.hidden_size, config.num_hidden_layers, config.vocab_size + ) + self.breakdown = None + self.lm_head_flops = 0 + elif formula == "hyena": + self.total_flops = compute_flops_hyena(config, batch_size, seq_len, hyena_layer_counts) + self.breakdown = None + self.lm_head_flops = 0 + else: + raise ValueError(f"Unknown formula: {formula!r}. Use 'analytical', 'simplified', or 'hyena'.") + + self.per_gpu_flops = self.total_flops // max(num_gpus, 1) + + if peak_tflops is not None: + self.peak_tflops = peak_tflops + self.device_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "unknown" + else: + detected, self.device_name = detect_gpu_peak_tflops() + self.peak_tflops = detected + + self.comm_bytes = self._estimate_comm() + + @classmethod + def from_config_dict(cls, config_dict, batch_size, seq_len, **kwargs): + """Create from an HF config dict with auto-detection.""" + config = from_hf_config(config_dict) + return cls(config, batch_size, seq_len, **kwargs) + + def compute_mfu(self, step_time): + """Compute MFU from measured step time. + + Args: + step_time: Wall-clock time for one training step (seconds). + + Returns: + Dict with mfu (%), tflops_per_gpu, per_gpu_flops, total_flops, step_time. + """ + tflops = self.per_gpu_flops / step_time / 1e12 + mfu = tflops / self.peak_tflops * 100 if self.peak_tflops else 0.0 + return { + "mfu": mfu, + "tflops_per_gpu": tflops, + "per_gpu_flops": self.per_gpu_flops, + "total_flops": self.total_flops, + "step_time": step_time, + } + + def estimate_comm_overhead(self, step_time, measured_bw_gbps=None): + """Estimate communication overhead as a fraction of step time. + + Args: + step_time: Measured step time in seconds. + measured_bw_gbps: Measured P2P bandwidth in GB/s. If None, uses a default estimate. + + Returns: + Dict with comm_bytes, estimated_comm_time, comm_pct. + """ + bw = measured_bw_gbps or 6.0 # Default ~PCIe Gen3 x8 + comm_time = self.comm_bytes / (bw * 1e9) if bw > 0 else 0.0 + comm_pct = comm_time / step_time * 100 if step_time > 0 else 0.0 + return {"comm_bytes": self.comm_bytes, "estimated_comm_time": comm_time, "comm_pct": comm_pct} + + def _estimate_comm(self): + """Estimate total communication bytes per step based on parallelism.""" + total = 0 + cp_size = self.parallelism.get("cp", 1) + dp_size = self.parallelism.get("dp", 1) + + if cp_size > 1: + total += estimate_cp_comm_bytes( + self.batch_size, + self.seq_len, + self.config.num_hidden_layers, + self.config.num_kv_heads, + self.config.head_dim, + cp_size, + ) + + if dp_size > 1: + # FSDP reduce-scatter estimate: ~2 * model_params * dtype_bytes * (dp-1)/dp + model_params = _estimate_model_params(self.config) + total += 2 * model_params * 2 * (dp_size - 1) // dp_size + + return total + + def summary(self): + """Return a human-readable summary string.""" + lines = [ + f"MFUTracker: {self.formula} formula, {self.num_gpus} GPU(s)", + f" Model: H={self.config.hidden_size}, L={self.config.num_hidden_layers}," + f" heads={self.config.num_attention_heads}, kv_heads={self.config.num_kv_heads}," + f" I={self.config.intermediate_size}, V={self.config.vocab_size}", + f" MLP projections: {self.config.num_mlp_projections}" + f" ({'SwiGLU/GLU' if self.config.num_mlp_projections == 3 else 'standard FFN'})", + f" Batch: B={self.batch_size}, S={self.seq_len}", + f" Total FLOPs/step: {format_flops(self.total_flops)} ({format_flops_exact(self.total_flops)})", + f" Per-GPU FLOPs: {format_flops(self.per_gpu_flops)}", + f" GPU: {self.device_name} (Peak: {self.peak_tflops} TFLOPS)" if self.peak_tflops else " GPU: unknown", + ] + if self.parallelism: + lines.append(f" Parallelism: {self.parallelism}") + if self.comm_bytes > 0: + lines.append(f" Estimated comm: {format_bytes(self.comm_bytes)}/step") + return "\n".join(lines) + + +# ============================================================================= +# Communication Estimation +# ============================================================================= + + +def estimate_cp_comm_bytes(b, s, num_layers, n_kv_heads, head_dim, cp_size, dtype_bytes=2): + """Estimate total bytes transferred for CP ring attention per training step. + + Ring attention sends local KV chunks around the ring. Per layer forward: + (cp-1) steps, each sending B * (S/cp) * 2 * kv_dim * dtype_bytes. + Training = ~2x forward communication (forward sends KV, backward sends dKV). + """ + if cp_size <= 1: + return 0 + s_local = s // cp_size + kv_dim = n_kv_heads * head_dim + per_layer_fwd = (cp_size - 1) * b * s_local * 2 * kv_dim * dtype_bytes + return 2 * num_layers * per_layer_fwd + + +def _estimate_model_params(config): + """Rough parameter count estimate from config dimensions.""" + h = config.hidden_size + kv_dim = config.num_kv_heads * config.head_dim + attn_params = h * h + 2 * h * kv_dim + h * h # Q + K + V + O + mlp_params = config.num_mlp_projections * h * config.intermediate_size + layer_params = attn_params + mlp_params + total = config.num_hidden_layers * layer_params + if config.has_lm_head: + total += config.vocab_size * h * 2 # embed + lm_head + return total + + +# ============================================================================= +# Step Time Measurement +# ============================================================================= + + +def measure_step_time( + model, + input_ids, + num_warmup=10, + num_timed=20, + distributed=False, + cp_context_fn=None, + labels=None, + position_ids=None, + **extra_fwd_kwargs, +): + """Measure average training step time (forward + backward). + + Args: + model: The model to benchmark. + input_ids: Input tensor. + num_warmup: Warmup iterations (discarded). + num_timed: Timed iterations to average. + distributed: Whether to use dist.barrier() for synchronization. + cp_context_fn: Optional callable returning a context manager for CP. + labels: Optional labels tensor. If None, uses input_ids. + position_ids: Optional position_ids for correct RoPE with CP. + **extra_fwd_kwargs: Additional kwargs for model forward. + """ + if labels is None: + labels = input_ids + + fwd_kwargs = {"input_ids": input_ids, "labels": labels, **extra_fwd_kwargs} + if position_ids is not None: + fwd_kwargs["position_ids"] = position_ids + + for _ in range(num_warmup): + ctx = cp_context_fn() if cp_context_fn else nullcontext() + with ctx: + output = model(**fwd_kwargs) + output.loss.backward() + model.zero_grad(set_to_none=True) + + if distributed: + dist.barrier() + torch.cuda.synchronize() + + times = [] + for _ in range(num_timed): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start.record() + + ctx = cp_context_fn() if cp_context_fn else nullcontext() + with ctx: + output = model(**fwd_kwargs) + output.loss.backward() + model.zero_grad(set_to_none=True) + + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end) / 1000.0) + + return sum(times) / len(times) + + +# ============================================================================= +# Utilities +# ============================================================================= + + +def split_for_cp_bshd(tensor, cp_rank, cp_size): + """Split a BSHD tensor for CP using the dual-chunk zigzag pattern.""" + if cp_size <= 1: + return tensor + total_chunks = 2 * cp_size + seq_len = tensor.size(1) + chunk_size = seq_len // total_chunks + chunk_indices = [cp_rank, total_chunks - cp_rank - 1] + slices = [tensor[:, idx * chunk_size : (idx + 1) * chunk_size] for idx in chunk_indices] + return torch.cat(slices, dim=1) + + +def measure_bus_bandwidth(device, world_size, num_iters=20, num_elements=10_000_000): + """Measure unidirectional P2P bandwidth via send/recv (matches CP ring pattern).""" + if world_size <= 1: + return 0.0 + + rank = dist.get_rank() + tensor = torch.randn(num_elements, device=device, dtype=torch.bfloat16) + peer = 1 - rank + + for _ in range(5): + if rank == 0: + dist.send(tensor, dst=peer) + dist.recv(tensor, src=peer) + else: + dist.recv(tensor, src=peer) + dist.send(tensor, dst=peer) + torch.cuda.synchronize() + + dist.barrier() + torch.cuda.synchronize() + start = time.perf_counter() + for _ in range(num_iters): + if rank == 0: + dist.send(tensor, dst=peer) + else: + dist.recv(tensor, src=peer) + torch.cuda.synchronize() + elapsed = time.perf_counter() - start + + data_bytes = tensor.nelement() * tensor.element_size() + return num_iters * data_bytes / elapsed / 1e9 + + +def count_flops_with_model(model, input_ids): + """Count forward FLOPs using PyTorch's FlopCounterMode, return 3x for training.""" + flop_counter = FlopCounterMode(display=False) + with flop_counter: + model(input_ids=input_ids) + return flop_counter.get_total_flops() * 3 + + +def load_model_config(config_path): + """Load model config dict from a local path or HuggingFace model ID. + + Supports: + - Local directory: ./model_configs/lingua-1B (reads config.json inside) + - Local file: ./model_configs/lingua-1B/config.json + - HF model ID: nvidia/esm2_t36_3B_UR50D (fetches from HuggingFace Hub) + """ + import json + from pathlib import Path + + path = Path(config_path) + if path.is_dir(): + path = path / "config.json" + if path.exists(): + return json.loads(path.read_text()) + + # Fall back to HuggingFace Hub + from transformers import AutoConfig + + hf_config = AutoConfig.from_pretrained(config_path, trust_remote_code=True) + return hf_config.to_dict() + + +def cleanup_model(model): + """Delete a model and free GPU memory.""" + del model + gc.collect() + torch.cuda.empty_cache() + + +# ============================================================================= +# Formatting +# ============================================================================= + + +def format_flops(flops): + """Format FLOPs with appropriate unit (G/T/P).""" + if flops >= 1e15: + return f"{flops / 1e15:.2f} P" + elif flops >= 1e12: + return f"{flops / 1e12:.2f} T" + elif flops >= 1e9: + return f"{flops / 1e9:.2f} G" + else: + return f"{flops:.2e}" + + +def format_flops_exact(flops): + """Format FLOPs as the full integer with commas.""" + return f"{int(flops):,}" + + +def format_bytes(num_bytes): + """Format bytes with appropriate unit.""" + if num_bytes >= 1e9: + return f"{num_bytes / 1e9:.2f} GB" + elif num_bytes >= 1e6: + return f"{num_bytes / 1e6:.2f} MB" + elif num_bytes >= 1e3: + return f"{num_bytes / 1e3:.2f} KB" + else: + return f"{num_bytes:.0f} B" + + +def print_breakdown(breakdown, lm_head_fwd, num_layers, total_flops, model_params): + """Print first-principles FLOPs breakdown.""" + print() + print("--- First Principles Breakdown (forward pass, per layer) ---") + per_layer_total = sum(breakdown.values()) + for component, flops_val in breakdown.items(): + pct = flops_val / per_layer_total * 100 + print(f" {component:<20} {format_flops(flops_val):>12} ({pct:>5.1f}%)") + total_fwd = num_layers * per_layer_total + lm_head_fwd + print(f" {'LM head':<20} {format_flops(lm_head_fwd):>12}") + print(f" {'Per-layer total':<20} {format_flops(per_layer_total):>12}") + print(f" {'All layers (x' + str(num_layers) + ')':<20} {format_flops(num_layers * per_layer_total):>12}") + print(f" {'Total forward':<20} {format_flops(total_fwd):>12}") + print(f" {'Total training (3x)':<20} {format_flops(total_flops):>12}") + print(f" {'Model params':<20} {model_params / 1e9:.2f}B") + + +# ============================================================================= +# CLI +# ============================================================================= + + +def _cli_bandwidth(): + """Measure P2P bandwidth. Launch with: torchrun --nproc_per_node=2 flops.py bandwidth.""" + import os + + dist.init_process_group(backend="cpu:gloo,cuda:nccl") + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + torch.cuda.set_device(local_rank) + rank = dist.get_rank() + world_size = dist.get_world_size() + device = torch.device(f"cuda:{local_rank}") + + if rank == 0: + print(f"Measuring P2P bandwidth between {world_size} GPUs...") + for i in range(world_size): + print(f" GPU {i}: {torch.cuda.get_device_name(i)}") + + bw = measure_bus_bandwidth(device, world_size) + if rank == 0: + print(f"\nUnidirectional P2P bandwidth: {bw:.2f} GB/s") + + dist.destroy_process_group() + + +def _cli_gpu_info(): + """Print GPU info and peak TFLOPS.""" + peak, name = detect_gpu_peak_tflops() + print(f"GPU: {name}") + if peak: + print(f"Peak bf16 TFLOPS: {peak:.1f}") + else: + print("Peak bf16 TFLOPS: unknown (use --peak-tflops to override)") + print() + print("Known GPUs:") + for gpu, tflops in GPU_PEAK_TFLOPS_BF16.items(): + print(f" {gpu:<16} {tflops:>8.1f} TFLOPS") + + +def _cli_flops(): + """Compute FLOPs for a model config.""" + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("command") + parser.add_argument("--config-path", default="./model_configs/lingua-1B") + parser.add_argument("--batch-size", type=int, default=1) + parser.add_argument("--seq-len", type=int, default=4096) + parser.add_argument("--formula", default="analytical", choices=["analytical", "simplified"]) + parser.add_argument("--num-gpus", type=int, default=1, help="Number of GPUs for per-GPU FLOPs and comm estimates") + parser.add_argument("--cp-size", type=int, default=1, help="Context parallelism size for comm overhead estimate") + parser.add_argument("--p2p-bw", type=float, default=6.0, help="P2P bandwidth in GB/s for comm time estimate") + args = parser.parse_args() + + cfg_dict = load_model_config(args.config_path) + config = from_hf_config(cfg_dict) + b, s = args.batch_size, args.seq_len + + print( + f"Config: H={config.hidden_size}, L={config.num_hidden_layers}," + f" n_heads={config.num_attention_heads}, n_kv={config.num_kv_heads}," + f" I={config.intermediate_size}, V={config.vocab_size}" + ) + print( + f"MLP: {config.num_mlp_projections} projections" + f" ({'SwiGLU/GLU' if config.num_mlp_projections == 3 else 'standard FFN'})" + ) + print(f"Batch: B={b}, S={s}, GPUs={args.num_gpus}, CP={args.cp_size}") + print() + + simplified = compute_flops_simplified(b, s, config.hidden_size, config.num_hidden_layers, config.vocab_size) + analytical, breakdown, lm_head = compute_flops_analytical(config, b, s) + + print(f"{'Method':<24} {'FLOPs/step':>14} {'Per-GPU':>14} {'Exact':>30}") + print("-" * 86) + for name, flops in [("Simplified (README)", simplified), ("Analytical", analytical)]: + per_gpu = flops // max(args.num_gpus, 1) + print(f"{name:<24} {format_flops(flops):>14} {format_flops(per_gpu):>14} {format_flops_exact(flops):>30}") + + if simplified != analytical: + diff = analytical - simplified + print(f"\nDifference: {format_flops_exact(diff)} ({diff / simplified * 100:+.2f}%)") + else: + print("\nFormulas agree exactly for this config.") + + # Communication overhead estimate + if args.cp_size > 1: + dp_size = args.num_gpus // args.cp_size + parallelism = {"dp": dp_size, "cp": args.cp_size} + tracker = MFUTracker(config, b, s, num_gpus=args.num_gpus, parallelism=parallelism) + print(f"\n--- Communication Overhead (CP={args.cp_size}, P2P BW={args.p2p_bw} GB/s) ---") + print( + f" CP ring attention: {format_bytes(tracker.comm_bytes)}/step ({format_flops_exact(tracker.comm_bytes)} bytes)" + ) + comm_time = tracker.comm_bytes / (args.p2p_bw * 1e9) if args.p2p_bw > 0 else 0 + print(f" Estimated comm time: {comm_time:.4f}s") + + model_params = _estimate_model_params(config) + print_breakdown(breakdown, lm_head, config.num_hidden_layers, analytical, model_params) + + +def _cli_cp_comm(): + """Estimate CP communication volume.""" + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("command") + parser.add_argument("--config-path", default="./model_configs/lingua-1B") + parser.add_argument("--batch-size", type=int, default=1) + parser.add_argument("--seq-len", type=int, default=16384) + parser.add_argument("--cp-size", type=int, default=2) + args = parser.parse_args() + + cfg_dict = load_model_config(args.config_path) + config = from_hf_config(cfg_dict) + b, s = args.batch_size, args.seq_len + + comm = estimate_cp_comm_bytes(b, s, config.num_hidden_layers, config.num_kv_heads, config.head_dim, args.cp_size) + print( + f"CP={args.cp_size}, B={b}, S={s}, L={config.num_hidden_layers}," + f" n_kv_heads={config.num_kv_heads}, head_dim={config.head_dim}" + ) + print(f"Estimated CP ring attention communication: {format_bytes(comm)}/step ({format_flops_exact(comm)} bytes)") + + +if __name__ == "__main__": + import sys + + commands = { + "bandwidth": ("Measure P2P bandwidth (requires torchrun --nproc_per_node=2)", _cli_bandwidth), + "gpu-info": ("Print GPU info and peak TFLOPS", _cli_gpu_info), + "flops": ("Compute FLOPs for a model config", _cli_flops), + "cp-comm": ("Estimate CP communication volume", _cli_cp_comm), + } + + if len(sys.argv) < 2 or sys.argv[1] in ("-h", "--help") or sys.argv[1] not in commands: + print("Usage: python flops.py [options]") + print(" torchrun --nproc_per_node=2 flops.py bandwidth") + print() + print("Commands:") + for cmd, (desc, _) in commands.items(): + print(f" {cmd:<16} {desc}") + sys.exit(0 if len(sys.argv) >= 2 and sys.argv[1] in ("-h", "--help") else 1) + + commands[sys.argv[1]][1]() diff --git a/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml index 969c8b3822..5857ee3ebe 100644 --- a/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml @@ -12,6 +12,8 @@ use_torch_compile: false cp_size: 1 +log_mfu: false + use_sequence_packing: false dataset: tokenizer_name: ${config_name_or_path} diff --git a/bionemo-recipes/recipes/esm2_native_te/perf_logger.py b/bionemo-recipes/recipes/esm2_native_te/perf_logger.py index 2e67b3aaa5..dd4070ecc9 100644 --- a/bionemo-recipes/recipes/esm2_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/esm2_native_te/perf_logger.py @@ -70,6 +70,7 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig): # We move metrics to a GPU device so we can use torch.distributed to aggregate them before logging. self.metrics.to(torch.device(f"cuda:{dist_config.local_rank}")) self.previous_step_time = time.perf_counter() + self.last_step_time = None # Set after each logged step for MFU tracking if self._dist_config.is_main_process(): # Log the entire args object to wandb for experiment tracking and reproducibility. @@ -115,6 +116,7 @@ def log_step( time.perf_counter(), ) step_time = elapsed_time / self.logging_frequency + self.last_step_time = step_time self.metrics["train/loss"].update(outputs.loss) self.metrics["train/learning_rate"].update(lr) diff --git a/bionemo-recipes/recipes/esm2_native_te/train_ddp.py b/bionemo-recipes/recipes/esm2_native_te/train_ddp.py index 65ad1fa2f3..5cb3837a74 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_ddp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_ddp.py @@ -29,6 +29,7 @@ from checkpoint import load_checkpoint_ddp, save_checkpoint_ddp, save_final_model_ddp, should_save_checkpoint from dataset import create_bshd_dataloader, create_thd_dataloader from distributed_config import DistributedConfig +from flops import MFUTracker, from_hf_config from modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM from perf_logger import PerfLogger from quantization import initialize_quant_stats_logging, resolve_layer_precision @@ -158,6 +159,18 @@ def main(args: DictConfig) -> float | None: perf_logger = PerfLogger(dist_config, args) + mfu_tracker = None + if args.get("log_mfu", False): + mfu_tracker = MFUTracker( + config=from_hf_config(config.to_dict()), + batch_size=args.dataset.micro_batch_size, + seq_len=args.dataset.max_seq_length, + num_gpus=dist_config.world_size, + parallelism={"dp": dist_config.world_size}, + ) + if dist_config.is_main_process(): + logger.info("MFU tracking enabled:\n%s", mfu_tracker.summary()) + # Training loop step = start_step while step < args.num_train_steps: @@ -187,6 +200,17 @@ def main(args: DictConfig) -> float | None: lr=optimizer.param_groups[0]["lr"], ) + if mfu_tracker is not None and perf_logger.last_step_time is not None: + mfu_info = mfu_tracker.compute_mfu(perf_logger.last_step_time) + if dist_config.is_main_process(): + logger.info( + "Step %d MFU: %.1f%% (%.1f TFLOPS/gpu, step_time=%.3fs)", + step, + mfu_info["mfu"], + mfu_info["tflops_per_gpu"], + mfu_info["step_time"], + ) + if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps): save_checkpoint_ddp( model=model, diff --git a/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py b/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py index aec8bb0a6d..3be83938b0 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py @@ -26,6 +26,7 @@ from checkpoint import load_checkpoint_ddp, save_checkpoint_ddp, save_final_model_ddp, should_save_checkpoint from dataset import create_cp_dataloader from distributed_config import DistributedConfig +from flops import MFUTracker, from_hf_config from modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM from perf_logger import PerfLogger from quantization import resolve_layer_precision @@ -167,6 +168,18 @@ def main(args: DictConfig) -> float | None: perf_logger = PerfLogger(dist_config, args) + mfu_tracker = None + if args.get("log_mfu", False): + mfu_tracker = MFUTracker( + config=from_hf_config(config.to_dict()), + batch_size=args.dataset.micro_batch_size, + seq_len=args.dataset.max_seq_length, + num_gpus=dist_config.world_size, + parallelism={"dp": dist_config.world_size // args.cp_size, "cp": args.cp_size}, + ) + if dist_config.is_main_process(): + logger.info("MFU tracking enabled:\n%s", mfu_tracker.summary()) + # Training loop step = start_step while step < args.num_train_steps: @@ -196,6 +209,17 @@ def main(args: DictConfig) -> float | None: lr=optimizer.param_groups[0]["lr"], ) + if mfu_tracker is not None and perf_logger.last_step_time is not None: + mfu_info = mfu_tracker.compute_mfu(perf_logger.last_step_time) + if dist_config.is_main_process(): + logger.info( + "Step %d MFU: %.1f%% (%.1f TFLOPS/gpu, step_time=%.3fs)", + step, + mfu_info["mfu"], + mfu_info["tflops_per_gpu"], + mfu_info["step_time"], + ) + if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps): save_checkpoint_ddp( model=model, diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py index 4cf5b6af6e..f01dcfcc91 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py @@ -32,6 +32,7 @@ from checkpoint import load_checkpoint_fsdp2, save_checkpoint_fsdp2, save_final_model_fsdp2, should_save_checkpoint from dataset import create_bshd_dataloader, create_thd_dataloader from distributed_config import DistributedConfig +from flops import MFUTracker, from_hf_config from modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM from perf_logger import PerfLogger from quantization import WandBQuantLogger, initialize_quant_stats_logging, resolve_layer_precision @@ -184,6 +185,18 @@ def main(args: DictConfig) -> float | None: perf_logger = PerfLogger(dist_config, args) + mfu_tracker = None + if args.get("log_mfu", False): + mfu_tracker = MFUTracker( + config=from_hf_config(config.to_dict()), + batch_size=args.dataset.micro_batch_size, + seq_len=args.dataset.max_seq_length, + num_gpus=dist_config.world_size, + parallelism={"dp": dist_config.world_size}, + ) + if dist_config.is_main_process(): + logger.info("MFU tracking enabled:\n%s", mfu_tracker.summary()) + # Training loop step = start_step while step < args.num_train_steps: @@ -214,6 +227,17 @@ def main(args: DictConfig) -> float | None: lr=optimizer.param_groups[0]["lr"], ) + if mfu_tracker is not None and perf_logger.last_step_time is not None: + mfu_info = mfu_tracker.compute_mfu(perf_logger.last_step_time) + if dist_config.is_main_process(): + logger.info( + "Step %d MFU: %.1f%% (%.1f TFLOPS/gpu, step_time=%.3fs)", + step, + mfu_info["mfu"], + mfu_info["tflops_per_gpu"], + mfu_info["step_time"], + ) + if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps): save_checkpoint_fsdp2( model=model, diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py index 5593a08721..6aaa0a8f2f 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py @@ -28,6 +28,7 @@ from checkpoint import load_checkpoint_fsdp2, save_checkpoint_fsdp2, save_final_model_fsdp2, should_save_checkpoint from dataset import create_cp_dataloader from distributed_config import DistributedConfig +from flops import MFUTracker, from_hf_config from modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM from perf_logger import PerfLogger from quantization import resolve_layer_precision @@ -179,6 +180,18 @@ def main(args: DictConfig) -> float | None: perf_logger = PerfLogger(dist_config, args) + mfu_tracker = None + if args.get("log_mfu", False): + mfu_tracker = MFUTracker( + config=from_hf_config(config.to_dict()), + batch_size=args.dataset.micro_batch_size, + seq_len=args.dataset.max_seq_length, + num_gpus=dist_config.world_size, + parallelism={"dp": dist_config.world_size // args.cp_size, "cp": args.cp_size}, + ) + if dist_config.is_main_process(): + logger.info("MFU tracking enabled:\n%s", mfu_tracker.summary()) + # Training loop step = start_step while step < args.num_train_steps: @@ -208,6 +221,17 @@ def main(args: DictConfig) -> float | None: lr=optimizer.param_groups[0]["lr"], ) + if mfu_tracker is not None and perf_logger.last_step_time is not None: + mfu_info = mfu_tracker.compute_mfu(perf_logger.last_step_time) + if dist_config.is_main_process(): + logger.info( + "Step %d MFU: %.1f%% (%.1f TFLOPS/gpu, step_time=%.3fs)", + step, + mfu_info["mfu"], + mfu_info["tflops_per_gpu"], + mfu_info["step_time"], + ) + if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps): save_checkpoint_fsdp2( model=model, diff --git a/bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py b/bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py index b998616316..3afc23a0f7 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py @@ -31,6 +31,7 @@ from checkpoint import load_checkpoint_mfsdp, save_checkpoint_mfsdp, save_final_model_mfsdp, should_save_checkpoint from dataset import create_bshd_dataloader, create_thd_dataloader from distributed_config import DistributedConfig +from flops import MFUTracker, from_hf_config from modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM from perf_logger import PerfLogger from quantization import resolve_layer_precision @@ -165,6 +166,18 @@ def main(args: DictConfig) -> float | None: perf_logger = PerfLogger(dist_config, args) + mfu_tracker = None + if args.get("log_mfu", False): + mfu_tracker = MFUTracker( + config=from_hf_config(config.to_dict()), + batch_size=args.dataset.micro_batch_size, + seq_len=args.dataset.max_seq_length, + num_gpus=dist_config.world_size, + parallelism={"dp": dist_config.world_size}, + ) + if dist_config.is_main_process(): + logger.info("MFU tracking enabled:\n%s", mfu_tracker.summary()) + # Training loop step = start_step while step < args.num_train_steps: @@ -195,6 +208,17 @@ def main(args: DictConfig) -> float | None: lr=optimizer.param_groups[0]["lr"], ) + if mfu_tracker is not None and perf_logger.last_step_time is not None: + mfu_info = mfu_tracker.compute_mfu(perf_logger.last_step_time) + if dist_config.is_main_process(): + logger.info( + "Step %d MFU: %.1f%% (%.1f TFLOPS/gpu, step_time=%.3fs)", + step, + mfu_info["mfu"], + mfu_info["tflops_per_gpu"], + mfu_info["step_time"], + ) + if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps): save_checkpoint_mfsdp( model=model, diff --git a/bionemo-recipes/recipes/llama3_native_te/compare_mfu.py b/bionemo-recipes/recipes/llama3_native_te/compare_mfu.py deleted file mode 100644 index a923ed36de..0000000000 --- a/bionemo-recipes/recipes/llama3_native_te/compare_mfu.py +++ /dev/null @@ -1,189 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Single-GPU MFU comparison: TE vs HF head-to-head. - -Compares FLOPs counting methods and measures MFU for TE and HF models on a single GPU. -No distributed setup required. - -Usage: - cd bionemo-recipes/recipes/llama3_native_te - python compare_mfu.py - python compare_mfu.py --seq-len 2048 --batch-size 2 -""" - -import argparse -import json -import sys -from pathlib import Path - -import torch -from transformers.models.llama.configuration_llama import LlamaConfig -from transformers.models.llama.modeling_llama import LlamaForCausalLM - -from compare_mfu_common import ( - cleanup_model, - compute_flops_first_principles, - compute_flops_readme, - count_flops_with_model, - create_te_model_on_gpu, - detect_gpu_peak_tflops, - format_flops, - format_flops_exact, - measure_step_time, - print_breakdown, -) -from modeling_llama_te import NVLlamaConfig - - -def main(): - """Run single-GPU MFU comparison: TE vs HF.""" - parser = argparse.ArgumentParser(description="Single-GPU MFU comparison: TE vs HF") - parser.add_argument("--config-path", default="./model_configs/lingua-1B", help="Model config directory") - parser.add_argument("--batch-size", type=int, default=1, help="Micro batch size") - parser.add_argument("--seq-len", type=int, default=4096, help="Sequence length") - parser.add_argument("--peak-tflops", type=float, default=None, help="Override GPU peak bf16 TFLOPS") - parser.add_argument("--warmup-steps", type=int, default=10, help="Warmup iterations before timing") - parser.add_argument("--timed-steps", type=int, default=20, help="Timed iterations to average") - args = parser.parse_args() - - # --- Load model config --- - config_path = Path(args.config_path) / "config.json" - with open(config_path) as f: - config_dict = json.load(f) - - b = args.batch_size - s = args.seq_len - h = config_dict["hidden_size"] - num_layers = config_dict["num_hidden_layers"] - vocab_size = config_dict["vocab_size"] - n_kv_heads = config_dict["num_key_value_heads"] - n_heads = config_dict["num_attention_heads"] - head_dim = h // n_heads - ffn_hidden_size = config_dict["intermediate_size"] - - # --- GPU detection --- - if args.peak_tflops: - peak_tflops = args.peak_tflops - device_name = torch.cuda.get_device_name(0) - else: - peak_tflops, device_name = detect_gpu_peak_tflops() - if peak_tflops is None: - print(f"ERROR: Could not auto-detect GPU peak TFLOPS for: {device_name}") - print("Use --peak-tflops to specify manually.") - sys.exit(1) - - peak_flops_per_sec = peak_tflops * 1e12 - - print(f"GPU: {device_name} (Peak: {peak_tflops:.1f} TFLOPS bf16)") - print( - f"Config: H={h}, L={num_layers}, n_heads={n_heads}, n_kv_heads={n_kv_heads}," - f" head_dim={head_dim}, I={ffn_hidden_size}, V={vocab_size}" - ) - print(f"Batch: B={b}, S={s}") - print() - - # ========================================================================= - # Table 1: FLOPs Counting - # ========================================================================= - total_flops_readme = compute_flops_readme(b, s, h, num_layers, vocab_size) - total_flops_fp, breakdown, lm_head_fwd = compute_flops_first_principles( - b, s, h, num_layers, n_kv_heads, head_dim, ffn_hidden_size, vocab_size - ) - - print("Counting FLOPs with HF model (meta device)...") - hf_config = LlamaConfig.from_pretrained(args.config_path) - hf_config._attn_implementation = "sdpa" - with torch.device("meta"): - hf_model_meta = LlamaForCausalLM(hf_config) - meta_input_ids = torch.randint(0, vocab_size, (b, s), device="meta") - total_flops_hf_counter = count_flops_with_model(hf_model_meta, meta_input_ids) - del hf_model_meta - print(f" HF FlopCounter: {format_flops(total_flops_hf_counter)} (training)") - - # ========================================================================= - # Table 2: MFU — TE vs HF - # ========================================================================= - input_ids = torch.randint(0, vocab_size, (b, s), device="cuda") - - # --- TE model --- - print(f"\n[1/2] TE model (S={s})...") - te_config = NVLlamaConfig.from_pretrained( - args.config_path, dtype=torch.bfloat16, attn_input_format="bshd", self_attn_mask_type="causal" - ) - te_model = create_te_model_on_gpu(te_config) - te_model.train() - print(f"Measuring TE step time ({args.warmup_steps} warmup + {args.timed_steps} timed)...") - te_step_time = measure_step_time(te_model, input_ids, args.warmup_steps, args.timed_steps) - model_params = sum(p.numel() for p in te_model.parameters()) - print(f" TE step time: {te_step_time:.4f}s") - cleanup_model(te_model) - - # --- HF model --- - print(f"[2/2] HF model (S={s})...") - hf_config_gpu = LlamaConfig.from_pretrained(args.config_path) - hf_model = LlamaForCausalLM(hf_config_gpu).to(dtype=torch.bfloat16, device="cuda") - hf_model.train() - print(f"Measuring HF step time ({args.warmup_steps} warmup + {args.timed_steps} timed)...") - hf_step_time = measure_step_time(hf_model, input_ids, args.warmup_steps, args.timed_steps) - print(f" HF step time: {hf_step_time:.4f}s") - cleanup_model(hf_model) - - # ========================================================================= - # Print results - # ========================================================================= - print() - print("=" * 75) - print(f"MFU Comparison: Lingua-1B (B={b}, S={s}, bf16)") - print(f"GPU: {device_name} (Peak: {peak_tflops:.1f} TFLOPS bf16)") - print("=" * 75) - - # --- Table 1 --- - print() - print("--- Table 1: FLOPs Counting (per training step) ---") - hdr1 = f"{'Method':<24} {'FLOPs/step':>14} {'Exact FLOPs':>30}" - print(hdr1) - print("-" * len(hdr1)) - for name, flops in [ - ("README Formula", total_flops_readme), - ("First Principles", total_flops_fp), - ("FlopCounter (HF)", total_flops_hf_counter), - ]: - print(f"{name:<24} {format_flops(flops):>14} {format_flops_exact(flops):>30}") - - # --- Table 2 --- - print() - print("--- Table 2: MFU ---") - hdr2 = f"{'Model':<12} {'FLOPs/step':>14} {'Step (s)':>9} {'TFLOPS/s':>9} {'MFU':>7}" - print(hdr2) - print("-" * len(hdr2)) - - for name, flops, step_time in [ - ("TE", total_flops_fp, te_step_time), - ("HF", total_flops_fp, hf_step_time), - ]: - tflops = flops / step_time / 1e12 - mfu = flops / step_time / peak_flops_per_sec * 100 - print(f"{name:<12} {format_flops(flops):>14} {step_time:>8.3f}s {tflops:>8.2f} {mfu:>6.1f}%") - - print() - print(f"TE vs HF speedup: {hf_step_time / te_step_time:.2f}x") - - # --- Breakdown --- - print_breakdown(breakdown, lm_head_fwd, num_layers, total_flops_fp, model_params) - - -if __name__ == "__main__": - main() diff --git a/bionemo-recipes/recipes/llama3_native_te/compare_mfu_common.py b/bionemo-recipes/recipes/llama3_native_te/compare_mfu_common.py deleted file mode 100644 index 3a5cfb4bca..0000000000 --- a/bionemo-recipes/recipes/llama3_native_te/compare_mfu_common.py +++ /dev/null @@ -1,436 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Shared utilities for MFU comparison scripts. - -Provides FLOPs counting formulas, GPU detection, model creation helpers, -step time measurement, and formatting utilities used by both single-GPU -and multi-GPU MFU comparison scripts. -""" - -import gc -import time -from contextlib import nullcontext - -import torch -import torch.distributed as dist -from torch.utils.flop_counter import FlopCounterMode - -from modeling_llama_te import NVLlamaForCausalLM - - -# Peak bf16 TFLOPS for common NVIDIA GPUs (tensor core ops). -GPU_PEAK_TFLOPS_BF16 = { - "H100": 989.0, - "H200": 989.0, - "A100": 312.0, - "A6000": 155.0, - "A5000": 111.0, - "L40": 181.0, - "RTX 4090": 330.0, - "RTX 3090": 142.0, - "GH200": 989.0, - "B200": 2250.0, - "GB200": 2250.0, -} - - -def detect_gpu_peak_tflops(): - """Auto-detect GPU peak bf16 TFLOPS from device name via substring match.""" - device_name = torch.cuda.get_device_name(0) - for gpu_key, tflops in GPU_PEAK_TFLOPS_BF16.items(): - if gpu_key.lower() in device_name.lower(): - return tflops, device_name - return None, device_name - - -def compute_flops_readme(b, s, h, num_layers, vocab_size): - """README formula: assumes standard MHA + standard MLP (I=4H, 2 projections).""" - return (24 * b * s * h * h + 4 * b * s * s * h) * (3 * num_layers) + (6 * b * s * h * vocab_size) - - -def compute_flops_first_principles(b, s, h, num_layers, n_kv_heads, head_dim, ffn_hidden_size, vocab_size): - """First-principles FLOPs for GQA + SwiGLU architecture. - - Returns: - total_training_flops: Total FLOPs for one training step (3x forward). - breakdown: Per-component forward FLOPs for one layer. - lm_head_fwd: Forward FLOPs for the LM head. - """ - kv_dim = n_kv_heads * head_dim - - breakdown = { - "Q projection": 2 * b * s * h * h, - "K projection": 2 * b * s * h * kv_dim, - "V projection": 2 * b * s * h * kv_dim, - "O projection": 2 * b * s * h * h, - "Attn logits": 2 * b * s * s * h, - "Attn values": 2 * b * s * s * h, - "Gate projection": 2 * b * s * h * ffn_hidden_size, - "Up projection": 2 * b * s * h * ffn_hidden_size, - "Down projection": 2 * b * s * ffn_hidden_size * h, - } - - per_layer_fwd = sum(breakdown.values()) - lm_head_fwd = 2 * b * s * h * vocab_size - total_fwd = num_layers * per_layer_fwd + lm_head_fwd - total_training = 3 * total_fwd - - return total_training, breakdown, lm_head_fwd - - -def count_flops_with_model(model, input_ids): - """Count forward FLOPs using PyTorch's FlopCounterMode, return 3x for training.""" - flop_counter = FlopCounterMode(display=False) - with flop_counter: - model(input_ids=input_ids) - return flop_counter.get_total_flops() * 3 - - -def create_te_model_on_gpu(config): - """Create a TE model on GPU using the meta device + init_empty_weights pattern.""" - with torch.device("meta"): - model = NVLlamaForCausalLM(config) - model.init_empty_weights() - return model - - -def measure_step_time( - model, - input_ids, - num_warmup=10, - num_timed=20, - distributed=False, - cp_context_fn=None, - labels=None, - position_ids=None, - **extra_fwd_kwargs, -): - """Measure average training step time (forward + backward). - - Args: - model: The model to benchmark. - input_ids: Input tensor. For CP with context_parallel, pass full-size tensors - (the cp_context_fn will shard them). - num_warmup: Number of warmup iterations (discarded). - num_timed: Number of timed iterations to average. - distributed: Whether to use dist.barrier() for synchronization. - cp_context_fn: Optional callable returning a context manager (e.g., context_parallel). - Called fresh each iteration. The context manager shards buffers in-place on entry - and restores them on exit, so reusing the same tensors across iterations is safe. - labels: Optional labels tensor. If None, uses input_ids as labels. - position_ids: Optional position_ids tensor. Critical for HF models with CP to ensure - correct RoPE positions per rank. - **extra_fwd_kwargs: Additional kwargs passed to model forward (e.g., max_length_q/k for TE CP). - """ - if labels is None: - labels = input_ids - - fwd_kwargs = {"input_ids": input_ids, "labels": labels, **extra_fwd_kwargs} - if position_ids is not None: - fwd_kwargs["position_ids"] = position_ids - - for _ in range(num_warmup): - ctx = cp_context_fn() if cp_context_fn else nullcontext() - with ctx: - output = model(**fwd_kwargs) - output.loss.backward() - model.zero_grad(set_to_none=True) - - if distributed: - dist.barrier() - torch.cuda.synchronize() - - times = [] - for _ in range(num_timed): - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - - torch.cuda.synchronize() - start.record() - - ctx = cp_context_fn() if cp_context_fn else nullcontext() - with ctx: - output = model(**fwd_kwargs) - output.loss.backward() - # no need to unshard output here since we're only measuring timing - model.zero_grad(set_to_none=True) - - end.record() - torch.cuda.synchronize() - times.append(start.elapsed_time(end) / 1000.0) - - return sum(times) / len(times) - - -def split_for_cp_bshd(tensor, cp_rank, cp_size): - """Split a BSHD tensor for CP using the dual-chunk zigzag pattern. - - For cp_size=2: rank 0 gets chunks [0, 3], rank 1 gets chunks [1, 2]. - """ - if cp_size <= 1: - return tensor - total_chunks = 2 * cp_size - seq_len = tensor.size(1) - chunk_size = seq_len // total_chunks - chunk_indices = [cp_rank, total_chunks - cp_rank - 1] - slices = [tensor[:, idx * chunk_size : (idx + 1) * chunk_size] for idx in chunk_indices] - return torch.cat(slices, dim=1) - - -def measure_bus_bandwidth(device, world_size, num_iters=20, num_elements=10_000_000): - """Measure unidirectional P2P bandwidth between GPU 0 and GPU 1 via send/recv. - - This matches CP ring attention's actual communication pattern: each rank sends - its KV chunk to the next rank in the ring via point-to-point transfer. - """ - if world_size <= 1: - return 0.0 - - rank = dist.get_rank() - tensor = torch.randn(num_elements, device=device, dtype=torch.bfloat16) - peer = 1 - rank # rank 0 <-> rank 1 - - # Warmup - for _ in range(5): - if rank == 0: - dist.send(tensor, dst=peer) - dist.recv(tensor, src=peer) - else: - dist.recv(tensor, src=peer) - dist.send(tensor, dst=peer) - torch.cuda.synchronize() - - # Timed: rank 0 sends, rank 1 receives (unidirectional) - dist.barrier() - torch.cuda.synchronize() - start = time.perf_counter() - for _ in range(num_iters): - if rank == 0: - dist.send(tensor, dst=peer) - else: - dist.recv(tensor, src=peer) - torch.cuda.synchronize() - elapsed = time.perf_counter() - start - - data_bytes = tensor.nelement() * tensor.element_size() - bw = num_iters * data_bytes / elapsed - return bw / 1e9 # GB/s - - -def estimate_cp_comm_bytes(b, s, num_layers, n_kv_heads, head_dim, cp_size, dtype_bytes=2): - """Estimate total bytes transferred for CP ring attention per training step.""" - if cp_size <= 1: - return 0 - s_local = s // cp_size - kv_dim = n_kv_heads * head_dim - per_layer_fwd = (cp_size - 1) * b * s_local * 2 * kv_dim * dtype_bytes - return 2 * num_layers * per_layer_fwd - - -def format_flops(flops): - """Format FLOPs value with appropriate unit.""" - if flops >= 1e15: - return f"{flops / 1e15:.2f} P" - elif flops >= 1e12: - return f"{flops / 1e12:.2f} T" - elif flops >= 1e9: - return f"{flops / 1e9:.2f} G" - else: - return f"{flops:.2e}" - - -def format_flops_exact(flops): - """Format FLOPs as the full integer with commas.""" - return f"{int(flops):,}" - - -def format_bytes(num_bytes): - """Format bytes with appropriate unit.""" - if num_bytes >= 1e9: - return f"{num_bytes / 1e9:.2f} GB" - elif num_bytes >= 1e6: - return f"{num_bytes / 1e6:.2f} MB" - elif num_bytes >= 1e3: - return f"{num_bytes / 1e3:.2f} KB" - else: - return f"{num_bytes:.0f} B" - - -def cleanup_model(model): - """Delete a model and free GPU memory.""" - del model - gc.collect() - torch.cuda.empty_cache() - - -def load_model_config(config_path): - """Load model config dict from a directory containing config.json.""" - import json - from pathlib import Path - - return json.loads(Path(config_path, "config.json").read_text()) - - -def print_breakdown(breakdown, lm_head_fwd, num_layers, total_flops, model_params): - """Print first-principles FLOPs breakdown.""" - print() - print("--- First Principles Breakdown (forward pass, per layer) ---") - per_layer_total = sum(breakdown.values()) - for component, flops_val in breakdown.items(): - pct = flops_val / per_layer_total * 100 - print(f" {component:<20} {format_flops(flops_val):>12} ({pct:>5.1f}%)") - total_fwd = num_layers * per_layer_total + lm_head_fwd - print(f" {'LM head':<20} {format_flops(lm_head_fwd):>12}") - print(f" {'Per-layer total':<20} {format_flops(per_layer_total):>12}") - print(f" {'All layers (x' + str(num_layers) + ')':<20} {format_flops(num_layers * per_layer_total):>12}") - print(f" {'Total forward':<20} {format_flops(total_fwd):>12}") - print(f" {'Total training (3x)':<20} {format_flops(total_flops):>12}") - print(f" {'Model params':<20} {model_params / 1e9:.2f}B") - - -# ============================================================================= -# CLI -# ============================================================================= - - -def _cli_bandwidth(): - """Measure P2P bandwidth. Launch with: torchrun --nproc_per_node=2 compare_mfu_common.py bandwidth.""" - import os - - dist.init_process_group(backend="cpu:gloo,cuda:nccl") - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - torch.cuda.set_device(local_rank) - rank = dist.get_rank() - world_size = dist.get_world_size() - device = torch.device(f"cuda:{local_rank}") - - if rank == 0: - print(f"Measuring P2P bandwidth between {world_size} GPUs...") - for i in range(world_size): - print(f" GPU {i}: {torch.cuda.get_device_name(i)}") - - bw = measure_bus_bandwidth(device, world_size) - if rank == 0: - print(f"\nUnidirectional P2P bandwidth: {bw:.2f} GB/s") - - dist.destroy_process_group() - - -def _cli_gpu_info(): - """Print GPU info and peak TFLOPS.""" - peak, name = detect_gpu_peak_tflops() - print(f"GPU: {name}") - if peak: - print(f"Peak bf16 TFLOPS: {peak:.1f}") - else: - print("Peak bf16 TFLOPS: unknown (use --peak-tflops to override)") - print() - print("Known GPUs:") - for gpu, tflops in GPU_PEAK_TFLOPS_BF16.items(): - print(f" {gpu:<16} {tflops:>8.1f} TFLOPS") - - -def _cli_flops(): - """Compute FLOPs for a model config.""" - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument("command") # "flops" - parser.add_argument("--config-path", default="./model_configs/lingua-1B") - parser.add_argument("--batch-size", type=int, default=1) - parser.add_argument("--seq-len", type=int, default=4096) - args = parser.parse_args() - - cfg = load_model_config(args.config_path) - b, s = args.batch_size, args.seq_len - h = cfg["hidden_size"] - num_layers = cfg["num_hidden_layers"] - vocab_size = cfg["vocab_size"] - n_kv_heads = cfg["num_key_value_heads"] - n_heads = cfg["num_attention_heads"] - head_dim = h // n_heads - ffn = cfg["intermediate_size"] - - print(f"Config: H={h}, L={num_layers}, n_heads={n_heads}, n_kv_heads={n_kv_heads}, I={ffn}, V={vocab_size}") - print(f"Batch: B={b}, S={s}") - print() - - readme = compute_flops_readme(b, s, h, num_layers, vocab_size) - fp, breakdown, lm_head = compute_flops_first_principles(b, s, h, num_layers, n_kv_heads, head_dim, ffn, vocab_size) - - print(f"{'Method':<24} {'FLOPs/step':>14} {'Exact':>30}") - print("-" * 70) - print(f"{'README Formula':<24} {format_flops(readme):>14} {format_flops_exact(readme):>30}") - print(f"{'First Principles':<24} {format_flops(fp):>14} {format_flops_exact(fp):>30}") - - if readme != fp: - diff = fp - readme - print(f"\nDifference: {format_flops_exact(diff)} ({diff / readme * 100:+.2f}%)") - else: - print("\nFormulas agree exactly for this config.") - - # Estimate model params from config - # Embedding + layers*(attn + mlp) + norm + lm_head - attn_params = h * h + 2 * h * (n_kv_heads * head_dim) + h * h # Q + K + V + O - mlp_params = 3 * h * ffn # gate + up + down (SwiGLU) - layer_params = attn_params + mlp_params - total_params = vocab_size * h + num_layers * layer_params + h + vocab_size * h # embed + layers + norm + lm_head - print_breakdown(breakdown, lm_head, num_layers, fp, total_params) - - -def _cli_cp_comm(): - """Estimate CP communication volume.""" - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument("command") # "cp-comm" - parser.add_argument("--config-path", default="./model_configs/lingua-1B") - parser.add_argument("--batch-size", type=int, default=1) - parser.add_argument("--seq-len", type=int, default=16384) - parser.add_argument("--cp-size", type=int, default=2) - args = parser.parse_args() - - cfg = load_model_config(args.config_path) - b, s = args.batch_size, args.seq_len - num_layers = cfg["num_hidden_layers"] - n_kv_heads = cfg["num_key_value_heads"] - head_dim = cfg["hidden_size"] // cfg["num_attention_heads"] - - comm = estimate_cp_comm_bytes(b, s, num_layers, n_kv_heads, head_dim, args.cp_size) - print(f"CP={args.cp_size}, B={b}, S={s}, L={num_layers}, n_kv_heads={n_kv_heads}, head_dim={head_dim}") - print(f"Estimated CP ring attention communication: {format_bytes(comm)}/step ({format_flops_exact(comm)} bytes)") - - -if __name__ == "__main__": - import sys - - commands = { - "bandwidth": ("Measure P2P bandwidth (requires torchrun --nproc_per_node=2)", _cli_bandwidth), - "gpu-info": ("Print GPU info and peak TFLOPS", _cli_gpu_info), - "flops": ("Compute FLOPs for a model config", _cli_flops), - "cp-comm": ("Estimate CP communication volume", _cli_cp_comm), - } - - if len(sys.argv) < 2 or sys.argv[1] in ("-h", "--help") or sys.argv[1] not in commands: - print("Usage: python compare_mfu_common.py [options]") - print(" torchrun --nproc_per_node=2 compare_mfu_common.py bandwidth") - print() - print("Commands:") - for cmd, (desc, _) in commands.items(): - print(f" {cmd:<16} {desc}") - sys.exit(0 if len(sys.argv) >= 2 and sys.argv[1] in ("-h", "--help") else 1) - - commands[sys.argv[1]][1]() diff --git a/bionemo-recipes/recipes/llama3_native_te/compare_mfu_multigpu.py b/bionemo-recipes/recipes/llama3_native_te/compare_mfu_multigpu.py deleted file mode 100644 index b0a8a3642a..0000000000 --- a/bionemo-recipes/recipes/llama3_native_te/compare_mfu_multigpu.py +++ /dev/null @@ -1,309 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Multi-GPU MFU comparison: TE CP vs HF CP head-to-head. - -Compares MFU with context parallelism for both TE (via set_context_parallel_group) -and HF (via PyTorch native context_parallel with ring attention). - -Usage: - cd bionemo-recipes/recipes/llama3_native_te - torchrun --nproc_per_node=2 compare_mfu_multigpu.py - torchrun --nproc_per_node=2 compare_mfu_multigpu.py --seq-len 32768 -""" - -import argparse -import json -import os -import sys -from pathlib import Path - -import torch -import torch.distributed as dist -from torch.distributed.device_mesh import init_device_mesh -from torch.distributed.tensor.experimental import context_parallel -from transformers.models.llama.configuration_llama import LlamaConfig -from transformers.models.llama.modeling_llama import LlamaForCausalLM - -from compare_mfu_common import ( - cleanup_model, - compute_flops_first_principles, - compute_flops_readme, - count_flops_with_model, - create_te_model_on_gpu, - detect_gpu_peak_tflops, - estimate_cp_comm_bytes, - format_bytes, - format_flops, - format_flops_exact, - measure_bus_bandwidth, - measure_step_time, - print_breakdown, - split_for_cp_bshd, -) -from modeling_llama_te import NVLlamaConfig - - -def main(): - """Run multi-GPU MFU comparison: TE CP vs HF CP.""" - # --- Distributed setup --- - dist.init_process_group(backend="cpu:gloo,cuda:nccl") - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - torch.cuda.set_device(local_rank) - rank = dist.get_rank() - world_size = dist.get_world_size() - device = torch.device(f"cuda:{local_rank}") - - # --- Parse arguments --- - parser = argparse.ArgumentParser(description="Multi-GPU MFU comparison: TE CP vs HF CP") - parser.add_argument("--config-path", default="./model_configs/lingua-1B", help="Model config directory") - parser.add_argument("--batch-size", type=int, default=1, help="Micro batch size per GPU") - parser.add_argument("--seq-len", type=int, default=16384, help="Total sequence length (split across CP ranks)") - parser.add_argument("--cp-size", type=int, default=None, help="CP size (default: world_size)") - parser.add_argument("--peak-tflops", type=float, default=None, help="Override GPU peak bf16 TFLOPS") - parser.add_argument("--warmup-steps", type=int, default=10, help="Warmup iterations before timing") - parser.add_argument("--timed-steps", type=int, default=20, help="Timed iterations to average") - args = parser.parse_args() - - cp_size = args.cp_size or world_size - dp_size = world_size // cp_size - if dp_size * cp_size != world_size: - if rank == 0: - print(f"ERROR: dp_size ({dp_size}) * cp_size ({cp_size}) != world_size ({world_size})") - dist.destroy_process_group() - sys.exit(1) - - # --- Device mesh --- - device_mesh = init_device_mesh("cuda", mesh_shape=(dp_size, cp_size), mesh_dim_names=("dp", "cp")) - cp_group = device_mesh["cp"].get_group() - cp_ranks = dist.get_process_group_ranks(cp_group) - cp_rank = device_mesh["cp"].get_local_rank() - - # --- Load model config --- - config_path = Path(args.config_path) / "config.json" - with open(config_path) as f: - config_dict = json.load(f) - - b = args.batch_size - s = args.seq_len - h = config_dict["hidden_size"] - num_layers = config_dict["num_hidden_layers"] - vocab_size = config_dict["vocab_size"] - n_kv_heads = config_dict["num_key_value_heads"] - n_heads = config_dict["num_attention_heads"] - head_dim = h // n_heads - ffn_hidden_size = config_dict["intermediate_size"] - s_local = s // cp_size - - if s % (2 * cp_size) != 0: - if rank == 0: - print(f"ERROR: seq_len ({s}) must be divisible by {2 * cp_size} (2 * cp_size)") - dist.destroy_process_group() - sys.exit(1) - - # --- GPU detection --- - if args.peak_tflops: - peak_tflops = args.peak_tflops - device_name = torch.cuda.get_device_name(0) - else: - peak_tflops, device_name = detect_gpu_peak_tflops() - if peak_tflops is None: - if rank == 0: - print(f"ERROR: Could not auto-detect GPU peak TFLOPS for: {device_name}") - dist.destroy_process_group() - sys.exit(1) - - peak_flops_per_sec = peak_tflops * 1e12 - - if rank == 0: - print(f"GPU: {world_size}x {device_name} (Peak: {peak_tflops:.1f} TFLOPS bf16 each)") - print(f"Parallelism: dp={dp_size}, cp={cp_size} ({world_size} GPUs)") - print( - f"Config: H={h}, L={num_layers}, n_heads={n_heads}, n_kv_heads={n_kv_heads}," - f" head_dim={head_dim}, I={ffn_hidden_size}, V={vocab_size}" - ) - print(f"Batch: B={b}, S={s} (S_local={s_local} per GPU)") - print() - - # ========================================================================= - # Table 1: FLOPs Counting - # ========================================================================= - total_flops_readme = compute_flops_readme(b, s, h, num_layers, vocab_size) - total_flops_fp, breakdown, lm_head_fwd = compute_flops_first_principles( - b, s, h, num_layers, n_kv_heads, head_dim, ffn_hidden_size, vocab_size - ) - - if rank == 0: - print("Counting FLOPs with HF model (meta device)...") - hf_config_meta = LlamaConfig.from_pretrained(args.config_path) - hf_config_meta._attn_implementation = "sdpa" - hf_config_meta.max_position_embeddings = max(hf_config_meta.max_position_embeddings, s) - with torch.device("meta"): - hf_model_meta = LlamaForCausalLM(hf_config_meta) - meta_input_ids = torch.randint(0, vocab_size, (b, s), device="meta") - total_flops_hf_counter = count_flops_with_model(hf_model_meta, meta_input_ids) - del hf_model_meta - if rank == 0: - print(f" HF FlopCounter: {format_flops(total_flops_hf_counter)} (training, full batch)") - - per_gpu_flops = total_flops_fp // world_size - - # ========================================================================= - # Table 2: MFU — TE CP vs HF CP - # ========================================================================= - - cp_mesh = device_mesh["cp"] - - # --- HF with PyTorch native CP --- - if rank == 0: - print(f"\n[1/2] HF model with PyTorch native CP={cp_size} (S={s})...") - hf_config_gpu = LlamaConfig.from_pretrained(args.config_path) - hf_config_gpu._attn_implementation = "sdpa" # Required for context_parallel - hf_config_gpu.max_position_embeddings = max(hf_config_gpu.max_position_embeddings, s) - hf_model = LlamaForCausalLM(hf_config_gpu).to(dtype=torch.bfloat16, device=device) - hf_model.train() - - # Full-size inputs — context_parallel shards them each iteration. - # position_ids are critical: without them, HF auto-generates [0..S/CP-1] per rank, - # giving WRONG RoPE positions. We create full [0..S-1] and shard alongside input_ids. - hf_full_ids = torch.randint(0, vocab_size, (b, s), device=device) - hf_full_labels = hf_full_ids.clone() - hf_full_pos = torch.arange(s, device=device).unsqueeze(0).expand(b, -1) - - def make_hf_cp_ctx(_ids=hf_full_ids, _labels=hf_full_labels, _pos=hf_full_pos): - return context_parallel(cp_mesh, buffers=(_ids, _labels, _pos), buffer_seq_dims=(1, 1, 1)) - - if rank == 0: - print(f"Measuring HF CP step time ({args.warmup_steps} warmup + {args.timed_steps} timed)...") - hf_cp_time = measure_step_time( - hf_model, - hf_full_ids, - args.warmup_steps, - args.timed_steps, - distributed=True, - cp_context_fn=make_hf_cp_ctx, - labels=hf_full_labels, - position_ids=hf_full_pos, - ) - if rank == 0: - print(f" HF CP step time: {hf_cp_time:.4f}s") - cleanup_model(hf_model) - del hf_full_ids, hf_full_labels, hf_full_pos - - # --- TE with CP via set_context_parallel_group --- - if rank == 0: - print(f"\n[2/2] TE model with CP={cp_size} (S={s})...") - te_config = NVLlamaConfig.from_pretrained( - args.config_path, - dtype=torch.bfloat16, - attn_input_format="bshd", - self_attn_mask_type="causal", - ) - te_config.max_position_embeddings = max(te_config.max_position_embeddings, s) - te_model = create_te_model_on_gpu(te_config) - for layer in te_model.model.layers: - layer.set_context_parallel_group(cp_group, cp_ranks, torch.cuda.Stream()) - te_model.train() - model_params = sum(p.numel() for p in te_model.parameters()) - - full_ids = torch.randint(0, vocab_size, (b, s), device=device) - te_local_ids = split_for_cp_bshd(full_ids, cp_rank, cp_size) - - if rank == 0: - print(f"Measuring TE CP step time ({args.warmup_steps} warmup + {args.timed_steps} timed)...") - te_cp_time = measure_step_time( - te_model, - te_local_ids, - args.warmup_steps, - args.timed_steps, - distributed=True, - max_length_q=s, - max_length_k=s, - ) - if rank == 0: - print(f" TE CP step time: {te_cp_time:.4f}s") - cleanup_model(te_model) - - # ========================================================================= - # Communication overhead - # ========================================================================= - if rank == 0: - print("\nMeasuring inter-GPU bandwidth...") - bus_bw_gbps = measure_bus_bandwidth(device, world_size) - if rank == 0: - print(f" Bus bandwidth: {bus_bw_gbps:.1f} GB/s") - - cp_comm_bytes = estimate_cp_comm_bytes(b, s, num_layers, n_kv_heads, head_dim, cp_size) - cp_comm_time = cp_comm_bytes / (bus_bw_gbps * 1e9) if bus_bw_gbps > 0 else 0.0 - - # ========================================================================= - # Print results (rank 0 only) - # ========================================================================= - if rank == 0: - print() - print("=" * 75) - print(f"MFU Comparison: Lingua-1B (B={b}, S={s}, bf16, CP={cp_size})") - print(f"GPU: {world_size}x {device_name} (Peak: {peak_tflops:.1f} TFLOPS bf16 each)") - print("=" * 75) - - # --- Table 1 --- - print() - print("--- Table 1: FLOPs Counting (per training step) ---") - hdr1 = f"{'Method':<24} {'Total FLOPs':>14} {'Per-GPU FLOPs':>14} {'Exact Total FLOPs':>30}" - print(hdr1) - print("-" * len(hdr1)) - for name, total in [ - ("README Formula", total_flops_readme), - ("First Principles", total_flops_fp), - ("FlopCounter (HF)", total_flops_hf_counter), - ]: - print( - f"{name:<24} {format_flops(total):>14} {format_flops(total // world_size):>14}" - f" {format_flops_exact(total):>30}" - ) - - # --- Table 2 --- - print() - print(f"--- Table 2: MFU (per GPU, CP={cp_size}) ---") - hdr2 = f"{'Model':<16} {'Per-GPU FLOPs':>14} {'Step (s)':>9} {'TFLOPS/s':>9} {'MFU':>7}" - print(hdr2) - print("-" * len(hdr2)) - - for name, step_time in [("TE (CP)", te_cp_time), ("HF (CP)", hf_cp_time)]: - tflops = per_gpu_flops / step_time / 1e12 - mfu = per_gpu_flops / step_time / peak_flops_per_sec * 100 - print(f"{name:<16} {format_flops(per_gpu_flops):>14} {step_time:>8.3f}s {tflops:>8.2f} {mfu:>6.1f}%") - - print() - print(f"TE vs HF speedup: {hf_cp_time / te_cp_time:.2f}x") - - # --- Communication overhead --- - print() - print("--- Communication Overhead ---") - print(f"Measured bus bandwidth: {bus_bw_gbps:.1f} GB/s") - print(f"CP ring attention (cp={cp_size}): {format_bytes(cp_comm_bytes):>12}/step (~{cp_comm_time:.4f}s)") - if te_cp_time > 0: - print(f" As % of TE step: {cp_comm_time / te_cp_time * 100:.1f}%") - if hf_cp_time > 0: - print(f" As % of HF step: {cp_comm_time / hf_cp_time * 100:.1f}%") - - # --- Breakdown --- - print_breakdown(breakdown, lm_head_fwd, num_layers, total_flops_fp, model_params) - - dist.destroy_process_group() - - -if __name__ == "__main__": - main() diff --git a/bionemo-recipes/recipes/llama3_native_te/compare_mfu_validate.py b/bionemo-recipes/recipes/llama3_native_te/compare_mfu_validate.py deleted file mode 100644 index e62b7fe7f5..0000000000 --- a/bionemo-recipes/recipes/llama3_native_te/compare_mfu_validate.py +++ /dev/null @@ -1,380 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Golden value tests for context parallelism correctness with FSDP2. - -Validates that FSDP2 + CP produces equivalent results to non-CP execution. -Uses the same FSDP2 + CP setup as compare_mfu_multigpu.py and train_fsdp2_cp.py. - -Strategy: -1. Init distributed, create FSDP2 model with CP -2. Gather full weights to rank 0 for non-CP baseline -3. Rank 0 runs non-CP baseline with identical weights -4. All ranks run CP forward+backward -5. Compare loss, logits (cosine sim), gradients (cosine sim) - -Tests both TE (set_context_parallel_group) and HF (PyTorch native context_parallel). - -Usage: - cd bionemo-recipes/recipes/llama3_native_te - torchrun --nproc_per_node=2 compare_mfu_validate.py -""" - -import argparse -import gc -import json -import os -import sys -from pathlib import Path - -import torch -import torch.distributed as dist -from torch.distributed.device_mesh import init_device_mesh -from torch.distributed.tensor.experimental import context_parallel -from torch.distributed.tensor.experimental._attention import context_parallel_unshard -from torch.distributed.tensor.experimental._context_parallel._load_balancer import _HeadTailLoadBalancer -from transformers.models.llama.configuration_llama import LlamaConfig -from transformers.models.llama.modeling_llama import LlamaForCausalLM - -from collator import _split_batch_by_cp_rank -from compare_mfu_common import create_te_model_on_gpu -from modeling_llama_te import NVLlamaConfig - - -SEED = 42 -LOSS_ATOL = 0.5 -LOSS_RTOL = 0.25 -LOGITS_COSINE_MIN = 0.99 -GRAD_COSINE_MIN = 0.8 - - -def seed_everything(seed): - """Set all random seeds for reproducibility.""" - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def get_dummy_data(vocab_size, batch_size=2, seq_length=64): - """Create deterministic dummy data for golden value tests.""" - seed_everything(SEED + 1000) - input_ids = torch.randint(0, vocab_size, (batch_size, seq_length)) - return {"input_ids": input_ids, "labels": input_ids.clone()} - - -def reconstruct_logits_from_cp(logits_list, full_seq_len, cp_world_size): - """Reconstruct full-sequence logits from TE CP-sharded chunks (zigzag pattern).""" - batch_size, _, vocab_size = logits_list[0].shape - total_chunks = 2 * cp_world_size - chunk_size = full_seq_len // total_chunks - reconstructed = torch.zeros( - (batch_size, full_seq_len, vocab_size), dtype=logits_list[0].dtype, device=logits_list[0].device - ) - for batch_idx in range(batch_size): - for cp_idx, logits_shard in enumerate(logits_list): - chunk_indices = [cp_idx, total_chunks - cp_idx - 1] - for chunk_pos, chunk_idx in enumerate(chunk_indices): - start_idx = chunk_idx * chunk_size - end_idx = start_idx + chunk_size - shard_start = chunk_pos * chunk_size - shard_end = shard_start + chunk_size - reconstructed[batch_idx, start_idx:end_idx, :] = logits_shard[batch_idx, shard_start:shard_end, :] - return reconstructed - - -def capture_gradients(model, layer_accessor): - """Capture gradients from sample layers for comparison.""" - gradients = {} - for i, layer in enumerate(layer_accessor(model)): - for name, param in layer.named_parameters(): - if param.grad is not None: - gradients[f"layer_{i}.{name}"] = param.grad.detach().clone().cpu() - return gradients - - -def compare_results(name, ref_loss, ref_logits, ref_grads, cp_loss, cp_logits, cp_grads, rank): - """Compare CP results against non-distributed reference on rank 0.""" - if rank != 0: - return True - all_passed = True - - try: - torch.testing.assert_close(cp_loss.cpu(), ref_loss.cpu(), atol=LOSS_ATOL, rtol=LOSS_RTOL) - print(f" [{name}] Loss: PASS (ref={ref_loss.item():.6f}, cp={cp_loss.item():.6f})") - except AssertionError as e: - print(f" [{name}] Loss: FAIL - {e}") - all_passed = False - - if ref_logits is not None and cp_logits is not None: - assert cp_logits.shape == ref_logits.shape, f"Shape mismatch: {cp_logits.shape} vs {ref_logits.shape}" - cosine_sim = torch.nn.functional.cosine_similarity( - cp_logits.flatten().float().cuda(), ref_logits.flatten().float().cuda(), dim=0 - ) - passed = cosine_sim > LOGITS_COSINE_MIN - print( - f" [{name}] Logits cosine sim: {'PASS' if passed else 'FAIL'} ({cosine_sim:.6f}, min={LOGITS_COSINE_MIN})" - ) - if not passed: - all_passed = False - - if ref_grads and cp_grads: - for key in ref_grads: - if key in cp_grads: - cosine_sim = torch.nn.functional.cosine_similarity( - cp_grads[key].flatten().float(), ref_grads[key].flatten().float(), dim=0 - ) - passed = cosine_sim > GRAD_COSINE_MIN - print(f" [{name}] Grad {key}: {'PASS' if passed else 'FAIL'} (cosine={cosine_sim:.4f})") - if not passed: - all_passed = False - return all_passed - - -def validate_te(te_config, vocab_size, b, s, device, local_rank, device_mesh, cp_group, cp_rank, cp_size, rank): - """Validate TE model: DDP + CP vs non-CP with identical weights.""" - if rank == 0: - print(f"Test 1: TE model (DDP + CP={cp_size} vs non-CP baseline)") - - # DDP process group for gradient synchronization (matches test_cp_bshd.py) - group_dp_cp = device_mesh[("dp", "cp")]._flatten("dp_cp").get_group() - - # Create CP model on all ranks with identical weights - seed_everything(SEED) - te_model = create_te_model_on_gpu(te_config) - for param in te_model.parameters(): - dist.broadcast(param.data, src=0) - - # DDP + CP (matches test_cp_bshd.py pattern) - te_model = torch.nn.parallel.DistributedDataParallel( - te_model, - device_ids=[local_rank], - output_device=local_rank, - process_group=group_dp_cp, - ) - for layer in te_model.module.model.layers: - layer.set_context_parallel_group(cp_group, dist.get_process_group_ranks(cp_group), torch.cuda.Stream()) - te_model.train() - - # --- Non-CP baseline on rank 0 (same weights) --- - ref_loss = ref_logits = None - ref_grads = {} - if rank == 0: - seed_everything(SEED) - ref_model = create_te_model_on_gpu(te_config) - # Copy weights from the CP model to ensure exact match - ref_model.load_state_dict( - {k: v for k, v in te_model.state_dict().items() if not k.endswith("_extra_state")}, strict=False - ) - ref_model.train() - batch = get_dummy_data(vocab_size, b, s) - batch_cuda = {k: v.to(device) for k, v in batch.items()} - ref_out = ref_model(**batch_cuda) - ref_out.loss.backward() - ref_loss = ref_out.loss.detach().clone().cpu() - ref_logits = ref_out.logits.detach().clone().cpu() - ref_grads = capture_gradients( - ref_model, - lambda m: [ - m.model.layers[0].self_attention.core_attention, - m.model.layers[0].self_attention.layernorm_qkv, - ], - ) - print(f" Baseline loss: {ref_loss.item():.6f}") - del ref_model, ref_out, batch_cuda - gc.collect() - torch.cuda.empty_cache() - dist.barrier() - - # --- CP forward+backward --- - batch = get_dummy_data(vocab_size, b, s) - batch_cuda = {k: v.detach().to(device) for k, v in batch.items()} - batch_shard = dict( - zip( - ["input_ids", "labels"], - _split_batch_by_cp_rank( - None, - batch_cuda["input_ids"], - batch_cuda["labels"], - qvk_format="bshd", - cp_rank=cp_rank, - cp_world_size=cp_size, - ), - ) - ) - batch_shard["max_length_q"] = batch_shard["max_length_k"] = s - - dist.barrier() - te_out = te_model(**batch_shard) - - # All-gather losses - losses = [torch.zeros_like(te_out.loss) for _ in range(cp_size)] - dist.all_gather(losses, te_out.loss, group=cp_group) - cp_loss = torch.mean(torch.stack(losses)).cpu() if rank == 0 else None - - # All-gather + reconstruct logits - logits_list = [torch.zeros_like(te_out.logits.contiguous()) for _ in range(cp_size)] - dist.all_gather(logits_list, te_out.logits.contiguous(), group=cp_group) - cp_logits = reconstruct_logits_from_cp(logits_list, s, cp_size).cpu() if rank == 0 else None - - te_out.loss.backward() # DDP all-reduces gradients automatically - cp_grads = capture_gradients( - te_model.module, - lambda m: [m.model.layers[0].self_attention.core_attention, m.model.layers[0].self_attention.layernorm_qkv], - ) - - passed = compare_results("TE CP", ref_loss, ref_logits, ref_grads, cp_loss, cp_logits, cp_grads, rank) - del te_model, te_out - gc.collect() - torch.cuda.empty_cache() - dist.barrier() - return passed - - -def validate_hf(hf_config, vocab_size, b, s, device, local_rank, device_mesh, cp_group, cp_rank, cp_size, rank): - """Validate HF model: DDP + PyTorch native CP vs non-CP with identical weights.""" - if rank == 0: - print(f"\nTest 2: HF model (DDP + PyTorch native CP={cp_size} vs non-CP baseline)") - - group_dp_cp = device_mesh[("dp", "cp")]._flatten("dp_cp").get_group() - - seed_everything(SEED + 100) - hf_model = LlamaForCausalLM(hf_config).to(dtype=torch.bfloat16, device=device) - for param in hf_model.parameters(): - dist.broadcast(param.data, src=0) - - # DDP for gradient synchronization - hf_model = torch.nn.parallel.DistributedDataParallel( - hf_model, - device_ids=[local_rank], - output_device=local_rank, - process_group=group_dp_cp, - ) - hf_model.train() - - # --- Non-CP baseline on rank 0 (same weights, already on GPU) --- - ref_loss = ref_logits = None - ref_grads = {} - if rank == 0: - ref_model = LlamaForCausalLM(hf_config).to(dtype=torch.bfloat16, device=device) - ref_model.load_state_dict(hf_model.module.state_dict()) - ref_model.train() - batch = get_dummy_data(vocab_size, b, s) - batch_cuda = {k: v.to(device) for k, v in batch.items()} - pos_ids = torch.arange(s, device=device).unsqueeze(0).expand(b, -1) - ref_out = ref_model(position_ids=pos_ids, **batch_cuda) - ref_out.loss.backward() - ref_loss = ref_out.loss.detach().clone().cpu() - ref_logits = ref_out.logits.detach().clone().cpu() - ref_grads = capture_gradients(ref_model, lambda m: [m.model.layers[0].self_attn, m.model.layers[0].mlp]) - print(f" Baseline loss: {ref_loss.item():.6f}") - del ref_model, ref_out, batch_cuda - gc.collect() - torch.cuda.empty_cache() - dist.barrier() - - # --- CP forward+backward --- - batch = get_dummy_data(vocab_size, b, s) - hf_full_ids = batch["input_ids"].to(device) - hf_full_labels = batch["labels"].to(device) - hf_full_pos = torch.arange(s, device=device).unsqueeze(0).expand(b, -1) - - cp_mesh = device_mesh["cp"] - with context_parallel(cp_mesh, buffers=(hf_full_ids, hf_full_labels, hf_full_pos), buffer_seq_dims=(1, 1, 1)): - hf_out = hf_model(input_ids=hf_full_ids, labels=hf_full_labels, position_ids=hf_full_pos) - cp_loss_local = hf_out.loss.detach().clone() - cp_logits_local = hf_out.logits.detach().clone() - hf_out.loss.backward() # DDP all-reduces gradients automatically - cp_grads = capture_gradients(hf_model.module, lambda m: [m.model.layers[0].self_attn, m.model.layers[0].mlp]) - - # All-gather losses - losses = [torch.zeros_like(cp_loss_local) for _ in range(cp_size)] - dist.all_gather(losses, cp_loss_local, group=cp_group) - cp_loss = torch.mean(torch.stack(losses)).cpu() if rank == 0 else None - - # Reconstruct logits with load balancer - load_balancer = _HeadTailLoadBalancer(seq_length=s, world_size=cp_size, device=device) - (cp_logits_full,) = context_parallel_unshard(cp_mesh, [cp_logits_local], [1], load_balancer=load_balancer) - cp_logits = cp_logits_full.cpu() if rank == 0 else None - - passed = compare_results("HF CP", ref_loss, ref_logits, ref_grads, cp_loss, cp_logits, cp_grads, rank) - del hf_model, hf_out - gc.collect() - torch.cuda.empty_cache() - dist.barrier() - return passed - - -def main(): - """Run golden value tests for CP correctness with FSDP2.""" - dist.init_process_group(backend="cpu:gloo,cuda:nccl") - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - torch.cuda.set_device(local_rank) - rank = dist.get_rank() - world_size = dist.get_world_size() - device = torch.device(f"cuda:{local_rank}") - - parser = argparse.ArgumentParser(description="Golden value tests for CP with FSDP2") - parser.add_argument("--config-path", default="./model_configs/lingua-1B", help="Model config directory") - parser.add_argument("--batch-size", type=int, default=2, help="Micro batch size") - parser.add_argument("--seq-len", type=int, default=64, help="Sequence length") - args = parser.parse_args() - - config_dict = json.loads(Path(args.config_path, "config.json").read_text()) - vocab_size = config_dict["vocab_size"] - b, s = args.batch_size, args.seq_len - - cp_size = world_size - dp_size = 1 - device_mesh = init_device_mesh("cuda", mesh_shape=(dp_size, cp_size), mesh_dim_names=("dp", "cp")) - - cp_group = device_mesh["cp"].get_group() - cp_rank = device_mesh["cp"].get_local_rank() - - if s % (2 * cp_size) != 0: - if rank == 0: - print(f"ERROR: seq_len ({s}) must be divisible by {2 * cp_size}") - dist.destroy_process_group() - sys.exit(1) - - if rank == 0: - print(f"Golden Value Tests (FSDP2 + CP): B={b}, S={s}, CP={cp_size}") - print() - - te_config = NVLlamaConfig.from_pretrained( - args.config_path, - dtype=torch.bfloat16, - attn_input_format="bshd", - self_attn_mask_type="causal", - ) - hf_config = LlamaConfig.from_pretrained(args.config_path) - hf_config._attn_implementation = "sdpa" - - te_passed = validate_te( - te_config, vocab_size, b, s, device, local_rank, device_mesh, cp_group, cp_rank, cp_size, rank - ) - hf_passed = validate_hf( - hf_config, vocab_size, b, s, device, local_rank, device_mesh, cp_group, cp_rank, cp_size, rank - ) - - if rank == 0: - print() - print(f"Summary: TE [{'PASS' if te_passed else 'FAIL'}], HF [{'PASS' if hf_passed else 'FAIL'}]") - if not (te_passed and hf_passed): - sys.exit(1) - - dist.destroy_process_group() - - -if __name__ == "__main__": - main() diff --git a/bionemo-recipes/recipes/llama3_native_te/flops.py b/bionemo-recipes/recipes/llama3_native_te/flops.py new file mode 100644 index 0000000000..fafd718ecf --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/flops.py @@ -0,0 +1,827 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# --- BEGIN COPIED FILE NOTICE --- +# This file is copied from: bionemo-recipes/models/esm2/flops.py +# Do not modify this file directly. Instead, modify the source and run: +# python ci/scripts/check_copied_files.py --fix +# --- END COPIED FILE NOTICE --- + +"""Architecture-independent FLOPs counting, MFU calculation, and communication overhead estimation. + +Supports transformer architectures (Llama, ESM2, CodonFM, etc.) and Hyena (Evo2). +Designed to be copied to any recipe via check_copied_files.py and hooked into +training scripts for live MFU tracking. + +Usage as a library (in training scripts): + from flops import MFUTracker, from_hf_config + tracker = MFUTracker.from_config_dict(config_dict, batch_size=4, seq_len=4096, num_gpus=2) + mfu_info = tracker.compute_mfu(step_time=0.5) + +Usage as a CLI: + python flops.py gpu-info + python flops.py flops --config-path ./model_configs/lingua-1B + python flops.py cp-comm --config-path ./model_configs/lingua-1B --cp-size 2 + torchrun --nproc_per_node=2 flops.py bandwidth +""" + +import gc +import math +import time +from contextlib import nullcontext +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from torch.utils.flop_counter import FlopCounterMode + + +# ============================================================================= +# GPU Peak TFLOPS +# ============================================================================= + +GPU_PEAK_TFLOPS_BF16 = { + "H100": 989.0, + "H200": 989.0, + "A100": 312.0, + "A6000": 155.0, + "A5000": 111.0, + "L40": 181.0, + "RTX 4090": 330.0, + "RTX 3090": 142.0, + "GH200": 989.0, + "B200": 2250.0, + "GB200": 2250.0, +} + + +def detect_gpu_peak_tflops(): + """Auto-detect GPU peak bf16 TFLOPS from device name via substring match.""" + device_name = torch.cuda.get_device_name(0) + for gpu_key, tflops in GPU_PEAK_TFLOPS_BF16.items(): + if gpu_key.lower() in device_name.lower(): + return tflops, device_name + return None, device_name + + +# ============================================================================= +# Model FLOPs Config +# ============================================================================= + +# Model types that use gated MLP (SwiGLU/GeGLU) with 3 projections instead of 2. +GATED_MLP_MODEL_TYPES = frozenset({"llama", "mistral", "qwen2"}) + + +@dataclass(frozen=True) +class ModelFLOPsConfig: + """Architecture-independent parameters for FLOPs calculation. + + Can be constructed manually or via from_hf_config() for auto-detection. + """ + + hidden_size: int # H + num_hidden_layers: int # L + num_attention_heads: int # n_heads + num_kv_heads: int # n_kv (== n_heads for MHA) + head_dim: int # H // n_heads + intermediate_size: int # I (FFN intermediate dimension) + num_mlp_projections: int # 2 (standard FFN) or 3 (SwiGLU/GLU) + vocab_size: int # V + has_lm_head: bool # True for LM models, False for ViT etc. + + +def from_hf_config(config_dict, **overrides): + """Create ModelFLOPsConfig from an HF-compatible config dict. + + Auto-detects architecture: + - GQA vs MHA: from num_key_value_heads (absent = MHA) + - Gated MLP (3 proj) vs standard FFN (2 proj): from model_type + - LM head: from vocab_size > 0 + + Args: + config_dict: Dict with standard HF config keys (hidden_size, num_hidden_layers, etc.). + Works with config.json dicts, config.to_dict(), or MODEL_PRESETS dicts. + **overrides: Explicit overrides for any field (e.g., num_mlp_projections=3). + """ + h = config_dict["hidden_size"] + n_heads = config_dict["num_attention_heads"] + n_kv = config_dict.get("num_key_value_heads", n_heads) + vocab = config_dict.get("vocab_size", 0) + model_type = config_dict.get("model_type", "") + + # Detect gated MLP (3 projections) vs standard FFN (2 projections). + # Llama/Mistral/Qwen use SwiGLU (gate + up + down = 3 projections). + # ESM2/CodonFM/Geneformer/BERT use standard FFN (up + down = 2 projections). + num_mlp_proj = 3 if model_type in GATED_MLP_MODEL_TYPES else 2 + + kwargs = { + "hidden_size": h, + "num_hidden_layers": config_dict["num_hidden_layers"], + "num_attention_heads": n_heads, + "num_kv_heads": n_kv, + "head_dim": h // n_heads, + "intermediate_size": config_dict["intermediate_size"], + "num_mlp_projections": num_mlp_proj, + "vocab_size": vocab, + "has_lm_head": vocab > 0, + } + kwargs.update(overrides) + return ModelFLOPsConfig(**kwargs) + + +# ============================================================================= +# FLOPs Formulas +# ============================================================================= + + +def compute_flops_analytical(config, batch_size, seq_len): + """First-principles FLOPs for any transformer (GQA/MHA, SwiGLU/GELU). + + Counts matmul FLOPs only (2 FLOPs per multiply-accumulate). Excludes softmax, + layer norms, activations, and element-wise ops. + + Handles: + - GQA vs MHA: K/V projection sizes based on config.num_kv_heads + - SwiGLU vs standard FFN: 2 or 3 MLP projections + - LM head presence + + Returns: + (total_training_flops, per_layer_breakdown_dict, lm_head_forward_flops) + """ + b, s, h = batch_size, seq_len, config.hidden_size + kv_dim = config.num_kv_heads * config.head_dim + + breakdown = { + "Q projection": 2 * b * s * h * h, + "K projection": 2 * b * s * h * kv_dim, + "V projection": 2 * b * s * h * kv_dim, + "O projection": 2 * b * s * h * h, + "Attn logits": 2 * b * s * s * h, + "Attn values": 2 * b * s * s * h, + } + + ffn = config.intermediate_size + if config.num_mlp_projections == 3: + # SwiGLU/GeGLU: gate + up + down = 3 matmuls + breakdown["Gate projection"] = 2 * b * s * h * ffn + breakdown["Up projection"] = 2 * b * s * h * ffn + breakdown["Down projection"] = 2 * b * s * ffn * h + else: + # Standard FFN: up + down = 2 matmuls + breakdown["Up projection"] = 2 * b * s * h * ffn + breakdown["Down projection"] = 2 * b * s * ffn * h + + per_layer_fwd = sum(breakdown.values()) + lm_head_fwd = 2 * b * s * h * config.vocab_size if config.has_lm_head else 0 + total_fwd = config.num_hidden_layers * per_layer_fwd + lm_head_fwd + total_training = 3 * total_fwd + + return total_training, breakdown, lm_head_fwd + + +def compute_flops_simplified(batch_size, seq_len, hidden_size, num_layers, vocab_size): + """Simplified formula assuming standard MHA + standard FFN with I=4H. + + This is the formula from the Llama3 README: + (24*B*S*H^2 + 4*B*S^2*H) * 3*L + 6*B*S*H*V + + The 24*H^2 coefficient assumes: 8*H^2 for 4 attention projections (MHA) + + 16*H^2 for 2 MLP projections with I=4H. See first_principles.md for details. + """ + b, s, h = batch_size, seq_len, hidden_size + return (24 * b * s * h * h + 4 * b * s * s * h) * (3 * num_layers) + (6 * b * s * h * vocab_size) + + +def compute_flops_hyena(config, batch_size, seq_len, hyena_layer_counts=None): + """FLOPs for Hyena-based models (Evo2). + + Based on evo2_provider.py. Hyena replaces attention with O(N) FFT convolution. + + Args: + config: ModelFLOPsConfig with model dimensions. + batch_size: Batch size. + seq_len: Sequence length. + hyena_layer_counts: Optional dict {"S": n, "D": n, "H": n, "A": n} for + short/medium/long conv and attention layer counts. If None, assumes + all layers are long-conv Hyena (H=num_layers, no attention). + """ + b, s, h = batch_size, seq_len, config.hidden_size + ffn = config.intermediate_size + + if hyena_layer_counts is None: + hyena_layer_counts = {"S": 0, "D": 0, "H": config.num_hidden_layers, "A": 0} + + # Common per-layer FLOPs + pre_attn_qkv_proj = 2 * 3 * b * s * h * h + post_attn_proj = 2 * b * s * h * h + glu_ffn = 2 * 3 * b * s * ffn * h + + # Layer-type-specific FLOPs (defaults from evo2_provider.py) + attn = 2 * 2 * b * h * s * s # Standard S^2 attention + hyena_proj = 2 * 3 * b * s * 3 * h # short_conv_L=3 default + hyena_short_conv = 2 * b * s * 7 * h # short_conv_len=7 + hyena_medium_conv = 2 * b * s * 128 * h # medium_conv_len=128 + hyena_long_fft = b * 10 * s * math.log2(max(s, 2)) * h + + n_s = hyena_layer_counts.get("S", 0) + n_d = hyena_layer_counts.get("D", 0) + n_h = hyena_layer_counts.get("H", 0) + n_a = hyena_layer_counts.get("A", 0) + + logits = 2 * b * s * h * config.vocab_size if config.has_lm_head else 0 + + total_fwd = ( + logits + + config.num_hidden_layers * (pre_attn_qkv_proj + post_attn_proj + glu_ffn) + + n_a * attn + + (n_s + n_d + n_h) * hyena_proj + + n_s * hyena_short_conv + + n_d * hyena_medium_conv + + int(n_h * hyena_long_fft) + ) + + return 3 * total_fwd + + +# Backward-compatible wrappers for existing compare_mfu*.py scripts. + + +def compute_flops_first_principles(b, s, h, num_layers, n_kv_heads, head_dim, ffn_hidden_size, vocab_size): + """Backward-compatible wrapper. Assumes SwiGLU (3 MLP projections).""" + config = ModelFLOPsConfig( + hidden_size=h, + num_hidden_layers=num_layers, + num_attention_heads=h // head_dim, + num_kv_heads=n_kv_heads, + head_dim=head_dim, + intermediate_size=ffn_hidden_size, + num_mlp_projections=3, + vocab_size=vocab_size, + has_lm_head=True, + ) + return compute_flops_analytical(config, b, s) + + +def compute_flops_readme(b, s, h, num_layers, vocab_size): + """Backward-compatible wrapper for the simplified README formula.""" + return compute_flops_simplified(b, s, h, num_layers, vocab_size) + + +# ============================================================================= +# MFU Tracker +# ============================================================================= + + +class MFUTracker: + """Tracks MFU during training. Initialize once, call compute_mfu() per step. + + Usage: + tracker = MFUTracker.from_config_dict(config_dict, batch_size=4, seq_len=4096, num_gpus=2) + # In training loop: + mfu_info = tracker.compute_mfu(step_time=0.5) + print(f"MFU: {mfu_info['mfu']:.1f}%") + """ + + def __init__( + self, + config, + batch_size, + seq_len, + num_gpus=1, + parallelism=None, + peak_tflops=None, + formula="analytical", + hyena_layer_counts=None, + ): + """Initialize MFU tracker. + + Args: + config: ModelFLOPsConfig instance. + batch_size: Micro batch size per GPU. + seq_len: Sequence length. + num_gpus: Total number of GPUs. + parallelism: Dict of parallelism dimensions, e.g. {"dp": 1, "cp": 2, "tp": 1}. + Used for communication overhead estimation. + peak_tflops: GPU peak bf16 TFLOPS. Auto-detected if None. + formula: "analytical", "simplified", or "hyena". + hyena_layer_counts: For Hyena formula, dict of layer type counts. + """ + self.config = config + self.batch_size = batch_size + self.seq_len = seq_len + self.num_gpus = num_gpus + self.parallelism = parallelism or {} + self.formula = formula + + if formula == "analytical": + self.total_flops, self.breakdown, self.lm_head_flops = compute_flops_analytical( + config, batch_size, seq_len + ) + elif formula == "simplified": + self.total_flops = compute_flops_simplified( + batch_size, seq_len, config.hidden_size, config.num_hidden_layers, config.vocab_size + ) + self.breakdown = None + self.lm_head_flops = 0 + elif formula == "hyena": + self.total_flops = compute_flops_hyena(config, batch_size, seq_len, hyena_layer_counts) + self.breakdown = None + self.lm_head_flops = 0 + else: + raise ValueError(f"Unknown formula: {formula!r}. Use 'analytical', 'simplified', or 'hyena'.") + + self.per_gpu_flops = self.total_flops // max(num_gpus, 1) + + if peak_tflops is not None: + self.peak_tflops = peak_tflops + self.device_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "unknown" + else: + detected, self.device_name = detect_gpu_peak_tflops() + self.peak_tflops = detected + + self.comm_bytes = self._estimate_comm() + + @classmethod + def from_config_dict(cls, config_dict, batch_size, seq_len, **kwargs): + """Create from an HF config dict with auto-detection.""" + config = from_hf_config(config_dict) + return cls(config, batch_size, seq_len, **kwargs) + + def compute_mfu(self, step_time): + """Compute MFU from measured step time. + + Args: + step_time: Wall-clock time for one training step (seconds). + + Returns: + Dict with mfu (%), tflops_per_gpu, per_gpu_flops, total_flops, step_time. + """ + tflops = self.per_gpu_flops / step_time / 1e12 + mfu = tflops / self.peak_tflops * 100 if self.peak_tflops else 0.0 + return { + "mfu": mfu, + "tflops_per_gpu": tflops, + "per_gpu_flops": self.per_gpu_flops, + "total_flops": self.total_flops, + "step_time": step_time, + } + + def estimate_comm_overhead(self, step_time, measured_bw_gbps=None): + """Estimate communication overhead as a fraction of step time. + + Args: + step_time: Measured step time in seconds. + measured_bw_gbps: Measured P2P bandwidth in GB/s. If None, uses a default estimate. + + Returns: + Dict with comm_bytes, estimated_comm_time, comm_pct. + """ + bw = measured_bw_gbps or 6.0 # Default ~PCIe Gen3 x8 + comm_time = self.comm_bytes / (bw * 1e9) if bw > 0 else 0.0 + comm_pct = comm_time / step_time * 100 if step_time > 0 else 0.0 + return {"comm_bytes": self.comm_bytes, "estimated_comm_time": comm_time, "comm_pct": comm_pct} + + def _estimate_comm(self): + """Estimate total communication bytes per step based on parallelism.""" + total = 0 + cp_size = self.parallelism.get("cp", 1) + dp_size = self.parallelism.get("dp", 1) + + if cp_size > 1: + total += estimate_cp_comm_bytes( + self.batch_size, + self.seq_len, + self.config.num_hidden_layers, + self.config.num_kv_heads, + self.config.head_dim, + cp_size, + ) + + if dp_size > 1: + # FSDP reduce-scatter estimate: ~2 * model_params * dtype_bytes * (dp-1)/dp + model_params = _estimate_model_params(self.config) + total += 2 * model_params * 2 * (dp_size - 1) // dp_size + + return total + + def summary(self): + """Return a human-readable summary string.""" + lines = [ + f"MFUTracker: {self.formula} formula, {self.num_gpus} GPU(s)", + f" Model: H={self.config.hidden_size}, L={self.config.num_hidden_layers}," + f" heads={self.config.num_attention_heads}, kv_heads={self.config.num_kv_heads}," + f" I={self.config.intermediate_size}, V={self.config.vocab_size}", + f" MLP projections: {self.config.num_mlp_projections}" + f" ({'SwiGLU/GLU' if self.config.num_mlp_projections == 3 else 'standard FFN'})", + f" Batch: B={self.batch_size}, S={self.seq_len}", + f" Total FLOPs/step: {format_flops(self.total_flops)} ({format_flops_exact(self.total_flops)})", + f" Per-GPU FLOPs: {format_flops(self.per_gpu_flops)}", + f" GPU: {self.device_name} (Peak: {self.peak_tflops} TFLOPS)" if self.peak_tflops else " GPU: unknown", + ] + if self.parallelism: + lines.append(f" Parallelism: {self.parallelism}") + if self.comm_bytes > 0: + lines.append(f" Estimated comm: {format_bytes(self.comm_bytes)}/step") + return "\n".join(lines) + + +# ============================================================================= +# Communication Estimation +# ============================================================================= + + +def estimate_cp_comm_bytes(b, s, num_layers, n_kv_heads, head_dim, cp_size, dtype_bytes=2): + """Estimate total bytes transferred for CP ring attention per training step. + + Ring attention sends local KV chunks around the ring. Per layer forward: + (cp-1) steps, each sending B * (S/cp) * 2 * kv_dim * dtype_bytes. + Training = ~2x forward communication (forward sends KV, backward sends dKV). + """ + if cp_size <= 1: + return 0 + s_local = s // cp_size + kv_dim = n_kv_heads * head_dim + per_layer_fwd = (cp_size - 1) * b * s_local * 2 * kv_dim * dtype_bytes + return 2 * num_layers * per_layer_fwd + + +def _estimate_model_params(config): + """Rough parameter count estimate from config dimensions.""" + h = config.hidden_size + kv_dim = config.num_kv_heads * config.head_dim + attn_params = h * h + 2 * h * kv_dim + h * h # Q + K + V + O + mlp_params = config.num_mlp_projections * h * config.intermediate_size + layer_params = attn_params + mlp_params + total = config.num_hidden_layers * layer_params + if config.has_lm_head: + total += config.vocab_size * h * 2 # embed + lm_head + return total + + +# ============================================================================= +# Step Time Measurement +# ============================================================================= + + +def measure_step_time( + model, + input_ids, + num_warmup=10, + num_timed=20, + distributed=False, + cp_context_fn=None, + labels=None, + position_ids=None, + **extra_fwd_kwargs, +): + """Measure average training step time (forward + backward). + + Args: + model: The model to benchmark. + input_ids: Input tensor. + num_warmup: Warmup iterations (discarded). + num_timed: Timed iterations to average. + distributed: Whether to use dist.barrier() for synchronization. + cp_context_fn: Optional callable returning a context manager for CP. + labels: Optional labels tensor. If None, uses input_ids. + position_ids: Optional position_ids for correct RoPE with CP. + **extra_fwd_kwargs: Additional kwargs for model forward. + """ + if labels is None: + labels = input_ids + + fwd_kwargs = {"input_ids": input_ids, "labels": labels, **extra_fwd_kwargs} + if position_ids is not None: + fwd_kwargs["position_ids"] = position_ids + + for _ in range(num_warmup): + ctx = cp_context_fn() if cp_context_fn else nullcontext() + with ctx: + output = model(**fwd_kwargs) + output.loss.backward() + model.zero_grad(set_to_none=True) + + if distributed: + dist.barrier() + torch.cuda.synchronize() + + times = [] + for _ in range(num_timed): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start.record() + + ctx = cp_context_fn() if cp_context_fn else nullcontext() + with ctx: + output = model(**fwd_kwargs) + output.loss.backward() + model.zero_grad(set_to_none=True) + + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end) / 1000.0) + + return sum(times) / len(times) + + +# ============================================================================= +# Utilities +# ============================================================================= + + +def split_for_cp_bshd(tensor, cp_rank, cp_size): + """Split a BSHD tensor for CP using the dual-chunk zigzag pattern.""" + if cp_size <= 1: + return tensor + total_chunks = 2 * cp_size + seq_len = tensor.size(1) + chunk_size = seq_len // total_chunks + chunk_indices = [cp_rank, total_chunks - cp_rank - 1] + slices = [tensor[:, idx * chunk_size : (idx + 1) * chunk_size] for idx in chunk_indices] + return torch.cat(slices, dim=1) + + +def measure_bus_bandwidth(device, world_size, num_iters=20, num_elements=10_000_000): + """Measure unidirectional P2P bandwidth via send/recv (matches CP ring pattern).""" + if world_size <= 1: + return 0.0 + + rank = dist.get_rank() + tensor = torch.randn(num_elements, device=device, dtype=torch.bfloat16) + peer = 1 - rank + + for _ in range(5): + if rank == 0: + dist.send(tensor, dst=peer) + dist.recv(tensor, src=peer) + else: + dist.recv(tensor, src=peer) + dist.send(tensor, dst=peer) + torch.cuda.synchronize() + + dist.barrier() + torch.cuda.synchronize() + start = time.perf_counter() + for _ in range(num_iters): + if rank == 0: + dist.send(tensor, dst=peer) + else: + dist.recv(tensor, src=peer) + torch.cuda.synchronize() + elapsed = time.perf_counter() - start + + data_bytes = tensor.nelement() * tensor.element_size() + return num_iters * data_bytes / elapsed / 1e9 + + +def count_flops_with_model(model, input_ids): + """Count forward FLOPs using PyTorch's FlopCounterMode, return 3x for training.""" + flop_counter = FlopCounterMode(display=False) + with flop_counter: + model(input_ids=input_ids) + return flop_counter.get_total_flops() * 3 + + +def load_model_config(config_path): + """Load model config dict from a local path or HuggingFace model ID. + + Supports: + - Local directory: ./model_configs/lingua-1B (reads config.json inside) + - Local file: ./model_configs/lingua-1B/config.json + - HF model ID: nvidia/esm2_t36_3B_UR50D (fetches from HuggingFace Hub) + """ + import json + from pathlib import Path + + path = Path(config_path) + if path.is_dir(): + path = path / "config.json" + if path.exists(): + return json.loads(path.read_text()) + + # Fall back to HuggingFace Hub + from transformers import AutoConfig + + hf_config = AutoConfig.from_pretrained(config_path, trust_remote_code=True) + return hf_config.to_dict() + + +def cleanup_model(model): + """Delete a model and free GPU memory.""" + del model + gc.collect() + torch.cuda.empty_cache() + + +# ============================================================================= +# Formatting +# ============================================================================= + + +def format_flops(flops): + """Format FLOPs with appropriate unit (G/T/P).""" + if flops >= 1e15: + return f"{flops / 1e15:.2f} P" + elif flops >= 1e12: + return f"{flops / 1e12:.2f} T" + elif flops >= 1e9: + return f"{flops / 1e9:.2f} G" + else: + return f"{flops:.2e}" + + +def format_flops_exact(flops): + """Format FLOPs as the full integer with commas.""" + return f"{int(flops):,}" + + +def format_bytes(num_bytes): + """Format bytes with appropriate unit.""" + if num_bytes >= 1e9: + return f"{num_bytes / 1e9:.2f} GB" + elif num_bytes >= 1e6: + return f"{num_bytes / 1e6:.2f} MB" + elif num_bytes >= 1e3: + return f"{num_bytes / 1e3:.2f} KB" + else: + return f"{num_bytes:.0f} B" + + +def print_breakdown(breakdown, lm_head_fwd, num_layers, total_flops, model_params): + """Print first-principles FLOPs breakdown.""" + print() + print("--- First Principles Breakdown (forward pass, per layer) ---") + per_layer_total = sum(breakdown.values()) + for component, flops_val in breakdown.items(): + pct = flops_val / per_layer_total * 100 + print(f" {component:<20} {format_flops(flops_val):>12} ({pct:>5.1f}%)") + total_fwd = num_layers * per_layer_total + lm_head_fwd + print(f" {'LM head':<20} {format_flops(lm_head_fwd):>12}") + print(f" {'Per-layer total':<20} {format_flops(per_layer_total):>12}") + print(f" {'All layers (x' + str(num_layers) + ')':<20} {format_flops(num_layers * per_layer_total):>12}") + print(f" {'Total forward':<20} {format_flops(total_fwd):>12}") + print(f" {'Total training (3x)':<20} {format_flops(total_flops):>12}") + print(f" {'Model params':<20} {model_params / 1e9:.2f}B") + + +# ============================================================================= +# CLI +# ============================================================================= + + +def _cli_bandwidth(): + """Measure P2P bandwidth. Launch with: torchrun --nproc_per_node=2 flops.py bandwidth.""" + import os + + dist.init_process_group(backend="cpu:gloo,cuda:nccl") + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + torch.cuda.set_device(local_rank) + rank = dist.get_rank() + world_size = dist.get_world_size() + device = torch.device(f"cuda:{local_rank}") + + if rank == 0: + print(f"Measuring P2P bandwidth between {world_size} GPUs...") + for i in range(world_size): + print(f" GPU {i}: {torch.cuda.get_device_name(i)}") + + bw = measure_bus_bandwidth(device, world_size) + if rank == 0: + print(f"\nUnidirectional P2P bandwidth: {bw:.2f} GB/s") + + dist.destroy_process_group() + + +def _cli_gpu_info(): + """Print GPU info and peak TFLOPS.""" + peak, name = detect_gpu_peak_tflops() + print(f"GPU: {name}") + if peak: + print(f"Peak bf16 TFLOPS: {peak:.1f}") + else: + print("Peak bf16 TFLOPS: unknown (use --peak-tflops to override)") + print() + print("Known GPUs:") + for gpu, tflops in GPU_PEAK_TFLOPS_BF16.items(): + print(f" {gpu:<16} {tflops:>8.1f} TFLOPS") + + +def _cli_flops(): + """Compute FLOPs for a model config.""" + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("command") + parser.add_argument("--config-path", default="./model_configs/lingua-1B") + parser.add_argument("--batch-size", type=int, default=1) + parser.add_argument("--seq-len", type=int, default=4096) + parser.add_argument("--formula", default="analytical", choices=["analytical", "simplified"]) + parser.add_argument("--num-gpus", type=int, default=1, help="Number of GPUs for per-GPU FLOPs and comm estimates") + parser.add_argument("--cp-size", type=int, default=1, help="Context parallelism size for comm overhead estimate") + parser.add_argument("--p2p-bw", type=float, default=6.0, help="P2P bandwidth in GB/s for comm time estimate") + args = parser.parse_args() + + cfg_dict = load_model_config(args.config_path) + config = from_hf_config(cfg_dict) + b, s = args.batch_size, args.seq_len + + print( + f"Config: H={config.hidden_size}, L={config.num_hidden_layers}," + f" n_heads={config.num_attention_heads}, n_kv={config.num_kv_heads}," + f" I={config.intermediate_size}, V={config.vocab_size}" + ) + print( + f"MLP: {config.num_mlp_projections} projections" + f" ({'SwiGLU/GLU' if config.num_mlp_projections == 3 else 'standard FFN'})" + ) + print(f"Batch: B={b}, S={s}, GPUs={args.num_gpus}, CP={args.cp_size}") + print() + + simplified = compute_flops_simplified(b, s, config.hidden_size, config.num_hidden_layers, config.vocab_size) + analytical, breakdown, lm_head = compute_flops_analytical(config, b, s) + + print(f"{'Method':<24} {'FLOPs/step':>14} {'Per-GPU':>14} {'Exact':>30}") + print("-" * 86) + for name, flops in [("Simplified (README)", simplified), ("Analytical", analytical)]: + per_gpu = flops // max(args.num_gpus, 1) + print(f"{name:<24} {format_flops(flops):>14} {format_flops(per_gpu):>14} {format_flops_exact(flops):>30}") + + if simplified != analytical: + diff = analytical - simplified + print(f"\nDifference: {format_flops_exact(diff)} ({diff / simplified * 100:+.2f}%)") + else: + print("\nFormulas agree exactly for this config.") + + # Communication overhead estimate + if args.cp_size > 1: + dp_size = args.num_gpus // args.cp_size + parallelism = {"dp": dp_size, "cp": args.cp_size} + tracker = MFUTracker(config, b, s, num_gpus=args.num_gpus, parallelism=parallelism) + print(f"\n--- Communication Overhead (CP={args.cp_size}, P2P BW={args.p2p_bw} GB/s) ---") + print( + f" CP ring attention: {format_bytes(tracker.comm_bytes)}/step ({format_flops_exact(tracker.comm_bytes)} bytes)" + ) + comm_time = tracker.comm_bytes / (args.p2p_bw * 1e9) if args.p2p_bw > 0 else 0 + print(f" Estimated comm time: {comm_time:.4f}s") + + model_params = _estimate_model_params(config) + print_breakdown(breakdown, lm_head, config.num_hidden_layers, analytical, model_params) + + +def _cli_cp_comm(): + """Estimate CP communication volume.""" + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("command") + parser.add_argument("--config-path", default="./model_configs/lingua-1B") + parser.add_argument("--batch-size", type=int, default=1) + parser.add_argument("--seq-len", type=int, default=16384) + parser.add_argument("--cp-size", type=int, default=2) + args = parser.parse_args() + + cfg_dict = load_model_config(args.config_path) + config = from_hf_config(cfg_dict) + b, s = args.batch_size, args.seq_len + + comm = estimate_cp_comm_bytes(b, s, config.num_hidden_layers, config.num_kv_heads, config.head_dim, args.cp_size) + print( + f"CP={args.cp_size}, B={b}, S={s}, L={config.num_hidden_layers}," + f" n_kv_heads={config.num_kv_heads}, head_dim={config.head_dim}" + ) + print(f"Estimated CP ring attention communication: {format_bytes(comm)}/step ({format_flops_exact(comm)} bytes)") + + +if __name__ == "__main__": + import sys + + commands = { + "bandwidth": ("Measure P2P bandwidth (requires torchrun --nproc_per_node=2)", _cli_bandwidth), + "gpu-info": ("Print GPU info and peak TFLOPS", _cli_gpu_info), + "flops": ("Compute FLOPs for a model config", _cli_flops), + "cp-comm": ("Estimate CP communication volume", _cli_cp_comm), + } + + if len(sys.argv) < 2 or sys.argv[1] in ("-h", "--help") or sys.argv[1] not in commands: + print("Usage: python flops.py [options]") + print(" torchrun --nproc_per_node=2 flops.py bandwidth") + print() + print("Commands:") + for cmd, (desc, _) in commands.items(): + print(f" {cmd:<16} {desc}") + sys.exit(0 if len(sys.argv) >= 2 and sys.argv[1] in ("-h", "--help") else 1) + + commands[sys.argv[1]][1]() diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml index 9302a0758d..ef5e064a88 100644 --- a/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml +++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml @@ -79,6 +79,8 @@ fp8_stats_config: fp8_stats_file: ./fp8_debugging_stats.yaml fp8_log_dir: ./log_fp8_stats +log_mfu: false + profiler: enabled: false start_step: 10 diff --git a/bionemo-recipes/recipes/llama3_native_te/train_ddp.py b/bionemo-recipes/recipes/llama3_native_te/train_ddp.py index 413b9262c7..b646b00f41 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_ddp.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_ddp.py @@ -25,6 +25,7 @@ import gc import logging +import time from contextlib import nullcontext from pathlib import Path @@ -42,6 +43,7 @@ from checkpoint import load_checkpoint_ddp, save_checkpoint_ddp, save_final_model_ddp, should_save_checkpoint from dataset import create_bshd_dataloader, create_thd_dataloader from distributed_config import DistributedConfig +from flops import MFUTracker, from_hf_config from fp8_debugging import initialize_fp8_debugging from modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM from perf_logger import PerfLogger @@ -143,6 +145,19 @@ def main(args: DictConfig) -> float | None: perf_logger = PerfLogger(dist_config, args, start_step=start_step) + # --- MFU Tracking (optional) --- + mfu_tracker = None + if args.get("log_mfu", False): + mfu_tracker = MFUTracker( + config=from_hf_config(config.to_dict()), + batch_size=args.dataset.micro_batch_size, + seq_len=args.dataset.max_seq_length, + num_gpus=dist_config.world_size, + parallelism={"dp": dist_config.world_size}, + ) + if dist_config.is_main_process(): + logger.info("MFU tracking enabled:\n%s", mfu_tracker.summary()) + gc.collect() torch.cuda.empty_cache() @@ -150,6 +165,7 @@ def main(args: DictConfig) -> float | None: logger.info("Starting training loop from step %s to %s", start_step, args.num_train_steps) step = start_step micro_step = 0 # Gradient accumulation step counter + step_start_time = time.perf_counter() while step < args.num_train_steps: for batch in train_dataloader: batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa: PLW2901 @@ -186,6 +202,13 @@ def main(args: DictConfig) -> float | None: lr=optimizer.param_groups[0]["lr"], ) + if mfu_tracker is not None: + step_time = time.perf_counter() - step_start_time + mfu_info = mfu_tracker.compute_mfu(step_time) + if dist_config.is_main_process(): + logger.info("MFU: %.1f%% (%.2f TFLOPS/GPU)", mfu_info["mfu"], mfu_info["tflops_per_gpu"]) + step_start_time = time.perf_counter() + if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps): save_checkpoint_ddp( model=model, diff --git a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py index da19daa2a7..ba8305c8f8 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py @@ -24,6 +24,7 @@ import gc import logging +import time from contextlib import nullcontext from pathlib import Path @@ -48,6 +49,7 @@ ) from dataset import create_bshd_dataloader, create_thd_dataloader from distributed_config import DistributedConfig +from flops import MFUTracker, from_hf_config from fp8_debugging import initialize_fp8_debugging from modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM from perf_logger import PerfLogger @@ -157,6 +159,19 @@ def main(args: DictConfig) -> float | None: perf_logger = PerfLogger(dist_config, args, start_step=start_step) + # --- MFU Tracking (optional) --- + mfu_tracker = None + if args.get("log_mfu", False): + mfu_tracker = MFUTracker( + config=from_hf_config(config.to_dict()), + batch_size=args.dataset.micro_batch_size, + seq_len=args.dataset.max_seq_length, + num_gpus=dist_config.world_size, + parallelism={"dp": dist_config.world_size}, + ) + if dist_config.is_main_process(): + logger.info("MFU tracking enabled:\n%s", mfu_tracker.summary()) + gc.collect() torch.cuda.empty_cache() @@ -164,6 +179,7 @@ def main(args: DictConfig) -> float | None: logger.info("Starting training loop from step %s to %s", start_step, args.num_train_steps) step = start_step micro_step = 0 # Gradient accumulation step counter + step_start_time = time.perf_counter() while step < args.num_train_steps: for batch in train_dataloader: batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa: PLW2901 @@ -199,6 +215,13 @@ def main(args: DictConfig) -> float | None: lr=optimizer.param_groups[0]["lr"], ) + if mfu_tracker is not None: + step_time = time.perf_counter() - step_start_time + mfu_info = mfu_tracker.compute_mfu(step_time) + if dist_config.is_main_process(): + logger.info("MFU: %.1f%% (%.2f TFLOPS/GPU)", mfu_info["mfu"], mfu_info["tflops_per_gpu"]) + step_start_time = time.perf_counter() + if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps): save_checkpoint_fsdp2( model=model, diff --git a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py index eaf1a1b39f..5b24d2eaa2 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py @@ -25,6 +25,7 @@ import gc import logging +import time from contextlib import nullcontext from pathlib import Path @@ -48,6 +49,7 @@ from collator import ContextParallelDataLoaderWrapper, DataCollatorForContextParallel from dataset import create_bshd_dataloader, create_thd_dataloader from distributed_config import DistributedConfig +from flops import MFUTracker, from_hf_config from modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM from perf_logger import PerfLogger from scheduler import get_cosine_annealing_schedule_with_warmup @@ -179,6 +181,19 @@ def main(args: DictConfig) -> float | None: perf_logger = PerfLogger(dist_config, args, start_step=start_step) + # --- MFU Tracking (optional) --- + mfu_tracker = None + if args.get("log_mfu", False): + mfu_tracker = MFUTracker( + config=from_hf_config(config.to_dict()), + batch_size=args.dataset.micro_batch_size, + seq_len=args.dataset.max_seq_length, + num_gpus=dist_config.world_size, + parallelism={"dp": dist_config.world_size // args.cp_size, "cp": args.cp_size}, + ) + if dist_config.is_main_process(): + logger.info("MFU tracking enabled:\n%s", mfu_tracker.summary()) + gc.collect() torch.cuda.empty_cache() @@ -186,6 +201,7 @@ def main(args: DictConfig) -> float | None: logger.info("Starting training loop from step %s to %s", start_step, args.num_train_steps) step = start_step micro_step = 0 # Gradient accumulation step counter + step_start_time = time.perf_counter() while step < args.num_train_steps: for batch in train_dataloader: batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa: PLW2901 @@ -224,6 +240,13 @@ def main(args: DictConfig) -> float | None: lr=optimizer.param_groups[0]["lr"], ) + if mfu_tracker is not None: + step_time = time.perf_counter() - step_start_time + mfu_info = mfu_tracker.compute_mfu(step_time) + if dist_config.is_main_process(): + logger.info("MFU: %.1f%% (%.2f TFLOPS/GPU)", mfu_info["mfu"], mfu_info["tflops_per_gpu"]) + step_start_time = time.perf_counter() + if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps): save_checkpoint_fsdp2( model=model, diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/flops.py b/bionemo-recipes/recipes/opengenome2_llama_native_te/flops.py new file mode 100644 index 0000000000..fafd718ecf --- /dev/null +++ b/bionemo-recipes/recipes/opengenome2_llama_native_te/flops.py @@ -0,0 +1,827 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# --- BEGIN COPIED FILE NOTICE --- +# This file is copied from: bionemo-recipes/models/esm2/flops.py +# Do not modify this file directly. Instead, modify the source and run: +# python ci/scripts/check_copied_files.py --fix +# --- END COPIED FILE NOTICE --- + +"""Architecture-independent FLOPs counting, MFU calculation, and communication overhead estimation. + +Supports transformer architectures (Llama, ESM2, CodonFM, etc.) and Hyena (Evo2). +Designed to be copied to any recipe via check_copied_files.py and hooked into +training scripts for live MFU tracking. + +Usage as a library (in training scripts): + from flops import MFUTracker, from_hf_config + tracker = MFUTracker.from_config_dict(config_dict, batch_size=4, seq_len=4096, num_gpus=2) + mfu_info = tracker.compute_mfu(step_time=0.5) + +Usage as a CLI: + python flops.py gpu-info + python flops.py flops --config-path ./model_configs/lingua-1B + python flops.py cp-comm --config-path ./model_configs/lingua-1B --cp-size 2 + torchrun --nproc_per_node=2 flops.py bandwidth +""" + +import gc +import math +import time +from contextlib import nullcontext +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from torch.utils.flop_counter import FlopCounterMode + + +# ============================================================================= +# GPU Peak TFLOPS +# ============================================================================= + +GPU_PEAK_TFLOPS_BF16 = { + "H100": 989.0, + "H200": 989.0, + "A100": 312.0, + "A6000": 155.0, + "A5000": 111.0, + "L40": 181.0, + "RTX 4090": 330.0, + "RTX 3090": 142.0, + "GH200": 989.0, + "B200": 2250.0, + "GB200": 2250.0, +} + + +def detect_gpu_peak_tflops(): + """Auto-detect GPU peak bf16 TFLOPS from device name via substring match.""" + device_name = torch.cuda.get_device_name(0) + for gpu_key, tflops in GPU_PEAK_TFLOPS_BF16.items(): + if gpu_key.lower() in device_name.lower(): + return tflops, device_name + return None, device_name + + +# ============================================================================= +# Model FLOPs Config +# ============================================================================= + +# Model types that use gated MLP (SwiGLU/GeGLU) with 3 projections instead of 2. +GATED_MLP_MODEL_TYPES = frozenset({"llama", "mistral", "qwen2"}) + + +@dataclass(frozen=True) +class ModelFLOPsConfig: + """Architecture-independent parameters for FLOPs calculation. + + Can be constructed manually or via from_hf_config() for auto-detection. + """ + + hidden_size: int # H + num_hidden_layers: int # L + num_attention_heads: int # n_heads + num_kv_heads: int # n_kv (== n_heads for MHA) + head_dim: int # H // n_heads + intermediate_size: int # I (FFN intermediate dimension) + num_mlp_projections: int # 2 (standard FFN) or 3 (SwiGLU/GLU) + vocab_size: int # V + has_lm_head: bool # True for LM models, False for ViT etc. + + +def from_hf_config(config_dict, **overrides): + """Create ModelFLOPsConfig from an HF-compatible config dict. + + Auto-detects architecture: + - GQA vs MHA: from num_key_value_heads (absent = MHA) + - Gated MLP (3 proj) vs standard FFN (2 proj): from model_type + - LM head: from vocab_size > 0 + + Args: + config_dict: Dict with standard HF config keys (hidden_size, num_hidden_layers, etc.). + Works with config.json dicts, config.to_dict(), or MODEL_PRESETS dicts. + **overrides: Explicit overrides for any field (e.g., num_mlp_projections=3). + """ + h = config_dict["hidden_size"] + n_heads = config_dict["num_attention_heads"] + n_kv = config_dict.get("num_key_value_heads", n_heads) + vocab = config_dict.get("vocab_size", 0) + model_type = config_dict.get("model_type", "") + + # Detect gated MLP (3 projections) vs standard FFN (2 projections). + # Llama/Mistral/Qwen use SwiGLU (gate + up + down = 3 projections). + # ESM2/CodonFM/Geneformer/BERT use standard FFN (up + down = 2 projections). + num_mlp_proj = 3 if model_type in GATED_MLP_MODEL_TYPES else 2 + + kwargs = { + "hidden_size": h, + "num_hidden_layers": config_dict["num_hidden_layers"], + "num_attention_heads": n_heads, + "num_kv_heads": n_kv, + "head_dim": h // n_heads, + "intermediate_size": config_dict["intermediate_size"], + "num_mlp_projections": num_mlp_proj, + "vocab_size": vocab, + "has_lm_head": vocab > 0, + } + kwargs.update(overrides) + return ModelFLOPsConfig(**kwargs) + + +# ============================================================================= +# FLOPs Formulas +# ============================================================================= + + +def compute_flops_analytical(config, batch_size, seq_len): + """First-principles FLOPs for any transformer (GQA/MHA, SwiGLU/GELU). + + Counts matmul FLOPs only (2 FLOPs per multiply-accumulate). Excludes softmax, + layer norms, activations, and element-wise ops. + + Handles: + - GQA vs MHA: K/V projection sizes based on config.num_kv_heads + - SwiGLU vs standard FFN: 2 or 3 MLP projections + - LM head presence + + Returns: + (total_training_flops, per_layer_breakdown_dict, lm_head_forward_flops) + """ + b, s, h = batch_size, seq_len, config.hidden_size + kv_dim = config.num_kv_heads * config.head_dim + + breakdown = { + "Q projection": 2 * b * s * h * h, + "K projection": 2 * b * s * h * kv_dim, + "V projection": 2 * b * s * h * kv_dim, + "O projection": 2 * b * s * h * h, + "Attn logits": 2 * b * s * s * h, + "Attn values": 2 * b * s * s * h, + } + + ffn = config.intermediate_size + if config.num_mlp_projections == 3: + # SwiGLU/GeGLU: gate + up + down = 3 matmuls + breakdown["Gate projection"] = 2 * b * s * h * ffn + breakdown["Up projection"] = 2 * b * s * h * ffn + breakdown["Down projection"] = 2 * b * s * ffn * h + else: + # Standard FFN: up + down = 2 matmuls + breakdown["Up projection"] = 2 * b * s * h * ffn + breakdown["Down projection"] = 2 * b * s * ffn * h + + per_layer_fwd = sum(breakdown.values()) + lm_head_fwd = 2 * b * s * h * config.vocab_size if config.has_lm_head else 0 + total_fwd = config.num_hidden_layers * per_layer_fwd + lm_head_fwd + total_training = 3 * total_fwd + + return total_training, breakdown, lm_head_fwd + + +def compute_flops_simplified(batch_size, seq_len, hidden_size, num_layers, vocab_size): + """Simplified formula assuming standard MHA + standard FFN with I=4H. + + This is the formula from the Llama3 README: + (24*B*S*H^2 + 4*B*S^2*H) * 3*L + 6*B*S*H*V + + The 24*H^2 coefficient assumes: 8*H^2 for 4 attention projections (MHA) + + 16*H^2 for 2 MLP projections with I=4H. See first_principles.md for details. + """ + b, s, h = batch_size, seq_len, hidden_size + return (24 * b * s * h * h + 4 * b * s * s * h) * (3 * num_layers) + (6 * b * s * h * vocab_size) + + +def compute_flops_hyena(config, batch_size, seq_len, hyena_layer_counts=None): + """FLOPs for Hyena-based models (Evo2). + + Based on evo2_provider.py. Hyena replaces attention with O(N) FFT convolution. + + Args: + config: ModelFLOPsConfig with model dimensions. + batch_size: Batch size. + seq_len: Sequence length. + hyena_layer_counts: Optional dict {"S": n, "D": n, "H": n, "A": n} for + short/medium/long conv and attention layer counts. If None, assumes + all layers are long-conv Hyena (H=num_layers, no attention). + """ + b, s, h = batch_size, seq_len, config.hidden_size + ffn = config.intermediate_size + + if hyena_layer_counts is None: + hyena_layer_counts = {"S": 0, "D": 0, "H": config.num_hidden_layers, "A": 0} + + # Common per-layer FLOPs + pre_attn_qkv_proj = 2 * 3 * b * s * h * h + post_attn_proj = 2 * b * s * h * h + glu_ffn = 2 * 3 * b * s * ffn * h + + # Layer-type-specific FLOPs (defaults from evo2_provider.py) + attn = 2 * 2 * b * h * s * s # Standard S^2 attention + hyena_proj = 2 * 3 * b * s * 3 * h # short_conv_L=3 default + hyena_short_conv = 2 * b * s * 7 * h # short_conv_len=7 + hyena_medium_conv = 2 * b * s * 128 * h # medium_conv_len=128 + hyena_long_fft = b * 10 * s * math.log2(max(s, 2)) * h + + n_s = hyena_layer_counts.get("S", 0) + n_d = hyena_layer_counts.get("D", 0) + n_h = hyena_layer_counts.get("H", 0) + n_a = hyena_layer_counts.get("A", 0) + + logits = 2 * b * s * h * config.vocab_size if config.has_lm_head else 0 + + total_fwd = ( + logits + + config.num_hidden_layers * (pre_attn_qkv_proj + post_attn_proj + glu_ffn) + + n_a * attn + + (n_s + n_d + n_h) * hyena_proj + + n_s * hyena_short_conv + + n_d * hyena_medium_conv + + int(n_h * hyena_long_fft) + ) + + return 3 * total_fwd + + +# Backward-compatible wrappers for existing compare_mfu*.py scripts. + + +def compute_flops_first_principles(b, s, h, num_layers, n_kv_heads, head_dim, ffn_hidden_size, vocab_size): + """Backward-compatible wrapper. Assumes SwiGLU (3 MLP projections).""" + config = ModelFLOPsConfig( + hidden_size=h, + num_hidden_layers=num_layers, + num_attention_heads=h // head_dim, + num_kv_heads=n_kv_heads, + head_dim=head_dim, + intermediate_size=ffn_hidden_size, + num_mlp_projections=3, + vocab_size=vocab_size, + has_lm_head=True, + ) + return compute_flops_analytical(config, b, s) + + +def compute_flops_readme(b, s, h, num_layers, vocab_size): + """Backward-compatible wrapper for the simplified README formula.""" + return compute_flops_simplified(b, s, h, num_layers, vocab_size) + + +# ============================================================================= +# MFU Tracker +# ============================================================================= + + +class MFUTracker: + """Tracks MFU during training. Initialize once, call compute_mfu() per step. + + Usage: + tracker = MFUTracker.from_config_dict(config_dict, batch_size=4, seq_len=4096, num_gpus=2) + # In training loop: + mfu_info = tracker.compute_mfu(step_time=0.5) + print(f"MFU: {mfu_info['mfu']:.1f}%") + """ + + def __init__( + self, + config, + batch_size, + seq_len, + num_gpus=1, + parallelism=None, + peak_tflops=None, + formula="analytical", + hyena_layer_counts=None, + ): + """Initialize MFU tracker. + + Args: + config: ModelFLOPsConfig instance. + batch_size: Micro batch size per GPU. + seq_len: Sequence length. + num_gpus: Total number of GPUs. + parallelism: Dict of parallelism dimensions, e.g. {"dp": 1, "cp": 2, "tp": 1}. + Used for communication overhead estimation. + peak_tflops: GPU peak bf16 TFLOPS. Auto-detected if None. + formula: "analytical", "simplified", or "hyena". + hyena_layer_counts: For Hyena formula, dict of layer type counts. + """ + self.config = config + self.batch_size = batch_size + self.seq_len = seq_len + self.num_gpus = num_gpus + self.parallelism = parallelism or {} + self.formula = formula + + if formula == "analytical": + self.total_flops, self.breakdown, self.lm_head_flops = compute_flops_analytical( + config, batch_size, seq_len + ) + elif formula == "simplified": + self.total_flops = compute_flops_simplified( + batch_size, seq_len, config.hidden_size, config.num_hidden_layers, config.vocab_size + ) + self.breakdown = None + self.lm_head_flops = 0 + elif formula == "hyena": + self.total_flops = compute_flops_hyena(config, batch_size, seq_len, hyena_layer_counts) + self.breakdown = None + self.lm_head_flops = 0 + else: + raise ValueError(f"Unknown formula: {formula!r}. Use 'analytical', 'simplified', or 'hyena'.") + + self.per_gpu_flops = self.total_flops // max(num_gpus, 1) + + if peak_tflops is not None: + self.peak_tflops = peak_tflops + self.device_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "unknown" + else: + detected, self.device_name = detect_gpu_peak_tflops() + self.peak_tflops = detected + + self.comm_bytes = self._estimate_comm() + + @classmethod + def from_config_dict(cls, config_dict, batch_size, seq_len, **kwargs): + """Create from an HF config dict with auto-detection.""" + config = from_hf_config(config_dict) + return cls(config, batch_size, seq_len, **kwargs) + + def compute_mfu(self, step_time): + """Compute MFU from measured step time. + + Args: + step_time: Wall-clock time for one training step (seconds). + + Returns: + Dict with mfu (%), tflops_per_gpu, per_gpu_flops, total_flops, step_time. + """ + tflops = self.per_gpu_flops / step_time / 1e12 + mfu = tflops / self.peak_tflops * 100 if self.peak_tflops else 0.0 + return { + "mfu": mfu, + "tflops_per_gpu": tflops, + "per_gpu_flops": self.per_gpu_flops, + "total_flops": self.total_flops, + "step_time": step_time, + } + + def estimate_comm_overhead(self, step_time, measured_bw_gbps=None): + """Estimate communication overhead as a fraction of step time. + + Args: + step_time: Measured step time in seconds. + measured_bw_gbps: Measured P2P bandwidth in GB/s. If None, uses a default estimate. + + Returns: + Dict with comm_bytes, estimated_comm_time, comm_pct. + """ + bw = measured_bw_gbps or 6.0 # Default ~PCIe Gen3 x8 + comm_time = self.comm_bytes / (bw * 1e9) if bw > 0 else 0.0 + comm_pct = comm_time / step_time * 100 if step_time > 0 else 0.0 + return {"comm_bytes": self.comm_bytes, "estimated_comm_time": comm_time, "comm_pct": comm_pct} + + def _estimate_comm(self): + """Estimate total communication bytes per step based on parallelism.""" + total = 0 + cp_size = self.parallelism.get("cp", 1) + dp_size = self.parallelism.get("dp", 1) + + if cp_size > 1: + total += estimate_cp_comm_bytes( + self.batch_size, + self.seq_len, + self.config.num_hidden_layers, + self.config.num_kv_heads, + self.config.head_dim, + cp_size, + ) + + if dp_size > 1: + # FSDP reduce-scatter estimate: ~2 * model_params * dtype_bytes * (dp-1)/dp + model_params = _estimate_model_params(self.config) + total += 2 * model_params * 2 * (dp_size - 1) // dp_size + + return total + + def summary(self): + """Return a human-readable summary string.""" + lines = [ + f"MFUTracker: {self.formula} formula, {self.num_gpus} GPU(s)", + f" Model: H={self.config.hidden_size}, L={self.config.num_hidden_layers}," + f" heads={self.config.num_attention_heads}, kv_heads={self.config.num_kv_heads}," + f" I={self.config.intermediate_size}, V={self.config.vocab_size}", + f" MLP projections: {self.config.num_mlp_projections}" + f" ({'SwiGLU/GLU' if self.config.num_mlp_projections == 3 else 'standard FFN'})", + f" Batch: B={self.batch_size}, S={self.seq_len}", + f" Total FLOPs/step: {format_flops(self.total_flops)} ({format_flops_exact(self.total_flops)})", + f" Per-GPU FLOPs: {format_flops(self.per_gpu_flops)}", + f" GPU: {self.device_name} (Peak: {self.peak_tflops} TFLOPS)" if self.peak_tflops else " GPU: unknown", + ] + if self.parallelism: + lines.append(f" Parallelism: {self.parallelism}") + if self.comm_bytes > 0: + lines.append(f" Estimated comm: {format_bytes(self.comm_bytes)}/step") + return "\n".join(lines) + + +# ============================================================================= +# Communication Estimation +# ============================================================================= + + +def estimate_cp_comm_bytes(b, s, num_layers, n_kv_heads, head_dim, cp_size, dtype_bytes=2): + """Estimate total bytes transferred for CP ring attention per training step. + + Ring attention sends local KV chunks around the ring. Per layer forward: + (cp-1) steps, each sending B * (S/cp) * 2 * kv_dim * dtype_bytes. + Training = ~2x forward communication (forward sends KV, backward sends dKV). + """ + if cp_size <= 1: + return 0 + s_local = s // cp_size + kv_dim = n_kv_heads * head_dim + per_layer_fwd = (cp_size - 1) * b * s_local * 2 * kv_dim * dtype_bytes + return 2 * num_layers * per_layer_fwd + + +def _estimate_model_params(config): + """Rough parameter count estimate from config dimensions.""" + h = config.hidden_size + kv_dim = config.num_kv_heads * config.head_dim + attn_params = h * h + 2 * h * kv_dim + h * h # Q + K + V + O + mlp_params = config.num_mlp_projections * h * config.intermediate_size + layer_params = attn_params + mlp_params + total = config.num_hidden_layers * layer_params + if config.has_lm_head: + total += config.vocab_size * h * 2 # embed + lm_head + return total + + +# ============================================================================= +# Step Time Measurement +# ============================================================================= + + +def measure_step_time( + model, + input_ids, + num_warmup=10, + num_timed=20, + distributed=False, + cp_context_fn=None, + labels=None, + position_ids=None, + **extra_fwd_kwargs, +): + """Measure average training step time (forward + backward). + + Args: + model: The model to benchmark. + input_ids: Input tensor. + num_warmup: Warmup iterations (discarded). + num_timed: Timed iterations to average. + distributed: Whether to use dist.barrier() for synchronization. + cp_context_fn: Optional callable returning a context manager for CP. + labels: Optional labels tensor. If None, uses input_ids. + position_ids: Optional position_ids for correct RoPE with CP. + **extra_fwd_kwargs: Additional kwargs for model forward. + """ + if labels is None: + labels = input_ids + + fwd_kwargs = {"input_ids": input_ids, "labels": labels, **extra_fwd_kwargs} + if position_ids is not None: + fwd_kwargs["position_ids"] = position_ids + + for _ in range(num_warmup): + ctx = cp_context_fn() if cp_context_fn else nullcontext() + with ctx: + output = model(**fwd_kwargs) + output.loss.backward() + model.zero_grad(set_to_none=True) + + if distributed: + dist.barrier() + torch.cuda.synchronize() + + times = [] + for _ in range(num_timed): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start.record() + + ctx = cp_context_fn() if cp_context_fn else nullcontext() + with ctx: + output = model(**fwd_kwargs) + output.loss.backward() + model.zero_grad(set_to_none=True) + + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end) / 1000.0) + + return sum(times) / len(times) + + +# ============================================================================= +# Utilities +# ============================================================================= + + +def split_for_cp_bshd(tensor, cp_rank, cp_size): + """Split a BSHD tensor for CP using the dual-chunk zigzag pattern.""" + if cp_size <= 1: + return tensor + total_chunks = 2 * cp_size + seq_len = tensor.size(1) + chunk_size = seq_len // total_chunks + chunk_indices = [cp_rank, total_chunks - cp_rank - 1] + slices = [tensor[:, idx * chunk_size : (idx + 1) * chunk_size] for idx in chunk_indices] + return torch.cat(slices, dim=1) + + +def measure_bus_bandwidth(device, world_size, num_iters=20, num_elements=10_000_000): + """Measure unidirectional P2P bandwidth via send/recv (matches CP ring pattern).""" + if world_size <= 1: + return 0.0 + + rank = dist.get_rank() + tensor = torch.randn(num_elements, device=device, dtype=torch.bfloat16) + peer = 1 - rank + + for _ in range(5): + if rank == 0: + dist.send(tensor, dst=peer) + dist.recv(tensor, src=peer) + else: + dist.recv(tensor, src=peer) + dist.send(tensor, dst=peer) + torch.cuda.synchronize() + + dist.barrier() + torch.cuda.synchronize() + start = time.perf_counter() + for _ in range(num_iters): + if rank == 0: + dist.send(tensor, dst=peer) + else: + dist.recv(tensor, src=peer) + torch.cuda.synchronize() + elapsed = time.perf_counter() - start + + data_bytes = tensor.nelement() * tensor.element_size() + return num_iters * data_bytes / elapsed / 1e9 + + +def count_flops_with_model(model, input_ids): + """Count forward FLOPs using PyTorch's FlopCounterMode, return 3x for training.""" + flop_counter = FlopCounterMode(display=False) + with flop_counter: + model(input_ids=input_ids) + return flop_counter.get_total_flops() * 3 + + +def load_model_config(config_path): + """Load model config dict from a local path or HuggingFace model ID. + + Supports: + - Local directory: ./model_configs/lingua-1B (reads config.json inside) + - Local file: ./model_configs/lingua-1B/config.json + - HF model ID: nvidia/esm2_t36_3B_UR50D (fetches from HuggingFace Hub) + """ + import json + from pathlib import Path + + path = Path(config_path) + if path.is_dir(): + path = path / "config.json" + if path.exists(): + return json.loads(path.read_text()) + + # Fall back to HuggingFace Hub + from transformers import AutoConfig + + hf_config = AutoConfig.from_pretrained(config_path, trust_remote_code=True) + return hf_config.to_dict() + + +def cleanup_model(model): + """Delete a model and free GPU memory.""" + del model + gc.collect() + torch.cuda.empty_cache() + + +# ============================================================================= +# Formatting +# ============================================================================= + + +def format_flops(flops): + """Format FLOPs with appropriate unit (G/T/P).""" + if flops >= 1e15: + return f"{flops / 1e15:.2f} P" + elif flops >= 1e12: + return f"{flops / 1e12:.2f} T" + elif flops >= 1e9: + return f"{flops / 1e9:.2f} G" + else: + return f"{flops:.2e}" + + +def format_flops_exact(flops): + """Format FLOPs as the full integer with commas.""" + return f"{int(flops):,}" + + +def format_bytes(num_bytes): + """Format bytes with appropriate unit.""" + if num_bytes >= 1e9: + return f"{num_bytes / 1e9:.2f} GB" + elif num_bytes >= 1e6: + return f"{num_bytes / 1e6:.2f} MB" + elif num_bytes >= 1e3: + return f"{num_bytes / 1e3:.2f} KB" + else: + return f"{num_bytes:.0f} B" + + +def print_breakdown(breakdown, lm_head_fwd, num_layers, total_flops, model_params): + """Print first-principles FLOPs breakdown.""" + print() + print("--- First Principles Breakdown (forward pass, per layer) ---") + per_layer_total = sum(breakdown.values()) + for component, flops_val in breakdown.items(): + pct = flops_val / per_layer_total * 100 + print(f" {component:<20} {format_flops(flops_val):>12} ({pct:>5.1f}%)") + total_fwd = num_layers * per_layer_total + lm_head_fwd + print(f" {'LM head':<20} {format_flops(lm_head_fwd):>12}") + print(f" {'Per-layer total':<20} {format_flops(per_layer_total):>12}") + print(f" {'All layers (x' + str(num_layers) + ')':<20} {format_flops(num_layers * per_layer_total):>12}") + print(f" {'Total forward':<20} {format_flops(total_fwd):>12}") + print(f" {'Total training (3x)':<20} {format_flops(total_flops):>12}") + print(f" {'Model params':<20} {model_params / 1e9:.2f}B") + + +# ============================================================================= +# CLI +# ============================================================================= + + +def _cli_bandwidth(): + """Measure P2P bandwidth. Launch with: torchrun --nproc_per_node=2 flops.py bandwidth.""" + import os + + dist.init_process_group(backend="cpu:gloo,cuda:nccl") + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + torch.cuda.set_device(local_rank) + rank = dist.get_rank() + world_size = dist.get_world_size() + device = torch.device(f"cuda:{local_rank}") + + if rank == 0: + print(f"Measuring P2P bandwidth between {world_size} GPUs...") + for i in range(world_size): + print(f" GPU {i}: {torch.cuda.get_device_name(i)}") + + bw = measure_bus_bandwidth(device, world_size) + if rank == 0: + print(f"\nUnidirectional P2P bandwidth: {bw:.2f} GB/s") + + dist.destroy_process_group() + + +def _cli_gpu_info(): + """Print GPU info and peak TFLOPS.""" + peak, name = detect_gpu_peak_tflops() + print(f"GPU: {name}") + if peak: + print(f"Peak bf16 TFLOPS: {peak:.1f}") + else: + print("Peak bf16 TFLOPS: unknown (use --peak-tflops to override)") + print() + print("Known GPUs:") + for gpu, tflops in GPU_PEAK_TFLOPS_BF16.items(): + print(f" {gpu:<16} {tflops:>8.1f} TFLOPS") + + +def _cli_flops(): + """Compute FLOPs for a model config.""" + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("command") + parser.add_argument("--config-path", default="./model_configs/lingua-1B") + parser.add_argument("--batch-size", type=int, default=1) + parser.add_argument("--seq-len", type=int, default=4096) + parser.add_argument("--formula", default="analytical", choices=["analytical", "simplified"]) + parser.add_argument("--num-gpus", type=int, default=1, help="Number of GPUs for per-GPU FLOPs and comm estimates") + parser.add_argument("--cp-size", type=int, default=1, help="Context parallelism size for comm overhead estimate") + parser.add_argument("--p2p-bw", type=float, default=6.0, help="P2P bandwidth in GB/s for comm time estimate") + args = parser.parse_args() + + cfg_dict = load_model_config(args.config_path) + config = from_hf_config(cfg_dict) + b, s = args.batch_size, args.seq_len + + print( + f"Config: H={config.hidden_size}, L={config.num_hidden_layers}," + f" n_heads={config.num_attention_heads}, n_kv={config.num_kv_heads}," + f" I={config.intermediate_size}, V={config.vocab_size}" + ) + print( + f"MLP: {config.num_mlp_projections} projections" + f" ({'SwiGLU/GLU' if config.num_mlp_projections == 3 else 'standard FFN'})" + ) + print(f"Batch: B={b}, S={s}, GPUs={args.num_gpus}, CP={args.cp_size}") + print() + + simplified = compute_flops_simplified(b, s, config.hidden_size, config.num_hidden_layers, config.vocab_size) + analytical, breakdown, lm_head = compute_flops_analytical(config, b, s) + + print(f"{'Method':<24} {'FLOPs/step':>14} {'Per-GPU':>14} {'Exact':>30}") + print("-" * 86) + for name, flops in [("Simplified (README)", simplified), ("Analytical", analytical)]: + per_gpu = flops // max(args.num_gpus, 1) + print(f"{name:<24} {format_flops(flops):>14} {format_flops(per_gpu):>14} {format_flops_exact(flops):>30}") + + if simplified != analytical: + diff = analytical - simplified + print(f"\nDifference: {format_flops_exact(diff)} ({diff / simplified * 100:+.2f}%)") + else: + print("\nFormulas agree exactly for this config.") + + # Communication overhead estimate + if args.cp_size > 1: + dp_size = args.num_gpus // args.cp_size + parallelism = {"dp": dp_size, "cp": args.cp_size} + tracker = MFUTracker(config, b, s, num_gpus=args.num_gpus, parallelism=parallelism) + print(f"\n--- Communication Overhead (CP={args.cp_size}, P2P BW={args.p2p_bw} GB/s) ---") + print( + f" CP ring attention: {format_bytes(tracker.comm_bytes)}/step ({format_flops_exact(tracker.comm_bytes)} bytes)" + ) + comm_time = tracker.comm_bytes / (args.p2p_bw * 1e9) if args.p2p_bw > 0 else 0 + print(f" Estimated comm time: {comm_time:.4f}s") + + model_params = _estimate_model_params(config) + print_breakdown(breakdown, lm_head, config.num_hidden_layers, analytical, model_params) + + +def _cli_cp_comm(): + """Estimate CP communication volume.""" + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("command") + parser.add_argument("--config-path", default="./model_configs/lingua-1B") + parser.add_argument("--batch-size", type=int, default=1) + parser.add_argument("--seq-len", type=int, default=16384) + parser.add_argument("--cp-size", type=int, default=2) + args = parser.parse_args() + + cfg_dict = load_model_config(args.config_path) + config = from_hf_config(cfg_dict) + b, s = args.batch_size, args.seq_len + + comm = estimate_cp_comm_bytes(b, s, config.num_hidden_layers, config.num_kv_heads, config.head_dim, args.cp_size) + print( + f"CP={args.cp_size}, B={b}, S={s}, L={config.num_hidden_layers}," + f" n_kv_heads={config.num_kv_heads}, head_dim={config.head_dim}" + ) + print(f"Estimated CP ring attention communication: {format_bytes(comm)}/step ({format_flops_exact(comm)} bytes)") + + +if __name__ == "__main__": + import sys + + commands = { + "bandwidth": ("Measure P2P bandwidth (requires torchrun --nproc_per_node=2)", _cli_bandwidth), + "gpu-info": ("Print GPU info and peak TFLOPS", _cli_gpu_info), + "flops": ("Compute FLOPs for a model config", _cli_flops), + "cp-comm": ("Estimate CP communication volume", _cli_cp_comm), + } + + if len(sys.argv) < 2 or sys.argv[1] in ("-h", "--help") or sys.argv[1] not in commands: + print("Usage: python flops.py [options]") + print(" torchrun --nproc_per_node=2 flops.py bandwidth") + print() + print("Commands:") + for cmd, (desc, _) in commands.items(): + print(f" {cmd:<16} {desc}") + sys.exit(0 if len(sys.argv) >= 2 and sys.argv[1] in ("-h", "--help") else 1) + + commands[sys.argv[1]][1]() diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/opengenome2_llama_native_te/hydra_config/defaults.yaml index 4295017f26..d532a83f06 100644 --- a/bionemo-recipes/recipes/opengenome2_llama_native_te/hydra_config/defaults.yaml +++ b/bionemo-recipes/recipes/opengenome2_llama_native_te/hydra_config/defaults.yaml @@ -103,6 +103,8 @@ fp8_stats_config: fp8_stats_file: ./fp8_debugging_stats.yaml fp8_log_dir: ./log_fp8_stats +log_mfu: false + profiler: enabled: false start_step: 10 diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/train_fsdp2.py b/bionemo-recipes/recipes/opengenome2_llama_native_te/train_fsdp2.py index 15d173b955..1963bab24d 100644 --- a/bionemo-recipes/recipes/opengenome2_llama_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/opengenome2_llama_native_te/train_fsdp2.py @@ -28,6 +28,7 @@ import gc import logging import random +import time from contextlib import nullcontext from pathlib import Path @@ -62,6 +63,7 @@ ) from dataset import create_bshd_dataloader, create_thd_dataloader from distributed_config import DistributedConfig +from flops import MFUTracker, from_hf_config from fp8_debugging import initialize_fp8_debugging from opengenome_modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM from optimizer import get_parameter_groups_with_weight_decay @@ -260,6 +262,19 @@ def main(args: DictConfig) -> float | None: perf_logger = PerfLogger(dist_config, args) + # --- MFU Tracking (optional) --- + mfu_tracker = None + if args.get("log_mfu", False): + mfu_tracker = MFUTracker( + config=from_hf_config(config.to_dict()), + batch_size=args.dataset.micro_batch_size, + seq_len=args.dataset.max_seq_length, + num_gpus=dist_config.world_size, + parallelism={"dp": dist_config.world_size}, + ) + if dist_config.is_main_process(): + logger.info("MFU tracking enabled:\n%s", mfu_tracker.summary()) + # Setup validation if enabled val_config = getattr(args, "validation", None) val_enabled = val_config is not None and getattr(val_config, "enabled", False) @@ -301,6 +316,7 @@ def main(args: DictConfig) -> float | None: logger.info(f"Starting training loop from step {start_step} to {args.num_train_steps}") step = start_step micro_step = 0 + step_start_time = time.perf_counter() if train_dataloader is None: raise RuntimeError("Expected train_dataloader to be initialized before training.") @@ -335,6 +351,13 @@ def main(args: DictConfig) -> float | None: lr=optimizer.param_groups[0]["lr"], ) + if mfu_tracker is not None: + step_time = time.perf_counter() - step_start_time + mfu_info = mfu_tracker.compute_mfu(step_time) + if dist_config.is_main_process(): + logger.info("MFU: %.1f%% (%.2f TFLOPS/GPU)", mfu_info["mfu"], mfu_info["tflops_per_gpu"]) + step_start_time = time.perf_counter() + if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps): save_checkpoint_fsdp2( model=model, diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/train_fsdp2_cp.py b/bionemo-recipes/recipes/opengenome2_llama_native_te/train_fsdp2_cp.py index 3319fb5d25..a500adb8a2 100644 --- a/bionemo-recipes/recipes/opengenome2_llama_native_te/train_fsdp2_cp.py +++ b/bionemo-recipes/recipes/opengenome2_llama_native_te/train_fsdp2_cp.py @@ -33,6 +33,7 @@ import gc import logging +import time from contextlib import nullcontext from pathlib import Path @@ -69,6 +70,7 @@ from collator import ContextParallelDataLoaderWrapper, DataCollatorForContextParallel from dataset import create_bshd_dataloader, create_thd_dataloader from distributed_config import DistributedConfig +from flops import MFUTracker, from_hf_config from fp8_debugging import initialize_fp8_debugging from opengenome_modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM from optimizer import get_parameter_groups_with_weight_decay @@ -300,6 +302,19 @@ def main(args: DictConfig) -> float | None: perf_logger = PerfLogger(dist_config, args) + # --- MFU Tracking (optional) --- + mfu_tracker = None + if args.get("log_mfu", False): + mfu_tracker = MFUTracker( + config=from_hf_config(config.to_dict()), + batch_size=args.dataset.micro_batch_size, + seq_len=args.dataset.max_seq_length, + num_gpus=dist_config.world_size, + parallelism={"dp": dist_config.world_size // args.cp_size, "cp": args.cp_size}, + ) + if dist_config.is_main_process(): + logger.info("MFU tracking enabled:\n%s", mfu_tracker.summary()) + gc.collect() torch.cuda.empty_cache() @@ -307,6 +322,7 @@ def main(args: DictConfig) -> float | None: logger.info(f"Starting training loop from step {start_step} to {args.num_train_steps}") step = start_step micro_step = 0 + step_start_time = time.perf_counter() while step < args.num_train_steps: for batch in train_dataloader: @@ -343,6 +359,13 @@ def main(args: DictConfig) -> float | None: lr=optimizer.param_groups[0]["lr"], ) + if mfu_tracker is not None: + step_time = time.perf_counter() - step_start_time + mfu_info = mfu_tracker.compute_mfu(step_time) + if dist_config.is_main_process(): + logger.info("MFU: %.1f%% (%.2f TFLOPS/GPU)", mfu_info["mfu"], mfu_info["tflops_per_gpu"]) + step_start_time = time.perf_counter() + if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps): save_checkpoint_fsdp2( model=model, diff --git a/ci/scripts/check_copied_files.py b/ci/scripts/check_copied_files.py index 92d65903cb..3a5094f6bf 100755 --- a/ci/scripts/check_copied_files.py +++ b/ci/scripts/check_copied_files.py @@ -199,6 +199,13 @@ def _compare_file_contents(source_file: Path, dest_file: Path, source_display: s "bionemo-recipes/models/codonfm/modeling_codonfm_te.py": [ "bionemo-recipes/recipes/codonfm_native_te/modeling_codonfm_te.py", ], + # FLOPs / MFU module - synced to recipes + "bionemo-recipes/models/esm2/flops.py": [ + "bionemo-recipes/recipes/llama3_native_te/flops.py", + "bionemo-recipes/recipes/esm2_native_te/flops.py", + "bionemo-recipes/recipes/codonfm_native_te/flops.py", + "bionemo-recipes/recipes/opengenome2_llama_native_te/flops.py", + ], # Common test library - synced between models "bionemo-recipes/models/esm2/tests/common": [ "bionemo-recipes/models/llama3/tests/common", From 1930644d06053d22eb1b795b0c7c3b6126462ee7 Mon Sep 17 00:00:00 2001 From: Gagan Kaushik Date: Sat, 11 Apr 2026 10:27:33 +0000 Subject: [PATCH 06/24] Remove first_principles.md from repository Signed-off-by: Gagan Kaushik --- .../models/esm2/first_principles.md | 348 ------------------ 1 file changed, 348 deletions(-) delete mode 100644 bionemo-recipes/models/esm2/first_principles.md diff --git a/bionemo-recipes/models/esm2/first_principles.md b/bionemo-recipes/models/esm2/first_principles.md deleted file mode 100644 index 27e249f963..0000000000 --- a/bionemo-recipes/models/esm2/first_principles.md +++ /dev/null @@ -1,348 +0,0 @@ -# First-Principles FLOPs Derivation for Llama 3 (GQA + SwiGLU) - -This document derives the per-training-step FLOPs formula used in `compute_flops_first_principles()`, explains each component, and contrasts it with the simplified README formula. - -## Counting convention - -We count **multiply-accumulate operations (MACs)** and report them as **2 FLOPs per MAC** (one multiply, one add). For a matrix multiplication of shapes `(M, K) @ (K, N)`, the FLOPs are: - -``` -FLOPs = 2 * M * K * N -``` - -We only count dense matmuls. Softmax, layer norms, RoPE rotations, element-wise activations (SiLU), and the Hadamard product in SwiGLU are negligible relative to the matmuls and are excluded, consistent with standard MFU methodology. - -## Notation - -| Symbol | Meaning | Lingua-1B value | -| ------ | ------------------------------------------------- | --------------- | -| B | Batch size | 1 | -| S | Sequence length | varies | -| H | Hidden size (`hidden_size`) | 2048 | -| L | Number of layers (`num_hidden_layers`) | 25 | -| n_h | Number of attention heads (`num_attention_heads`) | 16 | -| n_kv | Number of KV heads (`num_key_value_heads`) | 8 | -| d | Head dimension (H / n_h) | 128 | -| d_kv | KV dimension (n_kv * d) | 1024 | -| I | FFN intermediate size (`intermediate_size`) | 6144 | -| V | Vocabulary size (`vocab_size`) | 128256 | - -## Per-layer forward FLOPs - -### Attention projections - -Each attention layer projects the hidden states into queries, keys, values, and then projects the attention output back. - -**Q projection**: Each token's hidden state (H) is projected to the query space (H = n_h * d). - -``` -input: (B, S, H) -weight: (H, H) -output: (B, S, H) -FLOPs = 2 * B * S * H * H -``` - -**K projection**: With Grouped Query Attention (GQA), keys are projected to a smaller space (d_kv = n_kv * d) instead of the full H. This is the key difference from standard Multi-Head Attention (MHA). - -``` -input: (B, S, H) -weight: (H, d_kv) -output: (B, S, d_kv) -FLOPs = 2 * B * S * H * d_kv -``` - -**V projection**: Same dimensions as K projection. - -``` -FLOPs = 2 * B * S * H * d_kv -``` - -**O projection**: The concatenated attention output (H) is projected back to hidden size (H). - -``` -input: (B, S, H) -weight: (H, H) -output: (B, S, H) -FLOPs = 2 * B * S * H * H -``` - -**Total attention projections:** - -``` -attn_proj = 2 * B * S * H * (2*H + 2*d_kv) -``` - -For MHA (d_kv = H), this simplifies to `2 * B * S * H * 4H = 8 * B * S * H^2`. For GQA with d_kv < H, the K and V projections are smaller. - -### Attention scores - -After projection, attention computes Q @ K^T and then attn_weights @ V. Even with GQA (fewer KV heads), the KV heads are **broadcast** to match the query heads, so the effective computation uses all n_h query heads attending to S key positions. - -**Attention logits (Q @ K^T)**: For each head, the query (S, d) is multiplied by key^T (d, S). - -``` -Per head: 2 * B * S * d * S = 2 * B * S^2 * d -All n_h heads: 2 * B * S^2 * d * n_h = 2 * B * S^2 * H -``` - -Note: with GQA, each KV head is shared across (n_h / n_kv) query heads. The total FLOPs remain `2 * B * S^2 * H` because we still have n_h query heads each doing S\*d work against S keys. - -**Attention values (attn_weights @ V)**: Same shape — attention weights (S, S) multiplied by values (S, d) per head. - -``` -FLOPs = 2 * B * S^2 * H -``` - -**Total attention scores:** - -``` -attn_score = 4 * B * S^2 * H -``` - -### MLP (SwiGLU) - -Llama 3 uses SwiGLU activation, which has **three** linear projections instead of the standard MLP's two: - -``` -SwiGLU(x) = (x @ W_gate * SiLU(x @ W_up)) @ W_down -``` - -Standard MLP has two projections (up: H -> I, down: I -> H) with I = 4H typically. SwiGLU adds a third (gate) projection. - -**Gate projection**: H -> I - -``` -FLOPs = 2 * B * S * H * I -``` - -**Up projection**: H -> I - -``` -FLOPs = 2 * B * S * H * I -``` - -**Down projection**: I -> H - -``` -FLOPs = 2 * B * S * I * H -``` - -The element-wise SiLU activation and the Hadamard product (gate * up) are O(B * S * I) — negligible compared to the matmuls. - -**Total MLP:** - -``` -mlp = 6 * B * S * H * I -``` - -### Per-layer total - -``` -per_layer_fwd = attn_proj + attn_score + mlp - = 2*B*S*H*(2*H + 2*d_kv) + 4*B*S^2*H + 6*B*S*H*I -``` - -## LM head - -The language model head projects hidden states to vocabulary logits: - -``` -input: (B, S, H) -weight: (H, V) -output: (B, S, V) -FLOPs = 2 * B * S * H * V -``` - -## Total forward FLOPs - -``` -total_fwd = L * per_layer_fwd + lm_head - = L * [2*B*S*H*(2*H + 2*d_kv) + 4*B*S^2*H + 6*B*S*H*I] + 2*B*S*H*V -``` - -## Total training FLOPs (forward + backward) - -The standard approximation for training is that backward costs 2x the forward (one pass to compute dL/dW, another to compute dL/dX for each matmul). Total training = 3x forward. - -``` -total_training = 3 * total_fwd -``` - -## Comparison with the README formula - -The README uses a simplified formula for a standard transformer: - -```python -total = (24 * B * S * H * H + 4 * B * S * S * H) * (3 * L) + (6 * B * S * H * V) -``` - -The `3*L` folds the 3x training multiplier into the layer count, and `6*B*S*H*V = 3 * 2*B*S*H*V` does the same for the LM head. Extracting the per-layer **forward** FLOPs implicit in the README: - -``` -readme_per_layer_fwd = (24*B*S*H^2 + 4*B*S^2*H) / 3 - = 8*B*S*H^2 + (4/3)*B*S^2*H -``` - -The `4*B*S^2*H` attention score term (with 3x) matches our first-principles `4*B*S^2*H` exactly — both formulas agree on attention scores. The difference is entirely in the `24*B*S*H^2` term, which covers attention projections and MLP. Decomposing it: - -### Decomposition of the README's `24*B*S*H^2` - -The coefficient 24 encodes two assumptions about the per-layer linear projections: - -**Attention projections (coefficient = 8):** Four projections (Q, K, V, O) each of size H -> H, assuming standard Multi-Head Attention (MHA): - -``` -4 projections * 2*B*S*H*H = 8*B*S*H^2 -``` - -**MLP (coefficient = 16):** Two projections with intermediate size I = 4H, assuming a standard Feed-Forward Network: - -``` -Up: 2*B*S*H*(4H) = 8*B*S*H^2 -Down: 2*B*S*(4H)*H = 8*B*S*H^2 -Total: = 16*B*S*H^2 -``` - -Combined: `8 + 16 = 24`. - -### How our first-principles formula differs - -Our formula replaces both assumptions with the actual Llama 3 architecture: - -**Attention projections with GQA:** K and V project to d_kv (not H): - -``` -Q: 2*B*S*H*H K: 2*B*S*H*d_kv V: 2*B*S*H*d_kv O: 2*B*S*H*H -Total = 2*B*S*H*(2*H + 2*d_kv) -``` - -**MLP with SwiGLU:** Three projections (gate, up, down) with actual I: - -``` -Gate: 2*B*S*H*I Up: 2*B*S*H*I Down: 2*B*S*I*H -Total = 6*B*S*H*I -``` - -Side by side, per layer forward, factoring out `2*B*S*H`: - -| Component | README | First principles | -| ---------------- | ---------------------- | ------------------------------ | -| Attention proj | `4*H` (MHA) | `2*H + 2*d_kv` (GQA) | -| MLP | `2*4H = 8*H` (std FFN) | `3*I` (SwiGLU) | -| Attention scores | `2*S` (same) | `2*S` (same) | -| **Total coeff** | **`4*H + 8*H + 2*S`** | **`2*H + 2*d_kv + 3*I + 2*S`** | - -Setting them equal: `12*H + 2*S = 2*H + 2*d_kv + 3*I + 2*S`, which simplifies to: - -``` -10*H = 2*d_kv + 3*I -``` - -### Where the assumptions break - -| Component | README assumes | Llama 3 actual | Direction | -| ---------------- | --------------------- | ------------------------------ | --------------------- | -| K, V projections | H -> H (MHA) | H -> d_kv (GQA, d_kv < H) | README **overcounts** | -| MLP | 2 projections, I = 4H | 3 projections (SwiGLU), I < 4H | Depends on model dims | - -The errors go in opposite directions: - -- **GQA** makes K/V projections cheaper (d_kv < H), so the README overcounts attention -- **SwiGLU** adds a third MLP projection, so the README undercounts MLP (despite using a larger I=4H) - -### Why they cancel exactly for Lingua-1B - -For the Lingua-1B config (H=2048, d_kv=1024, I=6144): - -``` -README linear cost per layer: 12*H = 12 * 2048 = 24,576 -First-principles linear cost: 2*H + 2*d_kv + 3*I = 4096 + 2048 + 18432 = 24,576 -``` - -They match. Breaking down why: - -- README assumes attn proj cost: `4*H = 4 * 2048 = 8,192` - -- Actual attn proj cost: `2*H + 2*d_kv = 4096 + 2048 = 6,144` - -- **GQA saves: 2,048** - -- README assumes MLP cost: `8*H = 8 * 2048 = 16,384` - -- Actual MLP cost: `3*I = 3 * 6144 = 18,432` - -- **SwiGLU adds: 2,048** - -Saved from GQA = Added from SwiGLU = **2,048 exactly**. This is a coincidence specific to Lingua-1B's dimensions. For models with different d_kv/H or I/H ratios, the formulas diverge. - -### When would they diverge? - -For Llama 3.1 70B (H=8192, n_kv=8, d=128, d_kv=1024, I=28672): - -``` -README linear: 12*H = 98,304 -First-principles: 2*8192 + 2*1024 + 3*28672 = 16384 + 2048 + 86016 = 104,448 -Difference: +6.2% (README undercounts by ~6%) -``` - -The README would **undercount** FLOPs for Llama 70B because SwiGLU's third projection with the large I=28672 dominates the GQA savings. - -## Code - -```python -def compute_flops_first_principles( - b, s, h, num_layers, n_kv_heads, head_dim, ffn_hidden_size, vocab_size -): - kv_dim = n_kv_heads * head_dim - - breakdown = { - "Q projection": 2 * b * s * h * h, - "K projection": 2 * b * s * h * kv_dim, - "V projection": 2 * b * s * h * kv_dim, - "O projection": 2 * b * s * h * h, - "Attn logits": 2 * b * s * s * h, - "Attn values": 2 * b * s * s * h, - "Gate proj": 2 * b * s * h * ffn_hidden_size, - "Up proj": 2 * b * s * h * ffn_hidden_size, - "Down proj": 2 * b * s * ffn_hidden_size * h, - } - - per_layer_fwd = sum(breakdown.values()) - lm_head_fwd = 2 * b * s * h * vocab_size - total_fwd = num_layers * per_layer_fwd + lm_head_fwd - total_training = 3 * total_fwd - - return total_training, breakdown, lm_head_fwd -``` - -## Numerical example (Lingua-1B, B=1, S=4096) - -``` -Per layer forward: - Q proj: 2 * 1 * 4096 * 2048 * 2048 = 34,359,738,368 - K proj: 2 * 1 * 4096 * 2048 * 1024 = 17,179,869,184 - V proj: 2 * 1 * 4096 * 2048 * 1024 = 17,179,869,184 - O proj: 2 * 1 * 4096 * 2048 * 2048 = 34,359,738,368 - Attn logits: 2 * 1 * 4096 * 4096 * 2048 = 68,719,476,736 - Attn values: 2 * 1 * 4096 * 4096 * 2048 = 68,719,476,736 - Gate proj: 2 * 1 * 4096 * 2048 * 6144 = 103,079,215,104 - Up proj: 2 * 1 * 4096 * 2048 * 6144 = 103,079,215,104 - Down proj: 2 * 1 * 4096 * 6144 * 2048 = 103,079,215,104 - ───────────────────────────────────────────────────────────────── - Per-layer total: 549,755,813,888 - -LM head: 2 * 1 * 4096 * 2048 * 128256 = 2,152,726,528,000 - -Total forward: 25 * 549,755,813,888 + 2,152,726,528,000 - = 13,743,895,347,200 + 2,152,726,528,000 - = 15,896,621,875,200 - -Total training (3x): 3 * 15,896,621,875,200 = 47,689,865,625,600 -``` - -Note: the code uses integer arithmetic and reports `47,687,021,887,488` — the small difference is from the embedding layer not being counted here (it's a lookup, not a matmul) and the LM head weight tying configuration. - -## References - -- Korthikanti et al., "Reducing Activation Recomputation in Large Transformer Models" (2022) — establishes the 3x forward approximation for training FLOPs -- Chowdhery et al., "PaLM: Scaling Language Modeling with Pathways" (2022) — defines MFU as model_flops / (step_time * peak_hardware_flops) From 7094d17157bf49b224b0fd395a5e4cf430724894 Mon Sep 17 00:00:00 2001 From: Gagan Kaushik Date: Sat, 11 Apr 2026 10:30:25 +0000 Subject: [PATCH 07/24] Remove first_principles.md reference from flops.py docstring Signed-off-by: Gagan Kaushik --- bionemo-recipes/models/esm2/flops.py | 2 +- bionemo-recipes/recipes/codonfm_native_te/flops.py | 2 +- bionemo-recipes/recipes/esm2_native_te/flops.py | 2 +- bionemo-recipes/recipes/llama3_native_te/flops.py | 2 +- bionemo-recipes/recipes/opengenome2_llama_native_te/flops.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/bionemo-recipes/models/esm2/flops.py b/bionemo-recipes/models/esm2/flops.py index 31e2988700..e1c44a74e7 100644 --- a/bionemo-recipes/models/esm2/flops.py +++ b/bionemo-recipes/models/esm2/flops.py @@ -192,7 +192,7 @@ def compute_flops_simplified(batch_size, seq_len, hidden_size, num_layers, vocab (24*B*S*H^2 + 4*B*S^2*H) * 3*L + 6*B*S*H*V The 24*H^2 coefficient assumes: 8*H^2 for 4 attention projections (MHA) + - 16*H^2 for 2 MLP projections with I=4H. See first_principles.md for details. + 16*H^2 for 2 MLP projections with I=4H. """ b, s, h = batch_size, seq_len, hidden_size return (24 * b * s * h * h + 4 * b * s * s * h) * (3 * num_layers) + (6 * b * s * h * vocab_size) diff --git a/bionemo-recipes/recipes/codonfm_native_te/flops.py b/bionemo-recipes/recipes/codonfm_native_te/flops.py index fafd718ecf..d94b599f81 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/flops.py +++ b/bionemo-recipes/recipes/codonfm_native_te/flops.py @@ -198,7 +198,7 @@ def compute_flops_simplified(batch_size, seq_len, hidden_size, num_layers, vocab (24*B*S*H^2 + 4*B*S^2*H) * 3*L + 6*B*S*H*V The 24*H^2 coefficient assumes: 8*H^2 for 4 attention projections (MHA) + - 16*H^2 for 2 MLP projections with I=4H. See first_principles.md for details. + 16*H^2 for 2 MLP projections with I=4H. """ b, s, h = batch_size, seq_len, hidden_size return (24 * b * s * h * h + 4 * b * s * s * h) * (3 * num_layers) + (6 * b * s * h * vocab_size) diff --git a/bionemo-recipes/recipes/esm2_native_te/flops.py b/bionemo-recipes/recipes/esm2_native_te/flops.py index fafd718ecf..d94b599f81 100644 --- a/bionemo-recipes/recipes/esm2_native_te/flops.py +++ b/bionemo-recipes/recipes/esm2_native_te/flops.py @@ -198,7 +198,7 @@ def compute_flops_simplified(batch_size, seq_len, hidden_size, num_layers, vocab (24*B*S*H^2 + 4*B*S^2*H) * 3*L + 6*B*S*H*V The 24*H^2 coefficient assumes: 8*H^2 for 4 attention projections (MHA) + - 16*H^2 for 2 MLP projections with I=4H. See first_principles.md for details. + 16*H^2 for 2 MLP projections with I=4H. """ b, s, h = batch_size, seq_len, hidden_size return (24 * b * s * h * h + 4 * b * s * s * h) * (3 * num_layers) + (6 * b * s * h * vocab_size) diff --git a/bionemo-recipes/recipes/llama3_native_te/flops.py b/bionemo-recipes/recipes/llama3_native_te/flops.py index fafd718ecf..d94b599f81 100644 --- a/bionemo-recipes/recipes/llama3_native_te/flops.py +++ b/bionemo-recipes/recipes/llama3_native_te/flops.py @@ -198,7 +198,7 @@ def compute_flops_simplified(batch_size, seq_len, hidden_size, num_layers, vocab (24*B*S*H^2 + 4*B*S^2*H) * 3*L + 6*B*S*H*V The 24*H^2 coefficient assumes: 8*H^2 for 4 attention projections (MHA) + - 16*H^2 for 2 MLP projections with I=4H. See first_principles.md for details. + 16*H^2 for 2 MLP projections with I=4H. """ b, s, h = batch_size, seq_len, hidden_size return (24 * b * s * h * h + 4 * b * s * s * h) * (3 * num_layers) + (6 * b * s * h * vocab_size) diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/flops.py b/bionemo-recipes/recipes/opengenome2_llama_native_te/flops.py index fafd718ecf..d94b599f81 100644 --- a/bionemo-recipes/recipes/opengenome2_llama_native_te/flops.py +++ b/bionemo-recipes/recipes/opengenome2_llama_native_te/flops.py @@ -198,7 +198,7 @@ def compute_flops_simplified(batch_size, seq_len, hidden_size, num_layers, vocab (24*B*S*H^2 + 4*B*S^2*H) * 3*L + 6*B*S*H*V The 24*H^2 coefficient assumes: 8*H^2 for 4 attention projections (MHA) + - 16*H^2 for 2 MLP projections with I=4H. See first_principles.md for details. + 16*H^2 for 2 MLP projections with I=4H. """ b, s, h = batch_size, seq_len, hidden_size return (24 * b * s * h * h + 4 * b * s * s * h) * (3 * num_layers) + (6 * b * s * h * vocab_size) From 8ea45ac7ac94e2a9a13dd67e729f3fc6e5838a1a Mon Sep 17 00:00:00 2001 From: Gagan Kaushik Date: Sat, 11 Apr 2026 11:03:06 +0000 Subject: [PATCH 08/24] Update GPU TFLOPS table: fix RTX values, add B300/GB300, add sources MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix RTX 4090 (330→165) and RTX 3090 (142→71): were using sparse values - Remove A5000, RTX 4090, RTX 3090 (not relevant for data center workloads) - Add B300 and GB300 Blackwell Ultra at 2,500 TFLOPS dense BF16 - Add authoritative NVIDIA source URLs for all GPU specs as comments Sources: B300/GB300: https://www.nvidia.com/en-us/data-center/gb300-nvl72/ (360 PFLOPS BF16 sparse / 72 GPUs / 2 = 2,500 TFLOPS dense per GPU) Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Gagan Kaushik --- bionemo-recipes/models/esm2/flops.py | 14 +++++++++++--- bionemo-recipes/recipes/codonfm_native_te/flops.py | 14 +++++++++++--- bionemo-recipes/recipes/esm2_native_te/flops.py | 14 +++++++++++--- bionemo-recipes/recipes/llama3_native_te/flops.py | 14 +++++++++++--- .../recipes/opengenome2_llama_native_te/flops.py | 14 +++++++++++--- 5 files changed, 55 insertions(+), 15 deletions(-) diff --git a/bionemo-recipes/models/esm2/flops.py b/bionemo-recipes/models/esm2/flops.py index e1c44a74e7..4fbf8775dd 100644 --- a/bionemo-recipes/models/esm2/flops.py +++ b/bionemo-recipes/models/esm2/flops.py @@ -46,18 +46,26 @@ # GPU Peak TFLOPS # ============================================================================= +# Dense (without sparsity) BF16 tensor core peak TFLOPS. +# NVIDIA product pages often list the 2x sparse value; dense = sparse / 2. +# Sources: +# H100/H200/GH200: https://www.nvidia.com/en-us/data-center/h100/ (1,979 TFLOPS sparse) +# A100: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf +# A6000: https://www.nvidia.com/content/dam/en-zz/Solutions/design-visualization/quadro-product-literature/proviz-print-nvidia-rtx-a6000-datasheet-us-nvidia-1454980-r9-web%20(1).pdf +# L40: https://www.nvidia.com/en-us/data-center/l40/ +# B200/GB200: https://www.nvidia.com/en-us/data-center/dgx-b200/ +# B300/GB300: https://www.nvidia.com/en-us/data-center/gb300-nvl72/ (360 PFLOPS sparse / 72 GPUs / 2) GPU_PEAK_TFLOPS_BF16 = { "H100": 989.0, "H200": 989.0, "A100": 312.0, "A6000": 155.0, - "A5000": 111.0, "L40": 181.0, - "RTX 4090": 330.0, - "RTX 3090": 142.0, "GH200": 989.0, "B200": 2250.0, "GB200": 2250.0, + "B300": 2500.0, + "GB300": 2500.0, } diff --git a/bionemo-recipes/recipes/codonfm_native_te/flops.py b/bionemo-recipes/recipes/codonfm_native_te/flops.py index d94b599f81..b6f9ed4e1e 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/flops.py +++ b/bionemo-recipes/recipes/codonfm_native_te/flops.py @@ -52,18 +52,26 @@ # GPU Peak TFLOPS # ============================================================================= +# Dense (without sparsity) BF16 tensor core peak TFLOPS. +# NVIDIA product pages often list the 2x sparse value; dense = sparse / 2. +# Sources: +# H100/H200/GH200: https://www.nvidia.com/en-us/data-center/h100/ (1,979 TFLOPS sparse) +# A100: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf +# A6000: https://www.nvidia.com/content/dam/en-zz/Solutions/design-visualization/quadro-product-literature/proviz-print-nvidia-rtx-a6000-datasheet-us-nvidia-1454980-r9-web%20(1).pdf +# L40: https://www.nvidia.com/en-us/data-center/l40/ +# B200/GB200: https://www.nvidia.com/en-us/data-center/dgx-b200/ +# B300/GB300: https://www.nvidia.com/en-us/data-center/gb300-nvl72/ (360 PFLOPS sparse / 72 GPUs / 2) GPU_PEAK_TFLOPS_BF16 = { "H100": 989.0, "H200": 989.0, "A100": 312.0, "A6000": 155.0, - "A5000": 111.0, "L40": 181.0, - "RTX 4090": 330.0, - "RTX 3090": 142.0, "GH200": 989.0, "B200": 2250.0, "GB200": 2250.0, + "B300": 2500.0, + "GB300": 2500.0, } diff --git a/bionemo-recipes/recipes/esm2_native_te/flops.py b/bionemo-recipes/recipes/esm2_native_te/flops.py index d94b599f81..b6f9ed4e1e 100644 --- a/bionemo-recipes/recipes/esm2_native_te/flops.py +++ b/bionemo-recipes/recipes/esm2_native_te/flops.py @@ -52,18 +52,26 @@ # GPU Peak TFLOPS # ============================================================================= +# Dense (without sparsity) BF16 tensor core peak TFLOPS. +# NVIDIA product pages often list the 2x sparse value; dense = sparse / 2. +# Sources: +# H100/H200/GH200: https://www.nvidia.com/en-us/data-center/h100/ (1,979 TFLOPS sparse) +# A100: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf +# A6000: https://www.nvidia.com/content/dam/en-zz/Solutions/design-visualization/quadro-product-literature/proviz-print-nvidia-rtx-a6000-datasheet-us-nvidia-1454980-r9-web%20(1).pdf +# L40: https://www.nvidia.com/en-us/data-center/l40/ +# B200/GB200: https://www.nvidia.com/en-us/data-center/dgx-b200/ +# B300/GB300: https://www.nvidia.com/en-us/data-center/gb300-nvl72/ (360 PFLOPS sparse / 72 GPUs / 2) GPU_PEAK_TFLOPS_BF16 = { "H100": 989.0, "H200": 989.0, "A100": 312.0, "A6000": 155.0, - "A5000": 111.0, "L40": 181.0, - "RTX 4090": 330.0, - "RTX 3090": 142.0, "GH200": 989.0, "B200": 2250.0, "GB200": 2250.0, + "B300": 2500.0, + "GB300": 2500.0, } diff --git a/bionemo-recipes/recipes/llama3_native_te/flops.py b/bionemo-recipes/recipes/llama3_native_te/flops.py index d94b599f81..b6f9ed4e1e 100644 --- a/bionemo-recipes/recipes/llama3_native_te/flops.py +++ b/bionemo-recipes/recipes/llama3_native_te/flops.py @@ -52,18 +52,26 @@ # GPU Peak TFLOPS # ============================================================================= +# Dense (without sparsity) BF16 tensor core peak TFLOPS. +# NVIDIA product pages often list the 2x sparse value; dense = sparse / 2. +# Sources: +# H100/H200/GH200: https://www.nvidia.com/en-us/data-center/h100/ (1,979 TFLOPS sparse) +# A100: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf +# A6000: https://www.nvidia.com/content/dam/en-zz/Solutions/design-visualization/quadro-product-literature/proviz-print-nvidia-rtx-a6000-datasheet-us-nvidia-1454980-r9-web%20(1).pdf +# L40: https://www.nvidia.com/en-us/data-center/l40/ +# B200/GB200: https://www.nvidia.com/en-us/data-center/dgx-b200/ +# B300/GB300: https://www.nvidia.com/en-us/data-center/gb300-nvl72/ (360 PFLOPS sparse / 72 GPUs / 2) GPU_PEAK_TFLOPS_BF16 = { "H100": 989.0, "H200": 989.0, "A100": 312.0, "A6000": 155.0, - "A5000": 111.0, "L40": 181.0, - "RTX 4090": 330.0, - "RTX 3090": 142.0, "GH200": 989.0, "B200": 2250.0, "GB200": 2250.0, + "B300": 2500.0, + "GB300": 2500.0, } diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/flops.py b/bionemo-recipes/recipes/opengenome2_llama_native_te/flops.py index d94b599f81..b6f9ed4e1e 100644 --- a/bionemo-recipes/recipes/opengenome2_llama_native_te/flops.py +++ b/bionemo-recipes/recipes/opengenome2_llama_native_te/flops.py @@ -52,18 +52,26 @@ # GPU Peak TFLOPS # ============================================================================= +# Dense (without sparsity) BF16 tensor core peak TFLOPS. +# NVIDIA product pages often list the 2x sparse value; dense = sparse / 2. +# Sources: +# H100/H200/GH200: https://www.nvidia.com/en-us/data-center/h100/ (1,979 TFLOPS sparse) +# A100: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf +# A6000: https://www.nvidia.com/content/dam/en-zz/Solutions/design-visualization/quadro-product-literature/proviz-print-nvidia-rtx-a6000-datasheet-us-nvidia-1454980-r9-web%20(1).pdf +# L40: https://www.nvidia.com/en-us/data-center/l40/ +# B200/GB200: https://www.nvidia.com/en-us/data-center/dgx-b200/ +# B300/GB300: https://www.nvidia.com/en-us/data-center/gb300-nvl72/ (360 PFLOPS sparse / 72 GPUs / 2) GPU_PEAK_TFLOPS_BF16 = { "H100": 989.0, "H200": 989.0, "A100": 312.0, "A6000": 155.0, - "A5000": 111.0, "L40": 181.0, - "RTX 4090": 330.0, - "RTX 3090": 142.0, "GH200": 989.0, "B200": 2250.0, "GB200": 2250.0, + "B300": 2500.0, + "GB300": 2500.0, } From d853f9622b308cfa950f72c631fa4b3c7c7415a6 Mon Sep 17 00:00:00 2001 From: Gagan Kaushik Date: Mon, 13 Apr 2026 20:03:56 +0000 Subject: [PATCH 09/24] Clean up flops.py: remove dead code from deleted benchmark scripts Remove functions that were only used by the now-deleted compare_mfu*.py standalone benchmark scripts: - compute_flops_first_principles(), compute_flops_readme() (backward compat wrappers) - measure_step_time(), split_for_cp_bshd(), count_flops_with_model(), cleanup_model() - Unused imports: gc, nullcontext, FlopCounterMode Retained: MFUTracker, from_hf_config, compute_flops_analytical, compute_flops_simplified, compute_flops_hyena, CLI, and all utilities needed by the training script log_mfu integration. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Gagan Kaushik --- bionemo-recipes/models/esm2/flops.py | 121 ------------------ .../recipes/codonfm_native_te/flops.py | 121 ------------------ .../recipes/esm2_native_te/flops.py | 121 ------------------ .../recipes/llama3_native_te/flops.py | 121 ------------------ .../meta-llama/Llama-3.1-8B/config.json | 35 +++++ .../opengenome2_llama_native_te/flops.py | 121 ------------------ 6 files changed, 35 insertions(+), 605 deletions(-) create mode 100644 bionemo-recipes/recipes/llama3_native_te/model_configs/meta-llama/Llama-3.1-8B/config.json diff --git a/bionemo-recipes/models/esm2/flops.py b/bionemo-recipes/models/esm2/flops.py index 4fbf8775dd..bce8388069 100644 --- a/bionemo-recipes/models/esm2/flops.py +++ b/bionemo-recipes/models/esm2/flops.py @@ -31,15 +31,12 @@ torchrun --nproc_per_node=2 flops.py bandwidth """ -import gc import math import time -from contextlib import nullcontext from dataclasses import dataclass import torch import torch.distributed as dist -from torch.utils.flop_counter import FlopCounterMode # ============================================================================= @@ -257,30 +254,6 @@ def compute_flops_hyena(config, batch_size, seq_len, hyena_layer_counts=None): return 3 * total_fwd -# Backward-compatible wrappers for existing compare_mfu*.py scripts. - - -def compute_flops_first_principles(b, s, h, num_layers, n_kv_heads, head_dim, ffn_hidden_size, vocab_size): - """Backward-compatible wrapper. Assumes SwiGLU (3 MLP projections).""" - config = ModelFLOPsConfig( - hidden_size=h, - num_hidden_layers=num_layers, - num_attention_heads=h // head_dim, - num_kv_heads=n_kv_heads, - head_dim=head_dim, - intermediate_size=ffn_hidden_size, - num_mlp_projections=3, - vocab_size=vocab_size, - has_lm_head=True, - ) - return compute_flops_analytical(config, b, s) - - -def compute_flops_readme(b, s, h, num_layers, vocab_size): - """Backward-compatible wrapper for the simplified README formula.""" - return compute_flops_simplified(b, s, h, num_layers, vocab_size) - - # ============================================================================= # MFU Tracker # ============================================================================= @@ -472,90 +445,11 @@ def _estimate_model_params(config): return total -# ============================================================================= -# Step Time Measurement -# ============================================================================= - - -def measure_step_time( - model, - input_ids, - num_warmup=10, - num_timed=20, - distributed=False, - cp_context_fn=None, - labels=None, - position_ids=None, - **extra_fwd_kwargs, -): - """Measure average training step time (forward + backward). - - Args: - model: The model to benchmark. - input_ids: Input tensor. - num_warmup: Warmup iterations (discarded). - num_timed: Timed iterations to average. - distributed: Whether to use dist.barrier() for synchronization. - cp_context_fn: Optional callable returning a context manager for CP. - labels: Optional labels tensor. If None, uses input_ids. - position_ids: Optional position_ids for correct RoPE with CP. - **extra_fwd_kwargs: Additional kwargs for model forward. - """ - if labels is None: - labels = input_ids - - fwd_kwargs = {"input_ids": input_ids, "labels": labels, **extra_fwd_kwargs} - if position_ids is not None: - fwd_kwargs["position_ids"] = position_ids - - for _ in range(num_warmup): - ctx = cp_context_fn() if cp_context_fn else nullcontext() - with ctx: - output = model(**fwd_kwargs) - output.loss.backward() - model.zero_grad(set_to_none=True) - - if distributed: - dist.barrier() - torch.cuda.synchronize() - - times = [] - for _ in range(num_timed): - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - torch.cuda.synchronize() - start.record() - - ctx = cp_context_fn() if cp_context_fn else nullcontext() - with ctx: - output = model(**fwd_kwargs) - output.loss.backward() - model.zero_grad(set_to_none=True) - - end.record() - torch.cuda.synchronize() - times.append(start.elapsed_time(end) / 1000.0) - - return sum(times) / len(times) - - # ============================================================================= # Utilities # ============================================================================= -def split_for_cp_bshd(tensor, cp_rank, cp_size): - """Split a BSHD tensor for CP using the dual-chunk zigzag pattern.""" - if cp_size <= 1: - return tensor - total_chunks = 2 * cp_size - seq_len = tensor.size(1) - chunk_size = seq_len // total_chunks - chunk_indices = [cp_rank, total_chunks - cp_rank - 1] - slices = [tensor[:, idx * chunk_size : (idx + 1) * chunk_size] for idx in chunk_indices] - return torch.cat(slices, dim=1) - - def measure_bus_bandwidth(device, world_size, num_iters=20, num_elements=10_000_000): """Measure unidirectional P2P bandwidth via send/recv (matches CP ring pattern).""" if world_size <= 1: @@ -589,14 +483,6 @@ def measure_bus_bandwidth(device, world_size, num_iters=20, num_elements=10_000_ return num_iters * data_bytes / elapsed / 1e9 -def count_flops_with_model(model, input_ids): - """Count forward FLOPs using PyTorch's FlopCounterMode, return 3x for training.""" - flop_counter = FlopCounterMode(display=False) - with flop_counter: - model(input_ids=input_ids) - return flop_counter.get_total_flops() * 3 - - def load_model_config(config_path): """Load model config dict from a local path or HuggingFace model ID. @@ -621,13 +507,6 @@ def load_model_config(config_path): return hf_config.to_dict() -def cleanup_model(model): - """Delete a model and free GPU memory.""" - del model - gc.collect() - torch.cuda.empty_cache() - - # ============================================================================= # Formatting # ============================================================================= diff --git a/bionemo-recipes/recipes/codonfm_native_te/flops.py b/bionemo-recipes/recipes/codonfm_native_te/flops.py index b6f9ed4e1e..f4fb7411fc 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/flops.py +++ b/bionemo-recipes/recipes/codonfm_native_te/flops.py @@ -37,15 +37,12 @@ torchrun --nproc_per_node=2 flops.py bandwidth """ -import gc import math import time -from contextlib import nullcontext from dataclasses import dataclass import torch import torch.distributed as dist -from torch.utils.flop_counter import FlopCounterMode # ============================================================================= @@ -263,30 +260,6 @@ def compute_flops_hyena(config, batch_size, seq_len, hyena_layer_counts=None): return 3 * total_fwd -# Backward-compatible wrappers for existing compare_mfu*.py scripts. - - -def compute_flops_first_principles(b, s, h, num_layers, n_kv_heads, head_dim, ffn_hidden_size, vocab_size): - """Backward-compatible wrapper. Assumes SwiGLU (3 MLP projections).""" - config = ModelFLOPsConfig( - hidden_size=h, - num_hidden_layers=num_layers, - num_attention_heads=h // head_dim, - num_kv_heads=n_kv_heads, - head_dim=head_dim, - intermediate_size=ffn_hidden_size, - num_mlp_projections=3, - vocab_size=vocab_size, - has_lm_head=True, - ) - return compute_flops_analytical(config, b, s) - - -def compute_flops_readme(b, s, h, num_layers, vocab_size): - """Backward-compatible wrapper for the simplified README formula.""" - return compute_flops_simplified(b, s, h, num_layers, vocab_size) - - # ============================================================================= # MFU Tracker # ============================================================================= @@ -478,90 +451,11 @@ def _estimate_model_params(config): return total -# ============================================================================= -# Step Time Measurement -# ============================================================================= - - -def measure_step_time( - model, - input_ids, - num_warmup=10, - num_timed=20, - distributed=False, - cp_context_fn=None, - labels=None, - position_ids=None, - **extra_fwd_kwargs, -): - """Measure average training step time (forward + backward). - - Args: - model: The model to benchmark. - input_ids: Input tensor. - num_warmup: Warmup iterations (discarded). - num_timed: Timed iterations to average. - distributed: Whether to use dist.barrier() for synchronization. - cp_context_fn: Optional callable returning a context manager for CP. - labels: Optional labels tensor. If None, uses input_ids. - position_ids: Optional position_ids for correct RoPE with CP. - **extra_fwd_kwargs: Additional kwargs for model forward. - """ - if labels is None: - labels = input_ids - - fwd_kwargs = {"input_ids": input_ids, "labels": labels, **extra_fwd_kwargs} - if position_ids is not None: - fwd_kwargs["position_ids"] = position_ids - - for _ in range(num_warmup): - ctx = cp_context_fn() if cp_context_fn else nullcontext() - with ctx: - output = model(**fwd_kwargs) - output.loss.backward() - model.zero_grad(set_to_none=True) - - if distributed: - dist.barrier() - torch.cuda.synchronize() - - times = [] - for _ in range(num_timed): - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - torch.cuda.synchronize() - start.record() - - ctx = cp_context_fn() if cp_context_fn else nullcontext() - with ctx: - output = model(**fwd_kwargs) - output.loss.backward() - model.zero_grad(set_to_none=True) - - end.record() - torch.cuda.synchronize() - times.append(start.elapsed_time(end) / 1000.0) - - return sum(times) / len(times) - - # ============================================================================= # Utilities # ============================================================================= -def split_for_cp_bshd(tensor, cp_rank, cp_size): - """Split a BSHD tensor for CP using the dual-chunk zigzag pattern.""" - if cp_size <= 1: - return tensor - total_chunks = 2 * cp_size - seq_len = tensor.size(1) - chunk_size = seq_len // total_chunks - chunk_indices = [cp_rank, total_chunks - cp_rank - 1] - slices = [tensor[:, idx * chunk_size : (idx + 1) * chunk_size] for idx in chunk_indices] - return torch.cat(slices, dim=1) - - def measure_bus_bandwidth(device, world_size, num_iters=20, num_elements=10_000_000): """Measure unidirectional P2P bandwidth via send/recv (matches CP ring pattern).""" if world_size <= 1: @@ -595,14 +489,6 @@ def measure_bus_bandwidth(device, world_size, num_iters=20, num_elements=10_000_ return num_iters * data_bytes / elapsed / 1e9 -def count_flops_with_model(model, input_ids): - """Count forward FLOPs using PyTorch's FlopCounterMode, return 3x for training.""" - flop_counter = FlopCounterMode(display=False) - with flop_counter: - model(input_ids=input_ids) - return flop_counter.get_total_flops() * 3 - - def load_model_config(config_path): """Load model config dict from a local path or HuggingFace model ID. @@ -627,13 +513,6 @@ def load_model_config(config_path): return hf_config.to_dict() -def cleanup_model(model): - """Delete a model and free GPU memory.""" - del model - gc.collect() - torch.cuda.empty_cache() - - # ============================================================================= # Formatting # ============================================================================= diff --git a/bionemo-recipes/recipes/esm2_native_te/flops.py b/bionemo-recipes/recipes/esm2_native_te/flops.py index b6f9ed4e1e..f4fb7411fc 100644 --- a/bionemo-recipes/recipes/esm2_native_te/flops.py +++ b/bionemo-recipes/recipes/esm2_native_te/flops.py @@ -37,15 +37,12 @@ torchrun --nproc_per_node=2 flops.py bandwidth """ -import gc import math import time -from contextlib import nullcontext from dataclasses import dataclass import torch import torch.distributed as dist -from torch.utils.flop_counter import FlopCounterMode # ============================================================================= @@ -263,30 +260,6 @@ def compute_flops_hyena(config, batch_size, seq_len, hyena_layer_counts=None): return 3 * total_fwd -# Backward-compatible wrappers for existing compare_mfu*.py scripts. - - -def compute_flops_first_principles(b, s, h, num_layers, n_kv_heads, head_dim, ffn_hidden_size, vocab_size): - """Backward-compatible wrapper. Assumes SwiGLU (3 MLP projections).""" - config = ModelFLOPsConfig( - hidden_size=h, - num_hidden_layers=num_layers, - num_attention_heads=h // head_dim, - num_kv_heads=n_kv_heads, - head_dim=head_dim, - intermediate_size=ffn_hidden_size, - num_mlp_projections=3, - vocab_size=vocab_size, - has_lm_head=True, - ) - return compute_flops_analytical(config, b, s) - - -def compute_flops_readme(b, s, h, num_layers, vocab_size): - """Backward-compatible wrapper for the simplified README formula.""" - return compute_flops_simplified(b, s, h, num_layers, vocab_size) - - # ============================================================================= # MFU Tracker # ============================================================================= @@ -478,90 +451,11 @@ def _estimate_model_params(config): return total -# ============================================================================= -# Step Time Measurement -# ============================================================================= - - -def measure_step_time( - model, - input_ids, - num_warmup=10, - num_timed=20, - distributed=False, - cp_context_fn=None, - labels=None, - position_ids=None, - **extra_fwd_kwargs, -): - """Measure average training step time (forward + backward). - - Args: - model: The model to benchmark. - input_ids: Input tensor. - num_warmup: Warmup iterations (discarded). - num_timed: Timed iterations to average. - distributed: Whether to use dist.barrier() for synchronization. - cp_context_fn: Optional callable returning a context manager for CP. - labels: Optional labels tensor. If None, uses input_ids. - position_ids: Optional position_ids for correct RoPE with CP. - **extra_fwd_kwargs: Additional kwargs for model forward. - """ - if labels is None: - labels = input_ids - - fwd_kwargs = {"input_ids": input_ids, "labels": labels, **extra_fwd_kwargs} - if position_ids is not None: - fwd_kwargs["position_ids"] = position_ids - - for _ in range(num_warmup): - ctx = cp_context_fn() if cp_context_fn else nullcontext() - with ctx: - output = model(**fwd_kwargs) - output.loss.backward() - model.zero_grad(set_to_none=True) - - if distributed: - dist.barrier() - torch.cuda.synchronize() - - times = [] - for _ in range(num_timed): - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - torch.cuda.synchronize() - start.record() - - ctx = cp_context_fn() if cp_context_fn else nullcontext() - with ctx: - output = model(**fwd_kwargs) - output.loss.backward() - model.zero_grad(set_to_none=True) - - end.record() - torch.cuda.synchronize() - times.append(start.elapsed_time(end) / 1000.0) - - return sum(times) / len(times) - - # ============================================================================= # Utilities # ============================================================================= -def split_for_cp_bshd(tensor, cp_rank, cp_size): - """Split a BSHD tensor for CP using the dual-chunk zigzag pattern.""" - if cp_size <= 1: - return tensor - total_chunks = 2 * cp_size - seq_len = tensor.size(1) - chunk_size = seq_len // total_chunks - chunk_indices = [cp_rank, total_chunks - cp_rank - 1] - slices = [tensor[:, idx * chunk_size : (idx + 1) * chunk_size] for idx in chunk_indices] - return torch.cat(slices, dim=1) - - def measure_bus_bandwidth(device, world_size, num_iters=20, num_elements=10_000_000): """Measure unidirectional P2P bandwidth via send/recv (matches CP ring pattern).""" if world_size <= 1: @@ -595,14 +489,6 @@ def measure_bus_bandwidth(device, world_size, num_iters=20, num_elements=10_000_ return num_iters * data_bytes / elapsed / 1e9 -def count_flops_with_model(model, input_ids): - """Count forward FLOPs using PyTorch's FlopCounterMode, return 3x for training.""" - flop_counter = FlopCounterMode(display=False) - with flop_counter: - model(input_ids=input_ids) - return flop_counter.get_total_flops() * 3 - - def load_model_config(config_path): """Load model config dict from a local path or HuggingFace model ID. @@ -627,13 +513,6 @@ def load_model_config(config_path): return hf_config.to_dict() -def cleanup_model(model): - """Delete a model and free GPU memory.""" - del model - gc.collect() - torch.cuda.empty_cache() - - # ============================================================================= # Formatting # ============================================================================= diff --git a/bionemo-recipes/recipes/llama3_native_te/flops.py b/bionemo-recipes/recipes/llama3_native_te/flops.py index b6f9ed4e1e..f4fb7411fc 100644 --- a/bionemo-recipes/recipes/llama3_native_te/flops.py +++ b/bionemo-recipes/recipes/llama3_native_te/flops.py @@ -37,15 +37,12 @@ torchrun --nproc_per_node=2 flops.py bandwidth """ -import gc import math import time -from contextlib import nullcontext from dataclasses import dataclass import torch import torch.distributed as dist -from torch.utils.flop_counter import FlopCounterMode # ============================================================================= @@ -263,30 +260,6 @@ def compute_flops_hyena(config, batch_size, seq_len, hyena_layer_counts=None): return 3 * total_fwd -# Backward-compatible wrappers for existing compare_mfu*.py scripts. - - -def compute_flops_first_principles(b, s, h, num_layers, n_kv_heads, head_dim, ffn_hidden_size, vocab_size): - """Backward-compatible wrapper. Assumes SwiGLU (3 MLP projections).""" - config = ModelFLOPsConfig( - hidden_size=h, - num_hidden_layers=num_layers, - num_attention_heads=h // head_dim, - num_kv_heads=n_kv_heads, - head_dim=head_dim, - intermediate_size=ffn_hidden_size, - num_mlp_projections=3, - vocab_size=vocab_size, - has_lm_head=True, - ) - return compute_flops_analytical(config, b, s) - - -def compute_flops_readme(b, s, h, num_layers, vocab_size): - """Backward-compatible wrapper for the simplified README formula.""" - return compute_flops_simplified(b, s, h, num_layers, vocab_size) - - # ============================================================================= # MFU Tracker # ============================================================================= @@ -478,90 +451,11 @@ def _estimate_model_params(config): return total -# ============================================================================= -# Step Time Measurement -# ============================================================================= - - -def measure_step_time( - model, - input_ids, - num_warmup=10, - num_timed=20, - distributed=False, - cp_context_fn=None, - labels=None, - position_ids=None, - **extra_fwd_kwargs, -): - """Measure average training step time (forward + backward). - - Args: - model: The model to benchmark. - input_ids: Input tensor. - num_warmup: Warmup iterations (discarded). - num_timed: Timed iterations to average. - distributed: Whether to use dist.barrier() for synchronization. - cp_context_fn: Optional callable returning a context manager for CP. - labels: Optional labels tensor. If None, uses input_ids. - position_ids: Optional position_ids for correct RoPE with CP. - **extra_fwd_kwargs: Additional kwargs for model forward. - """ - if labels is None: - labels = input_ids - - fwd_kwargs = {"input_ids": input_ids, "labels": labels, **extra_fwd_kwargs} - if position_ids is not None: - fwd_kwargs["position_ids"] = position_ids - - for _ in range(num_warmup): - ctx = cp_context_fn() if cp_context_fn else nullcontext() - with ctx: - output = model(**fwd_kwargs) - output.loss.backward() - model.zero_grad(set_to_none=True) - - if distributed: - dist.barrier() - torch.cuda.synchronize() - - times = [] - for _ in range(num_timed): - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - torch.cuda.synchronize() - start.record() - - ctx = cp_context_fn() if cp_context_fn else nullcontext() - with ctx: - output = model(**fwd_kwargs) - output.loss.backward() - model.zero_grad(set_to_none=True) - - end.record() - torch.cuda.synchronize() - times.append(start.elapsed_time(end) / 1000.0) - - return sum(times) / len(times) - - # ============================================================================= # Utilities # ============================================================================= -def split_for_cp_bshd(tensor, cp_rank, cp_size): - """Split a BSHD tensor for CP using the dual-chunk zigzag pattern.""" - if cp_size <= 1: - return tensor - total_chunks = 2 * cp_size - seq_len = tensor.size(1) - chunk_size = seq_len // total_chunks - chunk_indices = [cp_rank, total_chunks - cp_rank - 1] - slices = [tensor[:, idx * chunk_size : (idx + 1) * chunk_size] for idx in chunk_indices] - return torch.cat(slices, dim=1) - - def measure_bus_bandwidth(device, world_size, num_iters=20, num_elements=10_000_000): """Measure unidirectional P2P bandwidth via send/recv (matches CP ring pattern).""" if world_size <= 1: @@ -595,14 +489,6 @@ def measure_bus_bandwidth(device, world_size, num_iters=20, num_elements=10_000_ return num_iters * data_bytes / elapsed / 1e9 -def count_flops_with_model(model, input_ids): - """Count forward FLOPs using PyTorch's FlopCounterMode, return 3x for training.""" - flop_counter = FlopCounterMode(display=False) - with flop_counter: - model(input_ids=input_ids) - return flop_counter.get_total_flops() * 3 - - def load_model_config(config_path): """Load model config dict from a local path or HuggingFace model ID. @@ -627,13 +513,6 @@ def load_model_config(config_path): return hf_config.to_dict() -def cleanup_model(model): - """Delete a model and free GPU memory.""" - del model - gc.collect() - torch.cuda.empty_cache() - - # ============================================================================= # Formatting # ============================================================================= diff --git a/bionemo-recipes/recipes/llama3_native_te/model_configs/meta-llama/Llama-3.1-8B/config.json b/bionemo-recipes/recipes/llama3_native_te/model_configs/meta-llama/Llama-3.1-8B/config.json new file mode 100644 index 0000000000..460f2f1b71 --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/model_configs/meta-llama/Llama-3.1-8B/config.json @@ -0,0 +1,35 @@ +{ + "architectures": [ + "LlamaForCausalLM" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 128000, + "eos_token_id": 128001, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 14336, + "max_position_embeddings": 131072, + "mlp_bias": false, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 8, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": { + "factor": 8.0, + "high_freq_factor": 4.0, + "low_freq_factor": 1.0, + "original_max_position_embeddings": 8192, + "rope_type": "llama3" + }, + "rope_theta": 500000.0, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.57.3", + "use_cache": true, + "vocab_size": 128256 +} diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/flops.py b/bionemo-recipes/recipes/opengenome2_llama_native_te/flops.py index b6f9ed4e1e..f4fb7411fc 100644 --- a/bionemo-recipes/recipes/opengenome2_llama_native_te/flops.py +++ b/bionemo-recipes/recipes/opengenome2_llama_native_te/flops.py @@ -37,15 +37,12 @@ torchrun --nproc_per_node=2 flops.py bandwidth """ -import gc import math import time -from contextlib import nullcontext from dataclasses import dataclass import torch import torch.distributed as dist -from torch.utils.flop_counter import FlopCounterMode # ============================================================================= @@ -263,30 +260,6 @@ def compute_flops_hyena(config, batch_size, seq_len, hyena_layer_counts=None): return 3 * total_fwd -# Backward-compatible wrappers for existing compare_mfu*.py scripts. - - -def compute_flops_first_principles(b, s, h, num_layers, n_kv_heads, head_dim, ffn_hidden_size, vocab_size): - """Backward-compatible wrapper. Assumes SwiGLU (3 MLP projections).""" - config = ModelFLOPsConfig( - hidden_size=h, - num_hidden_layers=num_layers, - num_attention_heads=h // head_dim, - num_kv_heads=n_kv_heads, - head_dim=head_dim, - intermediate_size=ffn_hidden_size, - num_mlp_projections=3, - vocab_size=vocab_size, - has_lm_head=True, - ) - return compute_flops_analytical(config, b, s) - - -def compute_flops_readme(b, s, h, num_layers, vocab_size): - """Backward-compatible wrapper for the simplified README formula.""" - return compute_flops_simplified(b, s, h, num_layers, vocab_size) - - # ============================================================================= # MFU Tracker # ============================================================================= @@ -478,90 +451,11 @@ def _estimate_model_params(config): return total -# ============================================================================= -# Step Time Measurement -# ============================================================================= - - -def measure_step_time( - model, - input_ids, - num_warmup=10, - num_timed=20, - distributed=False, - cp_context_fn=None, - labels=None, - position_ids=None, - **extra_fwd_kwargs, -): - """Measure average training step time (forward + backward). - - Args: - model: The model to benchmark. - input_ids: Input tensor. - num_warmup: Warmup iterations (discarded). - num_timed: Timed iterations to average. - distributed: Whether to use dist.barrier() for synchronization. - cp_context_fn: Optional callable returning a context manager for CP. - labels: Optional labels tensor. If None, uses input_ids. - position_ids: Optional position_ids for correct RoPE with CP. - **extra_fwd_kwargs: Additional kwargs for model forward. - """ - if labels is None: - labels = input_ids - - fwd_kwargs = {"input_ids": input_ids, "labels": labels, **extra_fwd_kwargs} - if position_ids is not None: - fwd_kwargs["position_ids"] = position_ids - - for _ in range(num_warmup): - ctx = cp_context_fn() if cp_context_fn else nullcontext() - with ctx: - output = model(**fwd_kwargs) - output.loss.backward() - model.zero_grad(set_to_none=True) - - if distributed: - dist.barrier() - torch.cuda.synchronize() - - times = [] - for _ in range(num_timed): - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - torch.cuda.synchronize() - start.record() - - ctx = cp_context_fn() if cp_context_fn else nullcontext() - with ctx: - output = model(**fwd_kwargs) - output.loss.backward() - model.zero_grad(set_to_none=True) - - end.record() - torch.cuda.synchronize() - times.append(start.elapsed_time(end) / 1000.0) - - return sum(times) / len(times) - - # ============================================================================= # Utilities # ============================================================================= -def split_for_cp_bshd(tensor, cp_rank, cp_size): - """Split a BSHD tensor for CP using the dual-chunk zigzag pattern.""" - if cp_size <= 1: - return tensor - total_chunks = 2 * cp_size - seq_len = tensor.size(1) - chunk_size = seq_len // total_chunks - chunk_indices = [cp_rank, total_chunks - cp_rank - 1] - slices = [tensor[:, idx * chunk_size : (idx + 1) * chunk_size] for idx in chunk_indices] - return torch.cat(slices, dim=1) - - def measure_bus_bandwidth(device, world_size, num_iters=20, num_elements=10_000_000): """Measure unidirectional P2P bandwidth via send/recv (matches CP ring pattern).""" if world_size <= 1: @@ -595,14 +489,6 @@ def measure_bus_bandwidth(device, world_size, num_iters=20, num_elements=10_000_ return num_iters * data_bytes / elapsed / 1e9 -def count_flops_with_model(model, input_ids): - """Count forward FLOPs using PyTorch's FlopCounterMode, return 3x for training.""" - flop_counter = FlopCounterMode(display=False) - with flop_counter: - model(input_ids=input_ids) - return flop_counter.get_total_flops() * 3 - - def load_model_config(config_path): """Load model config dict from a local path or HuggingFace model ID. @@ -627,13 +513,6 @@ def load_model_config(config_path): return hf_config.to_dict() -def cleanup_model(model): - """Delete a model and free GPU memory.""" - del model - gc.collect() - torch.cuda.empty_cache() - - # ============================================================================= # Formatting # ============================================================================= From 1635aab12a668399f34916f22d818e193f347709 Mon Sep 17 00:00:00 2001 From: Gagan Kaushik Date: Mon, 13 Apr 2026 20:22:45 +0000 Subject: [PATCH 10/24] Add flops tests and move source from models/esm2 to llama3_native_te - Move flops.py source from models/esm2/ to recipes/llama3_native_te/ (it's a training utility, not a model utility) - Add test_flops.py with 27 tests covering config auto-detection, analytical/simplified/hyena formulas, MFUTracker, and CP comm estimation - Update check_copied_files.py: source is now llama3_native_te, copied to esm2, codonfm, opengenome2 recipes (both flops.py and test_flops.py) Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Gagan Kaushik --- bionemo-recipes/models/esm2/flops.py | 708 ------------------ .../recipes/codonfm_native_te/flops.py | 2 +- .../codonfm_native_te/tests/test_flops.py | 323 ++++++++ .../recipes/esm2_native_te/flops.py | 2 +- .../esm2_native_te/tests/test_flops.py | 323 ++++++++ .../recipes/llama3_native_te/flops.py | 6 - .../llama3_native_te/tests/test_flops.py | 317 ++++++++ .../opengenome2_llama_native_te/flops.py | 2 +- .../tests/test_flops.py | 323 ++++++++ ci/scripts/check_copied_files.py | 11 +- 10 files changed, 1297 insertions(+), 720 deletions(-) delete mode 100644 bionemo-recipes/models/esm2/flops.py create mode 100644 bionemo-recipes/recipes/codonfm_native_te/tests/test_flops.py create mode 100644 bionemo-recipes/recipes/esm2_native_te/tests/test_flops.py create mode 100644 bionemo-recipes/recipes/llama3_native_te/tests/test_flops.py create mode 100644 bionemo-recipes/recipes/opengenome2_llama_native_te/tests/test_flops.py diff --git a/bionemo-recipes/models/esm2/flops.py b/bionemo-recipes/models/esm2/flops.py deleted file mode 100644 index bce8388069..0000000000 --- a/bionemo-recipes/models/esm2/flops.py +++ /dev/null @@ -1,708 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Architecture-independent FLOPs counting, MFU calculation, and communication overhead estimation. - -Supports transformer architectures (Llama, ESM2, CodonFM, etc.) and Hyena (Evo2). -Designed to be copied to any recipe via check_copied_files.py and hooked into -training scripts for live MFU tracking. - -Usage as a library (in training scripts): - from flops import MFUTracker, from_hf_config - tracker = MFUTracker.from_config_dict(config_dict, batch_size=4, seq_len=4096, num_gpus=2) - mfu_info = tracker.compute_mfu(step_time=0.5) - -Usage as a CLI: - python flops.py gpu-info - python flops.py flops --config-path ./model_configs/lingua-1B - python flops.py cp-comm --config-path ./model_configs/lingua-1B --cp-size 2 - torchrun --nproc_per_node=2 flops.py bandwidth -""" - -import math -import time -from dataclasses import dataclass - -import torch -import torch.distributed as dist - - -# ============================================================================= -# GPU Peak TFLOPS -# ============================================================================= - -# Dense (without sparsity) BF16 tensor core peak TFLOPS. -# NVIDIA product pages often list the 2x sparse value; dense = sparse / 2. -# Sources: -# H100/H200/GH200: https://www.nvidia.com/en-us/data-center/h100/ (1,979 TFLOPS sparse) -# A100: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf -# A6000: https://www.nvidia.com/content/dam/en-zz/Solutions/design-visualization/quadro-product-literature/proviz-print-nvidia-rtx-a6000-datasheet-us-nvidia-1454980-r9-web%20(1).pdf -# L40: https://www.nvidia.com/en-us/data-center/l40/ -# B200/GB200: https://www.nvidia.com/en-us/data-center/dgx-b200/ -# B300/GB300: https://www.nvidia.com/en-us/data-center/gb300-nvl72/ (360 PFLOPS sparse / 72 GPUs / 2) -GPU_PEAK_TFLOPS_BF16 = { - "H100": 989.0, - "H200": 989.0, - "A100": 312.0, - "A6000": 155.0, - "L40": 181.0, - "GH200": 989.0, - "B200": 2250.0, - "GB200": 2250.0, - "B300": 2500.0, - "GB300": 2500.0, -} - - -def detect_gpu_peak_tflops(): - """Auto-detect GPU peak bf16 TFLOPS from device name via substring match.""" - device_name = torch.cuda.get_device_name(0) - for gpu_key, tflops in GPU_PEAK_TFLOPS_BF16.items(): - if gpu_key.lower() in device_name.lower(): - return tflops, device_name - return None, device_name - - -# ============================================================================= -# Model FLOPs Config -# ============================================================================= - -# Model types that use gated MLP (SwiGLU/GeGLU) with 3 projections instead of 2. -GATED_MLP_MODEL_TYPES = frozenset({"llama", "mistral", "qwen2"}) - - -@dataclass(frozen=True) -class ModelFLOPsConfig: - """Architecture-independent parameters for FLOPs calculation. - - Can be constructed manually or via from_hf_config() for auto-detection. - """ - - hidden_size: int # H - num_hidden_layers: int # L - num_attention_heads: int # n_heads - num_kv_heads: int # n_kv (== n_heads for MHA) - head_dim: int # H // n_heads - intermediate_size: int # I (FFN intermediate dimension) - num_mlp_projections: int # 2 (standard FFN) or 3 (SwiGLU/GLU) - vocab_size: int # V - has_lm_head: bool # True for LM models, False for ViT etc. - - -def from_hf_config(config_dict, **overrides): - """Create ModelFLOPsConfig from an HF-compatible config dict. - - Auto-detects architecture: - - GQA vs MHA: from num_key_value_heads (absent = MHA) - - Gated MLP (3 proj) vs standard FFN (2 proj): from model_type - - LM head: from vocab_size > 0 - - Args: - config_dict: Dict with standard HF config keys (hidden_size, num_hidden_layers, etc.). - Works with config.json dicts, config.to_dict(), or MODEL_PRESETS dicts. - **overrides: Explicit overrides for any field (e.g., num_mlp_projections=3). - """ - h = config_dict["hidden_size"] - n_heads = config_dict["num_attention_heads"] - n_kv = config_dict.get("num_key_value_heads", n_heads) - vocab = config_dict.get("vocab_size", 0) - model_type = config_dict.get("model_type", "") - - # Detect gated MLP (3 projections) vs standard FFN (2 projections). - # Llama/Mistral/Qwen use SwiGLU (gate + up + down = 3 projections). - # ESM2/CodonFM/Geneformer/BERT use standard FFN (up + down = 2 projections). - num_mlp_proj = 3 if model_type in GATED_MLP_MODEL_TYPES else 2 - - kwargs = { - "hidden_size": h, - "num_hidden_layers": config_dict["num_hidden_layers"], - "num_attention_heads": n_heads, - "num_kv_heads": n_kv, - "head_dim": h // n_heads, - "intermediate_size": config_dict["intermediate_size"], - "num_mlp_projections": num_mlp_proj, - "vocab_size": vocab, - "has_lm_head": vocab > 0, - } - kwargs.update(overrides) - return ModelFLOPsConfig(**kwargs) - - -# ============================================================================= -# FLOPs Formulas -# ============================================================================= - - -def compute_flops_analytical(config, batch_size, seq_len): - """First-principles FLOPs for any transformer (GQA/MHA, SwiGLU/GELU). - - Counts matmul FLOPs only (2 FLOPs per multiply-accumulate). Excludes softmax, - layer norms, activations, and element-wise ops. - - Handles: - - GQA vs MHA: K/V projection sizes based on config.num_kv_heads - - SwiGLU vs standard FFN: 2 or 3 MLP projections - - LM head presence - - Returns: - (total_training_flops, per_layer_breakdown_dict, lm_head_forward_flops) - """ - b, s, h = batch_size, seq_len, config.hidden_size - kv_dim = config.num_kv_heads * config.head_dim - - breakdown = { - "Q projection": 2 * b * s * h * h, - "K projection": 2 * b * s * h * kv_dim, - "V projection": 2 * b * s * h * kv_dim, - "O projection": 2 * b * s * h * h, - "Attn logits": 2 * b * s * s * h, - "Attn values": 2 * b * s * s * h, - } - - ffn = config.intermediate_size - if config.num_mlp_projections == 3: - # SwiGLU/GeGLU: gate + up + down = 3 matmuls - breakdown["Gate projection"] = 2 * b * s * h * ffn - breakdown["Up projection"] = 2 * b * s * h * ffn - breakdown["Down projection"] = 2 * b * s * ffn * h - else: - # Standard FFN: up + down = 2 matmuls - breakdown["Up projection"] = 2 * b * s * h * ffn - breakdown["Down projection"] = 2 * b * s * ffn * h - - per_layer_fwd = sum(breakdown.values()) - lm_head_fwd = 2 * b * s * h * config.vocab_size if config.has_lm_head else 0 - total_fwd = config.num_hidden_layers * per_layer_fwd + lm_head_fwd - total_training = 3 * total_fwd - - return total_training, breakdown, lm_head_fwd - - -def compute_flops_simplified(batch_size, seq_len, hidden_size, num_layers, vocab_size): - """Simplified formula assuming standard MHA + standard FFN with I=4H. - - This is the formula from the Llama3 README: - (24*B*S*H^2 + 4*B*S^2*H) * 3*L + 6*B*S*H*V - - The 24*H^2 coefficient assumes: 8*H^2 for 4 attention projections (MHA) + - 16*H^2 for 2 MLP projections with I=4H. - """ - b, s, h = batch_size, seq_len, hidden_size - return (24 * b * s * h * h + 4 * b * s * s * h) * (3 * num_layers) + (6 * b * s * h * vocab_size) - - -def compute_flops_hyena(config, batch_size, seq_len, hyena_layer_counts=None): - """FLOPs for Hyena-based models (Evo2). - - Based on evo2_provider.py. Hyena replaces attention with O(N) FFT convolution. - - Args: - config: ModelFLOPsConfig with model dimensions. - batch_size: Batch size. - seq_len: Sequence length. - hyena_layer_counts: Optional dict {"S": n, "D": n, "H": n, "A": n} for - short/medium/long conv and attention layer counts. If None, assumes - all layers are long-conv Hyena (H=num_layers, no attention). - """ - b, s, h = batch_size, seq_len, config.hidden_size - ffn = config.intermediate_size - - if hyena_layer_counts is None: - hyena_layer_counts = {"S": 0, "D": 0, "H": config.num_hidden_layers, "A": 0} - - # Common per-layer FLOPs - pre_attn_qkv_proj = 2 * 3 * b * s * h * h - post_attn_proj = 2 * b * s * h * h - glu_ffn = 2 * 3 * b * s * ffn * h - - # Layer-type-specific FLOPs (defaults from evo2_provider.py) - attn = 2 * 2 * b * h * s * s # Standard S^2 attention - hyena_proj = 2 * 3 * b * s * 3 * h # short_conv_L=3 default - hyena_short_conv = 2 * b * s * 7 * h # short_conv_len=7 - hyena_medium_conv = 2 * b * s * 128 * h # medium_conv_len=128 - hyena_long_fft = b * 10 * s * math.log2(max(s, 2)) * h - - n_s = hyena_layer_counts.get("S", 0) - n_d = hyena_layer_counts.get("D", 0) - n_h = hyena_layer_counts.get("H", 0) - n_a = hyena_layer_counts.get("A", 0) - - logits = 2 * b * s * h * config.vocab_size if config.has_lm_head else 0 - - total_fwd = ( - logits - + config.num_hidden_layers * (pre_attn_qkv_proj + post_attn_proj + glu_ffn) - + n_a * attn - + (n_s + n_d + n_h) * hyena_proj - + n_s * hyena_short_conv - + n_d * hyena_medium_conv - + int(n_h * hyena_long_fft) - ) - - return 3 * total_fwd - - -# ============================================================================= -# MFU Tracker -# ============================================================================= - - -class MFUTracker: - """Tracks MFU during training. Initialize once, call compute_mfu() per step. - - Usage: - tracker = MFUTracker.from_config_dict(config_dict, batch_size=4, seq_len=4096, num_gpus=2) - # In training loop: - mfu_info = tracker.compute_mfu(step_time=0.5) - print(f"MFU: {mfu_info['mfu']:.1f}%") - """ - - def __init__( - self, - config, - batch_size, - seq_len, - num_gpus=1, - parallelism=None, - peak_tflops=None, - formula="analytical", - hyena_layer_counts=None, - ): - """Initialize MFU tracker. - - Args: - config: ModelFLOPsConfig instance. - batch_size: Micro batch size per GPU. - seq_len: Sequence length. - num_gpus: Total number of GPUs. - parallelism: Dict of parallelism dimensions, e.g. {"dp": 1, "cp": 2, "tp": 1}. - Used for communication overhead estimation. - peak_tflops: GPU peak bf16 TFLOPS. Auto-detected if None. - formula: "analytical", "simplified", or "hyena". - hyena_layer_counts: For Hyena formula, dict of layer type counts. - """ - self.config = config - self.batch_size = batch_size - self.seq_len = seq_len - self.num_gpus = num_gpus - self.parallelism = parallelism or {} - self.formula = formula - - if formula == "analytical": - self.total_flops, self.breakdown, self.lm_head_flops = compute_flops_analytical( - config, batch_size, seq_len - ) - elif formula == "simplified": - self.total_flops = compute_flops_simplified( - batch_size, seq_len, config.hidden_size, config.num_hidden_layers, config.vocab_size - ) - self.breakdown = None - self.lm_head_flops = 0 - elif formula == "hyena": - self.total_flops = compute_flops_hyena(config, batch_size, seq_len, hyena_layer_counts) - self.breakdown = None - self.lm_head_flops = 0 - else: - raise ValueError(f"Unknown formula: {formula!r}. Use 'analytical', 'simplified', or 'hyena'.") - - self.per_gpu_flops = self.total_flops // max(num_gpus, 1) - - if peak_tflops is not None: - self.peak_tflops = peak_tflops - self.device_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "unknown" - else: - detected, self.device_name = detect_gpu_peak_tflops() - self.peak_tflops = detected - - self.comm_bytes = self._estimate_comm() - - @classmethod - def from_config_dict(cls, config_dict, batch_size, seq_len, **kwargs): - """Create from an HF config dict with auto-detection.""" - config = from_hf_config(config_dict) - return cls(config, batch_size, seq_len, **kwargs) - - def compute_mfu(self, step_time): - """Compute MFU from measured step time. - - Args: - step_time: Wall-clock time for one training step (seconds). - - Returns: - Dict with mfu (%), tflops_per_gpu, per_gpu_flops, total_flops, step_time. - """ - tflops = self.per_gpu_flops / step_time / 1e12 - mfu = tflops / self.peak_tflops * 100 if self.peak_tflops else 0.0 - return { - "mfu": mfu, - "tflops_per_gpu": tflops, - "per_gpu_flops": self.per_gpu_flops, - "total_flops": self.total_flops, - "step_time": step_time, - } - - def estimate_comm_overhead(self, step_time, measured_bw_gbps=None): - """Estimate communication overhead as a fraction of step time. - - Args: - step_time: Measured step time in seconds. - measured_bw_gbps: Measured P2P bandwidth in GB/s. If None, uses a default estimate. - - Returns: - Dict with comm_bytes, estimated_comm_time, comm_pct. - """ - bw = measured_bw_gbps or 6.0 # Default ~PCIe Gen3 x8 - comm_time = self.comm_bytes / (bw * 1e9) if bw > 0 else 0.0 - comm_pct = comm_time / step_time * 100 if step_time > 0 else 0.0 - return {"comm_bytes": self.comm_bytes, "estimated_comm_time": comm_time, "comm_pct": comm_pct} - - def _estimate_comm(self): - """Estimate total communication bytes per step based on parallelism.""" - total = 0 - cp_size = self.parallelism.get("cp", 1) - dp_size = self.parallelism.get("dp", 1) - - if cp_size > 1: - total += estimate_cp_comm_bytes( - self.batch_size, - self.seq_len, - self.config.num_hidden_layers, - self.config.num_kv_heads, - self.config.head_dim, - cp_size, - ) - - if dp_size > 1: - # FSDP reduce-scatter estimate: ~2 * model_params * dtype_bytes * (dp-1)/dp - model_params = _estimate_model_params(self.config) - total += 2 * model_params * 2 * (dp_size - 1) // dp_size - - return total - - def summary(self): - """Return a human-readable summary string.""" - lines = [ - f"MFUTracker: {self.formula} formula, {self.num_gpus} GPU(s)", - f" Model: H={self.config.hidden_size}, L={self.config.num_hidden_layers}," - f" heads={self.config.num_attention_heads}, kv_heads={self.config.num_kv_heads}," - f" I={self.config.intermediate_size}, V={self.config.vocab_size}", - f" MLP projections: {self.config.num_mlp_projections}" - f" ({'SwiGLU/GLU' if self.config.num_mlp_projections == 3 else 'standard FFN'})", - f" Batch: B={self.batch_size}, S={self.seq_len}", - f" Total FLOPs/step: {format_flops(self.total_flops)} ({format_flops_exact(self.total_flops)})", - f" Per-GPU FLOPs: {format_flops(self.per_gpu_flops)}", - f" GPU: {self.device_name} (Peak: {self.peak_tflops} TFLOPS)" if self.peak_tflops else " GPU: unknown", - ] - if self.parallelism: - lines.append(f" Parallelism: {self.parallelism}") - if self.comm_bytes > 0: - lines.append(f" Estimated comm: {format_bytes(self.comm_bytes)}/step") - return "\n".join(lines) - - -# ============================================================================= -# Communication Estimation -# ============================================================================= - - -def estimate_cp_comm_bytes(b, s, num_layers, n_kv_heads, head_dim, cp_size, dtype_bytes=2): - """Estimate total bytes transferred for CP ring attention per training step. - - Ring attention sends local KV chunks around the ring. Per layer forward: - (cp-1) steps, each sending B * (S/cp) * 2 * kv_dim * dtype_bytes. - Training = ~2x forward communication (forward sends KV, backward sends dKV). - """ - if cp_size <= 1: - return 0 - s_local = s // cp_size - kv_dim = n_kv_heads * head_dim - per_layer_fwd = (cp_size - 1) * b * s_local * 2 * kv_dim * dtype_bytes - return 2 * num_layers * per_layer_fwd - - -def _estimate_model_params(config): - """Rough parameter count estimate from config dimensions.""" - h = config.hidden_size - kv_dim = config.num_kv_heads * config.head_dim - attn_params = h * h + 2 * h * kv_dim + h * h # Q + K + V + O - mlp_params = config.num_mlp_projections * h * config.intermediate_size - layer_params = attn_params + mlp_params - total = config.num_hidden_layers * layer_params - if config.has_lm_head: - total += config.vocab_size * h * 2 # embed + lm_head - return total - - -# ============================================================================= -# Utilities -# ============================================================================= - - -def measure_bus_bandwidth(device, world_size, num_iters=20, num_elements=10_000_000): - """Measure unidirectional P2P bandwidth via send/recv (matches CP ring pattern).""" - if world_size <= 1: - return 0.0 - - rank = dist.get_rank() - tensor = torch.randn(num_elements, device=device, dtype=torch.bfloat16) - peer = 1 - rank - - for _ in range(5): - if rank == 0: - dist.send(tensor, dst=peer) - dist.recv(tensor, src=peer) - else: - dist.recv(tensor, src=peer) - dist.send(tensor, dst=peer) - torch.cuda.synchronize() - - dist.barrier() - torch.cuda.synchronize() - start = time.perf_counter() - for _ in range(num_iters): - if rank == 0: - dist.send(tensor, dst=peer) - else: - dist.recv(tensor, src=peer) - torch.cuda.synchronize() - elapsed = time.perf_counter() - start - - data_bytes = tensor.nelement() * tensor.element_size() - return num_iters * data_bytes / elapsed / 1e9 - - -def load_model_config(config_path): - """Load model config dict from a local path or HuggingFace model ID. - - Supports: - - Local directory: ./model_configs/lingua-1B (reads config.json inside) - - Local file: ./model_configs/lingua-1B/config.json - - HF model ID: nvidia/esm2_t36_3B_UR50D (fetches from HuggingFace Hub) - """ - import json - from pathlib import Path - - path = Path(config_path) - if path.is_dir(): - path = path / "config.json" - if path.exists(): - return json.loads(path.read_text()) - - # Fall back to HuggingFace Hub - from transformers import AutoConfig - - hf_config = AutoConfig.from_pretrained(config_path, trust_remote_code=True) - return hf_config.to_dict() - - -# ============================================================================= -# Formatting -# ============================================================================= - - -def format_flops(flops): - """Format FLOPs with appropriate unit (G/T/P).""" - if flops >= 1e15: - return f"{flops / 1e15:.2f} P" - elif flops >= 1e12: - return f"{flops / 1e12:.2f} T" - elif flops >= 1e9: - return f"{flops / 1e9:.2f} G" - else: - return f"{flops:.2e}" - - -def format_flops_exact(flops): - """Format FLOPs as the full integer with commas.""" - return f"{int(flops):,}" - - -def format_bytes(num_bytes): - """Format bytes with appropriate unit.""" - if num_bytes >= 1e9: - return f"{num_bytes / 1e9:.2f} GB" - elif num_bytes >= 1e6: - return f"{num_bytes / 1e6:.2f} MB" - elif num_bytes >= 1e3: - return f"{num_bytes / 1e3:.2f} KB" - else: - return f"{num_bytes:.0f} B" - - -def print_breakdown(breakdown, lm_head_fwd, num_layers, total_flops, model_params): - """Print first-principles FLOPs breakdown.""" - print() - print("--- First Principles Breakdown (forward pass, per layer) ---") - per_layer_total = sum(breakdown.values()) - for component, flops_val in breakdown.items(): - pct = flops_val / per_layer_total * 100 - print(f" {component:<20} {format_flops(flops_val):>12} ({pct:>5.1f}%)") - total_fwd = num_layers * per_layer_total + lm_head_fwd - print(f" {'LM head':<20} {format_flops(lm_head_fwd):>12}") - print(f" {'Per-layer total':<20} {format_flops(per_layer_total):>12}") - print(f" {'All layers (x' + str(num_layers) + ')':<20} {format_flops(num_layers * per_layer_total):>12}") - print(f" {'Total forward':<20} {format_flops(total_fwd):>12}") - print(f" {'Total training (3x)':<20} {format_flops(total_flops):>12}") - print(f" {'Model params':<20} {model_params / 1e9:.2f}B") - - -# ============================================================================= -# CLI -# ============================================================================= - - -def _cli_bandwidth(): - """Measure P2P bandwidth. Launch with: torchrun --nproc_per_node=2 flops.py bandwidth.""" - import os - - dist.init_process_group(backend="cpu:gloo,cuda:nccl") - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - torch.cuda.set_device(local_rank) - rank = dist.get_rank() - world_size = dist.get_world_size() - device = torch.device(f"cuda:{local_rank}") - - if rank == 0: - print(f"Measuring P2P bandwidth between {world_size} GPUs...") - for i in range(world_size): - print(f" GPU {i}: {torch.cuda.get_device_name(i)}") - - bw = measure_bus_bandwidth(device, world_size) - if rank == 0: - print(f"\nUnidirectional P2P bandwidth: {bw:.2f} GB/s") - - dist.destroy_process_group() - - -def _cli_gpu_info(): - """Print GPU info and peak TFLOPS.""" - peak, name = detect_gpu_peak_tflops() - print(f"GPU: {name}") - if peak: - print(f"Peak bf16 TFLOPS: {peak:.1f}") - else: - print("Peak bf16 TFLOPS: unknown (use --peak-tflops to override)") - print() - print("Known GPUs:") - for gpu, tflops in GPU_PEAK_TFLOPS_BF16.items(): - print(f" {gpu:<16} {tflops:>8.1f} TFLOPS") - - -def _cli_flops(): - """Compute FLOPs for a model config.""" - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument("command") - parser.add_argument("--config-path", default="./model_configs/lingua-1B") - parser.add_argument("--batch-size", type=int, default=1) - parser.add_argument("--seq-len", type=int, default=4096) - parser.add_argument("--formula", default="analytical", choices=["analytical", "simplified"]) - parser.add_argument("--num-gpus", type=int, default=1, help="Number of GPUs for per-GPU FLOPs and comm estimates") - parser.add_argument("--cp-size", type=int, default=1, help="Context parallelism size for comm overhead estimate") - parser.add_argument("--p2p-bw", type=float, default=6.0, help="P2P bandwidth in GB/s for comm time estimate") - args = parser.parse_args() - - cfg_dict = load_model_config(args.config_path) - config = from_hf_config(cfg_dict) - b, s = args.batch_size, args.seq_len - - print( - f"Config: H={config.hidden_size}, L={config.num_hidden_layers}," - f" n_heads={config.num_attention_heads}, n_kv={config.num_kv_heads}," - f" I={config.intermediate_size}, V={config.vocab_size}" - ) - print( - f"MLP: {config.num_mlp_projections} projections" - f" ({'SwiGLU/GLU' if config.num_mlp_projections == 3 else 'standard FFN'})" - ) - print(f"Batch: B={b}, S={s}, GPUs={args.num_gpus}, CP={args.cp_size}") - print() - - simplified = compute_flops_simplified(b, s, config.hidden_size, config.num_hidden_layers, config.vocab_size) - analytical, breakdown, lm_head = compute_flops_analytical(config, b, s) - - print(f"{'Method':<24} {'FLOPs/step':>14} {'Per-GPU':>14} {'Exact':>30}") - print("-" * 86) - for name, flops in [("Simplified (README)", simplified), ("Analytical", analytical)]: - per_gpu = flops // max(args.num_gpus, 1) - print(f"{name:<24} {format_flops(flops):>14} {format_flops(per_gpu):>14} {format_flops_exact(flops):>30}") - - if simplified != analytical: - diff = analytical - simplified - print(f"\nDifference: {format_flops_exact(diff)} ({diff / simplified * 100:+.2f}%)") - else: - print("\nFormulas agree exactly for this config.") - - # Communication overhead estimate - if args.cp_size > 1: - dp_size = args.num_gpus // args.cp_size - parallelism = {"dp": dp_size, "cp": args.cp_size} - tracker = MFUTracker(config, b, s, num_gpus=args.num_gpus, parallelism=parallelism) - print(f"\n--- Communication Overhead (CP={args.cp_size}, P2P BW={args.p2p_bw} GB/s) ---") - print( - f" CP ring attention: {format_bytes(tracker.comm_bytes)}/step ({format_flops_exact(tracker.comm_bytes)} bytes)" - ) - comm_time = tracker.comm_bytes / (args.p2p_bw * 1e9) if args.p2p_bw > 0 else 0 - print(f" Estimated comm time: {comm_time:.4f}s") - - model_params = _estimate_model_params(config) - print_breakdown(breakdown, lm_head, config.num_hidden_layers, analytical, model_params) - - -def _cli_cp_comm(): - """Estimate CP communication volume.""" - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument("command") - parser.add_argument("--config-path", default="./model_configs/lingua-1B") - parser.add_argument("--batch-size", type=int, default=1) - parser.add_argument("--seq-len", type=int, default=16384) - parser.add_argument("--cp-size", type=int, default=2) - args = parser.parse_args() - - cfg_dict = load_model_config(args.config_path) - config = from_hf_config(cfg_dict) - b, s = args.batch_size, args.seq_len - - comm = estimate_cp_comm_bytes(b, s, config.num_hidden_layers, config.num_kv_heads, config.head_dim, args.cp_size) - print( - f"CP={args.cp_size}, B={b}, S={s}, L={config.num_hidden_layers}," - f" n_kv_heads={config.num_kv_heads}, head_dim={config.head_dim}" - ) - print(f"Estimated CP ring attention communication: {format_bytes(comm)}/step ({format_flops_exact(comm)} bytes)") - - -if __name__ == "__main__": - import sys - - commands = { - "bandwidth": ("Measure P2P bandwidth (requires torchrun --nproc_per_node=2)", _cli_bandwidth), - "gpu-info": ("Print GPU info and peak TFLOPS", _cli_gpu_info), - "flops": ("Compute FLOPs for a model config", _cli_flops), - "cp-comm": ("Estimate CP communication volume", _cli_cp_comm), - } - - if len(sys.argv) < 2 or sys.argv[1] in ("-h", "--help") or sys.argv[1] not in commands: - print("Usage: python flops.py [options]") - print(" torchrun --nproc_per_node=2 flops.py bandwidth") - print() - print("Commands:") - for cmd, (desc, _) in commands.items(): - print(f" {cmd:<16} {desc}") - sys.exit(0 if len(sys.argv) >= 2 and sys.argv[1] in ("-h", "--help") else 1) - - commands[sys.argv[1]][1]() diff --git a/bionemo-recipes/recipes/codonfm_native_te/flops.py b/bionemo-recipes/recipes/codonfm_native_te/flops.py index f4fb7411fc..ade8d8a12c 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/flops.py +++ b/bionemo-recipes/recipes/codonfm_native_te/flops.py @@ -14,7 +14,7 @@ # limitations under the License. # --- BEGIN COPIED FILE NOTICE --- -# This file is copied from: bionemo-recipes/models/esm2/flops.py +# This file is copied from: bionemo-recipes/recipes/llama3_native_te/flops.py # Do not modify this file directly. Instead, modify the source and run: # python ci/scripts/check_copied_files.py --fix # --- END COPIED FILE NOTICE --- diff --git a/bionemo-recipes/recipes/codonfm_native_te/tests/test_flops.py b/bionemo-recipes/recipes/codonfm_native_te/tests/test_flops.py new file mode 100644 index 0000000000..f514205fef --- /dev/null +++ b/bionemo-recipes/recipes/codonfm_native_te/tests/test_flops.py @@ -0,0 +1,323 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# --- BEGIN COPIED FILE NOTICE --- +# This file is copied from: bionemo-recipes/recipes/llama3_native_te/tests/test_flops.py +# Do not modify this file directly. Instead, modify the source and run: +# python ci/scripts/check_copied_files.py --fix +# --- END COPIED FILE NOTICE --- + +"""Tests for the flops.py FLOPs counting and MFU module.""" + +import sys +from pathlib import Path + +import pytest + + +# Add parent directory so we can import flops +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from flops import ( + MFUTracker, + ModelFLOPsConfig, + compute_flops_analytical, + compute_flops_hyena, + compute_flops_simplified, + estimate_cp_comm_bytes, + from_hf_config, +) + + +# ============================================================================ +# Test configs matching real models +# ============================================================================ + +LLAMA_1B_CONFIG = { + "hidden_size": 2048, + "num_hidden_layers": 25, + "num_attention_heads": 16, + "num_key_value_heads": 8, + "intermediate_size": 6144, + "vocab_size": 128256, + "model_type": "llama", + "hidden_act": "silu", +} + +ESM2_8M_CONFIG = { + "hidden_size": 320, + "num_hidden_layers": 6, + "num_attention_heads": 20, + "intermediate_size": 1280, + "vocab_size": 33, + "model_type": "nv_esm", + "hidden_act": "gelu", +} + +CODONFM_1B_CONFIG = { + "hidden_size": 2048, + "num_hidden_layers": 18, + "num_attention_heads": 16, + "intermediate_size": 8192, +} + + +# ============================================================================ +# Config auto-detection +# ============================================================================ + + +class TestFromHfConfig: + """Test auto-detection of model architecture from config dicts.""" + + def test_llama_detects_gqa_and_swiglu(self): + cfg = from_hf_config(LLAMA_1B_CONFIG) + assert cfg.num_kv_heads == 8 + assert cfg.num_mlp_projections == 3 + assert cfg.head_dim == 128 + + def test_esm2_detects_mha_and_standard_ffn(self): + cfg = from_hf_config(ESM2_8M_CONFIG) + assert cfg.num_kv_heads == 20 + assert cfg.num_mlp_projections == 2 + + def test_codonfm_defaults_to_mha_and_2_proj(self): + cfg = from_hf_config(CODONFM_1B_CONFIG) + assert cfg.num_kv_heads == 16 + assert cfg.num_mlp_projections == 2 + + def test_missing_vocab_defaults_to_no_lm_head(self): + cfg = from_hf_config( + {"hidden_size": 128, "num_hidden_layers": 2, "num_attention_heads": 4, "intermediate_size": 512} + ) + assert cfg.vocab_size == 0 + assert cfg.has_lm_head is False + + def test_overrides_take_precedence(self): + cfg = from_hf_config(ESM2_8M_CONFIG, num_mlp_projections=3, vocab_size=0, has_lm_head=False) + assert cfg.num_mlp_projections == 3 + assert cfg.vocab_size == 0 + assert cfg.has_lm_head is False + + +# ============================================================================ +# Analytical FLOPs formula +# ============================================================================ + + +class TestComputeFlopsAnalytical: + """Test the first-principles analytical FLOPs formula.""" + + def test_training_is_3x_forward(self): + cfg = from_hf_config(LLAMA_1B_CONFIG) + total, breakdown, lm_head = compute_flops_analytical(cfg, 1, 4096) + forward = cfg.num_hidden_layers * sum(breakdown.values()) + lm_head + assert total == 3 * forward + + def test_swiglu_has_3_mlp_projections(self): + cfg = from_hf_config(LLAMA_1B_CONFIG) + _, breakdown, _ = compute_flops_analytical(cfg, 1, 1024) + assert "Gate projection" in breakdown + assert "Up projection" in breakdown + assert "Down projection" in breakdown + + def test_standard_ffn_has_2_mlp_projections(self): + cfg = from_hf_config(ESM2_8M_CONFIG) + _, breakdown, _ = compute_flops_analytical(cfg, 1, 1024) + assert "Gate projection" not in breakdown + assert "Up projection" in breakdown + assert "Down projection" in breakdown + + def test_no_lm_head_when_vocab_zero(self): + cfg = from_hf_config(CODONFM_1B_CONFIG) + _, _, lm_head = compute_flops_analytical(cfg, 1, 1024) + assert lm_head == 0 + + def test_flops_scale_linearly_with_batch(self): + cfg = from_hf_config(LLAMA_1B_CONFIG) + flops_b1, _, _ = compute_flops_analytical(cfg, 1, 1024) + flops_b4, _, _ = compute_flops_analytical(cfg, 4, 1024) + assert flops_b4 == 4 * flops_b1 + + def test_known_value_llama_lingua_1b(self): + """Golden value: validated against PyTorch FlopCounterMode and README formula.""" + cfg = from_hf_config(LLAMA_1B_CONFIG) + total, _, _ = compute_flops_analytical(cfg, 1, 4096) + assert total == 47_687_021_887_488 + + +# ============================================================================ +# Simplified formula +# ============================================================================ + + +class TestComputeFlopsSimplified: + """Test the simplified README formula and its relationship to analytical.""" + + def test_matches_analytical_when_mha_and_i_equals_4h(self): + """ESM2 has MHA + I=4H + 2 projections: same assumptions as simplified formula.""" + cfg = from_hf_config(ESM2_8M_CONFIG) + analytical, _, _ = compute_flops_analytical(cfg, 1, 1024) + simplified = compute_flops_simplified(1, 1024, cfg.hidden_size, cfg.num_hidden_layers, cfg.vocab_size) + assert analytical == simplified + + def test_differs_when_gqa_or_swiglu(self): + """GQA + SwiGLU breaks the simplified formula's assumptions.""" + cfg_dict = { + **LLAMA_1B_CONFIG, + "num_attention_heads": 32, + "num_key_value_heads": 8, + "intermediate_size": 8192, + "num_hidden_layers": 16, + } + cfg = from_hf_config(cfg_dict) + analytical, _, _ = compute_flops_analytical(cfg, 1, 4096) + simplified = compute_flops_simplified(1, 4096, cfg.hidden_size, cfg.num_hidden_layers, cfg.vocab_size) + assert analytical != simplified + + +# ============================================================================ +# Hyena formula +# ============================================================================ + + +class TestComputeFlopsHyena: + """Test the Hyena (Evo2) FLOPs formula.""" + + @pytest.fixture() + def hyena_config(self): + return ModelFLOPsConfig( + hidden_size=1024, + num_hidden_layers=4, + num_attention_heads=8, + num_kv_heads=8, + head_dim=128, + intermediate_size=4096, + num_mlp_projections=3, + vocab_size=512, + has_lm_head=True, + ) + + def test_scales_subquadratically(self, hyena_config): + """Hyena uses O(S log S) convolution, not O(S^2) attention.""" + flops_1k = compute_flops_hyena(hyena_config, 1, 1024, hyena_layer_counts={"S": 0, "D": 0, "H": 4, "A": 0}) + flops_2k = compute_flops_hyena(hyena_config, 1, 2048, hyena_layer_counts={"S": 0, "D": 0, "H": 4, "A": 0}) + ratio = flops_2k / flops_1k + assert ratio < 3.0 # S^2 would give ~4x; Hyena should be well below + + def test_hybrid_attention_adds_quadratic_cost(self, hyena_config): + """Adding standard attention layers increases FLOPs due to S^2 term.""" + all_hyena = compute_flops_hyena(hyena_config, 1, 4096, hyena_layer_counts={"S": 0, "D": 0, "H": 4, "A": 0}) + with_attn = compute_flops_hyena(hyena_config, 1, 4096, hyena_layer_counts={"S": 0, "D": 0, "H": 2, "A": 2}) + assert with_attn > all_hyena + + +# ============================================================================ +# MFUTracker +# ============================================================================ + + +class TestMFUTracker: + """Test the MFUTracker class used by training scripts.""" + + def test_from_config_dict(self): + tracker = MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0) + assert tracker.total_flops == 47_687_021_887_488 + assert tracker.per_gpu_flops == tracker.total_flops + + def test_multi_gpu_divides_flops(self): + single = MFUTracker.from_config_dict( + LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, num_gpus=1, peak_tflops=155.0 + ) + multi = MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, num_gpus=2, peak_tflops=155.0) + assert multi.per_gpu_flops == single.total_flops // 2 + + def test_compute_mfu_correctness(self): + tracker = MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0) + result = tracker.compute_mfu(step_time=0.5) + expected_tflops = tracker.per_gpu_flops / 0.5 / 1e12 + expected_mfu = expected_tflops / 155.0 * 100 + assert abs(result["mfu"] - expected_mfu) < 0.01 + assert abs(result["tflops_per_gpu"] - expected_tflops) < 0.01 + + def test_mfu_inversely_proportional_to_step_time(self): + tracker = MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0) + fast = tracker.compute_mfu(step_time=0.5) + slow = tracker.compute_mfu(step_time=1.0) + assert abs(fast["mfu"] - 2 * slow["mfu"]) < 0.01 + + def test_all_formula_options(self): + for formula in ["analytical", "simplified", "hyena"]: + if formula == "hyena": + cfg = ModelFLOPsConfig( + hidden_size=1024, + num_hidden_layers=4, + num_attention_heads=8, + num_kv_heads=8, + head_dim=128, + intermediate_size=4096, + num_mlp_projections=3, + vocab_size=512, + has_lm_head=True, + ) + tracker = MFUTracker(cfg, batch_size=1, seq_len=4096, peak_tflops=155.0, formula=formula) + else: + tracker = MFUTracker.from_config_dict( + LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0, formula=formula + ) + assert tracker.total_flops > 0 + + def test_invalid_formula_raises(self): + with pytest.raises(ValueError, match="Unknown formula"): + MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0, formula="bad") + + def test_cp_communication_estimate(self): + tracker = MFUTracker.from_config_dict( + LLAMA_1B_CONFIG, batch_size=1, seq_len=16384, num_gpus=2, parallelism={"dp": 1, "cp": 2}, peak_tflops=155.0 + ) + assert tracker.comm_bytes > 0 + overhead = tracker.estimate_comm_overhead(step_time=1.0, measured_bw_gbps=6.6) + assert overhead["estimated_comm_time"] > 0 + assert 0 < overhead["comm_pct"] < 100 + + def test_no_comm_single_gpu(self): + tracker = MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0) + assert tracker.comm_bytes == 0 + + +# ============================================================================ +# Communication estimation +# ============================================================================ + + +class TestCPCommEstimation: + """Test CP ring attention communication byte estimates.""" + + def test_zero_without_cp(self): + assert estimate_cp_comm_bytes(1, 4096, 25, 8, 128, cp_size=1) == 0 + + def test_scales_linearly_with_seq_len(self): + comm_4k = estimate_cp_comm_bytes(1, 4096, 25, 8, 128, cp_size=2) + comm_8k = estimate_cp_comm_bytes(1, 8192, 25, 8, 128, cp_size=2) + assert comm_8k == 2 * comm_4k + + def test_scales_linearly_with_batch(self): + comm_b1 = estimate_cp_comm_bytes(1, 4096, 25, 8, 128, cp_size=2) + comm_b4 = estimate_cp_comm_bytes(4, 4096, 25, 8, 128, cp_size=2) + assert comm_b4 == 4 * comm_b1 + + def test_known_value_lingua_1b(self): + """Golden value for lingua-1B at S=4096, CP=2.""" + assert estimate_cp_comm_bytes(1, 4096, 25, 8, 128, cp_size=2) == 419_430_400 diff --git a/bionemo-recipes/recipes/esm2_native_te/flops.py b/bionemo-recipes/recipes/esm2_native_te/flops.py index f4fb7411fc..ade8d8a12c 100644 --- a/bionemo-recipes/recipes/esm2_native_te/flops.py +++ b/bionemo-recipes/recipes/esm2_native_te/flops.py @@ -14,7 +14,7 @@ # limitations under the License. # --- BEGIN COPIED FILE NOTICE --- -# This file is copied from: bionemo-recipes/models/esm2/flops.py +# This file is copied from: bionemo-recipes/recipes/llama3_native_te/flops.py # Do not modify this file directly. Instead, modify the source and run: # python ci/scripts/check_copied_files.py --fix # --- END COPIED FILE NOTICE --- diff --git a/bionemo-recipes/recipes/esm2_native_te/tests/test_flops.py b/bionemo-recipes/recipes/esm2_native_te/tests/test_flops.py new file mode 100644 index 0000000000..f514205fef --- /dev/null +++ b/bionemo-recipes/recipes/esm2_native_te/tests/test_flops.py @@ -0,0 +1,323 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# --- BEGIN COPIED FILE NOTICE --- +# This file is copied from: bionemo-recipes/recipes/llama3_native_te/tests/test_flops.py +# Do not modify this file directly. Instead, modify the source and run: +# python ci/scripts/check_copied_files.py --fix +# --- END COPIED FILE NOTICE --- + +"""Tests for the flops.py FLOPs counting and MFU module.""" + +import sys +from pathlib import Path + +import pytest + + +# Add parent directory so we can import flops +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from flops import ( + MFUTracker, + ModelFLOPsConfig, + compute_flops_analytical, + compute_flops_hyena, + compute_flops_simplified, + estimate_cp_comm_bytes, + from_hf_config, +) + + +# ============================================================================ +# Test configs matching real models +# ============================================================================ + +LLAMA_1B_CONFIG = { + "hidden_size": 2048, + "num_hidden_layers": 25, + "num_attention_heads": 16, + "num_key_value_heads": 8, + "intermediate_size": 6144, + "vocab_size": 128256, + "model_type": "llama", + "hidden_act": "silu", +} + +ESM2_8M_CONFIG = { + "hidden_size": 320, + "num_hidden_layers": 6, + "num_attention_heads": 20, + "intermediate_size": 1280, + "vocab_size": 33, + "model_type": "nv_esm", + "hidden_act": "gelu", +} + +CODONFM_1B_CONFIG = { + "hidden_size": 2048, + "num_hidden_layers": 18, + "num_attention_heads": 16, + "intermediate_size": 8192, +} + + +# ============================================================================ +# Config auto-detection +# ============================================================================ + + +class TestFromHfConfig: + """Test auto-detection of model architecture from config dicts.""" + + def test_llama_detects_gqa_and_swiglu(self): + cfg = from_hf_config(LLAMA_1B_CONFIG) + assert cfg.num_kv_heads == 8 + assert cfg.num_mlp_projections == 3 + assert cfg.head_dim == 128 + + def test_esm2_detects_mha_and_standard_ffn(self): + cfg = from_hf_config(ESM2_8M_CONFIG) + assert cfg.num_kv_heads == 20 + assert cfg.num_mlp_projections == 2 + + def test_codonfm_defaults_to_mha_and_2_proj(self): + cfg = from_hf_config(CODONFM_1B_CONFIG) + assert cfg.num_kv_heads == 16 + assert cfg.num_mlp_projections == 2 + + def test_missing_vocab_defaults_to_no_lm_head(self): + cfg = from_hf_config( + {"hidden_size": 128, "num_hidden_layers": 2, "num_attention_heads": 4, "intermediate_size": 512} + ) + assert cfg.vocab_size == 0 + assert cfg.has_lm_head is False + + def test_overrides_take_precedence(self): + cfg = from_hf_config(ESM2_8M_CONFIG, num_mlp_projections=3, vocab_size=0, has_lm_head=False) + assert cfg.num_mlp_projections == 3 + assert cfg.vocab_size == 0 + assert cfg.has_lm_head is False + + +# ============================================================================ +# Analytical FLOPs formula +# ============================================================================ + + +class TestComputeFlopsAnalytical: + """Test the first-principles analytical FLOPs formula.""" + + def test_training_is_3x_forward(self): + cfg = from_hf_config(LLAMA_1B_CONFIG) + total, breakdown, lm_head = compute_flops_analytical(cfg, 1, 4096) + forward = cfg.num_hidden_layers * sum(breakdown.values()) + lm_head + assert total == 3 * forward + + def test_swiglu_has_3_mlp_projections(self): + cfg = from_hf_config(LLAMA_1B_CONFIG) + _, breakdown, _ = compute_flops_analytical(cfg, 1, 1024) + assert "Gate projection" in breakdown + assert "Up projection" in breakdown + assert "Down projection" in breakdown + + def test_standard_ffn_has_2_mlp_projections(self): + cfg = from_hf_config(ESM2_8M_CONFIG) + _, breakdown, _ = compute_flops_analytical(cfg, 1, 1024) + assert "Gate projection" not in breakdown + assert "Up projection" in breakdown + assert "Down projection" in breakdown + + def test_no_lm_head_when_vocab_zero(self): + cfg = from_hf_config(CODONFM_1B_CONFIG) + _, _, lm_head = compute_flops_analytical(cfg, 1, 1024) + assert lm_head == 0 + + def test_flops_scale_linearly_with_batch(self): + cfg = from_hf_config(LLAMA_1B_CONFIG) + flops_b1, _, _ = compute_flops_analytical(cfg, 1, 1024) + flops_b4, _, _ = compute_flops_analytical(cfg, 4, 1024) + assert flops_b4 == 4 * flops_b1 + + def test_known_value_llama_lingua_1b(self): + """Golden value: validated against PyTorch FlopCounterMode and README formula.""" + cfg = from_hf_config(LLAMA_1B_CONFIG) + total, _, _ = compute_flops_analytical(cfg, 1, 4096) + assert total == 47_687_021_887_488 + + +# ============================================================================ +# Simplified formula +# ============================================================================ + + +class TestComputeFlopsSimplified: + """Test the simplified README formula and its relationship to analytical.""" + + def test_matches_analytical_when_mha_and_i_equals_4h(self): + """ESM2 has MHA + I=4H + 2 projections: same assumptions as simplified formula.""" + cfg = from_hf_config(ESM2_8M_CONFIG) + analytical, _, _ = compute_flops_analytical(cfg, 1, 1024) + simplified = compute_flops_simplified(1, 1024, cfg.hidden_size, cfg.num_hidden_layers, cfg.vocab_size) + assert analytical == simplified + + def test_differs_when_gqa_or_swiglu(self): + """GQA + SwiGLU breaks the simplified formula's assumptions.""" + cfg_dict = { + **LLAMA_1B_CONFIG, + "num_attention_heads": 32, + "num_key_value_heads": 8, + "intermediate_size": 8192, + "num_hidden_layers": 16, + } + cfg = from_hf_config(cfg_dict) + analytical, _, _ = compute_flops_analytical(cfg, 1, 4096) + simplified = compute_flops_simplified(1, 4096, cfg.hidden_size, cfg.num_hidden_layers, cfg.vocab_size) + assert analytical != simplified + + +# ============================================================================ +# Hyena formula +# ============================================================================ + + +class TestComputeFlopsHyena: + """Test the Hyena (Evo2) FLOPs formula.""" + + @pytest.fixture() + def hyena_config(self): + return ModelFLOPsConfig( + hidden_size=1024, + num_hidden_layers=4, + num_attention_heads=8, + num_kv_heads=8, + head_dim=128, + intermediate_size=4096, + num_mlp_projections=3, + vocab_size=512, + has_lm_head=True, + ) + + def test_scales_subquadratically(self, hyena_config): + """Hyena uses O(S log S) convolution, not O(S^2) attention.""" + flops_1k = compute_flops_hyena(hyena_config, 1, 1024, hyena_layer_counts={"S": 0, "D": 0, "H": 4, "A": 0}) + flops_2k = compute_flops_hyena(hyena_config, 1, 2048, hyena_layer_counts={"S": 0, "D": 0, "H": 4, "A": 0}) + ratio = flops_2k / flops_1k + assert ratio < 3.0 # S^2 would give ~4x; Hyena should be well below + + def test_hybrid_attention_adds_quadratic_cost(self, hyena_config): + """Adding standard attention layers increases FLOPs due to S^2 term.""" + all_hyena = compute_flops_hyena(hyena_config, 1, 4096, hyena_layer_counts={"S": 0, "D": 0, "H": 4, "A": 0}) + with_attn = compute_flops_hyena(hyena_config, 1, 4096, hyena_layer_counts={"S": 0, "D": 0, "H": 2, "A": 2}) + assert with_attn > all_hyena + + +# ============================================================================ +# MFUTracker +# ============================================================================ + + +class TestMFUTracker: + """Test the MFUTracker class used by training scripts.""" + + def test_from_config_dict(self): + tracker = MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0) + assert tracker.total_flops == 47_687_021_887_488 + assert tracker.per_gpu_flops == tracker.total_flops + + def test_multi_gpu_divides_flops(self): + single = MFUTracker.from_config_dict( + LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, num_gpus=1, peak_tflops=155.0 + ) + multi = MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, num_gpus=2, peak_tflops=155.0) + assert multi.per_gpu_flops == single.total_flops // 2 + + def test_compute_mfu_correctness(self): + tracker = MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0) + result = tracker.compute_mfu(step_time=0.5) + expected_tflops = tracker.per_gpu_flops / 0.5 / 1e12 + expected_mfu = expected_tflops / 155.0 * 100 + assert abs(result["mfu"] - expected_mfu) < 0.01 + assert abs(result["tflops_per_gpu"] - expected_tflops) < 0.01 + + def test_mfu_inversely_proportional_to_step_time(self): + tracker = MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0) + fast = tracker.compute_mfu(step_time=0.5) + slow = tracker.compute_mfu(step_time=1.0) + assert abs(fast["mfu"] - 2 * slow["mfu"]) < 0.01 + + def test_all_formula_options(self): + for formula in ["analytical", "simplified", "hyena"]: + if formula == "hyena": + cfg = ModelFLOPsConfig( + hidden_size=1024, + num_hidden_layers=4, + num_attention_heads=8, + num_kv_heads=8, + head_dim=128, + intermediate_size=4096, + num_mlp_projections=3, + vocab_size=512, + has_lm_head=True, + ) + tracker = MFUTracker(cfg, batch_size=1, seq_len=4096, peak_tflops=155.0, formula=formula) + else: + tracker = MFUTracker.from_config_dict( + LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0, formula=formula + ) + assert tracker.total_flops > 0 + + def test_invalid_formula_raises(self): + with pytest.raises(ValueError, match="Unknown formula"): + MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0, formula="bad") + + def test_cp_communication_estimate(self): + tracker = MFUTracker.from_config_dict( + LLAMA_1B_CONFIG, batch_size=1, seq_len=16384, num_gpus=2, parallelism={"dp": 1, "cp": 2}, peak_tflops=155.0 + ) + assert tracker.comm_bytes > 0 + overhead = tracker.estimate_comm_overhead(step_time=1.0, measured_bw_gbps=6.6) + assert overhead["estimated_comm_time"] > 0 + assert 0 < overhead["comm_pct"] < 100 + + def test_no_comm_single_gpu(self): + tracker = MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0) + assert tracker.comm_bytes == 0 + + +# ============================================================================ +# Communication estimation +# ============================================================================ + + +class TestCPCommEstimation: + """Test CP ring attention communication byte estimates.""" + + def test_zero_without_cp(self): + assert estimate_cp_comm_bytes(1, 4096, 25, 8, 128, cp_size=1) == 0 + + def test_scales_linearly_with_seq_len(self): + comm_4k = estimate_cp_comm_bytes(1, 4096, 25, 8, 128, cp_size=2) + comm_8k = estimate_cp_comm_bytes(1, 8192, 25, 8, 128, cp_size=2) + assert comm_8k == 2 * comm_4k + + def test_scales_linearly_with_batch(self): + comm_b1 = estimate_cp_comm_bytes(1, 4096, 25, 8, 128, cp_size=2) + comm_b4 = estimate_cp_comm_bytes(4, 4096, 25, 8, 128, cp_size=2) + assert comm_b4 == 4 * comm_b1 + + def test_known_value_lingua_1b(self): + """Golden value for lingua-1B at S=4096, CP=2.""" + assert estimate_cp_comm_bytes(1, 4096, 25, 8, 128, cp_size=2) == 419_430_400 diff --git a/bionemo-recipes/recipes/llama3_native_te/flops.py b/bionemo-recipes/recipes/llama3_native_te/flops.py index f4fb7411fc..bce8388069 100644 --- a/bionemo-recipes/recipes/llama3_native_te/flops.py +++ b/bionemo-recipes/recipes/llama3_native_te/flops.py @@ -13,12 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# --- BEGIN COPIED FILE NOTICE --- -# This file is copied from: bionemo-recipes/models/esm2/flops.py -# Do not modify this file directly. Instead, modify the source and run: -# python ci/scripts/check_copied_files.py --fix -# --- END COPIED FILE NOTICE --- - """Architecture-independent FLOPs counting, MFU calculation, and communication overhead estimation. Supports transformer architectures (Llama, ESM2, CodonFM, etc.) and Hyena (Evo2). diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_flops.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_flops.py new file mode 100644 index 0000000000..e9e3943658 --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_flops.py @@ -0,0 +1,317 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the flops.py FLOPs counting and MFU module.""" + +import sys +from pathlib import Path + +import pytest + + +# Add parent directory so we can import flops +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from flops import ( + MFUTracker, + ModelFLOPsConfig, + compute_flops_analytical, + compute_flops_hyena, + compute_flops_simplified, + estimate_cp_comm_bytes, + from_hf_config, +) + + +# ============================================================================ +# Test configs matching real models +# ============================================================================ + +LLAMA_1B_CONFIG = { + "hidden_size": 2048, + "num_hidden_layers": 25, + "num_attention_heads": 16, + "num_key_value_heads": 8, + "intermediate_size": 6144, + "vocab_size": 128256, + "model_type": "llama", + "hidden_act": "silu", +} + +ESM2_8M_CONFIG = { + "hidden_size": 320, + "num_hidden_layers": 6, + "num_attention_heads": 20, + "intermediate_size": 1280, + "vocab_size": 33, + "model_type": "nv_esm", + "hidden_act": "gelu", +} + +CODONFM_1B_CONFIG = { + "hidden_size": 2048, + "num_hidden_layers": 18, + "num_attention_heads": 16, + "intermediate_size": 8192, +} + + +# ============================================================================ +# Config auto-detection +# ============================================================================ + + +class TestFromHfConfig: + """Test auto-detection of model architecture from config dicts.""" + + def test_llama_detects_gqa_and_swiglu(self): + cfg = from_hf_config(LLAMA_1B_CONFIG) + assert cfg.num_kv_heads == 8 + assert cfg.num_mlp_projections == 3 + assert cfg.head_dim == 128 + + def test_esm2_detects_mha_and_standard_ffn(self): + cfg = from_hf_config(ESM2_8M_CONFIG) + assert cfg.num_kv_heads == 20 + assert cfg.num_mlp_projections == 2 + + def test_codonfm_defaults_to_mha_and_2_proj(self): + cfg = from_hf_config(CODONFM_1B_CONFIG) + assert cfg.num_kv_heads == 16 + assert cfg.num_mlp_projections == 2 + + def test_missing_vocab_defaults_to_no_lm_head(self): + cfg = from_hf_config( + {"hidden_size": 128, "num_hidden_layers": 2, "num_attention_heads": 4, "intermediate_size": 512} + ) + assert cfg.vocab_size == 0 + assert cfg.has_lm_head is False + + def test_overrides_take_precedence(self): + cfg = from_hf_config(ESM2_8M_CONFIG, num_mlp_projections=3, vocab_size=0, has_lm_head=False) + assert cfg.num_mlp_projections == 3 + assert cfg.vocab_size == 0 + assert cfg.has_lm_head is False + + +# ============================================================================ +# Analytical FLOPs formula +# ============================================================================ + + +class TestComputeFlopsAnalytical: + """Test the first-principles analytical FLOPs formula.""" + + def test_training_is_3x_forward(self): + cfg = from_hf_config(LLAMA_1B_CONFIG) + total, breakdown, lm_head = compute_flops_analytical(cfg, 1, 4096) + forward = cfg.num_hidden_layers * sum(breakdown.values()) + lm_head + assert total == 3 * forward + + def test_swiglu_has_3_mlp_projections(self): + cfg = from_hf_config(LLAMA_1B_CONFIG) + _, breakdown, _ = compute_flops_analytical(cfg, 1, 1024) + assert "Gate projection" in breakdown + assert "Up projection" in breakdown + assert "Down projection" in breakdown + + def test_standard_ffn_has_2_mlp_projections(self): + cfg = from_hf_config(ESM2_8M_CONFIG) + _, breakdown, _ = compute_flops_analytical(cfg, 1, 1024) + assert "Gate projection" not in breakdown + assert "Up projection" in breakdown + assert "Down projection" in breakdown + + def test_no_lm_head_when_vocab_zero(self): + cfg = from_hf_config(CODONFM_1B_CONFIG) + _, _, lm_head = compute_flops_analytical(cfg, 1, 1024) + assert lm_head == 0 + + def test_flops_scale_linearly_with_batch(self): + cfg = from_hf_config(LLAMA_1B_CONFIG) + flops_b1, _, _ = compute_flops_analytical(cfg, 1, 1024) + flops_b4, _, _ = compute_flops_analytical(cfg, 4, 1024) + assert flops_b4 == 4 * flops_b1 + + def test_known_value_llama_lingua_1b(self): + """Golden value: validated against PyTorch FlopCounterMode and README formula.""" + cfg = from_hf_config(LLAMA_1B_CONFIG) + total, _, _ = compute_flops_analytical(cfg, 1, 4096) + assert total == 47_687_021_887_488 + + +# ============================================================================ +# Simplified formula +# ============================================================================ + + +class TestComputeFlopsSimplified: + """Test the simplified README formula and its relationship to analytical.""" + + def test_matches_analytical_when_mha_and_i_equals_4h(self): + """ESM2 has MHA + I=4H + 2 projections: same assumptions as simplified formula.""" + cfg = from_hf_config(ESM2_8M_CONFIG) + analytical, _, _ = compute_flops_analytical(cfg, 1, 1024) + simplified = compute_flops_simplified(1, 1024, cfg.hidden_size, cfg.num_hidden_layers, cfg.vocab_size) + assert analytical == simplified + + def test_differs_when_gqa_or_swiglu(self): + """GQA + SwiGLU breaks the simplified formula's assumptions.""" + cfg_dict = { + **LLAMA_1B_CONFIG, + "num_attention_heads": 32, + "num_key_value_heads": 8, + "intermediate_size": 8192, + "num_hidden_layers": 16, + } + cfg = from_hf_config(cfg_dict) + analytical, _, _ = compute_flops_analytical(cfg, 1, 4096) + simplified = compute_flops_simplified(1, 4096, cfg.hidden_size, cfg.num_hidden_layers, cfg.vocab_size) + assert analytical != simplified + + +# ============================================================================ +# Hyena formula +# ============================================================================ + + +class TestComputeFlopsHyena: + """Test the Hyena (Evo2) FLOPs formula.""" + + @pytest.fixture() + def hyena_config(self): + return ModelFLOPsConfig( + hidden_size=1024, + num_hidden_layers=4, + num_attention_heads=8, + num_kv_heads=8, + head_dim=128, + intermediate_size=4096, + num_mlp_projections=3, + vocab_size=512, + has_lm_head=True, + ) + + def test_scales_subquadratically(self, hyena_config): + """Hyena uses O(S log S) convolution, not O(S^2) attention.""" + flops_1k = compute_flops_hyena(hyena_config, 1, 1024, hyena_layer_counts={"S": 0, "D": 0, "H": 4, "A": 0}) + flops_2k = compute_flops_hyena(hyena_config, 1, 2048, hyena_layer_counts={"S": 0, "D": 0, "H": 4, "A": 0}) + ratio = flops_2k / flops_1k + assert ratio < 3.0 # S^2 would give ~4x; Hyena should be well below + + def test_hybrid_attention_adds_quadratic_cost(self, hyena_config): + """Adding standard attention layers increases FLOPs due to S^2 term.""" + all_hyena = compute_flops_hyena(hyena_config, 1, 4096, hyena_layer_counts={"S": 0, "D": 0, "H": 4, "A": 0}) + with_attn = compute_flops_hyena(hyena_config, 1, 4096, hyena_layer_counts={"S": 0, "D": 0, "H": 2, "A": 2}) + assert with_attn > all_hyena + + +# ============================================================================ +# MFUTracker +# ============================================================================ + + +class TestMFUTracker: + """Test the MFUTracker class used by training scripts.""" + + def test_from_config_dict(self): + tracker = MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0) + assert tracker.total_flops == 47_687_021_887_488 + assert tracker.per_gpu_flops == tracker.total_flops + + def test_multi_gpu_divides_flops(self): + single = MFUTracker.from_config_dict( + LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, num_gpus=1, peak_tflops=155.0 + ) + multi = MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, num_gpus=2, peak_tflops=155.0) + assert multi.per_gpu_flops == single.total_flops // 2 + + def test_compute_mfu_correctness(self): + tracker = MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0) + result = tracker.compute_mfu(step_time=0.5) + expected_tflops = tracker.per_gpu_flops / 0.5 / 1e12 + expected_mfu = expected_tflops / 155.0 * 100 + assert abs(result["mfu"] - expected_mfu) < 0.01 + assert abs(result["tflops_per_gpu"] - expected_tflops) < 0.01 + + def test_mfu_inversely_proportional_to_step_time(self): + tracker = MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0) + fast = tracker.compute_mfu(step_time=0.5) + slow = tracker.compute_mfu(step_time=1.0) + assert abs(fast["mfu"] - 2 * slow["mfu"]) < 0.01 + + def test_all_formula_options(self): + for formula in ["analytical", "simplified", "hyena"]: + if formula == "hyena": + cfg = ModelFLOPsConfig( + hidden_size=1024, + num_hidden_layers=4, + num_attention_heads=8, + num_kv_heads=8, + head_dim=128, + intermediate_size=4096, + num_mlp_projections=3, + vocab_size=512, + has_lm_head=True, + ) + tracker = MFUTracker(cfg, batch_size=1, seq_len=4096, peak_tflops=155.0, formula=formula) + else: + tracker = MFUTracker.from_config_dict( + LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0, formula=formula + ) + assert tracker.total_flops > 0 + + def test_invalid_formula_raises(self): + with pytest.raises(ValueError, match="Unknown formula"): + MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0, formula="bad") + + def test_cp_communication_estimate(self): + tracker = MFUTracker.from_config_dict( + LLAMA_1B_CONFIG, batch_size=1, seq_len=16384, num_gpus=2, parallelism={"dp": 1, "cp": 2}, peak_tflops=155.0 + ) + assert tracker.comm_bytes > 0 + overhead = tracker.estimate_comm_overhead(step_time=1.0, measured_bw_gbps=6.6) + assert overhead["estimated_comm_time"] > 0 + assert 0 < overhead["comm_pct"] < 100 + + def test_no_comm_single_gpu(self): + tracker = MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0) + assert tracker.comm_bytes == 0 + + +# ============================================================================ +# Communication estimation +# ============================================================================ + + +class TestCPCommEstimation: + """Test CP ring attention communication byte estimates.""" + + def test_zero_without_cp(self): + assert estimate_cp_comm_bytes(1, 4096, 25, 8, 128, cp_size=1) == 0 + + def test_scales_linearly_with_seq_len(self): + comm_4k = estimate_cp_comm_bytes(1, 4096, 25, 8, 128, cp_size=2) + comm_8k = estimate_cp_comm_bytes(1, 8192, 25, 8, 128, cp_size=2) + assert comm_8k == 2 * comm_4k + + def test_scales_linearly_with_batch(self): + comm_b1 = estimate_cp_comm_bytes(1, 4096, 25, 8, 128, cp_size=2) + comm_b4 = estimate_cp_comm_bytes(4, 4096, 25, 8, 128, cp_size=2) + assert comm_b4 == 4 * comm_b1 + + def test_known_value_lingua_1b(self): + """Golden value for lingua-1B at S=4096, CP=2.""" + assert estimate_cp_comm_bytes(1, 4096, 25, 8, 128, cp_size=2) == 419_430_400 diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/flops.py b/bionemo-recipes/recipes/opengenome2_llama_native_te/flops.py index f4fb7411fc..ade8d8a12c 100644 --- a/bionemo-recipes/recipes/opengenome2_llama_native_te/flops.py +++ b/bionemo-recipes/recipes/opengenome2_llama_native_te/flops.py @@ -14,7 +14,7 @@ # limitations under the License. # --- BEGIN COPIED FILE NOTICE --- -# This file is copied from: bionemo-recipes/models/esm2/flops.py +# This file is copied from: bionemo-recipes/recipes/llama3_native_te/flops.py # Do not modify this file directly. Instead, modify the source and run: # python ci/scripts/check_copied_files.py --fix # --- END COPIED FILE NOTICE --- diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/tests/test_flops.py b/bionemo-recipes/recipes/opengenome2_llama_native_te/tests/test_flops.py new file mode 100644 index 0000000000..f514205fef --- /dev/null +++ b/bionemo-recipes/recipes/opengenome2_llama_native_te/tests/test_flops.py @@ -0,0 +1,323 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# --- BEGIN COPIED FILE NOTICE --- +# This file is copied from: bionemo-recipes/recipes/llama3_native_te/tests/test_flops.py +# Do not modify this file directly. Instead, modify the source and run: +# python ci/scripts/check_copied_files.py --fix +# --- END COPIED FILE NOTICE --- + +"""Tests for the flops.py FLOPs counting and MFU module.""" + +import sys +from pathlib import Path + +import pytest + + +# Add parent directory so we can import flops +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from flops import ( + MFUTracker, + ModelFLOPsConfig, + compute_flops_analytical, + compute_flops_hyena, + compute_flops_simplified, + estimate_cp_comm_bytes, + from_hf_config, +) + + +# ============================================================================ +# Test configs matching real models +# ============================================================================ + +LLAMA_1B_CONFIG = { + "hidden_size": 2048, + "num_hidden_layers": 25, + "num_attention_heads": 16, + "num_key_value_heads": 8, + "intermediate_size": 6144, + "vocab_size": 128256, + "model_type": "llama", + "hidden_act": "silu", +} + +ESM2_8M_CONFIG = { + "hidden_size": 320, + "num_hidden_layers": 6, + "num_attention_heads": 20, + "intermediate_size": 1280, + "vocab_size": 33, + "model_type": "nv_esm", + "hidden_act": "gelu", +} + +CODONFM_1B_CONFIG = { + "hidden_size": 2048, + "num_hidden_layers": 18, + "num_attention_heads": 16, + "intermediate_size": 8192, +} + + +# ============================================================================ +# Config auto-detection +# ============================================================================ + + +class TestFromHfConfig: + """Test auto-detection of model architecture from config dicts.""" + + def test_llama_detects_gqa_and_swiglu(self): + cfg = from_hf_config(LLAMA_1B_CONFIG) + assert cfg.num_kv_heads == 8 + assert cfg.num_mlp_projections == 3 + assert cfg.head_dim == 128 + + def test_esm2_detects_mha_and_standard_ffn(self): + cfg = from_hf_config(ESM2_8M_CONFIG) + assert cfg.num_kv_heads == 20 + assert cfg.num_mlp_projections == 2 + + def test_codonfm_defaults_to_mha_and_2_proj(self): + cfg = from_hf_config(CODONFM_1B_CONFIG) + assert cfg.num_kv_heads == 16 + assert cfg.num_mlp_projections == 2 + + def test_missing_vocab_defaults_to_no_lm_head(self): + cfg = from_hf_config( + {"hidden_size": 128, "num_hidden_layers": 2, "num_attention_heads": 4, "intermediate_size": 512} + ) + assert cfg.vocab_size == 0 + assert cfg.has_lm_head is False + + def test_overrides_take_precedence(self): + cfg = from_hf_config(ESM2_8M_CONFIG, num_mlp_projections=3, vocab_size=0, has_lm_head=False) + assert cfg.num_mlp_projections == 3 + assert cfg.vocab_size == 0 + assert cfg.has_lm_head is False + + +# ============================================================================ +# Analytical FLOPs formula +# ============================================================================ + + +class TestComputeFlopsAnalytical: + """Test the first-principles analytical FLOPs formula.""" + + def test_training_is_3x_forward(self): + cfg = from_hf_config(LLAMA_1B_CONFIG) + total, breakdown, lm_head = compute_flops_analytical(cfg, 1, 4096) + forward = cfg.num_hidden_layers * sum(breakdown.values()) + lm_head + assert total == 3 * forward + + def test_swiglu_has_3_mlp_projections(self): + cfg = from_hf_config(LLAMA_1B_CONFIG) + _, breakdown, _ = compute_flops_analytical(cfg, 1, 1024) + assert "Gate projection" in breakdown + assert "Up projection" in breakdown + assert "Down projection" in breakdown + + def test_standard_ffn_has_2_mlp_projections(self): + cfg = from_hf_config(ESM2_8M_CONFIG) + _, breakdown, _ = compute_flops_analytical(cfg, 1, 1024) + assert "Gate projection" not in breakdown + assert "Up projection" in breakdown + assert "Down projection" in breakdown + + def test_no_lm_head_when_vocab_zero(self): + cfg = from_hf_config(CODONFM_1B_CONFIG) + _, _, lm_head = compute_flops_analytical(cfg, 1, 1024) + assert lm_head == 0 + + def test_flops_scale_linearly_with_batch(self): + cfg = from_hf_config(LLAMA_1B_CONFIG) + flops_b1, _, _ = compute_flops_analytical(cfg, 1, 1024) + flops_b4, _, _ = compute_flops_analytical(cfg, 4, 1024) + assert flops_b4 == 4 * flops_b1 + + def test_known_value_llama_lingua_1b(self): + """Golden value: validated against PyTorch FlopCounterMode and README formula.""" + cfg = from_hf_config(LLAMA_1B_CONFIG) + total, _, _ = compute_flops_analytical(cfg, 1, 4096) + assert total == 47_687_021_887_488 + + +# ============================================================================ +# Simplified formula +# ============================================================================ + + +class TestComputeFlopsSimplified: + """Test the simplified README formula and its relationship to analytical.""" + + def test_matches_analytical_when_mha_and_i_equals_4h(self): + """ESM2 has MHA + I=4H + 2 projections: same assumptions as simplified formula.""" + cfg = from_hf_config(ESM2_8M_CONFIG) + analytical, _, _ = compute_flops_analytical(cfg, 1, 1024) + simplified = compute_flops_simplified(1, 1024, cfg.hidden_size, cfg.num_hidden_layers, cfg.vocab_size) + assert analytical == simplified + + def test_differs_when_gqa_or_swiglu(self): + """GQA + SwiGLU breaks the simplified formula's assumptions.""" + cfg_dict = { + **LLAMA_1B_CONFIG, + "num_attention_heads": 32, + "num_key_value_heads": 8, + "intermediate_size": 8192, + "num_hidden_layers": 16, + } + cfg = from_hf_config(cfg_dict) + analytical, _, _ = compute_flops_analytical(cfg, 1, 4096) + simplified = compute_flops_simplified(1, 4096, cfg.hidden_size, cfg.num_hidden_layers, cfg.vocab_size) + assert analytical != simplified + + +# ============================================================================ +# Hyena formula +# ============================================================================ + + +class TestComputeFlopsHyena: + """Test the Hyena (Evo2) FLOPs formula.""" + + @pytest.fixture() + def hyena_config(self): + return ModelFLOPsConfig( + hidden_size=1024, + num_hidden_layers=4, + num_attention_heads=8, + num_kv_heads=8, + head_dim=128, + intermediate_size=4096, + num_mlp_projections=3, + vocab_size=512, + has_lm_head=True, + ) + + def test_scales_subquadratically(self, hyena_config): + """Hyena uses O(S log S) convolution, not O(S^2) attention.""" + flops_1k = compute_flops_hyena(hyena_config, 1, 1024, hyena_layer_counts={"S": 0, "D": 0, "H": 4, "A": 0}) + flops_2k = compute_flops_hyena(hyena_config, 1, 2048, hyena_layer_counts={"S": 0, "D": 0, "H": 4, "A": 0}) + ratio = flops_2k / flops_1k + assert ratio < 3.0 # S^2 would give ~4x; Hyena should be well below + + def test_hybrid_attention_adds_quadratic_cost(self, hyena_config): + """Adding standard attention layers increases FLOPs due to S^2 term.""" + all_hyena = compute_flops_hyena(hyena_config, 1, 4096, hyena_layer_counts={"S": 0, "D": 0, "H": 4, "A": 0}) + with_attn = compute_flops_hyena(hyena_config, 1, 4096, hyena_layer_counts={"S": 0, "D": 0, "H": 2, "A": 2}) + assert with_attn > all_hyena + + +# ============================================================================ +# MFUTracker +# ============================================================================ + + +class TestMFUTracker: + """Test the MFUTracker class used by training scripts.""" + + def test_from_config_dict(self): + tracker = MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0) + assert tracker.total_flops == 47_687_021_887_488 + assert tracker.per_gpu_flops == tracker.total_flops + + def test_multi_gpu_divides_flops(self): + single = MFUTracker.from_config_dict( + LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, num_gpus=1, peak_tflops=155.0 + ) + multi = MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, num_gpus=2, peak_tflops=155.0) + assert multi.per_gpu_flops == single.total_flops // 2 + + def test_compute_mfu_correctness(self): + tracker = MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0) + result = tracker.compute_mfu(step_time=0.5) + expected_tflops = tracker.per_gpu_flops / 0.5 / 1e12 + expected_mfu = expected_tflops / 155.0 * 100 + assert abs(result["mfu"] - expected_mfu) < 0.01 + assert abs(result["tflops_per_gpu"] - expected_tflops) < 0.01 + + def test_mfu_inversely_proportional_to_step_time(self): + tracker = MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0) + fast = tracker.compute_mfu(step_time=0.5) + slow = tracker.compute_mfu(step_time=1.0) + assert abs(fast["mfu"] - 2 * slow["mfu"]) < 0.01 + + def test_all_formula_options(self): + for formula in ["analytical", "simplified", "hyena"]: + if formula == "hyena": + cfg = ModelFLOPsConfig( + hidden_size=1024, + num_hidden_layers=4, + num_attention_heads=8, + num_kv_heads=8, + head_dim=128, + intermediate_size=4096, + num_mlp_projections=3, + vocab_size=512, + has_lm_head=True, + ) + tracker = MFUTracker(cfg, batch_size=1, seq_len=4096, peak_tflops=155.0, formula=formula) + else: + tracker = MFUTracker.from_config_dict( + LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0, formula=formula + ) + assert tracker.total_flops > 0 + + def test_invalid_formula_raises(self): + with pytest.raises(ValueError, match="Unknown formula"): + MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0, formula="bad") + + def test_cp_communication_estimate(self): + tracker = MFUTracker.from_config_dict( + LLAMA_1B_CONFIG, batch_size=1, seq_len=16384, num_gpus=2, parallelism={"dp": 1, "cp": 2}, peak_tflops=155.0 + ) + assert tracker.comm_bytes > 0 + overhead = tracker.estimate_comm_overhead(step_time=1.0, measured_bw_gbps=6.6) + assert overhead["estimated_comm_time"] > 0 + assert 0 < overhead["comm_pct"] < 100 + + def test_no_comm_single_gpu(self): + tracker = MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0) + assert tracker.comm_bytes == 0 + + +# ============================================================================ +# Communication estimation +# ============================================================================ + + +class TestCPCommEstimation: + """Test CP ring attention communication byte estimates.""" + + def test_zero_without_cp(self): + assert estimate_cp_comm_bytes(1, 4096, 25, 8, 128, cp_size=1) == 0 + + def test_scales_linearly_with_seq_len(self): + comm_4k = estimate_cp_comm_bytes(1, 4096, 25, 8, 128, cp_size=2) + comm_8k = estimate_cp_comm_bytes(1, 8192, 25, 8, 128, cp_size=2) + assert comm_8k == 2 * comm_4k + + def test_scales_linearly_with_batch(self): + comm_b1 = estimate_cp_comm_bytes(1, 4096, 25, 8, 128, cp_size=2) + comm_b4 = estimate_cp_comm_bytes(4, 4096, 25, 8, 128, cp_size=2) + assert comm_b4 == 4 * comm_b1 + + def test_known_value_lingua_1b(self): + """Golden value for lingua-1B at S=4096, CP=2.""" + assert estimate_cp_comm_bytes(1, 4096, 25, 8, 128, cp_size=2) == 419_430_400 diff --git a/ci/scripts/check_copied_files.py b/ci/scripts/check_copied_files.py index 3a5094f6bf..21c573fec5 100755 --- a/ci/scripts/check_copied_files.py +++ b/ci/scripts/check_copied_files.py @@ -199,13 +199,18 @@ def _compare_file_contents(source_file: Path, dest_file: Path, source_display: s "bionemo-recipes/models/codonfm/modeling_codonfm_te.py": [ "bionemo-recipes/recipes/codonfm_native_te/modeling_codonfm_te.py", ], - # FLOPs / MFU module - synced to recipes - "bionemo-recipes/models/esm2/flops.py": [ - "bionemo-recipes/recipes/llama3_native_te/flops.py", + # FLOPs / MFU module - synced across recipes + "bionemo-recipes/recipes/llama3_native_te/flops.py": [ "bionemo-recipes/recipes/esm2_native_te/flops.py", "bionemo-recipes/recipes/codonfm_native_te/flops.py", "bionemo-recipes/recipes/opengenome2_llama_native_te/flops.py", ], + # FLOPs tests - synced across recipes + "bionemo-recipes/recipes/llama3_native_te/tests/test_flops.py": [ + "bionemo-recipes/recipes/esm2_native_te/tests/test_flops.py", + "bionemo-recipes/recipes/codonfm_native_te/tests/test_flops.py", + "bionemo-recipes/recipes/opengenome2_llama_native_te/tests/test_flops.py", + ], # Common test library - synced between models "bionemo-recipes/models/esm2/tests/common": [ "bionemo-recipes/models/llama3/tests/common", From 27255a169f3b1d17e9ad77ae52f59ed9c0753f96 Mon Sep 17 00:00:00 2001 From: Gagan Kaushik Date: Mon, 13 Apr 2026 20:30:25 +0000 Subject: [PATCH 11/24] Add MFU tracking documentation to recipe READMEs Document the log_mfu=true flag and flops.py CLI utilities in all 4 native_te recipe READMEs: llama3, esm2, codonfm, opengenome2. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Gagan Kaushik --- .../recipes/codonfm_native_te/README.md | 17 +++++++++++++++++ .../recipes/esm2_native_te/README.md | 19 +++++++++++++++++++ .../recipes/llama3_native_te/README.md | 19 +++++++++++++++++++ .../opengenome2_llama_native_te/README.md | 19 +++++++++++++++++++ 4 files changed, 74 insertions(+) diff --git a/bionemo-recipes/recipes/codonfm_native_te/README.md b/bionemo-recipes/recipes/codonfm_native_te/README.md index de31d5b49d..71d3a59ce7 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/README.md +++ b/bionemo-recipes/recipes/codonfm_native_te/README.md @@ -177,6 +177,23 @@ python train_fsdp2.py \ A final model suitable for uploading to the Hugging Face Hub can be exported at the end of training by setting `checkpoint.save_final_model=true`. +## MFU Tracking + +Enable per-step Model FLOPs Utilization (MFU) logging during training by adding `log_mfu=true`: + +```bash +torchrun --nproc_per_node=1 train_fsdp2.py --config-name encodon_1b log_mfu=true +``` + +This logs MFU (%), TFLOPS/GPU, and step time at each optimizer step. The module auto-detects model architecture from the model config. + +The `flops.py` CLI provides standalone utilities: + +```bash +python flops.py gpu-info # Show GPU and peak TFLOPS +torchrun --nproc_per_node=2 flops.py bandwidth # Measure P2P GPU bandwidth +``` + ## Developer Guide ### Running Tests diff --git a/bionemo-recipes/recipes/esm2_native_te/README.md b/bionemo-recipes/recipes/esm2_native_te/README.md index 330b8152db..5e6efab3f1 100644 --- a/bionemo-recipes/recipes/esm2_native_te/README.md +++ b/bionemo-recipes/recipes/esm2_native_te/README.md @@ -374,6 +374,25 @@ output = model(**inputs) - [ESM-2 Training with Accelerate](../esm2_accelerate_te/README.md) +## MFU Tracking + +Enable per-step Model FLOPs Utilization (MFU) logging during training by adding `log_mfu=true`: + +```bash +torchrun --nproc_per_node=2 train_fsdp2.py --config-name L1_3B log_mfu=true +``` + +This logs MFU (%), TFLOPS/GPU, and step time at each optimizer step. The module auto-detects model architecture (MHA, standard FFN, etc.) from the model config. + +The `flops.py` CLI provides standalone utilities: + +```bash +python flops.py gpu-info # Show GPU and peak TFLOPS +python flops.py flops --config-path ./model_configs/nvidia/esm2_t6_8M_UR50D # Compute FLOPs +python flops.py flops --config-path nvidia/esm2_t36_3B_UR50D # FLOPs from HF Hub config +torchrun --nproc_per_node=2 flops.py bandwidth # Measure P2P GPU bandwidth +``` + ## Developer Guide ### Running Tests diff --git a/bionemo-recipes/recipes/llama3_native_te/README.md b/bionemo-recipes/recipes/llama3_native_te/README.md index 2be3b0f11e..71c68ad78a 100644 --- a/bionemo-recipes/recipes/llama3_native_te/README.md +++ b/bionemo-recipes/recipes/llama3_native_te/README.md @@ -412,6 +412,25 @@ Once converted, the model can be loaded by any library that supports Llama 3, su vllm serve path/to/hf_converted_model ``` +## MFU Tracking + +Enable per-step Model FLOPs Utilization (MFU) logging during training by adding `log_mfu=true`: + +```bash +torchrun --nproc_per_node=2 train_fsdp2_cp.py --config-name L2_lingua_1b log_mfu=true +``` + +This logs MFU (%), TFLOPS/GPU, and step time at each optimizer step. The module auto-detects model architecture (GQA, SwiGLU, etc.) from the model config. + +The `flops.py` CLI provides standalone utilities: + +```bash +python flops.py gpu-info # Show GPU and peak TFLOPS +python flops.py flops --config-path ./model_configs/lingua-1B # Compute FLOPs for a config +python flops.py cp-comm --config-path ./model_configs/lingua-1B --cp-size 2 # CP comm estimate +torchrun --nproc_per_node=2 flops.py bandwidth # Measure P2P GPU bandwidth +``` + ## Developer Guide ### Running tests diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/README.md b/bionemo-recipes/recipes/opengenome2_llama_native_te/README.md index 52e1b45986..435b497573 100644 --- a/bionemo-recipes/recipes/opengenome2_llama_native_te/README.md +++ b/bionemo-recipes/recipes/opengenome2_llama_native_te/README.md @@ -411,6 +411,25 @@ Validation logging during training can be enabled with `validation.enabled=true` validation data (e.g. a JSONL file). The `og2_7b_thd_gqa` config enables validation by default. Control evaluation frequency with `validation.eval_interval` and `validation.num_batches`.This can be helpful when debugging training convergence. +## MFU Tracking + +Enable per-step Model FLOPs Utilization (MFU) logging during training by adding `log_mfu=true`: + +```bash +torchrun --nproc_per_node=2 train_fsdp2_cp.py log_mfu=true +``` + +This logs MFU (%), TFLOPS/GPU, and step time at each optimizer step. The module auto-detects model architecture (GQA, SwiGLU, etc.) from the model config. + +The `flops.py` CLI provides standalone utilities: + +```bash +python flops.py gpu-info # Show GPU and peak TFLOPS +python flops.py flops --config-path ./model_configs/meta-llama/Llama-3.1-8B # Compute FLOPs +python flops.py cp-comm --config-path ./model_configs/meta-llama/Llama-3.1-8B --cp-size 2 # CP comm estimate +torchrun --nproc_per_node=2 flops.py bandwidth # Measure P2P GPU bandwidth +``` + ## Developer Guide ### Running tests From e6468fc6dbae0ab9ff45fe405cb8cb223545805e Mon Sep 17 00:00:00 2001 From: Gagan Kaushik Date: Sat, 18 Apr 2026 00:48:56 +0000 Subject: [PATCH 12/24] Consolidate MFU tracking into perf_logger, address PR review feedback Reworks MFU tracking per reviewer feedback on #1548: - Delete per-recipe flops.py, test_flops.py, and the CLI entirely - Inline ~30-line FLOPs helper into each recipe's existing perf_logger.py - MFU metrics (train/tflops_per_gpu, train/mfu_pct) flow through the existing torchmetrics -> WANDB path, respecting logging_frequency - Drop comm-overhead estimation; will be a separate future PR The new formula is per_token_flops(seq_len) * num_unpadded_tokens_on_rank. The unpadded-tokens counter (already used by tokens_per_second_per_gpu) is per-rank after DP/CP sharding and accumulated across grad-acc micro-batches, so the formula works uniformly across DDP/FSDP2/FSDP2+CP/DDP+CP/mFSDP and across BSHD and THD (sequence packing) with no per-strategy factors. Net: -3000 / +300 lines. Training scripts lose all MFU scaffolding; the only change per script is one extra kwarg on the PerfLogger constructor. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Gagan Kaushik --- .../recipes/codonfm_native_te/README.md | 14 +- .../recipes/codonfm_native_te/flops.py | 714 ------------------ .../recipes/codonfm_native_te/perf_logger.py | 107 ++- .../codonfm_native_te/tests/test_flops.py | 323 -------- .../recipes/codonfm_native_te/train_fsdp2.py | 29 +- .../recipes/esm2_native_te/README.md | 16 +- .../recipes/esm2_native_te/flops.py | 714 ------------------ .../recipes/esm2_native_te/perf_logger.py | 102 ++- .../esm2_native_te/tests/test_flops.py | 323 -------- .../recipes/esm2_native_te/train_ddp.py | 30 +- .../recipes/esm2_native_te/train_ddp_cp.py | 30 +- .../recipes/esm2_native_te/train_fsdp2.py | 30 +- .../recipes/esm2_native_te/train_fsdp2_cp.py | 30 +- .../recipes/esm2_native_te/train_mfsdp.py | 30 +- .../recipes/llama3_native_te/README.md | 16 +- .../recipes/llama3_native_te/flops.py | 708 ----------------- .../recipes/llama3_native_te/perf_logger.py | 108 ++- .../llama3_native_te/tests/test_flops.py | 317 -------- .../tests/test_perf_logger.py | 104 ++- .../recipes/llama3_native_te/train_ddp.py | 30 +- .../recipes/llama3_native_te/train_fsdp2.py | 30 +- .../llama3_native_te/train_fsdp2_cp.py | 30 +- .../opengenome2_llama_native_te/README.md | 16 +- .../opengenome2_llama_native_te/flops.py | 714 ------------------ .../perf_logger.py | 101 ++- .../tests/test_flops.py | 323 -------- .../train_fsdp2.py | 29 +- .../train_fsdp2_cp.py | 29 +- ci/scripts/check_copied_files.py | 12 - 29 files changed, 601 insertions(+), 4458 deletions(-) delete mode 100644 bionemo-recipes/recipes/codonfm_native_te/flops.py delete mode 100644 bionemo-recipes/recipes/codonfm_native_te/tests/test_flops.py delete mode 100644 bionemo-recipes/recipes/esm2_native_te/flops.py delete mode 100644 bionemo-recipes/recipes/esm2_native_te/tests/test_flops.py delete mode 100644 bionemo-recipes/recipes/llama3_native_te/flops.py delete mode 100644 bionemo-recipes/recipes/llama3_native_te/tests/test_flops.py delete mode 100644 bionemo-recipes/recipes/opengenome2_llama_native_te/flops.py delete mode 100644 bionemo-recipes/recipes/opengenome2_llama_native_te/tests/test_flops.py diff --git a/bionemo-recipes/recipes/codonfm_native_te/README.md b/bionemo-recipes/recipes/codonfm_native_te/README.md index 71d3a59ce7..e2c9a05089 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/README.md +++ b/bionemo-recipes/recipes/codonfm_native_te/README.md @@ -185,14 +185,16 @@ Enable per-step Model FLOPs Utilization (MFU) logging during training by adding torchrun --nproc_per_node=1 train_fsdp2.py --config-name encodon_1b log_mfu=true ``` -This logs MFU (%), TFLOPS/GPU, and step time at each optimizer step. The module auto-detects model architecture from the model config. +This adds two metrics at each logging interval, emitted alongside existing metrics via WANDB and +stdout: -The `flops.py` CLI provides standalone utilities: +- `train/tflops_per_gpu` — achieved BF16 TFLOPS per GPU +- `train/mfu_pct` — MFU as a percentage of the GPU's peak dense BF16 TFLOPS -```bash -python flops.py gpu-info # Show GPU and peak TFLOPS -torchrun --nproc_per_node=2 flops.py bandwidth # Measure P2P GPU bandwidth -``` +The FLOPs formula auto-detects model architecture from the model config (MHA, standard FFN, +vocabulary size) and scales with the actual unpadded token count on each rank. This means it +naturally handles gradient accumulation, data parallelism, BSHD, and THD (sequence packing) +without per-strategy code paths. The implementation lives in `perf_logger.py`. ## Developer Guide diff --git a/bionemo-recipes/recipes/codonfm_native_te/flops.py b/bionemo-recipes/recipes/codonfm_native_te/flops.py deleted file mode 100644 index ade8d8a12c..0000000000 --- a/bionemo-recipes/recipes/codonfm_native_te/flops.py +++ /dev/null @@ -1,714 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# --- BEGIN COPIED FILE NOTICE --- -# This file is copied from: bionemo-recipes/recipes/llama3_native_te/flops.py -# Do not modify this file directly. Instead, modify the source and run: -# python ci/scripts/check_copied_files.py --fix -# --- END COPIED FILE NOTICE --- - -"""Architecture-independent FLOPs counting, MFU calculation, and communication overhead estimation. - -Supports transformer architectures (Llama, ESM2, CodonFM, etc.) and Hyena (Evo2). -Designed to be copied to any recipe via check_copied_files.py and hooked into -training scripts for live MFU tracking. - -Usage as a library (in training scripts): - from flops import MFUTracker, from_hf_config - tracker = MFUTracker.from_config_dict(config_dict, batch_size=4, seq_len=4096, num_gpus=2) - mfu_info = tracker.compute_mfu(step_time=0.5) - -Usage as a CLI: - python flops.py gpu-info - python flops.py flops --config-path ./model_configs/lingua-1B - python flops.py cp-comm --config-path ./model_configs/lingua-1B --cp-size 2 - torchrun --nproc_per_node=2 flops.py bandwidth -""" - -import math -import time -from dataclasses import dataclass - -import torch -import torch.distributed as dist - - -# ============================================================================= -# GPU Peak TFLOPS -# ============================================================================= - -# Dense (without sparsity) BF16 tensor core peak TFLOPS. -# NVIDIA product pages often list the 2x sparse value; dense = sparse / 2. -# Sources: -# H100/H200/GH200: https://www.nvidia.com/en-us/data-center/h100/ (1,979 TFLOPS sparse) -# A100: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf -# A6000: https://www.nvidia.com/content/dam/en-zz/Solutions/design-visualization/quadro-product-literature/proviz-print-nvidia-rtx-a6000-datasheet-us-nvidia-1454980-r9-web%20(1).pdf -# L40: https://www.nvidia.com/en-us/data-center/l40/ -# B200/GB200: https://www.nvidia.com/en-us/data-center/dgx-b200/ -# B300/GB300: https://www.nvidia.com/en-us/data-center/gb300-nvl72/ (360 PFLOPS sparse / 72 GPUs / 2) -GPU_PEAK_TFLOPS_BF16 = { - "H100": 989.0, - "H200": 989.0, - "A100": 312.0, - "A6000": 155.0, - "L40": 181.0, - "GH200": 989.0, - "B200": 2250.0, - "GB200": 2250.0, - "B300": 2500.0, - "GB300": 2500.0, -} - - -def detect_gpu_peak_tflops(): - """Auto-detect GPU peak bf16 TFLOPS from device name via substring match.""" - device_name = torch.cuda.get_device_name(0) - for gpu_key, tflops in GPU_PEAK_TFLOPS_BF16.items(): - if gpu_key.lower() in device_name.lower(): - return tflops, device_name - return None, device_name - - -# ============================================================================= -# Model FLOPs Config -# ============================================================================= - -# Model types that use gated MLP (SwiGLU/GeGLU) with 3 projections instead of 2. -GATED_MLP_MODEL_TYPES = frozenset({"llama", "mistral", "qwen2"}) - - -@dataclass(frozen=True) -class ModelFLOPsConfig: - """Architecture-independent parameters for FLOPs calculation. - - Can be constructed manually or via from_hf_config() for auto-detection. - """ - - hidden_size: int # H - num_hidden_layers: int # L - num_attention_heads: int # n_heads - num_kv_heads: int # n_kv (== n_heads for MHA) - head_dim: int # H // n_heads - intermediate_size: int # I (FFN intermediate dimension) - num_mlp_projections: int # 2 (standard FFN) or 3 (SwiGLU/GLU) - vocab_size: int # V - has_lm_head: bool # True for LM models, False for ViT etc. - - -def from_hf_config(config_dict, **overrides): - """Create ModelFLOPsConfig from an HF-compatible config dict. - - Auto-detects architecture: - - GQA vs MHA: from num_key_value_heads (absent = MHA) - - Gated MLP (3 proj) vs standard FFN (2 proj): from model_type - - LM head: from vocab_size > 0 - - Args: - config_dict: Dict with standard HF config keys (hidden_size, num_hidden_layers, etc.). - Works with config.json dicts, config.to_dict(), or MODEL_PRESETS dicts. - **overrides: Explicit overrides for any field (e.g., num_mlp_projections=3). - """ - h = config_dict["hidden_size"] - n_heads = config_dict["num_attention_heads"] - n_kv = config_dict.get("num_key_value_heads", n_heads) - vocab = config_dict.get("vocab_size", 0) - model_type = config_dict.get("model_type", "") - - # Detect gated MLP (3 projections) vs standard FFN (2 projections). - # Llama/Mistral/Qwen use SwiGLU (gate + up + down = 3 projections). - # ESM2/CodonFM/Geneformer/BERT use standard FFN (up + down = 2 projections). - num_mlp_proj = 3 if model_type in GATED_MLP_MODEL_TYPES else 2 - - kwargs = { - "hidden_size": h, - "num_hidden_layers": config_dict["num_hidden_layers"], - "num_attention_heads": n_heads, - "num_kv_heads": n_kv, - "head_dim": h // n_heads, - "intermediate_size": config_dict["intermediate_size"], - "num_mlp_projections": num_mlp_proj, - "vocab_size": vocab, - "has_lm_head": vocab > 0, - } - kwargs.update(overrides) - return ModelFLOPsConfig(**kwargs) - - -# ============================================================================= -# FLOPs Formulas -# ============================================================================= - - -def compute_flops_analytical(config, batch_size, seq_len): - """First-principles FLOPs for any transformer (GQA/MHA, SwiGLU/GELU). - - Counts matmul FLOPs only (2 FLOPs per multiply-accumulate). Excludes softmax, - layer norms, activations, and element-wise ops. - - Handles: - - GQA vs MHA: K/V projection sizes based on config.num_kv_heads - - SwiGLU vs standard FFN: 2 or 3 MLP projections - - LM head presence - - Returns: - (total_training_flops, per_layer_breakdown_dict, lm_head_forward_flops) - """ - b, s, h = batch_size, seq_len, config.hidden_size - kv_dim = config.num_kv_heads * config.head_dim - - breakdown = { - "Q projection": 2 * b * s * h * h, - "K projection": 2 * b * s * h * kv_dim, - "V projection": 2 * b * s * h * kv_dim, - "O projection": 2 * b * s * h * h, - "Attn logits": 2 * b * s * s * h, - "Attn values": 2 * b * s * s * h, - } - - ffn = config.intermediate_size - if config.num_mlp_projections == 3: - # SwiGLU/GeGLU: gate + up + down = 3 matmuls - breakdown["Gate projection"] = 2 * b * s * h * ffn - breakdown["Up projection"] = 2 * b * s * h * ffn - breakdown["Down projection"] = 2 * b * s * ffn * h - else: - # Standard FFN: up + down = 2 matmuls - breakdown["Up projection"] = 2 * b * s * h * ffn - breakdown["Down projection"] = 2 * b * s * ffn * h - - per_layer_fwd = sum(breakdown.values()) - lm_head_fwd = 2 * b * s * h * config.vocab_size if config.has_lm_head else 0 - total_fwd = config.num_hidden_layers * per_layer_fwd + lm_head_fwd - total_training = 3 * total_fwd - - return total_training, breakdown, lm_head_fwd - - -def compute_flops_simplified(batch_size, seq_len, hidden_size, num_layers, vocab_size): - """Simplified formula assuming standard MHA + standard FFN with I=4H. - - This is the formula from the Llama3 README: - (24*B*S*H^2 + 4*B*S^2*H) * 3*L + 6*B*S*H*V - - The 24*H^2 coefficient assumes: 8*H^2 for 4 attention projections (MHA) + - 16*H^2 for 2 MLP projections with I=4H. - """ - b, s, h = batch_size, seq_len, hidden_size - return (24 * b * s * h * h + 4 * b * s * s * h) * (3 * num_layers) + (6 * b * s * h * vocab_size) - - -def compute_flops_hyena(config, batch_size, seq_len, hyena_layer_counts=None): - """FLOPs for Hyena-based models (Evo2). - - Based on evo2_provider.py. Hyena replaces attention with O(N) FFT convolution. - - Args: - config: ModelFLOPsConfig with model dimensions. - batch_size: Batch size. - seq_len: Sequence length. - hyena_layer_counts: Optional dict {"S": n, "D": n, "H": n, "A": n} for - short/medium/long conv and attention layer counts. If None, assumes - all layers are long-conv Hyena (H=num_layers, no attention). - """ - b, s, h = batch_size, seq_len, config.hidden_size - ffn = config.intermediate_size - - if hyena_layer_counts is None: - hyena_layer_counts = {"S": 0, "D": 0, "H": config.num_hidden_layers, "A": 0} - - # Common per-layer FLOPs - pre_attn_qkv_proj = 2 * 3 * b * s * h * h - post_attn_proj = 2 * b * s * h * h - glu_ffn = 2 * 3 * b * s * ffn * h - - # Layer-type-specific FLOPs (defaults from evo2_provider.py) - attn = 2 * 2 * b * h * s * s # Standard S^2 attention - hyena_proj = 2 * 3 * b * s * 3 * h # short_conv_L=3 default - hyena_short_conv = 2 * b * s * 7 * h # short_conv_len=7 - hyena_medium_conv = 2 * b * s * 128 * h # medium_conv_len=128 - hyena_long_fft = b * 10 * s * math.log2(max(s, 2)) * h - - n_s = hyena_layer_counts.get("S", 0) - n_d = hyena_layer_counts.get("D", 0) - n_h = hyena_layer_counts.get("H", 0) - n_a = hyena_layer_counts.get("A", 0) - - logits = 2 * b * s * h * config.vocab_size if config.has_lm_head else 0 - - total_fwd = ( - logits - + config.num_hidden_layers * (pre_attn_qkv_proj + post_attn_proj + glu_ffn) - + n_a * attn - + (n_s + n_d + n_h) * hyena_proj - + n_s * hyena_short_conv - + n_d * hyena_medium_conv - + int(n_h * hyena_long_fft) - ) - - return 3 * total_fwd - - -# ============================================================================= -# MFU Tracker -# ============================================================================= - - -class MFUTracker: - """Tracks MFU during training. Initialize once, call compute_mfu() per step. - - Usage: - tracker = MFUTracker.from_config_dict(config_dict, batch_size=4, seq_len=4096, num_gpus=2) - # In training loop: - mfu_info = tracker.compute_mfu(step_time=0.5) - print(f"MFU: {mfu_info['mfu']:.1f}%") - """ - - def __init__( - self, - config, - batch_size, - seq_len, - num_gpus=1, - parallelism=None, - peak_tflops=None, - formula="analytical", - hyena_layer_counts=None, - ): - """Initialize MFU tracker. - - Args: - config: ModelFLOPsConfig instance. - batch_size: Micro batch size per GPU. - seq_len: Sequence length. - num_gpus: Total number of GPUs. - parallelism: Dict of parallelism dimensions, e.g. {"dp": 1, "cp": 2, "tp": 1}. - Used for communication overhead estimation. - peak_tflops: GPU peak bf16 TFLOPS. Auto-detected if None. - formula: "analytical", "simplified", or "hyena". - hyena_layer_counts: For Hyena formula, dict of layer type counts. - """ - self.config = config - self.batch_size = batch_size - self.seq_len = seq_len - self.num_gpus = num_gpus - self.parallelism = parallelism or {} - self.formula = formula - - if formula == "analytical": - self.total_flops, self.breakdown, self.lm_head_flops = compute_flops_analytical( - config, batch_size, seq_len - ) - elif formula == "simplified": - self.total_flops = compute_flops_simplified( - batch_size, seq_len, config.hidden_size, config.num_hidden_layers, config.vocab_size - ) - self.breakdown = None - self.lm_head_flops = 0 - elif formula == "hyena": - self.total_flops = compute_flops_hyena(config, batch_size, seq_len, hyena_layer_counts) - self.breakdown = None - self.lm_head_flops = 0 - else: - raise ValueError(f"Unknown formula: {formula!r}. Use 'analytical', 'simplified', or 'hyena'.") - - self.per_gpu_flops = self.total_flops // max(num_gpus, 1) - - if peak_tflops is not None: - self.peak_tflops = peak_tflops - self.device_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "unknown" - else: - detected, self.device_name = detect_gpu_peak_tflops() - self.peak_tflops = detected - - self.comm_bytes = self._estimate_comm() - - @classmethod - def from_config_dict(cls, config_dict, batch_size, seq_len, **kwargs): - """Create from an HF config dict with auto-detection.""" - config = from_hf_config(config_dict) - return cls(config, batch_size, seq_len, **kwargs) - - def compute_mfu(self, step_time): - """Compute MFU from measured step time. - - Args: - step_time: Wall-clock time for one training step (seconds). - - Returns: - Dict with mfu (%), tflops_per_gpu, per_gpu_flops, total_flops, step_time. - """ - tflops = self.per_gpu_flops / step_time / 1e12 - mfu = tflops / self.peak_tflops * 100 if self.peak_tflops else 0.0 - return { - "mfu": mfu, - "tflops_per_gpu": tflops, - "per_gpu_flops": self.per_gpu_flops, - "total_flops": self.total_flops, - "step_time": step_time, - } - - def estimate_comm_overhead(self, step_time, measured_bw_gbps=None): - """Estimate communication overhead as a fraction of step time. - - Args: - step_time: Measured step time in seconds. - measured_bw_gbps: Measured P2P bandwidth in GB/s. If None, uses a default estimate. - - Returns: - Dict with comm_bytes, estimated_comm_time, comm_pct. - """ - bw = measured_bw_gbps or 6.0 # Default ~PCIe Gen3 x8 - comm_time = self.comm_bytes / (bw * 1e9) if bw > 0 else 0.0 - comm_pct = comm_time / step_time * 100 if step_time > 0 else 0.0 - return {"comm_bytes": self.comm_bytes, "estimated_comm_time": comm_time, "comm_pct": comm_pct} - - def _estimate_comm(self): - """Estimate total communication bytes per step based on parallelism.""" - total = 0 - cp_size = self.parallelism.get("cp", 1) - dp_size = self.parallelism.get("dp", 1) - - if cp_size > 1: - total += estimate_cp_comm_bytes( - self.batch_size, - self.seq_len, - self.config.num_hidden_layers, - self.config.num_kv_heads, - self.config.head_dim, - cp_size, - ) - - if dp_size > 1: - # FSDP reduce-scatter estimate: ~2 * model_params * dtype_bytes * (dp-1)/dp - model_params = _estimate_model_params(self.config) - total += 2 * model_params * 2 * (dp_size - 1) // dp_size - - return total - - def summary(self): - """Return a human-readable summary string.""" - lines = [ - f"MFUTracker: {self.formula} formula, {self.num_gpus} GPU(s)", - f" Model: H={self.config.hidden_size}, L={self.config.num_hidden_layers}," - f" heads={self.config.num_attention_heads}, kv_heads={self.config.num_kv_heads}," - f" I={self.config.intermediate_size}, V={self.config.vocab_size}", - f" MLP projections: {self.config.num_mlp_projections}" - f" ({'SwiGLU/GLU' if self.config.num_mlp_projections == 3 else 'standard FFN'})", - f" Batch: B={self.batch_size}, S={self.seq_len}", - f" Total FLOPs/step: {format_flops(self.total_flops)} ({format_flops_exact(self.total_flops)})", - f" Per-GPU FLOPs: {format_flops(self.per_gpu_flops)}", - f" GPU: {self.device_name} (Peak: {self.peak_tflops} TFLOPS)" if self.peak_tflops else " GPU: unknown", - ] - if self.parallelism: - lines.append(f" Parallelism: {self.parallelism}") - if self.comm_bytes > 0: - lines.append(f" Estimated comm: {format_bytes(self.comm_bytes)}/step") - return "\n".join(lines) - - -# ============================================================================= -# Communication Estimation -# ============================================================================= - - -def estimate_cp_comm_bytes(b, s, num_layers, n_kv_heads, head_dim, cp_size, dtype_bytes=2): - """Estimate total bytes transferred for CP ring attention per training step. - - Ring attention sends local KV chunks around the ring. Per layer forward: - (cp-1) steps, each sending B * (S/cp) * 2 * kv_dim * dtype_bytes. - Training = ~2x forward communication (forward sends KV, backward sends dKV). - """ - if cp_size <= 1: - return 0 - s_local = s // cp_size - kv_dim = n_kv_heads * head_dim - per_layer_fwd = (cp_size - 1) * b * s_local * 2 * kv_dim * dtype_bytes - return 2 * num_layers * per_layer_fwd - - -def _estimate_model_params(config): - """Rough parameter count estimate from config dimensions.""" - h = config.hidden_size - kv_dim = config.num_kv_heads * config.head_dim - attn_params = h * h + 2 * h * kv_dim + h * h # Q + K + V + O - mlp_params = config.num_mlp_projections * h * config.intermediate_size - layer_params = attn_params + mlp_params - total = config.num_hidden_layers * layer_params - if config.has_lm_head: - total += config.vocab_size * h * 2 # embed + lm_head - return total - - -# ============================================================================= -# Utilities -# ============================================================================= - - -def measure_bus_bandwidth(device, world_size, num_iters=20, num_elements=10_000_000): - """Measure unidirectional P2P bandwidth via send/recv (matches CP ring pattern).""" - if world_size <= 1: - return 0.0 - - rank = dist.get_rank() - tensor = torch.randn(num_elements, device=device, dtype=torch.bfloat16) - peer = 1 - rank - - for _ in range(5): - if rank == 0: - dist.send(tensor, dst=peer) - dist.recv(tensor, src=peer) - else: - dist.recv(tensor, src=peer) - dist.send(tensor, dst=peer) - torch.cuda.synchronize() - - dist.barrier() - torch.cuda.synchronize() - start = time.perf_counter() - for _ in range(num_iters): - if rank == 0: - dist.send(tensor, dst=peer) - else: - dist.recv(tensor, src=peer) - torch.cuda.synchronize() - elapsed = time.perf_counter() - start - - data_bytes = tensor.nelement() * tensor.element_size() - return num_iters * data_bytes / elapsed / 1e9 - - -def load_model_config(config_path): - """Load model config dict from a local path or HuggingFace model ID. - - Supports: - - Local directory: ./model_configs/lingua-1B (reads config.json inside) - - Local file: ./model_configs/lingua-1B/config.json - - HF model ID: nvidia/esm2_t36_3B_UR50D (fetches from HuggingFace Hub) - """ - import json - from pathlib import Path - - path = Path(config_path) - if path.is_dir(): - path = path / "config.json" - if path.exists(): - return json.loads(path.read_text()) - - # Fall back to HuggingFace Hub - from transformers import AutoConfig - - hf_config = AutoConfig.from_pretrained(config_path, trust_remote_code=True) - return hf_config.to_dict() - - -# ============================================================================= -# Formatting -# ============================================================================= - - -def format_flops(flops): - """Format FLOPs with appropriate unit (G/T/P).""" - if flops >= 1e15: - return f"{flops / 1e15:.2f} P" - elif flops >= 1e12: - return f"{flops / 1e12:.2f} T" - elif flops >= 1e9: - return f"{flops / 1e9:.2f} G" - else: - return f"{flops:.2e}" - - -def format_flops_exact(flops): - """Format FLOPs as the full integer with commas.""" - return f"{int(flops):,}" - - -def format_bytes(num_bytes): - """Format bytes with appropriate unit.""" - if num_bytes >= 1e9: - return f"{num_bytes / 1e9:.2f} GB" - elif num_bytes >= 1e6: - return f"{num_bytes / 1e6:.2f} MB" - elif num_bytes >= 1e3: - return f"{num_bytes / 1e3:.2f} KB" - else: - return f"{num_bytes:.0f} B" - - -def print_breakdown(breakdown, lm_head_fwd, num_layers, total_flops, model_params): - """Print first-principles FLOPs breakdown.""" - print() - print("--- First Principles Breakdown (forward pass, per layer) ---") - per_layer_total = sum(breakdown.values()) - for component, flops_val in breakdown.items(): - pct = flops_val / per_layer_total * 100 - print(f" {component:<20} {format_flops(flops_val):>12} ({pct:>5.1f}%)") - total_fwd = num_layers * per_layer_total + lm_head_fwd - print(f" {'LM head':<20} {format_flops(lm_head_fwd):>12}") - print(f" {'Per-layer total':<20} {format_flops(per_layer_total):>12}") - print(f" {'All layers (x' + str(num_layers) + ')':<20} {format_flops(num_layers * per_layer_total):>12}") - print(f" {'Total forward':<20} {format_flops(total_fwd):>12}") - print(f" {'Total training (3x)':<20} {format_flops(total_flops):>12}") - print(f" {'Model params':<20} {model_params / 1e9:.2f}B") - - -# ============================================================================= -# CLI -# ============================================================================= - - -def _cli_bandwidth(): - """Measure P2P bandwidth. Launch with: torchrun --nproc_per_node=2 flops.py bandwidth.""" - import os - - dist.init_process_group(backend="cpu:gloo,cuda:nccl") - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - torch.cuda.set_device(local_rank) - rank = dist.get_rank() - world_size = dist.get_world_size() - device = torch.device(f"cuda:{local_rank}") - - if rank == 0: - print(f"Measuring P2P bandwidth between {world_size} GPUs...") - for i in range(world_size): - print(f" GPU {i}: {torch.cuda.get_device_name(i)}") - - bw = measure_bus_bandwidth(device, world_size) - if rank == 0: - print(f"\nUnidirectional P2P bandwidth: {bw:.2f} GB/s") - - dist.destroy_process_group() - - -def _cli_gpu_info(): - """Print GPU info and peak TFLOPS.""" - peak, name = detect_gpu_peak_tflops() - print(f"GPU: {name}") - if peak: - print(f"Peak bf16 TFLOPS: {peak:.1f}") - else: - print("Peak bf16 TFLOPS: unknown (use --peak-tflops to override)") - print() - print("Known GPUs:") - for gpu, tflops in GPU_PEAK_TFLOPS_BF16.items(): - print(f" {gpu:<16} {tflops:>8.1f} TFLOPS") - - -def _cli_flops(): - """Compute FLOPs for a model config.""" - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument("command") - parser.add_argument("--config-path", default="./model_configs/lingua-1B") - parser.add_argument("--batch-size", type=int, default=1) - parser.add_argument("--seq-len", type=int, default=4096) - parser.add_argument("--formula", default="analytical", choices=["analytical", "simplified"]) - parser.add_argument("--num-gpus", type=int, default=1, help="Number of GPUs for per-GPU FLOPs and comm estimates") - parser.add_argument("--cp-size", type=int, default=1, help="Context parallelism size for comm overhead estimate") - parser.add_argument("--p2p-bw", type=float, default=6.0, help="P2P bandwidth in GB/s for comm time estimate") - args = parser.parse_args() - - cfg_dict = load_model_config(args.config_path) - config = from_hf_config(cfg_dict) - b, s = args.batch_size, args.seq_len - - print( - f"Config: H={config.hidden_size}, L={config.num_hidden_layers}," - f" n_heads={config.num_attention_heads}, n_kv={config.num_kv_heads}," - f" I={config.intermediate_size}, V={config.vocab_size}" - ) - print( - f"MLP: {config.num_mlp_projections} projections" - f" ({'SwiGLU/GLU' if config.num_mlp_projections == 3 else 'standard FFN'})" - ) - print(f"Batch: B={b}, S={s}, GPUs={args.num_gpus}, CP={args.cp_size}") - print() - - simplified = compute_flops_simplified(b, s, config.hidden_size, config.num_hidden_layers, config.vocab_size) - analytical, breakdown, lm_head = compute_flops_analytical(config, b, s) - - print(f"{'Method':<24} {'FLOPs/step':>14} {'Per-GPU':>14} {'Exact':>30}") - print("-" * 86) - for name, flops in [("Simplified (README)", simplified), ("Analytical", analytical)]: - per_gpu = flops // max(args.num_gpus, 1) - print(f"{name:<24} {format_flops(flops):>14} {format_flops(per_gpu):>14} {format_flops_exact(flops):>30}") - - if simplified != analytical: - diff = analytical - simplified - print(f"\nDifference: {format_flops_exact(diff)} ({diff / simplified * 100:+.2f}%)") - else: - print("\nFormulas agree exactly for this config.") - - # Communication overhead estimate - if args.cp_size > 1: - dp_size = args.num_gpus // args.cp_size - parallelism = {"dp": dp_size, "cp": args.cp_size} - tracker = MFUTracker(config, b, s, num_gpus=args.num_gpus, parallelism=parallelism) - print(f"\n--- Communication Overhead (CP={args.cp_size}, P2P BW={args.p2p_bw} GB/s) ---") - print( - f" CP ring attention: {format_bytes(tracker.comm_bytes)}/step ({format_flops_exact(tracker.comm_bytes)} bytes)" - ) - comm_time = tracker.comm_bytes / (args.p2p_bw * 1e9) if args.p2p_bw > 0 else 0 - print(f" Estimated comm time: {comm_time:.4f}s") - - model_params = _estimate_model_params(config) - print_breakdown(breakdown, lm_head, config.num_hidden_layers, analytical, model_params) - - -def _cli_cp_comm(): - """Estimate CP communication volume.""" - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument("command") - parser.add_argument("--config-path", default="./model_configs/lingua-1B") - parser.add_argument("--batch-size", type=int, default=1) - parser.add_argument("--seq-len", type=int, default=16384) - parser.add_argument("--cp-size", type=int, default=2) - args = parser.parse_args() - - cfg_dict = load_model_config(args.config_path) - config = from_hf_config(cfg_dict) - b, s = args.batch_size, args.seq_len - - comm = estimate_cp_comm_bytes(b, s, config.num_hidden_layers, config.num_kv_heads, config.head_dim, args.cp_size) - print( - f"CP={args.cp_size}, B={b}, S={s}, L={config.num_hidden_layers}," - f" n_kv_heads={config.num_kv_heads}, head_dim={config.head_dim}" - ) - print(f"Estimated CP ring attention communication: {format_bytes(comm)}/step ({format_flops_exact(comm)} bytes)") - - -if __name__ == "__main__": - import sys - - commands = { - "bandwidth": ("Measure P2P bandwidth (requires torchrun --nproc_per_node=2)", _cli_bandwidth), - "gpu-info": ("Print GPU info and peak TFLOPS", _cli_gpu_info), - "flops": ("Compute FLOPs for a model config", _cli_flops), - "cp-comm": ("Estimate CP communication volume", _cli_cp_comm), - } - - if len(sys.argv) < 2 or sys.argv[1] in ("-h", "--help") or sys.argv[1] not in commands: - print("Usage: python flops.py [options]") - print(" torchrun --nproc_per_node=2 flops.py bandwidth") - print() - print("Commands:") - for cmd, (desc, _) in commands.items(): - print(f" {cmd:<16} {desc}") - sys.exit(0 if len(sys.argv) >= 2 and sys.argv[1] in ("-h", "--help") else 1) - - commands[sys.argv[1]][1]() diff --git a/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py b/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py index a0b5a21b70..340451d7cc 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py @@ -36,6 +36,69 @@ PAD_TOKEN_ID = 3 +# Dense BF16 tensor core peak TFLOPS (without sparsity). Product pages often list +# the 2x sparse number; dense = sparse / 2. Sources: NVIDIA datasheets for each GPU. +_GPU_PEAK_TFLOPS_BF16 = { + "H100": 989.0, + "H200": 989.0, + "A100": 312.0, + "A6000": 155.0, + "L40": 181.0, + "GH200": 989.0, + "B200": 2250.0, + "GB200": 2250.0, + "B300": 2500.0, + "GB300": 2500.0, +} + +# Model types that use gated MLP (SwiGLU/GeGLU) with 3 projections vs. standard FFN with 2. +_GATED_MLP_MODEL_TYPES = frozenset({"llama", "mistral", "qwen2"}) + + +def _detect_peak_tflops_bf16(): + """Auto-detect dense BF16 peak TFLOPS for the local GPU. Returns (peak, device_name).""" + if not torch.cuda.is_available(): + return None, "unknown" + name = torch.cuda.get_device_name(0) + for key, tflops in _GPU_PEAK_TFLOPS_BF16.items(): + if key.lower() in name.lower(): + return tflops, name + return None, name + + +def _compute_per_token_flops(model_config_dict: dict, seq_len: int) -> int: + """Training FLOPs per token for a transformer (forward + backward = 3x forward). + + First-principles matmul count: Q/K/V/O projections (GQA-aware), attention + logits/values (the S^2 cost expressed per-token as 4*S*H), 2-or-3-projection + MLP (SwiGLU detected via model_type), and LM head. The returned value is + multiplied by the actual unpadded token count at log time, so it naturally + handles BSHD, THD (sequence packing), gradient accumulation, DP, and CP: + unpadded tokens on each rank already reflect that rank's share of work. + """ + h = model_config_dict["hidden_size"] + n_heads = model_config_dict["num_attention_heads"] + n_kv = model_config_dict.get("num_key_value_heads", n_heads) + head_dim = h // n_heads + kv_dim = n_kv * head_dim + ffn = model_config_dict["intermediate_size"] + vocab = model_config_dict.get("vocab_size", 0) + num_layers = model_config_dict["num_hidden_layers"] + model_type = model_config_dict.get("model_type", "") + num_mlp_proj = 3 if model_type in _GATED_MLP_MODEL_TYPES else 2 + + per_layer = ( + 2 * h * h # Q projection + + 4 * h * kv_dim # K + V projections (GQA-aware) + + 2 * h * h # O projection + + 4 * seq_len * h # attention logits + values (S^2 -> S per token) + + 2 * num_mlp_proj * h * ffn # MLP (2 or 3 projections) + ) + lm_head = 2 * h * vocab if vocab > 0 else 0 + per_token_fwd = num_layers * per_layer + lm_head + return 3 * per_token_fwd + + class PerfLogger: """Performance logger for CodonFM training. @@ -44,17 +107,39 @@ class PerfLogger: Args: dist_config: The distributed configuration. args: The Hydra arguments. + model_config_dict: Optional HF-style model config dict. When supplied together with + ``args.log_mfu`` set to True, the logger computes per-step Model FLOPs Utilization + (``train/mfu_pct``) and throughput (``train/tflops_per_gpu``) on each logging step. """ - def __init__(self, dist_config: DistributedConfig, args: DictConfig): + def __init__(self, dist_config: DistributedConfig, args: DictConfig, model_config_dict: dict | None = None): """Initialize the logger.""" self._dist_config = dist_config self._run_config = OmegaConf.to_container(args, resolve=True, throw_on_missing=True) - self.min_loss = torch.tensor(float("inf"), device=torch.device(f"cuda:{dist_config.local_rank}")) + self._device = torch.device(f"cuda:{dist_config.local_rank}") + self.min_loss = torch.tensor(float("inf"), device=self._device) self.logging_frequency = args.logger.frequency + # MFU setup: compute per-token FLOPs and peak TFLOPS once at init. Actual FLOPs per + # step are derived at log time from the tracked unpadded token count, which already + # reflects each rank's share under DP and sequence packing. + self._log_mfu = bool(args.get("log_mfu", False)) and model_config_dict is not None + self._per_token_flops = 0 + self._peak_tflops: float | None = None + if self._log_mfu: + self._per_token_flops = _compute_per_token_flops(model_config_dict, args.dataset.max_seq_length) + self._peak_tflops, gpu_name = _detect_peak_tflops_bf16() + if dist_config.local_rank == 0: + logger.info( + "MFU tracking enabled: GPU=%s, peak=%s TFLOPS BF16, per-token FLOPs=%.3e, seq_len=%d", + gpu_name, + f"{self._peak_tflops:.1f}" if self._peak_tflops else "unknown", + float(self._per_token_flops), + args.dataset.max_seq_length, + ) + metrics_dict = { "train/loss": torchmetrics.MeanMetric(), "train/grad_norm": torchmetrics.MeanMetric(), @@ -66,9 +151,13 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig): "train/gpu_memory_allocated_max_gb": torchmetrics.MaxMetric(), "train/gpu_memory_allocated_mean_gb": torchmetrics.MeanMetric(), } + if self._log_mfu: + metrics_dict["train/tflops_per_gpu"] = torchmetrics.MeanMetric() + if self._peak_tflops is not None: + metrics_dict["train/mfu_pct"] = torchmetrics.MeanMetric() self.metrics = torchmetrics.MetricCollection(metrics_dict) - self.metrics.to(torch.device(f"cuda:{dist_config.local_rank}")) + self.metrics.to(self._device) self.previous_step_time = time.perf_counter() if self._dist_config.is_main_process(): @@ -79,7 +168,6 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig): self.quant_stats_config = args.quant_stats_config.enabled # Gradient accumulation tracking - self._device = torch.device(f"cuda:{dist_config.local_rank}") self.num_tokens = 0 self.num_unpadded_tokens = torch.tensor(0, dtype=torch.int64, device=self._device) self.running_loss = torch.tensor(0.0, device=self._device) @@ -155,6 +243,17 @@ def log_step( self.metrics["train/tokens_per_second_per_gpu"].update(self.num_tokens / step_time) self.metrics["train/unpadded_tokens_per_second_per_gpu"].update(self.num_unpadded_tokens / step_time) + if self._log_mfu: + # num_unpadded_tokens is accumulated over the grad-acc micro-batches of one + # optimizer step (the last step in the logging window), so this yields FLOPs + # per optimizer step per rank. step_time is already the per-step average. + tokens_on_rank = self.num_unpadded_tokens.item() + flops_per_step = self._per_token_flops * tokens_on_rank + tflops_per_gpu = flops_per_step / step_time / 1e12 + self.metrics["train/tflops_per_gpu"].update(tflops_per_gpu) + if self._peak_tflops is not None: + self.metrics["train/mfu_pct"].update(tflops_per_gpu / self._peak_tflops * 100.0) + memory_allocated = torch.cuda.memory_allocated() / (1024**3) self.metrics["train/gpu_memory_allocated_max_gb"].update(memory_allocated) self.metrics["train/gpu_memory_allocated_mean_gb"].update(memory_allocated) diff --git a/bionemo-recipes/recipes/codonfm_native_te/tests/test_flops.py b/bionemo-recipes/recipes/codonfm_native_te/tests/test_flops.py deleted file mode 100644 index f514205fef..0000000000 --- a/bionemo-recipes/recipes/codonfm_native_te/tests/test_flops.py +++ /dev/null @@ -1,323 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# --- BEGIN COPIED FILE NOTICE --- -# This file is copied from: bionemo-recipes/recipes/llama3_native_te/tests/test_flops.py -# Do not modify this file directly. Instead, modify the source and run: -# python ci/scripts/check_copied_files.py --fix -# --- END COPIED FILE NOTICE --- - -"""Tests for the flops.py FLOPs counting and MFU module.""" - -import sys -from pathlib import Path - -import pytest - - -# Add parent directory so we can import flops -sys.path.insert(0, str(Path(__file__).parent.parent)) - -from flops import ( - MFUTracker, - ModelFLOPsConfig, - compute_flops_analytical, - compute_flops_hyena, - compute_flops_simplified, - estimate_cp_comm_bytes, - from_hf_config, -) - - -# ============================================================================ -# Test configs matching real models -# ============================================================================ - -LLAMA_1B_CONFIG = { - "hidden_size": 2048, - "num_hidden_layers": 25, - "num_attention_heads": 16, - "num_key_value_heads": 8, - "intermediate_size": 6144, - "vocab_size": 128256, - "model_type": "llama", - "hidden_act": "silu", -} - -ESM2_8M_CONFIG = { - "hidden_size": 320, - "num_hidden_layers": 6, - "num_attention_heads": 20, - "intermediate_size": 1280, - "vocab_size": 33, - "model_type": "nv_esm", - "hidden_act": "gelu", -} - -CODONFM_1B_CONFIG = { - "hidden_size": 2048, - "num_hidden_layers": 18, - "num_attention_heads": 16, - "intermediate_size": 8192, -} - - -# ============================================================================ -# Config auto-detection -# ============================================================================ - - -class TestFromHfConfig: - """Test auto-detection of model architecture from config dicts.""" - - def test_llama_detects_gqa_and_swiglu(self): - cfg = from_hf_config(LLAMA_1B_CONFIG) - assert cfg.num_kv_heads == 8 - assert cfg.num_mlp_projections == 3 - assert cfg.head_dim == 128 - - def test_esm2_detects_mha_and_standard_ffn(self): - cfg = from_hf_config(ESM2_8M_CONFIG) - assert cfg.num_kv_heads == 20 - assert cfg.num_mlp_projections == 2 - - def test_codonfm_defaults_to_mha_and_2_proj(self): - cfg = from_hf_config(CODONFM_1B_CONFIG) - assert cfg.num_kv_heads == 16 - assert cfg.num_mlp_projections == 2 - - def test_missing_vocab_defaults_to_no_lm_head(self): - cfg = from_hf_config( - {"hidden_size": 128, "num_hidden_layers": 2, "num_attention_heads": 4, "intermediate_size": 512} - ) - assert cfg.vocab_size == 0 - assert cfg.has_lm_head is False - - def test_overrides_take_precedence(self): - cfg = from_hf_config(ESM2_8M_CONFIG, num_mlp_projections=3, vocab_size=0, has_lm_head=False) - assert cfg.num_mlp_projections == 3 - assert cfg.vocab_size == 0 - assert cfg.has_lm_head is False - - -# ============================================================================ -# Analytical FLOPs formula -# ============================================================================ - - -class TestComputeFlopsAnalytical: - """Test the first-principles analytical FLOPs formula.""" - - def test_training_is_3x_forward(self): - cfg = from_hf_config(LLAMA_1B_CONFIG) - total, breakdown, lm_head = compute_flops_analytical(cfg, 1, 4096) - forward = cfg.num_hidden_layers * sum(breakdown.values()) + lm_head - assert total == 3 * forward - - def test_swiglu_has_3_mlp_projections(self): - cfg = from_hf_config(LLAMA_1B_CONFIG) - _, breakdown, _ = compute_flops_analytical(cfg, 1, 1024) - assert "Gate projection" in breakdown - assert "Up projection" in breakdown - assert "Down projection" in breakdown - - def test_standard_ffn_has_2_mlp_projections(self): - cfg = from_hf_config(ESM2_8M_CONFIG) - _, breakdown, _ = compute_flops_analytical(cfg, 1, 1024) - assert "Gate projection" not in breakdown - assert "Up projection" in breakdown - assert "Down projection" in breakdown - - def test_no_lm_head_when_vocab_zero(self): - cfg = from_hf_config(CODONFM_1B_CONFIG) - _, _, lm_head = compute_flops_analytical(cfg, 1, 1024) - assert lm_head == 0 - - def test_flops_scale_linearly_with_batch(self): - cfg = from_hf_config(LLAMA_1B_CONFIG) - flops_b1, _, _ = compute_flops_analytical(cfg, 1, 1024) - flops_b4, _, _ = compute_flops_analytical(cfg, 4, 1024) - assert flops_b4 == 4 * flops_b1 - - def test_known_value_llama_lingua_1b(self): - """Golden value: validated against PyTorch FlopCounterMode and README formula.""" - cfg = from_hf_config(LLAMA_1B_CONFIG) - total, _, _ = compute_flops_analytical(cfg, 1, 4096) - assert total == 47_687_021_887_488 - - -# ============================================================================ -# Simplified formula -# ============================================================================ - - -class TestComputeFlopsSimplified: - """Test the simplified README formula and its relationship to analytical.""" - - def test_matches_analytical_when_mha_and_i_equals_4h(self): - """ESM2 has MHA + I=4H + 2 projections: same assumptions as simplified formula.""" - cfg = from_hf_config(ESM2_8M_CONFIG) - analytical, _, _ = compute_flops_analytical(cfg, 1, 1024) - simplified = compute_flops_simplified(1, 1024, cfg.hidden_size, cfg.num_hidden_layers, cfg.vocab_size) - assert analytical == simplified - - def test_differs_when_gqa_or_swiglu(self): - """GQA + SwiGLU breaks the simplified formula's assumptions.""" - cfg_dict = { - **LLAMA_1B_CONFIG, - "num_attention_heads": 32, - "num_key_value_heads": 8, - "intermediate_size": 8192, - "num_hidden_layers": 16, - } - cfg = from_hf_config(cfg_dict) - analytical, _, _ = compute_flops_analytical(cfg, 1, 4096) - simplified = compute_flops_simplified(1, 4096, cfg.hidden_size, cfg.num_hidden_layers, cfg.vocab_size) - assert analytical != simplified - - -# ============================================================================ -# Hyena formula -# ============================================================================ - - -class TestComputeFlopsHyena: - """Test the Hyena (Evo2) FLOPs formula.""" - - @pytest.fixture() - def hyena_config(self): - return ModelFLOPsConfig( - hidden_size=1024, - num_hidden_layers=4, - num_attention_heads=8, - num_kv_heads=8, - head_dim=128, - intermediate_size=4096, - num_mlp_projections=3, - vocab_size=512, - has_lm_head=True, - ) - - def test_scales_subquadratically(self, hyena_config): - """Hyena uses O(S log S) convolution, not O(S^2) attention.""" - flops_1k = compute_flops_hyena(hyena_config, 1, 1024, hyena_layer_counts={"S": 0, "D": 0, "H": 4, "A": 0}) - flops_2k = compute_flops_hyena(hyena_config, 1, 2048, hyena_layer_counts={"S": 0, "D": 0, "H": 4, "A": 0}) - ratio = flops_2k / flops_1k - assert ratio < 3.0 # S^2 would give ~4x; Hyena should be well below - - def test_hybrid_attention_adds_quadratic_cost(self, hyena_config): - """Adding standard attention layers increases FLOPs due to S^2 term.""" - all_hyena = compute_flops_hyena(hyena_config, 1, 4096, hyena_layer_counts={"S": 0, "D": 0, "H": 4, "A": 0}) - with_attn = compute_flops_hyena(hyena_config, 1, 4096, hyena_layer_counts={"S": 0, "D": 0, "H": 2, "A": 2}) - assert with_attn > all_hyena - - -# ============================================================================ -# MFUTracker -# ============================================================================ - - -class TestMFUTracker: - """Test the MFUTracker class used by training scripts.""" - - def test_from_config_dict(self): - tracker = MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0) - assert tracker.total_flops == 47_687_021_887_488 - assert tracker.per_gpu_flops == tracker.total_flops - - def test_multi_gpu_divides_flops(self): - single = MFUTracker.from_config_dict( - LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, num_gpus=1, peak_tflops=155.0 - ) - multi = MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, num_gpus=2, peak_tflops=155.0) - assert multi.per_gpu_flops == single.total_flops // 2 - - def test_compute_mfu_correctness(self): - tracker = MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0) - result = tracker.compute_mfu(step_time=0.5) - expected_tflops = tracker.per_gpu_flops / 0.5 / 1e12 - expected_mfu = expected_tflops / 155.0 * 100 - assert abs(result["mfu"] - expected_mfu) < 0.01 - assert abs(result["tflops_per_gpu"] - expected_tflops) < 0.01 - - def test_mfu_inversely_proportional_to_step_time(self): - tracker = MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0) - fast = tracker.compute_mfu(step_time=0.5) - slow = tracker.compute_mfu(step_time=1.0) - assert abs(fast["mfu"] - 2 * slow["mfu"]) < 0.01 - - def test_all_formula_options(self): - for formula in ["analytical", "simplified", "hyena"]: - if formula == "hyena": - cfg = ModelFLOPsConfig( - hidden_size=1024, - num_hidden_layers=4, - num_attention_heads=8, - num_kv_heads=8, - head_dim=128, - intermediate_size=4096, - num_mlp_projections=3, - vocab_size=512, - has_lm_head=True, - ) - tracker = MFUTracker(cfg, batch_size=1, seq_len=4096, peak_tflops=155.0, formula=formula) - else: - tracker = MFUTracker.from_config_dict( - LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0, formula=formula - ) - assert tracker.total_flops > 0 - - def test_invalid_formula_raises(self): - with pytest.raises(ValueError, match="Unknown formula"): - MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0, formula="bad") - - def test_cp_communication_estimate(self): - tracker = MFUTracker.from_config_dict( - LLAMA_1B_CONFIG, batch_size=1, seq_len=16384, num_gpus=2, parallelism={"dp": 1, "cp": 2}, peak_tflops=155.0 - ) - assert tracker.comm_bytes > 0 - overhead = tracker.estimate_comm_overhead(step_time=1.0, measured_bw_gbps=6.6) - assert overhead["estimated_comm_time"] > 0 - assert 0 < overhead["comm_pct"] < 100 - - def test_no_comm_single_gpu(self): - tracker = MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0) - assert tracker.comm_bytes == 0 - - -# ============================================================================ -# Communication estimation -# ============================================================================ - - -class TestCPCommEstimation: - """Test CP ring attention communication byte estimates.""" - - def test_zero_without_cp(self): - assert estimate_cp_comm_bytes(1, 4096, 25, 8, 128, cp_size=1) == 0 - - def test_scales_linearly_with_seq_len(self): - comm_4k = estimate_cp_comm_bytes(1, 4096, 25, 8, 128, cp_size=2) - comm_8k = estimate_cp_comm_bytes(1, 8192, 25, 8, 128, cp_size=2) - assert comm_8k == 2 * comm_4k - - def test_scales_linearly_with_batch(self): - comm_b1 = estimate_cp_comm_bytes(1, 4096, 25, 8, 128, cp_size=2) - comm_b4 = estimate_cp_comm_bytes(4, 4096, 25, 8, 128, cp_size=2) - assert comm_b4 == 4 * comm_b1 - - def test_known_value_lingua_1b(self): - """Golden value for lingua-1B at S=4096, CP=2.""" - assert estimate_cp_comm_bytes(1, 4096, 25, 8, 128, cp_size=2) == 419_430_400 diff --git a/bionemo-recipes/recipes/codonfm_native_te/train_fsdp2.py b/bionemo-recipes/recipes/codonfm_native_te/train_fsdp2.py index 4b2608c88f..f41e85ad8e 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/codonfm_native_te/train_fsdp2.py @@ -16,7 +16,6 @@ """FSDP2 training script for CodonFM with TransformerEngine layers.""" import logging -import time from contextlib import nullcontext from pathlib import Path @@ -26,7 +25,6 @@ from checkpoint import load_checkpoint_fsdp2, save_checkpoint_fsdp2, save_final_model_fsdp2, should_save_checkpoint from dataset import create_bshd_dataloader, create_thd_dataloader from distributed_config import DistributedConfig -from flops import MFUTracker, from_hf_config from modeling_codonfm_te import MODEL_PRESETS, CodonFMConfig, CodonFMForMaskedLM from omegaconf import DictConfig, OmegaConf from perf_logger import PerfLogger @@ -165,25 +163,15 @@ def main(args: DictConfig) -> float | None: start_step = 0 epoch = 0 - perf_logger = PerfLogger(dist_config, args) - - # --- MFU Tracking (optional) --- - mfu_tracker = None - if args.get("log_mfu", False): - mfu_tracker = MFUTracker( - config=from_hf_config(config.to_dict()), - batch_size=args.dataset.micro_batch_size, - seq_len=args.dataset.max_seq_length, - num_gpus=dist_config.world_size, - parallelism={"dp": dist_config.world_size}, - ) - if dist_config.is_main_process(): - logger.info("MFU tracking enabled:\n%s", mfu_tracker.summary()) + perf_logger = PerfLogger( + dist_config, + args, + model_config_dict=config.to_dict() if args.get("log_mfu", False) else None, + ) # Training loop step = start_step micro_step = 0 # Gradient accumulation step counter - step_start_time = time.perf_counter() while step < args.num_train_steps: batches_in_epoch = 0 for batch in train_dataloader: @@ -219,13 +207,6 @@ def main(args: DictConfig) -> float | None: lr=optimizer.param_groups[0]["lr"], ) - if mfu_tracker is not None: - step_time = time.perf_counter() - step_start_time - mfu_info = mfu_tracker.compute_mfu(step_time) - if dist_config.is_main_process(): - logger.info("MFU: %.1f%% (%.2f TFLOPS/GPU)", mfu_info["mfu"], mfu_info["tflops_per_gpu"]) - step_start_time = time.perf_counter() - if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps): save_checkpoint_fsdp2( model=model, diff --git a/bionemo-recipes/recipes/esm2_native_te/README.md b/bionemo-recipes/recipes/esm2_native_te/README.md index 5e6efab3f1..e97d13eb0e 100644 --- a/bionemo-recipes/recipes/esm2_native_te/README.md +++ b/bionemo-recipes/recipes/esm2_native_te/README.md @@ -382,16 +382,16 @@ Enable per-step Model FLOPs Utilization (MFU) logging during training by adding torchrun --nproc_per_node=2 train_fsdp2.py --config-name L1_3B log_mfu=true ``` -This logs MFU (%), TFLOPS/GPU, and step time at each optimizer step. The module auto-detects model architecture (MHA, standard FFN, etc.) from the model config. +This adds two metrics at each logging interval, emitted alongside existing metrics via WANDB and +stdout: -The `flops.py` CLI provides standalone utilities: +- `train/tflops_per_gpu` — achieved BF16 TFLOPS per GPU +- `train/mfu_pct` — MFU as a percentage of the GPU's peak dense BF16 TFLOPS -```bash -python flops.py gpu-info # Show GPU and peak TFLOPS -python flops.py flops --config-path ./model_configs/nvidia/esm2_t6_8M_UR50D # Compute FLOPs -python flops.py flops --config-path nvidia/esm2_t36_3B_UR50D # FLOPs from HF Hub config -torchrun --nproc_per_node=2 flops.py bandwidth # Measure P2P GPU bandwidth -``` +The FLOPs formula auto-detects model architecture from the HF config (MHA vs. GQA, gated vs. +standard FFN, LM head presence) and scales with the actual unpadded token count on each rank. This +means it naturally handles data parallelism, context parallelism, BSHD, and THD (sequence packing) +without per-strategy code paths. The implementation lives in `perf_logger.py`. ## Developer Guide diff --git a/bionemo-recipes/recipes/esm2_native_te/flops.py b/bionemo-recipes/recipes/esm2_native_te/flops.py deleted file mode 100644 index ade8d8a12c..0000000000 --- a/bionemo-recipes/recipes/esm2_native_te/flops.py +++ /dev/null @@ -1,714 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# --- BEGIN COPIED FILE NOTICE --- -# This file is copied from: bionemo-recipes/recipes/llama3_native_te/flops.py -# Do not modify this file directly. Instead, modify the source and run: -# python ci/scripts/check_copied_files.py --fix -# --- END COPIED FILE NOTICE --- - -"""Architecture-independent FLOPs counting, MFU calculation, and communication overhead estimation. - -Supports transformer architectures (Llama, ESM2, CodonFM, etc.) and Hyena (Evo2). -Designed to be copied to any recipe via check_copied_files.py and hooked into -training scripts for live MFU tracking. - -Usage as a library (in training scripts): - from flops import MFUTracker, from_hf_config - tracker = MFUTracker.from_config_dict(config_dict, batch_size=4, seq_len=4096, num_gpus=2) - mfu_info = tracker.compute_mfu(step_time=0.5) - -Usage as a CLI: - python flops.py gpu-info - python flops.py flops --config-path ./model_configs/lingua-1B - python flops.py cp-comm --config-path ./model_configs/lingua-1B --cp-size 2 - torchrun --nproc_per_node=2 flops.py bandwidth -""" - -import math -import time -from dataclasses import dataclass - -import torch -import torch.distributed as dist - - -# ============================================================================= -# GPU Peak TFLOPS -# ============================================================================= - -# Dense (without sparsity) BF16 tensor core peak TFLOPS. -# NVIDIA product pages often list the 2x sparse value; dense = sparse / 2. -# Sources: -# H100/H200/GH200: https://www.nvidia.com/en-us/data-center/h100/ (1,979 TFLOPS sparse) -# A100: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf -# A6000: https://www.nvidia.com/content/dam/en-zz/Solutions/design-visualization/quadro-product-literature/proviz-print-nvidia-rtx-a6000-datasheet-us-nvidia-1454980-r9-web%20(1).pdf -# L40: https://www.nvidia.com/en-us/data-center/l40/ -# B200/GB200: https://www.nvidia.com/en-us/data-center/dgx-b200/ -# B300/GB300: https://www.nvidia.com/en-us/data-center/gb300-nvl72/ (360 PFLOPS sparse / 72 GPUs / 2) -GPU_PEAK_TFLOPS_BF16 = { - "H100": 989.0, - "H200": 989.0, - "A100": 312.0, - "A6000": 155.0, - "L40": 181.0, - "GH200": 989.0, - "B200": 2250.0, - "GB200": 2250.0, - "B300": 2500.0, - "GB300": 2500.0, -} - - -def detect_gpu_peak_tflops(): - """Auto-detect GPU peak bf16 TFLOPS from device name via substring match.""" - device_name = torch.cuda.get_device_name(0) - for gpu_key, tflops in GPU_PEAK_TFLOPS_BF16.items(): - if gpu_key.lower() in device_name.lower(): - return tflops, device_name - return None, device_name - - -# ============================================================================= -# Model FLOPs Config -# ============================================================================= - -# Model types that use gated MLP (SwiGLU/GeGLU) with 3 projections instead of 2. -GATED_MLP_MODEL_TYPES = frozenset({"llama", "mistral", "qwen2"}) - - -@dataclass(frozen=True) -class ModelFLOPsConfig: - """Architecture-independent parameters for FLOPs calculation. - - Can be constructed manually or via from_hf_config() for auto-detection. - """ - - hidden_size: int # H - num_hidden_layers: int # L - num_attention_heads: int # n_heads - num_kv_heads: int # n_kv (== n_heads for MHA) - head_dim: int # H // n_heads - intermediate_size: int # I (FFN intermediate dimension) - num_mlp_projections: int # 2 (standard FFN) or 3 (SwiGLU/GLU) - vocab_size: int # V - has_lm_head: bool # True for LM models, False for ViT etc. - - -def from_hf_config(config_dict, **overrides): - """Create ModelFLOPsConfig from an HF-compatible config dict. - - Auto-detects architecture: - - GQA vs MHA: from num_key_value_heads (absent = MHA) - - Gated MLP (3 proj) vs standard FFN (2 proj): from model_type - - LM head: from vocab_size > 0 - - Args: - config_dict: Dict with standard HF config keys (hidden_size, num_hidden_layers, etc.). - Works with config.json dicts, config.to_dict(), or MODEL_PRESETS dicts. - **overrides: Explicit overrides for any field (e.g., num_mlp_projections=3). - """ - h = config_dict["hidden_size"] - n_heads = config_dict["num_attention_heads"] - n_kv = config_dict.get("num_key_value_heads", n_heads) - vocab = config_dict.get("vocab_size", 0) - model_type = config_dict.get("model_type", "") - - # Detect gated MLP (3 projections) vs standard FFN (2 projections). - # Llama/Mistral/Qwen use SwiGLU (gate + up + down = 3 projections). - # ESM2/CodonFM/Geneformer/BERT use standard FFN (up + down = 2 projections). - num_mlp_proj = 3 if model_type in GATED_MLP_MODEL_TYPES else 2 - - kwargs = { - "hidden_size": h, - "num_hidden_layers": config_dict["num_hidden_layers"], - "num_attention_heads": n_heads, - "num_kv_heads": n_kv, - "head_dim": h // n_heads, - "intermediate_size": config_dict["intermediate_size"], - "num_mlp_projections": num_mlp_proj, - "vocab_size": vocab, - "has_lm_head": vocab > 0, - } - kwargs.update(overrides) - return ModelFLOPsConfig(**kwargs) - - -# ============================================================================= -# FLOPs Formulas -# ============================================================================= - - -def compute_flops_analytical(config, batch_size, seq_len): - """First-principles FLOPs for any transformer (GQA/MHA, SwiGLU/GELU). - - Counts matmul FLOPs only (2 FLOPs per multiply-accumulate). Excludes softmax, - layer norms, activations, and element-wise ops. - - Handles: - - GQA vs MHA: K/V projection sizes based on config.num_kv_heads - - SwiGLU vs standard FFN: 2 or 3 MLP projections - - LM head presence - - Returns: - (total_training_flops, per_layer_breakdown_dict, lm_head_forward_flops) - """ - b, s, h = batch_size, seq_len, config.hidden_size - kv_dim = config.num_kv_heads * config.head_dim - - breakdown = { - "Q projection": 2 * b * s * h * h, - "K projection": 2 * b * s * h * kv_dim, - "V projection": 2 * b * s * h * kv_dim, - "O projection": 2 * b * s * h * h, - "Attn logits": 2 * b * s * s * h, - "Attn values": 2 * b * s * s * h, - } - - ffn = config.intermediate_size - if config.num_mlp_projections == 3: - # SwiGLU/GeGLU: gate + up + down = 3 matmuls - breakdown["Gate projection"] = 2 * b * s * h * ffn - breakdown["Up projection"] = 2 * b * s * h * ffn - breakdown["Down projection"] = 2 * b * s * ffn * h - else: - # Standard FFN: up + down = 2 matmuls - breakdown["Up projection"] = 2 * b * s * h * ffn - breakdown["Down projection"] = 2 * b * s * ffn * h - - per_layer_fwd = sum(breakdown.values()) - lm_head_fwd = 2 * b * s * h * config.vocab_size if config.has_lm_head else 0 - total_fwd = config.num_hidden_layers * per_layer_fwd + lm_head_fwd - total_training = 3 * total_fwd - - return total_training, breakdown, lm_head_fwd - - -def compute_flops_simplified(batch_size, seq_len, hidden_size, num_layers, vocab_size): - """Simplified formula assuming standard MHA + standard FFN with I=4H. - - This is the formula from the Llama3 README: - (24*B*S*H^2 + 4*B*S^2*H) * 3*L + 6*B*S*H*V - - The 24*H^2 coefficient assumes: 8*H^2 for 4 attention projections (MHA) + - 16*H^2 for 2 MLP projections with I=4H. - """ - b, s, h = batch_size, seq_len, hidden_size - return (24 * b * s * h * h + 4 * b * s * s * h) * (3 * num_layers) + (6 * b * s * h * vocab_size) - - -def compute_flops_hyena(config, batch_size, seq_len, hyena_layer_counts=None): - """FLOPs for Hyena-based models (Evo2). - - Based on evo2_provider.py. Hyena replaces attention with O(N) FFT convolution. - - Args: - config: ModelFLOPsConfig with model dimensions. - batch_size: Batch size. - seq_len: Sequence length. - hyena_layer_counts: Optional dict {"S": n, "D": n, "H": n, "A": n} for - short/medium/long conv and attention layer counts. If None, assumes - all layers are long-conv Hyena (H=num_layers, no attention). - """ - b, s, h = batch_size, seq_len, config.hidden_size - ffn = config.intermediate_size - - if hyena_layer_counts is None: - hyena_layer_counts = {"S": 0, "D": 0, "H": config.num_hidden_layers, "A": 0} - - # Common per-layer FLOPs - pre_attn_qkv_proj = 2 * 3 * b * s * h * h - post_attn_proj = 2 * b * s * h * h - glu_ffn = 2 * 3 * b * s * ffn * h - - # Layer-type-specific FLOPs (defaults from evo2_provider.py) - attn = 2 * 2 * b * h * s * s # Standard S^2 attention - hyena_proj = 2 * 3 * b * s * 3 * h # short_conv_L=3 default - hyena_short_conv = 2 * b * s * 7 * h # short_conv_len=7 - hyena_medium_conv = 2 * b * s * 128 * h # medium_conv_len=128 - hyena_long_fft = b * 10 * s * math.log2(max(s, 2)) * h - - n_s = hyena_layer_counts.get("S", 0) - n_d = hyena_layer_counts.get("D", 0) - n_h = hyena_layer_counts.get("H", 0) - n_a = hyena_layer_counts.get("A", 0) - - logits = 2 * b * s * h * config.vocab_size if config.has_lm_head else 0 - - total_fwd = ( - logits - + config.num_hidden_layers * (pre_attn_qkv_proj + post_attn_proj + glu_ffn) - + n_a * attn - + (n_s + n_d + n_h) * hyena_proj - + n_s * hyena_short_conv - + n_d * hyena_medium_conv - + int(n_h * hyena_long_fft) - ) - - return 3 * total_fwd - - -# ============================================================================= -# MFU Tracker -# ============================================================================= - - -class MFUTracker: - """Tracks MFU during training. Initialize once, call compute_mfu() per step. - - Usage: - tracker = MFUTracker.from_config_dict(config_dict, batch_size=4, seq_len=4096, num_gpus=2) - # In training loop: - mfu_info = tracker.compute_mfu(step_time=0.5) - print(f"MFU: {mfu_info['mfu']:.1f}%") - """ - - def __init__( - self, - config, - batch_size, - seq_len, - num_gpus=1, - parallelism=None, - peak_tflops=None, - formula="analytical", - hyena_layer_counts=None, - ): - """Initialize MFU tracker. - - Args: - config: ModelFLOPsConfig instance. - batch_size: Micro batch size per GPU. - seq_len: Sequence length. - num_gpus: Total number of GPUs. - parallelism: Dict of parallelism dimensions, e.g. {"dp": 1, "cp": 2, "tp": 1}. - Used for communication overhead estimation. - peak_tflops: GPU peak bf16 TFLOPS. Auto-detected if None. - formula: "analytical", "simplified", or "hyena". - hyena_layer_counts: For Hyena formula, dict of layer type counts. - """ - self.config = config - self.batch_size = batch_size - self.seq_len = seq_len - self.num_gpus = num_gpus - self.parallelism = parallelism or {} - self.formula = formula - - if formula == "analytical": - self.total_flops, self.breakdown, self.lm_head_flops = compute_flops_analytical( - config, batch_size, seq_len - ) - elif formula == "simplified": - self.total_flops = compute_flops_simplified( - batch_size, seq_len, config.hidden_size, config.num_hidden_layers, config.vocab_size - ) - self.breakdown = None - self.lm_head_flops = 0 - elif formula == "hyena": - self.total_flops = compute_flops_hyena(config, batch_size, seq_len, hyena_layer_counts) - self.breakdown = None - self.lm_head_flops = 0 - else: - raise ValueError(f"Unknown formula: {formula!r}. Use 'analytical', 'simplified', or 'hyena'.") - - self.per_gpu_flops = self.total_flops // max(num_gpus, 1) - - if peak_tflops is not None: - self.peak_tflops = peak_tflops - self.device_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "unknown" - else: - detected, self.device_name = detect_gpu_peak_tflops() - self.peak_tflops = detected - - self.comm_bytes = self._estimate_comm() - - @classmethod - def from_config_dict(cls, config_dict, batch_size, seq_len, **kwargs): - """Create from an HF config dict with auto-detection.""" - config = from_hf_config(config_dict) - return cls(config, batch_size, seq_len, **kwargs) - - def compute_mfu(self, step_time): - """Compute MFU from measured step time. - - Args: - step_time: Wall-clock time for one training step (seconds). - - Returns: - Dict with mfu (%), tflops_per_gpu, per_gpu_flops, total_flops, step_time. - """ - tflops = self.per_gpu_flops / step_time / 1e12 - mfu = tflops / self.peak_tflops * 100 if self.peak_tflops else 0.0 - return { - "mfu": mfu, - "tflops_per_gpu": tflops, - "per_gpu_flops": self.per_gpu_flops, - "total_flops": self.total_flops, - "step_time": step_time, - } - - def estimate_comm_overhead(self, step_time, measured_bw_gbps=None): - """Estimate communication overhead as a fraction of step time. - - Args: - step_time: Measured step time in seconds. - measured_bw_gbps: Measured P2P bandwidth in GB/s. If None, uses a default estimate. - - Returns: - Dict with comm_bytes, estimated_comm_time, comm_pct. - """ - bw = measured_bw_gbps or 6.0 # Default ~PCIe Gen3 x8 - comm_time = self.comm_bytes / (bw * 1e9) if bw > 0 else 0.0 - comm_pct = comm_time / step_time * 100 if step_time > 0 else 0.0 - return {"comm_bytes": self.comm_bytes, "estimated_comm_time": comm_time, "comm_pct": comm_pct} - - def _estimate_comm(self): - """Estimate total communication bytes per step based on parallelism.""" - total = 0 - cp_size = self.parallelism.get("cp", 1) - dp_size = self.parallelism.get("dp", 1) - - if cp_size > 1: - total += estimate_cp_comm_bytes( - self.batch_size, - self.seq_len, - self.config.num_hidden_layers, - self.config.num_kv_heads, - self.config.head_dim, - cp_size, - ) - - if dp_size > 1: - # FSDP reduce-scatter estimate: ~2 * model_params * dtype_bytes * (dp-1)/dp - model_params = _estimate_model_params(self.config) - total += 2 * model_params * 2 * (dp_size - 1) // dp_size - - return total - - def summary(self): - """Return a human-readable summary string.""" - lines = [ - f"MFUTracker: {self.formula} formula, {self.num_gpus} GPU(s)", - f" Model: H={self.config.hidden_size}, L={self.config.num_hidden_layers}," - f" heads={self.config.num_attention_heads}, kv_heads={self.config.num_kv_heads}," - f" I={self.config.intermediate_size}, V={self.config.vocab_size}", - f" MLP projections: {self.config.num_mlp_projections}" - f" ({'SwiGLU/GLU' if self.config.num_mlp_projections == 3 else 'standard FFN'})", - f" Batch: B={self.batch_size}, S={self.seq_len}", - f" Total FLOPs/step: {format_flops(self.total_flops)} ({format_flops_exact(self.total_flops)})", - f" Per-GPU FLOPs: {format_flops(self.per_gpu_flops)}", - f" GPU: {self.device_name} (Peak: {self.peak_tflops} TFLOPS)" if self.peak_tflops else " GPU: unknown", - ] - if self.parallelism: - lines.append(f" Parallelism: {self.parallelism}") - if self.comm_bytes > 0: - lines.append(f" Estimated comm: {format_bytes(self.comm_bytes)}/step") - return "\n".join(lines) - - -# ============================================================================= -# Communication Estimation -# ============================================================================= - - -def estimate_cp_comm_bytes(b, s, num_layers, n_kv_heads, head_dim, cp_size, dtype_bytes=2): - """Estimate total bytes transferred for CP ring attention per training step. - - Ring attention sends local KV chunks around the ring. Per layer forward: - (cp-1) steps, each sending B * (S/cp) * 2 * kv_dim * dtype_bytes. - Training = ~2x forward communication (forward sends KV, backward sends dKV). - """ - if cp_size <= 1: - return 0 - s_local = s // cp_size - kv_dim = n_kv_heads * head_dim - per_layer_fwd = (cp_size - 1) * b * s_local * 2 * kv_dim * dtype_bytes - return 2 * num_layers * per_layer_fwd - - -def _estimate_model_params(config): - """Rough parameter count estimate from config dimensions.""" - h = config.hidden_size - kv_dim = config.num_kv_heads * config.head_dim - attn_params = h * h + 2 * h * kv_dim + h * h # Q + K + V + O - mlp_params = config.num_mlp_projections * h * config.intermediate_size - layer_params = attn_params + mlp_params - total = config.num_hidden_layers * layer_params - if config.has_lm_head: - total += config.vocab_size * h * 2 # embed + lm_head - return total - - -# ============================================================================= -# Utilities -# ============================================================================= - - -def measure_bus_bandwidth(device, world_size, num_iters=20, num_elements=10_000_000): - """Measure unidirectional P2P bandwidth via send/recv (matches CP ring pattern).""" - if world_size <= 1: - return 0.0 - - rank = dist.get_rank() - tensor = torch.randn(num_elements, device=device, dtype=torch.bfloat16) - peer = 1 - rank - - for _ in range(5): - if rank == 0: - dist.send(tensor, dst=peer) - dist.recv(tensor, src=peer) - else: - dist.recv(tensor, src=peer) - dist.send(tensor, dst=peer) - torch.cuda.synchronize() - - dist.barrier() - torch.cuda.synchronize() - start = time.perf_counter() - for _ in range(num_iters): - if rank == 0: - dist.send(tensor, dst=peer) - else: - dist.recv(tensor, src=peer) - torch.cuda.synchronize() - elapsed = time.perf_counter() - start - - data_bytes = tensor.nelement() * tensor.element_size() - return num_iters * data_bytes / elapsed / 1e9 - - -def load_model_config(config_path): - """Load model config dict from a local path or HuggingFace model ID. - - Supports: - - Local directory: ./model_configs/lingua-1B (reads config.json inside) - - Local file: ./model_configs/lingua-1B/config.json - - HF model ID: nvidia/esm2_t36_3B_UR50D (fetches from HuggingFace Hub) - """ - import json - from pathlib import Path - - path = Path(config_path) - if path.is_dir(): - path = path / "config.json" - if path.exists(): - return json.loads(path.read_text()) - - # Fall back to HuggingFace Hub - from transformers import AutoConfig - - hf_config = AutoConfig.from_pretrained(config_path, trust_remote_code=True) - return hf_config.to_dict() - - -# ============================================================================= -# Formatting -# ============================================================================= - - -def format_flops(flops): - """Format FLOPs with appropriate unit (G/T/P).""" - if flops >= 1e15: - return f"{flops / 1e15:.2f} P" - elif flops >= 1e12: - return f"{flops / 1e12:.2f} T" - elif flops >= 1e9: - return f"{flops / 1e9:.2f} G" - else: - return f"{flops:.2e}" - - -def format_flops_exact(flops): - """Format FLOPs as the full integer with commas.""" - return f"{int(flops):,}" - - -def format_bytes(num_bytes): - """Format bytes with appropriate unit.""" - if num_bytes >= 1e9: - return f"{num_bytes / 1e9:.2f} GB" - elif num_bytes >= 1e6: - return f"{num_bytes / 1e6:.2f} MB" - elif num_bytes >= 1e3: - return f"{num_bytes / 1e3:.2f} KB" - else: - return f"{num_bytes:.0f} B" - - -def print_breakdown(breakdown, lm_head_fwd, num_layers, total_flops, model_params): - """Print first-principles FLOPs breakdown.""" - print() - print("--- First Principles Breakdown (forward pass, per layer) ---") - per_layer_total = sum(breakdown.values()) - for component, flops_val in breakdown.items(): - pct = flops_val / per_layer_total * 100 - print(f" {component:<20} {format_flops(flops_val):>12} ({pct:>5.1f}%)") - total_fwd = num_layers * per_layer_total + lm_head_fwd - print(f" {'LM head':<20} {format_flops(lm_head_fwd):>12}") - print(f" {'Per-layer total':<20} {format_flops(per_layer_total):>12}") - print(f" {'All layers (x' + str(num_layers) + ')':<20} {format_flops(num_layers * per_layer_total):>12}") - print(f" {'Total forward':<20} {format_flops(total_fwd):>12}") - print(f" {'Total training (3x)':<20} {format_flops(total_flops):>12}") - print(f" {'Model params':<20} {model_params / 1e9:.2f}B") - - -# ============================================================================= -# CLI -# ============================================================================= - - -def _cli_bandwidth(): - """Measure P2P bandwidth. Launch with: torchrun --nproc_per_node=2 flops.py bandwidth.""" - import os - - dist.init_process_group(backend="cpu:gloo,cuda:nccl") - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - torch.cuda.set_device(local_rank) - rank = dist.get_rank() - world_size = dist.get_world_size() - device = torch.device(f"cuda:{local_rank}") - - if rank == 0: - print(f"Measuring P2P bandwidth between {world_size} GPUs...") - for i in range(world_size): - print(f" GPU {i}: {torch.cuda.get_device_name(i)}") - - bw = measure_bus_bandwidth(device, world_size) - if rank == 0: - print(f"\nUnidirectional P2P bandwidth: {bw:.2f} GB/s") - - dist.destroy_process_group() - - -def _cli_gpu_info(): - """Print GPU info and peak TFLOPS.""" - peak, name = detect_gpu_peak_tflops() - print(f"GPU: {name}") - if peak: - print(f"Peak bf16 TFLOPS: {peak:.1f}") - else: - print("Peak bf16 TFLOPS: unknown (use --peak-tflops to override)") - print() - print("Known GPUs:") - for gpu, tflops in GPU_PEAK_TFLOPS_BF16.items(): - print(f" {gpu:<16} {tflops:>8.1f} TFLOPS") - - -def _cli_flops(): - """Compute FLOPs for a model config.""" - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument("command") - parser.add_argument("--config-path", default="./model_configs/lingua-1B") - parser.add_argument("--batch-size", type=int, default=1) - parser.add_argument("--seq-len", type=int, default=4096) - parser.add_argument("--formula", default="analytical", choices=["analytical", "simplified"]) - parser.add_argument("--num-gpus", type=int, default=1, help="Number of GPUs for per-GPU FLOPs and comm estimates") - parser.add_argument("--cp-size", type=int, default=1, help="Context parallelism size for comm overhead estimate") - parser.add_argument("--p2p-bw", type=float, default=6.0, help="P2P bandwidth in GB/s for comm time estimate") - args = parser.parse_args() - - cfg_dict = load_model_config(args.config_path) - config = from_hf_config(cfg_dict) - b, s = args.batch_size, args.seq_len - - print( - f"Config: H={config.hidden_size}, L={config.num_hidden_layers}," - f" n_heads={config.num_attention_heads}, n_kv={config.num_kv_heads}," - f" I={config.intermediate_size}, V={config.vocab_size}" - ) - print( - f"MLP: {config.num_mlp_projections} projections" - f" ({'SwiGLU/GLU' if config.num_mlp_projections == 3 else 'standard FFN'})" - ) - print(f"Batch: B={b}, S={s}, GPUs={args.num_gpus}, CP={args.cp_size}") - print() - - simplified = compute_flops_simplified(b, s, config.hidden_size, config.num_hidden_layers, config.vocab_size) - analytical, breakdown, lm_head = compute_flops_analytical(config, b, s) - - print(f"{'Method':<24} {'FLOPs/step':>14} {'Per-GPU':>14} {'Exact':>30}") - print("-" * 86) - for name, flops in [("Simplified (README)", simplified), ("Analytical", analytical)]: - per_gpu = flops // max(args.num_gpus, 1) - print(f"{name:<24} {format_flops(flops):>14} {format_flops(per_gpu):>14} {format_flops_exact(flops):>30}") - - if simplified != analytical: - diff = analytical - simplified - print(f"\nDifference: {format_flops_exact(diff)} ({diff / simplified * 100:+.2f}%)") - else: - print("\nFormulas agree exactly for this config.") - - # Communication overhead estimate - if args.cp_size > 1: - dp_size = args.num_gpus // args.cp_size - parallelism = {"dp": dp_size, "cp": args.cp_size} - tracker = MFUTracker(config, b, s, num_gpus=args.num_gpus, parallelism=parallelism) - print(f"\n--- Communication Overhead (CP={args.cp_size}, P2P BW={args.p2p_bw} GB/s) ---") - print( - f" CP ring attention: {format_bytes(tracker.comm_bytes)}/step ({format_flops_exact(tracker.comm_bytes)} bytes)" - ) - comm_time = tracker.comm_bytes / (args.p2p_bw * 1e9) if args.p2p_bw > 0 else 0 - print(f" Estimated comm time: {comm_time:.4f}s") - - model_params = _estimate_model_params(config) - print_breakdown(breakdown, lm_head, config.num_hidden_layers, analytical, model_params) - - -def _cli_cp_comm(): - """Estimate CP communication volume.""" - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument("command") - parser.add_argument("--config-path", default="./model_configs/lingua-1B") - parser.add_argument("--batch-size", type=int, default=1) - parser.add_argument("--seq-len", type=int, default=16384) - parser.add_argument("--cp-size", type=int, default=2) - args = parser.parse_args() - - cfg_dict = load_model_config(args.config_path) - config = from_hf_config(cfg_dict) - b, s = args.batch_size, args.seq_len - - comm = estimate_cp_comm_bytes(b, s, config.num_hidden_layers, config.num_kv_heads, config.head_dim, args.cp_size) - print( - f"CP={args.cp_size}, B={b}, S={s}, L={config.num_hidden_layers}," - f" n_kv_heads={config.num_kv_heads}, head_dim={config.head_dim}" - ) - print(f"Estimated CP ring attention communication: {format_bytes(comm)}/step ({format_flops_exact(comm)} bytes)") - - -if __name__ == "__main__": - import sys - - commands = { - "bandwidth": ("Measure P2P bandwidth (requires torchrun --nproc_per_node=2)", _cli_bandwidth), - "gpu-info": ("Print GPU info and peak TFLOPS", _cli_gpu_info), - "flops": ("Compute FLOPs for a model config", _cli_flops), - "cp-comm": ("Estimate CP communication volume", _cli_cp_comm), - } - - if len(sys.argv) < 2 or sys.argv[1] in ("-h", "--help") or sys.argv[1] not in commands: - print("Usage: python flops.py [options]") - print(" torchrun --nproc_per_node=2 flops.py bandwidth") - print() - print("Commands:") - for cmd, (desc, _) in commands.items(): - print(f" {cmd:<16} {desc}") - sys.exit(0 if len(sys.argv) >= 2 and sys.argv[1] in ("-h", "--help") else 1) - - commands[sys.argv[1]][1]() diff --git a/bionemo-recipes/recipes/esm2_native_te/perf_logger.py b/bionemo-recipes/recipes/esm2_native_te/perf_logger.py index dd4070ecc9..fa78b15327 100644 --- a/bionemo-recipes/recipes/esm2_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/esm2_native_te/perf_logger.py @@ -32,18 +32,84 @@ logger = logging.getLogger(__name__) +# Dense BF16 tensor core peak TFLOPS (without sparsity). Product pages often list +# the 2x sparse number; dense = sparse / 2. Sources: NVIDIA datasheets for each GPU. +_GPU_PEAK_TFLOPS_BF16 = { + "H100": 989.0, + "H200": 989.0, + "A100": 312.0, + "A6000": 155.0, + "L40": 181.0, + "GH200": 989.0, + "B200": 2250.0, + "GB200": 2250.0, + "B300": 2500.0, + "GB300": 2500.0, +} + +# Model types that use gated MLP (SwiGLU/GeGLU) with 3 projections vs. standard FFN with 2. +_GATED_MLP_MODEL_TYPES = frozenset({"llama", "mistral", "qwen2"}) + + +def _detect_peak_tflops_bf16(): + """Auto-detect dense BF16 peak TFLOPS for the local GPU. Returns (peak, device_name).""" + if not torch.cuda.is_available(): + return None, "unknown" + name = torch.cuda.get_device_name(0) + for key, tflops in _GPU_PEAK_TFLOPS_BF16.items(): + if key.lower() in name.lower(): + return tflops, name + return None, name + + +def _compute_per_token_flops(model_config_dict: dict, seq_len: int) -> int: + """Training FLOPs per token for a transformer (forward + backward = 3x forward). + + First-principles matmul count: Q/K/V/O projections (GQA-aware), attention + logits/values (the S^2 cost expressed per-token as 4*S*H), 2-or-3-projection + MLP (SwiGLU detected via model_type), and LM head. The returned value is + multiplied by the actual unpadded token count at log time, so it naturally + handles BSHD, THD (sequence packing), DP, and CP: unpadded tokens on each + rank already reflect that rank's share of work. + """ + h = model_config_dict["hidden_size"] + n_heads = model_config_dict["num_attention_heads"] + n_kv = model_config_dict.get("num_key_value_heads", n_heads) + head_dim = h // n_heads + kv_dim = n_kv * head_dim + ffn = model_config_dict["intermediate_size"] + vocab = model_config_dict.get("vocab_size", 0) + num_layers = model_config_dict["num_hidden_layers"] + model_type = model_config_dict.get("model_type", "") + num_mlp_proj = 3 if model_type in _GATED_MLP_MODEL_TYPES else 2 + + per_layer = ( + 2 * h * h # Q projection + + 4 * h * kv_dim # K + V projections (GQA-aware) + + 2 * h * h # O projection + + 4 * seq_len * h # attention logits + values (S^2 -> S per token) + + 2 * num_mlp_proj * h * ffn # MLP (2 or 3 projections) + ) + lm_head = 2 * h * vocab if vocab > 0 else 0 + per_token_fwd = num_layers * per_layer + lm_head + return 3 * per_token_fwd + + class PerfLogger: """Class to log performance metrics to stdout and wandb, and print final averaged metrics at the end of training. Args: dist_config: The distributed configuration. args: The arguments. + model_config_dict: Optional HF-style model config dict. When supplied together with + ``args.log_mfu`` set to True, the logger computes per-step Model FLOPs Utilization + (``train/mfu_pct``) and throughput (``train/tflops_per_gpu``) on each logging step. Attributes: min_loss: The minimum loss seen so far. """ - def __init__(self, dist_config: DistributedConfig, args: DictConfig): + def __init__(self, dist_config: DistributedConfig, args: DictConfig, model_config_dict: dict | None = None): """Initialize the logger.""" self._dist_config = dist_config self._run_config = OmegaConf.to_container(args, resolve=True, throw_on_missing=True) @@ -53,6 +119,24 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig): self.logging_frequency = args.logger.frequency # Track whether to collect memory stats (disabled by default for max performance) + # MFU setup: compute per-token FLOPs and peak TFLOPS once at init. Actual FLOPs per + # step are derived at log time from the current batch's unpadded token count, which + # already reflects each rank's share under DP/CP and sequence packing. + self._log_mfu = bool(args.get("log_mfu", False)) and model_config_dict is not None + self._per_token_flops = 0 + self._peak_tflops: float | None = None + if self._log_mfu: + self._per_token_flops = _compute_per_token_flops(model_config_dict, args.dataset.max_seq_length) + self._peak_tflops, gpu_name = _detect_peak_tflops_bf16() + if dist_config.local_rank == 0: + logger.info( + "MFU tracking enabled: GPU=%s, peak=%s TFLOPS BF16, per-token FLOPs=%.3e, seq_len=%d", + gpu_name, + f"{self._peak_tflops:.1f}" if self._peak_tflops else "unknown", + float(self._per_token_flops), + args.dataset.max_seq_length, + ) + metrics_dict = { "train/loss": torchmetrics.MeanMetric(), "train/grad_norm": torchmetrics.MeanMetric(), @@ -65,12 +149,15 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig): "train/gpu_memory_allocated_max_gb": torchmetrics.MaxMetric(), "train/gpu_memory_allocated_mean_gb": torchmetrics.MeanMetric(), } + if self._log_mfu: + metrics_dict["train/tflops_per_gpu"] = torchmetrics.MeanMetric() + if self._peak_tflops is not None: + metrics_dict["train/mfu_pct"] = torchmetrics.MeanMetric() self.metrics = torchmetrics.MetricCollection(metrics_dict) # We move metrics to a GPU device so we can use torch.distributed to aggregate them before logging. self.metrics.to(torch.device(f"cuda:{dist_config.local_rank}")) self.previous_step_time = time.perf_counter() - self.last_step_time = None # Set after each logged step for MFU tracking if self._dist_config.is_main_process(): # Log the entire args object to wandb for experiment tracking and reproducibility. @@ -116,7 +203,6 @@ def log_step( time.perf_counter(), ) step_time = elapsed_time / self.logging_frequency - self.last_step_time = step_time self.metrics["train/loss"].update(outputs.loss) self.metrics["train/learning_rate"].update(lr) @@ -126,6 +212,16 @@ def log_step( self.metrics["train/unpadded_tokens_per_second_per_gpu"].update(num_unpadded_tokens / step_time) self.metrics["train/total_unpadded_tokens_per_batch"].update(num_unpadded_tokens) + if self._log_mfu: + # Current batch's unpadded tokens already reflect this rank's share (CP + # shards the batch; DP replicates the model across ranks on distinct + # micro-batches). step_time is the per-step average over the logging window. + flops_per_step = self._per_token_flops * num_unpadded_tokens + tflops_per_gpu = flops_per_step / step_time / 1e12 + self.metrics["train/tflops_per_gpu"].update(tflops_per_gpu) + if self._peak_tflops is not None: + self.metrics["train/mfu_pct"].update(tflops_per_gpu / self._peak_tflops * 100.0) + # Handle sequence packing for torchmetrics calculation. if outputs.logits.dim() < 3: outputs.logits = outputs.logits.unsqueeze(0) diff --git a/bionemo-recipes/recipes/esm2_native_te/tests/test_flops.py b/bionemo-recipes/recipes/esm2_native_te/tests/test_flops.py deleted file mode 100644 index f514205fef..0000000000 --- a/bionemo-recipes/recipes/esm2_native_te/tests/test_flops.py +++ /dev/null @@ -1,323 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# --- BEGIN COPIED FILE NOTICE --- -# This file is copied from: bionemo-recipes/recipes/llama3_native_te/tests/test_flops.py -# Do not modify this file directly. Instead, modify the source and run: -# python ci/scripts/check_copied_files.py --fix -# --- END COPIED FILE NOTICE --- - -"""Tests for the flops.py FLOPs counting and MFU module.""" - -import sys -from pathlib import Path - -import pytest - - -# Add parent directory so we can import flops -sys.path.insert(0, str(Path(__file__).parent.parent)) - -from flops import ( - MFUTracker, - ModelFLOPsConfig, - compute_flops_analytical, - compute_flops_hyena, - compute_flops_simplified, - estimate_cp_comm_bytes, - from_hf_config, -) - - -# ============================================================================ -# Test configs matching real models -# ============================================================================ - -LLAMA_1B_CONFIG = { - "hidden_size": 2048, - "num_hidden_layers": 25, - "num_attention_heads": 16, - "num_key_value_heads": 8, - "intermediate_size": 6144, - "vocab_size": 128256, - "model_type": "llama", - "hidden_act": "silu", -} - -ESM2_8M_CONFIG = { - "hidden_size": 320, - "num_hidden_layers": 6, - "num_attention_heads": 20, - "intermediate_size": 1280, - "vocab_size": 33, - "model_type": "nv_esm", - "hidden_act": "gelu", -} - -CODONFM_1B_CONFIG = { - "hidden_size": 2048, - "num_hidden_layers": 18, - "num_attention_heads": 16, - "intermediate_size": 8192, -} - - -# ============================================================================ -# Config auto-detection -# ============================================================================ - - -class TestFromHfConfig: - """Test auto-detection of model architecture from config dicts.""" - - def test_llama_detects_gqa_and_swiglu(self): - cfg = from_hf_config(LLAMA_1B_CONFIG) - assert cfg.num_kv_heads == 8 - assert cfg.num_mlp_projections == 3 - assert cfg.head_dim == 128 - - def test_esm2_detects_mha_and_standard_ffn(self): - cfg = from_hf_config(ESM2_8M_CONFIG) - assert cfg.num_kv_heads == 20 - assert cfg.num_mlp_projections == 2 - - def test_codonfm_defaults_to_mha_and_2_proj(self): - cfg = from_hf_config(CODONFM_1B_CONFIG) - assert cfg.num_kv_heads == 16 - assert cfg.num_mlp_projections == 2 - - def test_missing_vocab_defaults_to_no_lm_head(self): - cfg = from_hf_config( - {"hidden_size": 128, "num_hidden_layers": 2, "num_attention_heads": 4, "intermediate_size": 512} - ) - assert cfg.vocab_size == 0 - assert cfg.has_lm_head is False - - def test_overrides_take_precedence(self): - cfg = from_hf_config(ESM2_8M_CONFIG, num_mlp_projections=3, vocab_size=0, has_lm_head=False) - assert cfg.num_mlp_projections == 3 - assert cfg.vocab_size == 0 - assert cfg.has_lm_head is False - - -# ============================================================================ -# Analytical FLOPs formula -# ============================================================================ - - -class TestComputeFlopsAnalytical: - """Test the first-principles analytical FLOPs formula.""" - - def test_training_is_3x_forward(self): - cfg = from_hf_config(LLAMA_1B_CONFIG) - total, breakdown, lm_head = compute_flops_analytical(cfg, 1, 4096) - forward = cfg.num_hidden_layers * sum(breakdown.values()) + lm_head - assert total == 3 * forward - - def test_swiglu_has_3_mlp_projections(self): - cfg = from_hf_config(LLAMA_1B_CONFIG) - _, breakdown, _ = compute_flops_analytical(cfg, 1, 1024) - assert "Gate projection" in breakdown - assert "Up projection" in breakdown - assert "Down projection" in breakdown - - def test_standard_ffn_has_2_mlp_projections(self): - cfg = from_hf_config(ESM2_8M_CONFIG) - _, breakdown, _ = compute_flops_analytical(cfg, 1, 1024) - assert "Gate projection" not in breakdown - assert "Up projection" in breakdown - assert "Down projection" in breakdown - - def test_no_lm_head_when_vocab_zero(self): - cfg = from_hf_config(CODONFM_1B_CONFIG) - _, _, lm_head = compute_flops_analytical(cfg, 1, 1024) - assert lm_head == 0 - - def test_flops_scale_linearly_with_batch(self): - cfg = from_hf_config(LLAMA_1B_CONFIG) - flops_b1, _, _ = compute_flops_analytical(cfg, 1, 1024) - flops_b4, _, _ = compute_flops_analytical(cfg, 4, 1024) - assert flops_b4 == 4 * flops_b1 - - def test_known_value_llama_lingua_1b(self): - """Golden value: validated against PyTorch FlopCounterMode and README formula.""" - cfg = from_hf_config(LLAMA_1B_CONFIG) - total, _, _ = compute_flops_analytical(cfg, 1, 4096) - assert total == 47_687_021_887_488 - - -# ============================================================================ -# Simplified formula -# ============================================================================ - - -class TestComputeFlopsSimplified: - """Test the simplified README formula and its relationship to analytical.""" - - def test_matches_analytical_when_mha_and_i_equals_4h(self): - """ESM2 has MHA + I=4H + 2 projections: same assumptions as simplified formula.""" - cfg = from_hf_config(ESM2_8M_CONFIG) - analytical, _, _ = compute_flops_analytical(cfg, 1, 1024) - simplified = compute_flops_simplified(1, 1024, cfg.hidden_size, cfg.num_hidden_layers, cfg.vocab_size) - assert analytical == simplified - - def test_differs_when_gqa_or_swiglu(self): - """GQA + SwiGLU breaks the simplified formula's assumptions.""" - cfg_dict = { - **LLAMA_1B_CONFIG, - "num_attention_heads": 32, - "num_key_value_heads": 8, - "intermediate_size": 8192, - "num_hidden_layers": 16, - } - cfg = from_hf_config(cfg_dict) - analytical, _, _ = compute_flops_analytical(cfg, 1, 4096) - simplified = compute_flops_simplified(1, 4096, cfg.hidden_size, cfg.num_hidden_layers, cfg.vocab_size) - assert analytical != simplified - - -# ============================================================================ -# Hyena formula -# ============================================================================ - - -class TestComputeFlopsHyena: - """Test the Hyena (Evo2) FLOPs formula.""" - - @pytest.fixture() - def hyena_config(self): - return ModelFLOPsConfig( - hidden_size=1024, - num_hidden_layers=4, - num_attention_heads=8, - num_kv_heads=8, - head_dim=128, - intermediate_size=4096, - num_mlp_projections=3, - vocab_size=512, - has_lm_head=True, - ) - - def test_scales_subquadratically(self, hyena_config): - """Hyena uses O(S log S) convolution, not O(S^2) attention.""" - flops_1k = compute_flops_hyena(hyena_config, 1, 1024, hyena_layer_counts={"S": 0, "D": 0, "H": 4, "A": 0}) - flops_2k = compute_flops_hyena(hyena_config, 1, 2048, hyena_layer_counts={"S": 0, "D": 0, "H": 4, "A": 0}) - ratio = flops_2k / flops_1k - assert ratio < 3.0 # S^2 would give ~4x; Hyena should be well below - - def test_hybrid_attention_adds_quadratic_cost(self, hyena_config): - """Adding standard attention layers increases FLOPs due to S^2 term.""" - all_hyena = compute_flops_hyena(hyena_config, 1, 4096, hyena_layer_counts={"S": 0, "D": 0, "H": 4, "A": 0}) - with_attn = compute_flops_hyena(hyena_config, 1, 4096, hyena_layer_counts={"S": 0, "D": 0, "H": 2, "A": 2}) - assert with_attn > all_hyena - - -# ============================================================================ -# MFUTracker -# ============================================================================ - - -class TestMFUTracker: - """Test the MFUTracker class used by training scripts.""" - - def test_from_config_dict(self): - tracker = MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0) - assert tracker.total_flops == 47_687_021_887_488 - assert tracker.per_gpu_flops == tracker.total_flops - - def test_multi_gpu_divides_flops(self): - single = MFUTracker.from_config_dict( - LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, num_gpus=1, peak_tflops=155.0 - ) - multi = MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, num_gpus=2, peak_tflops=155.0) - assert multi.per_gpu_flops == single.total_flops // 2 - - def test_compute_mfu_correctness(self): - tracker = MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0) - result = tracker.compute_mfu(step_time=0.5) - expected_tflops = tracker.per_gpu_flops / 0.5 / 1e12 - expected_mfu = expected_tflops / 155.0 * 100 - assert abs(result["mfu"] - expected_mfu) < 0.01 - assert abs(result["tflops_per_gpu"] - expected_tflops) < 0.01 - - def test_mfu_inversely_proportional_to_step_time(self): - tracker = MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0) - fast = tracker.compute_mfu(step_time=0.5) - slow = tracker.compute_mfu(step_time=1.0) - assert abs(fast["mfu"] - 2 * slow["mfu"]) < 0.01 - - def test_all_formula_options(self): - for formula in ["analytical", "simplified", "hyena"]: - if formula == "hyena": - cfg = ModelFLOPsConfig( - hidden_size=1024, - num_hidden_layers=4, - num_attention_heads=8, - num_kv_heads=8, - head_dim=128, - intermediate_size=4096, - num_mlp_projections=3, - vocab_size=512, - has_lm_head=True, - ) - tracker = MFUTracker(cfg, batch_size=1, seq_len=4096, peak_tflops=155.0, formula=formula) - else: - tracker = MFUTracker.from_config_dict( - LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0, formula=formula - ) - assert tracker.total_flops > 0 - - def test_invalid_formula_raises(self): - with pytest.raises(ValueError, match="Unknown formula"): - MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0, formula="bad") - - def test_cp_communication_estimate(self): - tracker = MFUTracker.from_config_dict( - LLAMA_1B_CONFIG, batch_size=1, seq_len=16384, num_gpus=2, parallelism={"dp": 1, "cp": 2}, peak_tflops=155.0 - ) - assert tracker.comm_bytes > 0 - overhead = tracker.estimate_comm_overhead(step_time=1.0, measured_bw_gbps=6.6) - assert overhead["estimated_comm_time"] > 0 - assert 0 < overhead["comm_pct"] < 100 - - def test_no_comm_single_gpu(self): - tracker = MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0) - assert tracker.comm_bytes == 0 - - -# ============================================================================ -# Communication estimation -# ============================================================================ - - -class TestCPCommEstimation: - """Test CP ring attention communication byte estimates.""" - - def test_zero_without_cp(self): - assert estimate_cp_comm_bytes(1, 4096, 25, 8, 128, cp_size=1) == 0 - - def test_scales_linearly_with_seq_len(self): - comm_4k = estimate_cp_comm_bytes(1, 4096, 25, 8, 128, cp_size=2) - comm_8k = estimate_cp_comm_bytes(1, 8192, 25, 8, 128, cp_size=2) - assert comm_8k == 2 * comm_4k - - def test_scales_linearly_with_batch(self): - comm_b1 = estimate_cp_comm_bytes(1, 4096, 25, 8, 128, cp_size=2) - comm_b4 = estimate_cp_comm_bytes(4, 4096, 25, 8, 128, cp_size=2) - assert comm_b4 == 4 * comm_b1 - - def test_known_value_lingua_1b(self): - """Golden value for lingua-1B at S=4096, CP=2.""" - assert estimate_cp_comm_bytes(1, 4096, 25, 8, 128, cp_size=2) == 419_430_400 diff --git a/bionemo-recipes/recipes/esm2_native_te/train_ddp.py b/bionemo-recipes/recipes/esm2_native_te/train_ddp.py index 5cb3837a74..ebf36d9d47 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_ddp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_ddp.py @@ -29,7 +29,6 @@ from checkpoint import load_checkpoint_ddp, save_checkpoint_ddp, save_final_model_ddp, should_save_checkpoint from dataset import create_bshd_dataloader, create_thd_dataloader from distributed_config import DistributedConfig -from flops import MFUTracker, from_hf_config from modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM from perf_logger import PerfLogger from quantization import initialize_quant_stats_logging, resolve_layer_precision @@ -157,19 +156,11 @@ def main(args: DictConfig) -> float | None: start_step = 0 epoch = 0 - perf_logger = PerfLogger(dist_config, args) - - mfu_tracker = None - if args.get("log_mfu", False): - mfu_tracker = MFUTracker( - config=from_hf_config(config.to_dict()), - batch_size=args.dataset.micro_batch_size, - seq_len=args.dataset.max_seq_length, - num_gpus=dist_config.world_size, - parallelism={"dp": dist_config.world_size}, - ) - if dist_config.is_main_process(): - logger.info("MFU tracking enabled:\n%s", mfu_tracker.summary()) + perf_logger = PerfLogger( + dist_config, + args, + model_config_dict=config.to_dict() if args.get("log_mfu", False) else None, + ) # Training loop step = start_step @@ -200,17 +191,6 @@ def main(args: DictConfig) -> float | None: lr=optimizer.param_groups[0]["lr"], ) - if mfu_tracker is not None and perf_logger.last_step_time is not None: - mfu_info = mfu_tracker.compute_mfu(perf_logger.last_step_time) - if dist_config.is_main_process(): - logger.info( - "Step %d MFU: %.1f%% (%.1f TFLOPS/gpu, step_time=%.3fs)", - step, - mfu_info["mfu"], - mfu_info["tflops_per_gpu"], - mfu_info["step_time"], - ) - if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps): save_checkpoint_ddp( model=model, diff --git a/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py b/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py index 3be83938b0..25829c64cf 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py @@ -26,7 +26,6 @@ from checkpoint import load_checkpoint_ddp, save_checkpoint_ddp, save_final_model_ddp, should_save_checkpoint from dataset import create_cp_dataloader from distributed_config import DistributedConfig -from flops import MFUTracker, from_hf_config from modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM from perf_logger import PerfLogger from quantization import resolve_layer_precision @@ -166,19 +165,11 @@ def main(args: DictConfig) -> float | None: start_step = 0 epoch = 0 - perf_logger = PerfLogger(dist_config, args) - - mfu_tracker = None - if args.get("log_mfu", False): - mfu_tracker = MFUTracker( - config=from_hf_config(config.to_dict()), - batch_size=args.dataset.micro_batch_size, - seq_len=args.dataset.max_seq_length, - num_gpus=dist_config.world_size, - parallelism={"dp": dist_config.world_size // args.cp_size, "cp": args.cp_size}, - ) - if dist_config.is_main_process(): - logger.info("MFU tracking enabled:\n%s", mfu_tracker.summary()) + perf_logger = PerfLogger( + dist_config, + args, + model_config_dict=config.to_dict() if args.get("log_mfu", False) else None, + ) # Training loop step = start_step @@ -209,17 +200,6 @@ def main(args: DictConfig) -> float | None: lr=optimizer.param_groups[0]["lr"], ) - if mfu_tracker is not None and perf_logger.last_step_time is not None: - mfu_info = mfu_tracker.compute_mfu(perf_logger.last_step_time) - if dist_config.is_main_process(): - logger.info( - "Step %d MFU: %.1f%% (%.1f TFLOPS/gpu, step_time=%.3fs)", - step, - mfu_info["mfu"], - mfu_info["tflops_per_gpu"], - mfu_info["step_time"], - ) - if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps): save_checkpoint_ddp( model=model, diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py index f01dcfcc91..fcfbf17fa6 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py @@ -32,7 +32,6 @@ from checkpoint import load_checkpoint_fsdp2, save_checkpoint_fsdp2, save_final_model_fsdp2, should_save_checkpoint from dataset import create_bshd_dataloader, create_thd_dataloader from distributed_config import DistributedConfig -from flops import MFUTracker, from_hf_config from modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM from perf_logger import PerfLogger from quantization import WandBQuantLogger, initialize_quant_stats_logging, resolve_layer_precision @@ -183,19 +182,11 @@ def main(args: DictConfig) -> float | None: start_step = 0 epoch = 0 - perf_logger = PerfLogger(dist_config, args) - - mfu_tracker = None - if args.get("log_mfu", False): - mfu_tracker = MFUTracker( - config=from_hf_config(config.to_dict()), - batch_size=args.dataset.micro_batch_size, - seq_len=args.dataset.max_seq_length, - num_gpus=dist_config.world_size, - parallelism={"dp": dist_config.world_size}, - ) - if dist_config.is_main_process(): - logger.info("MFU tracking enabled:\n%s", mfu_tracker.summary()) + perf_logger = PerfLogger( + dist_config, + args, + model_config_dict=config.to_dict() if args.get("log_mfu", False) else None, + ) # Training loop step = start_step @@ -227,17 +218,6 @@ def main(args: DictConfig) -> float | None: lr=optimizer.param_groups[0]["lr"], ) - if mfu_tracker is not None and perf_logger.last_step_time is not None: - mfu_info = mfu_tracker.compute_mfu(perf_logger.last_step_time) - if dist_config.is_main_process(): - logger.info( - "Step %d MFU: %.1f%% (%.1f TFLOPS/gpu, step_time=%.3fs)", - step, - mfu_info["mfu"], - mfu_info["tflops_per_gpu"], - mfu_info["step_time"], - ) - if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps): save_checkpoint_fsdp2( model=model, diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py index 6aaa0a8f2f..06573112f8 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py @@ -28,7 +28,6 @@ from checkpoint import load_checkpoint_fsdp2, save_checkpoint_fsdp2, save_final_model_fsdp2, should_save_checkpoint from dataset import create_cp_dataloader from distributed_config import DistributedConfig -from flops import MFUTracker, from_hf_config from modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM from perf_logger import PerfLogger from quantization import resolve_layer_precision @@ -178,19 +177,11 @@ def main(args: DictConfig) -> float | None: start_step = 0 epoch = 0 - perf_logger = PerfLogger(dist_config, args) - - mfu_tracker = None - if args.get("log_mfu", False): - mfu_tracker = MFUTracker( - config=from_hf_config(config.to_dict()), - batch_size=args.dataset.micro_batch_size, - seq_len=args.dataset.max_seq_length, - num_gpus=dist_config.world_size, - parallelism={"dp": dist_config.world_size // args.cp_size, "cp": args.cp_size}, - ) - if dist_config.is_main_process(): - logger.info("MFU tracking enabled:\n%s", mfu_tracker.summary()) + perf_logger = PerfLogger( + dist_config, + args, + model_config_dict=config.to_dict() if args.get("log_mfu", False) else None, + ) # Training loop step = start_step @@ -221,17 +212,6 @@ def main(args: DictConfig) -> float | None: lr=optimizer.param_groups[0]["lr"], ) - if mfu_tracker is not None and perf_logger.last_step_time is not None: - mfu_info = mfu_tracker.compute_mfu(perf_logger.last_step_time) - if dist_config.is_main_process(): - logger.info( - "Step %d MFU: %.1f%% (%.1f TFLOPS/gpu, step_time=%.3fs)", - step, - mfu_info["mfu"], - mfu_info["tflops_per_gpu"], - mfu_info["step_time"], - ) - if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps): save_checkpoint_fsdp2( model=model, diff --git a/bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py b/bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py index 3afc23a0f7..2d61d14d39 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py @@ -31,7 +31,6 @@ from checkpoint import load_checkpoint_mfsdp, save_checkpoint_mfsdp, save_final_model_mfsdp, should_save_checkpoint from dataset import create_bshd_dataloader, create_thd_dataloader from distributed_config import DistributedConfig -from flops import MFUTracker, from_hf_config from modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM from perf_logger import PerfLogger from quantization import resolve_layer_precision @@ -164,19 +163,11 @@ def main(args: DictConfig) -> float | None: start_step = 0 epoch = 0 - perf_logger = PerfLogger(dist_config, args) - - mfu_tracker = None - if args.get("log_mfu", False): - mfu_tracker = MFUTracker( - config=from_hf_config(config.to_dict()), - batch_size=args.dataset.micro_batch_size, - seq_len=args.dataset.max_seq_length, - num_gpus=dist_config.world_size, - parallelism={"dp": dist_config.world_size}, - ) - if dist_config.is_main_process(): - logger.info("MFU tracking enabled:\n%s", mfu_tracker.summary()) + perf_logger = PerfLogger( + dist_config, + args, + model_config_dict=config.to_dict() if args.get("log_mfu", False) else None, + ) # Training loop step = start_step @@ -208,17 +199,6 @@ def main(args: DictConfig) -> float | None: lr=optimizer.param_groups[0]["lr"], ) - if mfu_tracker is not None and perf_logger.last_step_time is not None: - mfu_info = mfu_tracker.compute_mfu(perf_logger.last_step_time) - if dist_config.is_main_process(): - logger.info( - "Step %d MFU: %.1f%% (%.1f TFLOPS/gpu, step_time=%.3fs)", - step, - mfu_info["mfu"], - mfu_info["tflops_per_gpu"], - mfu_info["step_time"], - ) - if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps): save_checkpoint_mfsdp( model=model, diff --git a/bionemo-recipes/recipes/llama3_native_te/README.md b/bionemo-recipes/recipes/llama3_native_te/README.md index 71c68ad78a..39dc8d8ac8 100644 --- a/bionemo-recipes/recipes/llama3_native_te/README.md +++ b/bionemo-recipes/recipes/llama3_native_te/README.md @@ -420,16 +420,16 @@ Enable per-step Model FLOPs Utilization (MFU) logging during training by adding torchrun --nproc_per_node=2 train_fsdp2_cp.py --config-name L2_lingua_1b log_mfu=true ``` -This logs MFU (%), TFLOPS/GPU, and step time at each optimizer step. The module auto-detects model architecture (GQA, SwiGLU, etc.) from the model config. +This adds two metrics at each logging interval, emitted alongside existing metrics via WANDB and +stdout: -The `flops.py` CLI provides standalone utilities: +- `train/tflops_per_gpu` — achieved BF16 TFLOPS per GPU +- `train/mfu_pct` — MFU as a percentage of the GPU's peak dense BF16 TFLOPS -```bash -python flops.py gpu-info # Show GPU and peak TFLOPS -python flops.py flops --config-path ./model_configs/lingua-1B # Compute FLOPs for a config -python flops.py cp-comm --config-path ./model_configs/lingua-1B --cp-size 2 # CP comm estimate -torchrun --nproc_per_node=2 flops.py bandwidth # Measure P2P GPU bandwidth -``` +The FLOPs formula auto-detects model architecture from the HF config (GQA vs. MHA, SwiGLU vs. +standard FFN, LM head presence) and scales with the actual unpadded token count on each rank. This +means it naturally handles gradient accumulation, data parallelism, context parallelism, BSHD, and +THD (sequence packing) without per-strategy code paths. The implementation lives in `perf_logger.py`. ## Developer Guide diff --git a/bionemo-recipes/recipes/llama3_native_te/flops.py b/bionemo-recipes/recipes/llama3_native_te/flops.py deleted file mode 100644 index bce8388069..0000000000 --- a/bionemo-recipes/recipes/llama3_native_te/flops.py +++ /dev/null @@ -1,708 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Architecture-independent FLOPs counting, MFU calculation, and communication overhead estimation. - -Supports transformer architectures (Llama, ESM2, CodonFM, etc.) and Hyena (Evo2). -Designed to be copied to any recipe via check_copied_files.py and hooked into -training scripts for live MFU tracking. - -Usage as a library (in training scripts): - from flops import MFUTracker, from_hf_config - tracker = MFUTracker.from_config_dict(config_dict, batch_size=4, seq_len=4096, num_gpus=2) - mfu_info = tracker.compute_mfu(step_time=0.5) - -Usage as a CLI: - python flops.py gpu-info - python flops.py flops --config-path ./model_configs/lingua-1B - python flops.py cp-comm --config-path ./model_configs/lingua-1B --cp-size 2 - torchrun --nproc_per_node=2 flops.py bandwidth -""" - -import math -import time -from dataclasses import dataclass - -import torch -import torch.distributed as dist - - -# ============================================================================= -# GPU Peak TFLOPS -# ============================================================================= - -# Dense (without sparsity) BF16 tensor core peak TFLOPS. -# NVIDIA product pages often list the 2x sparse value; dense = sparse / 2. -# Sources: -# H100/H200/GH200: https://www.nvidia.com/en-us/data-center/h100/ (1,979 TFLOPS sparse) -# A100: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf -# A6000: https://www.nvidia.com/content/dam/en-zz/Solutions/design-visualization/quadro-product-literature/proviz-print-nvidia-rtx-a6000-datasheet-us-nvidia-1454980-r9-web%20(1).pdf -# L40: https://www.nvidia.com/en-us/data-center/l40/ -# B200/GB200: https://www.nvidia.com/en-us/data-center/dgx-b200/ -# B300/GB300: https://www.nvidia.com/en-us/data-center/gb300-nvl72/ (360 PFLOPS sparse / 72 GPUs / 2) -GPU_PEAK_TFLOPS_BF16 = { - "H100": 989.0, - "H200": 989.0, - "A100": 312.0, - "A6000": 155.0, - "L40": 181.0, - "GH200": 989.0, - "B200": 2250.0, - "GB200": 2250.0, - "B300": 2500.0, - "GB300": 2500.0, -} - - -def detect_gpu_peak_tflops(): - """Auto-detect GPU peak bf16 TFLOPS from device name via substring match.""" - device_name = torch.cuda.get_device_name(0) - for gpu_key, tflops in GPU_PEAK_TFLOPS_BF16.items(): - if gpu_key.lower() in device_name.lower(): - return tflops, device_name - return None, device_name - - -# ============================================================================= -# Model FLOPs Config -# ============================================================================= - -# Model types that use gated MLP (SwiGLU/GeGLU) with 3 projections instead of 2. -GATED_MLP_MODEL_TYPES = frozenset({"llama", "mistral", "qwen2"}) - - -@dataclass(frozen=True) -class ModelFLOPsConfig: - """Architecture-independent parameters for FLOPs calculation. - - Can be constructed manually or via from_hf_config() for auto-detection. - """ - - hidden_size: int # H - num_hidden_layers: int # L - num_attention_heads: int # n_heads - num_kv_heads: int # n_kv (== n_heads for MHA) - head_dim: int # H // n_heads - intermediate_size: int # I (FFN intermediate dimension) - num_mlp_projections: int # 2 (standard FFN) or 3 (SwiGLU/GLU) - vocab_size: int # V - has_lm_head: bool # True for LM models, False for ViT etc. - - -def from_hf_config(config_dict, **overrides): - """Create ModelFLOPsConfig from an HF-compatible config dict. - - Auto-detects architecture: - - GQA vs MHA: from num_key_value_heads (absent = MHA) - - Gated MLP (3 proj) vs standard FFN (2 proj): from model_type - - LM head: from vocab_size > 0 - - Args: - config_dict: Dict with standard HF config keys (hidden_size, num_hidden_layers, etc.). - Works with config.json dicts, config.to_dict(), or MODEL_PRESETS dicts. - **overrides: Explicit overrides for any field (e.g., num_mlp_projections=3). - """ - h = config_dict["hidden_size"] - n_heads = config_dict["num_attention_heads"] - n_kv = config_dict.get("num_key_value_heads", n_heads) - vocab = config_dict.get("vocab_size", 0) - model_type = config_dict.get("model_type", "") - - # Detect gated MLP (3 projections) vs standard FFN (2 projections). - # Llama/Mistral/Qwen use SwiGLU (gate + up + down = 3 projections). - # ESM2/CodonFM/Geneformer/BERT use standard FFN (up + down = 2 projections). - num_mlp_proj = 3 if model_type in GATED_MLP_MODEL_TYPES else 2 - - kwargs = { - "hidden_size": h, - "num_hidden_layers": config_dict["num_hidden_layers"], - "num_attention_heads": n_heads, - "num_kv_heads": n_kv, - "head_dim": h // n_heads, - "intermediate_size": config_dict["intermediate_size"], - "num_mlp_projections": num_mlp_proj, - "vocab_size": vocab, - "has_lm_head": vocab > 0, - } - kwargs.update(overrides) - return ModelFLOPsConfig(**kwargs) - - -# ============================================================================= -# FLOPs Formulas -# ============================================================================= - - -def compute_flops_analytical(config, batch_size, seq_len): - """First-principles FLOPs for any transformer (GQA/MHA, SwiGLU/GELU). - - Counts matmul FLOPs only (2 FLOPs per multiply-accumulate). Excludes softmax, - layer norms, activations, and element-wise ops. - - Handles: - - GQA vs MHA: K/V projection sizes based on config.num_kv_heads - - SwiGLU vs standard FFN: 2 or 3 MLP projections - - LM head presence - - Returns: - (total_training_flops, per_layer_breakdown_dict, lm_head_forward_flops) - """ - b, s, h = batch_size, seq_len, config.hidden_size - kv_dim = config.num_kv_heads * config.head_dim - - breakdown = { - "Q projection": 2 * b * s * h * h, - "K projection": 2 * b * s * h * kv_dim, - "V projection": 2 * b * s * h * kv_dim, - "O projection": 2 * b * s * h * h, - "Attn logits": 2 * b * s * s * h, - "Attn values": 2 * b * s * s * h, - } - - ffn = config.intermediate_size - if config.num_mlp_projections == 3: - # SwiGLU/GeGLU: gate + up + down = 3 matmuls - breakdown["Gate projection"] = 2 * b * s * h * ffn - breakdown["Up projection"] = 2 * b * s * h * ffn - breakdown["Down projection"] = 2 * b * s * ffn * h - else: - # Standard FFN: up + down = 2 matmuls - breakdown["Up projection"] = 2 * b * s * h * ffn - breakdown["Down projection"] = 2 * b * s * ffn * h - - per_layer_fwd = sum(breakdown.values()) - lm_head_fwd = 2 * b * s * h * config.vocab_size if config.has_lm_head else 0 - total_fwd = config.num_hidden_layers * per_layer_fwd + lm_head_fwd - total_training = 3 * total_fwd - - return total_training, breakdown, lm_head_fwd - - -def compute_flops_simplified(batch_size, seq_len, hidden_size, num_layers, vocab_size): - """Simplified formula assuming standard MHA + standard FFN with I=4H. - - This is the formula from the Llama3 README: - (24*B*S*H^2 + 4*B*S^2*H) * 3*L + 6*B*S*H*V - - The 24*H^2 coefficient assumes: 8*H^2 for 4 attention projections (MHA) + - 16*H^2 for 2 MLP projections with I=4H. - """ - b, s, h = batch_size, seq_len, hidden_size - return (24 * b * s * h * h + 4 * b * s * s * h) * (3 * num_layers) + (6 * b * s * h * vocab_size) - - -def compute_flops_hyena(config, batch_size, seq_len, hyena_layer_counts=None): - """FLOPs for Hyena-based models (Evo2). - - Based on evo2_provider.py. Hyena replaces attention with O(N) FFT convolution. - - Args: - config: ModelFLOPsConfig with model dimensions. - batch_size: Batch size. - seq_len: Sequence length. - hyena_layer_counts: Optional dict {"S": n, "D": n, "H": n, "A": n} for - short/medium/long conv and attention layer counts. If None, assumes - all layers are long-conv Hyena (H=num_layers, no attention). - """ - b, s, h = batch_size, seq_len, config.hidden_size - ffn = config.intermediate_size - - if hyena_layer_counts is None: - hyena_layer_counts = {"S": 0, "D": 0, "H": config.num_hidden_layers, "A": 0} - - # Common per-layer FLOPs - pre_attn_qkv_proj = 2 * 3 * b * s * h * h - post_attn_proj = 2 * b * s * h * h - glu_ffn = 2 * 3 * b * s * ffn * h - - # Layer-type-specific FLOPs (defaults from evo2_provider.py) - attn = 2 * 2 * b * h * s * s # Standard S^2 attention - hyena_proj = 2 * 3 * b * s * 3 * h # short_conv_L=3 default - hyena_short_conv = 2 * b * s * 7 * h # short_conv_len=7 - hyena_medium_conv = 2 * b * s * 128 * h # medium_conv_len=128 - hyena_long_fft = b * 10 * s * math.log2(max(s, 2)) * h - - n_s = hyena_layer_counts.get("S", 0) - n_d = hyena_layer_counts.get("D", 0) - n_h = hyena_layer_counts.get("H", 0) - n_a = hyena_layer_counts.get("A", 0) - - logits = 2 * b * s * h * config.vocab_size if config.has_lm_head else 0 - - total_fwd = ( - logits - + config.num_hidden_layers * (pre_attn_qkv_proj + post_attn_proj + glu_ffn) - + n_a * attn - + (n_s + n_d + n_h) * hyena_proj - + n_s * hyena_short_conv - + n_d * hyena_medium_conv - + int(n_h * hyena_long_fft) - ) - - return 3 * total_fwd - - -# ============================================================================= -# MFU Tracker -# ============================================================================= - - -class MFUTracker: - """Tracks MFU during training. Initialize once, call compute_mfu() per step. - - Usage: - tracker = MFUTracker.from_config_dict(config_dict, batch_size=4, seq_len=4096, num_gpus=2) - # In training loop: - mfu_info = tracker.compute_mfu(step_time=0.5) - print(f"MFU: {mfu_info['mfu']:.1f}%") - """ - - def __init__( - self, - config, - batch_size, - seq_len, - num_gpus=1, - parallelism=None, - peak_tflops=None, - formula="analytical", - hyena_layer_counts=None, - ): - """Initialize MFU tracker. - - Args: - config: ModelFLOPsConfig instance. - batch_size: Micro batch size per GPU. - seq_len: Sequence length. - num_gpus: Total number of GPUs. - parallelism: Dict of parallelism dimensions, e.g. {"dp": 1, "cp": 2, "tp": 1}. - Used for communication overhead estimation. - peak_tflops: GPU peak bf16 TFLOPS. Auto-detected if None. - formula: "analytical", "simplified", or "hyena". - hyena_layer_counts: For Hyena formula, dict of layer type counts. - """ - self.config = config - self.batch_size = batch_size - self.seq_len = seq_len - self.num_gpus = num_gpus - self.parallelism = parallelism or {} - self.formula = formula - - if formula == "analytical": - self.total_flops, self.breakdown, self.lm_head_flops = compute_flops_analytical( - config, batch_size, seq_len - ) - elif formula == "simplified": - self.total_flops = compute_flops_simplified( - batch_size, seq_len, config.hidden_size, config.num_hidden_layers, config.vocab_size - ) - self.breakdown = None - self.lm_head_flops = 0 - elif formula == "hyena": - self.total_flops = compute_flops_hyena(config, batch_size, seq_len, hyena_layer_counts) - self.breakdown = None - self.lm_head_flops = 0 - else: - raise ValueError(f"Unknown formula: {formula!r}. Use 'analytical', 'simplified', or 'hyena'.") - - self.per_gpu_flops = self.total_flops // max(num_gpus, 1) - - if peak_tflops is not None: - self.peak_tflops = peak_tflops - self.device_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "unknown" - else: - detected, self.device_name = detect_gpu_peak_tflops() - self.peak_tflops = detected - - self.comm_bytes = self._estimate_comm() - - @classmethod - def from_config_dict(cls, config_dict, batch_size, seq_len, **kwargs): - """Create from an HF config dict with auto-detection.""" - config = from_hf_config(config_dict) - return cls(config, batch_size, seq_len, **kwargs) - - def compute_mfu(self, step_time): - """Compute MFU from measured step time. - - Args: - step_time: Wall-clock time for one training step (seconds). - - Returns: - Dict with mfu (%), tflops_per_gpu, per_gpu_flops, total_flops, step_time. - """ - tflops = self.per_gpu_flops / step_time / 1e12 - mfu = tflops / self.peak_tflops * 100 if self.peak_tflops else 0.0 - return { - "mfu": mfu, - "tflops_per_gpu": tflops, - "per_gpu_flops": self.per_gpu_flops, - "total_flops": self.total_flops, - "step_time": step_time, - } - - def estimate_comm_overhead(self, step_time, measured_bw_gbps=None): - """Estimate communication overhead as a fraction of step time. - - Args: - step_time: Measured step time in seconds. - measured_bw_gbps: Measured P2P bandwidth in GB/s. If None, uses a default estimate. - - Returns: - Dict with comm_bytes, estimated_comm_time, comm_pct. - """ - bw = measured_bw_gbps or 6.0 # Default ~PCIe Gen3 x8 - comm_time = self.comm_bytes / (bw * 1e9) if bw > 0 else 0.0 - comm_pct = comm_time / step_time * 100 if step_time > 0 else 0.0 - return {"comm_bytes": self.comm_bytes, "estimated_comm_time": comm_time, "comm_pct": comm_pct} - - def _estimate_comm(self): - """Estimate total communication bytes per step based on parallelism.""" - total = 0 - cp_size = self.parallelism.get("cp", 1) - dp_size = self.parallelism.get("dp", 1) - - if cp_size > 1: - total += estimate_cp_comm_bytes( - self.batch_size, - self.seq_len, - self.config.num_hidden_layers, - self.config.num_kv_heads, - self.config.head_dim, - cp_size, - ) - - if dp_size > 1: - # FSDP reduce-scatter estimate: ~2 * model_params * dtype_bytes * (dp-1)/dp - model_params = _estimate_model_params(self.config) - total += 2 * model_params * 2 * (dp_size - 1) // dp_size - - return total - - def summary(self): - """Return a human-readable summary string.""" - lines = [ - f"MFUTracker: {self.formula} formula, {self.num_gpus} GPU(s)", - f" Model: H={self.config.hidden_size}, L={self.config.num_hidden_layers}," - f" heads={self.config.num_attention_heads}, kv_heads={self.config.num_kv_heads}," - f" I={self.config.intermediate_size}, V={self.config.vocab_size}", - f" MLP projections: {self.config.num_mlp_projections}" - f" ({'SwiGLU/GLU' if self.config.num_mlp_projections == 3 else 'standard FFN'})", - f" Batch: B={self.batch_size}, S={self.seq_len}", - f" Total FLOPs/step: {format_flops(self.total_flops)} ({format_flops_exact(self.total_flops)})", - f" Per-GPU FLOPs: {format_flops(self.per_gpu_flops)}", - f" GPU: {self.device_name} (Peak: {self.peak_tflops} TFLOPS)" if self.peak_tflops else " GPU: unknown", - ] - if self.parallelism: - lines.append(f" Parallelism: {self.parallelism}") - if self.comm_bytes > 0: - lines.append(f" Estimated comm: {format_bytes(self.comm_bytes)}/step") - return "\n".join(lines) - - -# ============================================================================= -# Communication Estimation -# ============================================================================= - - -def estimate_cp_comm_bytes(b, s, num_layers, n_kv_heads, head_dim, cp_size, dtype_bytes=2): - """Estimate total bytes transferred for CP ring attention per training step. - - Ring attention sends local KV chunks around the ring. Per layer forward: - (cp-1) steps, each sending B * (S/cp) * 2 * kv_dim * dtype_bytes. - Training = ~2x forward communication (forward sends KV, backward sends dKV). - """ - if cp_size <= 1: - return 0 - s_local = s // cp_size - kv_dim = n_kv_heads * head_dim - per_layer_fwd = (cp_size - 1) * b * s_local * 2 * kv_dim * dtype_bytes - return 2 * num_layers * per_layer_fwd - - -def _estimate_model_params(config): - """Rough parameter count estimate from config dimensions.""" - h = config.hidden_size - kv_dim = config.num_kv_heads * config.head_dim - attn_params = h * h + 2 * h * kv_dim + h * h # Q + K + V + O - mlp_params = config.num_mlp_projections * h * config.intermediate_size - layer_params = attn_params + mlp_params - total = config.num_hidden_layers * layer_params - if config.has_lm_head: - total += config.vocab_size * h * 2 # embed + lm_head - return total - - -# ============================================================================= -# Utilities -# ============================================================================= - - -def measure_bus_bandwidth(device, world_size, num_iters=20, num_elements=10_000_000): - """Measure unidirectional P2P bandwidth via send/recv (matches CP ring pattern).""" - if world_size <= 1: - return 0.0 - - rank = dist.get_rank() - tensor = torch.randn(num_elements, device=device, dtype=torch.bfloat16) - peer = 1 - rank - - for _ in range(5): - if rank == 0: - dist.send(tensor, dst=peer) - dist.recv(tensor, src=peer) - else: - dist.recv(tensor, src=peer) - dist.send(tensor, dst=peer) - torch.cuda.synchronize() - - dist.barrier() - torch.cuda.synchronize() - start = time.perf_counter() - for _ in range(num_iters): - if rank == 0: - dist.send(tensor, dst=peer) - else: - dist.recv(tensor, src=peer) - torch.cuda.synchronize() - elapsed = time.perf_counter() - start - - data_bytes = tensor.nelement() * tensor.element_size() - return num_iters * data_bytes / elapsed / 1e9 - - -def load_model_config(config_path): - """Load model config dict from a local path or HuggingFace model ID. - - Supports: - - Local directory: ./model_configs/lingua-1B (reads config.json inside) - - Local file: ./model_configs/lingua-1B/config.json - - HF model ID: nvidia/esm2_t36_3B_UR50D (fetches from HuggingFace Hub) - """ - import json - from pathlib import Path - - path = Path(config_path) - if path.is_dir(): - path = path / "config.json" - if path.exists(): - return json.loads(path.read_text()) - - # Fall back to HuggingFace Hub - from transformers import AutoConfig - - hf_config = AutoConfig.from_pretrained(config_path, trust_remote_code=True) - return hf_config.to_dict() - - -# ============================================================================= -# Formatting -# ============================================================================= - - -def format_flops(flops): - """Format FLOPs with appropriate unit (G/T/P).""" - if flops >= 1e15: - return f"{flops / 1e15:.2f} P" - elif flops >= 1e12: - return f"{flops / 1e12:.2f} T" - elif flops >= 1e9: - return f"{flops / 1e9:.2f} G" - else: - return f"{flops:.2e}" - - -def format_flops_exact(flops): - """Format FLOPs as the full integer with commas.""" - return f"{int(flops):,}" - - -def format_bytes(num_bytes): - """Format bytes with appropriate unit.""" - if num_bytes >= 1e9: - return f"{num_bytes / 1e9:.2f} GB" - elif num_bytes >= 1e6: - return f"{num_bytes / 1e6:.2f} MB" - elif num_bytes >= 1e3: - return f"{num_bytes / 1e3:.2f} KB" - else: - return f"{num_bytes:.0f} B" - - -def print_breakdown(breakdown, lm_head_fwd, num_layers, total_flops, model_params): - """Print first-principles FLOPs breakdown.""" - print() - print("--- First Principles Breakdown (forward pass, per layer) ---") - per_layer_total = sum(breakdown.values()) - for component, flops_val in breakdown.items(): - pct = flops_val / per_layer_total * 100 - print(f" {component:<20} {format_flops(flops_val):>12} ({pct:>5.1f}%)") - total_fwd = num_layers * per_layer_total + lm_head_fwd - print(f" {'LM head':<20} {format_flops(lm_head_fwd):>12}") - print(f" {'Per-layer total':<20} {format_flops(per_layer_total):>12}") - print(f" {'All layers (x' + str(num_layers) + ')':<20} {format_flops(num_layers * per_layer_total):>12}") - print(f" {'Total forward':<20} {format_flops(total_fwd):>12}") - print(f" {'Total training (3x)':<20} {format_flops(total_flops):>12}") - print(f" {'Model params':<20} {model_params / 1e9:.2f}B") - - -# ============================================================================= -# CLI -# ============================================================================= - - -def _cli_bandwidth(): - """Measure P2P bandwidth. Launch with: torchrun --nproc_per_node=2 flops.py bandwidth.""" - import os - - dist.init_process_group(backend="cpu:gloo,cuda:nccl") - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - torch.cuda.set_device(local_rank) - rank = dist.get_rank() - world_size = dist.get_world_size() - device = torch.device(f"cuda:{local_rank}") - - if rank == 0: - print(f"Measuring P2P bandwidth between {world_size} GPUs...") - for i in range(world_size): - print(f" GPU {i}: {torch.cuda.get_device_name(i)}") - - bw = measure_bus_bandwidth(device, world_size) - if rank == 0: - print(f"\nUnidirectional P2P bandwidth: {bw:.2f} GB/s") - - dist.destroy_process_group() - - -def _cli_gpu_info(): - """Print GPU info and peak TFLOPS.""" - peak, name = detect_gpu_peak_tflops() - print(f"GPU: {name}") - if peak: - print(f"Peak bf16 TFLOPS: {peak:.1f}") - else: - print("Peak bf16 TFLOPS: unknown (use --peak-tflops to override)") - print() - print("Known GPUs:") - for gpu, tflops in GPU_PEAK_TFLOPS_BF16.items(): - print(f" {gpu:<16} {tflops:>8.1f} TFLOPS") - - -def _cli_flops(): - """Compute FLOPs for a model config.""" - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument("command") - parser.add_argument("--config-path", default="./model_configs/lingua-1B") - parser.add_argument("--batch-size", type=int, default=1) - parser.add_argument("--seq-len", type=int, default=4096) - parser.add_argument("--formula", default="analytical", choices=["analytical", "simplified"]) - parser.add_argument("--num-gpus", type=int, default=1, help="Number of GPUs for per-GPU FLOPs and comm estimates") - parser.add_argument("--cp-size", type=int, default=1, help="Context parallelism size for comm overhead estimate") - parser.add_argument("--p2p-bw", type=float, default=6.0, help="P2P bandwidth in GB/s for comm time estimate") - args = parser.parse_args() - - cfg_dict = load_model_config(args.config_path) - config = from_hf_config(cfg_dict) - b, s = args.batch_size, args.seq_len - - print( - f"Config: H={config.hidden_size}, L={config.num_hidden_layers}," - f" n_heads={config.num_attention_heads}, n_kv={config.num_kv_heads}," - f" I={config.intermediate_size}, V={config.vocab_size}" - ) - print( - f"MLP: {config.num_mlp_projections} projections" - f" ({'SwiGLU/GLU' if config.num_mlp_projections == 3 else 'standard FFN'})" - ) - print(f"Batch: B={b}, S={s}, GPUs={args.num_gpus}, CP={args.cp_size}") - print() - - simplified = compute_flops_simplified(b, s, config.hidden_size, config.num_hidden_layers, config.vocab_size) - analytical, breakdown, lm_head = compute_flops_analytical(config, b, s) - - print(f"{'Method':<24} {'FLOPs/step':>14} {'Per-GPU':>14} {'Exact':>30}") - print("-" * 86) - for name, flops in [("Simplified (README)", simplified), ("Analytical", analytical)]: - per_gpu = flops // max(args.num_gpus, 1) - print(f"{name:<24} {format_flops(flops):>14} {format_flops(per_gpu):>14} {format_flops_exact(flops):>30}") - - if simplified != analytical: - diff = analytical - simplified - print(f"\nDifference: {format_flops_exact(diff)} ({diff / simplified * 100:+.2f}%)") - else: - print("\nFormulas agree exactly for this config.") - - # Communication overhead estimate - if args.cp_size > 1: - dp_size = args.num_gpus // args.cp_size - parallelism = {"dp": dp_size, "cp": args.cp_size} - tracker = MFUTracker(config, b, s, num_gpus=args.num_gpus, parallelism=parallelism) - print(f"\n--- Communication Overhead (CP={args.cp_size}, P2P BW={args.p2p_bw} GB/s) ---") - print( - f" CP ring attention: {format_bytes(tracker.comm_bytes)}/step ({format_flops_exact(tracker.comm_bytes)} bytes)" - ) - comm_time = tracker.comm_bytes / (args.p2p_bw * 1e9) if args.p2p_bw > 0 else 0 - print(f" Estimated comm time: {comm_time:.4f}s") - - model_params = _estimate_model_params(config) - print_breakdown(breakdown, lm_head, config.num_hidden_layers, analytical, model_params) - - -def _cli_cp_comm(): - """Estimate CP communication volume.""" - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument("command") - parser.add_argument("--config-path", default="./model_configs/lingua-1B") - parser.add_argument("--batch-size", type=int, default=1) - parser.add_argument("--seq-len", type=int, default=16384) - parser.add_argument("--cp-size", type=int, default=2) - args = parser.parse_args() - - cfg_dict = load_model_config(args.config_path) - config = from_hf_config(cfg_dict) - b, s = args.batch_size, args.seq_len - - comm = estimate_cp_comm_bytes(b, s, config.num_hidden_layers, config.num_kv_heads, config.head_dim, args.cp_size) - print( - f"CP={args.cp_size}, B={b}, S={s}, L={config.num_hidden_layers}," - f" n_kv_heads={config.num_kv_heads}, head_dim={config.head_dim}" - ) - print(f"Estimated CP ring attention communication: {format_bytes(comm)}/step ({format_flops_exact(comm)} bytes)") - - -if __name__ == "__main__": - import sys - - commands = { - "bandwidth": ("Measure P2P bandwidth (requires torchrun --nproc_per_node=2)", _cli_bandwidth), - "gpu-info": ("Print GPU info and peak TFLOPS", _cli_gpu_info), - "flops": ("Compute FLOPs for a model config", _cli_flops), - "cp-comm": ("Estimate CP communication volume", _cli_cp_comm), - } - - if len(sys.argv) < 2 or sys.argv[1] in ("-h", "--help") or sys.argv[1] not in commands: - print("Usage: python flops.py [options]") - print(" torchrun --nproc_per_node=2 flops.py bandwidth") - print() - print("Commands:") - for cmd, (desc, _) in commands.items(): - print(f" {cmd:<16} {desc}") - sys.exit(0 if len(sys.argv) >= 2 and sys.argv[1] in ("-h", "--help") else 1) - - commands[sys.argv[1]][1]() diff --git a/bionemo-recipes/recipes/llama3_native_te/perf_logger.py b/bionemo-recipes/recipes/llama3_native_te/perf_logger.py index 726eb19e8e..3bebdee347 100644 --- a/bionemo-recipes/recipes/llama3_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/llama3_native_te/perf_logger.py @@ -33,18 +33,91 @@ logger = logging.getLogger(__name__) +# Dense BF16 tensor core peak TFLOPS (without sparsity). Product pages often list +# the 2x sparse number; dense = sparse / 2. Sources: NVIDIA datasheets for each GPU. +_GPU_PEAK_TFLOPS_BF16 = { + "H100": 989.0, + "H200": 989.0, + "A100": 312.0, + "A6000": 155.0, + "L40": 181.0, + "GH200": 989.0, + "B200": 2250.0, + "GB200": 2250.0, + "B300": 2500.0, + "GB300": 2500.0, +} + +# Model types that use gated MLP (SwiGLU/GeGLU) with 3 projections vs. standard FFN with 2. +_GATED_MLP_MODEL_TYPES = frozenset({"llama", "mistral", "qwen2"}) + + +def _detect_peak_tflops_bf16(): + """Auto-detect dense BF16 peak TFLOPS for the local GPU. Returns (peak, device_name).""" + if not torch.cuda.is_available(): + return None, "unknown" + name = torch.cuda.get_device_name(0) + for key, tflops in _GPU_PEAK_TFLOPS_BF16.items(): + if key.lower() in name.lower(): + return tflops, name + return None, name + + +def _compute_per_token_flops(model_config_dict: dict, seq_len: int) -> int: + """Training FLOPs per token for a transformer (forward + backward = 3x forward). + + First-principles matmul count: Q/K/V/O projections (GQA-aware), attention + logits/values (the S^2 cost expressed per-token as 4*S*H), 2-or-3-projection + MLP (SwiGLU detected via model_type), and LM head. The returned value is + multiplied by the actual unpadded token count at log time, so it naturally + handles BSHD, THD (sequence packing), gradient accumulation, DP, and CP: + unpadded tokens on each rank already reflect that rank's share of work. + """ + h = model_config_dict["hidden_size"] + n_heads = model_config_dict["num_attention_heads"] + n_kv = model_config_dict.get("num_key_value_heads", n_heads) + head_dim = h // n_heads + kv_dim = n_kv * head_dim + ffn = model_config_dict["intermediate_size"] + vocab = model_config_dict.get("vocab_size", 0) + num_layers = model_config_dict["num_hidden_layers"] + model_type = model_config_dict.get("model_type", "") + num_mlp_proj = 3 if model_type in _GATED_MLP_MODEL_TYPES else 2 + + per_layer = ( + 2 * h * h # Q projection + + 4 * h * kv_dim # K + V projections (GQA-aware) + + 2 * h * h # O projection + + 4 * seq_len * h # attention logits + values (S^2 -> S per token) + + 2 * num_mlp_proj * h * ffn # MLP (2 or 3 projections) + ) + lm_head = 2 * h * vocab if vocab > 0 else 0 + per_token_fwd = num_layers * per_layer + lm_head + return 3 * per_token_fwd + + class PerfLogger: """Class to log performance metrics to stdout and wandb, and print final averaged metrics at the end of training. Args: dist_config: The distributed configuration. args: The arguments. + start_step: The step to resume progress-bar counting from. + model_config_dict: Optional HF-style model config dict. When supplied together with + ``args.log_mfu`` set to True, the logger computes per-step Model FLOPs Utilization + (``train/mfu_pct``) and throughput (``train/tflops_per_gpu``) on each logging step. Attributes: min_loss: The minimum loss seen so far. """ - def __init__(self, dist_config: DistributedConfig, args: DictConfig, start_step: int): + def __init__( + self, + dist_config: DistributedConfig, + args: DictConfig, + start_step: int, + model_config_dict: dict | None = None, + ): """Initialize the logger.""" self._dist_config = dist_config self._run_config = OmegaConf.to_container(args, resolve=True, throw_on_missing=True) @@ -54,6 +127,24 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig, start_step: self.logging_frequency = args.logger.frequency + # MFU setup: compute per-token FLOPs and peak TFLOPS once at init. Actual FLOPs per + # step are derived at log time from the tracked unpadded token count, which already + # reflects each rank's share under DP/CP and sequence packing. + self._log_mfu = bool(args.get("log_mfu", False)) and model_config_dict is not None + self._per_token_flops = 0 + self._peak_tflops: float | None = None + if self._log_mfu: + self._per_token_flops = _compute_per_token_flops(model_config_dict, args.dataset.max_seq_length) + self._peak_tflops, gpu_name = _detect_peak_tflops_bf16() + if dist_config.local_rank == 0: + logger.info( + "MFU tracking enabled: GPU=%s, peak=%s TFLOPS BF16, per-token FLOPs=%.3e, seq_len=%d", + gpu_name, + f"{self._peak_tflops:.1f}" if self._peak_tflops else "unknown", + float(self._per_token_flops), + args.dataset.max_seq_length, + ) + metrics_dict = { "train/loss": torchmetrics.MeanMetric(), "train/grad_norm": torchmetrics.MeanMetric(), @@ -65,6 +156,10 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig, start_step: "train/gpu_memory_allocated_max_gb": torchmetrics.MaxMetric(), "train/gpu_memory_allocated_mean_gb": torchmetrics.MeanMetric(), } + if self._log_mfu: + metrics_dict["train/tflops_per_gpu"] = torchmetrics.MeanMetric() + if self._peak_tflops is not None: + metrics_dict["train/mfu_pct"] = torchmetrics.MeanMetric() self.metrics = torchmetrics.MetricCollection(metrics_dict) # We move metrics to a GPU device so we can use torch.distributed to aggregate them before logging. @@ -173,6 +268,17 @@ def log_step( self.metrics["train/unpadded_tokens_per_second_per_gpu"].update(self.num_unpadded_tokens / step_time) self.metrics["train/total_unpadded_tokens_per_batch"].update(self.num_unpadded_tokens) + if self._log_mfu: + # num_unpadded_tokens is accumulated over the grad-acc micro-batches of one + # optimizer step (the last step in the logging window), so this yields FLOPs + # per optimizer step per rank. step_time is already the per-step average. + tokens_on_rank = self.num_unpadded_tokens.item() + flops_per_step = self._per_token_flops * tokens_on_rank + tflops_per_gpu = flops_per_step / step_time / 1e12 + self.metrics["train/tflops_per_gpu"].update(tflops_per_gpu) + if self._peak_tflops is not None: + self.metrics["train/mfu_pct"].update(tflops_per_gpu / self._peak_tflops * 100.0) + memory_allocated = torch.cuda.memory_allocated() / (1024**3) self.metrics["train/gpu_memory_allocated_max_gb"].update(memory_allocated) self.metrics["train/gpu_memory_allocated_mean_gb"].update(memory_allocated) diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_flops.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_flops.py deleted file mode 100644 index e9e3943658..0000000000 --- a/bionemo-recipes/recipes/llama3_native_te/tests/test_flops.py +++ /dev/null @@ -1,317 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for the flops.py FLOPs counting and MFU module.""" - -import sys -from pathlib import Path - -import pytest - - -# Add parent directory so we can import flops -sys.path.insert(0, str(Path(__file__).parent.parent)) - -from flops import ( - MFUTracker, - ModelFLOPsConfig, - compute_flops_analytical, - compute_flops_hyena, - compute_flops_simplified, - estimate_cp_comm_bytes, - from_hf_config, -) - - -# ============================================================================ -# Test configs matching real models -# ============================================================================ - -LLAMA_1B_CONFIG = { - "hidden_size": 2048, - "num_hidden_layers": 25, - "num_attention_heads": 16, - "num_key_value_heads": 8, - "intermediate_size": 6144, - "vocab_size": 128256, - "model_type": "llama", - "hidden_act": "silu", -} - -ESM2_8M_CONFIG = { - "hidden_size": 320, - "num_hidden_layers": 6, - "num_attention_heads": 20, - "intermediate_size": 1280, - "vocab_size": 33, - "model_type": "nv_esm", - "hidden_act": "gelu", -} - -CODONFM_1B_CONFIG = { - "hidden_size": 2048, - "num_hidden_layers": 18, - "num_attention_heads": 16, - "intermediate_size": 8192, -} - - -# ============================================================================ -# Config auto-detection -# ============================================================================ - - -class TestFromHfConfig: - """Test auto-detection of model architecture from config dicts.""" - - def test_llama_detects_gqa_and_swiglu(self): - cfg = from_hf_config(LLAMA_1B_CONFIG) - assert cfg.num_kv_heads == 8 - assert cfg.num_mlp_projections == 3 - assert cfg.head_dim == 128 - - def test_esm2_detects_mha_and_standard_ffn(self): - cfg = from_hf_config(ESM2_8M_CONFIG) - assert cfg.num_kv_heads == 20 - assert cfg.num_mlp_projections == 2 - - def test_codonfm_defaults_to_mha_and_2_proj(self): - cfg = from_hf_config(CODONFM_1B_CONFIG) - assert cfg.num_kv_heads == 16 - assert cfg.num_mlp_projections == 2 - - def test_missing_vocab_defaults_to_no_lm_head(self): - cfg = from_hf_config( - {"hidden_size": 128, "num_hidden_layers": 2, "num_attention_heads": 4, "intermediate_size": 512} - ) - assert cfg.vocab_size == 0 - assert cfg.has_lm_head is False - - def test_overrides_take_precedence(self): - cfg = from_hf_config(ESM2_8M_CONFIG, num_mlp_projections=3, vocab_size=0, has_lm_head=False) - assert cfg.num_mlp_projections == 3 - assert cfg.vocab_size == 0 - assert cfg.has_lm_head is False - - -# ============================================================================ -# Analytical FLOPs formula -# ============================================================================ - - -class TestComputeFlopsAnalytical: - """Test the first-principles analytical FLOPs formula.""" - - def test_training_is_3x_forward(self): - cfg = from_hf_config(LLAMA_1B_CONFIG) - total, breakdown, lm_head = compute_flops_analytical(cfg, 1, 4096) - forward = cfg.num_hidden_layers * sum(breakdown.values()) + lm_head - assert total == 3 * forward - - def test_swiglu_has_3_mlp_projections(self): - cfg = from_hf_config(LLAMA_1B_CONFIG) - _, breakdown, _ = compute_flops_analytical(cfg, 1, 1024) - assert "Gate projection" in breakdown - assert "Up projection" in breakdown - assert "Down projection" in breakdown - - def test_standard_ffn_has_2_mlp_projections(self): - cfg = from_hf_config(ESM2_8M_CONFIG) - _, breakdown, _ = compute_flops_analytical(cfg, 1, 1024) - assert "Gate projection" not in breakdown - assert "Up projection" in breakdown - assert "Down projection" in breakdown - - def test_no_lm_head_when_vocab_zero(self): - cfg = from_hf_config(CODONFM_1B_CONFIG) - _, _, lm_head = compute_flops_analytical(cfg, 1, 1024) - assert lm_head == 0 - - def test_flops_scale_linearly_with_batch(self): - cfg = from_hf_config(LLAMA_1B_CONFIG) - flops_b1, _, _ = compute_flops_analytical(cfg, 1, 1024) - flops_b4, _, _ = compute_flops_analytical(cfg, 4, 1024) - assert flops_b4 == 4 * flops_b1 - - def test_known_value_llama_lingua_1b(self): - """Golden value: validated against PyTorch FlopCounterMode and README formula.""" - cfg = from_hf_config(LLAMA_1B_CONFIG) - total, _, _ = compute_flops_analytical(cfg, 1, 4096) - assert total == 47_687_021_887_488 - - -# ============================================================================ -# Simplified formula -# ============================================================================ - - -class TestComputeFlopsSimplified: - """Test the simplified README formula and its relationship to analytical.""" - - def test_matches_analytical_when_mha_and_i_equals_4h(self): - """ESM2 has MHA + I=4H + 2 projections: same assumptions as simplified formula.""" - cfg = from_hf_config(ESM2_8M_CONFIG) - analytical, _, _ = compute_flops_analytical(cfg, 1, 1024) - simplified = compute_flops_simplified(1, 1024, cfg.hidden_size, cfg.num_hidden_layers, cfg.vocab_size) - assert analytical == simplified - - def test_differs_when_gqa_or_swiglu(self): - """GQA + SwiGLU breaks the simplified formula's assumptions.""" - cfg_dict = { - **LLAMA_1B_CONFIG, - "num_attention_heads": 32, - "num_key_value_heads": 8, - "intermediate_size": 8192, - "num_hidden_layers": 16, - } - cfg = from_hf_config(cfg_dict) - analytical, _, _ = compute_flops_analytical(cfg, 1, 4096) - simplified = compute_flops_simplified(1, 4096, cfg.hidden_size, cfg.num_hidden_layers, cfg.vocab_size) - assert analytical != simplified - - -# ============================================================================ -# Hyena formula -# ============================================================================ - - -class TestComputeFlopsHyena: - """Test the Hyena (Evo2) FLOPs formula.""" - - @pytest.fixture() - def hyena_config(self): - return ModelFLOPsConfig( - hidden_size=1024, - num_hidden_layers=4, - num_attention_heads=8, - num_kv_heads=8, - head_dim=128, - intermediate_size=4096, - num_mlp_projections=3, - vocab_size=512, - has_lm_head=True, - ) - - def test_scales_subquadratically(self, hyena_config): - """Hyena uses O(S log S) convolution, not O(S^2) attention.""" - flops_1k = compute_flops_hyena(hyena_config, 1, 1024, hyena_layer_counts={"S": 0, "D": 0, "H": 4, "A": 0}) - flops_2k = compute_flops_hyena(hyena_config, 1, 2048, hyena_layer_counts={"S": 0, "D": 0, "H": 4, "A": 0}) - ratio = flops_2k / flops_1k - assert ratio < 3.0 # S^2 would give ~4x; Hyena should be well below - - def test_hybrid_attention_adds_quadratic_cost(self, hyena_config): - """Adding standard attention layers increases FLOPs due to S^2 term.""" - all_hyena = compute_flops_hyena(hyena_config, 1, 4096, hyena_layer_counts={"S": 0, "D": 0, "H": 4, "A": 0}) - with_attn = compute_flops_hyena(hyena_config, 1, 4096, hyena_layer_counts={"S": 0, "D": 0, "H": 2, "A": 2}) - assert with_attn > all_hyena - - -# ============================================================================ -# MFUTracker -# ============================================================================ - - -class TestMFUTracker: - """Test the MFUTracker class used by training scripts.""" - - def test_from_config_dict(self): - tracker = MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0) - assert tracker.total_flops == 47_687_021_887_488 - assert tracker.per_gpu_flops == tracker.total_flops - - def test_multi_gpu_divides_flops(self): - single = MFUTracker.from_config_dict( - LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, num_gpus=1, peak_tflops=155.0 - ) - multi = MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, num_gpus=2, peak_tflops=155.0) - assert multi.per_gpu_flops == single.total_flops // 2 - - def test_compute_mfu_correctness(self): - tracker = MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0) - result = tracker.compute_mfu(step_time=0.5) - expected_tflops = tracker.per_gpu_flops / 0.5 / 1e12 - expected_mfu = expected_tflops / 155.0 * 100 - assert abs(result["mfu"] - expected_mfu) < 0.01 - assert abs(result["tflops_per_gpu"] - expected_tflops) < 0.01 - - def test_mfu_inversely_proportional_to_step_time(self): - tracker = MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0) - fast = tracker.compute_mfu(step_time=0.5) - slow = tracker.compute_mfu(step_time=1.0) - assert abs(fast["mfu"] - 2 * slow["mfu"]) < 0.01 - - def test_all_formula_options(self): - for formula in ["analytical", "simplified", "hyena"]: - if formula == "hyena": - cfg = ModelFLOPsConfig( - hidden_size=1024, - num_hidden_layers=4, - num_attention_heads=8, - num_kv_heads=8, - head_dim=128, - intermediate_size=4096, - num_mlp_projections=3, - vocab_size=512, - has_lm_head=True, - ) - tracker = MFUTracker(cfg, batch_size=1, seq_len=4096, peak_tflops=155.0, formula=formula) - else: - tracker = MFUTracker.from_config_dict( - LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0, formula=formula - ) - assert tracker.total_flops > 0 - - def test_invalid_formula_raises(self): - with pytest.raises(ValueError, match="Unknown formula"): - MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0, formula="bad") - - def test_cp_communication_estimate(self): - tracker = MFUTracker.from_config_dict( - LLAMA_1B_CONFIG, batch_size=1, seq_len=16384, num_gpus=2, parallelism={"dp": 1, "cp": 2}, peak_tflops=155.0 - ) - assert tracker.comm_bytes > 0 - overhead = tracker.estimate_comm_overhead(step_time=1.0, measured_bw_gbps=6.6) - assert overhead["estimated_comm_time"] > 0 - assert 0 < overhead["comm_pct"] < 100 - - def test_no_comm_single_gpu(self): - tracker = MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0) - assert tracker.comm_bytes == 0 - - -# ============================================================================ -# Communication estimation -# ============================================================================ - - -class TestCPCommEstimation: - """Test CP ring attention communication byte estimates.""" - - def test_zero_without_cp(self): - assert estimate_cp_comm_bytes(1, 4096, 25, 8, 128, cp_size=1) == 0 - - def test_scales_linearly_with_seq_len(self): - comm_4k = estimate_cp_comm_bytes(1, 4096, 25, 8, 128, cp_size=2) - comm_8k = estimate_cp_comm_bytes(1, 8192, 25, 8, 128, cp_size=2) - assert comm_8k == 2 * comm_4k - - def test_scales_linearly_with_batch(self): - comm_b1 = estimate_cp_comm_bytes(1, 4096, 25, 8, 128, cp_size=2) - comm_b4 = estimate_cp_comm_bytes(4, 4096, 25, 8, 128, cp_size=2) - assert comm_b4 == 4 * comm_b1 - - def test_known_value_lingua_1b(self): - """Golden value for lingua-1B at S=4096, CP=2.""" - assert estimate_cp_comm_bytes(1, 4096, 25, 8, 128, cp_size=2) == 419_430_400 diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py index aebdfe17ef..5fd6e70ad0 100644 --- a/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py +++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py @@ -23,10 +23,14 @@ from transformers.modeling_outputs import CausalLMOutputWithPast from distributed_config import DistributedConfig -from perf_logger import PerfLogger +from perf_logger import ( + PerfLogger, + _compute_per_token_flops, + _detect_peak_tflops_bf16, +) -def _make_args(logging_frequency=1, num_train_steps=100): +def _make_args(logging_frequency=1, num_train_steps=100, log_mfu=False, max_seq_length=128): """Create a minimal args config for PerfLogger.""" return OmegaConf.create( { @@ -35,6 +39,8 @@ def _make_args(logging_frequency=1, num_train_steps=100): "num_train_steps": num_train_steps, "profiler": {"enabled": False}, "fp8_stats_config": {"enabled": False}, + "log_mfu": log_mfu, + "dataset": {"max_seq_length": max_seq_length}, } ) @@ -208,3 +214,97 @@ def test_min_loss_tracked_correctly(self, mock_wandb, mock_tqdm): _run_steps(perf_logger, losses) assert perf_logger.min_loss.item() == pytest.approx(1.0) + + +class TestComputePerTokenFlops: + """Test that the per-token training FLOPs formula matches hand-calculated values.""" + + def test_llama_gqa_swiglu(self): + """Llama-style config: GQA (n_kv=8 < n_heads=32) + SwiGLU (3 MLP projections).""" + config = { + "model_type": "llama", + "hidden_size": 4096, + "num_hidden_layers": 32, + "num_attention_heads": 32, + "num_key_value_heads": 8, # GQA + "intermediate_size": 14336, + "vocab_size": 128256, + } + seq_len = 8192 + h, i, v, kv_dim, layers = 4096, 14336, 128256, 8 * 128, 32 + # Per-layer: Q (2h^2) + K+V (4h*kv_dim) + O (2h^2) + attn (4*S*h) + MLP (2*3*h*i) + per_layer = 2 * h * h + 4 * h * kv_dim + 2 * h * h + 4 * seq_len * h + 2 * 3 * h * i + expected_fwd = layers * per_layer + 2 * h * v + assert _compute_per_token_flops(config, seq_len) == 3 * expected_fwd + + def test_esm_mha_gelu(self): + """ESM2-style config: MHA (no num_key_value_heads) + standard FFN (2 MLP projections).""" + config = { + "model_type": "esm", + "hidden_size": 1280, + "num_hidden_layers": 33, + "num_attention_heads": 20, + "intermediate_size": 5120, + "vocab_size": 33, + } + seq_len = 1024 + h, i, v, kv_dim, layers = 1280, 5120, 33, (1280 // 20) * 20, 33 # kv_dim=h for MHA + per_layer = 2 * h * h + 4 * h * kv_dim + 2 * h * h + 4 * seq_len * h + 2 * 2 * h * i + expected_fwd = layers * per_layer + 2 * h * v + assert _compute_per_token_flops(config, seq_len) == 3 * expected_fwd + + def test_scales_with_seq_len(self): + """Only the attention S^2 term should vary with seq_len.""" + config = { + "model_type": "llama", + "hidden_size": 2048, + "num_hidden_layers": 16, + "num_attention_heads": 16, + "num_key_value_heads": 8, + "intermediate_size": 8192, + "vocab_size": 32000, + } + h, layers = 2048, 16 + # Difference per token between seq_len=1024 and seq_len=2048: + # layers * 4 * (2048 - 1024) * h, times 3 (forward+backward) + diff = _compute_per_token_flops(config, 2048) - _compute_per_token_flops(config, 1024) + assert diff == 3 * layers * 4 * 1024 * h + + def test_linear_in_unpadded_tokens(self): + """Multiplying per-token FLOPs by N tokens is linear (MFU formula relies on this).""" + config = { + "model_type": "llama", + "hidden_size": 1024, + "num_hidden_layers": 8, + "num_attention_heads": 8, + "intermediate_size": 4096, + "vocab_size": 32000, + } + per_token = _compute_per_token_flops(config, seq_len=512) + assert per_token * 100 == 100 * per_token + # Sanity: doubling unpadded token count doubles total FLOPs + assert per_token * 200 == 2 * (per_token * 100) + + def test_no_lm_head_when_vocab_zero(self): + """vocab_size=0 should drop the LM head term.""" + config_base = { + "model_type": "llama", + "hidden_size": 512, + "num_hidden_layers": 4, + "num_attention_heads": 8, + "intermediate_size": 2048, + } + with_vocab = _compute_per_token_flops({**config_base, "vocab_size": 32000}, seq_len=256) + no_vocab = _compute_per_token_flops({**config_base, "vocab_size": 0}, seq_len=256) + # Difference = 3 (training) * 2 * h * vocab + assert with_vocab - no_vocab == 3 * 2 * 512 * 32000 + + +class TestDetectPeakTflops: + """Smoke test for GPU peak TFLOPS detection.""" + + def test_returns_tuple_shape(self): + """Returns (peak_tflops_or_none, device_name_str).""" + peak, name = _detect_peak_tflops_bf16() + assert isinstance(name, str) + assert peak is None or isinstance(peak, float) diff --git a/bionemo-recipes/recipes/llama3_native_te/train_ddp.py b/bionemo-recipes/recipes/llama3_native_te/train_ddp.py index b646b00f41..cb3251f590 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_ddp.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_ddp.py @@ -25,7 +25,6 @@ import gc import logging -import time from contextlib import nullcontext from pathlib import Path @@ -43,7 +42,6 @@ from checkpoint import load_checkpoint_ddp, save_checkpoint_ddp, save_final_model_ddp, should_save_checkpoint from dataset import create_bshd_dataloader, create_thd_dataloader from distributed_config import DistributedConfig -from flops import MFUTracker, from_hf_config from fp8_debugging import initialize_fp8_debugging from modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM from perf_logger import PerfLogger @@ -143,20 +141,12 @@ def main(args: DictConfig) -> float | None: start_step = 0 epoch = 0 - perf_logger = PerfLogger(dist_config, args, start_step=start_step) - - # --- MFU Tracking (optional) --- - mfu_tracker = None - if args.get("log_mfu", False): - mfu_tracker = MFUTracker( - config=from_hf_config(config.to_dict()), - batch_size=args.dataset.micro_batch_size, - seq_len=args.dataset.max_seq_length, - num_gpus=dist_config.world_size, - parallelism={"dp": dist_config.world_size}, - ) - if dist_config.is_main_process(): - logger.info("MFU tracking enabled:\n%s", mfu_tracker.summary()) + perf_logger = PerfLogger( + dist_config, + args, + start_step=start_step, + model_config_dict=config.to_dict() if args.get("log_mfu", False) else None, + ) gc.collect() torch.cuda.empty_cache() @@ -165,7 +155,6 @@ def main(args: DictConfig) -> float | None: logger.info("Starting training loop from step %s to %s", start_step, args.num_train_steps) step = start_step micro_step = 0 # Gradient accumulation step counter - step_start_time = time.perf_counter() while step < args.num_train_steps: for batch in train_dataloader: batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa: PLW2901 @@ -202,13 +191,6 @@ def main(args: DictConfig) -> float | None: lr=optimizer.param_groups[0]["lr"], ) - if mfu_tracker is not None: - step_time = time.perf_counter() - step_start_time - mfu_info = mfu_tracker.compute_mfu(step_time) - if dist_config.is_main_process(): - logger.info("MFU: %.1f%% (%.2f TFLOPS/GPU)", mfu_info["mfu"], mfu_info["tflops_per_gpu"]) - step_start_time = time.perf_counter() - if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps): save_checkpoint_ddp( model=model, diff --git a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py index ba8305c8f8..1230ed82ac 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py @@ -24,7 +24,6 @@ import gc import logging -import time from contextlib import nullcontext from pathlib import Path @@ -49,7 +48,6 @@ ) from dataset import create_bshd_dataloader, create_thd_dataloader from distributed_config import DistributedConfig -from flops import MFUTracker, from_hf_config from fp8_debugging import initialize_fp8_debugging from modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM from perf_logger import PerfLogger @@ -157,20 +155,12 @@ def main(args: DictConfig) -> float | None: start_step = 0 epoch = 0 - perf_logger = PerfLogger(dist_config, args, start_step=start_step) - - # --- MFU Tracking (optional) --- - mfu_tracker = None - if args.get("log_mfu", False): - mfu_tracker = MFUTracker( - config=from_hf_config(config.to_dict()), - batch_size=args.dataset.micro_batch_size, - seq_len=args.dataset.max_seq_length, - num_gpus=dist_config.world_size, - parallelism={"dp": dist_config.world_size}, - ) - if dist_config.is_main_process(): - logger.info("MFU tracking enabled:\n%s", mfu_tracker.summary()) + perf_logger = PerfLogger( + dist_config, + args, + start_step=start_step, + model_config_dict=config.to_dict() if args.get("log_mfu", False) else None, + ) gc.collect() torch.cuda.empty_cache() @@ -179,7 +169,6 @@ def main(args: DictConfig) -> float | None: logger.info("Starting training loop from step %s to %s", start_step, args.num_train_steps) step = start_step micro_step = 0 # Gradient accumulation step counter - step_start_time = time.perf_counter() while step < args.num_train_steps: for batch in train_dataloader: batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa: PLW2901 @@ -215,13 +204,6 @@ def main(args: DictConfig) -> float | None: lr=optimizer.param_groups[0]["lr"], ) - if mfu_tracker is not None: - step_time = time.perf_counter() - step_start_time - mfu_info = mfu_tracker.compute_mfu(step_time) - if dist_config.is_main_process(): - logger.info("MFU: %.1f%% (%.2f TFLOPS/GPU)", mfu_info["mfu"], mfu_info["tflops_per_gpu"]) - step_start_time = time.perf_counter() - if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps): save_checkpoint_fsdp2( model=model, diff --git a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py index 5b24d2eaa2..e0327ef60f 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py @@ -25,7 +25,6 @@ import gc import logging -import time from contextlib import nullcontext from pathlib import Path @@ -49,7 +48,6 @@ from collator import ContextParallelDataLoaderWrapper, DataCollatorForContextParallel from dataset import create_bshd_dataloader, create_thd_dataloader from distributed_config import DistributedConfig -from flops import MFUTracker, from_hf_config from modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM from perf_logger import PerfLogger from scheduler import get_cosine_annealing_schedule_with_warmup @@ -179,20 +177,12 @@ def main(args: DictConfig) -> float | None: start_step = 0 epoch = 0 - perf_logger = PerfLogger(dist_config, args, start_step=start_step) - - # --- MFU Tracking (optional) --- - mfu_tracker = None - if args.get("log_mfu", False): - mfu_tracker = MFUTracker( - config=from_hf_config(config.to_dict()), - batch_size=args.dataset.micro_batch_size, - seq_len=args.dataset.max_seq_length, - num_gpus=dist_config.world_size, - parallelism={"dp": dist_config.world_size // args.cp_size, "cp": args.cp_size}, - ) - if dist_config.is_main_process(): - logger.info("MFU tracking enabled:\n%s", mfu_tracker.summary()) + perf_logger = PerfLogger( + dist_config, + args, + start_step=start_step, + model_config_dict=config.to_dict() if args.get("log_mfu", False) else None, + ) gc.collect() torch.cuda.empty_cache() @@ -201,7 +191,6 @@ def main(args: DictConfig) -> float | None: logger.info("Starting training loop from step %s to %s", start_step, args.num_train_steps) step = start_step micro_step = 0 # Gradient accumulation step counter - step_start_time = time.perf_counter() while step < args.num_train_steps: for batch in train_dataloader: batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa: PLW2901 @@ -240,13 +229,6 @@ def main(args: DictConfig) -> float | None: lr=optimizer.param_groups[0]["lr"], ) - if mfu_tracker is not None: - step_time = time.perf_counter() - step_start_time - mfu_info = mfu_tracker.compute_mfu(step_time) - if dist_config.is_main_process(): - logger.info("MFU: %.1f%% (%.2f TFLOPS/GPU)", mfu_info["mfu"], mfu_info["tflops_per_gpu"]) - step_start_time = time.perf_counter() - if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps): save_checkpoint_fsdp2( model=model, diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/README.md b/bionemo-recipes/recipes/opengenome2_llama_native_te/README.md index 435b497573..85be87bded 100644 --- a/bionemo-recipes/recipes/opengenome2_llama_native_te/README.md +++ b/bionemo-recipes/recipes/opengenome2_llama_native_te/README.md @@ -419,16 +419,16 @@ Enable per-step Model FLOPs Utilization (MFU) logging during training by adding torchrun --nproc_per_node=2 train_fsdp2_cp.py log_mfu=true ``` -This logs MFU (%), TFLOPS/GPU, and step time at each optimizer step. The module auto-detects model architecture (GQA, SwiGLU, etc.) from the model config. +This adds two metrics at each logging interval, emitted alongside existing metrics via WANDB and +stdout: -The `flops.py` CLI provides standalone utilities: +- `train/tflops_per_gpu` — achieved BF16 TFLOPS per GPU +- `train/mfu_pct` — MFU as a percentage of the GPU's peak dense BF16 TFLOPS -```bash -python flops.py gpu-info # Show GPU and peak TFLOPS -python flops.py flops --config-path ./model_configs/meta-llama/Llama-3.1-8B # Compute FLOPs -python flops.py cp-comm --config-path ./model_configs/meta-llama/Llama-3.1-8B --cp-size 2 # CP comm estimate -torchrun --nproc_per_node=2 flops.py bandwidth # Measure P2P GPU bandwidth -``` +The FLOPs formula auto-detects model architecture from the HF config (GQA vs. MHA, SwiGLU vs. +standard FFN, LM head presence) and scales with the actual unpadded token count on each rank. This +means it naturally handles gradient accumulation, data parallelism, context parallelism, BSHD, and +THD (sequence packing) without per-strategy code paths. The implementation lives in `perf_logger.py`. ## Developer Guide diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/flops.py b/bionemo-recipes/recipes/opengenome2_llama_native_te/flops.py deleted file mode 100644 index ade8d8a12c..0000000000 --- a/bionemo-recipes/recipes/opengenome2_llama_native_te/flops.py +++ /dev/null @@ -1,714 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# --- BEGIN COPIED FILE NOTICE --- -# This file is copied from: bionemo-recipes/recipes/llama3_native_te/flops.py -# Do not modify this file directly. Instead, modify the source and run: -# python ci/scripts/check_copied_files.py --fix -# --- END COPIED FILE NOTICE --- - -"""Architecture-independent FLOPs counting, MFU calculation, and communication overhead estimation. - -Supports transformer architectures (Llama, ESM2, CodonFM, etc.) and Hyena (Evo2). -Designed to be copied to any recipe via check_copied_files.py and hooked into -training scripts for live MFU tracking. - -Usage as a library (in training scripts): - from flops import MFUTracker, from_hf_config - tracker = MFUTracker.from_config_dict(config_dict, batch_size=4, seq_len=4096, num_gpus=2) - mfu_info = tracker.compute_mfu(step_time=0.5) - -Usage as a CLI: - python flops.py gpu-info - python flops.py flops --config-path ./model_configs/lingua-1B - python flops.py cp-comm --config-path ./model_configs/lingua-1B --cp-size 2 - torchrun --nproc_per_node=2 flops.py bandwidth -""" - -import math -import time -from dataclasses import dataclass - -import torch -import torch.distributed as dist - - -# ============================================================================= -# GPU Peak TFLOPS -# ============================================================================= - -# Dense (without sparsity) BF16 tensor core peak TFLOPS. -# NVIDIA product pages often list the 2x sparse value; dense = sparse / 2. -# Sources: -# H100/H200/GH200: https://www.nvidia.com/en-us/data-center/h100/ (1,979 TFLOPS sparse) -# A100: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf -# A6000: https://www.nvidia.com/content/dam/en-zz/Solutions/design-visualization/quadro-product-literature/proviz-print-nvidia-rtx-a6000-datasheet-us-nvidia-1454980-r9-web%20(1).pdf -# L40: https://www.nvidia.com/en-us/data-center/l40/ -# B200/GB200: https://www.nvidia.com/en-us/data-center/dgx-b200/ -# B300/GB300: https://www.nvidia.com/en-us/data-center/gb300-nvl72/ (360 PFLOPS sparse / 72 GPUs / 2) -GPU_PEAK_TFLOPS_BF16 = { - "H100": 989.0, - "H200": 989.0, - "A100": 312.0, - "A6000": 155.0, - "L40": 181.0, - "GH200": 989.0, - "B200": 2250.0, - "GB200": 2250.0, - "B300": 2500.0, - "GB300": 2500.0, -} - - -def detect_gpu_peak_tflops(): - """Auto-detect GPU peak bf16 TFLOPS from device name via substring match.""" - device_name = torch.cuda.get_device_name(0) - for gpu_key, tflops in GPU_PEAK_TFLOPS_BF16.items(): - if gpu_key.lower() in device_name.lower(): - return tflops, device_name - return None, device_name - - -# ============================================================================= -# Model FLOPs Config -# ============================================================================= - -# Model types that use gated MLP (SwiGLU/GeGLU) with 3 projections instead of 2. -GATED_MLP_MODEL_TYPES = frozenset({"llama", "mistral", "qwen2"}) - - -@dataclass(frozen=True) -class ModelFLOPsConfig: - """Architecture-independent parameters for FLOPs calculation. - - Can be constructed manually or via from_hf_config() for auto-detection. - """ - - hidden_size: int # H - num_hidden_layers: int # L - num_attention_heads: int # n_heads - num_kv_heads: int # n_kv (== n_heads for MHA) - head_dim: int # H // n_heads - intermediate_size: int # I (FFN intermediate dimension) - num_mlp_projections: int # 2 (standard FFN) or 3 (SwiGLU/GLU) - vocab_size: int # V - has_lm_head: bool # True for LM models, False for ViT etc. - - -def from_hf_config(config_dict, **overrides): - """Create ModelFLOPsConfig from an HF-compatible config dict. - - Auto-detects architecture: - - GQA vs MHA: from num_key_value_heads (absent = MHA) - - Gated MLP (3 proj) vs standard FFN (2 proj): from model_type - - LM head: from vocab_size > 0 - - Args: - config_dict: Dict with standard HF config keys (hidden_size, num_hidden_layers, etc.). - Works with config.json dicts, config.to_dict(), or MODEL_PRESETS dicts. - **overrides: Explicit overrides for any field (e.g., num_mlp_projections=3). - """ - h = config_dict["hidden_size"] - n_heads = config_dict["num_attention_heads"] - n_kv = config_dict.get("num_key_value_heads", n_heads) - vocab = config_dict.get("vocab_size", 0) - model_type = config_dict.get("model_type", "") - - # Detect gated MLP (3 projections) vs standard FFN (2 projections). - # Llama/Mistral/Qwen use SwiGLU (gate + up + down = 3 projections). - # ESM2/CodonFM/Geneformer/BERT use standard FFN (up + down = 2 projections). - num_mlp_proj = 3 if model_type in GATED_MLP_MODEL_TYPES else 2 - - kwargs = { - "hidden_size": h, - "num_hidden_layers": config_dict["num_hidden_layers"], - "num_attention_heads": n_heads, - "num_kv_heads": n_kv, - "head_dim": h // n_heads, - "intermediate_size": config_dict["intermediate_size"], - "num_mlp_projections": num_mlp_proj, - "vocab_size": vocab, - "has_lm_head": vocab > 0, - } - kwargs.update(overrides) - return ModelFLOPsConfig(**kwargs) - - -# ============================================================================= -# FLOPs Formulas -# ============================================================================= - - -def compute_flops_analytical(config, batch_size, seq_len): - """First-principles FLOPs for any transformer (GQA/MHA, SwiGLU/GELU). - - Counts matmul FLOPs only (2 FLOPs per multiply-accumulate). Excludes softmax, - layer norms, activations, and element-wise ops. - - Handles: - - GQA vs MHA: K/V projection sizes based on config.num_kv_heads - - SwiGLU vs standard FFN: 2 or 3 MLP projections - - LM head presence - - Returns: - (total_training_flops, per_layer_breakdown_dict, lm_head_forward_flops) - """ - b, s, h = batch_size, seq_len, config.hidden_size - kv_dim = config.num_kv_heads * config.head_dim - - breakdown = { - "Q projection": 2 * b * s * h * h, - "K projection": 2 * b * s * h * kv_dim, - "V projection": 2 * b * s * h * kv_dim, - "O projection": 2 * b * s * h * h, - "Attn logits": 2 * b * s * s * h, - "Attn values": 2 * b * s * s * h, - } - - ffn = config.intermediate_size - if config.num_mlp_projections == 3: - # SwiGLU/GeGLU: gate + up + down = 3 matmuls - breakdown["Gate projection"] = 2 * b * s * h * ffn - breakdown["Up projection"] = 2 * b * s * h * ffn - breakdown["Down projection"] = 2 * b * s * ffn * h - else: - # Standard FFN: up + down = 2 matmuls - breakdown["Up projection"] = 2 * b * s * h * ffn - breakdown["Down projection"] = 2 * b * s * ffn * h - - per_layer_fwd = sum(breakdown.values()) - lm_head_fwd = 2 * b * s * h * config.vocab_size if config.has_lm_head else 0 - total_fwd = config.num_hidden_layers * per_layer_fwd + lm_head_fwd - total_training = 3 * total_fwd - - return total_training, breakdown, lm_head_fwd - - -def compute_flops_simplified(batch_size, seq_len, hidden_size, num_layers, vocab_size): - """Simplified formula assuming standard MHA + standard FFN with I=4H. - - This is the formula from the Llama3 README: - (24*B*S*H^2 + 4*B*S^2*H) * 3*L + 6*B*S*H*V - - The 24*H^2 coefficient assumes: 8*H^2 for 4 attention projections (MHA) + - 16*H^2 for 2 MLP projections with I=4H. - """ - b, s, h = batch_size, seq_len, hidden_size - return (24 * b * s * h * h + 4 * b * s * s * h) * (3 * num_layers) + (6 * b * s * h * vocab_size) - - -def compute_flops_hyena(config, batch_size, seq_len, hyena_layer_counts=None): - """FLOPs for Hyena-based models (Evo2). - - Based on evo2_provider.py. Hyena replaces attention with O(N) FFT convolution. - - Args: - config: ModelFLOPsConfig with model dimensions. - batch_size: Batch size. - seq_len: Sequence length. - hyena_layer_counts: Optional dict {"S": n, "D": n, "H": n, "A": n} for - short/medium/long conv and attention layer counts. If None, assumes - all layers are long-conv Hyena (H=num_layers, no attention). - """ - b, s, h = batch_size, seq_len, config.hidden_size - ffn = config.intermediate_size - - if hyena_layer_counts is None: - hyena_layer_counts = {"S": 0, "D": 0, "H": config.num_hidden_layers, "A": 0} - - # Common per-layer FLOPs - pre_attn_qkv_proj = 2 * 3 * b * s * h * h - post_attn_proj = 2 * b * s * h * h - glu_ffn = 2 * 3 * b * s * ffn * h - - # Layer-type-specific FLOPs (defaults from evo2_provider.py) - attn = 2 * 2 * b * h * s * s # Standard S^2 attention - hyena_proj = 2 * 3 * b * s * 3 * h # short_conv_L=3 default - hyena_short_conv = 2 * b * s * 7 * h # short_conv_len=7 - hyena_medium_conv = 2 * b * s * 128 * h # medium_conv_len=128 - hyena_long_fft = b * 10 * s * math.log2(max(s, 2)) * h - - n_s = hyena_layer_counts.get("S", 0) - n_d = hyena_layer_counts.get("D", 0) - n_h = hyena_layer_counts.get("H", 0) - n_a = hyena_layer_counts.get("A", 0) - - logits = 2 * b * s * h * config.vocab_size if config.has_lm_head else 0 - - total_fwd = ( - logits - + config.num_hidden_layers * (pre_attn_qkv_proj + post_attn_proj + glu_ffn) - + n_a * attn - + (n_s + n_d + n_h) * hyena_proj - + n_s * hyena_short_conv - + n_d * hyena_medium_conv - + int(n_h * hyena_long_fft) - ) - - return 3 * total_fwd - - -# ============================================================================= -# MFU Tracker -# ============================================================================= - - -class MFUTracker: - """Tracks MFU during training. Initialize once, call compute_mfu() per step. - - Usage: - tracker = MFUTracker.from_config_dict(config_dict, batch_size=4, seq_len=4096, num_gpus=2) - # In training loop: - mfu_info = tracker.compute_mfu(step_time=0.5) - print(f"MFU: {mfu_info['mfu']:.1f}%") - """ - - def __init__( - self, - config, - batch_size, - seq_len, - num_gpus=1, - parallelism=None, - peak_tflops=None, - formula="analytical", - hyena_layer_counts=None, - ): - """Initialize MFU tracker. - - Args: - config: ModelFLOPsConfig instance. - batch_size: Micro batch size per GPU. - seq_len: Sequence length. - num_gpus: Total number of GPUs. - parallelism: Dict of parallelism dimensions, e.g. {"dp": 1, "cp": 2, "tp": 1}. - Used for communication overhead estimation. - peak_tflops: GPU peak bf16 TFLOPS. Auto-detected if None. - formula: "analytical", "simplified", or "hyena". - hyena_layer_counts: For Hyena formula, dict of layer type counts. - """ - self.config = config - self.batch_size = batch_size - self.seq_len = seq_len - self.num_gpus = num_gpus - self.parallelism = parallelism or {} - self.formula = formula - - if formula == "analytical": - self.total_flops, self.breakdown, self.lm_head_flops = compute_flops_analytical( - config, batch_size, seq_len - ) - elif formula == "simplified": - self.total_flops = compute_flops_simplified( - batch_size, seq_len, config.hidden_size, config.num_hidden_layers, config.vocab_size - ) - self.breakdown = None - self.lm_head_flops = 0 - elif formula == "hyena": - self.total_flops = compute_flops_hyena(config, batch_size, seq_len, hyena_layer_counts) - self.breakdown = None - self.lm_head_flops = 0 - else: - raise ValueError(f"Unknown formula: {formula!r}. Use 'analytical', 'simplified', or 'hyena'.") - - self.per_gpu_flops = self.total_flops // max(num_gpus, 1) - - if peak_tflops is not None: - self.peak_tflops = peak_tflops - self.device_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "unknown" - else: - detected, self.device_name = detect_gpu_peak_tflops() - self.peak_tflops = detected - - self.comm_bytes = self._estimate_comm() - - @classmethod - def from_config_dict(cls, config_dict, batch_size, seq_len, **kwargs): - """Create from an HF config dict with auto-detection.""" - config = from_hf_config(config_dict) - return cls(config, batch_size, seq_len, **kwargs) - - def compute_mfu(self, step_time): - """Compute MFU from measured step time. - - Args: - step_time: Wall-clock time for one training step (seconds). - - Returns: - Dict with mfu (%), tflops_per_gpu, per_gpu_flops, total_flops, step_time. - """ - tflops = self.per_gpu_flops / step_time / 1e12 - mfu = tflops / self.peak_tflops * 100 if self.peak_tflops else 0.0 - return { - "mfu": mfu, - "tflops_per_gpu": tflops, - "per_gpu_flops": self.per_gpu_flops, - "total_flops": self.total_flops, - "step_time": step_time, - } - - def estimate_comm_overhead(self, step_time, measured_bw_gbps=None): - """Estimate communication overhead as a fraction of step time. - - Args: - step_time: Measured step time in seconds. - measured_bw_gbps: Measured P2P bandwidth in GB/s. If None, uses a default estimate. - - Returns: - Dict with comm_bytes, estimated_comm_time, comm_pct. - """ - bw = measured_bw_gbps or 6.0 # Default ~PCIe Gen3 x8 - comm_time = self.comm_bytes / (bw * 1e9) if bw > 0 else 0.0 - comm_pct = comm_time / step_time * 100 if step_time > 0 else 0.0 - return {"comm_bytes": self.comm_bytes, "estimated_comm_time": comm_time, "comm_pct": comm_pct} - - def _estimate_comm(self): - """Estimate total communication bytes per step based on parallelism.""" - total = 0 - cp_size = self.parallelism.get("cp", 1) - dp_size = self.parallelism.get("dp", 1) - - if cp_size > 1: - total += estimate_cp_comm_bytes( - self.batch_size, - self.seq_len, - self.config.num_hidden_layers, - self.config.num_kv_heads, - self.config.head_dim, - cp_size, - ) - - if dp_size > 1: - # FSDP reduce-scatter estimate: ~2 * model_params * dtype_bytes * (dp-1)/dp - model_params = _estimate_model_params(self.config) - total += 2 * model_params * 2 * (dp_size - 1) // dp_size - - return total - - def summary(self): - """Return a human-readable summary string.""" - lines = [ - f"MFUTracker: {self.formula} formula, {self.num_gpus} GPU(s)", - f" Model: H={self.config.hidden_size}, L={self.config.num_hidden_layers}," - f" heads={self.config.num_attention_heads}, kv_heads={self.config.num_kv_heads}," - f" I={self.config.intermediate_size}, V={self.config.vocab_size}", - f" MLP projections: {self.config.num_mlp_projections}" - f" ({'SwiGLU/GLU' if self.config.num_mlp_projections == 3 else 'standard FFN'})", - f" Batch: B={self.batch_size}, S={self.seq_len}", - f" Total FLOPs/step: {format_flops(self.total_flops)} ({format_flops_exact(self.total_flops)})", - f" Per-GPU FLOPs: {format_flops(self.per_gpu_flops)}", - f" GPU: {self.device_name} (Peak: {self.peak_tflops} TFLOPS)" if self.peak_tflops else " GPU: unknown", - ] - if self.parallelism: - lines.append(f" Parallelism: {self.parallelism}") - if self.comm_bytes > 0: - lines.append(f" Estimated comm: {format_bytes(self.comm_bytes)}/step") - return "\n".join(lines) - - -# ============================================================================= -# Communication Estimation -# ============================================================================= - - -def estimate_cp_comm_bytes(b, s, num_layers, n_kv_heads, head_dim, cp_size, dtype_bytes=2): - """Estimate total bytes transferred for CP ring attention per training step. - - Ring attention sends local KV chunks around the ring. Per layer forward: - (cp-1) steps, each sending B * (S/cp) * 2 * kv_dim * dtype_bytes. - Training = ~2x forward communication (forward sends KV, backward sends dKV). - """ - if cp_size <= 1: - return 0 - s_local = s // cp_size - kv_dim = n_kv_heads * head_dim - per_layer_fwd = (cp_size - 1) * b * s_local * 2 * kv_dim * dtype_bytes - return 2 * num_layers * per_layer_fwd - - -def _estimate_model_params(config): - """Rough parameter count estimate from config dimensions.""" - h = config.hidden_size - kv_dim = config.num_kv_heads * config.head_dim - attn_params = h * h + 2 * h * kv_dim + h * h # Q + K + V + O - mlp_params = config.num_mlp_projections * h * config.intermediate_size - layer_params = attn_params + mlp_params - total = config.num_hidden_layers * layer_params - if config.has_lm_head: - total += config.vocab_size * h * 2 # embed + lm_head - return total - - -# ============================================================================= -# Utilities -# ============================================================================= - - -def measure_bus_bandwidth(device, world_size, num_iters=20, num_elements=10_000_000): - """Measure unidirectional P2P bandwidth via send/recv (matches CP ring pattern).""" - if world_size <= 1: - return 0.0 - - rank = dist.get_rank() - tensor = torch.randn(num_elements, device=device, dtype=torch.bfloat16) - peer = 1 - rank - - for _ in range(5): - if rank == 0: - dist.send(tensor, dst=peer) - dist.recv(tensor, src=peer) - else: - dist.recv(tensor, src=peer) - dist.send(tensor, dst=peer) - torch.cuda.synchronize() - - dist.barrier() - torch.cuda.synchronize() - start = time.perf_counter() - for _ in range(num_iters): - if rank == 0: - dist.send(tensor, dst=peer) - else: - dist.recv(tensor, src=peer) - torch.cuda.synchronize() - elapsed = time.perf_counter() - start - - data_bytes = tensor.nelement() * tensor.element_size() - return num_iters * data_bytes / elapsed / 1e9 - - -def load_model_config(config_path): - """Load model config dict from a local path or HuggingFace model ID. - - Supports: - - Local directory: ./model_configs/lingua-1B (reads config.json inside) - - Local file: ./model_configs/lingua-1B/config.json - - HF model ID: nvidia/esm2_t36_3B_UR50D (fetches from HuggingFace Hub) - """ - import json - from pathlib import Path - - path = Path(config_path) - if path.is_dir(): - path = path / "config.json" - if path.exists(): - return json.loads(path.read_text()) - - # Fall back to HuggingFace Hub - from transformers import AutoConfig - - hf_config = AutoConfig.from_pretrained(config_path, trust_remote_code=True) - return hf_config.to_dict() - - -# ============================================================================= -# Formatting -# ============================================================================= - - -def format_flops(flops): - """Format FLOPs with appropriate unit (G/T/P).""" - if flops >= 1e15: - return f"{flops / 1e15:.2f} P" - elif flops >= 1e12: - return f"{flops / 1e12:.2f} T" - elif flops >= 1e9: - return f"{flops / 1e9:.2f} G" - else: - return f"{flops:.2e}" - - -def format_flops_exact(flops): - """Format FLOPs as the full integer with commas.""" - return f"{int(flops):,}" - - -def format_bytes(num_bytes): - """Format bytes with appropriate unit.""" - if num_bytes >= 1e9: - return f"{num_bytes / 1e9:.2f} GB" - elif num_bytes >= 1e6: - return f"{num_bytes / 1e6:.2f} MB" - elif num_bytes >= 1e3: - return f"{num_bytes / 1e3:.2f} KB" - else: - return f"{num_bytes:.0f} B" - - -def print_breakdown(breakdown, lm_head_fwd, num_layers, total_flops, model_params): - """Print first-principles FLOPs breakdown.""" - print() - print("--- First Principles Breakdown (forward pass, per layer) ---") - per_layer_total = sum(breakdown.values()) - for component, flops_val in breakdown.items(): - pct = flops_val / per_layer_total * 100 - print(f" {component:<20} {format_flops(flops_val):>12} ({pct:>5.1f}%)") - total_fwd = num_layers * per_layer_total + lm_head_fwd - print(f" {'LM head':<20} {format_flops(lm_head_fwd):>12}") - print(f" {'Per-layer total':<20} {format_flops(per_layer_total):>12}") - print(f" {'All layers (x' + str(num_layers) + ')':<20} {format_flops(num_layers * per_layer_total):>12}") - print(f" {'Total forward':<20} {format_flops(total_fwd):>12}") - print(f" {'Total training (3x)':<20} {format_flops(total_flops):>12}") - print(f" {'Model params':<20} {model_params / 1e9:.2f}B") - - -# ============================================================================= -# CLI -# ============================================================================= - - -def _cli_bandwidth(): - """Measure P2P bandwidth. Launch with: torchrun --nproc_per_node=2 flops.py bandwidth.""" - import os - - dist.init_process_group(backend="cpu:gloo,cuda:nccl") - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - torch.cuda.set_device(local_rank) - rank = dist.get_rank() - world_size = dist.get_world_size() - device = torch.device(f"cuda:{local_rank}") - - if rank == 0: - print(f"Measuring P2P bandwidth between {world_size} GPUs...") - for i in range(world_size): - print(f" GPU {i}: {torch.cuda.get_device_name(i)}") - - bw = measure_bus_bandwidth(device, world_size) - if rank == 0: - print(f"\nUnidirectional P2P bandwidth: {bw:.2f} GB/s") - - dist.destroy_process_group() - - -def _cli_gpu_info(): - """Print GPU info and peak TFLOPS.""" - peak, name = detect_gpu_peak_tflops() - print(f"GPU: {name}") - if peak: - print(f"Peak bf16 TFLOPS: {peak:.1f}") - else: - print("Peak bf16 TFLOPS: unknown (use --peak-tflops to override)") - print() - print("Known GPUs:") - for gpu, tflops in GPU_PEAK_TFLOPS_BF16.items(): - print(f" {gpu:<16} {tflops:>8.1f} TFLOPS") - - -def _cli_flops(): - """Compute FLOPs for a model config.""" - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument("command") - parser.add_argument("--config-path", default="./model_configs/lingua-1B") - parser.add_argument("--batch-size", type=int, default=1) - parser.add_argument("--seq-len", type=int, default=4096) - parser.add_argument("--formula", default="analytical", choices=["analytical", "simplified"]) - parser.add_argument("--num-gpus", type=int, default=1, help="Number of GPUs for per-GPU FLOPs and comm estimates") - parser.add_argument("--cp-size", type=int, default=1, help="Context parallelism size for comm overhead estimate") - parser.add_argument("--p2p-bw", type=float, default=6.0, help="P2P bandwidth in GB/s for comm time estimate") - args = parser.parse_args() - - cfg_dict = load_model_config(args.config_path) - config = from_hf_config(cfg_dict) - b, s = args.batch_size, args.seq_len - - print( - f"Config: H={config.hidden_size}, L={config.num_hidden_layers}," - f" n_heads={config.num_attention_heads}, n_kv={config.num_kv_heads}," - f" I={config.intermediate_size}, V={config.vocab_size}" - ) - print( - f"MLP: {config.num_mlp_projections} projections" - f" ({'SwiGLU/GLU' if config.num_mlp_projections == 3 else 'standard FFN'})" - ) - print(f"Batch: B={b}, S={s}, GPUs={args.num_gpus}, CP={args.cp_size}") - print() - - simplified = compute_flops_simplified(b, s, config.hidden_size, config.num_hidden_layers, config.vocab_size) - analytical, breakdown, lm_head = compute_flops_analytical(config, b, s) - - print(f"{'Method':<24} {'FLOPs/step':>14} {'Per-GPU':>14} {'Exact':>30}") - print("-" * 86) - for name, flops in [("Simplified (README)", simplified), ("Analytical", analytical)]: - per_gpu = flops // max(args.num_gpus, 1) - print(f"{name:<24} {format_flops(flops):>14} {format_flops(per_gpu):>14} {format_flops_exact(flops):>30}") - - if simplified != analytical: - diff = analytical - simplified - print(f"\nDifference: {format_flops_exact(diff)} ({diff / simplified * 100:+.2f}%)") - else: - print("\nFormulas agree exactly for this config.") - - # Communication overhead estimate - if args.cp_size > 1: - dp_size = args.num_gpus // args.cp_size - parallelism = {"dp": dp_size, "cp": args.cp_size} - tracker = MFUTracker(config, b, s, num_gpus=args.num_gpus, parallelism=parallelism) - print(f"\n--- Communication Overhead (CP={args.cp_size}, P2P BW={args.p2p_bw} GB/s) ---") - print( - f" CP ring attention: {format_bytes(tracker.comm_bytes)}/step ({format_flops_exact(tracker.comm_bytes)} bytes)" - ) - comm_time = tracker.comm_bytes / (args.p2p_bw * 1e9) if args.p2p_bw > 0 else 0 - print(f" Estimated comm time: {comm_time:.4f}s") - - model_params = _estimate_model_params(config) - print_breakdown(breakdown, lm_head, config.num_hidden_layers, analytical, model_params) - - -def _cli_cp_comm(): - """Estimate CP communication volume.""" - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument("command") - parser.add_argument("--config-path", default="./model_configs/lingua-1B") - parser.add_argument("--batch-size", type=int, default=1) - parser.add_argument("--seq-len", type=int, default=16384) - parser.add_argument("--cp-size", type=int, default=2) - args = parser.parse_args() - - cfg_dict = load_model_config(args.config_path) - config = from_hf_config(cfg_dict) - b, s = args.batch_size, args.seq_len - - comm = estimate_cp_comm_bytes(b, s, config.num_hidden_layers, config.num_kv_heads, config.head_dim, args.cp_size) - print( - f"CP={args.cp_size}, B={b}, S={s}, L={config.num_hidden_layers}," - f" n_kv_heads={config.num_kv_heads}, head_dim={config.head_dim}" - ) - print(f"Estimated CP ring attention communication: {format_bytes(comm)}/step ({format_flops_exact(comm)} bytes)") - - -if __name__ == "__main__": - import sys - - commands = { - "bandwidth": ("Measure P2P bandwidth (requires torchrun --nproc_per_node=2)", _cli_bandwidth), - "gpu-info": ("Print GPU info and peak TFLOPS", _cli_gpu_info), - "flops": ("Compute FLOPs for a model config", _cli_flops), - "cp-comm": ("Estimate CP communication volume", _cli_cp_comm), - } - - if len(sys.argv) < 2 or sys.argv[1] in ("-h", "--help") or sys.argv[1] not in commands: - print("Usage: python flops.py [options]") - print(" torchrun --nproc_per_node=2 flops.py bandwidth") - print() - print("Commands:") - for cmd, (desc, _) in commands.items(): - print(f" {cmd:<16} {desc}") - sys.exit(0 if len(sys.argv) >= 2 and sys.argv[1] in ("-h", "--help") else 1) - - commands[sys.argv[1]][1]() diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py b/bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py index 081103beb5..0633193d3f 100644 --- a/bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py @@ -41,18 +41,84 @@ logger = logging.getLogger(__name__) +# Dense BF16 tensor core peak TFLOPS (without sparsity). Product pages often list +# the 2x sparse number; dense = sparse / 2. Sources: NVIDIA datasheets for each GPU. +_GPU_PEAK_TFLOPS_BF16 = { + "H100": 989.0, + "H200": 989.0, + "A100": 312.0, + "A6000": 155.0, + "L40": 181.0, + "GH200": 989.0, + "B200": 2250.0, + "GB200": 2250.0, + "B300": 2500.0, + "GB300": 2500.0, +} + +# Model types that use gated MLP (SwiGLU/GeGLU) with 3 projections vs. standard FFN with 2. +_GATED_MLP_MODEL_TYPES = frozenset({"llama", "mistral", "qwen2"}) + + +def _detect_peak_tflops_bf16(): + """Auto-detect dense BF16 peak TFLOPS for the local GPU. Returns (peak, device_name).""" + if not torch.cuda.is_available(): + return None, "unknown" + name = torch.cuda.get_device_name(0) + for key, tflops in _GPU_PEAK_TFLOPS_BF16.items(): + if key.lower() in name.lower(): + return tflops, name + return None, name + + +def _compute_per_token_flops(model_config_dict: dict, seq_len: int) -> int: + """Training FLOPs per token for a transformer (forward + backward = 3x forward). + + First-principles matmul count: Q/K/V/O projections (GQA-aware), attention + logits/values (the S^2 cost expressed per-token as 4*S*H), 2-or-3-projection + MLP (SwiGLU detected via model_type), and LM head. The returned value is + multiplied by the actual unpadded token count at log time, so it naturally + handles BSHD, THD (sequence packing), gradient accumulation, DP, and CP: + unpadded tokens on each rank already reflect that rank's share of work. + """ + h = model_config_dict["hidden_size"] + n_heads = model_config_dict["num_attention_heads"] + n_kv = model_config_dict.get("num_key_value_heads", n_heads) + head_dim = h // n_heads + kv_dim = n_kv * head_dim + ffn = model_config_dict["intermediate_size"] + vocab = model_config_dict.get("vocab_size", 0) + num_layers = model_config_dict["num_hidden_layers"] + model_type = model_config_dict.get("model_type", "") + num_mlp_proj = 3 if model_type in _GATED_MLP_MODEL_TYPES else 2 + + per_layer = ( + 2 * h * h # Q projection + + 4 * h * kv_dim # K + V projections (GQA-aware) + + 2 * h * h # O projection + + 4 * seq_len * h # attention logits + values (S^2 -> S per token) + + 2 * num_mlp_proj * h * ffn # MLP (2 or 3 projections) + ) + lm_head = 2 * h * vocab if vocab > 0 else 0 + per_token_fwd = num_layers * per_layer + lm_head + return 3 * per_token_fwd + + class PerfLogger: """Class to log performance metrics to stdout and wandb, and print final averaged metrics at the end of training. Args: dist_config: The distributed configuration. args: The arguments. + model_config_dict: Optional HF-style model config dict. When supplied together with + ``args.log_mfu`` set to True, the logger computes per-step Model FLOPs Utilization + (``train/mfu_pct``) and throughput (``train/tflops_per_gpu``) on each logging step. Attributes: min_loss: The minimum loss seen so far. """ - def __init__(self, dist_config: DistributedConfig, args: DictConfig): + def __init__(self, dist_config: DistributedConfig, args: DictConfig, model_config_dict: dict | None = None): """Initialize the logger.""" self._dist_config = dist_config self._run_config = OmegaConf.to_container(args, resolve=True, throw_on_missing=True) @@ -62,6 +128,24 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig): self.logging_frequency = args.logger.frequency + # MFU setup: compute per-token FLOPs and peak TFLOPS once at init. Actual FLOPs per + # step are derived at log time from the tracked unpadded token count, which already + # reflects each rank's share under DP/CP and sequence packing. + self._log_mfu = bool(args.get("log_mfu", False)) and model_config_dict is not None + self._per_token_flops = 0 + self._peak_tflops: float | None = None + if self._log_mfu: + self._per_token_flops = _compute_per_token_flops(model_config_dict, args.dataset.max_seq_length) + self._peak_tflops, gpu_name = _detect_peak_tflops_bf16() + if dist_config.local_rank == 0: + logger.info( + "MFU tracking enabled: GPU=%s, peak=%s TFLOPS BF16, per-token FLOPs=%.3e, seq_len=%d", + gpu_name, + f"{self._peak_tflops:.1f}" if self._peak_tflops else "unknown", + float(self._per_token_flops), + args.dataset.max_seq_length, + ) + metrics_dict = { "train/loss": torchmetrics.MeanMetric(), "train/grad_norm": torchmetrics.MeanMetric(), @@ -73,6 +157,10 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig): "train/gpu_memory_allocated_max_gb": torchmetrics.MaxMetric(), "train/gpu_memory_allocated_mean_gb": torchmetrics.MeanMetric(), } + if self._log_mfu: + metrics_dict["train/tflops_per_gpu"] = torchmetrics.MeanMetric() + if self._peak_tflops is not None: + metrics_dict["train/mfu_pct"] = torchmetrics.MeanMetric() self.metrics = torchmetrics.MetricCollection(metrics_dict) # We move metrics to a GPU device so we can use torch.distributed to aggregate them before logging. @@ -181,6 +269,17 @@ def log_step( self.metrics["train/unpadded_tokens_per_second_per_gpu"].update(self.num_unpadded_tokens / step_time) self.metrics["train/total_unpadded_tokens_per_batch"].update(self.num_unpadded_tokens) + if self._log_mfu: + # num_unpadded_tokens is accumulated over the grad-acc micro-batches of one + # optimizer step (the last step in the logging window), so this yields FLOPs + # per optimizer step per rank. step_time is already the per-step average. + tokens_on_rank = self.num_unpadded_tokens.item() + flops_per_step = self._per_token_flops * tokens_on_rank + tflops_per_gpu = flops_per_step / step_time / 1e12 + self.metrics["train/tflops_per_gpu"].update(tflops_per_gpu) + if self._peak_tflops is not None: + self.metrics["train/mfu_pct"].update(tflops_per_gpu / self._peak_tflops * 100.0) + memory_allocated = torch.cuda.memory_allocated() / (1024**3) self.metrics["train/gpu_memory_allocated_max_gb"].update(memory_allocated) self.metrics["train/gpu_memory_allocated_mean_gb"].update(memory_allocated) diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/tests/test_flops.py b/bionemo-recipes/recipes/opengenome2_llama_native_te/tests/test_flops.py deleted file mode 100644 index f514205fef..0000000000 --- a/bionemo-recipes/recipes/opengenome2_llama_native_te/tests/test_flops.py +++ /dev/null @@ -1,323 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# --- BEGIN COPIED FILE NOTICE --- -# This file is copied from: bionemo-recipes/recipes/llama3_native_te/tests/test_flops.py -# Do not modify this file directly. Instead, modify the source and run: -# python ci/scripts/check_copied_files.py --fix -# --- END COPIED FILE NOTICE --- - -"""Tests for the flops.py FLOPs counting and MFU module.""" - -import sys -from pathlib import Path - -import pytest - - -# Add parent directory so we can import flops -sys.path.insert(0, str(Path(__file__).parent.parent)) - -from flops import ( - MFUTracker, - ModelFLOPsConfig, - compute_flops_analytical, - compute_flops_hyena, - compute_flops_simplified, - estimate_cp_comm_bytes, - from_hf_config, -) - - -# ============================================================================ -# Test configs matching real models -# ============================================================================ - -LLAMA_1B_CONFIG = { - "hidden_size": 2048, - "num_hidden_layers": 25, - "num_attention_heads": 16, - "num_key_value_heads": 8, - "intermediate_size": 6144, - "vocab_size": 128256, - "model_type": "llama", - "hidden_act": "silu", -} - -ESM2_8M_CONFIG = { - "hidden_size": 320, - "num_hidden_layers": 6, - "num_attention_heads": 20, - "intermediate_size": 1280, - "vocab_size": 33, - "model_type": "nv_esm", - "hidden_act": "gelu", -} - -CODONFM_1B_CONFIG = { - "hidden_size": 2048, - "num_hidden_layers": 18, - "num_attention_heads": 16, - "intermediate_size": 8192, -} - - -# ============================================================================ -# Config auto-detection -# ============================================================================ - - -class TestFromHfConfig: - """Test auto-detection of model architecture from config dicts.""" - - def test_llama_detects_gqa_and_swiglu(self): - cfg = from_hf_config(LLAMA_1B_CONFIG) - assert cfg.num_kv_heads == 8 - assert cfg.num_mlp_projections == 3 - assert cfg.head_dim == 128 - - def test_esm2_detects_mha_and_standard_ffn(self): - cfg = from_hf_config(ESM2_8M_CONFIG) - assert cfg.num_kv_heads == 20 - assert cfg.num_mlp_projections == 2 - - def test_codonfm_defaults_to_mha_and_2_proj(self): - cfg = from_hf_config(CODONFM_1B_CONFIG) - assert cfg.num_kv_heads == 16 - assert cfg.num_mlp_projections == 2 - - def test_missing_vocab_defaults_to_no_lm_head(self): - cfg = from_hf_config( - {"hidden_size": 128, "num_hidden_layers": 2, "num_attention_heads": 4, "intermediate_size": 512} - ) - assert cfg.vocab_size == 0 - assert cfg.has_lm_head is False - - def test_overrides_take_precedence(self): - cfg = from_hf_config(ESM2_8M_CONFIG, num_mlp_projections=3, vocab_size=0, has_lm_head=False) - assert cfg.num_mlp_projections == 3 - assert cfg.vocab_size == 0 - assert cfg.has_lm_head is False - - -# ============================================================================ -# Analytical FLOPs formula -# ============================================================================ - - -class TestComputeFlopsAnalytical: - """Test the first-principles analytical FLOPs formula.""" - - def test_training_is_3x_forward(self): - cfg = from_hf_config(LLAMA_1B_CONFIG) - total, breakdown, lm_head = compute_flops_analytical(cfg, 1, 4096) - forward = cfg.num_hidden_layers * sum(breakdown.values()) + lm_head - assert total == 3 * forward - - def test_swiglu_has_3_mlp_projections(self): - cfg = from_hf_config(LLAMA_1B_CONFIG) - _, breakdown, _ = compute_flops_analytical(cfg, 1, 1024) - assert "Gate projection" in breakdown - assert "Up projection" in breakdown - assert "Down projection" in breakdown - - def test_standard_ffn_has_2_mlp_projections(self): - cfg = from_hf_config(ESM2_8M_CONFIG) - _, breakdown, _ = compute_flops_analytical(cfg, 1, 1024) - assert "Gate projection" not in breakdown - assert "Up projection" in breakdown - assert "Down projection" in breakdown - - def test_no_lm_head_when_vocab_zero(self): - cfg = from_hf_config(CODONFM_1B_CONFIG) - _, _, lm_head = compute_flops_analytical(cfg, 1, 1024) - assert lm_head == 0 - - def test_flops_scale_linearly_with_batch(self): - cfg = from_hf_config(LLAMA_1B_CONFIG) - flops_b1, _, _ = compute_flops_analytical(cfg, 1, 1024) - flops_b4, _, _ = compute_flops_analytical(cfg, 4, 1024) - assert flops_b4 == 4 * flops_b1 - - def test_known_value_llama_lingua_1b(self): - """Golden value: validated against PyTorch FlopCounterMode and README formula.""" - cfg = from_hf_config(LLAMA_1B_CONFIG) - total, _, _ = compute_flops_analytical(cfg, 1, 4096) - assert total == 47_687_021_887_488 - - -# ============================================================================ -# Simplified formula -# ============================================================================ - - -class TestComputeFlopsSimplified: - """Test the simplified README formula and its relationship to analytical.""" - - def test_matches_analytical_when_mha_and_i_equals_4h(self): - """ESM2 has MHA + I=4H + 2 projections: same assumptions as simplified formula.""" - cfg = from_hf_config(ESM2_8M_CONFIG) - analytical, _, _ = compute_flops_analytical(cfg, 1, 1024) - simplified = compute_flops_simplified(1, 1024, cfg.hidden_size, cfg.num_hidden_layers, cfg.vocab_size) - assert analytical == simplified - - def test_differs_when_gqa_or_swiglu(self): - """GQA + SwiGLU breaks the simplified formula's assumptions.""" - cfg_dict = { - **LLAMA_1B_CONFIG, - "num_attention_heads": 32, - "num_key_value_heads": 8, - "intermediate_size": 8192, - "num_hidden_layers": 16, - } - cfg = from_hf_config(cfg_dict) - analytical, _, _ = compute_flops_analytical(cfg, 1, 4096) - simplified = compute_flops_simplified(1, 4096, cfg.hidden_size, cfg.num_hidden_layers, cfg.vocab_size) - assert analytical != simplified - - -# ============================================================================ -# Hyena formula -# ============================================================================ - - -class TestComputeFlopsHyena: - """Test the Hyena (Evo2) FLOPs formula.""" - - @pytest.fixture() - def hyena_config(self): - return ModelFLOPsConfig( - hidden_size=1024, - num_hidden_layers=4, - num_attention_heads=8, - num_kv_heads=8, - head_dim=128, - intermediate_size=4096, - num_mlp_projections=3, - vocab_size=512, - has_lm_head=True, - ) - - def test_scales_subquadratically(self, hyena_config): - """Hyena uses O(S log S) convolution, not O(S^2) attention.""" - flops_1k = compute_flops_hyena(hyena_config, 1, 1024, hyena_layer_counts={"S": 0, "D": 0, "H": 4, "A": 0}) - flops_2k = compute_flops_hyena(hyena_config, 1, 2048, hyena_layer_counts={"S": 0, "D": 0, "H": 4, "A": 0}) - ratio = flops_2k / flops_1k - assert ratio < 3.0 # S^2 would give ~4x; Hyena should be well below - - def test_hybrid_attention_adds_quadratic_cost(self, hyena_config): - """Adding standard attention layers increases FLOPs due to S^2 term.""" - all_hyena = compute_flops_hyena(hyena_config, 1, 4096, hyena_layer_counts={"S": 0, "D": 0, "H": 4, "A": 0}) - with_attn = compute_flops_hyena(hyena_config, 1, 4096, hyena_layer_counts={"S": 0, "D": 0, "H": 2, "A": 2}) - assert with_attn > all_hyena - - -# ============================================================================ -# MFUTracker -# ============================================================================ - - -class TestMFUTracker: - """Test the MFUTracker class used by training scripts.""" - - def test_from_config_dict(self): - tracker = MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0) - assert tracker.total_flops == 47_687_021_887_488 - assert tracker.per_gpu_flops == tracker.total_flops - - def test_multi_gpu_divides_flops(self): - single = MFUTracker.from_config_dict( - LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, num_gpus=1, peak_tflops=155.0 - ) - multi = MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, num_gpus=2, peak_tflops=155.0) - assert multi.per_gpu_flops == single.total_flops // 2 - - def test_compute_mfu_correctness(self): - tracker = MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0) - result = tracker.compute_mfu(step_time=0.5) - expected_tflops = tracker.per_gpu_flops / 0.5 / 1e12 - expected_mfu = expected_tflops / 155.0 * 100 - assert abs(result["mfu"] - expected_mfu) < 0.01 - assert abs(result["tflops_per_gpu"] - expected_tflops) < 0.01 - - def test_mfu_inversely_proportional_to_step_time(self): - tracker = MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0) - fast = tracker.compute_mfu(step_time=0.5) - slow = tracker.compute_mfu(step_time=1.0) - assert abs(fast["mfu"] - 2 * slow["mfu"]) < 0.01 - - def test_all_formula_options(self): - for formula in ["analytical", "simplified", "hyena"]: - if formula == "hyena": - cfg = ModelFLOPsConfig( - hidden_size=1024, - num_hidden_layers=4, - num_attention_heads=8, - num_kv_heads=8, - head_dim=128, - intermediate_size=4096, - num_mlp_projections=3, - vocab_size=512, - has_lm_head=True, - ) - tracker = MFUTracker(cfg, batch_size=1, seq_len=4096, peak_tflops=155.0, formula=formula) - else: - tracker = MFUTracker.from_config_dict( - LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0, formula=formula - ) - assert tracker.total_flops > 0 - - def test_invalid_formula_raises(self): - with pytest.raises(ValueError, match="Unknown formula"): - MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0, formula="bad") - - def test_cp_communication_estimate(self): - tracker = MFUTracker.from_config_dict( - LLAMA_1B_CONFIG, batch_size=1, seq_len=16384, num_gpus=2, parallelism={"dp": 1, "cp": 2}, peak_tflops=155.0 - ) - assert tracker.comm_bytes > 0 - overhead = tracker.estimate_comm_overhead(step_time=1.0, measured_bw_gbps=6.6) - assert overhead["estimated_comm_time"] > 0 - assert 0 < overhead["comm_pct"] < 100 - - def test_no_comm_single_gpu(self): - tracker = MFUTracker.from_config_dict(LLAMA_1B_CONFIG, batch_size=1, seq_len=4096, peak_tflops=155.0) - assert tracker.comm_bytes == 0 - - -# ============================================================================ -# Communication estimation -# ============================================================================ - - -class TestCPCommEstimation: - """Test CP ring attention communication byte estimates.""" - - def test_zero_without_cp(self): - assert estimate_cp_comm_bytes(1, 4096, 25, 8, 128, cp_size=1) == 0 - - def test_scales_linearly_with_seq_len(self): - comm_4k = estimate_cp_comm_bytes(1, 4096, 25, 8, 128, cp_size=2) - comm_8k = estimate_cp_comm_bytes(1, 8192, 25, 8, 128, cp_size=2) - assert comm_8k == 2 * comm_4k - - def test_scales_linearly_with_batch(self): - comm_b1 = estimate_cp_comm_bytes(1, 4096, 25, 8, 128, cp_size=2) - comm_b4 = estimate_cp_comm_bytes(4, 4096, 25, 8, 128, cp_size=2) - assert comm_b4 == 4 * comm_b1 - - def test_known_value_lingua_1b(self): - """Golden value for lingua-1B at S=4096, CP=2.""" - assert estimate_cp_comm_bytes(1, 4096, 25, 8, 128, cp_size=2) == 419_430_400 diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/train_fsdp2.py b/bionemo-recipes/recipes/opengenome2_llama_native_te/train_fsdp2.py index 1963bab24d..8578e14136 100644 --- a/bionemo-recipes/recipes/opengenome2_llama_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/opengenome2_llama_native_te/train_fsdp2.py @@ -28,7 +28,6 @@ import gc import logging import random -import time from contextlib import nullcontext from pathlib import Path @@ -63,7 +62,6 @@ ) from dataset import create_bshd_dataloader, create_thd_dataloader from distributed_config import DistributedConfig -from flops import MFUTracker, from_hf_config from fp8_debugging import initialize_fp8_debugging from opengenome_modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM from optimizer import get_parameter_groups_with_weight_decay @@ -260,20 +258,11 @@ def main(args: DictConfig) -> float | None: start_step = 0 epoch = 0 - perf_logger = PerfLogger(dist_config, args) - - # --- MFU Tracking (optional) --- - mfu_tracker = None - if args.get("log_mfu", False): - mfu_tracker = MFUTracker( - config=from_hf_config(config.to_dict()), - batch_size=args.dataset.micro_batch_size, - seq_len=args.dataset.max_seq_length, - num_gpus=dist_config.world_size, - parallelism={"dp": dist_config.world_size}, - ) - if dist_config.is_main_process(): - logger.info("MFU tracking enabled:\n%s", mfu_tracker.summary()) + perf_logger = PerfLogger( + dist_config, + args, + model_config_dict=config.to_dict() if args.get("log_mfu", False) else None, + ) # Setup validation if enabled val_config = getattr(args, "validation", None) @@ -316,7 +305,6 @@ def main(args: DictConfig) -> float | None: logger.info(f"Starting training loop from step {start_step} to {args.num_train_steps}") step = start_step micro_step = 0 - step_start_time = time.perf_counter() if train_dataloader is None: raise RuntimeError("Expected train_dataloader to be initialized before training.") @@ -351,13 +339,6 @@ def main(args: DictConfig) -> float | None: lr=optimizer.param_groups[0]["lr"], ) - if mfu_tracker is not None: - step_time = time.perf_counter() - step_start_time - mfu_info = mfu_tracker.compute_mfu(step_time) - if dist_config.is_main_process(): - logger.info("MFU: %.1f%% (%.2f TFLOPS/GPU)", mfu_info["mfu"], mfu_info["tflops_per_gpu"]) - step_start_time = time.perf_counter() - if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps): save_checkpoint_fsdp2( model=model, diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/train_fsdp2_cp.py b/bionemo-recipes/recipes/opengenome2_llama_native_te/train_fsdp2_cp.py index a500adb8a2..f8d3c19757 100644 --- a/bionemo-recipes/recipes/opengenome2_llama_native_te/train_fsdp2_cp.py +++ b/bionemo-recipes/recipes/opengenome2_llama_native_te/train_fsdp2_cp.py @@ -33,7 +33,6 @@ import gc import logging -import time from contextlib import nullcontext from pathlib import Path @@ -70,7 +69,6 @@ from collator import ContextParallelDataLoaderWrapper, DataCollatorForContextParallel from dataset import create_bshd_dataloader, create_thd_dataloader from distributed_config import DistributedConfig -from flops import MFUTracker, from_hf_config from fp8_debugging import initialize_fp8_debugging from opengenome_modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM from optimizer import get_parameter_groups_with_weight_decay @@ -300,20 +298,11 @@ def main(args: DictConfig) -> float | None: start_step = 0 epoch = 0 - perf_logger = PerfLogger(dist_config, args) - - # --- MFU Tracking (optional) --- - mfu_tracker = None - if args.get("log_mfu", False): - mfu_tracker = MFUTracker( - config=from_hf_config(config.to_dict()), - batch_size=args.dataset.micro_batch_size, - seq_len=args.dataset.max_seq_length, - num_gpus=dist_config.world_size, - parallelism={"dp": dist_config.world_size // args.cp_size, "cp": args.cp_size}, - ) - if dist_config.is_main_process(): - logger.info("MFU tracking enabled:\n%s", mfu_tracker.summary()) + perf_logger = PerfLogger( + dist_config, + args, + model_config_dict=config.to_dict() if args.get("log_mfu", False) else None, + ) gc.collect() torch.cuda.empty_cache() @@ -322,7 +311,6 @@ def main(args: DictConfig) -> float | None: logger.info(f"Starting training loop from step {start_step} to {args.num_train_steps}") step = start_step micro_step = 0 - step_start_time = time.perf_counter() while step < args.num_train_steps: for batch in train_dataloader: @@ -359,13 +347,6 @@ def main(args: DictConfig) -> float | None: lr=optimizer.param_groups[0]["lr"], ) - if mfu_tracker is not None: - step_time = time.perf_counter() - step_start_time - mfu_info = mfu_tracker.compute_mfu(step_time) - if dist_config.is_main_process(): - logger.info("MFU: %.1f%% (%.2f TFLOPS/GPU)", mfu_info["mfu"], mfu_info["tflops_per_gpu"]) - step_start_time = time.perf_counter() - if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps): save_checkpoint_fsdp2( model=model, diff --git a/ci/scripts/check_copied_files.py b/ci/scripts/check_copied_files.py index 21c573fec5..92d65903cb 100755 --- a/ci/scripts/check_copied_files.py +++ b/ci/scripts/check_copied_files.py @@ -199,18 +199,6 @@ def _compare_file_contents(source_file: Path, dest_file: Path, source_display: s "bionemo-recipes/models/codonfm/modeling_codonfm_te.py": [ "bionemo-recipes/recipes/codonfm_native_te/modeling_codonfm_te.py", ], - # FLOPs / MFU module - synced across recipes - "bionemo-recipes/recipes/llama3_native_te/flops.py": [ - "bionemo-recipes/recipes/esm2_native_te/flops.py", - "bionemo-recipes/recipes/codonfm_native_te/flops.py", - "bionemo-recipes/recipes/opengenome2_llama_native_te/flops.py", - ], - # FLOPs tests - synced across recipes - "bionemo-recipes/recipes/llama3_native_te/tests/test_flops.py": [ - "bionemo-recipes/recipes/esm2_native_te/tests/test_flops.py", - "bionemo-recipes/recipes/codonfm_native_te/tests/test_flops.py", - "bionemo-recipes/recipes/opengenome2_llama_native_te/tests/test_flops.py", - ], # Common test library - synced between models "bionemo-recipes/models/esm2/tests/common": [ "bionemo-recipes/models/llama3/tests/common", From e2f493419b0093fffae2c4a5dbdb1bdcfbf45929 Mon Sep 17 00:00:00 2001 From: Gagan Kaushik Date: Sat, 18 Apr 2026 02:38:36 +0000 Subject: [PATCH 13/24] MFU: count configured-shape tokens, not attention_mask.sum() The prior implementation used `self.num_unpadded_tokens` (attention_mask.sum()) as the token count in the MFU formula, which produces an "effective-work MFU" that under-reports vs published numbers when the batch has padding. Switch to `self.num_tokens` (input_ids.numel()) to match the PaLM / Megatron-LM / MosaicML llm-foundry / torchtitan convention of counting the configured token shape the hardware actually runs matmuls over, regardless of attention masking. For fully-packed THD (no inter-sequence padding) the two counters are equivalent; for BSHD with padding, this restores standard MFU semantics. References for the convention: - PaLM (Chowdhery et al. 2022, App. B): "tokens per second ... peak throughput R = P / (6N + 12LHQT)" where T is the configured/padded sequence length. - Megatron-LM num_floating_point_operations() uses global batch * args.seq_length (configured/padded). - MosaicML llm-foundry and torchtitan both use configured seq_len. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Gagan Kaushik --- .../recipes/codonfm_native_te/perf_logger.py | 12 +++++++----- .../recipes/esm2_native_te/perf_logger.py | 9 +++++---- .../recipes/llama3_native_te/perf_logger.py | 12 +++++++----- .../opengenome2_llama_native_te/perf_logger.py | 12 +++++++----- 4 files changed, 26 insertions(+), 19 deletions(-) diff --git a/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py b/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py index 340451d7cc..f7e0aaefae 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py @@ -244,11 +244,13 @@ def log_step( self.metrics["train/unpadded_tokens_per_second_per_gpu"].update(self.num_unpadded_tokens / step_time) if self._log_mfu: - # num_unpadded_tokens is accumulated over the grad-acc micro-batches of one - # optimizer step (the last step in the logging window), so this yields FLOPs - # per optimizer step per rank. step_time is already the per-step average. - tokens_on_rank = self.num_unpadded_tokens.item() - flops_per_step = self._per_token_flops * tokens_on_rank + # PaLM/Megatron/MosaicML convention: count the configured-shape token budget + # (input_ids.numel() = B * S_padded for BSHD, or total packed tokens for THD), + # not attention_mask.sum(). The hardware executes matmuls over every position + # regardless of masking, and this matches published MFU numbers. + # num_tokens is accumulated over the grad-acc micro-batches of one optimizer + # step (the last step in the logging window). step_time is per-step average. + flops_per_step = self._per_token_flops * self.num_tokens tflops_per_gpu = flops_per_step / step_time / 1e12 self.metrics["train/tflops_per_gpu"].update(tflops_per_gpu) if self._peak_tflops is not None: diff --git a/bionemo-recipes/recipes/esm2_native_te/perf_logger.py b/bionemo-recipes/recipes/esm2_native_te/perf_logger.py index fa78b15327..c107d93f42 100644 --- a/bionemo-recipes/recipes/esm2_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/esm2_native_te/perf_logger.py @@ -213,10 +213,11 @@ def log_step( self.metrics["train/total_unpadded_tokens_per_batch"].update(num_unpadded_tokens) if self._log_mfu: - # Current batch's unpadded tokens already reflect this rank's share (CP - # shards the batch; DP replicates the model across ranks on distinct - # micro-batches). step_time is the per-step average over the logging window. - flops_per_step = self._per_token_flops * num_unpadded_tokens + # PaLM/Megatron/MosaicML convention: count the configured-shape token budget + # (input_ids.numel() = B * S_padded for BSHD, or total packed tokens for THD), + # not the attention-mask count. The hardware executes matmuls over every + # position regardless of masking, and this matches published MFU numbers. + flops_per_step = self._per_token_flops * num_tokens tflops_per_gpu = flops_per_step / step_time / 1e12 self.metrics["train/tflops_per_gpu"].update(tflops_per_gpu) if self._peak_tflops is not None: diff --git a/bionemo-recipes/recipes/llama3_native_te/perf_logger.py b/bionemo-recipes/recipes/llama3_native_te/perf_logger.py index 3bebdee347..9503eb09a0 100644 --- a/bionemo-recipes/recipes/llama3_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/llama3_native_te/perf_logger.py @@ -269,11 +269,13 @@ def log_step( self.metrics["train/total_unpadded_tokens_per_batch"].update(self.num_unpadded_tokens) if self._log_mfu: - # num_unpadded_tokens is accumulated over the grad-acc micro-batches of one - # optimizer step (the last step in the logging window), so this yields FLOPs - # per optimizer step per rank. step_time is already the per-step average. - tokens_on_rank = self.num_unpadded_tokens.item() - flops_per_step = self._per_token_flops * tokens_on_rank + # PaLM/Megatron/MosaicML convention: count the configured-shape token budget + # (input_ids.numel() = B * S_padded for BSHD, or total packed tokens for THD), + # not attention_mask.sum(). The hardware executes matmuls over every position + # regardless of masking, and this matches published MFU numbers. + # num_tokens is accumulated over the grad-acc micro-batches of one optimizer + # step (the last step in the logging window). step_time is per-step average. + flops_per_step = self._per_token_flops * self.num_tokens tflops_per_gpu = flops_per_step / step_time / 1e12 self.metrics["train/tflops_per_gpu"].update(tflops_per_gpu) if self._peak_tflops is not None: diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py b/bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py index 0633193d3f..7b09d47d8f 100644 --- a/bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py @@ -270,11 +270,13 @@ def log_step( self.metrics["train/total_unpadded_tokens_per_batch"].update(self.num_unpadded_tokens) if self._log_mfu: - # num_unpadded_tokens is accumulated over the grad-acc micro-batches of one - # optimizer step (the last step in the logging window), so this yields FLOPs - # per optimizer step per rank. step_time is already the per-step average. - tokens_on_rank = self.num_unpadded_tokens.item() - flops_per_step = self._per_token_flops * tokens_on_rank + # PaLM/Megatron/MosaicML convention: count the configured-shape token budget + # (input_ids.numel() = B * S_padded for BSHD, or total packed tokens for THD), + # not attention_mask.sum(). The hardware executes matmuls over every position + # regardless of masking, and this matches published MFU numbers. + # num_tokens is accumulated over the grad-acc micro-batches of one optimizer + # step (the last step in the logging window). step_time is per-step average. + flops_per_step = self._per_token_flops * self.num_tokens tflops_per_gpu = flops_per_step / step_time / 1e12 self.metrics["train/tflops_per_gpu"].update(tflops_per_gpu) if self._peak_tflops is not None: From f4858eb9b5a67d27000855e363ee1eb2f0c6bafb Mon Sep 17 00:00:00 2001 From: Gagan Kaushik Date: Wed, 22 Apr 2026 12:25:53 -0700 Subject: [PATCH 14/24] =?UTF-8?q?MFU:=20use=20=CE=A3(L=E1=B5=A2=C2=B2)=20f?= =?UTF-8?q?or=20attention=20work;=20fix=20ESM-2=20grad-acc=20undercount?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previous formula baked `4·S·H` attention per token at init-time using `args.dataset.max_seq_length`, then multiplied by num_tokens at log time. For a packed THD batch of total length S with docs L₁, L₂, …, Lₙ, this yields `4·H·S²` attention per rank — overstating work by `S²/Σ(Lᵢ²)` because Flash-Attention actually computes one Lᵢ×Lᵢ score matrix per packed segment. Split per-token FLOPs into non-attention (QKV/MLP/LM-head) and attention coefficient, then at each logging window: flops = non_attn_per_token · num_tokens + coeff · Σ(Lᵢ²) / cp_size `Σ(Lᵢ²)` is accumulated across grad-acc micro-batches from `cu_seq_lens_q` (real per-doc lengths) in log_micro_step, falling back to `cu_seq_lens_q_padded` or `B·S²` from the batch shape for BSHD. The algebraic identity `non_attn + coeff·S ≡ _compute_per_token_flops(cfg, S)` is locked in by a unit test in every recipe, so BSHD and synthetic-single-doc-THD benchmarks produce identical numbers to the legacy formula. Also fixes ESM-2's latent grad-acc undercount: its PerfLogger previously had no log_micro_step and read num_tokens from only the last micro-batch at log time, silently dividing reported FLOPs by grad_acc_steps if grad-acc were ever enabled. Retrofitted with the llama3/og2/codonfm accumulator pattern: all five ESM-2 training scripts (train_{fsdp2, fsdp2_cp,ddp,ddp_cp,mfsdp}.py) now call log_micro_step inside the inner loop and log_step(step, grad_norm, lr) after the optimizer step. Applied identically to llama3_native_te, opengenome2_llama_native_te, codonfm_native_te, esm2_native_te per the per-recipe isolation rule. Six invariant unit tests land in each recipe's test_perf_logger.py (algebraic identity, BSHD no-op, THD single-doc/multi-doc/CP division, unpadded-preferred, padded-fallback) plus ESM-2-specific grad-acc accumulator tests. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Gagan Kaushik --- .../recipes/codonfm_native_te/perf_logger.py | 116 ++++++++- .../tests/test_perf_logger.py | 94 ++++++- .../recipes/esm2_native_te/perf_logger.py | 216 +++++++++++++--- .../esm2_native_te/tests/test_perf_logger.py | 234 ++++++++++++++++++ .../recipes/esm2_native_te/train_ddp.py | 5 +- .../recipes/esm2_native_te/train_ddp_cp.py | 5 +- .../recipes/esm2_native_te/train_fsdp2.py | 5 +- .../recipes/esm2_native_te/train_fsdp2_cp.py | 6 +- .../recipes/esm2_native_te/train_mfsdp.py | 5 +- .../recipes/llama3_native_te/perf_logger.py | 116 ++++++++- .../tests/test_perf_logger.py | 95 +++++++ .../perf_logger.py | 116 ++++++++- .../tests/test_perf_logger.py | 117 +++++++++ 13 files changed, 1039 insertions(+), 91 deletions(-) create mode 100644 bionemo-recipes/recipes/esm2_native_te/tests/test_perf_logger.py create mode 100644 bionemo-recipes/recipes/opengenome2_llama_native_te/tests/test_perf_logger.py diff --git a/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py b/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py index f7e0aaefae..cdeb46df1b 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py @@ -70,11 +70,15 @@ def _compute_per_token_flops(model_config_dict: dict, seq_len: int) -> int: """Training FLOPs per token for a transformer (forward + backward = 3x forward). First-principles matmul count: Q/K/V/O projections (GQA-aware), attention - logits/values (the S^2 cost expressed per-token as 4*S*H), 2-or-3-projection - MLP (SwiGLU detected via model_type), and LM head. The returned value is - multiplied by the actual unpadded token count at log time, so it naturally - handles BSHD, THD (sequence packing), gradient accumulation, DP, and CP: - unpadded tokens on each rank already reflect that rank's share of work. + logits/values (the S^2 cost expressed per-token as 4*S*H for a uniform + BSHD batch of length seq_len), 2-or-3-projection MLP (SwiGLU detected via + model_type), and LM head. + + Kept for back-compat. For accurate per-step accounting use + ``_compute_non_attn_per_token_flops`` (applied to the total token count) + together with ``_compute_attn_flop_coeff`` (applied to Σ(Lᵢ²) from + cu_seq_lens), since a packed THD batch of total length S containing docs + L₁, L₂, … has actual attention work Σ(Lᵢ²) ≤ S², not B·S². """ h = model_config_dict["hidden_size"] n_heads = model_config_dict["num_attention_heads"] @@ -99,6 +103,69 @@ def _compute_per_token_flops(model_config_dict: dict, seq_len: int) -> int: return 3 * per_token_fwd +def _compute_non_attn_per_token_flops(model_config_dict: dict) -> int: + """Per-token FLOPs for everything EXCEPT the S² attention term. + + Q/K/V/O projections (GQA-aware) + MLP + LM head, 3x for fwd+bwd. Multiply by the + actual total token count of the batch to get per-step non-attention FLOPs. Pairs + with ``_compute_attn_flop_coeff`` so that + ``non_attn + coeff·S ≡ _compute_per_token_flops(cfg, S)``. + """ + h = model_config_dict["hidden_size"] + n_heads = model_config_dict["num_attention_heads"] + n_kv = model_config_dict.get("num_key_value_heads", n_heads) + head_dim = h // n_heads + kv_dim = n_kv * head_dim + ffn = model_config_dict["intermediate_size"] + vocab = model_config_dict.get("vocab_size", 0) + num_layers = model_config_dict["num_hidden_layers"] + model_type = model_config_dict.get("model_type", "") + num_mlp_proj = 3 if model_type in _GATED_MLP_MODEL_TYPES else 2 + + per_layer = ( + 2 * h * h # Q projection + + 4 * h * kv_dim # K + V projections (GQA-aware) + + 2 * h * h # O projection + + 2 * num_mlp_proj * h * ffn # MLP (2 or 3 projections) + ) + lm_head = 2 * h * vocab if vocab > 0 else 0 + return 3 * (num_layers * per_layer + lm_head) + + +def _compute_attn_flop_coeff(model_config_dict: dict) -> int: + """Coefficient K such that per-step attention FLOPs = K · Σ(Lᵢ²) globally. + + Per CP rank: ``K · Σ(Lᵢ²) / cp_size`` — each CP rank computes 1/cp_size of each + doc's Lᵢ * Lᵢ score matrix. The 4 counts QK^T (2) + softmax·V (2); the 3 is + fwd+bwd. Hidden size appears linearly because attention is over heads and each + contributes head_dim, and heads * head_dim == h. + """ + h = model_config_dict["hidden_size"] + num_layers = model_config_dict["num_hidden_layers"] + return 3 * num_layers * 4 * h + + +def _attn_work_from_batch(batch: dict, device: torch.device) -> torch.Tensor: + """Return Σ(Lᵢ²) for this batch as a GPU int64 scalar tensor. + + THD: uses ``cu_seq_lens_q`` (real per-doc lengths — excludes CP zigzag padding + that FA processes as cycles but contributes nothing to the training signal). + Falls back to ``cu_seq_lens_q_padded`` if only the padded variant is present. + BSHD: no cu_seq_lens in batch → each of B rows is a single "doc" of length S, + so Σ(Lᵢ²) = B·S². Int32 lens are cast to int64 BEFORE squaring (overflow at + L ≈ 46k otherwise). + """ + cu = batch.get("cu_seq_lens_q") + if cu is None: + cu = batch.get("cu_seq_lens_q_padded") + if cu is not None: + lens = (cu[1:] - cu[:-1]).to(torch.int64) + return (lens * lens).sum() + shape = batch["input_ids"].shape + batch_size, seq_len = int(shape[0]), int(shape[-1]) + return torch.tensor(batch_size * seq_len * seq_len, dtype=torch.int64, device=device) + + class PerfLogger: """Performance logger for CodonFM training. @@ -127,17 +194,26 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig, model_confi # reflects each rank's share under DP and sequence packing. self._log_mfu = bool(args.get("log_mfu", False)) and model_config_dict is not None self._per_token_flops = 0 + self._non_attn_per_token_flops = 0 + self._attn_flop_coeff = 0 + self._cp_size = int(args.get("cp_size", 1)) self._peak_tflops: float | None = None if self._log_mfu: self._per_token_flops = _compute_per_token_flops(model_config_dict, args.dataset.max_seq_length) + self._non_attn_per_token_flops = _compute_non_attn_per_token_flops(model_config_dict) + self._attn_flop_coeff = _compute_attn_flop_coeff(model_config_dict) self._peak_tflops, gpu_name = _detect_peak_tflops_bf16() if dist_config.local_rank == 0: logger.info( - "MFU tracking enabled: GPU=%s, peak=%s TFLOPS BF16, per-token FLOPs=%.3e, seq_len=%d", + "MFU tracking enabled: GPU=%s, peak=%s TFLOPS BF16, per-token FLOPs=%.3e, seq_len=%d, " + "non_attn_per_token=%.3e, attn_coeff=%.3e, cp_size=%d", gpu_name, f"{self._peak_tflops:.1f}" if self._peak_tflops else "unknown", float(self._per_token_flops), args.dataset.max_seq_length, + float(self._non_attn_per_token_flops), + float(self._attn_flop_coeff), + self._cp_size, ) metrics_dict = { @@ -170,6 +246,8 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig, model_confi # Gradient accumulation tracking self.num_tokens = 0 self.num_unpadded_tokens = torch.tensor(0, dtype=torch.int64, device=self._device) + # Σ(Lᵢ²) over grad-acc micro-batches — drives the attention-FLOP term at log time. + self._attn_work_accum = torch.tensor(0, dtype=torch.int64, device=self._device) self.running_loss = torch.tensor(0.0, device=self._device) self.grad_acc_step_count = 0 @@ -191,6 +269,10 @@ def log_micro_step(self, step: int, batch: dict[str, torch.Tensor], outputs: Mas self.num_tokens += batch["input_ids"].numel() num_unpadded_tokens = batch["input_ids"][batch["input_ids"] != PAD_TOKEN_ID].numel() self.num_unpadded_tokens += num_unpadded_tokens + if self._log_mfu: + # Σ(Lᵢ²) from cu_seq_lens (THD) or B·S² from input_ids shape (BSHD). + # Accumulates across grad-acc micro-batches; drained in log_step. + self._attn_work_accum += _attn_work_from_batch(batch, self._device) # Update perplexity per micro-batch since it needs logits + labels logits = outputs.logits @@ -244,13 +326,20 @@ def log_step( self.metrics["train/unpadded_tokens_per_second_per_gpu"].update(self.num_unpadded_tokens / step_time) if self._log_mfu: - # PaLM/Megatron/MosaicML convention: count the configured-shape token budget - # (input_ids.numel() = B * S_padded for BSHD, or total packed tokens for THD), - # not attention_mask.sum(). The hardware executes matmuls over every position - # regardless of masking, and this matches published MFU numbers. - # num_tokens is accumulated over the grad-acc micro-batches of one optimizer - # step (the last step in the logging window). step_time is per-step average. - flops_per_step = self._per_token_flops * self.num_tokens + # Two-term FLOP accounting: + # non_attn_flops = non_attn_per_token * num_tokens + # attn_flops = attn_flop_coeff * Σ(Lᵢ²) / cp_size + # num_tokens follows the PaLM/Megatron/MosaicML convention of counting + # every slot in input_ids (hardware runs QKV/MLP/LM-head matmuls over + # padded positions too). The attention term uses the real per-doc + # Σ(Lᵢ²) from cu_seq_lens because FA gates the score matrix on doc + # boundaries, so a packed batch with fragmented docs does genuinely + # less attention work than a uniform B·S² would imply — the old + # single-term formula overstated MFU in exactly that regime. + attn_work = int(self._attn_work_accum.item()) + non_attn_flops = self._non_attn_per_token_flops * self.num_tokens + attn_flops = (self._attn_flop_coeff * attn_work) // self._cp_size + flops_per_step = non_attn_flops + attn_flops tflops_per_gpu = flops_per_step / step_time / 1e12 self.metrics["train/tflops_per_gpu"].update(tflops_per_gpu) if self._peak_tflops is not None: @@ -280,6 +369,7 @@ def log_step( self.running_loss.zero_() self.num_tokens = 0 self.num_unpadded_tokens.zero_() + self._attn_work_accum.zero_() self.grad_acc_step_count = 0 def finish(self): diff --git a/bionemo-recipes/recipes/codonfm_native_te/tests/test_perf_logger.py b/bionemo-recipes/recipes/codonfm_native_te/tests/test_perf_logger.py index b2e6f651eb..d417a283d5 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/tests/test_perf_logger.py +++ b/bionemo-recipes/recipes/codonfm_native_te/tests/test_perf_logger.py @@ -21,7 +21,13 @@ import torch from distributed_config import DistributedConfig from omegaconf import OmegaConf -from perf_logger import PerfLogger +from perf_logger import ( + PerfLogger, + _attn_work_from_batch, + _compute_attn_flop_coeff, + _compute_non_attn_per_token_flops, + _compute_per_token_flops, +) from transformers.modeling_outputs import MaskedLMOutput @@ -210,3 +216,89 @@ def test_min_loss_tracked_correctly(self, mock_wandb, mock_tqdm): _run_steps(perf_logger, losses) assert perf_logger.min_loss.item() == pytest.approx(1.0) + + +def _codon_cfg(): + """CodonFM-like config for the split-formula tests (MLM encoder).""" + return { + "model_type": "codonfm", # not in _GATED_MLP_MODEL_TYPES → standard 2-proj MLP + "hidden_size": 1024, + "num_hidden_layers": 24, + "num_attention_heads": 16, + "intermediate_size": 4096, + "vocab_size": VOCAB_SIZE, + } + + +class TestFlopSplitAndAttention: + """Verify the split non-attn + Σ(Lᵢ²) attention formula.""" + + def test_algebraic_identity(self): + """non_attn + coeff·S ≡ _compute_per_token_flops(cfg, S) for all S.""" + cfg = _codon_cfg() + for s in (256, 512, 1024, 8192): + lhs = _compute_non_attn_per_token_flops(cfg) + _compute_attn_flop_coeff(cfg) * s + rhs = _compute_per_token_flops(cfg, s) + assert lhs == rhs, f"S={s}: {lhs} != {rhs}" + + def test_bshd_no_op(self): + """BSHD batch (no cu_seq_lens) with cp=1 matches legacy formula exactly.""" + cfg = _codon_cfg() + b, s = 4, 512 + batch = {"input_ids": torch.zeros(b, s, dtype=torch.long)} + sigma_l_sq = _attn_work_from_batch(batch, torch.device("cpu")).item() + assert sigma_l_sq == b * s * s + new_flops = _compute_non_attn_per_token_flops(cfg) * b * s + _compute_attn_flop_coeff(cfg) * sigma_l_sq + legacy_flops = _compute_per_token_flops(cfg, s) * b * s + assert new_flops == legacy_flops + + def test_thd_single_doc_matches_bshd(self): + """cu_seq_lens_q=[0, S] reproduces BSHD's Σ(Lᵢ²)=S².""" + s = 512 + bshd = {"input_ids": torch.zeros(1, s, dtype=torch.long)} + thd = { + "input_ids": torch.zeros(1, s, dtype=torch.long), + "cu_seq_lens_q": torch.tensor([0, s], dtype=torch.int32), + } + assert _attn_work_from_batch(bshd, torch.device("cpu")).item() == s * s + assert _attn_work_from_batch(thd, torch.device("cpu")).item() == s * s + + def test_thd_multi_doc_uses_squared_sum(self): + """Multi-doc pack computes Σ(Lᵢ²), not (ΣLᵢ)².""" + cu = torch.tensor([0, 3, 8, 15], dtype=torch.int32) + batch = {"input_ids": torch.zeros(1, 15, dtype=torch.long), "cu_seq_lens_q": cu} + work = _attn_work_from_batch(batch, torch.device("cpu")).item() + assert work == 3**2 + 5**2 + 7**2 + assert work < 15 * 15 + + def test_cp_size_divides_attention_only(self): + """cp_size divides the attention term only; non-attn stays untouched. + Codonfm doesn't support CP, but the formula must still respect cp_size=1 default.""" + cfg = _codon_cfg() + non_attn_per_token = _compute_non_attn_per_token_flops(cfg) + coeff = _compute_attn_flop_coeff(cfg) + num_tokens, attn_work = 100, 10_000 + non_attn = non_attn_per_token * num_tokens + flops_cp1 = non_attn + (coeff * attn_work) // 1 + flops_cp4 = non_attn + (coeff * attn_work) // 4 + assert flops_cp1 - non_attn == coeff * attn_work + assert flops_cp4 - non_attn == (coeff * attn_work) // 4 + + def test_unpadded_preferred_over_padded(self): + """When both cu_seq_lens_q and cu_seq_lens_q_padded are present, _q wins.""" + batch = { + "input_ids": torch.zeros(1, 16, dtype=torch.long), + "cu_seq_lens_q": torch.tensor([0, 5, 11], dtype=torch.int32), + "cu_seq_lens_q_padded": torch.tensor([0, 8, 16], dtype=torch.int32), + } + work = _attn_work_from_batch(batch, torch.device("cpu")).item() + assert work == 5**2 + 6**2 + assert work != 8**2 + 8**2 + + def test_padded_fallback_when_unpadded_absent(self): + """If only cu_seq_lens_q_padded is present, it is used as a fallback.""" + batch = { + "input_ids": torch.zeros(1, 16, dtype=torch.long), + "cu_seq_lens_q_padded": torch.tensor([0, 8, 16], dtype=torch.int32), + } + assert _attn_work_from_batch(batch, torch.device("cpu")).item() == 8**2 + 8**2 diff --git a/bionemo-recipes/recipes/esm2_native_te/perf_logger.py b/bionemo-recipes/recipes/esm2_native_te/perf_logger.py index c107d93f42..2f068dfef2 100644 --- a/bionemo-recipes/recipes/esm2_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/esm2_native_te/perf_logger.py @@ -32,6 +32,9 @@ logger = logging.getLogger(__name__) +# ESM-2 uses token id 1 for the token. Unpadded-token counting filters this id out. +ESM2_PAD_TOKEN_ID = 1 + # Dense BF16 tensor core peak TFLOPS (without sparsity). Product pages often list # the 2x sparse number; dense = sparse / 2. Sources: NVIDIA datasheets for each GPU. _GPU_PEAK_TFLOPS_BF16 = { @@ -66,11 +69,15 @@ def _compute_per_token_flops(model_config_dict: dict, seq_len: int) -> int: """Training FLOPs per token for a transformer (forward + backward = 3x forward). First-principles matmul count: Q/K/V/O projections (GQA-aware), attention - logits/values (the S^2 cost expressed per-token as 4*S*H), 2-or-3-projection - MLP (SwiGLU detected via model_type), and LM head. The returned value is - multiplied by the actual unpadded token count at log time, so it naturally - handles BSHD, THD (sequence packing), DP, and CP: unpadded tokens on each - rank already reflect that rank's share of work. + logits/values (the S^2 cost expressed per-token as 4*S*H for a uniform + BSHD batch of length seq_len), 2-or-3-projection MLP (SwiGLU detected via + model_type), and LM head. + + Kept for back-compat. For accurate per-step accounting use + ``_compute_non_attn_per_token_flops`` (applied to the total token count) + together with ``_compute_attn_flop_coeff`` (applied to Σ(Lᵢ²) from + cu_seq_lens), since a packed THD batch of total length S containing docs + L₁, L₂, … has actual attention work Σ(Lᵢ²) ≤ S², not B·S². """ h = model_config_dict["hidden_size"] n_heads = model_config_dict["num_attention_heads"] @@ -95,9 +102,77 @@ def _compute_per_token_flops(model_config_dict: dict, seq_len: int) -> int: return 3 * per_token_fwd +def _compute_non_attn_per_token_flops(model_config_dict: dict) -> int: + """Per-token FLOPs for everything EXCEPT the S² attention term. + + Q/K/V/O projections (GQA-aware) + MLP + LM head, 3x for fwd+bwd. Multiply by the + actual total token count of the batch to get per-step non-attention FLOPs. Pairs + with ``_compute_attn_flop_coeff`` so that + ``non_attn + coeff·S ≡ _compute_per_token_flops(cfg, S)``. + """ + h = model_config_dict["hidden_size"] + n_heads = model_config_dict["num_attention_heads"] + n_kv = model_config_dict.get("num_key_value_heads", n_heads) + head_dim = h // n_heads + kv_dim = n_kv * head_dim + ffn = model_config_dict["intermediate_size"] + vocab = model_config_dict.get("vocab_size", 0) + num_layers = model_config_dict["num_hidden_layers"] + model_type = model_config_dict.get("model_type", "") + num_mlp_proj = 3 if model_type in _GATED_MLP_MODEL_TYPES else 2 + + per_layer = ( + 2 * h * h # Q projection + + 4 * h * kv_dim # K + V projections (GQA-aware) + + 2 * h * h # O projection + + 2 * num_mlp_proj * h * ffn # MLP (2 or 3 projections) + ) + lm_head = 2 * h * vocab if vocab > 0 else 0 + return 3 * (num_layers * per_layer + lm_head) + + +def _compute_attn_flop_coeff(model_config_dict: dict) -> int: + """Coefficient K such that per-step attention FLOPs = K · Σ(Lᵢ²) globally. + + Per CP rank: ``K · Σ(Lᵢ²) / cp_size`` — each CP rank computes 1/cp_size of each + doc's Lᵢ * Lᵢ score matrix. The 4 counts QK^T (2) + softmax·V (2); the 3 is + fwd+bwd. Hidden size appears linearly because attention is over heads and each + contributes head_dim, and heads * head_dim == h. + """ + h = model_config_dict["hidden_size"] + num_layers = model_config_dict["num_hidden_layers"] + return 3 * num_layers * 4 * h + + +def _attn_work_from_batch(batch: dict, device: torch.device) -> torch.Tensor: + """Return Σ(Lᵢ²) for this batch as a GPU int64 scalar tensor. + + THD: uses ``cu_seq_lens_q`` (real per-doc lengths — excludes CP zigzag padding + that FA processes as cycles but contributes nothing to the training signal). + Falls back to ``cu_seq_lens_q_padded`` if only the padded variant is present. + BSHD: no cu_seq_lens in batch → each of B rows is a single "doc" of length S, + so Σ(Lᵢ²) = B·S². Int32 lens are cast to int64 BEFORE squaring (overflow at + L ≈ 46k otherwise). + """ + cu = batch.get("cu_seq_lens_q") + if cu is None: + cu = batch.get("cu_seq_lens_q_padded") + if cu is not None: + lens = (cu[1:] - cu[:-1]).to(torch.int64) + return (lens * lens).sum() + shape = batch["input_ids"].shape + batch_size, seq_len = int(shape[0]), int(shape[-1]) + return torch.tensor(batch_size * seq_len * seq_len, dtype=torch.int64, device=device) + + class PerfLogger: """Class to log performance metrics to stdout and wandb, and print final averaged metrics at the end of training. + Uses the ``log_micro_step`` / ``log_step`` accumulator pattern (shared with the + llama3/og2/codonfm recipes) so gradient accumulation is correctly handled: + token counts, Σ(Lᵢ²), perplexity updates, and loss accumulate across every + micro-batch of an optimizer step; metrics are reported once per logging window. + Args: dist_config: The distributed configuration. args: The arguments. @@ -114,27 +189,36 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig, model_confi self._dist_config = dist_config self._run_config = OmegaConf.to_container(args, resolve=True, throw_on_missing=True) - self.min_loss = torch.tensor(float("inf"), device=torch.device(f"cuda:{dist_config.local_rank}")) + self._device = torch.device(f"cuda:{dist_config.local_rank}") + self.min_loss = torch.tensor(float("inf"), device=self._device) self.logging_frequency = args.logger.frequency - # Track whether to collect memory stats (disabled by default for max performance) # MFU setup: compute per-token FLOPs and peak TFLOPS once at init. Actual FLOPs per - # step are derived at log time from the current batch's unpadded token count, which + # step are derived at log time from the accumulated token count + Σ(Lᵢ²), which # already reflects each rank's share under DP/CP and sequence packing. self._log_mfu = bool(args.get("log_mfu", False)) and model_config_dict is not None self._per_token_flops = 0 + self._non_attn_per_token_flops = 0 + self._attn_flop_coeff = 0 + self._cp_size = int(args.get("cp_size", 1)) self._peak_tflops: float | None = None if self._log_mfu: self._per_token_flops = _compute_per_token_flops(model_config_dict, args.dataset.max_seq_length) + self._non_attn_per_token_flops = _compute_non_attn_per_token_flops(model_config_dict) + self._attn_flop_coeff = _compute_attn_flop_coeff(model_config_dict) self._peak_tflops, gpu_name = _detect_peak_tflops_bf16() if dist_config.local_rank == 0: logger.info( - "MFU tracking enabled: GPU=%s, peak=%s TFLOPS BF16, per-token FLOPs=%.3e, seq_len=%d", + "MFU tracking enabled: GPU=%s, peak=%s TFLOPS BF16, per-token FLOPs=%.3e, seq_len=%d, " + "non_attn_per_token=%.3e, attn_coeff=%.3e, cp_size=%d", gpu_name, f"{self._peak_tflops:.1f}" if self._peak_tflops else "unknown", float(self._per_token_flops), args.dataset.max_seq_length, + float(self._non_attn_per_token_flops), + float(self._attn_flop_coeff), + self._cp_size, ) metrics_dict = { @@ -156,7 +240,7 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig, model_confi self.metrics = torchmetrics.MetricCollection(metrics_dict) # We move metrics to a GPU device so we can use torch.distributed to aggregate them before logging. - self.metrics.to(torch.device(f"cuda:{dist_config.local_rank}")) + self.metrics.to(self._device) self.previous_step_time = time.perf_counter() if self._dist_config.is_main_process(): @@ -167,24 +251,63 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig, model_confi # Whether to step debug_api.step() after each step self.quant_stats_config = args.quant_stats_config.enabled + # Gradient accumulation tracking (accumulated over the grad-acc micro-batches of + # the last optimizer step in the logging window, then drained in log_step). + self.num_tokens = 0 + self.num_unpadded_tokens = torch.tensor(0, dtype=torch.int64, device=self._device) + # Σ(Lᵢ²) over grad-acc micro-batches — drives the attention-FLOP term at log time. + self._attn_work_accum = torch.tensor(0, dtype=torch.int64, device=self._device) + self.running_loss = torch.tensor(0.0, device=self._device) + self.grad_acc_step_count = 0 + + def log_micro_step(self, step: int, batch: dict[str, torch.Tensor], outputs: MaskedLMOutput): + """Store data on micro step for gradient accumulation metrics. + + Args: + step: The current optimizer step number (shared across all micro-batches). + batch: The input batch for this micro-step. + outputs: Model outputs for this micro-step (with unscaled loss). + """ + assert outputs.loss is not None, "Loss is None" + + with torch.no_grad(): + self.grad_acc_step_count += 1 + self.running_loss += outputs.loss + + if step % self.logging_frequency == 0 and step > 0: + self.num_tokens += batch["input_ids"].numel() + num_unpadded_tokens = batch["input_ids"][batch["input_ids"] != ESM2_PAD_TOKEN_ID].numel() + self.num_unpadded_tokens += num_unpadded_tokens + if self._log_mfu: + # Σ(Lᵢ²) from cu_seq_lens (THD) or B·S² from input_ids shape (BSHD). + # Accumulates across grad-acc micro-batches; drained in log_step. + self._attn_work_accum += _attn_work_from_batch(batch, self._device) + + # Update perplexity per micro-batch since it needs logits + labels. + logits = outputs.logits + if logits.dim() < 3: + logits = logits.unsqueeze(0) + self.metrics["train/perplexity"].update(logits, batch["labels"]) + def log_step( self, step: int, - batch: dict[str, torch.Tensor], - outputs: MaskedLMOutput, - grad_norm: torch.Tensor | DTensor, + grad_norm: torch.Tensor | DTensor | float, lr: float, ): - """Log a step to the logger and wandb. + """Log a training step (called once per optimizer step). Args: - step: The step number. - batch: The batch of data for the step. - outputs: The outputs of the step. - grad_norm: The gradient norm of the step. - lr: The learning rate of the step. + step: Current optimizer step. + grad_norm: Gradient norm value. + lr: Current learning rate. """ with torch.no_grad(): + assert self.grad_acc_step_count > 0, ( + f"Gradient accumulation steps ({self.grad_acc_step_count}) must be greater than 0, " + f"and can be incremented by log_micro_step()." + ) + # FSDP2's clip_grad_norm_ returns a DTensor; convert to local tensor for torchmetrics compatibility. if isinstance(grad_norm, DTensor): grad_norm = grad_norm.to_local() @@ -192,43 +315,47 @@ def log_step( if self.quant_stats_config: debug_api.step() - if step % self.logging_frequency == 0 and step > 0: - num_tokens = batch["input_ids"].numel() - # 1 is the padding token for ESM-2. - num_unpadded_tokens = batch["input_ids"][batch["input_ids"] != 1].numel() + # Calculate average loss over all micro steps in the logging window. + avg_loss = self.running_loss / self.grad_acc_step_count + self.min_loss = torch.minimum(self.min_loss, avg_loss) - self.min_loss = torch.minimum(self.min_loss, outputs.loss) + if step % self.logging_frequency == 0 and step > 0: elapsed_time, self.previous_step_time = ( time.perf_counter() - self.previous_step_time, time.perf_counter(), ) step_time = elapsed_time / self.logging_frequency - self.metrics["train/loss"].update(outputs.loss) + self.metrics["train/loss"].update(avg_loss) self.metrics["train/learning_rate"].update(lr) - self.metrics["train/grad_norm"].update(grad_norm) + self.metrics["train/grad_norm"].update( + grad_norm if isinstance(grad_norm, torch.Tensor) else torch.tensor(grad_norm) + ) self.metrics["train/step_time"].update(step_time) - self.metrics["train/tokens_per_second_per_gpu"].update(num_tokens / step_time) - self.metrics["train/unpadded_tokens_per_second_per_gpu"].update(num_unpadded_tokens / step_time) - self.metrics["train/total_unpadded_tokens_per_batch"].update(num_unpadded_tokens) + self.metrics["train/tokens_per_second_per_gpu"].update(self.num_tokens / step_time) + self.metrics["train/unpadded_tokens_per_second_per_gpu"].update(self.num_unpadded_tokens / step_time) + self.metrics["train/total_unpadded_tokens_per_batch"].update(self.num_unpadded_tokens) if self._log_mfu: - # PaLM/Megatron/MosaicML convention: count the configured-shape token budget - # (input_ids.numel() = B * S_padded for BSHD, or total packed tokens for THD), - # not the attention-mask count. The hardware executes matmuls over every - # position regardless of masking, and this matches published MFU numbers. - flops_per_step = self._per_token_flops * num_tokens + # Two-term FLOP accounting: + # non_attn_flops = non_attn_per_token * num_tokens + # attn_flops = attn_flop_coeff * Σ(Lᵢ²) / cp_size + # num_tokens follows the PaLM/Megatron/MosaicML convention of counting + # every slot in input_ids (hardware runs QKV/MLP/LM-head matmuls over + # padded positions too). The attention term uses the real per-doc + # Σ(Lᵢ²) from cu_seq_lens because FA gates the score matrix on doc + # boundaries, so a packed batch with fragmented docs does genuinely + # less attention work than a uniform B·S² would imply — the old + # single-term formula overstated MFU in exactly that regime. + attn_work = int(self._attn_work_accum.item()) + non_attn_flops = self._non_attn_per_token_flops * self.num_tokens + attn_flops = (self._attn_flop_coeff * attn_work) // self._cp_size + flops_per_step = non_attn_flops + attn_flops tflops_per_gpu = flops_per_step / step_time / 1e12 self.metrics["train/tflops_per_gpu"].update(tflops_per_gpu) if self._peak_tflops is not None: self.metrics["train/mfu_pct"].update(tflops_per_gpu / self._peak_tflops * 100.0) - # Handle sequence packing for torchmetrics calculation. - if outputs.logits.dim() < 3: - outputs.logits = outputs.logits.unsqueeze(0) - - self.metrics["train/perplexity"].update(outputs.logits, batch["labels"]) - memory_allocated = torch.cuda.memory_allocated() / (1024**3) self.metrics["train/gpu_memory_allocated_max_gb"].update(memory_allocated) self.metrics["train/gpu_memory_allocated_mean_gb"].update(memory_allocated) @@ -244,11 +371,18 @@ def log_step( if self._dist_config.is_main_process(): wandb.log(metrics, step=step) self._progress_bar.update(self.logging_frequency) - self._progress_bar.set_postfix({"loss": outputs.loss.item()}) + self._progress_bar.set_postfix({"loss": avg_loss.item()}) if self._dist_config.local_rank == 0: logger.info(", ".join([f"{k.split('/')[1]}: {v:.3g}" for k, v in metrics.items()])) + # Reset running accumulators for next logging window. + self.running_loss.zero_() + self.num_tokens = 0 + self.num_unpadded_tokens.zero_() + self._attn_work_accum.zero_() + self.grad_acc_step_count = 0 + def finish(self): """Finish the logger and close the progress bar.""" if self.quant_stats_config: diff --git a/bionemo-recipes/recipes/esm2_native_te/tests/test_perf_logger.py b/bionemo-recipes/recipes/esm2_native_te/tests/test_perf_logger.py new file mode 100644 index 0000000000..7d79bcede9 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_native_te/tests/test_perf_logger.py @@ -0,0 +1,234 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for ESM-2's PerfLogger: FLOP formula split + grad-acc accumulator pattern. + +ESM-2 was previously the odd-one-out: its PerfLogger read num_tokens from the *last* +micro-batch at log time, so any future gradient accumulation would have undercounted +FLOPs by 1/grad_acc_steps. This retrofit introduces the ``log_micro_step`` / +``log_step`` split shared with the other MFU-tracking recipes (llama3, og2, codonfm) +and fixes attention-FLOP overcounting on packed (THD) batches. +""" + +from unittest import mock + +import pytest +import torch +from omegaconf import OmegaConf +from transformers.modeling_outputs import MaskedLMOutput + +from distributed_config import DistributedConfig +from perf_logger import ( + PerfLogger, + _attn_work_from_batch, + _compute_attn_flop_coeff, + _compute_non_attn_per_token_flops, + _compute_per_token_flops, +) + + +ESM2_VOCAB = 33 + + +def _make_args(logging_frequency=1, num_train_steps=100, log_mfu=False, max_seq_length=128): + """Create a minimal args config for PerfLogger.""" + return OmegaConf.create( + { + "logger": {"frequency": logging_frequency}, + "wandb_init_args": {"project": "test", "mode": "disabled"}, + "num_train_steps": num_train_steps, + "quant_stats_config": {"enabled": False}, + "log_mfu": log_mfu, + "dataset": {"max_seq_length": max_seq_length}, + } + ) + + +def _make_batch(seq_len=128, device="cuda:0"): + """Create a minimal batch dict.""" + return { + "input_ids": torch.ones(1, seq_len, dtype=torch.long, device=device), + "labels": torch.ones(1, seq_len, dtype=torch.long, device=device), + } + + +def _make_outputs(loss_value, seq_len=128, device="cuda:0"): + """Create MaskedLMOutput with loss + logits.""" + logits = torch.randn(1, seq_len, ESM2_VOCAB, device=device) + return MaskedLMOutput(loss=torch.tensor(loss_value, device=device), logits=logits) + + +@pytest.fixture +def mock_wandb(): + with mock.patch("perf_logger.wandb") as mocked: + mocked.init.return_value = mock.MagicMock() + yield mocked + + +@pytest.fixture +def mock_tqdm(): + with mock.patch("perf_logger.tqdm") as mocked: + yield mocked + + +def _esm_cfg(): + """ESM-2-like MLM encoder config (MHA, no GQA, gelu MLP).""" + return { + "model_type": "esm", # not in _GATED_MLP_MODEL_TYPES → 2-proj MLP + "hidden_size": 1280, + "num_hidden_layers": 33, + "num_attention_heads": 20, + "intermediate_size": 5120, + "vocab_size": ESM2_VOCAB, + } + + +class TestFlopSplitAndAttention: + """Verify the split non-attn + Σ(Lᵢ²) attention formula for ESM-2.""" + + def test_algebraic_identity(self): + """non_attn + coeff·S ≡ _compute_per_token_flops(cfg, S) for all S.""" + cfg = _esm_cfg() + for s in (256, 1024, 8192, 131072): + lhs = _compute_non_attn_per_token_flops(cfg) + _compute_attn_flop_coeff(cfg) * s + rhs = _compute_per_token_flops(cfg, s) + assert lhs == rhs, f"S={s}: {lhs} != {rhs}" + + def test_bshd_no_op(self): + """BSHD batch (no cu_seq_lens) with cp=1 matches legacy formula exactly.""" + cfg = _esm_cfg() + b, s = 2, 512 + batch = {"input_ids": torch.zeros(b, s, dtype=torch.long)} + sigma_l_sq = _attn_work_from_batch(batch, torch.device("cpu")).item() + assert sigma_l_sq == b * s * s + new_flops = _compute_non_attn_per_token_flops(cfg) * b * s + _compute_attn_flop_coeff(cfg) * sigma_l_sq + legacy_flops = _compute_per_token_flops(cfg, s) * b * s + assert new_flops == legacy_flops + + def test_thd_single_doc_matches_bshd(self): + """cu_seq_lens_q=[0, S] reproduces BSHD's Σ(Lᵢ²)=S².""" + s = 512 + bshd = {"input_ids": torch.zeros(1, s, dtype=torch.long)} + thd = { + "input_ids": torch.zeros(1, s, dtype=torch.long), + "cu_seq_lens_q": torch.tensor([0, s], dtype=torch.int32), + } + assert _attn_work_from_batch(bshd, torch.device("cpu")).item() == s * s + assert _attn_work_from_batch(thd, torch.device("cpu")).item() == s * s + + def test_thd_multi_doc_uses_squared_sum(self): + """Multi-doc pack computes Σ(Lᵢ²), not (ΣLᵢ)².""" + cu = torch.tensor([0, 3, 8, 15], dtype=torch.int32) + batch = {"input_ids": torch.zeros(1, 15, dtype=torch.long), "cu_seq_lens_q": cu} + work = _attn_work_from_batch(batch, torch.device("cpu")).item() + assert work == 3**2 + 5**2 + 7**2 + assert work < 15 * 15 + + def test_cp_size_divides_attention_only(self): + """cp_size divides the attention term only; non-attn untouched.""" + cfg = _esm_cfg() + non_attn_per_token = _compute_non_attn_per_token_flops(cfg) + coeff = _compute_attn_flop_coeff(cfg) + num_tokens, attn_work = 100, 10_000 + non_attn = non_attn_per_token * num_tokens + flops_cp1 = non_attn + (coeff * attn_work) // 1 + flops_cp4 = non_attn + (coeff * attn_work) // 4 + assert flops_cp1 - non_attn == coeff * attn_work + assert flops_cp4 - non_attn == (coeff * attn_work) // 4 + + def test_unpadded_preferred_over_padded(self): + """When both cu_seq_lens_q and cu_seq_lens_q_padded are present, _q wins.""" + batch = { + "input_ids": torch.zeros(1, 16, dtype=torch.long), + "cu_seq_lens_q": torch.tensor([0, 5, 11], dtype=torch.int32), + "cu_seq_lens_q_padded": torch.tensor([0, 8, 16], dtype=torch.int32), + } + work = _attn_work_from_batch(batch, torch.device("cpu")).item() + assert work == 5**2 + 6**2 + assert work != 8**2 + 8**2 + + def test_padded_fallback_when_unpadded_absent(self): + """If only cu_seq_lens_q_padded is present, it is used as a fallback.""" + batch = { + "input_ids": torch.zeros(1, 16, dtype=torch.long), + "cu_seq_lens_q_padded": torch.tensor([0, 8, 16], dtype=torch.int32), + } + assert _attn_work_from_batch(batch, torch.device("cpu")).item() == 8**2 + 8**2 + + +class TestGradAccAccumulation: + """Lock in ESM-2's new log_micro_step/log_step split under gradient accumulation. + + Before this retrofit ESM-2 read num_tokens from only the last micro-batch of an + optimizer step, so with grad_acc_steps > 1 it would have reported 1/grad_acc the + true FLOP count. The new accumulator pattern sums across micro-batches. + """ + + def test_num_tokens_accumulates_across_grad_acc(self, mock_wandb, mock_tqdm): + """4 micro-batches of seq_len=128 → num_tokens = 4*128 at log boundary.""" + dist_config = DistributedConfig() + args = _make_args(logging_frequency=1, max_seq_length=128) + perf_logger = PerfLogger(dist_config, args) + device = perf_logger._device + + # One optimizer step with 4 micro-batches of shape (1, 128). + for _ in range(4): + batch = _make_batch(seq_len=128, device=device) + outputs = _make_outputs(1.0, seq_len=128, device=device) + perf_logger.log_micro_step(step=1, batch=batch, outputs=outputs) + + assert perf_logger.grad_acc_step_count == 4 + assert perf_logger.num_tokens == 4 * 128 # 4 micro-batches * 128 tokens each + # running_loss should sum 4 losses of 1.0 each + assert perf_logger.running_loss.item() == pytest.approx(4.0) + + def test_attn_work_accumulates_across_grad_acc(self, mock_wandb, mock_tqdm): + """_attn_work_accum sums Σ(Lᵢ²) over all micro-batches when log_mfu=True.""" + dist_config = DistributedConfig() + args = _make_args(logging_frequency=1, log_mfu=True, max_seq_length=128) + perf_logger = PerfLogger(dist_config, args, model_config_dict=_esm_cfg()) + device = perf_logger._device + + # 3 micro-batches of shape (2, 64) → each batch has Σ(Lᵢ²) = 2 * 64² = 8192 + for _ in range(3): + batch = { + "input_ids": torch.ones(2, 64, dtype=torch.long, device=device), + "labels": torch.ones(2, 64, dtype=torch.long, device=device), + } + outputs = _make_outputs(1.0, seq_len=64, device=device) + # Perplexity expects (B, S, V) logits + outputs.logits = torch.randn(2, 64, ESM2_VOCAB, device=device) + perf_logger.log_micro_step(step=1, batch=batch, outputs=outputs) + + # Accumulator should hold 3 * 2 * 64² = 24576 + assert perf_logger._attn_work_accum.item() == 3 * 2 * 64 * 64 + + def test_reset_on_log_boundary(self, mock_wandb, mock_tqdm): + """Calling log_step on a logging-boundary step drains all accumulators.""" + dist_config = DistributedConfig() + args = _make_args(logging_frequency=1, log_mfu=True, max_seq_length=128) + perf_logger = PerfLogger(dist_config, args, model_config_dict=_esm_cfg()) + device = perf_logger._device + + batch = _make_batch(seq_len=128, device=device) + outputs = _make_outputs(1.0, seq_len=128, device=device) + perf_logger.log_micro_step(step=1, batch=batch, outputs=outputs) + perf_logger.log_step(step=1, grad_norm=torch.tensor(1.0, device=device), lr=1e-4) + + assert perf_logger.grad_acc_step_count == 0 + assert perf_logger.num_tokens == 0 + assert perf_logger.num_unpadded_tokens.item() == 0 + assert perf_logger._attn_work_accum.item() == 0 + assert perf_logger.running_loss.item() == pytest.approx(0.0) diff --git a/bionemo-recipes/recipes/esm2_native_te/train_ddp.py b/bionemo-recipes/recipes/esm2_native_te/train_ddp.py index ebf36d9d47..707652fb9e 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_ddp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_ddp.py @@ -175,6 +175,9 @@ def main(args: DictConfig) -> float | None: loss = outputs.loss loss.backward() + # Record per-micro-batch metrics (loss, num_tokens, Σ(Lᵢ²), perplexity). + perf_logger.log_micro_step(step=step, batch=batch, outputs=outputs) + # Compute and clip gradient norms. total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item() @@ -185,8 +188,6 @@ def main(args: DictConfig) -> float | None: perf_logger.log_step( step=step, - batch=batch, - outputs=outputs, grad_norm=total_norm, lr=optimizer.param_groups[0]["lr"], ) diff --git a/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py b/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py index 25829c64cf..8c3d5f029b 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py @@ -184,6 +184,9 @@ def main(args: DictConfig) -> float | None: loss = outputs.loss loss.backward() + # Record per-micro-batch metrics (loss, num_tokens, Σ(Lᵢ²), perplexity). + perf_logger.log_micro_step(step=step, batch=batch, outputs=outputs) + # Compute and clip gradient norms. total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item() @@ -194,8 +197,6 @@ def main(args: DictConfig) -> float | None: perf_logger.log_step( step=step, - batch=batch, - outputs=outputs, grad_norm=total_norm, lr=optimizer.param_groups[0]["lr"], ) diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py index fcfbf17fa6..78a5c5055b 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py @@ -201,6 +201,9 @@ def main(args: DictConfig) -> float | None: loss = outputs.loss loss.backward() + # Record per-micro-batch metrics (loss, num_tokens, Σ(Lᵢ²), perplexity). + perf_logger.log_micro_step(step=step, batch=batch, outputs=outputs) + # --- Grad clip --- total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item() @@ -212,8 +215,6 @@ def main(args: DictConfig) -> float | None: perf_logger.log_step( step=step, - batch=batch, - outputs=outputs, grad_norm=total_norm, lr=optimizer.param_groups[0]["lr"], ) diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py index 06573112f8..b0d91f4a0d 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py @@ -196,6 +196,10 @@ def main(args: DictConfig) -> float | None: loss = outputs.loss loss.backward() + # Record per-micro-batch metrics (loss, num_tokens, Σ(Lᵢ²), perplexity). + # With future gradient accumulation, this would be called once per micro-batch. + perf_logger.log_micro_step(step=step, batch=batch, outputs=outputs) + # Compute and clip gradient norms. total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item() @@ -206,8 +210,6 @@ def main(args: DictConfig) -> float | None: perf_logger.log_step( step=step, - batch=batch, - outputs=outputs, grad_norm=total_norm, lr=optimizer.param_groups[0]["lr"], ) diff --git a/bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py b/bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py index 2d61d14d39..0dcd5875a3 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py @@ -182,6 +182,9 @@ def main(args: DictConfig) -> float | None: loss = outputs.loss loss.backward() + # Record per-micro-batch metrics (loss, num_tokens, Σ(Lᵢ²), perplexity). + perf_logger.log_micro_step(step=step, batch=batch, outputs=outputs) + # Compute and clip gradient norms. # This is causing training to hang in 26.01 torch base image for multi-process mFSDP. # total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item() @@ -193,8 +196,6 @@ def main(args: DictConfig) -> float | None: perf_logger.log_step( step=step, - batch=batch, - outputs=outputs, grad_norm=0.0, # total_norm, lr=optimizer.param_groups[0]["lr"], ) diff --git a/bionemo-recipes/recipes/llama3_native_te/perf_logger.py b/bionemo-recipes/recipes/llama3_native_te/perf_logger.py index 9503eb09a0..72e500eb3d 100644 --- a/bionemo-recipes/recipes/llama3_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/llama3_native_te/perf_logger.py @@ -67,11 +67,15 @@ def _compute_per_token_flops(model_config_dict: dict, seq_len: int) -> int: """Training FLOPs per token for a transformer (forward + backward = 3x forward). First-principles matmul count: Q/K/V/O projections (GQA-aware), attention - logits/values (the S^2 cost expressed per-token as 4*S*H), 2-or-3-projection - MLP (SwiGLU detected via model_type), and LM head. The returned value is - multiplied by the actual unpadded token count at log time, so it naturally - handles BSHD, THD (sequence packing), gradient accumulation, DP, and CP: - unpadded tokens on each rank already reflect that rank's share of work. + logits/values (the S^2 cost expressed per-token as 4*S*H for a uniform + BSHD batch of length seq_len), 2-or-3-projection MLP (SwiGLU detected via + model_type), and LM head. + + Kept for back-compat. For accurate per-step accounting use + ``_compute_non_attn_per_token_flops`` (applied to the total token count) + together with ``_compute_attn_flop_coeff`` (applied to Σ(Lᵢ²) from + cu_seq_lens), since a packed THD batch of total length S containing docs + L₁, L₂, … has actual attention work Σ(Lᵢ²) ≤ S², not B·S². """ h = model_config_dict["hidden_size"] n_heads = model_config_dict["num_attention_heads"] @@ -96,6 +100,69 @@ def _compute_per_token_flops(model_config_dict: dict, seq_len: int) -> int: return 3 * per_token_fwd +def _compute_non_attn_per_token_flops(model_config_dict: dict) -> int: + """Per-token FLOPs for everything EXCEPT the S² attention term. + + Q/K/V/O projections (GQA-aware) + MLP + LM head, 3x for fwd+bwd. Multiply by the + actual total token count of the batch to get per-step non-attention FLOPs. Pairs + with ``_compute_attn_flop_coeff`` so that + ``non_attn + coeff·S ≡ _compute_per_token_flops(cfg, S)``. + """ + h = model_config_dict["hidden_size"] + n_heads = model_config_dict["num_attention_heads"] + n_kv = model_config_dict.get("num_key_value_heads", n_heads) + head_dim = h // n_heads + kv_dim = n_kv * head_dim + ffn = model_config_dict["intermediate_size"] + vocab = model_config_dict.get("vocab_size", 0) + num_layers = model_config_dict["num_hidden_layers"] + model_type = model_config_dict.get("model_type", "") + num_mlp_proj = 3 if model_type in _GATED_MLP_MODEL_TYPES else 2 + + per_layer = ( + 2 * h * h # Q projection + + 4 * h * kv_dim # K + V projections (GQA-aware) + + 2 * h * h # O projection + + 2 * num_mlp_proj * h * ffn # MLP (2 or 3 projections) + ) + lm_head = 2 * h * vocab if vocab > 0 else 0 + return 3 * (num_layers * per_layer + lm_head) + + +def _compute_attn_flop_coeff(model_config_dict: dict) -> int: + """Coefficient K such that per-step attention FLOPs = K · Σ(Lᵢ²) globally. + + Per CP rank: ``K · Σ(Lᵢ²) / cp_size`` — each CP rank computes 1/cp_size of each + doc's Lᵢ * Lᵢ score matrix. The 4 counts QK^T (2) + softmax·V (2); the 3 is + fwd+bwd. Hidden size appears linearly because attention is over heads and each + contributes head_dim, and heads * head_dim == h. + """ + h = model_config_dict["hidden_size"] + num_layers = model_config_dict["num_hidden_layers"] + return 3 * num_layers * 4 * h + + +def _attn_work_from_batch(batch: dict, device: torch.device) -> torch.Tensor: + """Return Σ(Lᵢ²) for this batch as a GPU int64 scalar tensor. + + THD: uses ``cu_seq_lens_q`` (real per-doc lengths — excludes CP zigzag padding + that FA processes as cycles but contributes nothing to the training signal). + Falls back to ``cu_seq_lens_q_padded`` if only the padded variant is present. + BSHD: no cu_seq_lens in batch → each of B rows is a single "doc" of length S, + so Σ(Lᵢ²) = B·S². Int32 lens are cast to int64 BEFORE squaring (overflow at + L ≈ 46k otherwise). + """ + cu = batch.get("cu_seq_lens_q") + if cu is None: + cu = batch.get("cu_seq_lens_q_padded") + if cu is not None: + lens = (cu[1:] - cu[:-1]).to(torch.int64) + return (lens * lens).sum() + shape = batch["input_ids"].shape + batch_size, seq_len = int(shape[0]), int(shape[-1]) + return torch.tensor(batch_size * seq_len * seq_len, dtype=torch.int64, device=device) + + class PerfLogger: """Class to log performance metrics to stdout and wandb, and print final averaged metrics at the end of training. @@ -132,17 +199,26 @@ def __init__( # reflects each rank's share under DP/CP and sequence packing. self._log_mfu = bool(args.get("log_mfu", False)) and model_config_dict is not None self._per_token_flops = 0 + self._non_attn_per_token_flops = 0 + self._attn_flop_coeff = 0 + self._cp_size = int(args.get("cp_size", 1)) self._peak_tflops: float | None = None if self._log_mfu: self._per_token_flops = _compute_per_token_flops(model_config_dict, args.dataset.max_seq_length) + self._non_attn_per_token_flops = _compute_non_attn_per_token_flops(model_config_dict) + self._attn_flop_coeff = _compute_attn_flop_coeff(model_config_dict) self._peak_tflops, gpu_name = _detect_peak_tflops_bf16() if dist_config.local_rank == 0: logger.info( - "MFU tracking enabled: GPU=%s, peak=%s TFLOPS BF16, per-token FLOPs=%.3e, seq_len=%d", + "MFU tracking enabled: GPU=%s, peak=%s TFLOPS BF16, per-token FLOPs=%.3e, seq_len=%d, " + "non_attn_per_token=%.3e, attn_coeff=%.3e, cp_size=%d", gpu_name, f"{self._peak_tflops:.1f}" if self._peak_tflops else "unknown", float(self._per_token_flops), args.dataset.max_seq_length, + float(self._non_attn_per_token_flops), + float(self._attn_flop_coeff), + self._cp_size, ) metrics_dict = { @@ -182,6 +258,8 @@ def __init__( # Gradient accumulation tracking self.num_tokens = 0 self.num_unpadded_tokens = torch.tensor(0, dtype=torch.int64, device=self._device) + # Σ(Lᵢ²) over grad-acc micro-batches — drives the attention-FLOP term at log time. + self._attn_work_accum = torch.tensor(0, dtype=torch.int64, device=self._device) self.running_loss = torch.tensor(0.0, device=self._device) self.grad_acc_step_count = 0 @@ -214,6 +292,10 @@ def log_micro_step(self, step: int, batch: dict[str, torch.Tensor], outputs: Cau else: # Fallback for pure sequence packing with no padding: all tokens are unpadded self.num_unpadded_tokens += batch["input_ids"].numel() + if self._log_mfu: + # Σ(Lᵢ²) from cu_seq_lens (THD) or B·S² from input_ids shape (BSHD). + # Accumulates across grad-acc micro-batches; drained in log_step. + self._attn_work_accum += _attn_work_from_batch(batch, self._device) @nvtx.annotate("PerfLogger.log_step", color="purple") def log_step( @@ -269,13 +351,20 @@ def log_step( self.metrics["train/total_unpadded_tokens_per_batch"].update(self.num_unpadded_tokens) if self._log_mfu: - # PaLM/Megatron/MosaicML convention: count the configured-shape token budget - # (input_ids.numel() = B * S_padded for BSHD, or total packed tokens for THD), - # not attention_mask.sum(). The hardware executes matmuls over every position - # regardless of masking, and this matches published MFU numbers. - # num_tokens is accumulated over the grad-acc micro-batches of one optimizer - # step (the last step in the logging window). step_time is per-step average. - flops_per_step = self._per_token_flops * self.num_tokens + # Two-term FLOP accounting: + # non_attn_flops = non_attn_per_token * num_tokens + # attn_flops = attn_flop_coeff * Σ(Lᵢ²) / cp_size + # num_tokens follows the PaLM/Megatron/MosaicML convention of counting + # every slot in input_ids (hardware runs QKV/MLP/LM-head matmuls over + # padded positions too). The attention term uses the real per-doc + # Σ(Lᵢ²) from cu_seq_lens because FA gates the score matrix on doc + # boundaries, so a packed batch with fragmented docs does genuinely + # less attention work than a uniform B·S² would imply — the old + # single-term formula overstated MFU in exactly that regime. + attn_work = int(self._attn_work_accum.item()) + non_attn_flops = self._non_attn_per_token_flops * self.num_tokens + attn_flops = (self._attn_flop_coeff * attn_work) // self._cp_size + flops_per_step = non_attn_flops + attn_flops tflops_per_gpu = flops_per_step / step_time / 1e12 self.metrics["train/tflops_per_gpu"].update(tflops_per_gpu) if self._peak_tflops is not None: @@ -305,6 +394,7 @@ def log_step( self.running_loss.zero_() self.num_tokens = 0 self.num_unpadded_tokens.zero_() + self._attn_work_accum.zero_() self.grad_acc_step_count = 0 def finish(self): diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py index 5fd6e70ad0..790e9f90e9 100644 --- a/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py +++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py @@ -25,6 +25,9 @@ from distributed_config import DistributedConfig from perf_logger import ( PerfLogger, + _attn_work_from_batch, + _compute_attn_flop_coeff, + _compute_non_attn_per_token_flops, _compute_per_token_flops, _detect_peak_tflops_bf16, ) @@ -308,3 +311,95 @@ def test_returns_tuple_shape(self): peak, name = _detect_peak_tflops_bf16() assert isinstance(name, str) assert peak is None or isinstance(peak, float) + + +def _llama_cfg(): + """Small llama-like config used by the split-formula tests.""" + return { + "model_type": "llama", + "hidden_size": 1024, + "num_hidden_layers": 8, + "num_attention_heads": 16, + "num_key_value_heads": 4, + "intermediate_size": 4096, + "vocab_size": 32000, + } + + +class TestFlopSplitAndAttention: + """Verify the split non-attn + Σ(Lᵢ²) attention formula. + + The old single-term ``_per_token_flops * num_tokens`` formula treats a packed + batch as one giant S*S attention. Real Flash-Attention work is Σ(Lᵢ²) over + packed segments. These tests lock in the new split and its invariants. + """ + + def test_algebraic_identity(self): + """non_attn + coeff·S ≡ _compute_per_token_flops(cfg, S) for all S.""" + cfg = _llama_cfg() + for s in (256, 1024, 8192): + lhs = _compute_non_attn_per_token_flops(cfg) + _compute_attn_flop_coeff(cfg) * s + rhs = _compute_per_token_flops(cfg, s) + assert lhs == rhs, f"S={s}: {lhs} != {rhs}" + + def test_bshd_no_op(self): + """BSHD batch (no cu_seq_lens) with cp=1 matches legacy formula exactly.""" + cfg = _llama_cfg() + b, s = 2, 512 + batch = {"input_ids": torch.zeros(b, s, dtype=torch.long)} + sigma_l_sq = _attn_work_from_batch(batch, torch.device("cpu")).item() + assert sigma_l_sq == b * s * s + new_flops = _compute_non_attn_per_token_flops(cfg) * b * s + _compute_attn_flop_coeff(cfg) * sigma_l_sq + legacy_flops = _compute_per_token_flops(cfg, s) * b * s + assert new_flops == legacy_flops + + def test_thd_single_doc_matches_bshd(self): + """cu_seq_lens_q=[0, S] (synthetic-single-doc) reproduces BSHD's Σ(Lᵢ²)=S².""" + s = 512 + bshd = {"input_ids": torch.zeros(1, s, dtype=torch.long)} + thd = { + "input_ids": torch.zeros(1, s, dtype=torch.long), + "cu_seq_lens_q": torch.tensor([0, s], dtype=torch.int32), + } + assert _attn_work_from_batch(bshd, torch.device("cpu")).item() == s * s + assert _attn_work_from_batch(thd, torch.device("cpu")).item() == s * s + + def test_thd_multi_doc_uses_squared_sum(self): + """Multi-doc pack computes Σ(Lᵢ²), not (ΣLᵢ)² — the whole point of the fix.""" + # Doc lengths 3, 5, 7 → cumulative [0, 3, 8, 15] + cu = torch.tensor([0, 3, 8, 15], dtype=torch.int32) + batch = {"input_ids": torch.zeros(1, 15, dtype=torch.long), "cu_seq_lens_q": cu} + work = _attn_work_from_batch(batch, torch.device("cpu")).item() + assert work == 3**2 + 5**2 + 7**2 # 83 real QK pairs per layer + assert work < 15 * 15 # old formula would have said 225 + + def test_cp_size_divides_attention_only(self): + """Dividing attention by cp_size must leave the non-attention term untouched.""" + cfg = _llama_cfg() + non_attn_per_token = _compute_non_attn_per_token_flops(cfg) + coeff = _compute_attn_flop_coeff(cfg) + num_tokens, attn_work = 100, 10_000 + non_attn = non_attn_per_token * num_tokens + flops_cp1 = non_attn + (coeff * attn_work) // 1 + flops_cp4 = non_attn + (coeff * attn_work) // 4 + assert flops_cp1 - non_attn == coeff * attn_work + assert flops_cp4 - non_attn == (coeff * attn_work) // 4 + + def test_unpadded_preferred_over_padded(self): + """When both cu_seq_lens_q and cu_seq_lens_q_padded are present, _q wins.""" + batch = { + "input_ids": torch.zeros(1, 16, dtype=torch.long), + "cu_seq_lens_q": torch.tensor([0, 5, 11], dtype=torch.int32), + "cu_seq_lens_q_padded": torch.tensor([0, 8, 16], dtype=torch.int32), + } + work = _attn_work_from_batch(batch, torch.device("cpu")).item() + assert work == 5**2 + 6**2 # 61 (unpadded doc lens 5 and 6) + assert work != 8**2 + 8**2 # 128 (padded slot lens 8 and 8) + + def test_padded_fallback_when_unpadded_absent(self): + """If only cu_seq_lens_q_padded is present, it is used as a fallback.""" + batch = { + "input_ids": torch.zeros(1, 16, dtype=torch.long), + "cu_seq_lens_q_padded": torch.tensor([0, 8, 16], dtype=torch.int32), + } + assert _attn_work_from_batch(batch, torch.device("cpu")).item() == 8**2 + 8**2 diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py b/bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py index 7b09d47d8f..e9d1e6fdf0 100644 --- a/bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py @@ -75,11 +75,15 @@ def _compute_per_token_flops(model_config_dict: dict, seq_len: int) -> int: """Training FLOPs per token for a transformer (forward + backward = 3x forward). First-principles matmul count: Q/K/V/O projections (GQA-aware), attention - logits/values (the S^2 cost expressed per-token as 4*S*H), 2-or-3-projection - MLP (SwiGLU detected via model_type), and LM head. The returned value is - multiplied by the actual unpadded token count at log time, so it naturally - handles BSHD, THD (sequence packing), gradient accumulation, DP, and CP: - unpadded tokens on each rank already reflect that rank's share of work. + logits/values (the S^2 cost expressed per-token as 4*S*H for a uniform + BSHD batch of length seq_len), 2-or-3-projection MLP (SwiGLU detected via + model_type), and LM head. + + Kept for back-compat. For accurate per-step accounting use + ``_compute_non_attn_per_token_flops`` (applied to the total token count) + together with ``_compute_attn_flop_coeff`` (applied to Σ(Lᵢ²) from + cu_seq_lens), since a packed THD batch of total length S containing docs + L₁, L₂, … has actual attention work Σ(Lᵢ²) ≤ S², not B·S². """ h = model_config_dict["hidden_size"] n_heads = model_config_dict["num_attention_heads"] @@ -104,6 +108,69 @@ def _compute_per_token_flops(model_config_dict: dict, seq_len: int) -> int: return 3 * per_token_fwd +def _compute_non_attn_per_token_flops(model_config_dict: dict) -> int: + """Per-token FLOPs for everything EXCEPT the S² attention term. + + Q/K/V/O projections (GQA-aware) + MLP + LM head, 3x for fwd+bwd. Multiply by the + actual total token count of the batch to get per-step non-attention FLOPs. Pairs + with ``_compute_attn_flop_coeff`` so that + ``non_attn + coeff·S ≡ _compute_per_token_flops(cfg, S)``. + """ + h = model_config_dict["hidden_size"] + n_heads = model_config_dict["num_attention_heads"] + n_kv = model_config_dict.get("num_key_value_heads", n_heads) + head_dim = h // n_heads + kv_dim = n_kv * head_dim + ffn = model_config_dict["intermediate_size"] + vocab = model_config_dict.get("vocab_size", 0) + num_layers = model_config_dict["num_hidden_layers"] + model_type = model_config_dict.get("model_type", "") + num_mlp_proj = 3 if model_type in _GATED_MLP_MODEL_TYPES else 2 + + per_layer = ( + 2 * h * h # Q projection + + 4 * h * kv_dim # K + V projections (GQA-aware) + + 2 * h * h # O projection + + 2 * num_mlp_proj * h * ffn # MLP (2 or 3 projections) + ) + lm_head = 2 * h * vocab if vocab > 0 else 0 + return 3 * (num_layers * per_layer + lm_head) + + +def _compute_attn_flop_coeff(model_config_dict: dict) -> int: + """Coefficient K such that per-step attention FLOPs = K · Σ(Lᵢ²) globally. + + Per CP rank: ``K · Σ(Lᵢ²) / cp_size`` — each CP rank computes 1/cp_size of each + doc's Lᵢ * Lᵢ score matrix. The 4 counts QK^T (2) + softmax·V (2); the 3 is + fwd+bwd. Hidden size appears linearly because attention is over heads and each + contributes head_dim, and heads * head_dim == h. + """ + h = model_config_dict["hidden_size"] + num_layers = model_config_dict["num_hidden_layers"] + return 3 * num_layers * 4 * h + + +def _attn_work_from_batch(batch: dict, device: torch.device) -> torch.Tensor: + """Return Σ(Lᵢ²) for this batch as a GPU int64 scalar tensor. + + THD: uses ``cu_seq_lens_q`` (real per-doc lengths — excludes CP zigzag padding + that FA processes as cycles but contributes nothing to the training signal). + Falls back to ``cu_seq_lens_q_padded`` if only the padded variant is present. + BSHD: no cu_seq_lens in batch → each of B rows is a single "doc" of length S, + so Σ(Lᵢ²) = B·S². Int32 lens are cast to int64 BEFORE squaring (overflow at + L ≈ 46k otherwise). + """ + cu = batch.get("cu_seq_lens_q") + if cu is None: + cu = batch.get("cu_seq_lens_q_padded") + if cu is not None: + lens = (cu[1:] - cu[:-1]).to(torch.int64) + return (lens * lens).sum() + shape = batch["input_ids"].shape + batch_size, seq_len = int(shape[0]), int(shape[-1]) + return torch.tensor(batch_size * seq_len * seq_len, dtype=torch.int64, device=device) + + class PerfLogger: """Class to log performance metrics to stdout and wandb, and print final averaged metrics at the end of training. @@ -133,17 +200,26 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig, model_confi # reflects each rank's share under DP/CP and sequence packing. self._log_mfu = bool(args.get("log_mfu", False)) and model_config_dict is not None self._per_token_flops = 0 + self._non_attn_per_token_flops = 0 + self._attn_flop_coeff = 0 + self._cp_size = int(args.get("cp_size", 1)) self._peak_tflops: float | None = None if self._log_mfu: self._per_token_flops = _compute_per_token_flops(model_config_dict, args.dataset.max_seq_length) + self._non_attn_per_token_flops = _compute_non_attn_per_token_flops(model_config_dict) + self._attn_flop_coeff = _compute_attn_flop_coeff(model_config_dict) self._peak_tflops, gpu_name = _detect_peak_tflops_bf16() if dist_config.local_rank == 0: logger.info( - "MFU tracking enabled: GPU=%s, peak=%s TFLOPS BF16, per-token FLOPs=%.3e, seq_len=%d", + "MFU tracking enabled: GPU=%s, peak=%s TFLOPS BF16, per-token FLOPs=%.3e, seq_len=%d, " + "non_attn_per_token=%.3e, attn_coeff=%.3e, cp_size=%d", gpu_name, f"{self._peak_tflops:.1f}" if self._peak_tflops else "unknown", float(self._per_token_flops), args.dataset.max_seq_length, + float(self._non_attn_per_token_flops), + float(self._attn_flop_coeff), + self._cp_size, ) metrics_dict = { @@ -183,6 +259,8 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig, model_confi # Gradient accumulation tracking self.num_tokens = 0 self.num_unpadded_tokens = torch.tensor(0, dtype=torch.int64, device=self._device) + # Σ(Lᵢ²) over grad-acc micro-batches — drives the attention-FLOP term at log time. + self._attn_work_accum = torch.tensor(0, dtype=torch.int64, device=self._device) self.running_loss = torch.tensor(0.0, device=self._device) self.grad_acc_step_count = 0 @@ -215,6 +293,10 @@ def log_micro_step(self, step: int, batch: dict[str, torch.Tensor], outputs: Cau else: # Fallback for pure sequence packing with no padding: all tokens are unpadded self.num_unpadded_tokens += batch["input_ids"].numel() + if self._log_mfu: + # Σ(Lᵢ²) from cu_seq_lens (THD) or B·S² from input_ids shape (BSHD). + # Accumulates across grad-acc micro-batches; drained in log_step. + self._attn_work_accum += _attn_work_from_batch(batch, self._device) @nvtx.annotate("PerfLogger.log_step", color="purple") def log_step( @@ -270,13 +352,20 @@ def log_step( self.metrics["train/total_unpadded_tokens_per_batch"].update(self.num_unpadded_tokens) if self._log_mfu: - # PaLM/Megatron/MosaicML convention: count the configured-shape token budget - # (input_ids.numel() = B * S_padded for BSHD, or total packed tokens for THD), - # not attention_mask.sum(). The hardware executes matmuls over every position - # regardless of masking, and this matches published MFU numbers. - # num_tokens is accumulated over the grad-acc micro-batches of one optimizer - # step (the last step in the logging window). step_time is per-step average. - flops_per_step = self._per_token_flops * self.num_tokens + # Two-term FLOP accounting: + # non_attn_flops = non_attn_per_token * num_tokens + # attn_flops = attn_flop_coeff * Σ(Lᵢ²) / cp_size + # num_tokens follows the PaLM/Megatron/MosaicML convention of counting + # every slot in input_ids (hardware runs QKV/MLP/LM-head matmuls over + # padded positions too). The attention term uses the real per-doc + # Σ(Lᵢ²) from cu_seq_lens because FA gates the score matrix on doc + # boundaries, so a packed batch with fragmented docs does genuinely + # less attention work than a uniform B·S² would imply — the old + # single-term formula overstated MFU in exactly that regime. + attn_work = int(self._attn_work_accum.item()) + non_attn_flops = self._non_attn_per_token_flops * self.num_tokens + attn_flops = (self._attn_flop_coeff * attn_work) // self._cp_size + flops_per_step = non_attn_flops + attn_flops tflops_per_gpu = flops_per_step / step_time / 1e12 self.metrics["train/tflops_per_gpu"].update(tflops_per_gpu) if self._peak_tflops is not None: @@ -306,6 +395,7 @@ def log_step( self.running_loss.zero_() self.num_tokens = 0 self.num_unpadded_tokens.zero_() + self._attn_work_accum.zero_() self.grad_acc_step_count = 0 def log_validation(self, step: int, val_metrics: dict): diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/tests/test_perf_logger.py b/bionemo-recipes/recipes/opengenome2_llama_native_te/tests/test_perf_logger.py new file mode 100644 index 0000000000..c5c0405d85 --- /dev/null +++ b/bionemo-recipes/recipes/opengenome2_llama_native_te/tests/test_perf_logger.py @@ -0,0 +1,117 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the split non-attention + Σ(Lᵢ²) attention FLOP formula. + +The old single-term ``_per_token_flops * num_tokens`` formula treats a packed batch +as one giant S*S attention. Real Flash-Attention work is Σ(Lᵢ²) over packed +segments. These tests lock in the new split and its invariants so future drift +between sibling recipes is caught immediately. +""" + +import torch + +from perf_logger import ( + _attn_work_from_batch, + _compute_attn_flop_coeff, + _compute_non_attn_per_token_flops, + _compute_per_token_flops, +) + + +def _llama_cfg(): + """Llama-like OG2 config used by the split-formula tests.""" + return { + "model_type": "llama", + "hidden_size": 4096, + "num_hidden_layers": 32, + "num_attention_heads": 32, + "num_key_value_heads": 8, # GQA + "intermediate_size": 14336, + "vocab_size": 256, # OG2's nucleotide vocab + } + + +class TestFlopSplitAndAttention: + """Verify the split formula matches the legacy one and correctly handles THD.""" + + def test_algebraic_identity(self): + """non_attn + coeff·S ≡ _compute_per_token_flops(cfg, S) for all S.""" + cfg = _llama_cfg() + for s in (256, 1024, 8192, 131072): + lhs = _compute_non_attn_per_token_flops(cfg) + _compute_attn_flop_coeff(cfg) * s + rhs = _compute_per_token_flops(cfg, s) + assert lhs == rhs, f"S={s}: {lhs} != {rhs}" + + def test_bshd_no_op(self): + """BSHD batch (no cu_seq_lens) with cp=1 matches legacy formula exactly.""" + cfg = _llama_cfg() + b, s = 2, 512 + batch = {"input_ids": torch.zeros(b, s, dtype=torch.long)} + sigma_l_sq = _attn_work_from_batch(batch, torch.device("cpu")).item() + assert sigma_l_sq == b * s * s + new_flops = _compute_non_attn_per_token_flops(cfg) * b * s + _compute_attn_flop_coeff(cfg) * sigma_l_sq + legacy_flops = _compute_per_token_flops(cfg, s) * b * s + assert new_flops == legacy_flops + + def test_thd_single_doc_matches_bshd(self): + """cu_seq_lens_q=[0, S] (synthetic-single-doc) reproduces BSHD's Σ(Lᵢ²)=S².""" + s = 512 + bshd = {"input_ids": torch.zeros(1, s, dtype=torch.long)} + thd = { + "input_ids": torch.zeros(1, s, dtype=torch.long), + "cu_seq_lens_q": torch.tensor([0, s], dtype=torch.int32), + } + assert _attn_work_from_batch(bshd, torch.device("cpu")).item() == s * s + assert _attn_work_from_batch(thd, torch.device("cpu")).item() == s * s + + def test_thd_multi_doc_uses_squared_sum(self): + """Multi-doc pack computes Σ(Lᵢ²), not (ΣLᵢ)² — the whole point of the fix.""" + cu = torch.tensor([0, 3, 8, 15], dtype=torch.int32) + batch = {"input_ids": torch.zeros(1, 15, dtype=torch.long), "cu_seq_lens_q": cu} + work = _attn_work_from_batch(batch, torch.device("cpu")).item() + assert work == 3**2 + 5**2 + 7**2 # 83 real QK pairs per layer + assert work < 15 * 15 # old formula would have said 225 + + def test_cp_size_divides_attention_only(self): + """cp_size divides the attention term only; non-attention stays untouched.""" + cfg = _llama_cfg() + non_attn_per_token = _compute_non_attn_per_token_flops(cfg) + coeff = _compute_attn_flop_coeff(cfg) + num_tokens, attn_work = 100, 10_000 + non_attn = non_attn_per_token * num_tokens + flops_cp1 = non_attn + (coeff * attn_work) // 1 + flops_cp4 = non_attn + (coeff * attn_work) // 4 + assert flops_cp1 - non_attn == coeff * attn_work + assert flops_cp4 - non_attn == (coeff * attn_work) // 4 + + def test_unpadded_preferred_over_padded(self): + """When both cu_seq_lens_q and cu_seq_lens_q_padded are present, _q wins.""" + batch = { + "input_ids": torch.zeros(1, 16, dtype=torch.long), + "cu_seq_lens_q": torch.tensor([0, 5, 11], dtype=torch.int32), + "cu_seq_lens_q_padded": torch.tensor([0, 8, 16], dtype=torch.int32), + } + work = _attn_work_from_batch(batch, torch.device("cpu")).item() + assert work == 5**2 + 6**2 # 61 (unpadded doc lens 5 and 6) + assert work != 8**2 + 8**2 # 128 (padded slot lens 8 and 8) + + def test_padded_fallback_when_unpadded_absent(self): + """If only cu_seq_lens_q_padded is present, it is used as a fallback.""" + batch = { + "input_ids": torch.zeros(1, 16, dtype=torch.long), + "cu_seq_lens_q_padded": torch.tensor([0, 8, 16], dtype=torch.int32), + } + assert _attn_work_from_batch(batch, torch.device("cpu")).item() == 8**2 + 8**2 From b6685f51be6d84cecfb05e192cc1cca09408d257 Mon Sep 17 00:00:00 2001 From: Gagan Kaushik Date: Wed, 22 Apr 2026 15:47:21 -0700 Subject: [PATCH 15/24] =?UTF-8?q?MFU:=20fix=20BSHD+CP=20attention-FLOP=20u?= =?UTF-8?q?ndercount=20(factor=20cp=C2=B2)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Real-data benchmark runs (llama3 dclm + og2 metagenome at cp=8 dp=1 S=131072) surfaced a bug in _attn_work_from_batch's BSHD branch. ContextParallelDataLoaderWrapper pre-splits the sequence along the CP axis before the batch reaches log_micro_step, so batch["input_ids"].shape per rank is (B, S/cp), not (B, S). The helper computed B*(S/cp)² from the per-rank shape, and log_step divided by cp_size, yielding per-rank attention FLOPs of B*S²/cp³ — wrong by cp². Correct per-rank attention work is B*S²/cp (Q_rank @ K_full costs (S/cp)*S per head summed). Evidence on og2 BSHD cp=8 S=131072 with a full-length doc: reported 10.9% MFU / 246 TFLOPS/GPU, but the actual compute corresponds to ~60% MFU (matches the THD run with the same S on the same data). Fix: pass cp_size into _attn_work_from_batch and multiply the BSHD synthesis by cp_size². The helper now always returns a GLOBAL (pre-shard) Σ(Lᵢ²), so the /cp_size in log_step consistently gives per-rank attention regardless of whether the input is cu_seq_lens_q (already global) or per-rank shape (needs the cp² correction). THD path unaffected (cu_seq_lens_q was always global). BSHD without CP unaffected (cp_size=1 makes the multiplier 1). Existing published benchmarks in ci/benchmarks/{og2,esm2}_benchmarks.md used synthetic cu_seq_lens=[0, S] and therefore flowed through the THD path — they remain valid. Added test_bshd_cp_correction to all four recipes asserting both the BSHD+CP scaling and the THD invariance under cp_size variation. Signed-off-by: Gagan Kaushik --- .../recipes/codonfm_native_te/perf_logger.py | 38 +++++++++++++------ .../tests/test_perf_logger.py | 16 ++++++++ .../recipes/esm2_native_te/perf_logger.py | 36 ++++++++++++------ .../esm2_native_te/tests/test_perf_logger.py | 12 ++++++ .../recipes/llama3_native_te/perf_logger.py | 36 ++++++++++++------ .../tests/test_perf_logger.py | 25 ++++++++++++ .../perf_logger.py | 36 ++++++++++++------ .../tests/test_perf_logger.py | 18 +++++++++ 8 files changed, 169 insertions(+), 48 deletions(-) diff --git a/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py b/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py index cdeb46df1b..292ea0c5b9 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py @@ -145,15 +145,24 @@ def _compute_attn_flop_coeff(model_config_dict: dict) -> int: return 3 * num_layers * 4 * h -def _attn_work_from_batch(batch: dict, device: torch.device) -> torch.Tensor: - """Return Σ(Lᵢ²) for this batch as a GPU int64 scalar tensor. - - THD: uses ``cu_seq_lens_q`` (real per-doc lengths — excludes CP zigzag padding - that FA processes as cycles but contributes nothing to the training signal). - Falls back to ``cu_seq_lens_q_padded`` if only the padded variant is present. - BSHD: no cu_seq_lens in batch → each of B rows is a single "doc" of length S, - so Σ(Lᵢ²) = B·S². Int32 lens are cast to int64 BEFORE squaring (overflow at - L ≈ 46k otherwise). +def _attn_work_from_batch(batch: dict, device: torch.device, cp_size: int = 1) -> torch.Tensor: + """Return GLOBAL Σ(Lᵢ²) for this batch as an int64 scalar tensor. + + The caller divides by cp_size in log_step to convert this global number into + per-rank attention work, so this helper always returns a pre-CP-shard quantity. + + * THD: uses ``cu_seq_lens_q`` (real per-doc lengths — already global, the collator + emits boundaries before per-rank sharding). Falls back to ``cu_seq_lens_q_padded`` + if only the padded variant is present. Excludes CP zigzag padding that FA + processes as cycles but contributes nothing to the training signal. + * BSHD: ``batch["input_ids"].shape`` per rank is ``(B, S/cp)`` when CP is active + because ``ContextParallelDataLoaderWrapper`` pre-splits along the seq dim. To + recover the global B·S² from a per-rank shape, multiply by cp_size². When + cp_size=1 this is a no-op (per-rank shape == global shape). CodonFM runs FSDP + without CP, so cp_size=1 always here, but the formula stays correct if CP is + added later. + + Int32 lens cast to int64 BEFORE squaring (overflow at L ≈ 46k otherwise). """ cu = batch.get("cu_seq_lens_q") if cu is None: @@ -162,8 +171,12 @@ def _attn_work_from_batch(batch: dict, device: torch.device) -> torch.Tensor: lens = (cu[1:] - cu[:-1]).to(torch.int64) return (lens * lens).sum() shape = batch["input_ids"].shape - batch_size, seq_len = int(shape[0]), int(shape[-1]) - return torch.tensor(batch_size * seq_len * seq_len, dtype=torch.int64, device=device) + batch_size, seq_len_per_rank = int(shape[0]), int(shape[-1]) + return torch.tensor( + batch_size * seq_len_per_rank * seq_len_per_rank * cp_size * cp_size, + dtype=torch.int64, + device=device, + ) class PerfLogger: @@ -271,8 +284,9 @@ def log_micro_step(self, step: int, batch: dict[str, torch.Tensor], outputs: Mas self.num_unpadded_tokens += num_unpadded_tokens if self._log_mfu: # Σ(Lᵢ²) from cu_seq_lens (THD) or B·S² from input_ids shape (BSHD). + # Helper returns a GLOBAL value (pre-CP-shard); log_step divides by cp_size. # Accumulates across grad-acc micro-batches; drained in log_step. - self._attn_work_accum += _attn_work_from_batch(batch, self._device) + self._attn_work_accum += _attn_work_from_batch(batch, self._device, self._cp_size) # Update perplexity per micro-batch since it needs logits + labels logits = outputs.logits diff --git a/bionemo-recipes/recipes/codonfm_native_te/tests/test_perf_logger.py b/bionemo-recipes/recipes/codonfm_native_te/tests/test_perf_logger.py index d417a283d5..c782f4df24 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/tests/test_perf_logger.py +++ b/bionemo-recipes/recipes/codonfm_native_te/tests/test_perf_logger.py @@ -302,3 +302,19 @@ def test_padded_fallback_when_unpadded_absent(self): "cu_seq_lens_q_padded": torch.tensor([0, 8, 16], dtype=torch.int32), } assert _attn_work_from_batch(batch, torch.device("cpu")).item() == 8**2 + 8**2 + + def test_bshd_cp_correction(self): + """BSHD with CP: per-rank shape (B, S/cp) → helper must return global B*S². + + CodonFM currently runs FSDP without CP so this is latent defence, but the + formula must be correct if CP is added. + """ + batch = {"input_ids": torch.zeros(1, 16, dtype=torch.long)} + assert _attn_work_from_batch(batch, torch.device("cpu"), cp_size=1).item() == 1 * 16 * 16 + assert _attn_work_from_batch(batch, torch.device("cpu"), cp_size=8).item() == 128 * 128 + thd = { + "input_ids": torch.zeros(1, 64, dtype=torch.long), + "cu_seq_lens_q": torch.tensor([0, 3, 8, 15], dtype=torch.int32), + } + assert _attn_work_from_batch(thd, torch.device("cpu"), cp_size=1).item() == 3**2 + 5**2 + 7**2 + assert _attn_work_from_batch(thd, torch.device("cpu"), cp_size=8).item() == 3**2 + 5**2 + 7**2 diff --git a/bionemo-recipes/recipes/esm2_native_te/perf_logger.py b/bionemo-recipes/recipes/esm2_native_te/perf_logger.py index 2f068dfef2..b1d8abd865 100644 --- a/bionemo-recipes/recipes/esm2_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/esm2_native_te/perf_logger.py @@ -144,15 +144,22 @@ def _compute_attn_flop_coeff(model_config_dict: dict) -> int: return 3 * num_layers * 4 * h -def _attn_work_from_batch(batch: dict, device: torch.device) -> torch.Tensor: - """Return Σ(Lᵢ²) for this batch as a GPU int64 scalar tensor. - - THD: uses ``cu_seq_lens_q`` (real per-doc lengths — excludes CP zigzag padding - that FA processes as cycles but contributes nothing to the training signal). - Falls back to ``cu_seq_lens_q_padded`` if only the padded variant is present. - BSHD: no cu_seq_lens in batch → each of B rows is a single "doc" of length S, - so Σ(Lᵢ²) = B·S². Int32 lens are cast to int64 BEFORE squaring (overflow at - L ≈ 46k otherwise). +def _attn_work_from_batch(batch: dict, device: torch.device, cp_size: int = 1) -> torch.Tensor: + """Return GLOBAL Σ(Lᵢ²) for this batch as an int64 scalar tensor. + + The caller divides by cp_size in log_step to convert this global number into + per-rank attention work, so this helper always returns a pre-CP-shard quantity. + + * THD: uses ``cu_seq_lens_q`` (real per-doc lengths — already global, the collator + emits boundaries before per-rank sharding). Falls back to ``cu_seq_lens_q_padded`` + if only the padded variant is present. Excludes CP zigzag padding that FA + processes as cycles but contributes nothing to the training signal. + * BSHD: ``batch["input_ids"].shape`` per rank is ``(B, S/cp)`` when CP is active + because ``ContextParallelDataLoaderWrapper`` pre-splits along the seq dim. To + recover the global B·S² from a per-rank shape, multiply by cp_size². When + cp_size=1 this is a no-op (per-rank shape == global shape). + + Int32 lens cast to int64 BEFORE squaring (overflow at L ≈ 46k otherwise). """ cu = batch.get("cu_seq_lens_q") if cu is None: @@ -161,8 +168,12 @@ def _attn_work_from_batch(batch: dict, device: torch.device) -> torch.Tensor: lens = (cu[1:] - cu[:-1]).to(torch.int64) return (lens * lens).sum() shape = batch["input_ids"].shape - batch_size, seq_len = int(shape[0]), int(shape[-1]) - return torch.tensor(batch_size * seq_len * seq_len, dtype=torch.int64, device=device) + batch_size, seq_len_per_rank = int(shape[0]), int(shape[-1]) + return torch.tensor( + batch_size * seq_len_per_rank * seq_len_per_rank * cp_size * cp_size, + dtype=torch.int64, + device=device, + ) class PerfLogger: @@ -280,8 +291,9 @@ def log_micro_step(self, step: int, batch: dict[str, torch.Tensor], outputs: Mas self.num_unpadded_tokens += num_unpadded_tokens if self._log_mfu: # Σ(Lᵢ²) from cu_seq_lens (THD) or B·S² from input_ids shape (BSHD). + # Helper returns a GLOBAL value (pre-CP-shard); log_step divides by cp_size. # Accumulates across grad-acc micro-batches; drained in log_step. - self._attn_work_accum += _attn_work_from_batch(batch, self._device) + self._attn_work_accum += _attn_work_from_batch(batch, self._device, self._cp_size) # Update perplexity per micro-batch since it needs logits + labels. logits = outputs.logits diff --git a/bionemo-recipes/recipes/esm2_native_te/tests/test_perf_logger.py b/bionemo-recipes/recipes/esm2_native_te/tests/test_perf_logger.py index 7d79bcede9..e0889d5976 100644 --- a/bionemo-recipes/recipes/esm2_native_te/tests/test_perf_logger.py +++ b/bionemo-recipes/recipes/esm2_native_te/tests/test_perf_logger.py @@ -167,6 +167,18 @@ def test_padded_fallback_when_unpadded_absent(self): } assert _attn_work_from_batch(batch, torch.device("cpu")).item() == 8**2 + 8**2 + def test_bshd_cp_correction(self): + """BSHD with CP: per-rank shape (B, S/cp) → helper must return global B*S².""" + batch = {"input_ids": torch.zeros(1, 16, dtype=torch.long)} + assert _attn_work_from_batch(batch, torch.device("cpu"), cp_size=1).item() == 1 * 16 * 16 + assert _attn_work_from_batch(batch, torch.device("cpu"), cp_size=8).item() == 128 * 128 + thd = { + "input_ids": torch.zeros(1, 64, dtype=torch.long), + "cu_seq_lens_q": torch.tensor([0, 3, 8, 15], dtype=torch.int32), + } + assert _attn_work_from_batch(thd, torch.device("cpu"), cp_size=1).item() == 3**2 + 5**2 + 7**2 + assert _attn_work_from_batch(thd, torch.device("cpu"), cp_size=8).item() == 3**2 + 5**2 + 7**2 + class TestGradAccAccumulation: """Lock in ESM-2's new log_micro_step/log_step split under gradient accumulation. diff --git a/bionemo-recipes/recipes/llama3_native_te/perf_logger.py b/bionemo-recipes/recipes/llama3_native_te/perf_logger.py index 72e500eb3d..deea719a7e 100644 --- a/bionemo-recipes/recipes/llama3_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/llama3_native_te/perf_logger.py @@ -142,15 +142,22 @@ def _compute_attn_flop_coeff(model_config_dict: dict) -> int: return 3 * num_layers * 4 * h -def _attn_work_from_batch(batch: dict, device: torch.device) -> torch.Tensor: - """Return Σ(Lᵢ²) for this batch as a GPU int64 scalar tensor. - - THD: uses ``cu_seq_lens_q`` (real per-doc lengths — excludes CP zigzag padding - that FA processes as cycles but contributes nothing to the training signal). - Falls back to ``cu_seq_lens_q_padded`` if only the padded variant is present. - BSHD: no cu_seq_lens in batch → each of B rows is a single "doc" of length S, - so Σ(Lᵢ²) = B·S². Int32 lens are cast to int64 BEFORE squaring (overflow at - L ≈ 46k otherwise). +def _attn_work_from_batch(batch: dict, device: torch.device, cp_size: int = 1) -> torch.Tensor: + """Return GLOBAL Σ(Lᵢ²) for this batch as an int64 scalar tensor. + + The caller divides by cp_size in log_step to convert this global number into + per-rank attention work, so this helper always returns a pre-CP-shard quantity. + + * THD: uses ``cu_seq_lens_q`` (real per-doc lengths — already global, the collator + emits boundaries before per-rank sharding). Falls back to ``cu_seq_lens_q_padded`` + if only the padded variant is present. Excludes CP zigzag padding that FA + processes as cycles but contributes nothing to the training signal. + * BSHD: ``batch["input_ids"].shape`` per rank is ``(B, S/cp)`` when CP is active + because ``ContextParallelDataLoaderWrapper`` pre-splits along the seq dim. To + recover the global B·S² from a per-rank shape, multiply by cp_size². When + cp_size=1 this is a no-op (per-rank shape == global shape). + + Int32 lens cast to int64 BEFORE squaring (overflow at L ≈ 46k otherwise). """ cu = batch.get("cu_seq_lens_q") if cu is None: @@ -159,8 +166,12 @@ def _attn_work_from_batch(batch: dict, device: torch.device) -> torch.Tensor: lens = (cu[1:] - cu[:-1]).to(torch.int64) return (lens * lens).sum() shape = batch["input_ids"].shape - batch_size, seq_len = int(shape[0]), int(shape[-1]) - return torch.tensor(batch_size * seq_len * seq_len, dtype=torch.int64, device=device) + batch_size, seq_len_per_rank = int(shape[0]), int(shape[-1]) + return torch.tensor( + batch_size * seq_len_per_rank * seq_len_per_rank * cp_size * cp_size, + dtype=torch.int64, + device=device, + ) class PerfLogger: @@ -294,8 +305,9 @@ def log_micro_step(self, step: int, batch: dict[str, torch.Tensor], outputs: Cau self.num_unpadded_tokens += batch["input_ids"].numel() if self._log_mfu: # Σ(Lᵢ²) from cu_seq_lens (THD) or B·S² from input_ids shape (BSHD). + # Helper returns a GLOBAL value (pre-CP-shard); log_step divides by cp_size. # Accumulates across grad-acc micro-batches; drained in log_step. - self._attn_work_accum += _attn_work_from_batch(batch, self._device) + self._attn_work_accum += _attn_work_from_batch(batch, self._device, self._cp_size) @nvtx.annotate("PerfLogger.log_step", color="purple") def log_step( diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py index 790e9f90e9..99061d431c 100644 --- a/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py +++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py @@ -403,3 +403,28 @@ def test_padded_fallback_when_unpadded_absent(self): "cu_seq_lens_q_padded": torch.tensor([0, 8, 16], dtype=torch.int32), } assert _attn_work_from_batch(batch, torch.device("cpu")).item() == 8**2 + 8**2 + + def test_bshd_cp_correction(self): + """BSHD with CP: per-rank shape (B, S/cp) → helper must return global B*S². + + ContextParallelDataLoaderWrapper pre-splits the sequence so each rank's + input_ids.shape is (B, S/cp), not (B, S). The helper returns a GLOBAL + quantity (the caller divides by cp_size), so the BSHD synthesis branch + must multiply per-rank shape² by cp_size² to recover global B*S². + Without this correction, BSHD+CP attention FLOPs would be undercounted + by a factor of cp² (the bug surfaced when running real-data llama3/og2 + BSHD benchmarks at cp=8). + """ + # Pretend a rank has shape (1, 16) — this would correspond to global S=16*cp. + batch = {"input_ids": torch.zeros(1, 16, dtype=torch.long)} + # cp_size=1 → per-rank shape == global shape: 1*16² = 256 + assert _attn_work_from_batch(batch, torch.device("cpu"), cp_size=1).item() == 1 * 16 * 16 + # cp_size=8 → global S = 16*8 = 128, global B*S² = 128² = 16384 + assert _attn_work_from_batch(batch, torch.device("cpu"), cp_size=8).item() == 128 * 128 + # THD path is unaffected by cp_size since cu_seq_lens_q is already global + thd = { + "input_ids": torch.zeros(1, 64, dtype=torch.long), + "cu_seq_lens_q": torch.tensor([0, 3, 8, 15], dtype=torch.int32), + } + assert _attn_work_from_batch(thd, torch.device("cpu"), cp_size=1).item() == 3**2 + 5**2 + 7**2 + assert _attn_work_from_batch(thd, torch.device("cpu"), cp_size=8).item() == 3**2 + 5**2 + 7**2 diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py b/bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py index e9d1e6fdf0..012c3496e6 100644 --- a/bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py @@ -150,15 +150,22 @@ def _compute_attn_flop_coeff(model_config_dict: dict) -> int: return 3 * num_layers * 4 * h -def _attn_work_from_batch(batch: dict, device: torch.device) -> torch.Tensor: - """Return Σ(Lᵢ²) for this batch as a GPU int64 scalar tensor. - - THD: uses ``cu_seq_lens_q`` (real per-doc lengths — excludes CP zigzag padding - that FA processes as cycles but contributes nothing to the training signal). - Falls back to ``cu_seq_lens_q_padded`` if only the padded variant is present. - BSHD: no cu_seq_lens in batch → each of B rows is a single "doc" of length S, - so Σ(Lᵢ²) = B·S². Int32 lens are cast to int64 BEFORE squaring (overflow at - L ≈ 46k otherwise). +def _attn_work_from_batch(batch: dict, device: torch.device, cp_size: int = 1) -> torch.Tensor: + """Return GLOBAL Σ(Lᵢ²) for this batch as an int64 scalar tensor. + + The caller divides by cp_size in log_step to convert this global number into + per-rank attention work, so this helper always returns a pre-CP-shard quantity. + + * THD: uses ``cu_seq_lens_q`` (real per-doc lengths — already global, the collator + emits boundaries before per-rank sharding). Falls back to ``cu_seq_lens_q_padded`` + if only the padded variant is present. Excludes CP zigzag padding that FA + processes as cycles but contributes nothing to the training signal. + * BSHD: ``batch["input_ids"].shape`` per rank is ``(B, S/cp)`` when CP is active + because ``ContextParallelDataLoaderWrapper`` pre-splits along the seq dim. To + recover the global B·S² from a per-rank shape, multiply by cp_size². When + cp_size=1 this is a no-op (per-rank shape == global shape). + + Int32 lens cast to int64 BEFORE squaring (overflow at L ≈ 46k otherwise). """ cu = batch.get("cu_seq_lens_q") if cu is None: @@ -167,8 +174,12 @@ def _attn_work_from_batch(batch: dict, device: torch.device) -> torch.Tensor: lens = (cu[1:] - cu[:-1]).to(torch.int64) return (lens * lens).sum() shape = batch["input_ids"].shape - batch_size, seq_len = int(shape[0]), int(shape[-1]) - return torch.tensor(batch_size * seq_len * seq_len, dtype=torch.int64, device=device) + batch_size, seq_len_per_rank = int(shape[0]), int(shape[-1]) + return torch.tensor( + batch_size * seq_len_per_rank * seq_len_per_rank * cp_size * cp_size, + dtype=torch.int64, + device=device, + ) class PerfLogger: @@ -295,8 +306,9 @@ def log_micro_step(self, step: int, batch: dict[str, torch.Tensor], outputs: Cau self.num_unpadded_tokens += batch["input_ids"].numel() if self._log_mfu: # Σ(Lᵢ²) from cu_seq_lens (THD) or B·S² from input_ids shape (BSHD). + # Helper returns a GLOBAL value (pre-CP-shard); log_step divides by cp_size. # Accumulates across grad-acc micro-batches; drained in log_step. - self._attn_work_accum += _attn_work_from_batch(batch, self._device) + self._attn_work_accum += _attn_work_from_batch(batch, self._device, self._cp_size) @nvtx.annotate("PerfLogger.log_step", color="purple") def log_step( diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/tests/test_perf_logger.py b/bionemo-recipes/recipes/opengenome2_llama_native_te/tests/test_perf_logger.py index c5c0405d85..6e6b242a4d 100644 --- a/bionemo-recipes/recipes/opengenome2_llama_native_te/tests/test_perf_logger.py +++ b/bionemo-recipes/recipes/opengenome2_llama_native_te/tests/test_perf_logger.py @@ -115,3 +115,21 @@ def test_padded_fallback_when_unpadded_absent(self): "cu_seq_lens_q_padded": torch.tensor([0, 8, 16], dtype=torch.int32), } assert _attn_work_from_batch(batch, torch.device("cpu")).item() == 8**2 + 8**2 + + def test_bshd_cp_correction(self): + """BSHD with CP: per-rank shape (B, S/cp) → helper must return global B*S². + + ContextParallelDataLoaderWrapper pre-splits the sequence so each rank's + input_ids.shape is (B, S/cp), not (B, S). The helper returns a GLOBAL + quantity (the caller divides by cp_size), so the BSHD synthesis branch + must multiply per-rank shape² by cp_size² to recover global B*S². + """ + batch = {"input_ids": torch.zeros(1, 16, dtype=torch.long)} + assert _attn_work_from_batch(batch, torch.device("cpu"), cp_size=1).item() == 1 * 16 * 16 + assert _attn_work_from_batch(batch, torch.device("cpu"), cp_size=8).item() == 128 * 128 + thd = { + "input_ids": torch.zeros(1, 64, dtype=torch.long), + "cu_seq_lens_q": torch.tensor([0, 3, 8, 15], dtype=torch.int32), + } + assert _attn_work_from_batch(thd, torch.device("cpu"), cp_size=1).item() == 3**2 + 5**2 + 7**2 + assert _attn_work_from_batch(thd, torch.device("cpu"), cp_size=8).item() == 3**2 + 5**2 + 7**2 From ab462398aac4031ff3778157baede469a1825b30 Mon Sep 17 00:00:00 2001 From: Gagan Kaushik Date: Wed, 22 Apr 2026 16:09:40 -0700 Subject: [PATCH 16/24] perf_logger: report true peak memory, not post-step resting The memory reporting used torch.cuda.memory_allocated() (current) at log time, so both the "max" and "mean" metrics reflected the post-step resting footprint. Under FSDP that massively undercounts the transient peak: params get all-gathered briefly during forward/backward and activations balloon before backward frees them, but by the time log_step runs they're already released. Switch to torch.cuda.max_memory_allocated() (tracks the high-water mark of live allocations) paired with torch.cuda.reset_peak_memory_stats() so each logging window reports its own peak rather than a running max since process start. Both are host-side counter ops: no sync, no kernel launch. The "mean" metric keeps using the post-step current value as a "resting footprint" indicator (useful to see how much memory is held between steps), while "max" now reflects what the hardware actually peaks at. These are the values that belong in the benchmark tables. Applied identically to all four MFU-tracking recipes. Signed-off-by: Gagan Kaushik --- .../recipes/codonfm_native_te/perf_logger.py | 13 ++++++++++--- .../recipes/esm2_native_te/perf_logger.py | 13 ++++++++++--- .../recipes/llama3_native_te/perf_logger.py | 13 ++++++++++--- .../opengenome2_llama_native_te/perf_logger.py | 13 ++++++++++--- 4 files changed, 40 insertions(+), 12 deletions(-) diff --git a/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py b/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py index 292ea0c5b9..058b143024 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py @@ -359,9 +359,16 @@ def log_step( if self._peak_tflops is not None: self.metrics["train/mfu_pct"].update(tflops_per_gpu / self._peak_tflops * 100.0) - memory_allocated = torch.cuda.memory_allocated() / (1024**3) - self.metrics["train/gpu_memory_allocated_max_gb"].update(memory_allocated) - self.metrics["train/gpu_memory_allocated_mean_gb"].update(memory_allocated) + # Report TRUE peak memory across the logging window (FSDP-gathered params + + # activations held for backward), not just the post-step resting footprint. + # Reset the peak counter so each window reports its own peak instead of a + # running max since process start. Both calls are pure host-side counter ops + # -- no sync, no kernel launch. + peak_gb = torch.cuda.max_memory_allocated() / (1024**3) + current_gb = torch.cuda.memory_allocated() / (1024**3) + torch.cuda.reset_peak_memory_stats() + self.metrics["train/gpu_memory_allocated_max_gb"].update(peak_gb) + self.metrics["train/gpu_memory_allocated_mean_gb"].update(current_gb) metrics = self.metrics.compute() self.metrics.reset() diff --git a/bionemo-recipes/recipes/esm2_native_te/perf_logger.py b/bionemo-recipes/recipes/esm2_native_te/perf_logger.py index b1d8abd865..95e9594d1c 100644 --- a/bionemo-recipes/recipes/esm2_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/esm2_native_te/perf_logger.py @@ -368,9 +368,16 @@ def log_step( if self._peak_tflops is not None: self.metrics["train/mfu_pct"].update(tflops_per_gpu / self._peak_tflops * 100.0) - memory_allocated = torch.cuda.memory_allocated() / (1024**3) - self.metrics["train/gpu_memory_allocated_max_gb"].update(memory_allocated) - self.metrics["train/gpu_memory_allocated_mean_gb"].update(memory_allocated) + # Report TRUE peak memory across the logging window (FSDP-gathered params + + # activations held for backward), not just the post-step resting footprint. + # Reset the peak counter so each window reports its own peak instead of a + # running max since process start. Both calls are pure host-side counter ops + # -- no sync, no kernel launch. + peak_gb = torch.cuda.max_memory_allocated() / (1024**3) + current_gb = torch.cuda.memory_allocated() / (1024**3) + torch.cuda.reset_peak_memory_stats() + self.metrics["train/gpu_memory_allocated_max_gb"].update(peak_gb) + self.metrics["train/gpu_memory_allocated_mean_gb"].update(current_gb) metrics = self.metrics.compute() self.metrics.reset() diff --git a/bionemo-recipes/recipes/llama3_native_te/perf_logger.py b/bionemo-recipes/recipes/llama3_native_te/perf_logger.py index deea719a7e..96b23e903d 100644 --- a/bionemo-recipes/recipes/llama3_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/llama3_native_te/perf_logger.py @@ -382,9 +382,16 @@ def log_step( if self._peak_tflops is not None: self.metrics["train/mfu_pct"].update(tflops_per_gpu / self._peak_tflops * 100.0) - memory_allocated = torch.cuda.memory_allocated() / (1024**3) - self.metrics["train/gpu_memory_allocated_max_gb"].update(memory_allocated) - self.metrics["train/gpu_memory_allocated_mean_gb"].update(memory_allocated) + # Report TRUE peak memory across the logging window (FSDP-gathered params + + # activations held for backward), not just the post-step resting footprint. + # Reset the peak counter so each window reports its own peak instead of a + # running max since process start. Both calls are pure host-side counter ops + # -- no sync, no kernel launch. + peak_gb = torch.cuda.max_memory_allocated() / (1024**3) + current_gb = torch.cuda.memory_allocated() / (1024**3) + torch.cuda.reset_peak_memory_stats() + self.metrics["train/gpu_memory_allocated_max_gb"].update(peak_gb) + self.metrics["train/gpu_memory_allocated_mean_gb"].update(current_gb) metrics = self.metrics.compute() self.metrics.reset() diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py b/bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py index 012c3496e6..2d86bdf15a 100644 --- a/bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py @@ -383,9 +383,16 @@ def log_step( if self._peak_tflops is not None: self.metrics["train/mfu_pct"].update(tflops_per_gpu / self._peak_tflops * 100.0) - memory_allocated = torch.cuda.memory_allocated() / (1024**3) - self.metrics["train/gpu_memory_allocated_max_gb"].update(memory_allocated) - self.metrics["train/gpu_memory_allocated_mean_gb"].update(memory_allocated) + # Report TRUE peak memory across the logging window (FSDP-gathered params + + # activations held for backward), not just the post-step resting footprint. + # Reset the peak counter so each window reports its own peak instead of a + # running max since process start. Both calls are pure host-side counter ops + # -- no sync, no kernel launch. + peak_gb = torch.cuda.max_memory_allocated() / (1024**3) + current_gb = torch.cuda.memory_allocated() / (1024**3) + torch.cuda.reset_peak_memory_stats() + self.metrics["train/gpu_memory_allocated_max_gb"].update(peak_gb) + self.metrics["train/gpu_memory_allocated_mean_gb"].update(current_gb) metrics = self.metrics.compute() self.metrics.reset() From f8f84cb7fa00e4a3e5fb87fce6f4ccd5890564e9 Mon Sep 17 00:00:00 2001 From: Gagan Kaushik Date: Wed, 22 Apr 2026 19:10:38 -0700 Subject: [PATCH 17/24] perf_logger: split MFU into unpadded vs padded variants MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Current reporting mixes conventions: non-attention flops count every slot (padded, PaLM convention), but attention flops come from cu_seq_lens_q which drops both CP-zigzag padding and BSHD row padding. For BSHD with dynamic-padded short docs this made the attention term tiny while non-attention counted all the pad positions — an inconsistent hybrid that is neither pure useful work nor pure hardware view. Split into two parallel metrics: train/mfu_pct / tflops_per_gpu Useful-work rate. Counts only real tokens: non_attn uses num_unpadded_tokens (attention_mask.sum()) attn uses Sigma(L_i^2) from cu_seq_lens_q (THD), or per-row attention_mask.sum(dim=-1) (BSHD). Excludes CP zigzag divisibility pad and BSHD row pad. train/mfu_padded_pct / tflops_per_gpu_padded Hardware view. Counts all slots the GPU actually burns cycles on: non_attn uses input_ids.numel() attn uses Sigma(L_i^2) from cu_seq_lens_q_padded (THD) or full batch shape (BSHD). Includes both padding types. For synthetic single-doc-per-pack THD runs (where Sigma L_i^2 == S^2 and there is no padding of either kind), unpadded == padded and the two metrics report identical values; published benchmarks in ci/benchmarks remain internally consistent either way. Added _attn_work_from_batch(include_padding=bool) kwarg. The function still returns a pre-CP-shard global quantity; the caller divides by cp_size in log_step. BSHD per-row real lengths come from attention_mask; note in docstring that for BSHD+CP this is exact under evenly distributed padding (common case) and may slightly underestimate when padding is concentrated on one end. Added two unit tests per recipe (test_include_padding_thd, test_include_padding_bshd_with_attention_mask) pinning the split. Signed-off-by: Gagan Kaushik --- .../recipes/codonfm_native_te/perf_logger.py | 117 +++++++++++------ .../tests/test_perf_logger.py | 19 +++ .../recipes/esm2_native_te/perf_logger.py | 111 +++++++++++----- .../esm2_native_te/tests/test_perf_logger.py | 19 +++ .../recipes/llama3_native_te/perf_logger.py | 124 +++++++++++++----- .../tests/test_perf_logger.py | 22 ++++ .../perf_logger.py | 112 +++++++++++----- .../tests/test_perf_logger.py | 19 +++ 8 files changed, 400 insertions(+), 143 deletions(-) diff --git a/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py b/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py index 058b143024..7623a0bf21 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py @@ -145,28 +145,50 @@ def _compute_attn_flop_coeff(model_config_dict: dict) -> int: return 3 * num_layers * 4 * h -def _attn_work_from_batch(batch: dict, device: torch.device, cp_size: int = 1) -> torch.Tensor: +def _attn_work_from_batch( + batch: dict, device: torch.device, cp_size: int = 1, include_padding: bool = False +) -> torch.Tensor: """Return GLOBAL Σ(Lᵢ²) for this batch as an int64 scalar tensor. The caller divides by cp_size in log_step to convert this global number into - per-rank attention work, so this helper always returns a pre-CP-shard quantity. - - * THD: uses ``cu_seq_lens_q`` (real per-doc lengths — already global, the collator - emits boundaries before per-rank sharding). Falls back to ``cu_seq_lens_q_padded`` - if only the padded variant is present. Excludes CP zigzag padding that FA - processes as cycles but contributes nothing to the training signal. - * BSHD: ``batch["input_ids"].shape`` per rank is ``(B, S/cp)`` when CP is active - because ``ContextParallelDataLoaderWrapper`` pre-splits along the seq dim. To - recover the global B·S² from a per-rank shape, multiply by cp_size². When - cp_size=1 this is a no-op (per-rank shape == global shape). CodonFM runs FSDP - without CP, so cp_size=1 always here, but the formula stays correct if CP is - added later. + per-rank attention work; this helper always returns a pre-CP-shard quantity. + ``include_padding=False`` (default) counts only real tokens — "useful work": + * THD: uses ``cu_seq_lens_q`` (real per-doc lengths, already global). + * BSHD: uses ``attention_mask.sum(dim=-1)`` per row, scaled by ``cp_size²`` to + recover global. + + ``include_padding=True`` counts padded positions too — "hardware view": + * THD: uses ``cu_seq_lens_q_padded`` (includes CP zigzag-divisibility padding). + * BSHD: uses full ``input_ids.shape``, scaled by ``cp_size²``. + + CodonFM currently runs FSDP without CP (cp_size=1), but the formula stays correct + if CP is added later. Int32 lens cast to int64 BEFORE squaring (overflow at L ≈ 46k otherwise). """ - cu = batch.get("cu_seq_lens_q") - if cu is None: + if include_padding: cu = batch.get("cu_seq_lens_q_padded") + if cu is None: + cu = batch.get("cu_seq_lens_q") + if cu is not None: + lens = (cu[1:] - cu[:-1]).to(torch.int64) + return (lens * lens).sum() + shape = batch["input_ids"].shape + batch_size, seq_len_per_rank = int(shape[0]), int(shape[-1]) + return torch.tensor( + batch_size * seq_len_per_rank * seq_len_per_rank * cp_size * cp_size, + dtype=torch.int64, + device=device, + ) + cu = batch.get("cu_seq_lens_q") + if cu is not None: + lens = (cu[1:] - cu[:-1]).to(torch.int64) + return (lens * lens).sum() + mask = batch.get("attention_mask") + if mask is not None: + per_row_real = mask.sum(dim=-1).to(torch.int64) + return (per_row_real * per_row_real).sum() * cp_size * cp_size + cu = batch.get("cu_seq_lens_q_padded") if cu is not None: lens = (cu[1:] - cu[:-1]).to(torch.int64) return (lens * lens).sum() @@ -241,9 +263,14 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig, model_confi "train/gpu_memory_allocated_mean_gb": torchmetrics.MeanMetric(), } if self._log_mfu: + # Two TFLOPS/MFU pairs: + # * tflops_per_gpu / mfu_pct — useful work only (no padding) + # * tflops_per_gpu_padded / mfu_padded_pct — hardware view (counts padding slots) metrics_dict["train/tflops_per_gpu"] = torchmetrics.MeanMetric() + metrics_dict["train/tflops_per_gpu_padded"] = torchmetrics.MeanMetric() if self._peak_tflops is not None: metrics_dict["train/mfu_pct"] = torchmetrics.MeanMetric() + metrics_dict["train/mfu_padded_pct"] = torchmetrics.MeanMetric() self.metrics = torchmetrics.MetricCollection(metrics_dict) self.metrics.to(self._device) @@ -259,8 +286,11 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig, model_confi # Gradient accumulation tracking self.num_tokens = 0 self.num_unpadded_tokens = torch.tensor(0, dtype=torch.int64, device=self._device) - # Σ(Lᵢ²) over grad-acc micro-batches — drives the attention-FLOP term at log time. - self._attn_work_accum = torch.tensor(0, dtype=torch.int64, device=self._device) + # Σ(Lᵢ²) over grad-acc micro-batches — two flavors: + # unpadded: only real tokens (useful work), drives mfu_pct + # padded: all slots including CP-zigzag / BSHD row padding, drives mfu_padded_pct + self._attn_work_unpadded_accum = torch.tensor(0, dtype=torch.int64, device=self._device) + self._attn_work_padded_accum = torch.tensor(0, dtype=torch.int64, device=self._device) self.running_loss = torch.tensor(0.0, device=self._device) self.grad_acc_step_count = 0 @@ -283,10 +313,14 @@ def log_micro_step(self, step: int, batch: dict[str, torch.Tensor], outputs: Mas num_unpadded_tokens = batch["input_ids"][batch["input_ids"] != PAD_TOKEN_ID].numel() self.num_unpadded_tokens += num_unpadded_tokens if self._log_mfu: - # Σ(Lᵢ²) from cu_seq_lens (THD) or B·S² from input_ids shape (BSHD). + # Accumulate both unpadded (useful) and padded (hardware) Σ(Lᵢ²). # Helper returns a GLOBAL value (pre-CP-shard); log_step divides by cp_size. - # Accumulates across grad-acc micro-batches; drained in log_step. - self._attn_work_accum += _attn_work_from_batch(batch, self._device, self._cp_size) + self._attn_work_unpadded_accum += _attn_work_from_batch( + batch, self._device, self._cp_size, include_padding=False + ) + self._attn_work_padded_accum += _attn_work_from_batch( + batch, self._device, self._cp_size, include_padding=True + ) # Update perplexity per micro-batch since it needs logits + labels logits = outputs.logits @@ -340,24 +374,30 @@ def log_step( self.metrics["train/unpadded_tokens_per_second_per_gpu"].update(self.num_unpadded_tokens / step_time) if self._log_mfu: - # Two-term FLOP accounting: - # non_attn_flops = non_attn_per_token * num_tokens - # attn_flops = attn_flop_coeff * Σ(Lᵢ²) / cp_size - # num_tokens follows the PaLM/Megatron/MosaicML convention of counting - # every slot in input_ids (hardware runs QKV/MLP/LM-head matmuls over - # padded positions too). The attention term uses the real per-doc - # Σ(Lᵢ²) from cu_seq_lens because FA gates the score matrix on doc - # boundaries, so a packed batch with fragmented docs does genuinely - # less attention work than a uniform B·S² would imply — the old - # single-term formula overstated MFU in exactly that regime. - attn_work = int(self._attn_work_accum.item()) - non_attn_flops = self._non_attn_per_token_flops * self.num_tokens - attn_flops = (self._attn_flop_coeff * attn_work) // self._cp_size - flops_per_step = non_attn_flops + attn_flops - tflops_per_gpu = flops_per_step / step_time / 1e12 - self.metrics["train/tflops_per_gpu"].update(tflops_per_gpu) + # Two MFU flavors reported side-by-side: + # mfu_pct = useful-work rate. Non-attn over real tokens, + # attn over real Σ(Lᵢ²). Drops both padding types. + # mfu_padded_pct = hardware view. Non-attn over all slots, attn over + # padded Σ(Lᵢ²) (includes CP zigzag + BSHD row pad). + attn_unpadded = int(self._attn_work_unpadded_accum.item()) + attn_padded = int(self._attn_work_padded_accum.item()) + num_unpadded = int(self.num_unpadded_tokens.item()) + + non_attn_unpadded = self._non_attn_per_token_flops * num_unpadded + attn_flops_unpadded = (self._attn_flop_coeff * attn_unpadded) // self._cp_size + flops_unpadded = non_attn_unpadded + attn_flops_unpadded + tflops_unpadded = flops_unpadded / step_time / 1e12 + + non_attn_padded = self._non_attn_per_token_flops * self.num_tokens + attn_flops_padded = (self._attn_flop_coeff * attn_padded) // self._cp_size + flops_padded = non_attn_padded + attn_flops_padded + tflops_padded = flops_padded / step_time / 1e12 + + self.metrics["train/tflops_per_gpu"].update(tflops_unpadded) + self.metrics["train/tflops_per_gpu_padded"].update(tflops_padded) if self._peak_tflops is not None: - self.metrics["train/mfu_pct"].update(tflops_per_gpu / self._peak_tflops * 100.0) + self.metrics["train/mfu_pct"].update(tflops_unpadded / self._peak_tflops * 100.0) + self.metrics["train/mfu_padded_pct"].update(tflops_padded / self._peak_tflops * 100.0) # Report TRUE peak memory across the logging window (FSDP-gathered params + # activations held for backward), not just the post-step resting footprint. @@ -390,7 +430,8 @@ def log_step( self.running_loss.zero_() self.num_tokens = 0 self.num_unpadded_tokens.zero_() - self._attn_work_accum.zero_() + self._attn_work_unpadded_accum.zero_() + self._attn_work_padded_accum.zero_() self.grad_acc_step_count = 0 def finish(self): diff --git a/bionemo-recipes/recipes/codonfm_native_te/tests/test_perf_logger.py b/bionemo-recipes/recipes/codonfm_native_te/tests/test_perf_logger.py index c782f4df24..41eb31ec24 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/tests/test_perf_logger.py +++ b/bionemo-recipes/recipes/codonfm_native_te/tests/test_perf_logger.py @@ -318,3 +318,22 @@ def test_bshd_cp_correction(self): } assert _attn_work_from_batch(thd, torch.device("cpu"), cp_size=1).item() == 3**2 + 5**2 + 7**2 assert _attn_work_from_batch(thd, torch.device("cpu"), cp_size=8).item() == 3**2 + 5**2 + 7**2 + + def test_include_padding_thd(self): + """THD include_padding=True uses cu_seq_lens_q_padded; False uses cu_seq_lens_q.""" + batch = { + "input_ids": torch.zeros(1, 16, dtype=torch.long), + "cu_seq_lens_q": torch.tensor([0, 5, 11], dtype=torch.int32), + "cu_seq_lens_q_padded": torch.tensor([0, 8, 16], dtype=torch.int32), + } + dev = torch.device("cpu") + assert _attn_work_from_batch(batch, dev, cp_size=1, include_padding=False).item() == 5**2 + 6**2 + assert _attn_work_from_batch(batch, dev, cp_size=1, include_padding=True).item() == 8**2 + 8**2 + + def test_include_padding_bshd_with_attention_mask(self): + """BSHD include_padding=False uses attention_mask; True uses full shape.""" + mask = torch.tensor([[1, 1, 1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0, 0, 0]], dtype=torch.int64) + batch = {"input_ids": torch.zeros(2, 8, dtype=torch.long), "attention_mask": mask} + dev = torch.device("cpu") + assert _attn_work_from_batch(batch, dev, cp_size=1, include_padding=False).item() == 5**2 + 3**2 + assert _attn_work_from_batch(batch, dev, cp_size=1, include_padding=True).item() == 2 * 8 * 8 diff --git a/bionemo-recipes/recipes/esm2_native_te/perf_logger.py b/bionemo-recipes/recipes/esm2_native_te/perf_logger.py index 95e9594d1c..f5e64caf8c 100644 --- a/bionemo-recipes/recipes/esm2_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/esm2_native_te/perf_logger.py @@ -144,26 +144,48 @@ def _compute_attn_flop_coeff(model_config_dict: dict) -> int: return 3 * num_layers * 4 * h -def _attn_work_from_batch(batch: dict, device: torch.device, cp_size: int = 1) -> torch.Tensor: +def _attn_work_from_batch( + batch: dict, device: torch.device, cp_size: int = 1, include_padding: bool = False +) -> torch.Tensor: """Return GLOBAL Σ(Lᵢ²) for this batch as an int64 scalar tensor. The caller divides by cp_size in log_step to convert this global number into - per-rank attention work, so this helper always returns a pre-CP-shard quantity. + per-rank attention work; this helper always returns a pre-CP-shard quantity. - * THD: uses ``cu_seq_lens_q`` (real per-doc lengths — already global, the collator - emits boundaries before per-rank sharding). Falls back to ``cu_seq_lens_q_padded`` - if only the padded variant is present. Excludes CP zigzag padding that FA - processes as cycles but contributes nothing to the training signal. - * BSHD: ``batch["input_ids"].shape`` per rank is ``(B, S/cp)`` when CP is active - because ``ContextParallelDataLoaderWrapper`` pre-splits along the seq dim. To - recover the global B·S² from a per-rank shape, multiply by cp_size². When - cp_size=1 this is a no-op (per-rank shape == global shape). + ``include_padding=False`` (default) counts only real tokens — "useful work": + * THD: uses ``cu_seq_lens_q`` (real per-doc lengths, already global). + * BSHD: uses ``attention_mask.sum(dim=-1)`` per row, scaled by ``cp_size²`` to + recover global. + + ``include_padding=True`` counts padded positions too — "hardware view": + * THD: uses ``cu_seq_lens_q_padded`` (includes CP zigzag-divisibility padding). + * BSHD: uses full ``input_ids.shape``, scaled by ``cp_size²``. Int32 lens cast to int64 BEFORE squaring (overflow at L ≈ 46k otherwise). """ - cu = batch.get("cu_seq_lens_q") - if cu is None: + if include_padding: cu = batch.get("cu_seq_lens_q_padded") + if cu is None: + cu = batch.get("cu_seq_lens_q") + if cu is not None: + lens = (cu[1:] - cu[:-1]).to(torch.int64) + return (lens * lens).sum() + shape = batch["input_ids"].shape + batch_size, seq_len_per_rank = int(shape[0]), int(shape[-1]) + return torch.tensor( + batch_size * seq_len_per_rank * seq_len_per_rank * cp_size * cp_size, + dtype=torch.int64, + device=device, + ) + cu = batch.get("cu_seq_lens_q") + if cu is not None: + lens = (cu[1:] - cu[:-1]).to(torch.int64) + return (lens * lens).sum() + mask = batch.get("attention_mask") + if mask is not None: + per_row_real = mask.sum(dim=-1).to(torch.int64) + return (per_row_real * per_row_real).sum() * cp_size * cp_size + cu = batch.get("cu_seq_lens_q_padded") if cu is not None: lens = (cu[1:] - cu[:-1]).to(torch.int64) return (lens * lens).sum() @@ -245,9 +267,14 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig, model_confi "train/gpu_memory_allocated_mean_gb": torchmetrics.MeanMetric(), } if self._log_mfu: + # Two TFLOPS/MFU pairs: + # * tflops_per_gpu / mfu_pct — useful work only (no padding) + # * tflops_per_gpu_padded / mfu_padded_pct — hardware view (counts padding slots) metrics_dict["train/tflops_per_gpu"] = torchmetrics.MeanMetric() + metrics_dict["train/tflops_per_gpu_padded"] = torchmetrics.MeanMetric() if self._peak_tflops is not None: metrics_dict["train/mfu_pct"] = torchmetrics.MeanMetric() + metrics_dict["train/mfu_padded_pct"] = torchmetrics.MeanMetric() self.metrics = torchmetrics.MetricCollection(metrics_dict) # We move metrics to a GPU device so we can use torch.distributed to aggregate them before logging. @@ -266,8 +293,11 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig, model_confi # the last optimizer step in the logging window, then drained in log_step). self.num_tokens = 0 self.num_unpadded_tokens = torch.tensor(0, dtype=torch.int64, device=self._device) - # Σ(Lᵢ²) over grad-acc micro-batches — drives the attention-FLOP term at log time. - self._attn_work_accum = torch.tensor(0, dtype=torch.int64, device=self._device) + # Σ(Lᵢ²) over grad-acc micro-batches — two flavors: + # unpadded: only real tokens (useful work), drives mfu_pct + # padded: all slots including CP-zigzag / BSHD row padding, drives mfu_padded_pct + self._attn_work_unpadded_accum = torch.tensor(0, dtype=torch.int64, device=self._device) + self._attn_work_padded_accum = torch.tensor(0, dtype=torch.int64, device=self._device) self.running_loss = torch.tensor(0.0, device=self._device) self.grad_acc_step_count = 0 @@ -290,10 +320,14 @@ def log_micro_step(self, step: int, batch: dict[str, torch.Tensor], outputs: Mas num_unpadded_tokens = batch["input_ids"][batch["input_ids"] != ESM2_PAD_TOKEN_ID].numel() self.num_unpadded_tokens += num_unpadded_tokens if self._log_mfu: - # Σ(Lᵢ²) from cu_seq_lens (THD) or B·S² from input_ids shape (BSHD). + # Accumulate both unpadded (useful) and padded (hardware) Σ(Lᵢ²). # Helper returns a GLOBAL value (pre-CP-shard); log_step divides by cp_size. - # Accumulates across grad-acc micro-batches; drained in log_step. - self._attn_work_accum += _attn_work_from_batch(batch, self._device, self._cp_size) + self._attn_work_unpadded_accum += _attn_work_from_batch( + batch, self._device, self._cp_size, include_padding=False + ) + self._attn_work_padded_accum += _attn_work_from_batch( + batch, self._device, self._cp_size, include_padding=True + ) # Update perplexity per micro-batch since it needs logits + labels. logits = outputs.logits @@ -349,24 +383,30 @@ def log_step( self.metrics["train/total_unpadded_tokens_per_batch"].update(self.num_unpadded_tokens) if self._log_mfu: - # Two-term FLOP accounting: - # non_attn_flops = non_attn_per_token * num_tokens - # attn_flops = attn_flop_coeff * Σ(Lᵢ²) / cp_size - # num_tokens follows the PaLM/Megatron/MosaicML convention of counting - # every slot in input_ids (hardware runs QKV/MLP/LM-head matmuls over - # padded positions too). The attention term uses the real per-doc - # Σ(Lᵢ²) from cu_seq_lens because FA gates the score matrix on doc - # boundaries, so a packed batch with fragmented docs does genuinely - # less attention work than a uniform B·S² would imply — the old - # single-term formula overstated MFU in exactly that regime. - attn_work = int(self._attn_work_accum.item()) - non_attn_flops = self._non_attn_per_token_flops * self.num_tokens - attn_flops = (self._attn_flop_coeff * attn_work) // self._cp_size - flops_per_step = non_attn_flops + attn_flops - tflops_per_gpu = flops_per_step / step_time / 1e12 - self.metrics["train/tflops_per_gpu"].update(tflops_per_gpu) + # Two MFU flavors reported side-by-side: + # mfu_pct = useful-work rate. Non-attn over real tokens, + # attn over real Σ(Lᵢ²). Drops both padding types. + # mfu_padded_pct = hardware view. Non-attn over all slots, attn over + # padded Σ(Lᵢ²) (includes CP zigzag + BSHD row pad). + attn_unpadded = int(self._attn_work_unpadded_accum.item()) + attn_padded = int(self._attn_work_padded_accum.item()) + num_unpadded = int(self.num_unpadded_tokens.item()) + + non_attn_unpadded = self._non_attn_per_token_flops * num_unpadded + attn_flops_unpadded = (self._attn_flop_coeff * attn_unpadded) // self._cp_size + flops_unpadded = non_attn_unpadded + attn_flops_unpadded + tflops_unpadded = flops_unpadded / step_time / 1e12 + + non_attn_padded = self._non_attn_per_token_flops * self.num_tokens + attn_flops_padded = (self._attn_flop_coeff * attn_padded) // self._cp_size + flops_padded = non_attn_padded + attn_flops_padded + tflops_padded = flops_padded / step_time / 1e12 + + self.metrics["train/tflops_per_gpu"].update(tflops_unpadded) + self.metrics["train/tflops_per_gpu_padded"].update(tflops_padded) if self._peak_tflops is not None: - self.metrics["train/mfu_pct"].update(tflops_per_gpu / self._peak_tflops * 100.0) + self.metrics["train/mfu_pct"].update(tflops_unpadded / self._peak_tflops * 100.0) + self.metrics["train/mfu_padded_pct"].update(tflops_padded / self._peak_tflops * 100.0) # Report TRUE peak memory across the logging window (FSDP-gathered params + # activations held for backward), not just the post-step resting footprint. @@ -399,7 +439,8 @@ def log_step( self.running_loss.zero_() self.num_tokens = 0 self.num_unpadded_tokens.zero_() - self._attn_work_accum.zero_() + self._attn_work_unpadded_accum.zero_() + self._attn_work_padded_accum.zero_() self.grad_acc_step_count = 0 def finish(self): diff --git a/bionemo-recipes/recipes/esm2_native_te/tests/test_perf_logger.py b/bionemo-recipes/recipes/esm2_native_te/tests/test_perf_logger.py index e0889d5976..8d6d1bbb49 100644 --- a/bionemo-recipes/recipes/esm2_native_te/tests/test_perf_logger.py +++ b/bionemo-recipes/recipes/esm2_native_te/tests/test_perf_logger.py @@ -179,6 +179,25 @@ def test_bshd_cp_correction(self): assert _attn_work_from_batch(thd, torch.device("cpu"), cp_size=1).item() == 3**2 + 5**2 + 7**2 assert _attn_work_from_batch(thd, torch.device("cpu"), cp_size=8).item() == 3**2 + 5**2 + 7**2 + def test_include_padding_thd(self): + """THD include_padding=True uses cu_seq_lens_q_padded; False uses cu_seq_lens_q.""" + batch = { + "input_ids": torch.zeros(1, 16, dtype=torch.long), + "cu_seq_lens_q": torch.tensor([0, 5, 11], dtype=torch.int32), + "cu_seq_lens_q_padded": torch.tensor([0, 8, 16], dtype=torch.int32), + } + dev = torch.device("cpu") + assert _attn_work_from_batch(batch, dev, cp_size=1, include_padding=False).item() == 5**2 + 6**2 + assert _attn_work_from_batch(batch, dev, cp_size=1, include_padding=True).item() == 8**2 + 8**2 + + def test_include_padding_bshd_with_attention_mask(self): + """BSHD include_padding=False uses attention_mask; True uses full shape.""" + mask = torch.tensor([[1, 1, 1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0, 0, 0]], dtype=torch.int64) + batch = {"input_ids": torch.zeros(2, 8, dtype=torch.long), "attention_mask": mask} + dev = torch.device("cpu") + assert _attn_work_from_batch(batch, dev, cp_size=1, include_padding=False).item() == 5**2 + 3**2 + assert _attn_work_from_batch(batch, dev, cp_size=1, include_padding=True).item() == 2 * 8 * 8 + class TestGradAccAccumulation: """Lock in ESM-2's new log_micro_step/log_step split under gradient accumulation. diff --git a/bionemo-recipes/recipes/llama3_native_te/perf_logger.py b/bionemo-recipes/recipes/llama3_native_te/perf_logger.py index 96b23e903d..dbe4469857 100644 --- a/bionemo-recipes/recipes/llama3_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/llama3_native_te/perf_logger.py @@ -142,26 +142,54 @@ def _compute_attn_flop_coeff(model_config_dict: dict) -> int: return 3 * num_layers * 4 * h -def _attn_work_from_batch(batch: dict, device: torch.device, cp_size: int = 1) -> torch.Tensor: +def _attn_work_from_batch( + batch: dict, device: torch.device, cp_size: int = 1, include_padding: bool = False +) -> torch.Tensor: """Return GLOBAL Σ(Lᵢ²) for this batch as an int64 scalar tensor. The caller divides by cp_size in log_step to convert this global number into - per-rank attention work, so this helper always returns a pre-CP-shard quantity. + per-rank attention work; this helper always returns a pre-CP-shard quantity. - * THD: uses ``cu_seq_lens_q`` (real per-doc lengths — already global, the collator - emits boundaries before per-rank sharding). Falls back to ``cu_seq_lens_q_padded`` - if only the padded variant is present. Excludes CP zigzag padding that FA - processes as cycles but contributes nothing to the training signal. - * BSHD: ``batch["input_ids"].shape`` per rank is ``(B, S/cp)`` when CP is active - because ``ContextParallelDataLoaderWrapper`` pre-splits along the seq dim. To - recover the global B·S² from a per-rank shape, multiply by cp_size². When - cp_size=1 this is a no-op (per-rank shape == global shape). + ``include_padding=False`` (default) counts only real tokens — "useful work": + * THD: uses ``cu_seq_lens_q`` (real per-doc lengths, already global). + * BSHD: uses ``attention_mask.sum(dim=-1)`` per row to get real per-row lengths, + scaled by ``cp_size²`` to recover global. NOTE: for BSHD+CP this is exact when + padding is evenly distributed across cp chunks (the common case); can + underestimate slightly when padding is all on one end of the sequence because + per-rank mask.sum² loses row-level correlation info. + + ``include_padding=True`` counts padded positions too — "hardware view": + * THD: uses ``cu_seq_lens_q_padded`` (includes CP zigzag-divisibility padding). + * BSHD: uses full ``input_ids.shape`` (includes dynamic-padding-to-longest slots), + scaled by ``cp_size²``. Int32 lens cast to int64 BEFORE squaring (overflow at L ≈ 46k otherwise). """ - cu = batch.get("cu_seq_lens_q") - if cu is None: + if include_padding: cu = batch.get("cu_seq_lens_q_padded") + if cu is None: + cu = batch.get("cu_seq_lens_q") # fall back if no padded variant present + if cu is not None: + lens = (cu[1:] - cu[:-1]).to(torch.int64) + return (lens * lens).sum() + shape = batch["input_ids"].shape + batch_size, seq_len_per_rank = int(shape[0]), int(shape[-1]) + return torch.tensor( + batch_size * seq_len_per_rank * seq_len_per_rank * cp_size * cp_size, + dtype=torch.int64, + device=device, + ) + # Unpadded (real work) + cu = batch.get("cu_seq_lens_q") + if cu is not None: + lens = (cu[1:] - cu[:-1]).to(torch.int64) + return (lens * lens).sum() + mask = batch.get("attention_mask") + if mask is not None: + per_row_real = mask.sum(dim=-1).to(torch.int64) + return (per_row_real * per_row_real).sum() * cp_size * cp_size + # Fallback: no real-length signal present — try cu_seq_lens_q_padded, then shape. + cu = batch.get("cu_seq_lens_q_padded") if cu is not None: lens = (cu[1:] - cu[:-1]).to(torch.int64) return (lens * lens).sum() @@ -244,9 +272,14 @@ def __init__( "train/gpu_memory_allocated_mean_gb": torchmetrics.MeanMetric(), } if self._log_mfu: + # Two TFLOPS/MFU pairs: + # * tflops_per_gpu / mfu_pct — useful work only (no padding) + # * tflops_per_gpu_padded / mfu_padded_pct — hardware view (counts padding slots) metrics_dict["train/tflops_per_gpu"] = torchmetrics.MeanMetric() + metrics_dict["train/tflops_per_gpu_padded"] = torchmetrics.MeanMetric() if self._peak_tflops is not None: metrics_dict["train/mfu_pct"] = torchmetrics.MeanMetric() + metrics_dict["train/mfu_padded_pct"] = torchmetrics.MeanMetric() self.metrics = torchmetrics.MetricCollection(metrics_dict) # We move metrics to a GPU device so we can use torch.distributed to aggregate them before logging. @@ -269,8 +302,11 @@ def __init__( # Gradient accumulation tracking self.num_tokens = 0 self.num_unpadded_tokens = torch.tensor(0, dtype=torch.int64, device=self._device) - # Σ(Lᵢ²) over grad-acc micro-batches — drives the attention-FLOP term at log time. - self._attn_work_accum = torch.tensor(0, dtype=torch.int64, device=self._device) + # Σ(Lᵢ²) over grad-acc micro-batches — two flavors: + # unpadded: only real tokens (useful work), drives mfu_pct + # padded: all slots including CP-zigzag / BSHD row padding, drives mfu_padded_pct + self._attn_work_unpadded_accum = torch.tensor(0, dtype=torch.int64, device=self._device) + self._attn_work_padded_accum = torch.tensor(0, dtype=torch.int64, device=self._device) self.running_loss = torch.tensor(0.0, device=self._device) self.grad_acc_step_count = 0 @@ -304,10 +340,14 @@ def log_micro_step(self, step: int, batch: dict[str, torch.Tensor], outputs: Cau # Fallback for pure sequence packing with no padding: all tokens are unpadded self.num_unpadded_tokens += batch["input_ids"].numel() if self._log_mfu: - # Σ(Lᵢ²) from cu_seq_lens (THD) or B·S² from input_ids shape (BSHD). + # Accumulate both unpadded (useful) and padded (hardware) Σ(Lᵢ²). # Helper returns a GLOBAL value (pre-CP-shard); log_step divides by cp_size. - # Accumulates across grad-acc micro-batches; drained in log_step. - self._attn_work_accum += _attn_work_from_batch(batch, self._device, self._cp_size) + self._attn_work_unpadded_accum += _attn_work_from_batch( + batch, self._device, self._cp_size, include_padding=False + ) + self._attn_work_padded_accum += _attn_work_from_batch( + batch, self._device, self._cp_size, include_padding=True + ) @nvtx.annotate("PerfLogger.log_step", color="purple") def log_step( @@ -363,24 +403,37 @@ def log_step( self.metrics["train/total_unpadded_tokens_per_batch"].update(self.num_unpadded_tokens) if self._log_mfu: - # Two-term FLOP accounting: - # non_attn_flops = non_attn_per_token * num_tokens - # attn_flops = attn_flop_coeff * Σ(Lᵢ²) / cp_size - # num_tokens follows the PaLM/Megatron/MosaicML convention of counting - # every slot in input_ids (hardware runs QKV/MLP/LM-head matmuls over - # padded positions too). The attention term uses the real per-doc - # Σ(Lᵢ²) from cu_seq_lens because FA gates the score matrix on doc - # boundaries, so a packed batch with fragmented docs does genuinely - # less attention work than a uniform B·S² would imply — the old - # single-term formula overstated MFU in exactly that regime. - attn_work = int(self._attn_work_accum.item()) - non_attn_flops = self._non_attn_per_token_flops * self.num_tokens - attn_flops = (self._attn_flop_coeff * attn_work) // self._cp_size - flops_per_step = non_attn_flops + attn_flops - tflops_per_gpu = flops_per_step / step_time / 1e12 - self.metrics["train/tflops_per_gpu"].update(tflops_per_gpu) + # Two MFU flavors reported side-by-side: + # * mfu_pct = useful work rate. Non-attn over real tokens + # (num_unpadded_tokens), attn over Σ(Lᵢ²) from + # cu_seq_lens_q (THD) or per-row mask (BSHD). + # Drops padding from both terms — what the model + # actually learns from. + # * mfu_padded_pct = hardware view. Non-attn over all slots + # (num_tokens = input_ids.numel), attn over + # cu_seq_lens_q_padded / full B·S² — counts the + # cycles the HW actually burned, including + # CP-zigzag pad and BSHD row pad. + # Both divide the global Σ(Lᵢ²) by cp_size to get per-rank attn work. + attn_unpadded = int(self._attn_work_unpadded_accum.item()) + attn_padded = int(self._attn_work_padded_accum.item()) + num_unpadded = int(self.num_unpadded_tokens.item()) + + non_attn_unpadded = self._non_attn_per_token_flops * num_unpadded + attn_flops_unpadded = (self._attn_flop_coeff * attn_unpadded) // self._cp_size + flops_unpadded = non_attn_unpadded + attn_flops_unpadded + tflops_unpadded = flops_unpadded / step_time / 1e12 + + non_attn_padded = self._non_attn_per_token_flops * self.num_tokens + attn_flops_padded = (self._attn_flop_coeff * attn_padded) // self._cp_size + flops_padded = non_attn_padded + attn_flops_padded + tflops_padded = flops_padded / step_time / 1e12 + + self.metrics["train/tflops_per_gpu"].update(tflops_unpadded) + self.metrics["train/tflops_per_gpu_padded"].update(tflops_padded) if self._peak_tflops is not None: - self.metrics["train/mfu_pct"].update(tflops_per_gpu / self._peak_tflops * 100.0) + self.metrics["train/mfu_pct"].update(tflops_unpadded / self._peak_tflops * 100.0) + self.metrics["train/mfu_padded_pct"].update(tflops_padded / self._peak_tflops * 100.0) # Report TRUE peak memory across the logging window (FSDP-gathered params + # activations held for backward), not just the post-step resting footprint. @@ -413,7 +466,8 @@ def log_step( self.running_loss.zero_() self.num_tokens = 0 self.num_unpadded_tokens.zero_() - self._attn_work_accum.zero_() + self._attn_work_unpadded_accum.zero_() + self._attn_work_padded_accum.zero_() self.grad_acc_step_count = 0 def finish(self): diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py index 99061d431c..3825fbe46f 100644 --- a/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py +++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py @@ -428,3 +428,25 @@ def test_bshd_cp_correction(self): } assert _attn_work_from_batch(thd, torch.device("cpu"), cp_size=1).item() == 3**2 + 5**2 + 7**2 assert _attn_work_from_batch(thd, torch.device("cpu"), cp_size=8).item() == 3**2 + 5**2 + 7**2 + + def test_include_padding_thd(self): + """THD include_padding=True uses cu_seq_lens_q_padded (zigzag pad); False uses cu_seq_lens_q.""" + batch = { + "input_ids": torch.zeros(1, 16, dtype=torch.long), + "cu_seq_lens_q": torch.tensor([0, 5, 11], dtype=torch.int32), + "cu_seq_lens_q_padded": torch.tensor([0, 8, 16], dtype=torch.int32), + } + dev = torch.device("cpu") + assert _attn_work_from_batch(batch, dev, cp_size=1, include_padding=False).item() == 5**2 + 6**2 # 61 + assert _attn_work_from_batch(batch, dev, cp_size=1, include_padding=True).item() == 8**2 + 8**2 # 128 + + def test_include_padding_bshd_with_attention_mask(self): + """BSHD include_padding=False uses attention_mask (real per-row lengths); True uses full shape.""" + # 2 rows, each padded to 8 slots; real lengths 5 and 3. + mask = torch.tensor([[1, 1, 1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0, 0, 0]], dtype=torch.int64) + batch = {"input_ids": torch.zeros(2, 8, dtype=torch.long), "attention_mask": mask} + dev = torch.device("cpu") + # Unpadded (real): 5² + 3² = 34 + assert _attn_work_from_batch(batch, dev, cp_size=1, include_padding=False).item() == 34 + # Padded (hardware view): 2 * 8² = 128 + assert _attn_work_from_batch(batch, dev, cp_size=1, include_padding=True).item() == 2 * 8 * 8 diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py b/bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py index 2d86bdf15a..6e8b41d108 100644 --- a/bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py @@ -150,26 +150,49 @@ def _compute_attn_flop_coeff(model_config_dict: dict) -> int: return 3 * num_layers * 4 * h -def _attn_work_from_batch(batch: dict, device: torch.device, cp_size: int = 1) -> torch.Tensor: +def _attn_work_from_batch( + batch: dict, device: torch.device, cp_size: int = 1, include_padding: bool = False +) -> torch.Tensor: """Return GLOBAL Σ(Lᵢ²) for this batch as an int64 scalar tensor. The caller divides by cp_size in log_step to convert this global number into - per-rank attention work, so this helper always returns a pre-CP-shard quantity. + per-rank attention work; this helper always returns a pre-CP-shard quantity. - * THD: uses ``cu_seq_lens_q`` (real per-doc lengths — already global, the collator - emits boundaries before per-rank sharding). Falls back to ``cu_seq_lens_q_padded`` - if only the padded variant is present. Excludes CP zigzag padding that FA - processes as cycles but contributes nothing to the training signal. - * BSHD: ``batch["input_ids"].shape`` per rank is ``(B, S/cp)`` when CP is active - because ``ContextParallelDataLoaderWrapper`` pre-splits along the seq dim. To - recover the global B·S² from a per-rank shape, multiply by cp_size². When - cp_size=1 this is a no-op (per-rank shape == global shape). + ``include_padding=False`` (default) counts only real tokens — "useful work": + * THD: uses ``cu_seq_lens_q`` (real per-doc lengths, already global). + * BSHD: uses ``attention_mask.sum(dim=-1)`` per row, scaled by ``cp_size²`` to + recover global. Exact when padding distributes evenly across cp chunks; + approximate when padding is concentrated on one end. + + ``include_padding=True`` counts padded positions too — "hardware view": + * THD: uses ``cu_seq_lens_q_padded`` (includes CP zigzag-divisibility padding). + * BSHD: uses full ``input_ids.shape``, scaled by ``cp_size²``. Int32 lens cast to int64 BEFORE squaring (overflow at L ≈ 46k otherwise). """ - cu = batch.get("cu_seq_lens_q") - if cu is None: + if include_padding: cu = batch.get("cu_seq_lens_q_padded") + if cu is None: + cu = batch.get("cu_seq_lens_q") + if cu is not None: + lens = (cu[1:] - cu[:-1]).to(torch.int64) + return (lens * lens).sum() + shape = batch["input_ids"].shape + batch_size, seq_len_per_rank = int(shape[0]), int(shape[-1]) + return torch.tensor( + batch_size * seq_len_per_rank * seq_len_per_rank * cp_size * cp_size, + dtype=torch.int64, + device=device, + ) + cu = batch.get("cu_seq_lens_q") + if cu is not None: + lens = (cu[1:] - cu[:-1]).to(torch.int64) + return (lens * lens).sum() + mask = batch.get("attention_mask") + if mask is not None: + per_row_real = mask.sum(dim=-1).to(torch.int64) + return (per_row_real * per_row_real).sum() * cp_size * cp_size + cu = batch.get("cu_seq_lens_q_padded") if cu is not None: lens = (cu[1:] - cu[:-1]).to(torch.int64) return (lens * lens).sum() @@ -245,9 +268,14 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig, model_confi "train/gpu_memory_allocated_mean_gb": torchmetrics.MeanMetric(), } if self._log_mfu: + # Two TFLOPS/MFU pairs: + # * tflops_per_gpu / mfu_pct — useful work only (no padding) + # * tflops_per_gpu_padded / mfu_padded_pct — hardware view (counts padding slots) metrics_dict["train/tflops_per_gpu"] = torchmetrics.MeanMetric() + metrics_dict["train/tflops_per_gpu_padded"] = torchmetrics.MeanMetric() if self._peak_tflops is not None: metrics_dict["train/mfu_pct"] = torchmetrics.MeanMetric() + metrics_dict["train/mfu_padded_pct"] = torchmetrics.MeanMetric() self.metrics = torchmetrics.MetricCollection(metrics_dict) # We move metrics to a GPU device so we can use torch.distributed to aggregate them before logging. @@ -270,8 +298,11 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig, model_confi # Gradient accumulation tracking self.num_tokens = 0 self.num_unpadded_tokens = torch.tensor(0, dtype=torch.int64, device=self._device) - # Σ(Lᵢ²) over grad-acc micro-batches — drives the attention-FLOP term at log time. - self._attn_work_accum = torch.tensor(0, dtype=torch.int64, device=self._device) + # Σ(Lᵢ²) over grad-acc micro-batches — two flavors: + # unpadded: only real tokens (useful work), drives mfu_pct + # padded: all slots including CP-zigzag / BSHD row padding, drives mfu_padded_pct + self._attn_work_unpadded_accum = torch.tensor(0, dtype=torch.int64, device=self._device) + self._attn_work_padded_accum = torch.tensor(0, dtype=torch.int64, device=self._device) self.running_loss = torch.tensor(0.0, device=self._device) self.grad_acc_step_count = 0 @@ -305,10 +336,14 @@ def log_micro_step(self, step: int, batch: dict[str, torch.Tensor], outputs: Cau # Fallback for pure sequence packing with no padding: all tokens are unpadded self.num_unpadded_tokens += batch["input_ids"].numel() if self._log_mfu: - # Σ(Lᵢ²) from cu_seq_lens (THD) or B·S² from input_ids shape (BSHD). + # Accumulate both unpadded (useful) and padded (hardware) Σ(Lᵢ²). # Helper returns a GLOBAL value (pre-CP-shard); log_step divides by cp_size. - # Accumulates across grad-acc micro-batches; drained in log_step. - self._attn_work_accum += _attn_work_from_batch(batch, self._device, self._cp_size) + self._attn_work_unpadded_accum += _attn_work_from_batch( + batch, self._device, self._cp_size, include_padding=False + ) + self._attn_work_padded_accum += _attn_work_from_batch( + batch, self._device, self._cp_size, include_padding=True + ) @nvtx.annotate("PerfLogger.log_step", color="purple") def log_step( @@ -364,24 +399,30 @@ def log_step( self.metrics["train/total_unpadded_tokens_per_batch"].update(self.num_unpadded_tokens) if self._log_mfu: - # Two-term FLOP accounting: - # non_attn_flops = non_attn_per_token * num_tokens - # attn_flops = attn_flop_coeff * Σ(Lᵢ²) / cp_size - # num_tokens follows the PaLM/Megatron/MosaicML convention of counting - # every slot in input_ids (hardware runs QKV/MLP/LM-head matmuls over - # padded positions too). The attention term uses the real per-doc - # Σ(Lᵢ²) from cu_seq_lens because FA gates the score matrix on doc - # boundaries, so a packed batch with fragmented docs does genuinely - # less attention work than a uniform B·S² would imply — the old - # single-term formula overstated MFU in exactly that regime. - attn_work = int(self._attn_work_accum.item()) - non_attn_flops = self._non_attn_per_token_flops * self.num_tokens - attn_flops = (self._attn_flop_coeff * attn_work) // self._cp_size - flops_per_step = non_attn_flops + attn_flops - tflops_per_gpu = flops_per_step / step_time / 1e12 - self.metrics["train/tflops_per_gpu"].update(tflops_per_gpu) + # Two MFU flavors reported side-by-side: + # mfu_pct = useful-work rate. Non-attn over real tokens, + # attn over real Σ(Lᵢ²). Drops both padding types. + # mfu_padded_pct = hardware view. Non-attn over all slots, attn over + # padded Σ(Lᵢ²) (includes CP zigzag + BSHD row pad). + attn_unpadded = int(self._attn_work_unpadded_accum.item()) + attn_padded = int(self._attn_work_padded_accum.item()) + num_unpadded = int(self.num_unpadded_tokens.item()) + + non_attn_unpadded = self._non_attn_per_token_flops * num_unpadded + attn_flops_unpadded = (self._attn_flop_coeff * attn_unpadded) // self._cp_size + flops_unpadded = non_attn_unpadded + attn_flops_unpadded + tflops_unpadded = flops_unpadded / step_time / 1e12 + + non_attn_padded = self._non_attn_per_token_flops * self.num_tokens + attn_flops_padded = (self._attn_flop_coeff * attn_padded) // self._cp_size + flops_padded = non_attn_padded + attn_flops_padded + tflops_padded = flops_padded / step_time / 1e12 + + self.metrics["train/tflops_per_gpu"].update(tflops_unpadded) + self.metrics["train/tflops_per_gpu_padded"].update(tflops_padded) if self._peak_tflops is not None: - self.metrics["train/mfu_pct"].update(tflops_per_gpu / self._peak_tflops * 100.0) + self.metrics["train/mfu_pct"].update(tflops_unpadded / self._peak_tflops * 100.0) + self.metrics["train/mfu_padded_pct"].update(tflops_padded / self._peak_tflops * 100.0) # Report TRUE peak memory across the logging window (FSDP-gathered params + # activations held for backward), not just the post-step resting footprint. @@ -414,7 +455,8 @@ def log_step( self.running_loss.zero_() self.num_tokens = 0 self.num_unpadded_tokens.zero_() - self._attn_work_accum.zero_() + self._attn_work_unpadded_accum.zero_() + self._attn_work_padded_accum.zero_() self.grad_acc_step_count = 0 def log_validation(self, step: int, val_metrics: dict): diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/tests/test_perf_logger.py b/bionemo-recipes/recipes/opengenome2_llama_native_te/tests/test_perf_logger.py index 6e6b242a4d..b64206db5d 100644 --- a/bionemo-recipes/recipes/opengenome2_llama_native_te/tests/test_perf_logger.py +++ b/bionemo-recipes/recipes/opengenome2_llama_native_te/tests/test_perf_logger.py @@ -133,3 +133,22 @@ def test_bshd_cp_correction(self): } assert _attn_work_from_batch(thd, torch.device("cpu"), cp_size=1).item() == 3**2 + 5**2 + 7**2 assert _attn_work_from_batch(thd, torch.device("cpu"), cp_size=8).item() == 3**2 + 5**2 + 7**2 + + def test_include_padding_thd(self): + """THD include_padding=True uses cu_seq_lens_q_padded; False uses cu_seq_lens_q.""" + batch = { + "input_ids": torch.zeros(1, 16, dtype=torch.long), + "cu_seq_lens_q": torch.tensor([0, 5, 11], dtype=torch.int32), + "cu_seq_lens_q_padded": torch.tensor([0, 8, 16], dtype=torch.int32), + } + dev = torch.device("cpu") + assert _attn_work_from_batch(batch, dev, cp_size=1, include_padding=False).item() == 5**2 + 6**2 + assert _attn_work_from_batch(batch, dev, cp_size=1, include_padding=True).item() == 8**2 + 8**2 + + def test_include_padding_bshd_with_attention_mask(self): + """BSHD include_padding=False uses attention_mask; True uses full shape.""" + mask = torch.tensor([[1, 1, 1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0, 0, 0]], dtype=torch.int64) + batch = {"input_ids": torch.zeros(2, 8, dtype=torch.long), "attention_mask": mask} + dev = torch.device("cpu") + assert _attn_work_from_batch(batch, dev, cp_size=1, include_padding=False).item() == 5**2 + 3**2 + assert _attn_work_from_batch(batch, dev, cp_size=1, include_padding=True).item() == 2 * 8 * 8 From 909c1d76ade99e79084a63663a7fa8e697be6b54 Mon Sep 17 00:00:00 2001 From: Gagan Kaushik Date: Thu, 23 Apr 2026 11:10:48 -0700 Subject: [PATCH 18/24] perf_logger: remove legacy _compute_per_token_flops back-compat shim MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The single-term per-token FLOP function was kept around after the two-term split (non_attn + coeff·Σ(Lᵢ²)) was introduced, solely so old unit tests and the startup log line could still reference it. Nothing in the runtime path was using it — log_step had already switched to the split formula. Remove the function, the self._per_token_flops attribute, and the test scaffolding that existed only to pin the split against the legacy reference: - _compute_per_token_flops(cfg, seq_len) deleted from all 4 recipes - self._per_token_flops init + assignment + log-line mention removed - logger.info startup format no longer includes "per-token FLOPs=%e" - TestComputePerTokenFlops test class deleted (llama3 only) - test_algebraic_identity deleted (its whole purpose was to pin non_attn + coeff·S against the legacy function) - test_bshd_no_op simplified to test_bshd_shape_synthesis — keeps the Σ(Lᵢ²)=B·S² shape-synthesis check, drops the legacy comparison - Docstrings no longer reference the old function The tests that exercise actual formula correctness (test_thd_multi_doc_uses_squared_sum, test_cp_size_divides_attention_only, test_bshd_cp_correction, test_include_padding_*, etc.) all stay — they verify behavior directly without going through the legacy. Net: -336 / +40 lines. Pure dead-code removal, no runtime change. Signed-off-by: Gagan Kaushik --- .../recipes/codonfm_native_te/perf_logger.py | 50 +------- .../tests/test_perf_logger.py | 19 +-- .../recipes/esm2_native_te/perf_logger.py | 50 +------- .../esm2_native_te/tests/test_perf_logger.py | 19 +-- .../recipes/llama3_native_te/perf_logger.py | 50 +------- .../tests/test_perf_logger.py | 110 ++---------------- .../perf_logger.py | 50 +------- .../tests/test_perf_logger.py | 28 ++--- 8 files changed, 40 insertions(+), 336 deletions(-) diff --git a/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py b/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py index 7623a0bf21..e45a1f037c 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py @@ -66,50 +66,13 @@ def _detect_peak_tflops_bf16(): return None, name -def _compute_per_token_flops(model_config_dict: dict, seq_len: int) -> int: - """Training FLOPs per token for a transformer (forward + backward = 3x forward). - - First-principles matmul count: Q/K/V/O projections (GQA-aware), attention - logits/values (the S^2 cost expressed per-token as 4*S*H for a uniform - BSHD batch of length seq_len), 2-or-3-projection MLP (SwiGLU detected via - model_type), and LM head. - - Kept for back-compat. For accurate per-step accounting use - ``_compute_non_attn_per_token_flops`` (applied to the total token count) - together with ``_compute_attn_flop_coeff`` (applied to Σ(Lᵢ²) from - cu_seq_lens), since a packed THD batch of total length S containing docs - L₁, L₂, … has actual attention work Σ(Lᵢ²) ≤ S², not B·S². - """ - h = model_config_dict["hidden_size"] - n_heads = model_config_dict["num_attention_heads"] - n_kv = model_config_dict.get("num_key_value_heads", n_heads) - head_dim = h // n_heads - kv_dim = n_kv * head_dim - ffn = model_config_dict["intermediate_size"] - vocab = model_config_dict.get("vocab_size", 0) - num_layers = model_config_dict["num_hidden_layers"] - model_type = model_config_dict.get("model_type", "") - num_mlp_proj = 3 if model_type in _GATED_MLP_MODEL_TYPES else 2 - - per_layer = ( - 2 * h * h # Q projection - + 4 * h * kv_dim # K + V projections (GQA-aware) - + 2 * h * h # O projection - + 4 * seq_len * h # attention logits + values (S^2 -> S per token) - + 2 * num_mlp_proj * h * ffn # MLP (2 or 3 projections) - ) - lm_head = 2 * h * vocab if vocab > 0 else 0 - per_token_fwd = num_layers * per_layer + lm_head - return 3 * per_token_fwd - - def _compute_non_attn_per_token_flops(model_config_dict: dict) -> int: """Per-token FLOPs for everything EXCEPT the S² attention term. Q/K/V/O projections (GQA-aware) + MLP + LM head, 3x for fwd+bwd. Multiply by the actual total token count of the batch to get per-step non-attention FLOPs. Pairs - with ``_compute_attn_flop_coeff`` so that - ``non_attn + coeff·S ≡ _compute_per_token_flops(cfg, S)``. + with ``_compute_attn_flop_coeff``, which contributes the attention term as + ``coeff · Σ(Lᵢ²)`` from cu_seq_lens. """ h = model_config_dict["hidden_size"] n_heads = model_config_dict["num_attention_heads"] @@ -228,26 +191,23 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig, model_confi # step are derived at log time from the tracked unpadded token count, which already # reflects each rank's share under DP and sequence packing. self._log_mfu = bool(args.get("log_mfu", False)) and model_config_dict is not None - self._per_token_flops = 0 self._non_attn_per_token_flops = 0 self._attn_flop_coeff = 0 self._cp_size = int(args.get("cp_size", 1)) self._peak_tflops: float | None = None if self._log_mfu: - self._per_token_flops = _compute_per_token_flops(model_config_dict, args.dataset.max_seq_length) self._non_attn_per_token_flops = _compute_non_attn_per_token_flops(model_config_dict) self._attn_flop_coeff = _compute_attn_flop_coeff(model_config_dict) self._peak_tflops, gpu_name = _detect_peak_tflops_bf16() if dist_config.local_rank == 0: logger.info( - "MFU tracking enabled: GPU=%s, peak=%s TFLOPS BF16, per-token FLOPs=%.3e, seq_len=%d, " - "non_attn_per_token=%.3e, attn_coeff=%.3e, cp_size=%d", + "MFU tracking enabled: GPU=%s, peak=%s TFLOPS BF16, " + "non_attn_per_token=%.3e, attn_coeff=%.3e, seq_len=%d, cp_size=%d", gpu_name, f"{self._peak_tflops:.1f}" if self._peak_tflops else "unknown", - float(self._per_token_flops), - args.dataset.max_seq_length, float(self._non_attn_per_token_flops), float(self._attn_flop_coeff), + args.dataset.max_seq_length, self._cp_size, ) diff --git a/bionemo-recipes/recipes/codonfm_native_te/tests/test_perf_logger.py b/bionemo-recipes/recipes/codonfm_native_te/tests/test_perf_logger.py index 41eb31ec24..83f2358061 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/tests/test_perf_logger.py +++ b/bionemo-recipes/recipes/codonfm_native_te/tests/test_perf_logger.py @@ -26,7 +26,6 @@ _attn_work_from_batch, _compute_attn_flop_coeff, _compute_non_attn_per_token_flops, - _compute_per_token_flops, ) from transformers.modeling_outputs import MaskedLMOutput @@ -231,26 +230,14 @@ def _codon_cfg(): class TestFlopSplitAndAttention: - """Verify the split non-attn + Σ(Lᵢ²) attention formula.""" + """Verify the non-attn + Σ(Lᵢ²) attention formula is correctly computed.""" - def test_algebraic_identity(self): - """non_attn + coeff·S ≡ _compute_per_token_flops(cfg, S) for all S.""" - cfg = _codon_cfg() - for s in (256, 512, 1024, 8192): - lhs = _compute_non_attn_per_token_flops(cfg) + _compute_attn_flop_coeff(cfg) * s - rhs = _compute_per_token_flops(cfg, s) - assert lhs == rhs, f"S={s}: {lhs} != {rhs}" - - def test_bshd_no_op(self): - """BSHD batch (no cu_seq_lens) with cp=1 matches legacy formula exactly.""" - cfg = _codon_cfg() + def test_bshd_shape_synthesis(self): + """BSHD batch (no cu_seq_lens) synthesizes Σ(Lᵢ²) = B·S² from input_ids shape.""" b, s = 4, 512 batch = {"input_ids": torch.zeros(b, s, dtype=torch.long)} sigma_l_sq = _attn_work_from_batch(batch, torch.device("cpu")).item() assert sigma_l_sq == b * s * s - new_flops = _compute_non_attn_per_token_flops(cfg) * b * s + _compute_attn_flop_coeff(cfg) * sigma_l_sq - legacy_flops = _compute_per_token_flops(cfg, s) * b * s - assert new_flops == legacy_flops def test_thd_single_doc_matches_bshd(self): """cu_seq_lens_q=[0, S] reproduces BSHD's Σ(Lᵢ²)=S².""" diff --git a/bionemo-recipes/recipes/esm2_native_te/perf_logger.py b/bionemo-recipes/recipes/esm2_native_te/perf_logger.py index f5e64caf8c..bac8684e30 100644 --- a/bionemo-recipes/recipes/esm2_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/esm2_native_te/perf_logger.py @@ -65,50 +65,13 @@ def _detect_peak_tflops_bf16(): return None, name -def _compute_per_token_flops(model_config_dict: dict, seq_len: int) -> int: - """Training FLOPs per token for a transformer (forward + backward = 3x forward). - - First-principles matmul count: Q/K/V/O projections (GQA-aware), attention - logits/values (the S^2 cost expressed per-token as 4*S*H for a uniform - BSHD batch of length seq_len), 2-or-3-projection MLP (SwiGLU detected via - model_type), and LM head. - - Kept for back-compat. For accurate per-step accounting use - ``_compute_non_attn_per_token_flops`` (applied to the total token count) - together with ``_compute_attn_flop_coeff`` (applied to Σ(Lᵢ²) from - cu_seq_lens), since a packed THD batch of total length S containing docs - L₁, L₂, … has actual attention work Σ(Lᵢ²) ≤ S², not B·S². - """ - h = model_config_dict["hidden_size"] - n_heads = model_config_dict["num_attention_heads"] - n_kv = model_config_dict.get("num_key_value_heads", n_heads) - head_dim = h // n_heads - kv_dim = n_kv * head_dim - ffn = model_config_dict["intermediate_size"] - vocab = model_config_dict.get("vocab_size", 0) - num_layers = model_config_dict["num_hidden_layers"] - model_type = model_config_dict.get("model_type", "") - num_mlp_proj = 3 if model_type in _GATED_MLP_MODEL_TYPES else 2 - - per_layer = ( - 2 * h * h # Q projection - + 4 * h * kv_dim # K + V projections (GQA-aware) - + 2 * h * h # O projection - + 4 * seq_len * h # attention logits + values (S^2 -> S per token) - + 2 * num_mlp_proj * h * ffn # MLP (2 or 3 projections) - ) - lm_head = 2 * h * vocab if vocab > 0 else 0 - per_token_fwd = num_layers * per_layer + lm_head - return 3 * per_token_fwd - - def _compute_non_attn_per_token_flops(model_config_dict: dict) -> int: """Per-token FLOPs for everything EXCEPT the S² attention term. Q/K/V/O projections (GQA-aware) + MLP + LM head, 3x for fwd+bwd. Multiply by the actual total token count of the batch to get per-step non-attention FLOPs. Pairs - with ``_compute_attn_flop_coeff`` so that - ``non_attn + coeff·S ≡ _compute_per_token_flops(cfg, S)``. + with ``_compute_attn_flop_coeff``, which contributes the attention term as + ``coeff · Σ(Lᵢ²)`` from cu_seq_lens. """ h = model_config_dict["hidden_size"] n_heads = model_config_dict["num_attention_heads"] @@ -231,26 +194,23 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig, model_confi # step are derived at log time from the accumulated token count + Σ(Lᵢ²), which # already reflects each rank's share under DP/CP and sequence packing. self._log_mfu = bool(args.get("log_mfu", False)) and model_config_dict is not None - self._per_token_flops = 0 self._non_attn_per_token_flops = 0 self._attn_flop_coeff = 0 self._cp_size = int(args.get("cp_size", 1)) self._peak_tflops: float | None = None if self._log_mfu: - self._per_token_flops = _compute_per_token_flops(model_config_dict, args.dataset.max_seq_length) self._non_attn_per_token_flops = _compute_non_attn_per_token_flops(model_config_dict) self._attn_flop_coeff = _compute_attn_flop_coeff(model_config_dict) self._peak_tflops, gpu_name = _detect_peak_tflops_bf16() if dist_config.local_rank == 0: logger.info( - "MFU tracking enabled: GPU=%s, peak=%s TFLOPS BF16, per-token FLOPs=%.3e, seq_len=%d, " - "non_attn_per_token=%.3e, attn_coeff=%.3e, cp_size=%d", + "MFU tracking enabled: GPU=%s, peak=%s TFLOPS BF16, " + "non_attn_per_token=%.3e, attn_coeff=%.3e, seq_len=%d, cp_size=%d", gpu_name, f"{self._peak_tflops:.1f}" if self._peak_tflops else "unknown", - float(self._per_token_flops), - args.dataset.max_seq_length, float(self._non_attn_per_token_flops), float(self._attn_flop_coeff), + args.dataset.max_seq_length, self._cp_size, ) diff --git a/bionemo-recipes/recipes/esm2_native_te/tests/test_perf_logger.py b/bionemo-recipes/recipes/esm2_native_te/tests/test_perf_logger.py index 8d6d1bbb49..b543e27c2e 100644 --- a/bionemo-recipes/recipes/esm2_native_te/tests/test_perf_logger.py +++ b/bionemo-recipes/recipes/esm2_native_te/tests/test_perf_logger.py @@ -35,7 +35,6 @@ _attn_work_from_batch, _compute_attn_flop_coeff, _compute_non_attn_per_token_flops, - _compute_per_token_flops, ) @@ -96,26 +95,14 @@ def _esm_cfg(): class TestFlopSplitAndAttention: - """Verify the split non-attn + Σ(Lᵢ²) attention formula for ESM-2.""" + """Verify the non-attn + Σ(Lᵢ²) attention formula is correctly computed for ESM-2.""" - def test_algebraic_identity(self): - """non_attn + coeff·S ≡ _compute_per_token_flops(cfg, S) for all S.""" - cfg = _esm_cfg() - for s in (256, 1024, 8192, 131072): - lhs = _compute_non_attn_per_token_flops(cfg) + _compute_attn_flop_coeff(cfg) * s - rhs = _compute_per_token_flops(cfg, s) - assert lhs == rhs, f"S={s}: {lhs} != {rhs}" - - def test_bshd_no_op(self): - """BSHD batch (no cu_seq_lens) with cp=1 matches legacy formula exactly.""" - cfg = _esm_cfg() + def test_bshd_shape_synthesis(self): + """BSHD batch (no cu_seq_lens) synthesizes Σ(Lᵢ²) = B·S² from input_ids shape.""" b, s = 2, 512 batch = {"input_ids": torch.zeros(b, s, dtype=torch.long)} sigma_l_sq = _attn_work_from_batch(batch, torch.device("cpu")).item() assert sigma_l_sq == b * s * s - new_flops = _compute_non_attn_per_token_flops(cfg) * b * s + _compute_attn_flop_coeff(cfg) * sigma_l_sq - legacy_flops = _compute_per_token_flops(cfg, s) * b * s - assert new_flops == legacy_flops def test_thd_single_doc_matches_bshd(self): """cu_seq_lens_q=[0, S] reproduces BSHD's Σ(Lᵢ²)=S².""" diff --git a/bionemo-recipes/recipes/llama3_native_te/perf_logger.py b/bionemo-recipes/recipes/llama3_native_te/perf_logger.py index dbe4469857..58136173f7 100644 --- a/bionemo-recipes/recipes/llama3_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/llama3_native_te/perf_logger.py @@ -63,50 +63,13 @@ def _detect_peak_tflops_bf16(): return None, name -def _compute_per_token_flops(model_config_dict: dict, seq_len: int) -> int: - """Training FLOPs per token for a transformer (forward + backward = 3x forward). - - First-principles matmul count: Q/K/V/O projections (GQA-aware), attention - logits/values (the S^2 cost expressed per-token as 4*S*H for a uniform - BSHD batch of length seq_len), 2-or-3-projection MLP (SwiGLU detected via - model_type), and LM head. - - Kept for back-compat. For accurate per-step accounting use - ``_compute_non_attn_per_token_flops`` (applied to the total token count) - together with ``_compute_attn_flop_coeff`` (applied to Σ(Lᵢ²) from - cu_seq_lens), since a packed THD batch of total length S containing docs - L₁, L₂, … has actual attention work Σ(Lᵢ²) ≤ S², not B·S². - """ - h = model_config_dict["hidden_size"] - n_heads = model_config_dict["num_attention_heads"] - n_kv = model_config_dict.get("num_key_value_heads", n_heads) - head_dim = h // n_heads - kv_dim = n_kv * head_dim - ffn = model_config_dict["intermediate_size"] - vocab = model_config_dict.get("vocab_size", 0) - num_layers = model_config_dict["num_hidden_layers"] - model_type = model_config_dict.get("model_type", "") - num_mlp_proj = 3 if model_type in _GATED_MLP_MODEL_TYPES else 2 - - per_layer = ( - 2 * h * h # Q projection - + 4 * h * kv_dim # K + V projections (GQA-aware) - + 2 * h * h # O projection - + 4 * seq_len * h # attention logits + values (S^2 -> S per token) - + 2 * num_mlp_proj * h * ffn # MLP (2 or 3 projections) - ) - lm_head = 2 * h * vocab if vocab > 0 else 0 - per_token_fwd = num_layers * per_layer + lm_head - return 3 * per_token_fwd - - def _compute_non_attn_per_token_flops(model_config_dict: dict) -> int: """Per-token FLOPs for everything EXCEPT the S² attention term. Q/K/V/O projections (GQA-aware) + MLP + LM head, 3x for fwd+bwd. Multiply by the actual total token count of the batch to get per-step non-attention FLOPs. Pairs - with ``_compute_attn_flop_coeff`` so that - ``non_attn + coeff·S ≡ _compute_per_token_flops(cfg, S)``. + with ``_compute_attn_flop_coeff``, which contributes the attention term as + ``coeff · Σ(Lᵢ²)`` from cu_seq_lens. """ h = model_config_dict["hidden_size"] n_heads = model_config_dict["num_attention_heads"] @@ -237,26 +200,23 @@ def __init__( # step are derived at log time from the tracked unpadded token count, which already # reflects each rank's share under DP/CP and sequence packing. self._log_mfu = bool(args.get("log_mfu", False)) and model_config_dict is not None - self._per_token_flops = 0 self._non_attn_per_token_flops = 0 self._attn_flop_coeff = 0 self._cp_size = int(args.get("cp_size", 1)) self._peak_tflops: float | None = None if self._log_mfu: - self._per_token_flops = _compute_per_token_flops(model_config_dict, args.dataset.max_seq_length) self._non_attn_per_token_flops = _compute_non_attn_per_token_flops(model_config_dict) self._attn_flop_coeff = _compute_attn_flop_coeff(model_config_dict) self._peak_tflops, gpu_name = _detect_peak_tflops_bf16() if dist_config.local_rank == 0: logger.info( - "MFU tracking enabled: GPU=%s, peak=%s TFLOPS BF16, per-token FLOPs=%.3e, seq_len=%d, " - "non_attn_per_token=%.3e, attn_coeff=%.3e, cp_size=%d", + "MFU tracking enabled: GPU=%s, peak=%s TFLOPS BF16, " + "non_attn_per_token=%.3e, attn_coeff=%.3e, seq_len=%d, cp_size=%d", gpu_name, f"{self._peak_tflops:.1f}" if self._peak_tflops else "unknown", - float(self._per_token_flops), - args.dataset.max_seq_length, float(self._non_attn_per_token_flops), float(self._attn_flop_coeff), + args.dataset.max_seq_length, self._cp_size, ) diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py index 3825fbe46f..dee275a9b1 100644 --- a/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py +++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py @@ -28,7 +28,6 @@ _attn_work_from_batch, _compute_attn_flop_coeff, _compute_non_attn_per_token_flops, - _compute_per_token_flops, _detect_peak_tflops_bf16, ) @@ -219,90 +218,6 @@ def test_min_loss_tracked_correctly(self, mock_wandb, mock_tqdm): assert perf_logger.min_loss.item() == pytest.approx(1.0) -class TestComputePerTokenFlops: - """Test that the per-token training FLOPs formula matches hand-calculated values.""" - - def test_llama_gqa_swiglu(self): - """Llama-style config: GQA (n_kv=8 < n_heads=32) + SwiGLU (3 MLP projections).""" - config = { - "model_type": "llama", - "hidden_size": 4096, - "num_hidden_layers": 32, - "num_attention_heads": 32, - "num_key_value_heads": 8, # GQA - "intermediate_size": 14336, - "vocab_size": 128256, - } - seq_len = 8192 - h, i, v, kv_dim, layers = 4096, 14336, 128256, 8 * 128, 32 - # Per-layer: Q (2h^2) + K+V (4h*kv_dim) + O (2h^2) + attn (4*S*h) + MLP (2*3*h*i) - per_layer = 2 * h * h + 4 * h * kv_dim + 2 * h * h + 4 * seq_len * h + 2 * 3 * h * i - expected_fwd = layers * per_layer + 2 * h * v - assert _compute_per_token_flops(config, seq_len) == 3 * expected_fwd - - def test_esm_mha_gelu(self): - """ESM2-style config: MHA (no num_key_value_heads) + standard FFN (2 MLP projections).""" - config = { - "model_type": "esm", - "hidden_size": 1280, - "num_hidden_layers": 33, - "num_attention_heads": 20, - "intermediate_size": 5120, - "vocab_size": 33, - } - seq_len = 1024 - h, i, v, kv_dim, layers = 1280, 5120, 33, (1280 // 20) * 20, 33 # kv_dim=h for MHA - per_layer = 2 * h * h + 4 * h * kv_dim + 2 * h * h + 4 * seq_len * h + 2 * 2 * h * i - expected_fwd = layers * per_layer + 2 * h * v - assert _compute_per_token_flops(config, seq_len) == 3 * expected_fwd - - def test_scales_with_seq_len(self): - """Only the attention S^2 term should vary with seq_len.""" - config = { - "model_type": "llama", - "hidden_size": 2048, - "num_hidden_layers": 16, - "num_attention_heads": 16, - "num_key_value_heads": 8, - "intermediate_size": 8192, - "vocab_size": 32000, - } - h, layers = 2048, 16 - # Difference per token between seq_len=1024 and seq_len=2048: - # layers * 4 * (2048 - 1024) * h, times 3 (forward+backward) - diff = _compute_per_token_flops(config, 2048) - _compute_per_token_flops(config, 1024) - assert diff == 3 * layers * 4 * 1024 * h - - def test_linear_in_unpadded_tokens(self): - """Multiplying per-token FLOPs by N tokens is linear (MFU formula relies on this).""" - config = { - "model_type": "llama", - "hidden_size": 1024, - "num_hidden_layers": 8, - "num_attention_heads": 8, - "intermediate_size": 4096, - "vocab_size": 32000, - } - per_token = _compute_per_token_flops(config, seq_len=512) - assert per_token * 100 == 100 * per_token - # Sanity: doubling unpadded token count doubles total FLOPs - assert per_token * 200 == 2 * (per_token * 100) - - def test_no_lm_head_when_vocab_zero(self): - """vocab_size=0 should drop the LM head term.""" - config_base = { - "model_type": "llama", - "hidden_size": 512, - "num_hidden_layers": 4, - "num_attention_heads": 8, - "intermediate_size": 2048, - } - with_vocab = _compute_per_token_flops({**config_base, "vocab_size": 32000}, seq_len=256) - no_vocab = _compute_per_token_flops({**config_base, "vocab_size": 0}, seq_len=256) - # Difference = 3 (training) * 2 * h * vocab - assert with_vocab - no_vocab == 3 * 2 * 512 * 32000 - - class TestDetectPeakTflops: """Smoke test for GPU peak TFLOPS detection.""" @@ -327,31 +242,20 @@ def _llama_cfg(): class TestFlopSplitAndAttention: - """Verify the split non-attn + Σ(Lᵢ²) attention formula. + """Verify the non-attn + Σ(Lᵢ²) attention formula is correctly computed. - The old single-term ``_per_token_flops * num_tokens`` formula treats a packed - batch as one giant S*S attention. Real Flash-Attention work is Σ(Lᵢ²) over - packed segments. These tests lock in the new split and its invariants. + Non-attention FLOPs are tracked per real token; attention FLOPs are tracked as + coeff * Σ(Lᵢ²) over per-doc real lengths. These tests lock in the formula and + its invariants (shape synthesis for BSHD, cu_seq_lens handling for THD, CP + division, unpadded/padded behavior, fallbacks). """ - def test_algebraic_identity(self): - """non_attn + coeff·S ≡ _compute_per_token_flops(cfg, S) for all S.""" - cfg = _llama_cfg() - for s in (256, 1024, 8192): - lhs = _compute_non_attn_per_token_flops(cfg) + _compute_attn_flop_coeff(cfg) * s - rhs = _compute_per_token_flops(cfg, s) - assert lhs == rhs, f"S={s}: {lhs} != {rhs}" - - def test_bshd_no_op(self): - """BSHD batch (no cu_seq_lens) with cp=1 matches legacy formula exactly.""" - cfg = _llama_cfg() + def test_bshd_shape_synthesis(self): + """BSHD batch (no cu_seq_lens) synthesizes Σ(Lᵢ²) = B·S² from input_ids shape.""" b, s = 2, 512 batch = {"input_ids": torch.zeros(b, s, dtype=torch.long)} sigma_l_sq = _attn_work_from_batch(batch, torch.device("cpu")).item() assert sigma_l_sq == b * s * s - new_flops = _compute_non_attn_per_token_flops(cfg) * b * s + _compute_attn_flop_coeff(cfg) * sigma_l_sq - legacy_flops = _compute_per_token_flops(cfg, s) * b * s - assert new_flops == legacy_flops def test_thd_single_doc_matches_bshd(self): """cu_seq_lens_q=[0, S] (synthetic-single-doc) reproduces BSHD's Σ(Lᵢ²)=S².""" diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py b/bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py index 6e8b41d108..c61ab44599 100644 --- a/bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py @@ -71,50 +71,13 @@ def _detect_peak_tflops_bf16(): return None, name -def _compute_per_token_flops(model_config_dict: dict, seq_len: int) -> int: - """Training FLOPs per token for a transformer (forward + backward = 3x forward). - - First-principles matmul count: Q/K/V/O projections (GQA-aware), attention - logits/values (the S^2 cost expressed per-token as 4*S*H for a uniform - BSHD batch of length seq_len), 2-or-3-projection MLP (SwiGLU detected via - model_type), and LM head. - - Kept for back-compat. For accurate per-step accounting use - ``_compute_non_attn_per_token_flops`` (applied to the total token count) - together with ``_compute_attn_flop_coeff`` (applied to Σ(Lᵢ²) from - cu_seq_lens), since a packed THD batch of total length S containing docs - L₁, L₂, … has actual attention work Σ(Lᵢ²) ≤ S², not B·S². - """ - h = model_config_dict["hidden_size"] - n_heads = model_config_dict["num_attention_heads"] - n_kv = model_config_dict.get("num_key_value_heads", n_heads) - head_dim = h // n_heads - kv_dim = n_kv * head_dim - ffn = model_config_dict["intermediate_size"] - vocab = model_config_dict.get("vocab_size", 0) - num_layers = model_config_dict["num_hidden_layers"] - model_type = model_config_dict.get("model_type", "") - num_mlp_proj = 3 if model_type in _GATED_MLP_MODEL_TYPES else 2 - - per_layer = ( - 2 * h * h # Q projection - + 4 * h * kv_dim # K + V projections (GQA-aware) - + 2 * h * h # O projection - + 4 * seq_len * h # attention logits + values (S^2 -> S per token) - + 2 * num_mlp_proj * h * ffn # MLP (2 or 3 projections) - ) - lm_head = 2 * h * vocab if vocab > 0 else 0 - per_token_fwd = num_layers * per_layer + lm_head - return 3 * per_token_fwd - - def _compute_non_attn_per_token_flops(model_config_dict: dict) -> int: """Per-token FLOPs for everything EXCEPT the S² attention term. Q/K/V/O projections (GQA-aware) + MLP + LM head, 3x for fwd+bwd. Multiply by the actual total token count of the batch to get per-step non-attention FLOPs. Pairs - with ``_compute_attn_flop_coeff`` so that - ``non_attn + coeff·S ≡ _compute_per_token_flops(cfg, S)``. + with ``_compute_attn_flop_coeff``, which contributes the attention term as + ``coeff · Σ(Lᵢ²)`` from cu_seq_lens. """ h = model_config_dict["hidden_size"] n_heads = model_config_dict["num_attention_heads"] @@ -233,26 +196,23 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig, model_confi # step are derived at log time from the tracked unpadded token count, which already # reflects each rank's share under DP/CP and sequence packing. self._log_mfu = bool(args.get("log_mfu", False)) and model_config_dict is not None - self._per_token_flops = 0 self._non_attn_per_token_flops = 0 self._attn_flop_coeff = 0 self._cp_size = int(args.get("cp_size", 1)) self._peak_tflops: float | None = None if self._log_mfu: - self._per_token_flops = _compute_per_token_flops(model_config_dict, args.dataset.max_seq_length) self._non_attn_per_token_flops = _compute_non_attn_per_token_flops(model_config_dict) self._attn_flop_coeff = _compute_attn_flop_coeff(model_config_dict) self._peak_tflops, gpu_name = _detect_peak_tflops_bf16() if dist_config.local_rank == 0: logger.info( - "MFU tracking enabled: GPU=%s, peak=%s TFLOPS BF16, per-token FLOPs=%.3e, seq_len=%d, " - "non_attn_per_token=%.3e, attn_coeff=%.3e, cp_size=%d", + "MFU tracking enabled: GPU=%s, peak=%s TFLOPS BF16, " + "non_attn_per_token=%.3e, attn_coeff=%.3e, seq_len=%d, cp_size=%d", gpu_name, f"{self._peak_tflops:.1f}" if self._peak_tflops else "unknown", - float(self._per_token_flops), - args.dataset.max_seq_length, float(self._non_attn_per_token_flops), float(self._attn_flop_coeff), + args.dataset.max_seq_length, self._cp_size, ) diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/tests/test_perf_logger.py b/bionemo-recipes/recipes/opengenome2_llama_native_te/tests/test_perf_logger.py index b64206db5d..7f21c4aeca 100644 --- a/bionemo-recipes/recipes/opengenome2_llama_native_te/tests/test_perf_logger.py +++ b/bionemo-recipes/recipes/opengenome2_llama_native_te/tests/test_perf_logger.py @@ -13,12 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for the split non-attention + Σ(Lᵢ²) attention FLOP formula. +"""Tests for the non-attention + Σ(Lᵢ²) attention FLOP formula. -The old single-term ``_per_token_flops * num_tokens`` formula treats a packed batch -as one giant S*S attention. Real Flash-Attention work is Σ(Lᵢ²) over packed -segments. These tests lock in the new split and its invariants so future drift -between sibling recipes is caught immediately. +Non-attention FLOPs are tracked per real token; attention FLOPs are tracked as +coeff * Σ(Lᵢ²) over per-doc real lengths. These tests lock in the formula and its +invariants so future drift between sibling recipes is caught immediately. """ import torch @@ -27,7 +26,6 @@ _attn_work_from_batch, _compute_attn_flop_coeff, _compute_non_attn_per_token_flops, - _compute_per_token_flops, ) @@ -45,26 +43,14 @@ def _llama_cfg(): class TestFlopSplitAndAttention: - """Verify the split formula matches the legacy one and correctly handles THD.""" + """Verify the non-attn + Σ(Lᵢ²) attention formula is correctly computed.""" - def test_algebraic_identity(self): - """non_attn + coeff·S ≡ _compute_per_token_flops(cfg, S) for all S.""" - cfg = _llama_cfg() - for s in (256, 1024, 8192, 131072): - lhs = _compute_non_attn_per_token_flops(cfg) + _compute_attn_flop_coeff(cfg) * s - rhs = _compute_per_token_flops(cfg, s) - assert lhs == rhs, f"S={s}: {lhs} != {rhs}" - - def test_bshd_no_op(self): - """BSHD batch (no cu_seq_lens) with cp=1 matches legacy formula exactly.""" - cfg = _llama_cfg() + def test_bshd_shape_synthesis(self): + """BSHD batch (no cu_seq_lens) synthesizes Σ(Lᵢ²) = B·S² from input_ids shape.""" b, s = 2, 512 batch = {"input_ids": torch.zeros(b, s, dtype=torch.long)} sigma_l_sq = _attn_work_from_batch(batch, torch.device("cpu")).item() assert sigma_l_sq == b * s * s - new_flops = _compute_non_attn_per_token_flops(cfg) * b * s + _compute_attn_flop_coeff(cfg) * sigma_l_sq - legacy_flops = _compute_per_token_flops(cfg, s) * b * s - assert new_flops == legacy_flops def test_thd_single_doc_matches_bshd(self): """cu_seq_lens_q=[0, S] (synthetic-single-doc) reproduces BSHD's Σ(Lᵢ²)=S².""" From 423eab74233ff2d11e57fb1a9b4293a6076a8730 Mon Sep 17 00:00:00 2001 From: Gagan Kaushik Date: Thu, 23 Apr 2026 11:17:33 -0700 Subject: [PATCH 19/24] docs: update MFU tracking sections in recipe READMEs Reflect the modern two-pair metric layout (unpadded useful-work vs padded hardware view) and the peak-memory reporting fix. Applied identically to all four MFU-tracking recipes. Signed-off-by: Gagan Kaushik --- .../recipes/codonfm_native_te/README.md | 20 ++++++++-------- .../recipes/esm2_native_te/README.md | 21 +++++++++-------- .../recipes/llama3_native_te/README.md | 23 +++++++++++-------- .../opengenome2_llama_native_te/README.md | 22 ++++++++++-------- 4 files changed, 50 insertions(+), 36 deletions(-) diff --git a/bionemo-recipes/recipes/codonfm_native_te/README.md b/bionemo-recipes/recipes/codonfm_native_te/README.md index e2c9a05089..ad8f04dd08 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/README.md +++ b/bionemo-recipes/recipes/codonfm_native_te/README.md @@ -179,22 +179,24 @@ A final model suitable for uploading to the Hugging Face Hub can be exported at ## MFU Tracking -Enable per-step Model FLOPs Utilization (MFU) logging during training by adding `log_mfu=true`: +Enable per-step MFU logging by adding `log_mfu=true`: ```bash torchrun --nproc_per_node=1 train_fsdp2.py --config-name encodon_1b log_mfu=true ``` -This adds two metrics at each logging interval, emitted alongside existing metrics via WANDB and -stdout: +Two pairs of metrics are emitted per logging interval: -- `train/tflops_per_gpu` — achieved BF16 TFLOPS per GPU -- `train/mfu_pct` — MFU as a percentage of the GPU's peak dense BF16 TFLOPS +- `train/mfu_pct` / `train/tflops_per_gpu` — useful-work rate. Excludes padding of all kinds. +- `train/mfu_padded_pct` / `train/tflops_per_gpu_padded` — hardware view (HFU-like). Counts + every slot the GPU processes, including BSHD row padding. -The FLOPs formula auto-detects model architecture from the model config (MHA, standard FFN, -vocabulary size) and scales with the actual unpadded token count on each rank. This means it -naturally handles gradient accumulation, data parallelism, BSHD, and THD (sequence packing) -without per-strategy code paths. The implementation lives in `perf_logger.py`. +Non-attention uses the unpadded/padded token count respectively; attention uses `Σ(Lᵢ²)` from +`cu_seq_lens_q` (THD) or per-row `attention_mask.sum()` (BSHD) for the unpadded variant and +`cu_seq_lens_q_padded` / full `B·S²` for the padded variant. Implementation in `perf_logger.py`. + +Memory: `train/gpu_memory_allocated_max_gb` is the true transient peak per window; `_mean_gb` is +the post-step resting footprint. ## Developer Guide diff --git a/bionemo-recipes/recipes/esm2_native_te/README.md b/bionemo-recipes/recipes/esm2_native_te/README.md index e97d13eb0e..067841259b 100644 --- a/bionemo-recipes/recipes/esm2_native_te/README.md +++ b/bionemo-recipes/recipes/esm2_native_te/README.md @@ -376,22 +376,25 @@ output = model(**inputs) ## MFU Tracking -Enable per-step Model FLOPs Utilization (MFU) logging during training by adding `log_mfu=true`: +Enable per-step MFU logging by adding `log_mfu=true`: ```bash torchrun --nproc_per_node=2 train_fsdp2.py --config-name L1_3B log_mfu=true ``` -This adds two metrics at each logging interval, emitted alongside existing metrics via WANDB and -stdout: +Two pairs of metrics are emitted per logging interval: -- `train/tflops_per_gpu` — achieved BF16 TFLOPS per GPU -- `train/mfu_pct` — MFU as a percentage of the GPU's peak dense BF16 TFLOPS +- `train/mfu_pct` / `train/tflops_per_gpu` — useful-work rate. Excludes padding of all kinds. + Non-attention uses the unpadded token count; attention uses `Σ(Lᵢ²)` from `cu_seq_lens_q` (THD) + or per-row `attention_mask.sum()` (BSHD). +- `train/mfu_padded_pct` / `train/tflops_per_gpu_padded` — hardware view. Counts every slot the + GPU processes, including CP-zigzag and BSHD row padding. HFU-like. -The FLOPs formula auto-detects model architecture from the HF config (MHA vs. GQA, gated vs. -standard FFN, LM head presence) and scales with the actual unpadded token count on each rank. This -means it naturally handles data parallelism, context parallelism, BSHD, and THD (sequence packing) -without per-strategy code paths. The implementation lives in `perf_logger.py`. +The two pairs agree when the batch has no padding. The formula is CP-aware and auto-detects +MHA/GQA and FFN layout from the HF config. Implementation in `perf_logger.py`. + +Memory: `train/gpu_memory_allocated_max_gb` is the true transient peak per window +(`torch.cuda.max_memory_allocated()` + `reset_peak_memory_stats()`); `_mean_gb` is resting. ## Developer Guide diff --git a/bionemo-recipes/recipes/llama3_native_te/README.md b/bionemo-recipes/recipes/llama3_native_te/README.md index 39dc8d8ac8..c4d2912521 100644 --- a/bionemo-recipes/recipes/llama3_native_te/README.md +++ b/bionemo-recipes/recipes/llama3_native_te/README.md @@ -414,22 +414,27 @@ vllm serve path/to/hf_converted_model ## MFU Tracking -Enable per-step Model FLOPs Utilization (MFU) logging during training by adding `log_mfu=true`: +Enable per-step MFU logging by adding `log_mfu=true`: ```bash torchrun --nproc_per_node=2 train_fsdp2_cp.py --config-name L2_lingua_1b log_mfu=true ``` -This adds two metrics at each logging interval, emitted alongside existing metrics via WANDB and -stdout: +Two pairs of metrics are emitted per logging interval: -- `train/tflops_per_gpu` — achieved BF16 TFLOPS per GPU -- `train/mfu_pct` — MFU as a percentage of the GPU's peak dense BF16 TFLOPS +- `train/mfu_pct` / `train/tflops_per_gpu` — useful-work rate. Excludes padding of all kinds. + Non-attention uses the unpadded token count; attention uses `Σ(Lᵢ²)` from `cu_seq_lens_q` (THD) + or per-row `attention_mask.sum()` (BSHD). +- `train/mfu_padded_pct` / `train/tflops_per_gpu_padded` — hardware view. Counts every slot the + GPU processes, including CP-zigzag and BSHD row padding. HFU-like. -The FLOPs formula auto-detects model architecture from the HF config (GQA vs. MHA, SwiGLU vs. -standard FFN, LM head presence) and scales with the actual unpadded token count on each rank. This -means it naturally handles gradient accumulation, data parallelism, context parallelism, BSHD, and -THD (sequence packing) without per-strategy code paths. The implementation lives in `perf_logger.py`. +The two pairs agree when the batch has no padding (e.g. dense single-doc THD packs). The formula +is CP-aware (global `Σ(Lᵢ²)` divided by `cp_size`) and auto-detects GQA/MHA and SwiGLU/standard +FFN from the HF config. Implementation in `perf_logger.py`. + +Memory metrics: `train/gpu_memory_allocated_max_gb` is the true transient peak per logging window +(via `torch.cuda.max_memory_allocated()` + `reset_peak_memory_stats()`); `_mean_gb` is the +post-step resting footprint. ## Developer Guide diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/README.md b/bionemo-recipes/recipes/opengenome2_llama_native_te/README.md index 85be87bded..a7ff80e56b 100644 --- a/bionemo-recipes/recipes/opengenome2_llama_native_te/README.md +++ b/bionemo-recipes/recipes/opengenome2_llama_native_te/README.md @@ -413,22 +413,26 @@ Control evaluation frequency with `validation.eval_interval` and `validation.num ## MFU Tracking -Enable per-step Model FLOPs Utilization (MFU) logging during training by adding `log_mfu=true`: +Enable per-step MFU logging by adding `log_mfu=true`: ```bash torchrun --nproc_per_node=2 train_fsdp2_cp.py log_mfu=true ``` -This adds two metrics at each logging interval, emitted alongside existing metrics via WANDB and -stdout: +Two pairs of metrics are emitted per logging interval: -- `train/tflops_per_gpu` — achieved BF16 TFLOPS per GPU -- `train/mfu_pct` — MFU as a percentage of the GPU's peak dense BF16 TFLOPS +- `train/mfu_pct` / `train/tflops_per_gpu` — useful-work rate. Excludes padding of all kinds. + Non-attention uses the unpadded token count; attention uses `Σ(Lᵢ²)` from `cu_seq_lens_q` (THD) + or per-row `attention_mask.sum()` (BSHD). +- `train/mfu_padded_pct` / `train/tflops_per_gpu_padded` — hardware view. Counts every slot the + GPU processes, including CP-zigzag and BSHD row padding. HFU-like. -The FLOPs formula auto-detects model architecture from the HF config (GQA vs. MHA, SwiGLU vs. -standard FFN, LM head presence) and scales with the actual unpadded token count on each rank. This -means it naturally handles gradient accumulation, data parallelism, context parallelism, BSHD, and -THD (sequence packing) without per-strategy code paths. The implementation lives in `perf_logger.py`. +The two pairs agree when the batch has no padding (e.g. dense single-doc THD packs — common for +genomic data windowed to `max_seq_length`). The formula is CP-aware and auto-detects GQA/SwiGLU +from the HF config. Implementation in `perf_logger.py`. + +Memory: `train/gpu_memory_allocated_max_gb` is the true transient peak per window +(`torch.cuda.max_memory_allocated()` + `reset_peak_memory_stats()`); `_mean_gb` is resting. ## Developer Guide From b979eed89dfa3560fe5b2cf74992a89a6a661abf Mon Sep 17 00:00:00 2001 From: Gagan Kaushik Date: Thu, 23 Apr 2026 11:59:11 -0700 Subject: [PATCH 20/24] MFU: use padded_vocab_size for mfu_padded_pct LM-head FLOPs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit For configs with padded_vocab_size set (ESM-2: 33→64 for FP8/tensor-core friendliness), the LM-head matmul physically runs at padded width and the logits are sliced back afterward. Count the padded width in the hardware-view metric (mfu_padded_pct, tflops_per_gpu_padded) while continuing to count raw vocab_size in the useful-work metric (mfu_pct, tflops_per_gpu). For configs without padded_vocab_size (llama3, og2, codonfm) the two values collapse and nothing changes. Addresses review feedback from @trvachov on PR #1548. Signed-off-by: Gagan Kaushik --- .../recipes/codonfm_native_te/perf_logger.py | 12 ++++++++++-- .../recipes/esm2_native_te/perf_logger.py | 12 ++++++++++-- .../recipes/llama3_native_te/perf_logger.py | 12 ++++++++++-- .../opengenome2_llama_native_te/perf_logger.py | 12 ++++++++++-- 4 files changed, 40 insertions(+), 8 deletions(-) diff --git a/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py b/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py index e45a1f037c..f0ae75bd69 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py @@ -66,7 +66,7 @@ def _detect_peak_tflops_bf16(): return None, name -def _compute_non_attn_per_token_flops(model_config_dict: dict) -> int: +def _compute_non_attn_per_token_flops(model_config_dict: dict, use_padded_vocab: bool = False) -> int: """Per-token FLOPs for everything EXCEPT the S² attention term. Q/K/V/O projections (GQA-aware) + MLP + LM head, 3x for fwd+bwd. Multiply by the @@ -81,6 +81,10 @@ def _compute_non_attn_per_token_flops(model_config_dict: dict) -> int: kv_dim = n_kv * head_dim ffn = model_config_dict["intermediate_size"] vocab = model_config_dict.get("vocab_size", 0) + if use_padded_vocab: + # LM-head matmul runs at padded width (e.g. ESM-2: vocab=33 → padded=64 for + # FP8/tensor-core friendliness); logits are sliced back post-matmul. + vocab = model_config_dict.get("padded_vocab_size") or vocab num_layers = model_config_dict["num_hidden_layers"] model_type = model_config_dict.get("model_type", "") num_mlp_proj = 3 if model_type in _GATED_MLP_MODEL_TYPES else 2 @@ -192,11 +196,15 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig, model_confi # reflects each rank's share under DP and sequence packing. self._log_mfu = bool(args.get("log_mfu", False)) and model_config_dict is not None self._non_attn_per_token_flops = 0 + self._non_attn_per_token_flops_padded = 0 self._attn_flop_coeff = 0 self._cp_size = int(args.get("cp_size", 1)) self._peak_tflops: float | None = None if self._log_mfu: self._non_attn_per_token_flops = _compute_non_attn_per_token_flops(model_config_dict) + self._non_attn_per_token_flops_padded = _compute_non_attn_per_token_flops( + model_config_dict, use_padded_vocab=True + ) self._attn_flop_coeff = _compute_attn_flop_coeff(model_config_dict) self._peak_tflops, gpu_name = _detect_peak_tflops_bf16() if dist_config.local_rank == 0: @@ -348,7 +356,7 @@ def log_step( flops_unpadded = non_attn_unpadded + attn_flops_unpadded tflops_unpadded = flops_unpadded / step_time / 1e12 - non_attn_padded = self._non_attn_per_token_flops * self.num_tokens + non_attn_padded = self._non_attn_per_token_flops_padded * self.num_tokens attn_flops_padded = (self._attn_flop_coeff * attn_padded) // self._cp_size flops_padded = non_attn_padded + attn_flops_padded tflops_padded = flops_padded / step_time / 1e12 diff --git a/bionemo-recipes/recipes/esm2_native_te/perf_logger.py b/bionemo-recipes/recipes/esm2_native_te/perf_logger.py index bac8684e30..ac646f8b5d 100644 --- a/bionemo-recipes/recipes/esm2_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/esm2_native_te/perf_logger.py @@ -65,7 +65,7 @@ def _detect_peak_tflops_bf16(): return None, name -def _compute_non_attn_per_token_flops(model_config_dict: dict) -> int: +def _compute_non_attn_per_token_flops(model_config_dict: dict, use_padded_vocab: bool = False) -> int: """Per-token FLOPs for everything EXCEPT the S² attention term. Q/K/V/O projections (GQA-aware) + MLP + LM head, 3x for fwd+bwd. Multiply by the @@ -80,6 +80,10 @@ def _compute_non_attn_per_token_flops(model_config_dict: dict) -> int: kv_dim = n_kv * head_dim ffn = model_config_dict["intermediate_size"] vocab = model_config_dict.get("vocab_size", 0) + if use_padded_vocab: + # LM-head matmul runs at padded width (e.g. ESM-2: vocab=33 → padded=64 for + # FP8/tensor-core friendliness); logits are sliced back post-matmul. + vocab = model_config_dict.get("padded_vocab_size") or vocab num_layers = model_config_dict["num_hidden_layers"] model_type = model_config_dict.get("model_type", "") num_mlp_proj = 3 if model_type in _GATED_MLP_MODEL_TYPES else 2 @@ -195,11 +199,15 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig, model_confi # already reflects each rank's share under DP/CP and sequence packing. self._log_mfu = bool(args.get("log_mfu", False)) and model_config_dict is not None self._non_attn_per_token_flops = 0 + self._non_attn_per_token_flops_padded = 0 self._attn_flop_coeff = 0 self._cp_size = int(args.get("cp_size", 1)) self._peak_tflops: float | None = None if self._log_mfu: self._non_attn_per_token_flops = _compute_non_attn_per_token_flops(model_config_dict) + self._non_attn_per_token_flops_padded = _compute_non_attn_per_token_flops( + model_config_dict, use_padded_vocab=True + ) self._attn_flop_coeff = _compute_attn_flop_coeff(model_config_dict) self._peak_tflops, gpu_name = _detect_peak_tflops_bf16() if dist_config.local_rank == 0: @@ -357,7 +365,7 @@ def log_step( flops_unpadded = non_attn_unpadded + attn_flops_unpadded tflops_unpadded = flops_unpadded / step_time / 1e12 - non_attn_padded = self._non_attn_per_token_flops * self.num_tokens + non_attn_padded = self._non_attn_per_token_flops_padded * self.num_tokens attn_flops_padded = (self._attn_flop_coeff * attn_padded) // self._cp_size flops_padded = non_attn_padded + attn_flops_padded tflops_padded = flops_padded / step_time / 1e12 diff --git a/bionemo-recipes/recipes/llama3_native_te/perf_logger.py b/bionemo-recipes/recipes/llama3_native_te/perf_logger.py index 58136173f7..d581dd413b 100644 --- a/bionemo-recipes/recipes/llama3_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/llama3_native_te/perf_logger.py @@ -63,7 +63,7 @@ def _detect_peak_tflops_bf16(): return None, name -def _compute_non_attn_per_token_flops(model_config_dict: dict) -> int: +def _compute_non_attn_per_token_flops(model_config_dict: dict, use_padded_vocab: bool = False) -> int: """Per-token FLOPs for everything EXCEPT the S² attention term. Q/K/V/O projections (GQA-aware) + MLP + LM head, 3x for fwd+bwd. Multiply by the @@ -78,6 +78,10 @@ def _compute_non_attn_per_token_flops(model_config_dict: dict) -> int: kv_dim = n_kv * head_dim ffn = model_config_dict["intermediate_size"] vocab = model_config_dict.get("vocab_size", 0) + if use_padded_vocab: + # LM-head matmul runs at padded width (e.g. ESM-2: vocab=33 → padded=64 for + # FP8/tensor-core friendliness); logits are sliced back post-matmul. + vocab = model_config_dict.get("padded_vocab_size") or vocab num_layers = model_config_dict["num_hidden_layers"] model_type = model_config_dict.get("model_type", "") num_mlp_proj = 3 if model_type in _GATED_MLP_MODEL_TYPES else 2 @@ -201,11 +205,15 @@ def __init__( # reflects each rank's share under DP/CP and sequence packing. self._log_mfu = bool(args.get("log_mfu", False)) and model_config_dict is not None self._non_attn_per_token_flops = 0 + self._non_attn_per_token_flops_padded = 0 self._attn_flop_coeff = 0 self._cp_size = int(args.get("cp_size", 1)) self._peak_tflops: float | None = None if self._log_mfu: self._non_attn_per_token_flops = _compute_non_attn_per_token_flops(model_config_dict) + self._non_attn_per_token_flops_padded = _compute_non_attn_per_token_flops( + model_config_dict, use_padded_vocab=True + ) self._attn_flop_coeff = _compute_attn_flop_coeff(model_config_dict) self._peak_tflops, gpu_name = _detect_peak_tflops_bf16() if dist_config.local_rank == 0: @@ -384,7 +392,7 @@ def log_step( flops_unpadded = non_attn_unpadded + attn_flops_unpadded tflops_unpadded = flops_unpadded / step_time / 1e12 - non_attn_padded = self._non_attn_per_token_flops * self.num_tokens + non_attn_padded = self._non_attn_per_token_flops_padded * self.num_tokens attn_flops_padded = (self._attn_flop_coeff * attn_padded) // self._cp_size flops_padded = non_attn_padded + attn_flops_padded tflops_padded = flops_padded / step_time / 1e12 diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py b/bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py index c61ab44599..abb8891c93 100644 --- a/bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py @@ -71,7 +71,7 @@ def _detect_peak_tflops_bf16(): return None, name -def _compute_non_attn_per_token_flops(model_config_dict: dict) -> int: +def _compute_non_attn_per_token_flops(model_config_dict: dict, use_padded_vocab: bool = False) -> int: """Per-token FLOPs for everything EXCEPT the S² attention term. Q/K/V/O projections (GQA-aware) + MLP + LM head, 3x for fwd+bwd. Multiply by the @@ -86,6 +86,10 @@ def _compute_non_attn_per_token_flops(model_config_dict: dict) -> int: kv_dim = n_kv * head_dim ffn = model_config_dict["intermediate_size"] vocab = model_config_dict.get("vocab_size", 0) + if use_padded_vocab: + # LM-head matmul runs at padded width (e.g. ESM-2: vocab=33 → padded=64 for + # FP8/tensor-core friendliness); logits are sliced back post-matmul. + vocab = model_config_dict.get("padded_vocab_size") or vocab num_layers = model_config_dict["num_hidden_layers"] model_type = model_config_dict.get("model_type", "") num_mlp_proj = 3 if model_type in _GATED_MLP_MODEL_TYPES else 2 @@ -197,11 +201,15 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig, model_confi # reflects each rank's share under DP/CP and sequence packing. self._log_mfu = bool(args.get("log_mfu", False)) and model_config_dict is not None self._non_attn_per_token_flops = 0 + self._non_attn_per_token_flops_padded = 0 self._attn_flop_coeff = 0 self._cp_size = int(args.get("cp_size", 1)) self._peak_tflops: float | None = None if self._log_mfu: self._non_attn_per_token_flops = _compute_non_attn_per_token_flops(model_config_dict) + self._non_attn_per_token_flops_padded = _compute_non_attn_per_token_flops( + model_config_dict, use_padded_vocab=True + ) self._attn_flop_coeff = _compute_attn_flop_coeff(model_config_dict) self._peak_tflops, gpu_name = _detect_peak_tflops_bf16() if dist_config.local_rank == 0: @@ -373,7 +381,7 @@ def log_step( flops_unpadded = non_attn_unpadded + attn_flops_unpadded tflops_unpadded = flops_unpadded / step_time / 1e12 - non_attn_padded = self._non_attn_per_token_flops * self.num_tokens + non_attn_padded = self._non_attn_per_token_flops_padded * self.num_tokens attn_flops_padded = (self._attn_flop_coeff * attn_padded) // self._cp_size flops_padded = non_attn_padded + attn_flops_padded tflops_padded = flops_padded / step_time / 1e12 From 44172ae3d46fff1a19b32cfd40484b0c3a613edb Mon Sep 17 00:00:00 2001 From: Gagan Kaushik Date: Fri, 24 Apr 2026 12:20:13 -0700 Subject: [PATCH 21/24] test(esm2): update perf_logger tests for split _attn_work_*_accum buffers The single self._attn_work_accum was split into _attn_work_unpadded_accum and _attn_work_padded_accum to support the unpadded/padded MFU distinction, but two tests in esm2_native_te still referenced the old single name, failing in CI with AttributeError. Update the assertions to check both buffers. With no attention_mask and no cu_seq_lens on the test batch, both paths fall through to shape-synthesis and hold the same value, so each test now asserts both accumulators hold the expected amount. No changes needed in llama3 / opengenome2_llama / codonfm: their test files don't exercise _attn_work_accum lifecycle directly. Signed-off-by: Gagan Kaushik --- .../recipes/esm2_native_te/tests/test_perf_logger.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/bionemo-recipes/recipes/esm2_native_te/tests/test_perf_logger.py b/bionemo-recipes/recipes/esm2_native_te/tests/test_perf_logger.py index b543e27c2e..e8e7b61b7c 100644 --- a/bionemo-recipes/recipes/esm2_native_te/tests/test_perf_logger.py +++ b/bionemo-recipes/recipes/esm2_native_te/tests/test_perf_logger.py @@ -213,7 +213,7 @@ def test_num_tokens_accumulates_across_grad_acc(self, mock_wandb, mock_tqdm): assert perf_logger.running_loss.item() == pytest.approx(4.0) def test_attn_work_accumulates_across_grad_acc(self, mock_wandb, mock_tqdm): - """_attn_work_accum sums Σ(Lᵢ²) over all micro-batches when log_mfu=True.""" + """Both _attn_work_*_accum buffers sum Σ(Lᵢ²) over all micro-batches when log_mfu=True.""" dist_config = DistributedConfig() args = _make_args(logging_frequency=1, log_mfu=True, max_seq_length=128) perf_logger = PerfLogger(dist_config, args, model_config_dict=_esm_cfg()) @@ -230,8 +230,11 @@ def test_attn_work_accumulates_across_grad_acc(self, mock_wandb, mock_tqdm): outputs.logits = torch.randn(2, 64, ESM2_VOCAB, device=device) perf_logger.log_micro_step(step=1, batch=batch, outputs=outputs) - # Accumulator should hold 3 * 2 * 64² = 24576 - assert perf_logger._attn_work_accum.item() == 3 * 2 * 64 * 64 + # With no attention_mask and no cu_seq_lens, both unpadded and padded paths fall + # through to the shape-synthesis branch, so both accumulators hold 3 * 2 * 64² = 24576. + expected = 3 * 2 * 64 * 64 + assert perf_logger._attn_work_unpadded_accum.item() == expected + assert perf_logger._attn_work_padded_accum.item() == expected def test_reset_on_log_boundary(self, mock_wandb, mock_tqdm): """Calling log_step on a logging-boundary step drains all accumulators.""" @@ -248,5 +251,6 @@ def test_reset_on_log_boundary(self, mock_wandb, mock_tqdm): assert perf_logger.grad_acc_step_count == 0 assert perf_logger.num_tokens == 0 assert perf_logger.num_unpadded_tokens.item() == 0 - assert perf_logger._attn_work_accum.item() == 0 + assert perf_logger._attn_work_unpadded_accum.item() == 0 + assert perf_logger._attn_work_padded_accum.item() == 0 assert perf_logger.running_loss.item() == pytest.approx(0.0) From 29121fe3d6d8686f196b7825b723ba82088b4d0c Mon Sep 17 00:00:00 2001 From: Gagan Kaushik Date: Fri, 24 Apr 2026 14:00:21 -0700 Subject: [PATCH 22/24] docs(perf_logger): note pad_to_multiple_of / cu_seq_lens_q collapse (#1561) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When the collator's pad_to_multiple_of option is set (FP8/FP4 alignment), cu_seq_lens_q is mutated in place to include an appended mock pad sequence and no cu_seq_lens_q_padded key is written — that key is reserved for TE's per-sequence CP padding. In that path the unpadded and padded MFU metrics collapse, inflated by at most pad_to_multiple_of² of the real Σ(Lᵢ²) — typically <10⁻⁵, below measurement noise. Documented as a known limitation in _attn_work_from_batch's docstring in all four MFU-tracking recipes (esm2, llama3, opengenome2_llama, codonfm), with a pointer to issue #1561 for the full analysis and proposed fixes. No behavior change. Signed-off-by: Gagan Kaushik --- bionemo-recipes/recipes/codonfm_native_te/perf_logger.py | 9 +++++++++ bionemo-recipes/recipes/esm2_native_te/perf_logger.py | 8 ++++++++ bionemo-recipes/recipes/llama3_native_te/perf_logger.py | 8 ++++++++ .../recipes/opengenome2_llama_native_te/perf_logger.py | 8 ++++++++ 4 files changed, 33 insertions(+) diff --git a/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py b/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py index f0ae75bd69..4884245815 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py @@ -132,6 +132,15 @@ def _attn_work_from_batch( CodonFM currently runs FSDP without CP (cp_size=1), but the formula stays correct if CP is added later. Int32 lens cast to int64 BEFORE squaring (overflow at L ≈ 46k otherwise). + + NOTE: With the collator's ``pad_to_multiple_of`` option (FP8/FP4 alignment, inlined + in ``CodonTHDCollator.__call__`` in dataset.py), the cu_seq_lens_q tensor is mutated + in place to include one or more appended mock pad sequences and no + ``cu_seq_lens_q_padded`` key is written (that key is reserved for TE's per-sequence + CP padding). In that path the unpadded and padded metrics collapse, inflated by + ≤``pad_to_multiple_of²`` relative to the real Σ(Lᵢ²) — typically <10⁻⁵ and below + measurement noise. Known limitation; see + https://github.com/NVIDIA/bionemo-framework/issues/1561. """ if include_padding: cu = batch.get("cu_seq_lens_q_padded") diff --git a/bionemo-recipes/recipes/esm2_native_te/perf_logger.py b/bionemo-recipes/recipes/esm2_native_te/perf_logger.py index ac646f8b5d..ec8770a6fd 100644 --- a/bionemo-recipes/recipes/esm2_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/esm2_native_te/perf_logger.py @@ -129,6 +129,14 @@ def _attn_work_from_batch( * BSHD: uses full ``input_ids.shape``, scaled by ``cp_size²``. Int32 lens cast to int64 BEFORE squaring (overflow at L ≈ 46k otherwise). + + NOTE: With the collator's ``pad_to_multiple_of`` option (FP8/FP4 alignment), the + cu_seq_lens_q tensor is mutated in place to include an appended mock pad sequence + and no ``cu_seq_lens_q_padded`` key is written (that key is reserved for TE's + per-sequence CP padding). In that path the unpadded and padded metrics collapse, + inflated by ≤``pad_to_multiple_of²`` relative to the real Σ(Lᵢ²) — typically + <10⁻⁵ and below measurement noise. Known limitation; see + https://github.com/NVIDIA/bionemo-framework/issues/1561. """ if include_padding: cu = batch.get("cu_seq_lens_q_padded") diff --git a/bionemo-recipes/recipes/llama3_native_te/perf_logger.py b/bionemo-recipes/recipes/llama3_native_te/perf_logger.py index d581dd413b..db41c54a2f 100644 --- a/bionemo-recipes/recipes/llama3_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/llama3_native_te/perf_logger.py @@ -131,6 +131,14 @@ def _attn_work_from_batch( scaled by ``cp_size²``. Int32 lens cast to int64 BEFORE squaring (overflow at L ≈ 46k otherwise). + + NOTE: With the collator's ``pad_to_multiple_of`` option (FP8/FP4 alignment), the + cu_seq_lens_q tensor is mutated in place to include an appended mock pad sequence + and no ``cu_seq_lens_q_padded`` key is written (that key is reserved for TE's + per-sequence CP padding). In that path the unpadded and padded metrics collapse, + inflated by ≤``pad_to_multiple_of²`` relative to the real Σ(Lᵢ²) — typically + <10⁻⁵ and below measurement noise. Known limitation; see + https://github.com/NVIDIA/bionemo-framework/issues/1561. """ if include_padding: cu = batch.get("cu_seq_lens_q_padded") diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py b/bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py index abb8891c93..cd57e7dc38 100644 --- a/bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py @@ -136,6 +136,14 @@ def _attn_work_from_batch( * BSHD: uses full ``input_ids.shape``, scaled by ``cp_size²``. Int32 lens cast to int64 BEFORE squaring (overflow at L ≈ 46k otherwise). + + NOTE: With the collator's ``pad_to_multiple_of`` option (FP8/FP4 alignment), the + cu_seq_lens_q tensor is mutated in place to include an appended mock pad sequence + and no ``cu_seq_lens_q_padded`` key is written (that key is reserved for TE's + per-sequence CP padding). In that path the unpadded and padded metrics collapse, + inflated by ≤``pad_to_multiple_of²`` relative to the real Σ(Lᵢ²) — typically + <10⁻⁵ and below measurement noise. Known limitation; see + https://github.com/NVIDIA/bionemo-framework/issues/1561. """ if include_padding: cu = batch.get("cu_seq_lens_q_padded") From ff0410d0dcf25c1b865f5437add73810f48644e1 Mon Sep 17 00:00:00 2001 From: Gagan Kaushik Date: Mon, 27 Apr 2026 11:41:00 -0700 Subject: [PATCH 23/24] esm2: revert log_micro_step split (no grad-acc in ESM-2) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ESM-2's training loop performs a single forward+backward+optimizer-step per batch and never calls grad-accumulation. The log_micro_step / log_step split was over-engineered for it: the cross-step accumulators (running_loss, grad_acc_step_count, num_tokens, _attn_work_*_accum) trivially reset every iteration since the inner loop runs exactly once. Collapse the pair back into a single log_step(step, batch, outputs, grad_norm, lr) that reads the batch and outputs directly, drop the unused accumulators, and update all five ESM-2 train scripts (train_ddp, train_ddp_cp, train_fsdp2, train_fsdp2_cp, train_mfsdp) to call the unified entry point. The Σ(Lᵢ²) attention math and the dual mfu_pct / mfu_padded_pct flavors are orthogonal to grad-acc compatibility and remain unchanged. The other three MFU-tracked recipes (llama3, og2, codonfm) DO grad-accumulate and keep the log_micro_step / log_step split in their own perf_logger modules. Tests: drop the TestGradAccAccumulation class — its three tests exercised the log_micro_step path that no longer exists. TestFlopSplitAndAttention (the _attn_work_from_batch coverage, including BSHD/THD/CP variants) is unchanged. Addresses pstjohn review comments r3148145031 / r3148154309 / r3148219546 on PR #1548. Signed-off-by: Gagan Kaushik --- .../recipes/esm2_native_te/perf_logger.py | 111 +++++---------- .../esm2_native_te/tests/test_perf_logger.py | 130 +----------------- .../recipes/esm2_native_te/train_ddp.py | 5 +- .../recipes/esm2_native_te/train_ddp_cp.py | 5 +- .../recipes/esm2_native_te/train_fsdp2.py | 5 +- .../recipes/esm2_native_te/train_fsdp2_cp.py | 6 +- .../recipes/esm2_native_te/train_mfsdp.py | 5 +- 7 files changed, 51 insertions(+), 216 deletions(-) diff --git a/bionemo-recipes/recipes/esm2_native_te/perf_logger.py b/bionemo-recipes/recipes/esm2_native_te/perf_logger.py index ec8770a6fd..0dcbfa997c 100644 --- a/bionemo-recipes/recipes/esm2_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/esm2_native_te/perf_logger.py @@ -176,10 +176,11 @@ def _attn_work_from_batch( class PerfLogger: """Class to log performance metrics to stdout and wandb, and print final averaged metrics at the end of training. - Uses the ``log_micro_step`` / ``log_step`` accumulator pattern (shared with the - llama3/og2/codonfm recipes) so gradient accumulation is correctly handled: - token counts, Σ(Lᵢ²), perplexity updates, and loss accumulate across every - micro-batch of an optimizer step; metrics are reported once per logging window. + ESM-2 does not perform gradient accumulation — each optimizer step is a single + forward+backward — so ``log_step`` reads the batch and outputs directly without + cross-micro-batch accumulators. The other MFU-tracking recipes (llama3, og2, + codonfm) do grad-accumulate and use a separate ``log_micro_step`` / ``log_step`` + split in their own perf_logger modules. Args: dist_config: The distributed configuration. @@ -265,55 +266,11 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig, model_confi # Whether to step debug_api.step() after each step self.quant_stats_config = args.quant_stats_config.enabled - # Gradient accumulation tracking (accumulated over the grad-acc micro-batches of - # the last optimizer step in the logging window, then drained in log_step). - self.num_tokens = 0 - self.num_unpadded_tokens = torch.tensor(0, dtype=torch.int64, device=self._device) - # Σ(Lᵢ²) over grad-acc micro-batches — two flavors: - # unpadded: only real tokens (useful work), drives mfu_pct - # padded: all slots including CP-zigzag / BSHD row padding, drives mfu_padded_pct - self._attn_work_unpadded_accum = torch.tensor(0, dtype=torch.int64, device=self._device) - self._attn_work_padded_accum = torch.tensor(0, dtype=torch.int64, device=self._device) - self.running_loss = torch.tensor(0.0, device=self._device) - self.grad_acc_step_count = 0 - - def log_micro_step(self, step: int, batch: dict[str, torch.Tensor], outputs: MaskedLMOutput): - """Store data on micro step for gradient accumulation metrics. - - Args: - step: The current optimizer step number (shared across all micro-batches). - batch: The input batch for this micro-step. - outputs: Model outputs for this micro-step (with unscaled loss). - """ - assert outputs.loss is not None, "Loss is None" - - with torch.no_grad(): - self.grad_acc_step_count += 1 - self.running_loss += outputs.loss - - if step % self.logging_frequency == 0 and step > 0: - self.num_tokens += batch["input_ids"].numel() - num_unpadded_tokens = batch["input_ids"][batch["input_ids"] != ESM2_PAD_TOKEN_ID].numel() - self.num_unpadded_tokens += num_unpadded_tokens - if self._log_mfu: - # Accumulate both unpadded (useful) and padded (hardware) Σ(Lᵢ²). - # Helper returns a GLOBAL value (pre-CP-shard); log_step divides by cp_size. - self._attn_work_unpadded_accum += _attn_work_from_batch( - batch, self._device, self._cp_size, include_padding=False - ) - self._attn_work_padded_accum += _attn_work_from_batch( - batch, self._device, self._cp_size, include_padding=True - ) - - # Update perplexity per micro-batch since it needs logits + labels. - logits = outputs.logits - if logits.dim() < 3: - logits = logits.unsqueeze(0) - self.metrics["train/perplexity"].update(logits, batch["labels"]) - def log_step( self, step: int, + batch: dict[str, torch.Tensor], + outputs: MaskedLMOutput, grad_norm: torch.Tensor | DTensor | float, lr: float, ): @@ -321,15 +278,14 @@ def log_step( Args: step: Current optimizer step. + batch: The input batch for this step. + outputs: Model outputs for this step (with loss + logits). grad_norm: Gradient norm value. lr: Current learning rate. """ - with torch.no_grad(): - assert self.grad_acc_step_count > 0, ( - f"Gradient accumulation steps ({self.grad_acc_step_count}) must be greater than 0, " - f"and can be incremented by log_micro_step()." - ) + assert outputs.loss is not None, "Loss is None" + with torch.no_grad(): # FSDP2's clip_grad_norm_ returns a DTensor; convert to local tensor for torchmetrics compatibility. if isinstance(grad_norm, DTensor): grad_norm = grad_norm.to_local() @@ -337,9 +293,7 @@ def log_step( if self.quant_stats_config: debug_api.step() - # Calculate average loss over all micro steps in the logging window. - avg_loss = self.running_loss / self.grad_acc_step_count - self.min_loss = torch.minimum(self.min_loss, avg_loss) + self.min_loss = torch.minimum(self.min_loss, outputs.loss) if step % self.logging_frequency == 0 and step > 0: elapsed_time, self.previous_step_time = ( @@ -348,15 +302,24 @@ def log_step( ) step_time = elapsed_time / self.logging_frequency - self.metrics["train/loss"].update(avg_loss) + num_tokens = batch["input_ids"].numel() + num_unpadded_tokens = batch["input_ids"][batch["input_ids"] != ESM2_PAD_TOKEN_ID].numel() + + # Update perplexity from logits + labels (logits get a leading batch dim if absent). + logits = outputs.logits + if logits.dim() < 3: + logits = logits.unsqueeze(0) + self.metrics["train/perplexity"].update(logits, batch["labels"]) + + self.metrics["train/loss"].update(outputs.loss) self.metrics["train/learning_rate"].update(lr) self.metrics["train/grad_norm"].update( grad_norm if isinstance(grad_norm, torch.Tensor) else torch.tensor(grad_norm) ) self.metrics["train/step_time"].update(step_time) - self.metrics["train/tokens_per_second_per_gpu"].update(self.num_tokens / step_time) - self.metrics["train/unpadded_tokens_per_second_per_gpu"].update(self.num_unpadded_tokens / step_time) - self.metrics["train/total_unpadded_tokens_per_batch"].update(self.num_unpadded_tokens) + self.metrics["train/tokens_per_second_per_gpu"].update(num_tokens / step_time) + self.metrics["train/unpadded_tokens_per_second_per_gpu"].update(num_unpadded_tokens / step_time) + self.metrics["train/total_unpadded_tokens_per_batch"].update(num_unpadded_tokens) if self._log_mfu: # Two MFU flavors reported side-by-side: @@ -364,16 +327,20 @@ def log_step( # attn over real Σ(Lᵢ²). Drops both padding types. # mfu_padded_pct = hardware view. Non-attn over all slots, attn over # padded Σ(Lᵢ²) (includes CP zigzag + BSHD row pad). - attn_unpadded = int(self._attn_work_unpadded_accum.item()) - attn_padded = int(self._attn_work_padded_accum.item()) - num_unpadded = int(self.num_unpadded_tokens.item()) + # Helper returns GLOBAL Σ(Lᵢ²); divide by cp_size to convert to per-rank. + attn_unpadded = int( + _attn_work_from_batch(batch, self._device, self._cp_size, include_padding=False).item() + ) + attn_padded = int( + _attn_work_from_batch(batch, self._device, self._cp_size, include_padding=True).item() + ) - non_attn_unpadded = self._non_attn_per_token_flops * num_unpadded + non_attn_unpadded = self._non_attn_per_token_flops * num_unpadded_tokens attn_flops_unpadded = (self._attn_flop_coeff * attn_unpadded) // self._cp_size flops_unpadded = non_attn_unpadded + attn_flops_unpadded tflops_unpadded = flops_unpadded / step_time / 1e12 - non_attn_padded = self._non_attn_per_token_flops_padded * self.num_tokens + non_attn_padded = self._non_attn_per_token_flops_padded * num_tokens attn_flops_padded = (self._attn_flop_coeff * attn_padded) // self._cp_size flops_padded = non_attn_padded + attn_flops_padded tflops_padded = flops_padded / step_time / 1e12 @@ -406,19 +373,11 @@ def log_step( if self._dist_config.is_main_process(): wandb.log(metrics, step=step) self._progress_bar.update(self.logging_frequency) - self._progress_bar.set_postfix({"loss": avg_loss.item()}) + self._progress_bar.set_postfix({"loss": outputs.loss.item()}) if self._dist_config.local_rank == 0: logger.info(", ".join([f"{k.split('/')[1]}: {v:.3g}" for k, v in metrics.items()])) - # Reset running accumulators for next logging window. - self.running_loss.zero_() - self.num_tokens = 0 - self.num_unpadded_tokens.zero_() - self._attn_work_unpadded_accum.zero_() - self._attn_work_padded_accum.zero_() - self.grad_acc_step_count = 0 - def finish(self): """Finish the logger and close the progress bar.""" if self.quant_stats_config: diff --git a/bionemo-recipes/recipes/esm2_native_te/tests/test_perf_logger.py b/bionemo-recipes/recipes/esm2_native_te/tests/test_perf_logger.py index e8e7b61b7c..77b3609ed9 100644 --- a/bionemo-recipes/recipes/esm2_native_te/tests/test_perf_logger.py +++ b/bionemo-recipes/recipes/esm2_native_te/tests/test_perf_logger.py @@ -13,25 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for ESM-2's PerfLogger: FLOP formula split + grad-acc accumulator pattern. +"""Tests for ESM-2's PerfLogger: non-attn + Σ(Lᵢ²) attention FLOP formula. -ESM-2 was previously the odd-one-out: its PerfLogger read num_tokens from the *last* -micro-batch at log time, so any future gradient accumulation would have undercounted -FLOPs by 1/grad_acc_steps. This retrofit introduces the ``log_micro_step`` / -``log_step`` split shared with the other MFU-tracking recipes (llama3, og2, codonfm) -and fixes attention-FLOP overcounting on packed (THD) batches. +ESM-2 does not perform gradient accumulation — each optimizer step is a single +forward+backward — so PerfLogger has a single ``log_step`` entry point that reads +the batch and outputs directly. The other MFU-tracking recipes (llama3, og2, +codonfm) do grad-accumulate and use a ``log_micro_step`` / ``log_step`` split in +their own perf_logger modules. """ -from unittest import mock - -import pytest import torch -from omegaconf import OmegaConf -from transformers.modeling_outputs import MaskedLMOutput -from distributed_config import DistributedConfig from perf_logger import ( - PerfLogger, _attn_work_from_batch, _compute_attn_flop_coeff, _compute_non_attn_per_token_flops, @@ -41,47 +34,6 @@ ESM2_VOCAB = 33 -def _make_args(logging_frequency=1, num_train_steps=100, log_mfu=False, max_seq_length=128): - """Create a minimal args config for PerfLogger.""" - return OmegaConf.create( - { - "logger": {"frequency": logging_frequency}, - "wandb_init_args": {"project": "test", "mode": "disabled"}, - "num_train_steps": num_train_steps, - "quant_stats_config": {"enabled": False}, - "log_mfu": log_mfu, - "dataset": {"max_seq_length": max_seq_length}, - } - ) - - -def _make_batch(seq_len=128, device="cuda:0"): - """Create a minimal batch dict.""" - return { - "input_ids": torch.ones(1, seq_len, dtype=torch.long, device=device), - "labels": torch.ones(1, seq_len, dtype=torch.long, device=device), - } - - -def _make_outputs(loss_value, seq_len=128, device="cuda:0"): - """Create MaskedLMOutput with loss + logits.""" - logits = torch.randn(1, seq_len, ESM2_VOCAB, device=device) - return MaskedLMOutput(loss=torch.tensor(loss_value, device=device), logits=logits) - - -@pytest.fixture -def mock_wandb(): - with mock.patch("perf_logger.wandb") as mocked: - mocked.init.return_value = mock.MagicMock() - yield mocked - - -@pytest.fixture -def mock_tqdm(): - with mock.patch("perf_logger.tqdm") as mocked: - yield mocked - - def _esm_cfg(): """ESM-2-like MLM encoder config (MHA, no GQA, gelu MLP).""" return { @@ -184,73 +136,3 @@ def test_include_padding_bshd_with_attention_mask(self): dev = torch.device("cpu") assert _attn_work_from_batch(batch, dev, cp_size=1, include_padding=False).item() == 5**2 + 3**2 assert _attn_work_from_batch(batch, dev, cp_size=1, include_padding=True).item() == 2 * 8 * 8 - - -class TestGradAccAccumulation: - """Lock in ESM-2's new log_micro_step/log_step split under gradient accumulation. - - Before this retrofit ESM-2 read num_tokens from only the last micro-batch of an - optimizer step, so with grad_acc_steps > 1 it would have reported 1/grad_acc the - true FLOP count. The new accumulator pattern sums across micro-batches. - """ - - def test_num_tokens_accumulates_across_grad_acc(self, mock_wandb, mock_tqdm): - """4 micro-batches of seq_len=128 → num_tokens = 4*128 at log boundary.""" - dist_config = DistributedConfig() - args = _make_args(logging_frequency=1, max_seq_length=128) - perf_logger = PerfLogger(dist_config, args) - device = perf_logger._device - - # One optimizer step with 4 micro-batches of shape (1, 128). - for _ in range(4): - batch = _make_batch(seq_len=128, device=device) - outputs = _make_outputs(1.0, seq_len=128, device=device) - perf_logger.log_micro_step(step=1, batch=batch, outputs=outputs) - - assert perf_logger.grad_acc_step_count == 4 - assert perf_logger.num_tokens == 4 * 128 # 4 micro-batches * 128 tokens each - # running_loss should sum 4 losses of 1.0 each - assert perf_logger.running_loss.item() == pytest.approx(4.0) - - def test_attn_work_accumulates_across_grad_acc(self, mock_wandb, mock_tqdm): - """Both _attn_work_*_accum buffers sum Σ(Lᵢ²) over all micro-batches when log_mfu=True.""" - dist_config = DistributedConfig() - args = _make_args(logging_frequency=1, log_mfu=True, max_seq_length=128) - perf_logger = PerfLogger(dist_config, args, model_config_dict=_esm_cfg()) - device = perf_logger._device - - # 3 micro-batches of shape (2, 64) → each batch has Σ(Lᵢ²) = 2 * 64² = 8192 - for _ in range(3): - batch = { - "input_ids": torch.ones(2, 64, dtype=torch.long, device=device), - "labels": torch.ones(2, 64, dtype=torch.long, device=device), - } - outputs = _make_outputs(1.0, seq_len=64, device=device) - # Perplexity expects (B, S, V) logits - outputs.logits = torch.randn(2, 64, ESM2_VOCAB, device=device) - perf_logger.log_micro_step(step=1, batch=batch, outputs=outputs) - - # With no attention_mask and no cu_seq_lens, both unpadded and padded paths fall - # through to the shape-synthesis branch, so both accumulators hold 3 * 2 * 64² = 24576. - expected = 3 * 2 * 64 * 64 - assert perf_logger._attn_work_unpadded_accum.item() == expected - assert perf_logger._attn_work_padded_accum.item() == expected - - def test_reset_on_log_boundary(self, mock_wandb, mock_tqdm): - """Calling log_step on a logging-boundary step drains all accumulators.""" - dist_config = DistributedConfig() - args = _make_args(logging_frequency=1, log_mfu=True, max_seq_length=128) - perf_logger = PerfLogger(dist_config, args, model_config_dict=_esm_cfg()) - device = perf_logger._device - - batch = _make_batch(seq_len=128, device=device) - outputs = _make_outputs(1.0, seq_len=128, device=device) - perf_logger.log_micro_step(step=1, batch=batch, outputs=outputs) - perf_logger.log_step(step=1, grad_norm=torch.tensor(1.0, device=device), lr=1e-4) - - assert perf_logger.grad_acc_step_count == 0 - assert perf_logger.num_tokens == 0 - assert perf_logger.num_unpadded_tokens.item() == 0 - assert perf_logger._attn_work_unpadded_accum.item() == 0 - assert perf_logger._attn_work_padded_accum.item() == 0 - assert perf_logger.running_loss.item() == pytest.approx(0.0) diff --git a/bionemo-recipes/recipes/esm2_native_te/train_ddp.py b/bionemo-recipes/recipes/esm2_native_te/train_ddp.py index 707652fb9e..ebf36d9d47 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_ddp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_ddp.py @@ -175,9 +175,6 @@ def main(args: DictConfig) -> float | None: loss = outputs.loss loss.backward() - # Record per-micro-batch metrics (loss, num_tokens, Σ(Lᵢ²), perplexity). - perf_logger.log_micro_step(step=step, batch=batch, outputs=outputs) - # Compute and clip gradient norms. total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item() @@ -188,6 +185,8 @@ def main(args: DictConfig) -> float | None: perf_logger.log_step( step=step, + batch=batch, + outputs=outputs, grad_norm=total_norm, lr=optimizer.param_groups[0]["lr"], ) diff --git a/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py b/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py index 8c3d5f029b..25829c64cf 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py @@ -184,9 +184,6 @@ def main(args: DictConfig) -> float | None: loss = outputs.loss loss.backward() - # Record per-micro-batch metrics (loss, num_tokens, Σ(Lᵢ²), perplexity). - perf_logger.log_micro_step(step=step, batch=batch, outputs=outputs) - # Compute and clip gradient norms. total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item() @@ -197,6 +194,8 @@ def main(args: DictConfig) -> float | None: perf_logger.log_step( step=step, + batch=batch, + outputs=outputs, grad_norm=total_norm, lr=optimizer.param_groups[0]["lr"], ) diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py index 78a5c5055b..fcfbf17fa6 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py @@ -201,9 +201,6 @@ def main(args: DictConfig) -> float | None: loss = outputs.loss loss.backward() - # Record per-micro-batch metrics (loss, num_tokens, Σ(Lᵢ²), perplexity). - perf_logger.log_micro_step(step=step, batch=batch, outputs=outputs) - # --- Grad clip --- total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item() @@ -215,6 +212,8 @@ def main(args: DictConfig) -> float | None: perf_logger.log_step( step=step, + batch=batch, + outputs=outputs, grad_norm=total_norm, lr=optimizer.param_groups[0]["lr"], ) diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py index b0d91f4a0d..06573112f8 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py @@ -196,10 +196,6 @@ def main(args: DictConfig) -> float | None: loss = outputs.loss loss.backward() - # Record per-micro-batch metrics (loss, num_tokens, Σ(Lᵢ²), perplexity). - # With future gradient accumulation, this would be called once per micro-batch. - perf_logger.log_micro_step(step=step, batch=batch, outputs=outputs) - # Compute and clip gradient norms. total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item() @@ -210,6 +206,8 @@ def main(args: DictConfig) -> float | None: perf_logger.log_step( step=step, + batch=batch, + outputs=outputs, grad_norm=total_norm, lr=optimizer.param_groups[0]["lr"], ) diff --git a/bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py b/bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py index 0dcd5875a3..2d61d14d39 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py @@ -182,9 +182,6 @@ def main(args: DictConfig) -> float | None: loss = outputs.loss loss.backward() - # Record per-micro-batch metrics (loss, num_tokens, Σ(Lᵢ²), perplexity). - perf_logger.log_micro_step(step=step, batch=batch, outputs=outputs) - # Compute and clip gradient norms. # This is causing training to hang in 26.01 torch base image for multi-process mFSDP. # total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item() @@ -196,6 +193,8 @@ def main(args: DictConfig) -> float | None: perf_logger.log_step( step=step, + batch=batch, + outputs=outputs, grad_norm=0.0, # total_norm, lr=optimizer.param_groups[0]["lr"], ) From b9f31aec700f1af3d63675d5e9c27148315b0790 Mon Sep 17 00:00:00 2001 From: Gagan Kaushik Date: Mon, 27 Apr 2026 11:44:18 -0700 Subject: [PATCH 24/24] perf_logger: move log_mfu gate inside PerfLogger across all recipes Train scripts no longer ternary-gate the model_config_dict at the call site based on args.log_mfu. They always pass model_config_dict=config.to_dict(); PerfLogger.__init__ checks args.log_mfu internally and skips the FLOPs/MFU machinery when disabled. Also drop args.get("log_mfu", False) in favor of args.log_mfu directly: the default lives in each recipe's hydra_config/defaults.yaml, so the explicit get-with-default is dead defensive code. Touches all four MFU-tracked recipes (esm2, llama3, og2, codonfm) and all eleven train scripts that init a PerfLogger: esm2_native_te: train_ddp, train_ddp_cp, train_fsdp2, train_fsdp2_cp, train_mfsdp llama3_native_te: train_ddp, train_fsdp2, train_fsdp2_cp opengenome2_llama_native: train_fsdp2, train_fsdp2_cp codonfm_native_te: train_fsdp2 codonfm's tests/test_perf_logger.py _make_args helper gains the "log_mfu": log_mfu key so PerfLogger can read it via attribute access (the other recipes' helpers already had it). Addresses pstjohn review comment r3148207804 on PR #1548. Signed-off-by: Gagan Kaushik --- bionemo-recipes/recipes/codonfm_native_te/perf_logger.py | 2 +- .../recipes/codonfm_native_te/tests/test_perf_logger.py | 3 ++- bionemo-recipes/recipes/codonfm_native_te/train_fsdp2.py | 2 +- bionemo-recipes/recipes/esm2_native_te/perf_logger.py | 2 +- bionemo-recipes/recipes/esm2_native_te/train_ddp.py | 2 +- bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py | 2 +- bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py | 2 +- bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py | 2 +- bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py | 2 +- bionemo-recipes/recipes/llama3_native_te/perf_logger.py | 2 +- bionemo-recipes/recipes/llama3_native_te/train_ddp.py | 2 +- bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py | 2 +- bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py | 2 +- .../recipes/opengenome2_llama_native_te/perf_logger.py | 2 +- .../recipes/opengenome2_llama_native_te/train_fsdp2.py | 2 +- .../recipes/opengenome2_llama_native_te/train_fsdp2_cp.py | 2 +- 16 files changed, 17 insertions(+), 16 deletions(-) diff --git a/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py b/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py index 4884245815..1ef272e81b 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py @@ -203,7 +203,7 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig, model_confi # MFU setup: compute per-token FLOPs and peak TFLOPS once at init. Actual FLOPs per # step are derived at log time from the tracked unpadded token count, which already # reflects each rank's share under DP and sequence packing. - self._log_mfu = bool(args.get("log_mfu", False)) and model_config_dict is not None + self._log_mfu = args.log_mfu and model_config_dict is not None self._non_attn_per_token_flops = 0 self._non_attn_per_token_flops_padded = 0 self._attn_flop_coeff = 0 diff --git a/bionemo-recipes/recipes/codonfm_native_te/tests/test_perf_logger.py b/bionemo-recipes/recipes/codonfm_native_te/tests/test_perf_logger.py index 83f2358061..5d1301b8f3 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/tests/test_perf_logger.py +++ b/bionemo-recipes/recipes/codonfm_native_te/tests/test_perf_logger.py @@ -33,7 +33,7 @@ VOCAB_SIZE = 69 # CodonFM vocabulary size -def _make_args(logging_frequency=1, num_train_steps=100): +def _make_args(logging_frequency=1, num_train_steps=100, log_mfu=False): """Create a minimal args config for PerfLogger.""" return OmegaConf.create( { @@ -41,6 +41,7 @@ def _make_args(logging_frequency=1, num_train_steps=100): "wandb_init_args": {"project": "test", "mode": "disabled"}, "num_train_steps": num_train_steps, "quant_stats_config": {"enabled": False}, + "log_mfu": log_mfu, } ) diff --git a/bionemo-recipes/recipes/codonfm_native_te/train_fsdp2.py b/bionemo-recipes/recipes/codonfm_native_te/train_fsdp2.py index f41e85ad8e..479cf59783 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/codonfm_native_te/train_fsdp2.py @@ -166,7 +166,7 @@ def main(args: DictConfig) -> float | None: perf_logger = PerfLogger( dist_config, args, - model_config_dict=config.to_dict() if args.get("log_mfu", False) else None, + model_config_dict=config.to_dict(), ) # Training loop diff --git a/bionemo-recipes/recipes/esm2_native_te/perf_logger.py b/bionemo-recipes/recipes/esm2_native_te/perf_logger.py index 0dcbfa997c..b3ef7e5609 100644 --- a/bionemo-recipes/recipes/esm2_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/esm2_native_te/perf_logger.py @@ -206,7 +206,7 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig, model_confi # MFU setup: compute per-token FLOPs and peak TFLOPS once at init. Actual FLOPs per # step are derived at log time from the accumulated token count + Σ(Lᵢ²), which # already reflects each rank's share under DP/CP and sequence packing. - self._log_mfu = bool(args.get("log_mfu", False)) and model_config_dict is not None + self._log_mfu = args.log_mfu and model_config_dict is not None self._non_attn_per_token_flops = 0 self._non_attn_per_token_flops_padded = 0 self._attn_flop_coeff = 0 diff --git a/bionemo-recipes/recipes/esm2_native_te/train_ddp.py b/bionemo-recipes/recipes/esm2_native_te/train_ddp.py index ebf36d9d47..b18b96f9fe 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_ddp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_ddp.py @@ -159,7 +159,7 @@ def main(args: DictConfig) -> float | None: perf_logger = PerfLogger( dist_config, args, - model_config_dict=config.to_dict() if args.get("log_mfu", False) else None, + model_config_dict=config.to_dict(), ) # Training loop diff --git a/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py b/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py index 25829c64cf..6762e707c9 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py @@ -168,7 +168,7 @@ def main(args: DictConfig) -> float | None: perf_logger = PerfLogger( dist_config, args, - model_config_dict=config.to_dict() if args.get("log_mfu", False) else None, + model_config_dict=config.to_dict(), ) # Training loop diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py index fcfbf17fa6..43e6b9f0ef 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py @@ -185,7 +185,7 @@ def main(args: DictConfig) -> float | None: perf_logger = PerfLogger( dist_config, args, - model_config_dict=config.to_dict() if args.get("log_mfu", False) else None, + model_config_dict=config.to_dict(), ) # Training loop diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py index 06573112f8..9a0988f6bd 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py @@ -180,7 +180,7 @@ def main(args: DictConfig) -> float | None: perf_logger = PerfLogger( dist_config, args, - model_config_dict=config.to_dict() if args.get("log_mfu", False) else None, + model_config_dict=config.to_dict(), ) # Training loop diff --git a/bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py b/bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py index 2d61d14d39..1f86aa9859 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py @@ -166,7 +166,7 @@ def main(args: DictConfig) -> float | None: perf_logger = PerfLogger( dist_config, args, - model_config_dict=config.to_dict() if args.get("log_mfu", False) else None, + model_config_dict=config.to_dict(), ) # Training loop diff --git a/bionemo-recipes/recipes/llama3_native_te/perf_logger.py b/bionemo-recipes/recipes/llama3_native_te/perf_logger.py index db41c54a2f..eda6b66f5e 100644 --- a/bionemo-recipes/recipes/llama3_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/llama3_native_te/perf_logger.py @@ -211,7 +211,7 @@ def __init__( # MFU setup: compute per-token FLOPs and peak TFLOPS once at init. Actual FLOPs per # step are derived at log time from the tracked unpadded token count, which already # reflects each rank's share under DP/CP and sequence packing. - self._log_mfu = bool(args.get("log_mfu", False)) and model_config_dict is not None + self._log_mfu = args.log_mfu and model_config_dict is not None self._non_attn_per_token_flops = 0 self._non_attn_per_token_flops_padded = 0 self._attn_flop_coeff = 0 diff --git a/bionemo-recipes/recipes/llama3_native_te/train_ddp.py b/bionemo-recipes/recipes/llama3_native_te/train_ddp.py index cb3251f590..8882cdd311 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_ddp.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_ddp.py @@ -145,7 +145,7 @@ def main(args: DictConfig) -> float | None: dist_config, args, start_step=start_step, - model_config_dict=config.to_dict() if args.get("log_mfu", False) else None, + model_config_dict=config.to_dict(), ) gc.collect() diff --git a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py index 1230ed82ac..2c16daf557 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py @@ -159,7 +159,7 @@ def main(args: DictConfig) -> float | None: dist_config, args, start_step=start_step, - model_config_dict=config.to_dict() if args.get("log_mfu", False) else None, + model_config_dict=config.to_dict(), ) gc.collect() diff --git a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py index e0327ef60f..17d3ece756 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py @@ -181,7 +181,7 @@ def main(args: DictConfig) -> float | None: dist_config, args, start_step=start_step, - model_config_dict=config.to_dict() if args.get("log_mfu", False) else None, + model_config_dict=config.to_dict(), ) gc.collect() diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py b/bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py index cd57e7dc38..249f6ff0b1 100644 --- a/bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py @@ -207,7 +207,7 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig, model_confi # MFU setup: compute per-token FLOPs and peak TFLOPS once at init. Actual FLOPs per # step are derived at log time from the tracked unpadded token count, which already # reflects each rank's share under DP/CP and sequence packing. - self._log_mfu = bool(args.get("log_mfu", False)) and model_config_dict is not None + self._log_mfu = args.log_mfu and model_config_dict is not None self._non_attn_per_token_flops = 0 self._non_attn_per_token_flops_padded = 0 self._attn_flop_coeff = 0 diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/train_fsdp2.py b/bionemo-recipes/recipes/opengenome2_llama_native_te/train_fsdp2.py index 8578e14136..701391c316 100644 --- a/bionemo-recipes/recipes/opengenome2_llama_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/opengenome2_llama_native_te/train_fsdp2.py @@ -261,7 +261,7 @@ def main(args: DictConfig) -> float | None: perf_logger = PerfLogger( dist_config, args, - model_config_dict=config.to_dict() if args.get("log_mfu", False) else None, + model_config_dict=config.to_dict(), ) # Setup validation if enabled diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/train_fsdp2_cp.py b/bionemo-recipes/recipes/opengenome2_llama_native_te/train_fsdp2_cp.py index f8d3c19757..e97462ebb6 100644 --- a/bionemo-recipes/recipes/opengenome2_llama_native_te/train_fsdp2_cp.py +++ b/bionemo-recipes/recipes/opengenome2_llama_native_te/train_fsdp2_cp.py @@ -301,7 +301,7 @@ def main(args: DictConfig) -> float | None: perf_logger = PerfLogger( dist_config, args, - model_config_dict=config.to_dict() if args.get("log_mfu", False) else None, + model_config_dict=config.to_dict(), ) gc.collect()