diff --git a/CLAUDE.md b/CLAUDE.md index 8a624f4c..f3a8aeb7 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -31,6 +31,9 @@ uv run protspace prepare -i data/sizes/phosphatase.h5:prot_t5 -m pca2 -o output # Run all 6 DR methods on sample data uv run protspace prepare -i data/sizes/phosphatase.h5:prot_t5 -m "pca2,tsne2,umap2,pacmap2,mds2,localmap2" -o output --no-scores -v + +# Compare UMAP with different parameters in a single run +uv run protspace prepare -i data/sizes/phosphatase.h5:prot_t5 -m "umap2:n_neighbors=15" -m "umap2:n_neighbors=50" -m pca2 -o output --no-scores ``` ## CLI Commands @@ -60,6 +63,8 @@ protspace prepare -i -m -o [options] # Multi-embedding (different names → intersection): protspace prepare -i esm2.h5 -i prott5.h5 -m pca2 -o output # With similarity: protspace prepare -i emb.h5 -f seq.fasta -s -m pca2,mds2 -o output # Name override: protspace prepare -i emb.h5:custom_name -m pca2 -o output +# Parameter sweep: protspace prepare -i emb.h5 -m "umap2:n_neighbors=15" -m "umap2:n_neighbors=50" -m pca2 -o output +# Inline params: protspace prepare -i emb.h5 -m "pca2,umap2:n_neighbors=50;min_dist=0.3" -o output ``` ### Supported Embedders (via Biocentral API) @@ -204,7 +209,7 @@ uv run pytest tests/ --cov=src/protspace # With coverage | `test_interpro_annotation_retriever.py` | 46 | InterPro API mocking, parsing | | `test_settings_converter.py` | 31 | Settings table ↔ visualization state conversion | | `test_uniprot_annotation_retriever.py` | 24 | UniProt API mocking, inactive entry resolution | -| `test_pipeline_utils.py` | 41 | ReductionPipeline, EmbeddingSet, method parsing, multi-input merging | +| `test_pipeline_utils.py` | 70 | ReductionPipeline, EmbeddingSet, method parsing, multi-input merging, inline param overrides | | `test_biocentral_embedder.py` | 23 | Biocentral API client, embedding flow | | `test_fasta.py` | 17 | FASTA parsing, edge cases, CSV annotation loading | | `test_biocentral_retriever.py` | 14 | Biocentral prediction retriever (TMbed parsing, per-sequence) | diff --git a/docs/cli.md b/docs/cli.md index b51c339a..7fa4972d 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -36,6 +36,12 @@ protspace prepare -i emb.h5 -f seq.fasta -s -m pca2,mds2 -o output # External HDF5 without model_name attribute — use colon syntax protspace prepare -i external.h5:prot_t5 -m pca2 -o output + +# Compare UMAP with different parameters in a single run +protspace prepare -i emb.h5 -m "umap2:n_neighbors=15" -m "umap2:n_neighbors=50" -m pca2 -o output + +# Inline params with semicolons, comma-separated methods +protspace prepare -i emb.h5 -m "pca2,umap2:n_neighbors=50;min_dist=0.3,tsne2" -o output ``` ### Options @@ -63,7 +69,7 @@ protspace prepare -i external.h5:prot_t5 -m pca2 -o output | Flag | Description | Default | | ---- | ----------- | ------- | -| `-m, --methods` | DR methods (comma-separated): `pca2`, `umap2`, `tsne2`, `pacmap2`, `mds2`, `localmap2` | `pca2` | +| `-m, --methods` | DR methods (comma-sep or repeat). Inline params: `-m 'umap2:n_neighbors=50;min_dist=0.1'`. Methods: `pca2`, `umap2`, `tsne2`, `pacmap2`, `mds2`, `localmap2` | `pca2` | | `-s, --similarity` | Also compute sequence similarity DR from FASTA. | off | | `--metric` | Distance metric (`euclidean`, `cosine`, `manhattan`). | `euclidean` | | `--random-state` | Random seed. | `42` | diff --git a/notebooks/ProtSpace_Preparation.ipynb b/notebooks/ProtSpace_Preparation.ipynb index f25cda72..cc85dad4 100644 --- a/notebooks/ProtSpace_Preparation.ipynb +++ b/notebooks/ProtSpace_Preparation.ipynb @@ -25,39 +25,7 @@ "cellView": "form" }, "outputs": [], - "source": [ - "# @title 1. Install & Setup (~30s)\n", - "%%capture\n", - "!pip install -qqq protspace\n", - "\n", - "import urllib.request\n", - "from pathlib import Path\n", - "\n", - "import h5py\n", - "import ipywidgets as widgets\n", - "\n", - "# Patch tqdm for notebook-style progress bars BEFORE importing protspace\n", - "import tqdm as _tqdm_mod\n", - "import tqdm.notebook as _tqdm_nb\n", - "from google.colab import files\n", - "from IPython.display import HTML as IHTML\n", - "from IPython.display import clear_output, display\n", - "\n", - "_tqdm_mod.tqdm = _tqdm_nb.tqdm\n", - "\n", - "# Pre-import protspace (lightweight now — heavy reducers load per-method)\n", - "import time as _time\n", - "\n", - "import pandas as _pd\n", - "\n", - "from protspace.data.loaders import embed_fasta, load_h5\n", - "from protspace.data.loaders.query import query_uniprot\n", - "from protspace.data.processors.pipeline import (\n", - " PipelineConfig,\n", - " ReducerParams,\n", - " ReductionPipeline,\n", - ")" - ] + "source": "# @title 1. Install & Setup (~30s)\n%%capture\n!pip install -qqq protspace\n\nimport urllib.request\nfrom pathlib import Path\n\nimport h5py\nimport ipywidgets as widgets\n\n# Patch tqdm for notebook-style progress bars BEFORE importing protspace\nimport tqdm as _tqdm_mod\nimport tqdm.notebook as _tqdm_nb\nfrom google.colab import files\nfrom IPython.display import HTML as IHTML\nfrom IPython.display import clear_output, display\n\n_tqdm_mod.tqdm = _tqdm_nb.tqdm\n\n# Pre-import protspace (lightweight now — heavy reducers load per-method)\nimport time as _time\n\nimport pandas as _pd\n\nfrom protspace.data.loaders import embed_fasta, load_h5\nfrom protspace.data.loaders.query import query_uniprot\nfrom protspace.data.processors.pipeline import (\n PipelineConfig,\n ReducerParams,\n ReductionPipeline,\n parse_methods_arg,\n)" }, { "cell_type": "code", @@ -469,7 +437,8 @@ " if not input_args:\n", " print(\"Select input data in cell above first.\")\n", " return\n", - " method_specs = [m.lower().replace(\"-\", \"\") + \"2\" for m, t in method_toggles.items() if t.value]\n", + " method_strs = [m.lower().replace(\"-\", \"\") + \"2\" for m, t in method_toggles.items() if t.value]\n", + " method_specs = parse_methods_arg(method_strs)\n", " if not method_specs:\n", " print(\"Select at least one method.\")\n", " return\n", @@ -662,4 +631,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/src/protspace/cli/common_options.py b/src/protspace/cli/common_options.py index 57b86869..f691203a 100644 --- a/src/protspace/cli/common_options.py +++ b/src/protspace/cli/common_options.py @@ -27,11 +27,14 @@ class Metric(str, Enum): # Projection options (shared by prepare and project) Opt_Methods = Annotated[ - str, + list[str] | None, typer.Option( "-m", "--methods", - help="DR methods, comma-separated: pca2,umap2,tsne2,pacmap2,mds2,localmap2.", + help=( + "DR methods. Comma-sep or repeat: -m pca2,umap2 or -m pca2 -m umap2. " + "Inline params: -m 'umap2:n_neighbors=50;min_dist=0.1'." + ), rich_help_panel="Projection", ), ] diff --git a/src/protspace/cli/prepare.py b/src/protspace/cli/prepare.py index c697d304..4831f6b2 100644 --- a/src/protspace/cli/prepare.py +++ b/src/protspace/cli/prepare.py @@ -274,7 +274,7 @@ def prepare( embedder: Opt_Embedder = None, batch_size: Opt_BatchSize = 1000, # Projection - methods: Opt_Methods = "pca2", + methods: Opt_Methods = None, similarity: Opt_Similarity = False, metric: Opt_Metric = Metric.euclidean, random_state: Opt_RandomState = 42, @@ -481,8 +481,11 @@ def prepare( PipelineConfig, ReducerParams, ReductionPipeline, + parse_methods_arg, ) + method_specs = parse_methods_arg(methods or ["pca2"]) + reducer_params = ReducerParams( metric=metric.value, random_state=random_state, @@ -497,7 +500,7 @@ def prepare( eps=eps, ) config = PipelineConfig( - methods=methods.split(","), + methods=method_specs, output_path=output_path, bundled=bundled, keep_tmp=keep_tmp, @@ -642,7 +645,7 @@ def _write_run_log( lines.append(f"{key}: {val}") lines += ["", "## Projection"] - lines.append(f"methods: {', '.join(pipeline_config.methods)}") + lines.append(f"methods: {', '.join(str(m) for m in pipeline_config.methods)}") lines.append(f"similarity: {similarity}") for key, val in rp.items(): lines.append(f"{key}: {val}") diff --git a/src/protspace/cli/project.py b/src/protspace/cli/project.py index 2565dfaa..123c3e27 100644 --- a/src/protspace/cli/project.py +++ b/src/protspace/cli/project.py @@ -39,7 +39,7 @@ def project( help="HDF5 file(s). Repeat for multi-embedding. Colon syntax: -i file.h5:name", ), ], - methods: Opt_Methods = "pca2", + methods: Opt_Methods = None, output: Annotated[ Path, typer.Option( @@ -69,12 +69,20 @@ def project( """ setup_logging(verbose) + from collections import Counter + import pyarrow.parquet as pq from protspace.cli.prepare import _parse_input_specs from protspace.data.loaders import EmbeddingSet, compute_similarity, load_h5 + from protspace.data.loaders.embedding_set import format_projection_name from protspace.data.processors.base_processor import BaseProcessor - from protspace.data.processors.pipeline import parse_method_spec + from protspace.data.processors.pipeline import ( + ReducerParams, + _run_with_overridden_config, + disambiguation_suffix, + parse_methods_arg, + ) from protspace.utils import get_reducers from protspace.utils.constants import MDS_NAME @@ -95,7 +103,7 @@ def project( from dataclasses import asdict - from protspace.data.processors.pipeline import ReducerParams + method_specs = parse_methods_arg(methods or ["pca2"]) reducer_params = ReducerParams( metric=metric.value, @@ -110,14 +118,18 @@ def project( max_iter=max_iter, eps=eps, ) + global_params = asdict(reducer_params) reducers = get_reducers() - base = BaseProcessor(asdict(reducer_params), reducers) + base = BaseProcessor(global_params, reducers) + + # Pre-compute which (method, dims) pairs appear multiple times + method_counts = Counter((s.method, s.dims) for s in method_specs) all_reductions = [] headers = embedding_sets[0].headers for emb_set in embedding_sets: - for method_spec in methods.split(","): - method, dims = parse_method_spec(method_spec) + for spec in method_specs: + method, dims = spec.method, spec.dims if emb_set.precomputed and method != MDS_NAME: logger.warning( f"Skipping {method} for '{emb_set.name}' (only MDS for precomputed)" @@ -126,13 +138,21 @@ def project( if method not in reducers: logger.warning(f"Unknown method: {method}. Skipping.") continue + + effective_params = {**global_params, **spec.overrides_dict} if emb_set.precomputed: - base.config["precomputed"] = True - else: - base.config.pop("precomputed", None) + effective_params["precomputed"] = True + logger.info(f"Applying {method.upper()}{dims} to '{emb_set.name}'") - reduction = base.process_reduction(emb_set.data, method, dims) - reduction["name"] = f"{emb_set.name} — {reduction['name']}" + reduction = _run_with_overridden_config( + base, effective_params, method, dims, emb_set.data + ) + reduction["name"] = format_projection_name( + emb_set.name, + method, + dims, + disambiguation_suffix(spec, method_counts), + ) all_reductions.append(reduction) output.mkdir(parents=True, exist_ok=True) diff --git a/src/protspace/data/loaders/embedding_set.py b/src/protspace/data/loaders/embedding_set.py index 5c064eed..c90ba698 100644 --- a/src/protspace/data/loaders/embedding_set.py +++ b/src/protspace/data/loaders/embedding_set.py @@ -38,17 +38,51 @@ } -def format_projection_name(source: str, method: str, dims: int) -> str: +# Abbreviations for DR parameters in projection names +_PARAM_ABBREVS: dict[str, str] = { + "n_neighbors": "n", + "min_dist": "d", + "perplexity": "p", + "learning_rate": "lr", + "mn_ratio": "mn", + "fp_ratio": "fp", + "metric": "m", + "random_state": "rs", + "n_init": "ni", + "max_iter": "mi", + "eps": "e", +} + + +def format_param_suffix(overrides: dict[str, int | float | str]) -> str: + """Format parameter overrides into a compact suffix string. + + Examples: + {"n_neighbors": 50, "min_dist": 0.1} → "n=50, d=0.1" + {"metric": "cosine"} → "m=cosine" + """ + parts = [] + for key in sorted(overrides): + abbr = _PARAM_ABBREVS.get(key, key) + parts.append(f"{abbr}={overrides[key]}") + return ", ".join(parts) + + +def format_projection_name( + source: str, method: str, dims: int, param_suffix: str = "" +) -> str: """Format a human-readable projection name. Examples: ("prot_t5", "pca", 2) → "ProtT5 — PCA 2" - ("esm2_650m", "umap", 2) → "ESM2-650M — UMAP 2" - ("MMseqs2", "mds", 2) → "MMseqs2 — MDS 2" + ("esm2_650m", "umap", 2, "n=50, d=0.1") → "ESM2-650M — UMAP 2 (n=50, d=0.1)" """ source_display = MODEL_DISPLAY_NAMES.get(source, source) method_display = METHOD_DISPLAY_NAMES.get(method, method.upper()) - return f"{source_display} — {method_display} {dims}" + name = f"{source_display} — {method_display} {dims}" + if param_suffix: + name += f" ({param_suffix})" + return name @dataclass diff --git a/src/protspace/data/processors/pipeline.py b/src/protspace/data/processors/pipeline.py index 29e6157d..10ff2a8a 100644 --- a/src/protspace/data/processors/pipeline.py +++ b/src/protspace/data/processors/pipeline.py @@ -7,7 +7,8 @@ import json import logging import shutil -from dataclasses import asdict, dataclass, field +from collections import Counter +from dataclasses import asdict, dataclass, field, fields from pathlib import Path from typing import Any @@ -15,7 +16,10 @@ import pandas as pd from protspace.data.loaders import EmbeddingSet -from protspace.data.loaders.embedding_set import format_projection_name +from protspace.data.loaders.embedding_set import ( + format_param_suffix, + format_projection_name, +) from protspace.data.processors.base_processor import BaseProcessor from protspace.utils import get_reducers from protspace.utils.constants import MDS_NAME @@ -40,11 +44,31 @@ class ReducerParams: eps: float = 1e-6 +@dataclass(frozen=True) +class MethodSpec: + """A single DR method with its dimension count and parameter overrides.""" + + method: str # e.g. "umap" + dims: int # e.g. 2 + overrides: tuple[tuple[str, int | float | str], ...] = () + + def __str__(self) -> str: + base = f"{self.method}{self.dims}" + if self.overrides: + params = ";".join(f"{k}={v}" for k, v in self.overrides) + return f"{base}:{params}" + return base + + @property + def overrides_dict(self) -> dict[str, int | float | str]: + return dict(self.overrides) + + @dataclass class PipelineConfig: """Configuration for a ReductionPipeline run.""" - methods: list[str] + methods: list[MethodSpec] output_path: Path bundled: bool = True keep_tmp: bool = False @@ -55,11 +79,120 @@ class PipelineConfig: reducer_params: ReducerParams = field(default_factory=ReducerParams) -def parse_method_spec(method_spec: str) -> tuple[str, int]: - """Parse 'pca2' into ('pca', 2).""" - method = "".join(filter(str.isalpha, method_spec)) - dims = int("".join(filter(str.isdigit, method_spec))) - return method, dims +# Valid override parameter names (from ReducerParams fields) +_VALID_OVERRIDE_KEYS = {f.name for f in fields(ReducerParams)} +# Field types for coercion +_FIELD_TYPES = {f.name: f.type for f in fields(ReducerParams)} + + +def _coerce_value(key: str, raw: str) -> int | float | str: + """Coerce a string value to the appropriate type for the given parameter.""" + expected = _FIELD_TYPES.get(key) + if expected is int: + return int(raw) + if expected is float: + return float(raw) + return raw + + +def parse_method_spec(method_spec: str) -> MethodSpec: + """Parse a method spec string into a MethodSpec. + + Examples: + 'pca2' → MethodSpec('pca', 2) + 'umap2:n_neighbors=50;min_dist=0.1' → MethodSpec('umap', 2, overrides=...) + """ + # Split on first ':' to separate method from overrides + if ":" in method_spec: + base, params_str = method_spec.split(":", 1) + else: + base, params_str = method_spec, "" + + method = "".join(filter(str.isalpha, base)) + dims = int("".join(filter(str.isdigit, base))) + + overrides = {} + if params_str: + for pair in params_str.split(";"): + pair = pair.strip() + if not pair: + continue + if "=" not in pair: + raise ValueError( + f"Invalid parameter format '{pair}' in '{method_spec}'. " + f"Expected key=value." + ) + key, val = pair.split("=", 1) + key = key.strip() + if key not in _VALID_OVERRIDE_KEYS: + raise ValueError( + f"Unknown parameter '{key}' in '{method_spec}'. " + f"Valid parameters: {', '.join(sorted(_VALID_OVERRIDE_KEYS))}" + ) + overrides[key] = _coerce_value(key, val.strip()) + + return MethodSpec( + method=method, + dims=dims, + overrides=tuple(sorted(overrides.items())), + ) + + +def parse_methods_arg(raw: list[str]) -> list[MethodSpec]: + """Parse repeatable -m arguments into a deduplicated MethodSpec list. + + Each element may be comma-separated: "pca2,umap2:n_neighbors=50" + Semicolons separate parameters within a method override. + """ + specs: list[MethodSpec] = [] + seen: set[MethodSpec] = set() + for item in raw: + for part in item.split(","): + part = part.strip() + if not part: + continue + spec = parse_method_spec(part) + if spec not in seen: + seen.add(spec) + specs.append(spec) + return specs + + +def disambiguation_suffix(spec: MethodSpec, method_counts: Counter) -> str: + """Return a parameter suffix for projection name disambiguation. + + When the same (method, dims) pair appears multiple times in a run AND the + given spec carries parameter overrides, return the abbreviated parameter + string (e.g. "n=50, d=0.1"). Otherwise return "". + + A plain spec sitting alongside an override spec returns "" — the override + spec alone carries the disambiguating suffix, and the plain spec keeps the + default name (e.g. "ProtT5 — UMAP 2"). + """ + if method_counts[(spec.method, spec.dims)] > 1 and spec.overrides: + return format_param_suffix(spec.overrides_dict) + return "" + + +def _run_with_overridden_config( + base: BaseProcessor, + effective_params: dict[str, Any], + method: str, + dims: int, + data: Any, +) -> dict[str, Any]: + """Run base.process_reduction with effective_params, restoring the prior + base.config afterwards. + + Centralizes the save/restore pattern so a leaked `precomputed` flag (or + any other temporary key) cannot survive across reduction calls. + """ + saved = base.config + base.config = effective_params + try: + return base.process_reduction(data, method, dims) + finally: + base.config = saved class ReductionPipeline: @@ -392,7 +525,11 @@ def _merge_csv(api_df: pd.DataFrame, csv_df: pd.DataFrame | None) -> pd.DataFram # --- Projection caching helpers --- def _projection_cache_path( - self, embedding_name: str, method: str, dims: int + self, + embedding_name: str, + method: str, + dims: int, + effective_params: dict[str, Any] | None = None, ) -> Path | None: cache_dir = self.config.intermediate_dir if not cache_dir or not self.config.keep_tmp: @@ -401,16 +538,23 @@ def _projection_cache_path( "embedding": embedding_name, "method": method, "dims": dims, - "params": asdict(self.config.reducer_params), + "params": effective_params or asdict(self.config.reducer_params), } key_json = json.dumps(key_dict, sort_keys=True, default=str) h = hashlib.sha256(key_json.encode()).hexdigest()[:12] return cache_dir / f"proj_{embedding_name}_{method}{dims}_{h}.npz" def _load_cached_projection( - self, embedding_name: str, method: str, dims: int + self, + embedding_name: str, + method: str, + dims: int, + effective_params: dict[str, Any] | None = None, + param_suffix: str = "", ) -> dict[str, Any] | None: - path = self._projection_cache_path(embedding_name, method, dims) + path = self._projection_cache_path( + embedding_name, method, dims, effective_params + ) if ( path is None or not path.exists() @@ -426,16 +570,23 @@ def _load_cached_projection( cached = np.load(path, allow_pickle=False) info = json.loads(str(cached["info"])) return { - "name": format_projection_name(embedding_name, method, dims), + "name": format_projection_name(embedding_name, method, dims, param_suffix), "dimensions": dims, "info": info, "data": cached["data"], } def _save_projection_cache( - self, embedding_name: str, method: str, dims: int, reduction: dict + self, + embedding_name: str, + method: str, + dims: int, + reduction: dict, + effective_params: dict[str, Any] | None = None, ) -> None: - path = self._projection_cache_path(embedding_name, method, dims) + path = self._projection_cache_path( + embedding_name, method, dims, effective_params + ) if path is None: return np.savez( @@ -452,31 +603,51 @@ def _run_reductions( cached_projections: list[str] = [] # e.g. "PCA 2 (prot_t5)" computed_count = 0 + # Pre-compute which (method, dims) pairs appear multiple times + method_counts = Counter( + (spec.method, spec.dims) for spec in self.config.methods + ) + + global_params = asdict(self.config.reducer_params) + for emb_set in embedding_sets: if emb_set.precomputed: - cached = self._load_cached_projection(emb_set.name, MDS_NAME, 2) + cached = self._load_cached_projection( + emb_set.name, MDS_NAME, 2, global_params + ) if cached: all_reductions.append(cached) cached_projections.append(f"MDS 2 ({emb_set.name})") continue - self.base.config["precomputed"] = True logger.info(f"Applying MDS 2 to '{emb_set.name}' (precomputed)") - reduction = self.base.process_reduction(emb_set.data, MDS_NAME, 2) + effective_params = {**global_params, "precomputed": True} + reduction = _run_with_overridden_config( + self.base, effective_params, MDS_NAME, 2, emb_set.data + ) reduction["name"] = format_projection_name(emb_set.name, MDS_NAME, 2) all_reductions.append(reduction) - self._save_projection_cache(emb_set.name, MDS_NAME, 2, reduction) - self.base.config.pop("precomputed", None) + self._save_projection_cache( + emb_set.name, MDS_NAME, 2, reduction, global_params + ) computed_count += 1 continue - for method_spec in self.config.methods: - method, dims = parse_method_spec(method_spec) + for spec in self.config.methods: + method, dims = spec.method, spec.dims if method not in self.base.reducers: logger.warning(f"Unknown method: {method}. Skipping.") continue - cached = self._load_cached_projection(emb_set.name, method, dims) + # Merge global defaults with per-method overrides + effective_params = {**global_params, **spec.overrides_dict} + + # Build param suffix for disambiguation + param_suffix = disambiguation_suffix(spec, method_counts) + + cached = self._load_cached_projection( + emb_set.name, method, dims, effective_params, param_suffix + ) if cached: all_reductions.append(cached) cached_projections.append( @@ -485,10 +656,17 @@ def _run_reductions( continue logger.info(f"Applying {method.upper()} {dims} to '{emb_set.name}'") - reduction = self.base.process_reduction(emb_set.data, method, dims) - reduction["name"] = format_projection_name(emb_set.name, method, dims) + reduction = _run_with_overridden_config( + self.base, effective_params, method, dims, emb_set.data + ) + + reduction["name"] = format_projection_name( + emb_set.name, method, dims, param_suffix + ) all_reductions.append(reduction) - self._save_projection_cache(emb_set.name, method, dims, reduction) + self._save_projection_cache( + emb_set.name, method, dims, reduction, effective_params + ) computed_count += 1 if cached_projections: diff --git a/tests/test_pipeline_utils.py b/tests/test_pipeline_utils.py index 1608eb42..5ec6fc29 100644 --- a/tests/test_pipeline_utils.py +++ b/tests/test_pipeline_utils.py @@ -1,17 +1,24 @@ """Tests for pipeline utility functions.""" +from collections import Counter + import numpy as np import pytest from protspace.data.loaders.embedding_set import ( EmbeddingSet, + format_param_suffix, format_projection_name, merge_same_name_sets, ) from protspace.data.processors.pipeline import ( + MethodSpec, PipelineConfig, ReductionPipeline, + _run_with_overridden_config, + disambiguation_suffix, parse_method_spec, + parse_methods_arg, ) # --------------------------------------------------------------------------- @@ -21,28 +28,164 @@ class TestParseMethodSpec: def test_pca2(self): - assert parse_method_spec("pca2") == ("pca", 2) + spec = parse_method_spec("pca2") + assert spec.method == "pca" + assert spec.dims == 2 + assert spec.overrides == () def test_umap3(self): - assert parse_method_spec("umap3") == ("umap", 3) + spec = parse_method_spec("umap3") + assert spec.method == "umap" + assert spec.dims == 3 def test_tsne2(self): - assert parse_method_spec("tsne2") == ("tsne", 2) + spec = parse_method_spec("tsne2") + assert spec.method == "tsne" + assert spec.dims == 2 def test_pacmap2(self): - assert parse_method_spec("pacmap2") == ("pacmap", 2) + spec = parse_method_spec("pacmap2") + assert spec.method == "pacmap" + assert spec.dims == 2 def test_mds2(self): - assert parse_method_spec("mds2") == ("mds", 2) + spec = parse_method_spec("mds2") + assert spec.method == "mds" + assert spec.dims == 2 def test_localmap2(self): - assert parse_method_spec("localmap2") == ("localmap", 2) + spec = parse_method_spec("localmap2") + assert spec.method == "localmap" + assert spec.dims == 2 def test_invalid_no_digits(self): with pytest.raises(ValueError): parse_method_spec("pca") +# --------------------------------------------------------------------------- +# parse_method_spec with overrides +# --------------------------------------------------------------------------- + + +class TestParseMethodSpecWithOverrides: + def test_single_override(self): + spec = parse_method_spec("umap2:n_neighbors=50") + assert spec.method == "umap" + assert spec.dims == 2 + assert spec.overrides_dict == {"n_neighbors": 50} + + def test_multiple_overrides_semicolon(self): + spec = parse_method_spec("umap2:n_neighbors=50;min_dist=0.1") + assert spec.overrides_dict == {"n_neighbors": 50, "min_dist": 0.1} + + def test_int_coercion(self): + spec = parse_method_spec("umap2:n_neighbors=100") + assert isinstance(spec.overrides_dict["n_neighbors"], int) + + def test_float_coercion(self): + spec = parse_method_spec("umap2:min_dist=0.5") + assert isinstance(spec.overrides_dict["min_dist"], float) + + def test_string_metric(self): + spec = parse_method_spec("umap2:metric=cosine") + assert spec.overrides_dict["metric"] == "cosine" + + def test_unknown_param_raises(self): + with pytest.raises(ValueError, match="Unknown parameter 'bogus'"): + parse_method_spec("umap2:bogus=5") + + def test_missing_value_raises(self): + with pytest.raises(ValueError, match="Invalid parameter format"): + parse_method_spec("umap2:n_neighbors") + + def test_empty_params_after_colon(self): + spec = parse_method_spec("umap2:") + assert spec.overrides == () + + def test_overrides_are_sorted(self): + spec = parse_method_spec("umap2:min_dist=0.1;n_neighbors=50") + keys = [k for k, _ in spec.overrides] + assert keys == sorted(keys) + + +# --------------------------------------------------------------------------- +# MethodSpec +# --------------------------------------------------------------------------- + + +class TestMethodSpec: + def test_str_no_overrides(self): + spec = MethodSpec("pca", 2) + assert str(spec) == "pca2" + + def test_str_with_overrides(self): + spec = MethodSpec("umap", 2, (("min_dist", 0.1), ("n_neighbors", 50))) + assert str(spec) == "umap2:min_dist=0.1;n_neighbors=50" + + def test_overrides_dict(self): + spec = MethodSpec("umap", 2, (("n_neighbors", 50),)) + assert spec.overrides_dict == {"n_neighbors": 50} + + def test_equality(self): + a = parse_method_spec("umap2:n_neighbors=50") + b = parse_method_spec("umap2:n_neighbors=50") + assert a == b + + def test_hashable(self): + spec = parse_method_spec("umap2:n_neighbors=50") + assert hash(spec) == hash(spec) + + +# --------------------------------------------------------------------------- +# parse_methods_arg +# --------------------------------------------------------------------------- + + +class TestParseMethodsArg: + def test_single_comma_separated(self): + result = parse_methods_arg(["pca2,umap2"]) + assert len(result) == 2 + assert result[0].method == "pca" + assert result[1].method == "umap" + + def test_repeated(self): + result = parse_methods_arg(["pca2", "umap2"]) + assert len(result) == 2 + + def test_mixed_with_overrides(self): + result = parse_methods_arg(["pca2", "umap2:n_neighbors=50;min_dist=0.1"]) + assert len(result) == 2 + assert result[1].overrides_dict == {"n_neighbors": 50, "min_dist": 0.1} + + def test_comma_separated_with_overrides(self): + result = parse_methods_arg(["pca2,umap2:n_neighbors=50;min_dist=0.1,tsne2"]) + assert len(result) == 3 + assert result[0].method == "pca" + assert result[1].overrides_dict == {"n_neighbors": 50, "min_dist": 0.1} + assert result[2].method == "tsne" + + def test_deduplicates(self): + result = parse_methods_arg(["umap2", "umap2"]) + assert len(result) == 1 + + def test_different_overrides_not_deduped(self): + result = parse_methods_arg(["umap2:n_neighbors=50", "umap2:n_neighbors=100"]) + assert len(result) == 2 + + def test_backward_compatible(self): + result = parse_methods_arg(["pca2,umap2,tsne2"]) + assert len(result) == 3 + + def test_strips_whitespace(self): + result = parse_methods_arg([" pca2 , umap2 "]) + assert len(result) == 2 + + def test_skips_empty_parts(self): + result = parse_methods_arg(["pca2,,umap2"]) + assert len(result) == 2 + + # --------------------------------------------------------------------------- # format_projection_name # --------------------------------------------------------------------------- @@ -71,6 +214,78 @@ def test_unknown_method_uppercased(self): def test_3d(self): assert format_projection_name("prot_t5", "tsne", 3) == "ProtT5 — t-SNE 3" + def test_with_param_suffix(self): + assert ( + format_projection_name("prot_t5", "umap", 2, "n=50, d=0.1") + == "ProtT5 — UMAP 2 (n=50, d=0.1)" + ) + + def test_empty_suffix_no_parens(self): + assert format_projection_name("prot_t5", "pca", 2, "") == "ProtT5 — PCA 2" + + +# --------------------------------------------------------------------------- +# format_param_suffix +# --------------------------------------------------------------------------- + + +class TestFormatParamSuffix: + def test_single_param(self): + assert format_param_suffix({"n_neighbors": 50}) == "n=50" + + def test_multiple_params_sorted(self): + assert ( + format_param_suffix({"n_neighbors": 50, "min_dist": 0.1}) == "d=0.1, n=50" + ) + + def test_string_value(self): + assert format_param_suffix({"metric": "cosine"}) == "m=cosine" + + def test_unknown_key_passthrough(self): + assert format_param_suffix({"custom_key": 42}) == "custom_key=42" + + +# --------------------------------------------------------------------------- +# disambiguation_suffix +# --------------------------------------------------------------------------- + + +class TestDisambiguationSuffix: + def test_unique_method_returns_empty(self): + spec = parse_method_spec("umap2:n_neighbors=50") + counts = Counter([(spec.method, spec.dims)]) + assert disambiguation_suffix(spec, counts) == "" + + def test_duplicates_with_overrides_return_suffixes(self): + a = parse_method_spec("umap2:n_neighbors=15") + b = parse_method_spec("umap2:n_neighbors=50") + counts = Counter([(a.method, a.dims), (b.method, b.dims)]) + assert disambiguation_suffix(a, counts) == "n=15" + assert disambiguation_suffix(b, counts) == "n=50" + + def test_plain_spec_alongside_override_returns_empty(self): + """Mixed case: -m umap2 -m umap2:n_neighbors=50. + + The plain spec has no overrides, so its suffix is empty even though + the (method, dims) pair is duplicated. The override spec carries the + disambiguating suffix. + """ + plain = MethodSpec("umap", 2) + override = parse_method_spec("umap2:n_neighbors=50") + counts = Counter([(plain.method, plain.dims), (override.method, override.dims)]) + assert disambiguation_suffix(plain, counts) == "" + assert disambiguation_suffix(override, counts) == "n=50" + + def test_duplicate_method_no_overrides_anywhere(self): + """If two specs collide with no overrides at all, suffix is empty. + + This case cannot occur via parse_methods_arg (it dedupes), but the + helper should still behave sanely if called directly. + """ + spec = MethodSpec("umap", 2) + counts = Counter([(spec.method, spec.dims), (spec.method, spec.dims)]) + assert disambiguation_suffix(spec, counts) == "" + # --------------------------------------------------------------------------- # _resolve_annotation_names @@ -80,7 +295,7 @@ def test_3d(self): class TestResolveAnnotationNames: def _resolve(self, annotations): config = PipelineConfig( - methods=["pca2"], + methods=[MethodSpec("pca", 2)], output_path=None, annotations=annotations, ) @@ -137,7 +352,7 @@ def test_tsv_path(self): class TestValidateHeaders: def _make_pipeline(self): - config = PipelineConfig(methods=["pca2"], output_path=None) + config = PipelineConfig(methods=[MethodSpec("pca", 2)], output_path=None) return ReductionPipeline(config) def _make_es(self, name, headers): @@ -299,10 +514,131 @@ def test_empty_list(self): def test_same_name_no_overlap_through_pipeline(self): """Regression test for issue #44: same name, disjoint keys should work.""" - config = PipelineConfig(methods=["pca2"], output_path=None) + config = PipelineConfig(methods=[MethodSpec("pca", 2)], output_path=None) pipeline = ReductionPipeline(config) es1 = _make_es("prot_t5", ["A", "B"]) es2 = _make_es("prot_t5", ["C", "D"]) # Before the fix, this raised "No common protein identifiers" result = pipeline._validate_headers(merge_same_name_sets([es1, es2])) assert set(result) == {"A", "B", "C", "D"} + + +# --------------------------------------------------------------------------- +# project.py base.config isolation contract +# --------------------------------------------------------------------------- + + +class TestRunWithOverriddenConfig: + """The shared helper must restore base.config between iterations so a + `precomputed` flag (or any temporary key) cannot leak from one spec to + the next. + """ + + def test_base_config_restored_after_call(self): + original = {"metric": "euclidean", "n_neighbors": 15} + + class FakeBase: + def __init__(self, cfg): + self.config = cfg + self.seen_config = None + + def process_reduction(self, data, method, dims): + self.seen_config = dict(self.config) + return { + "name": f"{method}{dims}", + "dimensions": dims, + "data": [], + "info": {}, + } + + base = FakeBase(dict(original)) + effective = {**original, "n_neighbors": 50, "precomputed": True} + + result = _run_with_overridden_config(base, effective, "umap", 2, data=None) + + assert base.seen_config == effective, ( + "process_reduction should observe the overridden config" + ) + assert base.config == original, ( + f"base.config leaked override; expected {original}, got {base.config}" + ) + assert result["name"] == "umap2" + + def test_config_restored_on_exception(self): + original = {"metric": "euclidean"} + + class BoomBase: + def __init__(self, cfg): + self.config = cfg + + def process_reduction(self, data, method, dims): + raise RuntimeError("boom") + + base = BoomBase(dict(original)) + with pytest.raises(RuntimeError, match="boom"): + _run_with_overridden_config( + base, {"metric": "cosine"}, "umap", 2, data=None + ) + + assert base.config == original + + +# --------------------------------------------------------------------------- +# precomputed-MDS branch: base.config isolation +# --------------------------------------------------------------------------- + + +class TestPrecomputedMDSConfigIsolation: + """Regression tests for the precomputed-MDS branch in + ReductionPipeline._run_reductions: a `precomputed` flag must not survive + in self.base.config after the branch finishes — including when + process_reduction raises. + """ + + def _make_pipeline(self): + config = PipelineConfig(methods=[], output_path=None) + return ReductionPipeline(config) + + def _make_precomputed_es(self): + return EmbeddingSet( + name="MMseqs2", + data=np.eye(2, dtype=np.float32), + headers=["A", "B"], + precomputed=True, + ) + + def test_precomputed_mds_does_not_leak_precomputed_flag_on_success(self): + pipeline = self._make_pipeline() + + def fake_reduce(data, method, dims): + assert method == "mds" + assert pipeline.base.config.get("precomputed") is True + return { + "name": "stub", + "dimensions": dims, + "data": np.zeros((2, dims), dtype=np.float32), + "info": {}, + } + + pipeline.base.process_reduction = fake_reduce + + pipeline._run_reductions([self._make_precomputed_es()]) + + assert "precomputed" not in pipeline.base.config + + def test_precomputed_mds_restores_config_on_exception(self): + pipeline = self._make_pipeline() + + def boom(data, method, dims): + raise RuntimeError("boom") + + pipeline.base.process_reduction = boom + original_config_id = id(pipeline.base.config) + + with pytest.raises(RuntimeError, match="boom"): + pipeline._run_reductions([self._make_precomputed_es()]) + + assert "precomputed" not in pipeline.base.config + assert id(pipeline.base.config) == original_config_id, ( + "base.config reference should be the original dict, not a replacement" + )