diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/README.md b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/README.md new file mode 100644 index 0000000000..ad749dbedb --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/README.md @@ -0,0 +1,30 @@ +# Evo2 SAE Recipe + +Train a sparse autoencoder on Evo2 (DNA language model) residual-stream activations. + +Pipeline: + +``` +HF Savanna ckpt --convert--> MBridge ckpt + | + predict_evo2 --embedding-layer N (FASTA in, .pt out) + | + pt_to_parquet shim (.pt -> ActivationStore parquet shards) + | + train.py (TopK SAE) +``` + +The eval / dashboard stage from the esm2 recipe is intentionally not ported in v1. + +## Quick start (1B model, single GPU) + +```bash +bash scripts/1b.sh +``` + +This will: + +1. Convert `arcinstitute/savanna_evo2_1b_base` to MBridge format +2. Run `predict_evo2` on the OpenGenome2 organelle FASTA, extracting layer-12 embeddings +3. Convert the .pt outputs to parquet shards +4. Train a TopK SAE (expansion=8, k=32) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/.gitignore b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/.gitignore new file mode 100644 index 0000000000..f4daeb0de8 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/.gitignore @@ -0,0 +1,4 @@ +node_modules/ +package-lock.json +dist/ +.vite/ diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/README.md b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/README.md new file mode 100644 index 0000000000..b0c15231cc --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/README.md @@ -0,0 +1,55 @@ +# Evo 2 SAE Feature Explorer — Mockup + +**Mockup, not a real artifact.** This is a fork of `recipes/codonfm/codon_dashboard` adapted for DNA / Evo 2, populated with **synthetic data**. No real SAE outputs flow through it yet. The point of this v1 is to lock in the data contract that the future real eval pipeline will target. + +A `MOCKUP — synthetic data, not from a real SAE run` banner is rendered at the top of the app so nobody mistakes it for actual results. + +## Quick start (local) + +```bash +# In this directory: +npm install +npm run dev +# open http://localhost:5173 +``` + +The dashboard reads three parquet fixtures from `public/`: + +- `features_atlas.parquet` — UMAP coordinates + per-feature aggregates +- `feature_metadata.parquet` — feature label/stats table +- `feature_examples.parquet` — long table of (feature_id, example_rank, sequence_id, start, end, sequence, activations, ...) rows + +The fixtures are committed to the repo. To regenerate them: + +```bash +python ../scripts/make_mockup_features.py +``` + +That writes all three files into `public/`. Seed is fixed (`--seed 42`). + +## What's mocked vs. real + +| Thing | Source | +| ------------------------------------ | --------------------------------------------------------------- | +| Number of features | 20, hardcoded | +| Feature labels | Hardcoded biological-sounding strings | +| UMAP coordinates | 4 cluster centers + gaussian noise — fake but visibly clustered | +| Top activator windows | Random `ACGT` with a label-matching central motif spliced in | +| Per-token activations | Gaussian bump centered randomly in [80, 120], sigma ~= 8 bp | +| Vocab logits (promoted / suppressed) | Empty arrays — not in scope for v1 | + +## v2 roadmap placeholders + +A few greyed-out stats on each feature card (`Annotation`, `Sensitivity`, `Recon Δ`) and two empty sections on the feature detail page (`Annotations`, `Conservation`) hint at what's coming in v2. They render as em-dashes / dashed empty boxes with hover tooltips explaining what they'll show. + +## Out of scope (v1) + +- Real SAE inference or activation pass +- Annotation overlays (RefSeq / Rfam / JASPAR) +- Conservation tracks (phyloP) +- Strand handling, codon framing, chromosome ideograms +- External link-outs (UCSC, Ensembl) +- `sae.launch_dashboard()` Python wiring — run `npm run dev` directly +- Lepton-based serving + +These are deferred to v2, once the real eval pipeline produces matching parquet shapes. diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/index.html b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/index.html new file mode 100644 index 0000000000..5d09d6738d --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/index.html @@ -0,0 +1,104 @@ + + + + + + Evo 2 SAE Feature Explorer (Mockup) + + + +
+ + + diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/package.json b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/package.json new file mode 100644 index 0000000000..53674056a3 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/package.json @@ -0,0 +1,25 @@ +{ + "name": "evo2-dashboard-mockup", + "version": "0.1.0", + "private": true, + "type": "module", + "scripts": { + "dev": "vite", + "build": "vite build", + "preview": "vite preview" + }, + "dependencies": { + "@uwdata/mosaic-core": "^0.21.1", + "@uwdata/mosaic-sql": "^0.21.1", + "@uwdata/vgplot": "^0.21.1", + "embedding-atlas": "^0.16.1", + "lucide-react": "^0.577.0", + "react": "^18.2.0", + "react-dom": "^18.2.0", + "umap-js": "^1.4.0" + }, + "devDependencies": { + "@vitejs/plugin-react": "^4.2.0", + "vite": "^5.0.0" + } +} diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/feature_examples.parquet b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/feature_examples.parquet new file mode 100644 index 0000000000..6b69aa32f7 Binary files /dev/null and b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/feature_examples.parquet differ diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/feature_metadata.parquet b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/feature_metadata.parquet new file mode 100644 index 0000000000..8a46695799 Binary files /dev/null and b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/feature_metadata.parquet differ diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/features_atlas.parquet b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/features_atlas.parquet new file mode 100644 index 0000000000..8a46695799 Binary files /dev/null and b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/features_atlas.parquet differ diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/logos/feature_0.png b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/logos/feature_0.png new file mode 100644 index 0000000000..1aa768bcfa Binary files /dev/null and b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/logos/feature_0.png differ diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/logos/feature_1.png b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/logos/feature_1.png new file mode 100644 index 0000000000..5638d8ffd6 Binary files /dev/null and b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/logos/feature_1.png differ diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/logos/feature_10.png b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/logos/feature_10.png new file mode 100644 index 0000000000..8ea17e2517 Binary files /dev/null and b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/logos/feature_10.png differ diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/logos/feature_2.png b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/logos/feature_2.png new file mode 100644 index 0000000000..8bea049ab0 Binary files /dev/null and b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/logos/feature_2.png differ diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/logos/feature_3.png b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/logos/feature_3.png new file mode 100644 index 0000000000..a8fee6946f Binary files /dev/null and b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/logos/feature_3.png differ diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/logos/feature_4.png b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/logos/feature_4.png new file mode 100644 index 0000000000..7522420131 Binary files /dev/null and b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/logos/feature_4.png differ diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/logos/feature_5.png b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/logos/feature_5.png new file mode 100644 index 0000000000..778ce71ced Binary files /dev/null and b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/logos/feature_5.png differ diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/logos/feature_6.png b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/logos/feature_6.png new file mode 100644 index 0000000000..51ee2ab9d1 Binary files /dev/null and b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/logos/feature_6.png differ diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/logos/feature_7.png b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/logos/feature_7.png new file mode 100644 index 0000000000..f17f231dcd Binary files /dev/null and b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/logos/feature_7.png differ diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/logos/feature_8.png b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/logos/feature_8.png new file mode 100644 index 0000000000..7f6e2d4c57 Binary files /dev/null and b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/logos/feature_8.png differ diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/logos/feature_9.png b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/logos/feature_9.png new file mode 100644 index 0000000000..65a7c91cc8 Binary files /dev/null and b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/logos/feature_9.png differ diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/steering_data.json b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/steering_data.json new file mode 100644 index 0000000000..95172ada43 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/public/steering_data.json @@ -0,0 +1,146 @@ +{ + "seeds": { + "ecoli_16s": { + "name": "E. coli 16S rRNA region (A1408 example)", + "sequence": "ATCCGTCAACCTTCAAGCATCCAAACGGCGATGATCAAGGCATAAGCCTACAGGGTACCATAGCGAAGCTGCGTCTGAAGCGAACTGGCAAACGGCTAACAGCG", + "length": 100, + "alphabet": "DNA", + "default_target_position": 8, + "context_note": "Position 8 in this synthetic sequence is the analog of position 1408 in the full 16S rRNA. Sequence shown as DNA bases to match Evo2 tokenization." + }, + "promoter": { + "name": "Promoter region", + "sequence": "ATCGATGCGTAGCATGCATGGCATATATAAGCATCGATCGATCGATGCATGCTAGCATGCTAGCATGCATGCAT", + "length": 73, + "alphabet": "DNA", + "default_target_position": 24, + "context_note": "Non-rRNA context. AMR features should not shift predictions here." + }, + "brca1_exon": { + "name": "BRCA1 exon fragment", + "sequence": "ATGGATTTATCTGCTCTTCGCGTTGAAGAAGTACAAAATGTCATTAATGCTATGCAGAAAATCTTAGAGTGTCCCATCTGTCTGGAGTTGAT", + "length": 90, + "alphabet": "DNA", + "default_target_position": 45, + "context_note": "Human exonic context. AMR features (bacterial rRNA) should have no effect here." + }, + "random": { + "name": "Random control sequence", + "sequence": "GACTGCATCGATGCATGCATGCATGCATGCATGCATGCATGCATGCATGCATGCATGCATGCATGCATGCATGC", + "length": 73, + "alphabet": "DNA", + "default_target_position": 30, + "context_note": "No biological structure; serves as a negative control for AMR steering." + } + }, + "features_available": [ + {"id": 12, "label": "kanamycin_resistance", "is_amr": true}, + {"id": 13, "label": "streptomycin_resistance", "is_amr": true} + ], + "comparisons": { + "ecoli_16s__kanamycin_resistance__pos8": { + "seed": "ecoli_16s", + "feature_id": 12, + "target_position": 8, + "neighbor_count": 1, + "results_by_clamp": { + "-2": {"baseline": {"A": 0.74, "C": 0.12, "G": 0.08, "T": 0.06}, "steered": {"A": 0.91, "C": 0.04, "G": 0.03, "T": 0.02}}, + "0": {"baseline": {"A": 0.74, "C": 0.12, "G": 0.08, "T": 0.06}, "steered": {"A": 0.74, "C": 0.12, "G": 0.08, "T": 0.06}}, + "2": {"baseline": {"A": 0.74, "C": 0.12, "G": 0.08, "T": 0.06}, "steered": {"A": 0.42, "C": 0.08, "G": 0.43, "T": 0.07}}, + "5": {"baseline": {"A": 0.74, "C": 0.12, "G": 0.08, "T": 0.06}, "steered": {"A": 0.18, "C": 0.07, "G": 0.71, "T": 0.04}} + }, + "narrative_type": "headline_amr" + }, + "ecoli_16s__streptomycin_resistance__pos8": { + "seed": "ecoli_16s", + "feature_id": 13, + "target_position": 8, + "neighbor_count": 1, + "results_by_clamp": { + "-2": {"baseline": {"A": 0.74, "C": 0.12, "G": 0.08, "T": 0.06}, "steered": {"A": 0.88, "C": 0.05, "G": 0.04, "T": 0.03}}, + "0": {"baseline": {"A": 0.74, "C": 0.12, "G": 0.08, "T": 0.06}, "steered": {"A": 0.74, "C": 0.12, "G": 0.08, "T": 0.06}}, + "2": {"baseline": {"A": 0.74, "C": 0.12, "G": 0.08, "T": 0.06}, "steered": {"A": 0.50, "C": 0.09, "G": 0.34, "T": 0.07}}, + "5": {"baseline": {"A": 0.74, "C": 0.12, "G": 0.08, "T": 0.06}, "steered": {"A": 0.26, "C": 0.08, "G": 0.62, "T": 0.04}} + }, + "narrative_type": "headline_amr" + }, + "promoter__kanamycin_resistance__pos24": { + "seed": "promoter", + "feature_id": 12, + "target_position": 24, + "neighbor_count": 1, + "results_by_clamp": { + "-2": {"baseline": {"A": 0.62, "C": 0.10, "G": 0.08, "T": 0.20}, "steered": {"A": 0.63, "C": 0.10, "G": 0.08, "T": 0.19}}, + "0": {"baseline": {"A": 0.62, "C": 0.10, "G": 0.08, "T": 0.20}, "steered": {"A": 0.62, "C": 0.10, "G": 0.08, "T": 0.20}}, + "2": {"baseline": {"A": 0.62, "C": 0.10, "G": 0.08, "T": 0.20}, "steered": {"A": 0.59, "C": 0.11, "G": 0.10, "T": 0.20}}, + "5": {"baseline": {"A": 0.62, "C": 0.10, "G": 0.08, "T": 0.20}, "steered": {"A": 0.55, "C": 0.12, "G": 0.13, "T": 0.20}} + }, + "narrative_type": "null_result" + }, + "promoter__streptomycin_resistance__pos24": { + "seed": "promoter", + "feature_id": 13, + "target_position": 24, + "neighbor_count": 1, + "results_by_clamp": { + "-2": {"baseline": {"A": 0.62, "C": 0.10, "G": 0.08, "T": 0.20}, "steered": {"A": 0.63, "C": 0.10, "G": 0.08, "T": 0.19}}, + "0": {"baseline": {"A": 0.62, "C": 0.10, "G": 0.08, "T": 0.20}, "steered": {"A": 0.62, "C": 0.10, "G": 0.08, "T": 0.20}}, + "2": {"baseline": {"A": 0.62, "C": 0.10, "G": 0.08, "T": 0.20}, "steered": {"A": 0.60, "C": 0.10, "G": 0.10, "T": 0.20}}, + "5": {"baseline": {"A": 0.62, "C": 0.10, "G": 0.08, "T": 0.20}, "steered": {"A": 0.57, "C": 0.11, "G": 0.12, "T": 0.20}} + }, + "narrative_type": "null_result" + }, + "brca1_exon__kanamycin_resistance__pos45": { + "seed": "brca1_exon", + "feature_id": 12, + "target_position": 45, + "neighbor_count": 1, + "results_by_clamp": { + "-2": {"baseline": {"A": 0.31, "C": 0.21, "G": 0.28, "T": 0.20}, "steered": {"A": 0.31, "C": 0.21, "G": 0.28, "T": 0.20}}, + "0": {"baseline": {"A": 0.31, "C": 0.21, "G": 0.28, "T": 0.20}, "steered": {"A": 0.31, "C": 0.21, "G": 0.28, "T": 0.20}}, + "2": {"baseline": {"A": 0.31, "C": 0.21, "G": 0.28, "T": 0.20}, "steered": {"A": 0.30, "C": 0.21, "G": 0.29, "T": 0.20}}, + "5": {"baseline": {"A": 0.31, "C": 0.21, "G": 0.28, "T": 0.20}, "steered": {"A": 0.28, "C": 0.21, "G": 0.31, "T": 0.20}} + }, + "narrative_type": "null_result" + }, + "brca1_exon__streptomycin_resistance__pos45": { + "seed": "brca1_exon", + "feature_id": 13, + "target_position": 45, + "neighbor_count": 1, + "results_by_clamp": { + "-2": {"baseline": {"A": 0.31, "C": 0.21, "G": 0.28, "T": 0.20}, "steered": {"A": 0.31, "C": 0.21, "G": 0.28, "T": 0.20}}, + "0": {"baseline": {"A": 0.31, "C": 0.21, "G": 0.28, "T": 0.20}, "steered": {"A": 0.31, "C": 0.21, "G": 0.28, "T": 0.20}}, + "2": {"baseline": {"A": 0.31, "C": 0.21, "G": 0.28, "T": 0.20}, "steered": {"A": 0.30, "C": 0.22, "G": 0.28, "T": 0.20}}, + "5": {"baseline": {"A": 0.31, "C": 0.21, "G": 0.28, "T": 0.20}, "steered": {"A": 0.29, "C": 0.22, "G": 0.29, "T": 0.20}} + }, + "narrative_type": "null_result" + }, + "random__kanamycin_resistance__pos30": { + "seed": "random", + "feature_id": 12, + "target_position": 30, + "neighbor_count": 1, + "results_by_clamp": { + "-2": {"baseline": {"A": 0.28, "C": 0.22, "G": 0.27, "T": 0.23}, "steered": {"A": 0.27, "C": 0.23, "G": 0.27, "T": 0.23}}, + "0": {"baseline": {"A": 0.28, "C": 0.22, "G": 0.27, "T": 0.23}, "steered": {"A": 0.28, "C": 0.22, "G": 0.27, "T": 0.23}}, + "2": {"baseline": {"A": 0.28, "C": 0.22, "G": 0.27, "T": 0.23}, "steered": {"A": 0.26, "C": 0.24, "G": 0.28, "T": 0.22}}, + "5": {"baseline": {"A": 0.28, "C": 0.22, "G": 0.27, "T": 0.23}, "steered": {"A": 0.25, "C": 0.25, "G": 0.29, "T": 0.21}} + }, + "narrative_type": "null_result" + }, + "random__streptomycin_resistance__pos30": { + "seed": "random", + "feature_id": 13, + "target_position": 30, + "neighbor_count": 1, + "results_by_clamp": { + "-2": {"baseline": {"A": 0.28, "C": 0.22, "G": 0.27, "T": 0.23}, "steered": {"A": 0.28, "C": 0.22, "G": 0.27, "T": 0.23}}, + "0": {"baseline": {"A": 0.28, "C": 0.22, "G": 0.27, "T": 0.23}, "steered": {"A": 0.28, "C": 0.22, "G": 0.27, "T": 0.23}}, + "2": {"baseline": {"A": 0.28, "C": 0.22, "G": 0.27, "T": 0.23}, "steered": {"A": 0.27, "C": 0.23, "G": 0.28, "T": 0.22}}, + "5": {"baseline": {"A": 0.28, "C": 0.22, "G": 0.27, "T": 0.23}, "steered": {"A": 0.26, "C": 0.24, "G": 0.29, "T": 0.21}} + }, + "narrative_type": "null_result" + } + } +} diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/src/App.jsx b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/src/App.jsx new file mode 100644 index 0000000000..9b82027783 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/src/App.jsx @@ -0,0 +1,1645 @@ +import React, { useState, useEffect, useRef, useMemo, useCallback } from 'react' +import * as vg from '@uwdata/vgplot' +import { wasmConnector, MosaicClient } from '@uwdata/mosaic-core' +import { Query, sql, literal } from '@uwdata/mosaic-sql' +import FeatureCard from './FeatureCard' +import FeatureList from './FeatureList' +import EmbeddingView from './EmbeddingView' +import Histogram from './Histogram' +import InfoButton from './InfoButton' +import { Sun, Moon } from 'lucide-react' + +const styles = { + container: { + height: '100vh', + display: 'flex', + flexDirection: 'column', + padding: '16px', + gap: '4px', + overflow: 'hidden', + background: 'var(--bg)', + color: 'var(--text)', + }, + header: { + flexShrink: 0, + }, + title: { + fontSize: '22px', + fontWeight: '600', + marginBottom: '2px', + color: 'var(--text-heading)', + }, + subtitle: { + color: 'var(--text-secondary)', + fontSize: '13px', + margin: 0, + }, + mainContent: { + flex: 1, + display: 'grid', + gridTemplateColumns: '3fr 2fr', + gap: '16px', + minHeight: 0, + overflow: 'hidden', + }, + leftPanel: { + display: 'flex', + flexDirection: 'column', + gap: '12px', + minHeight: 0, + minWidth: 0, + overflow: 'hidden', + }, + embeddingPanel: { + flex: 1, + background: 'var(--bg-card)', + borderRadius: '8px', + border: '1px solid var(--border)', + padding: '12px', + display: 'flex', + flexDirection: 'column', + minHeight: '300px', + minWidth: 0, + overflow: 'hidden', + }, + embeddingContainer: { + flex: 1, + minHeight: 0, + overflow: 'hidden', + }, + histogramRow: { + display: 'grid', + gridTemplateColumns: '1fr 1fr 1fr', + gap: '12px', + flexShrink: 0, + height: '100px', + marginBottom: '4px', + }, + histogramPanel: { + background: 'var(--bg-card)', + borderRadius: '8px', + border: '1px solid var(--border)', + padding: '8px', + overflow: 'hidden', + }, + panelHeader: { + display: 'flex', + justifyContent: 'space-between', + alignItems: 'center', + marginBottom: '8px', + flexShrink: 0, + }, + panelTitle: { + fontSize: '14px', + fontWeight: '600', + color: 'var(--text-heading)', + }, + rightPanel: { + display: 'flex', + flexDirection: 'column', + gap: '10px', + minHeight: 0, + minWidth: 0, + height: '100%', + overflow: 'hidden', + }, + searchBar: { + display: 'flex', + gap: '8px', + flexShrink: 0, + }, + searchInput: { + flex: 0.81, + padding: '8px 12px', + fontSize: '13px', + border: '1px solid var(--border-input)', + borderRadius: '6px', + outline: 'none', + background: 'var(--bg-input)', + color: 'var(--text)', + }, + sortSelect: { + padding: '8px 12px', + fontSize: '13px', + border: '1px solid var(--border-input)', + borderRadius: '6px', + background: 'var(--bg-input)', + color: 'var(--text)', + cursor: 'pointer', + }, + stats: { + padding: '4px 0', + fontSize: '12px', + color: 'var(--text-secondary)', + flexShrink: 0, + }, + featureList: { + flex: 1, + overflowY: 'auto', + overflowX: 'hidden', + display: 'flex', + flexDirection: 'column', + gap: '10px', + paddingRight: '8px', + minHeight: 0, + }, + loading: { + textAlign: 'center', + padding: '40px', + color: 'var(--text-secondary)', + }, + error: { + textAlign: 'center', + padding: '40px', + color: '#c00', + }, + colorSelect: { + padding: '4px 8px', + fontSize: '12px', + border: '1px solid var(--border-input)', + borderRadius: '4px', + background: 'var(--bg-input)', + color: 'var(--text)', + cursor: 'pointer', + }, + clearButton: { + padding: '4px 12px', + fontSize: '12px', + border: '2px solid var(--accent)', + borderRadius: '4px', + background: 'transparent', + color: 'var(--accent)', + fontWeight: '600', + cursor: 'pointer', + }, +} + +export default function App({ title = "Evo 2 SAE Feature Explorer (Mockup)", subtitle = "Synthetic data — not from a real SAE run" }) { + const [darkMode, setDarkMode] = useState(true) + + // Toggle dark class on document root + useEffect(() => { + document.documentElement.classList.toggle('dark', darkMode) + }, [darkMode]) + + const [features, setFeatures] = useState([]) + const [loading, setLoading] = useState(true) + const [loadingProgress, setLoadingProgress] = useState({ step: 0, total: 4, message: 'Starting up...' }) + const [error, setError] = useState(null) + const [sortBy, setSortBy] = useState('frequency') + const [selectedFeatureIds, setSelectedFeatureIds] = useState(null) // null = all selected + const [mosaicReady, setMosaicReady] = useState(false) + const [categoryColumns, setCategoryColumns] = useState([]) + const [selectedCategory, setSelectedCategory] = useState('cluster_id') + const [hiddenCategories, setHiddenCategories] = useState(new Set()) + const [clickedFeatureId, setClickedFeatureId] = useState(null) + const [clusterLabels, setClusterLabels] = useState(null) + const [vocabLogits, setVocabLogits] = useState(null) + const [featureAnalysis, setFeatureAnalysis] = useState(null) + + const brushRef = useRef(null) + const [showGuideModal, setShowGuideModal] = useState(false) + const [showMetricsModal, setShowMetricsModal] = useState(false) + const [searchTerm, setSearchTerm] = useState('') + const [cardResetKey, setCardResetKey] = useState(0) + const [plotResetKey, setPlotResetKey] = useState(0) + const [viewportState, setViewportState] = useState(null) // null = let embedding-atlas auto-fit on first load + const [displayedCardCount, setDisplayedCardCount] = useState(20) // Pagination: start with 20 cards + const [showEditedOnly, setShowEditedOnly] = useState(false) // Filter for edited features only + const [histMetric1, setHistMetric1] = useState('log_frequency') + const [histMetric2, setHistMetric2] = useState('max_activation') + const [histMetric3, setHistMetric3] = useState('cluster_id') // tracks color-by selection + const featureRefs = useRef({}) + const featureListRef = useRef(null) + const endOfListRef = useRef(null) + const searchSource = useRef({ source: 'search' }) + const editedSource = useRef({ source: 'edited' }) + const legendSource = useRef({ source: 'legend' }) + const loadingMoreRef = useRef(false) + + // Lazy-load examples for a single feature from DuckDB (feature_examples VIEW) + const loadExamplesForFeature = useCallback(async (featureId) => { + const result = await vg.coordinator().query( + `SELECT * FROM feature_examples WHERE feature_id = ${featureId} ORDER BY example_rank` + ) + return result.toArray().map(row => ({ + sequence_id: row.sequence_id, + start: row.start, + end: row.end, + sequence: row.sequence, + activations: Array.from(row.activations), + max_activation: row.max_activation, + best_annotation: row.best_annotation, + })) + }, []) + + // Intersection Observer for infinite scroll pagination + useEffect(() => { + const sentinel = endOfListRef.current + const scrollContainer = featureListRef.current + if (!sentinel || !scrollContainer) return + + const observer = new IntersectionObserver( + entries => { + console.log('[scroll] sentinel intersecting:', entries[0].isIntersecting, 'loadingMore:', loadingMoreRef.current) + if (entries[0].isIntersecting && !loadingMoreRef.current) { + loadingMoreRef.current = true + setDisplayedCardCount(prev => prev + 20) + // Reset flag after a delay to allow next batch + setTimeout(() => { + loadingMoreRef.current = false + }, 300) + } + }, + { root: scrollContainer, threshold: 0.1, rootMargin: '200px' } + ) + + observer.observe(sentinel) + + return () => { + observer.disconnect() + } + }, [mosaicReady]) + + // Handle click on a feature in the UMAP (or null for empty canvas click) + const animationRef = useRef(null) + const currentViewportRef = useRef(null) + const initialViewportRef = useRef(null) + + // Handle viewport changes from the UMAP component + const handleViewportChange = useCallback((vp) => { + // Capture initial viewport on first report, slightly zoomed out so all points fit + if (!initialViewportRef.current && vp) { + initialViewportRef.current = { ...vp, scale: vp.scale * 0.5 } + setViewportState(initialViewportRef.current) + currentViewportRef.current = { ...initialViewportRef.current } + } + // Clamp zoom to max scale of 5 + if (vp && vp.scale > 5) { + const clamped = { ...vp, scale: 5 } + setViewportState(clamped) + currentViewportRef.current = clamped + return + } + // Always track current viewport (but not during our own animations) + if (!animationRef.current) { + currentViewportRef.current = vp + } + }, []) + + // Easing functions + const easeOutQuart = (t) => 1 - Math.pow(1 - t, 4) + const easeInOutCubic = (t) => t < 0.5 ? 4 * t * t * t : 1 - Math.pow(-2 * t + 2, 3) / 2 + const easeInOutQuad = (t) => t < 0.5 ? 2 * t * t : 1 - Math.pow(-2 * t + 2, 2) / 2 + + // Smooth zoom-in with "fly-to" trajectory (zoom out -> pan -> zoom in) + const zoomToPoint = useCallback((x, y) => { + if (x == null || y == null) return + + if (animationRef.current) { + cancelAnimationFrame(animationRef.current) + animationRef.current = null + } + + const start = currentViewportRef.current || initialViewportRef.current || { x: 0, y: 0, scale: 1 } + const targetScale = 4 // capped below max zoom of 5 + const duration = 800 + const startTime = performance.now() + + // Calculate how far we need to pan (in data space) + const panDistance = Math.sqrt(Math.pow(x - start.x, 2) + Math.pow(y - start.y, 2)) + + // Determine the "cruise altitude" - how much to zoom out during the pan + // Zoom out more for longer distances, less for short distances + const minScale = Math.min(start.scale, 0.8) // Never zoom out below 0.8 + const maxZoomOut = Math.max(0, start.scale - minScale) + const zoomOutAmount = Math.min(maxZoomOut, panDistance * 0.1) // Scale zoom-out with distance + const cruiseScale = start.scale - zoomOutAmount + + const animate = (currentTime) => { + const elapsed = currentTime - startTime + const t = Math.min(elapsed / duration, 1) + + // Use smooth ease-in-out for the overall progress + const smoothT = easeInOutCubic(t) + + // Pan follows the smooth progress + const panT = smoothT + + // Zoom follows a "U-shaped" profile: + // - First half: ease from start.scale down to cruiseScale (or stay flat if already low) + // - Second half: ease from cruiseScale up to targetScale + let zoomScale + if (t < 0.4) { + // First 40%: zoom out slightly (ease-out) + const zoomOutT = t / 0.4 + const easeOut = 1 - Math.pow(1 - zoomOutT, 2) + zoomScale = start.scale + (cruiseScale - start.scale) * easeOut + } else if (t < 0.6) { + // Middle 20%: hold at cruise altitude + zoomScale = cruiseScale + } else { + // Last 40%: zoom in to target (ease-in then ease-out) + const zoomInT = (t - 0.6) / 0.4 + const easeInOut = easeInOutQuad(zoomInT) + zoomScale = cruiseScale + (targetScale - cruiseScale) * easeInOut + } + + const newViewport = { + x: start.x + (x - start.x) * panT, + y: start.y + (y - start.y) * panT, + scale: zoomScale + } + + setViewportState(newViewport) + + if (t < 1) { + animationRef.current = requestAnimationFrame(animate) + } else { + currentViewportRef.current = { x, y, scale: targetScale } + animationRef.current = null + } + } + + animationRef.current = requestAnimationFrame(animate) + }, []) + + // Smooth zoom-out: zoom out first, then pan back + const resetViewport = useCallback(() => { + if (animationRef.current) { + cancelAnimationFrame(animationRef.current) + animationRef.current = null + } + + const start = currentViewportRef.current || { x: 0, y: 0, scale: 1 } + const target = initialViewportRef.current || { x: 0, y: 0, scale: 1 } + const duration = 600 + const startTime = performance.now() + + const animate = (currentTime) => { + const elapsed = currentTime - startTime + const t = Math.min(elapsed / duration, 1) + + // Zoom out fast at start (ease-out) + const zoomT = easeOutQuart(t) + + // Pan eases in-out + const panT = easeInOutCubic(t) + + const newViewport = { + x: start.x + (target.x - start.x) * panT, + y: start.y + (target.y - start.y) * panT, + scale: start.scale + (target.scale - start.scale) * zoomT + } + + setViewportState(newViewport) + + if (t < 1) { + animationRef.current = requestAnimationFrame(animate) + } else { + currentViewportRef.current = { ...target } + animationRef.current = null + } + } + + animationRef.current = requestAnimationFrame(animate) + }, []) + + // Handle click on a feature in the UMAP (with coordinates for zooming) + const handleFeatureClick = useCallback((featureId, x, y) => { + + setClickedFeatureId(featureId) + + if (featureId == null) return + + // Scroll to the feature card + setTimeout(() => { + const ref = featureRefs.current[featureId] + if (ref) { + ref.scrollIntoView({ behavior: 'smooth', block: 'center' }) + } + }, 50) + }, []) + + // Handle click on a feature card (highlights point in UMAP, no zoom) + const handleCardClick = useCallback(async (featureId, isExpanding) => { + + if (!isExpanding) { + setClickedFeatureId(null) + return + } + + setClickedFeatureId(featureId) + }, []) + + // Initialize Mosaic and load data + useEffect(() => { + async function init() { + try { + // Step 1: Initialize DuckDB-WASM + setLoadingProgress({ step: 1, total: 4, message: 'Initializing database engine...' }) + const wasm = wasmConnector() + vg.coordinator().databaseConnector(wasm) + + // Step 2: Load parquet data + setLoadingProgress({ step: 2, total: 4, message: 'Loading embedding data...' }) + const urlParams = new URLSearchParams(window.location.search) + const dataPath = urlParams.get('data') || '/features_atlas.parquet' + const parquetUrl = dataPath.startsWith('http') + ? dataPath + : new URL(dataPath, window.location.origin).href + + + await vg.coordinator().exec(` + CREATE TABLE features AS + SELECT * FROM read_parquet('${parquetUrl}') + `) + + // HDBSCAN assigns -1 to noise points; embedding-atlas casts category + // columns to UTINYINT which can't hold negatives. Remap to NULL. + try { + await vg.coordinator().exec(` + UPDATE features SET cluster_id = NULL WHERE cluster_id < 0 + `) + } catch (e) { + // cluster_id column may not exist — that's fine + } + + // Step 3: Process columns and categories + setLoadingProgress({ step: 3, total: 4, message: 'Processing columns...' }) + const schemaResult = await vg.coordinator().query(` + SELECT column_name, column_type + FROM (DESCRIBE features) + `) + + const columns = schemaResult.toArray().map(row => ({ + name: row.column_name, + type: row.column_type + })) + + const detectedCategories = [] + const sequentialColumns = [] + + for (const col of columns) { + if (['x', 'y', 'feature_id', 'top_example_idx', 'logo_path'].includes(col.name)) continue + + if (col.type === 'VARCHAR') { + const isGsea = col.name.startsWith('gsea_') + const maxUnique = isGsea ? Infinity : 50 + const cardinalityResult = await vg.coordinator().query(` + SELECT COUNT(DISTINCT "${col.name}") as n_unique FROM features WHERE "${col.name}" IS NOT NULL AND "${col.name}" != 'unlabeled' + `) + const nUnique = cardinalityResult.toArray()[0]?.n_unique ?? 0 + if (nUnique > 0 && nUnique <= maxUnique) { + // For high-cardinality GSEA columns, collapse to top 20 + "other" + if (isGsea && nUnique > 20) { + await vg.coordinator().exec(` + CREATE OR REPLACE TABLE features AS + SELECT * REPLACE ( + CASE + WHEN "${col.name}" IS NULL OR "${col.name}" = 'unlabeled' THEN 'unlabeled' + WHEN "${col.name}" IN ( + SELECT "${col.name}" FROM features + WHERE "${col.name}" IS NOT NULL AND "${col.name}" != 'unlabeled' + GROUP BY "${col.name}" ORDER BY COUNT(*) DESC LIMIT 20 + ) THEN "${col.name}" + ELSE 'other' + END AS "${col.name}" + ) FROM features + `) + detectedCategories.push({ name: col.name, type: 'string', nUnique: 22 }) + } else { + detectedCategories.push({ name: col.name, type: 'string', nUnique }) + } + } + } else if (col.type === 'BIGINT' || col.type === 'INTEGER') { + if (col.name.includes('cluster') || col.name.includes('category') || col.name.includes('group')) { + const cardinalityResult = await vg.coordinator().query(` + SELECT COUNT(DISTINCT "${col.name}") as n_unique FROM features WHERE "${col.name}" IS NOT NULL + `) + const nUnique = cardinalityResult.toArray()[0]?.n_unique ?? 0 + if (nUnique > 0 && nUnique <= 50) { + detectedCategories.push({ name: col.name, type: 'integer', nUnique }) + } + } + } else if (col.type === 'DOUBLE' || col.type === 'FLOAT') { + // Numeric columns for sequential coloring + if (['log_frequency', 'max_activation', 'activation_freq', 'frequency', + 'mean_variant_1bcdwt', + 'high_score_fraction', 'clinvar_fraction', + 'mean_phylop', 'mean_variant_delta', 'mean_site_delta', 'mean_local_delta', + 'high_score_delta', 'low_score_delta', + 'gc_mean', 'gc_std', + 'trinuc_entropy', 'trinuc_dominant_frac', + 'pli_mean_pli', 'pli_frac_constrained', 'pli_max_pli', + 'codon_cai', 'codon_tai', 'codon_rscu', + 'gene_entropy', 'gene_n_unique', 'gene_dominant_frac', + ].includes(col.name)) { + sequentialColumns.push({ name: col.name, type: 'sequential' }) + } + } + } + + // Create integer-encoded versions of string category columns + for (const col of detectedCategories) { + if (col.type === 'string') { + await vg.coordinator().exec(` + CREATE OR REPLACE TABLE features AS + SELECT *, + CASE WHEN "${col.name}" IS NULL THEN NULL + ELSE DENSE_RANK() OVER (ORDER BY "${col.name}") - 1 + END AS "${col.name}_cat" + FROM features + `) + } + } + + // Create binned versions of sequential columns (10 bins) + const NUM_BINS = 10 + for (const col of sequentialColumns) { + await vg.coordinator().exec(` + CREATE OR REPLACE TABLE features AS + SELECT *, + CASE WHEN "${col.name}" IS NULL THEN NULL + ELSE LEAST(${NUM_BINS - 1}, CAST( + (("${col.name}" - (SELECT MIN("${col.name}") FROM features)) / + NULLIF((SELECT MAX("${col.name}") - MIN("${col.name}") FROM features), 0)) * ${NUM_BINS} + AS INTEGER)) + END AS "${col.name}_bin" + FROM features + `) + detectedCategories.push({ name: col.name, type: 'sequential', nUnique: NUM_BINS }) + } + + setCategoryColumns(detectedCategories) + + // Create crossfilter selection + brushRef.current = vg.Selection.crossfilter() + + + // Step 4: Load feature metadata from parquet via DuckDB + setLoadingProgress({ step: 4, total: 4, message: 'Loading feature metadata...' }) + const metaUrl = new URL('/feature_metadata.parquet', window.location.origin).href + const examplesUrl = new URL('/feature_examples.parquet', window.location.origin).href + + await vg.coordinator().exec(` + CREATE TABLE IF NOT EXISTS feature_metadata AS + SELECT * FROM read_parquet('${metaUrl}') + `) + await vg.coordinator().exec(` + CREATE VIEW IF NOT EXISTS feature_examples AS + SELECT * FROM read_parquet('${examplesUrl}') + `) + + // Load features from the features table (which has labels + category columns) + const categorySelectCols = detectedCategories + .filter(c => c.type === 'string' || c.type === 'integer') + .map(c => `"${c.name}"`) + .join(', ') + const extraSelect = categorySelectCols ? `, ${categorySelectCols}` : '' + // logo_path is optional — older parquets won't have it, so detect and + // include it only if the column exists. + const hasLogoPath = columns.some(c => c.name === 'logo_path') + const logoSelect = hasLogoPath ? ', logo_path' : '' + const featuresResult = await vg.coordinator().query(` + SELECT + feature_id, + label, + activation_freq, + max_activation, + x, + y + ${logoSelect} + ${extraSelect} + FROM features + ORDER BY feature_id + `) + const loadedFeatures = featuresResult.toArray().map(row => { + const f = { + feature_id: row.feature_id, + label: row.label, + description: row.label, + activation_freq: row.activation_freq, + max_activation: row.max_activation, + x: row.x, + y: row.y, + logo_path: row.logo_path, + } + for (const col of detectedCategories) { + if (col.type === 'string' || col.type === 'integer') { + f[col.name] = row[col.name] + } + } + return f + }) + setFeatures(loadedFeatures) + + // Generate cluster labels from DuckDB (non-fatal if cluster_id doesn't exist) + try { + const clusterResult = await vg.coordinator().query(` + SELECT + cluster_id, + AVG(x) as cx, + AVG(y) as cy, + MODE(label) as top_label, + COUNT(*) as n + FROM features + WHERE cluster_id IS NOT NULL + GROUP BY cluster_id + ORDER BY n DESC + `) + const labels = clusterResult.toArray() + .filter(row => row.top_label && !row.top_label.startsWith('Feature ')) + .map((row, i) => ({ + x: Number(row.cx), + y: Number(row.cy), + text: row.top_label.length > 40 ? row.top_label.slice(0, 40) + '...' : row.top_label, + priority: row.n, + level: 0, + })) + console.log('[cluster labels] generated:', labels.length, labels.slice(0, 5)) + if (labels.length > 0) { + setClusterLabels(labels) + } + } catch (e) { + console.log('[cluster labels] query failed:', e.message) + } + + // Load cluster labels from file (overrides computed ones if present) + try { + const labelsRes = await fetch('./cluster_labels.json') + if (labelsRes.ok) { + const labelsData = await labelsRes.json() + setClusterLabels(labelsData) + } + } catch (labelErr) { + } + + // Load vocab logits (non-fatal if missing) + try { + const logitsRes = await fetch('./vocab_logits.json') + if (logitsRes.ok) { + const logitsData = await logitsRes.json() + setVocabLogits(logitsData) + } + } catch (e) { + } + + // Load feature analysis (non-fatal if missing) + try { + const analysisRes = await fetch('./feature_analysis.json') + if (analysisRes.ok) { + const analysisData = await analysisRes.json() + setFeatureAnalysis(analysisData) + } + } catch (e) { + } + + setMosaicReady(true) + setLoading(false) + + } catch (err) { + console.error('Init error:', err) + setError(err.message) + setLoading(false) + } + } + + init() + }, []) + + // Create a Mosaic client that receives filtered feature IDs + useEffect(() => { + if (!mosaicReady || !brushRef.current) return + + const coordinator = vg.coordinator() + const selection = brushRef.current + const totalFeatures = features.length + + // Create a class that extends MosaicClient + class FeatureFilterClient extends MosaicClient { + constructor(filterBy) { + super(filterBy) + this._isConnected = true + } + + query(filter = []) { + // Use Mosaic's Query builder + const q = Query + .select({ feature_id: 'feature_id' }) + .distinct() + .from('features') + + // Apply filter if present + if (filter.length > 0) { + q.where(filter) + } + + return q + } + + queryResult(data) { + if (!this._isConnected) return + + try { + let ids = new Set() + if (data && typeof data.getChild === 'function') { + const col = data.getChild('feature_id') + if (col) { + for (let i = 0; i < col.length; i++) { + ids.add(col.get(i)) + } + } + } else if (data && data.toArray) { + ids = new Set(data.toArray().map(r => r.feature_id)) + } + setSelectedFeatureIds(ids.size > 0 && ids.size < totalFeatures ? ids : null) + } catch (err) { + console.error('Error processing result:', err) + } + } + + // Required by Mosaic for selection updates + update() { + return this + } + + queryError(err) { + if (this._isConnected) { + console.error('FeatureFilterClient error:', err) + } + } + + disconnect() { + this._isConnected = false + } + } + + const client = new FeatureFilterClient(selection) + + // Delay connection slightly to ensure Mosaic is fully ready + const timeoutId = setTimeout(() => { + try { + coordinator.connect(client) + } catch (err) { + console.warn('Error connecting FeatureFilterClient:', err) + } + }, 0) + + return () => { + clearTimeout(timeoutId) + try { + client.disconnect() + coordinator.disconnect(client) + } catch (err) { + // Ignore disconnect errors + } + } + }, [mosaicReady, features.length]) + + // Clear ALL selections (search, histograms, UMAP, clicked feature) + const handleClearSelection = useCallback(() => { + if (brushRef.current) { + const selection = brushRef.current + // Clear each clause by updating with null predicate for each source + const clauses = selection.clauses || [] + for (const clause of clauses) { + if (clause.source) { + try { + selection.update({ source: clause.source, predicate: null, value: null }) + } catch (e) { + // Ignore errors from clearing + } + } + } + // Also clear the search clause specifically + if (searchSource.current) { + try { + selection.update({ source: searchSource.current, predicate: null, value: null }) + } catch (e) { + // Ignore + } + } + } + setSelectedFeatureIds(null) + setSearchTerm('') + setClickedFeatureId(null) + setHiddenCategories(new Set()) + // Reset viewport to the auto-fit view captured on first load + if (initialViewportRef.current) { + setViewportState({ ...initialViewportRef.current }) + currentViewportRef.current = { ...initialViewportRef.current } + } else { + setViewportState(null) + currentViewportRef.current = null + } + // Reset all cards to collapsed state + setCardResetKey(k => k + 1) + // Reset histograms and UMAP to clear brush visuals + setPlotResetKey(k => k + 1) + }, []) + + // Export all edited features to CSV with full data + const handleExportEdited = useCallback(() => { + // Get all edited features + const editedFeatures = features.filter(f => localStorage.getItem(`featureTitle_${f.feature_id}`) !== null) + + if (editedFeatures.length === 0) { + alert('No edited features to export') + return + } + + const lines = [] + const escapeCsv = (str) => `"${(str || '').toString().replace(/"/g, '""')}"` + + // Codon mapping for amino acids + const CODON_AA = { + 'TTT':'F','TTC':'F','TTA':'L','TTG':'L','TCT':'S','TCC':'S','TCA':'S','TCG':'S', + 'TAT':'Y','TAC':'Y','TAA':'*','TAG':'*','TGT':'C','TGC':'C','TGA':'*','TGG':'W', + 'CTT':'L','CTC':'L','CTA':'L','CTG':'L','CCT':'P','CCC':'P','CCA':'P','CCG':'P', + 'CAT':'H','CAC':'H','CAA':'Q','CAG':'Q','CGT':'R','CGC':'R','CGA':'R','CGG':'R', + 'ATT':'I','ATC':'I','ATA':'I','ATG':'M','ACT':'T','ACC':'T','ACA':'T','ACG':'T', + 'AAT':'N','AAC':'N','AAA':'K','AAG':'K','AGT':'S','AGC':'S','AGA':'R','AGG':'R', + 'GTT':'V','GTC':'V','GTA':'V','GTG':'V','GCT':'A','GCC':'A','GCA':'A','GCG':'A', + 'GAT':'D','GAC':'D','GAA':'E','GAG':'E','GGT':'G','GGC':'G','GGA':'G','GGG':'G', + } + + editedFeatures.forEach((f, idx) => { + const userTitle = localStorage.getItem(`featureTitle_${f.feature_id}`) + const label = f.label || `Feature ${f.feature_id}` + + // Add separator for readability + if (idx > 0) lines.push('') + + // Feature metadata + lines.push(`=== FEATURE ${f.feature_id} ===`) + lines.push(`Feature ID,${f.feature_id}`) + lines.push(`Original Label,${escapeCsv(label)}`) + lines.push(`Your Title,${escapeCsv(userTitle)}`) + lines.push(`Activation Frequency,${(f.activation_freq || 0).toFixed(6)}`) + lines.push(`Max Activation,${(f.max_activation || 0).toFixed(4)}`) + lines.push('') + + // Vocab logits + const logits = vocabLogits?.[String(f.feature_id)] + if (logits) { + lines.push('TOP PROMOTED CODONS') + lines.push('Codon,Amino Acid,Logit Value') + ;(logits.top_positive || []).forEach(([codon, val]) => { + lines.push(`${codon},${CODON_AA[codon] || '?'},${val.toFixed(4)}`) + }) + lines.push('') + + lines.push('TOP SUPPRESSED CODONS') + lines.push('Codon,Amino Acid,Logit Value') + ;(logits.top_negative || []).forEach(([codon, val]) => { + lines.push(`${codon},${CODON_AA[codon] || '?'},${val.toFixed(4)}`) + }) + lines.push('') + } + + // Feature analysis + const analysis = featureAnalysis?.[String(f.feature_id)] + if (analysis?.codon_annotations) { + lines.push('CODON ANNOTATIONS') + const ann = analysis.codon_annotations + if (ann.amino_acid) { + lines.push(`Amino Acid,${ann.amino_acid.aa}`) + lines.push(`AA Frequency,${(ann.amino_acid.fraction * 100).toFixed(1)}%`) + } + if (ann.codon_usage) { + lines.push(`Codon Usage,${ann.codon_usage.bias}`) + } + if (ann.wobble) { + lines.push(`Wobble Position,${ann.wobble.preference}`) + } + if (ann.cpg) { + lines.push(`CpG Context,${ann.cpg.fraction}`) + } + lines.push('') + } + }) + + // Create and download file + const csv = lines.join('\n') + const blob = new Blob([csv], { type: 'text/csv' }) + const url = URL.createObjectURL(blob) + const a = document.createElement('a') + a.href = url + a.download = `edited_features_${new Date().toISOString().split('T')[0]}.csv` + document.body.appendChild(a) + a.click() + document.body.removeChild(a) + URL.revokeObjectURL(url) + }, [features, vocabLogits, featureAnalysis]) + + // Update Mosaic crossfilter when "Edited Only" toggle changes + useEffect(() => { + if (!brushRef.current || !mosaicReady) return + + const selection = brushRef.current + + if (showEditedOnly) { + // Get all edited feature IDs from localStorage + const editedIds = features + .filter(f => localStorage.getItem(`featureTitle_${f.feature_id}`) !== null) + .map(f => f.feature_id) + + if (editedIds.length > 0) { + // Create predicate: feature_id IN (id1, id2, id3, ...) + const idsStr = editedIds.join(',') + // Use raw SQL string, not literal() which would quote it as a string + const predicateSql = `feature_id IN (${idsStr})` + + try { + selection.update({ + source: editedSource.current, + predicate: predicateSql, + value: 'edited' + }) + } catch (err) { + console.warn('Error updating edited filter:', err) + } + } + } else { + // Clear the edited filter + try { + selection.update({ + source: editedSource.current, + predicate: null, + value: null + }) + } catch (err) { + console.warn('Error clearing edited filter:', err) + } + } + }, [showEditedOnly, mosaicReady, features]) + + // Update Mosaic crossfilter when legend selection changes + useEffect(() => { + if (!brushRef.current || !mosaicReady) return + + const selection = brushRef.current + + if (hiddenCategories.size > 0 && selectedCategory && selectedCategory !== 'none') { + const colInfo = categoryColumns.find(c => c.name === selectedCategory) + if (colInfo && (colInfo.type === 'string' || colInfo.type === 'integer')) { + const values = Array.from(hiddenCategories).map(v => `'${v.replace(/'/g, "''")}'`).join(',') + const predicateSql = `"${selectedCategory}" IN (${values})` + + try { + selection.update({ + source: legendSource.current, + predicate: predicateSql, + value: Array.from(hiddenCategories).join(',') + }) + } catch (err) { + console.warn('Legend filter update failed:', err) + } + } + } else { + try { + selection.update({ + source: legendSource.current, + predicate: null, + value: null + }) + } catch (err) { + // Ignore + } + } + }, [hiddenCategories, selectedCategory, mosaicReady, categoryColumns]) + + // Handle search - updates both Mosaic crossfilter (for UMAP/histograms) and local state (for cards) + const handleSearchChange = useCallback((e) => { + const term = e.target.value + setSearchTerm(term) + + // Also update Mosaic crossfilter so UMAP and histograms filter + if (brushRef.current) { + const selection = brushRef.current + + try { + if (term.trim()) { + // Build predicate using sql template - ILIKE for case-insensitive search + const pattern = literal('%' + term.trim() + '%') + const predicate = sql`label ILIKE ${pattern}` + + selection.update({ + source: searchSource.current, + predicate: predicate, + value: term.trim() + }) + } else { + // Clear search by removing the clause + selection.update({ + source: searchSource.current, + predicate: null, + value: null + }) + } + } catch (err) { + console.warn('Search update error:', err) + } + } + }, []) + + // Filter and sort features + const filteredFeatures = useMemo(() => { + let result = features + + // Filter by Mosaic selection (includes UMAP brush) + if (selectedFeatureIds !== null) { + result = result.filter(f => selectedFeatureIds.has(f.feature_id)) + } + + // Also filter by search term client-side (searches metadata fields) + if (searchTerm.trim()) { + const q = searchTerm.toLowerCase() + result = result.filter(f => + f.description?.toLowerCase().includes(q) || + f.feature_id.toString().includes(q) || + f.best_annotation?.toLowerCase().includes(q) + ) + } + + // Filter by edited features only + if (showEditedOnly) { + result = result.filter(f => localStorage.getItem(`featureTitle_${f.feature_id}`) !== null) + } + + // Helper: unlabeled features sort last + const isUnlabeled = (f) => { + const lbl = (f.label || f.description || '').toLowerCase() + return !lbl || lbl.startsWith('feature ') || lbl.includes('common codons') + } + + // Sort (labeled features first, then by chosen metric) + if (sortBy === 'frequency') { + result = [...result].sort((a, b) => isUnlabeled(a) - isUnlabeled(b) || (b.activation_freq || 0) - (a.activation_freq || 0)) + } else if (sortBy === 'max_activation') { + result = [...result].sort((a, b) => isUnlabeled(a) - isUnlabeled(b) || (b.max_activation || 0) - (a.max_activation || 0)) + } else if (sortBy === 'feature_id') { + result = [...result].sort((a, b) => isUnlabeled(a) - isUnlabeled(b) || a.feature_id - b.feature_id) + } else if (sortBy === 'high_score_fraction') { + result = [...result].sort((a, b) => isUnlabeled(a) - isUnlabeled(b) || (b.high_score_fraction || 0) - (a.high_score_fraction || 0)) + } else if (sortBy === 'mean_variant_delta') { + result = [...result].sort((a, b) => isUnlabeled(a) - isUnlabeled(b) || Math.abs(b.mean_variant_delta || 0) - Math.abs(a.mean_variant_delta || 0)) + } else if (sortBy === 'mean_site_delta') { + result = [...result].sort((a, b) => isUnlabeled(a) - isUnlabeled(b) || Math.abs(b.mean_site_delta || 0) - Math.abs(a.mean_site_delta || 0)) + } else if (sortBy === 'mean_local_delta') { + result = [...result].sort((a, b) => isUnlabeled(a) - isUnlabeled(b) || Math.abs(b.mean_local_delta || 0) - Math.abs(a.mean_local_delta || 0)) + } else if (sortBy === 'clinvar_fraction') { + result = [...result].sort((a, b) => isUnlabeled(a) - isUnlabeled(b) || (b.clinvar_fraction || 0) - (a.clinvar_fraction || 0)) + } else if (sortBy === 'mean_phylop') { + result = [...result].sort((a, b) => isUnlabeled(a) - isUnlabeled(b) || (b.mean_phylop || 0) - (a.mean_phylop || 0)) + } else if (sortBy === 'gc_mean') { + result = [...result].sort((a, b) => isUnlabeled(a) - isUnlabeled(b) || Math.abs((b.gc_mean || 0.5) - 0.5) - Math.abs((a.gc_mean || 0.5) - 0.5)) + } else if (sortBy === 'trinuc_entropy') { + result = [...result].sort((a, b) => isUnlabeled(a) - isUnlabeled(b) || (a.trinuc_entropy ?? 99) - (b.trinuc_entropy ?? 99)) + } else if (sortBy === 'gene_entropy') { + result = [...result].sort((a, b) => isUnlabeled(a) - isUnlabeled(b) || (a.gene_entropy ?? 99) - (b.gene_entropy ?? 99)) + } else if (sortBy === 'gene_n_unique') { + result = [...result].sort((a, b) => isUnlabeled(a) - isUnlabeled(b) || (a.gene_n_unique || 999) - (b.gene_n_unique || 999)) + } + + return result + }, [features, sortBy, selectedFeatureIds, searchTerm, showEditedOnly]) + + // Reset pagination when filters change + useEffect(() => { + setDisplayedCardCount(20) + loadingMoreRef.current = false + }, [searchTerm, sortBy, selectedFeatureIds, showEditedOnly]) + + if (loading) { + const pct = Math.round(((loadingProgress.step - 1) / loadingProgress.total) * 100) + return ( +
+
Loading dashboard...
+
+
+
+
{loadingProgress.message}
+
+ ) + } + + if (error) { + return ( +
+

Error: {error}

+

+ Make sure features_atlas.parquet, feature_metadata.parquet, and feature_examples.parquet exist in the public/ folder. +

+
+ ) + } + + return ( +
+
+ MOCKUP — synthetic data, not from a real SAE run +
+
+
+

Evo 2 SAE Feature Explorer (Mockup)

+
+
+ + +
+
+ +
+
+
+
+ + Decoder UMAP + +
+ + + setShowMetricsModal(true)} + style={{ + display: 'inline-flex', alignItems: 'center', justifyContent: 'center', + width: '15px', height: '15px', borderRadius: '50%', border: '1px solid var(--border-input)', + fontSize: '10px', fontWeight: '600', color: 'var(--text-tertiary)', cursor: 'pointer', + userSelect: 'none', lineHeight: 1, flexShrink: 0, + }} + >i + +
+
+
+ {mosaicReady && ( + + )} + {selectedCategory && selectedCategory !== 'none' && (() => { + const colInfo = categoryColumns.find(c => c.name === selectedCategory) + if (!colInfo) return null + + if (colInfo.type === 'sequential') { + const colors = [ + "#c359ef", "#9525C6", "#0046a4", "#0074DF", "#3f8500", + "#76B900", "#ef9100", "#F9C500", "#ff8181", "#EF2020" + ] + const vals = features + .map(f => f[selectedCategory]) + .filter(v => v != null && !isNaN(v)) + const minVal = vals.length > 0 ? Math.min(...vals) : 0 + const maxVal = vals.length > 0 ? Math.max(...vals) : 1 + const fmt = (v) => Math.abs(v) >= 100 ? v.toFixed(0) : Math.abs(v) >= 1 ? v.toFixed(1) : v.toFixed(3) + return ( +
+ {fmt(maxVal)} +
+ {fmt(minVal)} + + {selectedCategory.replace(/_/g, ' ')} + +
+ ) + } + + if (colInfo.type === 'string' || colInfo.type === 'integer') { + const catColors = [ + "#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", + "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf", + "#aec7e8", "#ffbb78", "#98df8a", "#ff9896", "#c5b0d5", + "#c49c94", "#f7b6d2", "#c7c7c7", "#dbdb8d", "#9edae5" + ] + // Count occurrences of each category value, sorted alphabetically + // (matching DENSE_RANK ORDER BY which is alphabetical) + const counts = {} + for (const f of features) { + const val = f[selectedCategory] + if (val != null && val !== '') { + counts[val] = (counts[val] || 0) + 1 + } + } + // Sort alphabetically to match dense_rank ordering + const sortedCategories = Object.keys(counts).sort() + return ( +
+
+ {selectedCategory.replace(/_/g, ' ').replace('gsea ', '')} +
+ {sortedCategories.map((cat, i) => { + const hasFilter = hiddenCategories.size > 0 + const isHidden = hasFilter && !hiddenCategories.has(cat) + return ( +
{ + if (e.metaKey || e.ctrlKey) { + // Cmd/Ctrl+click: toggle this category in the selection + setHiddenCategories(prev => { + const next = new Set(prev) + if (next.has(cat)) { + next.delete(cat) + // If nothing left selected, clear filter + return next.size === 0 ? new Set() : next + } else { + next.add(cat) + return next + } + }) + } else { + // Regular click: solo this category (or clear if already solo'd) + setHiddenCategories(prev => { + if (prev.size === 1 && prev.has(cat)) return new Set() + return new Set([cat]) + }) + } + }} + style={{ + display: 'flex', alignItems: 'center', gap: '5px', padding: '2px 0', + cursor: 'pointer', opacity: isHidden ? 0.15 : 1, + userSelect: 'none', + }} + > + + + {cat} + + + {counts[cat]} + +
+ ) + })} +
+ ) + } + + return null + })()} +
+
+ +
+ {[ + { value: histMetric1, setter: setHistMetric1 }, + { value: histMetric2, setter: setHistMetric2 }, + { value: histMetric3, setter: setHistMetric3 }, + ].map(({ value, setter }, i) => ( +
+
+ +
+ {mosaicReady && value && value !== 'none' && ( + + )} +
+ ))} +
+
+ +
+
+ + + +
+ +
+ + Showing {filteredFeatures.length} of {features.length} features + {selectedFeatureIds !== null && ` (${selectedFeatureIds.size} selected in UMAP)`} + + setShowGuideModal(true)} + style={{ + display: 'inline-flex', alignItems: 'center', justifyContent: 'center', + width: '15px', height: '15px', borderRadius: '50%', border: '1px solid #bbb', + fontSize: '10px', fontWeight: '600', color: '#888', cursor: 'pointer', + userSelect: 'none', lineHeight: 1, flexShrink: 0, + }} + >i +
+ + +
+
+ + {showGuideModal && ( +
setShowGuideModal(false)} + style={{ + position: 'fixed', inset: 0, background: 'rgba(0,0,0,0.45)', + display: 'flex', alignItems: 'center', justifyContent: 'center', zIndex: 1000, + }} + > +
e.stopPropagation()} + style={{ + background: 'var(--bg-card)', borderRadius: '10px', maxWidth: '560px', width: '90%', + maxHeight: '80vh', overflowY: 'auto', padding: '28px 32px', + boxShadow: '0 8px 30px rgba(0,0,0,0.2)', + }} + > +
+

