diff --git a/bionemo-recipes/recipes/codonfm_native_te/README.md b/bionemo-recipes/recipes/codonfm_native_te/README.md index de31d5b49d..ad8f04dd08 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/README.md +++ b/bionemo-recipes/recipes/codonfm_native_te/README.md @@ -177,6 +177,27 @@ python train_fsdp2.py \ A final model suitable for uploading to the Hugging Face Hub can be exported at the end of training by setting `checkpoint.save_final_model=true`. +## MFU Tracking + +Enable per-step MFU logging by adding `log_mfu=true`: + +```bash +torchrun --nproc_per_node=1 train_fsdp2.py --config-name encodon_1b log_mfu=true +``` + +Two pairs of metrics are emitted per logging interval: + +- `train/mfu_pct` / `train/tflops_per_gpu` — useful-work rate. Excludes padding of all kinds. +- `train/mfu_padded_pct` / `train/tflops_per_gpu_padded` — hardware view (HFU-like). Counts + every slot the GPU processes, including BSHD row padding. + +Non-attention uses the unpadded/padded token count respectively; attention uses `Σ(Lᵢ²)` from +`cu_seq_lens_q` (THD) or per-row `attention_mask.sum()` (BSHD) for the unpadded variant and +`cu_seq_lens_q_padded` / full `B·S²` for the padded variant. Implementation in `perf_logger.py`. + +Memory: `train/gpu_memory_allocated_max_gb` is the true transient peak per window; `_mean_gb` is +the post-step resting footprint. + ## Developer Guide ### Running Tests diff --git a/bionemo-recipes/recipes/codonfm_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/codonfm_native_te/hydra_config/defaults.yaml index 3a97660834..fa581d601f 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/hydra_config/defaults.yaml +++ b/bionemo-recipes/recipes/codonfm_native_te/hydra_config/defaults.yaml @@ -65,3 +65,4 @@ quant_stats_config: fp8_layers: null fp4_layers: null use_fp32_master_weights: null +log_mfu: false diff --git a/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py b/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py index a0b5a21b70..1ef272e81b 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/codonfm_native_te/perf_logger.py @@ -36,6 +36,147 @@ PAD_TOKEN_ID = 3 +# Dense BF16 tensor core peak TFLOPS (without sparsity). Product pages often list +# the 2x sparse number; dense = sparse / 2. Sources: NVIDIA datasheets for each GPU. +_GPU_PEAK_TFLOPS_BF16 = { + "H100": 989.0, + "H200": 989.0, + "A100": 312.0, + "A6000": 155.0, + "L40": 181.0, + "GH200": 989.0, + "B200": 2250.0, + "GB200": 2250.0, + "B300": 2500.0, + "GB300": 2500.0, +} + +# Model types that use gated MLP (SwiGLU/GeGLU) with 3 projections vs. standard FFN with 2. +_GATED_MLP_MODEL_TYPES = frozenset({"llama", "mistral", "qwen2"}) + + +def _detect_peak_tflops_bf16(): + """Auto-detect dense BF16 peak TFLOPS for the local GPU. Returns (peak, device_name).""" + if not torch.cuda.is_available(): + return None, "unknown" + name = torch.cuda.get_device_name(0) + for key, tflops in _GPU_PEAK_TFLOPS_BF16.items(): + if key.lower() in name.lower(): + return tflops, name + return None, name + + +def _compute_non_attn_per_token_flops(model_config_dict: dict, use_padded_vocab: bool = False) -> int: + """Per-token FLOPs for everything EXCEPT the S² attention term. + + Q/K/V/O projections (GQA-aware) + MLP + LM head, 3x for fwd+bwd. Multiply by the + actual total token count of the batch to get per-step non-attention FLOPs. Pairs + with ``_compute_attn_flop_coeff``, which contributes the attention term as + ``coeff · Σ(Lᵢ²)`` from cu_seq_lens. + """ + h = model_config_dict["hidden_size"] + n_heads = model_config_dict["num_attention_heads"] + n_kv = model_config_dict.get("num_key_value_heads", n_heads) + head_dim = h // n_heads + kv_dim = n_kv * head_dim + ffn = model_config_dict["intermediate_size"] + vocab = model_config_dict.get("vocab_size", 0) + if use_padded_vocab: + # LM-head matmul runs at padded width (e.g. ESM-2: vocab=33 → padded=64 for + # FP8/tensor-core friendliness); logits are sliced back post-matmul. + vocab = model_config_dict.get("padded_vocab_size") or vocab + num_layers = model_config_dict["num_hidden_layers"] + model_type = model_config_dict.get("model_type", "") + num_mlp_proj = 3 if model_type in _GATED_MLP_MODEL_TYPES else 2 + + per_layer = ( + 2 * h * h # Q projection + + 4 * h * kv_dim # K + V projections (GQA-aware) + + 2 * h * h # O projection + + 2 * num_mlp_proj * h * ffn # MLP (2 or 3 projections) + ) + lm_head = 2 * h * vocab if vocab > 0 else 0 + return 3 * (num_layers * per_layer + lm_head) + + +def _compute_attn_flop_coeff(model_config_dict: dict) -> int: + """Coefficient K such that per-step attention FLOPs = K · Σ(Lᵢ²) globally. + + Per CP rank: ``K · Σ(Lᵢ²) / cp_size`` — each CP rank computes 1/cp_size of each + doc's Lᵢ * Lᵢ score matrix. The 4 counts QK^T (2) + softmax·V (2); the 3 is + fwd+bwd. Hidden size appears linearly because attention is over heads and each + contributes head_dim, and heads * head_dim == h. + """ + h = model_config_dict["hidden_size"] + num_layers = model_config_dict["num_hidden_layers"] + return 3 * num_layers * 4 * h + + +def _attn_work_from_batch( + batch: dict, device: torch.device, cp_size: int = 1, include_padding: bool = False +) -> torch.Tensor: + """Return GLOBAL Σ(Lᵢ²) for this batch as an int64 scalar tensor. + + The caller divides by cp_size in log_step to convert this global number into + per-rank attention work; this helper always returns a pre-CP-shard quantity. + + ``include_padding=False`` (default) counts only real tokens — "useful work": + * THD: uses ``cu_seq_lens_q`` (real per-doc lengths, already global). + * BSHD: uses ``attention_mask.sum(dim=-1)`` per row, scaled by ``cp_size²`` to + recover global. + + ``include_padding=True`` counts padded positions too — "hardware view": + * THD: uses ``cu_seq_lens_q_padded`` (includes CP zigzag-divisibility padding). + * BSHD: uses full ``input_ids.shape``, scaled by ``cp_size²``. + + CodonFM currently runs FSDP without CP (cp_size=1), but the formula stays correct + if CP is added later. + Int32 lens cast to int64 BEFORE squaring (overflow at L ≈ 46k otherwise). + + NOTE: With the collator's ``pad_to_multiple_of`` option (FP8/FP4 alignment, inlined + in ``CodonTHDCollator.__call__`` in dataset.py), the cu_seq_lens_q tensor is mutated + in place to include one or more appended mock pad sequences and no + ``cu_seq_lens_q_padded`` key is written (that key is reserved for TE's per-sequence + CP padding). In that path the unpadded and padded metrics collapse, inflated by + ≤``pad_to_multiple_of²`` relative to the real Σ(Lᵢ²) — typically <10⁻⁵ and below + measurement noise. Known limitation; see + https://github.com/NVIDIA/bionemo-framework/issues/1561. + """ + if include_padding: + cu = batch.get("cu_seq_lens_q_padded") + if cu is None: + cu = batch.get("cu_seq_lens_q") + if cu is not None: + lens = (cu[1:] - cu[:-1]).to(torch.int64) + return (lens * lens).sum() + shape = batch["input_ids"].shape + batch_size, seq_len_per_rank = int(shape[0]), int(shape[-1]) + return torch.tensor( + batch_size * seq_len_per_rank * seq_len_per_rank * cp_size * cp_size, + dtype=torch.int64, + device=device, + ) + cu = batch.get("cu_seq_lens_q") + if cu is not None: + lens = (cu[1:] - cu[:-1]).to(torch.int64) + return (lens * lens).sum() + mask = batch.get("attention_mask") + if mask is not None: + per_row_real = mask.sum(dim=-1).to(torch.int64) + return (per_row_real * per_row_real).sum() * cp_size * cp_size + cu = batch.get("cu_seq_lens_q_padded") + if cu is not None: + lens = (cu[1:] - cu[:-1]).to(torch.int64) + return (lens * lens).sum() + shape = batch["input_ids"].shape + batch_size, seq_len_per_rank = int(shape[0]), int(shape[-1]) + return torch.tensor( + batch_size * seq_len_per_rank * seq_len_per_rank * cp_size * cp_size, + dtype=torch.int64, + device=device, + ) + + class PerfLogger: """Performance logger for CodonFM training. @@ -44,17 +185,49 @@ class PerfLogger: Args: dist_config: The distributed configuration. args: The Hydra arguments. + model_config_dict: Optional HF-style model config dict. When supplied together with + ``args.log_mfu`` set to True, the logger computes per-step Model FLOPs Utilization + (``train/mfu_pct``) and throughput (``train/tflops_per_gpu``) on each logging step. """ - def __init__(self, dist_config: DistributedConfig, args: DictConfig): + def __init__(self, dist_config: DistributedConfig, args: DictConfig, model_config_dict: dict | None = None): """Initialize the logger.""" self._dist_config = dist_config self._run_config = OmegaConf.to_container(args, resolve=True, throw_on_missing=True) - self.min_loss = torch.tensor(float("inf"), device=torch.device(f"cuda:{dist_config.local_rank}")) + self._device = torch.device(f"cuda:{dist_config.local_rank}") + self.min_loss = torch.tensor(float("inf"), device=self._device) self.logging_frequency = args.logger.frequency + # MFU setup: compute per-token FLOPs and peak TFLOPS once at init. Actual FLOPs per + # step are derived at log time from the tracked unpadded token count, which already + # reflects each rank's share under DP and sequence packing. + self._log_mfu = args.log_mfu and model_config_dict is not None + self._non_attn_per_token_flops = 0 + self._non_attn_per_token_flops_padded = 0 + self._attn_flop_coeff = 0 + self._cp_size = int(args.get("cp_size", 1)) + self._peak_tflops: float | None = None + if self._log_mfu: + self._non_attn_per_token_flops = _compute_non_attn_per_token_flops(model_config_dict) + self._non_attn_per_token_flops_padded = _compute_non_attn_per_token_flops( + model_config_dict, use_padded_vocab=True + ) + self._attn_flop_coeff = _compute_attn_flop_coeff(model_config_dict) + self._peak_tflops, gpu_name = _detect_peak_tflops_bf16() + if dist_config.local_rank == 0: + logger.info( + "MFU tracking enabled: GPU=%s, peak=%s TFLOPS BF16, " + "non_attn_per_token=%.3e, attn_coeff=%.3e, seq_len=%d, cp_size=%d", + gpu_name, + f"{self._peak_tflops:.1f}" if self._peak_tflops else "unknown", + float(self._non_attn_per_token_flops), + float(self._attn_flop_coeff), + args.dataset.max_seq_length, + self._cp_size, + ) + metrics_dict = { "train/loss": torchmetrics.MeanMetric(), "train/grad_norm": torchmetrics.MeanMetric(), @@ -66,9 +239,18 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig): "train/gpu_memory_allocated_max_gb": torchmetrics.MaxMetric(), "train/gpu_memory_allocated_mean_gb": torchmetrics.MeanMetric(), } + if self._log_mfu: + # Two TFLOPS/MFU pairs: + # * tflops_per_gpu / mfu_pct — useful work only (no padding) + # * tflops_per_gpu_padded / mfu_padded_pct — hardware view (counts padding slots) + metrics_dict["train/tflops_per_gpu"] = torchmetrics.MeanMetric() + metrics_dict["train/tflops_per_gpu_padded"] = torchmetrics.MeanMetric() + if self._peak_tflops is not None: + metrics_dict["train/mfu_pct"] = torchmetrics.MeanMetric() + metrics_dict["train/mfu_padded_pct"] = torchmetrics.MeanMetric() self.metrics = torchmetrics.MetricCollection(metrics_dict) - self.metrics.to(torch.device(f"cuda:{dist_config.local_rank}")) + self.metrics.to(self._device) self.previous_step_time = time.perf_counter() if self._dist_config.is_main_process(): @@ -79,9 +261,13 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig): self.quant_stats_config = args.quant_stats_config.enabled # Gradient accumulation tracking - self._device = torch.device(f"cuda:{dist_config.local_rank}") self.num_tokens = 0 self.num_unpadded_tokens = torch.tensor(0, dtype=torch.int64, device=self._device) + # Σ(Lᵢ²) over grad-acc micro-batches — two flavors: + # unpadded: only real tokens (useful work), drives mfu_pct + # padded: all slots including CP-zigzag / BSHD row padding, drives mfu_padded_pct + self._attn_work_unpadded_accum = torch.tensor(0, dtype=torch.int64, device=self._device) + self._attn_work_padded_accum = torch.tensor(0, dtype=torch.int64, device=self._device) self.running_loss = torch.tensor(0.0, device=self._device) self.grad_acc_step_count = 0 @@ -103,6 +289,15 @@ def log_micro_step(self, step: int, batch: dict[str, torch.Tensor], outputs: Mas self.num_tokens += batch["input_ids"].numel() num_unpadded_tokens = batch["input_ids"][batch["input_ids"] != PAD_TOKEN_ID].numel() self.num_unpadded_tokens += num_unpadded_tokens + if self._log_mfu: + # Accumulate both unpadded (useful) and padded (hardware) Σ(Lᵢ²). + # Helper returns a GLOBAL value (pre-CP-shard); log_step divides by cp_size. + self._attn_work_unpadded_accum += _attn_work_from_batch( + batch, self._device, self._cp_size, include_padding=False + ) + self._attn_work_padded_accum += _attn_work_from_batch( + batch, self._device, self._cp_size, include_padding=True + ) # Update perplexity per micro-batch since it needs logits + labels logits = outputs.logits @@ -155,9 +350,42 @@ def log_step( self.metrics["train/tokens_per_second_per_gpu"].update(self.num_tokens / step_time) self.metrics["train/unpadded_tokens_per_second_per_gpu"].update(self.num_unpadded_tokens / step_time) - memory_allocated = torch.cuda.memory_allocated() / (1024**3) - self.metrics["train/gpu_memory_allocated_max_gb"].update(memory_allocated) - self.metrics["train/gpu_memory_allocated_mean_gb"].update(memory_allocated) + if self._log_mfu: + # Two MFU flavors reported side-by-side: + # mfu_pct = useful-work rate. Non-attn over real tokens, + # attn over real Σ(Lᵢ²). Drops both padding types. + # mfu_padded_pct = hardware view. Non-attn over all slots, attn over + # padded Σ(Lᵢ²) (includes CP zigzag + BSHD row pad). + attn_unpadded = int(self._attn_work_unpadded_accum.item()) + attn_padded = int(self._attn_work_padded_accum.item()) + num_unpadded = int(self.num_unpadded_tokens.item()) + + non_attn_unpadded = self._non_attn_per_token_flops * num_unpadded + attn_flops_unpadded = (self._attn_flop_coeff * attn_unpadded) // self._cp_size + flops_unpadded = non_attn_unpadded + attn_flops_unpadded + tflops_unpadded = flops_unpadded / step_time / 1e12 + + non_attn_padded = self._non_attn_per_token_flops_padded * self.num_tokens + attn_flops_padded = (self._attn_flop_coeff * attn_padded) // self._cp_size + flops_padded = non_attn_padded + attn_flops_padded + tflops_padded = flops_padded / step_time / 1e12 + + self.metrics["train/tflops_per_gpu"].update(tflops_unpadded) + self.metrics["train/tflops_per_gpu_padded"].update(tflops_padded) + if self._peak_tflops is not None: + self.metrics["train/mfu_pct"].update(tflops_unpadded / self._peak_tflops * 100.0) + self.metrics["train/mfu_padded_pct"].update(tflops_padded / self._peak_tflops * 100.0) + + # Report TRUE peak memory across the logging window (FSDP-gathered params + + # activations held for backward), not just the post-step resting footprint. + # Reset the peak counter so each window reports its own peak instead of a + # running max since process start. Both calls are pure host-side counter ops + # -- no sync, no kernel launch. + peak_gb = torch.cuda.max_memory_allocated() / (1024**3) + current_gb = torch.cuda.memory_allocated() / (1024**3) + torch.cuda.reset_peak_memory_stats() + self.metrics["train/gpu_memory_allocated_max_gb"].update(peak_gb) + self.metrics["train/gpu_memory_allocated_mean_gb"].update(current_gb) metrics = self.metrics.compute() self.metrics.reset() @@ -179,6 +407,8 @@ def log_step( self.running_loss.zero_() self.num_tokens = 0 self.num_unpadded_tokens.zero_() + self._attn_work_unpadded_accum.zero_() + self._attn_work_padded_accum.zero_() self.grad_acc_step_count = 0 def finish(self): diff --git a/bionemo-recipes/recipes/codonfm_native_te/tests/test_perf_logger.py b/bionemo-recipes/recipes/codonfm_native_te/tests/test_perf_logger.py index b2e6f651eb..5d1301b8f3 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/tests/test_perf_logger.py +++ b/bionemo-recipes/recipes/codonfm_native_te/tests/test_perf_logger.py @@ -21,14 +21,19 @@ import torch from distributed_config import DistributedConfig from omegaconf import OmegaConf -from perf_logger import PerfLogger +from perf_logger import ( + PerfLogger, + _attn_work_from_batch, + _compute_attn_flop_coeff, + _compute_non_attn_per_token_flops, +) from transformers.modeling_outputs import MaskedLMOutput VOCAB_SIZE = 69 # CodonFM vocabulary size -def _make_args(logging_frequency=1, num_train_steps=100): +def _make_args(logging_frequency=1, num_train_steps=100, log_mfu=False): """Create a minimal args config for PerfLogger.""" return OmegaConf.create( { @@ -36,6 +41,7 @@ def _make_args(logging_frequency=1, num_train_steps=100): "wandb_init_args": {"project": "test", "mode": "disabled"}, "num_train_steps": num_train_steps, "quant_stats_config": {"enabled": False}, + "log_mfu": log_mfu, } ) @@ -210,3 +216,112 @@ def test_min_loss_tracked_correctly(self, mock_wandb, mock_tqdm): _run_steps(perf_logger, losses) assert perf_logger.min_loss.item() == pytest.approx(1.0) + + +def _codon_cfg(): + """CodonFM-like config for the split-formula tests (MLM encoder).""" + return { + "model_type": "codonfm", # not in _GATED_MLP_MODEL_TYPES → standard 2-proj MLP + "hidden_size": 1024, + "num_hidden_layers": 24, + "num_attention_heads": 16, + "intermediate_size": 4096, + "vocab_size": VOCAB_SIZE, + } + + +class TestFlopSplitAndAttention: + """Verify the non-attn + Σ(Lᵢ²) attention formula is correctly computed.""" + + def test_bshd_shape_synthesis(self): + """BSHD batch (no cu_seq_lens) synthesizes Σ(Lᵢ²) = B·S² from input_ids shape.""" + b, s = 4, 512 + batch = {"input_ids": torch.zeros(b, s, dtype=torch.long)} + sigma_l_sq = _attn_work_from_batch(batch, torch.device("cpu")).item() + assert sigma_l_sq == b * s * s + + def test_thd_single_doc_matches_bshd(self): + """cu_seq_lens_q=[0, S] reproduces BSHD's Σ(Lᵢ²)=S².""" + s = 512 + bshd = {"input_ids": torch.zeros(1, s, dtype=torch.long)} + thd = { + "input_ids": torch.zeros(1, s, dtype=torch.long), + "cu_seq_lens_q": torch.tensor([0, s], dtype=torch.int32), + } + assert _attn_work_from_batch(bshd, torch.device("cpu")).item() == s * s + assert _attn_work_from_batch(thd, torch.device("cpu")).item() == s * s + + def test_thd_multi_doc_uses_squared_sum(self): + """Multi-doc pack computes Σ(Lᵢ²), not (ΣLᵢ)².""" + cu = torch.tensor([0, 3, 8, 15], dtype=torch.int32) + batch = {"input_ids": torch.zeros(1, 15, dtype=torch.long), "cu_seq_lens_q": cu} + work = _attn_work_from_batch(batch, torch.device("cpu")).item() + assert work == 3**2 + 5**2 + 7**2 + assert work < 15 * 15 + + def test_cp_size_divides_attention_only(self): + """cp_size divides the attention term only; non-attn stays untouched. + Codonfm doesn't support CP, but the formula must still respect cp_size=1 default.""" + cfg = _codon_cfg() + non_attn_per_token = _compute_non_attn_per_token_flops(cfg) + coeff = _compute_attn_flop_coeff(cfg) + num_tokens, attn_work = 100, 10_000 + non_attn = non_attn_per_token * num_tokens + flops_cp1 = non_attn + (coeff * attn_work) // 1 + flops_cp4 = non_attn + (coeff * attn_work) // 4 + assert flops_cp1 - non_attn == coeff * attn_work + assert flops_cp4 - non_attn == (coeff * attn_work) // 4 + + def test_unpadded_preferred_over_padded(self): + """When both cu_seq_lens_q and cu_seq_lens_q_padded are present, _q wins.""" + batch = { + "input_ids": torch.zeros(1, 16, dtype=torch.long), + "cu_seq_lens_q": torch.tensor([0, 5, 11], dtype=torch.int32), + "cu_seq_lens_q_padded": torch.tensor([0, 8, 16], dtype=torch.int32), + } + work = _attn_work_from_batch(batch, torch.device("cpu")).item() + assert work == 5**2 + 6**2 + assert work != 8**2 + 8**2 + + def test_padded_fallback_when_unpadded_absent(self): + """If only cu_seq_lens_q_padded is present, it is used as a fallback.""" + batch = { + "input_ids": torch.zeros(1, 16, dtype=torch.long), + "cu_seq_lens_q_padded": torch.tensor([0, 8, 16], dtype=torch.int32), + } + assert _attn_work_from_batch(batch, torch.device("cpu")).item() == 8**2 + 8**2 + + def test_bshd_cp_correction(self): + """BSHD with CP: per-rank shape (B, S/cp) → helper must return global B*S². + + CodonFM currently runs FSDP without CP so this is latent defence, but the + formula must be correct if CP is added. + """ + batch = {"input_ids": torch.zeros(1, 16, dtype=torch.long)} + assert _attn_work_from_batch(batch, torch.device("cpu"), cp_size=1).item() == 1 * 16 * 16 + assert _attn_work_from_batch(batch, torch.device("cpu"), cp_size=8).item() == 128 * 128 + thd = { + "input_ids": torch.zeros(1, 64, dtype=torch.long), + "cu_seq_lens_q": torch.tensor([0, 3, 8, 15], dtype=torch.int32), + } + assert _attn_work_from_batch(thd, torch.device("cpu"), cp_size=1).item() == 3**2 + 5**2 + 7**2 + assert _attn_work_from_batch(thd, torch.device("cpu"), cp_size=8).item() == 3**2 + 5**2 + 7**2 + + def test_include_padding_thd(self): + """THD include_padding=True uses cu_seq_lens_q_padded; False uses cu_seq_lens_q.""" + batch = { + "input_ids": torch.zeros(1, 16, dtype=torch.long), + "cu_seq_lens_q": torch.tensor([0, 5, 11], dtype=torch.int32), + "cu_seq_lens_q_padded": torch.tensor([0, 8, 16], dtype=torch.int32), + } + dev = torch.device("cpu") + assert _attn_work_from_batch(batch, dev, cp_size=1, include_padding=False).item() == 5**2 + 6**2 + assert _attn_work_from_batch(batch, dev, cp_size=1, include_padding=True).item() == 8**2 + 8**2 + + def test_include_padding_bshd_with_attention_mask(self): + """BSHD include_padding=False uses attention_mask; True uses full shape.""" + mask = torch.tensor([[1, 1, 1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0, 0, 0]], dtype=torch.int64) + batch = {"input_ids": torch.zeros(2, 8, dtype=torch.long), "attention_mask": mask} + dev = torch.device("cpu") + assert _attn_work_from_batch(batch, dev, cp_size=1, include_padding=False).item() == 5**2 + 3**2 + assert _attn_work_from_batch(batch, dev, cp_size=1, include_padding=True).item() == 2 * 8 * 8 diff --git a/bionemo-recipes/recipes/codonfm_native_te/train_fsdp2.py b/bionemo-recipes/recipes/codonfm_native_te/train_fsdp2.py index 8b07f8954e..479cf59783 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/codonfm_native_te/train_fsdp2.py @@ -163,7 +163,11 @@ def main(args: DictConfig) -> float | None: start_step = 0 epoch = 0 - perf_logger = PerfLogger(dist_config, args) + perf_logger = PerfLogger( + dist_config, + args, + model_config_dict=config.to_dict(), + ) # Training loop step = start_step diff --git a/bionemo-recipes/recipes/esm2_native_te/README.md b/bionemo-recipes/recipes/esm2_native_te/README.md index 330b8152db..067841259b 100644 --- a/bionemo-recipes/recipes/esm2_native_te/README.md +++ b/bionemo-recipes/recipes/esm2_native_te/README.md @@ -374,6 +374,28 @@ output = model(**inputs) - [ESM-2 Training with Accelerate](../esm2_accelerate_te/README.md) +## MFU Tracking + +Enable per-step MFU logging by adding `log_mfu=true`: + +```bash +torchrun --nproc_per_node=2 train_fsdp2.py --config-name L1_3B log_mfu=true +``` + +Two pairs of metrics are emitted per logging interval: + +- `train/mfu_pct` / `train/tflops_per_gpu` — useful-work rate. Excludes padding of all kinds. + Non-attention uses the unpadded token count; attention uses `Σ(Lᵢ²)` from `cu_seq_lens_q` (THD) + or per-row `attention_mask.sum()` (BSHD). +- `train/mfu_padded_pct` / `train/tflops_per_gpu_padded` — hardware view. Counts every slot the + GPU processes, including CP-zigzag and BSHD row padding. HFU-like. + +The two pairs agree when the batch has no padding. The formula is CP-aware and auto-detects +MHA/GQA and FFN layout from the HF config. Implementation in `perf_logger.py`. + +Memory: `train/gpu_memory_allocated_max_gb` is the true transient peak per window +(`torch.cuda.max_memory_allocated()` + `reset_peak_memory_stats()`); `_mean_gb` is resting. + ## Developer Guide ### Running Tests diff --git a/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml index 969c8b3822..5857ee3ebe 100644 --- a/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml @@ -12,6 +12,8 @@ use_torch_compile: false cp_size: 1 +log_mfu: false + use_sequence_packing: false dataset: tokenizer_name: ${config_name_or_path} diff --git a/bionemo-recipes/recipes/esm2_native_te/perf_logger.py b/bionemo-recipes/recipes/esm2_native_te/perf_logger.py index 2e67b3aaa5..b3ef7e5609 100644 --- a/bionemo-recipes/recipes/esm2_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/esm2_native_te/perf_logger.py @@ -32,26 +32,204 @@ logger = logging.getLogger(__name__) +# ESM-2 uses token id 1 for the token. Unpadded-token counting filters this id out. +ESM2_PAD_TOKEN_ID = 1 + +# Dense BF16 tensor core peak TFLOPS (without sparsity). Product pages often list +# the 2x sparse number; dense = sparse / 2. Sources: NVIDIA datasheets for each GPU. +_GPU_PEAK_TFLOPS_BF16 = { + "H100": 989.0, + "H200": 989.0, + "A100": 312.0, + "A6000": 155.0, + "L40": 181.0, + "GH200": 989.0, + "B200": 2250.0, + "GB200": 2250.0, + "B300": 2500.0, + "GB300": 2500.0, +} + +# Model types that use gated MLP (SwiGLU/GeGLU) with 3 projections vs. standard FFN with 2. +_GATED_MLP_MODEL_TYPES = frozenset({"llama", "mistral", "qwen2"}) + + +def _detect_peak_tflops_bf16(): + """Auto-detect dense BF16 peak TFLOPS for the local GPU. Returns (peak, device_name).""" + if not torch.cuda.is_available(): + return None, "unknown" + name = torch.cuda.get_device_name(0) + for key, tflops in _GPU_PEAK_TFLOPS_BF16.items(): + if key.lower() in name.lower(): + return tflops, name + return None, name + + +def _compute_non_attn_per_token_flops(model_config_dict: dict, use_padded_vocab: bool = False) -> int: + """Per-token FLOPs for everything EXCEPT the S² attention term. + + Q/K/V/O projections (GQA-aware) + MLP + LM head, 3x for fwd+bwd. Multiply by the + actual total token count of the batch to get per-step non-attention FLOPs. Pairs + with ``_compute_attn_flop_coeff``, which contributes the attention term as + ``coeff · Σ(Lᵢ²)`` from cu_seq_lens. + """ + h = model_config_dict["hidden_size"] + n_heads = model_config_dict["num_attention_heads"] + n_kv = model_config_dict.get("num_key_value_heads", n_heads) + head_dim = h // n_heads + kv_dim = n_kv * head_dim + ffn = model_config_dict["intermediate_size"] + vocab = model_config_dict.get("vocab_size", 0) + if use_padded_vocab: + # LM-head matmul runs at padded width (e.g. ESM-2: vocab=33 → padded=64 for + # FP8/tensor-core friendliness); logits are sliced back post-matmul. + vocab = model_config_dict.get("padded_vocab_size") or vocab + num_layers = model_config_dict["num_hidden_layers"] + model_type = model_config_dict.get("model_type", "") + num_mlp_proj = 3 if model_type in _GATED_MLP_MODEL_TYPES else 2 + + per_layer = ( + 2 * h * h # Q projection + + 4 * h * kv_dim # K + V projections (GQA-aware) + + 2 * h * h # O projection + + 2 * num_mlp_proj * h * ffn # MLP (2 or 3 projections) + ) + lm_head = 2 * h * vocab if vocab > 0 else 0 + return 3 * (num_layers * per_layer + lm_head) + + +def _compute_attn_flop_coeff(model_config_dict: dict) -> int: + """Coefficient K such that per-step attention FLOPs = K · Σ(Lᵢ²) globally. + + Per CP rank: ``K · Σ(Lᵢ²) / cp_size`` — each CP rank computes 1/cp_size of each + doc's Lᵢ * Lᵢ score matrix. The 4 counts QK^T (2) + softmax·V (2); the 3 is + fwd+bwd. Hidden size appears linearly because attention is over heads and each + contributes head_dim, and heads * head_dim == h. + """ + h = model_config_dict["hidden_size"] + num_layers = model_config_dict["num_hidden_layers"] + return 3 * num_layers * 4 * h + + +def _attn_work_from_batch( + batch: dict, device: torch.device, cp_size: int = 1, include_padding: bool = False +) -> torch.Tensor: + """Return GLOBAL Σ(Lᵢ²) for this batch as an int64 scalar tensor. + + The caller divides by cp_size in log_step to convert this global number into + per-rank attention work; this helper always returns a pre-CP-shard quantity. + + ``include_padding=False`` (default) counts only real tokens — "useful work": + * THD: uses ``cu_seq_lens_q`` (real per-doc lengths, already global). + * BSHD: uses ``attention_mask.sum(dim=-1)`` per row, scaled by ``cp_size²`` to + recover global. + + ``include_padding=True`` counts padded positions too — "hardware view": + * THD: uses ``cu_seq_lens_q_padded`` (includes CP zigzag-divisibility padding). + * BSHD: uses full ``input_ids.shape``, scaled by ``cp_size²``. + + Int32 lens cast to int64 BEFORE squaring (overflow at L ≈ 46k otherwise). + + NOTE: With the collator's ``pad_to_multiple_of`` option (FP8/FP4 alignment), the + cu_seq_lens_q tensor is mutated in place to include an appended mock pad sequence + and no ``cu_seq_lens_q_padded`` key is written (that key is reserved for TE's + per-sequence CP padding). In that path the unpadded and padded metrics collapse, + inflated by ≤``pad_to_multiple_of²`` relative to the real Σ(Lᵢ²) — typically + <10⁻⁵ and below measurement noise. Known limitation; see + https://github.com/NVIDIA/bionemo-framework/issues/1561. + """ + if include_padding: + cu = batch.get("cu_seq_lens_q_padded") + if cu is None: + cu = batch.get("cu_seq_lens_q") + if cu is not None: + lens = (cu[1:] - cu[:-1]).to(torch.int64) + return (lens * lens).sum() + shape = batch["input_ids"].shape + batch_size, seq_len_per_rank = int(shape[0]), int(shape[-1]) + return torch.tensor( + batch_size * seq_len_per_rank * seq_len_per_rank * cp_size * cp_size, + dtype=torch.int64, + device=device, + ) + cu = batch.get("cu_seq_lens_q") + if cu is not None: + lens = (cu[1:] - cu[:-1]).to(torch.int64) + return (lens * lens).sum() + mask = batch.get("attention_mask") + if mask is not None: + per_row_real = mask.sum(dim=-1).to(torch.int64) + return (per_row_real * per_row_real).sum() * cp_size * cp_size + cu = batch.get("cu_seq_lens_q_padded") + if cu is not None: + lens = (cu[1:] - cu[:-1]).to(torch.int64) + return (lens * lens).sum() + shape = batch["input_ids"].shape + batch_size, seq_len_per_rank = int(shape[0]), int(shape[-1]) + return torch.tensor( + batch_size * seq_len_per_rank * seq_len_per_rank * cp_size * cp_size, + dtype=torch.int64, + device=device, + ) + + class PerfLogger: """Class to log performance metrics to stdout and wandb, and print final averaged metrics at the end of training. + ESM-2 does not perform gradient accumulation — each optimizer step is a single + forward+backward — so ``log_step`` reads the batch and outputs directly without + cross-micro-batch accumulators. The other MFU-tracking recipes (llama3, og2, + codonfm) do grad-accumulate and use a separate ``log_micro_step`` / ``log_step`` + split in their own perf_logger modules. + Args: dist_config: The distributed configuration. args: The arguments. + model_config_dict: Optional HF-style model config dict. When supplied together with + ``args.log_mfu`` set to True, the logger computes per-step Model FLOPs Utilization + (``train/mfu_pct``) and throughput (``train/tflops_per_gpu``) on each logging step. Attributes: min_loss: The minimum loss seen so far. """ - def __init__(self, dist_config: DistributedConfig, args: DictConfig): + def __init__(self, dist_config: DistributedConfig, args: DictConfig, model_config_dict: dict | None = None): """Initialize the logger.""" self._dist_config = dist_config self._run_config = OmegaConf.to_container(args, resolve=True, throw_on_missing=True) - self.min_loss = torch.tensor(float("inf"), device=torch.device(f"cuda:{dist_config.local_rank}")) + self._device = torch.device(f"cuda:{dist_config.local_rank}") + self.min_loss = torch.tensor(float("inf"), device=self._device) self.logging_frequency = args.logger.frequency - # Track whether to collect memory stats (disabled by default for max performance) + + # MFU setup: compute per-token FLOPs and peak TFLOPS once at init. Actual FLOPs per + # step are derived at log time from the accumulated token count + Σ(Lᵢ²), which + # already reflects each rank's share under DP/CP and sequence packing. + self._log_mfu = args.log_mfu and model_config_dict is not None + self._non_attn_per_token_flops = 0 + self._non_attn_per_token_flops_padded = 0 + self._attn_flop_coeff = 0 + self._cp_size = int(args.get("cp_size", 1)) + self._peak_tflops: float | None = None + if self._log_mfu: + self._non_attn_per_token_flops = _compute_non_attn_per_token_flops(model_config_dict) + self._non_attn_per_token_flops_padded = _compute_non_attn_per_token_flops( + model_config_dict, use_padded_vocab=True + ) + self._attn_flop_coeff = _compute_attn_flop_coeff(model_config_dict) + self._peak_tflops, gpu_name = _detect_peak_tflops_bf16() + if dist_config.local_rank == 0: + logger.info( + "MFU tracking enabled: GPU=%s, peak=%s TFLOPS BF16, " + "non_attn_per_token=%.3e, attn_coeff=%.3e, seq_len=%d, cp_size=%d", + gpu_name, + f"{self._peak_tflops:.1f}" if self._peak_tflops else "unknown", + float(self._non_attn_per_token_flops), + float(self._attn_flop_coeff), + args.dataset.max_seq_length, + self._cp_size, + ) metrics_dict = { "train/loss": torchmetrics.MeanMetric(), @@ -65,10 +243,19 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig): "train/gpu_memory_allocated_max_gb": torchmetrics.MaxMetric(), "train/gpu_memory_allocated_mean_gb": torchmetrics.MeanMetric(), } + if self._log_mfu: + # Two TFLOPS/MFU pairs: + # * tflops_per_gpu / mfu_pct — useful work only (no padding) + # * tflops_per_gpu_padded / mfu_padded_pct — hardware view (counts padding slots) + metrics_dict["train/tflops_per_gpu"] = torchmetrics.MeanMetric() + metrics_dict["train/tflops_per_gpu_padded"] = torchmetrics.MeanMetric() + if self._peak_tflops is not None: + metrics_dict["train/mfu_pct"] = torchmetrics.MeanMetric() + metrics_dict["train/mfu_padded_pct"] = torchmetrics.MeanMetric() self.metrics = torchmetrics.MetricCollection(metrics_dict) # We move metrics to a GPU device so we can use torch.distributed to aggregate them before logging. - self.metrics.to(torch.device(f"cuda:{dist_config.local_rank}")) + self.metrics.to(self._device) self.previous_step_time = time.perf_counter() if self._dist_config.is_main_process(): @@ -84,18 +271,20 @@ def log_step( step: int, batch: dict[str, torch.Tensor], outputs: MaskedLMOutput, - grad_norm: torch.Tensor | DTensor, + grad_norm: torch.Tensor | DTensor | float, lr: float, ): - """Log a step to the logger and wandb. + """Log a training step (called once per optimizer step). Args: - step: The step number. - batch: The batch of data for the step. - outputs: The outputs of the step. - grad_norm: The gradient norm of the step. - lr: The learning rate of the step. + step: Current optimizer step. + batch: The input batch for this step. + outputs: Model outputs for this step (with loss + logits). + grad_norm: Gradient norm value. + lr: Current learning rate. """ + assert outputs.loss is not None, "Loss is None" + with torch.no_grad(): # FSDP2's clip_grad_norm_ returns a DTensor; convert to local tensor for torchmetrics compatibility. if isinstance(grad_norm, DTensor): @@ -104,35 +293,74 @@ def log_step( if self.quant_stats_config: debug_api.step() - if step % self.logging_frequency == 0 and step > 0: - num_tokens = batch["input_ids"].numel() - # 1 is the padding token for ESM-2. - num_unpadded_tokens = batch["input_ids"][batch["input_ids"] != 1].numel() + self.min_loss = torch.minimum(self.min_loss, outputs.loss) - self.min_loss = torch.minimum(self.min_loss, outputs.loss) + if step % self.logging_frequency == 0 and step > 0: elapsed_time, self.previous_step_time = ( time.perf_counter() - self.previous_step_time, time.perf_counter(), ) step_time = elapsed_time / self.logging_frequency + num_tokens = batch["input_ids"].numel() + num_unpadded_tokens = batch["input_ids"][batch["input_ids"] != ESM2_PAD_TOKEN_ID].numel() + + # Update perplexity from logits + labels (logits get a leading batch dim if absent). + logits = outputs.logits + if logits.dim() < 3: + logits = logits.unsqueeze(0) + self.metrics["train/perplexity"].update(logits, batch["labels"]) + self.metrics["train/loss"].update(outputs.loss) self.metrics["train/learning_rate"].update(lr) - self.metrics["train/grad_norm"].update(grad_norm) + self.metrics["train/grad_norm"].update( + grad_norm if isinstance(grad_norm, torch.Tensor) else torch.tensor(grad_norm) + ) self.metrics["train/step_time"].update(step_time) self.metrics["train/tokens_per_second_per_gpu"].update(num_tokens / step_time) self.metrics["train/unpadded_tokens_per_second_per_gpu"].update(num_unpadded_tokens / step_time) self.metrics["train/total_unpadded_tokens_per_batch"].update(num_unpadded_tokens) - # Handle sequence packing for torchmetrics calculation. - if outputs.logits.dim() < 3: - outputs.logits = outputs.logits.unsqueeze(0) + if self._log_mfu: + # Two MFU flavors reported side-by-side: + # mfu_pct = useful-work rate. Non-attn over real tokens, + # attn over real Σ(Lᵢ²). Drops both padding types. + # mfu_padded_pct = hardware view. Non-attn over all slots, attn over + # padded Σ(Lᵢ²) (includes CP zigzag + BSHD row pad). + # Helper returns GLOBAL Σ(Lᵢ²); divide by cp_size to convert to per-rank. + attn_unpadded = int( + _attn_work_from_batch(batch, self._device, self._cp_size, include_padding=False).item() + ) + attn_padded = int( + _attn_work_from_batch(batch, self._device, self._cp_size, include_padding=True).item() + ) + + non_attn_unpadded = self._non_attn_per_token_flops * num_unpadded_tokens + attn_flops_unpadded = (self._attn_flop_coeff * attn_unpadded) // self._cp_size + flops_unpadded = non_attn_unpadded + attn_flops_unpadded + tflops_unpadded = flops_unpadded / step_time / 1e12 + + non_attn_padded = self._non_attn_per_token_flops_padded * num_tokens + attn_flops_padded = (self._attn_flop_coeff * attn_padded) // self._cp_size + flops_padded = non_attn_padded + attn_flops_padded + tflops_padded = flops_padded / step_time / 1e12 - self.metrics["train/perplexity"].update(outputs.logits, batch["labels"]) + self.metrics["train/tflops_per_gpu"].update(tflops_unpadded) + self.metrics["train/tflops_per_gpu_padded"].update(tflops_padded) + if self._peak_tflops is not None: + self.metrics["train/mfu_pct"].update(tflops_unpadded / self._peak_tflops * 100.0) + self.metrics["train/mfu_padded_pct"].update(tflops_padded / self._peak_tflops * 100.0) - memory_allocated = torch.cuda.memory_allocated() / (1024**3) - self.metrics["train/gpu_memory_allocated_max_gb"].update(memory_allocated) - self.metrics["train/gpu_memory_allocated_mean_gb"].update(memory_allocated) + # Report TRUE peak memory across the logging window (FSDP-gathered params + + # activations held for backward), not just the post-step resting footprint. + # Reset the peak counter so each window reports its own peak instead of a + # running max since process start. Both calls are pure host-side counter ops + # -- no sync, no kernel launch. + peak_gb = torch.cuda.max_memory_allocated() / (1024**3) + current_gb = torch.cuda.memory_allocated() / (1024**3) + torch.cuda.reset_peak_memory_stats() + self.metrics["train/gpu_memory_allocated_max_gb"].update(peak_gb) + self.metrics["train/gpu_memory_allocated_mean_gb"].update(current_gb) metrics = self.metrics.compute() self.metrics.reset() diff --git a/bionemo-recipes/recipes/esm2_native_te/tests/test_perf_logger.py b/bionemo-recipes/recipes/esm2_native_te/tests/test_perf_logger.py new file mode 100644 index 0000000000..77b3609ed9 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_native_te/tests/test_perf_logger.py @@ -0,0 +1,138 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for ESM-2's PerfLogger: non-attn + Σ(Lᵢ²) attention FLOP formula. + +ESM-2 does not perform gradient accumulation — each optimizer step is a single +forward+backward — so PerfLogger has a single ``log_step`` entry point that reads +the batch and outputs directly. The other MFU-tracking recipes (llama3, og2, +codonfm) do grad-accumulate and use a ``log_micro_step`` / ``log_step`` split in +their own perf_logger modules. +""" + +import torch + +from perf_logger import ( + _attn_work_from_batch, + _compute_attn_flop_coeff, + _compute_non_attn_per_token_flops, +) + + +ESM2_VOCAB = 33 + + +def _esm_cfg(): + """ESM-2-like MLM encoder config (MHA, no GQA, gelu MLP).""" + return { + "model_type": "esm", # not in _GATED_MLP_MODEL_TYPES → 2-proj MLP + "hidden_size": 1280, + "num_hidden_layers": 33, + "num_attention_heads": 20, + "intermediate_size": 5120, + "vocab_size": ESM2_VOCAB, + } + + +class TestFlopSplitAndAttention: + """Verify the non-attn + Σ(Lᵢ²) attention formula is correctly computed for ESM-2.""" + + def test_bshd_shape_synthesis(self): + """BSHD batch (no cu_seq_lens) synthesizes Σ(Lᵢ²) = B·S² from input_ids shape.""" + b, s = 2, 512 + batch = {"input_ids": torch.zeros(b, s, dtype=torch.long)} + sigma_l_sq = _attn_work_from_batch(batch, torch.device("cpu")).item() + assert sigma_l_sq == b * s * s + + def test_thd_single_doc_matches_bshd(self): + """cu_seq_lens_q=[0, S] reproduces BSHD's Σ(Lᵢ²)=S².""" + s = 512 + bshd = {"input_ids": torch.zeros(1, s, dtype=torch.long)} + thd = { + "input_ids": torch.zeros(1, s, dtype=torch.long), + "cu_seq_lens_q": torch.tensor([0, s], dtype=torch.int32), + } + assert _attn_work_from_batch(bshd, torch.device("cpu")).item() == s * s + assert _attn_work_from_batch(thd, torch.device("cpu")).item() == s * s + + def test_thd_multi_doc_uses_squared_sum(self): + """Multi-doc pack computes Σ(Lᵢ²), not (ΣLᵢ)².""" + cu = torch.tensor([0, 3, 8, 15], dtype=torch.int32) + batch = {"input_ids": torch.zeros(1, 15, dtype=torch.long), "cu_seq_lens_q": cu} + work = _attn_work_from_batch(batch, torch.device("cpu")).item() + assert work == 3**2 + 5**2 + 7**2 + assert work < 15 * 15 + + def test_cp_size_divides_attention_only(self): + """cp_size divides the attention term only; non-attn untouched.""" + cfg = _esm_cfg() + non_attn_per_token = _compute_non_attn_per_token_flops(cfg) + coeff = _compute_attn_flop_coeff(cfg) + num_tokens, attn_work = 100, 10_000 + non_attn = non_attn_per_token * num_tokens + flops_cp1 = non_attn + (coeff * attn_work) // 1 + flops_cp4 = non_attn + (coeff * attn_work) // 4 + assert flops_cp1 - non_attn == coeff * attn_work + assert flops_cp4 - non_attn == (coeff * attn_work) // 4 + + def test_unpadded_preferred_over_padded(self): + """When both cu_seq_lens_q and cu_seq_lens_q_padded are present, _q wins.""" + batch = { + "input_ids": torch.zeros(1, 16, dtype=torch.long), + "cu_seq_lens_q": torch.tensor([0, 5, 11], dtype=torch.int32), + "cu_seq_lens_q_padded": torch.tensor([0, 8, 16], dtype=torch.int32), + } + work = _attn_work_from_batch(batch, torch.device("cpu")).item() + assert work == 5**2 + 6**2 + assert work != 8**2 + 8**2 + + def test_padded_fallback_when_unpadded_absent(self): + """If only cu_seq_lens_q_padded is present, it is used as a fallback.""" + batch = { + "input_ids": torch.zeros(1, 16, dtype=torch.long), + "cu_seq_lens_q_padded": torch.tensor([0, 8, 16], dtype=torch.int32), + } + assert _attn_work_from_batch(batch, torch.device("cpu")).item() == 8**2 + 8**2 + + def test_bshd_cp_correction(self): + """BSHD with CP: per-rank shape (B, S/cp) → helper must return global B*S².""" + batch = {"input_ids": torch.zeros(1, 16, dtype=torch.long)} + assert _attn_work_from_batch(batch, torch.device("cpu"), cp_size=1).item() == 1 * 16 * 16 + assert _attn_work_from_batch(batch, torch.device("cpu"), cp_size=8).item() == 128 * 128 + thd = { + "input_ids": torch.zeros(1, 64, dtype=torch.long), + "cu_seq_lens_q": torch.tensor([0, 3, 8, 15], dtype=torch.int32), + } + assert _attn_work_from_batch(thd, torch.device("cpu"), cp_size=1).item() == 3**2 + 5**2 + 7**2 + assert _attn_work_from_batch(thd, torch.device("cpu"), cp_size=8).item() == 3**2 + 5**2 + 7**2 + + def test_include_padding_thd(self): + """THD include_padding=True uses cu_seq_lens_q_padded; False uses cu_seq_lens_q.""" + batch = { + "input_ids": torch.zeros(1, 16, dtype=torch.long), + "cu_seq_lens_q": torch.tensor([0, 5, 11], dtype=torch.int32), + "cu_seq_lens_q_padded": torch.tensor([0, 8, 16], dtype=torch.int32), + } + dev = torch.device("cpu") + assert _attn_work_from_batch(batch, dev, cp_size=1, include_padding=False).item() == 5**2 + 6**2 + assert _attn_work_from_batch(batch, dev, cp_size=1, include_padding=True).item() == 8**2 + 8**2 + + def test_include_padding_bshd_with_attention_mask(self): + """BSHD include_padding=False uses attention_mask; True uses full shape.""" + mask = torch.tensor([[1, 1, 1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0, 0, 0]], dtype=torch.int64) + batch = {"input_ids": torch.zeros(2, 8, dtype=torch.long), "attention_mask": mask} + dev = torch.device("cpu") + assert _attn_work_from_batch(batch, dev, cp_size=1, include_padding=False).item() == 5**2 + 3**2 + assert _attn_work_from_batch(batch, dev, cp_size=1, include_padding=True).item() == 2 * 8 * 8 diff --git a/bionemo-recipes/recipes/esm2_native_te/train_ddp.py b/bionemo-recipes/recipes/esm2_native_te/train_ddp.py index 65ad1fa2f3..b18b96f9fe 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_ddp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_ddp.py @@ -156,7 +156,11 @@ def main(args: DictConfig) -> float | None: start_step = 0 epoch = 0 - perf_logger = PerfLogger(dist_config, args) + perf_logger = PerfLogger( + dist_config, + args, + model_config_dict=config.to_dict(), + ) # Training loop step = start_step diff --git a/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py b/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py index aec8bb0a6d..6762e707c9 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py @@ -165,7 +165,11 @@ def main(args: DictConfig) -> float | None: start_step = 0 epoch = 0 - perf_logger = PerfLogger(dist_config, args) + perf_logger = PerfLogger( + dist_config, + args, + model_config_dict=config.to_dict(), + ) # Training loop step = start_step diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py index 4cf5b6af6e..43e6b9f0ef 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py @@ -182,7 +182,11 @@ def main(args: DictConfig) -> float | None: start_step = 0 epoch = 0 - perf_logger = PerfLogger(dist_config, args) + perf_logger = PerfLogger( + dist_config, + args, + model_config_dict=config.to_dict(), + ) # Training loop step = start_step diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py index 5593a08721..9a0988f6bd 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py @@ -177,7 +177,11 @@ def main(args: DictConfig) -> float | None: start_step = 0 epoch = 0 - perf_logger = PerfLogger(dist_config, args) + perf_logger = PerfLogger( + dist_config, + args, + model_config_dict=config.to_dict(), + ) # Training loop step = start_step diff --git a/bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py b/bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py index b998616316..1f86aa9859 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py @@ -163,7 +163,11 @@ def main(args: DictConfig) -> float | None: start_step = 0 epoch = 0 - perf_logger = PerfLogger(dist_config, args) + perf_logger = PerfLogger( + dist_config, + args, + model_config_dict=config.to_dict(), + ) # Training loop step = start_step diff --git a/bionemo-recipes/recipes/llama3_native_te/README.md b/bionemo-recipes/recipes/llama3_native_te/README.md index 2be3b0f11e..c4d2912521 100644 --- a/bionemo-recipes/recipes/llama3_native_te/README.md +++ b/bionemo-recipes/recipes/llama3_native_te/README.md @@ -412,6 +412,30 @@ Once converted, the model can be loaded by any library that supports Llama 3, su vllm serve path/to/hf_converted_model ``` +## MFU Tracking + +Enable per-step MFU logging by adding `log_mfu=true`: + +```bash +torchrun --nproc_per_node=2 train_fsdp2_cp.py --config-name L2_lingua_1b log_mfu=true +``` + +Two pairs of metrics are emitted per logging interval: + +- `train/mfu_pct` / `train/tflops_per_gpu` — useful-work rate. Excludes padding of all kinds. + Non-attention uses the unpadded token count; attention uses `Σ(Lᵢ²)` from `cu_seq_lens_q` (THD) + or per-row `attention_mask.sum()` (BSHD). +- `train/mfu_padded_pct` / `train/tflops_per_gpu_padded` — hardware view. Counts every slot the + GPU processes, including CP-zigzag and BSHD row padding. HFU-like. + +The two pairs agree when the batch has no padding (e.g. dense single-doc THD packs). The formula +is CP-aware (global `Σ(Lᵢ²)` divided by `cp_size`) and auto-detects GQA/MHA and SwiGLU/standard +FFN from the HF config. Implementation in `perf_logger.py`. + +Memory metrics: `train/gpu_memory_allocated_max_gb` is the true transient peak per logging window +(via `torch.cuda.max_memory_allocated()` + `reset_peak_memory_stats()`); `_mean_gb` is the +post-step resting footprint. + ## Developer Guide ### Running tests diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml index 9302a0758d..ef5e064a88 100644 --- a/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml +++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml @@ -79,6 +79,8 @@ fp8_stats_config: fp8_stats_file: ./fp8_debugging_stats.yaml fp8_log_dir: ./log_fp8_stats +log_mfu: false + profiler: enabled: false start_step: 10 diff --git a/bionemo-recipes/recipes/llama3_native_te/model_configs/meta-llama/Llama-3.1-8B/config.json b/bionemo-recipes/recipes/llama3_native_te/model_configs/meta-llama/Llama-3.1-8B/config.json new file mode 100644 index 0000000000..460f2f1b71 --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/model_configs/meta-llama/Llama-3.1-8B/config.json @@ -0,0 +1,35 @@ +{ + "architectures": [ + "LlamaForCausalLM" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 128000, + "eos_token_id": 128001, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 14336, + "max_position_embeddings": 131072, + "mlp_bias": false, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 8, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": { + "factor": 8.0, + "high_freq_factor": 4.0, + "low_freq_factor": 1.0, + "original_max_position_embeddings": 8192, + "rope_type": "llama3" + }, + "rope_theta": 500000.0, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.57.3", + "use_cache": true, + "vocab_size": 128256 +} diff --git a/bionemo-recipes/recipes/llama3_native_te/perf_logger.py b/bionemo-recipes/recipes/llama3_native_te/perf_logger.py index 726eb19e8e..eda6b66f5e 100644 --- a/bionemo-recipes/recipes/llama3_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/llama3_native_te/perf_logger.py @@ -33,18 +33,172 @@ logger = logging.getLogger(__name__) +# Dense BF16 tensor core peak TFLOPS (without sparsity). Product pages often list +# the 2x sparse number; dense = sparse / 2. Sources: NVIDIA datasheets for each GPU. +_GPU_PEAK_TFLOPS_BF16 = { + "H100": 989.0, + "H200": 989.0, + "A100": 312.0, + "A6000": 155.0, + "L40": 181.0, + "GH200": 989.0, + "B200": 2250.0, + "GB200": 2250.0, + "B300": 2500.0, + "GB300": 2500.0, +} + +# Model types that use gated MLP (SwiGLU/GeGLU) with 3 projections vs. standard FFN with 2. +_GATED_MLP_MODEL_TYPES = frozenset({"llama", "mistral", "qwen2"}) + + +def _detect_peak_tflops_bf16(): + """Auto-detect dense BF16 peak TFLOPS for the local GPU. Returns (peak, device_name).""" + if not torch.cuda.is_available(): + return None, "unknown" + name = torch.cuda.get_device_name(0) + for key, tflops in _GPU_PEAK_TFLOPS_BF16.items(): + if key.lower() in name.lower(): + return tflops, name + return None, name + + +def _compute_non_attn_per_token_flops(model_config_dict: dict, use_padded_vocab: bool = False) -> int: + """Per-token FLOPs for everything EXCEPT the S² attention term. + + Q/K/V/O projections (GQA-aware) + MLP + LM head, 3x for fwd+bwd. Multiply by the + actual total token count of the batch to get per-step non-attention FLOPs. Pairs + with ``_compute_attn_flop_coeff``, which contributes the attention term as + ``coeff · Σ(Lᵢ²)`` from cu_seq_lens. + """ + h = model_config_dict["hidden_size"] + n_heads = model_config_dict["num_attention_heads"] + n_kv = model_config_dict.get("num_key_value_heads", n_heads) + head_dim = h // n_heads + kv_dim = n_kv * head_dim + ffn = model_config_dict["intermediate_size"] + vocab = model_config_dict.get("vocab_size", 0) + if use_padded_vocab: + # LM-head matmul runs at padded width (e.g. ESM-2: vocab=33 → padded=64 for + # FP8/tensor-core friendliness); logits are sliced back post-matmul. + vocab = model_config_dict.get("padded_vocab_size") or vocab + num_layers = model_config_dict["num_hidden_layers"] + model_type = model_config_dict.get("model_type", "") + num_mlp_proj = 3 if model_type in _GATED_MLP_MODEL_TYPES else 2 + + per_layer = ( + 2 * h * h # Q projection + + 4 * h * kv_dim # K + V projections (GQA-aware) + + 2 * h * h # O projection + + 2 * num_mlp_proj * h * ffn # MLP (2 or 3 projections) + ) + lm_head = 2 * h * vocab if vocab > 0 else 0 + return 3 * (num_layers * per_layer + lm_head) + + +def _compute_attn_flop_coeff(model_config_dict: dict) -> int: + """Coefficient K such that per-step attention FLOPs = K · Σ(Lᵢ²) globally. + + Per CP rank: ``K · Σ(Lᵢ²) / cp_size`` — each CP rank computes 1/cp_size of each + doc's Lᵢ * Lᵢ score matrix. The 4 counts QK^T (2) + softmax·V (2); the 3 is + fwd+bwd. Hidden size appears linearly because attention is over heads and each + contributes head_dim, and heads * head_dim == h. + """ + h = model_config_dict["hidden_size"] + num_layers = model_config_dict["num_hidden_layers"] + return 3 * num_layers * 4 * h + + +def _attn_work_from_batch( + batch: dict, device: torch.device, cp_size: int = 1, include_padding: bool = False +) -> torch.Tensor: + """Return GLOBAL Σ(Lᵢ²) for this batch as an int64 scalar tensor. + + The caller divides by cp_size in log_step to convert this global number into + per-rank attention work; this helper always returns a pre-CP-shard quantity. + + ``include_padding=False`` (default) counts only real tokens — "useful work": + * THD: uses ``cu_seq_lens_q`` (real per-doc lengths, already global). + * BSHD: uses ``attention_mask.sum(dim=-1)`` per row to get real per-row lengths, + scaled by ``cp_size²`` to recover global. NOTE: for BSHD+CP this is exact when + padding is evenly distributed across cp chunks (the common case); can + underestimate slightly when padding is all on one end of the sequence because + per-rank mask.sum² loses row-level correlation info. + + ``include_padding=True`` counts padded positions too — "hardware view": + * THD: uses ``cu_seq_lens_q_padded`` (includes CP zigzag-divisibility padding). + * BSHD: uses full ``input_ids.shape`` (includes dynamic-padding-to-longest slots), + scaled by ``cp_size²``. + + Int32 lens cast to int64 BEFORE squaring (overflow at L ≈ 46k otherwise). + + NOTE: With the collator's ``pad_to_multiple_of`` option (FP8/FP4 alignment), the + cu_seq_lens_q tensor is mutated in place to include an appended mock pad sequence + and no ``cu_seq_lens_q_padded`` key is written (that key is reserved for TE's + per-sequence CP padding). In that path the unpadded and padded metrics collapse, + inflated by ≤``pad_to_multiple_of²`` relative to the real Σ(Lᵢ²) — typically + <10⁻⁵ and below measurement noise. Known limitation; see + https://github.com/NVIDIA/bionemo-framework/issues/1561. + """ + if include_padding: + cu = batch.get("cu_seq_lens_q_padded") + if cu is None: + cu = batch.get("cu_seq_lens_q") # fall back if no padded variant present + if cu is not None: + lens = (cu[1:] - cu[:-1]).to(torch.int64) + return (lens * lens).sum() + shape = batch["input_ids"].shape + batch_size, seq_len_per_rank = int(shape[0]), int(shape[-1]) + return torch.tensor( + batch_size * seq_len_per_rank * seq_len_per_rank * cp_size * cp_size, + dtype=torch.int64, + device=device, + ) + # Unpadded (real work) + cu = batch.get("cu_seq_lens_q") + if cu is not None: + lens = (cu[1:] - cu[:-1]).to(torch.int64) + return (lens * lens).sum() + mask = batch.get("attention_mask") + if mask is not None: + per_row_real = mask.sum(dim=-1).to(torch.int64) + return (per_row_real * per_row_real).sum() * cp_size * cp_size + # Fallback: no real-length signal present — try cu_seq_lens_q_padded, then shape. + cu = batch.get("cu_seq_lens_q_padded") + if cu is not None: + lens = (cu[1:] - cu[:-1]).to(torch.int64) + return (lens * lens).sum() + shape = batch["input_ids"].shape + batch_size, seq_len_per_rank = int(shape[0]), int(shape[-1]) + return torch.tensor( + batch_size * seq_len_per_rank * seq_len_per_rank * cp_size * cp_size, + dtype=torch.int64, + device=device, + ) + + class PerfLogger: """Class to log performance metrics to stdout and wandb, and print final averaged metrics at the end of training. Args: dist_config: The distributed configuration. args: The arguments. + start_step: The step to resume progress-bar counting from. + model_config_dict: Optional HF-style model config dict. When supplied together with + ``args.log_mfu`` set to True, the logger computes per-step Model FLOPs Utilization + (``train/mfu_pct``) and throughput (``train/tflops_per_gpu``) on each logging step. Attributes: min_loss: The minimum loss seen so far. """ - def __init__(self, dist_config: DistributedConfig, args: DictConfig, start_step: int): + def __init__( + self, + dist_config: DistributedConfig, + args: DictConfig, + start_step: int, + model_config_dict: dict | None = None, + ): """Initialize the logger.""" self._dist_config = dist_config self._run_config = OmegaConf.to_container(args, resolve=True, throw_on_missing=True) @@ -54,6 +208,34 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig, start_step: self.logging_frequency = args.logger.frequency + # MFU setup: compute per-token FLOPs and peak TFLOPS once at init. Actual FLOPs per + # step are derived at log time from the tracked unpadded token count, which already + # reflects each rank's share under DP/CP and sequence packing. + self._log_mfu = args.log_mfu and model_config_dict is not None + self._non_attn_per_token_flops = 0 + self._non_attn_per_token_flops_padded = 0 + self._attn_flop_coeff = 0 + self._cp_size = int(args.get("cp_size", 1)) + self._peak_tflops: float | None = None + if self._log_mfu: + self._non_attn_per_token_flops = _compute_non_attn_per_token_flops(model_config_dict) + self._non_attn_per_token_flops_padded = _compute_non_attn_per_token_flops( + model_config_dict, use_padded_vocab=True + ) + self._attn_flop_coeff = _compute_attn_flop_coeff(model_config_dict) + self._peak_tflops, gpu_name = _detect_peak_tflops_bf16() + if dist_config.local_rank == 0: + logger.info( + "MFU tracking enabled: GPU=%s, peak=%s TFLOPS BF16, " + "non_attn_per_token=%.3e, attn_coeff=%.3e, seq_len=%d, cp_size=%d", + gpu_name, + f"{self._peak_tflops:.1f}" if self._peak_tflops else "unknown", + float(self._non_attn_per_token_flops), + float(self._attn_flop_coeff), + args.dataset.max_seq_length, + self._cp_size, + ) + metrics_dict = { "train/loss": torchmetrics.MeanMetric(), "train/grad_norm": torchmetrics.MeanMetric(), @@ -65,6 +247,15 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig, start_step: "train/gpu_memory_allocated_max_gb": torchmetrics.MaxMetric(), "train/gpu_memory_allocated_mean_gb": torchmetrics.MeanMetric(), } + if self._log_mfu: + # Two TFLOPS/MFU pairs: + # * tflops_per_gpu / mfu_pct — useful work only (no padding) + # * tflops_per_gpu_padded / mfu_padded_pct — hardware view (counts padding slots) + metrics_dict["train/tflops_per_gpu"] = torchmetrics.MeanMetric() + metrics_dict["train/tflops_per_gpu_padded"] = torchmetrics.MeanMetric() + if self._peak_tflops is not None: + metrics_dict["train/mfu_pct"] = torchmetrics.MeanMetric() + metrics_dict["train/mfu_padded_pct"] = torchmetrics.MeanMetric() self.metrics = torchmetrics.MetricCollection(metrics_dict) # We move metrics to a GPU device so we can use torch.distributed to aggregate them before logging. @@ -87,6 +278,11 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig, start_step: # Gradient accumulation tracking self.num_tokens = 0 self.num_unpadded_tokens = torch.tensor(0, dtype=torch.int64, device=self._device) + # Σ(Lᵢ²) over grad-acc micro-batches — two flavors: + # unpadded: only real tokens (useful work), drives mfu_pct + # padded: all slots including CP-zigzag / BSHD row padding, drives mfu_padded_pct + self._attn_work_unpadded_accum = torch.tensor(0, dtype=torch.int64, device=self._device) + self._attn_work_padded_accum = torch.tensor(0, dtype=torch.int64, device=self._device) self.running_loss = torch.tensor(0.0, device=self._device) self.grad_acc_step_count = 0 @@ -119,6 +315,15 @@ def log_micro_step(self, step: int, batch: dict[str, torch.Tensor], outputs: Cau else: # Fallback for pure sequence packing with no padding: all tokens are unpadded self.num_unpadded_tokens += batch["input_ids"].numel() + if self._log_mfu: + # Accumulate both unpadded (useful) and padded (hardware) Σ(Lᵢ²). + # Helper returns a GLOBAL value (pre-CP-shard); log_step divides by cp_size. + self._attn_work_unpadded_accum += _attn_work_from_batch( + batch, self._device, self._cp_size, include_padding=False + ) + self._attn_work_padded_accum += _attn_work_from_batch( + batch, self._device, self._cp_size, include_padding=True + ) @nvtx.annotate("PerfLogger.log_step", color="purple") def log_step( @@ -173,9 +378,49 @@ def log_step( self.metrics["train/unpadded_tokens_per_second_per_gpu"].update(self.num_unpadded_tokens / step_time) self.metrics["train/total_unpadded_tokens_per_batch"].update(self.num_unpadded_tokens) - memory_allocated = torch.cuda.memory_allocated() / (1024**3) - self.metrics["train/gpu_memory_allocated_max_gb"].update(memory_allocated) - self.metrics["train/gpu_memory_allocated_mean_gb"].update(memory_allocated) + if self._log_mfu: + # Two MFU flavors reported side-by-side: + # * mfu_pct = useful work rate. Non-attn over real tokens + # (num_unpadded_tokens), attn over Σ(Lᵢ²) from + # cu_seq_lens_q (THD) or per-row mask (BSHD). + # Drops padding from both terms — what the model + # actually learns from. + # * mfu_padded_pct = hardware view. Non-attn over all slots + # (num_tokens = input_ids.numel), attn over + # cu_seq_lens_q_padded / full B·S² — counts the + # cycles the HW actually burned, including + # CP-zigzag pad and BSHD row pad. + # Both divide the global Σ(Lᵢ²) by cp_size to get per-rank attn work. + attn_unpadded = int(self._attn_work_unpadded_accum.item()) + attn_padded = int(self._attn_work_padded_accum.item()) + num_unpadded = int(self.num_unpadded_tokens.item()) + + non_attn_unpadded = self._non_attn_per_token_flops * num_unpadded + attn_flops_unpadded = (self._attn_flop_coeff * attn_unpadded) // self._cp_size + flops_unpadded = non_attn_unpadded + attn_flops_unpadded + tflops_unpadded = flops_unpadded / step_time / 1e12 + + non_attn_padded = self._non_attn_per_token_flops_padded * self.num_tokens + attn_flops_padded = (self._attn_flop_coeff * attn_padded) // self._cp_size + flops_padded = non_attn_padded + attn_flops_padded + tflops_padded = flops_padded / step_time / 1e12 + + self.metrics["train/tflops_per_gpu"].update(tflops_unpadded) + self.metrics["train/tflops_per_gpu_padded"].update(tflops_padded) + if self._peak_tflops is not None: + self.metrics["train/mfu_pct"].update(tflops_unpadded / self._peak_tflops * 100.0) + self.metrics["train/mfu_padded_pct"].update(tflops_padded / self._peak_tflops * 100.0) + + # Report TRUE peak memory across the logging window (FSDP-gathered params + + # activations held for backward), not just the post-step resting footprint. + # Reset the peak counter so each window reports its own peak instead of a + # running max since process start. Both calls are pure host-side counter ops + # -- no sync, no kernel launch. + peak_gb = torch.cuda.max_memory_allocated() / (1024**3) + current_gb = torch.cuda.memory_allocated() / (1024**3) + torch.cuda.reset_peak_memory_stats() + self.metrics["train/gpu_memory_allocated_max_gb"].update(peak_gb) + self.metrics["train/gpu_memory_allocated_mean_gb"].update(current_gb) metrics = self.metrics.compute() self.metrics.reset() @@ -197,6 +442,8 @@ def log_step( self.running_loss.zero_() self.num_tokens = 0 self.num_unpadded_tokens.zero_() + self._attn_work_unpadded_accum.zero_() + self._attn_work_padded_accum.zero_() self.grad_acc_step_count = 0 def finish(self): diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py index aebdfe17ef..dee275a9b1 100644 --- a/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py +++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py @@ -23,10 +23,16 @@ from transformers.modeling_outputs import CausalLMOutputWithPast from distributed_config import DistributedConfig -from perf_logger import PerfLogger +from perf_logger import ( + PerfLogger, + _attn_work_from_batch, + _compute_attn_flop_coeff, + _compute_non_attn_per_token_flops, + _detect_peak_tflops_bf16, +) -def _make_args(logging_frequency=1, num_train_steps=100): +def _make_args(logging_frequency=1, num_train_steps=100, log_mfu=False, max_seq_length=128): """Create a minimal args config for PerfLogger.""" return OmegaConf.create( { @@ -35,6 +41,8 @@ def _make_args(logging_frequency=1, num_train_steps=100): "num_train_steps": num_train_steps, "profiler": {"enabled": False}, "fp8_stats_config": {"enabled": False}, + "log_mfu": log_mfu, + "dataset": {"max_seq_length": max_seq_length}, } ) @@ -208,3 +216,141 @@ def test_min_loss_tracked_correctly(self, mock_wandb, mock_tqdm): _run_steps(perf_logger, losses) assert perf_logger.min_loss.item() == pytest.approx(1.0) + + +class TestDetectPeakTflops: + """Smoke test for GPU peak TFLOPS detection.""" + + def test_returns_tuple_shape(self): + """Returns (peak_tflops_or_none, device_name_str).""" + peak, name = _detect_peak_tflops_bf16() + assert isinstance(name, str) + assert peak is None or isinstance(peak, float) + + +def _llama_cfg(): + """Small llama-like config used by the split-formula tests.""" + return { + "model_type": "llama", + "hidden_size": 1024, + "num_hidden_layers": 8, + "num_attention_heads": 16, + "num_key_value_heads": 4, + "intermediate_size": 4096, + "vocab_size": 32000, + } + + +class TestFlopSplitAndAttention: + """Verify the non-attn + Σ(Lᵢ²) attention formula is correctly computed. + + Non-attention FLOPs are tracked per real token; attention FLOPs are tracked as + coeff * Σ(Lᵢ²) over per-doc real lengths. These tests lock in the formula and + its invariants (shape synthesis for BSHD, cu_seq_lens handling for THD, CP + division, unpadded/padded behavior, fallbacks). + """ + + def test_bshd_shape_synthesis(self): + """BSHD batch (no cu_seq_lens) synthesizes Σ(Lᵢ²) = B·S² from input_ids shape.""" + b, s = 2, 512 + batch = {"input_ids": torch.zeros(b, s, dtype=torch.long)} + sigma_l_sq = _attn_work_from_batch(batch, torch.device("cpu")).item() + assert sigma_l_sq == b * s * s + + def test_thd_single_doc_matches_bshd(self): + """cu_seq_lens_q=[0, S] (synthetic-single-doc) reproduces BSHD's Σ(Lᵢ²)=S².""" + s = 512 + bshd = {"input_ids": torch.zeros(1, s, dtype=torch.long)} + thd = { + "input_ids": torch.zeros(1, s, dtype=torch.long), + "cu_seq_lens_q": torch.tensor([0, s], dtype=torch.int32), + } + assert _attn_work_from_batch(bshd, torch.device("cpu")).item() == s * s + assert _attn_work_from_batch(thd, torch.device("cpu")).item() == s * s + + def test_thd_multi_doc_uses_squared_sum(self): + """Multi-doc pack computes Σ(Lᵢ²), not (ΣLᵢ)² — the whole point of the fix.""" + # Doc lengths 3, 5, 7 → cumulative [0, 3, 8, 15] + cu = torch.tensor([0, 3, 8, 15], dtype=torch.int32) + batch = {"input_ids": torch.zeros(1, 15, dtype=torch.long), "cu_seq_lens_q": cu} + work = _attn_work_from_batch(batch, torch.device("cpu")).item() + assert work == 3**2 + 5**2 + 7**2 # 83 real QK pairs per layer + assert work < 15 * 15 # old formula would have said 225 + + def test_cp_size_divides_attention_only(self): + """Dividing attention by cp_size must leave the non-attention term untouched.""" + cfg = _llama_cfg() + non_attn_per_token = _compute_non_attn_per_token_flops(cfg) + coeff = _compute_attn_flop_coeff(cfg) + num_tokens, attn_work = 100, 10_000 + non_attn = non_attn_per_token * num_tokens + flops_cp1 = non_attn + (coeff * attn_work) // 1 + flops_cp4 = non_attn + (coeff * attn_work) // 4 + assert flops_cp1 - non_attn == coeff * attn_work + assert flops_cp4 - non_attn == (coeff * attn_work) // 4 + + def test_unpadded_preferred_over_padded(self): + """When both cu_seq_lens_q and cu_seq_lens_q_padded are present, _q wins.""" + batch = { + "input_ids": torch.zeros(1, 16, dtype=torch.long), + "cu_seq_lens_q": torch.tensor([0, 5, 11], dtype=torch.int32), + "cu_seq_lens_q_padded": torch.tensor([0, 8, 16], dtype=torch.int32), + } + work = _attn_work_from_batch(batch, torch.device("cpu")).item() + assert work == 5**2 + 6**2 # 61 (unpadded doc lens 5 and 6) + assert work != 8**2 + 8**2 # 128 (padded slot lens 8 and 8) + + def test_padded_fallback_when_unpadded_absent(self): + """If only cu_seq_lens_q_padded is present, it is used as a fallback.""" + batch = { + "input_ids": torch.zeros(1, 16, dtype=torch.long), + "cu_seq_lens_q_padded": torch.tensor([0, 8, 16], dtype=torch.int32), + } + assert _attn_work_from_batch(batch, torch.device("cpu")).item() == 8**2 + 8**2 + + def test_bshd_cp_correction(self): + """BSHD with CP: per-rank shape (B, S/cp) → helper must return global B*S². + + ContextParallelDataLoaderWrapper pre-splits the sequence so each rank's + input_ids.shape is (B, S/cp), not (B, S). The helper returns a GLOBAL + quantity (the caller divides by cp_size), so the BSHD synthesis branch + must multiply per-rank shape² by cp_size² to recover global B*S². + Without this correction, BSHD+CP attention FLOPs would be undercounted + by a factor of cp² (the bug surfaced when running real-data llama3/og2 + BSHD benchmarks at cp=8). + """ + # Pretend a rank has shape (1, 16) — this would correspond to global S=16*cp. + batch = {"input_ids": torch.zeros(1, 16, dtype=torch.long)} + # cp_size=1 → per-rank shape == global shape: 1*16² = 256 + assert _attn_work_from_batch(batch, torch.device("cpu"), cp_size=1).item() == 1 * 16 * 16 + # cp_size=8 → global S = 16*8 = 128, global B*S² = 128² = 16384 + assert _attn_work_from_batch(batch, torch.device("cpu"), cp_size=8).item() == 128 * 128 + # THD path is unaffected by cp_size since cu_seq_lens_q is already global + thd = { + "input_ids": torch.zeros(1, 64, dtype=torch.long), + "cu_seq_lens_q": torch.tensor([0, 3, 8, 15], dtype=torch.int32), + } + assert _attn_work_from_batch(thd, torch.device("cpu"), cp_size=1).item() == 3**2 + 5**2 + 7**2 + assert _attn_work_from_batch(thd, torch.device("cpu"), cp_size=8).item() == 3**2 + 5**2 + 7**2 + + def test_include_padding_thd(self): + """THD include_padding=True uses cu_seq_lens_q_padded (zigzag pad); False uses cu_seq_lens_q.""" + batch = { + "input_ids": torch.zeros(1, 16, dtype=torch.long), + "cu_seq_lens_q": torch.tensor([0, 5, 11], dtype=torch.int32), + "cu_seq_lens_q_padded": torch.tensor([0, 8, 16], dtype=torch.int32), + } + dev = torch.device("cpu") + assert _attn_work_from_batch(batch, dev, cp_size=1, include_padding=False).item() == 5**2 + 6**2 # 61 + assert _attn_work_from_batch(batch, dev, cp_size=1, include_padding=True).item() == 8**2 + 8**2 # 128 + + def test_include_padding_bshd_with_attention_mask(self): + """BSHD include_padding=False uses attention_mask (real per-row lengths); True uses full shape.""" + # 2 rows, each padded to 8 slots; real lengths 5 and 3. + mask = torch.tensor([[1, 1, 1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0, 0, 0]], dtype=torch.int64) + batch = {"input_ids": torch.zeros(2, 8, dtype=torch.long), "attention_mask": mask} + dev = torch.device("cpu") + # Unpadded (real): 5² + 3² = 34 + assert _attn_work_from_batch(batch, dev, cp_size=1, include_padding=False).item() == 34 + # Padded (hardware view): 2 * 8² = 128 + assert _attn_work_from_batch(batch, dev, cp_size=1, include_padding=True).item() == 2 * 8 * 8 diff --git a/bionemo-recipes/recipes/llama3_native_te/train_ddp.py b/bionemo-recipes/recipes/llama3_native_te/train_ddp.py index 413b9262c7..8882cdd311 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_ddp.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_ddp.py @@ -141,7 +141,12 @@ def main(args: DictConfig) -> float | None: start_step = 0 epoch = 0 - perf_logger = PerfLogger(dist_config, args, start_step=start_step) + perf_logger = PerfLogger( + dist_config, + args, + start_step=start_step, + model_config_dict=config.to_dict(), + ) gc.collect() torch.cuda.empty_cache() diff --git a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py index da19daa2a7..2c16daf557 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py @@ -155,7 +155,12 @@ def main(args: DictConfig) -> float | None: start_step = 0 epoch = 0 - perf_logger = PerfLogger(dist_config, args, start_step=start_step) + perf_logger = PerfLogger( + dist_config, + args, + start_step=start_step, + model_config_dict=config.to_dict(), + ) gc.collect() torch.cuda.empty_cache() diff --git a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py index eaf1a1b39f..17d3ece756 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py @@ -177,7 +177,12 @@ def main(args: DictConfig) -> float | None: start_step = 0 epoch = 0 - perf_logger = PerfLogger(dist_config, args, start_step=start_step) + perf_logger = PerfLogger( + dist_config, + args, + start_step=start_step, + model_config_dict=config.to_dict(), + ) gc.collect() torch.cuda.empty_cache() diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/README.md b/bionemo-recipes/recipes/opengenome2_llama_native_te/README.md index 52e1b45986..a7ff80e56b 100644 --- a/bionemo-recipes/recipes/opengenome2_llama_native_te/README.md +++ b/bionemo-recipes/recipes/opengenome2_llama_native_te/README.md @@ -411,6 +411,29 @@ Validation logging during training can be enabled with `validation.enabled=true` validation data (e.g. a JSONL file). The `og2_7b_thd_gqa` config enables validation by default. Control evaluation frequency with `validation.eval_interval` and `validation.num_batches`.This can be helpful when debugging training convergence. +## MFU Tracking + +Enable per-step MFU logging by adding `log_mfu=true`: + +```bash +torchrun --nproc_per_node=2 train_fsdp2_cp.py log_mfu=true +``` + +Two pairs of metrics are emitted per logging interval: + +- `train/mfu_pct` / `train/tflops_per_gpu` — useful-work rate. Excludes padding of all kinds. + Non-attention uses the unpadded token count; attention uses `Σ(Lᵢ²)` from `cu_seq_lens_q` (THD) + or per-row `attention_mask.sum()` (BSHD). +- `train/mfu_padded_pct` / `train/tflops_per_gpu_padded` — hardware view. Counts every slot the + GPU processes, including CP-zigzag and BSHD row padding. HFU-like. + +The two pairs agree when the batch has no padding (e.g. dense single-doc THD packs — common for +genomic data windowed to `max_seq_length`). The formula is CP-aware and auto-detects GQA/SwiGLU +from the HF config. Implementation in `perf_logger.py`. + +Memory: `train/gpu_memory_allocated_max_gb` is the true transient peak per window +(`torch.cuda.max_memory_allocated()` + `reset_peak_memory_stats()`); `_mean_gb` is resting. + ## Developer Guide ### Running tests diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/opengenome2_llama_native_te/hydra_config/defaults.yaml index 4295017f26..d532a83f06 100644 --- a/bionemo-recipes/recipes/opengenome2_llama_native_te/hydra_config/defaults.yaml +++ b/bionemo-recipes/recipes/opengenome2_llama_native_te/hydra_config/defaults.yaml @@ -103,6 +103,8 @@ fp8_stats_config: fp8_stats_file: ./fp8_debugging_stats.yaml fp8_log_dir: ./log_fp8_stats +log_mfu: false + profiler: enabled: false start_step: 10 diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py b/bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py index 081103beb5..249f6ff0b1 100644 --- a/bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/opengenome2_llama_native_te/perf_logger.py @@ -41,18 +41,160 @@ logger = logging.getLogger(__name__) +# Dense BF16 tensor core peak TFLOPS (without sparsity). Product pages often list +# the 2x sparse number; dense = sparse / 2. Sources: NVIDIA datasheets for each GPU. +_GPU_PEAK_TFLOPS_BF16 = { + "H100": 989.0, + "H200": 989.0, + "A100": 312.0, + "A6000": 155.0, + "L40": 181.0, + "GH200": 989.0, + "B200": 2250.0, + "GB200": 2250.0, + "B300": 2500.0, + "GB300": 2500.0, +} + +# Model types that use gated MLP (SwiGLU/GeGLU) with 3 projections vs. standard FFN with 2. +_GATED_MLP_MODEL_TYPES = frozenset({"llama", "mistral", "qwen2"}) + + +def _detect_peak_tflops_bf16(): + """Auto-detect dense BF16 peak TFLOPS for the local GPU. Returns (peak, device_name).""" + if not torch.cuda.is_available(): + return None, "unknown" + name = torch.cuda.get_device_name(0) + for key, tflops in _GPU_PEAK_TFLOPS_BF16.items(): + if key.lower() in name.lower(): + return tflops, name + return None, name + + +def _compute_non_attn_per_token_flops(model_config_dict: dict, use_padded_vocab: bool = False) -> int: + """Per-token FLOPs for everything EXCEPT the S² attention term. + + Q/K/V/O projections (GQA-aware) + MLP + LM head, 3x for fwd+bwd. Multiply by the + actual total token count of the batch to get per-step non-attention FLOPs. Pairs + with ``_compute_attn_flop_coeff``, which contributes the attention term as + ``coeff · Σ(Lᵢ²)`` from cu_seq_lens. + """ + h = model_config_dict["hidden_size"] + n_heads = model_config_dict["num_attention_heads"] + n_kv = model_config_dict.get("num_key_value_heads", n_heads) + head_dim = h // n_heads + kv_dim = n_kv * head_dim + ffn = model_config_dict["intermediate_size"] + vocab = model_config_dict.get("vocab_size", 0) + if use_padded_vocab: + # LM-head matmul runs at padded width (e.g. ESM-2: vocab=33 → padded=64 for + # FP8/tensor-core friendliness); logits are sliced back post-matmul. + vocab = model_config_dict.get("padded_vocab_size") or vocab + num_layers = model_config_dict["num_hidden_layers"] + model_type = model_config_dict.get("model_type", "") + num_mlp_proj = 3 if model_type in _GATED_MLP_MODEL_TYPES else 2 + + per_layer = ( + 2 * h * h # Q projection + + 4 * h * kv_dim # K + V projections (GQA-aware) + + 2 * h * h # O projection + + 2 * num_mlp_proj * h * ffn # MLP (2 or 3 projections) + ) + lm_head = 2 * h * vocab if vocab > 0 else 0 + return 3 * (num_layers * per_layer + lm_head) + + +def _compute_attn_flop_coeff(model_config_dict: dict) -> int: + """Coefficient K such that per-step attention FLOPs = K · Σ(Lᵢ²) globally. + + Per CP rank: ``K · Σ(Lᵢ²) / cp_size`` — each CP rank computes 1/cp_size of each + doc's Lᵢ * Lᵢ score matrix. The 4 counts QK^T (2) + softmax·V (2); the 3 is + fwd+bwd. Hidden size appears linearly because attention is over heads and each + contributes head_dim, and heads * head_dim == h. + """ + h = model_config_dict["hidden_size"] + num_layers = model_config_dict["num_hidden_layers"] + return 3 * num_layers * 4 * h + + +def _attn_work_from_batch( + batch: dict, device: torch.device, cp_size: int = 1, include_padding: bool = False +) -> torch.Tensor: + """Return GLOBAL Σ(Lᵢ²) for this batch as an int64 scalar tensor. + + The caller divides by cp_size in log_step to convert this global number into + per-rank attention work; this helper always returns a pre-CP-shard quantity. + + ``include_padding=False`` (default) counts only real tokens — "useful work": + * THD: uses ``cu_seq_lens_q`` (real per-doc lengths, already global). + * BSHD: uses ``attention_mask.sum(dim=-1)`` per row, scaled by ``cp_size²`` to + recover global. Exact when padding distributes evenly across cp chunks; + approximate when padding is concentrated on one end. + + ``include_padding=True`` counts padded positions too — "hardware view": + * THD: uses ``cu_seq_lens_q_padded`` (includes CP zigzag-divisibility padding). + * BSHD: uses full ``input_ids.shape``, scaled by ``cp_size²``. + + Int32 lens cast to int64 BEFORE squaring (overflow at L ≈ 46k otherwise). + + NOTE: With the collator's ``pad_to_multiple_of`` option (FP8/FP4 alignment), the + cu_seq_lens_q tensor is mutated in place to include an appended mock pad sequence + and no ``cu_seq_lens_q_padded`` key is written (that key is reserved for TE's + per-sequence CP padding). In that path the unpadded and padded metrics collapse, + inflated by ≤``pad_to_multiple_of²`` relative to the real Σ(Lᵢ²) — typically + <10⁻⁵ and below measurement noise. Known limitation; see + https://github.com/NVIDIA/bionemo-framework/issues/1561. + """ + if include_padding: + cu = batch.get("cu_seq_lens_q_padded") + if cu is None: + cu = batch.get("cu_seq_lens_q") + if cu is not None: + lens = (cu[1:] - cu[:-1]).to(torch.int64) + return (lens * lens).sum() + shape = batch["input_ids"].shape + batch_size, seq_len_per_rank = int(shape[0]), int(shape[-1]) + return torch.tensor( + batch_size * seq_len_per_rank * seq_len_per_rank * cp_size * cp_size, + dtype=torch.int64, + device=device, + ) + cu = batch.get("cu_seq_lens_q") + if cu is not None: + lens = (cu[1:] - cu[:-1]).to(torch.int64) + return (lens * lens).sum() + mask = batch.get("attention_mask") + if mask is not None: + per_row_real = mask.sum(dim=-1).to(torch.int64) + return (per_row_real * per_row_real).sum() * cp_size * cp_size + cu = batch.get("cu_seq_lens_q_padded") + if cu is not None: + lens = (cu[1:] - cu[:-1]).to(torch.int64) + return (lens * lens).sum() + shape = batch["input_ids"].shape + batch_size, seq_len_per_rank = int(shape[0]), int(shape[-1]) + return torch.tensor( + batch_size * seq_len_per_rank * seq_len_per_rank * cp_size * cp_size, + dtype=torch.int64, + device=device, + ) + + class PerfLogger: """Class to log performance metrics to stdout and wandb, and print final averaged metrics at the end of training. Args: dist_config: The distributed configuration. args: The arguments. + model_config_dict: Optional HF-style model config dict. When supplied together with + ``args.log_mfu`` set to True, the logger computes per-step Model FLOPs Utilization + (``train/mfu_pct``) and throughput (``train/tflops_per_gpu``) on each logging step. Attributes: min_loss: The minimum loss seen so far. """ - def __init__(self, dist_config: DistributedConfig, args: DictConfig): + def __init__(self, dist_config: DistributedConfig, args: DictConfig, model_config_dict: dict | None = None): """Initialize the logger.""" self._dist_config = dist_config self._run_config = OmegaConf.to_container(args, resolve=True, throw_on_missing=True) @@ -62,6 +204,34 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig): self.logging_frequency = args.logger.frequency + # MFU setup: compute per-token FLOPs and peak TFLOPS once at init. Actual FLOPs per + # step are derived at log time from the tracked unpadded token count, which already + # reflects each rank's share under DP/CP and sequence packing. + self._log_mfu = args.log_mfu and model_config_dict is not None + self._non_attn_per_token_flops = 0 + self._non_attn_per_token_flops_padded = 0 + self._attn_flop_coeff = 0 + self._cp_size = int(args.get("cp_size", 1)) + self._peak_tflops: float | None = None + if self._log_mfu: + self._non_attn_per_token_flops = _compute_non_attn_per_token_flops(model_config_dict) + self._non_attn_per_token_flops_padded = _compute_non_attn_per_token_flops( + model_config_dict, use_padded_vocab=True + ) + self._attn_flop_coeff = _compute_attn_flop_coeff(model_config_dict) + self._peak_tflops, gpu_name = _detect_peak_tflops_bf16() + if dist_config.local_rank == 0: + logger.info( + "MFU tracking enabled: GPU=%s, peak=%s TFLOPS BF16, " + "non_attn_per_token=%.3e, attn_coeff=%.3e, seq_len=%d, cp_size=%d", + gpu_name, + f"{self._peak_tflops:.1f}" if self._peak_tflops else "unknown", + float(self._non_attn_per_token_flops), + float(self._attn_flop_coeff), + args.dataset.max_seq_length, + self._cp_size, + ) + metrics_dict = { "train/loss": torchmetrics.MeanMetric(), "train/grad_norm": torchmetrics.MeanMetric(), @@ -73,6 +243,15 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig): "train/gpu_memory_allocated_max_gb": torchmetrics.MaxMetric(), "train/gpu_memory_allocated_mean_gb": torchmetrics.MeanMetric(), } + if self._log_mfu: + # Two TFLOPS/MFU pairs: + # * tflops_per_gpu / mfu_pct — useful work only (no padding) + # * tflops_per_gpu_padded / mfu_padded_pct — hardware view (counts padding slots) + metrics_dict["train/tflops_per_gpu"] = torchmetrics.MeanMetric() + metrics_dict["train/tflops_per_gpu_padded"] = torchmetrics.MeanMetric() + if self._peak_tflops is not None: + metrics_dict["train/mfu_pct"] = torchmetrics.MeanMetric() + metrics_dict["train/mfu_padded_pct"] = torchmetrics.MeanMetric() self.metrics = torchmetrics.MetricCollection(metrics_dict) # We move metrics to a GPU device so we can use torch.distributed to aggregate them before logging. @@ -95,6 +274,11 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig): # Gradient accumulation tracking self.num_tokens = 0 self.num_unpadded_tokens = torch.tensor(0, dtype=torch.int64, device=self._device) + # Σ(Lᵢ²) over grad-acc micro-batches — two flavors: + # unpadded: only real tokens (useful work), drives mfu_pct + # padded: all slots including CP-zigzag / BSHD row padding, drives mfu_padded_pct + self._attn_work_unpadded_accum = torch.tensor(0, dtype=torch.int64, device=self._device) + self._attn_work_padded_accum = torch.tensor(0, dtype=torch.int64, device=self._device) self.running_loss = torch.tensor(0.0, device=self._device) self.grad_acc_step_count = 0 @@ -127,6 +311,15 @@ def log_micro_step(self, step: int, batch: dict[str, torch.Tensor], outputs: Cau else: # Fallback for pure sequence packing with no padding: all tokens are unpadded self.num_unpadded_tokens += batch["input_ids"].numel() + if self._log_mfu: + # Accumulate both unpadded (useful) and padded (hardware) Σ(Lᵢ²). + # Helper returns a GLOBAL value (pre-CP-shard); log_step divides by cp_size. + self._attn_work_unpadded_accum += _attn_work_from_batch( + batch, self._device, self._cp_size, include_padding=False + ) + self._attn_work_padded_accum += _attn_work_from_batch( + batch, self._device, self._cp_size, include_padding=True + ) @nvtx.annotate("PerfLogger.log_step", color="purple") def log_step( @@ -181,9 +374,42 @@ def log_step( self.metrics["train/unpadded_tokens_per_second_per_gpu"].update(self.num_unpadded_tokens / step_time) self.metrics["train/total_unpadded_tokens_per_batch"].update(self.num_unpadded_tokens) - memory_allocated = torch.cuda.memory_allocated() / (1024**3) - self.metrics["train/gpu_memory_allocated_max_gb"].update(memory_allocated) - self.metrics["train/gpu_memory_allocated_mean_gb"].update(memory_allocated) + if self._log_mfu: + # Two MFU flavors reported side-by-side: + # mfu_pct = useful-work rate. Non-attn over real tokens, + # attn over real Σ(Lᵢ²). Drops both padding types. + # mfu_padded_pct = hardware view. Non-attn over all slots, attn over + # padded Σ(Lᵢ²) (includes CP zigzag + BSHD row pad). + attn_unpadded = int(self._attn_work_unpadded_accum.item()) + attn_padded = int(self._attn_work_padded_accum.item()) + num_unpadded = int(self.num_unpadded_tokens.item()) + + non_attn_unpadded = self._non_attn_per_token_flops * num_unpadded + attn_flops_unpadded = (self._attn_flop_coeff * attn_unpadded) // self._cp_size + flops_unpadded = non_attn_unpadded + attn_flops_unpadded + tflops_unpadded = flops_unpadded / step_time / 1e12 + + non_attn_padded = self._non_attn_per_token_flops_padded * self.num_tokens + attn_flops_padded = (self._attn_flop_coeff * attn_padded) // self._cp_size + flops_padded = non_attn_padded + attn_flops_padded + tflops_padded = flops_padded / step_time / 1e12 + + self.metrics["train/tflops_per_gpu"].update(tflops_unpadded) + self.metrics["train/tflops_per_gpu_padded"].update(tflops_padded) + if self._peak_tflops is not None: + self.metrics["train/mfu_pct"].update(tflops_unpadded / self._peak_tflops * 100.0) + self.metrics["train/mfu_padded_pct"].update(tflops_padded / self._peak_tflops * 100.0) + + # Report TRUE peak memory across the logging window (FSDP-gathered params + + # activations held for backward), not just the post-step resting footprint. + # Reset the peak counter so each window reports its own peak instead of a + # running max since process start. Both calls are pure host-side counter ops + # -- no sync, no kernel launch. + peak_gb = torch.cuda.max_memory_allocated() / (1024**3) + current_gb = torch.cuda.memory_allocated() / (1024**3) + torch.cuda.reset_peak_memory_stats() + self.metrics["train/gpu_memory_allocated_max_gb"].update(peak_gb) + self.metrics["train/gpu_memory_allocated_mean_gb"].update(current_gb) metrics = self.metrics.compute() self.metrics.reset() @@ -205,6 +431,8 @@ def log_step( self.running_loss.zero_() self.num_tokens = 0 self.num_unpadded_tokens.zero_() + self._attn_work_unpadded_accum.zero_() + self._attn_work_padded_accum.zero_() self.grad_acc_step_count = 0 def log_validation(self, step: int, val_metrics: dict): diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/tests/test_perf_logger.py b/bionemo-recipes/recipes/opengenome2_llama_native_te/tests/test_perf_logger.py new file mode 100644 index 0000000000..7f21c4aeca --- /dev/null +++ b/bionemo-recipes/recipes/opengenome2_llama_native_te/tests/test_perf_logger.py @@ -0,0 +1,140 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the non-attention + Σ(Lᵢ²) attention FLOP formula. + +Non-attention FLOPs are tracked per real token; attention FLOPs are tracked as +coeff * Σ(Lᵢ²) over per-doc real lengths. These tests lock in the formula and its +invariants so future drift between sibling recipes is caught immediately. +""" + +import torch + +from perf_logger import ( + _attn_work_from_batch, + _compute_attn_flop_coeff, + _compute_non_attn_per_token_flops, +) + + +def _llama_cfg(): + """Llama-like OG2 config used by the split-formula tests.""" + return { + "model_type": "llama", + "hidden_size": 4096, + "num_hidden_layers": 32, + "num_attention_heads": 32, + "num_key_value_heads": 8, # GQA + "intermediate_size": 14336, + "vocab_size": 256, # OG2's nucleotide vocab + } + + +class TestFlopSplitAndAttention: + """Verify the non-attn + Σ(Lᵢ²) attention formula is correctly computed.""" + + def test_bshd_shape_synthesis(self): + """BSHD batch (no cu_seq_lens) synthesizes Σ(Lᵢ²) = B·S² from input_ids shape.""" + b, s = 2, 512 + batch = {"input_ids": torch.zeros(b, s, dtype=torch.long)} + sigma_l_sq = _attn_work_from_batch(batch, torch.device("cpu")).item() + assert sigma_l_sq == b * s * s + + def test_thd_single_doc_matches_bshd(self): + """cu_seq_lens_q=[0, S] (synthetic-single-doc) reproduces BSHD's Σ(Lᵢ²)=S².""" + s = 512 + bshd = {"input_ids": torch.zeros(1, s, dtype=torch.long)} + thd = { + "input_ids": torch.zeros(1, s, dtype=torch.long), + "cu_seq_lens_q": torch.tensor([0, s], dtype=torch.int32), + } + assert _attn_work_from_batch(bshd, torch.device("cpu")).item() == s * s + assert _attn_work_from_batch(thd, torch.device("cpu")).item() == s * s + + def test_thd_multi_doc_uses_squared_sum(self): + """Multi-doc pack computes Σ(Lᵢ²), not (ΣLᵢ)² — the whole point of the fix.""" + cu = torch.tensor([0, 3, 8, 15], dtype=torch.int32) + batch = {"input_ids": torch.zeros(1, 15, dtype=torch.long), "cu_seq_lens_q": cu} + work = _attn_work_from_batch(batch, torch.device("cpu")).item() + assert work == 3**2 + 5**2 + 7**2 # 83 real QK pairs per layer + assert work < 15 * 15 # old formula would have said 225 + + def test_cp_size_divides_attention_only(self): + """cp_size divides the attention term only; non-attention stays untouched.""" + cfg = _llama_cfg() + non_attn_per_token = _compute_non_attn_per_token_flops(cfg) + coeff = _compute_attn_flop_coeff(cfg) + num_tokens, attn_work = 100, 10_000 + non_attn = non_attn_per_token * num_tokens + flops_cp1 = non_attn + (coeff * attn_work) // 1 + flops_cp4 = non_attn + (coeff * attn_work) // 4 + assert flops_cp1 - non_attn == coeff * attn_work + assert flops_cp4 - non_attn == (coeff * attn_work) // 4 + + def test_unpadded_preferred_over_padded(self): + """When both cu_seq_lens_q and cu_seq_lens_q_padded are present, _q wins.""" + batch = { + "input_ids": torch.zeros(1, 16, dtype=torch.long), + "cu_seq_lens_q": torch.tensor([0, 5, 11], dtype=torch.int32), + "cu_seq_lens_q_padded": torch.tensor([0, 8, 16], dtype=torch.int32), + } + work = _attn_work_from_batch(batch, torch.device("cpu")).item() + assert work == 5**2 + 6**2 # 61 (unpadded doc lens 5 and 6) + assert work != 8**2 + 8**2 # 128 (padded slot lens 8 and 8) + + def test_padded_fallback_when_unpadded_absent(self): + """If only cu_seq_lens_q_padded is present, it is used as a fallback.""" + batch = { + "input_ids": torch.zeros(1, 16, dtype=torch.long), + "cu_seq_lens_q_padded": torch.tensor([0, 8, 16], dtype=torch.int32), + } + assert _attn_work_from_batch(batch, torch.device("cpu")).item() == 8**2 + 8**2 + + def test_bshd_cp_correction(self): + """BSHD with CP: per-rank shape (B, S/cp) → helper must return global B*S². + + ContextParallelDataLoaderWrapper pre-splits the sequence so each rank's + input_ids.shape is (B, S/cp), not (B, S). The helper returns a GLOBAL + quantity (the caller divides by cp_size), so the BSHD synthesis branch + must multiply per-rank shape² by cp_size² to recover global B*S². + """ + batch = {"input_ids": torch.zeros(1, 16, dtype=torch.long)} + assert _attn_work_from_batch(batch, torch.device("cpu"), cp_size=1).item() == 1 * 16 * 16 + assert _attn_work_from_batch(batch, torch.device("cpu"), cp_size=8).item() == 128 * 128 + thd = { + "input_ids": torch.zeros(1, 64, dtype=torch.long), + "cu_seq_lens_q": torch.tensor([0, 3, 8, 15], dtype=torch.int32), + } + assert _attn_work_from_batch(thd, torch.device("cpu"), cp_size=1).item() == 3**2 + 5**2 + 7**2 + assert _attn_work_from_batch(thd, torch.device("cpu"), cp_size=8).item() == 3**2 + 5**2 + 7**2 + + def test_include_padding_thd(self): + """THD include_padding=True uses cu_seq_lens_q_padded; False uses cu_seq_lens_q.""" + batch = { + "input_ids": torch.zeros(1, 16, dtype=torch.long), + "cu_seq_lens_q": torch.tensor([0, 5, 11], dtype=torch.int32), + "cu_seq_lens_q_padded": torch.tensor([0, 8, 16], dtype=torch.int32), + } + dev = torch.device("cpu") + assert _attn_work_from_batch(batch, dev, cp_size=1, include_padding=False).item() == 5**2 + 6**2 + assert _attn_work_from_batch(batch, dev, cp_size=1, include_padding=True).item() == 8**2 + 8**2 + + def test_include_padding_bshd_with_attention_mask(self): + """BSHD include_padding=False uses attention_mask; True uses full shape.""" + mask = torch.tensor([[1, 1, 1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0, 0, 0]], dtype=torch.int64) + batch = {"input_ids": torch.zeros(2, 8, dtype=torch.long), "attention_mask": mask} + dev = torch.device("cpu") + assert _attn_work_from_batch(batch, dev, cp_size=1, include_padding=False).item() == 5**2 + 3**2 + assert _attn_work_from_batch(batch, dev, cp_size=1, include_padding=True).item() == 2 * 8 * 8 diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/train_fsdp2.py b/bionemo-recipes/recipes/opengenome2_llama_native_te/train_fsdp2.py index 15d173b955..701391c316 100644 --- a/bionemo-recipes/recipes/opengenome2_llama_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/opengenome2_llama_native_te/train_fsdp2.py @@ -258,7 +258,11 @@ def main(args: DictConfig) -> float | None: start_step = 0 epoch = 0 - perf_logger = PerfLogger(dist_config, args) + perf_logger = PerfLogger( + dist_config, + args, + model_config_dict=config.to_dict(), + ) # Setup validation if enabled val_config = getattr(args, "validation", None) diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/train_fsdp2_cp.py b/bionemo-recipes/recipes/opengenome2_llama_native_te/train_fsdp2_cp.py index 3319fb5d25..e97462ebb6 100644 --- a/bionemo-recipes/recipes/opengenome2_llama_native_te/train_fsdp2_cp.py +++ b/bionemo-recipes/recipes/opengenome2_llama_native_te/train_fsdp2_cp.py @@ -298,7 +298,11 @@ def main(args: DictConfig) -> float | None: start_step = 0 epoch = 0 - perf_logger = PerfLogger(dist_config, args) + perf_logger = PerfLogger( + dist_config, + args, + model_config_dict=config.to_dict(), + ) gc.collect() torch.cuda.empty_cache()