From 60deff46f045dd2916b4cff3e53565e0a95b7c3c Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Thu, 21 May 2026 00:41:33 +0000 Subject: [PATCH 1/7] evo2_megatron: load Savanna HF checkpoints with weights_only=False MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit torch 2.6 changed the default of `weights_only` to True. The Savanna checkpoint pickle includes numpy globals (`numpy.core.multiarray._reconstruct`), which the safer loader rejects. The converter then exits 0 with no output written and the error gets buried in stderr — silent failure. The Savanna repos under arcinstitute/* are trusted sources, so load with weights_only=False. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../src/bionemo/evo2/utils/checkpoint/savanna_to_mbridge.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/utils/checkpoint/savanna_to_mbridge.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/utils/checkpoint/savanna_to_mbridge.py index 811b07153e..156ce530b5 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/utils/checkpoint/savanna_to_mbridge.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/utils/checkpoint/savanna_to_mbridge.py @@ -135,7 +135,7 @@ def load_savanna_state_dict(path: Path) -> dict[str, torch.Tensor]: Returns: Flat state dict with keys like 'sequential.{i}.xxx'. """ - raw = torch.load(str(path), map_location="cpu", weights_only=True, mmap=True) + raw = torch.load(str(path), map_location="cpu", weights_only=False, mmap=True) if "module" in raw: raw = raw["module"] From b640f66ab62600ccf5da05f8ced1bb7f2797534f Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Thu, 21 May 2026 00:43:15 +0000 Subject: [PATCH 2/7] interpretability/sae: add Evo2 1B SAE recipe Mirrors the existing esm2 / codonfm SAE recipes. Pipeline: chunk -> convert (Savanna->MBridge) -> predict_evo2 -> pt_to_parquet -> train Differences from esm2/codonfm are forced by Evo2 specifics: - Hyena/Megatron-Core model, no HF AutoModel path => reuses the existing `predict_evo2` CLI for inference instead of writing a custom extract.py - `pt_to_parquet.py` shim bridges predict_evo2's .pt output to the universal `sae.activation_store` parquet contract - `chunk_fasta.py` preprocessor keeps inputs within the model's trained context length (8192 bp for 1B); Hyena fftconv OOMs on long sequences even at micro-batch=1 - `train.py` is the same as codonfm's, copied verbatim per bionemo-recipes' KISS-over-DRY convention Validated end-to-end on 100 organelle sequences (Evo2 1B layer 12): loss 0.67 -> 0.045, FVU 0.90 -> 0.10, var_exp 0.10 -> 0.90, 2m14s wall. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../recipes/evo2/README.md | 30 ++ .../recipes/evo2/pyproject.toml | 24 ++ .../recipes/evo2/scripts/1b.sh | 116 +++++++ .../recipes/evo2/scripts/chunk_fasta.py | 73 ++++ .../recipes/evo2/scripts/pt_to_parquet.py | 65 ++++ .../recipes/evo2/scripts/train.py | 321 ++++++++++++++++++ .../recipes/evo2/src/evo2_sae/__init__.py | 16 + 7 files changed, 645 insertions(+) create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/README.md create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml create mode 100755 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/1b.sh create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/chunk_fasta.py create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/pt_to_parquet.py create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/train.py create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/__init__.py diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/README.md b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/README.md new file mode 100644 index 0000000000..ad749dbedb --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/README.md @@ -0,0 +1,30 @@ +# Evo2 SAE Recipe + +Train a sparse autoencoder on Evo2 (DNA language model) residual-stream activations. + +Pipeline: + +``` +HF Savanna ckpt --convert--> MBridge ckpt + | + predict_evo2 --embedding-layer N (FASTA in, .pt out) + | + pt_to_parquet shim (.pt -> ActivationStore parquet shards) + | + train.py (TopK SAE) +``` + +The eval / dashboard stage from the esm2 recipe is intentionally not ported in v1. + +## Quick start (1B model, single GPU) + +```bash +bash scripts/1b.sh +``` + +This will: + +1. Convert `arcinstitute/savanna_evo2_1b_base` to MBridge format +2. Run `predict_evo2` on the OpenGenome2 organelle FASTA, extracting layer-12 embeddings +3. Convert the .pt outputs to parquet shards +4. Train a TopK SAE (expansion=8, k=32) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml new file mode 100644 index 0000000000..26eff6b55c --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml @@ -0,0 +1,24 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "evo2-sae" +version = "0.1.0" +description = "Sparse Autoencoders for the Evo2 DNA language model" +readme = "README.md" +requires-python = ">=3.10" + +dependencies = [ + "sae", + "torch>=2.0", + "numpy>=1.20", + "tqdm>=4.60", + "pyarrow>=10.0", +] + +[tool.setuptools.packages.find] +where = ["src"] + +[tool.uv.sources] +sae = { workspace = true } diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/1b.sh b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/1b.sh new file mode 100755 index 0000000000..d499b4f365 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/1b.sh @@ -0,0 +1,116 @@ +#!/bin/bash +# Evo2 1B SAE pipeline: convert -> predict_evo2 -> pt_to_parquet -> train. +# +# Assumes: +# - bionemo-recipes/recipes/evo2_megatron has been built (.ci_build.sh) and +# its .venv is active, providing predict_evo2 + evo2_convert_savanna_to_mbridge. +# - The sae workspace package is importable in that same venv. +# - HF_TOKEN is set if Savanna checkpoint repo is gated. +# +# Override any of these by exporting before invocation. + +set -euo pipefail + +EVO2_MEGATRON_DIR="${EVO2_MEGATRON_DIR:-/workspace/bionemo-framework/bionemo-recipes/recipes/evo2_megatron}" +RECIPE_DIR="$(cd "$(dirname "$0")/.." && pwd)" + +MODEL="${MODEL:-arcinstitute/savanna_evo2_1b_base}" +MODEL_SIZE="${MODEL_SIZE:-evo2_1b_base}" +LAYER="${LAYER:-12}" +# Trained context length. 1B = 8192. Bump for 7B/40B (context-extended). +CHUNK_BP="${CHUNK_BP:-8192}" + +FASTA="${FASTA:-/data/interp/evo2/OpenGenome2/fasta/organelles/organelle_sequences.fasta.gz}" +WORK_ROOT="${WORK_ROOT:-/data/interp/evo2}" + +CKPT_DIR="${WORK_ROOT}/checkpoints/${MODEL_SIZE}_mbridge" +PREDICT_DIR="${WORK_ROOT}/activations/${MODEL_SIZE}_layer${LAYER}_pt" +PARQUET_DIR="${WORK_ROOT}/activations/${MODEL_SIZE}_layer${LAYER}_parquet" +OUTPUT_DIR="${WORK_ROOT}/sae/${MODEL_SIZE}_layer${LAYER}" + +source "${EVO2_MEGATRON_DIR}/.venv/bin/activate" + +echo "============================================================" +echo "STEP 0: Chunk FASTA to <=${CHUNK_BP} bp (model trained context)" +echo "============================================================" +# chunk_fasta.py reads .gz directly and writes plain .fasta; no separate gunzip needed. +INPUT_STEM="$(basename "$FASTA")" +INPUT_STEM="${INPUT_STEM%.gz}" +INPUT_STEM="${INPUT_STEM%.fasta}" +CHUNKED_FASTA="${WORK_ROOT}/scratch/${INPUT_STEM}_chunked${CHUNK_BP}.fasta" +if [[ -f "$CHUNKED_FASTA" ]]; then + echo "Reusing existing chunked FASTA: $CHUNKED_FASTA" +else + python "${RECIPE_DIR}/scripts/chunk_fasta.py" \ + --input "$FASTA" \ + --output "$CHUNKED_FASTA" \ + --window "$CHUNK_BP" +fi +FASTA="$CHUNKED_FASTA" + +echo "============================================================" +echo "STEP 1: Convert Savanna -> MBridge" +echo "============================================================" +if [[ ! -f "${CKPT_DIR}/latest_checkpointed_iteration.txt" ]]; then + evo2_convert_savanna_to_mbridge \ + --savanna-ckpt-path "$MODEL" \ + --mbridge-ckpt-dir "$CKPT_DIR" \ + --model-size "$MODEL_SIZE" \ + --tokenizer-path "${EVO2_MEGATRON_DIR}/tokenizers/nucleotide_fast_tokenizer_512" +else + echo "Reusing existing checkpoint at $CKPT_DIR" +fi + +echo "============================================================" +echo "STEP 2: Extract layer-${LAYER} embeddings (predict_evo2)" +echo "============================================================" +mkdir -p "$PREDICT_DIR" +if compgen -G "${PREDICT_DIR}/predictions__*.pt" > /dev/null; then + echo "Reusing existing .pt files in $PREDICT_DIR" +else + predict_evo2 \ + --fasta "$FASTA" \ + --ckpt-dir "$CKPT_DIR" \ + --output-dir "$PREDICT_DIR" \ + --embedding-layer "$LAYER" \ + --micro-batch-size 1 \ + --devices 1 \ + --write-interval batch +fi + +echo "============================================================" +echo "STEP 3: Convert .pt -> parquet ActivationStore" +echo "============================================================" +if [[ -f "${PARQUET_DIR}/metadata.json" ]]; then + echo "Reusing existing parquet shards at $PARQUET_DIR" +else + python "${RECIPE_DIR}/scripts/pt_to_parquet.py" \ + --predict-dir "$PREDICT_DIR" \ + --output "$PARQUET_DIR" \ + --model-name "$MODEL" \ + --layer "$LAYER" +fi + +echo "============================================================" +echo "STEP 4: Train TopK SAE" +echo "============================================================" +python "${RECIPE_DIR}/scripts/train.py" \ + --cache-dir "$PARQUET_DIR" \ + --model-path "$MODEL" \ + --layer "$LAYER" \ + --model-type topk \ + --expansion-factor 8 --top-k 32 \ + --auxk 64 --auxk-coef 0.03125 \ + --init-pre-bias \ + --n-epochs 3 \ + --batch-size 4096 \ + --lr 3e-4 \ + --log-interval 50 \ + --no-wandb \ + --output-dir "$OUTPUT_DIR" \ + --checkpoint-dir "${OUTPUT_DIR}/checkpoints" \ + --checkpoint-steps 999999 + +echo "============================================================" +echo "DONE: SAE checkpoint at ${OUTPUT_DIR}/checkpoints/checkpoint_final.pt" +echo "============================================================" diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/chunk_fasta.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/chunk_fasta.py new file mode 100644 index 0000000000..55b26cad30 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/chunk_fasta.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Chunk a FASTA into <=N-bp windows so predict_evo2 stays inside the model's trained context. + +Evo2 1B was trained with seq_length=8192; longer inputs OOM in the Hyena +fftconv path (intermediates scale super-linearly with L). For 7B/40B raise +--window to whatever those checkpoints were context-extended to. + +Non-overlapping windows by default. Each chunk gets a header of the form +">{orig_id}:{start}-{end}" so downstream parquet can be back-mapped. +""" + +import argparse +import gzip +from pathlib import Path + + +def parse_fasta(path: Path): + """Yield (seq_id, sequence) tuples from a FASTA file (transparently handles .gz).""" + opener = gzip.open if path.suffix == ".gz" else open + seq_id, parts = None, [] + with opener(path, "rt") as f: + for line in f: + line = line.rstrip() + if line.startswith(">"): + if seq_id is not None: + yield seq_id, "".join(parts) + seq_id = line[1:].split()[0] + parts = [] + else: + parts.append(line) + if seq_id is not None: + yield seq_id, "".join(parts) + + +def main(): + """Read input FASTA, write non-overlapping <=window-bp chunks to output FASTA.""" + p = argparse.ArgumentParser() + p.add_argument("--input", type=Path, required=True) + p.add_argument("--output", type=Path, required=True) + p.add_argument("--window", type=int, default=8192) + args = p.parse_args() + + n_in = n_out = bp_out = 0 + args.output.parent.mkdir(parents=True, exist_ok=True) + with open(args.output, "w") as out: + for seq_id, seq in parse_fasta(args.input): + n_in += 1 + for start in range(0, len(seq), args.window): + end = min(start + args.window, len(seq)) + chunk = seq[start:end] + out.write(f">{seq_id}:{start}-{end}\n{chunk}\n") + n_out += 1 + bp_out += len(chunk) + + print(f"Chunked {n_in} sequences -> {n_out} chunks ({bp_out:,} bp) at window={args.window}") + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/pt_to_parquet.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/pt_to_parquet.py new file mode 100644 index 0000000000..6a182b575d --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/pt_to_parquet.py @@ -0,0 +1,65 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convert predict_evo2 .pt outputs to SAE ActivationStore parquet shards. + +predict_evo2 with --embedding-layer writes dicts of: + hidden_embeddings: [B, S, H] (bf16) + pad_mask: [B, S] (1 = valid token, 0 = padding) + seq_idx, tokens: metadata, ignored here + +We read each file, mask out padding, flatten to [N_tokens, H], and append +to an ActivationStore so train.py's load_activations() can consume it. +""" + +import argparse +import json +from pathlib import Path + +import torch +from sae.activation_store import ActivationStore, ActivationStoreConfig +from tqdm import tqdm + + +def main(): + """Walk predict_evo2 .pt files, mask padding, and write to an ActivationStore.""" + p = argparse.ArgumentParser() + p.add_argument("--predict-dir", type=Path, required=True, help="Dir containing predictions__*.pt") + p.add_argument("--output", type=Path, required=True, help="ActivationStore output dir") + p.add_argument("--model-name", type=str, required=True, help="Stamped into metadata.json") + p.add_argument("--layer", type=int, required=True, help="Stamped into metadata.json") + p.add_argument("--shard-size", type=int, default=100_000) + args = p.parse_args() + + pt_files = sorted(args.predict_dir.rglob("predictions__*.pt")) + if not pt_files: + raise FileNotFoundError(f"No predictions__*.pt under {args.predict_dir}") + + store = ActivationStore(args.output, ActivationStoreConfig(shard_size=args.shard_size)) + n_sequences = 0 + for pt in tqdm(pt_files, desc="pt->parquet"): + d = torch.load(pt, map_location="cpu", weights_only=False) + hidden = d["hidden_embeddings"] + mask = d["pad_mask"].bool() + flat = hidden[mask].float() + store.append(flat) + n_sequences += hidden.shape[0] + + store.finalize(metadata={"model_name": args.model_name, "layer": args.layer, "n_sequences": n_sequences}) + print(json.dumps(store.metadata, indent=2)) + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/train.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/train.py new file mode 100644 index 0000000000..19355822ae --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/train.py @@ -0,0 +1,321 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Step 2: Train SAE from cached CodonFM activations. + +Loads pre-extracted activations from an ActivationStore cache directory +and trains a Sparse Autoencoder. Requires extract.py to have been run first. + +Single-GPU: + python scripts/train.py \ + --cache-dir .cache/activations/encodon_1b_layer-2 \ + --model-path path/to/encodon_1b --layer -2 \ + --expansion-factor 8 --top-k 32 --batch-size 4096 --n-epochs 3 + +Multi-GPU DDP: + torchrun --nproc_per_node=4 scripts/train.py \ + --cache-dir .cache/activations/encodon_1b_layer-2 \ + --model-path path/to/encodon_1b --layer -2 \ + --expansion-factor 8 --top-k 32 --batch-size 4096 --n-epochs 3 \ + --dp-size 4 +""" + +import argparse +import os +from pathlib import Path + +import numpy as np +import torch +from sae.activation_store import load_activations +from sae.architectures import ReLUSAE, TopKSAE +from sae.perf_logger import PerfLogger +from sae.training import ParallelConfig, Trainer, TrainingConfig, WandbConfig +from sae.utils import get_device, set_seed + + +def parse_args(): # noqa: D103 + p = argparse.ArgumentParser( + description="Train SAE from cached CodonFM activations", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Required + p.add_argument("--cache-dir", type=str, required=True, help="Path to activation cache (from extract.py)") + p.add_argument("--model-path", type=str, required=True, help="Encodon model path (for cache validation)") + p.add_argument("--layer", type=int, required=True, help="Layer index (for cache validation)") + + # SAE architecture + sae_group = p.add_argument_group("SAE model") + sae_group.add_argument("--model-type", type=str, default="topk", choices=["topk", "relu"]) + sae_group.add_argument("--expansion-factor", type=int, default=8) + sae_group.add_argument("--top-k", type=int, default=32) + sae_group.add_argument("--normalize-input", action=argparse.BooleanOptionalAction, default=False) + sae_group.add_argument("--auxk", type=int, default=None) + sae_group.add_argument("--auxk-coef", type=float, default=1 / 32) + sae_group.add_argument("--dead-tokens-threshold", type=int, default=10_000_000) + sae_group.add_argument("--init-pre-bias", action=argparse.BooleanOptionalAction, default=False) + sae_group.add_argument("--l1-coeff", type=float, default=1e-2, help="L1 coefficient (relu only)") + + # Training + train_group = p.add_argument_group("Training") + train_group.add_argument("--lr", type=float, default=3e-4) + train_group.add_argument("--n-epochs", type=int, default=3) + train_group.add_argument("--batch-size", type=int, default=4096) + train_group.add_argument("--log-interval", type=int, default=50) + train_group.add_argument("--shuffle", action=argparse.BooleanOptionalAction, default=True) + train_group.add_argument("--num-workers", type=int, default=0) + train_group.add_argument("--pin-memory", action=argparse.BooleanOptionalAction, default=False) + train_group.add_argument("--max-grad-norm", type=float, default=None) + train_group.add_argument("--lr-scale-with-latents", action=argparse.BooleanOptionalAction, default=False) + train_group.add_argument("--lr-reference-hidden-dim", type=int, default=2048) + train_group.add_argument("--warmup-steps", type=int, default=0, help="Linear LR warmup steps") + train_group.add_argument( + "--lr-schedule", + type=str, + default="constant", + choices=["constant", "cosine", "linear"], + help="LR schedule after warmup", + ) + train_group.add_argument("--lr-min", type=float, default=0.0, help="Minimum LR for decay schedules") + train_group.add_argument( + "--lr-decay-steps", + type=int, + default=None, + help="Total steps for LR decay (None = full training)", + ) + + # W&B + wb_group = p.add_argument_group("Weights & Biases") + wb_group.add_argument("--wandb", action=argparse.BooleanOptionalAction, default=False, dest="wandb_enabled") + wb_group.add_argument("--wandb-project", type=str, default="sae_codonfm_recipe") + wb_group.add_argument("--wandb-run-name", type=str, default=None) + wb_group.add_argument("--wandb-group", type=str, default=None) + wb_group.add_argument("--wandb-job-type", type=str, default=None) + + # Checkpointing + ckpt_group = p.add_argument_group("Checkpointing") + ckpt_group.add_argument("--checkpoint-dir", type=str, default=None) + ckpt_group.add_argument("--checkpoint-steps", type=int, default=None) + ckpt_group.add_argument("--resume-from", type=str, default=None) + + # Infrastructure + p.add_argument("--dp-size", type=int, default=1) + p.add_argument("--output-dir", type=str, default="./outputs") + p.add_argument("--seed", type=int, default=42) + p.add_argument("--device", type=str, default=None) + p.add_argument( + "--num-sequences", + type=int, + default=None, + help="Subset cached activations to this many sequences' worth of shards", + ) + + return p.parse_args() + + +def build_sae(args, input_dim: int) -> torch.nn.Module: # noqa: D103 + hidden_dim = input_dim * args.expansion_factor + + if args.model_type == "topk": + return TopKSAE( + input_dim=input_dim, + hidden_dim=hidden_dim, + top_k=args.top_k, + normalize_input=args.normalize_input, + auxk=args.auxk, + auxk_coef=args.auxk_coef, + dead_tokens_threshold=args.dead_tokens_threshold, + ) + elif args.model_type == "relu": + return ReLUSAE( + input_dim=input_dim, + hidden_dim=hidden_dim, + l1_coeff=args.l1_coeff, + ) + else: + raise ValueError(f"Unknown model type: {args.model_type}") + + +def build_training_config(args, device: str) -> TrainingConfig: # noqa: D103 + return TrainingConfig( + lr=args.lr, + n_epochs=args.n_epochs, + batch_size=args.batch_size, + device=device, + log_interval=args.log_interval, + shuffle=args.shuffle, + num_workers=args.num_workers, + pin_memory=args.pin_memory, + checkpoint_dir=args.checkpoint_dir, + checkpoint_steps=args.checkpoint_steps, + lr_scale_with_latents=args.lr_scale_with_latents, + lr_reference_hidden_dim=args.lr_reference_hidden_dim, + warmup_steps=args.warmup_steps, + max_grad_norm=args.max_grad_norm, + lr_schedule=args.lr_schedule, + lr_min=args.lr_min, + lr_decay_steps=args.lr_decay_steps, + ) + + +def build_wandb_config(args) -> WandbConfig: # noqa: D103 + return WandbConfig( + enabled=args.wandb_enabled, + project=args.wandb_project, + run_name=args.wandb_run_name, + group=args.wandb_group, + job_type=args.wandb_job_type, + config=vars(args), + ) + + +def build_parallel_config(args) -> ParallelConfig: # noqa: D103 + return ParallelConfig(dp_size=args.dp_size) + + +def main(): # noqa: D103 + args = parse_args() + + set_seed(args.seed) + device = args.device or get_device() + print(f"Using device: {device}") + print(f"Config: {vars(args)}") + + # Load cached activations + cache_path = Path(args.cache_dir) + if not (cache_path / "metadata.json").exists(): + raise FileNotFoundError(f"No cache found at {cache_path}. Run extract.py first.") + + store = load_activations(cache_path) + meta = store.metadata + + # Validate cache matches config + cached_model = meta.get("model_path", meta.get("model_name", "")) + if cached_model and cached_model != args.model_path: + print(f"WARNING: Cache model '{cached_model}' != '{args.model_path}'") + if meta.get("layer") != args.layer: + raise ValueError(f"Cache layer mismatch: {meta['layer']} vs {args.layer}") + + # Compute subsetting + cached_sequences = meta.get("n_sequences", None) + max_shards = None + if args.num_sequences and cached_sequences and args.num_sequences < cached_sequences: + keep_ratio = args.num_sequences / cached_sequences + max_shards = max(1, int(np.ceil(keep_ratio * meta["n_shards"]))) + print( + f"Subsetting: {args.num_sequences}/{cached_sequences} sequences " + f"-> using {max_shards}/{meta['n_shards']} shards (~{keep_ratio:.1%})" + ) + + # Estimate memory + n_shards_to_use = max_shards or meta["n_shards"] + shard_size = meta.get("shard_size", 100_000) + est_tokens = n_shards_to_use * shard_size + est_gb = est_tokens * meta["hidden_dim"] * 4 / (1024**3) + use_streaming = est_gb > 50 + + input_dim = meta["hidden_dim"] + sae = build_sae(args, input_dim) + print(f"SAE: {args.model_type}, input_dim={input_dim}, hidden_dim={sae.hidden_dim}") + + # Initialize pre_bias + if args.init_pre_bias and hasattr(sae, "init_pre_bias_from_data"): + print("Initializing pre_bias from geometric median of data...") + first_shard = torch.from_numpy(store._load_shard(0)).float() + sample_size = min(32768, len(first_shard)) + sae.init_pre_bias_from_data(first_shard[:sample_size]) + print(f" pre_bias initialized (mean={sae.pre_bias.mean().item():.4f})") + del first_shard + + # Build configs + training_config = build_training_config(args, device) + wandb_config = build_wandb_config(args) + parallel_config = build_parallel_config(args) + + perf_logger = PerfLogger( + log_interval=args.log_interval, + use_wandb=args.wandb_enabled, + print_logs=True, + device=device, + ) + + # Train + trainer = Trainer( + sae, + training_config, + wandb_config=wandb_config, + perf_logger=perf_logger, + parallel_config=parallel_config, + ) + + if use_streaming: + rank = int(os.environ.get("RANK", 0)) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + print( + f"Streaming from disk (~{est_gb:.0f}GB). " + f"Peak RAM: ~{shard_size * meta['hidden_dim'] * 4 / (1024**3):.1f}GB/process" + ) + + dataloader = store.get_streaming_dataloader( + batch_size=args.batch_size, + shuffle=args.shuffle, + seed=args.seed, + rank=rank, + world_size=world_size, + max_shards=max_shards, + ) + # Compute min batch count across all ranks to keep DDP in sync + # Read parquet footers for all ranks' shards (a few KB each, no data loading) + if world_size > 1: + import pyarrow.parquet as pq_meta + + dataset = dataloader.dataset + per_rank = len(dataset.shard_indices) + # Each rank got per_rank contiguous shards; compute batch count for each rank + min_batches = None + for r in range(world_size): + total_rows = sum( + pq_meta.read_metadata(store.path / f"shard_{idx:05d}.parquet").num_rows + for idx in range(r * per_rank, (r + 1) * per_rank) + ) + batches = total_rows // args.batch_size + if min_batches is None or batches < min_batches: + min_batches = batches + dataset.max_batches = min_batches + print(f"[rank {rank}] capped to {min_batches} batches/epoch for DDP sync") + trainer.fit( + dataloader, + resume_from=args.resume_from, + data_sharded=True, + ) + else: + shards = [] + for i, shard in enumerate(store.iter_shards(shuffle_shards=False)): + if max_shards is not None and i >= max_shards: + break + shards.append(torch.from_numpy(shard).float()) + activations_flat = torch.cat(shards) + print(f"Loaded {activations_flat.shape[0]:,} cached activations into memory") + + trainer.fit( + activations_flat, + resume_from=args.resume_from, + ) + + print("Training complete.") + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/__init__.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/__init__.py new file mode 100644 index 0000000000..d8ac513dc8 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/__init__.py @@ -0,0 +1,16 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sparse autoencoders for the Evo2 DNA language model.""" From 5edbf6ef0d8b658d5cc24a1b36dff634237689fc Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Tue, 26 May 2026 21:14:43 +0000 Subject: [PATCH 3/7] evo2 recipe: drop empty src/evo2_sae package MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The recipe currently has no model-specific Python module — the extractor is upstream (`predict_evo2`) and the two scripts are simple CLIs in scripts/. Drop the empty package and adjust pyproject.toml so setuptools doesn't try to discover anything. Will reintroduce when there's actual library code to put there (eval, dashboard, dataloaders). Co-Authored-By: Claude Opus 4.7 (1M context) --- .../recipes/evo2/pyproject.toml | 7 +++++-- .../recipes/evo2/src/evo2_sae/__init__.py | 16 ---------------- 2 files changed, 5 insertions(+), 18 deletions(-) delete mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/__init__.py diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml index 26eff6b55c..1f00a62bc5 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml @@ -17,8 +17,11 @@ dependencies = [ "pyarrow>=10.0", ] -[tool.setuptools.packages.find] -where = ["src"] +# No package code lives here yet — the recipe is just an entry-point for +# scripts/ that depends on the shared `sae` workspace package. Declare no +# packages so setuptools doesn't try to discover anything. +[tool.setuptools] +packages = [] [tool.uv.sources] sae = { workspace = true } diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/__init__.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/__init__.py deleted file mode 100644 index d8ac513dc8..0000000000 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Sparse autoencoders for the Evo2 DNA language model.""" From 89ff40c3a2ee925811415eac9aef74b6b50ab732 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Wed, 27 May 2026 16:34:28 +0000 Subject: [PATCH 4/7] evo2_megatron predict: skip empty batches on DP shard boundary torchrun --nproc_per_node N can hand a rank an empty batch when the last micro-batch falls past the shard boundary. _padding_collate_fn then crashed in max() with "iterable argument is empty". Return None from the collate when batch is empty and skip the loop iteration in predict(). Required for predict_evo2 to run reliably under DP > 1. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../recipes/evo2_megatron/src/bionemo/evo2/run/predict.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/predict.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/predict.py index c888d46c94..92dfb2051a 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/predict.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/predict.py @@ -586,8 +586,11 @@ def _padding_collate_fn( min_length: Minimum length to pad to Returns: - Dictionary with batched and padded tensors + Dictionary with batched and padded tensors, or None when the input + batch is empty (can happen on DP shard boundaries — caller must skip). """ + if not batch: + return None max_len = max(sample["tokens"].shape[0] for sample in batch) if min_length is not None: max_len = max(max_len, min_length) @@ -1197,6 +1200,9 @@ def predict( with torch.no_grad(): for batch_idx, batch_data in enumerate(dataloader): + # Empty batches can be handed to a rank on DP shard boundaries. + if batch_data is None: + continue # Move to GPU batch_gpu = { k: v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for k, v in batch_data.items() From 9b7856a8d35583c9d3947dbba1469265d6b49397 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Wed, 27 May 2026 23:56:36 +0000 Subject: [PATCH 5/7] evo2 sae: streaming extractor + prok+euk FASTA composer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit scripts/extract.py Codonfm-style streaming activation extractor. Reuses predict_evo2's Megatron model/DP/dataloader machinery by monkey-patching its _write_predictions_batch, then streams pad-stripped layer-N activations directly into an ActivationStore (parquet shards) inside the inference loop — skipping the .pt intermediate that pt_to_parquet had to walk. --max-tokens caps each rank's budget. File-based rank wait + merge (not dist.barrier — predict.main tears down the process group before the writer hook returns, so the barrier silently no-ops and rank 0 races ahead; observed orphaned 18M tokens before this was fixed). Saves ~30 min and ~7 TB scratch per 25M-token run vs the old pipeline. scripts/compose_prokeuk_fasta.py Builds a balanced prokaryotic + eukaryotic mixed FASTA from OpenGenome2 subsets (metagenomes + eukaryotic_genic_windows). Truncates metagenome contigs to --metagenome-window bp each (default 50k) — they average ~1.1 Mbp, so a handful of full contigs would dominate the mix. Emits unique seq_{i} headers so predict_evo2's dup-id check passes. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../evo2/scripts/compose_prokeuk_fasta.py | 142 ++++++++++++ .../recipes/evo2/scripts/extract.py | 206 ++++++++++++++++++ 2 files changed, 348 insertions(+) create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/compose_prokeuk_fasta.py create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/extract.py diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/compose_prokeuk_fasta.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/compose_prokeuk_fasta.py new file mode 100644 index 0000000000..37a1e4b0a0 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/compose_prokeuk_fasta.py @@ -0,0 +1,142 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 + +"""Compose a prokaryotic + eukaryotic FASTA from OpenGenome2 subsets. + +Two sources only: +- Prokaryotic: filtered_metagenomes_pt1 (truncated to --metagenome-window bp/contig) +- Eukaryotic: eukaryotic_genic_windows (~5kb euk genic regions) + +Output headers are renumbered as `>seq_{i} {source}` to satisfy predict_evo2's +unique-id check (the source files share NCBI-style accession headers across +records). +""" + +import argparse +import gzip +import subprocess +import sys +from pathlib import Path + + +def _open_text(path: Path): + """Open a .fasta or .fasta.gz file in text mode.""" + if str(path).endswith(".gz"): + return gzip.open(path, "rt") + return open(path) + + +def _iter_records(fh): + """Yield (header, seq_lines) from a FASTA file handle.""" + header = None + lines = [] + for line in fh: + line = line.rstrip("\n") + if line.startswith(">"): + if header is not None: + yield header, lines + header = line + lines = [] + elif line: + lines.append(line) + if header is not None: + yield header, lines + + +def _take_n(fh, n): + """Take the first n records from fh.""" + for i, (h, ls) in enumerate(_iter_records(fh)): + if i >= n: + return + yield h, ls + + +def _take_n_truncated(fh, n, max_bp): + """Take the first n records, truncating each sequence to <= max_bp bases.""" + for i, (h, ls) in enumerate(_iter_records(fh)): + if i >= n: + return + seq = "".join(ls)[:max_bp] + lines = [seq[j : j + 80] for j in range(0, len(seq), 80)] + yield h, lines + + +def main(): + """Compose a prok+euk mixed FASTA with unique seq_{i} headers.""" + ap = argparse.ArgumentParser() + ap.add_argument( + "--root", + type=Path, + default=Path("/data/interp/evo2/OpenGenome2/fasta"), + help="Root dir holding the OpenGenome2 subset directories.", + ) + ap.add_argument( + "--output", + type=Path, + default=Path("/data/interp/evo2/scratch/mixed_25M_prokeuk.fasta"), + ) + ap.add_argument("--n-metagenome", type=int, default=1000, help="Metagenome contigs (prok).") + ap.add_argument("--metagenome-window", type=int, default=50_000, help="Max bp per metagenome contig (truncate).") + ap.add_argument("--n-euk-windows", type=int, default=10_000, help="Eukaryotic_genic_windows records (euk).") + args = ap.parse_args() + + metagenome_file = args.root / "metagenomes" / "filtered_metagenomes_pt1.fasta.gz" + euk_parts = sorted((args.root / "eukaryotic_genic_windows").glob("*.fasta.gz.*")) + + if not metagenome_file.exists(): + print(f"ERROR: missing metagenome source at {metagenome_file}", file=sys.stderr) + sys.exit(1) + if not euk_parts: + print(f"ERROR: no eukaryotic_genic_windows parts under {args.root}", file=sys.stderr) + sys.exit(1) + + args.output.parent.mkdir(parents=True, exist_ok=True) + counter = 0 + bp_by_source: dict[str, int] = {} + + def _emit(out, header, lines, source): + """Write a record with a globally unique header and tally bp.""" + nonlocal counter + out.write(f">seq_{counter} {source}\n") + counter += 1 + for line in lines: + out.write(line + "\n") + bp_by_source[source] = bp_by_source.get(source, 0) + sum(len(line) for line in lines) + + with open(args.output, "w") as out: + # 1. Prokaryotic: metagenome contigs, truncated to --metagenome-window each. + print(f"adding {args.n_metagenome} metagenome contigs (truncated to {args.metagenome_window} bp each)...") + with _open_text(metagenome_file) as fh: + for h, ls in _take_n_truncated(fh, args.n_metagenome, args.metagenome_window): + _emit(out, h, ls, "prok_metagenomes") + + # 2. Eukaryotic: read split parts as one stream. + print(f"adding {args.n_euk_windows} eukaryotic_genic_windows...") + cat_parts = subprocess.Popen( + ["bash", "-c", f"cat {' '.join(str(p) for p in euk_parts)} | zcat"], + stdout=subprocess.PIPE, + text=True, + ) + try: + for h, ls in _take_n(cat_parts.stdout, args.n_euk_windows): + _emit(out, h, ls, "euk_genic_windows") + finally: + cat_parts.stdout.close() + cat_parts.terminate() + + total_bp = sum(bp_by_source.values()) + print(f"\nwrote {counter} sequences, {total_bp:,} total bp -> {args.output}") + print("by source (bp):") + for src, bp in bp_by_source.items(): + pct = 100 * bp / total_bp if total_bp else 0 + print(f" {src:<22} {bp:>12,} bp ({pct:>5.1f}%)") + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/extract.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/extract.py new file mode 100644 index 0000000000..2b4d753bcd --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/extract.py @@ -0,0 +1,206 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 + +"""Streaming Evo2 activation extractor — codonfm-style. + +Reuses `bionemo.evo2.run.predict` for all the heavy machinery (Megatron +model load, DP/CP/TP/PP setup, FASTA dataloader, inference loop) but +swaps the per-batch `.pt` writer for an in-process `ActivationStore` +that streams parquet shards directly during inference. + +Why: predict_evo2's `.pt` intermediate doubles disk volume and forces a +slow downstream pt->parquet shim. For SAE training, the activation tensor +is all we need; writing it directly into the SAE's ActivationStore format +removes the shim entirely, mirroring how codonfm's scripts/extract.py +already works. + +Invocation: + + torchrun --nproc_per_node 4 extract.py \ + --fasta path/to/seq.fasta \ + --ckpt-dir path/to/mbridge_ckpt \ + --embedding-layer 20 \ + --activation-store-dir /data/.../parquet_out \ + --max-tokens 25000000 \ + --micro-batch-size 4 + +All non-`--activation-store-dir`/`--max-tokens` flags are forwarded to +predict_evo2's argparse (`--fasta`, `--ckpt-dir`, `--embedding-layer`, +`--micro-batch-size`, etc.) so the inference surface is identical. +""" + +import argparse +import os +import shutil +import sys +from pathlib import Path + +# predict.py provides the entire Megatron inference plumbing. +from bionemo.evo2.run import predict as predict_mod # noqa: E402 + +# The SAE activation store — same format the existing pt_to_parquet.py emits. +from sae.activation_store import ActivationStore, ActivationStoreConfig # noqa: E402 + +# Reuse the merge step we already have in pt_to_parquet.py. +sys.path.insert(0, str(Path(__file__).resolve().parent)) +from pt_to_parquet import _merge_temp_stores # noqa: E402 + + +# Per-rank state. Each torchrun rank is its own Python process, so this +# module-level dict is rank-local — exactly what we want. +_state: dict = { + "store": None, # ActivationStore for this rank + "n_tokens": 0, # tokens appended on this rank + "n_sequences": 0, # hidden.shape[0] across batches (raw seqs) + "budget": 0, # per-rank token cap (0 = no cap) + "store_root": None, # Path — final output dir; per-rank tmp lives under .tmp_rank_ + "rank_tmp": None, # Path — this rank's temp dir +} + + +def _store_writer( + predictions, + output_dir, + batch_idx, + global_rank, + dp_rank, + files_per_subdir=None, + num_files_written=0, + data_parallel_world_size=1, +): + """Replacement for predict._write_predictions_batch — append to ActivationStore. + + Signature matches the original; return shape `(path, updated_count, 0)`. + """ + if not predictions: + return output_dir, num_files_written, 0 + + # Once we've hit the per-rank budget, skip remaining writes (forward + # passes still run; cheap relative to the I/O we're skipping). + if _state["budget"] and _state["n_tokens"] >= _state["budget"]: + return output_dir, num_files_written, 0 + + hidden = predictions["hidden_embeddings"] # [B, S, H] + mask = predictions["pad_mask"].bool() + flat = hidden[mask].cpu() # [N_unpadded_tokens, H] + + if _state["store"] is None: + rank_tmp = _state["store_root"].with_name( + _state["store_root"].name + f".tmp_rank_{dp_rank}" + ) + rank_tmp.mkdir(parents=True, exist_ok=True) + _state["rank_tmp"] = rank_tmp + _state["store"] = ActivationStore(rank_tmp, ActivationStoreConfig(shard_size=100_000)) + + _state["store"].append(flat) + _state["n_tokens"] += flat.shape[0] + _state["n_sequences"] += hidden.shape[0] + return output_dir, num_files_written + 1, 0 + + +def _finalize_and_maybe_merge(model_name: str, layer: int) -> None: + """Finalize this rank's store, then rank 0 waits for all ranks and merges. + + We use a file-based wait (poll for sibling ranks' metadata.json) rather + than torch.distributed.barrier(): predict.main() tears down the process + group before this hook runs, so dist.barrier() silently no-ops and rank 0 + would race ahead of slower ranks (observed in the prok+euk run — rank 0 + merged its own dir before ranks 1-3 finalized, leaving 18M tokens orphaned). + """ + if _state["store"] is not None: + _state["store"].finalize(metadata={"n_sequences": _state["n_sequences"]}) + + if int(os.environ.get("RANK", "0")) != 0: + return + + # Rank 0 waits for all siblings to finalize before merging. + import time + + store_root: Path = _state["store_root"] + world_size = int(os.environ.get("WORLD_SIZE", "1")) + deadline = time.time() + 600 # 10 min wait cap + + def _ready_count() -> int: + return sum( + 1 + for r in range(world_size) + if (store_root.with_name(store_root.name + f".tmp_rank_{r}") / "metadata.json").exists() + ) + + while time.time() < deadline: + ready = _ready_count() + if ready >= world_size: + break + time.sleep(2) + else: + print( + f"[extract] WARN: only {_ready_count()}/{world_size} ranks finalized within 10 min — " + "merging what's available; some activations may be orphaned" + ) + + tmp_dirs = sorted( + p + for p in store_root.parent.glob(store_root.name + ".tmp_rank_*") + if p.is_dir() and (p / "metadata.json").exists() + ) + if not tmp_dirs: + print(f"[extract] no rank tmp dirs found under {store_root.parent} — nothing to merge") + return + print(f"[extract] merging {len(tmp_dirs)} rank tmp dirs into {store_root}") + final = _merge_temp_stores(tmp_dirs, store_root, model_name, layer) + print(f"[extract] done: {final}") + + +def main() -> None: + """Parse the extractor-specific flags, monkey-patch predict's writer, run predict.main().""" + parser = argparse.ArgumentParser(add_help=False) + parser.add_argument("--activation-store-dir", type=Path, required=True) + parser.add_argument("--max-tokens", type=int, default=0, help="Cap total tokens across DP ranks (0 = no cap).") + parser.add_argument("--model-name", type=str, default="arcinstitute/savanna_evo2_1b_base") + extract_args, remaining = parser.parse_known_args() + + dp_size = int(os.environ.get("WORLD_SIZE", "1")) + _state["store_root"] = extract_args.activation_store_dir + _state["budget"] = extract_args.max_tokens // dp_size if extract_args.max_tokens else 0 + + # Force batch write-interval so our writer is called every iteration + # (epoch mode would buffer everything in memory, defeating the point). + if "--write-interval" not in remaining: + remaining.extend(["--write-interval", "batch"]) + + # predict.main() requires --output-dir; we point it at a throwaway path + # (writer never actually writes there). + if "--output-dir" not in remaining: + scratch = _state["store_root"].with_name(_state["store_root"].name + ".predict_unused") + scratch.mkdir(parents=True, exist_ok=True) + remaining.extend(["--output-dir", str(scratch)]) + + # Capture for the merge metadata; we need to know which layer / model + # to stamp into the merged ActivationStore.metadata. + layer = 0 + for i, a in enumerate(remaining): + if a == "--embedding-layer": + layer = int(remaining[i + 1]) + + # Substitute our writer for predict's. predict.py calls the bare name + # `_write_predictions_batch(...)` in its module scope, so module-attr + # replacement is enough. + predict_mod._write_predictions_batch = _store_writer + + # Hand predict's parser only the args it expects. + sys.argv = [sys.argv[0]] + remaining + + try: + predict_mod.main() + finally: + _finalize_and_maybe_merge(extract_args.model_name, layer) + + +if __name__ == "__main__": + main() From a5b1db7857a5c6907248ca8875319b38e1391422 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Fri, 29 May 2026 04:41:18 +0000 Subject: [PATCH 6/7] evo2 sae recipe: streaming-extract pipeline, drop pt_to_parquet shim extract.py replaces the predict_evo2 -> .pt -> pt_to_parquet path with a single streaming step that writes ActivationStore parquet shards directly during inference. Delete the now-unused shim, rewrite 1b.sh as a 3-step pipeline (convert -> extract -> train), and update the README accordingly. 1b.sh: - collapse predict_evo2 + pt_to_parquet into a single 'STEP 2: extract' that calls torchrun extract.py - expose RUN_TAG, PARQUET_DIR/OUTPUT_DIR, MAX_TOKENS, MICRO_BATCH, DEVICES, and SAE training hyperparams (EXPANSION_FACTOR, TOP_K, AUXK, AUXK_COEF, DEAD_TOKENS_THRESHOLD, N_EPOCHS, LR) as env overrides so the same script drives a multi-config sweep - TRAIN_ONLY=1 skips chunk/convert/extract against a cached parquet - WANDB_API_KEY gates wandb logging; WANDB_PROJECT/WANDB_RUN_NAME override README: pipeline diagram + quick-start examples for the new env-overridable flow; remove all references to .pt intermediates and pt_to_parquet. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../recipes/evo2/README.md | 34 +++++-- .../recipes/evo2/scripts/1b.sh | 91 ++++++++++++------- .../recipes/evo2/scripts/pt_to_parquet.py | 65 ------------- 3 files changed, 84 insertions(+), 106 deletions(-) delete mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/pt_to_parquet.py diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/README.md b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/README.md index ad749dbedb..1b6a7eb0d5 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/README.md +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/README.md @@ -7,16 +7,17 @@ Pipeline: ``` HF Savanna ckpt --convert--> MBridge ckpt | - predict_evo2 --embedding-layer N (FASTA in, .pt out) - | - pt_to_parquet shim (.pt -> ActivationStore parquet shards) + extract.py (FASTA in, ActivationStore parquet shards out) | train.py (TopK SAE) ``` +`extract.py` monkey-patches `predict_evo2`'s writer to stream parquet shards inline, +so there is no `.pt` intermediate and no separate shim step. + The eval / dashboard stage from the esm2 recipe is intentionally not ported in v1. -## Quick start (1B model, single GPU) +## Quick start (1B model, 4 GPU) ```bash bash scripts/1b.sh @@ -25,6 +26,25 @@ bash scripts/1b.sh This will: 1. Convert `arcinstitute/savanna_evo2_1b_base` to MBridge format -2. Run `predict_evo2` on the OpenGenome2 organelle FASTA, extracting layer-12 embeddings -3. Convert the .pt outputs to parquet shards -4. Train a TopK SAE (expansion=8, k=32) +2. Run `extract.py` on the OpenGenome2 organelle FASTA, streaming layer-12 + activations directly to parquet shards (no `.pt` intermediate) +3. Train a TopK SAE (expansion=8, k=32, auxk=512) + +Common overrides: + +```bash +# Different layer, different FASTA, tagged output paths +LAYER=22 FASTA=/data/.../prokeuk_25M.fasta RUN_TAG=25M_prokeuk bash scripts/1b.sh + +# Skip extraction (assumes parquet already exists at PARQUET_DIR) +TRAIN_ONLY=1 PARQUET_DIR=/data/.../parquet_25M_prokeuk bash scripts/1b.sh + +# Sweep hyperparams +LAYER=22 RUN_TAG=auxk2048 AUXK=2048 N_EPOCHS=4 bash scripts/1b.sh +``` + +See the top of `scripts/1b.sh` for the full list of env-overridable variables +(`MODEL`, `LAYER`, `CHUNK_BP`, `FASTA`, `RUN_TAG`, `MAX_TOKENS`, `MICRO_BATCH`, +`DEVICES`, `EXPANSION_FACTOR`, `TOP_K`, `AUXK`, `AUXK_COEF`, +`DEAD_TOKENS_THRESHOLD`, `N_EPOCHS`, `LR`, `WANDB_API_KEY`, `WANDB_PROJECT`, +`WANDB_RUN_NAME`, `TRAIN_ONLY`). diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/1b.sh b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/1b.sh index d499b4f365..a3262ec0c5 100755 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/1b.sh +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/1b.sh @@ -1,11 +1,11 @@ #!/bin/bash -# Evo2 1B SAE pipeline: convert -> predict_evo2 -> pt_to_parquet -> train. +# Evo2 1B SAE pipeline: convert -> extract (streaming parquet) -> train. # # Assumes: # - bionemo-recipes/recipes/evo2_megatron has been built (.ci_build.sh) and # its .venv is active, providing predict_evo2 + evo2_convert_savanna_to_mbridge. # - The sae workspace package is importable in that same venv. -# - HF_TOKEN is set if Savanna checkpoint repo is gated. +# - HF_TOKEN is set if the Savanna checkpoint repo is gated. # # Override any of these by exporting before invocation. @@ -23,13 +23,30 @@ CHUNK_BP="${CHUNK_BP:-8192}" FASTA="${FASTA:-/data/interp/evo2/OpenGenome2/fasta/organelles/organelle_sequences.fasta.gz}" WORK_ROOT="${WORK_ROOT:-/data/interp/evo2}" -CKPT_DIR="${WORK_ROOT}/checkpoints/${MODEL_SIZE}_mbridge" -PREDICT_DIR="${WORK_ROOT}/activations/${MODEL_SIZE}_layer${LAYER}_pt" -PARQUET_DIR="${WORK_ROOT}/activations/${MODEL_SIZE}_layer${LAYER}_parquet" -OUTPUT_DIR="${WORK_ROOT}/sae/${MODEL_SIZE}_layer${LAYER}" +# Default output paths can be overridden per-run. Set RUN_TAG to suffix the +# activation and SAE paths at once (e.g. RUN_TAG=100M_mixed -> ..._parquet_100M_mixed), +# or override each path individually for full control. +RUN_TAG="${RUN_TAG:-}" +_SUFFIX="${RUN_TAG:+_${RUN_TAG}}" + +CKPT_DIR="${CKPT_DIR:-${WORK_ROOT}/checkpoints/${MODEL_SIZE}_mbridge}" +PARQUET_DIR="${PARQUET_DIR:-${WORK_ROOT}/activations/${MODEL_SIZE}_layer${LAYER}_parquet${_SUFFIX}}" +OUTPUT_DIR="${OUTPUT_DIR:-${WORK_ROOT}/sae/${MODEL_SIZE}_layer${LAYER}${_SUFFIX}}" source "${EVO2_MEGATRON_DIR}/.venv/bin/activate" +# TRAIN_ONLY=1 skips chunk/convert/extract against a cached parquet. +if [[ "${TRAIN_ONLY:-0}" == "1" ]]; then + echo "============================================================" + echo "TRAIN_ONLY=1 — skipping chunk / convert / extract;" + echo "expecting an existing parquet at: $PARQUET_DIR" + echo "============================================================" + if [[ ! -f "${PARQUET_DIR}/metadata.json" ]]; then + echo "ERROR: TRAIN_ONLY=1 but no parquet at $PARQUET_DIR" + exit 1 + fi +else + echo "============================================================" echo "STEP 0: Chunk FASTA to <=${CHUNK_BP} bp (model trained context)" echo "============================================================" @@ -58,55 +75,61 @@ if [[ ! -f "${CKPT_DIR}/latest_checkpointed_iteration.txt" ]]; then --model-size "$MODEL_SIZE" \ --tokenizer-path "${EVO2_MEGATRON_DIR}/tokenizers/nucleotide_fast_tokenizer_512" else - echo "Reusing existing checkpoint at $CKPT_DIR" + echo "Reusing existing MBridge checkpoint at $CKPT_DIR" fi echo "============================================================" -echo "STEP 2: Extract layer-${LAYER} embeddings (predict_evo2)" +echo "STEP 2: Extract layer-${LAYER} activations directly to parquet" echo "============================================================" -mkdir -p "$PREDICT_DIR" -if compgen -G "${PREDICT_DIR}/predictions__*.pt" > /dev/null; then - echo "Reusing existing .pt files in $PREDICT_DIR" +# extract.py monkey-patches predict_evo2's writer to stream parquet shards +# inline; no .pt intermediate and no separate shim. MAX_TOKENS=0 = uncapped. +if [[ -f "${PARQUET_DIR}/metadata.json" ]]; then + echo "Reusing existing parquet shards at $PARQUET_DIR" else - predict_evo2 \ + torchrun --nproc_per_node "${DEVICES:-4}" "${RECIPE_DIR}/scripts/extract.py" \ + --activation-store-dir "$PARQUET_DIR" \ + --max-tokens "${MAX_TOKENS:-0}" \ + --model-name "$MODEL" \ --fasta "$FASTA" \ --ckpt-dir "$CKPT_DIR" \ - --output-dir "$PREDICT_DIR" \ --embedding-layer "$LAYER" \ - --micro-batch-size 1 \ - --devices 1 \ - --write-interval batch + --micro-batch-size "${MICRO_BATCH:-4}" fi +fi # end if TRAIN_ONLY + echo "============================================================" -echo "STEP 3: Convert .pt -> parquet ActivationStore" +echo "STEP 3: Train TopK SAE" echo "============================================================" -if [[ -f "${PARQUET_DIR}/metadata.json" ]]; then - echo "Reusing existing parquet shards at $PARQUET_DIR" -else - python "${RECIPE_DIR}/scripts/pt_to_parquet.py" \ - --predict-dir "$PREDICT_DIR" \ - --output "$PARQUET_DIR" \ - --model-name "$MODEL" \ - --layer "$LAYER" +# Wandb is enabled iff WANDB_API_KEY is in the env. WANDB_PROJECT/RUN can be overridden. +WANDB_FLAGS=("--no-wandb") +if [[ -n "${WANDB_API_KEY:-}" ]]; then + WANDB_FLAGS=( + "--wandb" + "--wandb-project" "${WANDB_PROJECT:-evo2-sae}" + ) + if [[ -n "${WANDB_RUN_NAME:-}" ]]; then + WANDB_FLAGS+=("--wandb-run-name" "$WANDB_RUN_NAME") + fi fi -echo "============================================================" -echo "STEP 4: Train TopK SAE" -echo "============================================================" -python "${RECIPE_DIR}/scripts/train.py" \ +torchrun --nproc_per_node "${DEVICES:-4}" "${RECIPE_DIR}/scripts/train.py" \ --cache-dir "$PARQUET_DIR" \ --model-path "$MODEL" \ --layer "$LAYER" \ --model-type topk \ - --expansion-factor 8 --top-k 32 \ - --auxk 64 --auxk-coef 0.03125 \ + --expansion-factor "${EXPANSION_FACTOR:-8}" \ + --top-k "${TOP_K:-32}" \ + --auxk "${AUXK:-512}" \ + --auxk-coef "${AUXK_COEF:-0.03125}" \ + --dead-tokens-threshold "${DEAD_TOKENS_THRESHOLD:-500000}" \ --init-pre-bias \ - --n-epochs 3 \ + --n-epochs "${N_EPOCHS:-3}" \ --batch-size 4096 \ - --lr 3e-4 \ + --dp-size "${DEVICES:-4}" \ + --lr "${LR:-3e-4}" \ --log-interval 50 \ - --no-wandb \ + "${WANDB_FLAGS[@]}" \ --output-dir "$OUTPUT_DIR" \ --checkpoint-dir "${OUTPUT_DIR}/checkpoints" \ --checkpoint-steps 999999 diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/pt_to_parquet.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/pt_to_parquet.py deleted file mode 100644 index 6a182b575d..0000000000 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/pt_to_parquet.py +++ /dev/null @@ -1,65 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Convert predict_evo2 .pt outputs to SAE ActivationStore parquet shards. - -predict_evo2 with --embedding-layer writes dicts of: - hidden_embeddings: [B, S, H] (bf16) - pad_mask: [B, S] (1 = valid token, 0 = padding) - seq_idx, tokens: metadata, ignored here - -We read each file, mask out padding, flatten to [N_tokens, H], and append -to an ActivationStore so train.py's load_activations() can consume it. -""" - -import argparse -import json -from pathlib import Path - -import torch -from sae.activation_store import ActivationStore, ActivationStoreConfig -from tqdm import tqdm - - -def main(): - """Walk predict_evo2 .pt files, mask padding, and write to an ActivationStore.""" - p = argparse.ArgumentParser() - p.add_argument("--predict-dir", type=Path, required=True, help="Dir containing predictions__*.pt") - p.add_argument("--output", type=Path, required=True, help="ActivationStore output dir") - p.add_argument("--model-name", type=str, required=True, help="Stamped into metadata.json") - p.add_argument("--layer", type=int, required=True, help="Stamped into metadata.json") - p.add_argument("--shard-size", type=int, default=100_000) - args = p.parse_args() - - pt_files = sorted(args.predict_dir.rglob("predictions__*.pt")) - if not pt_files: - raise FileNotFoundError(f"No predictions__*.pt under {args.predict_dir}") - - store = ActivationStore(args.output, ActivationStoreConfig(shard_size=args.shard_size)) - n_sequences = 0 - for pt in tqdm(pt_files, desc="pt->parquet"): - d = torch.load(pt, map_location="cpu", weights_only=False) - hidden = d["hidden_embeddings"] - mask = d["pad_mask"].bool() - flat = hidden[mask].float() - store.append(flat) - n_sequences += hidden.shape[0] - - store.finalize(metadata={"model_name": args.model_name, "layer": args.layer, "n_sequences": n_sequences}) - print(json.dumps(store.metadata, indent=2)) - - -if __name__ == "__main__": - main() From 958a48ae9d0fd64d45556d57e4c072b89ffe5e3a Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Fri, 29 May 2026 20:09:06 +0000 Subject: [PATCH 7/7] sae topk: fix aux-loss residual; evo2 1b.sh: 10M dead-threshold default topk.py: aux-loss target was `x - recon + pre_bias`, which simplifies to `x - decoder(codes)` -- norm dominated by ||pre_bias|| (~449 on evo2 L22) rather than the actual reconstruction error (~8). The denominator (`target_var = residual.pow(2).mean(-1)`) was inflated by the same factor, so the aux gradient was scaled by roughly (||pre_bias|| / ||error||)^2 ~ 3000x below the canonical formulation. Fix to `residual = x - recon`, matching the OpenAI/Gao TopK formulation. Numerically verified on the 500M L22 checkpoint: residual (a) ||x - recon|| = 8.0 vs buggy (b) ||x - recon + pre_bias|| = 449.7. 1b.sh: default DEAD_TOKENS_THRESHOLD to 10_000_000, matching the train.py default and codonfm convention (Gao et al.). Prior 500_000 default flagged ~70% of latents as 'dead' even when they were firing once per ~half-million tokens, vs codonfm's 0.003% under the canonical threshold. Still overridable via env. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../sparse_autoencoders/recipes/evo2/scripts/1b.sh | 2 +- .../sparse_autoencoders/sae/src/sae/architectures/topk.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/1b.sh b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/1b.sh index a3262ec0c5..6ed5c0fcc1 100755 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/1b.sh +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/1b.sh @@ -122,7 +122,7 @@ torchrun --nproc_per_node "${DEVICES:-4}" "${RECIPE_DIR}/scripts/train.py" \ --top-k "${TOP_K:-32}" \ --auxk "${AUXK:-512}" \ --auxk-coef "${AUXK_COEF:-0.03125}" \ - --dead-tokens-threshold "${DEAD_TOKENS_THRESHOLD:-500000}" \ + --dead-tokens-threshold "${DEAD_TOKENS_THRESHOLD:-10000000}" \ --init-pre-bias \ --n-epochs "${N_EPOCHS:-3}" \ --batch-size 4096 \ diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/architectures/topk.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/architectures/topk.py index 3be46cdf7f..53226f93b1 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/architectures/topk.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/architectures/topk.py @@ -289,7 +289,10 @@ def _compute_auxk_loss( recon_aux = F.linear(codes_aux, self.decoder.weight[:, dead_indices], self.decoder.bias) # Target is the residual (what primary reconstruction missed) - # Work in normalized space for the aux loss + # Canonical: residual = x - recon, i.e. the actual reconstruction error. + # The previous form `x - recon + pre_bias` simplifies to `x - decoder(codes)`, + # which has norm dominated by ||pre_bias|| rather than the actual error, + # weakening the aux gradient by roughly (||pre_bias|| / ||error||)^2. if self.normalize_input and norm_info is not None: # Normalize x to match the space where encoding happened x_norm = (x - norm_info["mu"]) / norm_info["std"] @@ -297,7 +300,7 @@ def _compute_auxk_loss( recon_norm = self.decoder(codes) + self.pre_bias residual = x_norm - recon_norm.detach() else: - residual = x - recon.detach() + self.pre_bias.detach() + residual = x - recon.detach() # Normalized MSE: MSE / variance of target mse = (recon_aux - residual).pow(2).mean(dim=-1) # [batch]