Feature Card Guide

+ setShowGuideModal(false)} + style={{ cursor: 'pointer', fontSize: '20px', color: '#999', lineHeight: 1 }} + >× +
+ +
+

Decoder Logits

+

+ The decoder logits histogram shows the projection of each feature's learned decoder weight vector through the language model's prediction head, with the mean logit vector subtracted across all features. This mean-centering removes the model's shared baseline bias toward common codons (e.g. GCC), so values reflect what each feature specifically promotes or suppresses relative to the average feature. Each bar represents a codon. Green bars indicate codons the feature promotes above baseline; red bars indicate codons it suppresses below baseline. Gray bars have no feature-specific effect. This tells you what the feature pushes the model to output — not what activates it. Stop codons (TAA, TAG, TGA) are excluded because the model was trained on coding sequences where internal stops almost never appear, so all features uniformly suppress them. +

+ +

Top Activating Sequences

+

+ These are the protein-coding sequences where this feature fires most strongly. Each codon is colored by its activation value — brighter highlights mean the feature responds more strongly at that position. This shows what inputs trigger the feature, which is conceptually distinct from decoder logits. A feature can activate strongly on a particular codon (e.g., lysine codons) without promoting that same codon in the output — it may instead influence downstream or contextual predictions. +

+ +
+
+
+ )} + + {showMetricsModal && ( +
setShowMetricsModal(false)} + style={{ + position: 'fixed', inset: 0, background: 'rgba(0,0,0,0.45)', + display: 'flex', alignItems: 'center', justifyContent: 'center', zIndex: 1000, + }} + > +
e.stopPropagation()} + style={{ + background: 'var(--bg-card)', borderRadius: '10px', maxWidth: '620px', width: '90%', + maxHeight: '80vh', overflowY: 'auto', padding: '28px 32px', + boxShadow: '0 8px 30px rgba(0,0,0,0.2)', + }} + > +
+

