Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ uv run protspace prepare -i data/sizes/phosphatase.h5:prot_t5 -m pca2 -o output

# Run all 6 DR methods on sample data
uv run protspace prepare -i data/sizes/phosphatase.h5:prot_t5 -m "pca2,tsne2,umap2,pacmap2,mds2,localmap2" -o output --no-scores -v

# Compare UMAP with different parameters in a single run
uv run protspace prepare -i data/sizes/phosphatase.h5:prot_t5 -m "umap2:n_neighbors=15" -m "umap2:n_neighbors=50" -m pca2 -o output --no-scores
```

## CLI Commands
Expand Down Expand Up @@ -60,6 +63,8 @@ protspace prepare -i <input> -m <methods> -o <output> [options]
# Multi-embedding (different names → intersection): protspace prepare -i esm2.h5 -i prott5.h5 -m pca2 -o output
# With similarity: protspace prepare -i emb.h5 -f seq.fasta -s -m pca2,mds2 -o output
# Name override: protspace prepare -i emb.h5:custom_name -m pca2 -o output
# Parameter sweep: protspace prepare -i emb.h5 -m "umap2:n_neighbors=15" -m "umap2:n_neighbors=50" -m pca2 -o output
# Inline params: protspace prepare -i emb.h5 -m "pca2,umap2:n_neighbors=50;min_dist=0.3" -o output
```

### Supported Embedders (via Biocentral API)
Expand Down Expand Up @@ -204,7 +209,7 @@ uv run pytest tests/ --cov=src/protspace # With coverage
| `test_interpro_annotation_retriever.py` | 46 | InterPro API mocking, parsing |
| `test_settings_converter.py` | 31 | Settings table ↔ visualization state conversion |
| `test_uniprot_annotation_retriever.py` | 24 | UniProt API mocking, inactive entry resolution |
| `test_pipeline_utils.py` | 41 | ReductionPipeline, EmbeddingSet, method parsing, multi-input merging |
| `test_pipeline_utils.py` | 70 | ReductionPipeline, EmbeddingSet, method parsing, multi-input merging, inline param overrides |
| `test_biocentral_embedder.py` | 23 | Biocentral API client, embedding flow |
| `test_fasta.py` | 17 | FASTA parsing, edge cases, CSV annotation loading |
| `test_biocentral_retriever.py` | 14 | Biocentral prediction retriever (TMbed parsing, per-sequence) |
Expand Down
8 changes: 7 additions & 1 deletion docs/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ protspace prepare -i emb.h5 -f seq.fasta -s -m pca2,mds2 -o output

# External HDF5 without model_name attribute — use colon syntax
protspace prepare -i external.h5:prot_t5 -m pca2 -o output

# Compare UMAP with different parameters in a single run
protspace prepare -i emb.h5 -m "umap2:n_neighbors=15" -m "umap2:n_neighbors=50" -m pca2 -o output

