From 8a31579118be2fb30a633c91330a642a9c85778a Mon Sep 17 00:00:00 2001 From: jwilber Date: Tue, 21 Apr 2026 09:16:17 -0700 Subject: [PATCH] Add notebooks Signed-off-by: jwilber --- .../codonfm/notebooks/01_quickstart.ipynb | 415 ++++++++++ .../codonfm/notebooks/02_codon_analysis.ipynb | 673 +++++++++++++++ .../notebooks/03_gene_enrichment.ipynb | 485 +++++++++++ .../codonfm/notebooks/04_dashboard.ipynb | 538 ++++++++++++ .../codonfm/notebooks/05_auto_interp.ipynb | 721 ++++++++++++++++ .../codonfm/notebooks/06_probing.ipynb | 772 ++++++++++++++++++ 6 files changed, 3604 insertions(+) create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/notebooks/01_quickstart.ipynb create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/notebooks/02_codon_analysis.ipynb create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/notebooks/03_gene_enrichment.ipynb create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/notebooks/04_dashboard.ipynb create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/notebooks/05_auto_interp.ipynb create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/notebooks/06_probing.ipynb diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/notebooks/01_quickstart.ipynb b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/notebooks/01_quickstart.ipynb new file mode 100644 index 0000000000..9ded225f03 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/notebooks/01_quickstart.ipynb @@ -0,0 +1,415 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# CodonFM SAE — Quickstart\n", + "\n", + "This notebook loads a trained Sparse Autoencoder for CodonFM (a codon language model) and runs basic health checks. By the end you'll know whether your SAE is working: how many features are alive, what fraction of variance it explains, and what individual feature activations look like." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup\n", + "\n", + "Paths need to be configured for your environment. Update the checkpoint, model, and data paths below to match your local setup." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# ── Configure these paths for your environment ──\n", + "SAE_CHECKPOINT = \"../outputs/1b_layer16/checkpoints/checkpoint_final.pt\"\n", + "MODEL_PATH = \"../../../../../../../checkpoints/NV-CodonFM-Encodon-TE-Cdwt-1B-v1\"\n", + "CSV_PATH = \"../../../../../../../codonfm/data/codonfm_ref_only.csv\"\n", + "LAYER = 16\n", + "CONTEXT_LENGTH = 2048\n", + "BATCH_SIZE = 8\n", + "NUM_SEQUENCES = 500 # Keep small for quick iteration\n", + "DEVICE = \"cuda\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "from pathlib import Path\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import torch\n", + "\n", + "\n", + "# Add recipe paths\n", + "_REPO_ROOT = Path(\"..\").resolve().parent.parent.parent.parent.parent\n", + "_CODONFM_TE_DIR = _REPO_ROOT / \"recipes\" / \"codonfm_ptl_te\"\n", + "sys.path.insert(0, str(_CODONFM_TE_DIR))\n", + "sys.path.insert(0, str(Path(\"..\").resolve()))\n", + "\n", + "from codonfm_sae.data import read_codon_csv\n", + "from codonfm_sae.eval import evaluate_codonfm_loss_recovered\n", + "from sae.architectures import TopKSAE\n", + "from sae.utils import set_seed\n", + "from src.data.preprocess.codon_sequence import process_item\n", + "from src.inference.encodon import EncodonInference\n", + "\n", + "\n", + "set_seed(42)\n", + "device = DEVICE if torch.cuda.is_available() else \"cpu\"\n", + "print(f\"Device: {device}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load SAE Checkpoint\n", + "\n", + "Load the trained SAE from a checkpoint file. The checkpoint contains the model weights, architecture config (`input_dim`, `hidden_dim`, `top_k`), and training metadata." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ckpt = torch.load(SAE_CHECKPOINT, map_location=\"cpu\", weights_only=False)\n", + "state_dict = ckpt[\"model_state_dict\"]\n", + "if any(k.startswith(\"module.\") for k in state_dict):\n", + " state_dict = {k.removeprefix(\"module.\"): v for k, v in state_dict.items()}\n", + "\n", + "input_dim = ckpt.get(\"input_dim\") or state_dict[\"encoder.weight\"].shape[1]\n", + "hidden_dim = ckpt.get(\"hidden_dim\") or state_dict[\"encoder.weight\"].shape[0]\n", + "model_config = ckpt.get(\"model_config\", {})\n", + "top_k = model_config.get(\"top_k\")\n", + "\n", + "sae = TopKSAE(\n", + " input_dim=input_dim,\n", + " hidden_dim=hidden_dim,\n", + " top_k=top_k,\n", + " normalize_input=model_config.get(\"normalize_input\", False),\n", + ")\n", + "sae.load_state_dict(state_dict)\n", + "sae = sae.eval().to(device)\n", + "\n", + "print(f\"SAE: {input_dim} \\u2192 {hidden_dim:,} features (top-{top_k})\")\n", + "print(f\"Expansion factor: {hidden_dim / input_dim:.1f}x\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load Encodon Model & Data\n", + "\n", + "Load the CodonFM Encodon model (the base language model the SAE was trained on) and a CSV of codon sequences." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "inference = EncodonInference(\n", + " model_path=MODEL_PATH,\n", + " task_type=\"embedding_prediction\",\n", + " use_transformer_engine=True,\n", + ")\n", + "inference.configure_model()\n", + "inference.model.to(device).eval()\n", + "\n", + "num_layers = len(inference.model.model.layers)\n", + "target_layer = LAYER if LAYER >= 0 else num_layers + LAYER\n", + "print(f\"Encodon: {num_layers} layers, target layer {target_layer}\")\n", + "\n", + "# Load sequences\n", + "records = read_codon_csv(CSV_PATH, max_sequences=NUM_SEQUENCES, max_codons=CONTEXT_LENGTH - 2)\n", + "sequences = [r.sequence for r in records]\n", + "print(f\"Loaded {len(sequences)} sequences\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Loss Recovered\n", + "\n", + "The primary quality metric for an SAE. Measures what fraction of the language model's cross-entropy loss is preserved when hidden states are routed through the SAE. A value of 1.0 means perfect reconstruction; 0.0 means the SAE output is no better than zero. Good SAEs typically achieve 0.85–0.95." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "result = evaluate_codonfm_loss_recovered(\n", + " sae=sae,\n", + " inference=inference,\n", + " sequences=sequences[:200], # Use subset for speed\n", + " layer=LAYER,\n", + " context_length=CONTEXT_LENGTH,\n", + " batch_size=BATCH_SIZE,\n", + " device=device,\n", + ")\n", + "\n", + "print(f\"Loss recovered: {result.loss_recovered:.4f}\")\n", + "print(f\"CE (original): {result.ce_original:.4f}\")\n", + "print(f\"CE (with SAE): {result.ce_sae:.4f}\")\n", + "print(f\"CE (zero ablation): {result.ce_zero:.4f}\")\n", + "print(f\"Tokens evaluated: {result.n_tokens:,}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Feature Activation Statistics\n", + "\n", + "Extract activations and compute basic per-feature statistics. Key things to check: what fraction of features are \"alive\" (fire at least once), the distribution of activation frequencies, and the sparsity pattern." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from tqdm import tqdm\n", + "\n", + "\n", + "# Stream through sequences, compute per-feature max and firing counts\n", + "n_features = sae.hidden_dim\n", + "fire_counts = np.zeros(n_features, dtype=np.int64)\n", + "max_activations = np.zeros(n_features, dtype=np.float32)\n", + "total_tokens = 0\n", + "\n", + "with torch.no_grad():\n", + " for i in tqdm(range(0, len(sequences), BATCH_SIZE), desc=\"Extracting\"):\n", + " batch_seqs = sequences[i : i + BATCH_SIZE]\n", + " items = [process_item(s, context_length=CONTEXT_LENGTH, tokenizer=inference.tokenizer) for s in batch_seqs]\n", + " batch = {\n", + " \"input_ids\": torch.tensor(np.stack([it[\"input_ids\"] for it in items])).to(device),\n", + " \"attention_mask\": torch.tensor(np.stack([it[\"attention_mask\"] for it in items])).to(device),\n", + " }\n", + " out = inference.model(batch, return_hidden_states=True)\n", + " hidden = out.all_hidden_states[LAYER]\n", + "\n", + " for j, it in enumerate(items):\n", + " seq_len = it[\"attention_mask\"].sum()\n", + " emb = hidden[j, 1 : seq_len - 1, :].float() # Strip CLS/SEP\n", + " codes = sae.encode(emb) # [n_codons, n_features]\n", + " active = (codes > 0).cpu().numpy()\n", + " fire_counts += active.sum(axis=0)\n", + " np.maximum(max_activations, codes.max(dim=0).values.cpu().numpy(), out=max_activations)\n", + " total_tokens += emb.shape[0]\n", + "\n", + " del out, hidden, batch\n", + "\n", + "alive = fire_counts > 0\n", + "n_alive = alive.sum()\n", + "n_dead = n_features - n_alive\n", + "freq = fire_counts / total_tokens\n", + "\n", + "print(f\"Total tokens: {total_tokens:,}\")\n", + "print(f\"Alive features: {n_alive:,} / {n_features:,} ({n_alive / n_features:.1%})\")\n", + "print(f\"Dead features: {n_dead:,} ({n_dead / n_features:.1%})\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Activation Frequency Distribution\n", + "\n", + "How often each feature fires. A healthy SAE has a broad range of frequencies — some features fire rarely (specialized detectors) and some fire often (common patterns). A spike at zero means dead features." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n", + "\n", + "# Log-scale histogram of activation frequencies (alive only)\n", + "alive_freq = freq[alive]\n", + "axes[0].hist(np.log10(alive_freq + 1e-10), bins=50, color=\"#76b900\", edgecolor=\"white\", alpha=0.8)\n", + "axes[0].set_xlabel(\"log\\u2081\\u2080(activation frequency)\")\n", + "axes[0].set_ylabel(\"Number of features\")\n", + "axes[0].set_title(f\"Activation Frequency Distribution ({n_alive:,} alive features)\")\n", + "axes[0].axvline(\n", + " np.log10(np.median(alive_freq)), color=\"red\", linestyle=\"--\", label=f\"Median: {np.median(alive_freq):.4f}\"\n", + ")\n", + "axes[0].legend()\n", + "\n", + "# Max activation distribution\n", + "alive_max = max_activations[alive]\n", + "axes[1].hist(alive_max, bins=50, color=\"#0074DF\", edgecolor=\"white\", alpha=0.8)\n", + "axes[1].set_xlabel(\"Max activation value\")\n", + "axes[1].set_ylabel(\"Number of features\")\n", + "axes[1].set_title(\"Max Activation Distribution\")\n", + "axes[1].axvline(np.median(alive_max), color=\"red\", linestyle=\"--\", label=f\"Median: {np.median(alive_max):.1f}\")\n", + "axes[1].legend()\n", + "\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Inspect Individual Features\n", + "\n", + "Pick a few features and look at what codons they activate on. This gives intuition for what the SAE learned — whether features correspond to recognizable biological patterns." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Pick the top 5 most frequently firing features\n", + "top_feature_ids = np.argsort(freq)[-5:][::-1]\n", + "\n", + "CODON_TO_AA = {\n", + " \"TTT\": \"F\",\n", + " \"TTC\": \"F\",\n", + " \"TTA\": \"L\",\n", + " \"TTG\": \"L\",\n", + " \"CTT\": \"L\",\n", + " \"CTC\": \"L\",\n", + " \"CTA\": \"L\",\n", + " \"CTG\": \"L\",\n", + " \"ATT\": \"I\",\n", + " \"ATC\": \"I\",\n", + " \"ATA\": \"I\",\n", + " \"ATG\": \"M\",\n", + " \"GTT\": \"V\",\n", + " \"GTC\": \"V\",\n", + " \"GTA\": \"V\",\n", + " \"GTG\": \"V\",\n", + " \"TCT\": \"S\",\n", + " \"TCC\": \"S\",\n", + " \"TCA\": \"S\",\n", + " \"TCG\": \"S\",\n", + " \"CCT\": \"P\",\n", + " \"CCC\": \"P\",\n", + " \"CCA\": \"P\",\n", + " \"CCG\": \"P\",\n", + " \"ACT\": \"T\",\n", + " \"ACC\": \"T\",\n", + " \"ACA\": \"T\",\n", + " \"ACG\": \"T\",\n", + " \"GCT\": \"A\",\n", + " \"GCC\": \"A\",\n", + " \"GCA\": \"A\",\n", + " \"GCG\": \"A\",\n", + " \"TAT\": \"Y\",\n", + " \"TAC\": \"Y\",\n", + " \"TAA\": \"*\",\n", + " \"TAG\": \"*\",\n", + " \"CAT\": \"H\",\n", + " \"CAC\": \"H\",\n", + " \"CAA\": \"Q\",\n", + " \"CAG\": \"Q\",\n", + " \"AAT\": \"N\",\n", + " \"AAC\": \"N\",\n", + " \"AAA\": \"K\",\n", + " \"AAG\": \"K\",\n", + " \"GAT\": \"D\",\n", + " \"GAC\": \"D\",\n", + " \"GAA\": \"E\",\n", + " \"GAG\": \"E\",\n", + " \"TGT\": \"C\",\n", + " \"TGC\": \"C\",\n", + " \"TGA\": \"*\",\n", + " \"TGG\": \"W\",\n", + " \"CGT\": \"R\",\n", + " \"CGC\": \"R\",\n", + " \"CGA\": \"R\",\n", + " \"CGG\": \"R\",\n", + " \"AGT\": \"S\",\n", + " \"AGC\": \"S\",\n", + " \"AGA\": \"R\",\n", + " \"AGG\": \"R\",\n", + " \"GGT\": \"G\",\n", + " \"GGC\": \"G\",\n", + " \"GGA\": \"G\",\n", + " \"GGG\": \"G\",\n", + "}\n", + "\n", + "# Use decoder weights to find which codons each feature promotes\n", + "lm_head = inference.model.model.lm_head.weight.data.float() # [vocab, D]\n", + "decoder = sae.decoder.weight.data.float().to(device) # [input_dim, hidden_dim]\n", + "mean_acts = sae.pre_bias.data.float().to(device) if hasattr(sae, \"pre_bias\") else torch.zeros(input_dim, device=device)\n", + "\n", + "tokenizer = inference.tokenizer\n", + "codon_tokens = {}\n", + "for codon in CODON_TO_AA:\n", + " tok_id = tokenizer.token_to_id(codon)\n", + " if tok_id is not None:\n", + " codon_tokens[codon] = tok_id\n", + "\n", + "for feat_id in top_feature_ids:\n", + " feat_dir = decoder[:, feat_id]\n", + " logits = lm_head @ feat_dir # [vocab]\n", + " # Baseline: logits from mean\n", + " baseline = lm_head @ mean_acts\n", + " logit_diff = logits - baseline\n", + "\n", + " codon_logits = {c: float(logit_diff[tid]) for c, tid in codon_tokens.items()}\n", + " sorted_codons = sorted(codon_logits.items(), key=lambda x: x[1], reverse=True)\n", + "\n", + " top_pos = sorted_codons[:5]\n", + " top_neg = sorted_codons[-5:][::-1]\n", + "\n", + " print(f\"\\n{'=' * 50}\")\n", + " print(f\"Feature {feat_id} | freq={freq[feat_id]:.4f} | max_act={max_activations[feat_id]:.1f}\")\n", + " print(f\" Top promoted: {', '.join(f'{c}({CODON_TO_AA[c]})={v:.2f}' for c, v in top_pos)}\")\n", + " print(f\" Top suppressed: {', '.join(f'{c}({CODON_TO_AA[c]})={v:.2f}' for c, v in top_neg)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Next Steps\n", + "\n", + "Now that you've verified the SAE is working, explore deeper analyses:\n", + "\n", + "- **02_codon_analysis.ipynb** — Compute codon usage metrics (CAI, tAI, RSCU) per feature\n", + "- **03_gene_enrichment.ipynb** — Run GSEA to find which biological pathways each feature detects\n", + "- **04_dashboard.ipynb** — Export data and launch the interactive dashboard\n", + "- **05_auto_interp.ipynb** — Use an LLM to automatically describe each feature" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/notebooks/02_codon_analysis.ipynb b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/notebooks/02_codon_analysis.ipynb new file mode 100644 index 0000000000..5a1156d48f --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/notebooks/02_codon_analysis.ipynb @@ -0,0 +1,673 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# CodonFM SAE — Codon Analysis\n", + "\n", + "For each SAE feature, we compute codon-level properties of the positions where it activates. This reveals whether features track amino acid identity, synonymous codon preferences, codon usage bias, or positional patterns.\n", + "\n", + "We compute three standard codon optimality metrics:\n", + "- **CAI (Codon Adaptation Index)** — how well codons match the usage pattern of highly expressed genes. High CAI = optimized for efficient translation.\n", + "- **tAI (tRNA Adaptation Index)** — how well codons match tRNA availability. High tAI = faster translation due to abundant cognate tRNAs.\n", + "- **RSCU (Relative Synonymous Codon Usage)** — how biased the synonymous codon choice is. RSCU=1 means no preference among synonyms; >1 means the common synonym is preferred." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "SAE_CHECKPOINT = \"../outputs/1b_layer16/checkpoints/checkpoint_final.pt\"\n", + "MODEL_PATH = \"../../../../../../../checkpoints/NV-CodonFM-Encodon-TE-Cdwt-1B-v1\"\n", + "CSV_PATH = \"../../../../../../../codonfm/data/codonfm_ref_only.csv\"\n", + "LAYER = 16\n", + "CONTEXT_LENGTH = 2048\n", + "BATCH_SIZE = 8\n", + "NUM_SEQUENCES = 2000\n", + "DEVICE = \"cuda\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import sys\n", + "from pathlib import Path\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import torch\n", + "from tqdm import tqdm\n", + "\n", + "\n", + "_REPO_ROOT = Path(\"..\").resolve().parent.parent.parent.parent.parent\n", + "_CODONFM_TE_DIR = _REPO_ROOT / \"recipes\" / \"codonfm_ptl_te\"\n", + "sys.path.insert(0, str(_CODONFM_TE_DIR))\n", + "sys.path.insert(0, str(Path(\"..\").resolve()))\n", + "\n", + "from codonfm_sae.data import read_codon_csv\n", + "from sae.architectures import TopKSAE\n", + "from sae.utils import set_seed\n", + "from src.data.preprocess.codon_sequence import process_item\n", + "from src.inference.encodon import EncodonInference\n", + "\n", + "\n", + "set_seed(42)\n", + "device = DEVICE if torch.cuda.is_available() else \"cpu\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load SAE, Model, and Data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load SAE\n", + "ckpt = torch.load(SAE_CHECKPOINT, map_location=\"cpu\", weights_only=False)\n", + "state_dict = ckpt[\"model_state_dict\"]\n", + "if any(k.startswith(\"module.\") for k in state_dict):\n", + " state_dict = {k.removeprefix(\"module.\"): v for k, v in state_dict.items()}\n", + "input_dim = ckpt.get(\"input_dim\") or state_dict[\"encoder.weight\"].shape[1]\n", + "hidden_dim = ckpt.get(\"hidden_dim\") or state_dict[\"encoder.weight\"].shape[0]\n", + "model_config = ckpt.get(\"model_config\", {})\n", + "sae = TopKSAE(\n", + " input_dim=input_dim,\n", + " hidden_dim=hidden_dim,\n", + " top_k=model_config.get(\"top_k\"),\n", + " normalize_input=model_config.get(\"normalize_input\", False),\n", + ")\n", + "sae.load_state_dict(state_dict)\n", + "sae = sae.eval().to(device)\n", + "print(f\"SAE: {input_dim} \\u2192 {hidden_dim:,} features\")\n", + "\n", + "# Load Encodon\n", + "inference = EncodonInference(model_path=MODEL_PATH, task_type=\"embedding_prediction\", use_transformer_engine=True)\n", + "inference.configure_model()\n", + "inference.model.to(device).eval()\n", + "print(f\"Encodon loaded ({len(inference.model.model.layers)} layers)\")\n", + "\n", + "# Load data\n", + "records = read_codon_csv(CSV_PATH, max_sequences=NUM_SEQUENCES, max_codons=CONTEXT_LENGTH - 2)\n", + "sequences = [r.sequence for r in records]\n", + "print(f\"Loaded {len(sequences)} sequences\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Codon Optimality Reference Tables\n", + "\n", + "We use three reference tables to compute per-codon optimality weights:\n", + "- **Human codon usage frequencies** (Kazusa database) for CAI\n", + "- **Human tRNA gene copy numbers** (GtRNAdb, hg38) for tAI\n", + "- **RSCU** derived from the usage frequencies\n", + "\n", + "For each metric, the weight of a codon is normalized relative to the best synonym for the same amino acid." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "HUMAN_CODON_USAGE = {\n", + " \"TTT\": 17.6,\n", + " \"TTC\": 20.3,\n", + " \"TTA\": 7.7,\n", + " \"TTG\": 12.9,\n", + " \"CTT\": 13.2,\n", + " \"CTC\": 19.6,\n", + " \"CTA\": 7.2,\n", + " \"CTG\": 39.6,\n", + " \"ATT\": 16.0,\n", + " \"ATC\": 20.8,\n", + " \"ATA\": 7.5,\n", + " \"ATG\": 22.0,\n", + " \"GTT\": 11.0,\n", + " \"GTC\": 14.5,\n", + " \"GTA\": 7.1,\n", + " \"GTG\": 28.1,\n", + " \"TCT\": 15.2,\n", + " \"TCC\": 17.7,\n", + " \"TCA\": 12.2,\n", + " \"TCG\": 4.4,\n", + " \"CCT\": 17.5,\n", + " \"CCC\": 19.8,\n", + " \"CCA\": 16.9,\n", + " \"CCG\": 6.9,\n", + " \"ACT\": 13.1,\n", + " \"ACC\": 18.9,\n", + " \"ACA\": 15.1,\n", + " \"ACG\": 6.1,\n", + " \"GCT\": 18.4,\n", + " \"GCC\": 27.7,\n", + " \"GCA\": 15.8,\n", + " \"GCG\": 7.4,\n", + " \"TAT\": 12.2,\n", + " \"TAC\": 15.3,\n", + " \"TAA\": 1.0,\n", + " \"TAG\": 0.8,\n", + " \"CAT\": 10.9,\n", + " \"CAC\": 15.1,\n", + " \"CAA\": 12.3,\n", + " \"CAG\": 34.2,\n", + " \"AAT\": 17.0,\n", + " \"AAC\": 19.1,\n", + " \"AAA\": 24.4,\n", + " \"AAG\": 31.9,\n", + " \"GAT\": 21.8,\n", + " \"GAC\": 25.1,\n", + " \"GAA\": 29.0,\n", + " \"GAG\": 39.6,\n", + " \"TGT\": 10.6,\n", + " \"TGC\": 12.6,\n", + " \"TGA\": 1.6,\n", + " \"TGG\": 13.2,\n", + " \"CGT\": 4.5,\n", + " \"CGC\": 10.4,\n", + " \"CGA\": 6.2,\n", + " \"CGG\": 11.4,\n", + " \"AGT\": 12.1,\n", + " \"AGC\": 19.5,\n", + " \"AGA\": 12.2,\n", + " \"AGG\": 12.0,\n", + " \"GGT\": 10.8,\n", + " \"GGC\": 22.2,\n", + " \"GGA\": 16.5,\n", + " \"GGG\": 16.5,\n", + "}\n", + "\n", + "CODON_TO_AA = {\n", + " \"TTT\": \"F\",\n", + " \"TTC\": \"F\",\n", + " \"TTA\": \"L\",\n", + " \"TTG\": \"L\",\n", + " \"CTT\": \"L\",\n", + " \"CTC\": \"L\",\n", + " \"CTA\": \"L\",\n", + " \"CTG\": \"L\",\n", + " \"ATT\": \"I\",\n", + " \"ATC\": \"I\",\n", + " \"ATA\": \"I\",\n", + " \"ATG\": \"M\",\n", + " \"GTT\": \"V\",\n", + " \"GTC\": \"V\",\n", + " \"GTA\": \"V\",\n", + " \"GTG\": \"V\",\n", + " \"TCT\": \"S\",\n", + " \"TCC\": \"S\",\n", + " \"TCA\": \"S\",\n", + " \"TCG\": \"S\",\n", + " \"CCT\": \"P\",\n", + " \"CCC\": \"P\",\n", + " \"CCA\": \"P\",\n", + " \"CCG\": \"P\",\n", + " \"ACT\": \"T\",\n", + " \"ACC\": \"T\",\n", + " \"ACA\": \"T\",\n", + " \"ACG\": \"T\",\n", + " \"GCT\": \"A\",\n", + " \"GCC\": \"A\",\n", + " \"GCA\": \"A\",\n", + " \"GCG\": \"A\",\n", + " \"TAT\": \"Y\",\n", + " \"TAC\": \"Y\",\n", + " \"TAA\": \"*\",\n", + " \"TAG\": \"*\",\n", + " \"CAT\": \"H\",\n", + " \"CAC\": \"H\",\n", + " \"CAA\": \"Q\",\n", + " \"CAG\": \"Q\",\n", + " \"AAT\": \"N\",\n", + " \"AAC\": \"N\",\n", + " \"AAA\": \"K\",\n", + " \"AAG\": \"K\",\n", + " \"GAT\": \"D\",\n", + " \"GAC\": \"D\",\n", + " \"GAA\": \"E\",\n", + " \"GAG\": \"E\",\n", + " \"TGT\": \"C\",\n", + " \"TGC\": \"C\",\n", + " \"TGA\": \"*\",\n", + " \"TGG\": \"W\",\n", + " \"CGT\": \"R\",\n", + " \"CGC\": \"R\",\n", + " \"CGA\": \"R\",\n", + " \"CGG\": \"R\",\n", + " \"AGT\": \"S\",\n", + " \"AGC\": \"S\",\n", + " \"AGA\": \"R\",\n", + " \"AGG\": \"R\",\n", + " \"GGT\": \"G\",\n", + " \"GGC\": \"G\",\n", + " \"GGA\": \"G\",\n", + " \"GGG\": \"G\",\n", + "}\n", + "\n", + "_HUMAN_TRNA_COPY_NUMBERS = {\n", + " \"TTT\": 10,\n", + " \"TTC\": 20,\n", + " \"TTA\": 6,\n", + " \"TTG\": 11,\n", + " \"CTT\": 10,\n", + " \"CTC\": 20,\n", + " \"CTA\": 5,\n", + " \"CTG\": 20,\n", + " \"ATT\": 15,\n", + " \"ATC\": 23,\n", + " \"ATA\": 5,\n", + " \"ATG\": 23,\n", + " \"GTT\": 11,\n", + " \"GTC\": 14,\n", + " \"GTA\": 5,\n", + " \"GTG\": 16,\n", + " \"TCT\": 11,\n", + " \"TCC\": 17,\n", + " \"TCA\": 7,\n", + " \"TCG\": 4,\n", + " \"CCT\": 10,\n", + " \"CCC\": 12,\n", + " \"CCA\": 13,\n", + " \"CCG\": 5,\n", + " \"ACT\": 10,\n", + " \"ACC\": 20,\n", + " \"ACA\": 10,\n", + " \"ACG\": 6,\n", + " \"GCT\": 16,\n", + " \"GCC\": 34,\n", + " \"GCA\": 10,\n", + " \"GCG\": 6,\n", + " \"TAT\": 10,\n", + " \"TAC\": 16,\n", + " \"TAA\": 0,\n", + " \"TAG\": 0,\n", + " \"CAT\": 10,\n", + " \"CAC\": 15,\n", + " \"CAA\": 10,\n", + " \"CAG\": 34,\n", + " \"AAT\": 14,\n", + " \"AAC\": 20,\n", + " \"AAA\": 15,\n", + " \"AAG\": 34,\n", + " \"GAT\": 17,\n", + " \"GAC\": 25,\n", + " \"GAA\": 16,\n", + " \"GAG\": 40,\n", + " \"TGT\": 10,\n", + " \"TGC\": 20,\n", + " \"TGA\": 0,\n", + " \"TGG\": 10,\n", + " \"CGT\": 6,\n", + " \"CGC\": 15,\n", + " \"CGA\": 5,\n", + " \"CGG\": 5,\n", + " \"AGT\": 8,\n", + " \"AGC\": 18,\n", + " \"AGA\": 10,\n", + " \"AGG\": 8,\n", + " \"GGT\": 10,\n", + " \"GGC\": 22,\n", + " \"GGA\": 10,\n", + " \"GGG\": 8,\n", + "}\n", + "\n", + "from collections import defaultdict\n", + "\n", + "\n", + "aa_codons = defaultdict(list)\n", + "for codon, aa in CODON_TO_AA.items():\n", + " if aa != \"*\":\n", + " aa_codons[aa].append(codon)\n", + "\n", + "CAI_WEIGHTS = {}\n", + "for aa, codons in aa_codons.items():\n", + " freqs = [HUMAN_CODON_USAGE.get(c, 0.0) for c in codons]\n", + " max_freq = max(freqs)\n", + " for c, f in zip(codons, freqs):\n", + " CAI_WEIGHTS[c] = f / max_freq if max_freq > 0 else 0.0\n", + "\n", + "RSCU_VALUES = {}\n", + "for aa, codons in aa_codons.items():\n", + " freqs = [HUMAN_CODON_USAGE.get(c, 0.0) for c in codons]\n", + " total = sum(freqs)\n", + " n_syn = len(codons)\n", + " for c, f in zip(codons, freqs):\n", + " RSCU_VALUES[c] = (f * n_syn / total) if total > 0 else 1.0\n", + "\n", + "TAI_WEIGHTS = {}\n", + "for aa, codons in aa_codons.items():\n", + " copies = [_HUMAN_TRNA_COPY_NUMBERS.get(c, 0) for c in codons]\n", + " max_copy = max(copies)\n", + " for c, cp in zip(codons, copies):\n", + " TAI_WEIGHTS[c] = cp / max_copy if max_copy > 0 else 0.0\n", + "\n", + "print(f\"Built weights for {len(CAI_WEIGHTS)} codons\")\n", + "print(\n", + " f\"Example \\u2014 CTG(Leu): CAI={CAI_WEIGHTS['CTG']:.3f}, tAI={TAI_WEIGHTS['CTG']:.3f}, RSCU={RSCU_VALUES['CTG']:.3f}\"\n", + ")\n", + "print(\n", + " f\"Example \\u2014 CTA(Leu): CAI={CAI_WEIGHTS['CTA']:.3f}, tAI={TAI_WEIGHTS['CTA']:.3f}, RSCU={RSCU_VALUES['CTA']:.3f}\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Streaming Codon Analysis\n", + "\n", + "We stream sequences through the model and SAE, accumulating per-feature statistics:\n", + "- **Amino acid identity**: which amino acid each feature fires on most\n", + "- **Codon usage bias**: rare vs common codons\n", + "- **Wobble position**: GC vs AT at the 3rd (wobble) position\n", + "- **CpG sites**: whether the feature tracks CpG dinucleotides across codon boundaries\n", + "- **CAI/tAI/RSCU**: codon optimality metrics (geometric mean for CAI/tAI, arithmetic mean for RSCU)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "n_features = sae.hidden_dim\n", + "all_aas = sorted(set(CODON_TO_AA.values()))\n", + "aa_to_idx = {aa: i for i, aa in enumerate(all_aas)}\n", + "n_aa = len(all_aas)\n", + "\n", + "# Accumulators\n", + "aa_counts = np.zeros((n_aa, n_features), dtype=np.int64)\n", + "rare_counts = np.zeros(n_features, dtype=np.int64)\n", + "common_counts = np.zeros(n_features, dtype=np.int64)\n", + "wobble_gc_counts = np.zeros(n_features, dtype=np.int64)\n", + "wobble_at_counts = np.zeros(n_features, dtype=np.int64)\n", + "cai_log_sum = np.zeros(n_features, dtype=np.float64)\n", + "tai_log_sum = np.zeros(n_features, dtype=np.float64)\n", + "rscu_sum = np.zeros(n_features, dtype=np.float64)\n", + "optimality_count = np.zeros(n_features, dtype=np.int64)\n", + "\n", + "n_batches = (len(sequences) + BATCH_SIZE - 1) // BATCH_SIZE\n", + "\n", + "with torch.no_grad():\n", + " for batch_start in tqdm(range(0, len(sequences), BATCH_SIZE), total=n_batches, desc=\"Streaming\"):\n", + " batch_seqs = sequences[batch_start : batch_start + BATCH_SIZE]\n", + " items = [process_item(s, context_length=CONTEXT_LENGTH, tokenizer=inference.tokenizer) for s in batch_seqs]\n", + " batch = {\n", + " \"input_ids\": torch.tensor(np.stack([it[\"input_ids\"] for it in items])).to(device),\n", + " \"attention_mask\": torch.tensor(np.stack([it[\"attention_mask\"] for it in items])).to(device),\n", + " }\n", + " out = inference.model(batch, return_hidden_states=True)\n", + " hidden = out.all_hidden_states[LAYER].float()\n", + " attn = batch[\"attention_mask\"]\n", + "\n", + " # Build mask excluding CLS/SEP\n", + " keep = attn.clone()\n", + " keep[:, 0] = 0\n", + " lengths = attn.sum(dim=1)\n", + " for b in range(keep.shape[0]):\n", + " sep = int(lengths[b].item()) - 1\n", + " if sep > 0:\n", + " keep[b, sep] = 0\n", + "\n", + " for b in range(len(batch_seqs)):\n", + " vl = int(keep[b].sum().item())\n", + " if vl == 0:\n", + " continue\n", + " emb = hidden[b, :vl, :]\n", + " _, codes = sae(emb)\n", + " codes_cpu = codes.cpu().numpy()\n", + "\n", + " seq = batch_seqs[b]\n", + " codons = [seq[j * 3 : (j + 1) * 3].upper() for j in range(vl)]\n", + " active = codes_cpu > 0\n", + "\n", + " # Amino acid counts\n", + " aa_idx = np.array([aa_to_idx.get(CODON_TO_AA.get(c, \"?\"), 0) for c in codons])\n", + " for a in range(n_aa):\n", + " mask = aa_idx == a\n", + " if mask.any():\n", + " aa_counts[a] += active[mask].sum(axis=0)\n", + "\n", + " # Usage bias\n", + " is_rare = np.array([HUMAN_CODON_USAGE.get(c, 10.0) < 10.0 for c in codons])\n", + " rare_counts += active[is_rare].sum(axis=0) if is_rare.any() else 0\n", + " common_counts += active[~is_rare].sum(axis=0) if (~is_rare).any() else 0\n", + "\n", + " # Wobble\n", + " is_gc = np.array([c[2] in (\"G\", \"C\") if len(c) == 3 else False for c in codons])\n", + " wobble_gc_counts += active[is_gc].sum(axis=0) if is_gc.any() else 0\n", + " wobble_at_counts += active[~is_gc].sum(axis=0) if (~is_gc).any() else 0\n", + "\n", + " # CAI/tAI/RSCU\n", + " non_stop = np.array([CODON_TO_AA.get(c, \"*\") != \"*\" for c in codons])\n", + " cai_w = np.array([CAI_WEIGHTS.get(c, 0.0) for c in codons])\n", + " tai_w = np.array([TAI_WEIGHTS.get(c, 0.0) for c in codons])\n", + " rscu_v = np.array([RSCU_VALUES.get(c, 1.0) for c in codons])\n", + "\n", + " valid_cai = non_stop & (cai_w > 0)\n", + " valid_tai = non_stop & (tai_w > 0)\n", + "\n", + " if valid_cai.any():\n", + " log_cai = np.log(cai_w[valid_cai])\n", + " cai_log_sum += (active[valid_cai] * log_cai[:, None]).sum(axis=0)\n", + " if valid_tai.any():\n", + " log_tai = np.log(tai_w[valid_tai])\n", + " tai_log_sum += (active[valid_tai] * log_tai[:, None]).sum(axis=0)\n", + " if non_stop.any():\n", + " rscu_sum += (active[non_stop] * rscu_v[non_stop, None]).sum(axis=0)\n", + " optimality_count += active[non_stop].sum(axis=0)\n", + "\n", + " del out, hidden, batch\n", + " torch.cuda.empty_cache()\n", + "\n", + "print(\"Streaming complete.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Per-Feature Optimality Metrics\n", + "\n", + "Now we compute the final CAI, tAI, and RSCU for each feature. CAI and tAI use the geometric mean (exp of mean log-weights), while RSCU uses the arithmetic mean." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "alive_mask = optimality_count > 10\n", + "n_alive_opt = alive_mask.sum()\n", + "\n", + "feature_cai = np.full(n_features, np.nan)\n", + "feature_tai = np.full(n_features, np.nan)\n", + "feature_rscu = np.full(n_features, np.nan)\n", + "\n", + "feature_cai[alive_mask] = np.exp(cai_log_sum[alive_mask] / optimality_count[alive_mask])\n", + "feature_tai[alive_mask] = np.exp(tai_log_sum[alive_mask] / optimality_count[alive_mask])\n", + "feature_rscu[alive_mask] = rscu_sum[alive_mask] / optimality_count[alive_mask]\n", + "\n", + "print(f\"Features with optimality metrics: {n_alive_opt:,}\")\n", + "print(\n", + " f\"\\nCAI: mean={np.nanmean(feature_cai):.4f}, std={np.nanstd(feature_cai):.4f}, range=[{np.nanmin(feature_cai):.4f}, {np.nanmax(feature_cai):.4f}]\"\n", + ")\n", + "print(\n", + " f\"tAI: mean={np.nanmean(feature_tai):.4f}, std={np.nanstd(feature_tai):.4f}, range=[{np.nanmin(feature_tai):.4f}, {np.nanmax(feature_tai):.4f}]\"\n", + ")\n", + "print(\n", + " f\"RSCU: mean={np.nanmean(feature_rscu):.4f}, std={np.nanstd(feature_rscu):.4f}, range=[{np.nanmin(feature_rscu):.4f}, {np.nanmax(feature_rscu):.4f}]\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(1, 3, figsize=(16, 4))\n", + "\n", + "for ax, vals, name, color in [\n", + " (axes[0], feature_cai, \"CAI\", \"#76b900\"),\n", + " (axes[1], feature_tai, \"tAI\", \"#0074DF\"),\n", + " (axes[2], feature_rscu, \"RSCU\", \"#9525C6\"),\n", + "]:\n", + " valid = vals[~np.isnan(vals)]\n", + " ax.hist(valid, bins=50, color=color, edgecolor=\"white\", alpha=0.8)\n", + " ax.axvline(np.median(valid), color=\"red\", linestyle=\"--\", label=f\"Median: {np.median(valid):.3f}\")\n", + " ax.set_xlabel(name)\n", + " ax.set_ylabel(\"Number of features\")\n", + " ax.set_title(f\"{name} Distribution\")\n", + " ax.legend()\n", + "\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## CAI vs tAI Correlation\n", + "\n", + "CAI and tAI measure related but distinct properties. CAI reflects codon usage in highly expressed genes; tAI reflects tRNA availability. They're correlated (both favor 'optimal' codons) but can diverge \\u2014 some codons have high usage but low tRNA counts." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "valid = alive_mask\n", + "fig, ax = plt.subplots(figsize=(6, 6))\n", + "sc = ax.scatter(feature_cai[valid], feature_tai[valid], c=feature_rscu[valid], s=3, alpha=0.3, cmap=\"viridis\")\n", + "ax.set_xlabel(\"CAI (Codon Adaptation Index)\")\n", + "ax.set_ylabel(\"tAI (tRNA Adaptation Index)\")\n", + "ax.set_title(\"CAI vs tAI per Feature (colored by RSCU)\")\n", + "plt.colorbar(sc, label=\"RSCU\")\n", + "ax.plot([0, 1], [0, 1], \"k--\", alpha=0.3)\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Extreme Features\n", + "\n", + "Features with unusually high or low optimality scores are the most interesting:\n", + "- **High CAI features** likely fire on codons in highly expressed, translationally optimized genes\n", + "- **Low CAI features** may track rare codons used in tissue-specific or lowly expressed genes\n", + "- **High RSCU features** strongly prefer the dominant synonymous codon\n", + "- **Low RSCU features** prefer the rare synonym \\u2014 potentially detecting regulatory signals encoded in codon choice" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for metric, vals, name in [(\"CAI\", feature_cai, \"CAI\"), (\"tAI\", feature_tai, \"tAI\"), (\"RSCU\", feature_rscu, \"RSCU\")]:\n", + " valid_idx = np.where(~np.isnan(vals))[0]\n", + " sorted_idx = valid_idx[np.argsort(vals[valid_idx])]\n", + "\n", + " print(f\"\\n{'=' * 50}\")\n", + " print(f\"Top 5 highest {name}:\")\n", + " for i in sorted_idx[-5:][::-1]:\n", + " # Find dominant amino acid\n", + " total = aa_counts[:, i].sum()\n", + " if total > 0:\n", + " best_aa = all_aas[aa_counts[:, i].argmax()]\n", + " aa_frac = aa_counts[:, i].max() / total\n", + " aa_str = f\"AA={best_aa} ({aa_frac:.0%})\"\n", + " else:\n", + " aa_str = \"no data\"\n", + " print(f\" Feature {i:>5d}: {name}={vals[i]:.4f} {aa_str}\")\n", + "\n", + " print(f\"\\nTop 5 lowest {name}:\")\n", + " for i in sorted_idx[:5]:\n", + " total = aa_counts[:, i].sum()\n", + " if total > 0:\n", + " best_aa = all_aas[aa_counts[:, i].argmax()]\n", + " aa_frac = aa_counts[:, i].max() / total\n", + " aa_str = f\"AA={best_aa} ({aa_frac:.0%})\"\n", + " else:\n", + " aa_str = \"no data\"\n", + " print(f\" Feature {i:>5d}: {name}={vals[i]:.4f} {aa_str}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Save Results\n", + "\n", + "Save the per-feature optimality metrics for use in the dashboard (notebook 04) or further analysis." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "output_dir = Path(\"../outputs/1b_layer16/analysis\")\n", + "output_dir.mkdir(parents=True, exist_ok=True)\n", + "\n", + "results = {}\n", + "for f in range(n_features):\n", + " if alive_mask[f]:\n", + " results[f] = {\n", + " \"cai\": round(float(feature_cai[f]), 4),\n", + " \"tai\": round(float(feature_tai[f]), 4),\n", + " \"rscu\": round(float(feature_rscu[f]), 4),\n", + " }\n", + "\n", + "with open(output_dir / \"codon_optimality.json\", \"w\") as fp:\n", + " json.dump(results, fp, indent=2)\n", + "\n", + "print(f\"Saved optimality metrics for {len(results)} features to {output_dir / 'codon_optimality.json'}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/notebooks/03_gene_enrichment.ipynb b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/notebooks/03_gene_enrichment.ipynb new file mode 100644 index 0000000000..aa25589fe8 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/notebooks/03_gene_enrichment.ipynb @@ -0,0 +1,485 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# CodonFM SAE — Gene-Level Enrichment (GSEA)\n", + "\n", + "The quickstart and codon analysis notebooks examine features at the **codon level** — which codons a feature promotes, what amino acids it responds to, whether it tracks usage bias. But many biological patterns are best understood at the **gene level**: a feature that fires on all ribosomal protein genes, or one that lights up specifically on olfactory receptors.\n", + "\n", + "This notebook runs **Gene Set Enrichment Analysis (GSEA)** on each SAE feature. The idea is simple:\n", + "1. For each feature, rank all genes by how strongly the feature activates on them.\n", + "2. Test whether the top-ranked genes are enriched for known biological annotations.\n", + "\n", + "We test against five annotation databases, each capturing a different axis of biology:\n", + "\n", + "| Database | What it captures |\n", + "|---|---|\n", + "| **GO Biological Process** | Pathways and processes (e.g., \"translation\", \"apoptosis\") |\n", + "| **GO Molecular Function** | Biochemical activity (e.g., \"kinase activity\", \"DNA binding\") |\n", + "| **GO Cellular Component** | Subcellular location (e.g., \"ribosome\", \"mitochondrial matrix\") |\n", + "| **InterPro Domains** | Protein domain families (e.g., \"immunoglobulin-like fold\") |\n", + "| **Pfam Domains** | Conserved protein domain families |\n", + "\n", + "Beyond GSEA, we also run two lighter-weight analyses:\n", + "- **Gene family detection**: Do the top genes for a feature share a naming prefix (e.g., `OR` for olfactory receptors, `RPS` for ribosomal proteins)?\n", + "- **pLI scores** (optional): Are the top genes evolutionarily constrained? pLI (probability of loss-of-function intolerance) from gnomAD measures how much evolutionary pressure there is to keep a gene intact. Features that activate on high-pLI genes may be tracking essential cellular functions." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Config\n", + "SAE_CHECKPOINT = \"../outputs/1b_layer16/checkpoints/checkpoint_final.pt\"\n", + "MODEL_PATH = \"../../../../../../../checkpoints/NV-CodonFM-Encodon-TE-Cdwt-1B-v1\"\n", + "CSV_PATH = \"../../../../../../../codonfm/data/codonfm_ref_only.csv\"\n", + "LAYER = 16\n", + "CONTEXT_LENGTH = 2048\n", + "BATCH_SIZE = 8\n", + "DEVICE = \"cuda\"\n", + "NUM_SEQUENCES = 5000\n", + "N_WORKERS = 8" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import sys\n", + "from pathlib import Path\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "import torch\n", + "\n", + "\n", + "_REPO_ROOT = Path(\"..\").resolve().parent.parent.parent.parent.parent\n", + "_CODONFM_TE_DIR = _REPO_ROOT / \"recipes\" / \"codonfm_ptl_te\"\n", + "sys.path.insert(0, str(_CODONFM_TE_DIR))\n", + "sys.path.insert(0, str(Path(\"..\").resolve()))\n", + "\n", + "from codonfm_sae.data import read_codon_csv\n", + "from sae.architectures import TopKSAE\n", + "from sae.utils import set_seed\n", + "from src.data.preprocess.codon_sequence import process_item\n", + "from src.inference.encodon import EncodonInference\n", + "\n", + "\n", + "set_seed(42)\n", + "device = DEVICE if torch.cuda.is_available() else \"cpu\"\n", + "print(f\"Device: {device}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load SAE & Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ckpt = torch.load(SAE_CHECKPOINT, map_location=\"cpu\", weights_only=False)\n", + "state_dict = ckpt[\"model_state_dict\"]\n", + "if any(k.startswith(\"module.\") for k in state_dict):\n", + " state_dict = {k.removeprefix(\"module.\"): v for k, v in state_dict.items()}\n", + "\n", + "input_dim = ckpt.get(\"input_dim\") or state_dict[\"encoder.weight\"].shape[1]\n", + "hidden_dim = ckpt.get(\"hidden_dim\") or state_dict[\"encoder.weight\"].shape[0]\n", + "model_config = ckpt.get(\"model_config\", {})\n", + "top_k = model_config.get(\"top_k\")\n", + "\n", + "sae = TopKSAE(\n", + " input_dim=input_dim,\n", + " hidden_dim=hidden_dim,\n", + " top_k=top_k,\n", + " normalize_input=model_config.get(\"normalize_input\", False),\n", + ")\n", + "sae.load_state_dict(state_dict)\n", + "sae = sae.eval().to(device)\n", + "\n", + "print(f\"SAE: {input_dim} -> {hidden_dim:,} features (top-{top_k})\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "inference = EncodonInference(\n", + " model_path=MODEL_PATH,\n", + " task_type=\"embedding_prediction\",\n", + " use_transformer_engine=True,\n", + ")\n", + "inference.configure_model()\n", + "inference.model.to(device).eval()\n", + "\n", + "num_layers = len(inference.model.model.layers)\n", + "target_layer = LAYER if LAYER >= 0 else num_layers + LAYER\n", + "print(f\"Encodon: {num_layers} layers, target layer {target_layer}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load Data and Extract Gene Names\n", + "\n", + "GSEA requires gene-level labels. We load the CSV and filter to sequences that have a `gene` column populated. Each sequence maps to one gene, but a gene may appear in multiple sequences (e.g., different transcripts or species). We'll collapse across sequences later by taking the max activation per gene." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "records = read_codon_csv(CSV_PATH, max_sequences=NUM_SEQUENCES, max_codons=CONTEXT_LENGTH - 2)\n", + "\n", + "gene_names = []\n", + "valid_records = []\n", + "for rec in records:\n", + " gene = rec.metadata.get(\"gene\")\n", + " if gene and str(gene).strip():\n", + " gene_names.append(str(gene).strip())\n", + " valid_records.append(rec)\n", + "\n", + "sequences = [r.sequence for r in valid_records]\n", + "print(f\"{len(sequences)} sequences with gene names ({len(set(gene_names))} unique genes)\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Compute Per-Gene Activations\n", + "\n", + "We need to go from per-codon SAE activations to a single score per (gene, feature) pair. The aggregation pipeline is:\n", + "\n", + "1. **Per codon**: Run each sequence through the model + SAE to get activations at every codon position.\n", + "2. **Per sequence**: Take the **max** activation across all codons in that sequence, for each feature. This captures whether the feature fires *anywhere* in the sequence.\n", + "3. **Per gene**: If a gene appears in multiple sequences, take the **max** across sequences. This gives us the strongest signal the feature has for that gene.\n", + "\n", + "The result is a dictionary: `feature_idx -> gene_name -> activation_score`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from tqdm import tqdm\n", + "\n", + "\n", + "n_features = sae.hidden_dim\n", + "n_sequences = len(sequences)\n", + "\n", + "# Phase 1: Compute per-sequence max activations\n", + "# Stream one batch at a time to avoid OOM\n", + "seq_max_acts = np.zeros((n_sequences, n_features), dtype=np.float32)\n", + "\n", + "with torch.no_grad():\n", + " for i in tqdm(range(0, n_sequences, BATCH_SIZE), desc=\"Extracting activations\"):\n", + " batch_seqs = sequences[i : i + BATCH_SIZE]\n", + " items = [process_item(s, context_length=CONTEXT_LENGTH, tokenizer=inference.tokenizer) for s in batch_seqs]\n", + " batch = {\n", + " \"input_ids\": torch.tensor(np.stack([it[\"input_ids\"] for it in items])).to(device),\n", + " \"attention_mask\": torch.tensor(np.stack([it[\"attention_mask\"] for it in items])).to(device),\n", + " }\n", + " out = inference.model(batch, return_hidden_states=True)\n", + " hidden = out.all_hidden_states[LAYER]\n", + "\n", + " for j, it in enumerate(items):\n", + " seq_len = it[\"attention_mask\"].sum()\n", + " emb = hidden[j, 1 : seq_len - 1, :].float() # Strip CLS/SEP\n", + " codes = sae.encode(emb) # [n_codons, n_features]\n", + " seq_max_acts[i + j] = codes.max(dim=0).values.cpu().numpy()\n", + "\n", + " del out, hidden, batch\n", + " torch.cuda.empty_cache()\n", + "\n", + "print(f\"Computed max activations: {seq_max_acts.shape}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Phase 2: Collapse to per-gene max\n", + "df = pd.DataFrame(seq_max_acts)\n", + "df[\"gene\"] = gene_names\n", + "gene_max = df.groupby(\"gene\").max() # (n_genes, n_features)\n", + "\n", + "# Convert to the dict format expected by run_gene_enrichment:\n", + "# feature_idx -> gene_name -> score\n", + "gene_activations = {}\n", + "for feat_idx in range(n_features):\n", + " col = gene_max[feat_idx]\n", + " nonzero = col[col > 0]\n", + " if len(nonzero) > 0:\n", + " gene_activations[feat_idx] = nonzero.to_dict()\n", + "\n", + "print(f\"{len(gene_activations)} features with non-zero gene activations\")\n", + "print(f\"{len(gene_max)} unique genes in activation matrix\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Run GSEA\n", + "\n", + "For each feature, we run `gseapy.prerank()` against all five annotation databases. This uses the pre-ranked gene list (sorted by activation strength) and tests whether genes belonging to each annotation term are concentrated at the top of the ranking.\n", + "\n", + "We use an FDR (false discovery rate) threshold of 0.05 — a feature is considered \"enriched\" if at least one term from any database passes this threshold. The `min_size=5` parameter ensures we only test terms with at least 5 genes, avoiding spurious hits from tiny gene sets.\n", + "\n", + "This step is parallelized across features since each GSEA run is independent." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from codonfm_sae.eval.gene_enrichment import ANNOTATION_DATABASES, run_gene_enrichment\n", + "\n", + "\n", + "report = run_gene_enrichment(\n", + " gene_activations,\n", + " databases=ANNOTATION_DATABASES,\n", + " fdr_threshold=0.05,\n", + " n_workers=N_WORKERS,\n", + ")\n", + "\n", + "print(f\"Enriched: {report.n_features_with_enrichment}/{report.n_features_total} ({report.frac_enriched:.1%})\")\n", + "for db, stats in report.per_database_stats.items():\n", + " print(f\" {db}: {stats['n_enriched']} features, {stats['n_unique_terms']} terms\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Explore Results\n", + "\n", + "Each enriched feature gets a \"best\" label — the annotation term with the lowest FDR across all databases. Features enriched for GO Biological Process terms like \"translation\" or \"immune response\" are strong evidence that the SAE has learned biologically meaningful decompositions.\n", + "\n", + "Let's look at the top 10 most significantly enriched features." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Top enriched features by FDR\n", + "sorted_features = sorted(\n", + " report.per_feature,\n", + " key=lambda x: x.overall_best.fdr if x.overall_best else 1.0,\n", + ")\n", + "\n", + "print(\"Top 10 most significantly enriched features:\")\n", + "print(f\"{'Feature':>8} {'FDR':>8} {'NES':>6} {'Database':<30} {'Term'}\")\n", + "print(\"-\" * 90)\n", + "for fl in sorted_features[:10]:\n", + " if fl.overall_best:\n", + " b = fl.overall_best\n", + " print(f\"{fl.feature_idx:>8} {b.fdr:>8.4f} {b.enrichment_score:>6.2f} {b.database:<30} {b.term_name}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Show all significant terms for a single feature\n", + "if sorted_features and sorted_features[0].overall_best:\n", + " example_feat = sorted_features[0]\n", + " print(f\"\\nAll significant terms for feature {example_feat.feature_idx}:\")\n", + " print(f\"{'Database':<30} {'FDR':>8} {'NES':>6} {'Term'}\")\n", + " print(\"-\" * 90)\n", + " for er in sorted(example_feat.all_significant, key=lambda x: x.fdr):\n", + " print(f\"{er.database:<30} {er.fdr:>8.4f} {er.enrichment_score:>6.2f} {er.term_name}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Gene Family Detection\n", + "\n", + "A simpler heuristic that complements GSEA: look at the top-K genes for each feature and check if they share a naming prefix. Gene families often use systematic prefixes — `OR` for olfactory receptors, `RPS`/`RPL` for ribosomal proteins, `HLA` for MHC genes, `KRT` for keratins, etc.\n", + "\n", + "A feature where 8 out of 10 top genes start with `OR` is almost certainly tracking olfactory receptor genes, even if GSEA didn't find a significant hit (perhaps because the gene set databases don't have an explicit \"olfactory receptor\" term)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from codonfm_sae.eval.gene_enrichment import detect_gene_families\n", + "\n", + "\n", + "families = detect_gene_families(gene_activations)\n", + "print(f\"{len(families)} features with dominant gene family\")\n", + "\n", + "print(f\"\\n{'Feature':>8} {'Family'}\")\n", + "print(\"-\" * 40)\n", + "for feat, label in list(families.items())[:15]:\n", + " print(f\"{feat:>8} {label}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## pLI Scores (Optional)\n", + "\n", + "**pLI** (probability of Loss-of-function Intolerance) is a measure from the gnomAD project that quantifies how much evolutionary selection pressure acts against loss-of-function mutations in a gene. A pLI score close to 1.0 means the gene is highly constrained — losing one copy is likely lethal or strongly deleterious.\n", + "\n", + "By computing the mean pLI of a feature's top-activating genes, we can ask: **does this feature preferentially fire on essential genes?** Features with high mean pLI may be tracking conserved regulatory patterns or housekeeping functions.\n", + "\n", + "This requires the gnomAD constraint file (`gnomad.v2.1.1.lof_metrics.by_gene.txt.bgz`), which is not included in this repo. Uncomment the cell below if you have it available." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Uncomment and update the path if you have the gnomAD pLI file:\n", + "# PLI_PATH = \"./datasets/gnomad.v2.1.1.lof_metrics.by_gene.txt.bgz\"\n", + "# pli_scores = load_pli_scores(PLI_PATH)\n", + "# print(f\"Loaded pLI scores for {len(pli_scores)} genes\")\n", + "#\n", + "# feature_pli = compute_feature_pli(gene_activations, pli_scores)\n", + "# print(f\"{len(feature_pli)} features with pLI metrics\")\n", + "#\n", + "# # Show features with highest mean pLI (most constrained)\n", + "# top_pli = sorted(feature_pli.items(), key=lambda x: x[1][\"mean_pli\"], reverse=True)[:10]\n", + "# print(f\"\\n{'Feature':>8} {'Mean pLI':>8} {'Frac constrained':>16} {'Max pLI':>8}\")\n", + "# print(\"-\" * 50)\n", + "# for feat, metrics in top_pli:\n", + "# print(f\"{feat:>8} {metrics['mean_pli']:>8.3f} {metrics['frac_constrained']:>16.3f} {metrics['max_pli']:>8.3f}\")\n", + "\n", + "print(\"pLI analysis requires gnomAD constraint file. See docstring above.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Save Results\n", + "\n", + "Save the enrichment report as JSON for downstream use (e.g., enriching the dashboard atlas, providing GSEA context to auto-interp prompts)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from dataclasses import asdict\n", + "\n", + "\n", + "output_dir = Path(\"../outputs/1b_layer16/gene_enrichment\")\n", + "output_dir.mkdir(parents=True, exist_ok=True)\n", + "\n", + "\n", + "# Save report JSON\n", + "def _enrichment_result_to_dict(er):\n", + " if er is None:\n", + " return None\n", + " return asdict(er)\n", + "\n", + "\n", + "report_data = {\n", + " \"databases_used\": report.databases_used,\n", + " \"n_features_with_enrichment\": report.n_features_with_enrichment,\n", + " \"n_features_total\": report.n_features_total,\n", + " \"frac_enriched\": report.frac_enriched,\n", + " \"per_database_stats\": report.per_database_stats,\n", + " \"significance_threshold\": report.significance_threshold,\n", + " \"per_feature\": [\n", + " {\n", + " \"feature_idx\": fl.feature_idx,\n", + " \"overall_best\": _enrichment_result_to_dict(fl.overall_best),\n", + " \"go_slim_term\": fl.go_slim_term,\n", + " \"go_slim_name\": fl.go_slim_name,\n", + " \"best_per_database\": {db: _enrichment_result_to_dict(er) for db, er in fl.best_per_database.items()},\n", + " \"n_significant\": len(fl.all_significant),\n", + " }\n", + " for fl in report.per_feature\n", + " ],\n", + "}\n", + "\n", + "report_path = output_dir / \"gene_enrichment_report.json\"\n", + "with open(report_path, \"w\") as f:\n", + " json.dump(report_data, f, indent=2)\n", + "\n", + "print(f\"Saved report to {report_path}\")\n", + "print(f\" {report.n_features_with_enrichment} features enriched across {len(report.databases_used)} databases\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Next Steps\n", + "\n", + "- **04_dashboard.ipynb** — Export feature atlas + examples for the interactive dashboard, incorporating GSEA labels\n", + "- **05_auto_interp.ipynb** — Use an LLM to generate natural-language descriptions, with GSEA context as input\n", + "- Run the standalone script for production workloads:\n", + " ```bash\n", + " python scripts/eval_gene_enrichment.py \\\n", + " --checkpoint outputs/1b_layer16/checkpoints/checkpoint_final.pt \\\n", + " --model-path $MODEL_PATH --layer 16 \\\n", + " --csv-path $CSV_PATH --n-workers 8 \\\n", + " --output-dir outputs/1b_layer16/gene_enrichment\n", + " ```" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/notebooks/04_dashboard.ipynb b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/notebooks/04_dashboard.ipynb new file mode 100644 index 0000000000..5be432b4f8 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/notebooks/04_dashboard.ipynb @@ -0,0 +1,538 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# CodonFM SAE — Dashboard Export\n", + "\n", + "The interactive dashboard provides a visual interface for exploring SAE features. It shows:\n", + "\n", + "- A **UMAP scatter plot** of all features, computed from the SAE decoder weight vectors. Features that decode to similar directions in activation space appear nearby, revealing natural clusters (e.g., codon-usage features vs. domain-specific features).\n", + "- **Feature cards** with top-activating sequences, where each codon is colored by activation intensity. This lets you see *what* a feature responds to in context.\n", + "- **Crossfiltering** by metadata columns — GSEA labels, codon annotations, gene families, variant analysis scores.\n", + "\n", + "This notebook generates the three data files the dashboard needs:\n", + "\n", + "| File | Contents |\n", + "|---|---|\n", + "| `features_atlas.parquet` | One row per feature: UMAP coordinates, activation frequency, max activation, cluster ID, plus any enrichment columns |\n", + "| `feature_metadata.parquet` | Per-feature description and stats (activation frequency, max activation) |\n", + "| `feature_examples.parquet` | Top-K activating sequences per feature, with per-codon activation arrays |" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Config\n", + "SAE_CHECKPOINT = \"../outputs/1b_layer16/checkpoints/checkpoint_final.pt\"\n", + "MODEL_PATH = \"../../../../../../../checkpoints/NV-CodonFM-Encodon-TE-Cdwt-1B-v1\"\n", + "CSV_PATH = \"../../../../../../../codonfm/data/codonfm_ref_only.csv\"\n", + "LAYER = 16\n", + "CONTEXT_LENGTH = 2048\n", + "BATCH_SIZE = 8\n", + "DEVICE = \"cuda\"\n", + "NUM_SEQUENCES = 2000\n", + "N_EXAMPLES = 6 # Top examples per feature" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import sys\n", + "from pathlib import Path\n", + "\n", + "import numpy as np\n", + "import torch\n", + "\n", + "\n", + "_REPO_ROOT = Path(\"..\").resolve().parent.parent.parent.parent.parent\n", + "_CODONFM_TE_DIR = _REPO_ROOT / \"recipes\" / \"codonfm_ptl_te\"\n", + "sys.path.insert(0, str(_CODONFM_TE_DIR))\n", + "sys.path.insert(0, str(Path(\"..\").resolve()))\n", + "\n", + "from codonfm_sae.data import read_codon_csv\n", + "from sae.architectures import TopKSAE\n", + "from sae.utils import set_seed\n", + "from src.data.preprocess.codon_sequence import process_item\n", + "from src.inference.encodon import EncodonInference\n", + "\n", + "\n", + "set_seed(42)\n", + "device = DEVICE if torch.cuda.is_available() else \"cpu\"\n", + "print(f\"Device: {device}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load SAE, Model, and Data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ckpt = torch.load(SAE_CHECKPOINT, map_location=\"cpu\", weights_only=False)\n", + "state_dict = ckpt[\"model_state_dict\"]\n", + "if any(k.startswith(\"module.\") for k in state_dict):\n", + " state_dict = {k.removeprefix(\"module.\"): v for k, v in state_dict.items()}\n", + "\n", + "input_dim = ckpt.get(\"input_dim\") or state_dict[\"encoder.weight\"].shape[1]\n", + "hidden_dim = ckpt.get(\"hidden_dim\") or state_dict[\"encoder.weight\"].shape[0]\n", + "model_config = ckpt.get(\"model_config\", {})\n", + "top_k = model_config.get(\"top_k\")\n", + "\n", + "sae = TopKSAE(\n", + " input_dim=input_dim,\n", + " hidden_dim=hidden_dim,\n", + " top_k=top_k,\n", + " normalize_input=model_config.get(\"normalize_input\", False),\n", + ")\n", + "sae.load_state_dict(state_dict)\n", + "sae = sae.eval().to(device)\n", + "\n", + "print(f\"SAE: {input_dim} -> {hidden_dim:,} features (top-{top_k})\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "inference = EncodonInference(\n", + " model_path=MODEL_PATH,\n", + " task_type=\"embedding_prediction\",\n", + " use_transformer_engine=True,\n", + ")\n", + "inference.configure_model()\n", + "inference.model.to(device).eval()\n", + "\n", + "num_layers = len(inference.model.model.layers)\n", + "target_layer = LAYER if LAYER >= 0 else num_layers + LAYER\n", + "print(f\"Encodon: {num_layers} layers, target layer {target_layer}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "records = read_codon_csv(CSV_PATH, max_sequences=NUM_SEQUENCES, max_codons=CONTEXT_LENGTH - 2)\n", + "sequences = [r.sequence for r in records]\n", + "sequence_ids = [r.id for r in records]\n", + "print(f\"Loaded {len(sequences)} sequences\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Extract Activations\n", + "\n", + "We extract 3D activations (sequences x positions x hidden_dim) from the target layer, stripping the CLS and SEP tokens. These are the raw inputs the SAE was trained on." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from tqdm import tqdm\n", + "\n", + "\n", + "all_embeddings = []\n", + "all_masks = []\n", + "\n", + "n_batches = (len(sequences) + BATCH_SIZE - 1) // BATCH_SIZE\n", + "\n", + "with torch.no_grad():\n", + " for i in tqdm(range(0, len(sequences), BATCH_SIZE), total=n_batches, desc=\"Extracting activations\"):\n", + " batch_seqs = sequences[i : i + BATCH_SIZE]\n", + " items = [process_item(s, context_length=CONTEXT_LENGTH, tokenizer=inference.tokenizer) for s in batch_seqs]\n", + "\n", + " batch = {\n", + " \"input_ids\": torch.tensor(np.stack([it[\"input_ids\"] for it in items])).to(device),\n", + " \"attention_mask\": torch.tensor(np.stack([it[\"attention_mask\"] for it in items])).to(device),\n", + " }\n", + "\n", + " out = inference.model(batch, return_hidden_states=True)\n", + " hidden = out.all_hidden_states[LAYER].float().cpu()\n", + " attn_mask = batch[\"attention_mask\"].cpu()\n", + "\n", + " # Build mask excluding CLS (pos 0) and SEP (last real pos)\n", + " keep = attn_mask.clone()\n", + " keep[:, 0] = 0\n", + " lengths = attn_mask.sum(dim=1)\n", + " for b in range(keep.shape[0]):\n", + " sep = int(lengths[b].item()) - 1\n", + " if sep > 0:\n", + " keep[b, sep] = 0\n", + "\n", + " all_embeddings.append(hidden)\n", + " all_masks.append(keep)\n", + "\n", + " del out, batch\n", + " torch.cuda.empty_cache()\n", + "\n", + "# Pad to same length\n", + "max_len = max(e.shape[1] for e in all_embeddings)\n", + "padded_emb = []\n", + "padded_masks = []\n", + "for emb, msk in zip(all_embeddings, all_masks):\n", + " B, L, D = emb.shape\n", + " if L < max_len:\n", + " emb = torch.cat([emb, torch.zeros(B, max_len - L, D)], dim=1)\n", + " msk = torch.cat([msk, torch.zeros(B, max_len - L, dtype=msk.dtype)], dim=1)\n", + " padded_emb.append(emb)\n", + " padded_masks.append(msk)\n", + "\n", + "activations = torch.cat(padded_emb, dim=0)\n", + "masks = torch.cat(padded_masks, dim=0)\n", + "activations_flat = activations[masks.bool()]\n", + "\n", + "print(f\"Activations: {activations.shape}\")\n", + "print(f\"Valid codons: {activations_flat.shape[0]:,}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Compute Feature Statistics\n", + "\n", + "For each feature, compute global statistics (activation frequency, mean activation, max activation) from the flattened activations. These become columns in the atlas and are used for filtering in the dashboard." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sae.analysis import compute_feature_stats, compute_feature_umap, save_feature_atlas\n", + "\n", + "\n", + "stats, _ = compute_feature_stats(sae, activations_flat, device=device)\n", + "print(f\"Computed stats for {len(stats)} features\")\n", + "print(f\" Alive features: {sum(1 for s in stats.values() if s.get('activation_freq', 0) > 0)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Compute UMAP from Decoder Weights\n", + "\n", + "The UMAP is computed from the SAE **decoder weight matrix** (not from activations). Each feature has a decoder vector in activation space — the direction it represents. UMAP projects these high-dimensional vectors to 2D, preserving local structure.\n", + "\n", + "Features with similar decoder vectors appear nearby on the UMAP, even if they fire on different sequences. This reveals the geometric structure the SAE has learned: clusters of synonymous-codon features, amino-acid features, domain-specific features, etc.\n", + "\n", + "We also run HDBSCAN clustering on the UMAP coordinates to automatically identify feature groups." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "geometry = compute_feature_umap(\n", + " sae,\n", + " n_neighbors=15,\n", + " min_dist=0.1,\n", + " random_state=42,\n", + " compute_clusters=True,\n", + " hdbscan_min_cluster_size=20,\n", + ")\n", + "\n", + "n_clusters = len(set(geometry.get(\"cluster_id\", {}).values())) - (\n", + " 1 if -1 in geometry.get(\"cluster_id\", {}).values() else 0\n", + ")\n", + "print(f\"UMAP computed: {len(geometry.get('umap_x', {}))} features\")\n", + "print(f\"HDBSCAN found {n_clusters} clusters\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Export Feature Atlas\n", + "\n", + "The atlas is the central parquet file — one row per feature with all metadata columns. The dashboard loads this to render the UMAP and populate filter controls." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "output_dir = Path(\"../outputs/1b_layer16/dashboard\")\n", + "output_dir.mkdir(parents=True, exist_ok=True)\n", + "\n", + "atlas_path = output_dir / \"features_atlas.parquet\"\n", + "save_feature_atlas(stats, geometry, atlas_path)\n", + "print(f\"Saved atlas to {atlas_path}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Export Feature Examples\n", + "\n", + "For each feature, we find the top-K sequences with the highest max activation and extract per-codon activation arrays. This is the data behind the \"feature cards\" in the dashboard — the highlighted sequences that show what each feature responds to.\n", + "\n", + "The export uses a two-pass approach to avoid holding all per-codon activations in memory:\n", + "1. **Pass 1**: Compute max activation per (sequence, feature) to identify top examples.\n", + "2. **Pass 2**: Re-encode only the needed sequences to extract per-codon activations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pyarrow as pa\n", + "import pyarrow.parquet as pq\n", + "\n", + "\n", + "n_features = sae.hidden_dim\n", + "n_sequences = activations.shape[0]\n", + "valid_lens = masks.sum(dim=1).long()\n", + "\n", + "# Pass 1: max activation per (sequence, feature)\n", + "print(\"Pass 1: Computing max activations per sequence...\")\n", + "max_acts = torch.zeros(n_sequences, n_features)\n", + "for i in tqdm(range(n_sequences), desc=\"Max activations\"):\n", + " vl = int(valid_lens[i].item())\n", + " if vl == 0:\n", + " continue\n", + " emb = activations[i, :vl, :].to(device)\n", + " with torch.no_grad():\n", + " _, codes = sae(emb)\n", + " max_acts[i] = codes.max(dim=0).values.cpu()\n", + "\n", + "# Find top examples per feature\n", + "print(\"Finding top examples per feature...\")\n", + "top_indices = torch.topk(max_acts, k=min(N_EXAMPLES, n_sequences), dim=0).indices\n", + "\n", + "# Build reverse index\n", + "needed_sequences = {}\n", + "for feat_idx in range(n_features):\n", + " for rank in range(top_indices.shape[0]):\n", + " seq_idx = int(top_indices[rank, feat_idx].item())\n", + " if seq_idx not in needed_sequences:\n", + " needed_sequences[seq_idx] = set()\n", + " needed_sequences[seq_idx].add(feat_idx)\n", + "\n", + "# Pass 2: extract per-codon activations for top examples\n", + "print(f\"Pass 2: Extracting per-codon activations ({len(needed_sequences)} sequences)...\")\n", + "example_acts = {}\n", + "for seq_idx in tqdm(sorted(needed_sequences.keys()), desc=\"Per-codon activations\"):\n", + " vl = int(valid_lens[seq_idx].item())\n", + " if vl == 0:\n", + " continue\n", + " emb = activations[seq_idx, :vl, :].to(device)\n", + " with torch.no_grad():\n", + " _, codes = sae(emb)\n", + " codes_cpu = codes.cpu()\n", + " for feat_idx in needed_sequences[seq_idx]:\n", + " example_acts[(seq_idx, feat_idx)] = codes_cpu[:, feat_idx].numpy().tolist()\n", + "\n", + "# Write feature_metadata.parquet\n", + "print(\"Writing feature_metadata.parquet...\")\n", + "meta_rows = []\n", + "for feat_idx in range(n_features):\n", + " freq = (max_acts[:, feat_idx] > 0).float().mean().item()\n", + " max_val = max_acts[:, feat_idx].max().item()\n", + " meta_rows.append(\n", + " {\n", + " \"feature_id\": feat_idx,\n", + " \"description\": f\"Feature {feat_idx}\",\n", + " \"activation_freq\": freq,\n", + " \"max_activation\": max_val,\n", + " }\n", + " )\n", + "\n", + "meta_table = pa.table(\n", + " {\n", + " \"feature_id\": pa.array([r[\"feature_id\"] for r in meta_rows], type=pa.int32()),\n", + " \"description\": pa.array([r[\"description\"] for r in meta_rows]),\n", + " \"activation_freq\": pa.array([r[\"activation_freq\"] for r in meta_rows], type=pa.float32()),\n", + " \"max_activation\": pa.array([r[\"max_activation\"] for r in meta_rows], type=pa.float32()),\n", + " }\n", + ")\n", + "pq.write_table(meta_table, output_dir / \"feature_metadata.parquet\", compression=\"snappy\")\n", + "\n", + "# Write feature_examples.parquet\n", + "print(\"Writing feature_examples.parquet...\")\n", + "example_rows = []\n", + "for feat_idx in range(n_features):\n", + " for rank in range(top_indices.shape[0]):\n", + " seq_idx = int(top_indices[rank, feat_idx].item())\n", + " key = (seq_idx, feat_idx)\n", + " if key not in example_acts:\n", + " continue\n", + " raw_seq = sequences[seq_idx]\n", + " n_codons = len(raw_seq) // 3\n", + " codon_seq = \" \".join(raw_seq[j * 3 : (j + 1) * 3] for j in range(n_codons))\n", + " example_rows.append(\n", + " {\n", + " \"feature_id\": feat_idx,\n", + " \"example_rank\": rank,\n", + " \"protein_id\": sequence_ids[seq_idx],\n", + " \"sequence\": codon_seq,\n", + " \"activations\": example_acts[key],\n", + " \"max_activation\": max(example_acts[key]) if example_acts[key] else 0.0,\n", + " }\n", + " )\n", + "\n", + "example_rows.sort(key=lambda r: (r[\"feature_id\"], r[\"example_rank\"]))\n", + "examples_table = pa.table(\n", + " {\n", + " \"feature_id\": pa.array([r[\"feature_id\"] for r in example_rows], type=pa.int32()),\n", + " \"example_rank\": pa.array([r[\"example_rank\"] for r in example_rows], type=pa.int8()),\n", + " \"protein_id\": pa.array([r[\"protein_id\"] for r in example_rows]),\n", + " \"sequence\": pa.array([r[\"sequence\"] for r in example_rows]),\n", + " \"activations\": pa.array([r[\"activations\"] for r in example_rows], type=pa.list_(pa.float32())),\n", + " \"max_activation\": pa.array([r[\"max_activation\"] for r in example_rows], type=pa.float32()),\n", + " }\n", + ")\n", + "pq.write_table(\n", + " examples_table, output_dir / \"feature_examples.parquet\", row_group_size=N_EXAMPLES * 100, compression=\"snappy\"\n", + ")\n", + "\n", + "print(f\"Wrote {len(meta_rows)} features, {len(example_rows)} examples\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Enrich Atlas with Analysis Results\n", + "\n", + "If you've already run notebooks 02 (codon analysis) and 03 (gene enrichment), their output files can be loaded and merged into the atlas. This adds columns like `gsea_overall_best`, `gsea_GO_Biological_Process`, etc., which become available as color/filter dimensions in the dashboard." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pyarrow as pa\n", + "import pyarrow.parquet as pq\n", + "\n", + "\n", + "atlas_path = output_dir / \"features_atlas.parquet\"\n", + "table = pq.read_table(atlas_path)\n", + "n = table.num_rows\n", + "print(f\"Atlas has {n} features, {len(table.column_names)} columns\")\n", + "\n", + "# Add GSEA columns if gene_enrichment_report.json exists\n", + "gsea_report_path = Path(\"../outputs/1b_layer16/gene_enrichment/gene_enrichment_report.json\")\n", + "if gsea_report_path.exists():\n", + " with open(gsea_report_path) as f:\n", + " gsea_data = json.load(f)\n", + "\n", + " # Build label lookup: feature_idx -> best term name\n", + " best_labels = {}\n", + " for entry in gsea_data.get(\"per_feature\", []):\n", + " idx = entry[\"feature_idx\"]\n", + " ob = entry.get(\"overall_best\")\n", + " if ob and ob.get(\"term_name\"):\n", + " best_labels[idx] = ob[\"term_name\"]\n", + "\n", + " gsea_col = [best_labels.get(i, \"unlabeled\") for i in range(n)]\n", + " if \"gsea_overall_best\" in table.column_names:\n", + " table = table.drop(\"gsea_overall_best\")\n", + " table = table.append_column(\"gsea_overall_best\", pa.array(gsea_col))\n", + " print(f\" Added GSEA labels: {sum(1 for l in gsea_col if l != 'unlabeled')} features labeled\")\n", + "\n", + " pq.write_table(table, atlas_path, compression=\"snappy\")\n", + " print(f\" Updated {atlas_path}\")\n", + "else:\n", + " print(f\" No GSEA report found at {gsea_report_path}\")\n", + " print(\" Run notebook 03_gene_enrichment.ipynb first to add GSEA labels.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Launch Dashboard\n", + "\n", + "The dashboard is a React app served by a Python backend. Point it at the output directory containing the parquet files." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Dashboard files generated:\")\n", + "print(f\" {output_dir / 'features_atlas.parquet'}\")\n", + "print(f\" {output_dir / 'feature_metadata.parquet'}\")\n", + "print(f\" {output_dir / 'feature_examples.parquet'}\")\n", + "print()\n", + "print(\"To launch the dashboard, run:\")\n", + "print(f\" python scripts/launch_dashboard.py --data-dir {output_dir}\")\n", + "print()\n", + "print(\"Or for a production run with variant analysis (uses the dashboard.py script):\")\n", + "print(\" python scripts/dashboard.py \\\\\")\n", + "print(\" --checkpoint outputs/1b_layer16/checkpoints/checkpoint_final.pt \\\\\")\n", + "print(\" --model-path $MODEL_PATH \\\\\")\n", + "print(\" --csv-path $CSV_PATH \\\\\")\n", + "print(\" --layer 16 --num-sequences 2000 \\\\\")\n", + "print(\" --output-dir outputs/1b_layer16/dashboard\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Next Steps\n", + "\n", + "- **05_auto_interp.ipynb** — Generate natural-language feature descriptions with an LLM\n", + "- Explore the dashboard interactively: color the UMAP by GSEA labels, click features to see their top sequences, filter by activation frequency" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/notebooks/05_auto_interp.ipynb b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/notebooks/05_auto_interp.ipynb new file mode 100644 index 0000000000..aa0f21f3ab --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/notebooks/05_auto_interp.ipynb @@ -0,0 +1,721 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# CodonFM SAE — Automated Feature Interpretation\n", + "\n", + "The previous notebooks give us *quantitative* descriptions of features: which codons they promote, what GSEA terms they're enriched for, their activation frequency. But these are still fragmented signals — a human has to synthesize \"this feature promotes GCC/GCT, fires on ribosomal protein genes, and has high tAI correlation\" into \"this feature detects optimally-translated codons in highly-expressed housekeeping genes.\"\n", + "\n", + "**Auto-interpretation** automates that synthesis step by sending all available evidence about a feature to a large language model and asking it to produce:\n", + "1. A **description** (2-3 sentences) of what the feature detects and why.\n", + "2. A **label** (one concise phrase) suitable for a dashboard tooltip.\n", + "3. A **confidence score** (0.0-1.0) reflecting how interpretable the pattern is.\n", + "\n", + "The prompt includes:\n", + "- **Top promoted/suppressed codons** from the decoder weight projection through the LM head\n", + "- **Top activating sequences** with per-codon highlighting (codons where the feature fires strongly are marked)\n", + "- **Gene metadata** (gene names, pathogenicity labels, variant info) for each example\n", + "- **GSEA enrichment context** (if available from notebook 03) — the top enriched biological terms\n", + "\n", + "This is expensive (one LLM call per feature), so we demo on a small subset here. The `analyze.py` script handles production-scale runs with checkpointing and parallelism." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Config\n", + "SAE_CHECKPOINT = \"../outputs/1b_layer16/checkpoints/checkpoint_final.pt\"\n", + "MODEL_PATH = \"../../../../../../../checkpoints/NV-CodonFM-Encodon-TE-Cdwt-1B-v1\"\n", + "CSV_PATH = \"../../../../../../../codonfm/data/codonfm_ref_only.csv\"\n", + "LAYER = 16\n", + "CONTEXT_LENGTH = 2048\n", + "BATCH_SIZE = 8\n", + "DEVICE = \"cuda\"\n", + "\n", + "# Auto-interp specific\n", + "LLM_PROVIDER = \"nvidia-internal\" # Options: \"anthropic\", \"openai\", \"nim\", \"nvidia-internal\"\n", + "NUM_FEATURES_TO_INTERPRET = 20 # Small subset for demo\n", + "NUM_SEQUENCES = 500 # Enough to find good examples" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import sys\n", + "from pathlib import Path\n", + "\n", + "import numpy as np\n", + "import torch\n", + "\n", + "\n", + "_REPO_ROOT = Path(\"..\").resolve().parent.parent.parent.parent.parent\n", + "_CODONFM_TE_DIR = _REPO_ROOT / \"recipes\" / \"codonfm_ptl_te\"\n", + "sys.path.insert(0, str(_CODONFM_TE_DIR))\n", + "sys.path.insert(0, str(Path(\"..\").resolve()))\n", + "\n", + "from codonfm_sae.data import read_codon_csv\n", + "from sae.architectures import TopKSAE\n", + "from sae.utils import set_seed\n", + "from src.data.preprocess.codon_sequence import process_item\n", + "from src.inference.encodon import EncodonInference\n", + "\n", + "\n", + "set_seed(42)\n", + "device = DEVICE if torch.cuda.is_available() else \"cpu\"\n", + "print(f\"Device: {device}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load SAE, Model, and Data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ckpt = torch.load(SAE_CHECKPOINT, map_location=\"cpu\", weights_only=False)\n", + "state_dict = ckpt[\"model_state_dict\"]\n", + "if any(k.startswith(\"module.\") for k in state_dict):\n", + " state_dict = {k.removeprefix(\"module.\"): v for k, v in state_dict.items()}\n", + "\n", + "input_dim = ckpt.get(\"input_dim\") or state_dict[\"encoder.weight\"].shape[1]\n", + "hidden_dim = ckpt.get(\"hidden_dim\") or state_dict[\"encoder.weight\"].shape[0]\n", + "model_config = ckpt.get(\"model_config\", {})\n", + "top_k = model_config.get(\"top_k\")\n", + "\n", + "sae = TopKSAE(\n", + " input_dim=input_dim,\n", + " hidden_dim=hidden_dim,\n", + " top_k=top_k,\n", + " normalize_input=model_config.get(\"normalize_input\", False),\n", + ")\n", + "sae.load_state_dict(state_dict)\n", + "sae = sae.eval().to(device)\n", + "\n", + "print(f\"SAE: {input_dim} -> {hidden_dim:,} features (top-{top_k})\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "inference = EncodonInference(\n", + " model_path=MODEL_PATH,\n", + " task_type=\"embedding_prediction\",\n", + " use_transformer_engine=True,\n", + ")\n", + "inference.configure_model()\n", + "inference.model.to(device).eval()\n", + "\n", + "num_layers = len(inference.model.model.layers)\n", + "target_layer = LAYER if LAYER >= 0 else num_layers + LAYER\n", + "print(f\"Encodon: {num_layers} layers, target layer {target_layer}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "records = read_codon_csv(CSV_PATH, max_sequences=NUM_SEQUENCES, max_codons=CONTEXT_LENGTH - 2)\n", + "sequences = [r.sequence for r in records]\n", + "print(f\"Loaded {len(sequences)} sequences\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Compute Vocabulary Logits\n", + "\n", + "Before building prompts, we need to know which codons each feature promotes and suppresses. This comes from projecting the SAE decoder weights through the model's LM head — the same vocabulary projection the model uses to predict the next codon." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "CODON_TO_AA = {\n", + " \"TTT\": \"F\",\n", + " \"TTC\": \"F\",\n", + " \"TTA\": \"L\",\n", + " \"TTG\": \"L\",\n", + " \"CTT\": \"L\",\n", + " \"CTC\": \"L\",\n", + " \"CTA\": \"L\",\n", + " \"CTG\": \"L\",\n", + " \"ATT\": \"I\",\n", + " \"ATC\": \"I\",\n", + " \"ATA\": \"I\",\n", + " \"ATG\": \"M\",\n", + " \"GTT\": \"V\",\n", + " \"GTC\": \"V\",\n", + " \"GTA\": \"V\",\n", + " \"GTG\": \"V\",\n", + " \"TCT\": \"S\",\n", + " \"TCC\": \"S\",\n", + " \"TCA\": \"S\",\n", + " \"TCG\": \"S\",\n", + " \"CCT\": \"P\",\n", + " \"CCC\": \"P\",\n", + " \"CCA\": \"P\",\n", + " \"CCG\": \"P\",\n", + " \"ACT\": \"T\",\n", + " \"ACC\": \"T\",\n", + " \"ACA\": \"T\",\n", + " \"ACG\": \"T\",\n", + " \"GCT\": \"A\",\n", + " \"GCC\": \"A\",\n", + " \"GCA\": \"A\",\n", + " \"GCG\": \"A\",\n", + " \"TAT\": \"Y\",\n", + " \"TAC\": \"Y\",\n", + " \"TAA\": \"*\",\n", + " \"TAG\": \"*\",\n", + " \"CAT\": \"H\",\n", + " \"CAC\": \"H\",\n", + " \"CAA\": \"Q\",\n", + " \"CAG\": \"Q\",\n", + " \"AAT\": \"N\",\n", + " \"AAC\": \"N\",\n", + " \"AAA\": \"K\",\n", + " \"AAG\": \"K\",\n", + " \"GAT\": \"D\",\n", + " \"GAC\": \"D\",\n", + " \"GAA\": \"E\",\n", + " \"GAG\": \"E\",\n", + " \"TGT\": \"C\",\n", + " \"TGC\": \"C\",\n", + " \"TGA\": \"*\",\n", + " \"TGG\": \"W\",\n", + " \"CGT\": \"R\",\n", + " \"CGC\": \"R\",\n", + " \"CGA\": \"R\",\n", + " \"CGG\": \"R\",\n", + " \"AGT\": \"S\",\n", + " \"AGC\": \"S\",\n", + " \"AGA\": \"R\",\n", + " \"AGG\": \"R\",\n", + " \"GGT\": \"G\",\n", + " \"GGC\": \"G\",\n", + " \"GGA\": \"G\",\n", + " \"GGG\": \"G\",\n", + "}\n", + "\n", + "tokenizer = inference.tokenizer\n", + "codon_tokens = {}\n", + "for codon in CODON_TO_AA:\n", + " tok_id = tokenizer.token_to_id(codon)\n", + " if tok_id is not None:\n", + " codon_tokens[codon] = tok_id\n", + "\n", + "# Project decoder through LM head\n", + "encodon = inference.model.model\n", + "lm_head = encodon.cls\n", + "W_dec = sae.decoder.weight.to(device)\n", + "\n", + "with torch.no_grad():\n", + " logits = lm_head(W_dec.T) # (n_features, vocab_size)\n", + "\n", + "# Baseline logits (from mean activation)\n", + "mean_acts = sae.pre_bias.data.float().to(device) if hasattr(sae, \"pre_bias\") else torch.zeros(input_dim, device=device)\n", + "with torch.no_grad():\n", + " baseline = lm_head(mean_acts.unsqueeze(0)).squeeze(0)\n", + "\n", + "logit_diff = logits - baseline.unsqueeze(0) # (n_features, vocab_size)\n", + "\n", + "print(f\"Computed vocab logit diffs: {logit_diff.shape}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prepare Feature Examples\n", + "\n", + "For each feature we want to interpret, we need to find the sequences where it fires most strongly and format them for the prompt. Each example shows the codon sequence with high-activation positions highlighted using `***markers***`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from tqdm import tqdm\n", + "\n", + "\n", + "# First, find the most frequently-firing features to interpret\n", + "n_features = sae.hidden_dim\n", + "fire_counts = np.zeros(n_features, dtype=np.int64)\n", + "max_activations = np.zeros(n_features, dtype=np.float32)\n", + "# Also store per-sequence max acts for finding top examples\n", + "seq_max_acts = np.zeros((len(sequences), n_features), dtype=np.float32)\n", + "\n", + "with torch.no_grad():\n", + " for i in tqdm(range(0, len(sequences), BATCH_SIZE), desc=\"Computing activations\"):\n", + " batch_seqs = sequences[i : i + BATCH_SIZE]\n", + " items = [process_item(s, context_length=CONTEXT_LENGTH, tokenizer=tokenizer) for s in batch_seqs]\n", + " batch = {\n", + " \"input_ids\": torch.tensor(np.stack([it[\"input_ids\"] for it in items])).to(device),\n", + " \"attention_mask\": torch.tensor(np.stack([it[\"attention_mask\"] for it in items])).to(device),\n", + " }\n", + " out = inference.model(batch, return_hidden_states=True)\n", + " hidden = out.all_hidden_states[LAYER]\n", + "\n", + " for j, it in enumerate(items):\n", + " seq_len = it[\"attention_mask\"].sum()\n", + " emb = hidden[j, 1 : seq_len - 1, :].float()\n", + " codes = sae.encode(emb)\n", + " active = (codes > 0).cpu().numpy()\n", + " fire_counts += active.sum(axis=0)\n", + " seq_max = codes.max(dim=0).values.cpu().numpy()\n", + " np.maximum(max_activations, seq_max, out=max_activations)\n", + " seq_max_acts[i + j] = seq_max\n", + "\n", + " del out, hidden, batch\n", + " torch.cuda.empty_cache()\n", + "\n", + "# Select top features by activation frequency\n", + "alive = fire_counts > 0\n", + "alive_indices = np.where(alive)[0]\n", + "sorted_by_freq = alive_indices[np.argsort(fire_counts[alive_indices])[::-1]]\n", + "features_to_interpret = sorted_by_freq[:NUM_FEATURES_TO_INTERPRET].tolist()\n", + "\n", + "print(f\"{alive.sum()} alive features, interpreting top {len(features_to_interpret)} by frequency\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# For each feature, get the top-5 sequences and extract per-codon activations\n", + "N_EXAMPLES = 5\n", + "\n", + "feature_examples = {} # feat_idx -> list of (max_act, seq_idx, per_codon_acts)\n", + "\n", + "for feat_idx in tqdm(features_to_interpret, desc=\"Collecting examples\"):\n", + " # Find top sequences for this feature\n", + " top_seq_indices = np.argsort(seq_max_acts[:, feat_idx])[::-1][:N_EXAMPLES]\n", + "\n", + " examples = []\n", + " for seq_idx in top_seq_indices:\n", + " if seq_max_acts[seq_idx, feat_idx] == 0:\n", + " continue\n", + " seq = sequences[seq_idx]\n", + " items = [process_item(seq, context_length=CONTEXT_LENGTH, tokenizer=tokenizer)]\n", + " batch = {\n", + " \"input_ids\": torch.tensor(np.stack([it[\"input_ids\"] for it in items])).to(device),\n", + " \"attention_mask\": torch.tensor(np.stack([it[\"attention_mask\"] for it in items])).to(device),\n", + " }\n", + " with torch.no_grad():\n", + " out = inference.model(batch, return_hidden_states=True)\n", + " hidden = out.all_hidden_states[LAYER]\n", + " seq_len = items[0][\"attention_mask\"].sum()\n", + " emb = hidden[0, 1 : seq_len - 1, :].float()\n", + " _, codes = sae(emb)\n", + " acts = codes[:, feat_idx].cpu().numpy()\n", + "\n", + " examples.append((float(acts.max()), int(seq_idx), acts))\n", + " del out, batch\n", + "\n", + " feature_examples[feat_idx] = examples\n", + "\n", + "torch.cuda.empty_cache()\n", + "print(f\"Collected examples for {len(feature_examples)} features\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Build the Interpretation Prompt\n", + "\n", + "This is the core of auto-interp: assembling all available evidence into a structured prompt. The LLM receives:\n", + "\n", + "1. **Context** about the model (CodonFM, a DNA codon language model)\n", + "2. **Decoder logit analysis** — which codons the feature promotes/suppresses\n", + "3. **Top activating sequences** with highlighted codons and metadata\n", + "4. **GSEA enrichment** (if available) — biological annotations\n", + "5. **Output format** instructions\n", + "\n", + "Let's build a prompt for one feature to see what it looks like." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def build_prompt(\n", + " feat_idx, examples, logit_diff_tensor, codon_tokens_map, records_list, sequences_list, gsea_context=None\n", + "):\n", + " \"\"\"Build the auto-interp prompt for a single feature.\"\"\"\n", + " # Top promoted/suppressed codons\n", + " feat_logits = {c: float(logit_diff_tensor[feat_idx, tid]) for c, tid in codon_tokens_map.items()}\n", + " sorted_codons = sorted(feat_logits.items(), key=lambda x: x[1], reverse=True)\n", + " top_pos = sorted_codons[:8]\n", + " top_neg = sorted_codons[-8:][::-1]\n", + "\n", + " pos_str = \", \".join(f\"{c}({CODON_TO_AA[c]})={v:.2f}\" for c, v in top_pos)\n", + " neg_str = \", \".join(f\"{c}({CODON_TO_AA[c]})={v:.2f}\" for c, v in top_neg)\n", + "\n", + " # Format example sequences\n", + " examples_parts = []\n", + " for rank, (max_act, seq_idx, acts) in enumerate(examples[:5]):\n", + " seq = sequences_list[seq_idx]\n", + " vl = len(acts)\n", + " codons = [seq[j * 3 : (j + 1) * 3] for j in range(vl)]\n", + "\n", + " threshold = np.percentile(acts[acts > 0], 80) if (acts > 0).sum() > 0 else 0\n", + " marked = []\n", + " for j, (codon, act) in enumerate(zip(codons, acts)):\n", + " aa = CODON_TO_AA.get(codon.upper(), \"?\")\n", + " if act > threshold:\n", + " marked.append(f\"***{codon}({aa})***\")\n", + " else:\n", + " marked.append(f\"{codon}({aa})\")\n", + "\n", + " # Metadata\n", + " meta_str = \"\"\n", + " if records_list and seq_idx < len(records_list):\n", + " m = records_list[seq_idx].metadata\n", + " meta_parts = []\n", + " gene = m.get(\"gene\")\n", + " if gene:\n", + " meta_parts.append(f\"gene={gene}\")\n", + " is_path = m.get(\"is_pathogenic\")\n", + " if is_path is not None:\n", + " meta_parts.append(f\"pathogenic={is_path}\")\n", + " if meta_parts:\n", + " meta_str = f\" [{', '.join(meta_parts)}]\"\n", + "\n", + " # Show first 60 codons to keep prompt manageable\n", + " codon_str = \" \".join(marked[:60])\n", + " if len(marked) > 60:\n", + " codon_str += f\" ... ({len(marked)} codons total)\"\n", + "\n", + " examples_parts.append(f\"Example {rank + 1} (max_act={max_act:.2f}){meta_str}:\\n{codon_str}\")\n", + "\n", + " examples_str = \"\\n\\n\".join(examples_parts)\n", + "\n", + " # GSEA context\n", + " gsea_str = \"Not available (run notebook 03 first)\"\n", + " if gsea_context and str(feat_idx) in gsea_context:\n", + " ctx = gsea_context[str(feat_idx)]\n", + " gsea_str = ctx if isinstance(ctx, str) else json.dumps(ctx, indent=2)\n", + "\n", + " prompt = f\"\"\"Analyze this sparse autoencoder feature from a DNA codon language model (CodonFM).\n", + "\n", + "CodonFM is trained on coding DNA sequences (codons = 3-nucleotide triplets that encode amino acids).\n", + "This SAE feature was learned from the model's internal representations. Your task is to identify\n", + "what biological pattern this feature detects.\n", + "\n", + "Top promoted codons (decoder logits, higher = feature promotes this codon):\n", + "{pos_str}\n", + "\n", + "Top suppressed codons (decoder logits, lower = feature suppresses this codon):\n", + "{neg_str}\n", + "\n", + "Top activating sequences (***highlighted*** = high activation codons):\n", + "\n", + "{examples_str}\n", + "\n", + "Gene-level GSEA enrichment:\n", + "{gsea_str}\n", + "\n", + "Based on all available evidence, describe what this feature detects.\n", + "Consider: codon usage bias, amino acid preferences, gene family specificity,\n", + "sequence composition patterns, biological pathway associations.\n", + "\n", + "Format your response EXACTLY as:\n", + "Description: <2-3 sentences explaining the activation pattern and its biological significance>\n", + "Label: \n", + "Confidence: <0.00 to 1.00, how confident you are in this interpretation>\"\"\"\n", + "\n", + " return prompt\n", + "\n", + "\n", + "# Build and display a prompt for the first feature\n", + "example_feat = features_to_interpret[0]\n", + "example_prompt = build_prompt(\n", + " example_feat,\n", + " feature_examples[example_feat],\n", + " logit_diff,\n", + " codon_tokens,\n", + " records,\n", + " sequences,\n", + ")\n", + "print(f\"Prompt for feature {example_feat} ({len(example_prompt)} chars):\")\n", + "print(\"=\" * 80)\n", + "print(example_prompt)\n", + "print(\"=\" * 80)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Run LLM Interpretation\n", + "\n", + "Send the prompt to an LLM. The `sae.autointerp` module provides clients for multiple providers:\n", + "- `AnthropicClient` — Claude (requires `ANTHROPIC_API_KEY`)\n", + "- `OpenAIClient` — GPT-4 (requires `OPENAI_API_KEY`)\n", + "- `NIMClient` — NVIDIA NIM (requires `NIM_API_KEY`)\n", + "- `NVIDIAInternalClient` — Internal NVIDIA endpoint (requires `CLAUDE_SONNET_INFERENCE_API_KEY`)\n", + "\n", + "Uncomment the appropriate client below and ensure the environment variable is set." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Uncomment the client for your LLM provider:\n", + "# client = NVIDIAInternalClient() # Requires CLAUDE_SONNET_INFERENCE_API_KEY env var\n", + "# client = AnthropicClient() # Requires ANTHROPIC_API_KEY env var\n", + "# client = OpenAIClient() # Requires OPENAI_API_KEY env var\n", + "# client = NIMClient() # Requires NIM_API_KEY env var\n", + "\n", + "# For demo purposes, we'll show the prompt without calling the LLM.\n", + "# To actually run interpretation, uncomment one of the clients above and\n", + "# uncomment the code in the next cell.\n", + "print(\"LLM client setup. Uncomment the appropriate client above to run interpretation.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load GSEA context if available (from notebook 03)\n", + "gsea_context = None\n", + "gsea_report_path = Path(\"../outputs/1b_layer16/gene_enrichment/gene_enrichment_report.json\")\n", + "if gsea_report_path.exists():\n", + " with open(gsea_report_path) as f:\n", + " gsea_data = json.load(f)\n", + " gsea_context = {}\n", + " for entry in gsea_data.get(\"per_feature\", []):\n", + " idx = str(entry[\"feature_idx\"])\n", + " ob = entry.get(\"overall_best\")\n", + " if ob:\n", + " gsea_context[idx] = f\"{ob['term_name']} ({ob['database']}, FDR={ob['fdr']:.4f})\"\n", + " print(f\"Loaded GSEA context for {len(gsea_context)} features\")\n", + "else:\n", + " print(\"No GSEA report found. Run notebook 03 first for richer prompts.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Uncomment to run LLM interpretation on a few features:\n", + "#\n", + "# results = {}\n", + "# for feat_idx in features_to_interpret[:5]: # Start with just 5\n", + "# prompt = build_prompt(\n", + "# feat_idx,\n", + "# feature_examples[feat_idx],\n", + "# logit_diff,\n", + "# codon_tokens,\n", + "# records,\n", + "# sequences,\n", + "# gsea_context=gsea_context,\n", + "# )\n", + "# response = client.generate(prompt)\n", + "# results[feat_idx] = response.text\n", + "# print(f\"\\nFeature {feat_idx}:\")\n", + "# print(response.text)\n", + "# print(\"-\" * 40)\n", + "\n", + "print(\"Uncomment the cell above to run LLM interpretation.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Parse Results\n", + "\n", + "The LLM response follows a structured format. We extract the label, description, and confidence score from each response." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import re\n", + "\n", + "\n", + "def parse_interp_response(text):\n", + " \"\"\"Extract description, label, and confidence from LLM response.\"\"\"\n", + " description = \"\"\n", + " label = \"\"\n", + " confidence = 0.0\n", + "\n", + " for line in text.strip().split(\"\\n\"):\n", + " line = line.strip()\n", + " if line.lower().startswith(\"description:\"):\n", + " description = line.split(\":\", 1)[1].strip()\n", + " elif line.lower().startswith(\"label:\"):\n", + " label = line.split(\":\", 1)[1].strip()\n", + " elif line.lower().startswith(\"confidence:\"):\n", + " try:\n", + " confidence = float(re.search(r\"[\\d.]+\", line.split(\":\", 1)[1]).group())\n", + " except (ValueError, AttributeError):\n", + " confidence = 0.0\n", + "\n", + " return {\"description\": description, \"label\": label, \"confidence\": confidence}\n", + "\n", + "\n", + "# Example parsing (using a mock response since we may not have run the LLM)\n", + "mock_response = \"\"\"Description: This feature detects GC-rich optimal codons (GCC, GCG, CTG) that are preferentially used in highly expressed human genes. The highlighted positions cluster in regions encoding alanine and leucine with strong codon usage bias toward tRNA-abundant codons.\n", + "Label: GC-rich optimal codon usage\n", + "Confidence: 0.85\"\"\"\n", + "\n", + "parsed = parse_interp_response(mock_response)\n", + "print(f\"Label: {parsed['label']}\")\n", + "print(f\"Confidence: {parsed['confidence']}\")\n", + "print(f\"Description: {parsed['description']}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example Results\n", + "\n", + "Below is what interpreted features look like. Each feature gets a human-readable label that can be displayed in the dashboard, along with a confidence score indicating how interpretable the pattern is.\n", + "\n", + "Features with high confidence (>0.7) typically have clear codon usage or gene family patterns. Low confidence features may detect more subtle or combinatorial patterns that are harder to describe in words." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# If you ran the LLM interpretation, parse and display results:\n", + "#\n", + "# parsed_results = {}\n", + "# for feat_idx, text in results.items():\n", + "# parsed = parse_interp_response(text)\n", + "# parsed_results[feat_idx] = parsed\n", + "# print(f\"Feature {feat_idx}:\")\n", + "# print(f\" Label: {parsed['label']}\")\n", + "# print(f\" Confidence: {parsed['confidence']}\")\n", + "# print(f\" Description: {parsed['description'][:100]}...\")\n", + "# print()\n", + "#\n", + "# # Save results\n", + "# output_dir = Path(\"../outputs/1b_layer16/analysis\")\n", + "# output_dir.mkdir(parents=True, exist_ok=True)\n", + "# with open(output_dir / \"auto_interp_results.json\", \"w\") as f:\n", + "# json.dump({str(k): v for k, v in parsed_results.items()}, f, indent=2)\n", + "# print(f\"Saved {len(parsed_results)} interpretations\")\n", + "\n", + "print(\"Uncomment the cells above after running LLM interpretation.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Scaling Up\n", + "\n", + "Interpreting 20 features is fine for exploration, but a full SAE may have thousands of alive features. The `analyze.py` script handles this at scale with:\n", + "- **Checkpointing**: saves progress after each batch so you can resume if interrupted\n", + "- **Parallel LLM calls**: uses `--auto-interp-workers` to send multiple requests concurrently\n", + "- **GSEA context injection**: pass `--gsea-report` to include enrichment data in prompts\n", + "- **Dashboard integration**: pass `--dashboard-dir` to write labels directly to the atlas parquet" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"\"\"To run auto-interp on all features:\n", + "\n", + " python scripts/analyze.py \\\\\n", + " --checkpoint outputs/1b_layer16/checkpoints/checkpoint_final.pt \\\\\n", + " --model-path $MODEL_PATH \\\\\n", + " --csv-path $CSV_PATH \\\\\n", + " --layer 16 --auto-interp \\\\\n", + " --llm-provider nvidia-internal \\\\\n", + " --gsea-report outputs/1b_layer16/gene_enrichment/gene_enrichment_report.json \\\\\n", + " --auto-interp-workers 8 \\\\\n", + " --output-dir outputs/1b_layer16/analysis\n", + "\"\"\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Next Steps\n", + "\n", + "With all analysis notebooks complete, you have:\n", + "\n", + "1. **01_quickstart** — SAE health checks, loss recovered, basic feature stats\n", + "2. **02_codon_analysis** — Codon usage metrics (CAI, tAI, RSCU) per feature\n", + "3. **03_gene_enrichment** — GSEA labels, gene families, pLI scores\n", + "4. **04_dashboard** — Interactive visualization export\n", + "5. **05_auto_interp** — LLM-generated natural language descriptions\n", + "\n", + "Together these provide a comprehensive picture of what each SAE feature has learned about codon biology." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/notebooks/06_probing.ipynb b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/notebooks/06_probing.ipynb new file mode 100644 index 0000000000..973b535b48 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/notebooks/06_probing.ipynb @@ -0,0 +1,772 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a1b2c3d4", + "metadata": {}, + "source": [ + "# CodonFM SAE — Linear Probing\n", + "\n", + "Do individual SAE features encode interpretable biological properties? We test this by **linear probing**: for each candidate biological label (e.g., \"is this a rare codon?\"), we fit a simple logistic regression `sigmoid(w · z_i + b)` on every latent `z_i` independently and measure how well it predicts the label.\n", + "\n", + "If a single latent achieves high AUROC for a label, that latent has learned a clean representation of that concept — it's a biological detector. If no single latent works but a linear combination does, the concept is distributed across features.\n", + "\n", + "We implement two levels of probing:\n", + "\n", + "**Codon-level probes** — one label per codon position:\n", + "- Rare codon detection (bottom quartile CAI for its amino acid)\n", + "- Start codon (ATG)\n", + "- GC-rich context (local GC content > 0.6)\n", + "- Wobble position GC (3rd nucleotide is G or C)\n", + "\n", + "**Gene-level probes** — one label per gene (using mean-pooled activations):\n", + "- Housekeeping gene (constitutively expressed across tissues)\n", + "- High pLI (loss-of-function intolerant, pLI > 0.9)\n", + "\n", + "For each probe, we report the best single latent, its AUROC, and the full distribution of per-latent AUROCs." + ] + }, + { + "cell_type": "markdown", + "id": "b2c3d4e5", + "metadata": {}, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c3d4e5f6", + "metadata": {}, + "outputs": [], + "source": [ + "# ── Configure paths ──\n", + "SAE_CHECKPOINT = \"../outputs/1b_layer16/checkpoints/checkpoint_final.pt\"\n", + "MODEL_PATH = \"../../../../../../../checkpoints/NV-CodonFM-Encodon-TE-Cdwt-1B-v1\"\n", + "CSV_PATH = \"../../../../../../../codonfm/data/codonfm_ref_only.csv\"\n", + "LAYER = 16\n", + "CONTEXT_LENGTH = 2048\n", + "BATCH_SIZE = 8\n", + "NUM_SEQUENCES = 1000 # Use more for reliable results; 1000 is fast for iteration\n", + "DEVICE = \"cuda\"\n", + "\n", + "# Optional: path to gnomAD pLI file for gene-level probing\n", + "PLI_PATH = None # e.g., \"../datasets/gnomad.v2.1.1.lof_metrics.by_gene.txt.bgz\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d4e5f6a7", + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "from collections import defaultdict\n", + "from pathlib import Path\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import torch\n", + "from sklearn.linear_model import LogisticRegression\n", + "from sklearn.metrics import roc_auc_score\n", + "from tqdm import tqdm\n", + "\n", + "\n", + "_REPO_ROOT = Path(\"..\").resolve().parent.parent.parent.parent.parent\n", + "_CODONFM_TE_DIR = _REPO_ROOT / \"recipes\" / \"codonfm_ptl_te\"\n", + "sys.path.insert(0, str(_CODONFM_TE_DIR))\n", + "sys.path.insert(0, str(Path(\"..\").resolve()))\n", + "\n", + "from codonfm_sae.data import read_codon_csv\n", + "from sae.architectures import TopKSAE\n", + "from sae.utils import set_seed\n", + "from src.data.preprocess.codon_sequence import process_item\n", + "from src.inference.encodon import EncodonInference\n", + "\n", + "\n", + "set_seed(42)\n", + "device = DEVICE if torch.cuda.is_available() else \"cpu\"\n", + "print(f\"Device: {device}\")" + ] + }, + { + "cell_type": "markdown", + "id": "e5f6a7b8", + "metadata": {}, + "source": [ + "## Load SAE, Model, and Data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f6a7b8c9", + "metadata": {}, + "outputs": [], + "source": [ + "# Load SAE\n", + "ckpt = torch.load(SAE_CHECKPOINT, map_location=\"cpu\", weights_only=False)\n", + "state_dict = ckpt[\"model_state_dict\"]\n", + "if any(k.startswith(\"module.\") for k in state_dict):\n", + " state_dict = {k.removeprefix(\"module.\"): v for k, v in state_dict.items()}\n", + "input_dim = ckpt.get(\"input_dim\") or state_dict[\"encoder.weight\"].shape[1]\n", + "hidden_dim = ckpt.get(\"hidden_dim\") or state_dict[\"encoder.weight\"].shape[0]\n", + "model_config = ckpt.get(\"model_config\", {})\n", + "sae = TopKSAE(\n", + " input_dim=input_dim,\n", + " hidden_dim=hidden_dim,\n", + " top_k=model_config.get(\"top_k\"),\n", + " normalize_input=model_config.get(\"normalize_input\", False),\n", + ")\n", + "sae.load_state_dict(state_dict)\n", + "sae = sae.eval().to(device)\n", + "n_features = sae.hidden_dim\n", + "print(f\"SAE: {input_dim} \\u2192 {hidden_dim:,} features (top-{model_config.get('top_k')})\")\n", + "\n", + "# Load Encodon\n", + "inference = EncodonInference(model_path=MODEL_PATH, task_type=\"embedding_prediction\", use_transformer_engine=True)\n", + "inference.configure_model()\n", + "inference.model.to(device).eval()\n", + "print(f\"Encodon: {len(inference.model.model.layers)} layers\")\n", + "\n", + "# Load data\n", + "records = read_codon_csv(CSV_PATH, max_sequences=NUM_SEQUENCES, max_codons=CONTEXT_LENGTH - 2)\n", + "sequences = [r.sequence for r in records]\n", + "gene_names = [r.metadata.get(\"gene\", \"\") for r in records]\n", + "print(f\"Loaded {len(sequences)} sequences\")" + ] + }, + { + "cell_type": "markdown", + "id": "a7b8c9d0", + "metadata": {}, + "source": [ + "## Extract SAE Activations\n", + "\n", + "We stream sequences through the Encodon model and SAE encoder, collecting per-codon activations. For codon-level probes we need the full `(n_codons, n_features)` matrix; for gene-level probes we mean-pool across codons per gene.\n", + "\n", + "To keep memory manageable, we process one sequence at a time and store only what we need: per-codon activations for codon-level probes, and running mean-pooled activations per gene for gene-level probes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b8c9d0e1", + "metadata": {}, + "outputs": [], + "source": [ + "# Collect per-codon SAE activations + metadata for probing\n", + "# We store: activations, codon strings, gene names, and sequence indices\n", + "# For memory, we cap total codons stored\n", + "\n", + "MAX_CODONS = 500_000 # Cap for codon-level probes\n", + "\n", + "all_activations = [] # list of numpy arrays, each (n_codons_in_seq, n_features)\n", + "all_codons = [] # list of codon string lists\n", + "all_gene_labels = [] # gene name per codon (repeated for each codon in gene)\n", + "total_codons = 0\n", + "\n", + "# Gene-level: accumulate mean activations per gene\n", + "gene_act_sum = defaultdict(lambda: np.zeros(n_features, dtype=np.float64))\n", + "gene_act_count = defaultdict(int)\n", + "\n", + "print(f\"Extracting SAE activations (max {MAX_CODONS:,} codons)...\")\n", + "with torch.no_grad():\n", + " for i in tqdm(range(0, len(sequences), BATCH_SIZE), desc=\"Extracting\"):\n", + " if total_codons >= MAX_CODONS:\n", + " break\n", + " batch_seqs = sequences[i : i + BATCH_SIZE]\n", + " batch_genes = gene_names[i : i + BATCH_SIZE]\n", + " items = [process_item(s, context_length=CONTEXT_LENGTH, tokenizer=inference.tokenizer) for s in batch_seqs]\n", + " batch = {\n", + " \"input_ids\": torch.tensor(np.stack([it[\"input_ids\"] for it in items])).to(device),\n", + " \"attention_mask\": torch.tensor(np.stack([it[\"attention_mask\"] for it in items])).to(device),\n", + " }\n", + " out = inference.model(batch, return_hidden_states=True)\n", + " hidden = out.all_hidden_states[LAYER].float()\n", + " attn = batch[\"attention_mask\"]\n", + "\n", + " # Build keep mask (exclude CLS at 0, SEP at last position)\n", + " keep = attn.clone()\n", + " keep[:, 0] = 0\n", + " lengths = attn.sum(dim=1)\n", + " for b in range(keep.shape[0]):\n", + " sep = int(lengths[b].item()) - 1\n", + " if sep > 0:\n", + " keep[b, sep] = 0\n", + "\n", + " for b in range(len(batch_seqs)):\n", + " if total_codons >= MAX_CODONS:\n", + " break\n", + " vl = int(keep[b].sum().item())\n", + " if vl == 0:\n", + " continue\n", + " emb = hidden[b, :vl, :]\n", + " codes = sae.encode(emb).cpu().numpy() # (vl, n_features)\n", + "\n", + " seq = batch_seqs[b]\n", + " codons = [seq[j * 3 : (j + 1) * 3].upper() for j in range(vl)]\n", + " gene = str(batch_genes[b]).strip() if batch_genes[b] else \"\"\n", + "\n", + " all_activations.append(codes)\n", + " all_codons.append(codons)\n", + " all_gene_labels.extend([gene] * vl)\n", + " total_codons += vl\n", + "\n", + " # Gene-level accumulation\n", + " if gene:\n", + " gene_act_sum[gene] += codes.sum(axis=0)\n", + " gene_act_count[gene] += vl\n", + "\n", + " del out, hidden, batch\n", + " torch.cuda.empty_cache()\n", + "\n", + "# Stack codon-level data\n", + "codon_acts = np.concatenate(all_activations, axis=0) # (total_codons, n_features)\n", + "codon_strings = []\n", + "for cl in all_codons:\n", + " codon_strings.extend(cl)\n", + "\n", + "print(f\"Collected {codon_acts.shape[0]:,} codons \\u00d7 {codon_acts.shape[1]:,} features\")\n", + "print(f\"Gene-level: {len(gene_act_sum)} unique genes\")" + ] + }, + { + "cell_type": "markdown", + "id": "c9d0e1f2", + "metadata": {}, + "source": [ + "## Define Probe Labels\n", + "\n", + "Each probe is a binary classification task. We construct labels from the codon sequences and metadata — no external data needed for the codon-level probes.\n", + "\n", + "### Codon-Level Probes\n", + "\n", + "| Probe | What it tests | Label = 1 when... |\n", + "|-------|---------------|-------------------|\n", + "| **Rare codon** | Does the feature detect translationally suboptimal codons? | CAI weight < 0.5 (bottom half for its amino acid) |\n", + "| **Start codon** | Does any feature specifically detect ATG? | Codon is ATG |\n", + "| **GC-rich wobble** | Does the feature track GC content at the wobble (3rd) position? | 3rd nucleotide is G or C |\n", + "| **CpG site** | Does the feature detect CpG dinucleotides across codon boundaries? | Last nt of this codon is C and first nt of next codon is G |" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d0e1f2a3", + "metadata": {}, + "outputs": [], + "source": [ + "# ── Codon optimality weights (same as notebook 02) ──\n", + "CODON_TO_AA = {\n", + " \"TTT\": \"F\",\n", + " \"TTC\": \"F\",\n", + " \"TTA\": \"L\",\n", + " \"TTG\": \"L\",\n", + " \"CTT\": \"L\",\n", + " \"CTC\": \"L\",\n", + " \"CTA\": \"L\",\n", + " \"CTG\": \"L\",\n", + " \"ATT\": \"I\",\n", + " \"ATC\": \"I\",\n", + " \"ATA\": \"I\",\n", + " \"ATG\": \"M\",\n", + " \"GTT\": \"V\",\n", + " \"GTC\": \"V\",\n", + " \"GTA\": \"V\",\n", + " \"GTG\": \"V\",\n", + " \"TCT\": \"S\",\n", + " \"TCC\": \"S\",\n", + " \"TCA\": \"S\",\n", + " \"TCG\": \"S\",\n", + " \"CCT\": \"P\",\n", + " \"CCC\": \"P\",\n", + " \"CCA\": \"P\",\n", + " \"CCG\": \"P\",\n", + " \"ACT\": \"T\",\n", + " \"ACC\": \"T\",\n", + " \"ACA\": \"T\",\n", + " \"ACG\": \"T\",\n", + " \"GCT\": \"A\",\n", + " \"GCC\": \"A\",\n", + " \"GCA\": \"A\",\n", + " \"GCG\": \"A\",\n", + " \"TAT\": \"Y\",\n", + " \"TAC\": \"Y\",\n", + " \"TAA\": \"*\",\n", + " \"TAG\": \"*\",\n", + " \"CAT\": \"H\",\n", + " \"CAC\": \"H\",\n", + " \"CAA\": \"Q\",\n", + " \"CAG\": \"Q\",\n", + " \"AAT\": \"N\",\n", + " \"AAC\": \"N\",\n", + " \"AAA\": \"K\",\n", + " \"AAG\": \"K\",\n", + " \"GAT\": \"D\",\n", + " \"GAC\": \"D\",\n", + " \"GAA\": \"E\",\n", + " \"GAG\": \"E\",\n", + " \"TGT\": \"C\",\n", + " \"TGC\": \"C\",\n", + " \"TGA\": \"*\",\n", + " \"TGG\": \"W\",\n", + " \"CGT\": \"R\",\n", + " \"CGC\": \"R\",\n", + " \"CGA\": \"R\",\n", + " \"CGG\": \"R\",\n", + " \"AGT\": \"S\",\n", + " \"AGC\": \"S\",\n", + " \"AGA\": \"R\",\n", + " \"AGG\": \"R\",\n", + " \"GGT\": \"G\",\n", + " \"GGC\": \"G\",\n", + " \"GGA\": \"G\",\n", + " \"GGG\": \"G\",\n", + "}\n", + "\n", + "HUMAN_CODON_USAGE = {\n", + " \"TTT\": 17.6,\n", + " \"TTC\": 20.3,\n", + " \"TTA\": 7.7,\n", + " \"TTG\": 12.9,\n", + " \"CTT\": 13.2,\n", + " \"CTC\": 19.6,\n", + " \"CTA\": 7.2,\n", + " \"CTG\": 39.6,\n", + " \"ATT\": 16.0,\n", + " \"ATC\": 20.8,\n", + " \"ATA\": 7.5,\n", + " \"ATG\": 22.0,\n", + " \"GTT\": 11.0,\n", + " \"GTC\": 14.5,\n", + " \"GTA\": 7.1,\n", + " \"GTG\": 28.1,\n", + " \"TCT\": 15.2,\n", + " \"TCC\": 17.7,\n", + " \"TCA\": 12.2,\n", + " \"TCG\": 4.4,\n", + " \"CCT\": 17.5,\n", + " \"CCC\": 19.8,\n", + " \"CCA\": 16.9,\n", + " \"CCG\": 6.9,\n", + " \"ACT\": 13.1,\n", + " \"ACC\": 18.9,\n", + " \"ACA\": 15.1,\n", + " \"ACG\": 6.1,\n", + " \"GCT\": 18.4,\n", + " \"GCC\": 27.7,\n", + " \"GCA\": 15.8,\n", + " \"GCG\": 7.4,\n", + " \"TAT\": 12.2,\n", + " \"TAC\": 15.3,\n", + " \"TAA\": 1.0,\n", + " \"TAG\": 0.8,\n", + " \"CAT\": 10.9,\n", + " \"CAC\": 15.1,\n", + " \"CAA\": 12.3,\n", + " \"CAG\": 34.2,\n", + " \"AAT\": 17.0,\n", + " \"AAC\": 19.1,\n", + " \"AAA\": 24.4,\n", + " \"AAG\": 31.9,\n", + " \"GAT\": 21.8,\n", + " \"GAC\": 25.1,\n", + " \"GAA\": 29.0,\n", + " \"GAG\": 39.6,\n", + " \"TGT\": 10.6,\n", + " \"TGC\": 12.6,\n", + " \"TGA\": 1.6,\n", + " \"TGG\": 13.2,\n", + " \"CGT\": 4.5,\n", + " \"CGC\": 10.4,\n", + " \"CGA\": 6.2,\n", + " \"CGG\": 11.4,\n", + " \"AGT\": 12.1,\n", + " \"AGC\": 19.5,\n", + " \"AGA\": 12.2,\n", + " \"AGG\": 12.0,\n", + " \"GGT\": 10.8,\n", + " \"GGC\": 22.2,\n", + " \"GGA\": 16.5,\n", + " \"GGG\": 16.5,\n", + "}\n", + "\n", + "# CAI weights: freq / max_freq for same amino acid\n", + "from collections import defaultdict as _dd\n", + "\n", + "\n", + "_aa_codons = _dd(list)\n", + "for c, aa in CODON_TO_AA.items():\n", + " if aa != \"*\":\n", + " _aa_codons[aa].append(c)\n", + "\n", + "CAI_WEIGHTS = {}\n", + "for aa, codons in _aa_codons.items():\n", + " freqs = [HUMAN_CODON_USAGE[c] for c in codons]\n", + " mx = max(freqs)\n", + " for c, f in zip(codons, freqs):\n", + " CAI_WEIGHTS[c] = f / mx\n", + "\n", + "# ── Build codon-level probe labels ──\n", + "n_codons_total = len(codon_strings)\n", + "\n", + "probes = {}\n", + "\n", + "# 1. Rare codon: CAI weight < 0.5\n", + "probes[\"rare_codon\"] = np.array([1 if CAI_WEIGHTS.get(c, 1.0) < 0.5 else 0 for c in codon_strings], dtype=np.float32)\n", + "\n", + "# 2. Start codon (ATG)\n", + "probes[\"start_codon_ATG\"] = np.array([1 if c == \"ATG\" else 0 for c in codon_strings], dtype=np.float32)\n", + "\n", + "# 3. GC-rich wobble position (3rd nt is G or C)\n", + "probes[\"wobble_GC\"] = np.array(\n", + " [1 if len(c) == 3 and c[2] in (\"G\", \"C\") else 0 for c in codon_strings], dtype=np.float32\n", + ")\n", + "\n", + "# 4. CpG site (last nt = C, next codon first nt = G)\n", + "cpg_labels = np.zeros(n_codons_total, dtype=np.float32)\n", + "for j in range(n_codons_total - 1):\n", + " c_curr = codon_strings[j]\n", + " c_next = codon_strings[j + 1]\n", + " # Only within same sequence (check gene label continuity as proxy)\n", + " if len(c_curr) == 3 and len(c_next) >= 1 and all_gene_labels[j] == all_gene_labels[j + 1]:\n", + " if c_curr[2] == \"C\" and c_next[0] == \"G\":\n", + " cpg_labels[j] = 1.0\n", + "probes[\"CpG_site\"] = cpg_labels\n", + "\n", + "for name, labels in probes.items():\n", + " n_pos = int(labels.sum())\n", + " print(f\" {name}: {n_pos:,} positive / {len(labels):,} total ({n_pos / len(labels):.1%})\")" + ] + }, + { + "cell_type": "markdown", + "id": "e1f2a3b4", + "metadata": {}, + "source": [ + "## Single-Latent Probing\n", + "\n", + "For each probe task and each SAE latent, we compute the AUROC of using that single latent's activation as a classifier score. This is the simplest possible probe — no learned parameters, just \"does higher activation of latent `i` predict label = 1?\"\n", + "\n", + "We compute AUROC rather than accuracy because the labels are often imbalanced (e.g., only ~3% of codons are ATG). AUROC is threshold-independent and handles class imbalance naturally.\n", + "\n", + "For efficiency, we compute all latent AUROCs in parallel using vectorized operations. With 32K features \\u00d7 500K codons, this is the bottleneck — we subsample if needed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f2a3b4c5", + "metadata": {}, + "outputs": [], + "source": [ + "def compute_single_latent_aurocs(activations, labels, max_features=None, subsample=100_000):\n", + " \"\"\"Compute AUROC for each latent as a single-feature classifier.\n", + "\n", + " Args:\n", + " activations: (n_samples, n_features) array\n", + " labels: (n_samples,) binary array\n", + " max_features: if set, only evaluate this many features (highest-variance)\n", + " subsample: subsample to this many examples for speed\n", + "\n", + " Returns:\n", + " aurocs: (n_features,) array of AUROC scores\n", + " feature_indices: which features were evaluated (if max_features used)\n", + " \"\"\"\n", + " n_samples, n_feats = activations.shape\n", + "\n", + " # Subsample for speed\n", + " if subsample and n_samples > subsample:\n", + " rng = np.random.RandomState(42)\n", + " idx = rng.choice(n_samples, subsample, replace=False)\n", + " activations = activations[idx]\n", + " labels = labels[idx]\n", + " n_samples = subsample\n", + "\n", + " # Filter to features that have any variance (dead features get AUROC=0.5)\n", + " feat_var = activations.var(axis=0)\n", + " if max_features and max_features < n_feats:\n", + " feature_indices = np.argsort(feat_var)[-max_features:]\n", + " else:\n", + " feature_indices = np.where(feat_var > 0)[0]\n", + "\n", + " aurocs = np.full(n_feats, 0.5)\n", + "\n", + " pos_mask = labels == 1\n", + " n_pos = pos_mask.sum()\n", + " n_neg = n_samples - n_pos\n", + "\n", + " if n_pos < 5 or n_neg < 5:\n", + " print(f\" WARNING: Too few positives ({n_pos}) or negatives ({n_neg}), skipping\")\n", + " return aurocs, feature_indices\n", + "\n", + " for feat_idx in tqdm(feature_indices, desc=\" AUROC per latent\", leave=False):\n", + " try:\n", + " aurocs[feat_idx] = roc_auc_score(labels, activations[:, feat_idx])\n", + " except ValueError:\n", + " pass # constant predictions\n", + "\n", + " return aurocs, feature_indices\n", + "\n", + "\n", + "# Run codon-level probes\n", + "codon_probe_results = {}\n", + "\n", + "for probe_name, labels in probes.items():\n", + " print(f\"\\n{'=' * 50}\")\n", + " print(f\"Probe: {probe_name}\")\n", + " aurocs, feat_idx = compute_single_latent_aurocs(codon_acts, labels, max_features=5000, subsample=100_000)\n", + "\n", + " best_idx = np.argmax(aurocs)\n", + " best_auroc = aurocs[best_idx]\n", + " top10 = np.argsort(aurocs)[-10:][::-1]\n", + "\n", + " print(f\" Best latent: {best_idx} (AUROC = {best_auroc:.4f})\")\n", + " print(\" Top 10 latents:\")\n", + " for idx in top10:\n", + " print(f\" Feature {idx:>5d}: AUROC = {aurocs[idx]:.4f}\")\n", + "\n", + " codon_probe_results[probe_name] = {\n", + " \"aurocs\": aurocs,\n", + " \"best_feature\": int(best_idx),\n", + " \"best_auroc\": float(best_auroc),\n", + " }" + ] + }, + { + "cell_type": "markdown", + "id": "a3b4c5d6", + "metadata": {}, + "source": [ + "## Visualize Probe Results\n", + "\n", + "For each probe, we plot the distribution of per-latent AUROCs. A spike near 1.0 means one or more latents have learned an almost perfect detector for that concept. A distribution clustered around 0.5 means no single latent captures it — the concept is either not represented or distributed across multiple features." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b4c5d6e7", + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(1, len(probes), figsize=(5 * len(probes), 4))\n", + "if len(probes) == 1:\n", + " axes = [axes]\n", + "\n", + "colors = [\"#76b900\", \"#0074DF\", \"#9525C6\", \"#EF2020\"]\n", + "\n", + "for ax, (probe_name, result), color in zip(axes, codon_probe_results.items(), colors):\n", + " aurocs = result[\"aurocs\"]\n", + " alive_aurocs = aurocs[aurocs != 0.5] # Exclude dead features\n", + "\n", + " ax.hist(alive_aurocs, bins=50, color=color, edgecolor=\"white\", alpha=0.8)\n", + " ax.axvline(\n", + " result[\"best_auroc\"],\n", + " color=\"red\",\n", + " linestyle=\"--\",\n", + " linewidth=2,\n", + " label=f\"Best: {result['best_auroc']:.3f} (#{result['best_feature']})\",\n", + " )\n", + " ax.axvline(0.5, color=\"gray\", linestyle=\":\", alpha=0.5)\n", + " ax.set_xlabel(\"AUROC\")\n", + " ax.set_ylabel(\"Number of features\")\n", + " ax.set_title(probe_name.replace(\"_\", \" \"))\n", + " ax.legend(fontsize=9)\n", + " ax.set_xlim(0.3, 1.0)\n", + "\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "c5d6e7f8", + "metadata": {}, + "source": [ + "## Gene-Level Probes\n", + "\n", + "For gene-level probing, we mean-pool SAE activations across all codons in each gene, giving one activation vector per gene. We then probe for gene-level properties.\n", + "\n", + "We demonstrate with **pLI** (probability of loss-of-function intolerance): genes with pLI > 0.9 are essential and can't tolerate losing function. If an SAE feature has high AUROC for this label, it's detecting patterns specific to essential genes — potentially codon optimization signals, since essential genes tend to be highly expressed and use optimal codons." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d6e7f8a9", + "metadata": {}, + "outputs": [], + "source": [ + "# Build gene-level mean activations\n", + "gene_names_unique = sorted(gene_act_sum.keys())\n", + "gene_mean_acts = np.array([gene_act_sum[g] / gene_act_count[g] for g in gene_names_unique])\n", + "print(f\"Gene-level activations: {gene_mean_acts.shape[0]} genes \\u00d7 {gene_mean_acts.shape[1]} features\")\n", + "\n", + "gene_probe_results = {}\n", + "\n", + "# ── pLI probe (requires gnomAD file) ──\n", + "if PLI_PATH:\n", + " import pandas as pd\n", + "\n", + " pli_df = pd.read_csv(PLI_PATH, sep=\"\\t\", compression=\"gzip\", usecols=[\"gene\", \"pLI\"]).dropna()\n", + " pli_df = pli_df.drop_duplicates(subset=[\"gene\"], keep=\"first\")\n", + " pli_map = dict(zip(pli_df[\"gene\"], pli_df[\"pLI\"]))\n", + "\n", + " pli_labels = np.array([1.0 if pli_map.get(g, 0) > 0.9 else 0.0 for g in gene_names_unique])\n", + " n_constrained = int(pli_labels.sum())\n", + " print(f\"pLI probe: {n_constrained} constrained / {len(pli_labels)} genes ({n_constrained / len(pli_labels):.1%})\")\n", + "\n", + " aurocs, feat_idx = compute_single_latent_aurocs(gene_mean_acts, pli_labels, max_features=5000)\n", + " best_idx = np.argmax(aurocs)\n", + " print(f\" Best latent: {best_idx} (AUROC = {aurocs[best_idx]:.4f})\")\n", + " gene_probe_results[\"pLI_constrained\"] = {\n", + " \"aurocs\": aurocs,\n", + " \"best_feature\": int(best_idx),\n", + " \"best_auroc\": float(aurocs[best_idx]),\n", + " }\n", + "else:\n", + " print(\"PLI_PATH not set \\u2014 skipping pLI probe.\")\n", + " print(\"Set PLI_PATH to the gnomAD constraint file to enable gene-level pLI probing.\")\n", + "\n", + "# ── Housekeeping gene probe (heuristic: genes present in many tissues) ──\n", + "# As a simple proxy, genes with very common names (no tissue prefix) that appear\n", + "# frequently in the dataset are likely housekeeping. We use gene frequency as a proxy.\n", + "gene_freq = {g: gene_act_count[g] for g in gene_names_unique}\n", + "freq_values = np.array([gene_freq[g] for g in gene_names_unique])\n", + "# Top quartile by codon count = likely highly expressed / housekeeping\n", + "housekeeping_labels = (freq_values >= np.percentile(freq_values, 75)).astype(np.float32)\n", + "n_hk = int(housekeeping_labels.sum())\n", + "print(\n", + " f\"\\nHousekeeping proxy probe: {n_hk} positive / {len(housekeeping_labels)} genes ({n_hk / len(housekeeping_labels):.1%})\"\n", + ")\n", + "\n", + "aurocs, feat_idx = compute_single_latent_aurocs(gene_mean_acts, housekeeping_labels, max_features=5000)\n", + "best_idx = np.argmax(aurocs)\n", + "print(f\" Best latent: {best_idx} (AUROC = {aurocs[best_idx]:.4f})\")\n", + "gene_probe_results[\"housekeeping_proxy\"] = {\n", + " \"aurocs\": aurocs,\n", + " \"best_feature\": int(best_idx),\n", + " \"best_auroc\": float(aurocs[best_idx]),\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "e7f8a9b0", + "metadata": {}, + "source": [ + "## Logistic Regression Probe (Multi-Latent Baseline)\n", + "\n", + "The single-latent AUROC tells us about monosemantic features. But how well can a *linear combination* of all latents predict each label? This is the upper bound for linear probing — if the best single latent gets AUROC 0.95 and the full logistic regression gets 0.96, the concept is cleanly captured by one feature. If the single best is 0.65 but the full model gets 0.95, the concept is distributed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f8a9b0c1", + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.model_selection import cross_val_score\n", + "\n", + "\n", + "print(\"Logistic regression probes (5-fold CV):\\n\")\n", + "\n", + "# Codon-level (subsample for speed)\n", + "rng = np.random.RandomState(42)\n", + "n_sub = min(50_000, codon_acts.shape[0])\n", + "sub_idx = rng.choice(codon_acts.shape[0], n_sub, replace=False)\n", + "X_sub = codon_acts[sub_idx]\n", + "\n", + "for probe_name, labels in probes.items():\n", + " y_sub = labels[sub_idx]\n", + " if y_sub.sum() < 10 or (1 - y_sub).sum() < 10:\n", + " print(f\" {probe_name}: skipped (insufficient labels)\")\n", + " continue\n", + "\n", + " clf = LogisticRegression(max_iter=500, C=1.0, solver=\"saga\", random_state=42)\n", + " scores = cross_val_score(clf, X_sub, y_sub, cv=5, scoring=\"roc_auc\", n_jobs=-1)\n", + "\n", + " best_single = codon_probe_results[probe_name][\"best_auroc\"]\n", + " print(f\" {probe_name}:\")\n", + " print(f\" Best single latent: AUROC = {best_single:.4f}\")\n", + " print(f\" Logistic regression: AUROC = {scores.mean():.4f} \\u00b1 {scores.std():.4f}\")\n", + " gap = scores.mean() - best_single\n", + " if gap < 0.02:\n", + " print(\" \\u2192 Concept is MONOSEMANTIC (single feature captures it)\")\n", + " elif gap < 0.1:\n", + " print(\" \\u2192 Concept is MOSTLY monosemantic (small multi-feature gain)\")\n", + " else:\n", + " print(f\" \\u2192 Concept is DISTRIBUTED (large multi-feature gain: +{gap:.3f})\")\n", + "\n", + "# Gene-level\n", + "print(\"\\nGene-level probes:\")\n", + "for probe_name, result in gene_probe_results.items():\n", + " if probe_name == \"pLI_constrained\":\n", + " y = np.array([1.0 if pli_map.get(g, 0) > 0.9 else 0.0 for g in gene_names_unique])\n", + " elif probe_name == \"housekeeping_proxy\":\n", + " y = housekeeping_labels\n", + " else:\n", + " continue\n", + "\n", + " if y.sum() < 10 or (1 - y).sum() < 10:\n", + " print(f\" {probe_name}: skipped\")\n", + " continue\n", + "\n", + " clf = LogisticRegression(max_iter=500, C=1.0, solver=\"saga\", random_state=42)\n", + " scores = cross_val_score(clf, gene_mean_acts, y, cv=5, scoring=\"roc_auc\", n_jobs=-1)\n", + " best_single = result[\"best_auroc\"]\n", + " print(f\" {probe_name}:\")\n", + " print(f\" Best single latent: AUROC = {best_single:.4f}\")\n", + " print(f\" Logistic regression: AUROC = {scores.mean():.4f} \\u00b1 {scores.std():.4f}\")" + ] + }, + { + "cell_type": "markdown", + "id": "a9b0c1d2", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "The probing results tell us:\n", + "\n", + "1. **Which biological concepts are monosemantically encoded** — single features that achieve high AUROC are clean detectors. These are the features you'd want to use for steering experiments (e.g., clamping the \"rare codon\" detector high to generate rare-codon-enriched sequences).\n", + "\n", + "2. **Which concepts are distributed** — if no single feature predicts a label well but logistic regression does, the concept is spread across features. This is still useful (the SAE represents it) but harder to steer with a single knob.\n", + "\n", + "3. **Which concepts the model doesn't represent** — if even logistic regression can't predict a label from SAE activations, the base model likely doesn't encode that property in its hidden states at this layer.\n", + "\n", + "### Adding Custom Probes\n", + "\n", + "To add your own probe, define a binary label array and pass it to `compute_single_latent_aurocs`:\n", + "\n", + "```python\n", + "# Example: probe for codons ending in 'A'\n", + "my_labels = np.array([1.0 if c.endswith(\"A\") else 0.0 for c in codon_strings])\n", + "aurocs, _ = compute_single_latent_aurocs(codon_acts, my_labels, max_features=5000)\n", + "best = np.argmax(aurocs)\n", + "print(f\"Best feature for 'ends with A': {best} (AUROC={aurocs[best]:.4f})\")\n", + "```\n", + "\n", + "You can also probe for any gene-level property by adding a label per gene in `gene_names_unique` and using `gene_mean_acts`." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}