Variant Analysis Metrics

+ setShowMetricsModal(false)} + style={{ cursor: 'pointer', fontSize: '20px', color: '#999', lineHeight: 1 }} + >× +
+ +
+

Mean Variant Score (per model)

+

+ For each feature, the average model effect score across variant sequences where the feature fires. Computed for the 1b_cdwt model score column. A high value means the feature preferentially activates on variants that model predicts to be functionally impactful. +

+ +

High Score Fraction

+

+ Variants are split at the median model score. Among variants where a feature fires, what fraction are high-scoring? A value of 0.5 means no preference. Above 0.5 means the feature disproportionately fires on high-impact variants. Robust to outliers — measures distributional preference rather than average. +

+ +

ClinVar Fraction

+

+ Among variant sequences where the feature fires, the fraction from ClinVar vs COSMIC. ClinVar variants are germline (inherited, Mendelian disease). COSMIC variants are somatic (cancer mutations). High ClinVar fraction means the feature responds to germline disease patterns; low means it prefers somatic cancer mutation patterns. +

+ +

Mean PhyloP

+

+ Average evolutionary conservation score (PhyloP) across sequences where the feature fires. High values indicate conserved positions (functionally important). Negative values indicate rapidly evolving regions. Features with high mean PhyloP capture evolutionarily constrained patterns. +

+ +

Mean Variant Delta

+

+ For each gene, the difference in max feature activation between the variant and reference sequence: max_act(variant) − max_act(ref), averaged across all variant-ref pairs. Positive means the mutation increases feature activation; negative means it suppresses it. Near zero means the feature responds to the gene background, not the specific mutation. This controls for gene identity. +

+ +

Mean Site Delta

+

+ Like mean variant delta, but measured only at the exact codon position where the mutation occurs: activation_f(variant, pos) − activation_f(ref, pos). This captures direct effects — the feature responding to the changed codon itself. Compare with mean variant delta: a large variant delta but small site delta means the feature captures indirect/distal effects of the mutation (e.g., changes to predicted protein folding context), not the local codon change. +

+ +

Mean Local Delta

+

+ Like variant delta, but using the max activation within a 3-codon window around the variant site instead of the full sequence. Captures local effects of the mutation: max(window_variant) − max(window_ref). A large local delta with a small global delta means the mutation's effect is localized. Compare with site delta (exact position only) and variant delta (full sequence). +

+ +

GC Content (mean, std)

+

+ Mean and standard deviation of GC content across all sequences where the feature fires. Features with extreme GC mean (far from ~0.5) are GC-biased. Features with low GC std activate only on sequences with similar GC content — suggesting sensitivity to nucleotide composition rather than specific codon patterns. +

+ +

Trinuc Entropy

+

+ Shannon entropy (in bits) of the trinucleotide context distribution among variant sequences where the feature fires. Low entropy means the feature concentrates on specific mutation contexts (e.g., C[C>T]G for CpG transitions). High entropy means it fires across diverse mutation types. The dominant fraction shows what fraction of activations come from the most common trinuc context. +

+ +

Gene Distribution

+

+ Shannon entropy of the gene distribution among sequences where the feature fires. Low entropy means the feature is gene-specific — it concentrates on a few genes. High entropy means it fires broadly. gene_n_unique is the number of distinct genes. gene_dominant_frac is the fraction from the most common gene. A feature with low entropy and high dominant fraction has learned something specific to one gene family. +

+ +

High Score Delta

+

+ Same as mean variant delta, but averaged only over variants with model scores above the median. Shows how the feature responds specifically to high-impact mutations. Compare with low score delta: if high_score_delta >> low_score_delta, the feature selectively detects impactful mutations. +

+ +

Low Score Delta

+

+ Same as mean variant delta, but averaged only over variants with model scores below the median. Features where high score delta and low score delta differ significantly have learned to discriminate mutation severity. Features where both are similar just detect that a mutation occurred without distinguishing impact. +