# Inline params with semicolons, comma-separated methods
protspace prepare -i emb.h5 -m "pca2,umap2:n_neighbors=50;min_dist=0.3,tsne2" -o output
```

### Options
Expand Down Expand Up @@ -63,7 +69,7 @@ protspace prepare -i external.h5:prot_t5 -m pca2 -o output

| Flag | Description | Default |
| ---- | ----------- | ------- |
| `-m, --methods` | DR methods (comma-separated): `pca2`, `umap2`, `tsne2`, `pacmap2`, `mds2`, `localmap2` | `pca2` |
| `-m, --methods` | DR methods (comma-sep or repeat). Inline params: `-m 'umap2:n_neighbors=50;min_dist=0.1'`. Methods: `pca2`, `umap2`, `tsne2`, `pacmap2`, `mds2`, `localmap2` | `pca2` |
| `-s, --similarity` | Also compute sequence similarity DR from FASTA. | off |
| `--metric` | Distance metric (`euclidean`, `cosine`, `manhattan`). | `euclidean` |
| `--random-state` | Random seed. | `42` |
Expand Down
39 changes: 4 additions & 35 deletions notebooks/ProtSpace_Preparation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -25,39 +25,7 @@
"cellView": "form"
},
"outputs": [],
"source": [
"# @title 1. Install & Setup (~30s)\n",
"%%capture\n",
"!pip install -qqq protspace\n",
"\n",
"import urllib.request\n",
"from pathlib import Path\n",
"\n",
"import h5py\n",
"import ipywidgets as widgets\n",
"\n",
"# Patch tqdm for notebook-style progress bars BEFORE importing protspace\n",
"import tqdm as _tqdm_mod\n",
"import tqdm.notebook as _tqdm_nb\n",
"from google.colab import files\n",
"from IPython.display import HTML as IHTML\n",
"from IPython.display import clear_output, display\n",
"\n",
"_tqdm_mod.tqdm = _tqdm_nb.tqdm\n",
"\n",
"# Pre-import protspace (lightweight now — heavy reducers load per-method)\n",
"import time as _time\n",
"\n",
"import pandas as _pd\n",
"\n",
"from protspace.data.loaders import embed_fasta, load_h5\n",
"from protspace.data.loaders.query import query_uniprot\n",
"from protspace.data.processors.pipeline import (\n",
" PipelineConfig,\n",
" ReducerParams,\n",
" ReductionPipeline,\n",
")"
]
"source": "# @title 1. Install & Setup (~30s)\n%%capture\n!pip install -qqq protspace\n\nimport urllib.request\nfrom pathlib import Path\n\nimport h5py\nimport ipywidgets as widgets\n\n# Patch tqdm for notebook-style progress bars BEFORE importing protspace\nimport tqdm as _tqdm_mod\nimport tqdm.notebook as _tqdm_nb\nfrom google.colab import files\nfrom IPython.display import HTML as IHTML\nfrom IPython.display import clear_output, display\n\n_tqdm_mod.tqdm = _tqdm_nb.tqdm\n\n# Pre-import protspace (lightweight now — heavy reducers load per-method)\nimport time as _time\n\nimport pandas as _pd\n\nfrom protspace.data.loaders import embed_fasta, load_h5\nfrom protspace.data.loaders.query import query_uniprot\nfrom protspace.data.processors.pipeline import (\n PipelineConfig,\n ReducerParams,\n ReductionPipeline,\n parse_methods_arg,\n)"
},
{
"cell_type": "code",
Expand Down Expand Up @@ -469,7 +437,8 @@
" if not input_args:\n",
" print(\"Select input data in cell above first.\")\n",
" return\n",
" method_specs = [m.lower().replace(\"-\", \"\") + \"2\" for m, t in method_toggles.items() if t.value]\n",
" method_strs = [m.lower().replace(\"-\", \"\") + \"2\" for m, t in method_toggles.items() if t.value]\n",
" method_specs = parse_methods_arg(method_strs)\n",
" if not method_specs:\n",
" print(\"Select at least one method.\")\n",
" return\n",
Expand Down Expand Up @@ -662,4 +631,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}
7 changes: 5 additions & 2 deletions src/protspace/cli/common_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,14 @@ class Metric(str, Enum):

# Projection options (shared by prepare and project)
Opt_Methods = Annotated[
str,
list[str] | None,
typer.Option(
"-m",
"--methods",
help="DR methods, comma-separated: pca2,umap2,tsne2,pacmap2,mds2,localmap2.",
help=(
"DR methods. Comma-sep or repeat: -m pca2,umap2 or -m pca2 -m umap2. "
"Inline params: -m 'umap2:n_neighbors=50;min_dist=0.1'."
),
rich_help_panel="Projection",
),
]
Expand Down
9 changes: 6 additions & 3 deletions src/protspace/cli/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def prepare(
embedder: Opt_Embedder = None,
batch_size: Opt_BatchSize = 1000,
# Projection
methods: Opt_Methods = "pca2",
methods: Opt_Methods = None,
similarity: Opt_Similarity = False,
metric: Opt_Metric = Metric.euclidean,
random_state: Opt_RandomState = 42,
Expand Down Expand Up @@ -481,8 +481,11 @@ def prepare(
PipelineConfig,
ReducerParams,
ReductionPipeline,
parse_methods_arg,
)

method_specs = parse_methods_arg(methods or ["pca2"])

reducer_params = ReducerParams(
metric=metric.value,
random_state=random_state,
Expand All @@ -497,7 +500,7 @@ def prepare(
eps=eps,
)
config = PipelineConfig(
methods=methods.split(","),
methods=method_specs,
output_path=output_path,
bundled=bundled,
keep_tmp=keep_tmp,
Expand Down Expand Up @@ -642,7 +645,7 @@ def _write_run_log(
lines.append(f"{key}: {val}")

lines += ["", "## Projection"]
lines.append(f"methods: {', '.join(pipeline_config.methods)}")
lines.append(f"methods: {', '.join(str(m) for m in pipeline_config.methods)}")
lines.append(f"similarity: {similarity}")
for key, val in rp.items():
lines.append(f"{key}: {val}")
Expand Down
42 changes: 31 additions & 11 deletions src/protspace/cli/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def project(
help="HDF5 file(s). Repeat for multi-embedding. Colon syntax: -i file.h5:name",
),
],
methods: Opt_Methods = "pca2",
methods: Opt_Methods = None,
output: Annotated[
Path,
typer.Option(
Expand Down Expand Up @@ -69,12 +69,20 @@ def project(
"""
setup_logging(verbose)

from collections import Counter

import pyarrow.parquet as pq

from protspace.cli.prepare import _parse_input_specs
from protspace.data.loaders import EmbeddingSet, compute_similarity, load_h5
from protspace.data.loaders.embedding_set import format_projection_name
from protspace.data.processors.base_processor import BaseProcessor
from protspace.data.processors.pipeline import parse_method_spec
from protspace.data.processors.pipeline import (
ReducerParams,
_run_with_overridden_config,
disambiguation_suffix,
parse_methods_arg,
)
from protspace.utils import get_reducers
from protspace.utils.constants import MDS_NAME

Expand All @@ -95,7 +103,7 @@ def project(

from dataclasses import asdict

from protspace.data.processors.pipeline import ReducerParams
method_specs = parse_methods_arg(methods or ["pca2"])

reducer_params = ReducerParams(
metric=metric.value,
Expand All @@ -110,14 +118,18 @@ def project(
max_iter=max_iter,
eps=eps,
)
global_params = asdict(reducer_params)
reducers = get_reducers()
base = BaseProcessor(asdict(reducer_params), reducers)
base = BaseProcessor(global_params, reducers)

# Pre-compute which (method, dims) pairs appear multiple times
method_counts = Counter((s.method, s.dims) for s in method_specs)

all_reductions = []
headers = embedding_sets[0].headers
for emb_set in embedding_sets:
for method_spec in methods.split(","):
method, dims = parse_method_spec(method_spec)
for spec in method_specs:
method, dims = spec.method, spec.dims
if emb_set.precomputed and method != MDS_NAME:
logger.warning(
f"Skipping {method} for '{emb_set.name}' (only MDS for precomputed)"
Expand All @@ -126,13 +138,21 @@ def project(
if method not in reducers:
logger.warning(f"Unknown method: {method}. Skipping.")
continue

effective_params = {**global_params, **spec.overrides_dict}
if emb_set.precomputed:
base.config["precomputed"] = True
else:
base.config.pop("precomputed", None)
effective_params["precomputed"] = True

logger.info(f"Applying {method.upper()}{dims} to '{emb_set.name}'")
reduction = base.process_reduction(emb_set.data, method, dims)
reduction["name"] = f"{emb_set.name} — {reduction['name']}"
reduction = _run_with_overridden_config(
base, effective_params, method, dims, emb_set.data
)
reduction["name"] = format_projection_name(
emb_set.name,
method,
dims,
disambiguation_suffix(spec, method_counts),
)
all_reductions.append(reduction)

output.mkdir(parents=True, exist_ok=True)
Expand Down
42 changes: 38 additions & 4 deletions src/protspace/data/loaders/embedding_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,51 @@
}


def format_projection_name(source: str, method: str, dims: int) -> str:
# Abbreviations for DR parameters in projection names
_PARAM_ABBREVS: dict[str, str] = {
"n_neighbors": "n",
"min_dist": "d",
"perplexity": "p",
"learning_rate": "lr",
"mn_ratio": "mn",
"fp_ratio": "fp",
"metric": "m",
"random_state": "rs",
"n_init": "ni",
"max_iter": "mi",
"eps": "e",
}


def format_param_suffix(overrides: dict[str, int | float | str]) -> str:
"""Format parameter overrides into a compact suffix string.

Examples:
{"n_neighbors": 50, "min_dist": 0.1} → "n=50, d=0.1"
{"metric": "cosine"} → "m=cosine"
"""
parts = []
for key in sorted(overrides):
abbr = _PARAM_ABBREVS.get(key, key)
parts.append(f"{abbr}={overrides[key]}")
return ", ".join(parts)


def format_projection_name(
source: str, method: str, dims: int, param_suffix: str = ""
) -> str:
"""Format a human-readable projection name.

Examples:
("prot_t5", "pca", 2) → "ProtT5 — PCA 2"
("esm2_650m", "umap", 2) → "ESM2-650M — UMAP 2"
("MMseqs2", "mds", 2) → "MMseqs2 — MDS 2"
("esm2_650m", "umap", 2, "n=50, d=0.1") → "ESM2-650M — UMAP 2 (n=50, d=0.1)"
"""
source_display = MODEL_DISPLAY_NAMES.get(source, source)
method_display = METHOD_DISPLAY_NAMES.get(method, method.upper())
return f"{source_display} — {method_display} {dims}"
name = f"{source_display} — {method_display} {dims}"
if param_suffix:
name += f" ({param_suffix})"
return name


@dataclass
Expand Down
Loading