diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/README.md b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/README.md new file mode 100644 index 0000000000..1b6a7eb0d5 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/README.md @@ -0,0 +1,50 @@ +# Evo2 SAE Recipe + +Train a sparse autoencoder on Evo2 (DNA language model) residual-stream activations. + +Pipeline: + +``` +HF Savanna ckpt --convert--> MBridge ckpt + | + extract.py (FASTA in, ActivationStore parquet shards out) + | + train.py (TopK SAE) +``` + +`extract.py` monkey-patches `predict_evo2`'s writer to stream parquet shards inline, +so there is no `.pt` intermediate and no separate shim step. + +The eval / dashboard stage from the esm2 recipe is intentionally not ported in v1. + +## Quick start (1B model, 4 GPU) + +```bash +bash scripts/1b.sh +``` + +This will: + +1. Convert `arcinstitute/savanna_evo2_1b_base` to MBridge format +2. Run `extract.py` on the OpenGenome2 organelle FASTA, streaming layer-12 + activations directly to parquet shards (no `.pt` intermediate) +3. Train a TopK SAE (expansion=8, k=32, auxk=512) + +Common overrides: + +```bash +# Different layer, different FASTA, tagged output paths +LAYER=22 FASTA=/data/.../prokeuk_25M.fasta RUN_TAG=25M_prokeuk bash scripts/1b.sh + +# Skip extraction (assumes parquet already exists at PARQUET_DIR) +TRAIN_ONLY=1 PARQUET_DIR=/data/.../parquet_25M_prokeuk bash scripts/1b.sh + +# Sweep hyperparams +LAYER=22 RUN_TAG=auxk2048 AUXK=2048 N_EPOCHS=4 bash scripts/1b.sh +``` + +See the top of `scripts/1b.sh` for the full list of env-overridable variables +(`MODEL`, `LAYER`, `CHUNK_BP`, `FASTA`, `RUN_TAG`, `MAX_TOKENS`, `MICRO_BATCH`, +`DEVICES`, `EXPANSION_FACTOR`, `TOP_K`, `AUXK`, `AUXK_COEF`, +`DEAD_TOKENS_THRESHOLD`, `N_EPOCHS`, `LR`, `WANDB_API_KEY`, `WANDB_PROJECT`, +`WANDB_RUN_NAME`, `TRAIN_ONLY`). diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml new file mode 100644 index 0000000000..1f00a62bc5 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml @@ -0,0 +1,27 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "evo2-sae" +version = "0.1.0" +description = "Sparse Autoencoders for the Evo2 DNA language model" +readme = "README.md" +requires-python = ">=3.10" + +dependencies = [ + "sae", + "torch>=2.0", + "numpy>=1.20", + "tqdm>=4.60", + "pyarrow>=10.0", +] + +# No package code lives here yet — the recipe is just an entry-point for +# scripts/ that depends on the shared `sae` workspace package. Declare no +# packages so setuptools doesn't try to discover anything. +[tool.setuptools] +packages = [] + +[tool.uv.sources] +sae = { workspace = true } diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/1b.sh b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/1b.sh new file mode 100755 index 0000000000..6ed5c0fcc1 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/1b.sh @@ -0,0 +1,139 @@ +#!/bin/bash +# Evo2 1B SAE pipeline: convert -> extract (streaming parquet) -> train. +# +# Assumes: +# - bionemo-recipes/recipes/evo2_megatron has been built (.ci_build.sh) and +# its .venv is active, providing predict_evo2 + evo2_convert_savanna_to_mbridge. +# - The sae workspace package is importable in that same venv. +# - HF_TOKEN is set if the Savanna checkpoint repo is gated. +# +# Override any of these by exporting before invocation. + +set -euo pipefail + +EVO2_MEGATRON_DIR="${EVO2_MEGATRON_DIR:-/workspace/bionemo-framework/bionemo-recipes/recipes/evo2_megatron}" +RECIPE_DIR="$(cd "$(dirname "$0")/.." && pwd)" + +MODEL="${MODEL:-arcinstitute/savanna_evo2_1b_base}" +MODEL_SIZE="${MODEL_SIZE:-evo2_1b_base}" +LAYER="${LAYER:-12}" +# Trained context length. 1B = 8192. Bump for 7B/40B (context-extended). +CHUNK_BP="${CHUNK_BP:-8192}" + +FASTA="${FASTA:-/data/interp/evo2/OpenGenome2/fasta/organelles/organelle_sequences.fasta.gz}" +WORK_ROOT="${WORK_ROOT:-/data/interp/evo2}" + +# Default output paths can be overridden per-run. Set RUN_TAG to suffix the +# activation and SAE paths at once (e.g. RUN_TAG=100M_mixed -> ..._parquet_100M_mixed), +# or override each path individually for full control. +RUN_TAG="${RUN_TAG:-}" +_SUFFIX="${RUN_TAG:+_${RUN_TAG}}" + +CKPT_DIR="${CKPT_DIR:-${WORK_ROOT}/checkpoints/${MODEL_SIZE}_mbridge}" +PARQUET_DIR="${PARQUET_DIR:-${WORK_ROOT}/activations/${MODEL_SIZE}_layer${LAYER}_parquet${_SUFFIX}}" +OUTPUT_DIR="${OUTPUT_DIR:-${WORK_ROOT}/sae/${MODEL_SIZE}_layer${LAYER}${_SUFFIX}}" + +source "${EVO2_MEGATRON_DIR}/.venv/bin/activate" + +# TRAIN_ONLY=1 skips chunk/convert/extract against a cached parquet. +if [[ "${TRAIN_ONLY:-0}" == "1" ]]; then + echo "============================================================" + echo "TRAIN_ONLY=1 — skipping chunk / convert / extract;" + echo "expecting an existing parquet at: $PARQUET_DIR" + echo "============================================================" + if [[ ! -f "${PARQUET_DIR}/metadata.json" ]]; then + echo "ERROR: TRAIN_ONLY=1 but no parquet at $PARQUET_DIR" + exit 1 + fi +else + +echo "============================================================" +echo "STEP 0: Chunk FASTA to <=${CHUNK_BP} bp (model trained context)" +echo "============================================================" +# chunk_fasta.py reads .gz directly and writes plain .fasta; no separate gunzip needed. +INPUT_STEM="$(basename "$FASTA")" +INPUT_STEM="${INPUT_STEM%.gz}" +INPUT_STEM="${INPUT_STEM%.fasta}" +CHUNKED_FASTA="${WORK_ROOT}/scratch/${INPUT_STEM}_chunked${CHUNK_BP}.fasta" +if [[ -f "$CHUNKED_FASTA" ]]; then + echo "Reusing existing chunked FASTA: $CHUNKED_FASTA" +else + python "${RECIPE_DIR}/scripts/chunk_fasta.py" \ + --input "$FASTA" \ + --output "$CHUNKED_FASTA" \ + --window "$CHUNK_BP" +fi +FASTA="$CHUNKED_FASTA" + +echo "============================================================" +echo "STEP 1: Convert Savanna -> MBridge" +echo "============================================================" +if [[ ! -f "${CKPT_DIR}/latest_checkpointed_iteration.txt" ]]; then + evo2_convert_savanna_to_mbridge \ + --savanna-ckpt-path "$MODEL" \ + --mbridge-ckpt-dir "$CKPT_DIR" \ + --model-size "$MODEL_SIZE" \ + --tokenizer-path "${EVO2_MEGATRON_DIR}/tokenizers/nucleotide_fast_tokenizer_512" +else + echo "Reusing existing MBridge checkpoint at $CKPT_DIR" +fi + +echo "============================================================" +echo "STEP 2: Extract layer-${LAYER} activations directly to parquet" +echo "============================================================" +# extract.py monkey-patches predict_evo2's writer to stream parquet shards +# inline; no .pt intermediate and no separate shim. MAX_TOKENS=0 = uncapped. +if [[ -f "${PARQUET_DIR}/metadata.json" ]]; then + echo "Reusing existing parquet shards at $PARQUET_DIR" +else + torchrun --nproc_per_node "${DEVICES:-4}" "${RECIPE_DIR}/scripts/extract.py" \ + --activation-store-dir "$PARQUET_DIR" \ + --max-tokens "${MAX_TOKENS:-0}" \ + --model-name "$MODEL" \ + --fasta "$FASTA" \ + --ckpt-dir "$CKPT_DIR" \ + --embedding-layer "$LAYER" \ + --micro-batch-size "${MICRO_BATCH:-4}" +fi + +fi # end if TRAIN_ONLY + +echo "============================================================" +echo "STEP 3: Train TopK SAE" +echo "============================================================" +# Wandb is enabled iff WANDB_API_KEY is in the env. WANDB_PROJECT/RUN can be overridden. +WANDB_FLAGS=("--no-wandb") +if [[ -n "${WANDB_API_KEY:-}" ]]; then + WANDB_FLAGS=( + "--wandb" + "--wandb-project" "${WANDB_PROJECT:-evo2-sae}" + ) + if [[ -n "${WANDB_RUN_NAME:-}" ]]; then + WANDB_FLAGS+=("--wandb-run-name" "$WANDB_RUN_NAME") + fi +fi + +torchrun --nproc_per_node "${DEVICES:-4}" "${RECIPE_DIR}/scripts/train.py" \ + --cache-dir "$PARQUET_DIR" \ + --model-path "$MODEL" \ + --layer "$LAYER" \ + --model-type topk \ + --expansion-factor "${EXPANSION_FACTOR:-8}" \ + --top-k "${TOP_K:-32}" \ + --auxk "${AUXK:-512}" \ + --auxk-coef "${AUXK_COEF:-0.03125}" \ + --dead-tokens-threshold "${DEAD_TOKENS_THRESHOLD:-10000000}" \ + --init-pre-bias \ + --n-epochs "${N_EPOCHS:-3}" \ + --batch-size 4096 \ + --dp-size "${DEVICES:-4}" \ + --lr "${LR:-3e-4}" \ + --log-interval 50 \ + "${WANDB_FLAGS[@]}" \ + --output-dir "$OUTPUT_DIR" \ + --checkpoint-dir "${OUTPUT_DIR}/checkpoints" \ + --checkpoint-steps 999999 + +echo "============================================================" +echo "DONE: SAE checkpoint at ${OUTPUT_DIR}/checkpoints/checkpoint_final.pt" +echo "============================================================" diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/chunk_fasta.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/chunk_fasta.py new file mode 100644 index 0000000000..55b26cad30 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/chunk_fasta.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Chunk a FASTA into <=N-bp windows so predict_evo2 stays inside the model's trained context. + +Evo2 1B was trained with seq_length=8192; longer inputs OOM in the Hyena +fftconv path (intermediates scale super-linearly with L). For 7B/40B raise +--window to whatever those checkpoints were context-extended to. + +Non-overlapping windows by default. Each chunk gets a header of the form +">{orig_id}:{start}-{end}" so downstream parquet can be back-mapped. +""" + +import argparse +import gzip +from pathlib import Path + + +def parse_fasta(path: Path): + """Yield (seq_id, sequence) tuples from a FASTA file (transparently handles .gz).""" + opener = gzip.open if path.suffix == ".gz" else open + seq_id, parts = None, [] + with opener(path, "rt") as f: + for line in f: + line = line.rstrip() + if line.startswith(">"): + if seq_id is not None: + yield seq_id, "".join(parts) + seq_id = line[1:].split()[0] + parts = [] + else: + parts.append(line) + if seq_id is not None: + yield seq_id, "".join(parts) + + +def main(): + """Read input FASTA, write non-overlapping <=window-bp chunks to output FASTA.""" + p = argparse.ArgumentParser() + p.add_argument("--input", type=Path, required=True) + p.add_argument("--output", type=Path, required=True) + p.add_argument("--window", type=int, default=8192) + args = p.parse_args() + + n_in = n_out = bp_out = 0 + args.output.parent.mkdir(parents=True, exist_ok=True) + with open(args.output, "w") as out: + for seq_id, seq in parse_fasta(args.input): + n_in += 1 + for start in range(0, len(seq), args.window): + end = min(start + args.window, len(seq)) + chunk = seq[start:end] + out.write(f">{seq_id}:{start}-{end}\n{chunk}\n") + n_out += 1 + bp_out += len(chunk) + + print(f"Chunked {n_in} sequences -> {n_out} chunks ({bp_out:,} bp) at window={args.window}") + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/compose_prokeuk_fasta.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/compose_prokeuk_fasta.py new file mode 100644 index 0000000000..37a1e4b0a0 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/compose_prokeuk_fasta.py @@ -0,0 +1,142 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 + +"""Compose a prokaryotic + eukaryotic FASTA from OpenGenome2 subsets. + +Two sources only: +- Prokaryotic: filtered_metagenomes_pt1 (truncated to --metagenome-window bp/contig) +- Eukaryotic: eukaryotic_genic_windows (~5kb euk genic regions) + +Output headers are renumbered as `>seq_{i} {source}` to satisfy predict_evo2's +unique-id check (the source files share NCBI-style accession headers across +records). +""" + +import argparse +import gzip +import subprocess +import sys +from pathlib import Path + + +def _open_text(path: Path): + """Open a .fasta or .fasta.gz file in text mode.""" + if str(path).endswith(".gz"): + return gzip.open(path, "rt") + return open(path) + + +def _iter_records(fh): + """Yield (header, seq_lines) from a FASTA file handle.""" + header = None + lines = [] + for line in fh: + line = line.rstrip("\n") + if line.startswith(">"): + if header is not None: + yield header, lines + header = line + lines = [] + elif line: + lines.append(line) + if header is not None: + yield header, lines + + +def _take_n(fh, n): + """Take the first n records from fh.""" + for i, (h, ls) in enumerate(_iter_records(fh)): + if i >= n: + return + yield h, ls + + +def _take_n_truncated(fh, n, max_bp): + """Take the first n records, truncating each sequence to <= max_bp bases.""" + for i, (h, ls) in enumerate(_iter_records(fh)): + if i >= n: + return + seq = "".join(ls)[:max_bp] + lines = [seq[j : j + 80] for j in range(0, len(seq), 80)] + yield h, lines + + +def main(): + """Compose a prok+euk mixed FASTA with unique seq_{i} headers.""" + ap = argparse.ArgumentParser() + ap.add_argument( + "--root", + type=Path, + default=Path("/data/interp/evo2/OpenGenome2/fasta"), + help="Root dir holding the OpenGenome2 subset directories.", + ) + ap.add_argument( + "--output", + type=Path, + default=Path("/data/interp/evo2/scratch/mixed_25M_prokeuk.fasta"), + ) + ap.add_argument("--n-metagenome", type=int, default=1000, help="Metagenome contigs (prok).") + ap.add_argument("--metagenome-window", type=int, default=50_000, help="Max bp per metagenome contig (truncate).") + ap.add_argument("--n-euk-windows", type=int, default=10_000, help="Eukaryotic_genic_windows records (euk).") + args = ap.parse_args() + + metagenome_file = args.root / "metagenomes" / "filtered_metagenomes_pt1.fasta.gz" + euk_parts = sorted((args.root / "eukaryotic_genic_windows").glob("*.fasta.gz.*")) + + if not metagenome_file.exists(): + print(f"ERROR: missing metagenome source at {metagenome_file}", file=sys.stderr) + sys.exit(1) + if not euk_parts: + print(f"ERROR: no eukaryotic_genic_windows parts under {args.root}", file=sys.stderr) + sys.exit(1) + + args.output.parent.mkdir(parents=True, exist_ok=True) + counter = 0 + bp_by_source: dict[str, int] = {} + + def _emit(out, header, lines, source): + """Write a record with a globally unique header and tally bp.""" + nonlocal counter + out.write(f">seq_{counter} {source}\n") + counter += 1 + for line in lines: + out.write(line + "\n") + bp_by_source[source] = bp_by_source.get(source, 0) + sum(len(line) for line in lines) + + with open(args.output, "w") as out: + # 1. Prokaryotic: metagenome contigs, truncated to --metagenome-window each. + print(f"adding {args.n_metagenome} metagenome contigs (truncated to {args.metagenome_window} bp each)...") + with _open_text(metagenome_file) as fh: + for h, ls in _take_n_truncated(fh, args.n_metagenome, args.metagenome_window): + _emit(out, h, ls, "prok_metagenomes") + + # 2. Eukaryotic: read split parts as one stream. + print(f"adding {args.n_euk_windows} eukaryotic_genic_windows...") + cat_parts = subprocess.Popen( + ["bash", "-c", f"cat {' '.join(str(p) for p in euk_parts)} | zcat"], + stdout=subprocess.PIPE, + text=True, + ) + try: + for h, ls in _take_n(cat_parts.stdout, args.n_euk_windows): + _emit(out, h, ls, "euk_genic_windows") + finally: + cat_parts.stdout.close() + cat_parts.terminate() + + total_bp = sum(bp_by_source.values()) + print(f"\nwrote {counter} sequences, {total_bp:,} total bp -> {args.output}") + print("by source (bp):") + for src, bp in bp_by_source.items(): + pct = 100 * bp / total_bp if total_bp else 0 + print(f" {src:<22} {bp:>12,} bp ({pct:>5.1f}%)") + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/extract.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/extract.py new file mode 100644 index 0000000000..2b4d753bcd --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/extract.py @@ -0,0 +1,206 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 + +"""Streaming Evo2 activation extractor — codonfm-style. + +Reuses `bionemo.evo2.run.predict` for all the heavy machinery (Megatron +model load, DP/CP/TP/PP setup, FASTA dataloader, inference loop) but +swaps the per-batch `.pt` writer for an in-process `ActivationStore` +that streams parquet shards directly during inference. + +Why: predict_evo2's `.pt` intermediate doubles disk volume and forces a +slow downstream pt->parquet shim. For SAE training, the activation tensor +is all we need; writing it directly into the SAE's ActivationStore format +removes the shim entirely, mirroring how codonfm's scripts/extract.py +already works. + +Invocation: + + torchrun --nproc_per_node 4 extract.py \ + --fasta path/to/seq.fasta \ + --ckpt-dir path/to/mbridge_ckpt \ + --embedding-layer 20 \ + --activation-store-dir /data/.../parquet_out \ + --max-tokens 25000000 \ + --micro-batch-size 4 + +All non-`--activation-store-dir`/`--max-tokens` flags are forwarded to +predict_evo2's argparse (`--fasta`, `--ckpt-dir`, `--embedding-layer`, +`--micro-batch-size`, etc.) so the inference surface is identical. +""" + +import argparse +import os +import shutil +import sys +from pathlib import Path + +# predict.py provides the entire Megatron inference plumbing. +from bionemo.evo2.run import predict as predict_mod # noqa: E402 + +# The SAE activation store — same format the existing pt_to_parquet.py emits. +from sae.activation_store import ActivationStore, ActivationStoreConfig # noqa: E402 + +# Reuse the merge step we already have in pt_to_parquet.py. +sys.path.insert(0, str(Path(__file__).resolve().parent)) +from pt_to_parquet import _merge_temp_stores # noqa: E402 + + +# Per-rank state. Each torchrun rank is its own Python process, so this +# module-level dict is rank-local — exactly what we want. +_state: dict = { + "store": None, # ActivationStore for this rank + "n_tokens": 0, # tokens appended on this rank + "n_sequences": 0, # hidden.shape[0] across batches (raw seqs) + "budget": 0, # per-rank token cap (0 = no cap) + "store_root": None, # Path — final output dir; per-rank tmp lives under .tmp_rank_ + "rank_tmp": None, # Path — this rank's temp dir +} + + +def _store_writer( + predictions, + output_dir, + batch_idx, + global_rank, + dp_rank, + files_per_subdir=None, + num_files_written=0, + data_parallel_world_size=1, +): + """Replacement for predict._write_predictions_batch — append to ActivationStore. + + Signature matches the original; return shape `(path, updated_count, 0)`. + """ + if not predictions: + return output_dir, num_files_written, 0 + + # Once we've hit the per-rank budget, skip remaining writes (forward + # passes still run; cheap relative to the I/O we're skipping). + if _state["budget"] and _state["n_tokens"] >= _state["budget"]: + return output_dir, num_files_written, 0 + + hidden = predictions["hidden_embeddings"] # [B, S, H] + mask = predictions["pad_mask"].bool() + flat = hidden[mask].cpu() # [N_unpadded_tokens, H] + + if _state["store"] is None: + rank_tmp = _state["store_root"].with_name( + _state["store_root"].name + f".tmp_rank_{dp_rank}" + ) + rank_tmp.mkdir(parents=True, exist_ok=True) + _state["rank_tmp"] = rank_tmp + _state["store"] = ActivationStore(rank_tmp, ActivationStoreConfig(shard_size=100_000)) + + _state["store"].append(flat) + _state["n_tokens"] += flat.shape[0] + _state["n_sequences"] += hidden.shape[0] + return output_dir, num_files_written + 1, 0 + + +def _finalize_and_maybe_merge(model_name: str, layer: int) -> None: + """Finalize this rank's store, then rank 0 waits for all ranks and merges. + + We use a file-based wait (poll for sibling ranks' metadata.json) rather + than torch.distributed.barrier(): predict.main() tears down the process + group before this hook runs, so dist.barrier() silently no-ops and rank 0 + would race ahead of slower ranks (observed in the prok+euk run — rank 0 + merged its own dir before ranks 1-3 finalized, leaving 18M tokens orphaned). + """ + if _state["store"] is not None: + _state["store"].finalize(metadata={"n_sequences": _state["n_sequences"]}) + + if int(os.environ.get("RANK", "0")) != 0: + return + + # Rank 0 waits for all siblings to finalize before merging. + import time + + store_root: Path = _state["store_root"] + world_size = int(os.environ.get("WORLD_SIZE", "1")) + deadline = time.time() + 600 # 10 min wait cap + + def _ready_count() -> int: + return sum( + 1 + for r in range(world_size) + if (store_root.with_name(store_root.name + f".tmp_rank_{r}") / "metadata.json").exists() + ) + + while time.time() < deadline: + ready = _ready_count() + if ready >= world_size: + break + time.sleep(2) + else: + print( + f"[extract] WARN: only {_ready_count()}/{world_size} ranks finalized within 10 min — " + "merging what's available; some activations may be orphaned" + ) + + tmp_dirs = sorted( + p + for p in store_root.parent.glob(store_root.name + ".tmp_rank_*") + if p.is_dir() and (p / "metadata.json").exists() + ) + if not tmp_dirs: + print(f"[extract] no rank tmp dirs found under {store_root.parent} — nothing to merge") + return + print(f"[extract] merging {len(tmp_dirs)} rank tmp dirs into {store_root}") + final = _merge_temp_stores(tmp_dirs, store_root, model_name, layer) + print(f"[extract] done: {final}") + + +def main() -> None: + """Parse the extractor-specific flags, monkey-patch predict's writer, run predict.main().""" + parser = argparse.ArgumentParser(add_help=False) + parser.add_argument("--activation-store-dir", type=Path, required=True) + parser.add_argument("--max-tokens", type=int, default=0, help="Cap total tokens across DP ranks (0 = no cap).") + parser.add_argument("--model-name", type=str, default="arcinstitute/savanna_evo2_1b_base") + extract_args, remaining = parser.parse_known_args() + + dp_size = int(os.environ.get("WORLD_SIZE", "1")) + _state["store_root"] = extract_args.activation_store_dir + _state["budget"] = extract_args.max_tokens // dp_size if extract_args.max_tokens else 0 + + # Force batch write-interval so our writer is called every iteration + # (epoch mode would buffer everything in memory, defeating the point). + if "--write-interval" not in remaining: + remaining.extend(["--write-interval", "batch"]) + + # predict.main() requires --output-dir; we point it at a throwaway path + # (writer never actually writes there). + if "--output-dir" not in remaining: + scratch = _state["store_root"].with_name(_state["store_root"].name + ".predict_unused") + scratch.mkdir(parents=True, exist_ok=True) + remaining.extend(["--output-dir", str(scratch)]) + + # Capture for the merge metadata; we need to know which layer / model + # to stamp into the merged ActivationStore.metadata. + layer = 0 + for i, a in enumerate(remaining): + if a == "--embedding-layer": + layer = int(remaining[i + 1]) + + # Substitute our writer for predict's. predict.py calls the bare name + # `_write_predictions_batch(...)` in its module scope, so module-attr + # replacement is enough. + predict_mod._write_predictions_batch = _store_writer + + # Hand predict's parser only the args it expects. + sys.argv = [sys.argv[0]] + remaining + + try: + predict_mod.main() + finally: + _finalize_and_maybe_merge(extract_args.model_name, layer) + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/train.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/train.py new file mode 100644 index 0000000000..19355822ae --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/train.py @@ -0,0 +1,321 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Step 2: Train SAE from cached CodonFM activations. + +Loads pre-extracted activations from an ActivationStore cache directory +and trains a Sparse Autoencoder. Requires extract.py to have been run first. + +Single-GPU: + python scripts/train.py \ + --cache-dir .cache/activations/encodon_1b_layer-2 \ + --model-path path/to/encodon_1b --layer -2 \ + --expansion-factor 8 --top-k 32 --batch-size 4096 --n-epochs 3 + +Multi-GPU DDP: + torchrun --nproc_per_node=4 scripts/train.py \ + --cache-dir .cache/activations/encodon_1b_layer-2 \ + --model-path path/to/encodon_1b --layer -2 \ + --expansion-factor 8 --top-k 32 --batch-size 4096 --n-epochs 3 \ + --dp-size 4 +""" + +import argparse +import os +from pathlib import Path + +import numpy as np +import torch +from sae.activation_store import load_activations +from sae.architectures import ReLUSAE, TopKSAE +from sae.perf_logger import PerfLogger +from sae.training import ParallelConfig, Trainer, TrainingConfig, WandbConfig +from sae.utils import get_device, set_seed + + +def parse_args(): # noqa: D103 + p = argparse.ArgumentParser( + description="Train SAE from cached CodonFM activations", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Required + p.add_argument("--cache-dir", type=str, required=True, help="Path to activation cache (from extract.py)") + p.add_argument("--model-path", type=str, required=True, help="Encodon model path (for cache validation)") + p.add_argument("--layer", type=int, required=True, help="Layer index (for cache validation)") + + # SAE architecture + sae_group = p.add_argument_group("SAE model") + sae_group.add_argument("--model-type", type=str, default="topk", choices=["topk", "relu"]) + sae_group.add_argument("--expansion-factor", type=int, default=8) + sae_group.add_argument("--top-k", type=int, default=32) + sae_group.add_argument("--normalize-input", action=argparse.BooleanOptionalAction, default=False) + sae_group.add_argument("--auxk", type=int, default=None) + sae_group.add_argument("--auxk-coef", type=float, default=1 / 32) + sae_group.add_argument("--dead-tokens-threshold", type=int, default=10_000_000) + sae_group.add_argument("--init-pre-bias", action=argparse.BooleanOptionalAction, default=False) + sae_group.add_argument("--l1-coeff", type=float, default=1e-2, help="L1 coefficient (relu only)") + + # Training + train_group = p.add_argument_group("Training") + train_group.add_argument("--lr", type=float, default=3e-4) + train_group.add_argument("--n-epochs", type=int, default=3) + train_group.add_argument("--batch-size", type=int, default=4096) + train_group.add_argument("--log-interval", type=int, default=50) + train_group.add_argument("--shuffle", action=argparse.BooleanOptionalAction, default=True) + train_group.add_argument("--num-workers", type=int, default=0) + train_group.add_argument("--pin-memory", action=argparse.BooleanOptionalAction, default=False) + train_group.add_argument("--max-grad-norm", type=float, default=None) + train_group.add_argument("--lr-scale-with-latents", action=argparse.BooleanOptionalAction, default=False) + train_group.add_argument("--lr-reference-hidden-dim", type=int, default=2048) + train_group.add_argument("--warmup-steps", type=int, default=0, help="Linear LR warmup steps") + train_group.add_argument( + "--lr-schedule", + type=str, + default="constant", + choices=["constant", "cosine", "linear"], + help="LR schedule after warmup", + ) + train_group.add_argument("--lr-min", type=float, default=0.0, help="Minimum LR for decay schedules") + train_group.add_argument( + "--lr-decay-steps", + type=int, + default=None, + help="Total steps for LR decay (None = full training)", + ) + + # W&B + wb_group = p.add_argument_group("Weights & Biases") + wb_group.add_argument("--wandb", action=argparse.BooleanOptionalAction, default=False, dest="wandb_enabled") + wb_group.add_argument("--wandb-project", type=str, default="sae_codonfm_recipe") + wb_group.add_argument("--wandb-run-name", type=str, default=None) + wb_group.add_argument("--wandb-group", type=str, default=None) + wb_group.add_argument("--wandb-job-type", type=str, default=None) + + # Checkpointing + ckpt_group = p.add_argument_group("Checkpointing") + ckpt_group.add_argument("--checkpoint-dir", type=str, default=None) + ckpt_group.add_argument("--checkpoint-steps", type=int, default=None) + ckpt_group.add_argument("--resume-from", type=str, default=None) + + # Infrastructure + p.add_argument("--dp-size", type=int, default=1) + p.add_argument("--output-dir", type=str, default="./outputs") + p.add_argument("--seed", type=int, default=42) + p.add_argument("--device", type=str, default=None) + p.add_argument( + "--num-sequences", + type=int, + default=None, + help="Subset cached activations to this many sequences' worth of shards", + ) + + return p.parse_args() + + +def build_sae(args, input_dim: int) -> torch.nn.Module: # noqa: D103 + hidden_dim = input_dim * args.expansion_factor + + if args.model_type == "topk": + return TopKSAE( + input_dim=input_dim, + hidden_dim=hidden_dim, + top_k=args.top_k, + normalize_input=args.normalize_input, + auxk=args.auxk, + auxk_coef=args.auxk_coef, + dead_tokens_threshold=args.dead_tokens_threshold, + ) + elif args.model_type == "relu": + return ReLUSAE( + input_dim=input_dim, + hidden_dim=hidden_dim, + l1_coeff=args.l1_coeff, + ) + else: + raise ValueError(f"Unknown model type: {args.model_type}") + + +def build_training_config(args, device: str) -> TrainingConfig: # noqa: D103 + return TrainingConfig( + lr=args.lr, + n_epochs=args.n_epochs, + batch_size=args.batch_size, + device=device, + log_interval=args.log_interval, + shuffle=args.shuffle, + num_workers=args.num_workers, + pin_memory=args.pin_memory, + checkpoint_dir=args.checkpoint_dir, + checkpoint_steps=args.checkpoint_steps, + lr_scale_with_latents=args.lr_scale_with_latents, + lr_reference_hidden_dim=args.lr_reference_hidden_dim, + warmup_steps=args.warmup_steps, + max_grad_norm=args.max_grad_norm, + lr_schedule=args.lr_schedule, + lr_min=args.lr_min, + lr_decay_steps=args.lr_decay_steps, + ) + + +def build_wandb_config(args) -> WandbConfig: # noqa: D103 + return WandbConfig( + enabled=args.wandb_enabled, + project=args.wandb_project, + run_name=args.wandb_run_name, + group=args.wandb_group, + job_type=args.wandb_job_type, + config=vars(args), + ) + + +def build_parallel_config(args) -> ParallelConfig: # noqa: D103 + return ParallelConfig(dp_size=args.dp_size) + + +def main(): # noqa: D103 + args = parse_args() + + set_seed(args.seed) + device = args.device or get_device() + print(f"Using device: {device}") + print(f"Config: {vars(args)}") + + # Load cached activations + cache_path = Path(args.cache_dir) + if not (cache_path / "metadata.json").exists(): + raise FileNotFoundError(f"No cache found at {cache_path}. Run extract.py first.") + + store = load_activations(cache_path) + meta = store.metadata + + # Validate cache matches config + cached_model = meta.get("model_path", meta.get("model_name", "")) + if cached_model and cached_model != args.model_path: + print(f"WARNING: Cache model '{cached_model}' != '{args.model_path}'") + if meta.get("layer") != args.layer: + raise ValueError(f"Cache layer mismatch: {meta['layer']} vs {args.layer}") + + # Compute subsetting + cached_sequences = meta.get("n_sequences", None) + max_shards = None + if args.num_sequences and cached_sequences and args.num_sequences < cached_sequences: + keep_ratio = args.num_sequences / cached_sequences + max_shards = max(1, int(np.ceil(keep_ratio * meta["n_shards"]))) + print( + f"Subsetting: {args.num_sequences}/{cached_sequences} sequences " + f"-> using {max_shards}/{meta['n_shards']} shards (~{keep_ratio:.1%})" + ) + + # Estimate memory + n_shards_to_use = max_shards or meta["n_shards"] + shard_size = meta.get("shard_size", 100_000) + est_tokens = n_shards_to_use * shard_size + est_gb = est_tokens * meta["hidden_dim"] * 4 / (1024**3) + use_streaming = est_gb > 50 + + input_dim = meta["hidden_dim"] + sae = build_sae(args, input_dim) + print(f"SAE: {args.model_type}, input_dim={input_dim}, hidden_dim={sae.hidden_dim}") + + # Initialize pre_bias + if args.init_pre_bias and hasattr(sae, "init_pre_bias_from_data"): + print("Initializing pre_bias from geometric median of data...") + first_shard = torch.from_numpy(store._load_shard(0)).float() + sample_size = min(32768, len(first_shard)) + sae.init_pre_bias_from_data(first_shard[:sample_size]) + print(f" pre_bias initialized (mean={sae.pre_bias.mean().item():.4f})") + del first_shard + + # Build configs + training_config = build_training_config(args, device) + wandb_config = build_wandb_config(args) + parallel_config = build_parallel_config(args) + + perf_logger = PerfLogger( + log_interval=args.log_interval, + use_wandb=args.wandb_enabled, + print_logs=True, + device=device, + ) + + # Train + trainer = Trainer( + sae, + training_config, + wandb_config=wandb_config, + perf_logger=perf_logger, + parallel_config=parallel_config, + ) + + if use_streaming: + rank = int(os.environ.get("RANK", 0)) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + print( + f"Streaming from disk (~{est_gb:.0f}GB). " + f"Peak RAM: ~{shard_size * meta['hidden_dim'] * 4 / (1024**3):.1f}GB/process" + ) + + dataloader = store.get_streaming_dataloader( + batch_size=args.batch_size, + shuffle=args.shuffle, + seed=args.seed, + rank=rank, + world_size=world_size, + max_shards=max_shards, + ) + # Compute min batch count across all ranks to keep DDP in sync + # Read parquet footers for all ranks' shards (a few KB each, no data loading) + if world_size > 1: + import pyarrow.parquet as pq_meta + + dataset = dataloader.dataset + per_rank = len(dataset.shard_indices) + # Each rank got per_rank contiguous shards; compute batch count for each rank + min_batches = None + for r in range(world_size): + total_rows = sum( + pq_meta.read_metadata(store.path / f"shard_{idx:05d}.parquet").num_rows + for idx in range(r * per_rank, (r + 1) * per_rank) + ) + batches = total_rows // args.batch_size + if min_batches is None or batches < min_batches: + min_batches = batches + dataset.max_batches = min_batches + print(f"[rank {rank}] capped to {min_batches} batches/epoch for DDP sync") + trainer.fit( + dataloader, + resume_from=args.resume_from, + data_sharded=True, + ) + else: + shards = [] + for i, shard in enumerate(store.iter_shards(shuffle_shards=False)): + if max_shards is not None and i >= max_shards: + break + shards.append(torch.from_numpy(shard).float()) + activations_flat = torch.cat(shards) + print(f"Loaded {activations_flat.shape[0]:,} cached activations into memory") + + trainer.fit( + activations_flat, + resume_from=args.resume_from, + ) + + print("Training complete.") + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/architectures/topk.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/architectures/topk.py index 3be46cdf7f..53226f93b1 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/architectures/topk.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/architectures/topk.py @@ -289,7 +289,10 @@ def _compute_auxk_loss( recon_aux = F.linear(codes_aux, self.decoder.weight[:, dead_indices], self.decoder.bias) # Target is the residual (what primary reconstruction missed) - # Work in normalized space for the aux loss + # Canonical: residual = x - recon, i.e. the actual reconstruction error. + # The previous form `x - recon + pre_bias` simplifies to `x - decoder(codes)`, + # which has norm dominated by ||pre_bias|| rather than the actual error, + # weakening the aux gradient by roughly (||pre_bias|| / ||error||)^2. if self.normalize_input and norm_info is not None: # Normalize x to match the space where encoding happened x_norm = (x - norm_info["mu"]) / norm_info["std"] @@ -297,7 +300,7 @@ def _compute_auxk_loss( recon_norm = self.decoder(codes) + self.pre_bias residual = x_norm - recon_norm.detach() else: - residual = x - recon.detach() + self.pre_bias.detach() + residual = x - recon.detach() # Normalized MSE: MSE / variance of target mse = (recon_aux - residual).pow(2).mean(dim=-1) # [batch] diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/predict.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/predict.py index c888d46c94..92dfb2051a 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/predict.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/predict.py @@ -586,8 +586,11 @@ def _padding_collate_fn( min_length: Minimum length to pad to Returns: - Dictionary with batched and padded tensors + Dictionary with batched and padded tensors, or None when the input + batch is empty (can happen on DP shard boundaries — caller must skip). """ + if not batch: + return None max_len = max(sample["tokens"].shape[0] for sample in batch) if min_length is not None: max_len = max(max_len, min_length) @@ -1197,6 +1200,9 @@ def predict( with torch.no_grad(): for batch_idx, batch_data in enumerate(dataloader): + # Empty batches can be handed to a rank on DP shard boundaries. + if batch_data is None: + continue # Move to GPU batch_gpu = { k: v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for k, v in batch_data.items() diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/utils/checkpoint/savanna_to_mbridge.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/utils/checkpoint/savanna_to_mbridge.py index 811b07153e..156ce530b5 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/utils/checkpoint/savanna_to_mbridge.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/utils/checkpoint/savanna_to_mbridge.py @@ -135,7 +135,7 @@ def load_savanna_state_dict(path: Path) -> dict[str, torch.Tensor]: Returns: Flat state dict with keys like 'sequential.{i}.xxx'. """ - raw = torch.load(str(path), map_location="cpu", weights_only=True, mmap=True) + raw = torch.load(str(path), map_location="cpu", weights_only=False, mmap=True) if "module" in raw: raw = raw["module"]