+
+
+
+ )} +
+ ) +} diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/src/EmbeddingView.jsx b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/src/EmbeddingView.jsx new file mode 100644 index 0000000000..bc14226257 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/src/EmbeddingView.jsx @@ -0,0 +1,334 @@ +import React, { useEffect, useRef } from 'react' +import { EmbeddingViewMosaic } from 'embedding-atlas' + +// Color palette for categories (D3 category10 + extended) +const CATEGORY_COLORS = [ + "#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", + "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf", + "#aec7e8", "#ffbb78", "#98df8a", "#ff9896", "#c5b0d5", + "#c49c94", "#f7b6d2", "#c7c7c7", "#dbdb8d", "#9edae5" +] + +// Sequential color palette (NVIDIA brand) +const SEQUENTIAL_COLORS = [ + "#c359ef", "#9525C6", "#0046a4", "#0074DF", "#3f8500", + "#76B900", "#ef9100", "#F9C500", "#ff8181", "#EF2020" +] + +// Default color for uniform coloring (NVIDIA green) +const DEFAULT_COLOR = "#76b900" + +// Custom tooltip renderer +class FeatureTooltip { + constructor(node, props) { + this.node = node + this.inner = document.createElement("div") + this.inner.style.cssText = ` + background: var(--bg-card); + border: 1px solid var(--border); + border-radius: 4px; + padding: 8px 12px; + font-family: 'NVIDIA Sans', -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; + font-size: 13px; + box-shadow: 0 2px 8px rgba(0,0,0,0.25); + max-width: 300px; + color: var(--text); + ` + this.node.appendChild(this.inner) + this.update(props) + } + + update(props) { + const { tooltip } = props + if (!tooltip) { + this.inner.innerHTML = "" + return + } + const featureId = tooltip.identifier ?? "" + const label = tooltip.fields?.label ?? tooltip.text ?? "" + const logFreq = tooltip.fields?.log_frequency + const maxAct = tooltip.fields?.max_activation + const colorField = tooltip.fields?.color_field + + this.inner.innerHTML = ` +
Feature #${featureId}
+
${label}
+ ${colorField ? `
Category: ${colorField}
` : ""} + ${logFreq !== undefined ? `
Log Frequency: ${logFreq.toFixed(3)}
` : ""} + ${maxAct !== undefined ? `
Max Activation: ${maxAct.toFixed(2)}
` : ""} + ` + } + + destroy() { + this.inner.remove() + } +} + +export default function EmbeddingView({ brush, categoryColumn, categoryColumns, onFeatureClick, highlightedFeatureId, viewportState, onViewportChange, labels, features, selectedCategory, darkMode, hiddenCategories }) { + const containerRef = useRef(null) + const viewRef = useRef(null) + const onFeatureClickRef = useRef(onFeatureClick) + const onViewportChangeRef = useRef(onViewportChange) + + // Keep the callback refs updated + useEffect(() => { + onFeatureClickRef.current = onFeatureClick + }, [onFeatureClick]) + + useEffect(() => { + onViewportChangeRef.current = onViewportChange + }, [onViewportChange]) + + // Update selection and tooltip when highlightedFeatureId changes + useEffect(() => { + if (viewRef.current && highlightedFeatureId != null) { + // Find the feature data + const feature = features?.find(f => f.feature_id === highlightedFeatureId) + + // Build tooltip fields + const fields = { + label: feature?.label || `Feature ${highlightedFeatureId}`, + log_frequency: feature?.log_frequency || feature?.activation_freq || 0, + max_activation: feature?.max_activation || 0, + color_field: null + } + + // Add selected category metric if available + if (selectedCategory && selectedCategory !== 'none' && feature) { + const metricName = selectedCategory.replace(/_/g, ' ') + const metricValue = feature[selectedCategory] + if (metricValue !== undefined && metricValue !== null) { + fields.color_field = `${metricName}: ${typeof metricValue === 'number' ? metricValue.toFixed(3) : metricValue}` + } + } + + // Construct tooltip object with feature data + const tooltipObj = { + identifier: highlightedFeatureId, + text: `Feature #${highlightedFeatureId}`, + x: feature?.x, + y: feature?.y, + fields: fields + } + // Clear previous selection first to avoid animated transition + viewRef.current.update({ + selection: null, + tooltip: null + }) + viewRef.current.update({ + selection: [highlightedFeatureId], + tooltip: tooltipObj + }) + } else if (viewRef.current && highlightedFeatureId == null) { + viewRef.current.update({ + selection: null, + tooltip: null + }) + } + }, [highlightedFeatureId, features, selectedCategory]) + + // Update viewport when viewportState changes (skip null to let auto-fit persist) + useEffect(() => { + if (viewRef.current && viewportState != null) { + viewRef.current.update({ + viewportState: viewportState + }) + } + }, [viewportState]) + + // Update color scheme when dark mode changes + useEffect(() => { + if (viewRef.current) { + viewRef.current.update({ + config: { colorScheme: darkMode ? "dark" : "light" } + }) + } + }, [darkMode]) + + // Update labels when they change + useEffect(() => { + if (viewRef.current && labels) { + console.log('[EmbeddingView] updating labels:', labels.length, labels.slice(0, 2)) + viewRef.current.update({ + labels: labels + }) + } + }, [labels]) + + useEffect(() => { + if (!containerRef.current || !brush) return + + // Clear previous view + if (viewRef.current) { + containerRef.current.innerHTML = '' + } + + // Determine category column and colors + let categoryColName = null + let colors = Array(50).fill(DEFAULT_COLOR) + let additionalFields = { + label: "label", + log_frequency: "log_frequency", + max_activation: "max_activation", + } + + if (categoryColumn && categoryColumn !== "none") { + const colInfo = categoryColumns?.find(c => c.name === categoryColumn) + if (colInfo) { + if (colInfo.type === 'sequential') { + // Sequential column - use binned version and sequential colors + categoryColName = `${categoryColumn}_bin` + colors = SEQUENTIAL_COLORS + } else if (colInfo.type === 'string') { + // Categorical string column + categoryColName = `${categoryColumn}_cat` + colors = CATEGORY_COLORS.slice(0, Math.max(colInfo.nUnique, 10)) + } else { + // Integer categorical column + categoryColName = categoryColumn + colors = CATEGORY_COLORS.slice(0, Math.max(colInfo.nUnique, 10)) + } + additionalFields.color_field = categoryColumn + } + } + + const width = containerRef.current.clientWidth + const height = containerRef.current.clientHeight + + try { + viewRef.current = new EmbeddingViewMosaic( + containerRef.current, + { + table: "features", + x: "x", + y: "y", + category: categoryColName, + text: "label", + identifier: "feature_id", + filter: brush, + rangeSelection: brush, + selection: highlightedFeatureId != null ? [highlightedFeatureId] : null, + viewportState: viewportState, + categoryColors: colors, + width: width, + height: height, + labels: labels || null, + config: { + mode: "points", + colorScheme: document.documentElement.classList.contains('dark') ? "dark" : "light", + autoLabelEnabled: false, + }, + theme: { + brandingLink: { + text: "NVIDIA BioNeMo", + href: "https://github.com/NVIDIA/bionemo-framework", + }, + }, + additionalFields: additionalFields, + customTooltip: FeatureTooltip, + onSelection: (selection) => { + // selection is DataPoint[] | null + if (!onFeatureClickRef.current) return + + if (selection && selection.length > 0) { + // Get the last clicked point (most recent selection) + const lastPoint = selection[selection.length - 1] + const featureId = lastPoint?.identifier ?? lastPoint + const x = lastPoint?.x + const y = lastPoint?.y + if (featureId != null) { + onFeatureClickRef.current(featureId, x, y) + } + } else { + // Clicked on empty canvas - clear selection + onFeatureClickRef.current(null) + } + }, + onViewportState: (vp) => { + if (onViewportChangeRef.current && vp) { + onViewportChangeRef.current(vp) + } + }, + } + ) + } catch (err) { + console.warn('Error creating EmbeddingViewMosaic:', err) + } + + return () => { + if (containerRef.current) { + containerRef.current.innerHTML = '' + } + } + }, [brush]) + + // Update category coloring in-place (without recreating the view) + useEffect(() => { + if (!viewRef.current) return + + let categoryColName = null + const HIDDEN_COLOR = darkMode ? "#0a0a0a" : "#fafafa" + let colors = Array(50).fill(HIDDEN_COLOR) + + if (categoryColumn && categoryColumn !== "none") { + const colInfo = categoryColumns?.find(c => c.name === categoryColumn) + if (colInfo) { + if (colInfo.type === 'sequential') { + categoryColName = `${categoryColumn}_bin` + colors = SEQUENTIAL_COLORS + } else if (colInfo.type === 'string') { + categoryColName = `${categoryColumn}_cat` + colors = CATEGORY_COLORS.slice(0, Math.max(colInfo.nUnique, 10)) + // Map colors to match DENSE_RANK order, dim non-selected when filtering + if (hiddenCategories && hiddenCategories.size > 0 && features) { + const allCatNames = [...new Set( + features.map(f => f[categoryColumn]).filter(v => v != null) + )].sort() + colors = colors.map((c, i) => { + const name = allCatNames[i] + if (!name) return c + return !hiddenCategories.has(name) ? HIDDEN_COLOR : c + }) + } + } else { + categoryColName = categoryColumn + colors = CATEGORY_COLORS.slice(0, Math.max(colInfo.nUnique, 10)) + } + } + } + + viewRef.current.update({ + category: categoryColName, + categoryColors: colors, + selection: null, + tooltip: null, + }) + }, [categoryColumn, categoryColumns, hiddenCategories]) + + // Handle resize + useEffect(() => { + const handleResize = () => { + if (viewRef.current && containerRef.current) { + const width = containerRef.current.clientWidth + const height = containerRef.current.clientHeight + viewRef.current.update({ width, height }) + } + } + + const resizeObserver = new ResizeObserver(handleResize) + if (containerRef.current) { + resizeObserver.observe(containerRef.current) + } + + return () => { + resizeObserver.disconnect() + } + }, []) + + return ( +
+ ) +} diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/src/FeatureCard.jsx b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/src/FeatureCard.jsx new file mode 100644 index 0000000000..d45bc121df --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/src/FeatureCard.jsx @@ -0,0 +1,518 @@ +import React, { useState, useEffect, useRef, forwardRef } from 'react' +import SequenceView, { computeAlignInfo } from './SequenceView' +import FeatureDetailPage from './FeatureDetailPage' +import { getRegionLabel } from './utils' + +const styles = { + card: { + background: 'var(--bg-card)', + borderRadius: '8px', + border: '1px solid var(--border)', + flexShrink: 0, + }, + cardHighlighted: { + background: 'var(--bg-card)', + borderRadius: '8px', + border: '2px solid var(--highlight-border)', + flexShrink: 0, + boxShadow: '0 2px 8px var(--highlight-shadow)', + }, + header: { + padding: '12px 14px', + borderBottom: '1px solid var(--border-light)', + cursor: 'pointer', + display: 'flex', + justifyContent: 'space-between', + alignItems: 'flex-start', + gap: '10px', + }, + headerLeft: { + flex: 1, + minWidth: 0, + }, + featureId: { + fontSize: '11px', + color: 'var(--text-tertiary)', + fontFamily: 'monospace', + marginBottom: '2px', + }, + description: { + fontSize: '13px', + fontWeight: '500', + wordBreak: 'break-word', + lineHeight: '1.4', + color: 'var(--text)', + }, + userTitle: { + fontSize: '13px', + fontWeight: '500', + wordBreak: 'break-word', + lineHeight: '1.4', + color: 'var(--accent)', + fontStyle: 'italic', + }, + stats: { + display: 'flex', + gap: '12px', + fontSize: '11px', + color: 'var(--text-secondary)', + flexShrink: 0, + }, + stat: { + display: 'flex', + flexDirection: 'column', + alignItems: 'flex-end', + }, + statLabel: { + color: 'var(--text-muted)', + fontSize: '9px', + textTransform: 'uppercase', + }, + statValue: { + fontFamily: 'monospace', + fontWeight: '500', + }, + expandIcon: { + color: 'var(--text-muted)', + fontSize: '10px', + marginLeft: '6px', + }, + expandedContent: { + padding: '10px 14px', + background: 'var(--bg-card-expanded)', + maxHeight: '900px', + overflowY: 'auto', + }, + sectionHeader: { + fontSize: '10px', + color: 'var(--text-tertiary)', + textTransform: 'uppercase', + marginBottom: '8px', + fontWeight: '500', + }, + example: { + marginBottom: '8px', + padding: '8px 10px', + background: 'var(--bg-example)', + borderRadius: '4px', + border: '1px solid var(--border-light)', + }, + exampleMeta: { + fontSize: '10px', + color: 'var(--text-muted)', + marginBottom: '4px', + fontFamily: 'monospace', + display: 'flex', + justifyContent: 'space-between', + alignItems: 'center', + }, + proteinId: { + color: 'var(--text-heading)', + fontWeight: '700', + }, + annotation: { + color: 'var(--text-secondary)', + fontStyle: 'italic', + marginLeft: '8px', + }, + uniprotLink: { + color: 'var(--link)', + textDecoration: 'none', + fontSize: '11px', + marginLeft: '4px', + opacity: 0.6, + }, + noExamples: { + color: 'var(--text-muted)', + fontSize: '12px', + fontStyle: 'italic', + }, + densityBar: { + width: '50px', + height: '3px', + background: 'var(--density-bar-bg)', + borderRadius: '2px', + overflow: 'hidden', + marginTop: '3px', + }, + densityFill: { + height: '100%', + background: '#76b900', + borderRadius: '2px', + }, + alignBar: { + display: 'flex', + alignItems: 'center', + gap: '6px', + marginBottom: '10px', + fontSize: '10px', + color: '#888', + }, + alignLabel: { + textTransform: 'uppercase', + fontWeight: '500', + }, + alignBtn: { + padding: '2px 8px', + border: '1px solid #ddd', + borderRadius: '3px', + background: '#fff', + cursor: 'pointer', + fontSize: '10px', + color: '#555', + }, + alignBtnActive: { + padding: '2px 8px', + border: '1px solid #76b900', + borderRadius: '3px', + background: '#f0f9e0', + cursor: 'pointer', + fontSize: '10px', + color: '#333', + fontWeight: '600', + }, +} + +const FeatureCard = forwardRef(function FeatureCard({ feature, isHighlighted, forceExpanded, onClick, loadExamples }, ref) { + const [expanded, setExpanded] = useState(false) + const [showDetailPage, setShowDetailPage] = useState(false) + const [examples, setExamples] = useState([]) + const [loadingExamples, setLoadingExamples] = useState(false) + const examplesCacheRef = useRef(null) + const [alignMode, setAlignMode] = useState('start') + const scrollGroupRef = useRef([]) + const [editingTitle, setEditingTitle] = useState(false) + const [userTitle, setUserTitle] = useState('') + const inputRef = useRef(null) + + // Load user-provided title from localStorage + useEffect(() => { + const stored = localStorage.getItem(`featureTitle_${feature.feature_id}`) + if (stored) { + setUserTitle(stored) + } + }, [feature.feature_id]) + + // Focus input when editing starts + useEffect(() => { + if (editingTitle && inputRef.current) { + inputRef.current.focus() + inputRef.current.select() + } + }, [editingTitle]) + + // Reset scroll group when alignment changes + useEffect(() => { scrollGroupRef.current = [] }, [alignMode]) + + // If forceExpanded changes to true, expand the card + useEffect(() => { + if (forceExpanded) { + setExpanded(true) + } + }, [forceExpanded]) + + // Lazy-load examples from DuckDB when card is expanded + useEffect(() => { + if (!expanded || !loadExamples || examplesCacheRef.current) return + let cancelled = false + setLoadingExamples(true) + loadExamples(feature.feature_id).then(result => { + if (cancelled) return + examplesCacheRef.current = result + setExamples(result) + setLoadingExamples(false) + }).catch(err => { + if (cancelled) return + console.error('Error loading examples for feature', feature.feature_id, err) + setLoadingExamples(false) + }) + return () => { cancelled = true } + }, [expanded, loadExamples, feature.feature_id]) + + const freq = feature.activation_freq || 0 + const maxAct = feature.max_activation || 0 + const rawDesc = feature.label || feature.description || `Feature ${feature.feature_id}` + const description = rawDesc.toLowerCase().includes('common codons') ? 'Unidentified Feature' : rawDesc + + + const handleClick = () => { + const willExpand = !expanded + // Update UMAP highlight immediately, defer card expansion so it doesn't block + if (onClick) { + onClick(feature.feature_id, willExpand) + } + requestAnimationFrame(() => { + setExpanded(willExpand) + }) + } + + const handleSaveTitle = () => { + if (userTitle.trim()) { + localStorage.setItem(`featureTitle_${feature.feature_id}`, userTitle.trim()) + } else { + localStorage.removeItem(`featureTitle_${feature.feature_id}`) + setUserTitle('') + } + setEditingTitle(false) + } + + const handleCancelEdit = () => { + const stored = localStorage.getItem(`featureTitle_${feature.feature_id}`) + setUserTitle(stored || '') + setEditingTitle(false) + } + + const displayTitle = userTitle || description + + const handleTitleKeyDown = (e) => { + if (e.key === 'Enter') { + handleSaveTitle() + } else if (e.key === 'Escape') { + handleCancelEdit() + } + } + + const exportToCSV = () => { + const lines = [] + + // Feature metadata section + lines.push('=== FEATURE METADATA ===') + lines.push(`Feature ID,${feature.feature_id}`) + lines.push(`Label,${displayTitle}`) + if (userTitle) { + lines.push(`User Title,${userTitle}`) + } + lines.push(`Activation Frequency,${(freq * 100).toFixed(2)}%`) + lines.push(`Max Activation,${maxAct.toFixed(4)}`) + lines.push('') + + // Examples section + if (examples && examples.length > 0) { + lines.push('=== ACTIVATION EXAMPLES ===') + lines.push('Rank,Region,Max Activation,Sequence') + examples.forEach((ex, i) => { + lines.push(`${i + 1},${getRegionLabel(ex) || ''},${ex.max_activation?.toFixed(4) || ''},${ex.sequence || ''}`) + }) + } + + // Generate CSV + const csv = lines.join('\n') + + // Create download link + const filename = `feature_${feature.feature_id}_${displayTitle.replace(/[^a-z0-9]/gi, '_').substring(0, 20)}.csv` + const blob = new Blob([csv], { type: 'text/csv;charset=utf-8;' }) + const link = document.createElement('a') + link.setAttribute('href', URL.createObjectURL(blob)) + link.setAttribute('download', filename) + link.style.visibility = 'hidden' + document.body.appendChild(link) + link.click() + document.body.removeChild(link) + } + + return ( +
+
+
+
Feature #{feature.feature_id}
+ {editingTitle ? ( +
+ setUserTitle(e.target.value)} + onKeyDown={handleTitleKeyDown} + onClick={(e) => e.stopPropagation()} + style={{ + fontSize: '13px', + fontWeight: '500', + padding: '4px 8px', + border: '1px solid #76b900', + borderRadius: '4px', + flex: 1, + }} + /> + + +
+ ) : ( +
+
{displayTitle}
+ { e.stopPropagation(); setEditingTitle(true) }} + style={{ + fontSize: '11px', + color: '#999', + cursor: 'pointer', + padding: '2px 4px', + borderRadius: '3px', + userSelect: 'none', + }} + title="Click to edit title" + > + ✎ + +
+ )} +
+
+
+ Freq + {(freq * 100).toFixed(1)}% +
+
+
+
+
+ Max + {maxAct.toFixed(1)} +
+ {/* v2 roadmap placeholders — populated when real eval pipeline lands. */} +
+ Annotation + +
+
+ Sensitivity + +
+
+ Recon Δ + +
+ {expanded ? '▼' : '▶'} +
+
+ + {/* Details and export buttons - shown when expanded */} + {expanded && ( +
+ + +
+ )} + + {expanded && ( +
+ {feature.logo_path && ( +
+
Sequence Logo
+ {`Sequence +
+ )} + {/* Sequence examples */} +
+
Top Activating Sequences
+
+ Align by: + {['start', 'first_activation', 'max_activation'].map(mode => ( + + ))} +
+
+ {loadingExamples ? ( +
+ Loading examples... +
+ ) : examples.length > 0 ? ( + <> + {(() => { + const visibleExamples = examples.slice(0, 6) + const { anchor: alignAnchor, totalLength } = computeAlignInfo(visibleExamples, alignMode) + return visibleExamples.map((ex, i) => ( +
+
+ + {getRegionLabel(ex)} + {ex.best_annotation && ( + {ex.best_annotation} + )} + + max: {ex.max_activation?.toFixed(3) || 'N/A'} +
+ +
+ )) + })()} + + + ) : ( +
No examples available
+ )} +
+ )} + + {showDetailPage && ( + setShowDetailPage(false)} + /> + )} +
+ ) +}) + +export default FeatureCard diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/src/FeatureDetailPage.jsx b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/src/FeatureDetailPage.jsx new file mode 100644 index 0000000000..b70fdcdbde --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/src/FeatureDetailPage.jsx @@ -0,0 +1,198 @@ +import React, { useState, useEffect, useRef } from 'react' +import SequenceView, { computeAlignInfo } from './SequenceView' +import { getRegionLabel } from './utils' + +const styles = { + overlay: { + position: 'fixed', + inset: 0, + background: 'rgba(0, 0, 0, 0.5)', + zIndex: 2000, + overflowY: 'auto', + }, + page: { + maxWidth: '960px', + margin: '20px auto', + background: 'var(--bg-card)', + borderRadius: '8px', + boxShadow: '0 4px 24px rgba(0,0,0,0.2)', + color: 'var(--text)', + }, + header: { + padding: '12px 20px', + borderBottom: '1px solid var(--border-light)', + display: 'flex', + justifyContent: 'space-between', + alignItems: 'center', + }, + title: { + fontSize: '14px', + fontWeight: '700', + color: 'var(--text-heading)', + }, + closeBtn: { + background: 'none', + border: '1px solid var(--border-input)', + borderRadius: '4px', + padding: '3px 10px', + cursor: 'pointer', + fontSize: '11px', + color: 'var(--text-secondary)', + }, + section: { + padding: '10px 20px', + borderBottom: '1px solid var(--border-light)', + }, + sectionTitle: { + fontSize: '11px', + fontWeight: '600', + marginBottom: '6px', + color: 'var(--text-heading)', + textTransform: 'uppercase', + }, + example: { + marginBottom: '6px', + padding: '6px 8px', + background: 'var(--bg-example)', + borderRadius: '4px', + border: '1px solid var(--border-light)', + }, + exampleMeta: { + fontSize: '10px', + color: 'var(--text-secondary)', + marginBottom: '4px', + fontFamily: 'monospace', + display: 'flex', + justifyContent: 'space-between', + }, + placeholder: { + border: '1px dashed var(--border)', + borderRadius: '6px', + padding: '24px', + textAlign: 'center', + color: 'var(--text-muted)', + fontSize: '12px', + fontStyle: 'italic', + }, + placeholderLabel: { + fontSize: '13px', + fontWeight: '500', + color: 'var(--text-muted)', + marginBottom: '8px', + }, +} + +export default function FeatureDetailPage({ feature, examples, onClose }) { + const [alignMode, setAlignMode] = useState('max_activation') + const scrollGroupRef = useRef(null) + + const freq = feature.activation_freq || 0 + const maxAct = feature.max_activation || 0 + const description = feature.description || feature.label || `Feature ${feature.feature_id}` + + useEffect(() => { + const handleKey = (e) => { if (e.key === 'Escape') onClose() } + document.addEventListener('keydown', handleKey) + return () => document.removeEventListener('keydown', handleKey) + }, [onClose]) + + const visibleExamples = (examples || []).slice(0, 30) + const { anchor: alignAnchor, totalLength } = computeAlignInfo(visibleExamples.slice(0, 6), alignMode) + + return ( +
{ if (e.target === e.currentTarget) onClose() }}> +
+ +
+
+
+ Feature #{feature.feature_id} + + {description} + +
+
+
+
+ freq: {(freq * 100).toFixed(1)}% + max: {maxAct.toFixed(1)} +
+ +
+
+ + {feature.logo_path && ( +
+
Sequence Logo
+ {`Sequence +
+ )} + +
+
+
Top Activating Sequences
+
+ {['start', 'first_activation', 'max_activation'].map(mode => ( + + ))} +
+
+ + {visibleExamples.length > 0 ? ( + visibleExamples.map((ex, i) => ( +
+
+ {getRegionLabel(ex)} + max: {ex.max_activation?.toFixed(3)} +
+ +
+ )) + ) : ( +
No examples loaded
+ )} +
+ + {/* v2 roadmap placeholders — populated when annotation + conservation pipelines land. */} +
+
Annotations
+
+ Annotation overlay (RefSeq, Rfam, JASPAR) — coming in v2 +
+
+ +
+
Conservation
+
+ Conservation track (phyloP) — coming in v2 +
+
+ +
+
+ ) +} diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/src/FeatureList.jsx b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/src/FeatureList.jsx new file mode 100644 index 0000000000..26cd6c2457 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/src/FeatureList.jsx @@ -0,0 +1,83 @@ +import React, { memo } from 'react' +import FeatureCard from './FeatureCard' + +const styles = { + featureList: { + flex: 1, + overflowY: 'auto', + overflowX: 'hidden', + display: 'flex', + flexDirection: 'column', + gap: '10px', + paddingRight: '8px', + minHeight: 0, + }, +} + +function FeatureListComponent({ + filteredFeatures, + displayedCardCount, + clickedFeatureId, + features, + cardResetKey, + handleCardClick, + loadExamples, + vocabLogits, + featureAnalysis, + featureListRef, + endOfListRef, + featureRefs, +}) { + const visibleFeatures = filteredFeatures.slice(0, displayedCardCount) + const clickedIsVisible = clickedFeatureId != null && + visibleFeatures.some(f => Number(f.feature_id) === Number(clickedFeatureId)) + const clickedFeature = clickedFeatureId != null && !clickedIsVisible + ? features.find(f => Number(f.feature_id) === Number(clickedFeatureId)) + : null + + return ( +
+ {/* Only render clicked feature at top if NOT already in visible list */} + {clickedFeature && ( + { featureRefs.current[clickedFeature.feature_id] = el }} + feature={clickedFeature} + isHighlighted={true} + forceExpanded={true} + onClick={handleCardClick} + loadExamples={loadExamples} + vocabLogits={vocabLogits} + featureAnalysis={featureAnalysis} + /> + )} + {visibleFeatures.map(feature => ( + { featureRefs.current[feature.feature_id] = el }} + feature={feature} + isHighlighted={Number(clickedFeatureId) === Number(feature.feature_id)} + forceExpanded={Number(clickedFeatureId) === Number(feature.feature_id)} + onClick={handleCardClick} + loadExamples={loadExamples} + vocabLogits={vocabLogits} + featureAnalysis={featureAnalysis} + /> + ))} + {/* Sentinel element for infinite scroll detection */} +
+ {displayedCardCount < filteredFeatures.length && ( +
+ Scroll to load more... ({visibleFeatures.length} of {filteredFeatures.length}) +
+ )} + {filteredFeatures.length === 0 && clickedFeatureId == null && ( +
+ No features match your selection. +
+ )} +
+ ) +} + +export default memo(FeatureListComponent) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/src/Histogram.jsx b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/src/Histogram.jsx new file mode 100644 index 0000000000..553330862d --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/src/Histogram.jsx @@ -0,0 +1,85 @@ +import React, { useEffect, useRef } from 'react' +import * as vg from '@uwdata/vgplot' + +const FILL_COLOR = "#76b900" + +function injectAxisLine(plot, marginLeft, marginRight, marginBottom, height, axisColor) { + const svg = plot.tagName === 'svg' ? plot : plot.querySelector?.('svg') + if (!svg) return + // Remove any previously injected line + svg.querySelectorAll('.x-axis-line').forEach(el => el.remove()) + const svgWidth = svg.getAttribute('width') || svg.clientWidth + const line = document.createElementNS('http://www.w3.org/2000/svg', 'line') + line.classList.add('x-axis-line') + line.setAttribute('x1', marginLeft) + line.setAttribute('x2', svgWidth - marginRight) + line.setAttribute('y1', height - marginBottom) + line.setAttribute('y2', height - marginBottom) + line.setAttribute('stroke', axisColor) + line.setAttribute('stroke-width', '1') + svg.appendChild(line) +} + +export default function Histogram({ brush, column, label }) { + const containerRef = useRef(null) + + useEffect(() => { + if (!containerRef.current || !brush) return + + // Clear previous content + containerRef.current.innerHTML = '' + + const bgColor = getComputedStyle(document.documentElement).getPropertyValue('--density-bar-bg').trim() || '#e0e0e0' + const axisColor = getComputedStyle(document.documentElement).getPropertyValue('--text-tertiary').trim() || '#888' + const width = containerRef.current.clientWidth - 20 + const height = 50 + const marginLeft = 45 + const marginBottom = 20 + const marginRight = 10 + const marginTop = 5 + + const plot = vg.plot( + // Background histogram: full data (no filterBy) + vg.rectY( + vg.from("features"), + { x: vg.bin(column), y: vg.count(), fill: bgColor, inset: 1 } + ), + // Foreground histogram: filtered data + vg.rectY( + vg.from("features", { filterBy: brush }), + { x: vg.bin(column), y: vg.count(), fill: FILL_COLOR, inset: 1 } + ), + vg.intervalX({ as: brush }), + vg.xLabel(null), + vg.yLabel(null), + vg.width(width), + vg.height(height), + vg.marginLeft(marginLeft), + vg.marginBottom(marginBottom), + vg.marginTop(marginTop), + vg.marginRight(marginRight) + ) + + containerRef.current.appendChild(plot) + + // Inject axis line into the SVG directly (immune to container resize) + // Use a short delay to ensure the SVG is rendered + const timer = setTimeout(() => { + injectAxisLine(plot, marginLeft, marginRight, marginBottom, height, axisColor) + }, 50) + + return () => { + clearTimeout(timer) + if (containerRef.current) { + containerRef.current.innerHTML = '' + } + } + }, [brush, column, label]) + + return ( +
+ ) +} diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/src/InfoButton.jsx b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/src/InfoButton.jsx new file mode 100644 index 0000000000..40184d0ef6 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/src/InfoButton.jsx @@ -0,0 +1,78 @@ +import React, { useState, useEffect, useRef } from 'react' +import { createPortal } from 'react-dom' + +export default function InfoButton({ text }) { + const [open, setOpen] = useState(false) + const wrapperRef = useRef(null) + const buttonRef = useRef(null) + const [pos, setPos] = useState(null) + + useEffect(() => { + if (!open) return + const handleClick = (e) => { + if (wrapperRef.current && !wrapperRef.current.contains(e.target)) { + setOpen(false) + } + } + document.addEventListener('mousedown', handleClick) + return () => document.removeEventListener('mousedown', handleClick) + }, [open]) + + useEffect(() => { + if (open && buttonRef.current) { + const rect = buttonRef.current.getBoundingClientRect() + setPos({ + top: rect.top - 8, + left: rect.left + rect.width / 2, + }) + } + }, [open]) + + return ( + + setOpen(o => !o)} + style={{ + display: 'inline-flex', + alignItems: 'center', + justifyContent: 'center', + width: '15px', + height: '15px', + borderRadius: '50%', + border: '1px solid var(--border-input)', + fontSize: '10px', + fontWeight: '600', + color: 'var(--text-tertiary)', + cursor: 'pointer', + userSelect: 'none', + lineHeight: 1, + }} + > + i + + {open && pos && createPortal( +
+ {text} +
, + document.body + )} +
+ ) +} diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/src/Preview.jsx b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/src/Preview.jsx new file mode 100644 index 0000000000..ce0e0cee76 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/src/Preview.jsx @@ -0,0 +1,88 @@ +import React, { useState } from 'react' +import App from './App' +import SteeringDemo from './SteeringDemo' + +// Hit http://localhost:5176/#preview to see the tabbed preview. The plain `/` +// URL still renders the unchanged main dashboard. + +const TABS = [ + { id: 'main', label: 'Main dashboard (features + atlas + WebLogos)' }, + { id: 'steering', label: 'Causal steering' }, +] + +const styles = { + container: { + fontFamily: 'system-ui, sans-serif', + color: 'var(--text, #222)', + background: 'var(--bg, #fafafa)', + minHeight: '100vh', + display: 'flex', + flexDirection: 'column', + }, + tabBar: { + display: 'flex', + gap: '4px', + padding: '8px 16px', + background: 'var(--bg-card, #fff)', + borderBottom: '1px solid var(--border, #ddd)', + flexShrink: 0, + }, + tab: (active) => ({ + padding: '6px 14px', + border: '1px solid', + borderColor: active ? 'var(--accent, #76b900)' : 'var(--border, #ddd)', + background: active ? 'var(--bg-card-expanded, #f0f8e8)' : '#fff', + borderRadius: '4px', + cursor: 'pointer', + fontSize: '12px', + fontWeight: active ? 600 : 400, + color: active ? 'var(--accent, #76b900)' : 'var(--text-secondary, #555)', + }), + tabContent: { + flex: 1, + overflow: 'auto', + background: 'var(--bg, #fafafa)', + }, + wrap: { padding: '24px' }, + title: { fontSize: '20px', fontWeight: 600, marginBottom: '4px' }, + subtitle: { fontSize: '12px', color: 'var(--text-secondary, #666)', marginBottom: '16px' }, +} + + +export default function Preview() { + const [tab, setTab] = useState('main') + + return ( +
+
+ {TABS.map((t) => ( + + ))} +
+ +
+ {tab === 'main' && ( + // Existing dashboard: feature catalog, UMAP atlas, FeatureCard + // expansions with WebLogo PNGs, histograms. Untouched. +
+ +
+ )} + + {tab === 'steering' && ( +
+
Causal steering of SAE features
+
+ Pick a position, pick a feature, drag the clamp. Features that genuinely represent + biological concepts move predictions toward biologically meaningful outputs; + unrelated features don't. +
+ +
+ )} +
+
+ ) +} diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/src/RegionDetailModal.jsx b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/src/RegionDetailModal.jsx new file mode 100644 index 0000000000..d72dc41358 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/src/RegionDetailModal.jsx @@ -0,0 +1,157 @@ +import React, { useEffect } from 'react' +import ReactDOM from 'react-dom' +import SequenceView from './SequenceView' +import { getRegionLabel } from './utils' + +const styles = { + backdrop: { + position: 'fixed', + inset: 0, + background: 'rgba(0,0,0,0.5)', + zIndex: 9999, + display: 'flex', + alignItems: 'center', + justifyContent: 'center', + }, + modal: { + background: '#fff', + borderRadius: '12px', + width: '90vw', + maxWidth: '1000px', + maxHeight: '85vh', + display: 'flex', + flexDirection: 'column', + overflow: 'hidden', + boxShadow: '0 20px 60px rgba(0,0,0,0.3)', + position: 'relative', + }, + closeBtn: { + position: 'absolute', + top: '12px', + right: '12px', + zIndex: 10, + background: 'rgba(255,255,255,0.9)', + border: '1px solid #ddd', + borderRadius: '50%', + width: '32px', + height: '32px', + display: 'flex', + alignItems: 'center', + justifyContent: 'center', + cursor: 'pointer', + fontSize: '16px', + color: '#555', + }, + body: { + flex: 1, + padding: '32px', + overflowY: 'auto', + display: 'flex', + flexDirection: 'column', + gap: '20px', + }, + header: { + display: 'flex', + alignItems: 'center', + gap: '8px', + flexWrap: 'wrap', + }, + regionLabel: { + fontSize: '18px', + fontWeight: '700', + fontFamily: 'monospace', + color: '#222', + }, + statsRow: { + display: 'flex', + gap: '20px', + flexWrap: 'wrap', + }, + statBox: { + padding: '10px 14px', + background: '#f9fafb', + borderRadius: '8px', + border: '1px solid #eee', + }, + statLabel: { + fontSize: '10px', + color: '#888', + textTransform: 'uppercase', + marginBottom: '2px', + }, + statValue: { + fontSize: '14px', + fontWeight: '600', + fontFamily: 'monospace', + color: '#333', + }, + sectionLabel: { + fontSize: '11px', + color: '#888', + textTransform: 'uppercase', + fontWeight: '500', + }, + sequenceBox: { + background: '#fafafa', + border: '1px solid #eee', + borderRadius: '8px', + padding: '12px', + maxHeight: '300px', + overflowY: 'auto', + }, +} + +export default function RegionDetailModal({ region, onClose }) { + useEffect(() => { + const handleKey = (e) => { if (e.key === 'Escape') onClose() } + document.addEventListener('keydown', handleKey) + return () => document.removeEventListener('keydown', handleKey) + }, [onClose]) + + const label = getRegionLabel(region) + const sequenceLength = (region.sequence || '').length + + const modal = ( +
+
e.stopPropagation()}> +
x
+ +
+
+ {label} +
+ +
+
+
Max Activation
+
{(region.max_activation || 0).toFixed(4)}
+
+
+
Sequence Length
+
{sequenceLength} bp
+
+ {region.best_annotation && ( +
+
Annotation
+
{region.best_annotation}
+
+ )} +
+ +
+
Sequence (activation highlighted)
+
+ +
+
+
+
+
+ ) + + return ReactDOM.createPortal(modal, document.body) +} diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/src/SequenceView.jsx b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/src/SequenceView.jsx new file mode 100644 index 0000000000..5583180fc5 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/src/SequenceView.jsx @@ -0,0 +1,267 @@ +import React, { useState, useEffect, useRef } from 'react' +import { parseBases } from './utils' + +function activationColorHex(value, maxValue) { + if (maxValue <= 0 || value <= 0) return 'transparent' + const n = Math.min(value / maxValue, 1) + const r = Math.round(255 - n * 137) + const g = Math.round(255 - n * 70) + const b = Math.round(255 * (1 - n)) + const toHex = (c) => c.toString(16).padStart(2, '0') + return `#${toHex(r)}${toHex(g)}${toHex(b)}` +} + +const BASE_WIDTH = 12 + +const styles = { + container: { + fontFamily: 'Monaco, Menlo, "Courier New", monospace', + fontSize: '11px', + lineHeight: '1.2', + overflowX: 'auto', + position: 'relative', + }, + baseRow: { + display: 'inline-flex', + whiteSpace: 'nowrap', + }, + baseBlock: { + display: 'inline-flex', + flexDirection: 'column', + alignItems: 'center', + cursor: 'default', + borderRadius: '2px', + padding: '1px 1px', + marginRight: '0px', + minWidth: `${BASE_WIDTH}px`, + }, + padBlock: { + display: 'inline-flex', + flexDirection: 'column', + alignItems: 'center', + borderRadius: '2px', + padding: '1px 1px', + marginRight: '0px', + minWidth: `${BASE_WIDTH}px`, + background: 'var(--density-bar-bg)', + }, + padText: { + fontSize: '10px', + color: 'var(--text-muted)', + }, + baseText: { + fontSize: '10px', + letterSpacing: '0.5px', + color: 'var(--text)', + }, + idxText: { + fontSize: '7px', + color: 'var(--text-tertiary)', + marginTop: '0px', + lineHeight: '1', + }, + tooltip: { + position: 'fixed', + background: 'var(--bg-card)', + color: 'var(--text)', + border: '1px solid var(--border)', + padding: '4px 8px', + borderRadius: '4px', + fontSize: '10px', + fontFamily: 'monospace', + zIndex: 1000, + pointerEvents: 'none', + whiteSpace: 'nowrap', + }, +} + +// Show index under every Nth base to keep the row scannable +const INDEX_INTERVAL = 10 + +export default function SequenceView({ + sequence, activations, maxActivation, + alignMode, alignAnchor, totalLength, + scrollGroupRef, +}) { + const [tooltip, setTooltip] = useState(null) + const scrollRef = useRef(null) + const anchorRef = useRef(null) + + const bases = parseBases(sequence) + const acts = activations ? activations.slice(0, bases.length) : [] + const maxAct = maxActivation || Math.max(...acts, 0.001) + + // Compute local anchor index + let localAnchor = 0 + if (alignMode === 'first_activation') { + localAnchor = acts.findIndex(a => a > 0) + if (localAnchor < 0) localAnchor = 0 + } else if (alignMode === 'max_activation') { + let maxVal = -1 + acts.forEach((a, i) => { if (a > maxVal) { maxVal = a; localAnchor = i } }) + } + + // Padding + const isAligned = alignMode && alignMode !== 'start' && alignAnchor != null + const leftPad = isAligned ? Math.max(0, alignAnchor - localAnchor) : 0 + const rightPad = (totalLength != null) + ? Math.max(0, totalLength - leftPad - bases.length) + : 0 + + // Scroll to anchor when alignMode changes + useEffect(() => { + if (isAligned && anchorRef.current && scrollRef.current) { + anchorRef.current.scrollIntoView({ behavior: 'instant', inline: 'center', block: 'nearest' }) + } + }, [alignMode, alignAnchor]) + + // Synchronized scrolling across sequences in the same card + useEffect(() => { + const el = scrollRef.current + if (!el || !scrollGroupRef) return + + if (!scrollGroupRef.current) scrollGroupRef.current = [] + const group = scrollGroupRef.current + if (!group.includes(el)) group.push(el) + + let isSyncing = false + const handleScroll = () => { + if (isSyncing) return + isSyncing = true + const scrollLeft = el.scrollLeft + for (const other of group) { + if (other !== el) other.scrollLeft = scrollLeft + } + isSyncing = false + } + + el.addEventListener('scroll', handleScroll) + return () => { + el.removeEventListener('scroll', handleScroll) + const idx = group.indexOf(el) + if (idx !== -1) group.splice(idx, 1) + } + }, [scrollGroupRef]) + + if (!sequence || sequence.length === 0) { + return No sequence + } + + const handleMouseEnter = (e, base, idx, act) => { + setTooltip({ + x: e.clientX + 10, + y: e.clientY - 25, + text: `${base} pos ${idx + 1} — activation: ${act.toFixed(4)}`, + }) + } + + const handleMouseMove = (e) => { + if (tooltip) { + setTooltip((prev) => prev ? { ...prev, x: e.clientX + 10, y: e.clientY - 25 } : null) + } + } + + const handleMouseLeave = () => { + setTooltip(null) + } + + const shouldShowIdx = (idx) => (idx + 1) % INDEX_INTERVAL === 0 || idx === 0 + + return ( +
+
+ {/* Left padding */} + {Array.from({ length: leftPad }, (_, i) => ( + + · +   + + ))} + + {/* Actual bases */} + {bases.map((base, idx) => { + const act = acts[idx] || 0 + const bg = activationColorHex(act, maxAct) + const isAnchor = isAligned && idx === localAnchor + const hasActivation = act > 0 + const activeTextColor = hasActivation ? '#000' : undefined + return ( + handleMouseEnter(e, base, idx, act)} + onMouseMove={handleMouseMove} + onMouseLeave={handleMouseLeave} + > + {base} + {shouldShowIdx(idx) ? idx + 1 : ' '} + + ) + })} + + {/* Right padding */} + {Array.from({ length: rightPad }, (_, i) => ( + + · +   + + ))} +
+ {tooltip && ( + + {tooltip.text} + + )} +
+ ) +} + +/** + * Compute alignment info for a set of examples — same logic as the codonfm + * version, just operating on per-base activation arrays rather than per-codon. + */ +export function computeAlignInfo(examples, alignMode) { + if (!examples || examples.length === 0) return { anchor: 0, totalLength: 0 } + + if (alignMode === 'start') { + const maxLen = Math.max(...examples.map(ex => (ex.activations || []).length)) + return { anchor: 0, totalLength: maxLen } + } + + let maxAnchor = 0 + for (const ex of examples) { + const acts = ex.activations || [] + let anchor = 0 + if (alignMode === 'first_activation') { + anchor = acts.findIndex(a => a > 0) + if (anchor < 0) anchor = 0 + } else if (alignMode === 'max_activation') { + let maxVal = -1 + acts.forEach((a, i) => { if (a > maxVal) { maxVal = a; anchor = i } }) + } + if (anchor > maxAnchor) maxAnchor = anchor + } + + let totalLength = 0 + for (const ex of examples) { + const acts = ex.activations || [] + let anchor = 0 + if (alignMode === 'first_activation') { + anchor = acts.findIndex(a => a > 0) + if (anchor < 0) anchor = 0 + } else if (alignMode === 'max_activation') { + let maxVal = -1 + acts.forEach((a, i) => { if (a > maxVal) { maxVal = a; anchor = i } }) + } + const leftPad = maxAnchor - anchor + const thisTotal = leftPad + acts.length + if (thisTotal > totalLength) totalLength = thisTotal + } + + return { anchor: maxAnchor, totalLength } +} diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/src/SteeringDemo.jsx b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/src/SteeringDemo.jsx new file mode 100644 index 0000000000..8ebddc73f1 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/src/SteeringDemo.jsx @@ -0,0 +1,879 @@ +import React, { useEffect, useMemo, useState } from 'react' + +// Position-targeted steering demo. Click a position in the sequence, pick a +// feature, drag the clamp. Two side-by-side P(base) bar charts compare +// baseline vs steered. + +// Evo2 is DNA-tokenized; the model emits P(A/C/G/T) on every input, including +// rRNA contexts. We display T everywhere — never U — to match what the model +// actually predicts. +const BASES = ['A', 'C', 'G', 'T'] +const COLORS = { + A: '#59A14F', C: '#4E79A7', G: '#F28E2B', T: '#E15759', + accent: '#76b900', + pass: '#5a9c3f', + fail: '#c34a4a', + baseline: '#9C755F', + steered: '#76B7B2', +} + +// UI clamp axis is a multiplier on the feature's natural peak activation: +// 0× = baseline (no intervention), 1× = the highest value we'd see naturally +// in the training data (in-distribution upper bound), 2× = pushed past what +// the model has been trained on (out-of-distribution). +// Underlying JSON stores discrete shifts at {-2, 0, 2, 5}; the UI -0.5×/0×/1×/2× +// map to those four points so we can interpolate without re-mocking the data. +const UI_CLAMP_POINTS = [-0.5, 0, 1, 2] +const JSON_CLAMP_KEYS = ['-2', '0', '2', '5'] +const UI_CLAMP_MIN = -0.5 +const UI_CLAMP_MAX = 2 +const OOD_THRESHOLD = 1.0 // anything past 1× is out-of-distribution + +const BASES_PER_LINE = 60 + + +function interpProbs(low, high, t) { + const out = {} + let s = 0 + for (const b of Object.keys(low)) { + out[b] = low[b] + (high[b] - low[b]) * t + s += out[b] + } + // renormalize (interp can drift from 1.0) + for (const b of Object.keys(out)) out[b] /= s + return out +} + + +// Map a UI clamp value (e.g. 0.65 = 0.65× natural peak) to two adjacent JSON +// clamp keys and an interpolation factor t in [0,1] between them. +function pickSegment(uiValue) { + if (uiValue <= UI_CLAMP_POINTS[0]) return { lo: JSON_CLAMP_KEYS[0], hi: JSON_CLAMP_KEYS[0], t: 0 } + if (uiValue >= UI_CLAMP_POINTS[UI_CLAMP_POINTS.length - 1]) { + const last = JSON_CLAMP_KEYS[JSON_CLAMP_KEYS.length - 1] + return { lo: last, hi: last, t: 0 } + } + for (let i = 0; i < UI_CLAMP_POINTS.length - 1; i++) { + if (uiValue >= UI_CLAMP_POINTS[i] && uiValue <= UI_CLAMP_POINTS[i + 1]) { + const t = (uiValue - UI_CLAMP_POINTS[i]) / (UI_CLAMP_POINTS[i + 1] - UI_CLAMP_POINTS[i]) + return { lo: JSON_CLAMP_KEYS[i], hi: JSON_CLAMP_KEYS[i + 1], t } + } + } + return { lo: JSON_CLAMP_KEYS[0], hi: JSON_CLAMP_KEYS[0], t: 0 } +} + + +const NARRATIVES = { + headline_amr: + "Matches the known A1408G aminoglycoside-resistance mutation in E. coli 16S rRNA. The model learned this association without supervision; steering the kanamycin-resistance SAE feature reproduces the resistance mutation.", + tata_demo: + "Steering the TATA-box feature concentrates probability at A — the canonical first base of the TATAAA consensus.", + structural_demo: + "Amplifying the α-helix feature in a coding region biases the predicted base toward G — consistent with codons encoding helix-favoring amino acids.", + null_result: + "No meaningful shift. This is a random control with no biological context that would make any feature appropriate. A well-behaved feature shouldn't shift predictions where it has no reason to fire.", +} + + +// Pick the default pair to surface for a given seed. +function defaultPairForSeed(seedId, data) { + const prefix = `${seedId}__` + return Object.keys(data.comparisons).find((k) => k.startsWith(prefix)) +} + + +export default function SteeringDemo() { + const [data, setData] = useState(null) + const [error, setError] = useState(null) + const [seedId, setSeedId] = useState('ecoli_16s') + // Multi-feature clamping: clamp 1..N features simultaneously, all at the + // same slider value. With one selected the page behaves as before; with + // multiple, the per-feature steered-minus-baseline deltas sum and we + // renormalize. Cheap mock; the real backend would compute a joint forward. + const [featureIds, setFeatureIds] = useState([12]) + const [clamp, setClamp] = useState(1) // start at the in-distribution upper bound (1× natural peak) + const [mode, setMode] = useState('position') // 'position' = targeted; 'global' = clamp every position + const [neighbors, setNeighbors] = useState(1) + const [targetPos, setTargetPos] = useState(null) + + useEffect(() => { + fetch('/steering_data.json').then((r) => r.json()).then(setData).catch((e) => setError(e.message)) + }, []) + + // Whenever data lands or the seed changes, find the matching pair (if any) + // and align the feature + target position to it. + useEffect(() => { + if (!data) return + const pairKey = defaultPairForSeed(seedId, data) + if (pairKey) { + const cmp = data.comparisons[pairKey] + setFeatureIds([cmp.feature_id]) + setTargetPos(cmp.target_position) + } else { + setTargetPos(data.seeds[seedId].default_target_position) + } + }, [data, seedId]) + + // Find the seed's default comparison for narrative + baseline. The + // "primary" comparison is the one matching the first selected feature + // (if it exists), else the first comparison for this seed. + const primaryComparison = useMemo(() => { + if (!data) return null + const primaryFid = featureIds[0] + const exactKey = Object.keys(data.comparisons).find((k) => { + const c = data.comparisons[k] + return c.seed === seedId && c.feature_id === primaryFid && c.target_position === targetPos + }) + if (exactKey) return data.comparisons[exactKey] + const seedKey = defaultPairForSeed(seedId, data) + return seedKey ? data.comparisons[seedKey] : null + }, [data, seedId, featureIds, targetPos]) + + // Per-feature steered distributions at the current clamp value, additively + // combined into a single steered distribution. With one feature it's just + // that feature's steered probs; with more, sum (steered_f - baseline) over + // f and add to baseline, then renormalize. + // + // Global mode short-circuits: clamping every position smears the output, so + // we return a fixed low-confidence distribution with no clean winner — that + // pattern is the visible point of contrast with position-targeted steering. + const interpolated = useMemo(() => { + if (!data || !primaryComparison) return null + const { lo, hi, t } = pickSegment(clamp) + const baseline = primaryComparison.results_by_clamp['0']?.baseline + || primaryComparison.results_by_clamp[lo].baseline + if (!baseline) return null + + if (mode === 'global') { + // Clamping every position smears toward a low-confidence distribution. + // We interpolate from baseline (at clamp 0) to a target smear at the + // extreme |clamp| = 2, so the slider still has visible effect — just + // never a clean winner. + const SMEAR = { A: 0.31, G: 0.34, C: 0.19, T: 0.16 } + const t = Math.min(1, Math.abs(clamp) / UI_CLAMP_MAX) + const out = {} + let z = 0 + for (const b of BASES) { + out[b] = (1 - t) * (baseline[b] ?? 0) + t * SMEAR[b] + z += out[b] + } + for (const b of BASES) out[b] /= z + return { baseline, steered: out } + } + + // Find one comparison per selected feature (matching seed, falling back + // to any pair using that feature on this seed). + const perFeatureSteered = featureIds.map((fid) => { + const k = Object.keys(data.comparisons).find((key) => { + const c = data.comparisons[key] + return c.seed === seedId && c.feature_id === fid + }) + if (!k) return null + const c = data.comparisons[k] + const loSet = c.results_by_clamp[lo] + const hiSet = c.results_by_clamp[hi] + if (!loSet || !hiSet) return null + return interpProbs(loSet.steered, hiSet.steered, t) + }).filter(Boolean) + + if (perFeatureSteered.length === 0) return null + if (perFeatureSteered.length === 1) return { baseline, steered: perFeatureSteered[0] } + + // Multi-feature combine: sum the deltas from baseline, add to baseline, renormalize. + const combined = { ...baseline } + for (const b of BASES) { + let d = 0 + for (const s of perFeatureSteered) { + d += (s[b] ?? 0) - (baseline[b] ?? 0) + } + combined[b] = Math.max(1e-6, (baseline[b] ?? 0) + d) + } + const z = BASES.reduce((acc, b) => acc + combined[b], 0) + for (const b of BASES) combined[b] /= z + return { baseline, steered: combined } + }, [data, seedId, featureIds, primaryComparison, clamp]) + + const comparison = primaryComparison // used downstream for the no-data fallback message + + if (error) return
Failed to load steering_data.json: {error}
+ if (!data) return
Loading steering demo…
+ + const seed = data.seeds[seedId] + + return ( +
+
+ MOCKUP — hand-rolled probability distributions per (seed, feature, clamp). Position-targeted + steering protocol. +
+ + + + + + {mode === 'global' && comparison && ( + + )} + + {comparison && interpolated ? ( + + ) : ( +
+ No demo data for this combination. Pick a seed; the feature + target position will snap + to the demo pair available for that seed. +
+ )} +
+ ) +} + + +function Controls({ data, seedId, setSeedId, featureIds, setFeatureIds, clamp, setClamp, neighbors, setNeighbors, mode, setMode }) { + const primaryFid = featureIds[0] + const additionalFids = featureIds.slice(1) + + const setPrimary = (fid) => { + // Move new primary to front; drop it from additional if present. + const remaining = featureIds.filter((x) => x !== fid) + setFeatureIds([fid, ...remaining]) + } + const toggleAdditional = (fid) => { + if (additionalFids.includes(fid)) { + setFeatureIds([primaryFid, ...additionalFids.filter((x) => x !== fid)]) + } else { + setFeatureIds([primaryFid, ...additionalFids, fid]) + } + } + + return ( +
+
+ + +
+ +
+ + +
+ + {data.features_available.length > 1 && ( +
+ +
+ {data.features_available + .filter((f) => f.id !== primaryFid) + .map((f) => { + const active = additionalFids.includes(f.id) + return ( + + ) + })} +
+
+ )} +
+ {featureIds.length === 1 + ? 'Clamping 1 feature.' + : `Clamping ${featureIds.length} features at the same clamp value. Per-feature shifts add and renormalize (mock).`} +
+ +
+ +
+ setClamp(parseFloat(e.target.value))} + style={styles.slider} + /> +
+ setClamp(-0.5)} style={styles.tick}>−0.5× suppress + setClamp(0)} style={styles.tick}>0× baseline + setClamp(1)} style={{ ...styles.tick, fontWeight: 600 }}>1× natural peak + setClamp(2)} style={styles.tick}>2× OOD +
+
+ + = {clamp.toFixed(2)}×{clamp > OOD_THRESHOLD ? ' (OOD)' : ''} + +
+ +
+ + {[ + { id: 'position', label: 'Position-restricted' }, + { id: 'global', label: 'Global (all positions)' }, + ].map((m) => ( + + ))} +
+ +
+ + {[0, 1, 2, 3, 4].map((n) => ( + + ))} +
+
+ ) +} + + +// Three-segment colored bar above the slider that visually marks the zones: +// suppress (left of 0×), in-distribution (0× to 1×), OOD (past 1×). Widths are +// proportional to the actual UI range so the bar lines up under the slider thumb. +function ClampZoneBar() { + const total = UI_CLAMP_MAX - UI_CLAMP_MIN + const w = (a, b) => `${((b - a) / total) * 100}%` + return ( +
+
+ suppress +
+
+ in-distribution +
+
+ out-of-distribution +
+
+ ) +} + + +function SequenceTarget({ seed, targetPos, setTargetPos, neighbors }) { + const seq = seed.sequence + const lines = [] + for (let start = 0; start < seq.length; start += BASES_PER_LINE) { + lines.push({ start, end: Math.min(start + BASES_PER_LINE, seq.length) }) + } + // Neighbor positions: `neighbors` bases immediately before the target. + const neighborSet = new Set() + if (targetPos != null) { + for (let k = 1; k <= neighbors; k++) { + const p = targetPos - k + if (p >= 0) neighborSet.add(p) + } + } + return ( +
+
+ Click a position to target it. Currently targeting position {targetPos != null ? targetPos + 1 : '—'}. +
+
+ {lines.map(({ start, end }) => ( +
+ {String(start + 1).padStart(4, ' ')} + + {[...seq.slice(start, end)].map((base, j) => { + const pos = start + j + const isTarget = pos === targetPos + const isNeighbor = neighborSet.has(pos) + let style = styles.baseChar + if (isTarget) style = { ...style, ...styles.baseTarget } + else if (isNeighbor) style = { ...style, ...styles.baseNeighbor } + return ( + setTargetPos(pos)} + style={style} + title={isNeighbor ? `Position ${pos + 1} (clamped neighbor)` : `Position ${pos + 1}`} + > + {base} + + ) + })} + +
+ ))} +
+
+ ) +} + + +// Global-mode visualization: clamping every position degrades the argmax +// sequence almost everywhere. We render the baseline argmax (~= input seq, since +// Evo2 reproduces a real bio sequence on its own context) vs a deterministically +// scrambled steered argmax. Fraction of flipped positions scales with |clamp|. +function SequenceStrip({ seed, seedId, featureIds, clamp }) { + const SMEAR_WEIGHTS = [['A', 0.31], ['G', 0.34], ['C', 0.19], ['T', 0.16]] + const baseChars = seed.sequence.toUpperCase().split('').filter((c) => 'ACGT'.includes(c)) + + // mulberry32 PRNG so the scramble is stable for a given (seed, feature, clamp) + function hashStr(s) { + let h = 2166136261 + for (let i = 0; i < s.length; i++) { h ^= s.charCodeAt(i); h = (h * 16777619) >>> 0 } + return h >>> 0 + } + const mulberry32 = (a) => () => { + a = (a + 0x6D2B79F5) >>> 0 + let t = a + t = Math.imul(t ^ (t >>> 15), t | 1) + t ^= t + Math.imul(t ^ (t >>> 7), t | 61) + return ((t ^ (t >>> 14)) >>> 0) / 4294967296 + } + const rngKey = `${seedId}|${featureIds.join(',')}|${clamp.toFixed(2)}` + const rand = mulberry32(hashStr(rngKey)) + + const flipFrac = Math.min(1, Math.abs(clamp) / UI_CLAMP_MAX) + + const steeredChars = baseChars.map((c) => { + if (rand() >= flipFrac) return c + // pick a new base weighted by the smear distribution + let r = rand(), acc = 0 + for (const [b, w] of SMEAR_WEIGHTS) { acc += w; if (r < acc) return b } + return SMEAR_WEIGHTS[SMEAR_WEIGHTS.length - 1][0] + }) + + const changed = baseChars.reduce((n, c, i) => n + (steeredChars[i] !== c ? 1 : 0), 0) + const preserved = baseChars.length - changed + + return ( +
+
+ Effect across all positions (global clamp) + + argmax preserved:{' '} + {preserved} / {baseChars.length}{' '} + ({((preserved / baseChars.length) * 100).toFixed(0)}%) + +
+ +
+ Baseline + + {baseChars.map((c, i) => ( + {c} + ))} + +
+
+ Steered + + {steeredChars.map((c, i) => { + const flipped = c !== baseChars[i] + return ( + + {c} + + ) + })} + +
+ +
+ Position-restricted steering would change exactly 1 position (or 1 + neighbors). + Global clamping degrades the prediction nearly everywhere. +
+
+ ) +} + + +function BarComparison({ targetPos, baseline, steered, mode }) { + // top base in each + let topB = BASES[0], topS = BASES[0] + for (const b of BASES) { + if (baseline[b] > baseline[topB]) topB = b + if ((steered[b] ?? 0) > (steered[topS] ?? 0)) topS = b + } + const flipped = topB !== topS + return ( +
+
+ Predicted base distribution at position {targetPos + 1} +
+
+ + +
+
+ {mode === 'global' ? ( + <> + Top base:{' '} + {topB} ({baseline[topB].toFixed(2)}) →{' '} + {topS} ({steered[topS].toFixed(2)}) + no clean flip — degraded + + ) : flipped ? ( + <> + Top base changed:{' '} + {topB} ({baseline[topB].toFixed(2)}) →{' '} + {topS} ({steered[topS].toFixed(2)}) + FLIPPED + + ) : ( + <> + Top base unchanged:{' '} + {topB} ({baseline[topB].toFixed(2)}) → {topS}{' '} + ({steered[topS].toFixed(2)}) + + )} +
+
+ ) +} + + +function BarChart({ title, dist, top }) { + return ( +
+
{title}
+ {BASES.map((b) => { + const p = dist[b] ?? 0 + const isTop = b === top + return ( +
+ {b} +
+
+
+ {p.toFixed(2)} +
+ ) + })} +
+ ) +} + + +function Selectivity({ rows, clamp }) { + return ( +
+
Selectivity check — does only the right feature shift the prediction?
+ + + + + + + + + + + {rows.map((r) => ( + + + + + + + ))} + +
FeatureSteered topP(top)Related?
{r.feature_label} + {r.steered_top_base} + {r.p_top.toFixed(2)} + {r.is_amr ? ( + + ) : ( + + )} +
+
+ Selectivity values shown at clamp = +5 (canonical headline strength); current slider at {clamp.toFixed(1)}. +
+
+ ) +} + + +function Narrative({ type }) { + const text = NARRATIVES[type] + if (!text) return null + return ( +
+ 💡 + {text} +
+ ) +} + + +const styles = { + container: { fontFamily: 'system-ui, sans-serif', color: 'var(--text, #222)' }, + banner: { + background: '#fff3cd', border: '1px solid #ffeeba', color: '#856404', + padding: '6px 12px', borderRadius: '4px', fontSize: '11px', marginBottom: '12px', + }, + controls: { + background: 'var(--bg-card, #fff)', border: '1px solid var(--border, #ddd)', + borderRadius: '6px', padding: '10px 14px', marginBottom: '12px', + position: 'sticky', top: 0, zIndex: 10, + }, + controlRow: { display: 'flex', alignItems: 'center', gap: '12px', marginBottom: '6px', flexWrap: 'wrap' }, + controlLabel: { fontSize: '12px', fontWeight: 600, color: 'var(--text-secondary, #555)', minWidth: '120px' }, + select: { padding: '4px 8px', fontSize: '12px', borderRadius: '4px', border: '1px solid var(--border, #ddd)', background: '#fff', minWidth: '260px' }, + featureChips: { display: 'flex', flexWrap: 'wrap', gap: '10px' }, + checkLabel: { + display: 'inline-flex', + alignItems: 'center', + gap: '4px', + padding: '3px 8px', + border: '1px solid var(--border, #ddd)', + background: '#fff', + borderRadius: '4px', + cursor: 'pointer', + fontSize: '11px', + color: 'var(--text-secondary, #555)', + userSelect: 'none', + }, + checkLabelActive: { + display: 'inline-flex', + alignItems: 'center', + gap: '4px', + padding: '3px 8px', + border: '1px solid var(--accent, #76b900)', + background: 'var(--bg-card-expanded, #f0f8e8)', + borderRadius: '4px', + cursor: 'pointer', + fontSize: '11px', + color: 'var(--accent, #76b900)', + fontWeight: 600, + userSelect: 'none', + }, + checkbox: { margin: 0, cursor: 'pointer' }, + multiHint: { + marginLeft: '120px', + fontSize: '10px', + fontStyle: 'italic', + color: 'var(--text-muted, #888)', + marginBottom: '8px', + }, + sliderColumn: { display: 'flex', flexDirection: 'column', flex: 1, maxWidth: '520px', gap: '4px' }, + slider: { width: '100%' }, + sliderTicks: { display: 'flex', justifyContent: 'space-between', fontSize: '10px', color: 'var(--text-muted, #888)' }, + tick: { cursor: 'pointer', userSelect: 'none' }, + clampValue: { fontFamily: 'monospace', fontSize: '12px', fontWeight: 600, minWidth: '80px', whiteSpace: 'nowrap' }, + zoneBar: { + display: 'flex', + height: '14px', + border: '1px solid var(--border, #ddd)', + borderRadius: '3px', + overflow: 'hidden', + fontSize: '9px', + fontWeight: 600, + textTransform: 'uppercase', + }, + zone: { + display: 'flex', + alignItems: 'center', + justifyContent: 'center', + borderRight: '1px solid #fff', + whiteSpace: 'nowrap', + overflow: 'hidden', + }, + infoIcon: { marginLeft: '4px', color: 'var(--text-muted, #aaa)', cursor: 'help', fontSize: '11px' }, + neighborBtn: { + padding: '3px 12px', border: '1px solid var(--border, #ddd)', background: '#fff', + borderRadius: '4px', cursor: 'pointer', fontSize: '11px', fontFamily: 'monospace', color: 'var(--text-secondary, #555)', + }, + neighborBtnActive: { + padding: '3px 12px', border: '1px solid var(--accent, #76b900)', + background: 'var(--bg-card-expanded, #f0f8e8)', borderRadius: '4px', cursor: 'pointer', + fontSize: '11px', fontFamily: 'monospace', color: 'var(--accent, #76b900)', fontWeight: 700, + }, + seqPanel: { + background: 'var(--bg-card, #fff)', border: '1px solid var(--border, #ddd)', + borderRadius: '6px', padding: '10px 14px', marginBottom: '12px', + }, + seqHeader: { fontSize: '11px', color: 'var(--text-secondary, #555)', marginBottom: '8px' }, + seqBody: { fontFamily: 'monospace', fontSize: '13px', lineHeight: '1.7' }, + seqLine: { display: 'flex', gap: '8px', alignItems: 'baseline' }, + seqIndex: { color: 'var(--text-muted, #aaa)', fontSize: '11px', minWidth: '32px', textAlign: 'right' }, + seqBases: { letterSpacing: '1px' }, + baseChar: { padding: '0 1px', cursor: 'pointer', borderRadius: '2px' }, + baseTarget: { + outline: `2px solid ${COLORS.accent}`, + background: 'rgba(118, 185, 0, 0.18)', + fontWeight: 700, + }, + baseNeighbor: { + background: 'rgba(118, 185, 0, 0.10)', + outline: `1px dashed ${COLORS.accent}`, + }, + barPanel: { + background: 'var(--bg-card, #fff)', border: '1px solid var(--border, #ddd)', + borderRadius: '6px', padding: '10px 14px', marginBottom: '12px', + }, + barTitle: { fontSize: '12px', fontWeight: 600, color: 'var(--text-heading, #222)', marginBottom: '10px' }, + barCharts: { display: 'grid', gridTemplateColumns: '1fr 1fr', gap: '14px' }, + barCard: { + background: 'var(--bg-card-expanded, #fafafa)', + border: '1px solid var(--border-light, #eee)', borderRadius: '4px', padding: '10px', + }, + barCardTitle: { + fontSize: '10px', textTransform: 'uppercase', fontWeight: 600, + color: 'var(--text-tertiary, #888)', marginBottom: '6px', + }, + barRow: { display: 'grid', gridTemplateColumns: '18px 1fr 40px', gap: '6px', alignItems: 'center', marginBottom: '4px' }, + barBaseLabel: { fontFamily: 'monospace', fontWeight: 700 }, + barTrack: { height: '14px', background: '#f0f0f0', borderRadius: '3px', overflow: 'hidden' }, + barFill: { height: '100%', borderRadius: '3px' }, + barProb: { fontFamily: 'monospace', fontSize: '11px', textAlign: 'right' }, + barSummary: { marginTop: '8px', fontSize: '12px', color: 'var(--text-secondary, #444)' }, + flipBadge: { + marginLeft: '8px', background: '#fcebea', color: '#c34', padding: '1px 6px', + borderRadius: '3px', fontSize: '9px', fontWeight: 700, + }, + degradedBadge: { + marginLeft: '8px', background: '#f0f0f0', color: '#666', padding: '1px 6px', + borderRadius: '3px', fontSize: '10px', fontStyle: 'italic', + }, + modeBtn: { + padding: '4px 10px', border: '1px solid var(--border, #ddd)', background: '#fff', + borderRadius: '4px', cursor: 'pointer', fontSize: '11px', color: 'var(--text-secondary, #555)', + }, + modeBtnActive: { + padding: '4px 10px', border: '1px solid var(--accent, #76b900)', + background: 'var(--bg-card-expanded, #f0f8e8)', borderRadius: '4px', cursor: 'pointer', + fontSize: '11px', color: 'var(--accent, #76b900)', fontWeight: 600, + }, + stripPanel: { + border: '1px solid var(--border, #ddd)', background: 'var(--bg-card, #fff)', + borderRadius: '6px', padding: '12px 16px', marginBottom: '12px', + }, + stripHeader: { + display: 'flex', justifyContent: 'space-between', alignItems: 'baseline', + marginBottom: '8px', fontSize: '12px', + }, + stripTitle: { fontWeight: 600, color: 'var(--text, #222)' }, + stripStat: { color: 'var(--text-secondary, #555)', fontSize: '11px' }, + stripRow: { display: 'flex', alignItems: 'center', gap: '8px', marginBottom: '2px' }, + stripLabel: { + width: '60px', fontSize: '10px', color: 'var(--text-secondary, #666)', + textTransform: 'uppercase', letterSpacing: '0.5px', + }, + stripSeq: { + fontFamily: 'ui-monospace, "SF Mono", Menlo, monospace', fontSize: '11px', + letterSpacing: '0', whiteSpace: 'pre-wrap', wordBreak: 'break-all', lineHeight: 1.5, + }, + stripCell: { display: 'inline-block', width: '11px', textAlign: 'center', fontWeight: 600 }, + stripFooter: { + marginTop: '8px', fontSize: '10px', color: 'var(--text-secondary, #888)', + fontStyle: 'italic', + }, + selPanel: { + background: 'var(--bg-card, #fff)', border: '1px solid var(--border, #ddd)', + borderRadius: '6px', padding: '10px 14px', marginBottom: '12px', + }, + selTitle: { fontSize: '12px', fontWeight: 600, color: 'var(--text-heading, #222)', marginBottom: '8px' }, + selTable: { width: '100%', borderCollapse: 'collapse', fontSize: '12px' }, + selTableHeader: { + fontSize: '10px', textTransform: 'uppercase', color: 'var(--text-tertiary, #888)', fontWeight: 600, + borderBottom: '1px solid var(--border-light, #eee)', + }, + selRow: { borderBottom: '1px solid var(--border-light, #f5f5f5)' }, + selCell: { padding: '5px 8px', textAlign: 'center' }, + selFootnote: { marginTop: '6px', fontSize: '10px', color: 'var(--text-muted, #888)', fontStyle: 'italic' }, + narrative: { + display: 'flex', alignItems: 'flex-start', gap: '8px', + background: '#eef6ff', border: '1px solid #bcd9ff', + borderRadius: '4px', padding: '8px 12px', + fontSize: '12px', color: '#1a3a6a', lineHeight: '1.5', + }, + narrativeIcon: { fontSize: '16px' }, + loading: { padding: '40px', textAlign: 'center', color: 'var(--text-muted, #aaa)', fontStyle: 'italic' }, + empty: { + padding: '24px', textAlign: 'center', background: 'var(--bg-card-expanded, #f8f8f8)', + border: '1px dashed var(--border, #ddd)', borderRadius: '6px', fontSize: '12px', + color: 'var(--text-muted, #888)', + }, + error: { padding: '20px', background: '#fee', color: '#c34', borderRadius: '4px', fontSize: '12px', fontFamily: 'monospace' }, +} diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/src/index.jsx b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/src/index.jsx new file mode 100644 index 0000000000..776d4b28fb --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/src/index.jsx @@ -0,0 +1,14 @@ +import React from 'react' +import ReactDOM from 'react-dom/client' +import App from './App' +import Preview from './Preview' + +// Hit `/#preview` to see the new ColoredSequence + GeneUMAPView demo without +// disturbing the production dashboard routing. +const isPreview = typeof window !== 'undefined' && window.location.hash === '#preview' + +ReactDOM.createRoot(document.getElementById('root')).render( + + {isPreview ? : } + +) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/src/utils.js b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/src/utils.js new file mode 100644 index 0000000000..936d971eae --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/src/utils.js @@ -0,0 +1,22 @@ +/** + * Build a human-readable label for a genomic region example. + * Expects an object with sequence_id, start, end fields. Falls back + * gracefully if any of those are missing. + */ +export function getRegionLabel(example) { + if (!example) return '' + const sid = example.sequence_id || example.protein_id || '' + if (example.start != null && example.end != null) { + return `${sid}:${example.start}-${example.end}` + } + return sid +} + +/** + * Parse a DNA sequence into an array of single-base tokens. + * No codon framing — each base is rendered independently. + */ +export function parseBases(sequence) { + if (!sequence) return [] + return sequence.split('') +} diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/vite.config.js b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/vite.config.js new file mode 100644 index 0000000000..6df23efee3 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/evo2_dashboard_mockup/vite.config.js @@ -0,0 +1,13 @@ +import { defineConfig } from 'vite' +import react from '@vitejs/plugin-react' + +export default defineConfig({ + plugins: [react()], + root: '.', + build: { + outDir: 'dist', + }, + server: { + port: 5176, + }, +}) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml new file mode 100644 index 0000000000..1f00a62bc5 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml @@ -0,0 +1,27 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "evo2-sae" +version = "0.1.0" +description = "Sparse Autoencoders for the Evo2 DNA language model" +readme = "README.md" +requires-python = ">=3.10" + +dependencies = [ + "sae", + "torch>=2.0", + "numpy>=1.20", + "tqdm>=4.60", + "pyarrow>=10.0", +] + +# No package code lives here yet — the recipe is just an entry-point for +# scripts/ that depends on the shared `sae` workspace package. Declare no +# packages so setuptools doesn't try to discover anything. +[tool.setuptools] +packages = [] + +[tool.uv.sources] +sae = { workspace = true } diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/1b.sh b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/1b.sh new file mode 100755 index 0000000000..d499b4f365 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/1b.sh @@ -0,0 +1,116 @@ +#!/bin/bash +# Evo2 1B SAE pipeline: convert -> predict_evo2 -> pt_to_parquet -> train. +# +# Assumes: +# - bionemo-recipes/recipes/evo2_megatron has been built (.ci_build.sh) and +# its .venv is active, providing predict_evo2 + evo2_convert_savanna_to_mbridge. +# - The sae workspace package is importable in that same venv. +# - HF_TOKEN is set if Savanna checkpoint repo is gated. +# +# Override any of these by exporting before invocation. + +set -euo pipefail + +EVO2_MEGATRON_DIR="${EVO2_MEGATRON_DIR:-/workspace/bionemo-framework/bionemo-recipes/recipes/evo2_megatron}" +RECIPE_DIR="$(cd "$(dirname "$0")/.." && pwd)" + +MODEL="${MODEL:-arcinstitute/savanna_evo2_1b_base}" +MODEL_SIZE="${MODEL_SIZE:-evo2_1b_base}" +LAYER="${LAYER:-12}" +# Trained context length. 1B = 8192. Bump for 7B/40B (context-extended). +CHUNK_BP="${CHUNK_BP:-8192}" + +FASTA="${FASTA:-/data/interp/evo2/OpenGenome2/fasta/organelles/organelle_sequences.fasta.gz}" +WORK_ROOT="${WORK_ROOT:-/data/interp/evo2}" + +CKPT_DIR="${WORK_ROOT}/checkpoints/${MODEL_SIZE}_mbridge" +PREDICT_DIR="${WORK_ROOT}/activations/${MODEL_SIZE}_layer${LAYER}_pt" +PARQUET_DIR="${WORK_ROOT}/activations/${MODEL_SIZE}_layer${LAYER}_parquet" +OUTPUT_DIR="${WORK_ROOT}/sae/${MODEL_SIZE}_layer${LAYER}" + +source "${EVO2_MEGATRON_DIR}/.venv/bin/activate" + +echo "============================================================" +echo "STEP 0: Chunk FASTA to <=${CHUNK_BP} bp (model trained context)" +echo "============================================================" +# chunk_fasta.py reads .gz directly and writes plain .fasta; no separate gunzip needed. +INPUT_STEM="$(basename "$FASTA")" +INPUT_STEM="${INPUT_STEM%.gz}" +INPUT_STEM="${INPUT_STEM%.fasta}" +CHUNKED_FASTA="${WORK_ROOT}/scratch/${INPUT_STEM}_chunked${CHUNK_BP}.fasta" +if [[ -f "$CHUNKED_FASTA" ]]; then + echo "Reusing existing chunked FASTA: $CHUNKED_FASTA" +else + python "${RECIPE_DIR}/scripts/chunk_fasta.py" \ + --input "$FASTA" \ + --output "$CHUNKED_FASTA" \ + --window "$CHUNK_BP" +fi +FASTA="$CHUNKED_FASTA" + +echo "============================================================" +echo "STEP 1: Convert Savanna -> MBridge" +echo "============================================================" +if [[ ! -f "${CKPT_DIR}/latest_checkpointed_iteration.txt" ]]; then + evo2_convert_savanna_to_mbridge \ + --savanna-ckpt-path "$MODEL" \ + --mbridge-ckpt-dir "$CKPT_DIR" \ + --model-size "$MODEL_SIZE" \ + --tokenizer-path "${EVO2_MEGATRON_DIR}/tokenizers/nucleotide_fast_tokenizer_512" +else + echo "Reusing existing checkpoint at $CKPT_DIR" +fi + +echo "============================================================" +echo "STEP 2: Extract layer-${LAYER} embeddings (predict_evo2)" +echo "============================================================" +mkdir -p "$PREDICT_DIR" +if compgen -G "${PREDICT_DIR}/predictions__*.pt" > /dev/null; then + echo "Reusing existing .pt files in $PREDICT_DIR" +else + predict_evo2 \ + --fasta "$FASTA" \ + --ckpt-dir "$CKPT_DIR" \ + --output-dir "$PREDICT_DIR" \ + --embedding-layer "$LAYER" \ + --micro-batch-size 1 \ + --devices 1 \ + --write-interval batch +fi + +echo "============================================================" +echo "STEP 3: Convert .pt -> parquet ActivationStore" +echo "============================================================" +if [[ -f "${PARQUET_DIR}/metadata.json" ]]; then + echo "Reusing existing parquet shards at $PARQUET_DIR" +else + python "${RECIPE_DIR}/scripts/pt_to_parquet.py" \ + --predict-dir "$PREDICT_DIR" \ + --output "$PARQUET_DIR" \ + --model-name "$MODEL" \ + --layer "$LAYER" +fi + +echo "============================================================" +echo "STEP 4: Train TopK SAE" +echo "============================================================" +python "${RECIPE_DIR}/scripts/train.py" \ + --cache-dir "$PARQUET_DIR" \ + --model-path "$MODEL" \ + --layer "$LAYER" \ + --model-type topk \ + --expansion-factor 8 --top-k 32 \ + --auxk 64 --auxk-coef 0.03125 \ + --init-pre-bias \ + --n-epochs 3 \ + --batch-size 4096 \ + --lr 3e-4 \ + --log-interval 50 \ + --no-wandb \ + --output-dir "$OUTPUT_DIR" \ + --checkpoint-dir "${OUTPUT_DIR}/checkpoints" \ + --checkpoint-steps 999999 + +echo "============================================================" +echo "DONE: SAE checkpoint at ${OUTPUT_DIR}/checkpoints/checkpoint_final.pt" +echo "============================================================" diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/chunk_fasta.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/chunk_fasta.py new file mode 100644 index 0000000000..55b26cad30 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/chunk_fasta.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Chunk a FASTA into <=N-bp windows so predict_evo2 stays inside the model's trained context. + +Evo2 1B was trained with seq_length=8192; longer inputs OOM in the Hyena +fftconv path (intermediates scale super-linearly with L). For 7B/40B raise +--window to whatever those checkpoints were context-extended to. + +Non-overlapping windows by default. Each chunk gets a header of the form +">{orig_id}:{start}-{end}" so downstream parquet can be back-mapped. +""" + +import argparse +import gzip +from pathlib import Path + + +def parse_fasta(path: Path): + """Yield (seq_id, sequence) tuples from a FASTA file (transparently handles .gz).""" + opener = gzip.open if path.suffix == ".gz" else open + seq_id, parts = None, [] + with opener(path, "rt") as f: + for line in f: + line = line.rstrip() + if line.startswith(">"): + if seq_id is not None: + yield seq_id, "".join(parts) + seq_id = line[1:].split()[0] + parts = [] + else: + parts.append(line) + if seq_id is not None: + yield seq_id, "".join(parts) + + +def main(): + """Read input FASTA, write non-overlapping <=window-bp chunks to output FASTA.""" + p = argparse.ArgumentParser() + p.add_argument("--input", type=Path, required=True) + p.add_argument("--output", type=Path, required=True) + p.add_argument("--window", type=int, default=8192) + args = p.parse_args() + + n_in = n_out = bp_out = 0 + args.output.parent.mkdir(parents=True, exist_ok=True) + with open(args.output, "w") as out: + for seq_id, seq in parse_fasta(args.input): + n_in += 1 + for start in range(0, len(seq), args.window): + end = min(start + args.window, len(seq)) + chunk = seq[start:end] + out.write(f">{seq_id}:{start}-{end}\n{chunk}\n") + n_out += 1 + bp_out += len(chunk) + + print(f"Chunked {n_in} sequences -> {n_out} chunks ({bp_out:,} bp) at window={args.window}") + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/make_mockup_features.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/make_mockup_features.py new file mode 100644 index 0000000000..f3ccbccc99 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/make_mockup_features.py @@ -0,0 +1,479 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generate synthetic features.json + features_atlas.parquet for the evo2 SAE mockup dashboard. + +Run once, commit outputs as fixtures. No real SAE involved — this is a v1 demo of the +visualization shell. The data shape is the contract the real eval pipeline will target later. +""" + +import argparse +import json +from pathlib import Path + +import numpy as np +import pandas as pd + + +# DNA-native labels for evo2 features, each with a real central signature spliced into +# the middle of the 200bp window so the mockup features are visually distinguishable. +LABELS = [ + "Start codon (ATG) context", + "TATA box", + "Polyadenylation signal", + "Bacterial promoter -10 box", + "CpG island", + "Shine-Dalgarno sequence", + "Bacterial promoter -35 box", + "Splice donor site", + "Splice acceptor site", + "Stop codon (TAA) context", + "Stop codon (TAG) context", +] + +# Plausible accessions to rotate across examples. +SEQ_IDS = ["NC_000913.3", "NC_002695.2", "chr1", "chr17"] + +# Central motif spliced into the middle ~20bp of each top-activating window. +CENTRAL_MOTIFS = { + "Start codon (ATG) context": "GCCACCATGGCC", + "TATA box": "TATAAA", + "Polyadenylation signal": "AATAAA", + "Bacterial promoter -10 box": "TATAAT", + "CpG island": "CGCGCGCGCGCGCGCG", + "Shine-Dalgarno sequence": "AGGAGGT", + "Bacterial promoter -35 box": "TTGACA", + "Splice donor site": "GTAAGT", # GT at exon-intron boundary, with consensus context + "Splice acceptor site": "TTTTCAGG", # AG at intron-exon boundary, with pyrimidine tract + "Stop codon (TAA) context": "GCCTAAGCC", # TAA in coding context + "Stop codon (TAG) context": "GCCTAGGCC", # TAG in coding context +} + +# 19bp PWM window centered on the activation peak (positions -9..+9). +PWM_WINDOW = 19 +PWM_PEAK = 9 +PWM_BASES = ["A", "C", "G", "T"] + + +# Per-position base probabilities for each feature's central signature. Positions +# outside the signature are filled with near-uniform (low information) draws so +# real-looking logos have low-info flanks. +PWM_SIGNATURES: dict[str, list[dict[str, float]]] = { + # Kozak-like GCCACCATGG — ATG at signature positions 6..8. + "Start codon (ATG) context": [ + {"G": 0.70}, {"C": 0.70}, {"C": 0.70}, {"A": 0.60}, {"C": 0.65}, + {"C": 0.60}, {"A": 0.95}, {"T": 0.95}, {"G": 0.95}, {"G": 0.75}, + ], + "TATA box": [ # TATAAA + {"T": 0.90}, {"A": 0.90}, {"T": 0.90}, + {"A": 0.70, "G": 0.20}, {"A": 0.95}, {"A": 0.80, "T": 0.15}, + ], + "Polyadenylation signal": [ # AATAAA + {"A": 0.90}, {"A": 0.90}, {"T": 0.85}, + {"A": 0.85}, {"A": 0.85}, {"A": 0.80}, + ], + "Bacterial promoter -10 box": [ # TATAAT + {"T": 0.90}, {"A": 0.85}, {"T": 0.85}, + {"A": 0.75}, {"A": 0.75}, {"T": 0.85}, + ], + # CpG-rich: alternating GC bias across 12 positions, no sharp single peak. + "CpG island": [ + {"C": 0.55, "G": 0.35}, {"G": 0.55, "C": 0.35}, + {"C": 0.55, "G": 0.35}, {"G": 0.55, "C": 0.35}, + {"C": 0.55, "G": 0.35}, {"G": 0.55, "C": 0.35}, + {"C": 0.55, "G": 0.35}, {"G": 0.55, "C": 0.35}, + {"C": 0.55, "G": 0.35}, {"G": 0.55, "C": 0.35}, + {"C": 0.55, "G": 0.35}, {"G": 0.55, "C": 0.35}, + ], + "Shine-Dalgarno sequence": [ # AGGAGGT + {"A": 0.80}, {"G": 0.90}, {"G": 0.90}, + {"A": 0.75}, {"G": 0.90}, {"G": 0.85}, {"T": 0.60}, + ], + "Bacterial promoter -35 box": [ # TTGACA + {"T": 0.90}, {"T": 0.85}, {"G": 0.85}, + {"A": 0.80}, {"C": 0.85}, {"A": 0.80}, + ], + # GT at the exon|intron boundary is essentially invariant. + "Splice donor site": [ # GT.AAGT + {"G": 0.99}, {"T": 0.99}, {"A": 0.60}, + {"A": 0.70}, {"G": 0.80}, {"T": 0.60}, + ], + # Pyrimidine tract leading into an invariant AG at the intron|exon boundary. + "Splice acceptor site": [ + {"T": 0.80}, {"T": 0.80}, {"T": 0.80}, {"T": 0.80}, + {"C": 0.70}, {"A": 0.99}, {"G": 0.99}, {"G": 0.55}, + ], + "Stop codon (TAA) context": [ # GCC.TAA.GCC — coding-context flanks + {"G": 0.45, "C": 0.40}, {"C": 0.55, "G": 0.30}, {"C": 0.50, "G": 0.35}, + {"T": 0.95}, {"A": 0.90}, {"A": 0.90}, + {"G": 0.45, "C": 0.40}, {"C": 0.55, "G": 0.30}, {"C": 0.50, "G": 0.35}, + ], + "Stop codon (TAG) context": [ + {"G": 0.45, "C": 0.40}, {"C": 0.55, "G": 0.30}, {"C": 0.50, "G": 0.35}, + {"T": 0.95}, {"A": 0.90}, {"G": 0.90}, + {"G": 0.45, "C": 0.40}, {"C": 0.55, "G": 0.30}, {"C": 0.50, "G": 0.35}, + ], +} + + +# Annotation-database source for each feature label. +DB_SOURCES = { + "Start codon (ATG) context": "RefSeq", + "TATA box": "JASPAR / ENCODE", + "Polyadenylation signal": "RefSeq UTR", + "Bacterial promoter -10 box": "bacterial annotation", + "CpG island": "ENCODE / RefSeq", + "Shine-Dalgarno sequence": "bacterial annotation", + "Bacterial promoter -35 box": "bacterial annotation", + "Splice donor site": "RefSeq", + "Splice acceptor site": "RefSeq", + "Stop codon (TAA) context": "RefSeq", + "Stop codon (TAG) context": "RefSeq", +} + + +def _random_dna(rng: np.random.Generator, length: int) -> str: + """Generate a length-N DNA string by uniform-sampling A/C/G/T.""" + return "".join(rng.choice(list("ACGT"), size=length)) + + +def _build_pwm(rng: np.random.Generator, label: str | None) -> np.ndarray: + """Build a (PWM_WINDOW, 4) probability PWM for one feature label. + + Central signature pulled from PWM_SIGNATURES (or near-uniform for unlabeled + features). Flanks are exactly uniform (0 bits — blank logomaker columns) + so the logo reads as "this is the motif, everything else is background" + instead of a sea of tiny speckle letters. + """ + pwm = np.zeros((PWM_WINDOW, 4)) + uniform = np.full(4, 0.25) + signature = PWM_SIGNATURES.get(label) if label else None + + if signature is None: + # Unlabeled feature: essentially uniform — a very tight Dirichlet draw + # produces a flat, mostly-blank logo with no spurious consensus. + return np.tile(uniform, (PWM_WINDOW, 1)) + + sig_len = len(signature) + sig_start = PWM_PEAK - sig_len // 2 # center the signature on the activation peak + for i in range(PWM_WINDOW): + sig_idx = i - sig_start + if 0 <= sig_idx < sig_len: + spec = signature[sig_idx] + row = np.zeros(4) + for base, prob in spec.items(): + row[PWM_BASES.index(base)] = prob + # Distribute the remainder evenly across unspecified bases — no + # randomness, so secondary letters stay symmetric and quiet. + remainder = max(0.0, 1.0 - sum(spec.values())) + unspec = [b for b in PWM_BASES if b not in spec] + if unspec and remainder > 0: + share = remainder / len(unspec) + for b in unspec: + row[PWM_BASES.index(b)] = share + row = np.clip(row, 1e-6, None) + row /= row.sum() + pwm[i] = row + else: + # Flank: exactly uniform -> 0 bits -> blank column in the logo. + pwm[i] = uniform + + return pwm + + +def _render_logo(pwm: np.ndarray, feature_id: int, out_dir: Path) -> Path: + """Render a WebLogo-style PNG for one feature's PWM using logomaker. + + The information transform produces letter heights in bits (0..2). Position + labels are relative to the activation peak (-PWM_PEAK..+PWM_PEAK). + """ + import matplotlib + + matplotlib.use("Agg") # headless backend — safe for cron/CI + import logomaker + import matplotlib.pyplot as plt + + df = pd.DataFrame(pwm, columns=PWM_BASES) + info_df = logomaker.transform_matrix(df, from_type="probability", to_type="information") + + fig, ax = plt.subplots(figsize=(6, 1.8)) + logomaker.Logo(info_df, ax=ax, color_scheme="classic") + ax.set_xticks(range(PWM_WINDOW)) + ax.set_xticklabels([str(i - PWM_PEAK) for i in range(PWM_WINDOW)], fontsize=8) + ax.set_ylabel("Bits") + ax.set_ylim(0, 2) + ax.set_xlabel("Position relative to peak") + fig.tight_layout() + + out_path = out_dir / f"feature_{feature_id}.png" + fig.savefig(out_path, dpi=120, bbox_inches="tight") + plt.close(fig) + return out_path + + +def _make_example(rng: np.random.Generator, label: str, feature_max: float, window: int = 200) -> dict: + """Build one top-activating example: 200bp window with a central motif + a gaussian activation bump.""" + seq = list(_random_dna(rng, window)) + + # Splice the feature's central motif into the middle ± a few bp jitter. + motif = CENTRAL_MOTIFS[label] + center = window // 2 + int(rng.integers(-5, 6)) + motif_start = center - len(motif) // 2 + for i, base in enumerate(motif): + pos = motif_start + i + if 0 <= pos < window: + seq[pos] = base + + # Activation bump: gaussian centered in [80, 120], sigma ~= 8 bp, peak = feature_max * U(0.5, 1.0). + bump_center = int(rng.integers(80, 121)) + sigma = 8.0 + peak = float(feature_max * rng.uniform(0.5, 1.0)) + positions = np.arange(window) + activations = peak * np.exp(-((positions - bump_center) ** 2) / (2 * sigma**2)) + activations[activations < 0.01] = 0.0 # zero out the tails so the JSON is sparse-ish + + seq_id = SEQ_IDS[int(rng.integers(0, len(SEQ_IDS)))] + start = int(rng.integers(1, 5_000_001)) + + return { + "sequence_id": seq_id, + "start": start, + "end": start + window, + "sequence": "".join(seq), + "activations": [round(float(a), 3) for a in activations], + "max_activation": round(float(activations.max()), 4), + "max_activation_position": int(activations.argmax()), + } + + +def _make_features(rng: np.random.Generator) -> list[dict]: + """Build the 20 synthetic feature entries for features.json.""" + features = [] + for fid, label in enumerate(LABELS): + activation_freq = float(np.exp(rng.uniform(np.log(0.001), np.log(0.1)))) + max_activation = float(rng.uniform(5.0, 30.0)) + examples = [_make_example(rng, label, max_activation) for _ in range(30)] + + features.append( + { + "feature_id": fid, + "description": label, + "label": label, + "db_source": DB_SOURCES[label], + "activation_freq": round(activation_freq, 6), + "max_activation": round(max_activation, 4), + "top_positive_logits": [], + "top_negative_logits": [], + "examples": examples, + "logo_path": f"/logos/feature_{fid}.png", + } + ) + return features + + +def _make_atlas(rng: np.random.Generator, features: list[dict]) -> pd.DataFrame: + """Build features_atlas.parquet — UMAP coords grouped into thematic clusters. + + Labeled features sit in 3 clusters: eukaryotic regulatory (0), bacterial regulatory (1), + codon context (2). Unlabeled features (label==None) land in a 4th "uninterpreted" cluster (3) + spread more diffusely between the others — mimicking what a real SAE would look like. + """ + cluster_assignments = { + "Start codon (ATG) context": 2, + "TATA box": 0, + "Polyadenylation signal": 0, + "Bacterial promoter -10 box": 1, + "CpG island": 0, + "Shine-Dalgarno sequence": 1, + "Bacterial promoter -35 box": 1, + "Splice donor site": 0, + "Splice acceptor site": 0, + "Stop codon (TAA) context": 2, + "Stop codon (TAG) context": 2, + } + cluster_centers = { + 0: (-3.0, 1.5), + 1: (3.0, 1.5), + 2: (0.0, -3.0), + 3: (0.0, 0.5), # uninterpreted: between the other clusters + } + + coords = [] + cluster_ids = [] + for f in features: + if f["label"] is None: + cid = 3 + sigma = 1.4 # diffuse for the unlabeled cloud + else: + cid = cluster_assignments[f["label"]] + sigma = 0.4 + cx, cy = cluster_centers[cid] + x = cx + rng.normal(0, sigma) + y = cy + rng.normal(0, sigma) + coords.append((x, y)) + cluster_ids.append(cid) + coords = np.array(coords) + + return pd.DataFrame( + { + "feature_id": [f["feature_id"] for f in features], + "x": coords[:, 0].round(4), + "y": coords[:, 1].round(4), + "label": [f["label"] for f in features], + "db_source": [f["db_source"] for f in features], + "activation_freq": [f["activation_freq"] for f in features], + "log_frequency": [round(float(np.log10(f["activation_freq"])), 4) for f in features], + "max_activation": [f["max_activation"] for f in features], + "cluster_id": cluster_ids, + "logo_path": [f["logo_path"] for f in features], + } + ) + + +def _make_unlabeled_example(rng: np.random.Generator, feature_max: float, window: int = 200) -> dict: + """A top-activating example for an unlabeled feature: random sequence + gaussian activation bump.""" + seq = _random_dna(rng, window) + bump_center = int(rng.integers(60, 141)) + sigma = 8.0 + peak = float(feature_max * rng.uniform(0.5, 1.0)) + positions = np.arange(window) + activations = peak * np.exp(-((positions - bump_center) ** 2) / (2 * sigma**2)) + activations[activations < 0.01] = 0.0 + + seq_id = SEQ_IDS[int(rng.integers(0, len(SEQ_IDS)))] + start = int(rng.integers(1, 5_000_001)) + + return { + "sequence_id": seq_id, + "start": start, + "end": start + window, + "sequence": seq, + "activations": [round(float(a), 3) for a in activations], + "max_activation": round(float(activations.max()), 4), + "max_activation_position": int(activations.argmax()), + } + + +def _make_unlabeled_features(rng: np.random.Generator, n: int, start_id: int) -> list[dict]: + """Build n unlabeled features — no semantic label, random top-activator sequences.""" + out = [] + for i in range(n): + fid = start_id + i + activation_freq = float(np.exp(rng.uniform(np.log(0.001), np.log(0.1)))) + max_activation = float(rng.uniform(5.0, 30.0)) + examples = [_make_unlabeled_example(rng, max_activation) for _ in range(30)] + out.append( + { + "feature_id": fid, + "description": None, + "label": None, + "db_source": None, + "activation_freq": round(activation_freq, 6), + "max_activation": round(max_activation, 4), + "top_positive_logits": [], + "top_negative_logits": [], + "examples": examples, + # No motif means no logo — the UI will fall back to its + # "no logo available" empty-state for unlabeled features. + "logo_path": None, + } + ) + return out + + +def _make_examples_table(features: list[dict]) -> pd.DataFrame: + """Flatten per-feature examples into a long table for feature_examples.parquet. + + One row per (feature_id, example_rank). The dashboard lazy-loads these via DuckDB. + """ + rows = [] + for feature in features: + for rank, ex in enumerate(feature["examples"]): + rows.append( + { + "feature_id": feature["feature_id"], + "example_rank": rank, + "sequence_id": ex["sequence_id"], + "start": ex["start"], + "end": ex["end"], + "sequence": ex["sequence"], + "activations": ex["activations"], + "max_activation": ex["max_activation"], + "max_activation_position": ex["max_activation_position"], + "best_annotation": feature["db_source"], + } + ) + return pd.DataFrame(rows) + + +def main(): + """Generate synthetic parquet fixtures (atlas + metadata + examples) for the mockup dashboard.""" + p = argparse.ArgumentParser() + p.add_argument( + "--output-dir", + type=Path, + default=Path(__file__).resolve().parent.parent / "evo2_dashboard_mockup" / "public", + help="Where to write the three parquet fixtures", + ) + p.add_argument( + "--write-json", + action="store_true", + help="Also write features.json (only useful if you point the dashboard at the legacy JSON path)", + ) + p.add_argument("--n-unlabeled", type=int, default=9, help="How many unlabeled features to add alongside the labeled ones") + p.add_argument("--seed", type=int, default=42) + args = p.parse_args() + + rng = np.random.default_rng(args.seed) + args.output_dir.mkdir(parents=True, exist_ok=True) + + features = _make_features(rng) + features += _make_unlabeled_features(rng, n=args.n_unlabeled, start_id=len(features)) + + # Render one WebLogo PNG per labeled feature into /logos/. Unlabeled + # features get no logo — an empty WebLogo reads as a render bug, so we let + # the dashboard show its "no logo" empty state instead. + logo_dir = args.output_dir / "logos" + logo_dir.mkdir(parents=True, exist_ok=True) + rendered = 0 + for f in features: + if f["label"] is None: + continue + pwm = _build_pwm(rng, f["label"]) + _render_logo(pwm, f["feature_id"], logo_dir) + rendered += 1 + print(f"Wrote {rendered} logo PNGs -> {logo_dir}") + + atlas = _make_atlas(rng, features) + atlas.to_parquet(args.output_dir / "features_atlas.parquet", index=False) + # feature_metadata is the same shape as the atlas for the mockup — the dashboard + # loads them as two tables but the queried columns are identical. + atlas.to_parquet(args.output_dir / "feature_metadata.parquet", index=False) + + examples = _make_examples_table(features) + examples.to_parquet(args.output_dir / "feature_examples.parquet", index=False) + + if args.write_json: + with open(args.output_dir / "features.json", "w") as f: + json.dump({"features": features}, f) + print(f"Wrote {len(features)} features -> {args.output_dir / 'features.json'}") + + print(f"Wrote {len(atlas)} atlas rows -> {args.output_dir / 'features_atlas.parquet'}") + print(f"Wrote {len(atlas)} metadata rows -> {args.output_dir / 'feature_metadata.parquet'}") + print(f"Wrote {len(examples)} example rows -> {args.output_dir / 'feature_examples.parquet'}") + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/pt_to_parquet.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/pt_to_parquet.py new file mode 100644 index 0000000000..6a182b575d --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/pt_to_parquet.py @@ -0,0 +1,65 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convert predict_evo2 .pt outputs to SAE ActivationStore parquet shards. + +predict_evo2 with --embedding-layer writes dicts of: + hidden_embeddings: [B, S, H] (bf16) + pad_mask: [B, S] (1 = valid token, 0 = padding) + seq_idx, tokens: metadata, ignored here + +We read each file, mask out padding, flatten to [N_tokens, H], and append +to an ActivationStore so train.py's load_activations() can consume it. +""" + +import argparse +import json +from pathlib import Path + +import torch +from sae.activation_store import ActivationStore, ActivationStoreConfig +from tqdm import tqdm + + +def main(): + """Walk predict_evo2 .pt files, mask padding, and write to an ActivationStore.""" + p = argparse.ArgumentParser() + p.add_argument("--predict-dir", type=Path, required=True, help="Dir containing predictions__*.pt") + p.add_argument("--output", type=Path, required=True, help="ActivationStore output dir") + p.add_argument("--model-name", type=str, required=True, help="Stamped into metadata.json") + p.add_argument("--layer", type=int, required=True, help="Stamped into metadata.json") + p.add_argument("--shard-size", type=int, default=100_000) + args = p.parse_args() + + pt_files = sorted(args.predict_dir.rglob("predictions__*.pt")) + if not pt_files: + raise FileNotFoundError(f"No predictions__*.pt under {args.predict_dir}") + + store = ActivationStore(args.output, ActivationStoreConfig(shard_size=args.shard_size)) + n_sequences = 0 + for pt in tqdm(pt_files, desc="pt->parquet"): + d = torch.load(pt, map_location="cpu", weights_only=False) + hidden = d["hidden_embeddings"] + mask = d["pad_mask"].bool() + flat = hidden[mask].float() + store.append(flat) + n_sequences += hidden.shape[0] + + store.finalize(metadata={"model_name": args.model_name, "layer": args.layer, "n_sequences": n_sequences}) + print(json.dumps(store.metadata, indent=2)) + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/train.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/train.py new file mode 100644 index 0000000000..19355822ae --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/train.py @@ -0,0 +1,321 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Step 2: Train SAE from cached CodonFM activations. + +Loads pre-extracted activations from an ActivationStore cache directory +and trains a Sparse Autoencoder. Requires extract.py to have been run first. + +Single-GPU: + python scripts/train.py \ + --cache-dir .cache/activations/encodon_1b_layer-2 \ + --model-path path/to/encodon_1b --layer -2 \ + --expansion-factor 8 --top-k 32 --batch-size 4096 --n-epochs 3 + +Multi-GPU DDP: + torchrun --nproc_per_node=4 scripts/train.py \ + --cache-dir .cache/activations/encodon_1b_layer-2 \ + --model-path path/to/encodon_1b --layer -2 \ + --expansion-factor 8 --top-k 32 --batch-size 4096 --n-epochs 3 \ + --dp-size 4 +""" + +import argparse +import os +from pathlib import Path + +import numpy as np +import torch +from sae.activation_store import load_activations +from sae.architectures import ReLUSAE, TopKSAE +from sae.perf_logger import PerfLogger +from sae.training import ParallelConfig, Trainer, TrainingConfig, WandbConfig +from sae.utils import get_device, set_seed + + +def parse_args(): # noqa: D103 + p = argparse.ArgumentParser( + description="Train SAE from cached CodonFM activations", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Required + p.add_argument("--cache-dir", type=str, required=True, help="Path to activation cache (from extract.py)") + p.add_argument("--model-path", type=str, required=True, help="Encodon model path (for cache validation)") + p.add_argument("--layer", type=int, required=True, help="Layer index (for cache validation)") + + # SAE architecture + sae_group = p.add_argument_group("SAE model") + sae_group.add_argument("--model-type", type=str, default="topk", choices=["topk", "relu"]) + sae_group.add_argument("--expansion-factor", type=int, default=8) + sae_group.add_argument("--top-k", type=int, default=32) + sae_group.add_argument("--normalize-input", action=argparse.BooleanOptionalAction, default=False) + sae_group.add_argument("--auxk", type=int, default=None) + sae_group.add_argument("--auxk-coef", type=float, default=1 / 32) + sae_group.add_argument("--dead-tokens-threshold", type=int, default=10_000_000) + sae_group.add_argument("--init-pre-bias", action=argparse.BooleanOptionalAction, default=False) + sae_group.add_argument("--l1-coeff", type=float, default=1e-2, help="L1 coefficient (relu only)") + + # Training + train_group = p.add_argument_group("Training") + train_group.add_argument("--lr", type=float, default=3e-4) + train_group.add_argument("--n-epochs", type=int, default=3) + train_group.add_argument("--batch-size", type=int, default=4096) + train_group.add_argument("--log-interval", type=int, default=50) + train_group.add_argument("--shuffle", action=argparse.BooleanOptionalAction, default=True) + train_group.add_argument("--num-workers", type=int, default=0) + train_group.add_argument("--pin-memory", action=argparse.BooleanOptionalAction, default=False) + train_group.add_argument("--max-grad-norm", type=float, default=None) + train_group.add_argument("--lr-scale-with-latents", action=argparse.BooleanOptionalAction, default=False) + train_group.add_argument("--lr-reference-hidden-dim", type=int, default=2048) + train_group.add_argument("--warmup-steps", type=int, default=0, help="Linear LR warmup steps") + train_group.add_argument( + "--lr-schedule", + type=str, + default="constant", + choices=["constant", "cosine", "linear"], + help="LR schedule after warmup", + ) + train_group.add_argument("--lr-min", type=float, default=0.0, help="Minimum LR for decay schedules") + train_group.add_argument( + "--lr-decay-steps", + type=int, + default=None, + help="Total steps for LR decay (None = full training)", + ) + + # W&B + wb_group = p.add_argument_group("Weights & Biases") + wb_group.add_argument("--wandb", action=argparse.BooleanOptionalAction, default=False, dest="wandb_enabled") + wb_group.add_argument("--wandb-project", type=str, default="sae_codonfm_recipe") + wb_group.add_argument("--wandb-run-name", type=str, default=None) + wb_group.add_argument("--wandb-group", type=str, default=None) + wb_group.add_argument("--wandb-job-type", type=str, default=None) + + # Checkpointing + ckpt_group = p.add_argument_group("Checkpointing") + ckpt_group.add_argument("--checkpoint-dir", type=str, default=None) + ckpt_group.add_argument("--checkpoint-steps", type=int, default=None) + ckpt_group.add_argument("--resume-from", type=str, default=None) + + # Infrastructure + p.add_argument("--dp-size", type=int, default=1) + p.add_argument("--output-dir", type=str, default="./outputs") + p.add_argument("--seed", type=int, default=42) + p.add_argument("--device", type=str, default=None) + p.add_argument( + "--num-sequences", + type=int, + default=None, + help="Subset cached activations to this many sequences' worth of shards", + ) + + return p.parse_args() + + +def build_sae(args, input_dim: int) -> torch.nn.Module: # noqa: D103 + hidden_dim = input_dim * args.expansion_factor + + if args.model_type == "topk": + return TopKSAE( + input_dim=input_dim, + hidden_dim=hidden_dim, + top_k=args.top_k, + normalize_input=args.normalize_input, + auxk=args.auxk, + auxk_coef=args.auxk_coef, + dead_tokens_threshold=args.dead_tokens_threshold, + ) + elif args.model_type == "relu": + return ReLUSAE( + input_dim=input_dim, + hidden_dim=hidden_dim, + l1_coeff=args.l1_coeff, + ) + else: + raise ValueError(f"Unknown model type: {args.model_type}") + + +def build_training_config(args, device: str) -> TrainingConfig: # noqa: D103 + return TrainingConfig( + lr=args.lr, + n_epochs=args.n_epochs, + batch_size=args.batch_size, + device=device, + log_interval=args.log_interval, + shuffle=args.shuffle, + num_workers=args.num_workers, + pin_memory=args.pin_memory, + checkpoint_dir=args.checkpoint_dir, + checkpoint_steps=args.checkpoint_steps, + lr_scale_with_latents=args.lr_scale_with_latents, + lr_reference_hidden_dim=args.lr_reference_hidden_dim, + warmup_steps=args.warmup_steps, + max_grad_norm=args.max_grad_norm, + lr_schedule=args.lr_schedule, + lr_min=args.lr_min, + lr_decay_steps=args.lr_decay_steps, + ) + + +def build_wandb_config(args) -> WandbConfig: # noqa: D103 + return WandbConfig( + enabled=args.wandb_enabled, + project=args.wandb_project, + run_name=args.wandb_run_name, + group=args.wandb_group, + job_type=args.wandb_job_type, + config=vars(args), + ) + + +def build_parallel_config(args) -> ParallelConfig: # noqa: D103 + return ParallelConfig(dp_size=args.dp_size) + + +def main(): # noqa: D103 + args = parse_args() + + set_seed(args.seed) + device = args.device or get_device() + print(f"Using device: {device}") + print(f"Config: {vars(args)}") + + # Load cached activations + cache_path = Path(args.cache_dir) + if not (cache_path / "metadata.json").exists(): + raise FileNotFoundError(f"No cache found at {cache_path}. Run extract.py first.") + + store = load_activations(cache_path) + meta = store.metadata + + # Validate cache matches config + cached_model = meta.get("model_path", meta.get("model_name", "")) + if cached_model and cached_model != args.model_path: + print(f"WARNING: Cache model '{cached_model}' != '{args.model_path}'") + if meta.get("layer") != args.layer: + raise ValueError(f"Cache layer mismatch: {meta['layer']} vs {args.layer}") + + # Compute subsetting + cached_sequences = meta.get("n_sequences", None) + max_shards = None + if args.num_sequences and cached_sequences and args.num_sequences < cached_sequences: + keep_ratio = args.num_sequences / cached_sequences + max_shards = max(1, int(np.ceil(keep_ratio * meta["n_shards"]))) + print( + f"Subsetting: {args.num_sequences}/{cached_sequences} sequences " + f"-> using {max_shards}/{meta['n_shards']} shards (~{keep_ratio:.1%})" + ) + + # Estimate memory + n_shards_to_use = max_shards or meta["n_shards"] + shard_size = meta.get("shard_size", 100_000) + est_tokens = n_shards_to_use * shard_size + est_gb = est_tokens * meta["hidden_dim"] * 4 / (1024**3) + use_streaming = est_gb > 50 + + input_dim = meta["hidden_dim"] + sae = build_sae(args, input_dim) + print(f"SAE: {args.model_type}, input_dim={input_dim}, hidden_dim={sae.hidden_dim}") + + # Initialize pre_bias + if args.init_pre_bias and hasattr(sae, "init_pre_bias_from_data"): + print("Initializing pre_bias from geometric median of data...") + first_shard = torch.from_numpy(store._load_shard(0)).float() + sample_size = min(32768, len(first_shard)) + sae.init_pre_bias_from_data(first_shard[:sample_size]) + print(f" pre_bias initialized (mean={sae.pre_bias.mean().item():.4f})") + del first_shard + + # Build configs + training_config = build_training_config(args, device) + wandb_config = build_wandb_config(args) + parallel_config = build_parallel_config(args) + + perf_logger = PerfLogger( + log_interval=args.log_interval, + use_wandb=args.wandb_enabled, + print_logs=True, + device=device, + ) + + # Train + trainer = Trainer( + sae, + training_config, + wandb_config=wandb_config, + perf_logger=perf_logger, + parallel_config=parallel_config, + ) + + if use_streaming: + rank = int(os.environ.get("RANK", 0)) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + print( + f"Streaming from disk (~{est_gb:.0f}GB). " + f"Peak RAM: ~{shard_size * meta['hidden_dim'] * 4 / (1024**3):.1f}GB/process" + ) + + dataloader = store.get_streaming_dataloader( + batch_size=args.batch_size, + shuffle=args.shuffle, + seed=args.seed, + rank=rank, + world_size=world_size, + max_shards=max_shards, + ) + # Compute min batch count across all ranks to keep DDP in sync + # Read parquet footers for all ranks' shards (a few KB each, no data loading) + if world_size > 1: + import pyarrow.parquet as pq_meta + + dataset = dataloader.dataset + per_rank = len(dataset.shard_indices) + # Each rank got per_rank contiguous shards; compute batch count for each rank + min_batches = None + for r in range(world_size): + total_rows = sum( + pq_meta.read_metadata(store.path / f"shard_{idx:05d}.parquet").num_rows + for idx in range(r * per_rank, (r + 1) * per_rank) + ) + batches = total_rows // args.batch_size + if min_batches is None or batches < min_batches: + min_batches = batches + dataset.max_batches = min_batches + print(f"[rank {rank}] capped to {min_batches} batches/epoch for DDP sync") + trainer.fit( + dataloader, + resume_from=args.resume_from, + data_sharded=True, + ) + else: + shards = [] + for i, shard in enumerate(store.iter_shards(shuffle_shards=False)): + if max_shards is not None and i >= max_shards: + break + shards.append(torch.from_numpy(shard).float()) + activations_flat = torch.cat(shards) + print(f"Loaded {activations_flat.shape[0]:,} cached activations into memory") + + trainer.fit( + activations_flat, + resume_from=args.resume_from, + ) + + print("Training complete.") + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/utils/checkpoint/savanna_to_mbridge.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/utils/checkpoint/savanna_to_mbridge.py index 811b07153e..156ce530b5 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/utils/checkpoint/savanna_to_mbridge.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/utils/checkpoint/savanna_to_mbridge.py @@ -135,7 +135,7 @@ def load_savanna_state_dict(path: Path) -> dict[str, torch.Tensor]: Returns: Flat state dict with keys like 'sequential.{i}.xxx'. """ - raw = torch.load(str(path), map_location="cpu", weights_only=True, mmap=True) + raw = torch.load(str(path), map_location="cpu", weights_only=False, mmap=True) if "module" in raw: raw = raw["module"]