diff --git a/applications/airtable/configs/prepare_config.yml b/applications/airtable/configs/prepare_config.yml new file mode 100644 index 000000000..da9eb5f7b --- /dev/null +++ b/applications/airtable/configs/prepare_config.yml @@ -0,0 +1,47 @@ +# Dataset preparation pipeline: NFS -> VAST rechunked zarr v3 +# Usage: prepare run -c prepare_config.yml [--dry-run] + +nfs_root: /hpc/projects/intracellular_dashboard/organelle_dynamics +vast_root: /hpc/projects/organelle_phenotyping/datasets +workspace_dir: /hpc/mydata/eduardo.hirata/repos/viscy + +concatenate: + # null = auto-detect raw channels (Phase3D + raw *). Set explicitly to override. + channel_names: null + chunks_czyx: [1, 16, 256, 256] + shards_ratio: [1, 1, 8, 8, 8] + output_ome_zarr_version: "0.5" + conda_env: biahub + # Override biahub's internal SLURM settings (passed via -sb flag) + # Set to null to use biahub defaults + sbatch_overrides: + partition: cpu + +qc: + channel_names: [Phase3D] + NA_det: 1.35 + lambda_ill: 0.450 + pixel_size: 0.1494 + midband_fractions: [0.125, 0.25] + device: cuda + num_workers: 16 + +preprocess: + channel_names: -1 + num_workers: 32 + block_size: 32 + +# biahub concatenate submits its own SLURM jobs via submitit (no config needed) +# QC and preprocess run as separate SLURM jobs (no race condition) +slurm: + qc: + partition: gpu + gres: "gpu:1" + cpus_per_task: 16 + mem_per_cpu: 4G + time: "00:30:00" + preprocess: + partition: cpu + cpus_per_task: 32 + mem_per_cpu: 4G + time: "04:00:00" diff --git a/applications/airtable/scripts/write_experiment_metadata.py b/applications/airtable/scripts/write_experiment_metadata.py index 192bff024..6b0ce2853 100644 --- a/applications/airtable/scripts/write_experiment_metadata.py +++ b/applications/airtable/scripts/write_experiment_metadata.py @@ -68,6 +68,9 @@ def register(position_paths: list[Path], dry_run: bool = False, dataset: str | N if result.updated: db.batch_update(result.updated) logger.info("Updated %d existing records", len(result.updated)) + if result.template_ids_to_delete: + db.batch_delete(result.template_ids_to_delete) + logger.info("Deleted %d well template records", len(result.template_ids_to_delete)) print(format_register_summary(result, dry_run=dry_run)) diff --git a/applications/airtable/src/airtable_utils/database.py b/applications/airtable/src/airtable_utils/database.py index 1cb9ffd06..c1fd19a70 100644 --- a/applications/airtable/src/airtable_utils/database.py +++ b/applications/airtable/src/airtable_utils/database.py @@ -143,3 +143,18 @@ def batch_create(self, records: list[dict]) -> list[dict]: Created records as returned by the Airtable API. """ return self._table.batch_create([r["fields"] for r in records]) + + def batch_delete(self, record_ids: list[str]) -> list[dict]: + """Batch-delete records by ID. + + Parameters + ---------- + record_ids : list[str] + Airtable record IDs to delete. + + Returns + ------- + list[dict] + Deletion confirmations from the Airtable API. + """ + return self._table.batch_delete(record_ids) diff --git a/applications/airtable/src/airtable_utils/prepare.py b/applications/airtable/src/airtable_utils/prepare.py new file mode 100644 index 000000000..36f5d077d --- /dev/null +++ b/applications/airtable/src/airtable_utils/prepare.py @@ -0,0 +1,670 @@ +"""Config-driven dataset preparation: NFS -> VAST rechunked zarr v3.""" + +from __future__ import annotations + +import json +import logging +from pathlib import Path +from textwrap import dedent + +import yaml +from pydantic import BaseModel, Field + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Pydantic config models +# --------------------------------------------------------------------------- + + +class ConcatenateConfig(BaseModel): + """Parameters for biahub concatenate.""" + + channel_names: list[str] | None = None + chunks_czyx: list[int] = [1, 16, 256, 256] + shards_ratio: list[int] = [1, 1, 8, 8, 8] + output_ome_zarr_version: str = "0.5" + conda_env: str = "biahub" + sbatch_overrides: dict[str, str] | None = None + + +class QCParams(BaseModel): + """Focus-slice QC parameters.""" + + channel_names: list[str] = ["Phase3D"] + NA_det: float = 1.35 + lambda_ill: float = 0.450 + pixel_size: float = 0.1494 + midband_fractions: tuple[float, float] = (0.125, 0.25) + device: str = "cuda" + num_workers: int = 16 + + +class PreprocessParams(BaseModel): + """Normalization preprocessing parameters.""" + + channel_names: int | list[str] = -1 + num_workers: int = 48 + block_size: int = 32 + + +class SlurmStageConfig(BaseModel): + """SLURM resource settings for one job stage.""" + + partition: str + cpus_per_task: int = 24 + mem_per_cpu: str = "4G" + time: str = "06:00:00" + gres: str | None = None + constraint: str | None = None + + +class SlurmConfig(BaseModel): + """SLURM settings for QC and preprocess stages (separate jobs). + + The concatenation stage is not a SLURM job — ``biahub concatenate`` + submits its own SLURM jobs internally via submitit. + """ + + qc: SlurmStageConfig = Field( + default_factory=lambda: SlurmStageConfig( + partition="gpu", + gres="gpu:1", + cpus_per_task=16, + mem_per_cpu="4G", + time="00:30:00", + ) + ) + preprocess: SlurmStageConfig = Field( + default_factory=lambda: SlurmStageConfig( + partition="preempted", + cpus_per_task=16, + mem_per_cpu="4G", + time="04:00:00", + ) + ) + + +class PrepareConfig(BaseModel): + """Top-level prepare pipeline configuration.""" + + nfs_root: Path = Path("/hpc/projects/intracellular_dashboard/organelle_dynamics") + vast_root: Path = Path("/hpc/projects/organelle_phenotyping/datasets") + workspace_dir: Path = Path("/hpc/mydata/eduardo.hirata/repos/viscy") + concatenate: ConcatenateConfig = Field(default_factory=ConcatenateConfig) + qc: QCParams = Field(default_factory=QCParams) + preprocess: PreprocessParams = Field(default_factory=PreprocessParams) + slurm: SlurmConfig = Field(default_factory=SlurmConfig) + + +# --------------------------------------------------------------------------- +# Path resolution +# --------------------------------------------------------------------------- + + +def resolve_nfs_paths(dataset_name: str, nfs_root: Path) -> dict[str, Path]: + """Return NFS zarr and tracking paths for a dataset. + + Parameters + ---------- + dataset_name : str + Dataset identifier, e.g. ``"2025_01_22_A549_G3BP1_ZIKV_DENV"``. + nfs_root : Path + Root of organelle_dynamics on NFS. + + Returns + ------- + dict[str, Path] + Keys: ``zarr``, ``tracking``. + + Raises + ------ + FileNotFoundError + If the assembled zarr does not exist on NFS. + """ + zarr_path = nfs_root / dataset_name / "2-assemble" / f"{dataset_name}.zarr" + tracking_path = nfs_root / dataset_name / "1-preprocess" / "label-free" / "3-track" / f"{dataset_name}_cropped.zarr" + if not zarr_path.exists(): + raise FileNotFoundError(f"NFS zarr not found: {zarr_path}") + return {"zarr": zarr_path, "tracking": tracking_path} + + +def resolve_vast_paths(dataset_name: str, vast_root: Path) -> dict[str, Path]: + """Return expected VAST output paths for a dataset. + + Parameters + ---------- + dataset_name : str + Dataset identifier. + vast_root : Path + Root of datasets directory on VAST. + + Returns + ------- + dict[str, Path] + Keys: ``output_dir``, ``zarr``, ``tracking``. + """ + output_dir = vast_root / dataset_name + return { + "output_dir": output_dir, + "zarr": output_dir / f"{dataset_name}.zarr", + "tracking": output_dir / "tracking.zarr", + } + + +# --------------------------------------------------------------------------- +# Zarr version validation +# --------------------------------------------------------------------------- + + +def check_zarr_version(zarr_path: Path) -> dict[str, int | str | None]: + """Check zarr format version and OME-Zarr version of an existing store. + + Parameters + ---------- + zarr_path : Path + Path to the zarr store root. + + Returns + ------- + dict[str, int | str | None] + Keys: ``zarr_format`` (2, 3, or None), ``ome_version`` (e.g. "0.5" or None). + """ + result: dict[str, int | str | None] = {"zarr_format": None, "ome_version": None} + + zarr_json = zarr_path / "zarr.json" + zgroup = zarr_path / ".zgroup" + + if zarr_json.exists(): + with open(zarr_json) as f: + meta = json.load(f) + result["zarr_format"] = meta.get("zarr_format", 3) + ome = meta.get("attributes", {}).get("ome", {}) + result["ome_version"] = ome.get("version") + elif zgroup.exists(): + with open(zgroup) as f: + meta = json.load(f) + result["zarr_format"] = meta.get("zarr_format", 2) + zattrs = zarr_path / ".zattrs" + if zattrs.exists(): + with open(zattrs) as f: + attrs = json.load(f) + result["ome_version"] = attrs.get("plate", {}).get("version") + + return result + + +def check_preprocessed(zarr_path: Path) -> bool: + """Check if normalization metadata has been written to the zarr store. + + Parameters + ---------- + zarr_path : Path + Path to the zarr store root. + + Returns + ------- + bool + True if normalization stats are present. + """ + zarr_json = zarr_path / "zarr.json" + zattrs = zarr_path / ".zattrs" + + if zarr_json.exists(): + with open(zarr_json) as f: + meta = json.load(f) + return "normalization" in meta.get("attributes", {}) + elif zattrs.exists(): + with open(zattrs) as f: + attrs = json.load(f) + return "normalization" in attrs + + return False + + +# --------------------------------------------------------------------------- +# Discovery (reads NFS zarr via iohub) +# --------------------------------------------------------------------------- + + +def discover_wells(nfs_zarr_path: Path) -> list[str]: + """Enumerate well paths from an NFS OME-Zarr plate. + + Returns well-level paths (e.g. ``"B/1"``) not full position paths. + The ``crop_concat.yml`` format expects ``{zarr}/{well}/*`` globs + so that biahub concatenate can discover positions within each well. + + Parameters + ---------- + nfs_zarr_path : Path + Path to the assembled zarr on NFS. + + Returns + ------- + list[str] + Sorted well paths like ``["A/1", "B/1", "C/2"]``. + """ + from iohub import open_ome_zarr + + wells: list[str] = [] + with open_ome_zarr(str(nfs_zarr_path), mode="r") as plate: + for pos_path, _pos in plate.positions(): + # pos_path is like "A/1/000000" — extract well as "A/1" + well = "/".join(pos_path.split("/")[:2]) + if well not in wells: + wells.append(well) + return sorted(wells) + + +def discover_channels(nfs_zarr_path: Path) -> list[str]: + """Read channel names from an NFS OME-Zarr plate. + + Parameters + ---------- + nfs_zarr_path : Path + Path to the assembled zarr on NFS. + + Returns + ------- + list[str] + Channel names, e.g. ``["Phase3D", "raw GFP EX488 EM525-45", ...]``. + """ + from iohub import open_ome_zarr + + with open_ome_zarr(str(nfs_zarr_path), mode="r") as plate: + return list(plate.channel_names) + + +RAW_CHANNEL_PREFIXES = ("Phase3D", "raw ") + + +def filter_raw_channels(channel_names: list[str]) -> list[str]: + """Filter to only raw imaging channels (Phase3D and raw fluorescence). + + Excludes virtual stains (``nuclei_prediction``, ``membrane_prediction``), + deconvolved channels (``GFP EX488 ...`` without ``raw`` prefix), and + other derived channels (``BF``). + + Parameters + ---------- + channel_names : list[str] + All channel names from the zarr. + + Returns + ------- + list[str] + Only channels starting with ``"Phase3D"`` or ``"raw "``. + """ + return [ch for ch in channel_names if ch.startswith(RAW_CHANNEL_PREFIXES)] + + +# --------------------------------------------------------------------------- +# Config generation +# --------------------------------------------------------------------------- + + +def generate_crop_concat_config( + nfs_zarr_path: Path, + wells: list[str], + channel_names: list[str], + concat_cfg: ConcatenateConfig, +) -> dict: + """Build a crop_concat.yml dict for biahub concatenate. + + Parameters + ---------- + nfs_zarr_path : Path + Path to the source zarr on NFS. + wells : list[str] + Well paths like ``["A/1", "B/2"]`` (row/col level). + Each becomes ``"{zarr}/{well}/*"`` so biahub globs positions within. + channel_names : list[str] + Channel names (repeated once per well entry). + concat_cfg : ConcatenateConfig + Concatenation parameters. + + Returns + ------- + dict + Config dict ready to write as YAML. + """ + concat_data_paths = [f"{nfs_zarr_path}/{well}/*" for well in wells] + return { + "concat_data_paths": concat_data_paths, + "time_indices": "all", + "channel_names": [channel_names] * len(wells), + "X_slice": "all", + "Y_slice": "all", + "Z_slice": "all", + "chunks_czyx": concat_cfg.chunks_czyx, + "shards_ratio": concat_cfg.shards_ratio, + "output_ome_zarr_version": concat_cfg.output_ome_zarr_version, + } + + +def generate_qc_config(data_path: Path, qc_params: QCParams) -> dict: + """Build a QC config dict compatible with ``qc run -c``. + + Parameters + ---------- + data_path : Path + Path to the VAST zarr (target of QC). + qc_params : QCParams + Focus-slice QC parameters. + + Returns + ------- + dict + Config dict ready to write as YAML. + """ + return { + "data_path": str(data_path), + "num_workers": qc_params.num_workers, + "focus_slice": { + "channel_names": qc_params.channel_names, + "NA_det": qc_params.NA_det, + "lambda_ill": qc_params.lambda_ill, + "pixel_size": qc_params.pixel_size, + "midband_fractions": list(qc_params.midband_fractions), + "device": qc_params.device, + }, + } + + +def write_yaml(config: dict, output_path: Path) -> None: + """Write a dict to a YAML file. + + Parameters + ---------- + config : dict + Config to serialize. + output_path : Path + Destination file path. + """ + # Use a Dumper that avoids YAML anchors/aliases for repeated lists. + dumper = yaml.Dumper + dumper.ignore_aliases = lambda self, data: True + with open(output_path, "w") as f: + yaml.dump(config, f, Dumper=dumper, default_flow_style=False, sort_keys=False) + + +# --------------------------------------------------------------------------- +# SLURM script generation +# --------------------------------------------------------------------------- + + +def _slurm_header(job_name: str, output_dir: Path, cfg: SlurmStageConfig) -> str: + """Build SBATCH header lines.""" + lines = [ + "#!/bin/bash", + f"#SBATCH --job-name={job_name}", + "#SBATCH --nodes=1", + "#SBATCH --ntasks-per-node=1", + f"#SBATCH --partition={cfg.partition}", + f"#SBATCH --cpus-per-task={cfg.cpus_per_task}", + f"#SBATCH --mem-per-cpu={cfg.mem_per_cpu}", + f"#SBATCH --time={cfg.time}", + f"#SBATCH --output={output_dir}/slurm_{job_name}_%j.out", + ] + if cfg.gres: + lines.append(f"#SBATCH --gres={cfg.gres}") + if cfg.constraint: + lines.append(f'#SBATCH --constraint="{cfg.constraint}"') + return "\n".join(lines) + + +def generate_sbatch_override_file(overrides: dict[str, str]) -> str: + """Generate content for a biahub sbatch override file. + + Parameters + ---------- + overrides : dict[str, str] + SLURM directive keys and values, e.g. + ``{"partition": "preempted", "mem-per-cpu": "16G"}``. + + Returns + ------- + str + File content with ``#SBATCH`` lines. + """ + lines = ["#!/bin/bash"] + for key, value in overrides.items(): + lines.append(f"#SBATCH --{key}={value}") + return "\n".join(lines) + "\n" + + +def generate_concatenate_script( + crop_concat_path: Path, + vast_zarr_path: Path, + nfs_tracking_path: Path, + vast_tracking_path: Path, + conda_env: str, + sbatch_override_path: Path | None = None, +) -> str: + """Generate a bash script for biahub concatenate + tracking copy. + + This is NOT a SLURM script. ``biahub concatenate`` submits its own + SLURM jobs internally via submitit. The ``-m`` flag makes it block + until those jobs complete. After concatenation, tracking is rsynced. + + Parameters + ---------- + crop_concat_path : Path + Path to the generated crop_concat.yml. + vast_zarr_path : Path + Target zarr output path. + nfs_tracking_path : Path + Source tracking zarr on NFS. + vast_tracking_path : Path + Target tracking zarr on VAST. + conda_env : str + Conda environment name for biahub. + sbatch_override_path : Path or None + Path to sbatch override file for biahub's internal SLURM jobs. + + Returns + ------- + str + Bash script content. + """ + # Build the biahub command as a single line to avoid conda run + # swallowing backslash continuations. + cmd_parts = [ + f"conda run -n {conda_env} biahub concatenate", + f'-c "{crop_concat_path}"', + f'-o "{vast_zarr_path}"', + "-m", + ] + if sbatch_override_path: + cmd_parts.append(f'-sb "{sbatch_override_path}"') + biahub_cmd = " ".join(cmd_parts) + + return dedent(f"""\ + #!/bin/bash + set -euo pipefail + + echo "=== Step 1: biahub concatenate (submits SLURM jobs via submitit) ===" + {biahub_cmd} + echo "Concatenation complete." + + echo "=== Step 2: Copy tracking zarr ===" + if [ -d "{nfs_tracking_path}" ]; then + rsync -a --copy-links "{nfs_tracking_path}/" "{vast_tracking_path}/" + echo "Tracking copy complete." + else + echo "WARNING: NFS tracking zarr not found at {nfs_tracking_path}, skipping." + fi + """) + + +def generate_qc_slurm( + dataset_name: str, + vast_output_dir: Path, + qc_config_path: Path, + workspace_dir: Path, + slurm_cfg: SlurmStageConfig, +) -> str: + """Generate SLURM script for focus-slice QC (needs GPU). + + Parameters + ---------- + dataset_name : str + Dataset identifier (used for job name). + vast_output_dir : Path + Output directory on VAST. + qc_config_path : Path + Path to the generated qc_config.yml. + workspace_dir : Path + Path to the viscy repo root. + slurm_cfg : SlurmStageConfig + SLURM resource parameters. + + Returns + ------- + str + Complete SLURM script content. + """ + header = _slurm_header(f"qc_{dataset_name}", vast_output_dir, slurm_cfg) + body = dedent(f"""\ + + export PYTHONNOUSERSITE=1 + + echo "=== QC: focus slice detection ===" + uv run --project "{workspace_dir}" --package qc \ + qc run -c "{qc_config_path}" + echo "QC complete." + """) + return header + "\n" + body + + +def generate_preprocess_slurm( + dataset_name: str, + vast_output_dir: Path, + vast_zarr_path: Path, + workspace_dir: Path, + preprocess_params: PreprocessParams, + slurm_cfg: SlurmStageConfig, +) -> str: + """Generate SLURM script for normalization preprocessing (CPU only). + + Parameters + ---------- + dataset_name : str + Dataset identifier (used for job name). + vast_output_dir : Path + Output directory on VAST. + vast_zarr_path : Path + Path to the rechunked zarr on VAST. + workspace_dir : Path + Path to the viscy repo root. + preprocess_params : PreprocessParams + Normalization preprocessing parameters. + slurm_cfg : SlurmStageConfig + SLURM resource parameters. + + Returns + ------- + str + Complete SLURM script content. + """ + header = _slurm_header(f"preprocess_{dataset_name}", vast_output_dir, slurm_cfg) + + ch_arg = preprocess_params.channel_names + if isinstance(ch_arg, int): + ch_flag = f"--channel_names={ch_arg}" + else: + ch_flag = " ".join(f"--channel_names={c}" for c in ch_arg) + + body = dedent(f"""\ + + export PYTHONNOUSERSITE=1 + + echo "=== Preprocess: normalization stats ===" + echo "Data: {vast_zarr_path}" + uv run --project "{workspace_dir}" --package dynaclr \ + viscy preprocess --data_path "{vast_zarr_path}" \ + {ch_flag} --num_workers {preprocess_params.num_workers} \ + --block_size {preprocess_params.block_size} + echo "Preprocess complete." + """) + return header + "\n" + body + + +# --------------------------------------------------------------------------- +# Status check +# --------------------------------------------------------------------------- + + +def check_dataset_status(dataset_name: str, nfs_root: Path, vast_root: Path) -> dict[str, str]: + """Check existence and version info for a dataset across NFS and VAST. + + Parameters + ---------- + dataset_name : str + Dataset identifier. + nfs_root : Path + NFS root directory. + vast_root : Path + VAST root directory. + + Returns + ------- + dict[str, str] + Status fields for the dataset. + """ + nfs_zarr = nfs_root / dataset_name / "2-assemble" / f"{dataset_name}.zarr" + vast = resolve_vast_paths(dataset_name, vast_root) + + nfs_exists = nfs_zarr.exists() + vast_zarr_exists = vast["zarr"].exists() + vast_tracking_exists = vast["tracking"].exists() + + zarr_fmt: str = "-" + ome_ver: str = "-" + preprocessed: str = "-" + + if vast_zarr_exists: + ver = check_zarr_version(vast["zarr"]) + zarr_fmt = str(ver["zarr_format"]) if ver["zarr_format"] else "?" + ome_ver = str(ver["ome_version"]) if ver["ome_version"] else "?" + preprocessed = "yes" if check_preprocessed(vast["zarr"]) else "no" + + return { + "dataset": dataset_name, + "nfs": "yes" if nfs_exists else "no", + "vast_zarr": "yes" if vast_zarr_exists else "no", + "zarr_version": zarr_fmt, + "ome_version": ome_ver, + "tracking": "yes" if vast_tracking_exists else "no", + "preprocessed": preprocessed, + } + + +def format_status_table(rows: list[dict[str, str]]) -> str: + """Format dataset status rows as a markdown table. + + Parameters + ---------- + rows : list[dict[str, str]] + Each dict from :func:`check_dataset_status`. + + Returns + ------- + str + Markdown table string. + """ + headers = [ + "dataset", + "nfs", + "vast_zarr", + "zarr_version", + "ome_version", + "tracking", + "preprocessed", + ] + col_widths = {h: max(len(h), *(len(r[h]) for r in rows)) for h in headers} + + header_line = "| " + " | ".join(h.ljust(col_widths[h]) for h in headers) + " |" + sep_line = "| " + " | ".join("-" * col_widths[h] for h in headers) + " |" + data_lines = ["| " + " | ".join(r[h].ljust(col_widths[h]) for h in headers) + " |" for r in rows] + return "\n".join([header_line, sep_line, *data_lines]) diff --git a/applications/airtable/src/airtable_utils/prepare_cli.py b/applications/airtable/src/airtable_utils/prepare_cli.py new file mode 100644 index 000000000..c4e9486bb --- /dev/null +++ b/applications/airtable/src/airtable_utils/prepare_cli.py @@ -0,0 +1,259 @@ +"""CLI for config-driven dataset preparation (NFS -> VAST).""" + +from __future__ import annotations + +import logging +import re +import subprocess + +import click + +from airtable_utils.prepare import ( + PrepareConfig, + check_dataset_status, + check_preprocessed, + check_zarr_version, + discover_channels, + discover_wells, + filter_raw_channels, + format_status_table, + generate_concatenate_script, + generate_crop_concat_config, + generate_preprocess_slurm, + generate_qc_config, + generate_qc_slurm, + generate_sbatch_override_file, + resolve_nfs_paths, + resolve_vast_paths, + write_yaml, +) + +logger = logging.getLogger(__name__) + +CONTEXT_SETTINGS = {"help_option_names": ["-h", "--help"]} + + +def _load_prepare_config(config_path: str) -> PrepareConfig: + """Load and validate a prepare config YAML.""" + from viscy_utils.cli_utils import load_config + + raw = load_config(config_path) + return PrepareConfig(**raw) + + +def _parse_slurm_job_id(sbatch_output: str) -> str: + """Extract job ID from sbatch stdout like 'Submitted batch job 12345'.""" + match = re.search(r"Submitted batch job (\d+)", sbatch_output) + if not match: + raise RuntimeError(f"Could not parse sbatch output: {sbatch_output}") + return match.group(1) + + +@click.group(context_settings=CONTEXT_SETTINGS) +def prepare(): + """Prepare datasets for training on VAST storage.""" + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + +@prepare.command() +@click.argument("dataset_name") +@click.option( + "-c", + "--config", + "config_path", + required=True, + type=click.Path(exists=True), + help="Path to prepare config YAML.", +) +@click.option("--dry-run", is_flag=True, help="Generate configs without submitting SLURM jobs.") +@click.option("--force", is_flag=True, help="Overwrite existing VAST zarr even if it is zarr v2.") +def run(dataset_name: str, config_path: str, dry_run: bool, force: bool) -> None: + """Run the full preparation pipeline for DATASET_NAME. + + Steps: Airtable validation -> discover positions/channels -> generate + crop_concat.yml + qc_config.yml + SLURM scripts -> submit jobs. + """ + cfg = _load_prepare_config(config_path) + + # 1. Validate dataset is registered in Airtable + click.echo(f"Validating {dataset_name} in Airtable...") + from airtable_utils.database import AirtableDatasets + + db = AirtableDatasets() + records = db.get_dataset_records(dataset_name) + if not records: + raise click.ClickException( + f"Dataset '{dataset_name}' not found in Airtable. Register it first with the airtable-register workflow." + ) + click.echo(f" Found {len(records)} FOV records in Airtable.") + + # 2. Resolve NFS paths + nfs = resolve_nfs_paths(dataset_name, cfg.nfs_root) + click.echo(f" NFS zarr: {nfs['zarr']}") + + # 3. Resolve VAST paths + vast = resolve_vast_paths(dataset_name, cfg.vast_root) + click.echo(f" VAST output: {vast['output_dir']}") + + # 4. Check existing VAST zarr + if vast["zarr"].exists(): + ver = check_zarr_version(vast["zarr"]) + is_v3 = ver["zarr_format"] == 3 + is_ome05 = ver["ome_version"] == "0.5" + is_preprocessed = check_preprocessed(vast["zarr"]) + + if is_v3 and is_ome05 and is_preprocessed: + click.echo( + f" VAST zarr already exists: zarr v{ver['zarr_format']}, " + f"OME {ver['ome_version']}, preprocessed. Skipping." + ) + return + + if not force: + msg = ( + f"VAST zarr already exists at {vast['zarr']} " + f"(zarr v{ver['zarr_format']}, OME {ver['ome_version']}, " + f"preprocessed={is_preprocessed}). " + "Use --force to overwrite." + ) + raise click.ClickException(msg) + + click.echo(f" WARNING: Overwriting existing VAST zarr (zarr v{ver['zarr_format']}, OME {ver['ome_version']}).") + + # 5. Discover wells and resolve channels from NFS zarr + click.echo("Discovering wells and channels from NFS zarr...") + wells = discover_wells(nfs["zarr"]) + zarr_channels = discover_channels(nfs["zarr"]) + + if cfg.concatenate.channel_names is not None: + concat_channels = cfg.concatenate.channel_names + missing = [ch for ch in concat_channels if ch not in zarr_channels] + if missing: + raise click.ClickException(f"Channels {missing} from config not found in zarr. Available: {zarr_channels}") + else: + concat_channels = filter_raw_channels(zarr_channels) + if not concat_channels: + raise click.ClickException(f"No raw channels found in zarr. Available: {zarr_channels}") + + click.echo(f" Wells: {wells}") + click.echo(f" Zarr channels: {zarr_channels}") + click.echo(f" Extracting: {concat_channels}") + + # 6. Create output directory + vast["output_dir"].mkdir(parents=True, exist_ok=True) + + # 7. Generate crop_concat.yml + crop_concat_cfg = generate_crop_concat_config(nfs["zarr"], wells, concat_channels, cfg.concatenate) + crop_concat_path = vast["output_dir"] / "crop_concat.yml" + write_yaml(crop_concat_cfg, crop_concat_path) + click.echo(f" Wrote: {crop_concat_path}") + + # 8. Generate qc_config.yml + qc_cfg = generate_qc_config(vast["zarr"], cfg.qc) + qc_config_path = vast["output_dir"] / "qc_config.yml" + write_yaml(qc_cfg, qc_config_path) + click.echo(f" Wrote: {qc_config_path}") + + # 9. Generate scripts + sbatch_override_path = None + if cfg.concatenate.sbatch_overrides: + sbatch_content = generate_sbatch_override_file(cfg.concatenate.sbatch_overrides) + sbatch_override_path = vast["output_dir"] / "sbatch_overrides.sh" + sbatch_override_path.write_text(sbatch_content) + click.echo(f" Wrote: {sbatch_override_path}") + + concat_script = generate_concatenate_script( + crop_concat_path=crop_concat_path, + vast_zarr_path=vast["zarr"], + nfs_tracking_path=nfs["tracking"], + vast_tracking_path=vast["tracking"], + conda_env=cfg.concatenate.conda_env, + sbatch_override_path=sbatch_override_path, + ) + concat_script_path = vast["output_dir"] / "01_concatenate.sh" + concat_script_path.write_text(concat_script) + click.echo(f" Wrote: {concat_script_path}") + + qc_script = generate_qc_slurm( + dataset_name=dataset_name, + vast_output_dir=vast["output_dir"], + qc_config_path=qc_config_path, + workspace_dir=cfg.workspace_dir, + slurm_cfg=cfg.slurm.qc, + ) + qc_script_path = vast["output_dir"] / "02_qc.sh" + qc_script_path.write_text(qc_script) + click.echo(f" Wrote: {qc_script_path}") + + preprocess_script = generate_preprocess_slurm( + dataset_name=dataset_name, + vast_output_dir=vast["output_dir"], + vast_zarr_path=vast["zarr"], + workspace_dir=cfg.workspace_dir, + preprocess_params=cfg.preprocess, + slurm_cfg=cfg.slurm.preprocess, + ) + preprocess_script_path = vast["output_dir"] / "03_preprocess.sh" + preprocess_script_path.write_text(preprocess_script) + click.echo(f" Wrote: {preprocess_script_path}") + + if dry_run: + click.echo("\n--dry-run: configs and scripts generated, nothing executed.") + return + + # 10. Run concatenation (biahub submits its own SLURM jobs via submitit) + click.echo("\nRunning biahub concatenate + tracking copy...") + click.echo(" (biahub will submit SLURM jobs internally and -m will monitor them)") + subprocess.run(["bash", str(concat_script_path)], check=True) + click.echo("Concatenation and tracking copy complete.") + + # 11. Submit QC and preprocess as separate SLURM jobs (no dependency, no race condition) + click.echo("\nSubmitting QC and preprocess SLURM jobs...") + result_qc = subprocess.run( + ["sbatch", str(qc_script_path)], + capture_output=True, + text=True, + check=True, + ) + qc_job_id = _parse_slurm_job_id(result_qc.stdout) + click.echo(f" QC job: {qc_job_id} (GPU, ~5-20 min)") + + result_pp = subprocess.run( + ["sbatch", str(preprocess_script_path)], + capture_output=True, + text=True, + check=True, + ) + pp_job_id = _parse_slurm_job_id(result_pp.stdout) + click.echo(f" Preprocess job: {pp_job_id} (CPU, ~3 hrs)") + + click.echo(f"\nPipeline running for {dataset_name}.") + click.echo(f" Output: {vast['output_dir']}") + click.echo(f" Monitor: squeue -j {qc_job_id},{pp_job_id}") + + +@prepare.command() +@click.argument("dataset_names", nargs=-1, required=True) +@click.option( + "-c", + "--config", + "config_path", + required=True, + type=click.Path(exists=True), + help="Path to prepare config YAML.", +) +def status(dataset_names: tuple[str, ...], config_path: str) -> None: + """Check NFS/VAST existence and version status for one or more datasets.""" + cfg = _load_prepare_config(config_path) + + rows = [check_dataset_status(name, cfg.nfs_root, cfg.vast_root) for name in dataset_names] + click.echo(format_status_table(rows)) + + +def main() -> None: + """Entry point for the prepare CLI.""" + prepare() + + +if __name__ == "__main__": + main() diff --git a/applications/airtable/src/airtable_utils/registration.py b/applications/airtable/src/airtable_utils/registration.py index c189ff1fa..e35072659 100644 --- a/applications/airtable/src/airtable_utils/registration.py +++ b/applications/airtable/src/airtable_utils/registration.py @@ -35,6 +35,10 @@ "seeding_density", "treatment_concentration_nm", "fluorescence_modality", + "microscope", + "labelfree_modality", + "treatment", + "hours_post_treatment", ) @@ -49,6 +53,7 @@ class RegisterResult: channel_names: list[str] = field(default_factory=list) pixel_size_xy_um: float | None = None pixel_size_z_um: float | None = None + template_ids_to_delete: list[str] = field(default_factory=list) def parse_position_path(position_path: Path) -> tuple[Path, str]: @@ -264,6 +269,7 @@ def format_register_summary(result: RegisterResult, dry_run: bool = False) -> st f"| created | {len(result.created)} |", f"| updated | {len(result.updated)} |", f"| unmatched | {len(result.unmatched)} |", + f"| templates_to_delete | {len(result.template_ids_to_delete)} |", f"| pixel_size_xy_um | {xy} |", f"| pixel_size_z_um | {z} |", f"| status | {status} |", @@ -421,8 +427,8 @@ def register_fovs( result = RegisterResult(dataset=dataset_name) - # Filter to directories only — glob("*/*/*") also picks up .zattrs/.zgroup files - pos_names = [p for p in pos_names if not Path(zarr_root / p).name.startswith(".")] + # Filter to directories only — glob("*/*/*") also picks up zarr.json, .zattrs, .zgroup files + pos_names = [p for p in pos_names if (zarr_root / p).is_dir()] with open_ome_zarr(str(zarr_root), mode="r") as plate: result.channel_names = plate.channel_names @@ -453,7 +459,13 @@ def register_fovs( # Resolve cell_line linked records -> registry entries -> marker rec_for_marker = fov_records.get((well_id, fov)) or well_templates.get(well_id) - if rec_for_marker is not None and rec_for_marker.cell_line: + if rec_for_marker is not None: + if not rec_for_marker.cell_line: + raise ValueError( + f"Well '{well_id}' has no cell_line set in Airtable. " + "cell_line is required for channel marker derivation — " + "fill it in the platemap before registering." + ) marker_entries = [registry[rid] for rid in rec_for_marker.cell_line if rid in registry] marker_fields = derive_channel_marker(result.channel_names, marker_entries) zarr_fields.update(marker_fields) @@ -478,4 +490,11 @@ def register_fovs( } result.created.append({"fields": fields}) + # Collect well template record IDs to delete — only for wells where at least + # one FOV was created from the template in this batch. + used_wells: set[str] = {rec["fields"]["well_id"] for rec in result.created} + for well_id, template in well_templates.items(): + if well_id in used_wells and template.record_id: + result.template_ids_to_delete.append(template.record_id) + return result diff --git a/applications/airtable/src/airtable_utils/schemas.py b/applications/airtable/src/airtable_utils/schemas.py index c84dd2930..1d608178b 100644 --- a/applications/airtable/src/airtable_utils/schemas.py +++ b/applications/airtable/src/airtable_utils/schemas.py @@ -131,7 +131,7 @@ class DatasetRecord(FOVRecord): @model_validator(mode="after") def _derive_channel_names(self) -> DatasetRecord: - """Populate ``channel_names`` from ``channel_0..7_name`` fields.""" + """Populate ``channel_names`` and ``channel_markers`` from ``channel_0..7_name/marker`` fields.""" if not self.channel_names: names = [] for i in range(MAX_CHANNELS): @@ -139,6 +139,14 @@ def _derive_channel_names(self) -> DatasetRecord: if name is not None: names.append(name) self.channel_names = names + if not self.channel_markers: + markers: dict[str, str] = {} + for i in range(MAX_CHANNELS): + name = getattr(self, f"channel_{i}_name") + marker = getattr(self, f"channel_{i}_marker") + if name is not None and marker is not None: + markers[name] = marker + self.channel_markers = markers return self @classmethod @@ -191,6 +199,10 @@ def _multi_select_val(v): data_path=fields.get("data_path"), tracks_path=fields.get("tracks_path"), fluorescence_modality=_select_val(fields.get("fluorescence_modality")), + microscope=_select_val(fields.get("microscope")), + labelfree_modality=_select_val(fields.get("labelfree_modality")), + treatment=_select_val(fields.get("treatment")), + hours_post_treatment=fields.get("hours post treatment"), t_shape=fields.get("t_shape"), c_shape=fields.get("c_shape"), z_shape=fields.get("z_shape"), diff --git a/applications/airtable/tests/conftest.py b/applications/airtable/tests/conftest.py index 728f0016a..2a3b7fddd 100644 --- a/applications/airtable/tests/conftest.py +++ b/applications/airtable/tests/conftest.py @@ -37,6 +37,10 @@ "channel_3_marker": None, "data_path": "/hpc/datasets/alpha.zarr", "fluorescence_modality": {"name": "widefield"}, + "microscope": {"name": "mantis"}, + "labelfree_modality": {"name": "widefield"}, + "treatment": {"name": "DMSO"}, + "hours post treatment": 2.0, "t_shape": 50, "c_shape": 2, "z_shape": 30, @@ -70,6 +74,10 @@ "channel_3_marker": None, "data_path": "/hpc/datasets/beta.zarr", "fluorescence_modality": None, + "microscope": "dragonfly", + "labelfree_modality": "oblique", + "treatment": None, + "hours post treatment": None, "t_shape": 100, "c_shape": 2, "z_shape": 15, diff --git a/applications/airtable/tests/test_database.py b/applications/airtable/tests/test_database.py index 15cbb5634..42f483fba 100644 --- a/applications/airtable/tests/test_database.py +++ b/applications/airtable/tests/test_database.py @@ -22,8 +22,8 @@ def test_init_with_env_vars(self, mock_env, mock_api): AirtableDatasets() # Api was called with the fake key mock_api.assert_called_once_with("patFAKEKEY123") - # .table() was called with the fake base id and TABLE_NAME - mock_api.return_value.table.assert_called_once_with("appFAKEBASE456", "Datasets") + # .table() is called twice: once for Datasets, once for Marker Registry + mock_api.return_value.table.assert_any_call("appFAKEBASE456", "Datasets") def test_init_raises_when_api_key_missing(self, monkeypatch): """ValueError is raised when AIRTABLE_API_KEY is not set.""" @@ -183,15 +183,43 @@ def test_dataframe_columns(self, airtable_datasets, mock_table, sample_airtable_ "seeding_density", "treatment_concentration_nm", "channel_names", + "channel_markers", *(f"channel_{i}_{attr}" for i in range(8) for attr in ("name", "marker")), "data_path", "tracks_path", "fluorescence_modality", + "microscope", + "labelfree_modality", + "treatment", + "hours_post_treatment", "t_shape", "c_shape", "z_shape", "y_shape", "x_shape", + "pixel_size_xy_um", + "pixel_size_z_um", "record_id", } assert set(df.columns) == expected_cols + + +# --------------------------------------------------------------------------- +# batch_delete +# --------------------------------------------------------------------------- + + +class TestBatchDelete: + """Test AirtableDatasets.batch_delete().""" + + def test_delegates_to_table(self, airtable_datasets, mock_table): + mock_table.batch_delete.return_value = [{"id": "rec001", "deleted": True}] + result = airtable_datasets.batch_delete(["rec001"]) + mock_table.batch_delete.assert_called_once_with(["rec001"]) + assert result == [{"id": "rec001", "deleted": True}] + + def test_passes_multiple_ids(self, airtable_datasets, mock_table): + ids = ["rec001", "rec002", "rec003"] + mock_table.batch_delete.return_value = [] + airtable_datasets.batch_delete(ids) + mock_table.batch_delete.assert_called_once_with(ids) diff --git a/applications/airtable/tests/test_register_fovs.py b/applications/airtable/tests/test_register_fovs.py index 0e9964c2f..aaddcd8e0 100644 --- a/applications/airtable/tests/test_register_fovs.py +++ b/applications/airtable/tests/test_register_fovs.py @@ -29,6 +29,7 @@ def _make_well_template(well_id: str, record_id: str | None = None, **overrides) "fov": None, "cell_type": "A549", "cell_state": "Live", + "cell_line": ["recCELLLINE1"], "marker": "TOMM20", "organelle": "mitochondria", "perturbation": "ZIKV", @@ -36,6 +37,10 @@ def _make_well_template(well_id: str, record_id: str | None = None, **overrides) "moi": 5.0, "time_interval_min": 30.0, "fluorescence_modality": "Light-sheet", + "microscope": "mantis", + "labelfree_modality": "widefield", + "treatment": "DMSO", + "hours_post_treatment": 2.0, "channel_0_marker": "brightfield", "channel_1_marker": "mitochondria", "record_id": record_id, @@ -51,6 +56,7 @@ def _make_fov_record(well_id: str, fov: str, record_id: str, **overrides) -> Dat "well_id": well_id, "fov": fov, "cell_type": "A549", + "cell_line": ["recCELLLINE1"], "marker": "TOMM20", "organelle": "mitochondria", "record_id": record_id, @@ -133,7 +139,10 @@ def test_creates_new_fov_records_from_well_templates(self): Path("/data/test_dataset.zarr/A/1/000000"), Path("/data/test_dataset.zarr/A/1/000001"), ] - with patch("airtable_utils.registration.open_ome_zarr", return_value=mock_plate): + with ( + patch("airtable_utils.registration.open_ome_zarr", return_value=mock_plate), + patch("pathlib.Path.is_dir", return_value=True), + ): result = register_fovs(paths, db=db) assert result.dataset == "test_dataset" @@ -159,8 +168,13 @@ def test_creates_new_fov_records_from_well_templates(self): assert rec0["organelle"] == "mitochondria" assert rec0["perturbation"] == "ZIKV" assert rec0["moi"] == 5.0 + assert rec0["microscope"] == "mantis" + assert rec0["labelfree_modality"] == "widefield" + assert rec0["treatment"] == "DMSO" + assert rec0["hours_post_treatment"] == 2.0 assert rec0["channel_0_marker"] == "brightfield" assert rec0["channel_1_marker"] == "mitochondria" + assert result.template_ids_to_delete == ["recWELL1"] def test_updates_existing_fov_records(self): """Existing per-FOV records get updated with zarr-derived fields only.""" @@ -172,7 +186,10 @@ def test_updates_existing_fov_records(self): mock_plate = _make_mock_plate(positions) paths = [Path("/data/test_dataset.zarr/A/1/000000")] - with patch("airtable_utils.registration.open_ome_zarr", return_value=mock_plate): + with ( + patch("airtable_utils.registration.open_ome_zarr", return_value=mock_plate), + patch("pathlib.Path.is_dir", return_value=True), + ): result = register_fovs(paths, db=db) assert len(result.created) == 0 @@ -202,7 +219,10 @@ def test_unmatched_positions(self): Path("/data/test_dataset.zarr/A/1/000000"), Path("/data/test_dataset.zarr/B/2/000000"), ] - with patch("airtable_utils.registration.open_ome_zarr", return_value=mock_plate): + with ( + patch("airtable_utils.registration.open_ome_zarr", return_value=mock_plate), + patch("pathlib.Path.is_dir", return_value=True), + ): result = register_fovs(paths, db=db) assert len(result.created) == 1 @@ -226,7 +246,10 @@ def test_mixed_create_and_update(self): Path("/data/test_dataset.zarr/A/1/000000"), Path("/data/test_dataset.zarr/A/1/000001"), ] - with patch("airtable_utils.registration.open_ome_zarr", return_value=mock_plate): + with ( + patch("airtable_utils.registration.open_ome_zarr", return_value=mock_plate), + patch("pathlib.Path.is_dir", return_value=True), + ): result = register_fovs(paths, db=db) assert len(result.updated) == 1 @@ -259,6 +282,23 @@ def test_raises_on_mixed_zarr_stores(self): with pytest.raises(ValueError, match="same zarr store"): register_fovs(paths, db=db) + def test_raises_when_cell_line_missing(self): + """ValueError raised when a well template has no cell_line set.""" + template_no_cell_line = _make_well_template("A/1", cell_line=None) + db = MagicMock() + db.get_dataset_records.return_value = [template_no_cell_line] + + positions = {"A/1/000000": (10, 3, 1, 512, 512)} + mock_plate = _make_mock_plate(positions) + + paths = [Path("/data/test_dataset.zarr/A/1/000000")] + with ( + patch("airtable_utils.registration.open_ome_zarr", return_value=mock_plate), + patch("pathlib.Path.is_dir", return_value=True), + ): + with pytest.raises(ValueError, match="cell_line is required"): + register_fovs(paths, db=db) + def test_all_records_already_per_fov_no_templates(self): """When all records are per-FOV and no templates exist, only updates happen.""" existing = _make_fov_record("A/1", "000000", record_id="recFOV1") @@ -275,7 +315,10 @@ def test_all_records_already_per_fov_no_templates(self): Path("/data/test_dataset.zarr/A/1/000000"), Path("/data/test_dataset.zarr/A/1/000001"), ] - with patch("airtable_utils.registration.open_ome_zarr", return_value=mock_plate): + with ( + patch("airtable_utils.registration.open_ome_zarr", return_value=mock_plate), + patch("pathlib.Path.is_dir", return_value=True), + ): result = register_fovs(paths, db=db) assert len(result.updated) == 1 @@ -341,12 +384,112 @@ def test_copies_non_none_fields(self): assert fields["perturbation"] == "ZIKV" assert fields["moi"] == 5.0 assert fields["time_interval_min"] == 30.0 + assert fields["microscope"] == "mantis" + assert fields["labelfree_modality"] == "widefield" + assert fields["treatment"] == "DMSO" + assert fields["hours_post_treatment"] == 2.0 assert fields["channel_0_marker"] == "brightfield" assert fields["channel_1_marker"] == "mitochondria" def test_skips_none_fields(self): - template = _make_well_template("A/1", seeding_density=None, treatment_concentration_nm=None) + template = _make_well_template( + "A/1", + seeding_density=None, + treatment_concentration_nm=None, + microscope=None, + labelfree_modality=None, + ) fields = copy_well_template_fields(template) assert "seeding_density" not in fields assert "treatment_concentration_nm" not in fields + assert "microscope" not in fields + assert "labelfree_modality" not in fields + + +# --------------------------------------------------------------------------- +# template deletion tracking +# --------------------------------------------------------------------------- + + +class TestTemplateDeletion: + """Tests for template_ids_to_delete population in register_fovs.""" + + def test_template_deleted_when_fov_created(self): + """Template record ID appears in deletion list when FOVs are created from it.""" + template_a1 = _make_well_template("A/1", record_id="recWELL1") + db = MagicMock() + db.get_dataset_records.return_value = [template_a1] + + positions = {"A/1/000000": (10, 3, 1, 512, 512)} + mock_plate = _make_mock_plate(positions) + + paths = [Path("/data/test_dataset.zarr/A/1/000000")] + with ( + patch("airtable_utils.registration.open_ome_zarr", return_value=mock_plate), + patch("pathlib.Path.is_dir", return_value=True), + ): + result = register_fovs(paths, db=db) + + assert len(result.created) == 1 + assert result.template_ids_to_delete == ["recWELL1"] + + def test_template_not_deleted_when_all_positions_unmatched(self): + """Template with no created FOVs is not in deletion list.""" + template_a1 = _make_well_template("A/1", record_id="recWELL1") + db = MagicMock() + db.get_dataset_records.return_value = [template_a1] + + # B/2 has no template — will be unmatched + positions = {"B/2/000000": (10, 3, 1, 512, 512)} + mock_plate = _make_mock_plate(positions) + + paths = [Path("/data/test_dataset.zarr/B/2/000000")] + with ( + patch("airtable_utils.registration.open_ome_zarr", return_value=mock_plate), + patch("pathlib.Path.is_dir", return_value=True), + ): + result = register_fovs(paths, db=db) + + assert len(result.unmatched) == 1 + assert result.template_ids_to_delete == [] + + def test_only_used_templates_deleted(self): + """Only templates where at least one FOV was created appear in deletion list.""" + template_a1 = _make_well_template("A/1", record_id="recWELL_A1") + template_b2 = _make_well_template("B/2", record_id="recWELL_B2") + db = MagicMock() + db.get_dataset_records.return_value = [template_a1, template_b2] + + # A/1 gets a FOV; B/2 gets no positions in this batch + positions = {"A/1/000000": (10, 3, 1, 512, 512)} + mock_plate = _make_mock_plate(positions) + + paths = [Path("/data/test_dataset.zarr/A/1/000000")] + with ( + patch("airtable_utils.registration.open_ome_zarr", return_value=mock_plate), + patch("pathlib.Path.is_dir", return_value=True), + ): + result = register_fovs(paths, db=db) + + assert len(result.created) == 1 + assert result.template_ids_to_delete == ["recWELL_A1"] + + def test_template_without_record_id_not_added(self): + """Template with no record_id is skipped in deletion list.""" + template_a1 = _make_well_template("A/1", record_id=None) + db = MagicMock() + db.get_dataset_records.return_value = [template_a1] + + positions = {"A/1/000000": (10, 3, 1, 512, 512)} + mock_plate = _make_mock_plate(positions) + + paths = [Path("/data/test_dataset.zarr/A/1/000000")] + with ( + patch("airtable_utils.registration.open_ome_zarr", return_value=mock_plate), + patch("pathlib.Path.is_dir", return_value=True), + ): + result = register_fovs(paths, db=db) + + assert len(result.created) == 1 + assert result.template_ids_to_delete == [] diff --git a/applications/airtable/tests/test_schemas.py b/applications/airtable/tests/test_schemas.py index 7917af8ba..11e611355 100644 --- a/applications/airtable/tests/test_schemas.py +++ b/applications/airtable/tests/test_schemas.py @@ -164,6 +164,10 @@ def test_full_record_with_select_dicts(self, sample_airtable_records): assert rec.channel_1_marker == "Endoplasmic Reticulum" assert rec.data_path == "/hpc/datasets/alpha.zarr" assert rec.fluorescence_modality == "widefield" + assert rec.microscope == "mantis" + assert rec.labelfree_modality == "widefield" + assert rec.treatment == "DMSO" + assert rec.hours_post_treatment == 2.0 assert rec.t_shape == 50 assert rec.c_shape == 2 assert rec.z_shape == 30 @@ -181,6 +185,10 @@ def test_record_with_plain_string_fields(self, sample_airtable_records): assert rec.perturbation == "ZIKV" assert rec.moi == 0.5 assert rec.cell_line is None + assert rec.microscope == "dragonfly" + assert rec.labelfree_modality == "oblique" + assert rec.treatment is None + assert rec.hours_post_treatment is None def test_minimal_record(self): """Record with only required fields.""" diff --git a/applications/dynaclr/configs/cellanome/embed_all.sh b/applications/dynaclr/configs/cellanome/embed_all.sh new file mode 100755 index 000000000..3671d4bd5 --- /dev/null +++ b/applications/dynaclr/configs/cellanome/embed_all.sh @@ -0,0 +1,55 @@ +#!/bin/bash +# SLURM array job: generate DINOv3 + DynaCLR embeddings for all 5 cellanome datasets. +# Array index: 0-9 (5 datasets × 2 models) +# 0-4 → DINOv3 +# 5-9 → DynaCLR +# +# Usage: +# sbatch embed_all.sh +# # or a single task interactively: +# SLURM_ARRAY_TASK_ID=0 bash embed_all.sh + +#SBATCH --job-name=cellanome_embed +#SBATCH --array=0-9 +#SBATCH --partition=gpu +#SBATCH --gres=gpu:1 +#SBATCH --cpus-per-task=8 +#SBATCH --mem=64G +#SBATCH --time=4:00:00 +#SBATCH --output=/hpc/mydata/eduardo.hirata/logs/cellanome_embed_%A_%a.out +#SBATCH --error=/hpc/mydata/eduardo.hirata/logs/cellanome_embed_%A_%a.err + +export PYTHONNOUSERSITE=1 + +REPO=/home/eduardo.hirata/repos/viscy +CFG_ROOT="${REPO}/applications/dynaclr/configs/cellanome" + +DATASETS=( + "20251203141914_P-05_R000414_FC_BH_120325_try4_Adherent_with_SRA_training_4lanes" + "20260211112411_P-05_R000439_FC_2026_02_11_manual_loading_mixed_GFP+RFP" + "20260220144306_P-05_R000476_FC_2026_02_20_A549_GFP_RFP_Org_Cells" + "20260310112219_P-05_R000486_FC_2026_03_10_A549_pAL27+ISG15_off_on_DENV" + "20260324133209_P-05_R000497_FC_2026_03_24_A549_SEC61B_G3BP1_pAL40_DENV_rerun" +) + +TASK=${SLURM_ARRAY_TASK_ID} +N=${#DATASETS[@]} # 5 + +DATASET_IDX=$(( TASK % N )) +MODEL_IDX=$(( TASK / N )) # 0 = DINOv3, 1 = DynaCLR + +DATASET="${DATASETS[$DATASET_IDX]}" + +if [ "$MODEL_IDX" -eq 0 ]; then + SCRIPT="${REPO}/applications/dynaclr/scripts/cellanome/embed_dinov3.py" + CONFIG="${CFG_ROOT}/${DATASET}/embed_dinov3.yml" +else + SCRIPT="${REPO}/applications/dynaclr/scripts/cellanome/embed_dynaclr.py" + CONFIG="${CFG_ROOT}/${DATASET}/embed_dynaclr.yml" +fi + +echo "Task ${TASK}: dataset=${DATASET} model_idx=${MODEL_IDX}" +echo "Config: ${CONFIG}" + +cd "${REPO}" +uv run python "${SCRIPT}" "${CONFIG}" diff --git a/applications/dynaclr/configs/cellanome/embed_dinov3.yml b/applications/dynaclr/configs/cellanome/embed_dinov3.yml new file mode 100644 index 000000000..0314593a6 --- /dev/null +++ b/applications/dynaclr/configs/cellanome/embed_dinov3.yml @@ -0,0 +1,38 @@ +# DINOv3 embedding extraction for cellanome dataset. +# Reads primary_analysis.csv directly, outputs cell-level anndata. +# Usage: uv run python applications/dynaclr/scripts/cellanome/embed_dinov3.py configs/cellanome/embed_dinov3.yml + +# --- Data paths --- +zarr_store: /hpc/projects/multimodal/datasets/20251203141914_P-05_R000414_FC_BH_120325_try4_Adherent_with_SRA_training_4lanes/20251203141914_P-05_R000414_FC_BH_120325_try4_Adherent_with_SRA_training_4lanes.zarr +analysis_base: /hpc/instruments/cm.r3200/20251203141914_P-05_R000414_FC_BH_120325_try4_Adherent_with_SRA_training_4lanes/image_analysis_output-12032025-143316 +transcriptome_anndata: /hpc/projects/multimodal/datasets/20251203141914_P-05_R000414_FC_BH_120325_try4_Adherent_with_SRA_training_4lanes/anndata/seurat-bc3a-l_all.zarr +output_path: /hpc/projects/multimodal/datasets/20251203141914_P-05_R000414_FC_BH_120325_try4_Adherent_with_SRA_training_4lanes/2-embeddings/dinov3-convnext-tiny-BF.zarr + +# --- Experiment --- +# Omit to auto-discover all scans/lanes under analysis_base. +# scan_ids: [5] +# lane_ids: [3, 4, 5, 6] + +# --- Model --- +model_name: facebook/dinov3-convnext-tiny-pretrain-lvd1689m + +# --- Channels --- +channels: + - White + +# --- Crop --- +patch_size: 96 +reference_pixel_size: 1.0 +source_pixel_size: 1.0 + +# --- Filtering --- +# Dict of column_name: {min, max, eq, isin} applied to primary_analysis.csv. +filters: + object_class: + isin: [cell, cell-adhered] + object_radius_px: + min: 39 + +# --- Inference --- +batch_size: 128 +device: cuda diff --git a/applications/dynaclr/configs/cellanome/embed_dynaclr.yml b/applications/dynaclr/configs/cellanome/embed_dynaclr.yml new file mode 100644 index 000000000..a6023b717 --- /dev/null +++ b/applications/dynaclr/configs/cellanome/embed_dynaclr.yml @@ -0,0 +1,46 @@ +# DynaCLR embedding extraction for cellanome dataset. +# Reads primary_analysis.csv directly, outputs cell-level anndata. +# Usage: uv run python applications/dynaclr/scripts/cellanome/embed_dynaclr.py configs/cellanome/embed_dynaclr.yml + +# --- Data paths --- +zarr_store: /hpc/projects/multimodal/datasets/20251203141914_P-05_R000414_FC_BH_120325_try4_Adherent_with_SRA_training_4lanes/20251203141914_P-05_R000414_FC_BH_120325_try4_Adherent_with_SRA_training_4lanes.zarr +analysis_base: /hpc/instruments/cm.r3200/20251203141914_P-05_R000414_FC_BH_120325_try4_Adherent_with_SRA_training_4lanes/image_analysis_output-12032025-143316 +transcriptome_anndata: /hpc/projects/multimodal/datasets/20251203141914_P-05_R000414_FC_BH_120325_try4_Adherent_with_SRA_training_4lanes/anndata/seurat-bc3a-l_all.zarr +output_path: /hpc/projects/multimodal/datasets/20251203141914_P-05_R000414_FC_BH_120325_try4_Adherent_with_SRA_training_4lanes/2-embeddings/dynaclr-2d-boc-BF.zarr + +# --- Experiment --- +# scan_ids: [5] +# lane_ids: [3, 4, 5, 6] + +# --- Model --- +ckpt_path: /hpc/projects/organelle_phenotyping/models/SEC61_TOMM20_G3BP1_Sensor/time_interval/dynaclr_gfp_rfp_Ph/organelle_sensor_phase_maxproj_ver3_150epochs/saved_checkpoints/epoch=104-step=53760.ckpt +encoder_config: + backbone: convnext_tiny + in_channels: 1 + in_stack_depth: 1 + stem_kernel_size: [1, 4, 4] + stem_stride: [1, 4, 4] + embedding_dim: 768 + projection_dim: 32 + +# --- Channel --- +channel_name: White + +# --- Crop --- +# Trained on 160x160 at 0.149 µm/px (~23.8 µm physical). +# Cellanome is 20x at 0.247 µm/px. +# raw_crop = 160 * 0.149 / 0.247 = 96 px, resized to 160. +patch_size: 160 +reference_pixel_size: 0.149 +source_pixel_size: 0.247 + +# --- Filtering --- +filters: + object_class: + isin: [cell, cell-adhered] + object_radius_px: + min: 39 + +# --- Inference --- +batch_size: 128 +device: cuda diff --git a/applications/dynaclr/configs/collections/DynaCLR-2D-BagOfChannels-v3.yml b/applications/dynaclr/configs/collections/DynaCLR-2D-BagOfChannels-v3.yml index 95047fdc5..415c9238f 100644 --- a/applications/dynaclr/configs/collections/DynaCLR-2D-BagOfChannels-v3.yml +++ b/applications/dynaclr/configs/collections/DynaCLR-2D-BagOfChannels-v3.yml @@ -1,5 +1,5 @@ name: DynaCLR-2D-BagOfChannels-v3 -description: "Multi-organelle bag-of-channels DynaCLR training collection. Includes SEC61B (ER), TOMM20 (mitochondria), and G3BP1 (stress granules) experiments with ZIKV/DENV infection. All 3 channels (Phase3D, GFP, mCherry) trained jointly." +description: "[LEGACY] Multi-organelle bag-of-channels DynaCLR training collection. Includes SEC61B (ER), TOMM20 (mitochondria), and G3BP1 (stress granules) experiments with ZIKV/DENV infection. All 3 channels (Phase3D, GFP, mCherry) trained jointly." provenance: airtable_base_id: app8vqaoWyOwa0sB5 diff --git a/applications/dynaclr/configs/collections/DynaCLR-2D-MIP-BagOfChannels-annotated.yml b/applications/dynaclr/configs/collections/DynaCLR-2D-MIP-BagOfChannels-annotated.yml new file mode 100644 index 000000000..960779080 --- /dev/null +++ b/applications/dynaclr/configs/collections/DynaCLR-2D-MIP-BagOfChannels-annotated.yml @@ -0,0 +1,119 @@ +name: DynaCLR-2D-MIP-BagOfChannels-Annotated +description: "Subset of DynaCLR-2D-MIP-BagOfChannels-MultiCell with available cell annotations. + Includes 2025_01_28 G3BP1 and 2025_07_24 multi-channel experiments. + Used for linear classifier evaluation. ALFI excluded." +datasets_root: /hpc/projects/organelle_phenotyping + +provenance: + airtable_base_id: app8vqaoWyOwa0sB5 + airtable_query: "OR(SEARCH(\"2025_01_28_A549_G3BP1_ZIKV_DENV\", {dataset}), SEARCH(\"2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV\", {dataset}))" + record_ids: [] + created_at: "2026-04-08T00:00:00" + created_by: "eduardo.hirata" + +experiments: + # ── G3BP1 (stress granules) — 2025_01_28 ── + # Annotations: B/4 (uninfected), C/4 (infected) + - name: 2025_01_28_A549_G3BP1_ZIKV_DENV_G3BP1 + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: G3BP1 + perturbation_wells: + uninfected: + - B/4 + infected: + - B/2 + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: G3BP1 + organelle: stress_granules + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── 2025_07_24 multi-channel — G3BP1, SEC61B, viral sensor, Phase3D ── + # Annotations: A/2, C/1, C/2 (TOMM20 wells B/1, B/2 not annotated — excluded) + - name: 2025_07_24_A549_G3BP1_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: G3BP1 + perturbation_wells: + uninfected: + - C/1 + infected: + - C/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: G3BP1 + organelle: stress_granules + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_07_24_A549_SEC61_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: SEC61B + perturbation_wells: + uninfected: + - A/1 + infected: + - A/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: SEC61B + organelle: endoplasmic_reticulum + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_07_24_A549_viral_sensor + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - C/1 + - B/1 + infected: + - C/2 + - B/2 + - A/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_07_24_A549_Phase3D + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - C/1 + - B/1 + infected: + - C/2 + - B/2 + - A/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 diff --git a/applications/dynaclr/configs/collections/DynaCLR-2D-MIP-BagOfChannels-v2.yml b/applications/dynaclr/configs/collections/DynaCLR-2D-MIP-BagOfChannels-v2.yml new file mode 100644 index 000000000..fb52e3f1e --- /dev/null +++ b/applications/dynaclr/configs/collections/DynaCLR-2D-MIP-BagOfChannels-v2.yml @@ -0,0 +1,658 @@ +name: DynaCLR-2D-MIP-BagOfChannels-MultiCell +description: "Multi-cell-type bag-of-channels 2D DynaCLR training collection with z-reduction. Combines A549 infectomics (3D z-stacks from VAST, MIP for fluorescence / center-slice for Phase3D), microglia dynamorph (BF, Phase3D, Retardance), ALFI mitosis (DIC, U2OS/RPE-1/HeLa), and dragonfly confocal. All data paths point to /hpc/projects/organelle_phenotyping/datasets/." +datasets_root: /hpc/projects/organelle_phenotyping + +provenance: + airtable_base_id: app8vqaoWyOwa0sB5 + airtable_query: "OR(SEARCH(\"2024_10_09_A549_TOMM20_ZIKV_DENV\", {dataset}), SEARCH(\"2024_11_05_A549_TOMM20_ZIKV_DENV\", {dataset}), SEARCH(\"2024_10_16_A549_SEC61_ZIKV_DENV\", {dataset}), SEARCH(\"2024_10_31_A549_SEC61_ZIKV_DENV\", {dataset}), SEARCH(\"2025_01_28_A549_G3BP1_ZIKV_DENV\", {dataset}), SEARCH(\"2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV\", {dataset}), SEARCH(\"2025_04_15_A549_H2B_CAAX_ZIKV_DENV\", {dataset}), SEARCH(\"2025_04_17_A549_H2B_CAAX_DENV\", {dataset}), SEARCH(\"2024_08_14_ZIKV_pal17_48h\", {dataset}), SEARCH(\"20191107_1209_1_GW23_dynamorph\", {dataset}), SEARCH(\"ALFI\", {dataset}))" + record_ids: [] + created_at: "2026-03-30T00:00:00" + created_by: "eduardo.hirata" + +experiments: + # ══════════════════════════════════════════════════════════════════════ + # A549 infectomics — 3D z-stacks on VAST (single-channel bags) + # ══════════════════════════════════════════════════════════════════════ + + # ── G3BP1 (stress granules) ── + - name: 2025_01_28_A549_G3BP1_ZIKV_DENV_G3BP1 + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: G3BP1 + perturbation_wells: + uninfected: + - B/4 + infected: + - B/2 + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: G3BP1 + organelle: stress_granules + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_01_28_A549_viral_sensor_ZIKV_DENV + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/4 + infected: + - B/2 + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_01_28_A549_Phase3D_ZIKV_DENV + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/4 + infected: + - B/2 + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_07_24_A549_G3BP1_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: G3BP1 + perturbation_wells: + uninfected: + - C/1 + infected: + - C/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: G3BP1 + organelle: stress_granules + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── CAAX (membrane) ── + - name: 2025_04_15_A549_H2B_CAAX_ZIKV_DENV_CAAX + data_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/2025_04_15_A549_H2B_CAAX_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: CAAX + perturbation_wells: + uninfected: + - B/1 + DENV: + - B/2 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: CAAX + organelle: membrane + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_04_17_A549_H2B_CAAX_DENV_CAAX + data_path: ${datasets_root}/datasets/2025_04_17_A549_H2B_CAAX_DENV/2025_04_17_A549_H2B_CAAX_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_17_A549_H2B_CAAX_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: CAAX + perturbation_wells: + uninfected: + - B/1 + DENV: + - B/2 + interval_minutes: 10.0 + start_hpi: 3.0 + marker: CAAX + organelle: membrane + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── H2B (chromatin) ── + - name: 2025_04_15_A549_H2B_CAAX_ZIKV_DENV_H2B + data_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/2025_04_15_A549_H2B_CAAX_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/tracking.zarr + channels: + - name: raw Cy5 EX639 EM698-70 + marker: HIST2H2BE + perturbation_wells: + uninfected: + - B/1 + DENV: + - B/2 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: HIST2H2BE + organelle: chromatin + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_04_17_A549_H2B_CAAX_DENV_H2B + data_path: ${datasets_root}/datasets/2025_04_17_A549_H2B_CAAX_DENV/2025_04_17_A549_H2B_CAAX_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_17_A549_H2B_CAAX_DENV/tracking.zarr + channels: + - name: raw Cy5 EX639 EM698-70 + marker: HIST2H2BE + perturbation_wells: + uninfected: + - B/1 + DENV: + - B/2 + interval_minutes: 10.0 + start_hpi: 3.0 + marker: HIST2H2BE + organelle: chromatin + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── TOMM20 (mitochondria) ── + - name: 2024_10_09_A549_TOMM20_ZIKV_DENV_TOMM20 + data_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/2024_10_09_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: TOMM20 + perturbation_wells: + uninfected: + - A/4 + infected: + - B/4 + interval_minutes: 30.0 + start_hpi: 5.0 + marker: TOMM20 + organelle: mitochondria + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_09_A549_TOMM20_ZIKV_DENV_viral_sensor + data_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/2024_10_09_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - A/4 + infected: + - B/4 + interval_minutes: 30.0 + start_hpi: 5.0 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_09_A549_TOMM20_ZIKV_DENV_Phase3D + data_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/2024_10_09_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - A/4 + infected: + - B/4 + interval_minutes: 30.0 + start_hpi: 5.0 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_11_05_A549_TOMM20_ZIKV_DENV_TOMM20 + data_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/2024_11_05_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: TOMM20 + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.5 + marker: TOMM20 + organelle: mitochondria + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_07_24_A549_TOMM20_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: TOMM20 + perturbation_wells: + uninfected: + - B/1 + infected: + - B/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: TOMM20 + organelle: mitochondria + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── SEC61B (endoplasmic reticulum) ── + - name: 2024_10_16_A549_SEC61_ZIKV_DENV_SEC61 + data_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/2024_10_16_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: SEC61B + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: SEC61B + organelle: endoplasmic_reticulum + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_16_A549_SEC61_ZIKV_DENV_viral_sensor + data_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/2024_10_16_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_16_A549_SEC61_ZIKV_DENV_Phase3D + data_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/2024_10_16_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_31_A549_SEC61_ZIKV_DENV_SEC61 + data_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/2024_10_31_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: SEC61B + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: SEC61B + organelle: endoplasmic_reticulum + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_07_24_A549_SEC61_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: SEC61B + perturbation_wells: + uninfected: + - A/1 + infected: + - A/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: SEC61B + organelle: endoplasmic_reticulum + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── Viral sensor (mCherry) ── + - name: 2025_07_24_A549_viral_sensor_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - C/1 + - B/1 + infected: + - C/2 + - B/2 + - A/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_11_05_A549_TOMM20_ZIKV_DENV_viral_sensor + data_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/2024_11_05_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.5 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_31_A549_SEC61_ZIKV_DENV_viral_sensor + data_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/2024_10_31_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── A549 Phase3D (label-free) ── + - name: 2025_07_24_A549_Phase3D_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - C/1 + - B/1 + infected: + - C/2 + - B/2 + - A/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_11_05_A549_TOMM20_ZIKV_DENV_Phase3D + data_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/2024_11_05_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.5 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_31_A549_SEC61_ZIKV_DENV_Phase3D + data_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/2024_10_31_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── Dragonfly confocal — viral sensor (pAL10) ── + - name: 2024_08_14_ZIKV_pal17_48h_pAL10 + data_path: ${datasets_root}/datasets/2024_08_14_ZIKV_pal17_48h/2024_08_14_ZIKV_pal17_48h_sharded.zarr + tracks_path: ${datasets_root}/datasets/2024_08_14_ZIKV_pal17_48h/tracking.zarr/2024_08_14_ZIKV_pal17_48h.zarr + channels: + - name: MultiCam_GFP_BF + marker: pAL10 + perturbation_wells: + uninfected: + - "0/3" + ZIKV: + - "0/4" + - "0/5" + - "0/6" + interval_minutes: 30.0 + start_hpi: 3.0 + marker: pAL10 + organelle: viral_sensor + moi: 1.0 + pixel_size_xy_um: 0.206 + pixel_size_z_um: 0.2878 + + - name: 2024_08_14_ZIKV_pal17_48h_Phase3D + data_path: ${datasets_root}/datasets/2024_08_14_ZIKV_pal17_48h/2024_08_14_ZIKV_pal17_48h_sharded.zarr + tracks_path: ${datasets_root}/datasets/2024_08_14_ZIKV_pal17_48h/tracking.zarr/2024_08_14_ZIKV_pal17_48h.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - "0/3" + ZIKV: + - "0/4" + - "0/5" + - "0/6" + interval_minutes: 30.0 + start_hpi: 3.0 + marker: Phase3D + organelle: Phase3D + moi: 1.0 + pixel_size_xy_um: 0.206 + pixel_size_z_um: 0.2878 + + # ══════════════════════════════════════════════════════════════════════ + # Microglia dynamorph — 2D label-free (BF, Phase3D, Retardance) + # ══════════════════════════════════════════════════════════════════════ + + - name: 20191107_GW23_dynamorph_Brightfield + data_path: ${datasets_root}/datasets/20191107_1209_1_GW23_dynamorph/20191107_1209_1_GW23_dynamorph.zarr + tracks_path: ${datasets_root}/datasets/20191107_1209_1_GW23_dynamorph/tracking.zarr + channels: + - name: Brightfield + marker: Brightfield + perturbation_wells: + untreated: + - C/5 + IL17: + - B/4 + IFN-beta: + - B/5 + Rubella: + - C/4 + Glioblastoma_supernatant: + - B/2 + interval_minutes: 9.0 + start_hpi: 0.0 + marker: Brightfield + organelle: Brightfield + pixel_size_xy_um: 0.325 + + - name: 20191107_GW23_dynamorph_Phase3D + data_path: ${datasets_root}/datasets/20191107_1209_1_GW23_dynamorph/20191107_1209_1_GW23_dynamorph.zarr + tracks_path: ${datasets_root}/datasets/20191107_1209_1_GW23_dynamorph/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + untreated: + - C/5 + IL17: + - B/4 + IFN-beta: + - B/5 + Rubella: + - C/4 + Glioblastoma_supernatant: + - B/2 + interval_minutes: 9.0 + start_hpi: 0.0 + marker: Phase3D + organelle: Phase3D + pixel_size_xy_um: 0.325 + + - name: 20191107_GW23_dynamorph_Retardance + data_path: ${datasets_root}/datasets/20191107_1209_1_GW23_dynamorph/20191107_1209_1_GW23_dynamorph.zarr + tracks_path: ${datasets_root}/datasets/20191107_1209_1_GW23_dynamorph/tracking.zarr + channels: + - name: Retardance + marker: Retardance + perturbation_wells: + untreated: + - C/5 + IL17: + - B/4 + IFN-beta: + - B/5 + Rubella: + - C/4 + Glioblastoma_supernatant: + - B/2 + interval_minutes: 9.0 + start_hpi: 0.0 + marker: Retardance + organelle: Retardance + pixel_size_xy_um: 0.325 + + # ══════════════════════════════════════════════════════════════════════ + # ALFI — 2D DIC mitosis datasets (multiple cell types) + # ══════════════════════════════════════════════════════════════════════ + + - name: ALFI_U2OS_DMSO_MLN8237 + data_path: ${datasets_root}/datasets/ALFI_U2OS_DMSO_MLN8237/ALFI_U2OS_DMSO_MLN8237.zarr + tracks_path: ${datasets_root}/datasets/ALFI_U2OS_DMSO_MLN8237/tracking.zarr + channels: + - name: DIC + marker: DIC + perturbation_wells: + DMSO: + - MI02/0 + MLN8237: + - MI01/0 + - MI03/0 + - MI04/0 + - MI05/0 + interval_minutes: 7.0 + start_hpi: 0.0 + marker: DIC + organelle: DIC + pixel_size_xy_um: 0.1766 + + - name: ALFI_RPE1_untreated + data_path: ${datasets_root}/datasets/ALFI_RPE1_untreated/ALFI_RPE1_untreated.zarr + tracks_path: ${datasets_root}/datasets/ALFI_RPE1_untreated/tracking.zarr + channels: + - name: DIC + marker: DIC + perturbation_wells: + untreated: + - MI07/0 + - MI08/0 + interval_minutes: 7.0 + start_hpi: 0.0 + marker: DIC + organelle: DIC + pixel_size_xy_um: 0.2631 + + - name: ALFI_HeLa_DMSO_MLN8237 + data_path: ${datasets_root}/datasets/ALFI_HeLa_DMSO_MLN8237/ALFI_HeLa_DMSO_MLN8237.zarr + tracks_path: ${datasets_root}/datasets/ALFI_HeLa_DMSO_MLN8237/tracking.zarr + channels: + - name: DIC + marker: DIC + perturbation_wells: + DMSO: + - MI06/0 + interval_minutes: 7.0 + start_hpi: 0.0 + marker: DIC + organelle: DIC + pixel_size_xy_um: 0.2631 diff --git a/applications/dynaclr/configs/collections/DynaCLR-2D-MIP-BagOfChannels.yml b/applications/dynaclr/configs/collections/DynaCLR-2D-MIP-BagOfChannels.yml new file mode 100644 index 000000000..fb52e3f1e --- /dev/null +++ b/applications/dynaclr/configs/collections/DynaCLR-2D-MIP-BagOfChannels.yml @@ -0,0 +1,658 @@ +name: DynaCLR-2D-MIP-BagOfChannels-MultiCell +description: "Multi-cell-type bag-of-channels 2D DynaCLR training collection with z-reduction. Combines A549 infectomics (3D z-stacks from VAST, MIP for fluorescence / center-slice for Phase3D), microglia dynamorph (BF, Phase3D, Retardance), ALFI mitosis (DIC, U2OS/RPE-1/HeLa), and dragonfly confocal. All data paths point to /hpc/projects/organelle_phenotyping/datasets/." +datasets_root: /hpc/projects/organelle_phenotyping + +provenance: + airtable_base_id: app8vqaoWyOwa0sB5 + airtable_query: "OR(SEARCH(\"2024_10_09_A549_TOMM20_ZIKV_DENV\", {dataset}), SEARCH(\"2024_11_05_A549_TOMM20_ZIKV_DENV\", {dataset}), SEARCH(\"2024_10_16_A549_SEC61_ZIKV_DENV\", {dataset}), SEARCH(\"2024_10_31_A549_SEC61_ZIKV_DENV\", {dataset}), SEARCH(\"2025_01_28_A549_G3BP1_ZIKV_DENV\", {dataset}), SEARCH(\"2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV\", {dataset}), SEARCH(\"2025_04_15_A549_H2B_CAAX_ZIKV_DENV\", {dataset}), SEARCH(\"2025_04_17_A549_H2B_CAAX_DENV\", {dataset}), SEARCH(\"2024_08_14_ZIKV_pal17_48h\", {dataset}), SEARCH(\"20191107_1209_1_GW23_dynamorph\", {dataset}), SEARCH(\"ALFI\", {dataset}))" + record_ids: [] + created_at: "2026-03-30T00:00:00" + created_by: "eduardo.hirata" + +experiments: + # ══════════════════════════════════════════════════════════════════════ + # A549 infectomics — 3D z-stacks on VAST (single-channel bags) + # ══════════════════════════════════════════════════════════════════════ + + # ── G3BP1 (stress granules) ── + - name: 2025_01_28_A549_G3BP1_ZIKV_DENV_G3BP1 + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: G3BP1 + perturbation_wells: + uninfected: + - B/4 + infected: + - B/2 + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: G3BP1 + organelle: stress_granules + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_01_28_A549_viral_sensor_ZIKV_DENV + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/4 + infected: + - B/2 + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_01_28_A549_Phase3D_ZIKV_DENV + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/4 + infected: + - B/2 + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_07_24_A549_G3BP1_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: G3BP1 + perturbation_wells: + uninfected: + - C/1 + infected: + - C/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: G3BP1 + organelle: stress_granules + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── CAAX (membrane) ── + - name: 2025_04_15_A549_H2B_CAAX_ZIKV_DENV_CAAX + data_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/2025_04_15_A549_H2B_CAAX_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: CAAX + perturbation_wells: + uninfected: + - B/1 + DENV: + - B/2 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: CAAX + organelle: membrane + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_04_17_A549_H2B_CAAX_DENV_CAAX + data_path: ${datasets_root}/datasets/2025_04_17_A549_H2B_CAAX_DENV/2025_04_17_A549_H2B_CAAX_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_17_A549_H2B_CAAX_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: CAAX + perturbation_wells: + uninfected: + - B/1 + DENV: + - B/2 + interval_minutes: 10.0 + start_hpi: 3.0 + marker: CAAX + organelle: membrane + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── H2B (chromatin) ── + - name: 2025_04_15_A549_H2B_CAAX_ZIKV_DENV_H2B + data_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/2025_04_15_A549_H2B_CAAX_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/tracking.zarr + channels: + - name: raw Cy5 EX639 EM698-70 + marker: HIST2H2BE + perturbation_wells: + uninfected: + - B/1 + DENV: + - B/2 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: HIST2H2BE + organelle: chromatin + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_04_17_A549_H2B_CAAX_DENV_H2B + data_path: ${datasets_root}/datasets/2025_04_17_A549_H2B_CAAX_DENV/2025_04_17_A549_H2B_CAAX_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_17_A549_H2B_CAAX_DENV/tracking.zarr + channels: + - name: raw Cy5 EX639 EM698-70 + marker: HIST2H2BE + perturbation_wells: + uninfected: + - B/1 + DENV: + - B/2 + interval_minutes: 10.0 + start_hpi: 3.0 + marker: HIST2H2BE + organelle: chromatin + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── TOMM20 (mitochondria) ── + - name: 2024_10_09_A549_TOMM20_ZIKV_DENV_TOMM20 + data_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/2024_10_09_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: TOMM20 + perturbation_wells: + uninfected: + - A/4 + infected: + - B/4 + interval_minutes: 30.0 + start_hpi: 5.0 + marker: TOMM20 + organelle: mitochondria + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_09_A549_TOMM20_ZIKV_DENV_viral_sensor + data_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/2024_10_09_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - A/4 + infected: + - B/4 + interval_minutes: 30.0 + start_hpi: 5.0 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_09_A549_TOMM20_ZIKV_DENV_Phase3D + data_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/2024_10_09_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - A/4 + infected: + - B/4 + interval_minutes: 30.0 + start_hpi: 5.0 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_11_05_A549_TOMM20_ZIKV_DENV_TOMM20 + data_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/2024_11_05_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: TOMM20 + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.5 + marker: TOMM20 + organelle: mitochondria + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_07_24_A549_TOMM20_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: TOMM20 + perturbation_wells: + uninfected: + - B/1 + infected: + - B/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: TOMM20 + organelle: mitochondria + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── SEC61B (endoplasmic reticulum) ── + - name: 2024_10_16_A549_SEC61_ZIKV_DENV_SEC61 + data_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/2024_10_16_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: SEC61B + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: SEC61B + organelle: endoplasmic_reticulum + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_16_A549_SEC61_ZIKV_DENV_viral_sensor + data_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/2024_10_16_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_16_A549_SEC61_ZIKV_DENV_Phase3D + data_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/2024_10_16_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_31_A549_SEC61_ZIKV_DENV_SEC61 + data_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/2024_10_31_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: SEC61B + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: SEC61B + organelle: endoplasmic_reticulum + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_07_24_A549_SEC61_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: SEC61B + perturbation_wells: + uninfected: + - A/1 + infected: + - A/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: SEC61B + organelle: endoplasmic_reticulum + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── Viral sensor (mCherry) ── + - name: 2025_07_24_A549_viral_sensor_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - C/1 + - B/1 + infected: + - C/2 + - B/2 + - A/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_11_05_A549_TOMM20_ZIKV_DENV_viral_sensor + data_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/2024_11_05_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.5 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_31_A549_SEC61_ZIKV_DENV_viral_sensor + data_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/2024_10_31_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── A549 Phase3D (label-free) ── + - name: 2025_07_24_A549_Phase3D_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - C/1 + - B/1 + infected: + - C/2 + - B/2 + - A/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_11_05_A549_TOMM20_ZIKV_DENV_Phase3D + data_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/2024_11_05_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.5 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_31_A549_SEC61_ZIKV_DENV_Phase3D + data_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/2024_10_31_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── Dragonfly confocal — viral sensor (pAL10) ── + - name: 2024_08_14_ZIKV_pal17_48h_pAL10 + data_path: ${datasets_root}/datasets/2024_08_14_ZIKV_pal17_48h/2024_08_14_ZIKV_pal17_48h_sharded.zarr + tracks_path: ${datasets_root}/datasets/2024_08_14_ZIKV_pal17_48h/tracking.zarr/2024_08_14_ZIKV_pal17_48h.zarr + channels: + - name: MultiCam_GFP_BF + marker: pAL10 + perturbation_wells: + uninfected: + - "0/3" + ZIKV: + - "0/4" + - "0/5" + - "0/6" + interval_minutes: 30.0 + start_hpi: 3.0 + marker: pAL10 + organelle: viral_sensor + moi: 1.0 + pixel_size_xy_um: 0.206 + pixel_size_z_um: 0.2878 + + - name: 2024_08_14_ZIKV_pal17_48h_Phase3D + data_path: ${datasets_root}/datasets/2024_08_14_ZIKV_pal17_48h/2024_08_14_ZIKV_pal17_48h_sharded.zarr + tracks_path: ${datasets_root}/datasets/2024_08_14_ZIKV_pal17_48h/tracking.zarr/2024_08_14_ZIKV_pal17_48h.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - "0/3" + ZIKV: + - "0/4" + - "0/5" + - "0/6" + interval_minutes: 30.0 + start_hpi: 3.0 + marker: Phase3D + organelle: Phase3D + moi: 1.0 + pixel_size_xy_um: 0.206 + pixel_size_z_um: 0.2878 + + # ══════════════════════════════════════════════════════════════════════ + # Microglia dynamorph — 2D label-free (BF, Phase3D, Retardance) + # ══════════════════════════════════════════════════════════════════════ + + - name: 20191107_GW23_dynamorph_Brightfield + data_path: ${datasets_root}/datasets/20191107_1209_1_GW23_dynamorph/20191107_1209_1_GW23_dynamorph.zarr + tracks_path: ${datasets_root}/datasets/20191107_1209_1_GW23_dynamorph/tracking.zarr + channels: + - name: Brightfield + marker: Brightfield + perturbation_wells: + untreated: + - C/5 + IL17: + - B/4 + IFN-beta: + - B/5 + Rubella: + - C/4 + Glioblastoma_supernatant: + - B/2 + interval_minutes: 9.0 + start_hpi: 0.0 + marker: Brightfield + organelle: Brightfield + pixel_size_xy_um: 0.325 + + - name: 20191107_GW23_dynamorph_Phase3D + data_path: ${datasets_root}/datasets/20191107_1209_1_GW23_dynamorph/20191107_1209_1_GW23_dynamorph.zarr + tracks_path: ${datasets_root}/datasets/20191107_1209_1_GW23_dynamorph/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + untreated: + - C/5 + IL17: + - B/4 + IFN-beta: + - B/5 + Rubella: + - C/4 + Glioblastoma_supernatant: + - B/2 + interval_minutes: 9.0 + start_hpi: 0.0 + marker: Phase3D + organelle: Phase3D + pixel_size_xy_um: 0.325 + + - name: 20191107_GW23_dynamorph_Retardance + data_path: ${datasets_root}/datasets/20191107_1209_1_GW23_dynamorph/20191107_1209_1_GW23_dynamorph.zarr + tracks_path: ${datasets_root}/datasets/20191107_1209_1_GW23_dynamorph/tracking.zarr + channels: + - name: Retardance + marker: Retardance + perturbation_wells: + untreated: + - C/5 + IL17: + - B/4 + IFN-beta: + - B/5 + Rubella: + - C/4 + Glioblastoma_supernatant: + - B/2 + interval_minutes: 9.0 + start_hpi: 0.0 + marker: Retardance + organelle: Retardance + pixel_size_xy_um: 0.325 + + # ══════════════════════════════════════════════════════════════════════ + # ALFI — 2D DIC mitosis datasets (multiple cell types) + # ══════════════════════════════════════════════════════════════════════ + + - name: ALFI_U2OS_DMSO_MLN8237 + data_path: ${datasets_root}/datasets/ALFI_U2OS_DMSO_MLN8237/ALFI_U2OS_DMSO_MLN8237.zarr + tracks_path: ${datasets_root}/datasets/ALFI_U2OS_DMSO_MLN8237/tracking.zarr + channels: + - name: DIC + marker: DIC + perturbation_wells: + DMSO: + - MI02/0 + MLN8237: + - MI01/0 + - MI03/0 + - MI04/0 + - MI05/0 + interval_minutes: 7.0 + start_hpi: 0.0 + marker: DIC + organelle: DIC + pixel_size_xy_um: 0.1766 + + - name: ALFI_RPE1_untreated + data_path: ${datasets_root}/datasets/ALFI_RPE1_untreated/ALFI_RPE1_untreated.zarr + tracks_path: ${datasets_root}/datasets/ALFI_RPE1_untreated/tracking.zarr + channels: + - name: DIC + marker: DIC + perturbation_wells: + untreated: + - MI07/0 + - MI08/0 + interval_minutes: 7.0 + start_hpi: 0.0 + marker: DIC + organelle: DIC + pixel_size_xy_um: 0.2631 + + - name: ALFI_HeLa_DMSO_MLN8237 + data_path: ${datasets_root}/datasets/ALFI_HeLa_DMSO_MLN8237/ALFI_HeLa_DMSO_MLN8237.zarr + tracks_path: ${datasets_root}/datasets/ALFI_HeLa_DMSO_MLN8237/tracking.zarr + channels: + - name: DIC + marker: DIC + perturbation_wells: + DMSO: + - MI06/0 + interval_minutes: 7.0 + start_hpi: 0.0 + marker: DIC + organelle: DIC + pixel_size_xy_um: 0.2631 diff --git a/applications/dynaclr/configs/collections/DynaCLR-3D-BagOfChannels-v2.yml b/applications/dynaclr/configs/collections/DynaCLR-3D-BagOfChannels-v2.yml index d15d22bc6..c71b97a79 100644 --- a/applications/dynaclr/configs/collections/DynaCLR-3D-BagOfChannels-v2.yml +++ b/applications/dynaclr/configs/collections/DynaCLR-3D-BagOfChannels-v2.yml @@ -1,5 +1,6 @@ name: DynaCLR-3D-BagOfChannels-v2 description: "Multi-organelle bag-of-channels 3D DynaCLR training collection. Each experiment entry is a single-channel bag: H2B (chromatin), CAAX (membrane), TOMM20 (mitochondria), SEC61B (ER), G3BP1 (stress granules), viral sensor (mCherry/pAL10), and Phase3D (label-free). Includes dragonfly confocal (2024_08_14_ZIKV_pal17_48h) for cross-microscope training. All data paths point to VAST (zarr v3, rechunked)." +datasets_root: /hpc/projects/organelle_phenotyping provenance: airtable_base_id: app8vqaoWyOwa0sB5 @@ -11,8 +12,8 @@ provenance: experiments: # ── G3BP1 (stress granules) ── - name: 2025_01_28_A549_G3BP1_ZIKV_DENV_G3BP1 - data_path: /hpc/projects/organelle_phenotyping/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr - tracks_path: /hpc/projects/organelle_phenotyping/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr channels: - name: raw GFP EX488 EM525-45 marker: G3BP1 @@ -30,9 +31,49 @@ experiments: pixel_size_xy_um: 0.1494 pixel_size_z_um: 0.174 + - name: 2025_01_28_A549_viral_sensor_ZIKV_DENV + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/4 + infected: + - B/2 + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_01_28_A549_Phase3D_ZIKV_DENV + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/4 + infected: + - B/2 + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: Phase3D + organelle: label_free + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + - name: 2025_07_24_A549_G3BP1_ZIKV - data_path: /hpc/projects/organelle_phenotyping/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr - tracks_path: /hpc/projects/organelle_phenotyping/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr channels: - name: raw GFP EX488 EM525-45 marker: G3BP1 @@ -51,8 +92,8 @@ experiments: # ── CAAX (membrane) ── - name: 2025_04_15_A549_H2B_CAAX_ZIKV_DENV_CAAX - data_path: /hpc/projects/organelle_phenotyping/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/2025_04_15_A549_H2B_CAAX_ZIKV_DENV.zarr - tracks_path: /hpc/projects/organelle_phenotyping/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/tracking.zarr + data_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/2025_04_15_A549_H2B_CAAX_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/tracking.zarr channels: - name: raw mCherry EX561 EM600-37 marker: CAAX @@ -70,8 +111,8 @@ experiments: pixel_size_z_um: 0.174 - name: 2025_04_17_A549_H2B_CAAX_DENV_CAAX - data_path: /hpc/projects/organelle_phenotyping/datasets/2025_04_17_A549_H2B_CAAX_DENV/2025_04_17_A549_H2B_CAAX_DENV.zarr - tracks_path: /hpc/projects/organelle_phenotyping/datasets/2025_04_17_A549_H2B_CAAX_DENV/tracking.zarr + data_path: ${datasets_root}/datasets/2025_04_17_A549_H2B_CAAX_DENV/2025_04_17_A549_H2B_CAAX_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_17_A549_H2B_CAAX_DENV/tracking.zarr channels: - name: raw mCherry EX561 EM600-37 marker: CAAX @@ -90,8 +131,8 @@ experiments: # ── H2B (chromatin) ── - name: 2025_04_15_A549_H2B_CAAX_ZIKV_DENV_H2B - data_path: /hpc/projects/organelle_phenotyping/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/2025_04_15_A549_H2B_CAAX_ZIKV_DENV.zarr - tracks_path: /hpc/projects/organelle_phenotyping/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/tracking.zarr + data_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/2025_04_15_A549_H2B_CAAX_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/tracking.zarr channels: - name: raw Cy5 EX639 EM698-70 marker: HIST2H2BE @@ -109,8 +150,8 @@ experiments: pixel_size_z_um: 0.174 - name: 2025_04_17_A549_H2B_CAAX_DENV_H2B - data_path: /hpc/projects/organelle_phenotyping/datasets/2025_04_17_A549_H2B_CAAX_DENV/2025_04_17_A549_H2B_CAAX_DENV.zarr - tracks_path: /hpc/projects/organelle_phenotyping/datasets/2025_04_17_A549_H2B_CAAX_DENV/tracking.zarr + data_path: ${datasets_root}/datasets/2025_04_17_A549_H2B_CAAX_DENV/2025_04_17_A549_H2B_CAAX_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_17_A549_H2B_CAAX_DENV/tracking.zarr channels: - name: raw Cy5 EX639 EM698-70 marker: HIST2H2BE @@ -129,8 +170,8 @@ experiments: # ── TOMM20 (mitochondria) ── - name: 2024_10_09_A549_TOMM20_ZIKV_DENV - data_path: /hpc/projects/organelle_phenotyping/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/2024_10_09_A549_TOMM20_ZIKV_DENV.zarr - tracks_path: /hpc/projects/organelle_phenotyping/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/tracking.zarr + data_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/2024_10_09_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/tracking.zarr channels: - name: raw GFP EX488 EM525-45 marker: TOMM20 @@ -147,9 +188,47 @@ experiments: pixel_size_xy_um: 0.1494 pixel_size_z_um: 0.174 + - name: 2024_10_09_A549_TOMM20_ZIKV_DENV_viral_sensor + data_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/2024_10_09_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - A/4 + infected: + - B/4 + interval_minutes: 30.0 + start_hpi: 5.0 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_09_A549_TOMM20_ZIKV_DENV_Phase3D + data_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/2024_10_09_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - A/4 + infected: + - B/4 + interval_minutes: 30.0 + start_hpi: 5.0 + marker: Phase3D + organelle: label_free + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + - name: 2024_11_05_A549_TOMM20_ZIKV_DENV_TOMM20 - data_path: /hpc/projects/organelle_phenotyping/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/2024_11_05_A549_TOMM20_ZIKV_DENV.zarr - tracks_path: /hpc/projects/organelle_phenotyping/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/tracking.zarr + data_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/2024_11_05_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/tracking.zarr channels: - name: raw GFP EX488 EM525-45 marker: TOMM20 @@ -167,8 +246,8 @@ experiments: pixel_size_z_um: 0.174 - name: 2025_07_24_A549_TOMM20_ZIKV - data_path: /hpc/projects/organelle_phenotyping/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr - tracks_path: /hpc/projects/organelle_phenotyping/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr channels: - name: raw GFP EX488 EM525-45 marker: TOMM20 @@ -187,8 +266,8 @@ experiments: # ── SEC61B (endoplasmic reticulum) ── - name: 2024_10_16_A549_SEC61_ZIKV_DENV - data_path: /hpc/projects/organelle_phenotyping/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/2024_10_16_A549_SEC61_ZIKV_DENV.zarr - tracks_path: /hpc/projects/organelle_phenotyping/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/tracking.zarr + data_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/2024_10_16_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/tracking.zarr channels: - name: raw GFP EX488 EM525-45 marker: SEC61B @@ -205,9 +284,47 @@ experiments: pixel_size_xy_um: 0.1494 pixel_size_z_um: 0.174 + - name: 2024_10_16_A549_SEC61_ZIKV_DENV_viral_sensor + data_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/2024_10_16_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_16_A549_SEC61_ZIKV_DENV_Phase3D + data_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/2024_10_16_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: Phase3D + organelle: label_free + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + - name: 2024_10_31_A549_SEC61_ZIKV_DENV_SEC61 - data_path: /hpc/projects/organelle_phenotyping/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/2024_10_31_A549_SEC61_ZIKV_DENV.zarr - tracks_path: /hpc/projects/organelle_phenotyping/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/tracking.zarr + data_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/2024_10_31_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/tracking.zarr channels: - name: raw GFP EX488 EM525-45 marker: SEC61B @@ -225,8 +342,8 @@ experiments: pixel_size_z_um: 0.174 - name: 2025_07_24_A549_SEC61_ZIKV - data_path: /hpc/projects/organelle_phenotyping/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr - tracks_path: /hpc/projects/organelle_phenotyping/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr channels: - name: raw GFP EX488 EM525-45 marker: SEC61B @@ -244,9 +361,9 @@ experiments: pixel_size_z_um: 0.174 # ── Viral sensor (mCherry) ── - - name: 2025_07_24_A549_viral_sensor - data_path: /hpc/projects/organelle_phenotyping/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr - tracks_path: /hpc/projects/organelle_phenotyping/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + - name: 2025_07_24_A549_viral_sensor_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr channels: - name: raw mCherry EX561 EM600-37 marker: viral_sensor @@ -267,8 +384,8 @@ experiments: pixel_size_z_um: 0.174 - name: 2024_11_05_A549_TOMM20_ZIKV_DENV_viral_sensor - data_path: /hpc/projects/organelle_phenotyping/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/2024_11_05_A549_TOMM20_ZIKV_DENV.zarr - tracks_path: /hpc/projects/organelle_phenotyping/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/tracking.zarr + data_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/2024_11_05_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/tracking.zarr channels: - name: raw mCherry EX561 EM600-37 marker: viral_sensor @@ -286,8 +403,8 @@ experiments: pixel_size_z_um: 0.174 - name: 2024_10_31_A549_SEC61_ZIKV_DENV_viral_sensor - data_path: /hpc/projects/organelle_phenotyping/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/2024_10_31_A549_SEC61_ZIKV_DENV.zarr - tracks_path: /hpc/projects/organelle_phenotyping/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/tracking.zarr + data_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/2024_10_31_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/tracking.zarr channels: - name: raw mCherry EX561 EM600-37 marker: viral_sensor @@ -305,9 +422,9 @@ experiments: pixel_size_z_um: 0.174 # ── Phase3D (label-free) ── - - name: 2025_07_24_A549_Phase3D - data_path: /hpc/projects/organelle_phenotyping/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr - tracks_path: /hpc/projects/organelle_phenotyping/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + - name: 2025_07_24_A549_Phase3D_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr channels: - name: Phase3D marker: Phase3D @@ -328,8 +445,8 @@ experiments: pixel_size_z_um: 0.174 - name: 2024_11_05_A549_TOMM20_ZIKV_DENV_Phase3D - data_path: /hpc/projects/organelle_phenotyping/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/2024_11_05_A549_TOMM20_ZIKV_DENV.zarr - tracks_path: /hpc/projects/organelle_phenotyping/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/tracking.zarr + data_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/2024_11_05_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/tracking.zarr channels: - name: Phase3D marker: Phase3D @@ -347,8 +464,8 @@ experiments: pixel_size_z_um: 0.174 - name: 2024_10_31_A549_SEC61_ZIKV_DENV_Phase3D - data_path: /hpc/projects/organelle_phenotyping/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/2024_10_31_A549_SEC61_ZIKV_DENV.zarr - tracks_path: /hpc/projects/organelle_phenotyping/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/tracking.zarr + data_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/2024_10_31_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/tracking.zarr channels: - name: Phase3D marker: Phase3D @@ -367,8 +484,8 @@ experiments: # ── Dragonfly confocal — viral sensor (pAL10) ── - name: 2024_08_14_ZIKV_pal17_48h_pAL10 - data_path: /hpc/projects/organelle_phenotyping/datasets/2024_08_14_ZIKV_pal17_48h/2024_08_14_ZIKV_pal17_48h_sharded.zarr - tracks_path: /hpc/projects/intracellular_dashboard/viral-sensor/2024_08_14_ZIKV_pal17_48h/5-tracking/2024_08_14_ZIKV_pal17_48h.zarr + data_path: ${datasets_root}/datasets/2024_08_14_ZIKV_pal17_48h/2024_08_14_ZIKV_pal17_48h_sharded.zarr + tracks_path: ${datasets_root}/datasets/2024_08_14_ZIKV_pal17_48h/tracking.zarr channels: - name: MultiCam_GFP_BF marker: pAL10 @@ -389,8 +506,8 @@ experiments: # ── Dragonfly confocal — Phase3D (label-free) ── - name: 2024_08_14_ZIKV_pal17_48h_Phase3D - data_path: /hpc/projects/organelle_phenotyping/datasets/2024_08_14_ZIKV_pal17_48h/2024_08_14_ZIKV_pal17_48h_sharded.zarr - tracks_path: /hpc/projects/intracellular_dashboard/viral-sensor/2024_08_14_ZIKV_pal17_48h/5-tracking/2024_08_14_ZIKV_pal17_48h.zarr + data_path: ${datasets_root}/datasets/2024_08_14_ZIKV_pal17_48h/2024_08_14_ZIKV_pal17_48h_sharded.zarr + tracks_path: ${datasets_root}/datasets/2024_08_14_ZIKV_pal17_48h/tracking.zarr channels: - name: Phase3D marker: Phase3D diff --git a/applications/dynaclr/configs/collections/DynaCLR-3D-BagOfChannels-v3.yml b/applications/dynaclr/configs/collections/DynaCLR-3D-BagOfChannels-v3.yml new file mode 100644 index 000000000..fc53aab45 --- /dev/null +++ b/applications/dynaclr/configs/collections/DynaCLR-3D-BagOfChannels-v3.yml @@ -0,0 +1,527 @@ +name: DynaCLR-3D-BagOfChannels-v3 +description: "Multi-organelle bag-of-channels 3D DynaCLR training collection. Each experiment entry is a single-channel bag: H2B (chromatin), CAAX (membrane), TOMM20 (mitochondria), SEC61B (ER), G3BP1 (stress granules), viral sensor (mCherry/pAL10), and Phase3D (label-free). Includes dragonfly confocal (2024_08_14_ZIKV_pal17_48h) for cross-microscope training. All data paths point to VAST (zarr v3, rechunked)." +datasets_root: /hpc/projects/organelle_phenotyping + +provenance: + airtable_base_id: app8vqaoWyOwa0sB5 + airtable_query: "OR(SEARCH(\"2024_10_09_A549_TOMM20_ZIKV_DENV\", {dataset}), SEARCH(\"2024_11_05_A549_TOMM20_ZIKV_DENV\", {dataset}), SEARCH(\"2024_10_16_A549_SEC61_ZIKV_DENV\", {dataset}), SEARCH(\"2024_10_31_A549_SEC61_ZIKV_DENV\", {dataset}), SEARCH(\"2025_01_28_A549_G3BP1_ZIKV_DENV\", {dataset}), SEARCH(\"2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV\", {dataset}), SEARCH(\"2025_04_15_A549_H2B_CAAX_ZIKV_DENV\", {dataset}), SEARCH(\"2025_04_17_A549_H2B_CAAX_DENV\", {dataset}), SEARCH(\"2024_08_14_ZIKV_pal17_48h\", {dataset}))" + record_ids: [] + created_at: "2026-04-10T00:00:00" + created_by: "eduardo.hirata" + +experiments: + # ── G3BP1 (stress granules) ── + - name: 2025_01_28_A549_G3BP1_ZIKV_DENV_G3BP1 + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: G3BP1 + perturbation_wells: + uninfected: + - B/4 + infected: + - B/2 + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: G3BP1 + organelle: stress_granules + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_01_28_A549_viral_sensor_ZIKV_DENV + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/4 + infected: + - B/2 + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_01_28_A549_Phase3D_ZIKV_DENV + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/4 + infected: + - B/2 + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: Phase3D + organelle: label_free + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_07_24_A549_G3BP1_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: G3BP1 + perturbation_wells: + uninfected: + - C/1 + infected: + - C/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: G3BP1 + organelle: stress_granules + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── CAAX (membrane) ── + - name: 2025_04_15_A549_H2B_CAAX_ZIKV_DENV_CAAX + data_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/2025_04_15_A549_H2B_CAAX_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: CAAX + perturbation_wells: + uninfected: + - B/1 + DENV: + - B/2 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: CAAX + organelle: membrane + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_04_17_A549_H2B_CAAX_DENV_CAAX + data_path: ${datasets_root}/datasets/2025_04_17_A549_H2B_CAAX_DENV/2025_04_17_A549_H2B_CAAX_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_17_A549_H2B_CAAX_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: CAAX + perturbation_wells: + uninfected: + - B/1 + DENV: + - B/2 + interval_minutes: 10.0 + start_hpi: 3.0 + marker: CAAX + organelle: membrane + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── H2B (chromatin) ── + - name: 2025_04_15_A549_H2B_CAAX_ZIKV_DENV_H2B + data_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/2025_04_15_A549_H2B_CAAX_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/tracking.zarr + channels: + - name: raw Cy5 EX639 EM698-70 + marker: HIST2H2BE + perturbation_wells: + uninfected: + - B/1 + DENV: + - B/2 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: HIST2H2BE + organelle: chromatin + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_04_17_A549_H2B_CAAX_DENV_H2B + data_path: ${datasets_root}/datasets/2025_04_17_A549_H2B_CAAX_DENV/2025_04_17_A549_H2B_CAAX_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_17_A549_H2B_CAAX_DENV/tracking.zarr + channels: + - name: raw Cy5 EX639 EM698-70 + marker: HIST2H2BE + perturbation_wells: + uninfected: + - B/1 + DENV: + - B/2 + interval_minutes: 10.0 + start_hpi: 3.0 + marker: HIST2H2BE + organelle: chromatin + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── TOMM20 (mitochondria) ── + - name: 2024_10_09_A549_TOMM20_ZIKV_DENV + data_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/2024_10_09_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: TOMM20 + perturbation_wells: + uninfected: + - A/4 + infected: + - B/4 + interval_minutes: 30.0 + start_hpi: 5.0 + marker: TOMM20 + organelle: mitochondria + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_09_A549_TOMM20_ZIKV_DENV_viral_sensor + data_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/2024_10_09_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - A/4 + infected: + - B/4 + interval_minutes: 30.0 + start_hpi: 5.0 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_09_A549_TOMM20_ZIKV_DENV_Phase3D + data_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/2024_10_09_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - A/4 + infected: + - B/4 + interval_minutes: 30.0 + start_hpi: 5.0 + marker: Phase3D + organelle: label_free + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_11_05_A549_TOMM20_ZIKV_DENV_TOMM20 + data_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/2024_11_05_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: TOMM20 + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.5 + marker: TOMM20 + organelle: mitochondria + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_07_24_A549_TOMM20_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: TOMM20 + perturbation_wells: + uninfected: + - B/1 + infected: + - B/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: TOMM20 + organelle: mitochondria + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── SEC61B (endoplasmic reticulum) ── + - name: 2024_10_16_A549_SEC61_ZIKV_DENV + data_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/2024_10_16_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: SEC61B + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: SEC61B + organelle: endoplasmic_reticulum + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_16_A549_SEC61_ZIKV_DENV_viral_sensor + data_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/2024_10_16_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_16_A549_SEC61_ZIKV_DENV_Phase3D + data_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/2024_10_16_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: Phase3D + organelle: label_free + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_31_A549_SEC61_ZIKV_DENV_SEC61 + data_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/2024_10_31_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: SEC61B + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: SEC61B + organelle: endoplasmic_reticulum + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_07_24_A549_SEC61_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: SEC61B + perturbation_wells: + uninfected: + - A/1 + infected: + - A/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: SEC61B + organelle: endoplasmic_reticulum + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── Viral sensor (mCherry) ── + - name: 2025_07_24_A549_viral_sensor_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - C/1 + - B/1 + infected: + - C/2 + - B/2 + - A/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_11_05_A549_TOMM20_ZIKV_DENV_viral_sensor + data_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/2024_11_05_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.5 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_31_A549_SEC61_ZIKV_DENV_viral_sensor + data_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/2024_10_31_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── Phase3D (label-free) ── + - name: 2025_07_24_A549_Phase3D_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - C/1 + - B/1 + infected: + - C/2 + - B/2 + - A/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: Phase3D + organelle: label_free + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_11_05_A549_TOMM20_ZIKV_DENV_Phase3D + data_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/2024_11_05_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.5 + marker: Phase3D + organelle: label_free + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_31_A549_SEC61_ZIKV_DENV_Phase3D + data_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/2024_10_31_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: Phase3D + organelle: label_free + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── Dragonfly confocal — viral sensor (pAL10) ── + - name: 2024_08_14_ZIKV_pal17_48h_pAL10 + data_path: ${datasets_root}/datasets/2024_08_14_ZIKV_pal17_48h/2024_08_14_ZIKV_pal17_48h_sharded.zarr + tracks_path: ${datasets_root}/datasets/2024_08_14_ZIKV_pal17_48h/tracking.zarr/2024_08_14_ZIKV_pal17_48h.zarr + channels: + - name: MultiCam_GFP_BF + marker: pAL10 + perturbation_wells: + uninfected: + - "0/3" + ZIKV: + - "0/4" + - "0/5" + - "0/6" + interval_minutes: 30.0 + start_hpi: 3.0 + marker: pAL10 + organelle: viral_sensor + moi: 1.0 + pixel_size_xy_um: 0.206 + pixel_size_z_um: 0.2878 + + # ── Dragonfly confocal — Phase3D (label-free) ── + - name: 2024_08_14_ZIKV_pal17_48h_Phase3D + data_path: ${datasets_root}/datasets/2024_08_14_ZIKV_pal17_48h/2024_08_14_ZIKV_pal17_48h_sharded.zarr + tracks_path: ${datasets_root}/datasets/2024_08_14_ZIKV_pal17_48h/tracking.zarr/2024_08_14_ZIKV_pal17_48h.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - "0/3" + ZIKV: + - "0/4" + - "0/5" + - "0/6" + interval_minutes: 30.0 + start_hpi: 3.0 + marker: Phase3D + organelle: label_free + moi: 1.0 + pixel_size_xy_um: 0.206 + pixel_size_z_um: 0.2878 diff --git a/applications/dynaclr/configs/collections/DynaCLR-3D-BagOfChannels-v4.yml b/applications/dynaclr/configs/collections/DynaCLR-3D-BagOfChannels-v4.yml new file mode 100644 index 000000000..23787e77d --- /dev/null +++ b/applications/dynaclr/configs/collections/DynaCLR-3D-BagOfChannels-v4.yml @@ -0,0 +1,527 @@ +name: DynaCLR-3D-BagOfChannels-v2 +description: "Multi-organelle bag-of-channels 3D DynaCLR training collection. Each experiment entry is a single-channel bag: H2B (chromatin), CAAX (membrane), TOMM20 (mitochondria), SEC61B (ER), G3BP1 (stress granules), viral sensor (mCherry/pAL10), and Phase3D (label-free). Includes dragonfly confocal (2024_08_14_ZIKV_pal17_48h) for cross-microscope training. All data paths point to VAST (zarr v3, rechunked)." +datasets_root: /hpc/projects/organelle_phenotyping + +provenance: + airtable_base_id: app8vqaoWyOwa0sB5 + airtable_query: "OR(SEARCH(\"2024_10_09_A549_TOMM20_ZIKV_DENV\", {dataset}), SEARCH(\"2024_11_05_A549_TOMM20_ZIKV_DENV\", {dataset}), SEARCH(\"2024_10_16_A549_SEC61_ZIKV_DENV\", {dataset}), SEARCH(\"2024_10_31_A549_SEC61_ZIKV_DENV\", {dataset}), SEARCH(\"2025_01_28_A549_G3BP1_ZIKV_DENV\", {dataset}), SEARCH(\"2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV\", {dataset}), SEARCH(\"2025_04_15_A549_H2B_CAAX_ZIKV_DENV\", {dataset}), SEARCH(\"2025_04_17_A549_H2B_CAAX_DENV\", {dataset}), SEARCH(\"2024_08_14_ZIKV_pal17_48h\", {dataset}))" + record_ids: [] + created_at: "2026-03-27T00:00:00" + created_by: "eduardo.hirata" + +experiments: + # ── G3BP1 (stress granules) ── + - name: 2025_01_28_A549_G3BP1_ZIKV_DENV_G3BP1 + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: G3BP1 + perturbation_wells: + uninfected: + - B/4 + infected: + - B/2 + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: G3BP1 + organelle: stress_granules + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_01_28_A549_viral_sensor_ZIKV_DENV + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/4 + infected: + - B/2 + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_01_28_A549_Phase3D_ZIKV_DENV + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/4 + infected: + - B/2 + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: Phase3D + organelle: label_free + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_07_24_A549_G3BP1_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: G3BP1 + perturbation_wells: + uninfected: + - C/1 + infected: + - C/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: G3BP1 + organelle: stress_granules + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── CAAX (membrane) ── + - name: 2025_04_15_A549_H2B_CAAX_ZIKV_DENV_CAAX + data_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/2025_04_15_A549_H2B_CAAX_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: CAAX + perturbation_wells: + uninfected: + - B/1 + DENV: + - B/2 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: CAAX + organelle: membrane + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_04_17_A549_H2B_CAAX_DENV_CAAX + data_path: ${datasets_root}/datasets/2025_04_17_A549_H2B_CAAX_DENV/2025_04_17_A549_H2B_CAAX_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_17_A549_H2B_CAAX_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: CAAX + perturbation_wells: + uninfected: + - B/1 + DENV: + - B/2 + interval_minutes: 10.0 + start_hpi: 3.0 + marker: CAAX + organelle: membrane + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── H2B (chromatin) ── + - name: 2025_04_15_A549_H2B_CAAX_ZIKV_DENV_H2B + data_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/2025_04_15_A549_H2B_CAAX_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/tracking.zarr + channels: + - name: raw Cy5 EX639 EM698-70 + marker: HIST2H2BE + perturbation_wells: + uninfected: + - B/1 + DENV: + - B/2 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: HIST2H2BE + organelle: chromatin + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_04_17_A549_H2B_CAAX_DENV_H2B + data_path: ${datasets_root}/datasets/2025_04_17_A549_H2B_CAAX_DENV/2025_04_17_A549_H2B_CAAX_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_17_A549_H2B_CAAX_DENV/tracking.zarr + channels: + - name: raw Cy5 EX639 EM698-70 + marker: HIST2H2BE + perturbation_wells: + uninfected: + - B/1 + DENV: + - B/2 + interval_minutes: 10.0 + start_hpi: 3.0 + marker: HIST2H2BE + organelle: chromatin + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── TOMM20 (mitochondria) ── + - name: 2024_10_09_A549_TOMM20_ZIKV_DENV + data_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/2024_10_09_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: TOMM20 + perturbation_wells: + uninfected: + - A/4 + infected: + - B/4 + interval_minutes: 30.0 + start_hpi: 5.0 + marker: TOMM20 + organelle: mitochondria + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_09_A549_TOMM20_ZIKV_DENV_viral_sensor + data_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/2024_10_09_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - A/4 + infected: + - B/4 + interval_minutes: 30.0 + start_hpi: 5.0 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_09_A549_TOMM20_ZIKV_DENV_Phase3D + data_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/2024_10_09_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - A/4 + infected: + - B/4 + interval_minutes: 30.0 + start_hpi: 5.0 + marker: Phase3D + organelle: label_free + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_11_05_A549_TOMM20_ZIKV_DENV_TOMM20 + data_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/2024_11_05_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: TOMM20 + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.5 + marker: TOMM20 + organelle: mitochondria + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_07_24_A549_TOMM20_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: TOMM20 + perturbation_wells: + uninfected: + - B/1 + infected: + - B/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: TOMM20 + organelle: mitochondria + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── SEC61B (endoplasmic reticulum) ── + - name: 2024_10_16_A549_SEC61_ZIKV_DENV + data_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/2024_10_16_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: SEC61B + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: SEC61B + organelle: endoplasmic_reticulum + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_16_A549_SEC61_ZIKV_DENV_viral_sensor + data_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/2024_10_16_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_16_A549_SEC61_ZIKV_DENV_Phase3D + data_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/2024_10_16_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: Phase3D + organelle: label_free + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_31_A549_SEC61_ZIKV_DENV_SEC61 + data_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/2024_10_31_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: SEC61B + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: SEC61B + organelle: endoplasmic_reticulum + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_07_24_A549_SEC61_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: SEC61B + perturbation_wells: + uninfected: + - A/1 + infected: + - A/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: SEC61B + organelle: endoplasmic_reticulum + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── Viral sensor (mCherry) ── + - name: 2025_07_24_A549_viral_sensor_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - C/1 + - B/1 + infected: + - C/2 + - B/2 + - A/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_11_05_A549_TOMM20_ZIKV_DENV_viral_sensor + data_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/2024_11_05_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.5 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_31_A549_SEC61_ZIKV_DENV_viral_sensor + data_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/2024_10_31_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── Phase3D (label-free) ── + - name: 2025_07_24_A549_Phase3D_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - C/1 + - B/1 + infected: + - C/2 + - B/2 + - A/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: Phase3D + organelle: label_free + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_11_05_A549_TOMM20_ZIKV_DENV_Phase3D + data_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/2024_11_05_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.5 + marker: Phase3D + organelle: label_free + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_31_A549_SEC61_ZIKV_DENV_Phase3D + data_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/2024_10_31_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: Phase3D + organelle: label_free + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── Dragonfly confocal — viral sensor (pAL10) ── + - name: 2024_08_14_ZIKV_pal17_48h_pAL10 + data_path: ${datasets_root}/datasets/2024_08_14_ZIKV_pal17_48h/2024_08_14_ZIKV_pal17_48h_sharded.zarr + tracks_path: ${datasets_root}/datasets/2024_08_14_ZIKV_pal17_48h/tracking.zarr/2024_08_14_ZIKV_pal17_48h.zarr + channels: + - name: MultiCam_GFP_BF + marker: pAL10 + perturbation_wells: + uninfected: + - "0/3" + ZIKV: + - "0/4" + - "0/5" + - "0/6" + interval_minutes: 30.0 + start_hpi: 3.0 + marker: pAL10 + organelle: viral_sensor + moi: 1.0 + pixel_size_xy_um: 0.206 + pixel_size_z_um: 0.2878 + + # ── Dragonfly confocal — Phase3D (label-free) ── + - name: 2024_08_14_ZIKV_pal17_48h_Phase3D + data_path: ${datasets_root}/datasets/2024_08_14_ZIKV_pal17_48h/2024_08_14_ZIKV_pal17_48h_sharded.zarr + tracks_path: ${datasets_root}/datasets/2024_08_14_ZIKV_pal17_48h/tracking.zarr/2024_08_14_ZIKV_pal17_48h.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - "0/3" + ZIKV: + - "0/4" + - "0/5" + - "0/6" + interval_minutes: 30.0 + start_hpi: 3.0 + marker: Phase3D + organelle: label_free + moi: 1.0 + pixel_size_xy_um: 0.206 + pixel_size_z_um: 0.2878 diff --git a/applications/dynaclr/configs/collections/DynaCLR-BoC-lc-evaluation-v1-test.yml b/applications/dynaclr/configs/collections/DynaCLR-BoC-lc-evaluation-v1-test.yml new file mode 100644 index 000000000..0b03d5401 --- /dev/null +++ b/applications/dynaclr/configs/collections/DynaCLR-BoC-lc-evaluation-v1-test.yml @@ -0,0 +1,174 @@ +name: DynaCLR-BoC-lc-evaluation-v1-test +description: "Minimal subset of DynaCLR-BoC-lc-evaluation-v1 for fast end-to-end + testing of MMD and linear classifier evaluation. Three markers (G3BP1, Phase3D, + viral_sensor) across two dates (2025_07_22 and 2025_07_24) + one G3BP1-only + experiment (2025_01_28). Enables cross-experiment MMD for all three markers." +datasets_root: /hpc/projects/organelle_phenotyping + +provenance: + airtable_base_id: app8vqaoWyOwa0sB5 + airtable_query: "OR(SEARCH(\"2025_01_28_A549_G3BP1_ZIKV_DENV\", {dataset}), SEARCH(\"2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV\", {dataset}), SEARCH(\"2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV\", {dataset}))" + record_ids: [] + created_at: "2026-04-09T00:00:00" + created_by: "eduardo.hirata" + +experiments: + # ── 2025_01_28: G3BP1, ZIKV + DENV — 1 uninfected + 1 infected well ── + - name: 2025_01_28_A549_G3BP1_ZIKV_DENV_G3BP1 + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: G3BP1 + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: G3BP1 + organelle: stress_granules + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── 2025_01_28: Phase3D, ZIKV + DENV — 1 uninfected + 1 infected well ── + - name: 2025_01_28_A549_Phase3D_ZIKV_DENV + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── 2025_07_24: G3BP1, ZIKV — 1 uninfected + 1 infected well ── + - name: 2025_07_24_A549_G3BP1_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: G3BP1 + perturbation_wells: + uninfected: + - C/1 + infected: + - C/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: G3BP1 + organelle: stress_granules + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── 2025_07_24: Phase3D, ZIKV — 1 uninfected + 1 infected well ── + - name: 2025_07_24_A549_Phase3D_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - C/1 + infected: + - C/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── 2025_07_22: G3BP1, ZIKV — 1 uninfected (C/1) + 1 infected (C/2) ── + - name: 2025_07_22_A549_G3BP1_ZIKV + data_path: ${datasets_root}/datasets/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: G3BP1 + perturbation_wells: + uninfected: + - C/1 + infected: + - C/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: G3BP1 + organelle: stress_granules + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── 2025_07_22: Phase3D, ZIKV — 1 uninfected (C/1) + 1 infected (C/2) ── + - name: 2025_07_22_A549_Phase3D_ZIKV + data_path: ${datasets_root}/datasets/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - C/1 + infected: + - C/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── 2025_07_22: viral_sensor, ZIKV — 1 uninfected (C/1) + 1 infected (C/2) ── + - name: 2025_07_22_A549_viral_sensor_ZIKV + data_path: ${datasets_root}/datasets/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - C/1 + infected: + - C/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── 2025_07_24: viral_sensor, ZIKV — 1 uninfected + 1 infected well ── + - name: 2025_07_24_A549_viral_sensor_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - C/1 + infected: + - C/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 diff --git a/applications/dynaclr/configs/collections/DynaCLR-BoC-lc-evaluation-v1.yml b/applications/dynaclr/configs/collections/DynaCLR-BoC-lc-evaluation-v1.yml new file mode 100644 index 000000000..f96572fd6 --- /dev/null +++ b/applications/dynaclr/configs/collections/DynaCLR-BoC-lc-evaluation-v1.yml @@ -0,0 +1,401 @@ +name: DynaCLR-BoC-lc-evaluation-v1 +description: "Annotated experiments for linear classifier evaluation of bag-of-channels DynaCLR models. + Includes all datasets with infection_state / cell_division_state annotations and processed zarr stores." +datasets_root: /hpc/projects/organelle_phenotyping + +provenance: + airtable_base_id: app8vqaoWyOwa0sB5 + airtable_query: "OR(SEARCH(\"2024_11_07\", {dataset}), SEARCH(\"2025_01_24\", {dataset}), SEARCH(\"2025_01_28_A549_G3BP1_ZIKV_DENV\", {dataset}), SEARCH(\"2025_07_22\", {dataset}), SEARCH(\"2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV\", {dataset}), SEARCH(\"2025_08_26\", {dataset}))" + record_ids: [] + created_at: "2026-04-09T00:00:00" + created_by: "eduardo.hirata" + +experiments: + # ── 2025_01_28: G3BP1 (stress granules), ZIKV + DENV ── + # Annotated wells: B/4 (uninfected), C/4 (infected) + - name: 2025_01_28_A549_G3BP1_ZIKV_DENV_G3BP1 + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: G3BP1 + perturbation_wells: + uninfected: + - B/4 + infected: + - B/2 + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: G3BP1 + organelle: stress_granules + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_01_28_A549_viral_sensor_ZIKV_DENV + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/4 + infected: + - B/2 + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_01_28_A549_Phase3D_ZIKV_DENV + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/4 + infected: + - B/2 + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── 2025_07_24: multi-channel (G3BP1, SEC61B, viral sensor, Phase3D), ZIKV ── + # Annotated wells: A/2 (infected), C/1 (uninfected), C/2 (infected) + # TOMM20 wells B/1, B/2 not annotated — excluded from this collection + - name: 2025_07_24_A549_G3BP1_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: G3BP1 + perturbation_wells: + uninfected: + - C/1 + infected: + - C/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: G3BP1 + organelle: stress_granules + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_07_24_A549_SEC61_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: SEC61B + perturbation_wells: + uninfected: + - A/1 + infected: + - A/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: SEC61B + organelle: endoplasmic_reticulum + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_07_24_A549_viral_sensor_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - A/1 + - C/1 + - B/1 + infected: + - C/2 + - B/2 + - A/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_07_24_A549_Phase3D_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - A/1 + - C/1 + - B/1 + infected: + - C/2 + - B/2 + - A/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── 2024_11_07: SEC61B (ER), DENV ── + # Annotated wells: B/3 (uninfected), C/2 (infected+uninfected) + - name: 2024_11_07_A549_SEC61_DENV_SEC61B + data_path: ${datasets_root}/datasets/2024_11_07_A549_SEC61_DENV/2024_11_07_A549_SEC61_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_11_07_A549_SEC61_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: SEC61B + perturbation_wells: + uninfected: + - B/3 + infected: + - C/2 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: SEC61B + organelle: endoplasmic_reticulum + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_11_07_A549_SEC61_DENV_viral_sensor + data_path: ${datasets_root}/datasets/2024_11_07_A549_SEC61_DENV/2024_11_07_A549_SEC61_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_11_07_A549_SEC61_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/3 + infected: + - C/2 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_11_07_A549_SEC61_DENV_Phase3D + data_path: ${datasets_root}/datasets/2024_11_07_A549_SEC61_DENV/2024_11_07_A549_SEC61_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_11_07_A549_SEC61_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/3 + infected: + - C/2 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── 2025_01_24: G3BP1 (stress granules), DENV ── + # Annotated wells: B/1 (uninfected), B/2 (infected), B/3 (uninfected), C/2 (infected) + - name: 2025_01_24_A549_G3BP1_DENV_G3BP1 + data_path: ${datasets_root}/datasets/2025_01_24_A549_G3BP1_DENV/2025_01_24_A549_G3BP1_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_24_A549_G3BP1_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: G3BP1 + perturbation_wells: + uninfected: + - B/1 + - B/3 + infected: + - B/2 + - C/2 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: G3BP1 + organelle: stress_granules + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_01_24_A549_G3BP1_DENV_viral_sensor + data_path: ${datasets_root}/datasets/2025_01_24_A549_G3BP1_DENV/2025_01_24_A549_G3BP1_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_24_A549_G3BP1_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/3 + infected: + - C/2 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_01_24_A549_G3BP1_DENV_Phase3D + data_path: ${datasets_root}/datasets/2025_01_24_A549_G3BP1_DENV/2025_01_24_A549_G3BP1_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_24_A549_G3BP1_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/1 + - B/3 + infected: + - B/2 + - C/2 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── 2025_07_22: G3BP1 (stress granules) + pAL17 (viral sensor), ZIKV ── + # Annotated wells: C/1 (uninfected), C/2 (infected) + - name: 2025_07_22_A549_G3BP1_ZIKV + data_path: ${datasets_root}/datasets/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: G3BP1 + perturbation_wells: + uninfected: + - C/1 + infected: + - C/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: G3BP1 + organelle: stress_granules + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_07_22_A549_viral_sensor_ZIKV + data_path: ${datasets_root}/datasets/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - C/1 + infected: + - C/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_07_22_A549_Phase3D_ZIKV + data_path: ${datasets_root}/datasets/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - C/1 + infected: + - C/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── 2025_08_26: SEC61B (ER), ZIKV ── + # Annotated wells: A/1 (uninfected), B/1 (infected+uninfected) + - name: 2025_08_26_A549_SEC61_ZIKV + data_path: ${datasets_root}/datasets/2025_08_26_A549_SEC61_TOMM20_ZIKV/2025_08_26_A549_SEC61_TOMM20_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_08_26_A549_SEC61_TOMM20_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: SEC61B + perturbation_wells: + uninfected: + - A/1 + infected: + - B/1 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: SEC61B + organelle: endoplasmic_reticulum + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_08_26_A549_viral_sensor_ZIKV + data_path: ${datasets_root}/datasets/2025_08_26_A549_SEC61_TOMM20_ZIKV/2025_08_26_A549_SEC61_TOMM20_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_08_26_A549_SEC61_TOMM20_ZIKV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - A/1 + infected: + - B/1 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_08_26_A549_Phase3D_ZIKV + data_path: ${datasets_root}/datasets/2025_08_26_A549_SEC61_TOMM20_ZIKV/2025_08_26_A549_SEC61_TOMM20_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_08_26_A549_SEC61_TOMM20_ZIKV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - A/1 + infected: + - B/1 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 diff --git a/applications/dynaclr/configs/collections/alfi-eval.yml b/applications/dynaclr/configs/collections/alfi-eval.yml new file mode 100644 index 000000000..3b66830f8 --- /dev/null +++ b/applications/dynaclr/configs/collections/alfi-eval.yml @@ -0,0 +1,55 @@ +name: alfi-eval +description: "ALFI mitosis evaluation collection. All 3 cell lines (HeLa MI06, RPE1 MI07/MI08, U2OS MI01-MI05), DIC channel. Analysis done per cell line." +datasets_root: /hpc/projects/organelle_phenotyping + +experiments: + - name: ALFI_HeLa_DMSO_MLN8237 + data_path: ${datasets_root}/datasets/ALFI_HeLa_DMSO_MLN8237/ALFI_HeLa_DMSO_MLN8237.zarr + tracks_path: ${datasets_root}/datasets/ALFI_HeLa_DMSO_MLN8237/tracking.zarr + channels: + - name: DIC + marker: DIC + perturbation_wells: + DMSO: + - MI06/0 + interval_minutes: 7.0 + start_hpi: 0.0 + marker: DIC + organelle: DIC + pixel_size_xy_um: 0.2631 + + - name: ALFI_RPE1_untreated + data_path: ${datasets_root}/datasets/ALFI_RPE1_untreated/ALFI_RPE1_untreated.zarr + tracks_path: ${datasets_root}/datasets/ALFI_RPE1_untreated/tracking.zarr + channels: + - name: DIC + marker: DIC + perturbation_wells: + untreated: + - MI07/0 + - MI08/0 + interval_minutes: 7.0 + start_hpi: 0.0 + marker: DIC + organelle: DIC + pixel_size_xy_um: 0.2631 + + - name: ALFI_U2OS_DMSO_MLN8237 + data_path: ${datasets_root}/datasets/ALFI_U2OS_DMSO_MLN8237/ALFI_U2OS_DMSO_MLN8237.zarr + tracks_path: ${datasets_root}/datasets/ALFI_U2OS_DMSO_MLN8237/tracking.zarr + channels: + - name: DIC + marker: DIC + perturbation_wells: + DMSO: + - MI02/0 + MLN8237: + - MI01/0 + - MI03/0 + - MI04/0 + - MI05/0 + interval_minutes: 7.0 + start_hpi: 0.0 + marker: DIC + organelle: DIC + pixel_size_xy_um: 0.1766 diff --git a/applications/dynaclr/configs/collections/benchmark_2exp.yml b/applications/dynaclr/configs/collections/benchmark_2exp.yml new file mode 100644 index 000000000..eeada4a1c --- /dev/null +++ b/applications/dynaclr/configs/collections/benchmark_2exp.yml @@ -0,0 +1,36 @@ +name: benchmark_2exp +description: "Benchmark collection: G3BP1 (2025_07_24) + H2B (2025_04_15) for dataloader profiling" +datasets_root: /hpc/projects/organelle_phenotyping + +experiments: + - name: 2025_07_24_A549_G3BP1_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: G3BP1 + perturbation_wells: + uninfected: + - C/1 + infected: + - C/2 + interval_minutes: 30.0 + start_hpi: 3.5 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_04_15_A549_H2B_CAAX_ZIKV_DENV_H2B + data_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/2025_04_15_A549_H2B_CAAX_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/tracking.zarr + channels: + - name: raw Cy5 EX639 EM698-70 + marker: HIST2H2BE + perturbation_wells: + uninfected: + - B/1 + DENV: + - B/2 + interval_minutes: 30.0 + start_hpi: 4.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 diff --git a/applications/dynaclr/configs/collections/example_mantis_dragonfly.yml b/applications/dynaclr/configs/collections/example_mantis_dragonfly.yml new file mode 100644 index 000000000..97953cea3 --- /dev/null +++ b/applications/dynaclr/configs/collections/example_mantis_dragonfly.yml @@ -0,0 +1,54 @@ +name: example_mantis_dragonfly +description: "Example collection combining mantis (lightsheet) and dragonfly (confocal) datasets. SEC61B from 2025_07_24 ZIKV experiment and pAL10 viral sensor from 2024_08_14 ZIKV experiment." + +provenance: + airtable_base_id: app8vqaoWyOwa0sB5 + airtable_query: "OR(SEARCH(\"2025_07_24\", {dataset}), SEARCH(\"2024_08_14\", {dataset}))" + record_ids: [] + created_at: "2026-04-01T00:00:00" + created_by: "eduardo.hirata" + +experiments: + - name: 2025_07_24_A549_SEC61_ZIKV + data_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/4-phenotyping/train-test/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/1-preprocess/label-free/3-track/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV_cropped.zarr + channels: + - name: Phase3D + marker: Phase3D + - name: GFP EX488 EM525-45 + marker: SEC61B + - name: mCherry EX561 EM600-37 + marker: mCherry + perturbation_wells: + ZIKV: + - A/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: SEC61B + organelle: endoplasmic_reticulum + microscope: mantis + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + moi: 5.0 + + - name: 2024_08_14_ZIKV_pal17_48h + data_path: /hpc/projects/organelle_phenotyping/datasets/2024_08_14_ZIKV_pal17_48h/2024_08_14_ZIKV_pal17_48h_sharded.zarr + tracks_path: /hpc/projects/organelle_phenotyping/datasets/2024_08_14_ZIKV_pal17_48h/2024_08_14_ZIKV_pal17_48h_timeaware_tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + - name: MultiCam_GFP_BF + marker: pAL10 + perturbation_wells: + uninfected: + - "0/3" + ZIKV: + - "0/4" + interval_minutes: 30.0 + start_hpi: 3.0 + marker: pAL10 + organelle: viral_sensor + microscope: dragonfly + pixel_size_xy_um: 0.206 + pixel_size_z_um: 0.2878069639205931 + moi: 5.0 diff --git a/applications/dynaclr/configs/collections/microglia-eval.yml b/applications/dynaclr/configs/collections/microglia-eval.yml new file mode 100644 index 000000000..db2c13f00 --- /dev/null +++ b/applications/dynaclr/configs/collections/microglia-eval.yml @@ -0,0 +1,73 @@ +name: microglia-eval +description: "Microglia dynamorph evaluation collection. All 3 label-free channels (Brightfield, Phase3D, Retardance), all 5 perturbation conditions." +datasets_root: /hpc/projects/organelle_phenotyping + +experiments: + - name: 20191107_GW23_dynamorph_Brightfield + data_path: ${datasets_root}/datasets/20191107_1209_1_GW23_dynamorph/20191107_1209_1_GW23_dynamorph.zarr + tracks_path: ${datasets_root}/datasets/20191107_1209_1_GW23_dynamorph/tracking.zarr + channels: + - name: Brightfield + marker: Brightfield + perturbation_wells: + untreated: + - C/5 + IL17: + - B/4 + IFN-beta: + - B/5 + Rubella: + - C/4 + Glioblastoma_supernatant: + - B/2 + interval_minutes: 9.0 + start_hpi: 0.0 + marker: Brightfield + organelle: Brightfield + pixel_size_xy_um: 0.325 + + - name: 20191107_GW23_dynamorph_Phase3D + data_path: ${datasets_root}/datasets/20191107_1209_1_GW23_dynamorph/20191107_1209_1_GW23_dynamorph.zarr + tracks_path: ${datasets_root}/datasets/20191107_1209_1_GW23_dynamorph/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + untreated: + - C/5 + IL17: + - B/4 + IFN-beta: + - B/5 + Rubella: + - C/4 + Glioblastoma_supernatant: + - B/2 + interval_minutes: 9.0 + start_hpi: 0.0 + marker: Phase3D + organelle: Phase3D + pixel_size_xy_um: 0.325 + + - name: 20191107_GW23_dynamorph_Retardance + data_path: ${datasets_root}/datasets/20191107_1209_1_GW23_dynamorph/20191107_1209_1_GW23_dynamorph.zarr + tracks_path: ${datasets_root}/datasets/20191107_1209_1_GW23_dynamorph/tracking.zarr + channels: + - name: Retardance + marker: Retardance + perturbation_wells: + untreated: + - C/5 + IL17: + - B/4 + IFN-beta: + - B/5 + Rubella: + - C/4 + Glioblastoma_supernatant: + - B/2 + interval_minutes: 9.0 + start_hpi: 0.0 + marker: Retardance + organelle: Retardance + pixel_size_xy_um: 0.325 diff --git a/applications/dynaclr/configs/dimensionality_reduction/multi-dataset-dim-reduction.yml b/applications/dynaclr/configs/dimensionality_reduction/multi-dataset-dim-reduction.yml new file mode 100644 index 000000000..878087e3f --- /dev/null +++ b/applications/dynaclr/configs/dimensionality_reduction/multi-dataset-dim-reduction.yml @@ -0,0 +1,31 @@ +datasets: + "2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV": + hcs_plate: /hpc/projects/intracellular_dashboard/organelle_dynamics/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/4-phenotyping/train-test/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV_2.zarr + anndata: /hpc/projects/intracellular_dashboard/organelle_dynamics/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/4-phenotyping/predictions/DynaCLR-2D-BagOfChannels-timeaware/v3/timeaware_phase_160patch_104ckpt.zarr + "2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV": + hcs_plate: /hpc/projects/intracellular_dashboard/organelle_dynamics/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV/4-phenotyping/train-test/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + anndata: /hpc/projects/intracellular_dashboard/organelle_dynamics/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV/4-phenotyping/predictions/DynaCLR-2D-BagOfChannels-timeaware/v3/timeaware_phase_160patch_104ckpt.zarr + +# Usage: +# dynaclr combined-dim-reduction -c multi-dataset-dim-reduction.yml +# +# Notes: +# - `datasets[*].anndata` are the AnnData zarrs that will be concatenated to fit the joint reductions. +# - Remove any method section (pca/umap/phate) to skip computing it. +reduce_combined: + overwrite_keys: false + + # PCA configuration (remove this section to skip PCA) + pca: + # Number of components. null = keep all components. + n_components: 32 + normalize_features: true + + # PHATE configuration (remove this section to skip PHATE) + phate: + n_components: 2 + knn: 5 + decay: 40 + scale_embeddings: true + random_state: 42 + n_jobs: -1 diff --git a/applications/dynaclr/configs/evaluation/DINOv3-temporal-MLP-2D-BagOfChannels-v1.yaml b/applications/dynaclr/configs/evaluation/DINOv3-temporal-MLP-2D-BagOfChannels-v1.yaml new file mode 100644 index 000000000..973b17068 --- /dev/null +++ b/applications/dynaclr/configs/evaluation/DINOv3-temporal-MLP-2D-BagOfChannels-v1.yaml @@ -0,0 +1,31 @@ +# Evaluation config for DINOv3-temporal-MLP-2D-BagOfChannels +# Experiments: 14 infectomics experiments (ZIKV + DENV, G3BP1/SEC61B/Phase3D/viral_sensor) +# +# Usage: +# nextflow run applications/dynaclr/nextflow/main.nf \ +# --eval_config applications/dynaclr/configs/evaluation/DINOv3-temporal-MLP-2D-BagOfChannels-v1.yaml \ +# --workspace_dir /hpc/mydata/eduardo.hirata/repos/viscy \ +# -profile local \ +# -resume + +base: + - recipes/predict.yml + - recipes/reduce.yml + - recipes/plot_infectomics.yml + - recipes/linear_classifiers_infectomics.yml + +training_config: /hpc/projects/organelle_phenotyping/models/DINOv3-temporal-MLP-2D-BagOfChannels/v1/config_updated.yaml +ckpt_path: /hpc/projects/organelle_phenotyping/models/DINOv3-temporal-MLP-2D-BagOfChannels/v1/DINOv3-temporal-MLP-2D-BagOfChannels-v1/20260319-235942/checkpoints/epoch=71-step=14040.ckpt +output_dir: /hpc/projects/organelle_phenotyping/models/DINOv3-temporal-MLP-2D-BagOfChannels/v1/evaluation +cell_index_path: /hpc/projects/organelle_phenotyping/models/collections/DynaCLR-BoC-lc-evaluation-v1.parquet + +steps: + - predict + - split + - reduce_dimensionality + - reduce_combined + - plot + - smoothness + - linear_classifiers + - append_annotations + - append_predictions diff --git a/applications/dynaclr/configs/evaluation/DynaCLR-2D-BagOfChannels-v3.yaml b/applications/dynaclr/configs/evaluation/DynaCLR-2D-BagOfChannels-v3.yaml new file mode 100644 index 000000000..a816a503e --- /dev/null +++ b/applications/dynaclr/configs/evaluation/DynaCLR-2D-BagOfChannels-v3.yaml @@ -0,0 +1,30 @@ +# Evaluation config for DynaCLR-2D-BagOfChannels-v3 +# Experiments: 14 infectomics experiments (ZIKV + DENV, G3BP1/SEC61B/Phase3D/viral_sensor) +# +# Usage: +# nextflow run applications/dynaclr/nextflow/main.nf \ +# --eval_config applications/dynaclr/configs/evaluation/DynaCLR-2D-BagOfChannels-v3.yaml \ +# --workspace_dir /hpc/mydata/eduardo.hirata/repos/viscy \ +# -resume + +base: + - recipes/predict.yml + - recipes/reduce.yml + - recipes/plot_infectomics.yml + - recipes/linear_classifiers_infectomics.yml + +training_config: /hpc/projects/organelle_phenotyping/models/SEC61_TOMM20_G3BP1_Sensor/time_interval/dynaclr_gfp_rfp_Ph/organelle_sensor_phase_maxproj_ver3_150epochs/DynaCLR-2D-BagOfChannels-v3.yml +ckpt_path: /hpc/projects/organelle_phenotyping/models/SEC61_TOMM20_G3BP1_Sensor/time_interval/dynaclr_gfp_rfp_Ph/organelle_sensor_phase_maxproj_ver3_150epochs/saved_checkpoints/epoch=104-step=53760.ckpt +output_dir: /hpc/projects/organelle_phenotyping/models/SEC61_TOMM20_G3BP1_Sensor/time_interval/dynaclr_gfp_rfp_Ph/v3 +cell_index_path: /hpc/projects/organelle_phenotyping/models/collections/DynaCLR-BoC-lc-evaluation-v1.parquet + +steps: + - predict + - split + - reduce_dimensionality + - reduce_combined + - plot + - smoothness + - linear_classifiers + - append_annotations + - append_predictions diff --git a/applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels-v1.yaml b/applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels-v1.yaml new file mode 100644 index 000000000..fe9544c12 --- /dev/null +++ b/applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels-v1.yaml @@ -0,0 +1,30 @@ +# Evaluation config for DynaCLR-2D-MIP-BagOfChannels +# Experiments: 14 infectomics experiments (ZIKV + DENV, G3BP1/SEC61B/Phase3D/viral_sensor) +# +# Usage: +# nextflow run applications/dynaclr/nextflow/main.nf \ +# --eval_config applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels-v1.yaml \ +# --workspace_dir /hpc/mydata/eduardo.hirata/repos/viscy \ +# -resume + +base: + - recipes/predict.yml + - recipes/reduce.yml + - recipes/plot_infectomics.yml + - recipes/linear_classifiers_infectomics.yml + +training_config: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11/DynaCLR-2D-MIP-BagOfChannels.yml +ckpt_path: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11/DynaCLR-2D-MIP-BagOfChannels/20260403-150013/checkpoints/last.ckpt +output_dir: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11/evaluation_lc_v1/ +cell_index_path: /hpc/projects/organelle_phenotyping/models/collections/DynaCLR-BoC-lc-evaluation-v1.parquet + +steps: + - predict + - split + - reduce_dimensionality + - reduce_combined + - smoothness + - linear_classifiers + - append_annotations + - append_predictions + - plot diff --git a/applications/dynaclr/configs/evaluation/alfi-eval.yaml b/applications/dynaclr/configs/evaluation/alfi-eval.yaml new file mode 100644 index 000000000..2bad7fef6 --- /dev/null +++ b/applications/dynaclr/configs/evaluation/alfi-eval.yaml @@ -0,0 +1,72 @@ +# Evaluation config for ALFI mitosis datasets +# Checkpoint: DynaCLR-2D-MIP-BagOfChannels +# Data: HeLa (MI06), RPE1 (MI07/MI08), U2OS (MI01-MI05), DIC channel +# Annotations: /hpc/projects/organelle_phenotyping/datasets/annotations/ALFI/ALFI_combined_annotations.csv +# Labels: cell_division_state (interphase / mitosis), cell_cycle_fine_state +# +# Steps: +# 1. Build cell index: +# dynaclr build-cell-index \ +# applications/dynaclr/configs/collections/alfi-eval.yml \ +# /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/evaluations/alfi/cell_index.parquet +# +# 2. Run predict: +# viscy predict -c +# (or use the Nextflow orchestrator) + +training_config: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11/DynaCLR-2D-MIP-BagOfChannels.yml +ckpt_path: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11/DynaCLR-2D-MIP-BagOfChannels/20260403-150013/checkpoints/last.ckpt +cell_index_path: /hpc/projects/organelle_phenotyping/models/collections/alfi-eval.parquet +output_dir: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/evaluations/alfi/ + +steps: + - predict + - reduce_dimensionality + - reduce_combined + - smoothness + - linear_classifiers + - append_annotations + - append_predictions + - plot + +predict: + batch_size: 256 + num_workers: 4 + precision: 32-true + devices: 1 + +reduce_dimensionality: + overwrite_keys: true + pca: + n_components: 32 + normalize_features: true + umap: + n_components: 2 + n_neighbors: 15 + normalize: true + +reduce_combined: + overwrite_keys: true + pca: + n_components: 32 + normalize_features: true + +smoothness: {} + +linear_classifiers: + annotations: + - experiment: "ALFI_HeLa_DMSO_MLN8237" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/ALFI/ALFI_combined_annotations.csv + - experiment: "ALFI_RPE1_untreated" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/ALFI/ALFI_combined_annotations.csv + - experiment: "ALFI_U2OS_DMSO_MLN8237" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/ALFI/ALFI_combined_annotations.csv + tasks: + - task: cell_division_state + - task: cell_death_state + use_scaling: true + use_pca: false + split_train_data: 0.8 + random_seed: 42 + +plot: {} diff --git a/applications/dynaclr/configs/evaluation/ctc_tracking_2d_mip_boc.yaml b/applications/dynaclr/configs/evaluation/ctc_tracking_2d_mip_boc.yaml new file mode 100644 index 000000000..7c68590c7 --- /dev/null +++ b/applications/dynaclr/configs/evaluation/ctc_tracking_2d_mip_boc.yaml @@ -0,0 +1,20 @@ +models: + - path: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11/evaluations/onnx/last.onnx + label: DynaCLR-2D-MIP + pixel_size_um: 0.149 # training pixel size (ALFI dragonfly) + - path: null + label: baseline-iou + +datasets: + - path: /hpc/reference/group.royer/CTC/training/DIC-C2DH-HeLa + sequences: ["01", "02"] + pixel_size_um: 0.190 # DIC-C2DH-HeLa from TIFF XResolution metadata + +ctc_metadata_path: /hpc/reference/group.royer/CTC/metadata.yaml +model_input_shape: [160, 160] +distance_threshold: 325.0 +n_neighbors: 10 +delta_t: 5 +batch_size: 128 +output_dir: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11/evaluations/ctc_tracking/ +show_napari: false diff --git a/applications/dynaclr/configs/evaluation/ctc_tracking_2d_mip_boc_all.sh b/applications/dynaclr/configs/evaluation/ctc_tracking_2d_mip_boc_all.sh new file mode 100644 index 000000000..10bb6ba95 --- /dev/null +++ b/applications/dynaclr/configs/evaluation/ctc_tracking_2d_mip_boc_all.sh @@ -0,0 +1,23 @@ +#!/bin/bash +# CTC tracking accuracy benchmark — DynaCLR-2D-MIP vs IoU baseline +# Runs on all 9 2D CTC training datasets. +# +# sbatch applications/dynaclr/configs/evaluation/ctc_tracking_2d_mip_boc_all.sh + +#SBATCH --job-name=ctc_tracking +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=32 +#SBATCH --mem=64G +#SBATCH --partition=gpu +#SBATCH --gres=gpu:1 +#SBATCH --time=0-02:00:00 +#SBATCH --output=%x-%j.out + +export PYTHONNOUSERSITE=1 +export GRB_LICENSE_FILE=/home/eduardo.hirata/gurobi/gurobi.lic + +WORKSPACE="/hpc/mydata/eduardo.hirata/repos/viscy" +CONFIG="$WORKSPACE/applications/dynaclr/configs/evaluation/ctc_tracking_2d_mip_boc_all.yaml" + +uv run --project "$WORKSPACE" dynaclr evaluate-tracking-accuracy -c "$CONFIG" diff --git a/applications/dynaclr/configs/evaluation/ctc_tracking_2d_mip_boc_all.yaml b/applications/dynaclr/configs/evaluation/ctc_tracking_2d_mip_boc_all.yaml new file mode 100644 index 000000000..9b81b3dca --- /dev/null +++ b/applications/dynaclr/configs/evaluation/ctc_tracking_2d_mip_boc_all.yaml @@ -0,0 +1,37 @@ +models: + - path: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11/evaluations/onnx/last.onnx + label: DynaCLR-2D-MIP + pixel_size_um: 0.149 # training pixel size (Mantis-v1 ) + - path: null + label: baseline-iou + +# 2D datasets only — 3D datasets excluded (model is 2D-only) +# pixel_size_um is auto-detected from TIFF XResolution metadata +datasets: + - path: /hpc/reference/group.royer/CTC/training/BF-C2DL-HSC + sequences: ["01", "02"] + - path: /hpc/reference/group.royer/CTC/training/BF-C2DL-MuSC + sequences: ["01", "02"] + - path: /hpc/reference/group.royer/CTC/training/DIC-C2DH-HeLa + sequences: ["01", "02"] + - path: /hpc/reference/group.royer/CTC/training/Fluo-C2DL-MSC + sequences: ["01", "02"] + - path: /hpc/reference/group.royer/CTC/training/Fluo-N2DH-GOWT1 + sequences: ["01", "02"] + - path: /hpc/reference/group.royer/CTC/training/Fluo-N2DH-SIM+ + sequences: ["01", "02"] + - path: /hpc/reference/group.royer/CTC/training/Fluo-N2DL-HeLa + sequences: ["01", "02"] + - path: /hpc/reference/group.royer/CTC/training/PhC-C2DH-U373 + sequences: ["01", "02"] + - path: /hpc/reference/group.royer/CTC/training/PhC-C2DL-PSC + sequences: ["01", "02"] + +ctc_metadata_path: /hpc/reference/group.royer/CTC/metadata.yaml +model_input_shape: [160, 160] +distance_threshold: 325.0 +n_neighbors: 10 +delta_t: 5 +batch_size: 128 +output_dir: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11/evaluations/ctc_tracking_all/ +show_napari: false diff --git a/applications/dynaclr/configs/evaluation/eval_registry.yaml b/applications/dynaclr/configs/evaluation/eval_registry.yaml new file mode 100644 index 000000000..37e3901dc --- /dev/null +++ b/applications/dynaclr/configs/evaluation/eval_registry.yaml @@ -0,0 +1,21 @@ +# Eval registry — declarative list of models to evaluate. +# +# Each entry points to an eval config YAML. The `force_rerun` flag forces +# re-execution even when outputs already exist. Status is derived from the +# filesystem by `dynaclr check-evals`, not stored here. +# +# Usage: +# dynaclr check-evals -r applications/dynaclr/configs/evaluation/eval_registry.yaml + +models: + - name: DynaCLR-2D-MIP-BagOfChannels + eval_config: applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels_evaluation.yaml + force_rerun: false + + - name: DINOv3-temporal-MLP-2D-BagOfChannels + eval_config: applications/dynaclr/configs/evaluation/DINOv3-temporal-MLP-2D-BagOfChannels_evaluation.yaml + force_rerun: false + + - name: DynaCLR-2D-BagOfChannels-v3 + eval_config: applications/dynaclr/configs/evaluation/DynaCLR-2D-BagOfChannels-v3.yaml + force_rerun: false diff --git a/applications/dynaclr/configs/evaluation/export_onnx_2d_mip_boc.yml b/applications/dynaclr/configs/evaluation/export_onnx_2d_mip_boc.yml new file mode 100644 index 000000000..52d70fc76 --- /dev/null +++ b/applications/dynaclr/configs/evaluation/export_onnx_2d_mip_boc.yml @@ -0,0 +1,24 @@ +model: + class_path: dynaclr.engine.ContrastiveModule + init_args: + encoder: + class_path: viscy_models.contrastive.ContrastiveEncoder + init_args: + backbone: convnext_tiny + in_channels: 1 + in_stack_depth: 1 + stem_kernel_size: [1, 4, 4] + stem_stride: [1, 4, 4] + embedding_dim: 768 + projection_dim: 32 + drop_path_rate: 0.1 + loss_function: + class_path: viscy_models.contrastive.loss.NTXentLoss + init_args: + temperature: 0.2 + lr: 0.00002 + example_input_array_shape: [1, 1, 1, 160, 160] + +ckpt_path: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11/DynaCLR-2D-MIP-BagOfChannels/20260403-150013/checkpoints/last.ckpt + +export_path: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11/evaluations/onnx/last.onnx diff --git a/applications/dynaclr/configs/evaluation/microglia-eval.yaml b/applications/dynaclr/configs/evaluation/microglia-eval.yaml new file mode 100644 index 000000000..c2f3220e7 --- /dev/null +++ b/applications/dynaclr/configs/evaluation/microglia-eval.yaml @@ -0,0 +1,50 @@ +# Evaluation config for microglia dynamorph dataset +# Checkpoint: DynaCLR-2D-MIP-BagOfChannels +# Data: 20191107_1209_1_GW23_dynamorph — Brightfield, Phase3D, Retardance +# Perturbations: untreated, IL17, IFN-beta, Rubella, Glioblastoma_supernatant +# +# Steps: +# 1. Build cell index: +# dynaclr build-cell-index /home/eduardo.hirata/repos/viscy/applications/dynaclr/configs/evaluation/microglia-eval.yaml /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/evaluations/microglia/cell_index.parquet +# +# 2. Run predict: +# viscy predict -c +# (or use the Nextflow orchestrator) + +training_config: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11/DynaCLR-2D-MIP-BagOfChannels.yml +ckpt_path: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11/DynaCLR-2D-MIP-BagOfChannels/20260403-150013/checkpoints/last.ckpt +cell_index_path: /hpc/projects/organelle_phenotyping/models/collections/microglia-eval.parquet +output_dir: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/evaluations/microglia/ + +steps: + # - predict + - reduce_dimensionality + - reduce_combined + - smoothness + - plot + +predict: + batch_size: 256 + num_workers: 4 + precision: 32-true + devices: 1 + +reduce_dimensionality: + overwrite_keys: true + pca: + n_components: 32 + normalize_features: true + umap: + n_components: 2 + n_neighbors: 15 + normalize: true + +reduce_combined: + overwrite_keys: true + pca: + n_components: 32 + normalize_features: true + +smoothness: {} + +plot: {} diff --git a/applications/dynaclr/configs/evaluation/predict_microglia_alfi.sh b/applications/dynaclr/configs/evaluation/predict_microglia_alfi.sh new file mode 100644 index 000000000..14736dac4 --- /dev/null +++ b/applications/dynaclr/configs/evaluation/predict_microglia_alfi.sh @@ -0,0 +1,24 @@ +#!/bin/bash +# Predict embeddings for microglia and ALFI datasets +# Uses DynaCLR-2D-MIP-BagOfChannels checkpoint. +# +# Usage: +# sbatch applications/dynaclr/configs/evaluation/predict_microglia_alfi.sh + +#SBATCH --job-name=dynaclr_predict_microglia_alfi +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:1 +#SBATCH --partition=gpu +#SBATCH --cpus-per-task=8 +#SBATCH --mem-per-cpu=16G +#SBATCH --time=3:00:00 + +export PYTHONNOUSERSITE=1 +WORKSPACE_DIR="/hpc/mydata/eduardo.hirata/repos/viscy" + +# echo "=== Microglia predict ===" +# srun uv run --project /hpc/mydata/eduardo.hirata/repos/viscy viscy predict --config /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/evaluations/microglia/configs/predict.yml + +echo "=== ALFI predict ===" +srun uv run --project /hpc/mydata/eduardo.hirata/repos/viscy viscy predict --config /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/evaluations/alfi/configs/predict.yml diff --git a/applications/dynaclr/configs/evaluation/recipes/linear_classifiers_infectomics.yml b/applications/dynaclr/configs/evaluation/recipes/linear_classifiers_infectomics.yml new file mode 100644 index 000000000..830ccb75a --- /dev/null +++ b/applications/dynaclr/configs/evaluation/recipes/linear_classifiers_infectomics.yml @@ -0,0 +1,62 @@ +# Linear classifier settings for the infectomics benchmark. +# Covers ZIKV + DENV datasets across G3BP1, SEC61B, Phase3D, viral_sensor markers. +# Every experiment here needs an annotation CSV — when an experiment is listed +# without a matching CSV (or the CSV's tracks don't overlap the zarr obs), the +# LC step writes nothing and downstream scripts (Stage 3d label-timing) quietly +# get `predicted_* = NaN`. Add every zarr that needs predictions; missing the +# sensor-channel zarrs is how we lost ZIKV pool coverage in v1. +linear_classifiers: + annotations: + - experiment: "2025_01_28_A549_G3BP1_ZIKV_DENV_G3BP1" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV_combined_annotations.csv + - experiment: "2025_01_28_A549_viral_sensor_ZIKV_DENV" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV_combined_annotations.csv + - experiment: "2025_01_28_A549_Phase3D_ZIKV_DENV" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV_combined_annotations.csv + - experiment: "2025_07_24_A549_G3BP1_ZIKV" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv + - experiment: "2025_07_24_A549_SEC61_ZIKV" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv + - experiment: "2025_07_24_A549_viral_sensor_ZIKV" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv + - experiment: "2025_07_24_A549_Phase3D_ZIKV" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv + - experiment: "2024_11_07_A549_SEC61_DENV_SEC61B" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2024_11_07_A549_SEC61_DENV/2024_11_07_A549_SEC61B_DENV_combined_annotations.csv + - experiment: "2024_11_07_A549_SEC61_DENV_Phase3D" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2024_11_07_A549_SEC61_DENV/2024_11_07_A549_SEC61B_DENV_combined_annotations.csv + - experiment: "2024_11_07_A549_SEC61_DENV_viral_sensor" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2024_11_07_A549_SEC61_DENV/2024_11_07_A549_SEC61B_DENV_combined_annotations.csv + - experiment: "2025_01_24_A549_G3BP1_DENV_G3BP1" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_01_24_A549_G3BP1_DENV/2025_01_24_A549_G3BP1_DENV_combined_annotations.csv + - experiment: "2025_01_24_A549_G3BP1_DENV_Phase3D" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_01_24_A549_G3BP1_DENV/2025_01_24_A549_G3BP1_DENV_combined_annotations.csv + - experiment: "2025_01_24_A549_G3BP1_DENV_viral_sensor" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_01_24_A549_G3BP1_DENV/2025_01_24_A549_G3BP1_DENV_combined_annotations.csv + - experiment: "2025_07_22_A549_G3BP1_ZIKV" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv + - experiment: "2025_07_22_A549_viral_sensor_ZIKV" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv + - experiment: "2025_07_22_A549_Phase3D_ZIKV" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv + - experiment: "2025_08_26_A549_SEC61_ZIKV" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_08_26_A549_SEC61_TOMM20_ZIKV/2025_08_26_A549_SEC61_TOMM20_ZIKV_combined_annotations.csv + - experiment: "2025_08_26_A549_viral_sensor_ZIKV" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_08_26_A549_SEC61_TOMM20_ZIKV/2025_08_26_A549_SEC61_TOMM20_ZIKV_combined_annotations.csv + - experiment: "2025_08_26_A549_Phase3D_ZIKV" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_08_26_A549_SEC61_TOMM20_ZIKV/2025_08_26_A549_SEC61_TOMM20_ZIKV_combined_annotations.csv + tasks: + - task: infection_state + - task: cell_division_state + - task: organelle_state + marker_filters: + - G3BP1 + - SEC61B + - task: cell_death_state + marker_filters: + - G3BP1 + - SEC61B + use_scaling: true + use_pca: false + split_train_data: 0.8 + random_seed: 42 diff --git a/applications/dynaclr/configs/evaluation/recipes/mmd_defaults.yml b/applications/dynaclr/configs/evaluation/recipes/mmd_defaults.yml new file mode 100644 index 000000000..8370035b4 --- /dev/null +++ b/applications/dynaclr/configs/evaluation/recipes/mmd_defaults.yml @@ -0,0 +1,29 @@ +# Default MMD algorithm settings shared across all MMD eval configs. +# Use as a base: reference in per-experiment or pooled MMD configs to avoid +# repeating these parameters. Override any field in the leaf config. +# +# Usage: +# base: recipes/mmd_defaults.yml +# input_path: /path/to/embeddings.zarr +# output_dir: /path/to/output +# comparisons: +# - cond_a: uninfected +# cond_b: ZIKV +# label: "uninfected vs ZIKV" + +group_by: perturbation +save_plots: true + +mmd: + n_permutations: 1000 + max_cells: 2000 + min_cells: 20 + seed: 42 + balance_samples: false + share_bandwidth_from: null + +map_settings: + enabled: false + distance: cosine + null_size: 10000 + seed: 0 diff --git a/applications/dynaclr/configs/evaluation/recipes/plot_infectomics.yml b/applications/dynaclr/configs/evaluation/recipes/plot_infectomics.yml new file mode 100644 index 000000000..bdf4e3da1 --- /dev/null +++ b/applications/dynaclr/configs/evaluation/recipes/plot_infectomics.yml @@ -0,0 +1,15 @@ +# Default plot settings for infectomics DynaCLR evaluation. +plot: + embedding_keys: + - X_pca + combined_embedding_keys: + - X_pca_combined + - X_phate_combined + color_by: + - perturbation + - hours_post_perturbation + - experiment + - marker + point_size: 1.0 + components: [0, 1] + format: pdf diff --git a/applications/dynaclr/configs/evaluation/recipes/predict.yml b/applications/dynaclr/configs/evaluation/recipes/predict.yml new file mode 100644 index 000000000..1dcc4951e --- /dev/null +++ b/applications/dynaclr/configs/evaluation/recipes/predict.yml @@ -0,0 +1,6 @@ +# Default predict step settings for DynaCLR evaluation. +predict: + batch_size: 400 + num_workers: 4 + precision: 32-true + devices: 1 diff --git a/applications/dynaclr/configs/evaluation/recipes/reduce.yml b/applications/dynaclr/configs/evaluation/recipes/reduce.yml new file mode 100644 index 000000000..6923f4acd --- /dev/null +++ b/applications/dynaclr/configs/evaluation/recipes/reduce.yml @@ -0,0 +1,21 @@ +# Default dimensionality reduction settings for DynaCLR evaluation. +# PHATE runs only in reduce_combined; per-experiment reduce_dimensionality uses PCA only. +# Override n_jobs for reduce_combined.phate in the leaf config if needed. +reduce_dimensionality: + overwrite_keys: true + pca: + n_components: 32 + normalize_features: true + +reduce_combined: + overwrite_keys: true + pca: + n_components: 32 + normalize_features: true + phate: + n_components: 2 + knn: 5 + decay: 40 + scale_embeddings: false + random_state: 42 + n_jobs: 48 diff --git a/applications/dynaclr/configs/evaluation/test_evaluation.yaml b/applications/dynaclr/configs/evaluation/test_evaluation.yaml new file mode 100644 index 000000000..02646e1e0 --- /dev/null +++ b/applications/dynaclr/configs/evaluation/test_evaluation.yaml @@ -0,0 +1,77 @@ +# Minimal test config for MMD + linear classifier evaluation. +# Collection: DynaCLR-BoC-lc-evaluation-v1-test (7 experiments, 3 markers x 2 dates) +# 2025_01_28_A549_G3BP1_ZIKV_DENV_G3BP1: B/4 (uninfected), C/4 (infected) — G3BP1 +# 2025_07_22_A549_G3BP1_ZIKV: C/2 (infected) — G3BP1 +# 2025_07_22_A549_Phase3D_ZIKV: C/2 (infected) — Phase3D +# 2025_07_22_A549_viral_sensor_ZIKV: C/2 (infected) — viral_sensor +# 2025_07_24_A549_G3BP1_ZIKV: C/1 (uninfected), C/2 (infected) — G3BP1 +# 2025_07_24_A549_Phase3D_ZIKV: C/1 (uninfected), C/2 (infected) — Phase3D +# 2025_07_24_A549_viral_sensor_ZIKV: C/1 (uninfected), C/2 (infected) — viral_sensor +# +# nextflow run applications/dynaclr/nextflow/main.nf \ +# --eval_config applications/dynaclr/configs/evaluation/test_evaluation.yaml \ +# --workspace_dir /hpc/mydata/eduardo.hirata/repos/viscy \ +# -resume -profile local + +base: + - recipes/predict.yml + - recipes/reduce.yml + +training_config: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11/DynaCLR-2D-MIP-BagOfChannels.yml +ckpt_path: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11/DynaCLR-2D-MIP-BagOfChannels/20260403-150013/checkpoints/last.ckpt +cell_index_path: /hpc/projects/organelle_phenotyping/models/collections/DynaCLR-BoC-lc-evaluation-v1-test.parquet +output_dir: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11/evaluation_test_lc_2 + +steps: + - predict + - split + - reduce_dimensionality + - reduce_combined + - linear_classifiers + - smoothness + +# Override n_jobs for smaller test run +reduce_combined: + phate: + n_jobs: 12 + +mmd: + - name: perturbation + group_by: perturbation + comparisons: + - cond_a: uninfected + cond_b: infected + label: "uninfected vs infected" + temporal_bin_size: 4.0 + combined_temporal_bin_size: null + combined_mode: true + +linear_classifiers: + annotations: + - experiment: "2025_01_28_A549_G3BP1_ZIKV_DENV_G3BP1" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV_combined_annotations.csv + - experiment: "2025_01_28_A549_Phase3D_ZIKV_DENV" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV_combined_annotations.csv + - experiment: "2025_07_24_A549_G3BP1_ZIKV" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv + - experiment: "2025_07_24_A549_Phase3D_ZIKV" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv + - experiment: "2025_07_24_A549_viral_sensor_ZIKV" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv + - experiment: "2025_07_22_A549_G3BP1_ZIKV" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv + - experiment: "2025_07_22_A549_Phase3D_ZIKV" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv + - experiment: "2025_07_22_A549_viral_sensor_ZIKV" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv + tasks: + - task: infection_state + - task: cell_division_state + - task: organelle_state + marker_filters: + - G3BP1 + - task: cell_death_state + use_scaling: true + use_pca: false + split_train_data: 0.8 + random_seed: 42 diff --git a/applications/dynaclr/configs/linear_classifiers/evaluate_dataset_example.yaml b/applications/dynaclr/configs/linear_classifiers/evaluate_dataset_example.yaml deleted file mode 100644 index c2514d04e..000000000 --- a/applications/dynaclr/configs/linear_classifiers/evaluate_dataset_example.yaml +++ /dev/null @@ -1,38 +0,0 @@ -# Example configuration for evaluate_dataset.py -# -# Usage: -# python evaluate_dataset.py -c configs/evaluate_dataset_example.yaml -# python evaluate_dataset.py -c configs/evaluate_dataset_example.yaml --report - -dataset_name: my_test_dataset -test_annotations_csv: /path/to/test_annotations.csv -output_dir: /path/to/output - -models: - 2D: - name: DynaCLR-2D-BagOfChannels-timeaware - version: v3 - wandb_project: linearclassifiers-DynaCLR-2D-BagOfChannels-timeaware-v3 - test_embeddings_dir: /path/to/2D/embeddings/ - train_datasets: - - embeddings_dir: /path/to/train_ds1/embeddings/ - annotations: /path/to/train_ds1/annotations.csv - - embeddings_dir: /path/to/train_ds2/embeddings/ - annotations: /path/to/train_ds2/annotations.csv - -# Optional: auto-detected from test CSV if omitted -task_channels: - infection_state: [phase, sensor] - cell_division_state: [phase] - -# Classifier hyperparams (all optional, shown with defaults) -use_scaling: true -n_pca_components: null -max_iter: 1000 -class_weight: balanced -solver: liblinear -split_train_data: 0.8 -random_seed: 42 - -# W&B logging (set to false for local-only runs) -wandb_logging: true diff --git a/applications/dynaclr/configs/prediction/predict.yml b/applications/dynaclr/configs/prediction/predict.yml index a76cf05c6..0f560fa8c 100644 --- a/applications/dynaclr/configs/prediction/predict.yml +++ b/applications/dynaclr/configs/prediction/predict.yml @@ -11,6 +11,9 @@ trainer: num_nodes: 1 precision: 32-true callbacks: + - class_path: lightning.pytorch.callbacks.TQDMProgressBar + init_args: + refresh_rate: 10 - class_path: viscy_utils.callbacks.embedding_writer.EmbeddingWriter init_args: output_path: #TODO point to the path to save the embeddings diff --git a/applications/dynaclr/configs/training/DINOv3/DINOv3-temporal-MLP-2D-BagOfChannels.sh b/applications/dynaclr/configs/training/DINOv3/DINOv3-temporal-MLP-2D-BagOfChannels.sh new file mode 100644 index 000000000..1b6814553 --- /dev/null +++ b/applications/dynaclr/configs/training/DINOv3/DINOv3-temporal-MLP-2D-BagOfChannels.sh @@ -0,0 +1,29 @@ +#!/bin/bash +# DINOv3-temporal-MLP-2D-BagOfChannels +# +# New run: +# sbatch applications/dynaclr/configs/training/DINOv3/DINOv3-temporal-MLP-2D-BagOfChannels.sh +# +# Resume: edit CKPT_PATH and WANDB_RUN_ID below, then sbatch from RUN_DIR: +# sbatch /hpc/projects/.../DINOv3-temporal-MLP-2D-BagOfChannels.sh + +#SBATCH --job-name=dinov3_mlp_2d +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=2 +#SBATCH --gres=gpu:2 +#SBATCH --constraint="h100|h200" +#SBATCH --partition=gpu +#SBATCH --cpus-per-task=15 +#SBATCH --mem-per-cpu=8G +#SBATCH --time=2-00:00:00 + +# ── Run identity ────────────────────────────────────────────────────── +export PROJECT="DINOv3-temporal-MLP-2D-BagOfChannels-v1" +export RUN_NAME="dinov3-mlp-2d-mip-ntxent-t0p5-lr1e4-bs512" +export CONFIGS="applications/dynaclr/configs/training/DINOv3/DINOv3-temporal-MLP-2D-BagOfChannels.yml" + +# ── Resume (uncomment to continue from checkpoint) ──────────────────── +export CKPT_PATH="/hpc/projects/organelle_phenotyping/models/DINOv3-temporal-MLP-2D-BagOfChannels-v1/dinov3-mlp-2d-mip-ntxent-t0p5-lr1e4-bs512/DINOv3-temporal-MLP-2D-BagOfChannels-v1/20260403-223550/checkpoints/last.ckpt" +export WANDB_RUN_ID="20260403-223550" + +source /hpc/mydata/eduardo.hirata/repos/viscy/applications/dynaclr/configs/training/slurm/train.sh diff --git a/applications/dynaclr/configs/training/DINOv3/DINOv3-temporal-MLP-2D-BagOfChannels.yml b/applications/dynaclr/configs/training/DINOv3/DINOv3-temporal-MLP-2D-BagOfChannels.yml new file mode 100644 index 000000000..6f3ccda42 --- /dev/null +++ b/applications/dynaclr/configs/training/DINOv3/DINOv3-temporal-MLP-2D-BagOfChannels.yml @@ -0,0 +1,120 @@ +# DINOv3-temporal-MLP-2D-BagOfChannels +# ========================================= +# Frozen DINOv3 backbone + trainable MLP projection head. +# 2D bag-of-channels with MIP z-reduction (same data pipeline as +# DynaCLR-2D-MIP-BagOfChannels). +# +# Launch: +# sbatch applications/dynaclr/configs/training/DINOv3/DINOv3-temporal-MLP-2D-BagOfChannels.sh +# +# Resume: +# CKPT_PATH=.../last.ckpt sbatch .../DINOv3-temporal-MLP-2D-BagOfChannels.sh + +base: + - ../recipes/trainer.yml + - ../recipes/model/dinov3_frozen_mlp.yml + +trainer: + strategy: ddp + devices: 2 + precision: bf16-mixed + max_epochs: 100 + logger: + init_args: + project: DINOv3-temporal-MLP-2D-BagOfChannels-v1 + name: null + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: loss/val + every_n_epochs: 1 + save_top_k: 5 + save_last: true + - class_path: viscy_utils.callbacks.OnlineEvalCallback + init_args: + every_n_epochs: 5 + label_key: perturbation + k: 20 + track_id_key: global_track_id + timepoint_key: t + +model: + init_args: + pca_color_keys: [perturbation, hours_post_perturbation, experiment, marker] + log_negative_metrics_every_n_epochs: 2 + example_input_array_shape: [1, 1, 1, 160, 160] + +data: + class_path: dynaclr.data.datamodule.MultiExperimentDataModule + init_args: + cell_index_path: /hpc/projects/organelle_phenotyping/models/collections/DynaCLR-2D-MIP-BagOfChannels-MultiCell.parquet + focus_channel: Phase3D + reference_pixel_size_xy_um: 0.1494 + z_window: 1 + z_extraction_window: 11 + z_focus_offset: 0.5 + yx_patch_size: [256, 256] + final_yx_patch_size: [160, 160] + channels_per_sample: 1 + positive_cell_source: lookup + positive_match_columns: [lineage_id] + positive_channel_source: same + tau_range: [0.5, 2.0] + tau_decay_rate: 2.0 + stratify_by: [perturbation, marker] + split_ratio: 0.8 + batch_size: 512 + num_workers: 2 + seed: 42 + normalizations: + - class_path: viscy_transforms.NormalizeSampled + init_args: + keys: [channel_0] + level: timepoint_statistics + subtrahend: mean + divisor: std + augmentations: + - class_path: viscy_transforms.BatchedRandAffined + init_args: + keys: [channel_0] + prob: 0.8 + scale_range: [[0.9, 1.1], [0.9, 1.1], [0.9, 1.1]] + rotate_range: [3.14, 0.0, 0.0] + shear_range: [0.05, 0.05, 0.0, 0.05, 0.0, 0.05] + - class_path: viscy_transforms.BatchedRandFlipd + init_args: + keys: [channel_0] + spatial_axes: [1, 2] + prob: 0.5 + - class_path: viscy_transforms.BatchedRandAdjustContrastd + init_args: + keys: [channel_0] + prob: 0.5 + gamma: [0.8, 1.3] + - class_path: viscy_transforms.BatchedRandScaleIntensityd + init_args: + keys: [channel_0] + prob: 0.5 + factors: 0.5 + - class_path: viscy_transforms.BatchedRandGaussianSmoothd + init_args: + keys: [channel_0] + prob: 0.5 + sigma_x: [0.25, 0.50] + sigma_y: [0.25, 0.50] + sigma_z: [0.0, 0.0] + - class_path: viscy_transforms.BatchedRandGaussianNoised + init_args: + keys: [channel_0] + prob: 0.5 + mean: 0.0 + std: 0.08 + # Z-reduction: MIP for fluorescence, center-slice for label-free. + # Must be LAST augmentation (before implicit final spatial crop). + - class_path: viscy_transforms.BatchedChannelWiseZReductiond + init_args: + keys: [channel_0] + allow_missing_keys: true diff --git a/applications/dynaclr/configs/training/DynaCLR-2D-BagOfChannels-v3.sh b/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-BagOfChannels-v3.sh similarity index 87% rename from applications/dynaclr/configs/training/DynaCLR-2D-BagOfChannels-v3.sh rename to applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-BagOfChannels-v3.sh index feb6edadd..3db90a813 100755 --- a/applications/dynaclr/configs/training/DynaCLR-2D-BagOfChannels-v3.sh +++ b/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-BagOfChannels-v3.sh @@ -2,7 +2,7 @@ # DynaCLR-2D-BagOfChannels-v3 # # New run: -# sbatch applications/dynaclr/configs/training/DynaCLR-2D-BagOfChannels-v3.sh +# sbatch applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-BagOfChannels-v3.sh # # Resume: edit CKPT_PATH and WANDB_RUN_ID below, then sbatch from RUN_DIR. @@ -18,7 +18,7 @@ # ── Run identity ────────────────────────────────────────────────────── export PROJECT="DynaCLR-2D-BagOfChannels-v3" export RUN_NAME="phase1-ntxent-temp0p2" -export CONFIGS="applications/dynaclr/configs/training/DynaCLR-2D-BagOfChannels-v3.yml" +export CONFIGS="applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-BagOfChannels-v3.yml" # ── Resume (uncomment to continue from checkpoint) ──────────────────── # export CKPT_PATH="" diff --git a/applications/dynaclr/configs/training/DynaCLR-2D-BagOfChannels-v3.yml b/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-BagOfChannels-v3.yml similarity index 78% rename from applications/dynaclr/configs/training/DynaCLR-2D-BagOfChannels-v3.yml rename to applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-BagOfChannels-v3.yml index ff4eba7b5..e50e2ba10 100644 --- a/applications/dynaclr/configs/training/DynaCLR-2D-BagOfChannels-v3.yml +++ b/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-BagOfChannels-v3.yml @@ -5,22 +5,17 @@ # Temporal positive pairs (same lineage at t+tau), stratified by perturbation + marker. # # Launch: -# sbatch applications/dynaclr/configs/training/DynaCLR-2D-BagOfChannels-v3.sh +# sbatch applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-BagOfChannels-v3.sh -seed_everything: 42 +base: + - ../recipes/trainer.yml + - ../recipes/model/contrastive_encoder_convnext_tiny.yml trainer: - accelerator: gpu strategy: ddp devices: 2 - num_nodes: 1 precision: bf16-mixed max_epochs: 150 - log_every_n_steps: 10 - enable_checkpointing: true - enable_model_summary: false - inference_mode: true - use_distributed_sampler: false callbacks: - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: @@ -36,37 +31,27 @@ trainer: every_n_epochs: 5 label_key: perturbation k: 20 - - class_path: viscy_utils.callbacks.SaveConfigToWandb model: - class_path: dynaclr.engine.ContrastiveModule init_args: encoder: - class_path: viscy_models.contrastive.ContrastiveEncoder init_args: - backbone: convnext_tiny - in_channels: 1 in_stack_depth: 1 stem_kernel_size: [1, 4, 4] stem_stride: [1, 4, 4] - embedding_dim: 768 projection_dim: 32 drop_path_rate: 0.1 loss_function: - class_path: viscy_models.contrastive.loss.NTXentLoss init_args: temperature: 0.2 lr: 0.00002 - log_batches_per_epoch: 3 - log_samples_per_batch: 3 - log_embeddings_every_n_epochs: 10 pca_color_keys: "[perturbation,hours_post_perturbation]" example_input_array_shape: [1, 1, 1, 160, 160] data: class_path: dynaclr.data.datamodule.MultiExperimentDataModule init_args: - cell_index_path: /hpc/mydata/eduardo.hirata/repos/viscy/applications/dynaclr/configs/cell_index/DynaCLR-2D-BagOfChannels-v3.parquet + cell_index_path: /hpc/projects/organelle_phenotyping/models/collections/DynaCLR-2D-BagOfChannels-v3.parquet z_window: 1 yx_patch_size: [192, 192] final_yx_patch_size: [160, 160] diff --git a/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker-A40.sh b/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker-A40.sh new file mode 100644 index 000000000..2c6f52ad6 --- /dev/null +++ b/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker-A40.sh @@ -0,0 +1,21 @@ +#!/bin/bash +# DynaCLR-2D-MIP-BagOfChannels single-marker — A40 interactive single-GPU variant. +# For smoke-testing and small-scale iteration on the interactive partition +# without queueing on the gpu partition. +# +# sbatch applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker-A40.sh + +#SBATCH --job-name=dynaclr_2d_sm_a40 +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:a40:1 +#SBATCH --partition=interactive +#SBATCH --cpus-per-task=16 +#SBATCH --mem-per-cpu=14G +#SBATCH --time=4-00:00:00 + +export PROJECT="DynaCLR-2D-MIP-BagOfChannels" +export RUN_NAME="2d-mip-ntxent-t0p2-lr2e5-bs128-A40-single-marker" +export CONFIGS="applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels.yml applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker.yml applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker-A40.yml" + +source /hpc/mydata/eduardo.hirata/repos/viscy/applications/dynaclr/configs/training/slurm/train.sh diff --git a/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker-A40.yml b/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker-A40.yml new file mode 100644 index 000000000..1a85a68a5 --- /dev/null +++ b/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker-A40.yml @@ -0,0 +1,12 @@ +# Single-GPU A40 override for DynaCLR-2D-MIP-BagOfChannels single-marker. +# Chains on top of the 4-GPU base + single-marker override; strips DDP and +# halves batch size to fit the A40's 48 GB VRAM. + +trainer: + strategy: auto + devices: 1 + +data: + init_args: + batch_size: 128 + num_workers: 1 diff --git a/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker.sh b/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker.sh new file mode 100755 index 000000000..593319e2d --- /dev/null +++ b/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker.sh @@ -0,0 +1,21 @@ +#!/bin/bash +# DynaCLR-2D-MIP-BagOfChannels SINGLE-MARKER variant. +# Every batch contains only one marker (OPS-style). +# +# sbatch applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker.sh + +#SBATCH --job-name=dynaclr_2d_sm +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=4 +#SBATCH --gres=gpu:4 +#SBATCH --constraint="h100|h200" +#SBATCH --partition=gpu +#SBATCH --cpus-per-task=15 +#SBATCH --mem-per-cpu=8G +#SBATCH --time=3-00:00:00 + +export PROJECT="DynaCLR-2D-MIP-BagOfChannels" +export RUN_NAME="2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11-single-marker" +export CONFIGS="applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels.yml applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker.yml" + +source /hpc/mydata/eduardo.hirata/repos/viscy/applications/dynaclr/configs/training/slurm/train.sh diff --git a/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker.yml b/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker.yml new file mode 100644 index 000000000..27ab67d85 --- /dev/null +++ b/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker.yml @@ -0,0 +1,10 @@ +# Override: single-marker batches for DynaCLR-2D-MIP-BoC. +# Matches the OPS strategy — every batch is one marker, forcing the model +# to learn cellular features instead of channel-filter shortcuts. + +data: + init_args: + batch_group_by: marker + stratify_by: null + # Equal weighting across markers as a first pass. Switch to + # sqrt(cell_count) weights after measuring marker distribution. diff --git a/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels.sh b/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels.sh new file mode 100644 index 000000000..8ca88d4c0 --- /dev/null +++ b/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels.sh @@ -0,0 +1,34 @@ +#!/bin/bash +# DynaCLR-2D-MIP-BagOfChannels +# Multi-cell-type 2D contrastive learning with channel-wise z-reduction. +# +# New run: +# sbatch applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels.sh +# +# Resume: edit CKPT_PATH and WANDB_RUN_ID below, then sbatch. + +#SBATCH --job-name=dynaclr_2d_mip +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=4 +#SBATCH --gres=gpu:4 +#SBATCH --constraint="h100|h200" +#SBATCH --partition=gpu +#SBATCH --cpus-per-task=15 +#SBATCH --mem-per-cpu=8G +#SBATCH --time=3-00:00:00 + +# ── Run identity ────────────────────────────────────────────────────── +# Fresh retrain after FOV cache collision fix (commit 1435f493) and +# dataloader vectorization. Prior run 2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11 +# trained on 157 collided samples that silently read from the wrong zarr; +# retraining from scratch is cleaner than warm-starting a partially-corrupt +# encoder. +export PROJECT="DynaCLR-2D-MIP-BagOfChannels" +export RUN_NAME="2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11-mixed-markers" +export CONFIGS="applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels.yml" + +# ── Resume (uncomment to continue from checkpoint) ──────────────────── +# export CKPT_PATH="/path/to/last.ckpt" +# export WANDB_RUN_ID="" + +source /hpc/mydata/eduardo.hirata/repos/viscy/applications/dynaclr/configs/training/slurm/train.sh diff --git a/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels.yml b/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels.yml new file mode 100644 index 000000000..f4799624e --- /dev/null +++ b/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels.yml @@ -0,0 +1,139 @@ +# DynaCLR-2D-MIP-BagOfChannels +# ============================== +# 2D bag-of-channels contrastive learning with channel-wise z-reduction. +# Extracts a 20-slice z-stack around focus, randomly crops to 10 slices +# (Z-invariance), then applies MIP for fluorescence and center-slice for +# label-free (Phase3D, BF, DIC, Retardance). +# Multi-cell-type: A549 infectomics, microglia dynamorph, ALFI mitosis. +# +# Launch: +# sbatch applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels.sh +# +# Resume: +# CKPT_PATH=.../last.ckpt sbatch .../DynaCLR-2D-MIP-BagOfChannels.sh + +base: + - ../recipes/trainer.yml + - ../recipes/model/contrastive_encoder_convnext_tiny.yml + +trainer: + strategy: ddp + devices: 4 + precision: bf16-mixed + max_epochs: 150 + logger: + init_args: + project: DynaCLR-2D-MIP-BagOfChannels + name: null + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: loss/val + every_n_epochs: 1 + save_top_k: 5 + save_last: true + - class_path: viscy_utils.callbacks.OnlineEvalCallback + init_args: + every_n_epochs: 5 + label_key: perturbation + k: 20 + track_id_key: global_track_id + timepoint_key: t + +model: + init_args: + encoder: + init_args: + in_stack_depth: 1 + stem_kernel_size: [1, 4, 4] + stem_stride: [1, 4, 4] + projection_dim: 32 + drop_path_rate: 0.1 + loss_function: + init_args: + temperature: 0.2 + lr: 0.00002 + pca_color_keys: "[perturbation,hours_post_perturbation,experiment,marker]" + log_negative_metrics_every_n_epochs: 2 + example_input_array_shape: [1, 1, 1, 160, 160] + +data: + class_path: dynaclr.data.datamodule.MultiExperimentDataModule + init_args: + cell_index_path: /hpc/projects/organelle_phenotyping/models/collections/DynaCLR-2D-MIP-BagOfChannels-v2.parquet + focus_channel: Phase3D + reference_pixel_size_xy_um: 0.1494 + z_window: 1 + z_extraction_window: 20 + z_focus_offset: 0.3 + yx_patch_size: [256, 256] + final_yx_patch_size: [160, 160] + channels_per_sample: 1 + positive_cell_source: lookup + positive_match_columns: [lineage_id] + positive_channel_source: same + tau_range: [0.5, 2.0] + tau_decay_rate: 2.0 + stratify_by: [perturbation, marker] + split_ratio: 0.8 + batch_size: 256 + num_workers: 2 + seed: 42 + normalizations: + - class_path: viscy_transforms.NormalizeSampled + init_args: + keys: [channel_0] + level: timepoint_statistics + subtrahend: mean + divisor: std + augmentations: + - class_path: viscy_transforms.BatchedRandAffined + init_args: + keys: [channel_0] + prob: 0.8 + scale_range: [[0.8, 1.3], [0.8, 1.3], [0.8, 1.3]] + rotate_range: [3.14, 0.0, 0.0] + shear_range: [0.05, 0.05, 0.0, 0.05, 0.0, 0.05] + - class_path: viscy_transforms.BatchedRandFlipd + init_args: + keys: [channel_0] + spatial_axes: [1, 2] + prob: 0.5 + - class_path: viscy_transforms.BatchedRandAdjustContrastd + init_args: + keys: [channel_0] + prob: 0.5 + gamma: [0.6, 1.6] + - class_path: viscy_transforms.BatchedRandScaleIntensityd + init_args: + keys: [channel_0] + prob: 0.5 + factors: 0.5 + - class_path: viscy_transforms.BatchedRandGaussianSmoothd + init_args: + keys: [channel_0] + prob: 0.5 + sigma_x: [0.25, 0.50] + sigma_y: [0.25, 0.50] + sigma_z: [0.0, 0.0] + - class_path: viscy_transforms.BatchedRandGaussianNoised + init_args: + keys: [channel_0] + prob: 0.5 + mean: 0.0 + std: 0.1 + # Random Z crop: select 10 of 20 extracted slices for Z-invariance. + # Must come before ZReduction so MIP sees a variable sub-stack. + - class_path: viscy_transforms.BatchedRandSpatialCropd + init_args: + keys: [channel_0] + roi_size: [10, 192, 192] + # Z-reduction: MIP for fluorescence, center-slice for label-free. + # Must be LAST augmentation (before implicit final spatial crop). + - class_path: viscy_transforms.BatchedChannelWiseZReductiond + init_args: + keys: [channel_0] + allow_missing_keys: true diff --git a/applications/dynaclr/configs/training/DynaCLR-3D/DynaCLR-3D-BagOfChannels-v2-single-marker.sh b/applications/dynaclr/configs/training/DynaCLR-3D/DynaCLR-3D-BagOfChannels-v2-single-marker.sh new file mode 100755 index 000000000..2e7ae0927 --- /dev/null +++ b/applications/dynaclr/configs/training/DynaCLR-3D/DynaCLR-3D-BagOfChannels-v2-single-marker.sh @@ -0,0 +1,20 @@ +#!/bin/bash +# DynaCLR-3D-BagOfChannels-v2 SINGLE-MARKER variant (fresh, no resume). +# +# sbatch applications/dynaclr/configs/training/DynaCLR-3D/DynaCLR-3D-BagOfChannels-v2-single-marker.sh + +#SBATCH --job-name=dynaclr_3d_sm +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=2 +#SBATCH --gres=gpu:2 +#SBATCH --constraint="h100|h200" +#SBATCH --partition=gpu +#SBATCH --cpus-per-task=15 +#SBATCH --mem-per-cpu=12G +#SBATCH --time=4-00:00:00 + +export PROJECT="DynaCLR-3D-BagOfChannels-v2" +export RUN_NAME="3d-z32-256to228to160-ntxent-t0p2-single-marker" +export CONFIGS="applications/dynaclr/configs/training/DynaCLR-3D/DynaCLR-3D-BagOfChannels-v2.yml applications/dynaclr/configs/training/DynaCLR-3D/DynaCLR-3D-BagOfChannels-v2-single-marker.yml" + +source /hpc/mydata/eduardo.hirata/repos/viscy/applications/dynaclr/configs/training/slurm/train.sh diff --git a/applications/dynaclr/configs/training/DynaCLR-3D/DynaCLR-3D-BagOfChannels-v2-single-marker.yml b/applications/dynaclr/configs/training/DynaCLR-3D/DynaCLR-3D-BagOfChannels-v2-single-marker.yml new file mode 100644 index 000000000..95ffb127e --- /dev/null +++ b/applications/dynaclr/configs/training/DynaCLR-3D/DynaCLR-3D-BagOfChannels-v2-single-marker.yml @@ -0,0 +1,7 @@ +# Override: single-marker batches for DynaCLR-3D-BoC-v2. +# Matches the OPS strategy — every batch is one marker. + +data: + init_args: + batch_group_by: marker + stratify_by: null diff --git a/applications/dynaclr/configs/training/DynaCLR-3D-BagOfChannels-v2.sh b/applications/dynaclr/configs/training/DynaCLR-3D/DynaCLR-3D-BagOfChannels-v2.sh similarity index 52% rename from applications/dynaclr/configs/training/DynaCLR-3D-BagOfChannels-v2.sh rename to applications/dynaclr/configs/training/DynaCLR-3D/DynaCLR-3D-BagOfChannels-v2.sh index d8f73fd63..80ca1a59c 100755 --- a/applications/dynaclr/configs/training/DynaCLR-3D-BagOfChannels-v2.sh +++ b/applications/dynaclr/configs/training/DynaCLR-3D/DynaCLR-3D-BagOfChannels-v2.sh @@ -2,28 +2,29 @@ # DynaCLR-3D-BagOfChannels-v2 # # New run: -# sbatch applications/dynaclr/configs/training/DynaCLR-3D-BagOfChannels-v2.sh +# sbatch applications/dynaclr/configs/training/DynaCLR-3D/DynaCLR-3D-BagOfChannels-v2.sh # # Resume: edit CKPT_PATH and WANDB_RUN_ID below, then sbatch from RUN_DIR: -# sbatch /hpc/projects/.../3d-z16-.../DynaCLR-3D-BagOfChannels-v2.sh +# sbatch /hpc/projects/.../3d-z32-.../DynaCLR-3D-BagOfChannels-v2.sh #SBATCH --job-name=dynaclr_3d_v2 #SBATCH --nodes=1 -#SBATCH --ntasks-per-node=4 -#SBATCH --gres=gpu:4 +#SBATCH --ntasks-per-node=2 +#SBATCH --gres=gpu:2 #SBATCH --constraint="h100|h200" #SBATCH --partition=gpu #SBATCH --cpus-per-task=15 -#SBATCH --mem-per-cpu=8G -#SBATCH --time=0-22:00:00 +#SBATCH --mem-per-cpu=12G +#SBATCH --time=4-00:00:00 # ── Run identity ────────────────────────────────────────────────────── export PROJECT="DynaCLR-3D-BagOfChannels-v2" -export RUN_NAME="3d-z16-ntxent-t0p2-lr2e5-bs512-192to160-zext45" -export CONFIGS="applications/dynaclr/configs/training/DynaCLR-3D-BagOfChannels-v2.yml" +export RUN_NAME="3d-z32-256to228to160-ntxent-t0p2-mixed-markers" +export CONFIGS="applications/dynaclr/configs/training/DynaCLR-3D/DynaCLR-3D-BagOfChannels-v2.yml" # ── Resume (uncomment to continue from checkpoint) ──────────────────── -# export CKPT_PATH="/hpc/projects/organelle_phenotyping/models/DynaCLR-3D-BagOfChannels-v2/3d-z16-ntxent-t0p2-lr2e5-bs512-192to160-zext45/checkpoints/last.ckpt" -# export WANDB_RUN_ID="20260329-063341" +# Commented out for fresh A/B comparison run against single-marker variant. +# export CKPT_PATH="/hpc/projects/organelle_phenotyping/models/DynaCLR-3D-BagOfChannels-v2/3d-z32-256to228to160-ntxent-t0p2/DynaCLR-3D-BagOfChannels-v2/20260402-185442/checkpoints/last.ckpt" +# export WANDB_RUN_ID="20260402-185442" -source "$(dirname "$0")/slurm/train.sh" +source /hpc/mydata/eduardo.hirata/repos/viscy/applications/dynaclr/configs/training/slurm/train.sh diff --git a/applications/dynaclr/configs/training/DynaCLR-3D-BagOfChannels-v2.yml b/applications/dynaclr/configs/training/DynaCLR-3D/DynaCLR-3D-BagOfChannels-v2.yml similarity index 73% rename from applications/dynaclr/configs/training/DynaCLR-3D-BagOfChannels-v2.yml rename to applications/dynaclr/configs/training/DynaCLR-3D/DynaCLR-3D-BagOfChannels-v2.yml index 3d6392ce7..b9272212d 100644 --- a/applications/dynaclr/configs/training/DynaCLR-3D-BagOfChannels-v2.yml +++ b/applications/dynaclr/configs/training/DynaCLR-3D/DynaCLR-3D-BagOfChannels-v2.yml @@ -1,35 +1,32 @@ # DynaCLR-3D-BagOfChannels-v2 # ============================ # 3D bag-of-channels contrastive learning. -# One random fluorescence channel per sample, 16-slice Z window. +# One random fluorescence channel per sample, 32-slice Z window. # Temporal positive pairs (same lineage at t+tau), stratified by perturbation. # +# Augmentation pipeline: +# extract (45,256,256) → normalize → affine → RandCrop (40,228,228) +# → flip/contrast/noise → CenterCrop (32,160,160) [auto-appended] +# # Launch: -# sbatch applications/dynaclr/configs/training/DynaCLR-3D-BagOfChannels-v2.sh +# sbatch applications/dynaclr/configs/training/DynaCLR-3D/DynaCLR-3D-BagOfChannels-v2.sh # # Resume: # CKPT_PATH=.../last.ckpt sbatch .../DynaCLR-3D-BagOfChannels-v2.sh -seed_everything: 42 +base: + - ../recipes/trainer.yml + - ../recipes/model/contrastive_encoder_convnext_tiny.yml trainer: - accelerator: gpu strategy: ddp - devices: 4 - num_nodes: 1 + devices: 2 precision: bf16-mixed max_epochs: 150 - log_every_n_steps: 10 - enable_checkpointing: true - enable_model_summary: false - inference_mode: true - use_distributed_sampler: false logger: - class_path: lightning.pytorch.loggers.WandbLogger init_args: - entity: computational_imaging project: DynaCLR-3D-BagOfChannels-v2 - name: unnamed-run + name: 3d-z32-256to228to160-ntxent-t0p2 callbacks: - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: @@ -47,46 +44,35 @@ trainer: k: 20 track_id_key: global_track_id timepoint_key: t - - class_path: viscy_utils.callbacks.SaveConfigToWandb model: - class_path: dynaclr.engine.ContrastiveModule init_args: encoder: - class_path: viscy_models.contrastive.ContrastiveEncoder init_args: - backbone: convnext_tiny - in_channels: 1 - in_stack_depth: 16 + in_stack_depth: 32 stem_kernel_size: [4, 4, 4] stem_stride: [4, 4, 4] - embedding_dim: 768 projection_dim: 32 drop_path_rate: 0.1 loss_function: - class_path: viscy_models.contrastive.loss.NTXentLoss init_args: temperature: 0.2 lr: 0.00002 - log_batches_per_epoch: 3 - log_samples_per_batch: 3 - log_embeddings_every_n_epochs: 10 pca_color_keys: "[perturbation,hours_post_perturbation,experiment,marker]" log_negative_metrics_every_n_epochs: 2 - example_input_array_shape: [1, 1, 16, 160, 160] + example_input_array_shape: [1, 1, 32, 160, 160] data: class_path: dynaclr.data.datamodule.MultiExperimentDataModule init_args: - collection_path: applications/dynaclr/configs/collections/DynaCLR-3D-BagOfChannels-v2.yml - cell_index_path: /hpc/projects/organelle_phenotyping/models/collections/DynaCLR-3D-BagOfChannels-v2.parquet + cell_index_path: /hpc/projects/organelle_phenotyping/models/collections/DynaCLR-3D-BagOfChannels-v4.parquet focus_channel: Phase3D reference_pixel_size_xy_um: 0.1494 reference_pixel_size_z_um: 0.174 - z_window: 16 + z_window: 32 z_extraction_window: 45 z_focus_offset: 0.3 - yx_patch_size: [192, 192] + yx_patch_size: [256, 256] final_yx_patch_size: [160, 160] channels_per_sample: 1 positive_cell_source: lookup @@ -96,8 +82,8 @@ data: tau_decay_rate: 2.0 stratify_by: [perturbation] split_ratio: 0.8 - batch_size: 512 - num_workers: 1 + batch_size: 256 + num_workers: 4 seed: 42 normalizations: - class_path: viscy_transforms.NormalizeSampled @@ -114,6 +100,13 @@ data: scale_range: [[0.9, 1.1], [0.9, 1.1], [0.9, 1.1]] rotate_range: [3.14, 0.0, 0.0] shear_range: [0.05, 0.05, 0.0, 0.05, 0.0, 0.05] + # Random crop: Z for focus invariance + YX for translation augmentation. + # The datamodule auto-appends a CenterCrop to [32, 160, 160] after this + # to remove rotation zero-fill artifacts at the edges. + - class_path: viscy_transforms.BatchedRandSpatialCropd + init_args: + keys: [channel_0] + roi_size: [40, 228, 228] - class_path: viscy_transforms.BatchedRandFlipd init_args: keys: [channel_0] diff --git a/applications/dynaclr/configs/training/OPS-1000genes-lite.sh b/applications/dynaclr/configs/training/OPS-1000genes-lite.sh deleted file mode 100755 index ebc569469..000000000 --- a/applications/dynaclr/configs/training/OPS-1000genes-lite.sh +++ /dev/null @@ -1,30 +0,0 @@ -#!/bin/bash -# OPS 1000-gene DynaCLR with cosine gene classifier head (lite dataset) -# -# New run: -# sbatch applications/dynaclr/configs/training/OPS-1000genes-lite.sh -# -# Resume: edit CKPT_PATH and WANDB_RUN_ID below, then sbatch from RUN_DIR. - -#SBATCH --job-name=dynaclr_ops_1k -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=4 -#SBATCH --gres=gpu:4 -#SBATCH --partition=gpu -#SBATCH --constraint="h100|h200" -#SBATCH --cpus-per-task=15 -#SBATCH --mem-per-cpu=8G -#SBATCH --time=0-22:00:00 - -# ── Run identity ────────────────────────────────────────────────────── -export PROJECT="OPS" -export RUN_NAME="OPS-1000genes-lite-CosineClassifier" -export EXTRA_ARGS="--trainer.logger.init_args.project=OPS-1000genes-lite-CosineClassifier" -export CONFIGS="applications/dynaclr/configs/training/OPS-1000genes-lite.yml" - -# ── Resume (uncomment to continue from checkpoint) ──────────────────── -# export CKPT_PATH="" -# export WANDB_RUN_ID="" - -WORKSPACE_DIR="${WORKSPACE_DIR:-/hpc/mydata/eduardo.hirata/repos/viscy}" -source "${WORKSPACE_DIR}/applications/dynaclr/configs/training/slurm/train.sh" diff --git a/applications/dynaclr/configs/training/OPS-1000genes-lite.yml b/applications/dynaclr/configs/training/OPS-1000genes-lite.yml deleted file mode 100644 index 88b9ef4c9..000000000 --- a/applications/dynaclr/configs/training/OPS-1000genes-lite.yml +++ /dev/null @@ -1,143 +0,0 @@ -# OPS 1000-gene DynaCLR with cosine gene classifier head (lite dataset) -# ====================================================================== -# Lite dataset: 11M cells, 1001 perturbations, 22 reporters, 74 experiments. -# Percentile normalization (50-99), bag-of-channels, gene+reporter positive pairs. -# -# Launch: -# sbatch applications/dynaclr/configs/training/OPS-1000genes-lite.sh - -seed_everything: 42 - -trainer: - accelerator: gpu - strategy: ddp - devices: 4 - num_nodes: 1 - precision: bf16-mixed - max_epochs: 300 - limit_train_batches: 400 - limit_val_batches: 100 - log_every_n_steps: 5 - enable_checkpointing: true - enable_model_summary: false - inference_mode: true - use_distributed_sampler: false - callbacks: - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: loss/val - every_n_epochs: 1 - save_top_k: 5 - save_last: true - - class_path: viscy_utils.callbacks.OnlineEvalCallback - init_args: - every_n_epochs: 5 - label_key: perturbation - k: 20 - - class_path: viscy_utils.callbacks.SaveConfigToWandb - -model: - class_path: dynaclr.engine.ContrastiveModule - init_args: - encoder: - class_path: viscy_models.contrastive.ContrastiveEncoder - init_args: - backbone: convnext_tiny - in_channels: 1 - in_stack_depth: 1 - stem_kernel_size: [1, 4, 4] - stem_stride: [1, 4, 4] - embedding_dim: 768 - projection_dim: 256 - drop_path_rate: 0.0 - loss_function: - class_path: viscy_models.contrastive.loss.NTXentLoss - init_args: - temperature: 0.5 - auxiliary_heads: - gene: - class_path: viscy_models.components.heads.ClassificationHead - init_args: - head_name: gene - batch_key: gene_label - in_dims: 768 - hidden_dims: 256 - num_classes: 1001 - cosine_classifier: true - loss_weight: 0.5 - top_k: 5 - weight_schedule: cosine - weight_start: 0.0 - weight_warmup_epochs: 30 - lr: 0.0002 - log_batches_per_epoch: 8 - log_samples_per_batch: 1 - log_embeddings_every_n_epochs: 10 - example_input_array_shape: [1, 1, 1, 128, 128] - -data: - class_path: dynaclr.data.datamodule.MultiExperimentDataModule - init_args: - cell_index_path: /hpc/projects/organelle_phenotyping/datasets/ops/training_labels_1000genes_lite_v2_valid.parquet - z_window: 1 - yx_patch_size: [224, 224] - final_yx_patch_size: [128, 128] - channels_per_sample: 1 - positive_cell_source: lookup - positive_match_columns: [perturbation, marker] - stratify_by: marker - split_ratio: 0.8 - batch_size: 512 - num_workers: 4 - seed: 0 - shuffle_val: true - label_columns: - gene_label: perturbation - normalizations: - - class_path: viscy_transforms.BatchedScaleIntensityRangePercentilesd - init_args: - keys: [channel_0] - lower: 50 - upper: 99 - b_min: 0.0 - b_max: 1.0 - clip: true - augmentations: - - class_path: viscy_transforms.BatchedRandAffined - init_args: - keys: [channel_0] - prob: 0.8 - scale_range: [[1.0, 1.0], [0.9, 1.1], [0.9, 1.1]] - rotate_range: [3.14, 0.0, 0.0] - shear_range: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] - - class_path: viscy_transforms.BatchedRandFlipd - init_args: - keys: [channel_0] - spatial_axes: [1, 2] - prob: 0.5 - - class_path: viscy_transforms.BatchedRandAdjustContrastd - init_args: - keys: [channel_0] - prob: 0.5 - gamma: [0.8, 1.2] - - class_path: viscy_transforms.BatchedRandScaleIntensityd - init_args: - keys: [channel_0] - prob: 0.5 - factors: 0.5 - - class_path: viscy_transforms.BatchedRandGaussianSmoothd - init_args: - keys: [channel_0] - prob: 0.5 - sigma_x: [0.2, 0.5] - sigma_y: [0.2, 0.5] - sigma_z: [0.0, 0.0] - - class_path: viscy_transforms.BatchedRandGaussianNoised - init_args: - keys: [channel_0] - prob: 0.5 - mean: 0.0 - std: 0.08 diff --git a/applications/dynaclr/configs/training/OPS-373genes.sh b/applications/dynaclr/configs/training/OPS-373genes.sh deleted file mode 100755 index 1a7134086..000000000 --- a/applications/dynaclr/configs/training/OPS-373genes.sh +++ /dev/null @@ -1,27 +0,0 @@ -#!/bin/bash -# OPS 373-gene DynaCLR with gene classifier head -# -# New run: -# sbatch applications/dynaclr/configs/training/OPS-373genes.sh -# -# Resume: edit CKPT_PATH and WANDB_RUN_ID below, then sbatch from RUN_DIR. - -#SBATCH --job-name=dynaclr_ops -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=4 -#SBATCH --gres=gpu:4 -#SBATCH --partition=gpu -#SBATCH --cpus-per-task=15 -#SBATCH --mem-per-cpu=8G -#SBATCH --time=0-22:00:00 - -# ── Run identity ────────────────────────────────────────────────────── -export PROJECT="dynaclr" -export RUN_NAME="OPS-373genes-GeneClassifier" -export CONFIGS="applications/dynaclr/configs/training/OPS-373genes.yml" - -# ── Resume (uncomment to continue from checkpoint) ──────────────────── -# export CKPT_PATH="" -# export WANDB_RUN_ID="" - -source "$(dirname "$0")/slurm/train.sh" diff --git a/applications/dynaclr/configs/training/OPS-373genes.yml b/applications/dynaclr/configs/training/OPS-373genes.yml deleted file mode 100644 index 875f17714..000000000 --- a/applications/dynaclr/configs/training/OPS-373genes.yml +++ /dev/null @@ -1,124 +0,0 @@ -# OPS 373-gene DynaCLR with gene classifier head -# ================================================= -# Fine-tune from pre-trained OPS checkpoint with cosine classifier. -# Gene+reporter positive pairs, stratified by marker (reporter). -# -# Launch: -# sbatch applications/dynaclr/configs/training/OPS-373genes.sh - -seed_everything: 42 - -trainer: - accelerator: gpu - strategy: ddp - devices: 4 - num_nodes: 1 - precision: bf16-mixed - max_epochs: 300 - limit_train_batches: 400 - limit_val_batches: 100 - log_every_n_steps: 5 - enable_checkpointing: true - enable_model_summary: false - inference_mode: true - use_distributed_sampler: false - callbacks: - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: loss/val - every_n_epochs: 1 - save_top_k: 5 - save_last: true - - class_path: viscy_utils.callbacks.SaveConfigToWandb - -model: - class_path: dynaclr.engine.ContrastiveModule - init_args: - encoder: - class_path: viscy_models.contrastive.ContrastiveEncoder - init_args: - backbone: convnext_tiny - in_channels: 1 - in_stack_depth: 1 - stem_kernel_size: [1, 4, 4] - stem_stride: [1, 4, 4] - embedding_dim: 768 - projection_dim: 256 - drop_path_rate: 0.0 - loss_function: - class_path: viscy_models.contrastive.loss.NTXentLoss - init_args: - temperature: 0.5 - ckpt_path: /hpc/projects/intracellular_dashboard/ops/models/logs/dynaclr/ops_bagofchannels_gene_n_reporter_grouped_reporter_256proj_373genes_convnext_tiny_temp0p5_512bs_lr1e-4_pretrained_self/version_0/checkpoints/last.ckpt - lr: 0.0001 - log_batches_per_epoch: 8 - log_samples_per_batch: 1 - log_embeddings_every_n_epochs: 10 - example_input_array_shape: [1, 1, 1, 128, 128] - -data: - class_path: dynaclr.data.datamodule.MultiExperimentDataModule - init_args: - cell_index_path: /hpc/mydata/eduardo.hirata/repos/viscy/applications/dynaclr/configs/cell_index/ops_373genes.parquet - z_window: 1 - yx_patch_size: [224, 224] - final_yx_patch_size: [128, 128] - channels_per_sample: 1 - positive_cell_source: lookup - positive_match_columns: [perturbation, marker] - stratify_by: marker - split_ratio: 0.8 - batch_size: 512 - num_workers: 4 - seed: 0 - shuffle_val: true - label_columns: - gene_label: perturbation - normalizations: - - class_path: viscy_transforms.BatchedScaleIntensityRangePercentilesd - init_args: - keys: [channel_0] - lower: 50 - upper: 99 - b_min: 0.0 - b_max: 1.0 - clip: true - augmentations: - - class_path: viscy_transforms.BatchedRandAffined - init_args: - keys: [channel_0] - prob: 0.8 - scale_range: [[1.0, 1.0], [0.9, 1.1], [0.9, 1.1]] - rotate_range: [3.14, 0.0, 0.0] - shear_range: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] - - class_path: viscy_transforms.BatchedRandFlipd - init_args: - keys: [channel_0] - spatial_axes: [1, 2] - prob: 0.5 - - class_path: viscy_transforms.BatchedRandAdjustContrastd - init_args: - keys: [channel_0] - prob: 0.5 - gamma: [0.8, 1.2] - - class_path: viscy_transforms.BatchedRandScaleIntensityd - init_args: - keys: [channel_0] - prob: 0.5 - factors: 0.5 - - class_path: viscy_transforms.BatchedRandGaussianSmoothd - init_args: - keys: [channel_0] - prob: 0.5 - sigma_x: [0.2, 0.5] - sigma_y: [0.2, 0.5] - sigma_z: [0.0, 0.0] - - class_path: viscy_transforms.BatchedRandGaussianNoised - init_args: - keys: [channel_0] - prob: 0.5 - mean: 0.0 - std: 0.08 diff --git a/applications/dynaclr/configs/training/Phase-contrastive-timeaware.sh b/applications/dynaclr/configs/training/Phase-contrastive/Phase-contrastive-timeaware.sh similarity index 93% rename from applications/dynaclr/configs/training/Phase-contrastive-timeaware.sh rename to applications/dynaclr/configs/training/Phase-contrastive/Phase-contrastive-timeaware.sh index 6637b6634..96dbf1a99 100755 --- a/applications/dynaclr/configs/training/Phase-contrastive-timeaware.sh +++ b/applications/dynaclr/configs/training/Phase-contrastive/Phase-contrastive-timeaware.sh @@ -2,7 +2,7 @@ # Phase contrastive timeaware — DINOv3 frozen backbone + temporal MLP # # New run: -# sbatch applications/dynaclr/configs/training/Phase-contrastive-timeaware.sh +# sbatch applications/dynaclr/configs/training/Phase-contrastive/Phase-contrastive-timeaware.sh # # Resume: edit CKPT_PATH and WANDB_RUN_ID below, then sbatch from RUN_DIR. @@ -18,7 +18,7 @@ # ── Run identity ────────────────────────────────────────────────────── export PROJECT="Phase-contrastive-timeaware" export RUN_NAME="dinov3-mlp-temp0p5" -export CONFIGS="applications/dynaclr/configs/training/Phase-contrastive-timeaware.yml" +export CONFIGS="applications/dynaclr/configs/training/Phase-contrastive/Phase-contrastive-timeaware.yml" # ── Resume (uncomment to continue from checkpoint) ──────────────────── # export CKPT_PATH="" diff --git a/applications/dynaclr/configs/training/Phase-contrastive-timeaware.yml b/applications/dynaclr/configs/training/Phase-contrastive/Phase-contrastive-timeaware.yml similarity index 75% rename from applications/dynaclr/configs/training/Phase-contrastive-timeaware.yml rename to applications/dynaclr/configs/training/Phase-contrastive/Phase-contrastive-timeaware.yml index 5f50eed02..d0007b902 100644 --- a/applications/dynaclr/configs/training/Phase-contrastive-timeaware.yml +++ b/applications/dynaclr/configs/training/Phase-contrastive/Phase-contrastive-timeaware.yml @@ -5,22 +5,17 @@ # Reproduces legacy Phase contrastive timeaware ablations. # # Launch: -# sbatch applications/dynaclr/configs/training/Phase-contrastive-timeaware.sh +# sbatch applications/dynaclr/configs/training/Phase-contrastive/Phase-contrastive-timeaware.sh -seed_everything: 42 +base: + - ../recipes/trainer.yml + - ../recipes/model/dinov3_frozen_mlp.yml trainer: - accelerator: gpu strategy: auto devices: 1 - num_nodes: 1 precision: 32-true max_epochs: 150 - log_every_n_steps: 10 - enable_checkpointing: true - enable_model_summary: false - inference_mode: true - use_distributed_sampler: false callbacks: - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: @@ -36,39 +31,15 @@ trainer: every_n_epochs: 5 label_key: perturbation k: 20 - - class_path: viscy_utils.callbacks.SaveConfigToWandb model: - class_path: dynaclr.engine.ContrastiveModule init_args: - encoder: - class_path: viscy_models.foundation.DINOv3Model - init_args: - model_name: facebook/dinov3-convnext-tiny-pretrain-lvd1689m - freeze: true - projection: - class_path: viscy_models.components.heads.MLP - init_args: - in_dims: 768 - hidden_dims: 768 - out_dims: 128 - norm: ln - activation: relu - loss_function: - class_path: viscy_models.contrastive.loss.NTXentLoss - init_args: - temperature: 0.5 - lr: 0.0001 - log_batches_per_epoch: 3 - log_samples_per_batch: 3 - log_embeddings_every_n_epochs: 10 pca_color_keys: "[perturbation,hours_post_perturbation]" example_input_array_shape: [1, 1, 30, 192, 192] data: class_path: dynaclr.data.datamodule.MultiExperimentDataModule init_args: - collection_path: applications/dynaclr/configs/collections/Phase-contrastive-timeaware.yml cell_index_path: applications/dynaclr/configs/cell_index/Phase-contrastive-timeaware.parquet z_window: 30 z_extraction_window: 40 diff --git a/applications/dynaclr/configs/training/README.md b/applications/dynaclr/configs/training/README.md index f9f933da4..599d1fd45 100644 --- a/applications/dynaclr/configs/training/README.md +++ b/applications/dynaclr/configs/training/README.md @@ -1,96 +1,110 @@ # DynaCLR Training Configs -Composable training configuration using LightningCLI `--config` stacking. -Each layer is a YAML fragment; later configs deep-merge into earlier ones -(dicts merge, lists replace). +Training configuration stack for LightningCLI `--config`. Later configs +deep-merge into earlier ones (dicts merge, lists replace). Each leaf +YAML declares a `base:` list of recipes to compose on top of. -## Structure +## Directory layout ``` configs/training/ - _base.yml Trainer + model defaults (callbacks, optimizer, encoder) - arch/ Encoder geometry (stem, z_depth, patch size) - 2d_z1.yml stem=[1,4,4], z_window=1 - 3d_z16.yml stem=[4,4,4], z_window=16, random Z crop - 3d_z30.yml stem=[5,4,4], z_window=30, 192px patch - data/ Data pipeline: sampling + normalization + augmentations - boc_{dim}_{positive_pair}_{batch_composition}.yml - demo/ Self-contained configs for smoke tests (single --config) - slurm/ SLURM experiment scripts (sbatch entry points) - train.sh Shared launcher (sourced, not sbatch'd directly) - _legacy/ Old monolithic configs (reference only) + DynaCLR-2D/ # 2D (and MIP) time-lapse contrastive runs + DynaCLR-2D-BagOfChannels-v3.{yml,sh} + DynaCLR-2D-MIP-BagOfChannels.{yml,sh} + DynaCLR-2D-MIP-BagOfChannels-single-marker.{yml,sh} + DynaCLR-2D-MIP-BagOfChannels-single-marker-A40.{yml,sh} + DynaCLR-3D/ # 3D time-lapse contrastive runs + DynaCLR-3D-BagOfChannels-v2.{yml,sh} + DynaCLR-3D-BagOfChannels-v2-single-marker.{yml,sh} + DINOv3/ # DINOv3 frozen-encoder + MLP probes + DINOv3-temporal-MLP-2D-BagOfChannels.{yml,sh} + Phase-contrastive/ + Phase-contrastive-timeaware.{yml,sh} + + recipes/ # Reusable building blocks (referenced via base:) + trainer.yml Trainer + logger + common callbacks + model/ Encoder and head architectures + data/ Sampling / positive-pair strategies + augmentations/ Augmentation pipelines (ops_2d_mild, etc.) + + debug/ # Fast-dev-run / tiny configs for reproducing hangs / OOMs + demo/ # Self-contained single-file demos for smoke tests + slurm/ + train.sh Shared launcher sourced by every sbatch script + preprocess.yml Preprocessing config (not a training run) ``` -## Data config naming convention +Each top-level model family lives in its own folder. The `yml` and `sh` +for a given run share a name and a directory so `CONFIGS=` references +stay local. -``` -{channel_mode}_{dim}_{positive_pair_strategy}_{batch_composition}.yml -``` - -| Segment | Values | Meaning | -|---------|--------|---------| -| channel_mode | `boc` | bag-of-channels (1 random channel per sample) | -| dim | `2d`, `3d` | spatial dimensionality | -| positive_pair | `temporal` | same cell lineage at t+tau | -| | `gene-reporter` | same gene + same reporter (OPS) | -| | `self` | SimCLR-style (same crop, different augmentation) | -| batch_composition | `stratify-perturbation` | balance infected/uninfected | -| | `stratify-perturbation-marker` | balance perturbation and organelle marker | -| | `stratify-marker` | balance by reporter/marker only | - -## Composition +## Composition via `base:` -Stack three configs: `_base.yml` + `arch/*.yml` + `data/*.yml`, then -pass experiment-specific values as CLI overrides in the SLURM script. +Each leaf YAML starts with a `base:` list pointing at recipe fragments +(paths are relative to the YAML's directory; since all leaf YAMLs live +one level below `recipes/`, they use `../recipes/...`): -```bash -viscy fit \ - --config _base.yml \ - --config arch/3d_z16.yml \ - --config data/boc_3d_temporal_stratify-perturbation.yml \ - --trainer.devices 4 \ - --data.init_args.batch_size 512 \ - --data.init_args.collection_path path/to/collection.yml +```yaml +# DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels.yml +base: + - ../recipes/trainer.yml + - ../recipes/model/contrastive_encoder_convnext_tiny.yml ``` +`viscy_utils.compose.load_composed_config` walks the `base:` chain, +deep-merges dicts, and replaces lists. + ## SLURM scripts -Each experiment is a thin `.sh` that sets `PROJECT`, `RUN_NAME`, `CONFIGS`, -experiment-specific `EXTRA_ARGS`, and sources `train.sh`: +Each experiment is a thin `.sh` that sets `PROJECT`, `RUN_NAME`, +`CONFIGS`, optional `EXTRA_ARGS`, and sources `slurm/train.sh`: ```bash -# Submit -sbatch slurm/DynaCLR-3D-BagOfChannels-v2.sh +sbatch applications/dynaclr/configs/training/DynaCLR-3D/DynaCLR-3D-BagOfChannels-v2.sh -# Override run name -RUN_NAME=phase2-hcl sbatch slurm/DynaCLR-3D-BagOfChannels-v2.sh +RUN_NAME=phase2-hcl sbatch applications/dynaclr/configs/training/DynaCLR-3D/DynaCLR-3D-BagOfChannels-v2.sh -# Parameter sweep for TEMP in 0.1 0.2 0.5; do RUN_NAME="sweep-temp${TEMP}" \ EXTRA_ARGS="--model.init_args.loss_function.init_args.temperature ${TEMP}" \ - sbatch slurm/DynaCLR-3D-BagOfChannels-v2.sh + sbatch applications/dynaclr/configs/training/DynaCLR-3D/DynaCLR-3D-BagOfChannels-v2.sh done ``` `train.sh` handles: -- `PYTHONNOUSERSITE=1` (prevents `~/.local/` shadowing conda) -- Creates `${MODEL_ROOT}/${PROJECT}/${RUN_NAME}/` output directory -- Copies config files into the run directory for reproducibility -- Sets WandB logger project/name/save_dir via CLI overrides -- Sets checkpoint dirpath via CLI override +- `export PYTHONNOUSERSITE=1` (prevents `~/.local/` shadowing conda) +- Creates `${MODEL_ROOT}/${PROJECT}/${RUN_NAME}/` output dir +- Rotates `config.yaml` from any previous run +- Copies the calling sbatch script into the run dir for reproducibility +- Sets WandB logger project / name / save_dir via CLI overrides +- Optional `CKPT_PATH` resume and `WANDB_RUN_ID` to continue a run -## Adding a new experiment +## Resuming a run -1. Check if an existing `data/*.yml` matches your sampling strategy. - If not, create a new one following the naming convention. -2. Create a new `slurm/.sh` with SBATCH directives and overrides. -3. Submit with `sbatch slurm/.sh`. +```bash +CKPT_PATH=/hpc/projects/.../checkpoints/last.ckpt \ +WANDB_RUN_ID= \ + sbatch --export=ALL,CKPT_PATH,WANDB_RUN_ID \ + applications/dynaclr/configs/training/DynaCLR-3D/DynaCLR-3D-BagOfChannels-v2.sh +``` -## Demo configs +`WANDB_RUN_ID` appends `--trainer.logger.init_args.id= +--trainer.logger.init_args.resume=must` so metrics land on the same +W&B timeline. -Self-contained single-file configs for quick testing: +## Adding a new experiment -```bash -viscy fit --config demo/demo_3d_fit.yml --trainer.fast_dev_run true -``` +1. Find the closest existing run in the matching model family + folder. Copy the `.yml` and `.sh` alongside it with a new name. +2. Edit `base:` in the YAML to pick the right recipes. +3. Override training-specific values in the YAML (or via `EXTRA_ARGS` + in the sbatch script for one-off sweeps). +4. `sbatch applications/dynaclr/configs/training//.sh`. + +## Debug / demo configs + +- `debug/` — fastdev, tiny, and DDP-reproducer configs used to isolate + SLURM hangs, memory spikes, and DDP sync issues. Launched with + `uv run viscy fit --config .yml --config debug/.yml`. +- `demo/` — self-contained single-file configs for quick local smoke + tests (no base chain). diff --git a/applications/dynaclr/configs/training/debug/DynaCLR-2D-MIP-BagOfChannels-fastdev-ddp.sh b/applications/dynaclr/configs/training/debug/DynaCLR-2D-MIP-BagOfChannels-fastdev-ddp.sh new file mode 100755 index 000000000..b42f5a35e --- /dev/null +++ b/applications/dynaclr/configs/training/debug/DynaCLR-2D-MIP-BagOfChannels-fastdev-ddp.sh @@ -0,0 +1,23 @@ +#!/bin/bash +# Fast-dev-run smoke test of BoC training on 4-GPU DDP. +# Goal: validate sampler generator + FOV split + NCCL init + first +# batch end-to-end with the 20k-row boc_tiny parquet. + +#SBATCH --job-name=boc_fastdev_ddp +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=4 +#SBATCH --gres=gpu:4 +#SBATCH --partition=gpu +#SBATCH --cpus-per-task=8 +#SBATCH --mem-per-cpu=4G +#SBATCH --time=0-00:30:00 +#SBATCH --output=/hpc/mydata/eduardo.hirata/repos/viscy/tmp/boc_fastdev_ddp_%j.out + +export PYTHONNOUSERSITE=1 +export NCCL_DEBUG=WARN + +cd /hpc/mydata/eduardo.hirata/repos/viscy + +srun uv run --project . viscy fit \ + --config applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels.yml \ + --config applications/dynaclr/configs/training/DynaCLR-2D-MIP-BagOfChannels-fastdev-ddp.yml diff --git a/applications/dynaclr/configs/training/debug/DynaCLR-2D-MIP-BagOfChannels-fastdev-ddp.yml b/applications/dynaclr/configs/training/debug/DynaCLR-2D-MIP-BagOfChannels-fastdev-ddp.yml new file mode 100644 index 000000000..59e42559b --- /dev/null +++ b/applications/dynaclr/configs/training/debug/DynaCLR-2D-MIP-BagOfChannels-fastdev-ddp.yml @@ -0,0 +1,29 @@ +# SLURM fast-dev-run override for DynaCLR-2D-MIP-BagOfChannels. +# Tests DDP end-to-end on a 20k-row slice: 4 ranks × sampler __iter__ + +# NCCL init + first batch + backward + val. +# +# Launch: +# sbatch applications/dynaclr/configs/training/DynaCLR-2D-MIP-BagOfChannels-fastdev-ddp.sh + +trainer: + strategy: ddp + devices: 4 + num_nodes: 1 + fast_dev_run: 5 + logger: null + callbacks: [] + max_epochs: 1 + +data: + init_args: + # 20k-row slice, enough to exercise sampler/dataset without load cost. + cell_index_path: /hpc/mydata/eduardo.hirata/repos/viscy/tmp/boc_tiny.parquet + batch_size: 8 + num_workers: 1 + prefetch_factor: 2 + buffer_size: 2 + cache_pool_bytes: 0 + +model: + init_args: + ckpt_path: null diff --git a/applications/dynaclr/configs/training/debug/DynaCLR-2D-MIP-BagOfChannels-fastdev.yml b/applications/dynaclr/configs/training/debug/DynaCLR-2D-MIP-BagOfChannels-fastdev.yml new file mode 100644 index 000000000..ad9aa883a --- /dev/null +++ b/applications/dynaclr/configs/training/debug/DynaCLR-2D-MIP-BagOfChannels-fastdev.yml @@ -0,0 +1,32 @@ +# Local fast-dev-run override for DynaCLR-2D-MIP-BagOfChannels. +# Goal: verify training starts and completes ≥1 train + val batch end-to-end +# on a single GPU. Uses the smallest DynaCLR parquet (3.4M rows). +# +# Run: +# uv run viscy fit \ +# --config applications/dynaclr/configs/training/DynaCLR-2D-MIP-BagOfChannels.yml \ +# --config applications/dynaclr/configs/training/DynaCLR-2D-MIP-BagOfChannels-fastdev.yml + +trainer: + strategy: auto + devices: 1 + num_nodes: 1 + fast_dev_run: 5 + logger: null + callbacks: [] + max_epochs: 1 + +data: + init_args: + # ~20k-row slice of the full BoC parquet — enough to exercise every + # sampling path without a 3-minute parquet load. + cell_index_path: /hpc/mydata/eduardo.hirata/repos/viscy/tmp/boc_tiny.parquet + batch_size: 8 + num_workers: 0 + prefetch_factor: null + buffer_size: 2 + cache_pool_bytes: 0 + +model: + init_args: + ckpt_path: null diff --git a/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-fastdev-ddp.sh b/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-fastdev-ddp.sh new file mode 100755 index 000000000..a38ff1d8c --- /dev/null +++ b/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-fastdev-ddp.sh @@ -0,0 +1,24 @@ +#!/bin/bash +# Minimal OPS fast_dev_run on 4-GPU DDP to localize the post-LOCAL_RANK hang. +# Strips callbacks, logger, wandb. + +#SBATCH --job-name=ops_fastdev_ddp +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=4 +#SBATCH --gres=gpu:4 +#SBATCH --partition=gpu +#SBATCH --constraint="h100|h200" +#SBATCH --exclude=gpu-h-5 +#SBATCH --cpus-per-task=15 +#SBATCH --mem-per-cpu=14G +#SBATCH --time=0-01:00:00 +#SBATCH --output=/hpc/mydata/eduardo.hirata/repos/viscy/tmp/ops_fastdev_ddp_%j.out + +export PYTHONNOUSERSITE=1 +export NCCL_DEBUG=WARN + +cd /hpc/mydata/eduardo.hirata/repos/viscy + +srun uv run --project . viscy fit \ + --config applications/dynaclr/configs/training/OPS/OPS-1000genes-allmarkers.yml \ + --config applications/dynaclr/configs/training/OPS-1000genes-allmarkers-fastdev-ddp.yml diff --git a/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-fastdev-ddp.yml b/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-fastdev-ddp.yml new file mode 100644 index 000000000..e2689417c --- /dev/null +++ b/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-fastdev-ddp.yml @@ -0,0 +1,31 @@ +# SLURM fast_dev_run=5 override for OPS to localize the post-LOCAL_RANK hang. +# Strips all callbacks, logger, wandb — just data + model + one train + val batch. +# +# Launch: +# sbatch applications/dynaclr/configs/training/OPS-1000genes-allmarkers-fastdev-ddp.sh + +trainer: + strategy: ddp + devices: 4 + num_nodes: 1 + # Narrowing: wandb logger (31264775) + OnlineEvalCallback (31264776) + # both confirmed harmless. Now testing val_check_interval and limit_* + # knobs. Dropping fast_dev_run so these actually take effect. + callbacks: [] + max_epochs: 1 + limit_train_batches: 10 + limit_val_batches: 5 + val_check_interval: 0.5 + num_sanity_val_steps: 0 + +data: + init_args: + batch_size: 16 + num_workers: 1 + prefetch_factor: 2 + buffer_size: 2 + cache_pool_bytes: 0 + +model: + init_args: + ckpt_path: null diff --git a/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-fastdev.yml b/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-fastdev.yml new file mode 100644 index 000000000..5d37feb38 --- /dev/null +++ b/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-fastdev.yml @@ -0,0 +1,30 @@ +# Local fast-dev-run override for OPS-1000genes-allmarkers. +# Goal: reproduce the OOM path from job 31264591 on a single A40 locally. +# +# Run: +# uv run viscy fit \ +# --config applications/dynaclr/configs/training/OPS-1000genes-allmarkers.yml \ +# --config applications/dynaclr/configs/training/OPS-1000genes-allmarkers-fastdev.yml + +trainer: + strategy: auto + devices: 1 + num_nodes: 1 + fast_dev_run: true + logger: null + callbacks: [] + # fast_dev_run already caps batches/epochs; the explicit limits are defensive. + limit_train_batches: 1 + limit_val_batches: 1 + max_epochs: 1 + +data: + init_args: + batch_size: 8 + num_workers: 0 + prefetch_factor: null + # Skip warm-start checkpoint to keep this self-contained. + +model: + init_args: + ckpt_path: null diff --git a/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-tiny-ddp-local.yml b/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-tiny-ddp-local.yml new file mode 100644 index 000000000..002ff6e6d --- /dev/null +++ b/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-tiny-ddp-local.yml @@ -0,0 +1,33 @@ +# Local single-GPU "DDP" reproducer: strategy=ddp, devices=1. +# Exercises the DDP wrap path without needing 2 GPUs. Keeps wandb ENABLED +# (inherited from the parent OPS config) so we can test whether DDP-wrap +# + wandb is the hang, without SLURM. +# +# Run: +# uv run viscy fit \ +# --config applications/dynaclr/configs/training/OPS-1000genes-allmarkers.yml \ +# --config applications/dynaclr/configs/training/OPS-1000genes-allmarkers-tiny-ddp-local.yml + +trainer: + strategy: ddp + devices: 1 + num_nodes: 1 + max_epochs: 1 + limit_train_batches: 10 + limit_val_batches: 5 + val_check_interval: 0.5 + num_sanity_val_steps: 0 + callbacks: [] + +data: + init_args: + cell_index_path: /hpc/mydata/eduardo.hirata/repos/viscy/tmp/ops_tiny.parquet + batch_size: 8 + num_workers: 0 + prefetch_factor: null + buffer_size: 2 + cache_pool_bytes: 0 + +model: + init_args: + ckpt_path: null diff --git a/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-tiny-ddp.sh b/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-tiny-ddp.sh new file mode 100755 index 000000000..f6e280869 --- /dev/null +++ b/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-tiny-ddp.sh @@ -0,0 +1,25 @@ +#!/bin/bash +# 4-GPU DDP on OPS tiny (346k rows) WITHOUT fast_dev_run. +# Isolates DDP+wandb+val_check_interval from dataset-size effects. + +#SBATCH --job-name=ops_tiny_ddp +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=2 +#SBATCH --gres=gpu:2 +#SBATCH --partition=gpu +# Drop GPU-type constraint to clear the queue faster. nodes=1 guarantees +# the two ranks share a single GPU model, which is what matters for DDP. +#SBATCH --exclude=gpu-h-5 +#SBATCH --cpus-per-task=8 +#SBATCH --mem-per-cpu=8G +#SBATCH --time=0-00:30:00 +#SBATCH --output=/hpc/mydata/eduardo.hirata/repos/viscy/tmp/ops_tiny_ddp_%j.out + +export PYTHONNOUSERSITE=1 +export NCCL_DEBUG=WARN + +cd /hpc/mydata/eduardo.hirata/repos/viscy + +srun uv run --project . viscy fit \ + --config applications/dynaclr/configs/training/OPS/OPS-1000genes-allmarkers.yml \ + --config applications/dynaclr/configs/training/OPS-1000genes-allmarkers-tiny-ddp.yml diff --git a/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-tiny-ddp.yml b/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-tiny-ddp.yml new file mode 100644 index 000000000..d6328fe24 --- /dev/null +++ b/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-tiny-ddp.yml @@ -0,0 +1,33 @@ +# 4-GPU DDP test on OPS tiny (346k rows) WITHOUT fast_dev_run. +# Purpose: narrow the hang — does it need full OPS scale (55M), or does +# DDP+wandb+val_check_interval on any OPS-flavored data reproduce it? +# +# Launch: +# sbatch applications/dynaclr/configs/training/OPS-1000genes-allmarkers-tiny-ddp.sh + +trainer: + strategy: ddp + devices: 2 + num_nodes: 1 + max_epochs: 1 + limit_train_batches: 10 + limit_val_batches: 5 + val_check_interval: 0.5 + num_sanity_val_steps: 0 + callbacks: [] + # 31265169 TIMEOUT with wandb logger on. Now disabling it to isolate + # whether wandb + DDP + no-fastdev is the bug. + logger: null + +data: + init_args: + cell_index_path: /hpc/mydata/eduardo.hirata/repos/viscy/tmp/ops_tiny.parquet + batch_size: 8 + num_workers: 1 + prefetch_factor: 2 + buffer_size: 2 + cache_pool_bytes: 0 + +model: + init_args: + ckpt_path: null diff --git a/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-tiny-full.yml b/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-tiny-full.yml new file mode 100644 index 000000000..73411a4ef --- /dev/null +++ b/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-tiny-full.yml @@ -0,0 +1,34 @@ +# Local reproducer for the post-LOCAL_RANK hang on OPS tiny (346k rows). +# Same as OPS-1000genes-allmarkers-tiny.yml but DROPS fast_dev_run — because +# every passing run used fast_dev_run and every hanging run did not. +# Single GPU to rule out DDP as the variable. +# +# Run: +# uv run viscy fit \ +# --config applications/dynaclr/configs/training/OPS-1000genes-allmarkers.yml \ +# --config applications/dynaclr/configs/training/OPS-1000genes-allmarkers-tiny-full.yml + +trainer: + strategy: auto + devices: 1 + num_nodes: 1 + max_epochs: 1 + limit_train_batches: 10 + limit_val_batches: 5 + val_check_interval: 0.5 + num_sanity_val_steps: 0 + logger: null + callbacks: [] + +data: + init_args: + cell_index_path: /hpc/mydata/eduardo.hirata/repos/viscy/tmp/ops_tiny.parquet + batch_size: 8 + num_workers: 0 + prefetch_factor: null + buffer_size: 2 + cache_pool_bytes: 0 + +model: + init_args: + ckpt_path: null diff --git a/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-tiny.yml b/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-tiny.yml new file mode 100644 index 000000000..4bf557a2b --- /dev/null +++ b/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-tiny.yml @@ -0,0 +1,32 @@ +# Local fast_dev_run override for OPS-1000genes-allmarkers on a tiny slice. +# Purpose: reproduce the full-config hang on a single GPU locally, where +# iteration is ~60 sec/test instead of ~10 min/SLURM-cycle. +# +# Run: +# uv run viscy fit \ +# --config applications/dynaclr/configs/training/OPS-1000genes-allmarkers.yml \ +# --config applications/dynaclr/configs/training/OPS-1000genes-allmarkers-tiny.yml + +trainer: + strategy: auto + devices: 1 + num_nodes: 1 + fast_dev_run: 5 + logger: null + callbacks: [] + max_epochs: 1 + +data: + init_args: + # 346k-row slice: 2 experiments × 5 markers × 20 genes + # Preserves [gene_name, marker] SupCon pairing and batch_group_by=marker. + cell_index_path: /hpc/mydata/eduardo.hirata/repos/viscy/tmp/ops_tiny.parquet + batch_size: 8 + num_workers: 0 + prefetch_factor: null + buffer_size: 2 + cache_pool_bytes: 0 + +model: + init_args: + ckpt_path: null diff --git a/applications/dynaclr/configs/training/demo/demo_2d_fit.yml b/applications/dynaclr/configs/training/demo/demo_2d_fit.yml index b35f31e98..5d2febb41 100644 --- a/applications/dynaclr/configs/training/demo/demo_2d_fit.yml +++ b/applications/dynaclr/configs/training/demo/demo_2d_fit.yml @@ -58,11 +58,6 @@ data: class_path: dynaclr.data.datamodule.MultiExperimentDataModule init_args: # ── Data source ────────────────────────────────────────────────────────── - # For production: use the full v3 collection + parquet - # collection_path: applications/dynaclr/configs/collections/DynaCLR-2D-BagOfChannels-v3.yml - # cell_index_path: /hpc/projects/organelle_phenotyping/models/collections/DynaCLR-2D-BagOfChannels-v3.parquet - # For demo: single zarr, fast startup - collection_path: null cell_index_path: applications/dynaclr/configs/cell_index/example_flat.parquet # ── Patch extraction ───────────────────────────────────────────────────── diff --git a/applications/dynaclr/configs/training/demo/demo_3d_fit.yml b/applications/dynaclr/configs/training/demo/demo_3d_fit.yml index b5a0a0573..c809968f9 100644 --- a/applications/dynaclr/configs/training/demo/demo_3d_fit.yml +++ b/applications/dynaclr/configs/training/demo/demo_3d_fit.yml @@ -58,13 +58,6 @@ data: class_path: dynaclr.data.datamodule.MultiExperimentDataModule init_args: # ── Data source ────────────────────────────────────────────────────────── - # Provide one of collection_path or cell_index_path. - # cell_index_path is faster (skips zarr enumeration at startup). - # For production: use the full v2 collection + parquet - # collection_path: applications/dynaclr/configs/collections/DynaCLR-3D-BagOfChannels-v2.yml - # cell_index_path: /hpc/projects/organelle_phenotyping/models/collections/DynaCLR-3D-BagOfChannels-v2.parquet - # For demo: single zarr, fast startup - collection_path: null cell_index_path: applications/dynaclr/configs/cell_index/example_flat.parquet # ── Patch extraction ───────────────────────────────────────────────────── diff --git a/applications/dynaclr/configs/training/demo/demo_bag_of_channels_v3_fit.yml b/applications/dynaclr/configs/training/demo/demo_bag_of_channels_v3_fit.yml index 8f29ecad4..96eb260b0 100644 --- a/applications/dynaclr/configs/training/demo/demo_bag_of_channels_v3_fit.yml +++ b/applications/dynaclr/configs/training/demo/demo_bag_of_channels_v3_fit.yml @@ -53,8 +53,7 @@ model: data: class_path: dynaclr.data.datamodule.MultiExperimentDataModule init_args: - collection_path: applications/dynaclr/configs/collections/demo_bag_of_channels_v3.yml - cell_index_path: null + cell_index_path: applications/dynaclr/configs/cell_index/demo_bag_of_channels_v3.parquet z_window: 30 yx_patch_size: [288, 288] final_yx_patch_size: [192, 192] diff --git a/applications/dynaclr/configs/training/recipes/augmentations/ops_2d_mild.yml b/applications/dynaclr/configs/training/recipes/augmentations/ops_2d_mild.yml new file mode 100644 index 000000000..763d1cc68 --- /dev/null +++ b/applications/dynaclr/configs/training/recipes/augmentations/ops_2d_mild.yml @@ -0,0 +1,41 @@ +# Augmentation recipe: mild 2D augmentations for OPS data. +# Lighter affine (no Z scaling, no shear), narrower gamma, lower noise std. + +data: + init_args: + augmentations: + - class_path: viscy_transforms.BatchedRandAffined + init_args: + keys: [channel_0] + prob: 0.8 + scale_range: [[1.0, 1.0], [0.9, 1.1], [0.9, 1.1]] + rotate_range: [3.14, 0.0, 0.0] + shear_range: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + - class_path: viscy_transforms.BatchedRandFlipd + init_args: + keys: [channel_0] + spatial_axes: [1, 2] + prob: 0.5 + - class_path: viscy_transforms.BatchedRandAdjustContrastd + init_args: + keys: [channel_0] + prob: 0.5 + gamma: [0.8, 1.2] + - class_path: viscy_transforms.BatchedRandScaleIntensityd + init_args: + keys: [channel_0] + prob: 0.5 + factors: 0.5 + - class_path: viscy_transforms.BatchedRandGaussianSmoothd + init_args: + keys: [channel_0] + prob: 0.5 + sigma_x: [0.2, 0.5] + sigma_y: [0.2, 0.5] + sigma_z: [0.0, 0.0] + - class_path: viscy_transforms.BatchedRandGaussianNoised + init_args: + keys: [channel_0] + prob: 0.5 + mean: 0.0 + std: 0.08 diff --git a/applications/dynaclr/configs/training/recipes/data/ops_gene_reporter.yml b/applications/dynaclr/configs/training/recipes/data/ops_gene_reporter.yml new file mode 100644 index 000000000..f93941c8f --- /dev/null +++ b/applications/dynaclr/configs/training/recipes/data/ops_gene_reporter.yml @@ -0,0 +1,20 @@ +# Data recipe: OPS gene+reporter contrastive learning defaults. +# Leaf configs override: cell_index_path, normalizations (lower percentile differs). + +data: + class_path: dynaclr.data.datamodule.MultiExperimentDataModule + init_args: + z_window: 1 + yx_patch_size: [224, 224] + final_yx_patch_size: [128, 128] + channels_per_sample: 1 + positive_cell_source: lookup + positive_match_columns: [perturbation, marker] + stratify_by: marker + split_ratio: 0.8 + batch_size: 512 + num_workers: 4 + seed: 0 + shuffle_val: true + label_columns: + gene_label: perturbation diff --git a/applications/dynaclr/configs/training/recipes/model/contrastive_encoder_convnext_tiny.yml b/applications/dynaclr/configs/training/recipes/model/contrastive_encoder_convnext_tiny.yml new file mode 100644 index 000000000..3b70366e8 --- /dev/null +++ b/applications/dynaclr/configs/training/recipes/model/contrastive_encoder_convnext_tiny.yml @@ -0,0 +1,18 @@ +# Model recipe: ContrastiveModule with ConvNeXt-Tiny encoder. +# Leaf configs override: in_stack_depth, stem_kernel_size, stem_stride, +# projection_dim, drop_path_rate, temperature, lr, and logging args. + +model: + class_path: dynaclr.engine.ContrastiveModule + init_args: + encoder: + class_path: viscy_models.contrastive.ContrastiveEncoder + init_args: + backbone: convnext_tiny + in_channels: 1 + embedding_dim: 768 + loss_function: + class_path: viscy_models.contrastive.loss.NTXentLoss + log_batches_per_epoch: 3 + log_samples_per_batch: 3 + log_embeddings_every_n_epochs: 10 diff --git a/applications/dynaclr/configs/training/recipes/model/dinov3_frozen_mlp.yml b/applications/dynaclr/configs/training/recipes/model/dinov3_frozen_mlp.yml new file mode 100644 index 000000000..1e2e71699 --- /dev/null +++ b/applications/dynaclr/configs/training/recipes/model/dinov3_frozen_mlp.yml @@ -0,0 +1,27 @@ +# Model recipe: Frozen DINOv3-ConvNeXt-Tiny backbone + trainable MLP projection. +# Leaf configs override: pca_color_keys, example_input_array_shape. + +model: + class_path: dynaclr.engine.ContrastiveModule + init_args: + encoder: + class_path: viscy_models.foundation.DINOv3Model + init_args: + model_name: facebook/dinov3-convnext-tiny-pretrain-lvd1689m + freeze: true + projection: + class_path: viscy_models.components.heads.MLP + init_args: + in_dims: 768 + hidden_dims: 768 + out_dims: 128 + norm: ln + activation: relu + loss_function: + class_path: viscy_models.contrastive.loss.NTXentLoss + init_args: + temperature: 0.5 + lr: 0.0001 + log_batches_per_epoch: 3 + log_samples_per_batch: 3 + log_embeddings_every_n_epochs: 10 diff --git a/applications/dynaclr/configs/training/recipes/trainer.yml b/applications/dynaclr/configs/training/recipes/trainer.yml new file mode 100644 index 000000000..be2e63364 --- /dev/null +++ b/applications/dynaclr/configs/training/recipes/trainer.yml @@ -0,0 +1,34 @@ +# Trainer recipe: DynaCLR shared trainer defaults. +# Includes WandB logger (project/name/save_dir set by train.sh CLI overrides), +# LR monitor, and model checkpoint. Config is saved to trainer.log_dir by +# Lightning's default SaveConfigCallback; the wandb files tab picks it up +# automatically when save_dir matches. +# +# Leaf configs override: strategy, devices, precision, max_epochs, +# logger.init_args.project/name, and optionally re-list callbacks +# to add OnlineEvalCallback (callbacks is a list — it replaces entirely). + +seed_everything: 42 + +trainer: + accelerator: gpu + num_nodes: 1 + log_every_n_steps: 10 + enable_checkpointing: true + enable_model_summary: false + inference_mode: true + use_distributed_sampler: false + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + entity: computational_imaging + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: loss/val + every_n_epochs: 1 + save_top_k: 5 + save_last: true diff --git a/applications/dynaclr/docs/DAGs/ai_ready_datasets.md b/applications/dynaclr/docs/DAGs/ai_ready_datasets.md new file mode 100644 index 000000000..8e000769f --- /dev/null +++ b/applications/dynaclr/docs/DAGs/ai_ready_datasets.md @@ -0,0 +1,163 @@ +# Data Preparation DAG + +## Entry point + +`prepare run -c prepare_config.yaml` (from `airtable_utils`) discovers wells and +channels from NFS, generates all configs and SLURM scripts, and submits the pipeline. + +```bash +prepare run 2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV \ + -c /path/to/prepare_config.yaml + +# Dry-run: generate configs/scripts without submitting +prepare run 2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV \ + -c /path/to/prepare_config.yaml \ + --dry-run +``` + +## Step-by-step detail + +``` +NFS assembled zarr (intracellular_dashboard/organelle_dynamics/{dataset}/2-assemble/) + │ + ▼ +prepare run # discovers wells + channels from NFS zarr + │ airtable_utils.prepare_cli # validates dataset is in Airtable + │ airtable_utils.prepare # generates all configs and scripts + ▼ +{vast_output_dir}/ + ├── crop_concat.yml # biahub concatenate config (wells × channels) + ├── qc_config.yml # focus-slice QC config + ├── sbatch_overrides.sh # optional SLURM overrides for biahub's internal jobs + ├── 01_concatenate.sh # bash (not SLURM): runs biahub + rsync tracking + ├── 02_qc.sh # SLURM: GPU focus-slice detection + └── 03_preprocess.sh # SLURM: CPU normalization stats + │ + ▼ +bash 01_concatenate.sh # NOT a SLURM job — runs interactively + │ Step 1: conda run biahub concatenate -c crop_concat.yml -o {dataset}.zarr -m + │ biahub submits its own SLURM jobs internally via submitit; -m blocks until done + │ Step 2: rsync tracking zarr (NFS → VAST) + ▼ +{dataset}.zarr (OME-Zarr v0.5 / zarr v3, rechunked) +tracking.zarr (cell tracking results) + │ + ├──► sbatch 02_qc.sh # GPU (~30 min) + │ qc run -c qc_config.yml # focus-slice detection on Phase3D channel + │ → writes focus_slice metadata into {dataset}.zarr + │ + └──► sbatch 03_preprocess.sh # CPU, preempted partition (~4 hrs) + viscy preprocess # computes per-channel normalization stats + --data_path {dataset}.zarr + → writes normalization metadata into {dataset}.zarr +``` + +## Pipeline DAG (process dependency) + +``` +NFS zarr (assembled) + │ + ▼ +prepare run ──── generates configs + scripts + │ + ▼ +01_concatenate.sh (interactive bash, blocks until biahub SLURM jobs finish) + │ + ▼ +{dataset}.zarr + tracking.zarr + │ + ├──► 02_qc.sh (SLURM, GPU) → focus_slice metadata in zarr + └──► 03_preprocess.sh (SLURM, CPU) → normalization metadata in zarr +``` + +02_qc and 03_preprocess run in parallel (no dependency between them). +Both write metadata back to the same zarr; their outputs are checked by +`check_preprocessed()` before downstream training or evaluation. + +## Key commands + + +| Step | Command | Input | Output | +| ----------------- | ------------------------------------------------- | ------------------ | --------------------------------------------------------------- | +| Generate + submit | `prepare run -c prepare_config.yaml` | NFS assembled zarr | scripts + configs, submits jobs | +| Status check | `prepare status -c prepare_config.yaml` | - | markdown table (NFS/VAST existence, zarr version, preprocessed) | +| Concatenate | `bash 01_concatenate.sh` | crop_concat.yml | {dataset}.zarr + tracking.zarr | +| QC | `sbatch 02_qc.sh` | qc_config.yml | focus_slice metadata in zarr | +| Preprocess | `sbatch 03_preprocess.sh` | {dataset}.zarr | normalization metadata in zarr | + + +## prepare_config.yaml format + +```yaml +nfs_root: /hpc/projects/intracellular_dashboard/organelle_dynamics +vast_root: /hpc/projects/organelle_phenotyping/datasets +workspace_dir: /hpc/mydata/eduardo.hirata/repos/viscy + +concatenate: + channel_names: null # null = auto-detect raw channels (Phase3D + "raw " prefix) + chunks_czyx: [1, 16, 256, 256] + shards_ratio: [1, 1, 8, 8, 8] + output_ome_zarr_version: "0.5" + conda_env: biahub + sbatch_overrides: # optional: overrides for biahub's internal SLURM jobs + partition: preempted + mem-per-cpu: 16G + +qc: + channel_names: [Phase3D] + NA_det: 1.35 + lambda_ill: 0.450 + pixel_size: 0.1494 + midband_fractions: [0.125, 0.25] + device: cuda + num_workers: 16 + +preprocess: + channel_names: -1 # -1 = all channels + num_workers: 32 + block_size: 32 + +slurm: + qc: + partition: gpu + gres: gpu:1 + cpus_per_task: 16 + mem_per_cpu: 4G + time: "00:30:00" + preprocess: + partition: preempted + cpus_per_task: 32 + mem_per_cpu: 4G + time: "04:00:00" +``` + +## Notes + +- `prepare run` validates the dataset exists in Airtable before generating anything. +Use `--force` to overwrite an existing VAST zarr (e.g. to upgrade from zarr v2 to v0.5). +- `01_concatenate.sh` is an interactive bash script, not a SLURM job. Run it from a login +node or an interactive session; it blocks until biahub's internal SLURM jobs finish (`-m` flag). +- `02_qc.sh` and `03_preprocess.sh` are independent — submit both immediately after +`01_concatenate.sh` completes; no need to wait for QC before running preprocess. +- Channel auto-detection (`channel_names: null`) keeps channels with prefix `Phase3D` or `raw` . +Virtual stains (`nuclei_prediction`, `membrane_prediction`) and deconvolved channels are excluded. +- `check_preprocessed()` checks for `normalization` key in zarr metadata; used by `prepare status` +and as a gate before evaluation. +- Raw channel names written to `crop_concat.yml` are repeated once per well entry — this is a +biahub concatenate requirement. + +## Path convention + +All AI-ready data lives under `/hpc/projects/organelle_phenotyping/`: + + +| Directory | Contents | +| -------------------------- | --------------------------------------------- | +| `datasets//` | Zarr v3 store + `tracking.zarr` | +| `datasets/annotations/` | Per-experiment annotation CSVs | +| `models/collections/` | Cell index parquets (one per collection YAML) | +| `models//` | Training runs (checkpoints, WandB configs) | + + +Collection YAMLs use `datasets_root: /hpc/projects/organelle_phenotyping` and +`${datasets_root}/datasets/...` placeholders — resolved at load time by `load_collection()`. diff --git a/applications/dynaclr/docs/DAGs/evaluation.md b/applications/dynaclr/docs/DAGs/evaluation.md new file mode 100644 index 000000000..79f26d95e --- /dev/null +++ b/applications/dynaclr/docs/DAGs/evaluation.md @@ -0,0 +1,472 @@ +# Evaluation DAG + +## Running with Nextflow (recommended) + +```bash +module load nextflow/24.10.5 + +nextflow run applications/dynaclr/nextflow/main.nf \ + --eval_config applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels.yaml \ + --workspace_dir /hpc/mydata/eduardo.hirata/repos/viscy \ + -resume +``` + +`-resume` makes Nextflow skip steps whose outputs already exist. Re-run the same command after a failure — Nextflow picks up from where it left off. + +### Local test (no SLURM) + +```bash +nextflow run applications/dynaclr/nextflow/main.nf \ + --eval_config applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels_test.yaml \ + --workspace_dir /hpc/mydata/eduardo.hirata/repos/viscy \ + -profile local \ + -resume +``` + +## Pipeline entry point + +`dynaclr prepare-eval-configs` (also aliased as `dynaclr evaluate`) generates all YAML configs +under `output_dir/configs/` and prints a JSON manifest to stdout. Nextflow reads the manifest +to wire steps together. + +``` +eval_config.yaml + │ + ▼ +dynaclr prepare-eval-configs -c eval_config.yaml # writes configs/ + manifest JSON + │ + ▼ +output_dir/configs/ + ├── eval.yaml # copy of input config (for re-runs) + ├── predict.yml # GPU step: viscy predict + ├── reduce.yaml # template: dynaclr reduce-dimensionality (per-experiment) + ├── reduce_combined.yaml # CPU step: dynaclr combined-dim-reduction (joint) + ├── smoothness.yaml # template: dynaclr evaluate-smoothness (per-experiment) + ├── plot.yaml # template: dynaclr plot-embeddings (per-experiment) + ├── plot_combined.yaml # CPU step: dynaclr plot-embeddings (all experiments) + ├── {block_name}.yaml # template: dynaclr compute-mmd (per-experiment, per-block) + ├── {block_name}_cross_exp.yaml # CPU step: dynaclr compute-mmd --combined (per-block) + └── linear_classifiers.yaml # CPU step (optional) +``` + +## Step-by-step detail + +``` +checkpoint.ckpt + cell_index.parquet + │ + ▼ +viscy predict -c predict.yml # MultiExperimentDataModule predict mode + │ EmbeddingWriter callback # normalizations + z_reduction, no augmentations + ▼ # obs: fov_name, id, t, track_id, +embeddings/embeddings.zarr # experiment, marker, perturbation, + │ (AnnData: .X=features, # hours_post_perturbation, organelle, well, microscope + │ .obs=cell metadata) + │ + ▼ +dynaclr split-embeddings \ + --input embeddings/embeddings.zarr \ + --output-dir embeddings/ + │ Splits by obs["experiment"], deletes combined zarr + │ Also writes configs/viewer.yaml (datasets: {exp: {hcs_plate, anndata}}) + │ hcs_plate read from obs["store_path"] of each split zarr + ▼ +embeddings/{experiment_A}.zarr +embeddings/{experiment_B}.zarr + ... +configs/viewer.yaml # nd-embedding viewer config (also valid input + ... # for combined-dim-reduction via datasets: key) + │ + ├──► dynaclr reduce-dimensionality # PCA only (per experiment, parallel SLURM jobs) + │ -c reduce.yaml # __ZARR_PATH__ substituted by Nextflow + │ → {experiment}.zarr (obsm: X_pca) + │ NOTE: skip PHATE here to avoid computing it twice + │ + │ (after reduce-dimensionality finishes for ALL experiments) + │ + ├──► dynaclr combined-dim-reduction # joint PCA + PHATE across all experiments + │ -c reduce_combined.yaml # fits on concatenated embeddings + │ → {experiment}.zarr (obsm: X_pca_combined, X_phate_combined) + │ + │ (after combined-dim-reduction finishes) + │ + ├──► dynaclr plot-embeddings # per-experiment PCA scatter (X_pca) + │ -c plot.yaml # parallel SLURM jobs, one per experiment + │ → plots/{experiment}/*.pdf + │ + ├──► dynaclr plot-embeddings # all-experiments combined (X_pca_combined, X_phate_combined) + │ -c plot_combined.yaml # concatenates all zarrs into one figure + │ → plots/combined/*.pdf + │ + ├──► dynaclr evaluate-smoothness # temporal smoothness + dynamic range + │ -c smoothness.yaml # parallel SLURM jobs, one per experiment + │ → smoothness/{model}_per_marker_smoothness.csv # one row per marker + │ → smoothness/{model}_smoothness_stats.csv # mean ± std across markers + │ → smoothness/*.pdf # per-marker + per-model plots + │ + ├──► dynaclr compute-mmd # one SLURM job per (experiment, block) + │ -c {block_name}.yaml # __ZARR_PATH__ substituted by Nextflow + │ → mmd/{block_name}/mmd_results.csv + │ → mmd/{block_name}/kinetics.pdf + │ → mmd/{block_name}/activity_heatmap.pdf + │ + ├──► dynaclr compute-mmd --combined # pairwise cross-experiment batch effect detection + │ -c {block_name}_cross_exp.yaml # only generated when combined_mode: true + │ # For each marker shared by a pair of experiments, runs MMD per + │ # (condition, time_bin) after per-pair mean centering. + │ # Conditions are auto-discovered from data intersection. + │ → mmd/{block_name}_cross_exp/combined_mmd_results.csv + │ → mmd/{block_name}_cross_exp/kinetics.pdf + │ → mmd/{block_name}_cross_exp/activity_heatmap.pdf + │ + ├──► dynaclr run-linear-classifiers # logistic regression probe + │ -c linear_classifiers.yaml # reads per-experiment zarrs directory + annotation CSVs + │ # joins annotations on (fov_name, t, track_id); trains one LogisticRegression + │ # per (task, marker); marker_filters omitted → auto-discovers all markers + │ # also saves trained pipelines to linear_classifiers/pipelines/ for append-predictions + │ → linear_classifiers/metrics_summary.csv + │ → linear_classifiers/{task}_summary.pdf + │ → linear_classifiers/pipelines/{task}_{marker}.joblib + │ → linear_classifiers/pipelines/manifest.json + │ + ├──► dynaclr append-annotations # persist ground truth labels to per-experiment zarrs + │ -c append_annotations.yaml # reads annotation CSVs + writes task columns to zarr obs + │ # only experiments with AnnotationSource entries are processed; others skipped + │ → {experiment}.zarr (obs: infection_state, organelle_state, ...) + │ + └──► dynaclr append-predictions # (after linear_classifiers) apply saved classifiers + -c append_predictions.yaml # predicts on ALL cells per marker, not just annotated ones + # loads pipelines/manifest.json, applies each pipeline to matching marker cells + → {experiment}.zarr (obs: predicted_infection_state, ...) + → {experiment}.zarr (obsm: predicted_infection_state_proba, ...) + → {experiment}.zarr (uns: predicted_infection_state_classes, ...) + +checkpoint.ckpt (independent of predict/split — runs in parallel) + │ + ▼ +viscy export -c export_onnx.yml # export backbone to ONNX + │ + ▼ +model.onnx + CTC datasets ({seq}_ERR_SEG/, {seq}/, {seq}_GT/TRA/) + │ + ▼ +dynaclr evaluate-tracking-accuracy \ # ILP tracking on CTC benchmarks + -c tracking_accuracy.yaml # loops over (model, dataset, sequence) + │ builds tracksdata graph from segmentation masks + │ runs ONNX inference on cell crops → dynaclr_similarity edge cost + │ solves ILP; compares to GT via evaluate_ctc_metrics + │ set show_napari: true for interactive inspection + ▼ +tracking_accuracy/results.csv # one row per (model, dataset, sequence) +tracking_accuracy/ # grouped mean summary printed to stdout +``` + +After all enrichment steps complete, per-experiment zarrs contain: + +- `.obs`: embeddings metadata + annotations (`infection_state`, etc.) + predictions (`predicted_infection_state`, etc.) +- `.obsm`: `X_pca`, `X_pca_combined`, `X_phate_combined`, `predicted_{task}_proba` +- `.uns`: `predicted_{task}_classes` + +This enables plots colored by experiment, perturbation, annotation, and prediction from a single zarr. + +## Nextflow DAG (process dependency graph) + +``` +checkpoint.ckpt ──────────────────────────────────────────────────────────────┐ + │ │ + ▼ ▼ +PREPARE_CONFIGS EXPORT_ONNX (optional) + │ │ + ▼ ▼ +PREDICT (GPU) model.onnx + CTC datasets + │ │ + ▼ ▼ +SPLIT (CPU light) TRACKING_ACCURACY (CPU) + │ → results.csv + ├─[scatter]─► REDUCE ─[gather]─► REDUCE_COMBINED ─┐ + │ │ + ├─► APPEND_ANNOTATIONS ───────────────────────────►├─[scatter]─► PLOT + │ │ [gather]─► PLOT_COMBINED + ├─► LINEAR_CLASSIFIERS ─► APPEND_PREDICTIONS ─────►┘ + │ + ├─[scatter]─► SMOOTHNESS ─[gather]─► SMOOTHNESS_GATHER + ├─[scatter per (exp,block)]─► MMD ─[gather]─► MMD_PLOT_HEATMAP + └─[gather per block]─► MMD_COMBINED +``` + +Key: **scatter** = one SLURM job per experiment (parallel). **gather** = waits for all scatter jobs. + +`TRACKING_ACCURACY` is independent of the embedding pipeline — it reads directly from an ONNX +model and CTC-format data. Run it manually or as a separate Nextflow job alongside the main DAG. + +`APPEND_ANNOTATIONS` and `APPEND_PREDICTIONS` emit a `'skip'` signal when not present in +`steps`, so `PLOT` and `PLOT_COMBINED` always proceed once `REDUCE_COMBINED` finishes. + +## CTC Tracking Accuracy Benchmark + +Standalone benchmark that evaluates whether DynaCLR embeddings improve cell tracking +accuracy on [Cell Tracking Challenge](https://celltrackingchallenge.net/) datasets. +**Not part of the Nextflow embedding pipeline** — run independently after exporting an ONNX model. + +### Approach + +``` +CTC segmentation masks + raw images + │ + ▼ +tracksdata graph (RegionPropsNodes + DistanceEdges) + │ + ├── baseline: IoU edge weights (no model) + │ + └── DynaCLR: ONNX inference on cell crops + → dynaclr_similarity × spatial_dist_weight as ILP edge cost + │ + ▼ +ILPSolver → tracked graph + │ + ▼ +evaluate_ctc_metrics vs. ground truth + │ + ▼ +results.csv (model × dataset × sequence × CTC metrics) +``` + +### Usage + +```bash +dynaclr evaluate-tracking-accuracy -c tracking_accuracy_config.yaml +``` + +### Config format + +```yaml +models: + - path: /hpc/projects/.../model_ckpt146.onnx + label: DynaCLR-classical + - path: /hpc/projects/.../model_ckpt185.onnx + label: DynaCLR-timeaware + - path: null # baseline: IoU + spatial distance only + label: baseline-iou + +datasets: + - path: /hpc/reference/group.royer/CTC/training/BF-C2DL-HSC + sequences: ["01", "02"] + - path: /hpc/reference/group.royer/CTC/training/Fluo-C2DL-Huh7 + sequences: ["01", "02"] + +crop_shape: [64, 64] # must match the model's training resolution +distance_threshold: 325.0 # spatial candidate edge threshold (pixels) +n_neighbors: 10 +delta_t: 5 # max frame gap for candidate edges +batch_size: 128 +output_dir: /path/to/tracking_accuracy_results +``` + +### Output + +**`results.csv`** — one row per (model, dataset, sequence): + +| Column | Description | +|--------|-------------| +| `model` | Model label | +| `dataset` | CTC dataset name | +| `sequence` | Sequence number (01, 02) | +| `LNK` | CTC Linking metric | +| `TRA` | Tracking metric | +| `DET` | Detection metric | +| `CHOTA` | Cell-specific HOTA | +| `HOTA` | Higher Order Tracking Accuracy | +| `MOTA` | Multiple Object Tracking Accuracy | +| `IDF1` | ID F1 score | +| `BIO(0)` | Biological metric | +| `OP_CLB(0)` | Combined linking+bio score | + +Prints a grouped summary (mean over sequences) at the end. + +### Prerequisites + +1. Export the model to ONNX: + ```bash + viscy export -c export_onnx.yml + ``` +2. CTC datasets must have `{seq}_ERR_SEG/`, `{seq}/`, and `{seq}_GT/TRA/` subdirectories. +3. Install eval dependencies: `uv sync --all-packages --extra eval` + +## Cross-model comparison + +After running evals for multiple models, compare results with: + +```bash +python applications/dynaclr/scripts/evaluation/compare_evals.py -c eval_registry.yml +``` + +Registry format: + +```yaml +models: + - name: DynaCLR-v3 + eval_dir: /path/to/eval_v3 + - name: DINOv3-MLP + eval_dir: /path/to/eval_dino +output_dir: /path/to/comparison_output +fdr_threshold: 0.05 +``` + +Auto-discovers results from each `eval_dir` and produces overlaid plots and summary CSVs for +smoothness, linear classifiers, and MMD. + +## Key commands + +| Step | Command | Input | Output | +|------|---------|-------|--------| +| Config gen | `dynaclr prepare-eval-configs -c eval.yaml` | eval config | configs/ + manifest JSON | +| Predict | `viscy predict -c predict.yml` | checkpoint + parquet | embeddings/embeddings.zarr | +| Split | `dynaclr split-embeddings --input ... --output-dir ...` | combined zarr | per-experiment zarrs + `configs/viewer.yaml` | +| Dim reduction | `dynaclr reduce-dimensionality -c reduce.yaml` | {experiment}.zarr | zarr with X_pca | +| Combined reduction | `dynaclr combined-dim-reduction -c reduce_combined.yaml` | all {experiment}.zarr | zarrs with X_pca_combined/X_phate_combined | +| Plots (per-exp) | `dynaclr plot-embeddings -c plot.yaml` | {experiment}.zarr | plots/{experiment}/*.pdf | +| Plots (combined) | `dynaclr plot-embeddings -c plot_combined.yaml` | all {experiment}.zarr | plots/combined/*.pdf | +| Smoothness | `dynaclr evaluate-smoothness -c smoothness.yaml` | {experiment}.zarr | per_marker_smoothness.csv, smoothness_stats.csv | +| MMD (per-exp) | `dynaclr compute-mmd -c {block}.yaml` | {experiment}.zarr | mmd/{block}/mmd_results.csv | +| MMD (combined) | `dynaclr compute-mmd --combined -c {block}_cross_exp.yaml` | all {experiment}.zarr | mmd/{block}_cross_exp/combined_mmd_results.csv | +| MMD (pooled) | `dynaclr compute-mmd --pooled -c pooled.yaml` | all {experiment}.zarr | mmd_results.csv | +| Linear probe | `dynaclr run-linear-classifiers -c clf.yaml` | per-experiment zarrs + annotations | metrics_summary.csv, {task}_summary.pdf, pipelines/ | +| Append annotations | `dynaclr append-annotations -c append_annotations.yaml` | per-experiment zarrs + annotation CSVs | zarrs with obs annotation columns | +| Append predictions | `dynaclr append-predictions -c append_predictions.yaml` | per-experiment zarrs + pipelines/ | zarrs with predicted_{task} in obs/obsm/uns | +| Compare models | `python compare_evals.py -c eval_registry.yml` | multiple eval dirs | comparison CSVs + plots | +| CTC tracking | `dynaclr evaluate-tracking-accuracy -c tracking_accuracy.yaml` | ONNX model + CTC datasets | tracking_accuracy/results.csv | + +## Placeholder pattern + +Template YAMLs (`reduce.yaml`, `smoothness.yaml`, `{block}.yaml`, `plot.yaml`) contain `__ZARR_PATH__` +as a placeholder for `input_path`. `plot.yaml` also contains `__PLOT_DIR__`. Nextflow process +scripts substitute these inline with Python one-liners before calling the CLI command: + +```python +import yaml +with open('reduce.yaml') as f: + cfg = yaml.safe_load(f) +cfg['input_path'] = '/path/to/experiment.zarr' +with open('reduce_patched.yaml', 'w') as f: + yaml.dump(cfg, f, default_flow_style=False, sort_keys=False) +``` + +For `reduce_combined.yaml`, `plot_combined.yaml`, and `{block}_cross_exp.yaml`, Nextflow collects +all zarr paths and writes the `input_paths` list directly. + +## Notes + +- `MultiExperimentDataModule` supports `stage="predict"` since the eval orchestrator was added. + It uses the full cell index (no train/val split), applies only normalizations + z-reduction (no augmentations). +- `BatchedChannelWiseZReductiond` is architecturally required for 2D models even at inference time + (converts 3D z-stack → 2D MIP/center-slice). The orchestrator moves it from `augmentations` + to `normalizations` in the generated predict config. +- Dimensionality reductions (PCA, PHATE) are **not** computed inline during predict. + They run as separate CPU steps after splitting, keeping predict fast. +- The `combined-dim-reduction` step fits reductions on all experiments jointly and writes + `X_pca_combined` / `X_phate_combined` back to each per-experiment zarr. +- PHATE is not computed per-experiment by default (`reduce_dimensionality.phate: null`). Run it only jointly via `reduce_combined`. +- `configs/viewer.yaml` is generated after split and can be passed directly to `dynaclr combined-dim-reduction`. +- MMD reads `.X` (raw backbone embeddings) by default. It can also run on `X_pca` or `X_pca_combined` via `embedding_key`. +- Embeddings obs carries `organelle`, `well`, and `microscope` in addition to `experiment`, `marker`, `perturbation`, `hours_post_perturbation`. + +## MMD config format + +Use `configs/evaluation/recipes/mmd_defaults.yml` as a base to avoid repeating MMD algorithm parameters: + +```yaml +# Per-experiment (template — __ZARR_PATH__ substituted at runtime) +base: recipes/mmd_defaults.yml +input_path: __ZARR_PATH__ +output_dir: /path/to/evaluation/mmd/perturbation/ +group_by: perturbation +comparisons: + - cond_a: uninfected + cond_b: ZIKV + label: "uninfected vs ZIKV" +embedding_key: null # null = raw .X; or "X_pca", "X_pca_combined" +temporal_bin_size: 4.0 # uniform bin width in hours (null = aggregate) +# temporal_bins: [0, 6, 12, 24] # alternative: explicit bin edges (mutually exclusive) +mmd: + balance_samples: true # subsample larger group to match smaller + share_bandwidth_from: "uninfected vs uninfected" # reuse bandwidth from baseline comparison +map_settings: + enabled: true # compute mAP via copairs alongside MMD + +# Cross-experiment ({block}_cross_exp.yaml — input_paths substituted at runtime) +# No comparisons — conditions auto-discovered from data intersection. +base: recipes/mmd_defaults.yml +input_paths: [__ZARR_PATH__] +output_dir: /path/to/evaluation/mmd/perturbation_cross_exp/ +group_by: perturbation +temporal_bin_size: 4.0 + +# Pooled (standalone CLI only — not generated by orchestrator) +base: recipes/mmd_defaults.yml +input_paths: + - /path/to/exp_A.zarr + - /path/to/exp_B.zarr +output_dir: /path/to/evaluation/mmd/pooled/ +comparisons: + - cond_a: uninfected + cond_b: ZIKV + label: "uninfected vs ZIKV" +condition_aliases: + uninfected: [uninfected, uninfected1, uninfected2] # map variants to canonical name +``` + +## MMD output columns + +### Per-experiment and pooled (`mmd_results.csv`) + +| Column | Description | +|--------|-------------| +| `experiment` | Experiment name (absent in pooled output) | +| `marker` | Organelle marker (e.g., "TOMM20", "G3BP1") | +| `cond_a` | Reference/control condition | +| `cond_b` | Treatment condition | +| `label` | Human-readable comparison label | +| `hours_bin_start` | Start of temporal bin (NaN if no binning) | +| `hours_bin_end` | End of temporal bin (NaN if no binning) | +| `n_a` | Cells from `cond_a` used after subsampling | +| `n_b` | Cells from `cond_b` used after subsampling | +| `mmd2` | Unbiased MMD² estimate | +| `p_value` | Permutation test p-value (Phipson & Smyth smoothed) | +| `q_value` | BH-corrected FDR (pooled mode only) | +| `bandwidth` | Gaussian RBF bandwidth | +| `effect_size` | mmd2 / bandwidth (scale-free) | +| `activity_zscore` | (mmd2 − null_mean) / null_std — normalized against permutation null | +| `map_value` | Mean Average Precision (NaN if map_settings.enabled=false) | +| `map_p_value` | mAP permutation p-value (NaN if map_settings.enabled=false) | +| `embedding_key` | Embedding used ("X" or obsm key) | + +### Cross-experiment (`combined_mmd_results.csv`) + +| Column | Description | +|--------|-------------| +| `marker` | Organelle marker | +| `exp_a` | First experiment in the pair | +| `exp_b` | Second experiment in the pair | +| `condition` | Condition value matched across experiments | +| `hours_bin_start` | Start of temporal bin (NaN if no binning) | +| `hours_bin_end` | End of temporal bin (NaN if no binning) | +| `n_a` | Cells from `exp_a` used | +| `n_b` | Cells from `exp_b` used | +| `mmd2` | Unbiased MMD² estimate | +| `p_value` | Permutation test p-value | +| `bandwidth` | Gaussian RBF bandwidth | +| `effect_size` | mmd2 / bandwidth | +| `activity_zscore` | (mmd2 − null_mean) / null_std | +| `embedding_key` | Embedding used | + +## Linear classifiers output columns + +| Column | Description | +|--------|-------------| +| `task` | Classification task (e.g., `infection_state`) | +| `marker_filter` | Marker used to filter cells (one row per marker per task) | +| `n_samples` | Total annotated cells used | +| `val_accuracy` | Validation accuracy | +| `val_weighted_f1` | Validation weighted F1 | +| `val_auroc` | Validation AUROC (OvR macro for multiclass) | +| `train_*` | Training set counterparts of the above | +| `val_{class}_f1` | Per-class F1 on validation set | diff --git a/applications/dynaclr/docs/DAGs/pseudotime.md b/applications/dynaclr/docs/DAGs/pseudotime.md new file mode 100644 index 000000000..5862d443f --- /dev/null +++ b/applications/dynaclr/docs/DAGs/pseudotime.md @@ -0,0 +1,793 @@ +# Pseudotime DAG + +Pipeline for DTW-based pseudotime alignment of cell trajectories. +Each stage is a standalone Python script; outputs from one stage feed the next. + +The pipeline is organised around three explicit axes: + +- **Task** — the event that anchors the time-alignment (`t_key_event`). Derived from the anchor label (e.g. first `infected` frame for `infection_onset`). Planned tasks: cell division, cell death. +- **Channel** — which embedding zarr to align (`phase`, `sensor`, `organelle_sec61`, `organelle_g3bp1`). +- **Annotated candidates** — which cells to build the template from, plus per-frame labels. Expressed as an annotations CSV so users can inspect, curate, or hand-write the list. + +### What `t_rel = 0` actually means (infection templates) + +The `infection_state` label is derived from the **NS3 protease sensor translocating to the nucleus**: a viral-protease-cleavable reporter that gets transcribed, and once NS3 is expressed and cleaves it, the reporter moves nucleus-ward. So `infected = True` at the single-cell level means "NS3 protease is active in this cell," which is downstream of: + +1. Virus entry, endocytosis, RNA release (minutes–hours earlier) +2. Initial translation of the viral polyprotein +3. ER membrane invagination to form replication organelles +4. NS3 accumulation to a level high enough to cleave the sensor +5. Sensor translocation past the detection threshold + +**Implication for organelle remodeling.** Organelle changes that happen at `t_rel < 0` are not alignment artifacts or noise — they are the upstream biology of infection. For ZIKV/DENV specifically, ER (SEC61) remodeling *must* precede `t_rel = 0` because the replication organelles are what let NS3 be made in the first place. G3BP1 (stress granules) may show biphasic kinetics: mild rise while the virus suppresses SG formation, then a sharp rise once antiviral response breaks through. See Hofstadter & Cristea 2025 (Annu. Rev. Virol., DOI 10.1146/annurev-virology-092623-094221) for the review. The sensor template gives us a **reproducible, late-stage, cell-intrinsic anchor** — not the start of infection, but a reliable clock we can measure other events against. + +## Directory layout + +``` +applications/dynaclr/ +├── configs/pseudotime/ +│ ├── datasets.yaml # shared infra: datasets + embedding patterns (loaded by every stage via --datasets) +│ ├── build_template.yaml # Stage 1 recipe: candidate_sets + templates +│ └── align_cells.yaml # Stage 2 recipe: query_sets +├── docs/DAGs/pseudotime.md # this file +└── scripts/pseudotime/ + ├── utils.py # shared helpers (load_stage_config, read_focus_slice) + ├── sweep_pcs.py # PCA sweep — build × align × compare for multiple n_components + ├── 0-select_candidates/ + │ ├── select_candidates.py # Stage 0: auto path (from annotations) + │ ├── manual_candidates.py # Stage 0: manual path (hand-picked tracks) + │ ├── inspect_candidates.py # Stage 0 QC: per-track-anchored image montage + QC CSV + │ ├── refine_candidates.py # Stage 0.5: bootstrap-rank candidates by DTW cost, keep top-N + │ └── candidates/ # output: {set}_annotations.csv + _montage.png + _qc.csv + _ranking.csv + ├── 1-build_template/ + │ ├── build_template.py # build DBA template (raw + PCA flavors) + │ ├── evaluate_template.py # self-align (build-set only, sanity check) + │ ├── plot_pcs.py # PCs over pseudotime (post-hoc PCA, self-align) + │ ├── templates/ # output: template_*.zarr + │ ├── alignments/ # output: self-align parquet + │ └── plots/ # output: self-align montages + PC plots + ├── 2-align_cells/ + │ ├── align_cells.py # subsequence-DTW scan template over query tracks + │ ├── rank_by_cost.py # DTW cost histogram + duration scatter + │ ├── plot_top_n_montage.py # montage of top/worst-N cells anchored at template t=0 + │ ├── plot_pcs_aligned.py # PCs vs real time: pre/aligned/post (query cells) + │ ├── alignments/ # output: {template}_{flavor}_on_{query_set}.parquet + │ └── plots/ # output: cost diagnostics + montages + PC plots + pca_sweep_*.png/.md + └── 3-organelle-remodeling/ + ├── plot_organelle_remodeling.py # Stage 3a: organelle-channel remodeling vs sensor-aligned t_rel + ├── plot_aligned_montage.py # Stage 3b: dual-channel (organelle + sensor) montage, orange border on remodel frames + ├── compute_timing_metrics.py # Stage 3c: per-cell timing scalars from embedding cosine distance (t_onset_abs, t50, t_peak, Δpeak, rise_rate_per_hour) + ├── compute_label_timing.py # Stage 3d: per-cell timing scalars from LC predictions (t_first_pos, t_run_start, pos_fraction, flips) + ├── plots/ # output: organelle_remodeling_*.png, aligned_montage_*.png + ├── timing/ # output: compute_timing_metrics per-cell parquet + summary.md + compare_*.png/.md + └── timing_labels/ # output: compute_label_timing per-cell parquet + summary.md + compare_*.png/.md +``` + +## DAG + +``` + ┌──────────────── AUTO ────────────────┐ ┌──────── MANUAL (debug/test) ────────┐ + │ │ │ │ + │ [annotations.csv] [embedding .zarr] │ │ user-observed phenotypes │ + │ │ │ │ │ │ │ + │ └──────────┬─────────┘ │ │ ▼ │ + │ ▼ │ │ manual_candidates.py │ + │ select_candidates.py --candidate-set │ │ (hand-picked track specs with │ + │ (filter tracks, emit per-frame │ │ [t_on, t_off] label intervals) │ + │ labels over the crop window) │ │ │ + └──────────────┬───────────────────────┘ └──────────────┬──────────────────────┘ + │ │ + └───────────────────┬──────────────────────┘ + ▼ + candidates/{candidate_set}_annotations.csv + (dataset_id, fov_name, track_id, t, + infection_state, organelle_state, cell_division_state) + one row per (cell, frame) over the crop window + │ + ▼ + 1-build_template/build_template.py --template {name} + (join CSV with embedding zarr on (dataset_id, fov_name, track_id, t), + derive per-cell crop window and t_key_event from the annotations, + apply optional per-experiment z-score + L2-normalize, + run DTW-DBA (cosine metric) to build TWO template flavors in parallel: + raw/ — template in 768-D embedding space + pca/ — template after PCA to N components + save template zarr) + │ + ▼ + templates/template_{name}.zarr + ├── raw/template (T, 768) DBA template, 768-D + ├── raw/time_calibration (T,) minutes relative to t_key_event + ├── raw/template_labels/{col} (T,) per-position label fractions + ├── pca/template (T, N) DBA template, PCA-reduced + ├── pca/time_calibration (T,) + ├── pca/template_labels/{col} (T,) + ├── pca/components (N, D) build-time PCA model + ├── pca/mean (D,) + ├── pca/explained_variance_ratio (N,) + ├── zscore_params/{ds_id}/* (D,) only if zscore=per_dataset + ├── t_key_event (N_cells,) per-cell anchor frame + └── attrs: template_cell_ids, l2_normalize, metric, aggregator + │ + ▼ + 1-build_template/evaluate_template.py --template {name} --flavor {raw|pca} + (self-consistency check — re-align the same cells used to build + the template. Not subsequence DTW; closed-endpoint on both sides.) + │ + ├──► 1-build_template/alignments/template_alignments_{name}_{flavor}.parquet + └──► 1-build_template/plots/realtime_montage_{name}_{flavor}_{channel}.png + plots/pcs_over_pseudotime_{name}_{flavor}.png + (via plot_pcs.py — diagnostic post-hoc PCA on build-set cells) + │ + ▼ +──── Stage 2: scan template across query tracks ──────────────────────────────── + │ + 2-align_cells/align_cells.py \ + --template {name} --flavor {raw|pca} --query-set {qset} + (for every query cell track — NOT in the build set — run + SUBSEQUENCE DTW: template (length T) must match fully, query + (length Q ≥ T) endpoints float. Scans the template across the + query's time axis and picks the window with minimum cost. + Preprocessing: apply the build-time zscore + PCA + L2 from the + template zarr — never refit at alignment time.) + │ + ▼ + 2-align_cells/alignments/{template}_{flavor}_on_{qset}.parquet + (one row per query (dataset_id, fov_name, track_id, t): + pseudotime ∈ [0, 1] template position from warp path + alignment_region "pre" | "aligned" | "post" + estimated_t_rel_minutes time_calibration[template_pos] + NaN outside alignment_region == "aligned" + dtw_cost per-track total cost (repeated on rows) + length_normalized_cost dtw_cost / len(warp_path) + match_q_start, match_q_end absolute query frames bounding the match + match_duration_minutes (q_end - q_start) * frame_interval_minutes + ) + │ + ├──► 2-align_cells/rank_by_cost.py --template {name} --flavor {..} --query-set {..} + │ (histogram of length_normalized_cost, + │ scatter match_duration_minutes vs cost. + │ Use to pick a cost cutoff before montage.) + │ plots/cost_ranking_{template}_{flavor}_{qset}.png + │ + ├──► 2-align_cells/plot_top_n_montage.py \ + │ --template {..} --flavor {..} --query-set {..} \ + │ --top-n 30 --worst-n 10 + │ (rows = query cells sorted by length_normalized_cost + │ ascending; top-N at top, worst-N at bottom for + │ contrast. Columns = real time anchored at each + │ cell's warped t=0, i.e. the frame where + │ estimated_t_rel_minutes crosses 0. Red border at t=0. + │ Frames in "pre"/"post" are shown faded.) + │ plots/realtime_montage_{template}_{flavor}_{qset}.png + │ + └──► 2-align_cells/plot_pcs_aligned.py \ + --template {..} --flavor {..} --query-set {..} \ + --top-n 50 + (fit diagnostic post-hoc PCA on aligned-region + frames of top-N query cells. Plot PCs vs minutes: + left = unaligned: PC vs (t - match_q_start) * frame_interval + — each cell anchored at its own match start; + traces scatter in shape. + right = aligned: PC vs estimated_t_rel_minutes; + traces collapse onto a shared curve. + Bottom row: query-truth label fraction (solid red) on + BOTH axes + template-build-cells fraction (grey dashed, + secondary). A sharper right-panel truth curve = real + alignment, not just embedding-shape collapse.) + plots/pcs_over_pseudotime_{template}_{flavor}_{qset}.png + +──── Stage 3: organelle remodeling vs sensor-aligned t_rel ─────────────────── + (consumes Stage 2's alignment parquet; no new DTW) + + 3-organelle-remodeling/plot_organelle_remodeling.py \ + --template {..} --flavor {..} --query-set {..} \ + --organelle-channel {organelle_sec61 | organelle_g3bp1 | phase} + (REUSE the sensor alignment parquet as a timing skeleton and + project organelle-channel embeddings onto the sensor-derived + t_rel_minutes. No new DTW. For each (dataset, fov, track, t) + in the sensor parquet, look up the organelle embedding from + its zarr, compute distance-from-pre-baseline (cosine, + per-cell), and plot vs t_rel. + Three rows: (A) per-cell organelle distance traces, + (B) post-hoc PC1/PC2 of organelle embeddings over t_rel, + (C) ground-truth organelle_state fraction (when available). + Report remodeling onset offset in title: "SEC61 remodels at + t_rel = +X min".) + plots/organelle_remodeling_{template}_{flavor}_{organelle_channel}_{qset}.png +``` + +## How to run + +Run each script from its own directory — scripts resolve output paths relative to their own location. + +### Stage 0 — Select candidates + +Stage 0 emits a single artifact: `{candidate_set}_annotations.csv`, one row per `(dataset_id, fov_name, track_id, t)` with per-frame label columns. Two independent scripts produce this file — downstream consumers treat the outputs identically. + +**Auto — `select_candidates.py`** (from annotations) + +```bash +cd applications/dynaclr/scripts/pseudotime/0-select_candidates +uv run python select_candidates.py \ + --datasets ../../../configs/pseudotime/datasets.yaml \ + --config ../../../configs/pseudotime/build_template.yaml \ + --candidate-set infection_transitioning_nondiv +``` + +Filters tracks per `config["candidate_sets"][NAME]["filter"]` (anchor label, anchor_positive, min_pre/post_minutes, crop_window_minutes), then expands each selected track into per-frame rows over its crop window, copying real annotation labels onto each row. Writes `candidates/{candidate_set}_annotations.csv`. + +**Manual — `manual_candidates.py`** (user-written, for debugging / hand-curated cells) + +Each track spec is a `{t_before, t_after, labels: {label_col: [[t_on, t_off], ...]}}` entry in a Python dict. For every frame in `[t_before, t_after]`, the script emits the positive label if that frame falls inside any interval, otherwise the negative label. Columns with no intervals are left blank. + +```bash +cd applications/dynaclr/scripts/pseudotime/0-select_candidates +python manual_candidates.py +``` + +This path shares no code with `select_candidates.py`; the CSV schema is the only contract. + +**Inspect — `inspect_candidates.py`** (per-track-anchored QC montage + stats CSV) + +Reads the candidate annotations CSV and renders a montage where every row is anchored at that cell's `t_key_event` (red border at offset 0), so scanning down rows makes bad candidates obvious. Also writes a sidecar `{candidate_set}_qc.csv` with per-track stats (n_frames, pre_frames, post_frames, fov) for non-visual QC. + +```bash +cd applications/dynaclr/scripts/pseudotime/0-select_candidates +uv run python inspect_candidates.py \ + --datasets ../../../configs/pseudotime/datasets.yaml \ + --config ../../../configs/pseudotime/build_template.yaml \ + --candidate-set infection_transitioning_nondiv +``` + +### Stage 0.5 — Refine candidates (bootstrap) + +`refine_candidates.py` handles the common case of noisy annotations producing a broad candidate set that contains some bad (mislabeled / wrong-cell) tracks. Two-pass filter: + +1. **Strict headroom inside the crop**: drops tracks whose `t_key_event` is too close to the window start/end (the "annotation starts at transition" cases where the cell has no genuine uninfected baseline). +2. **Bootstrap self-alignment**: builds an initial DBA template from the surviving tracks, self-aligns each cell against it, ranks by `length_normalized_cost`, and keeps the top-N. + +Produces a **refined candidate-set CSV** that the final template build consumes. Cells surviving both filters are simultaneously well-annotated *and* consistent with the population consensus trajectory. + +```bash +cd applications/dynaclr/scripts/pseudotime/0-select_candidates +uv run python refine_candidates.py \ + --datasets ../../../configs/pseudotime/datasets.yaml \ + --config ../../../configs/pseudotime/build_template.yaml \ + --candidate-set infection_transitioning_nondiv_top20 +``` + +The refined set is declared in `build_template.yaml` as a candidate entry with `refine_from: `, `min_pre_frames`, `min_post_frames`, and `top_n_by_cost`. See "Example refined-candidate entry" below. + +Outputs: `candidates/{refined_set}_annotations.csv`, `{refined_set}_ranking.csv` (full ranking with kept/rejected flags). Run `inspect_candidates.py` on the refined set afterwards to visually QC the surviving cells. + +### Stage 1 — Build template + +```bash +cd applications/dynaclr/scripts/pseudotime/1-build_template +uv run python build_template.py --config ../../../configs/pseudotime/multi_template.yaml --template infection_nondividing_sensor +``` + +Outputs `templates/template_{name}.zarr` with **both flavors** (raw and PCA) built from the same input cells. The downstream picks which flavor to use at alignment time. + +**What the builder does** + +1. Reads `candidates/{candidate_set}_annotations.csv`. +2. Groups by `(dataset_id, fov_name, track_id)`; pulls embedding rows from the channel's zarr. +3. Derives each cell's crop window from `[min(t), max(t)]` and `t_key_event` from the first frame where the anchor label is positive. +4. Applies optional per-dataset z-score. +5. Builds **two templates from the same cells**, in parallel: + - `raw/` — optional L2-normalize, then DTW-DBA with cosine metric. + - `pca/` — fits PCA (`n_components`), transforms, optional L2, then DTW-DBA. +6. Saves the combined zarr. + +#### 1a — Self-consistency check (`evaluate_template.py`, `plot_pcs.py`) + +Both scripts live under `1-build_template/` and operate on the **build set only** — they re-align the cells that built the template onto itself. They are **not** subsequence DTW (template and cell share endpoints). Treat outputs as a sanity check, not as evaluation of generalization. + +```bash +cd applications/dynaclr/scripts/pseudotime/1-build_template +uv run python evaluate_template.py --config ../../../configs/pseudotime/multi_template.yaml --template infection_nondividing_sensor --flavor raw +uv run python plot_pcs.py --config ../../../configs/pseudotime/multi_template.yaml --template infection_nondividing_sensor --flavor raw --n-pcs 5 +``` + +Outputs: `alignments/template_alignments_{name}_{flavor}.parquet`, `plots/realtime_montage_{name}_{flavor}_*.png`, `plots/pcs_over_pseudotime_{name}_{flavor}.png`. + +Montage optional args: `--pre-minutes 180`, `--post-minutes 420`, `--crop-half 80`, `--n-cells 50` (sorted by DTW cost). +PC plot optional args: `--n-pcs 5`, `--n-bins 20`. + +### Stage 2 — Align query cells to the template (subsequence DTW) + +This stage takes the template built in Stage 1 and scans it across **new** cell tracks from any dataset (not necessarily the ones used to build the template). Subsequence DTW finds, per query track, the time window where the template best matches — i.e. the time when that cell traverses the same canonical event. + +The template's `time_calibration` provides the real-time clock. Once a cell's best-matching window is found, each frame inside the window is mapped to template-relative minutes; frames before/after stay untouched but are labeled `"pre"` / `"post"` for downstream pre-vs-post analysis. + +**All alignment, evaluation, and plotting for Stage 2 live under `2-align_cells/` — same convention as Stage 1.** + +#### 2a — Align (`align_cells.py`) + +```bash +cd applications/dynaclr/scripts/pseudotime/2-align_cells +uv run python align_cells.py \ + --datasets ../../../configs/pseudotime/datasets.yaml \ + --config ../../../configs/pseudotime/align_cells.yaml \ + --template infection_nondividing_sensor \ + --flavor raw \ + --query-set sensor_all_07_24 \ + --min-match-minutes 360 --max-skew 0.7 +``` + +What it does: + +1. Loads `templates/template_{name}.zarr` and reconstructs a `TemplateResult` for the chosen flavor. **Reuses the build-time zscore + PCA + L2 stored in the zarr — never refits at alignment time.** +2. Loads the query set's embedding zarr(s), restricted to the template's channel. +3. For each query track, calls `dtw_align_tracks(..., subsequence=True, frame_interval_minutes=..., max_psi_minutes=...)` so psi is frame-rate invariant — same wall-clock freedom on 10 min/frame and 30 min/frame tracks. The template (length T) must match fully while the query (length Q ≥ T) floats; returns a warp path, best-match window `[q_start, q_end]`, cost, and `path_skew`. +4. Applies guards (see "Guards and frame-rate invariance" below) and writes one row per `(dataset_id, fov_name, track_id, t)`. + +#### 2b — Rank cells by DTW cost (`rank_by_cost.py`) + +Diagnostic before rendering montages. Length-normalized cost (`dtw_cost / len(path)`) is the correct rank for subsequence DTW because matched windows have variable length. + +```bash +uv run python rank_by_cost.py \ + --datasets ../../../configs/pseudotime/datasets.yaml \ + --config ../../../configs/pseudotime/align_cells.yaml \ + --template infection_nondividing_sensor --flavor raw --query-set sensor_all_07_24 +``` + +Outputs `plots/cost_ranking_{template}_{flavor}_{qset}.png` (histogram + duration-vs-cost scatter). + +#### 2c — Top-N realtime montage (`plot_top_n_montage.py`) + +```bash +uv run python plot_top_n_montage.py \ + --datasets ../../../configs/pseudotime/datasets.yaml \ + --config ../../../configs/pseudotime/align_cells.yaml \ + --template infection_nondividing_sensor --flavor raw --query-set sensor_all_07_24 \ + --top-n 30 --worst-n 10 +``` + +Rows = query cells ranked by length-normalized cost; columns = real time anchored at each cell's warped `t=0` (the frame where `estimated_t_rel_minutes` crosses 0). Top-N at top, worst-N at bottom for contrast. Pre/post frames are shown faded. Red border at `t=0`. + +Outputs `plots/realtime_montage_{template}_{flavor}_{qset}.png`. + +#### 2d — PCs over real time, pre / aligned / post (`plot_pcs_aligned.py`) + +Fits a diagnostic post-hoc PCA on the **aligned-region** frames of the top-N query cells, then projects pre / aligned / post frames through the same basis so trajectories extend on both sides of the event window. Plots top PCs vs minutes: + +- **Left (unaligned):** PC vs `(t - match_q_start) * frame_interval_minutes` — anchored at each cell's own match start. +- **Right (aligned):** PC vs `estimated_t_rel_minutes`; pre/post frames are extrapolated off either end using `time_calibration[0]` / `time_calibration[-1]` as anchors. +- Points are coloured by `alignment_region` (grey = pre, blue = aligned, red = post); legend is written to a separate `*.legend.png` so the main grid isn't squeezed. + +`--exclude-template-cells` drops query cells that match the template build-set (honest generalization reporting). Without it, build-set cells will always score best since they're matching themselves. + +The bottom row carries **two** curves so alignment quality can be judged honestly: + +- **Solid red — query truth**: fraction of query cells where `obs[truth_column] == truth_positive` at each bin. Present on BOTH axes. Left bins by `(t - match_q_start) * frame_interval`; right bins by `estimated_t_rel_minutes` restricted to `alignment_region == "aligned"`. A sharper right-panel curve than left = DTW is genuinely moving the annotated transition into alignment with template t=0. +- **Dashed grey — template fraction**: label fractions stored in the template zarr (`raw/template_labels/{col}`). This is a property of the build-set cells only, not the query. Included as a secondary reference; do NOT treat it as evidence of query-side alignment. + +`--truth-column` / `--truth-positive` pick the label. Use human `infection_state` on 07_22/07_24 when available; `predicted_infection_state` on 08_26/01_28 (or 07_22 where human labels are sparse). + +```bash +uv run python plot_pcs_aligned.py \ + --datasets ../../../configs/pseudotime/datasets.yaml \ + --config ../../../configs/pseudotime/align_cells.yaml \ + --template infection_nondividing_sensor --flavor raw --query-set sensor_all_07_24 \ + --top-n 50 --n-pcs 5 --exclude-template-cells \ + --truth-column infection_state --truth-positive infected +``` + +Outputs `plots/pcs_over_pseudotime_{template}_{flavor}_{qset}.png` + `.legend.png`. + +### Stage 3 — Organelle remodeling vs sensor-aligned t_rel (`3-organelle-remodeling/plot_organelle_remodeling.py`) + +Stage 3 is a **consumer** of Stage 2's alignment parquet. It runs no new DTW — it joins the sensor-channel alignment parquet with an organelle-channel embedding zarr and plots organelle dynamics on the sensor-derived time axis. Lives in its own `3-organelle-remodeling/` directory so the scope (read Stage 2 artifacts, write new plots) is obvious. + +**Scientific question.** The sensor channel tells us *when* the NS3 protease sensor translocates to the nucleus (via template alignment; see "what t=0 actually means" note at the top of the doc). Do the organelle channels (SEC61 ER, G3BP1 stress granules) show coordinated remodeling around that same t=0, and at what offset — before, after, or simultaneous? + +**Design decision: reuse the sensor alignment as a timing skeleton** (option a, not build a separate organelle template). Rationale: the claim is "organelle remodeling *relative to* infection onset," which requires a single shared clock. A sensor-derived t=0 is meaningful; a SEC61-derived t=0 would be tautological. + +**Inputs** + +- Sensor alignment parquet: `infection_nondividing_sensor_{raw|pca}_on_{qset}.parquet`. +- One organelle embedding zarr resolved via `datasets.yaml.embeddings.{organelle_channel}`. Supported channels: `organelle_sec61`, `organelle_g3bp1`, `phase`. + +**Organelle channels live in disjoint FOV groups.** Each fluorophore was only acquired in its dedicated wells — on 07_24, SEC61 is only in A/1 + A/2 and G3BP1 only in C/1 + C/2. A sensor-query row from `2025_07_24_G3BP1` therefore has **no** SEC61 embedding and vice versa; those rows are dropped at join time. This is not a bug — it's the microscopy design. The per-organelle plot effectively restricts to the subset of sensor-aligned cells that were imaged in that organelle's wells. + +**Pipeline** + +1. Join the sensor parquet with the organelle embedding on `(dataset_id, fov_name, track_id, t)`. +2. Compute **distance-from-baseline** per frame. Baseline = mean organelle embedding across `alignment_region == "pre"` frames per cell. Per-frame scalar = cosine distance from that per-cell baseline. +3. Render three panels stacked: + - **Panel A**: per-cell organelle-distance traces vs `estimated_t_rel_minutes`, colored by pre/aligned/post. Binned median + IQR overlay. + - **Panel B**: post-hoc PC1/PC2 of the organelle embeddings (fitted on aligned-region frames, projected onto pre + post) vs `estimated_t_rel_minutes`. Mirror of `plot_pcs_aligned.py` but in organelle-embedding space. + - **Panel C**: `organelle_state` fraction vs `estimated_t_rel_minutes` (when the query obs has the column). Same truth-binning convention as Stage 2d. +4. Compute the **remodeling onset offset**: the `t_rel_minutes` where Panel A's binned median crosses a threshold (default: 2σ above the pre-baseline distance distribution). Report in the plot title — e.g. `"SEC61 remodels at t_rel = +60 min"`. + +**Preprocessing** + +Organelle embeddings are used as-is. `--flavor` only selects which sensor alignment parquet to join on (different warp paths yield different t_rel mappings); the organelle distance metric (and Panel B's post-hoc PCA) are computed per-run on the joined organelle embeddings. + +**Template cells are not excluded by default.** The sensor template was built on sensor embeddings — organelle embeddings from the same cells aren't "self-alignment" in any meaningful sense. Keep the full top-N by sensor DTW cost. + +**How to run (Phase 1 — Panel A)** + +```bash +cd applications/dynaclr/scripts/pseudotime/3-organelle-remodeling +uv run python plot_organelle_remodeling.py \ + --datasets ../../../configs/pseudotime/datasets.yaml \ + --config ../../../configs/pseudotime/align_cells.yaml \ + --template infection_nondividing_sensor --flavor raw \ + --query-set sensor_all_07_24 \ + --organelle-channel organelle_sec61 \ + --top-n 30 +``` + +Outputs `plots/organelle_remodeling_{template}_{flavor}_{organelle_channel}_{qset}.png`. + +**Delivery plan** + +1. ✅ Panel A only, SEC61 + G3BP1 on `sensor_all_07_24` — sanity-check the join + baseline subtraction. +2. Add Panels B + C (organelle-space PCA + `organelle_state` truth curve; CLI grows `--n-pcs`, `--truth-column`, `--truth-positive`). +3. Sweep across organelle channels × query sets (07_24 + 07_22 + 01_28; 08_26 missing labels). +4. Replicate check: does the remodeling offset hold across datasets? Emit a summary table analogous to the cross-dataset sensor results above. + +**Phase 1 results (Apr 2026, infection_nondividing_sensor, raw flavor, 07_24, top-30 by sensor cost, template cells NOT excluded)** + +Two distinct organelle kinetics visible in cosine distance from per-cell pre-baseline: + +| Organelle | Cells kept | Pre (t≈-400) | Onset of divergence | At sensor t=0 | Post | +|---|---:|---:|---:|---:|---:| +| **SEC61 (ER)** | 15 / 30 (A/2 only) | ~0.025 | **~-250 min** — gradual, monotonic | ~0.09 | ~0.24, still rising | +| **G3BP1 (stress granules)** | 15 / 30 (C/2 only) | ~0.03 | biphasic: gentle rise from ~-300 min, plateau around t=0, **sharp kink at ~+200 min** | ~0.10 | ~0.28, plateaus ~0.28 by t≈+400 | + +**Two qualitatively different kinetics.** + +- **SEC61 (ER) — steady, one-way remodeling.** The cosine distance from baseline rises monotonically from ~-250 min through the entire post window, with no return toward baseline. This matches the biology of ER-derived replication organelles: once the ER is restructured into invagination-type ROs for flavivirus replication, it stays restructured for as long as the virus is replicating. We don't expect the ER to "snap back" during the observation window — SEC61 remodeling is a persistent, one-way structural change upstream of the NS3 sensor signal. +- **G3BP1 (stress granules) — transient, comes-and-goes.** The distance curve shows small, repeated up-and-down excursions through the pre + early-aligned region (gentle rises, mini-plateaus), then a sharp rise around t≈+200 min, and finally a plateau rather than continued growth. This matches the biology of stress granules: they are phase-separated membraneless condensates that **assemble and disassemble** on minute timescales. Flavivirus NS3 and capsid proteins actively suppress SG formation early (so translation of viral proteins can continue) — hence the low, flickering pre-phase — and then once the antiviral response overwhelms that suppression, SGs form persistently and the signal jumps. The plateau (not continued rise) is expected: SG mass is bounded by the available G3BP1 pool, unlike ER membrane area. + +The **SEC61 steady climb vs G3BP1 transient-then-step** contrast is exactly the kind of temporal signature the pipeline was built to surface. Same sensor clock, different organelle grammars. + +Per Hofstadter & Cristea 2025 (Annu. Rev. Virol., DOI 10.1146/annurev-virology-092623-094221): "Flaviviruses (including ZIKV) actively suppress stress granule formation to maintain translation of viral proteins" — consistent with the suppressed early G3BP1 signal and the late breakthrough. ER invagination happening before the sensor readout is consistent with "ZIKV/DENV form replication organelles from ER membranes" being an upstream prerequisite for NS3 expression. + +### Stage 3c — Per-cell embedding-timing metrics (`compute_timing_metrics.py`) + +Reduces each cell's per-frame cosine-distance-from-pre-baseline curve to five scalars, then pools into a per-organelle distribution so distributions (not cells, since FOVs are disjoint) can be compared across organelles. + +**Per-cell scalars (computed on the aligned region only, with interior restriction):** + +| metric | definition | why | +|---|---|---| +| `t_onset_abs` | first `t_rel` where `distance − pre_median` crosses `+0.10` (cosine units) | SNR-robust: cells with small Δpeak can't fake an early onset by their noise floor crossing a normalized fraction | +| `t50` | first `t_rel` where distance crosses `pre_median + 0.5 × Δpeak`, last 2 aligned frames excluded | half-rise timing, interior-restricted to dodge DTW endpoint pile-up | +| `t_peak` | `argmax` of distance over interior aligned region | time of maximum embedding divergence | +| `delta_peak` | `max(aligned distance) − median(pre distance)` | amplitude of remodeling in cosine units | +| `rise_rate_per_hour` | OLS slope of distance vs `t_rel` over aligned region × 60 | per-cell aggregate speed of change | + +**Outputs:** `timing/{stem}_per_cell.parquet` + `timing/{stem}_summary.md` (per-well medians + pooled bootstrap CI). Run `compute_timing_metrics.py compare` on multiple per-cell parquets to emit strip plots + pairwise rank-sum tests (writes `timing/{out_stem}.png/.md`). + +### Stage 3d — Per-cell label-timing metrics (`compute_label_timing.py`) + +Parallel to Stage 3c but uses **linear classifier predictions** (`predicted_{state}`, the dense LC output per frame) instead of embedding distance. Supervised projection → collapses off-axis embedding noise (cell cycle, focus, photobleaching) that cosine distance would catch. + +**Per-cell scalars on the binarized predicted-label trajectory (1 = positive):** + +| metric | definition | region | +|---|---|---| +| `t_first_pos` | first `t_rel` with a positive prediction | whole track | +| `t_run_start` | first `t_rel` entering a run of ≥ `min_run` (default 3) consecutive positives | whole track | +| `t_run_end` | last `t_rel` in the run | whole track | +| `pos_duration` | `t_run_end − t_run_start` | whole track | +| `pos_fraction` | fraction of aligned frames predicted positive | **aligned only** | +| `flips` | number of 0↔1 transitions across the track | whole track | + +**Aligned-vs-whole-track asymmetry is intentional** — `pos_fraction` is the aligned-period fingerprint (density of the positive state during DTW-mappable frames); the timing scalars run across the whole track so "LC fires before sensor translocation" can be measured as a negative `t_first_pos`. + +**Example: SEC61 vs G3BP1 `predicted_organelle_state==remodel` on `sensor_all_07_24` (n=15 each)** + +| metric | SEC61 median [CI] | G3BP1 median [CI] | p (MW-U) | +|---|---|---|---| +| `t_first_pos` (min) | **-207 [-354, -158]** | +221 [+198, +341] | **4.7e-4** | +| `t_run_start` (min) | **-72 [-170, +3]** | +221 [+198, +221] | 0.048 | +| `pos_fraction` | **0.81 [0.52, 0.93]** | 0.00 [0.00, 0.03] | **1.6e-4** | +| `flips` | 3 [3, 6] | 1 [0, 4] | 0.028 | + +Signal that was suggestive but not significant in Stage 3c (embedding-timing ΔT ≈ 120 min, p ≈ 0.4) becomes sharp in Stage 3d because the LC was trained on the `remodel` label directly. Biologically consistent with Hofstadter & Cristea 2025: SEC61 (ER) remodels early for replication-organelle formation; G3BP1 (stress granules) is actively suppressed by flavivirus NS3/capsid during infection. + +**Caveat.** A near-zero G3BP1 `pos_fraction` could be real suppression or LC blind spot (if trained on SEC61-dominated data). Before interpreting as biology, verify the LC's training set covered the G3BP1 channel and morphology. + +### Delivery plan for modular multi-dataset + virus comparison (next) + +Current Stage 3c/3d take one `--query-set` (one alignment parquet → one population). Next iteration moves the pooling to be **dataset-group-aware** so the same templates can be evaluated across: + +1. **A configurable dataset pool** — pass a list of datasets to pool (all will have LC predictions; only some have human annotations). The script should error softly when a requested label column is missing from a dataset rather than silently NaN-ing those rows. +2. **Virus-stratified comparison** — ZIKV vs DENV. Cells from `2025_01_28_ZIKV_DENV` carry a `perturbation` column (`infected`, `mock`) plus a `virus` column; per-organelle distributions should split on `virus` and the compare step should render side-by-side strips. +3. **Artifact caching** — because each stage writes its own parquet, re-running only the comparison step on different pool/virus filters should be cheap (no re-computation of per-cell metrics). Confirm this already holds with the current output layout. + +### Guards and frame-rate invariance + +Subsequence DTW with generous psi relaxation can collapse the template onto a single query frame (near-zero cost, no biological meaning). Four guards prevent and surface this: + +| guard | CLI flag | default | what it rejects | +|---|---|---|---| +| Non-finite cost | (always on) | — | tracks too short for the solver to find any valid path | +| Minimum match window | `--min-match-minutes` or `--min-match-ratio` | ratio 0.5 | template compressed onto a tiny real-time window | +| Path skewness | `--max-skew` | 0.8 | L-shaped / non-diagonal warps that slip past psi | +| Pre/post headroom | query-set YAML `min_pre_minutes` / `min_post_minutes` | 0 | cells without real footage on either side of the event | + +**Minute-based guards supersede frame-based ones when both are set.** When query datasets have heterogeneous frame intervals (e.g. 07_22 at 10 min/frame vs 07_24 at 30 min/frame), use `--min-match-minutes` and `--max-psi-minutes` instead of `--min-match-ratio` and the implicit `t_template // 2` psi: minute-based thresholds apply the same wall-clock requirement regardless of frame rate. + +`--max-psi-minutes` defaults to **half the template duration**, read from `template_duration_minutes` in the template zarr attrs. Per-track psi is then `round(max_psi_minutes / dataset_frame_interval_minutes)`. + +### PCA sweep — finding the sweet spot (`sweep_pcs.py`) + +Sweeps `n_components` for one template, rebuilding the template at each value and re-running Stage 2a against a fixed query set. Produces a 2×2 summary plot + a markdown table sidecar: + +- Cost distribution vs n_components (boxplot) +- Tracks kept vs n_components +- Spearman rank correlation to the RAW 768-D reference (the sweet-spot indicator) +- PCA explained variance vs n_components + +```bash +cd applications/dynaclr/scripts/pseudotime +uv run python sweep_pcs.py \ + --datasets ../../configs/pseudotime/datasets.yaml \ + --build-config ../../configs/pseudotime/build_template.yaml \ + --align-config ../../configs/pseudotime/align_cells.yaml \ + --template infection_nondividing_sensor \ + --query-set sensor_all_07_24 \ + --n-components 5,10,20,30,50 \ + --min-match-ratio 0.7 --max-skew 0.7 +``` + +Outputs `plots/pca_sweep_{template}_{qset}.png` and `.md`. + +## Key config fields + +Three YAMLs split across `configs/pseudotime/`, each loaded alongside `datasets.yaml` via the `--datasets` + `--config` CLI pair: + +| File | Contains | Used by | +|---|---|---| +| `datasets.yaml` | `data_zarr`, `embeddings` glob patterns, `datasets` list (pred_dir, annotations_path, fov_pattern, `frame_interval_minutes`) | every stage (passed via `--datasets`) | +| `build_template.yaml` | `candidate_sets.{name}`, `templates.{name}` | Stage 0 (auto), Stage 1 | +| `align_cells.yaml` | `query_sets.{name}` | Stage 2 | + +Field reference: + +| Field | Purpose | +|---|---| +| `data_zarr` (top-level) | source image zarr for cell crop montages (Stage 0 inspect, Stage 2c) | +| `embeddings.{channel}` | glob pattern → zarr per channel | +| `datasets[].frame_interval_minutes` | real-time spacing between adjacent `t` values; used for minute→frame conversions | +| `datasets[].fov_pattern` | substring selecting FOVs from that dataset's zarr (e.g. `A/2`) | +| `candidate_sets.{name}` | anchor label + minute-based filters + `crop_window_minutes` + `max_tracks` | +| `templates.{name}` | candidate_set reference, channel, anchor label, preprocessing, DBA params | +| `query_sets.{name}` | channel (must match template), datasets, `min_pre_minutes` / `min_post_minutes`, optional `track_filter` | + +### Example candidate-set entry + +```yaml +candidate_sets: + infection_transitioning_nondiv: + datasets: ["2025_07_24_SEC61", "2025_07_24_G3BP1"] + filter: + anchor_label: infection_state + anchor_positive: infected + anchor_negative: uninfected + min_pre_minutes: 120 # need ~4 frames before onset (at 30 min/frame) + min_post_minutes: 180 + crop_window_minutes: 240 # ± half-window around the onset + max_tracks: 50 # cap for speed +``` + +### Example template entry + +```yaml +templates: + infection_nondividing_sensor: + candidate_set: infection_transitioning_nondiv # → candidates/{..}_annotations.csv + channel: sensor # key in datasets.yaml embeddings: + anchor_label: infection_state # determines t_key_event + anchor_positive: infected + + preprocessing: + zscore: none # {none, per_dataset} + pca: + n_components: 20 # pca/ flavor; raw/ always built. Use sweep_pcs.py to pick. + l2_normalize: true # applied last — on both flavors + + aggregator: dba # {dba, median} + dba: + max_iter: 30 + tol: 1.0e-5 + init: medoid + metric: cosine # {cosine, euclidean} +``` + +`track_filter`, `min_track_minutes`, `crop_window_minutes`, per-template `datasets` are all **gone** — they're baked into the annotations CSV by Stage 0. + +### Example query-set entry (Stage 2) + +Query sets describe which cells to **scan the template over** — typically cells from other datasets, or cells you deliberately withheld from the build set. + +```yaml +query_sets: + sensor_all_07_24: + channel: sensor # must match templates.{name}.channel + datasets: + - dataset_id: "2025_07_24_SEC61" + - dataset_id: "2025_07_24_G3BP1" + # Pre/post headroom (minutes, per-cell). Pass 1 (_load_query_embeddings) + # requires the track to hold template + pre + post frames; pass 2 (after DTW) + # requires the matched window to sit with real footage on both sides. + min_pre_minutes: 120 + min_post_minutes: 180 + min_track_minutes: 120 # floor; the template+headroom calculation takes the max + track_filter: {} # optional obs-column equality filters +``` + +Unlike `candidate_sets`, query sets do **not** require an `anchor_label` — we are *estimating* `t_key_event` for each query cell via DTW, not reading it off annotations. + +## Annotations CSV schema + +One file per candidate set, at `0-select_candidates/candidates/{candidate_set}_annotations.csv`. One row per `(dataset_id, fov_name, track_id, t)` over the hand-picked or auto-selected crop window. + +| column | type | notes | +|---|---|---| +| `dataset_id` | str | matches a key in `config["datasets"]` | +| `fov_name` | str | e.g. `A/2/000000` | +| `track_id` | int | | +| `t` | int | absolute frame index | +| `infection_state` | str | `"infected"` / `"uninfected"` / blank | +| `organelle_state` | str | `"remodeled"` / `"noremodeled"` / blank | +| `cell_division_state` | str | `"mitosis"` / `"interphase"` / blank | + +Positive/negative values per label are defined in `manual_candidates.py::LABEL_VALUES`. Additional label columns can be added by extending that dict. + +### Derived at read time (not stored in the CSV) + +Stage 1 computes the following from the annotations CSV; they are **not** CSV columns: + +- **Crop window** per cell: `[t_before, t_after] = [min(t), max(t)]` across that cell's rows. +- **`t_key_event`** per cell: the first `t` where the anchor label (configured per template) takes its positive value. + +## Template zarr contents + +Every build produces **both flavors** from the same input cells. + +| Path | Type | Description | +|---|---|---| +| `raw/template` | (T, D) array | DBA template in raw embedding space (D = 768 after optional z-score + L2). | +| `raw/time_calibration` | (T,) array | mean `t_relative_minutes` at each raw-template position | +| `raw/template_labels/{col}` | (T,) array | per-position label fraction for each label column | +| `pca/template` | (T, N) array | DBA template in PCA-reduced space | +| `pca/time_calibration` | (T,) array | analogous, warping paths differ | +| `pca/template_labels/{col}` | (T,) array | analogous | +| `pca/components` | (N, D) array | build-time PCA components (downstream alignment must apply these) | +| `pca/mean` | (D,) array | build-time PCA mean | +| `pca/explained_variance_ratio` | (N,) array | fraction of variance per component | +| `zscore_params/{ds_id}/mean` | (D,) array | only present when `zscore=per_dataset`. Shared across flavors. | +| `zscore_params/{ds_id}/std` | (D,) array | only present when `zscore=per_dataset` | +| `t_key_event` | (N_cells,) array | per-cell anchor frame | +| attrs `template_cell_ids` | list | `[dataset_id, fov_name, track_id]` per input cell | +| attrs `l2_normalize` | bool | whether L2 was applied before DTW | +| attrs `metric` | str | `"cosine"` — downstream alignment must match | +| attrs `aggregator` | str | `"dba"` or `"median"` | +| attrs `template_duration_minutes` | float | `time_calibration[-1] - time_calibration[0]`; used by Stage 2 to default `max_psi_minutes = template_duration_minutes / 2` | +| attrs `build_frame_intervals_minutes` | dict | `{dataset_id: frame_interval_minutes}` — records the real-time scale of each build dataset | + +The `pca/` entries are the **build-time** PCA that maps raw embeddings into the `pca/` flavor's feature space. This is distinct from the Stage 2d diagnostic PCA (`plot_pcs_aligned.py`), which is fit post-hoc on the aligned-region frames of query cells for plotting only and is not stored in the template zarr. + +## Stage 2 alignment parquet schema + +One row per `(dataset_id, fov_name, track_id, t)`. Per-track columns (`dtw_cost`, `length_normalized_cost`, `path_skew`, `match_q_start`, `match_q_end`, `match_duration_minutes`) are repeated on every frame so downstream scripts can filter rows without a separate join. + +| column | type | per-track? | notes | +|---|---|---|---| +| `dataset_id`, `fov_name`, `track_id`, `t` | ids | per-frame | identifiers | +| `pseudotime` | float ∈ [0, 1] | per-frame | warp-path template position, unit-free | +| `alignment_region` | str | per-frame | `"pre"` / `"aligned"` / `"post"` | +| `estimated_t_rel_minutes` | float | per-frame | `time_calibration[template_pos]`; `NaN` outside `aligned` (see `plot_pcs_aligned.py` for the extrapolation it uses for plotting only) | +| `dtw_cost` | float | yes | raw DTW cost at the best-path endpoint | +| `length_normalized_cost` | float | yes | `dtw_cost / len(warp_path)` — the correct ranking signal | +| `path_skew` | float ∈ [0, 1] | yes | mean deviation of warp path from ideal diagonal; ported from the old `find_best_match_dtw_bernd_clifford` | +| `match_q_start`, `match_q_end` | int | yes | absolute query frames bounding the matched window | +| `match_duration_minutes` | float | yes | `(q_end - q_start) * dataset.frame_interval_minutes` | +| `warping_speed` | float | per-frame | discrete derivative of `pseudotime` | +| `propagated_{label}_label` | float | per-frame | template label fraction propagated via warp path; `NaN` outside `aligned` | +| `template_id` | str | per-frame | UUID linking to template zarr | + +## Example refined-candidate entry (Stage 0.5) + +```yaml +candidate_sets: + infection_transitioning_nondiv_top20: + refine_from: infection_transitioning_nondiv # parent candidate set + channel: sensor # channel used for bootstrap alignment + min_pre_frames: 4 # stricter than the parent's min_pre_minutes + min_post_frames: 6 + top_n_by_cost: 20 # keep cells with lowest DTW cost against the initial template +``` + +The final template entry references the *refined* set: + +```yaml +templates: + infection_nondividing_sensor: + candidate_set: infection_transitioning_nondiv_top20 + channel: sensor + anchor_label: infection_state + anchor_positive: infected + preprocessing: + pca: + n_components: 20 + l2_normalize: true + dba: + max_iter: 30 + init: medoid + metric: cosine +``` + +## Cross-dataset results (reference — refined 20-cell template, Apr 2026) + +Template built from 20 hand-picked+bootstrap-refined cells from 07_24 (SEC61 A/2 + G3BP1 C/2), 17 frames × 30 min = 455 min. + +| Query set | Frame rate | Virus | Tracks kept | Cost p50 | +|---|---:|---|---:|---:| +| `sensor_all_07_24` (build datasets) | 30 min | ZIKV | 96 | **0.206** | +| `sensor_07_22_zikv` (cross frame rate) | 10 min | ZIKV | 49 | 0.207 | +| `sensor_08_26_zikv` (new replicate) | 30 min | ZIKV | 92 | 0.232 | +| `sensor_01_28_zikv_denv` (cross-virus) | 30 min | ZIKV+DENV | 136 | 0.292 | + +Ordering is the expected signal: build ≈ cross-frame-rate < cross-replicate < cross-virus. + +### Template selection (Apr 2026): keep both `manual_debug_sensor` and `infection_nondividing_sensor` + +Both templates are maintained. They serve different purposes: + +| Template | Build set | Use case | +|---|---|---| +| `manual_debug_sensor` | 4 hand-picked cells on 07_24 A/2 | Debug / smoke-test. Sharpest in-distribution PC collapse; useful for verifying new code paths. | +| `infection_nondividing_sensor` | 20 bootstrap-refined cells on 07_24 (A/2 + C/2) | Production. Monotonic query-truth curves on every dataset with per-frame labels. Use this for organelle-remodeling and cross-dataset analyses. | + +**Honest query-truth comparison with the updated Stage 2d plot** (raw flavor, query-truth curve binned by `estimated_t_rel_minutes` on `alignment_region == "aligned"`): + +| Query set | Truth col | `manual_debug` right-panel | `infection_nondiv` right-panel | +|---|---|---|---| +| `sensor_all_07_24` | `infection_state` | sharp rise to ~0.95, width ~200 min | rise to ~0.75, width ~350 min | +| `sensor_07_22_zikv` | `predicted_infection_state` | modest rise 0.2 → 0.85 | sharp rise 0.1 → 0.95 | +| `sensor_08_26_zikv` | — (no per-frame labels) | — | — | +| `sensor_01_28_zikv_denv` | `predicted_infection_state` | **non-monotonic** (rises, falls, rises; overfits to ZIKV-only trajectory) | roughly monotonic rise 0.15 → 0.7 | + +So `manual_debug` wins in-distribution but breaks on cross-virus; `infection_nondiv` gives monotonic alignment everywhere the labels exist. Neither is "the right answer" universally — pick based on the analysis target. For organelle remodeling we use `infection_nondividing_sensor` because the question spans multiple replicates. + +08_26 is currently uninformative for truth-curve evaluation because its embedding zarr obs lacks `predicted_infection_state`. Running the infection classifier on that zarr is the gating step to close the cross-replicate picture. + +## Next steps & known gaps + +### Outstanding + +- **Stage 2e — Organelle remodeling (the main goal).** Design locked (option a: reuse the sensor alignment parquet as the timing skeleton; no separate organelle template). Full spec in the "Stage 2e" section above. Implementation delivered in phases: Panel A → add Panels B + C → sweep channels × query sets → cross-dataset offset replication. +- **UMAP/PHATE colored by pseudotime.** Once organelle plots land, this is the natural next exploratory step. +- **Run infection classifier on 08_26 embedding zarr.** The 2025_08_26 sensor zarr obs lacks `predicted_infection_state`, so Stage 2d/2e truth curves can't be evaluated on that dataset. Gating step for closing the cross-replicate picture. +- **Stage 1a PC plots** still use the old closed-endpoint `evaluate_template.py` pipeline. Works, but the left-column "unaligned" curve now uses the true annotation (fixed Apr 2026). When ready for a deeper refactor, switch Stage 1a to use subsequence DTW like Stage 2 for consistency. +- **07_22 build-set integration.** 07_22 annotations use an older tracking version that doesn't match the embedding zarr's track_ids. Re-tracking 07_22 with the current version would let us include it in the template build (not just as a query set). +- **Cleanup of swept template zarrs.** `sweep_pcs.py` leaves `template_*_pc5/10/20/30/50.zarr` (~50 MB each) under `1-build_template/templates/`. Add a `--cleanup` flag or document manual deletion. + +### Followups / fragility + +- **Sakoe-Chiba band** (`--sakoe-chiba-ratio`) as an optional 4th guard alongside psi, skew, min_match — only wire up if we see more collapse symptoms. +- **Per-dataset `data_zarr` in `datasets.yaml`** is populated for 07_22/07_24 but not for 08_26/01_28 (query-only — no image montages needed). Adding them would enable Stage 2c montages on those datasets. +- **Annotation noise (±2-3 frames around true onset)** is handled by DBA averaging, but a systematic bias across annotators would shift the template's t=0. No known bias today; worth re-checking if a new annotator starts contributing. +- **Stage 2d truth curve** (`plot_pcs_aligned.py --truth-column`) falls back gracefully to a placeholder when the query obs doesn't have the requested label column. 08_26/01_28 have `predicted_infection_state` only; 07_24/07_22 have human `infection_state`. Use `--truth-column infection_state --truth-positive infected` when human labels exist; `predicted_infection_state` otherwise. + +### Bugs fixed this cycle (Apr 2026) + +- **Psi collapse**: unconstrained psi let DTW collapse the template onto a single query frame (cost ~0, no biology). Capped at `t_template // 2`. +- **Minute-based psi was wrong**: initial `max_psi_minutes` scaling used the *query* frame interval, which over-relaxed on cross-frame-rate datasets. Psi is a template-axis budget; the frame-unit default handles all frame rates correctly. +- **Label propagation** set pre/post frames to 0.0/1.0; now `NaN` (matches `estimated_t_rel_minutes` convention). +- **Stage 1a truth curve** was using `propagated_*_label` (template-warped) instead of the candidate CSV ground-truth. Fixed to read from CSV directly. +- **Stage 2d truth curve** rendered a placeholder "(no ground-truth for query cells)"; now reads from query obs (`--truth-column`). +- **Stage 2d right-panel was misleading**: the bottom-right curve plotted the template's own stored `template_labels/{col}` fraction, which is a property of 4-20 build-set cells and always looks sharp (goes from 0 to 1 within one template step). This read as "alignment is perfect" when in reality no query labels were involved. Fix: right panel now plots query-truth binned by `estimated_t_rel_minutes` (restricted to `alignment_region == "aligned"`) as the primary solid-red curve, and demotes the template fraction to a dashed grey secondary reference. This is how we caught the `manual_debug_sensor` cross-virus failure on 01_28 (right-panel truth curve became non-monotonic — the real signal). +- **DBA medoid init** subsampled randomly; could pick a short track as medoid, truncating the template. Now picks the longest N. +- **Dead code deleted**: `evaluation.py` (broken `onset_concordance` metric) and untracked `classification.py`. diff --git a/applications/dynaclr/docs/DAGs/training.md b/applications/dynaclr/docs/DAGs/training.md new file mode 100644 index 000000000..2e93c903d --- /dev/null +++ b/applications/dynaclr/docs/DAGs/training.md @@ -0,0 +1,160 @@ +# Training DAG + +## Prerequisites + +Datasets must be AI-ready before building a collection. See [ai_ready_datasets.md](ai_ready_datasets.md) +for the full data preparation pipeline (`prepare run` → concatenate → QC → preprocess). + +A dataset is ready when `prepare status` shows `preprocessed: yes` — meaning both +`normalization` and `focus_slice` metadata exist in the zarr zattrs. + +## Step-by-step detail + +``` +zarr stores (preprocessed: normalization + focus_slice in zattrs) +tracking.zarr (per-dataset, synced from NFS) + │ + ├──► collection.yml # defines experiments, channels, perturbation_wells + │ # versioned in git under configs/collections/ + ▼ +dynaclr build-cell-index \ + configs/collections/.yml \ + /hpc/projects/organelle_phenotyping/models/collections/.parquet \ + --num-workers 8 + │ reads tracking CSVs + zarr shape metadata + │ one row per (cell, timepoint, channel) + │ sets z=0 placeholder (overwritten in next step) + ▼ +.parquet (raw: shape columns, z=0, no norm stats) + │ + ▼ +dynaclr preprocess-cell-index \ + /hpc/.../collections/.parquet \ + --focus-channel Phase3D + │ opens each unique FOV once from zarr zattrs: + │ norm_mean/std/median/iqr/max/min — per (cell, timepoint, channel) + │ z_focus_mean — per FOV (mean across timepoints) + │ z — per timepoint focus slice index + │ drops empty frames (max == 0) + ▼ +.parquet (ready: self-contained, no zarr reads at training time) + │ + ▼ +viscy fit --config configs/training/.yml + │ OR: sbatch configs/training/.sh (SLURM, recommended) + │ MultiExperimentDataModule reads parquet only at init + │ tensorstore opens zarr lazily on first batch + │ ExperimentRegistry reads plate.zattrs["focus_slice"] once at startup + │ for z_ranges (z_extraction_window centered on dataset z_focus_mean) + ▼ +checkpoints/ + wandb logs +``` + +## Pipeline DAG (process dependency) + +``` +collection.yml + │ + ▼ +build-cell-index (CPU, ~1 min) + │ + ▼ +preprocess-cell-index (CPU, ~5 min, I/O bound) + │ + ▼ +viscy fit (GPU, hours–days) +``` + +## Key commands + + +| Step | Command | Input | Output | +| --------------------- | ------------------------------------------------------------------------- | -------------------------------------- | --------------------------------------------------------- | +| Build cell index | `dynaclr build-cell-index --num-workers 8` | collection YAML + zarr + tracking CSVs | parquet with TCZYX shape columns | +| Preprocess cell index | `dynaclr preprocess-cell-index --focus-channel Phase3D` | parquet + zarr zattrs | parquet with norm stats, per-timepoint z, empties removed | +| Train (interactive) | `uv run viscy fit --config configs/training/.yml` | training config + parquet | checkpoints + logs | +| Train (SLURM) | `sbatch configs/training/.sh` | training config + parquet | checkpoints + logs | +| Resume (SLURM) | `CKPT_PATH=.../last.ckpt sbatch configs/training/.sh` | checkpoint path env var | resumed checkpoints | + + +## What lives where + + +| Data | Location | When written | +| --------------------------------------- | --------------------------------------------------------- | -------------------------------------------- | +| Pixel data (TCZYX arrays) | zarr store on VAST | `prepare run` → concatenate | +| Cell tracking (y, x, t, track_id) | tracking.zarr on VAST | `prepare run` → concatenate | +| Normalization stats (per FOV/timepoint) | zarr zattrs → parquet `norm_*` columns | `viscy preprocess` → `preprocess-cell-index` | +| Focus slice (per timepoint) | zarr zattrs → parquet `z` column | `viscy preprocess` → `preprocess-cell-index` | +| Focus slice mean (per FOV) | zarr zattrs → parquet `z_focus_mean` | `viscy preprocess` → `preprocess-cell-index` | +| TCZYX shape per FOV | parquet columns | `build-cell-index` | +| Collection definition | `configs/collections/.yml` in git | manually authored | +| Parquet | `/hpc/projects/organelle_phenotyping/models/collections/` | `build-cell-index` | + + +## collection.yml format + +```yaml +name: +description: "..." + +experiments: + - name: # {date}_{cell}_{marker}_{perturbation} + data_path: /hpc/projects/.../dataset.zarr + tracks_path: /hpc/projects/.../tracking.zarr + channels: + - name: "raw GFP EX488 EM525-45" # zarr channel name (exact match) + marker: G3BP1 # protein label used in parquet + perturbation_wells: + uninfected: [C/1] + infected: [C/2] + interval_minutes: 30.0 + start_hpi: 3.5 + marker: G3BP1 + organelle: stress_granules + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 +``` + +Experiment name convention: `{date}_{cell_line}_{marker}_{perturbation}` — +perturbation suffix is always included (e.g., `_ZIKV`, `_DENV`, `_ZIKV_DENV`). + +## Training config structure + +Training configs use Lightning CLI `base:` inheritance: + +```yaml +base: + - recipes/trainer.yml # seed, accelerator, logger, callbacks + - recipes/model/contrastive_encoder_convnext_tiny.yml # or dinov3_frozen_mlp.yml + +trainer: + strategy: ddp + devices: 2 + precision: bf16-mixed + max_epochs: 150 + +data: + cell_index_path: /hpc/.../collections/.parquet + ... +``` + +SLURM `.sh` scripts export `PYTHONNOUSERSITE=1` and launch via `srun` for DDP. + +## Reproducibility + +Version `collection.yml` in git. The parquet is derived deterministically from: + +1. The collection YAML (experiment definitions, channels, wells) +2. Tracking zarrs (cell positions) +3. Zarr zattrs (normalization + focus stats from `viscy preprocess` + `qc run`) + +To reproduce: `build-cell-index` → `preprocess-cell-index` from the same collection YAML. + +## Notes + +- `preprocess-cell-index` overwrites the parquet in-place by default. Pass `--output` to write elsewhere. +- `--focus-channel Phase3D` selects which channel's `per_timepoint` focus indices are written to the `z` column. Use the channel that has the sharpest axial contrast (label-free Phase3D for most experiments). +- At training time, `ExperimentRegistry.__post_init__` reads `plate.zattrs["focus_slice"][channel]["dataset_statistics"]["z_focus_mean"]` to compute per-experiment z_ranges for patch extraction. This is the only zarr metadata read at training startup; the parquet is self-contained for all per-cell data. +- The `z` column in the parquet is carried through to embeddings obs during predict — downstream consumers (e.g., visualization) can use it to recover the in-focus plane for each cell at each timepoint. diff --git a/applications/dynaclr/docs/linear_classifiers/README.md b/applications/dynaclr/docs/linear_classifiers/README.md index a4f893c9c..a0486d299 100644 --- a/applications/dynaclr/docs/linear_classifiers/README.md +++ b/applications/dynaclr/docs/linear_classifiers/README.md @@ -9,14 +9,13 @@ This directory contains: | File | Description | |------|-------------| | `src/utils.py` | Shared functions for discovering predictions, annotations, channel resolution, and path utilities | -| `src/report.py` | PDF report generation for cross-validation and evaluation (optional) | +| `src/report.py` | PDF report generation for cross-validation (optional, `--report` flag) | | `scripts/generate_prediction_scripts.py` | Generates SLURM `.sh`/`.yml` scripts for datasets missing embeddings | | `scripts/generate_batch_predictions.py` | Batch prediction config & SLURM script generator with auto z-range | | `scripts/generate_train_config.py` | Generates training YAML configs for all valid task x channel combinations | | `scripts/train_linear_classifier.py` | CLI for training a classifier from a config | | `scripts/apply_linear_classifier.py` | CLI for applying a trained classifier to new embeddings | | `scripts/cross_validation.py` | Leave-one-dataset-out CV with impact scoring (helps/hurts/uncertain) | -| `scripts/evaluate_dataset.py` | Compare embedding models (e.g. 2D vs 3D) on a held-out test set | ## Prerequisites @@ -80,8 +79,8 @@ dynaclr apply-linear-classifier -c configs/example_linear_classifier_inference.y Determine which training datasets help or hurt classifier performance using rotating leave-one-dataset-out CV. Run from the `linear_classifiers/` directory: ```bash -python scripts/cross_validation.py -c configs/cross_validate_example.yaml -python scripts/cross_validation.py -c configs/cross_validate_example.yaml --report # with PDF +dynaclr cross-validate -c configs/cross_validate_example.yaml +dynaclr cross-validate -c configs/cross_validate_example.yaml --report # with PDF ``` Outputs: @@ -96,24 +95,6 @@ Each dataset is labeled as: - **uncertain** — delta within noise - **unsafe** — fold skipped due to insufficient class samples -### 6. Evaluate models on a held-out test set - -Compare embedding models by training classifiers and evaluating on a held-out dataset: - -```bash -python scripts/evaluate_dataset.py -c configs/evaluate_dataset_example.yaml -python scripts/evaluate_dataset.py -c configs/evaluate_dataset_example.yaml --report # with PDF -``` - -Outputs per model: -- `{model}/{task}_{channel}_pipeline.joblib` — trained classifier -- `{model}/{task}_{channel}_predictions.zarr` — test predictions -- `{model}/metrics_summary.csv` — per-model metrics - -Combined outputs: -- `train_metrics_comparison.csv` — validation metrics across models -- `test_metrics_comparison.csv` — test metrics across models - ## Training Configuration Create a YAML config file (see `configs/example_linear_classifier_train.yaml`): diff --git a/applications/dynaclr/pyproject.toml b/applications/dynaclr/pyproject.toml index dc31623a5..3258119d3 100644 --- a/applications/dynaclr/pyproject.toml +++ b/applications/dynaclr/pyproject.toml @@ -43,6 +43,7 @@ dependencies = [ optional-dependencies.eval = [ "anndata", + "dtaidistance", "natsort", "phate", "scikit-learn", @@ -51,6 +52,13 @@ optional-dependencies.eval = [ "umap-learn", "wandb", ] +optional-dependencies.tracking = [ + "gurobipy>=12.0.1,<13", + "onnxruntime-gpu", + "py-ctcmetrics", + "tabulate", + "tracksdata", +] urls.Homepage = "https://github.com/mehta-lab/VisCy" urls.Issues = "https://github.com/mehta-lab/VisCy/issues" urls.Repository = "https://github.com/mehta-lab/VisCy" diff --git a/applications/dynaclr/scripts/cellanome/embed_dinov3.py b/applications/dynaclr/scripts/cellanome/embed_dinov3.py new file mode 100644 index 000000000..c0f084c69 --- /dev/null +++ b/applications/dynaclr/scripts/cellanome/embed_dinov3.py @@ -0,0 +1,407 @@ +"""Extract DINOv3 embeddings for cellanome cells → cell-level AnnData. + +Reads primary_analysis.csv from the Cellanome pipeline, crops cell patches +from the OME-Zarr store, runs them through a frozen DINOv3 model, and writes +a new cell-level AnnData zarr where each row is one segmented cell. + +Usage +----- +uv run python embed_dinov3.py config.yaml +""" + +import argparse +import logging +import math +from pathlib import Path + +import anndata as ad +import numpy as np +import pandas as pd +import torch +import torch.nn.functional as F +import yaml +import zarr +from tqdm import tqdm + +from viscy_models.foundation import DINOv3Model + +CHANNEL_SHORT_NAMES = { + "White": "BF", + "Blue-FITC (520)": "FITC", + "Red-CY5 (700)": "CY5", + "Green-CY3 (605)": "CY3", +} + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") +logger = logging.getLogger(__name__) + + +def load_primary_analysis( + analysis_base: str, + scan_ids: list[int] | None = None, + lane_ids: list[int] | None = None, +) -> pd.DataFrame: + """Load and concatenate primary_analysis.csv for all scans/lanes. + + Parameters + ---------- + analysis_base : str + Path to the image_analysis_output directory. + scan_ids : list[int] or None + Scan IDs to include. If None, auto-discover. + lane_ids : list[int] or None + Lane IDs to include. If None, auto-discover. + + Returns + ------- + pd.DataFrame + Concatenated primary analysis with all columns. + """ + base = Path(analysis_base) + if scan_ids is None: + scan_ids = sorted(int(p.name.split("_")[1]) for p in base.glob("scan_*") if p.is_dir()) + if lane_ids is None: + all_lanes = set() + for scan_id in scan_ids: + scan_dir = base / f"scan_{scan_id}" + all_lanes.update(int(p.name.split("_")[1]) for p in scan_dir.glob("lane_*") if p.is_dir()) + lane_ids = sorted(all_lanes) + + frames = [] + for scan_id in scan_ids: + for lane_id in lane_ids: + csv_path = ( + base + / f"scan_{scan_id}" + / f"lane_{lane_id}" + / "processed" + / "CAGE_REGISTRATION" + / "primary_analysis.csv" + ) + if not csv_path.exists(): + logger.warning(f"Missing: {csv_path}") + continue + df = pd.read_csv(csv_path) + frames.append(df) + logger.info(f"scan_{scan_id}/lane_{lane_id}: {len(df)} objects") + + combined = pd.concat(frames, ignore_index=True) + logger.info(f"Total: {len(combined)} objects across {len(frames)} scan/lane combinations") + return combined + + +def derive_zarr_paths(df: pd.DataFrame) -> pd.DataFrame: + """Derive zarr position and path from cage_crop_file_name. + + Parameters + ---------- + df : pd.DataFrame + Must have columns: cage_crop_file_name, lane_id, scan_id. + + Returns + ------- + pd.DataFrame + With added zarr_position and zarr_path columns. + """ + + def _parse_position(cage_crop: str) -> str: + parts = str(cage_crop).split("_") + return f"{parts[4]}{parts[5]}" + + df["zarr_position"] = df["cage_crop_file_name"].apply(_parse_position) + df["zarr_path"] = df["lane_id"].astype(str) + "/" + df["scan_id"].astype(str) + "/" + df["zarr_position"] + return df + + +def build_barcode_lookup(anndata_path: str) -> dict[tuple[int, str], list[str]]: + """Build (global_cage_id_matched, lane) → [barcode_index, ...] lookup. + + Parameters + ---------- + anndata_path : str + Path to the transcriptome AnnData zarr. + + Returns + ------- + dict[tuple[int, str], list[str]] + Mapping from (cage_id, lane_string) to list of barcode obs_names. + """ + adata = ad.read_zarr(anndata_path) + obs = adata.obs.copy() + obs["_lane"] = obs.index.str.extract(r"(lane_\d)")[0].to_numpy() + obs["_cage_id"] = obs["global.cage.id.matched"].astype(int) + + lookup: dict[tuple[int, str], list[str]] = {} + for idx, row in obs.iterrows(): + key = (row["_cage_id"], row["_lane"]) + lookup.setdefault(key, []).append(idx) + + logger.info(f"Barcode lookup: {len(lookup)} unique (cage, lane) pairs from {adata.n_obs} barcodes") + return lookup + + +def join_barcodes(df: pd.DataFrame, lookup: dict[tuple[int, str], list[str]]) -> pd.DataFrame: + """Join barcode indices to cells via (global_cage_id_matched, lane). + + Parameters + ---------- + df : pd.DataFrame + Must have columns: global_cage_id_matched, lane_id. + lookup : dict + From build_barcode_lookup. + + Returns + ------- + pd.DataFrame + With added barcode_index and in_anndata columns. + """ + barcode_indices = [] + for _, row in df.iterrows(): + key = (int(row["global_cage_id_matched"]), f"lane_{int(row['lane_id'])}") + barcodes = lookup.get(key, []) + barcode_indices.append(";".join(barcodes) if barcodes else "") + df["barcode_index"] = barcode_indices + df["in_anndata"] = df["barcode_index"] != "" + return df + + +def apply_filters(df: pd.DataFrame, filters: dict) -> pd.DataFrame: + """Apply column-level filters to the DataFrame. + + Parameters + ---------- + df : pd.DataFrame + Input DataFrame. + filters : dict + Mapping of column_name → {min, max, eq, isin}. + + Returns + ------- + pd.DataFrame + Filtered DataFrame. + """ + for col, conditions in filters.items(): + if col not in df.columns: + raise ValueError(f"Filter column '{col}' not found. Available: {list(df.columns)[:10]}...") + if "min" in conditions: + df = df[df[col] >= conditions["min"]] + if "max" in conditions: + df = df[df[col] <= conditions["max"]] + if "eq" in conditions: + df = df[df[col] == conditions["eq"]] + if "isin" in conditions: + df = df[df[col].isin(conditions["isin"])] + return df + + +def resolve_channel_indices(store: zarr.Group, zarr_path: str, channel_names: list[str]) -> list[int]: + """Resolve integer indices for named channels in an OME-Zarr FOV. + + Parameters + ---------- + store : zarr.Group + Opened zarr store. + zarr_path : str + Relative path to the FOV group. + channel_names : list[str] + Channel labels to look up. + + Returns + ------- + list[int] + Zero-based channel indices. + """ + fov_group = store[zarr_path] + channels = fov_group.attrs["omero"]["channels"] + labels = [ch.get("label", ch.get("name", "")) for ch in channels] + indices = [] + for name in channel_names: + if name not in labels: + raise ValueError(f"Channel '{name}' not found. Available: {labels}") + indices.append(labels.index(name)) + return indices + + +def crop_cell( + fov_array: np.ndarray, + cy: int, + cx: int, + half: int, + channels: list[int] | None = None, +) -> np.ndarray | None: + """Crop a square patch centered on (cy, cx) from a 2D FOV array. + + Parameters + ---------- + fov_array : np.ndarray + FOV image array of shape ``(C, H, W)``. + cy : int + Y centroid in FOV pixels. + cx : int + X centroid in FOV pixels. + half : int + Half the crop size in pixels. + channels : list[int] or None + Channel indices to select. If None, use all channels. + + Returns + ------- + np.ndarray or None + Cropped patch, or None if out of bounds. + """ + _, h, w = fov_array.shape + y0, y1 = cy - half, cy + half + x0, x1 = cx - half, cx + half + if y0 < 0 or x0 < 0 or y1 > h or x1 > w: + return None + patch = fov_array[:, y0:y1, x0:x1] + if channels is not None: + patch = patch[channels] + return patch + + +def main(): + """Extract DINOv3 embeddings for cellanome cells.""" + parser = argparse.ArgumentParser(description="Extract DINOv3 embeddings for cellanome cells.") + parser.add_argument("config", help="Path to YAML config file") + args = parser.parse_args() + + with open(args.config) as f: + cfg = yaml.safe_load(f) + + zarr_store = cfg["zarr_store"] + analysis_base = cfg["analysis_base"] + transcriptome_anndata = cfg.get("transcriptome_anndata", None) + output_path = cfg["output_path"] + model_name = cfg.get("model_name", "facebook/dinov2-base") + channels = cfg.get("channels", None) + output_key = cfg.get("output_key", None) + patch_size = cfg.get("patch_size", 96) + reference_pixel_size = cfg.get("reference_pixel_size", 1.0) + source_pixel_size = cfg.get("source_pixel_size", 1.0) + batch_size = cfg.get("batch_size", 128) + device_str = cfg.get("device", "cuda") + scan_ids = cfg.get("scan_ids", None) + lane_ids = cfg.get("lane_ids", None) + filters = cfg.get("filters", {}) + + # --- Load and prepare data --- + df = load_primary_analysis(analysis_base, scan_ids, lane_ids) + n_raw = len(df) + df = apply_filters(df, filters) + logger.info(f"After filtering: {len(df)} cells (removed {n_raw - len(df)})") + + df = derive_zarr_paths(df) + if transcriptome_anndata is not None: + lookup = build_barcode_lookup(transcriptome_anndata) + df = join_barcodes(df, lookup) + n_matched = df["in_anndata"].sum() + logger.info(f"Barcode match: {n_matched}/{len(df)} cells ({100 * n_matched / len(df):.1f}%)") + else: + logger.info("No transcriptome_anndata provided; skipping barcode join") + + # --- Pixel size rescaling --- + # raw_crop covers the same physical area as patch_size at reference resolution. + # Larger source pixels → fewer pixels needed. + raw_half = round(patch_size * reference_pixel_size / source_pixel_size) // 2 + raw_crop_size = 2 * raw_half + logger.info(f"Raw crop: {raw_crop_size}x{raw_crop_size} -> model input: {patch_size}x{patch_size}") + + # --- Resolve channels --- + store = zarr.open(zarr_store, mode="r") + first_zarr_path = df["zarr_path"].iloc[0] + if channels is not None: + channel_indices = resolve_channel_indices(store, first_zarr_path, channels) + channel_labels = channels + else: + fov_group = store[first_zarr_path] + omero_channels = fov_group.attrs["omero"]["channels"] + channel_labels = [ch.get("label", ch.get("name", "")) for ch in omero_channels] + channel_indices = list(range(len(channel_labels))) + logger.info(f"Channels: {channel_labels} (indices {channel_indices})") + + short_names = [CHANNEL_SHORT_NAMES.get(ch, ch) for ch in channel_labels] + output_key = output_key or "dinov3_" + "_".join(short_names) + + # --- Load model --- + device = torch.device(device_str if torch.cuda.is_available() else "cpu") + model = DINOv3Model(model_name=model_name, freeze=True) + model = model.to(device) + model.eval() + logger.info(f"Loaded DINOv3 {model_name} on {device}") + + # --- Inference --- + df = df.sort_values("zarr_path").reset_index(drop=True) + current_fov_path: str | None = None + current_fov: np.ndarray | None = None + all_embeddings = [] + valid_indices = [] + skipped_border = 0 + + n_batches = math.ceil(len(df) / batch_size) + pbar = tqdm(range(0, len(df), batch_size), total=n_batches, desc="Embedding", unit="batch") + for batch_start in pbar: + batch_df = df.iloc[batch_start : batch_start + batch_size] + patches = [] + batch_valid = [] + + for idx, row in batch_df.iterrows(): + zarr_path = row["zarr_path"] + cy, cx = int(row["object_y_fov"]), int(row["object_x_fov"]) + + if zarr_path != current_fov_path: + current_fov = store[zarr_path]["0"][0, :, 0] + current_fov_path = zarr_path + + patch = crop_cell(current_fov, cy, cx, raw_half, channels=channel_indices) + if patch is None: + skipped_border += 1 + continue + + patches.append(patch) + batch_valid.append(idx) + + if not patches: + continue + + batch_tensor = torch.from_numpy(np.stack(patches)).float() + if raw_crop_size != patch_size: + batch_tensor = F.interpolate( + batch_tensor, size=(patch_size, patch_size), mode="bilinear", align_corners=False + ) + + # Per-sample z-score: zero mean, unit std + mean = batch_tensor.flatten(1).mean(dim=1, keepdim=True).unsqueeze(-1).unsqueeze(-1) + std = batch_tensor.flatten(1).std(dim=1, keepdim=True).unsqueeze(-1).unsqueeze(-1).clamp(min=1e-8) + batch_tensor = (batch_tensor - mean) / std + + batch_tensor = batch_tensor.unsqueeze(2).to(device) + + with torch.inference_mode(): + features, _ = model(batch_tensor) + + all_embeddings.append(features.cpu().numpy()) + valid_indices.extend(batch_valid) + pbar.set_postfix(cells=len(valid_indices), skipped=skipped_border) + + if skipped_border > 0: + logger.warning(f"Skipped {skipped_border} cells too close to FOV border") + + # --- Write cell-level anndata --- + embeddings = np.concatenate(all_embeddings, axis=0) + valid_df = df.iloc[valid_indices].reset_index(drop=True) + logger.info(f"Embeddings: {embeddings.shape}") + + pd.options.future.infer_string = False + obs = valid_df.copy() + for col in obs.select_dtypes(include=["string", "string[pyarrow]"]).columns: + obs[col] = obs[col].astype(object) + obs.index = obs["object_uuid"].astype(str) + + cell_adata = ad.AnnData(X=embeddings.astype(np.float32), obs=obs) + cell_adata.write_zarr(output_path) + logger.info(f"Wrote {output_path}: {cell_adata.n_obs} cells x {cell_adata.n_vars} dims") + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/scripts/cellanome/embed_dynaclr.py b/applications/dynaclr/scripts/cellanome/embed_dynaclr.py new file mode 100644 index 000000000..32e343856 --- /dev/null +++ b/applications/dynaclr/scripts/cellanome/embed_dynaclr.py @@ -0,0 +1,405 @@ +"""Extract DynaCLR embeddings for cellanome cells → cell-level AnnData. + +Reads primary_analysis.csv from the Cellanome pipeline, crops cell patches +(single channel) from the OME-Zarr store, runs them through a DynaCLR +contrastive encoder checkpoint, and writes a new cell-level AnnData zarr. + +Usage +----- +uv run python applications/dynaclr/scripts/cellanome/embed_dynaclr.py config.yaml +""" + +import argparse +import logging +import math +from pathlib import Path + +import anndata as ad +import numpy as np +import pandas as pd +import torch +import torch.nn.functional as F +import yaml +import zarr +from tqdm import tqdm + +from dynaclr.engine import ContrastiveEncoder + +CHANNEL_SHORT_NAMES = { + "White": "BF", + "Blue-FITC (520)": "FITC", + "Red-CY5 (700)": "CY5", + "Green-CY3 (605)": "CY3", +} + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") +logger = logging.getLogger(__name__) + + +def load_primary_analysis( + analysis_base: str, + scan_ids: list[int] | None = None, + lane_ids: list[int] | None = None, +) -> pd.DataFrame: + """Load and concatenate primary_analysis.csv for all scans/lanes. + + Parameters + ---------- + analysis_base : str + Path to the image_analysis_output directory. + scan_ids : list[int] or None + Scan IDs to include. If None, auto-discover. + lane_ids : list[int] or None + Lane IDs to include. If None, auto-discover. + + Returns + ------- + pd.DataFrame + Concatenated primary analysis with all columns. + """ + base = Path(analysis_base) + if scan_ids is None: + scan_ids = sorted(int(p.name.split("_")[1]) for p in base.glob("scan_*") if p.is_dir()) + if lane_ids is None: + all_lanes = set() + for scan_id in scan_ids: + scan_dir = base / f"scan_{scan_id}" + all_lanes.update(int(p.name.split("_")[1]) for p in scan_dir.glob("lane_*") if p.is_dir()) + lane_ids = sorted(all_lanes) + + frames = [] + for scan_id in scan_ids: + for lane_id in lane_ids: + csv_path = ( + base + / f"scan_{scan_id}" + / f"lane_{lane_id}" + / "processed" + / "CAGE_REGISTRATION" + / "primary_analysis.csv" + ) + if not csv_path.exists(): + logger.warning(f"Missing: {csv_path}") + continue + df = pd.read_csv(csv_path) + frames.append(df) + logger.info(f"scan_{scan_id}/lane_{lane_id}: {len(df)} objects") + + combined = pd.concat(frames, ignore_index=True) + logger.info(f"Total: {len(combined)} objects across {len(frames)} scan/lane combinations") + return combined + + +def derive_zarr_paths(df: pd.DataFrame) -> pd.DataFrame: + """Derive zarr position and path from cage_crop_file_name. + + Parameters + ---------- + df : pd.DataFrame + Must have columns: cage_crop_file_name, lane_id, scan_id. + + Returns + ------- + pd.DataFrame + With added zarr_position and zarr_path columns. + """ + + def _parse_position(cage_crop: str) -> str: + parts = str(cage_crop).split("_") + return f"{parts[4]}{parts[5]}" + + df["zarr_position"] = df["cage_crop_file_name"].apply(_parse_position) + df["zarr_path"] = df["lane_id"].astype(str) + "/" + df["scan_id"].astype(str) + "/" + df["zarr_position"] + return df + + +def build_barcode_lookup(anndata_path: str) -> dict[tuple[int, str], list[str]]: + """Build (global_cage_id_matched, lane) → [barcode_index, ...] lookup. + + Parameters + ---------- + anndata_path : str + Path to the transcriptome AnnData zarr. + + Returns + ------- + dict[tuple[int, str], list[str]] + Mapping from (cage_id, lane_string) to list of barcode obs_names. + """ + adata = ad.read_zarr(anndata_path) + obs = adata.obs.copy() + obs["_lane"] = obs.index.str.extract(r"(lane_\d)")[0].to_numpy() + obs["_cage_id"] = obs["global.cage.id.matched"].astype(int) + + lookup: dict[tuple[int, str], list[str]] = {} + for idx, row in obs.iterrows(): + key = (row["_cage_id"], row["_lane"]) + lookup.setdefault(key, []).append(idx) + + logger.info(f"Barcode lookup: {len(lookup)} unique (cage, lane) pairs from {adata.n_obs} barcodes") + return lookup + + +def join_barcodes(df: pd.DataFrame, lookup: dict[tuple[int, str], list[str]]) -> pd.DataFrame: + """Join barcode indices to cells via (global_cage_id_matched, lane). + + Parameters + ---------- + df : pd.DataFrame + Must have columns: global_cage_id_matched, lane_id. + lookup : dict + From build_barcode_lookup. + + Returns + ------- + pd.DataFrame + With added barcode_index and in_anndata columns. + """ + barcode_indices = [] + for _, row in df.iterrows(): + key = (int(row["global_cage_id_matched"]), f"lane_{int(row['lane_id'])}") + barcodes = lookup.get(key, []) + barcode_indices.append(";".join(barcodes) if barcodes else "") + df["barcode_index"] = barcode_indices + df["in_anndata"] = df["barcode_index"] != "" + return df + + +def apply_filters(df: pd.DataFrame, filters: dict) -> pd.DataFrame: + """Apply column-level filters to the DataFrame. + + Parameters + ---------- + df : pd.DataFrame + Input DataFrame. + filters : dict + Mapping of column_name → {min, max, eq, isin}. + + Returns + ------- + pd.DataFrame + Filtered DataFrame. + """ + for col, conditions in filters.items(): + if col not in df.columns: + raise ValueError(f"Filter column '{col}' not found. Available: {list(df.columns)[:10]}...") + if "min" in conditions: + df = df[df[col] >= conditions["min"]] + if "max" in conditions: + df = df[df[col] <= conditions["max"]] + if "eq" in conditions: + df = df[df[col] == conditions["eq"]] + if "isin" in conditions: + df = df[df[col].isin(conditions["isin"])] + return df + + +def resolve_channel_index(store: zarr.Group, zarr_path: str, channel_name: str) -> int: + """Resolve the integer index of a named channel in an OME-Zarr FOV. + + Parameters + ---------- + store : zarr.Group + Opened zarr store. + zarr_path : str + Relative path to the FOV group. + channel_name : str + Channel label to look up. + + Returns + ------- + int + Zero-based channel index. + """ + fov_group = store[zarr_path] + channels = fov_group.attrs["omero"]["channels"] + labels = [ch.get("label", ch.get("name", "")) for ch in channels] + if channel_name not in labels: + raise ValueError(f"Channel '{channel_name}' not found. Available: {labels}") + return labels.index(channel_name) + + +def crop_cell( + fov_array: np.ndarray, + cy: int, + cx: int, + half: int, + channel_idx: int | None = None, +) -> np.ndarray | None: + """Crop a square patch centered on (cy, cx) from a 2D FOV array. + + Parameters + ---------- + fov_array : np.ndarray + FOV image array of shape ``(C, H, W)``. + cy : int + Y centroid in FOV pixels. + cx : int + X centroid in FOV pixels. + half : int + Half the crop size in pixels. + channel_idx : int or None + Channel index to select. If None, use all channels. + + Returns + ------- + np.ndarray or None + Cropped patch, or None if out of bounds. + """ + _, h, w = fov_array.shape + y0, y1 = cy - half, cy + half + x0, x1 = cx - half, cx + half + if y0 < 0 or x0 < 0 or y1 > h or x1 > w: + return None + if channel_idx is not None: + return fov_array[channel_idx : channel_idx + 1, y0:y1, x0:x1] + return fov_array[:, y0:y1, x0:x1] + + +def main(): + """Extract DynaCLR embeddings for cellanome cells.""" + parser = argparse.ArgumentParser(description="Extract DynaCLR embeddings for cellanome cells.") + parser.add_argument("config", help="Path to YAML config file") + args = parser.parse_args() + + with open(args.config) as f: + cfg = yaml.safe_load(f) + + zarr_store = cfg["zarr_store"] + analysis_base = cfg["analysis_base"] + transcriptome_anndata = cfg.get("transcriptome_anndata", None) + output_path = cfg["output_path"] + ckpt_path = cfg["ckpt_path"] + encoder_config = cfg["encoder_config"] + channel_name = cfg.get("channel_name", "White") + output_key = cfg.get("output_key", None) + patch_size = cfg.get("patch_size", 96) + reference_pixel_size = cfg.get("reference_pixel_size", 1.0) + source_pixel_size = cfg.get("source_pixel_size", 1.0) + batch_size = cfg.get("batch_size", 128) + device_str = cfg.get("device", "cuda") + scan_ids = cfg.get("scan_ids", None) + lane_ids = cfg.get("lane_ids", None) + filters = cfg.get("filters", {}) + + # --- Load and prepare data --- + df = load_primary_analysis(analysis_base, scan_ids, lane_ids) + n_raw = len(df) + df = apply_filters(df, filters) + logger.info(f"After filtering: {len(df)} cells (removed {n_raw - len(df)})") + + df = derive_zarr_paths(df) + if transcriptome_anndata is not None: + lookup = build_barcode_lookup(transcriptome_anndata) + df = join_barcodes(df, lookup) + n_matched = df["in_anndata"].sum() + logger.info(f"Barcode match: {n_matched}/{len(df)} cells ({100 * n_matched / len(df):.1f}%)") + else: + logger.info("No transcriptome_anndata provided; skipping barcode join") + + # --- Pixel size rescaling --- + # raw_crop covers the same physical area as patch_size at reference resolution. + # Larger source pixels → fewer pixels needed. + raw_half = round(patch_size * reference_pixel_size / source_pixel_size) // 2 + raw_crop_size = 2 * raw_half + logger.info(f"Raw crop: {raw_crop_size}x{raw_crop_size} -> model input: {patch_size}x{patch_size}") + + # --- Resolve channel --- + store = zarr.open(zarr_store, mode="r") + first_zarr_path = df["zarr_path"].iloc[0] + channel_idx = resolve_channel_index(store, first_zarr_path, channel_name) + logger.info(f"Channel '{channel_name}' -> index {channel_idx}") + + short_name = CHANNEL_SHORT_NAMES.get(channel_name, channel_name) + output_key = output_key or f"dynaclr_{short_name}" + + # --- Load model --- + device = torch.device(device_str if torch.cuda.is_available() else "cpu") + encoder_config["stem_kernel_size"] = tuple(encoder_config["stem_kernel_size"]) + encoder_config["stem_stride"] = tuple(encoder_config["stem_stride"]) + encoder = ContrastiveEncoder(**encoder_config) + ckpt = torch.load(ckpt_path, map_location=device, weights_only=False) + sd = {k.replace("model.", "", 1): v for k, v in ckpt["state_dict"].items() if k.startswith("model.")} + encoder.load_state_dict(sd) + encoder = encoder.to(device) + encoder.eval() + logger.info(f"Loaded DynaCLR encoder from {ckpt_path} on {device}") + + # --- Inference --- + df = df.sort_values("zarr_path").reset_index(drop=True) + current_fov_path: str | None = None + current_fov: np.ndarray | None = None + all_embeddings = [] + valid_indices = [] + skipped_border = 0 + + n_batches = math.ceil(len(df) / batch_size) + pbar = tqdm(range(0, len(df), batch_size), total=n_batches, desc="Embedding", unit="batch") + for batch_start in pbar: + batch_df = df.iloc[batch_start : batch_start + batch_size] + patches = [] + batch_valid = [] + + for idx, row in batch_df.iterrows(): + zarr_path = row["zarr_path"] + cy, cx = int(row["object_y_fov"]), int(row["object_x_fov"]) + + if zarr_path != current_fov_path: + current_fov = store[zarr_path]["0"][0, :, 0] + current_fov_path = zarr_path + + patch = crop_cell(current_fov, cy, cx, raw_half, channel_idx=channel_idx) + if patch is None: + skipped_border += 1 + continue + + patches.append(patch) + batch_valid.append(idx) + + if not patches: + continue + + batch_tensor = torch.from_numpy(np.stack(patches)).float() + if raw_crop_size != patch_size: + batch_tensor = F.interpolate( + batch_tensor, + size=(patch_size, patch_size), + mode="bilinear", + align_corners=False, + ) + + # Per-sample z-score: zero mean, unit std + mean = batch_tensor.flatten(1).mean(dim=1, keepdim=True).unsqueeze(-1).unsqueeze(-1) + std = batch_tensor.flatten(1).std(dim=1, keepdim=True).unsqueeze(-1).unsqueeze(-1).clamp(min=1e-8) + batch_tensor = (batch_tensor - mean) / std + + batch_tensor = batch_tensor.unsqueeze(2).to(device) + + with torch.inference_mode(): + embedding, _ = encoder(batch_tensor) + + all_embeddings.append(embedding.cpu().numpy()) + valid_indices.extend(batch_valid) + pbar.set_postfix(cells=len(valid_indices), skipped=skipped_border) + + if skipped_border > 0: + logger.warning(f"Skipped {skipped_border} cells too close to FOV border") + + # --- Write cell-level anndata --- + embeddings = np.concatenate(all_embeddings, axis=0) + valid_df = df.iloc[valid_indices].reset_index(drop=True) + logger.info(f"Embeddings: {embeddings.shape}") + + pd.options.future.infer_string = False + obs = valid_df.copy() + for col in obs.select_dtypes(include=["string", "string[pyarrow]"]).columns: + obs[col] = obs[col].astype(object) + obs.index = obs["object_uuid"].astype(str) + + cell_adata = ad.AnnData(X=embeddings.astype(np.float32), obs=obs) + cell_adata.write_zarr(output_path) + logger.info(f"Wrote {output_path}: {cell_adata.n_obs} cells x {cell_adata.n_vars} dims") + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/scripts/dataloader_inspection/benchmark_setup_time.py b/applications/dynaclr/scripts/dataloader_inspection/benchmark_setup_time.py new file mode 100644 index 000000000..7668b4aa9 --- /dev/null +++ b/applications/dynaclr/scripts/dataloader_inspection/benchmark_setup_time.py @@ -0,0 +1,117 @@ +"""Benchmark MultiExperimentDataModule setup time. + +Measures the time for _compute_valid_anchors and _build_match_lookup +on the DynaCLR-2D-MIP-BagOfChannels parquet (3.3M rows) to quantify +the speedup from the vectorized implementations. + +Usage +----- + uv run python applications/dynaclr/scripts/dataloader_inspection/benchmark_setup_time.py +""" + +from __future__ import annotations + +import time + +CELL_INDEX_PARQUET = "applications/dynaclr/configs/cell_index/DynaCLR-2D-MIP-BagOfChannels.parquet" +TAU_RANGE = (0.5, 2.0) +YX_PATCH_SIZE = (256, 256) + + +def _fmt(seconds: float) -> str: + if seconds < 1: + return f"{seconds * 1000:.1f} ms" + if seconds < 60: + return f"{seconds:.2f} s" + return f"{seconds / 60:.1f} min" + + +def main() -> None: + """Run the MultiExperimentDataModule setup benchmark and print a timing summary.""" + from dynaclr.data.experiment import ExperimentRegistry + from dynaclr.data.index import MultiExperimentIndex + from viscy_data.cell_index import read_cell_index + + print("=" * 60) + print("MultiExperimentDataModule setup benchmark") + print(f"Parquet: {CELL_INDEX_PARQUET}") + print("=" * 60) + + # ---------------------------------------------------------------- + # Parquet read (shared cost) + # ---------------------------------------------------------------- + t0 = time.perf_counter() + df = read_cell_index(CELL_INDEX_PARQUET) + parquet_time = time.perf_counter() - t0 + print(f"\nParquet read: {_fmt(parquet_time)} ({len(df):,} rows)") + + # ---------------------------------------------------------------- + # Registry build (shared cost) + # ---------------------------------------------------------------- + t0 = time.perf_counter() + registry, _ = ExperimentRegistry.from_cell_index( + CELL_INDEX_PARQUET, + z_window=1, + z_extraction_window=20, + z_focus_offset=0.3, + focus_channel="Phase3D", + reference_pixel_size_xy_um=0.1494, + ) + registry_time = time.perf_counter() - t0 + print(f"Registry build: {_fmt(registry_time)} ({len(registry.experiments)} experiments)") + + # ---------------------------------------------------------------- + # MultiExperimentIndex (includes _compute_valid_anchors) + # ---------------------------------------------------------------- + print("\n--- MultiExperimentIndex (cell_index_df path) ---") + t0 = time.perf_counter() + index = MultiExperimentIndex( + registry=registry, + yx_patch_size=YX_PATCH_SIZE, + tau_range_hours=TAU_RANGE, + cell_index_df=df, + positive_cell_source="lookup", + positive_match_columns=["lineage_id"], + ) + index_time = time.perf_counter() - t0 + print(f" Total: {_fmt(index_time)}") + print(f" Tracks: {len(index.tracks):,} Valid anchors: {len(index.valid_anchors):,}") + + # ---------------------------------------------------------------- + # _build_match_lookup (MultiExperimentTripletDataset init) + # ---------------------------------------------------------------- + print("\n--- _build_match_lookup (dataset init) ---") + from dynaclr.data.dataset import MultiExperimentTripletDataset + + t0 = time.perf_counter() + MultiExperimentTripletDataset( + index=index, + fit=True, + tau_range_hours=TAU_RANGE, + cache_pool_bytes=0, + channels_per_sample=1, + positive_cell_source="lookup", + positive_match_columns=["lineage_id"], + ) + dataset_time = time.perf_counter() - t0 + print(f" _build_match_lookup: {_fmt(dataset_time)}") + + # ---------------------------------------------------------------- + # Summary + # ---------------------------------------------------------------- + total = parquet_time + registry_time + index_time + dataset_time + print("\n" + "=" * 60) + print("SUMMARY") + print("=" * 60) + print("| Step | Time |") + print("|-------------------------|----------------|") + print(f"| Parquet read | {_fmt(parquet_time):>14} |") + print(f"| Registry build | {_fmt(registry_time):>14} |") + print(f"| Index (_valid_anchors) | {_fmt(index_time):>14} |") + print(f"| Dataset (_match_lookup) | {_fmt(dataset_time):>14} |") + print("|-------------------------|----------------|") + print(f"| **Total** | {_fmt(total):>14} |") + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/scripts/dataloader_inspection/check_batch_composition.py b/applications/dynaclr/scripts/dataloader_inspection/check_batch_composition.py index af6a38a8b..766c0551d 100644 --- a/applications/dynaclr/scripts/dataloader_inspection/check_batch_composition.py +++ b/applications/dynaclr/scripts/dataloader_inspection/check_batch_composition.py @@ -13,7 +13,7 @@ Usage:: - python applications/dynaclr/scripts/dataloader_inspection/check_batch_composition.py + uv run python applications/dynaclr/scripts/dataloader_inspection/check_batch_composition.py """ # ruff: noqa: E402, D103 @@ -45,7 +45,7 @@ COLLECTION_PATH = "/home/eduardo.hirata/repos/viscy/applications/dynaclr/configs/collections/example_cell_index.yaml" Z_WINDOW = 1 -YX_PATCH_SIZE = (256, 256) +YX_PATCH_SIZE = (192, 192) FINAL_YX_PATCH_SIZE = (160, 160) BATCH_SIZE = 8 NUM_WORKERS = 4 @@ -164,7 +164,7 @@ def run_scenario( bi, name, checks=checks, - save_path=OUTPUT_DIR / f"{name.lower().replace(' ', '_')}_batch{bi}.png" if OUTPUT_DIR else None, + save_path=(OUTPUT_DIR / f"{name.lower().replace(' ', '_')}_batch{bi}.png" if OUTPUT_DIR else None), ) return batches @@ -183,7 +183,6 @@ def run_scenario( print("Building DataModule...") dm = MultiExperimentDataModule( - collection_path=COLLECTION_PATH, cell_index_path=CELL_INDEX_PATH, z_window=Z_WINDOW, yx_patch_size=YX_PATCH_SIZE, @@ -367,7 +366,6 @@ def run_scenario( # %% dm_simclr = MultiExperimentDataModule( - collection_path=COLLECTION_PATH, cell_index_path=CELL_INDEX_PATH, z_window=Z_WINDOW, yx_patch_size=YX_PATCH_SIZE, @@ -434,7 +432,6 @@ def run_scenario( def run_normalization_scenario(name: str, level: str) -> None: dm_n = MultiExperimentDataModule( - collection_path=COLLECTION_PATH, cell_index_path=CELL_INDEX_PATH, z_window=Z_WINDOW, yx_patch_size=YX_PATCH_SIZE, diff --git a/applications/dynaclr/scripts/dataloader_inspection/data_patch_resizing.py b/applications/dynaclr/scripts/dataloader_inspection/data_patch_resizing.py index 0a2816438..e21015fab 100644 --- a/applications/dynaclr/scripts/dataloader_inspection/data_patch_resizing.py +++ b/applications/dynaclr/scripts/dataloader_inspection/data_patch_resizing.py @@ -1,29 +1,32 @@ """End-to-end proof that DynaCLR pixel-size normalization works. -Creates a temporary parquet with modified pixel sizes, feeds it through the -real ``MultiExperimentDataModule`` dataloader, and plots the output patches. +Builds the datamodule once to get sample metadata (cell coordinates), +then reads native zarr crops at different pixel-size-derived scales +and rescales them to show how the pipeline normalizes physical extent. -The Mantis experiment (0.1494 um/px) is the reference. The Dragonfly experiment -natively has 0.206 um/px — we test with both the real value and an artificial -override to show the dataloader responds correctly. +Row 0: Raw FOV with bounding boxes for each pixel-size variant. +Row 1: Native zarr crop → _rescale_patch → center crop = model input (160×160). Usage:: uv run python applications/dynaclr/scripts/dataloader_inspection/data_patch_resizing.py """ +# %% # ruff: noqa: D103 from __future__ import annotations -import tempfile from pathlib import Path +import matplotlib.patches as mpatches import matplotlib.pyplot as plt import numpy as np -import pandas as pd +import torch +from iohub.ngff.nodes import open_ome_zarr from dynaclr.data.datamodule import MultiExperimentDataModule +from dynaclr.data.dataset import _rescale_patch from viscy_transforms._crop import BatchedCenterSpatialCrop # --------------------------------------------------------------------------- @@ -32,7 +35,7 @@ _ROOT = Path(__file__).resolve().parents[4] -CELL_INDEX_PATH = _ROOT / "applications/dynaclr/configs/cell_index/dragonfly_mantis_demo.parquet" +CELL_INDEX_PATH = _ROOT / "applications/dynaclr/configs/cell_index/example_mantis_dragonfly.parquet" OUTPUT_DIR = _ROOT / "applications/dynaclr/scripts/dataloader_inspection/output" OUTPUT_PATH = OUTPUT_DIR / "data_patch_resizing.png" @@ -40,116 +43,194 @@ YX_PATCH_SIZE = (200, 200) FINAL_YX_PATCH_SIZE = (160, 160) REFERENCE_PIXEL_SIZE_XY_UM = 0.1494 -REFERENCE_PIXEL_SIZE_Z_UM = 0.2878 CHANNEL_NAME = "Phase3D" DRAGONFLY_EXP = "2024_08_14_ZIKV_pal17_48h" -MANTIS_EXP = "2025_07_24_A549_SEC61B_ZIKV" +MANTIS_EXP = "2025_07_24_A549_SEC61_ZIKV" -# Pixel sizes to test for Dragonfly (real + artificial overrides) +# Pixel sizes to visualize for Dragonfly DRAGONFLY_PIXEL_SIZES = { "real (0.206)": 0.206, - "override (0.1494)": 0.1494, # same as reference — should be no-op - "override (0.7)": 0.7, # even coarser — should crop fewer pixels + "same as ref (0.1494)": 0.1494, + "coarser (0.7)": 0.7, } +BBOX_COLORS = ["#e74c3c", "#2ecc71", "#3498db"] +INCLUDE_WELLS = ["A/2", "0/4"] # --------------------------------------------------------------------------- -# Helpers +# Step 1: Build datamodule once to get sample metadata # --------------------------------------------------------------------------- +print("Building datamodule...") +dm = MultiExperimentDataModule( + cell_index_path=str(CELL_INDEX_PATH), + z_window=Z_WINDOW, + yx_patch_size=YX_PATCH_SIZE, + final_yx_patch_size=FINAL_YX_PATCH_SIZE, + batch_size=8, + num_workers=0, + channels_per_sample=[CHANNEL_NAME], + reference_pixel_size_xy_um=REFERENCE_PIXEL_SIZE_XY_UM, + reference_pixel_size_z_um=None, + positive_cell_source="self", + tau_range=(0.0, 100.0), + stratify_by=None, + include_wells=INCLUDE_WELLS, +) +dm.setup("fit") + +registry = dm.train_dataset.index.registry + +print("Drawing samples for metadata...") +loader = dm.train_dataloader() +per_exp: dict[str, dict] = {} +needed = {e.name for e in registry.experiments} + +MAX_BATCHES = 200 +for batch_idx, batch in enumerate(loader): + anchor = batch["anchor"] + meta = batch["anchor_meta"] + for i in range(len(meta)): + exp_name = meta[i]["experiment"] + if exp_name not in per_exp: + per_exp[exp_name] = {"meta": meta[i], "patch": anchor[i]} + if per_exp.keys() >= needed: + break + if batch_idx >= MAX_BATCHES: + print(f" WARNING: only found experiments {set(per_exp.keys())} after {MAX_BATCHES} batches") + break + +for exp_name, d in per_exp.items(): + m = d["meta"] + print(f" {exp_name}: fov={m['fov_name']}, t={m['t']}, y={m['y_clamp']}, x={m['x_clamp']}") -def make_tmp_parquet(pixel_size_xy: float, pixel_size_z: float = REFERENCE_PIXEL_SIZE_Z_UM) -> str: - """Write a temp parquet with Dragonfly pixel sizes overridden.""" - df = pd.read_parquet(CELL_INDEX_PATH) - mask = df["experiment"] == DRAGONFLY_EXP - df.loc[mask, "pixel_size_xy_um"] = pixel_size_xy - df.loc[mask, "pixel_size_z_um"] = pixel_size_z - tmp = tempfile.NamedTemporaryFile(suffix=".parquet", delete=False) - df.to_parquet(tmp.name) - return tmp.name - - -def draw_one_sample(parquet_path: str) -> dict: - """Build a datamodule, draw one batch, return first anchor patch + metadata.""" - dm = MultiExperimentDataModule( - collection_path=None, - cell_index_path=parquet_path, - z_window=Z_WINDOW, - yx_patch_size=YX_PATCH_SIZE, - final_yx_patch_size=FINAL_YX_PATCH_SIZE, - batch_size=8, - num_workers=0, - channels_per_sample=[CHANNEL_NAME], - reference_pixel_size_xy_um=REFERENCE_PIXEL_SIZE_XY_UM, - reference_pixel_size_z_um=REFERENCE_PIXEL_SIZE_Z_UM, - positive_cell_source="self", - tau_range=(0.0, 100.0), - stratify_by=None, - ) - dm.setup("fit") - - registry = dm.train_dataset.index.registry - scale_factors = {e.name: registry.scale_factors[e.name] for e in registry.experiments} - - # Draw batches until we get one from each experiment - loader = dm.train_dataloader() - per_exp: dict[str, dict] = {} - needed = {e.name for e in registry.experiments} - for batch in loader: - anchor = batch["anchor"] - meta = batch["anchor_meta"] - for i in range(anchor.shape[0]): - exp_name = meta[i]["experiment"] - if exp_name not in per_exp: - per_exp[exp_name] = { - "patch": anchor[i], - "meta": meta[i], - "scale": scale_factors[exp_name], - } - if per_exp.keys() >= needed: - break +# --------------------------------------------------------------------------- +# Step 2: Read raw FOV slices and native crops from zarr +# --------------------------------------------------------------------------- - return per_exp +def read_fov_and_crop( + meta: dict, + pixel_size_xy: float, + z_focus: int, + channel_name: str = CHANNEL_NAME, +) -> tuple[np.ndarray, np.ndarray, int, int]: + """Read the focus Z-slice FOV and a native crop at the given pixel size. + + Returns + ------- + fov : np.ndarray + Full FOV 2D image at the focus Z-slice. + crop : np.ndarray + Native crop at the scale implied by pixel_size_xy. + y_half, x_half : int + Half-widths of the native crop in pixels. + """ + store_path = meta["store_path"] + fov_name = meta["fov_name"] + t = int(meta["t"]) + y_center = int(meta["y_clamp"]) + x_center = int(meta["x_clamp"]) + + scale_yx = REFERENCE_PIXEL_SIZE_XY_UM / pixel_size_xy + y_half = round((YX_PATCH_SIZE[0] // 2) * scale_yx) + x_half = round((YX_PATCH_SIZE[1] // 2) * scale_yx) + + fov_path = f"{store_path}/{fov_name}" + with open_ome_zarr(fov_path, mode="r") as pos: + ch_idx = list(pos.channel_names).index(channel_name) + _, _, _, img_h, img_w = pos.data.shape + + fov = pos.data.oindex[t, ch_idx, z_focus, :, :] + + y0 = max(0, y_center - y_half) + y1 = min(img_h, y_center + y_half) + x0 = max(0, x_center - x_half) + x1 = min(img_w, x_center + x_half) + crop = pos.data.oindex[t, ch_idx, z_focus, y0:y1, x0:x1] + + return fov, crop, y_half, x_half -# --------------------------------------------------------------------------- -# Run the dataloader for each Dragonfly pixel size configuration -# --------------------------------------------------------------------------- center_crop = BatchedCenterSpatialCrop(roi_size=(Z_WINDOW, FINAL_YX_PATCH_SIZE[0], FINAL_YX_PATCH_SIZE[1])) -all_results = {} -for label, px_size in DRAGONFLY_PIXEL_SIZES.items(): - print(f"\n--- Dragonfly pixel_size_xy_um = {px_size} ({label}) ---") - tmp_path = make_tmp_parquet(px_size) - per_exp = draw_one_sample(tmp_path) - - for exp_name, data in per_exp.items(): - scale = data["scale"] - patch = data["patch"] # (C, Z, Y, X) at yx_patch_size - final = center_crop(patch[None])[0] - key = f"{exp_name}\n{label}" if exp_name == DRAGONFLY_EXP else exp_name - if exp_name == MANTIS_EXP and label != "real (0.206)": - continue # Mantis is unchanged, only show once - print(f" {exp_name}: scale_yx={scale[1]:.3f}, patch={tuple(patch.shape)}") - all_results[key] = { - "patch_2d": patch[0, 0].numpy(), +z_focuses = {} +for e in registry.experiments: + zr = registry.z_ranges[e.name] + z_focuses[e.name] = (zr[0] + zr[1]) // 2 + print(f" {e.name}: z_range={zr}, z_focus={z_focuses[e.name]}") + +print("Reading zarr crops...") + +results: list[dict] = [] + +# Mantis (reference — scale ≈ 1.0) +m_meta = per_exp[MANTIS_EXP]["meta"] +m_fov, m_crop, m_yh, m_xh = read_fov_and_crop(m_meta, REFERENCE_PIXEL_SIZE_XY_UM, z_focuses[MANTIS_EXP]) +m_tensor = torch.from_numpy(m_crop).float().unsqueeze(0).unsqueeze(0) # (1, 1, H, W) +m_rescaled = _rescale_patch(m_tensor, (1.0, 1.0, 1.0), (Z_WINDOW, YX_PATCH_SIZE[0], YX_PATCH_SIZE[1])) +m_final = center_crop(m_rescaled[None])[0] +m_dl_patch = per_exp[MANTIS_EXP]["patch"] +m_dl_final = center_crop(m_dl_patch[None])[0] +results.append( + { + "label": f"{MANTIS_EXP}\nreference ({REFERENCE_PIXEL_SIZE_XY_UM} µm/px)", + "exp": MANTIS_EXP, + "fov": m_fov, + "native_crop": m_crop, + "final_2d": m_final[0, 0].numpy(), + "dl_final_2d": m_dl_final[0, 0].numpy(), + "scale_yx": 1.0, + "pixel_size": REFERENCE_PIXEL_SIZE_XY_UM, + "y_half": m_yh, + "x_half": m_xh, + "meta": m_meta, + } +) + +# Dragonfly — one entry per pixel-size variant +d_meta = per_exp[DRAGONFLY_EXP]["meta"] +d_dl_patch = per_exp[DRAGONFLY_EXP]["patch"] +d_dl_final = center_crop(d_dl_patch[None])[0] +d_fov = None + +for i, (label, px_size) in enumerate(DRAGONFLY_PIXEL_SIZES.items()): + fov, crop, y_half, x_half = read_fov_and_crop(d_meta, px_size, z_focuses[DRAGONFLY_EXP]) + if d_fov is None: + d_fov = fov + + scale_yx = REFERENCE_PIXEL_SIZE_XY_UM / px_size + scale = (1.0, scale_yx, scale_yx) + target = (Z_WINDOW, YX_PATCH_SIZE[0], YX_PATCH_SIZE[1]) + + crop_tensor = torch.from_numpy(crop).float().unsqueeze(0).unsqueeze(0) + rescaled = _rescale_patch(crop_tensor, scale, target) + final = center_crop(rescaled[None])[0] + + print(f" {label}: scale_yx={scale_yx:.3f}, native_crop={crop.shape}, rescaled={tuple(rescaled.shape)}") + + results.append( + { + "label": f"{DRAGONFLY_EXP}\n{label}", + "exp": DRAGONFLY_EXP, + "fov": d_fov, + "native_crop": crop, "final_2d": final[0, 0].numpy(), - "scale": scale, - "pixel_size_label": label if exp_name == DRAGONFLY_EXP else "reference", + "dl_final_2d": d_dl_final[0, 0].numpy(), + "scale_yx": scale_yx, + "pixel_size": px_size, + "y_half": y_half, + "x_half": x_half, + "meta": d_meta, } + ) # --------------------------------------------------------------------------- -# Plot +# Step 3: Plot # --------------------------------------------------------------------------- -n = len(all_results) -fig, axes = plt.subplots(2, n, figsize=(5 * n, 10)) -if n == 1: - axes = axes[:, None] - def add_scalebar(ax, pixel_size_um, patch_px, bar_um=5.0): bar_px = bar_um / pixel_size_um @@ -165,7 +246,7 @@ def add_scalebar(ax, pixel_size_um, patch_px, bar_um=5.0): ax.text( x0 + bar_px / 2, y - 8, - f"{bar_um:.0f} um", + f"{bar_um:.0f} µm", color="white", fontsize=9, ha="center", @@ -173,30 +254,73 @@ def add_scalebar(ax, pixel_size_um, patch_px, bar_um=5.0): ) -for col, (key, r) in enumerate(all_results.items()): - scale = r["scale"] +def add_bbox(ax, y_center, x_center, y_half, x_half, color, label, img_shape): + y0 = max(0, y_center - y_half) + x0 = max(0, x_center - x_half) + h = min(y_center + y_half, img_shape[0]) - y0 + w = min(x_center + x_half, img_shape[1]) - x0 + rect = mpatches.Rectangle( + (x0, y0), + w, + h, + linewidth=2, + edgecolor=color, + facecolor="none", + linestyle="-", + label=label, + ) + ax.add_patch(rect) + + +n = len(results) +fig, axes = plt.subplots(3, n, figsize=(5 * n, 14)) +if n == 1: + axes = axes[:, None] + +for col, r in enumerate(results): + meta = r["meta"] + exp_name = r["exp"] + y_center = int(meta["y_clamp"]) + x_center = int(meta["x_clamp"]) - # Row 0: Dataloader output (yx_patch_size, after _rescale_patch) + # Row 0: Raw FOV with bounding box ax = axes[0, col] - patch = r["patch_2d"] - vmin, vmax = np.percentile(patch, (1, 99)) - ax.imshow(patch, cmap="gray", vmin=vmin, vmax=vmax) - add_scalebar(ax, REFERENCE_PIXEL_SIZE_XY_UM, YX_PATCH_SIZE[0]) - ax.set_title( - f"{key}\nscale_yx=({scale[1]:.3f}, {scale[2]:.3f})\nDataloader: {YX_PATCH_SIZE[0]}x{YX_PATCH_SIZE[1]} px", - fontsize=9, - fontweight="bold", - ) + fov = r["fov"] + vmin_raw, vmax_raw = np.percentile(fov, (1, 99)) + ax.imshow(fov, cmap="gray", vmin=vmin_raw, vmax=vmax_raw) + + if exp_name == DRAGONFLY_EXP: + for i, (lbl, px_size) in enumerate(DRAGONFLY_PIXEL_SIZES.items()): + s = REFERENCE_PIXEL_SIZE_XY_UM / px_size + yh = round((YX_PATCH_SIZE[0] // 2) * s) + xh = round((YX_PATCH_SIZE[1] // 2) * s) + add_bbox(ax, y_center, x_center, yh, xh, BBOX_COLORS[i], lbl, fov.shape) + ax.legend(loc="upper left", fontsize=7, framealpha=0.7) + else: + add_bbox( + ax, + y_center, + x_center, + r["y_half"], + r["x_half"], + BBOX_COLORS[0], + "reference", + fov.shape, + ) + + ax.set_title(f"{r['label']}\nRaw FOV (mid-Z)", fontsize=9, fontweight="bold") ax.axis("off") - # Row 1: After center crop = MODEL INPUT + # Row 1: Model input (native crop → rescale → center crop) ax = axes[1, col] final = r["final_2d"] + vmin, vmax = np.percentile(final, (1, 99)) ax.imshow(final, cmap="gray", vmin=vmin, vmax=vmax) add_scalebar(ax, REFERENCE_PIXEL_SIZE_XY_UM, FINAL_YX_PATCH_SIZE[0]) phys = FINAL_YX_PATCH_SIZE[0] * REFERENCE_PIXEL_SIZE_XY_UM ax.set_title( - f"Model input: {FINAL_YX_PATCH_SIZE[0]}x{FINAL_YX_PATCH_SIZE[1]} px | {phys:.1f} um", + f"Model input: {FINAL_YX_PATCH_SIZE[0]}×{FINAL_YX_PATCH_SIZE[1]} px | {phys:.1f} µm\n" + f"native crop: {r['native_crop'].shape} → scale_yx={r['scale_yx']:.3f}", fontsize=9, ) ax.axis("off") @@ -205,9 +329,27 @@ def add_scalebar(ax, pixel_size_um, patch_px, bar_um=5.0): spine.set_edgecolor("#2ecc71") spine.set_linewidth(3) + # Row 2: Actual dataloader output (for comparison with "real" variant) + ax = axes[2, col] + dl_final = r["dl_final_2d"] + vmin_dl, vmax_dl = np.percentile(dl_final, (1, 99)) + ax.imshow(dl_final, cmap="gray", vmin=vmin_dl, vmax=vmax_dl) + add_scalebar(ax, REFERENCE_PIXEL_SIZE_XY_UM, FINAL_YX_PATCH_SIZE[0]) + ax.set_title( + f"Dataloader output: {FINAL_YX_PATCH_SIZE[0]}×{FINAL_YX_PATCH_SIZE[1]} px\n" + f"(same for all variants — real pixel size)", + fontsize=9, + ) + ax.axis("off") + for spine in ax.spines.values(): + spine.set_visible(True) + spine.set_edgecolor("#e67e22") + spine.set_linewidth(3) + row_labels = [ - "Dataloader output\n(after _rescale_patch)", - "Model input\n(after center crop)", + "Raw FOV + crop region", + "Expected\n(native crop → rescale → crop)", + "Dataloader output\n(real pixel size)", ] for row_idx, label in enumerate(row_labels): axes[row_idx, 0].annotate( @@ -222,8 +364,9 @@ def add_scalebar(ax, pixel_size_um, patch_px, bar_um=5.0): ) fig.suptitle( - f"Pixel-size normalization proof: reference={REFERENCE_PIXEL_SIZE_XY_UM} um/px\n" - f"Same Dragonfly data with different declared pixel sizes -> different scale factors", + f"Pixel-size normalization: reference={REFERENCE_PIXEL_SIZE_XY_UM} µm/px\n" + f"Different pixel sizes → different native crops" + f" → same {FINAL_YX_PATCH_SIZE[0]}×{FINAL_YX_PATCH_SIZE[1]} model input", fontsize=12, fontweight="bold", y=0.99, @@ -233,3 +376,5 @@ def add_scalebar(ax, pixel_size_um, patch_px, bar_um=5.0): fig.savefig(OUTPUT_PATH, dpi=150, bbox_inches="tight") print(f"\nSaved: {OUTPUT_PATH}") plt.close(fig) + +# %% diff --git a/applications/dynaclr/scripts/dataloader_inspection/dataloader_demo.py b/applications/dynaclr/scripts/dataloader_inspection/dataloader_demo.py new file mode 100644 index 000000000..3744014d6 --- /dev/null +++ b/applications/dynaclr/scripts/dataloader_inspection/dataloader_demo.py @@ -0,0 +1,443 @@ +"""Dataloader demo: visualize raw, normalized, and augmented batches. + +Jupyter-style notebook (use ``# %%`` cells in VS Code or JupyterLab). + +Shows what the DynaCLR model actually receives as input. For each batch: + +- **Row 0 (anchor raw)**: raw patches from zarr (no transforms). +- **Row 1 (anchor aug)**: after normalization + augmentation + crop + (exactly what the model sees during training). +- **Row 2 (positive raw)**: positive pair raw patches. +- **Row 3 (positive aug)**: positive after transforms. + +Each column annotation shows experiment, marker, perturbation, timepoint, +and lineage/temporal checks. Batch composition is summarized in the title. + +Usage:: + + uv run python applications/dynaclr/scripts/dataloader_inspection/dataloader_demo.py +""" + +# ruff: noqa: E402, D103 + +# %% [markdown] +# # DynaCLR Dataloader Demo +# +# Visualize anchor/positive pairs with normalization and augmentation. +# All parameters are inline — edit and re-run cells. +# +# ## Augmentation pipeline +# +# The augmentation order matters. The pipeline is: +# +# 1. **Normalize** on full extraction patch ``(45, 256, 256)`` +# 2. **Affine** (rotate/scale/shear) on ``(45, 256, 256)`` +# 3. **RandSpatialCrop** to ``(40, 228, 228)`` — random Z for focus +# invariance + random YX for translation augmentation +# 4. **Flip, contrast, scale, smooth, noise** on ``(40, 228, 228)`` +# 5. **CenterCrop** to ``(32, 160, 160)`` — auto-appended by datamodule, +# removes rotation zero-fill artifacts at the edges + +# %% +from __future__ import annotations + +import copy +from collections import Counter +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np + +from dynaclr.data.datamodule import MultiExperimentDataModule +from viscy_transforms import ( + BatchedRandAdjustContrastd, + BatchedRandAffined, + BatchedRandFlipd, + BatchedRandGaussianNoised, + BatchedRandGaussianSmoothd, + BatchedRandScaleIntensityd, + BatchedRandSpatialCropd, + NormalizeSampled, +) + +# %% [markdown] +# ## Configuration +# +# Everything is inline — edit and re-run. + +# %% +# --- Data source --- +CELL_INDEX_PATH = "/hpc/projects/organelle_phenotyping/models/collections/DynaCLR-3D-BagOfChannels-v4.parquet" + +# --- Patch extraction --- +Z_WINDOW = 32 +Z_EXTRACTION_WINDOW = 45 +Z_FOCUS_OFFSET = 0.3 +YX_PATCH_SIZE = (256, 256) +FINAL_YX_PATCH_SIZE = (160, 160) + +# --- Channel mode --- +# 1 = bag-of-channels (one random channel per sample, key="channel_0") +# None = all channels; ["Phase3D", "GFP"] = fixed list +CHANNELS_PER_SAMPLE = 1 +CHANNEL_NAMES = ["channel_0"] + +# --- Positive pair sampling --- +POSITIVE_CELL_SOURCE = "lookup" +POSITIVE_MATCH_COLUMNS = ["lineage_id"] +TAU_RANGE = (0.5, 2.0) +TAU_DECAY_RATE = 2.0 + +# --- Batch sampling --- +BATCH_SIZE = 10 +BATCH_GROUP_BY = None +STRATIFY_BY = ["perturbation"] +SEED = 42 + +# --- Pixel size normalization --- +REFERENCE_PIXEL_SIZE_XY_UM = 0.1494 +REFERENCE_PIXEL_SIZE_Z_UM = 0.174 +FOCUS_CHANNEL = "Phase3D" + +# --- Normalization --- +NORMALIZATIONS = [ + NormalizeSampled( + keys=CHANNEL_NAMES, + level="timepoint_statistics", + subtrahend="mean", + divisor="std", + ), +] + +# --- Augmentations --- +# The RandSpatialCrop goes after the affine to trim rotation artifacts +# and provide random Z + XY translation. The datamodule auto-appends +# a CenterCrop to [Z_WINDOW, 160, 160] at the end. +AUGMENTATIONS = [ + BatchedRandAffined( + keys=CHANNEL_NAMES, + prob=1, + scale_range=[[0.9, 1.1], [0.9, 1.1], [0.9, 1.1]], + rotate_range=[3.14, 0.0, 0.0], + shear_range=[0.05, 0.05, 0.0, 0.05, 0.0, 0.05], + ), + BatchedRandSpatialCropd( + keys=CHANNEL_NAMES, + roi_size=[40, 228, 228], + ), + BatchedRandFlipd(keys=CHANNEL_NAMES, spatial_axes=[1, 2], prob=0.5), + BatchedRandAdjustContrastd(keys=CHANNEL_NAMES, prob=0.5, gamma=(0.6, 1.6)), + BatchedRandScaleIntensityd(keys=CHANNEL_NAMES, prob=0.5, factors=0.5), + BatchedRandGaussianSmoothd( + keys=CHANNEL_NAMES, + prob=1, + sigma_x=[0.25, 0.50], + sigma_y=[0.25, 0.50], + sigma_z=[0.0, 0.2], + ), + BatchedRandGaussianNoised(keys=CHANNEL_NAMES, prob=1, mean=0.0, std=0.1), +] + +# --- Display --- +N_BATCHES = 4 +N_SHOW = 10 +NUM_WORKERS = 1 +SHOW_AUGMENTED = True +OUTPUT_DIR = Path("applications/dynaclr/scripts/dataloader_inspection/results/dataloader_demo") + + +# %% [markdown] +# ## Helpers + + +# %% +def _img_2d(tensor_5d: np.ndarray, sample_idx: int) -> np.ndarray: + """Extract a 2D slice from (B, C, Z, Y, X) for display.""" + img = tensor_5d[sample_idx] + if img.ndim == 4: + img = img[0, img.shape[1] // 2] + elif img.ndim == 3: + img = img[0] + return img + + +def plot_batch( + raw_batch: dict, + aug_batch: dict | None, + batch_idx: int, + n_show: int, + show_augmented: bool = True, + save_path: Path | None = None, +) -> None: + """Plot one batch: raw and augmented anchor/positive pairs.""" + anchor_raw = raw_batch["anchor"].numpy() + positive_raw = raw_batch.get("positive") + has_positive = positive_raw is not None + if has_positive: + positive_raw = positive_raw.numpy() + + anchor_meta = raw_batch["anchor_meta"] + positive_meta = raw_batch.get("positive_meta", [{}] * len(anchor_meta)) + n = min(n_show, len(anchor_meta)) + + row_labels = ["anchor (raw)"] + if show_augmented and aug_batch is not None: + row_labels.append("anchor (aug)") + if has_positive: + row_labels.append("positive (raw)") + if show_augmented and aug_batch is not None: + row_labels.append("positive (aug)") + n_rows = len(row_labels) + + fig, axes = plt.subplots(n_rows, n, figsize=(n * 2.0, n_rows * 2.4), squeeze=False) + + markers = Counter(m.get("marker", "?") for m in anchor_meta[:n]) + perts = Counter(m.get("perturbation", "?") for m in anchor_meta[:n]) + m_str = " ".join(f"{k}={v}" for k, v in markers.most_common(5)) + p_str = " ".join(f"{k}={v}" for k, v in perts.most_common(5)) + fig.suptitle( + f"Batch {batch_idx} | markers: {m_str} | pert: {p_str}", + fontsize=9, + fontweight="bold", + ) + + anchor_aug = aug_batch["anchor"].numpy() if (show_augmented and aug_batch) else None + positive_aug = None + if has_positive and show_augmented and aug_batch: + pa = aug_batch.get("positive") + positive_aug = pa.numpy() if pa is not None else None + + for i in range(n): + am = anchor_meta[i] + pm = positive_meta[i] if i < len(positive_meta) else {} + + row = 0 + img = _img_2d(anchor_raw, i) + vmin, vmax = np.percentile(img, [1, 99]) + axes[row, i].imshow(img, cmap="gray", vmin=vmin, vmax=vmax) + axes[row, i].set_xticks([]) + axes[row, i].set_yticks([]) + lines = [ + f"{am.get('experiment', '?')[:25]}", + f"fov={am.get('fov_name', '?')}", + f"track={am.get('global_track_id', '?')[-15:]}", + f"marker={am.get('marker', '?')}", + f"pert={am.get('perturbation', '?')}", + f"t={am.get('t', '?')}", + ] + if has_positive: + lin_ok = am.get("lineage_id") == pm.get("lineage_id") + dt_ok = am.get("t") != pm.get("t") + lines.append(f"lineage={'✓' if lin_ok else '✗'} Δt={'✓' if dt_ok else '✗'}") + axes[row, i].set_title("\n".join(lines), fontsize=5, linespacing=1.1) + + if anchor_aug is not None: + row += 1 + img_a = _img_2d(anchor_aug, i) + vmin_a, vmax_a = np.percentile(img_a, [1, 99]) + axes[row, i].imshow(img_a, cmap="gray", vmin=vmin_a, vmax=vmax_a) + axes[row, i].set_xticks([]) + axes[row, i].set_yticks([]) + axes[row, i].set_title(f"μ={img_a.mean():.2f} σ={img_a.std():.2f}", fontsize=5) + + if has_positive: + row += 1 + img_p = _img_2d(positive_raw, i) + vmin_p, vmax_p = np.percentile(img_p, [1, 99]) + axes[row, i].imshow(img_p, cmap="gray", vmin=vmin_p, vmax=vmax_p) + axes[row, i].set_xticks([]) + axes[row, i].set_yticks([]) + pos_lines = [ + f"fov={pm.get('fov_name', '?')}", + f"track={pm.get('global_track_id', '?')[-15:]}", + f"pert={pm.get('perturbation', '?')} t={pm.get('t', '?')}", + ] + axes[row, i].set_title("\n".join(pos_lines), fontsize=5, linespacing=1.1) + + if positive_aug is not None: + row += 1 + img_pa = _img_2d(positive_aug, i) + vmin_pa, vmax_pa = np.percentile(img_pa, [1, 99]) + axes[row, i].imshow(img_pa, cmap="gray", vmin=vmin_pa, vmax=vmax_pa) + axes[row, i].set_xticks([]) + axes[row, i].set_yticks([]) + axes[row, i].set_title(f"μ={img_pa.mean():.2f} σ={img_pa.std():.2f}", fontsize=5) + + for r, label in enumerate(row_labels): + axes[r, 0].set_ylabel(label, fontsize=7, fontweight="bold") + + plt.tight_layout() + if save_path: + fig.savefig(save_path, dpi=150, bbox_inches="tight") + print(f" Saved: {save_path}") + else: + plt.show() + # plt.close(fig) + + +# %% [markdown] +# ## Build DataModule +# +# Passes normalizations + augmentations directly to the DataModule. +# ``on_after_batch_transfer`` applies: normalizations → augmentations +# (including RandSpatialCrop) → auto-appended CenterCrop to final size. + +# %% +dm = MultiExperimentDataModule( + cell_index_path=CELL_INDEX_PATH, + z_window=Z_WINDOW, + z_extraction_window=Z_EXTRACTION_WINDOW, + z_focus_offset=Z_FOCUS_OFFSET, + yx_patch_size=YX_PATCH_SIZE, + final_yx_patch_size=FINAL_YX_PATCH_SIZE, + channels_per_sample=CHANNELS_PER_SAMPLE, + positive_cell_source=POSITIVE_CELL_SOURCE, + positive_match_columns=POSITIVE_MATCH_COLUMNS, + tau_range=TAU_RANGE, + tau_decay_rate=TAU_DECAY_RATE, + batch_size=BATCH_SIZE, + batch_group_by=BATCH_GROUP_BY, + stratify_by=STRATIFY_BY, + num_workers=NUM_WORKERS, + seed=SEED, + focus_channel=FOCUS_CHANNEL, + reference_pixel_size_xy_um=REFERENCE_PIXEL_SIZE_XY_UM, + reference_pixel_size_z_um=REFERENCE_PIXEL_SIZE_Z_UM, + channel_dropout_prob=0.0, + normalizations=NORMALIZATIONS, + augmentations=AUGMENTATIONS, +) +dm.setup("fit") + + +# Fake a minimal trainer so on_after_batch_transfer can check .predicting +class _FakeTrainer: + predicting = False + training = True + + +dm.trainer = _FakeTrainer() +print("DataModule ready.\n") + +va = dm.train_dataset.index.valid_anchors +print(f"Anchors: {len(va):,} | Experiments: {va['experiment'].nunique()}") +for exp, g in va.groupby("experiment"): + markers = g["marker"].value_counts().to_dict() if "marker" in g.columns else {} + perts = g["perturbation"].value_counts().to_dict() + print(f" {exp}: {len(g):,} anchors, markers={markers}, perturbations={perts}") + +# %% [markdown] +# ## Draw batches +# +# The dataloader returns raw patches ``(B, C, 45, 256, 256)`` (no transforms). +# ``dm.on_after_batch_transfer`` applies the full pipeline: +# +# 1. Normalize ``(45, 256, 256)`` +# 2. Affine ``(45, 256, 256)`` +# 3. RandSpatialCrop ``(40, 228, 228)`` +# 4. Flip / contrast / noise ``(40, 228, 228)`` +# 5. CenterCrop ``(32, 160, 160)`` (auto-appended) +# +# We deepcopy each batch so we can show raw vs augmented side by side. + +# %% +if OUTPUT_DIR: + OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + +dl = dm.train_dataloader() +dl_iter = iter(dl) + +for batch_idx in range(N_BATCHES): + print(f"\n--- Batch {batch_idx} ---") + batch = next(dl_iter) + + meta = batch["anchor_meta"] + n = len(meta) + markers = Counter(m.get("marker", "?") for m in meta) + perts = Counter(m.get("perturbation", "?") for m in meta) + print(f" {n} samples, markers={dict(markers)}, perturbations={dict(perts)}") + + raw_batch = copy.deepcopy(batch) + aug_batch = dm.on_after_batch_transfer(batch, dataloader_idx=0) if SHOW_AUGMENTED else None + + save_path = OUTPUT_DIR / f"train_batch_{batch_idx}.png" if OUTPUT_DIR else None + plot_batch( + raw_batch=raw_batch, + aug_batch=aug_batch, + batch_idx=batch_idx, + n_show=N_SHOW, + show_augmented=SHOW_AUGMENTED, + save_path=save_path, + ) + +# %% +print("\nDone.") + +# %% [markdown] +# ## Validation dataloader +# +# The val dataloader uses the same dataset class but a different subset +# (train/val FOV split). Worth inspecting because DDP validation-epoch-end +# syncs `loss/val` across ranks — a bad val batch on any rank can stall +# the whole sync, or produce NaN features that poison metrics aggregation. +# +# We also scan the raw val batch for NaN/Inf before and after normalization, +# to catch any rows the preprocess step failed to filter. + +# %% +val_dl = dm.val_dataloader() +val_iter = iter(val_dl) + +nan_batches_raw = 0 +nan_batches_norm = 0 +for batch_idx in range(N_BATCHES): + print(f"\n--- Val batch {batch_idx} ---") + batch = next(val_iter) + + meta = batch["anchor_meta"] + n = len(meta) + markers = Counter(m.get("marker", "?") for m in meta) + perts = Counter(m.get("perturbation", "?") for m in meta) + print(f" {n} samples, markers={dict(markers)}, perturbations={dict(perts)}") + + raw_anchor = batch["anchor"] + raw_pos = batch.get("positive") + raw_bad = raw_anchor.isnan().any() or raw_anchor.isinf().any() + if raw_pos is not None: + raw_bad = raw_bad or raw_pos.isnan().any() or raw_pos.isinf().any() + if raw_bad: + nan_batches_raw += 1 + print(" ⚠ raw val batch contains NaN/Inf") + + raw_batch = copy.deepcopy(batch) + aug_batch = dm.on_after_batch_transfer(batch, dataloader_idx=1) if SHOW_AUGMENTED else None + + if aug_batch is not None: + aa = aug_batch["anchor"] + ap = aug_batch.get("positive") + norm_bad = aa.isnan().any() or aa.isinf().any() + if ap is not None: + norm_bad = norm_bad or ap.isnan().any() or ap.isinf().any() + if norm_bad and not raw_bad: + nan_batches_norm += 1 + print(" ⚠ post-normalize val batch contains NaN/Inf") + + save_path = OUTPUT_DIR / f"val_batch_{batch_idx}.png" if OUTPUT_DIR else None + plot_batch( + raw_batch=raw_batch, + aug_batch=aug_batch, + batch_idx=batch_idx, + n_show=N_SHOW, + show_augmented=SHOW_AUGMENTED, + save_path=save_path, + ) + +print(f"\nVal scan over {N_BATCHES} batches: raw NaN/Inf={nan_batches_raw}, post-norm NaN/Inf={nan_batches_norm}") + +# %% [markdown] +# ## Re-run additional batches +# +# Edit ``batch_idx`` and re-run this cell to inspect more batches +# without restarting the dataloader iterator. + +# %% diff --git a/applications/dynaclr/scripts/dataloader_inspection/explore_gut_parquet.py b/applications/dynaclr/scripts/dataloader_inspection/explore_gut_parquet.py new file mode 100644 index 000000000..c12fd32d3 --- /dev/null +++ b/applications/dynaclr/scripts/dataloader_inspection/explore_gut_parquet.py @@ -0,0 +1,238 @@ +"""Minimal exploration of Zuben's gut cell classifier parquet with DynaCLR dataloader. + +Parquet: /hpc/projects/jacobo_group/zuben/proj/gutCellClassifier/data/dynaclr_cell_index.parquet + +Key findings: +- Flat schema: one row per (cell, t, channel). Compatible with MultiExperimentDataModule. +- NOT timelapse: all t=0, no temporal positives. Use positive_cell_source="self" (SimCLR). +- 25 experiments (AAY6/7/8 × day 0/1/2 × gut1-6), 4 channels, 6 perturbation stages. +- Missing: hours_post_perturbation (not needed for self-positive mode). + +Usage:: + + cd /home/eduardo.hirata/repos/viscy + uv run python applications/dynaclr/scripts/dataloader_inspection/explore_gut_parquet.py +""" + +# ruff: noqa: E402, D103 + +# %% [markdown] +# # Gut Cell Parquet Explorer + +# %% +from __future__ import annotations + +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import zarr + +# %% [markdown] +# ## 1. Parquet Summary + +# %% +PARQUET_PATH = "/hpc/projects/jacobo_group/zuben/proj/gutCellClassifier/data/dynaclr_cell_index.parquet" +OUTPUT_DIR = Path("applications/dynaclr/scripts/dataloader_inspection/output/gut_parquet") +OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + +df = pd.read_parquet(PARQUET_PATH) +print(f"Shape: {df.shape}") +print(f"Columns: {df.columns.tolist()}\n") + +print(f"Experiments ({df['experiment'].nunique()}): {sorted(df['experiment'].unique())}\n") +print(f"Channels: {df['channel_name'].unique().tolist()}") +print(f"Perturbations: {sorted(df['perturbation'].unique())}") +print(f"t values: {sorted(df['t'].unique())} <- all 0, not timelapse") +print(f"z range: {df['z'].min()} - {df['z'].max()}") + +# %% +# Per-experiment cell counts and stage breakdown +print("\n## Per-experiment cell counts (unique cells × 4 channels = rows)") +for exp, g in df.groupby("experiment"): + n_cells = g["cell_id"].nunique() + stages = g["perturbation"].value_counts().to_dict() + print(f" {exp}: {n_cells} cells | stages={stages}") + +# %% [markdown] +# ## 2. Sample random patches from zarr +# +# Direct zarr read bypasses the iohub channel_names issue. +# Array shape: (T, C, Z, Y, X) = (1, 4, ~98, H, W) +# Channel order: nuclear, septate, brush_border, SuH + +CHANNEL_NAMES = ["nuclear", "septate", "brush_border", "SuH"] +PATCH_SIZE = 128 # pixels around cell center +N_SAMPLES_PER_CHANNEL = 4 +N_STAGES = 3 # show first N stages + + +def read_patch(row: pd.Series, channel_idx: int, patch: int = PATCH_SIZE) -> np.ndarray | None: + """Read a 2D patch around the cell center from zarr.""" + store = zarr.open(row["store_path"], mode="r") + pos_path = f"{row['well']}/{row['fov']}" + arr = store[pos_path]["0"] # (T, C, Z, Y, X) + z = int(row["z"]) + y = int(row["y"]) + x = int(row["x"]) + H, W = arr.shape[3], arr.shape[4] + half = patch // 2 + y0, y1 = max(0, y - half), min(H, y + half) + x0, x1 = max(0, x - half), min(W, x + half) + t = int(row["t"]) + return arr[t, channel_idx, z, y0:y1, x0:x1] + + +# %% [markdown] +# ## 3. Grid: channels × perturbation stages + +# %% +stages = sorted(df["perturbation"].unique())[:N_STAGES] +n_cols = N_SAMPLES_PER_CHANNEL +n_rows = len(CHANNEL_NAMES) * len(stages) + +fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 2, n_rows * 2), squeeze=False) +fig.suptitle("Gut cell patches: rows=channel×stage, cols=random samples", fontsize=10) + +row_idx = 0 +for stage in stages: + stage_df = df[df["perturbation"] == stage] + for ch_i, ch_name in enumerate(CHANNEL_NAMES): + ch_df = stage_df[stage_df["channel_name"] == ch_name] + sampled = ch_df.sample(min(N_SAMPLES_PER_CHANNEL, len(ch_df)), random_state=42) + ax_row = axes[row_idx] + for col_i, (_, row) in enumerate(sampled.iterrows()): + patch = read_patch(row, ch_i) + ax = ax_row[col_i] + vmin, vmax = np.percentile(patch, [1, 99]) + ax.imshow(patch, cmap="gray", vmin=vmin, vmax=vmax) + ax.set_xticks([]) + ax.set_yticks([]) + if col_i == 0: + ax.set_ylabel(f"{ch_name}\n{stage}", fontsize=7) + row_idx += 1 + +plt.tight_layout() +save_path = OUTPUT_DIR / "patches_channel_by_stage.png" +fig.savefig(save_path, dpi=120, bbox_inches="tight") +print(f"Saved: {save_path}") + +# %% [markdown] +# ## 4. Stage distribution per experiment + +# %% +fig, ax = plt.subplots(figsize=(14, 4)) +pivot = ( + df.drop_duplicates(["cell_id", "perturbation"]).groupby(["experiment", "perturbation"]).size().unstack(fill_value=0) # noqa: PD010 +) +pivot.plot.bar(ax=ax, stacked=True, colormap="tab10") +ax.set_title("Cell counts by experiment and stage") +ax.set_xlabel("") +ax.tick_params(axis="x", rotation=45) +ax.legend(title="stage", bbox_to_anchor=(1, 1)) +plt.tight_layout() +save_path = OUTPUT_DIR / "stage_distribution.png" +fig.savefig(save_path, dpi=120, bbox_inches="tight") +print(f"Saved: {save_path}") + +# %% [markdown] +# ## 5. Channel distribution + +# %% +fig, axes = plt.subplots(1, 2, figsize=(10, 4)) +df.drop_duplicates(["cell_id", "channel_name"])["channel_name"].value_counts().plot.bar(ax=axes[0], color="steelblue") +axes[0].set_title("Cells per channel") +axes[0].tick_params(axis="x", rotation=30) + +df.drop_duplicates(["cell_id", "perturbation"])["perturbation"].value_counts().plot.bar(ax=axes[1], color="coral") +axes[1].set_title("Cells per stage") +axes[1].tick_params(axis="x", rotation=30) + +plt.tight_layout() +save_path = OUTPUT_DIR / "distributions.png" +fig.savefig(save_path, dpi=120, bbox_inches="tight") +print(f"Saved: {save_path}") + +# %% [markdown] +# ## 6. DynaCLR DataModule (self-positive / SimCLR) +# +# Not timelapse (t=0 only) so use positive_cell_source="self" — +# augmentation creates two views of the same cell. + +# %% +from dynaclr.data.datamodule import MultiExperimentDataModule + +Z_WINDOW = 1 +YX_PATCH_SIZE = (256, 256) +FINAL_YX_PATCH_SIZE = (224, 224) +BATCH_SIZE = 8 +NUM_WORKERS = 4 +N_BATCHES = 2 + +print("Building DataModule (self-positive, marker-grouped)...") +dm = MultiExperimentDataModule( + cell_index_path=PARQUET_PATH, + z_window=Z_WINDOW, + yx_patch_size=YX_PATCH_SIZE, + final_yx_patch_size=FINAL_YX_PATCH_SIZE, + batch_size=BATCH_SIZE, + num_workers=NUM_WORKERS, + channel_dropout_prob=0.0, + positive_cell_source="self", + channels_per_sample=1, + batch_group_by=["marker"], + stratify_by="perturbation", +) +dm.setup("fit") +print("Done.\n") + +va = dm.train_dataset.index.valid_anchors +print(f"Valid anchors: {len(va):,}") +print(f"Channels: {va['marker'].value_counts().to_dict()}") +print(f"Perturbations: {va['perturbation'].value_counts().to_dict()}") + + +# %% +def plot_batch(batch: dict, batch_idx: int, title: str, save_path: Path | None = None) -> None: + """Grid of anchor images annotated with channel + perturbation.""" + anchor = batch["anchor"].numpy() + meta = batch["anchor_meta"] + n = len(meta) + + fig, axes = plt.subplots(1, n, figsize=(n * 2.2, 2.8), squeeze=False) + channels_in_batch = {m.get("marker", "?") for m in meta} + perts_in_batch = {m.get("perturbation", "?") for m in meta} + fig.suptitle( + f"{title} — Batch {batch_idx}\nchannel={channels_in_batch} | stages={perts_in_batch}", + fontsize=9, + ) + for i, (ax, m) in enumerate(zip(axes[0], meta)): + img = anchor[i] + if img.ndim == 4: + img = img[0, img.shape[1] // 2] + elif img.ndim == 3: + img = img[0] + vmin, vmax = np.percentile(img, [1, 99]) + ax.imshow(img, cmap="gray", vmin=vmin, vmax=vmax) + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_title(f"{m.get('marker', '?')}\n{m.get('perturbation', '?')}", fontsize=6) + plt.tight_layout() + if save_path: + fig.savefig(save_path, dpi=120, bbox_inches="tight") + print(f" Saved: {save_path}") + + +dl = dm.train_dataloader() +for i, batch in enumerate(dl): + if i >= N_BATCHES: + break + meta = batch["anchor_meta"] + print(f"Batch {i}: {len(meta)} samples marker={{{meta[0].get('marker')}}} anchor shape={batch['anchor'].shape}") + plot_batch( + batch, i, "Gut: marker-grouped, perturbation-stratified", save_path=OUTPUT_DIR / f"dataloader_batch_{i}.png" + ) + +# %% +plt.show() diff --git a/applications/dynaclr/scripts/dataloader_inspection/test_2d_mip_augmentation.py b/applications/dynaclr/scripts/dataloader_inspection/test_2d_mip_augmentation.py new file mode 100644 index 000000000..5f9687ea5 --- /dev/null +++ b/applications/dynaclr/scripts/dataloader_inspection/test_2d_mip_augmentation.py @@ -0,0 +1,319 @@ +"""2D MIP augmentation demo — inspect and verify the pipeline. + +Jupyter-style notebook (use ``# %%`` cells in VS Code or JupyterLab). + +Shows what the 2D MIP model receives as input and verifies: + +- **Row 0 (anchor raw)**: center z-slice of the 20-slice raw extraction patch. +- **Row 1 (anchor aug)**: after normalize → affine → RandSpatialCrop(10) → MIP/center-slice → CenterCrop(160,160). + +Column annotations show marker, perturbation, and the z-reduction strategy +applied (MIP for fluorescence, center-slice for label-free). + +Pipeline: + extract (20, 192, 192) → normalize → affine → RandSpatialCrop(10, 192, 192) + → flip/contrast/noise → ZReduction (MIP or center-slice) → CenterCrop(1, 160, 160) + +Usage:: + + uv run python applications/dynaclr/scripts/dataloader_inspection/test_2d_mip_augmentation.py +""" + +# ruff: noqa: E402, D103 + +# %% [markdown] +# # 2D MIP Augmentation Demo +# +# Verify the z-reduction strategy per marker and visualize raw vs augmented. +# +# ## Pipeline +# +# 1. **Extract** 20 z-slices around focus +# 2. **Normalize** (subtract mean, divide std) +# 3. **Affine** (rotate/scale/shear) +# 4. **RandSpatialCrop** to (10, 192, 192) — random Z for focus invariance +# 5. **Flip, contrast, scale, smooth, noise** +# 6. **ZReduction**: MIP for fluorescence, center-slice for label-free +# 7. **CenterCrop** to (1, 160, 160) — auto-appended by datamodule + +# %% +from __future__ import annotations + +import copy +from collections import Counter +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import torch + +from dynaclr.data.datamodule import MultiExperimentDataModule +from viscy_data._utils import _transform_channel_wise +from viscy_data.channel_utils import parse_channel_name +from viscy_transforms import ( + BatchedChannelWiseZReductiond, + BatchedRandAdjustContrastd, + BatchedRandAffined, + BatchedRandFlipd, + BatchedRandGaussianNoised, + BatchedRandGaussianSmoothd, + BatchedRandScaleIntensityd, + BatchedRandSpatialCropd, + NormalizeSampled, +) + +# %% [markdown] +# ## Configuration + +# %% +CELL_INDEX_PATH = "/home/eduardo.hirata/repos/viscy/applications/dynaclr/configs/cell_index/test_2d_mip_mixed.parquet" + +Z_WINDOW = 1 +Z_EXTRACTION_WINDOW = 20 +Z_FOCUS_OFFSET = 0.5 +YX_PATCH_SIZE = (192, 192) +FINAL_YX_PATCH_SIZE = (160, 160) +CHANNEL_NAMES = ["channel_0"] + +BATCH_SIZE = 16 +N_BATCHES = 4 +N_SHOW = 10 +NUM_WORKERS = 4 +OUTPUT_DIR = Path("/home/eduardo.hirata/repos/viscy/applications/dynaclr/scripts/dataloader_inspection/results") + +# %% [markdown] +# ## Build DataModule + +# %% +normalizations = [ + NormalizeSampled( + keys=CHANNEL_NAMES, + level="timepoint_statistics", + subtrahend="mean", + divisor="std", + ) +] +augmentations = [ + BatchedRandAffined( + keys=CHANNEL_NAMES, + prob=0.8, + scale_range=[[0.8, 1.3], [0.8, 1.3], [0.8, 1.3]], + rotate_range=[3.14, 0.0, 0.0], + shear_range=[0.05, 0.05, 0.0, 0.05, 0.0, 0.05], + ), + BatchedRandFlipd(keys=CHANNEL_NAMES, spatial_axes=[1, 2], prob=0.5), + BatchedRandAdjustContrastd(keys=CHANNEL_NAMES, prob=0.5, gamma=(0.6, 1.6)), + BatchedRandScaleIntensityd(keys=CHANNEL_NAMES, prob=0.5, factors=0.5), + BatchedRandGaussianSmoothd( + keys=CHANNEL_NAMES, + prob=0.5, + sigma_x=[0.25, 0.50], + sigma_y=[0.25, 0.50], + sigma_z=[0.0, 0.0], + ), + BatchedRandGaussianNoised(keys=CHANNEL_NAMES, prob=0.5, mean=0.0, std=0.1), + # Random Z crop: select 10 of 20 extracted slices for Z-invariance. + BatchedRandSpatialCropd(keys=CHANNEL_NAMES, roi_size=[10, 192, 192]), + # Z-reduction: MIP for fluorescence, center-slice for label-free. + BatchedChannelWiseZReductiond(keys=CHANNEL_NAMES, allow_missing_keys=True), +] + +dm = MultiExperimentDataModule( + cell_index_path=CELL_INDEX_PATH, + z_window=Z_WINDOW, + z_extraction_window=Z_EXTRACTION_WINDOW, + z_focus_offset=Z_FOCUS_OFFSET, + yx_patch_size=YX_PATCH_SIZE, + final_yx_patch_size=FINAL_YX_PATCH_SIZE, + channels_per_sample=1, + positive_cell_source="lookup", + positive_match_columns=["lineage_id"], + tau_range=(0.5, 2.0), + tau_decay_rate=2.0, + stratify_by=["perturbation", "marker"], + split_ratio=0.8, + batch_size=BATCH_SIZE, + num_workers=NUM_WORKERS, + seed=42, + focus_channel="Phase3D", + reference_pixel_size_xy_um=0.1494, + channel_dropout_prob=0.0, + normalizations=normalizations, + augmentations=augmentations, +) +dm.setup("fit") + +va = dm.train_dataset.index.valid_anchors +print(f"Anchors: {len(va):,} | Experiments: {va['experiment'].nunique()}") +for exp, g in va.groupby("experiment"): + markers = g["marker"].value_counts().to_dict() if "marker" in g.columns else {} + print(f" {exp}: {len(g):,} anchors markers={markers}") + + +# %% [markdown] +# ## Helpers + + +# %% +def _apply_augmentations(batch: dict) -> torch.Tensor: + """Apply the full augmentation pipeline to a raw batch, return (B,C,1,H,W).""" + norm_meta = batch.get("anchor_norm_meta") + is_labelfree = torch.tensor( + [parse_channel_name(m.get("marker", ""))["channel_type"] == "labelfree" for m in batch["anchor_meta"]], + dtype=torch.bool, + ) + return _transform_channel_wise( + transform=dm._augmentation_transform, + channel_names=dm._channel_names, + patch=batch["anchor"], + norm_meta=norm_meta, + extra={"_is_labelfree": is_labelfree}, + ) + + +def _img2d_raw(tensor: np.ndarray, sample_idx: int) -> np.ndarray: + """Center z-slice from raw (B, C, Z, Y, X) for display.""" + vol = tensor[sample_idx, 0] # (Z, Y, X) + return vol[vol.shape[0] // 2] + + +def _img2d_aug(tensor: np.ndarray, sample_idx: int) -> np.ndarray: + """2D image from augmented (B, C, 1, Y, X).""" + return tensor[sample_idx, 0, 0] + + +def _strategy(marker: str) -> str: + ct = parse_channel_name(marker)["channel_type"] + return "center-slice" if ct == "labelfree" else "MIP" + + +def plot_batch( + raw_batch: dict, + aug_patch: torch.Tensor, + batch_idx: int, + n_show: int = N_SHOW, + save_path: Path | None = None, +) -> None: + anchor_raw = raw_batch["anchor"].numpy() + anchor_aug = aug_patch.numpy() + meta = raw_batch.get("anchor_meta", []) + n = min(n_show, len(meta)) + + markers = Counter(m.get("marker", "?") for m in meta[:n]) + perts = Counter(m.get("perturbation", "?") for m in meta[:n]) + m_str = " ".join(f"{k}={v}" for k, v in markers.most_common(5)) + p_str = " ".join(f"{k}={v}" for k, v in perts.most_common(5)) + + fig, axes = plt.subplots(2, n, figsize=(n * 2.0, 2 * 2.4), squeeze=False) + fig.suptitle( + f"Batch {batch_idx} | markers: {m_str} | pert: {p_str}\n" + f"raw z-depth={anchor_raw.shape[2]} aug z-depth={anchor_aug.shape[2]}", + fontsize=8, + fontweight="bold", + ) + + for i in range(n): + am = meta[i] if i < len(meta) else {} + marker = am.get("marker", "?") + strategy = _strategy(marker) + + # Row 0: raw center z-slice + img_raw = _img2d_raw(anchor_raw, i) + vmin, vmax = np.percentile(img_raw, [1, 99]) + axes[0, i].imshow(img_raw, cmap="gray", vmin=vmin, vmax=vmax) + axes[0, i].set_xticks([]) + axes[0, i].set_yticks([]) + axes[0, i].set_title( + "\n".join( + [ + f"{am.get('experiment', '?')[:20]}", + f"marker={marker}", + f"pert={am.get('perturbation', '?')}", + f"t={am.get('t', '?')}", + f"z_reduction={strategy}", + ] + ), + fontsize=5, + linespacing=1.1, + ) + + # Row 1: augmented (post ZReduction) + img_aug = _img2d_aug(anchor_aug, i) + vmin_a, vmax_a = np.percentile(img_aug, [1, 99]) + axes[1, i].imshow(img_aug, cmap="gray", vmin=vmin_a, vmax=vmax_a) + axes[1, i].set_xticks([]) + axes[1, i].set_yticks([]) + axes[1, i].set_title(f"μ={img_aug.mean():.2f} σ={img_aug.std():.2f}", fontsize=5) + + axes[0, 0].set_ylabel("raw (center z)", fontsize=7, fontweight="bold") + axes[1, 0].set_ylabel("aug (MIP/center)", fontsize=7, fontweight="bold") + + plt.tight_layout() + if save_path: + fig.savefig(save_path, dpi=150, bbox_inches="tight") + print(f" Saved: {save_path}") + else: + plt.show() + + +def check_batch(batch_idx: int, raw_batch: dict, aug_patch: torch.Tensor) -> None: + """Assert shape and z-reduction correctness, print summary.""" + meta = raw_batch.get("anchor_meta", []) + + assert aug_patch.shape[2] == 1, f"Batch {batch_idx}: z should be 1, got {aug_patch.shape}" + assert aug_patch.shape[3] == FINAL_YX_PATCH_SIZE[0], f"Y should be {FINAL_YX_PATCH_SIZE[0]}" + assert aug_patch.shape[4] == FINAL_YX_PATCH_SIZE[1], f"X should be {FINAL_YX_PATCH_SIZE[1]}" + print(f" [PASS] shape: {tuple(aug_patch.shape)}") + + n_lf, n_fl = 0, 0 + for i, m in enumerate(meta): + marker = m.get("marker", "") + ct = parse_channel_name(marker)["channel_type"] + assert not torch.all(aug_patch[i] == 0), f"Sample {i} ({marker}) is all zeros" + if ct == "labelfree": + n_lf += 1 + else: + n_fl += 1 + + raw_z = raw_batch["anchor"].shape[2] + print(f" [PASS] label-free (center-slice)={n_lf} fluorescence (MIP)={n_fl} raw_z={raw_z}") + print(f" [INFO] markers: {dict(Counter(m.get('marker', '?') for m in meta))}") + + +# %% [markdown] +# ## Draw batches + +# %% +if OUTPUT_DIR: + OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + +dl = dm.train_dataloader() +dl_iter = iter(dl) + +for batch_idx in range(N_BATCHES): + print(f"\n--- Batch {batch_idx} ---") + batch = next(dl_iter) + raw_batch = copy.deepcopy(batch) + aug_patch = _apply_augmentations(batch) + check_batch(batch_idx, raw_batch, aug_patch) + save_path = OUTPUT_DIR / f"batch_{batch_idx}.png" if OUTPUT_DIR else None + plot_batch(raw_batch, aug_patch, batch_idx, save_path=save_path) + +# %% +print("\nDone.") + +# %% [markdown] +# ## Re-run additional batches +# +# Edit ``batch_idx`` and re-run this cell to inspect more batches +# without restarting the dataloader iterator. + +# %% +batch_idx = N_BATCHES +batch = next(dl_iter) +raw_batch = copy.deepcopy(batch) +aug_patch = _apply_augmentations(batch) +check_batch(batch_idx, raw_batch, aug_patch) +plot_batch(raw_batch, aug_patch, batch_idx) + +# %% diff --git a/applications/dynaclr/scripts/evaluation/compare_evals.py b/applications/dynaclr/scripts/evaluation/compare_evals.py new file mode 100644 index 000000000..4bceb4dce --- /dev/null +++ b/applications/dynaclr/scripts/evaluation/compare_evals.py @@ -0,0 +1,332 @@ +"""Compare evaluation results across multiple model runs. + +Reads outputs produced by ``dynaclr evaluate`` from multiple model eval directories, +compares smoothness, linear classifier AUROC, and MMD activity z-scores side by side, +and writes summary CSVs and plots to a shared output directory. + +Usage +----- +python compare_evals.py -c eval_registry.yml + +Registry YAML format +-------------------- +models: + - name: DynaCLR-v3 + eval_dir: /path/to/eval_v3 + - name: DINOv3-MLP + eval_dir: /path/to/eval_dino +output_dir: /path/to/comparison_output +fdr_threshold: 0.05 # optional, default 0.05 +""" + +from __future__ import annotations + +from pathlib import Path + +import click +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import yaml +from matplotlib.lines import Line2D + +# --------------------------------------------------------------------------- +# Registry loading +# --------------------------------------------------------------------------- + + +def _load_registry(path: Path) -> tuple[list[dict], Path, float]: + with open(path) as f: + raw = yaml.safe_load(f) + output_dir = Path(raw["output_dir"]) + fdr_threshold = float(raw.get("fdr_threshold", 0.05)) + return raw["models"], output_dir, fdr_threshold + + +# --------------------------------------------------------------------------- +# Smoothness +# --------------------------------------------------------------------------- + + +def _load_smoothness(models: list[dict]) -> pd.DataFrame | None: + frames = [] + for entry in models: + smoothness_dir = Path(entry["eval_dir"]) / "smoothness" + csvs = list(smoothness_dir.glob("*_smoothness_stats.csv")) + if not csvs: + click.echo(f"[smoothness] No smoothness CSV found for {entry['name']}", err=True) + continue + # Take the first (usually only) stats file — not per-group + df = pd.read_csv(csvs[0]) + df["model"] = entry["name"] + frames.append(df) + if not frames: + return None + return pd.concat(frames, ignore_index=True) + + +def _plot_smoothness(df: pd.DataFrame, output_dir: Path) -> None: + metrics = ["smoothness_score", "dynamic_range"] + present = [m for m in metrics if m in df.columns] + if not present: + return + + fig, axes = plt.subplots(1, len(present), figsize=(5 * len(present), 4), squeeze=False) + for ax, metric in zip(axes[0], present): + vals = df.set_index("model")[metric] + ax.bar(vals.index, vals.values, color=plt.cm.tab10(np.arange(len(vals)) / len(vals))) + ax.set_title(metric.replace("_", " ").title()) + ax.set_ylabel(metric) + plt.setp(ax.get_xticklabels(), rotation=30, ha="right") + + fig.tight_layout() + out = output_dir / "smoothness_comparison.pdf" + fig.savefig(out, bbox_inches="tight") + plt.close(fig) + click.echo(f"[smoothness] Saved: {out}", err=True) + + +# --------------------------------------------------------------------------- +# Linear classifiers +# --------------------------------------------------------------------------- + + +def _load_linear_classifiers(models: list[dict]) -> pd.DataFrame | None: + frames = [] + for entry in models: + csv = Path(entry["eval_dir"]) / "linear_classifiers" / "metrics_summary.csv" + if not csv.exists(): + click.echo(f"[linear_classifiers] Not found for {entry['name']}: {csv}", err=True) + continue + df = pd.read_csv(csv) + df["model"] = entry["name"] + frames.append(df) + if not frames: + return None + return pd.concat(frames, ignore_index=True) + + +def _plot_linear_classifiers(df: pd.DataFrame, output_dir: Path) -> None: + if "auroc" not in df.columns: + return + + tasks = sorted(df["task"].unique()) if "task" in df.columns else ["all"] + ncols = min(4, len(tasks)) + nrows = int(np.ceil(len(tasks) / ncols)) + fig, axes = plt.subplots(nrows, ncols, figsize=(5 * ncols, 4 * nrows), squeeze=False) + axes_flat = axes.flatten() + models = sorted(df["model"].unique()) + colors = plt.cm.tab10(np.linspace(0, 1, len(models))) + model_color = dict(zip(models, colors)) + + for ax_idx, task in enumerate(tasks): + ax = axes_flat[ax_idx] + sub = df[df["task"] == task] if "task" in df.columns else df + pivot = sub.pivot_table( + index="marker" if "marker" in sub.columns else sub.index, columns="model", values="auroc" + ) + pivot = pivot.reindex(columns=models) + + x = np.arange(len(pivot)) + width = 0.8 / len(models) + for i, model in enumerate(models): + if model not in pivot.columns: + continue + ax.bar(x + i * width, pivot[model].values, width, label=model, color=model_color[model]) + + ax.set_xticks(x + width * (len(models) - 1) / 2) + ax.set_xticklabels(pivot.index, rotation=45, ha="right", fontsize=8) + ax.set_ylabel("AUROC") + ax.set_title(task, fontsize=9) + ax.axhline(0.5, color="gray", linewidth=0.8, linestyle="--") + ax.set_ylim(0, 1.05) + + for ax in axes_flat[len(tasks) :]: + ax.set_visible(False) + + handles = [plt.Rectangle((0, 0), 1, 1, color=model_color[m], label=m) for m in models] + fig.legend(handles=handles, loc="lower center", ncol=len(models), fontsize=8, bbox_to_anchor=(0.5, 0)) + fig.tight_layout(rect=[0, 0.05, 1, 1]) + out = output_dir / "linear_classifiers_comparison.pdf" + fig.savefig(out, bbox_inches="tight") + plt.close(fig) + click.echo(f"[linear_classifiers] Saved: {out}", err=True) + + +# --------------------------------------------------------------------------- +# MMD +# --------------------------------------------------------------------------- + + +def _load_mmd(models: list[dict]) -> pd.DataFrame | None: + frames = [] + for entry in models: + mmd_root = Path(entry["eval_dir"]) / "mmd" + if not mmd_root.exists(): + click.echo(f"[mmd] No mmd directory for {entry['name']}", err=True) + continue + for csv in sorted(mmd_root.rglob("mmd_results.csv")): + block_name = csv.parent.name + df = pd.read_csv(csv) + df["model"] = entry["name"] + df["block"] = block_name + frames.append(df) + if not frames: + return None + return pd.concat(frames, ignore_index=True) + + +def _plot_mmd_kinetics(df: pd.DataFrame, output_dir: Path, fdr_threshold: float) -> None: + temporal = df.dropna(subset=["hours_bin_start", "hours_bin_end"]).copy() + if temporal.empty: + click.echo("[mmd] No temporal rows — skipping kinetics plot", err=True) + return + + temporal["hours_mid"] = (temporal["hours_bin_start"] + temporal["hours_bin_end"]) / 2 + markers = sorted(temporal["marker"].unique()) + models = sorted(temporal["model"].unique()) + labels = sorted(temporal["label"].unique()) + blocks = sorted(temporal["block"].unique()) + + for block in blocks: + sub_block = temporal[temporal["block"] == block] + ncols = min(4, len(markers)) + nrows = int(np.ceil(len(markers) / ncols)) + fig, axes = plt.subplots(nrows, ncols, figsize=(5 * ncols, 4 * nrows), squeeze=False) + axes_flat = axes.flatten() + + colors = plt.cm.tab10(np.linspace(0, 1, len(models))) + linestyles = ["-", "--", ":", "-."] + model_color = dict(zip(models, colors)) + label_ls = dict(zip(labels, linestyles[: len(labels)])) + + for ax_idx, marker in enumerate(markers): + ax = axes_flat[ax_idx] + sub = sub_block[sub_block["marker"] == marker] + for model in models: + for label in labels: + grp = sub[(sub["model"] == model) & (sub["label"] == label)].sort_values("hours_mid") + if grp.empty: + continue + ax.plot( + grp["hours_mid"], + grp["activity_zscore"], + color=model_color[model], + linestyle=label_ls[label], + linewidth=1.5, + ) + if "q_value" in grp.columns: + sig = grp[grp["q_value"] < fdr_threshold] + ax.scatter(sig["hours_mid"], sig["activity_zscore"], color=model_color[model], s=30, zorder=5) + ax.axhline(0, color="gray", linewidth=0.8, linestyle="--") + ax.set_title(marker, fontsize=9) + ax.set_xlabel("Hours post perturbation") + ax.set_ylabel("Activity z-score") + + for ax in axes_flat[len(markers) :]: + ax.set_visible(False) + + legend_handles = [Line2D([0], [0], color=model_color[m], linewidth=2, label=m) for m in models] + legend_handles += [ + Line2D([0], [0], color="black", linestyle=label_ls[lb], linewidth=1.5, label=lb) for lb in labels + ] + fig.legend( + handles=legend_handles, + loc="lower center", + ncol=len(models) + len(labels), + fontsize=8, + bbox_to_anchor=(0.5, 0), + ) + fig.tight_layout(rect=[0, 0.05, 1, 1]) + + out = output_dir / f"mmd_kinetics_{block}.pdf" + fig.savefig(out, bbox_inches="tight") + plt.close(fig) + click.echo(f"[mmd] Saved: {out}", err=True) + + +def _plot_mmd_summary_heatmap(summary: pd.DataFrame, output_dir: Path) -> None: + blocks = sorted(summary["block"].unique()) + labels = sorted(summary["label"].unique()) + models = sorted(summary["model"].unique()) + + for block in blocks: + sub_block = summary[summary["block"] == block] + ncols = len(labels) + markers = sorted(sub_block["marker"].unique()) + fig, axes = plt.subplots(1, ncols, figsize=(5 * ncols, max(3, len(markers) * 0.5 + 1)), squeeze=False) + for col_idx, label in enumerate(labels): + ax = axes[0, col_idx] + pivot = sub_block[sub_block["label"] == label].pivot_table( + index="marker", columns="model", values="mean_activity_zscore", aggfunc="mean" + ) + pivot = pivot.reindex(columns=models) + vmax = np.nanpercentile(np.abs(pivot.values), 95) if pivot.values.size > 0 else 1.0 + im = ax.imshow(pivot.values, aspect="auto", cmap="RdBu_r", vmin=-vmax, vmax=vmax) + ax.set_xticks(range(len(models))) + ax.set_xticklabels(models, rotation=45, ha="right", fontsize=8) + ax.set_yticks(range(len(pivot.index))) + ax.set_yticklabels(pivot.index, fontsize=8) + ax.set_title(label, fontsize=9) + plt.colorbar(im, ax=ax, label="Mean activity z-score") + + fig.tight_layout() + out = output_dir / f"mmd_summary_heatmap_{block}.pdf" + fig.savefig(out, bbox_inches="tight") + plt.close(fig) + click.echo(f"[mmd] Saved: {out}", err=True) + + +def _build_mmd_summary(df: pd.DataFrame) -> pd.DataFrame: + return ( + df.groupby(["block", "model", "marker", "label"])["activity_zscore"] + .agg(mean_activity_zscore="mean", n_bins="count") + .reset_index() + .sort_values(["block", "label", "marker", "mean_activity_zscore"], ascending=[True, True, True, False]) + ) + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + + +@click.command() +@click.option( + "-c", "--config", required=True, type=click.Path(exists=True, path_type=Path), help="Path to eval_registry.yml" +) +def main(config: Path) -> None: + """Compare evaluation results across model runs.""" + models, output_dir, fdr_threshold = _load_registry(config) + output_dir.mkdir(parents=True, exist_ok=True) + + # Smoothness + smoothness_df = _load_smoothness(models) + if smoothness_df is not None: + smoothness_df.to_csv(output_dir / "smoothness_comparison.csv", index=False) + _plot_smoothness(smoothness_df, output_dir) + click.echo("\n## Smoothness\n") + click.echo(smoothness_df[["model", "smoothness_score", "dynamic_range"]].to_markdown(index=False)) + + # Linear classifiers + lc_df = _load_linear_classifiers(models) + if lc_df is not None: + lc_df.to_csv(output_dir / "linear_classifiers_comparison.csv", index=False) + _plot_linear_classifiers(lc_df, output_dir) + summary_cols = [c for c in ["model", "task", "marker", "auroc", "f1"] if c in lc_df.columns] + click.echo("\n## Linear Classifiers\n") + click.echo(lc_df[summary_cols].to_markdown(index=False)) + + # MMD + mmd_df = _load_mmd(models) + if mmd_df is not None: + mmd_summary = _build_mmd_summary(mmd_df) + mmd_summary.to_csv(output_dir / "mmd_comparison.csv", index=False) + _plot_mmd_kinetics(mmd_df, output_dir, fdr_threshold) + _plot_mmd_summary_heatmap(mmd_summary, output_dir) + click.echo("\n## MMD activity z-score\n") + click.echo(mmd_summary.to_markdown(index=False)) + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/scripts/evaluation/microglia_alfi_analysis.py b/applications/dynaclr/scripts/evaluation/microglia_alfi_analysis.py new file mode 100644 index 000000000..2bd1e4672 --- /dev/null +++ b/applications/dynaclr/scripts/evaluation/microglia_alfi_analysis.py @@ -0,0 +1,361 @@ +"""Embedding analysis for microglia and ALFI datasets. + +Microglia (unsupervised): + PCA/UMAP colored by perturbation condition and per-track embedding + displacement — proxy for morphological dynamics (Khurana et al. 2022, + https://doi.org/10.1091/mbc.E21-11-0561). + +ALFI HeLa (supervised): + PCA/UMAP colored by cell cycle phase annotations (interphase vs mitosis) + from the ALFI dataset (Dang et al. 2023, + https://doi.org/10.1038/s41597-023-02540-1). + +Usage +----- +python scripts/evaluation/microglia_alfi_analysis.py \\ + --microglia-embeddings /path/to/microglia/embeddings.zarr \\ + --alfi-embeddings /path/to/alfi/embeddings.zarr \\ + --output-dir /path/to/output/ +""" + +import argparse +from pathlib import Path + +import anndata as ad +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler +from umap import UMAP + +ALFI_ANNOTATIONS = Path("/hpc/projects/organelle_phenotyping/datasets/annotations/ALFI/ALFI_combined_annotations.csv") + +DIVISION_PALETTE = { + "interphase": "cornflowerblue", + "mitosis": "darkorange", +} + + +def compute_track_displacement_metrics(adata: ad.AnnData) -> pd.DataFrame: + """Compute per-track embedding displacement metrics. + + Parameters + ---------- + adata : AnnData + Embeddings with obs columns fov_name, track_id, t. + adata.X contains raw embeddings (N x D). + + Returns + ------- + pd.DataFrame + One row per track with columns: + fov_name, track_id, mean_step_size, total_path_length, + net_displacement, track_length, and any available metadata columns. + """ + embeddings = np.asarray(adata.X) + obs = adata.obs.copy() + obs["_idx"] = np.arange(len(obs)) + + meta_cols = [c for c in ["perturbation", "marker", "experiment"] if c in obs.columns] + records = [] + + for (fov, tid), grp in obs.groupby(["fov_name", "track_id"], sort=False): + grp = grp.sort_values("t") + idxs = grp["_idx"].values + if len(idxs) < 2: + continue + embs = embeddings[idxs] + steps = np.linalg.norm(np.diff(embs, axis=0), axis=1) + record = { + "fov_name": fov, + "track_id": tid, + "mean_step_size": steps.mean(), + "total_path_length": steps.sum(), + "net_displacement": float(np.linalg.norm(embs[-1] - embs[0])), + "track_length": len(idxs), + } + for col in meta_cols: + record[col] = grp[col].iloc[0] + records.append(record) + + return pd.DataFrame(records) + + +def _get_or_compute_pca(adata: ad.AnnData, features_scaled: np.ndarray) -> np.ndarray: + if "X_pca" in adata.obsm: + return adata.obsm["X_pca"] + pca = PCA(n_components=32) + return pca.fit_transform(features_scaled) + + +def _get_or_compute_umap(adata: ad.AnnData, features_scaled: np.ndarray) -> np.ndarray: + if "X_umap" in adata.obsm: + return adata.obsm["X_umap"] + print(" Computing UMAP...") + return UMAP(n_components=2, n_neighbors=15, random_state=42).fit_transform(features_scaled) + + +def analyze_microglia(adata: ad.AnnData, output_dir: Path) -> None: + """Run microglia displacement analysis and save plots.""" + print(f"Microglia: {adata.shape[0]:,} observations") + + features = np.asarray(adata.X) + features_scaled = StandardScaler().fit_transform(features) + pca_emb = _get_or_compute_pca(adata, features_scaled) + umap_emb = _get_or_compute_umap(adata, features_scaled) + + track_metrics = compute_track_displacement_metrics(adata) + print(f" {len(track_metrics):,} tracks") + + obs = adata.obs.copy().merge( + track_metrics[["fov_name", "track_id", "mean_step_size", "net_displacement"]], + on=["fov_name", "track_id"], + how="left", + ) + + perturbations = sorted(obs["perturbation"].unique()) if "perturbation" in obs.columns else [] + markers = sorted(obs["marker"].unique()) if "marker" in obs.columns else [] + palette_p = dict(zip(perturbations, sns.color_palette("tab10", len(perturbations)))) + palette_m = dict(zip(markers, sns.color_palette("Set2", len(markers)))) + + plot_df = pd.DataFrame( + { + "PC1": pca_emb[:, 0], + "PC2": pca_emb[:, 1], + "UMAP1": umap_emb[:, 0], + "UMAP2": umap_emb[:, 1], + "perturbation": obs["perturbation"].values if "perturbation" in obs.columns else "unknown", + "marker": obs["marker"].values if "marker" in obs.columns else "unknown", + "mean_step_size": obs["mean_step_size"].values, + "net_displacement": obs["net_displacement"].values, + } + ) + + vmin = np.nanpercentile(plot_df["mean_step_size"], 5) + vmax = np.nanpercentile(plot_df["mean_step_size"], 95) + + for reduction, x_col, y_col in [("pca", "PC1", "PC2"), ("umap", "UMAP1", "UMAP2")]: + fig, axes = plt.subplots(1, 3, figsize=(18, 5)) + + sns.scatterplot( + data=plot_df, + x=x_col, + y=y_col, + hue="perturbation", + palette=palette_p, + ax=axes[0], + alpha=0.5, + s=8, + linewidth=0, + ) + axes[0].set_title(f"{reduction.upper()} — perturbation") + + sns.scatterplot( + data=plot_df, + x=x_col, + y=y_col, + hue="marker", + palette=palette_m, + ax=axes[1], + alpha=0.5, + s=8, + linewidth=0, + ) + axes[1].set_title(f"{reduction.upper()} — channel/marker") + + sc = axes[2].scatter( + plot_df[x_col], + plot_df[y_col], + c=plot_df["mean_step_size"], + cmap="plasma", + alpha=0.5, + s=8, + vmin=vmin, + vmax=vmax, + ) + plt.colorbar(sc, ax=axes[2], label="Mean embedding step size") + axes[2].set_title(f"{reduction.upper()} — embedding displacement") + axes[2].set_xlabel(x_col) + axes[2].set_ylabel(y_col) + + plt.tight_layout() + out = output_dir / f"microglia_{reduction}.pdf" + plt.savefig(out, bbox_inches="tight") + plt.close() + print(f" Saved {out}") + + # Displacement by perturbation + fig, axes = plt.subplots(1, 2, figsize=(14, 5)) + order = sorted(track_metrics["perturbation"].unique()) if "perturbation" in track_metrics.columns else None + + sns.boxplot(data=track_metrics, x="perturbation", y="mean_step_size", ax=axes[0], order=order) + axes[0].set_title("Mean embedding step size by perturbation") + axes[0].set_ylabel("Mean step size in embedding space") + axes[0].tick_params(axis="x", rotation=30) + + sns.boxplot(data=track_metrics, x="perturbation", y="net_displacement", ax=axes[1], order=order) + axes[1].set_title("Net displacement (start→end) by perturbation") + axes[1].set_ylabel("Net displacement in embedding space") + axes[1].tick_params(axis="x", rotation=30) + + plt.tight_layout() + out = output_dir / "microglia_displacement_by_perturbation.pdf" + plt.savefig(out, bbox_inches="tight") + plt.close() + print(f" Saved {out}") + + summary = track_metrics.groupby("perturbation")[["mean_step_size", "net_displacement", "track_length"]].agg( + ["median", "mean", "std", "count"] + ) + print("\n## Microglia track displacement summary\n") + print(summary.to_markdown()) + + +def analyze_alfi(adata: ad.AnnData, output_dir: Path) -> None: + """Run ALFI HeLa cell cycle analysis and save plots.""" + print(f"\nALFI total: {adata.shape[0]:,} observations") + + # Filter to HeLa (MI06) + if "fov_name" in adata.obs.columns: + hela_mask = adata.obs["fov_name"] == "MI06" + elif "experiment" in adata.obs.columns: + hela_mask = adata.obs["experiment"].str.contains("HeLa") + else: + raise RuntimeError("Cannot identify HeLa cells — no fov_name or experiment column in obs") + + adata_hela = adata[hela_mask].copy() + print(f" HeLa (MI06): {adata_hela.shape[0]:,} observations") + + # Join annotations + annotations = pd.read_csv(ALFI_ANNOTATIONS) + ann_indexed = annotations.set_index(["fov_name", "track_id", "t"]) + + obs_hela = adata_hela.obs.copy() + mi = pd.MultiIndex.from_arrays( + [ + obs_hela["fov_name"], + obs_hela["track_id"].astype(int), + obs_hela["t"].astype(int), + ], + names=["fov_name", "track_id", "t"], + ) + obs_hela["cell_division_state"] = ann_indexed.reindex(mi)["cell_division_state"].values + obs_hela["cell_cycle_fine_state"] = ann_indexed.reindex(mi)["cell_cycle_fine_state"].values + + n_annotated = obs_hela["cell_division_state"].notna().sum() + print(f" Annotated: {n_annotated:,} / {len(obs_hela):,}") + print(obs_hela["cell_division_state"].value_counts().to_string()) + + features_hela = np.asarray(adata_hela.X) + features_scaled = StandardScaler().fit_transform(features_hela) + pca_emb = _get_or_compute_pca(adata_hela, features_scaled) + umap_emb = _get_or_compute_umap(adata_hela, features_scaled) + + unannotated = obs_hela["cell_division_state"].isna() + + for reduction, emb in [("pca", pca_emb), ("umap", umap_emb)]: + x_col, y_col = ("PC1", "PC2") if reduction == "pca" else ("UMAP1", "UMAP2") + + # Division state plot + fig, axes = plt.subplots(1, 2, figsize=(14, 6)) + for ax, fine in zip(axes, [False, True]): + col = "cell_cycle_fine_state" if fine else "cell_division_state" + states = obs_hela[col].dropna().unique() + if fine: + palette = dict(zip(sorted(states), sns.color_palette("tab10", len(states)))) + else: + palette = DIVISION_PALETTE + + for state, color in palette.items(): + mask = obs_hela[col] == state + ax.scatter( + emb[mask, 0], + emb[mask, 1], + c=color, + label=state, + alpha=0.6, + s=10, + linewidth=0, + ) + ax.scatter( + emb[unannotated, 0], + emb[unannotated, 1], + c="lightgray", + label="unannotated", + alpha=0.3, + s=6, + linewidth=0, + ) + title = "fine cell cycle state" if fine else "cell division state" + ax.set_title(f"HeLa {reduction.upper()} — {title}") + ax.set_xlabel(x_col) + ax.set_ylabel(y_col) + ax.legend(markerscale=2, bbox_to_anchor=(1, 1), loc="upper left", fontsize=8) + + plt.tight_layout() + out = output_dir / f"alfi_hela_{reduction}_cell_cycle.pdf" + plt.savefig(out, bbox_inches="tight") + plt.close() + print(f" Saved {out}") + + # Displacement by cell cycle state + track_metrics = compute_track_displacement_metrics(adata_hela) + + track_annotations = ( + annotations[annotations["fov_name"] == "MI06"] + .groupby(["fov_name", "track_id"])["cell_division_state"] + .agg(lambda x: x.dropna().mode().iloc[0] if x.dropna().shape[0] > 0 else pd.NA) + .reset_index() + .rename(columns={"cell_division_state": "dominant_state"}) + ) + track_metrics = track_metrics.merge(track_annotations, on=["fov_name", "track_id"], how="left") + + annotated = track_metrics.dropna(subset=["dominant_state"]) + if len(annotated) > 0: + fig, ax = plt.subplots(figsize=(6, 5)) + sns.boxplot( + data=annotated, + x="dominant_state", + y="mean_step_size", + palette=DIVISION_PALETTE, + ax=ax, + order=[s for s in DIVISION_PALETTE if s in annotated["dominant_state"].unique()], + ) + ax.set_title("HeLa: embedding step size by cell cycle state") + ax.set_xlabel("Dominant cell division state (per track)") + ax.set_ylabel("Mean step size in embedding space") + plt.tight_layout() + out = output_dir / "alfi_hela_displacement_by_state.pdf" + plt.savefig(out, bbox_inches="tight") + plt.close() + print(f" Saved {out}") + + summary = annotated.groupby("dominant_state")["mean_step_size"].describe() + print("\n## ALFI HeLa displacement by state\n") + print(summary.to_markdown()) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument( + "--microglia-embeddings", type=Path, required=True, help="AnnData zarr from microglia inference" + ) + parser.add_argument("--alfi-embeddings", type=Path, required=True, help="AnnData zarr from ALFI inference") + parser.add_argument("--output-dir", type=Path, required=True, help="Directory to save PDF figures") + args = parser.parse_args() + + args.output_dir.mkdir(parents=True, exist_ok=True) + + print("=== Microglia analysis ===") + adata_micro = ad.read_zarr(args.microglia_embeddings) + analyze_microglia(adata_micro, args.output_dir) + + print("\n=== ALFI analysis ===") + adata_alfi = ad.read_zarr(args.alfi_embeddings) + analyze_alfi(adata_alfi, args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/scripts/profiling/README.md b/applications/dynaclr/scripts/profiling/README.md new file mode 100644 index 000000000..b1b44a730 --- /dev/null +++ b/applications/dynaclr/scripts/profiling/README.md @@ -0,0 +1,40 @@ +# DynaCLR I/O profiling scripts + +Scripts that validate data-loading performance on VAST/NFS for the DynaCLR +contrastive training pipeline. + +## Current scripts + +### `benchmark_recheck_cached_data.py` + +Measures the effect of `TensorStoreConfig.recheck_cached_data` on NFS read +latency for the DynaCLR contrastive read pattern. Exercises the iohub +tensorstore implementation directly (no training stack involved) so it can +be run **before** the dynaclr datamodule is ported to iohub 0.3.x. + +**Prerequisite.** Requires an iohub build with the upstream +`recheck_cached_data` knob on `TensorStoreConfig`. Until that lands, either +install iohub from the feature branch locally, or skip this script. + +Run: + +``` +uv run python applications/dynaclr/scripts/profiling/benchmark_recheck_cached_data.py +``` + +Output is a markdown table comparing median/p95 batch latency, patches/s, +and MiB/s across three configurations (`none`, `"open"`, `false`). Run +twice back-to-back and compare: if the `none` vs `"open"` gap shrinks on +the second run, the Linux NFS client page cache is masking the +per-chunk revalidation cost on this node. + +## Planned follow-ups (after iohub 0.3.x merge into dynadtw) + +- **Dataset-level A/B** — same configurations, but driven through + `MultiExperimentDataModule` + `MultiExperimentTripletDataset` so we + exercise `_get_position`/`_get_tensorstore`/`_slice_patches` and the + `ts.stack(...).read().result()` batched read path exactly as training + does. +- **SLURM DDP A/B** — 200-step fastdev runs with Lightning's + `SimpleProfiler`, comparing `data_time`/`batch_time` and GETATTR/s + from `nfsiostat` across ranks. diff --git a/applications/dynaclr/scripts/profiling/benchmark_boc2d_real.py b/applications/dynaclr/scripts/profiling/benchmark_boc2d_real.py new file mode 100644 index 000000000..206ddcf2e --- /dev/null +++ b/applications/dynaclr/scripts/profiling/benchmark_boc2d_real.py @@ -0,0 +1,264 @@ +"""Production-config DataLoader benchmark + batch-composition sanity check. + +Exercises the real +``DynaCLR-2D-MIP-BagOfChannels.yml`` training settings against the +committed v2 parquet to measure end-to-end DataLoader throughput and +verify that batch-grouping/stratification actually do what the config +says. + +Two parts +--------- + +**1. Composition check** — forces ``batch_group_by="marker"`` and checks +the first 20 batches: + +- every batch contains exactly one marker (single-marker batches), +- different batches surface different markers (proves the grouping is + shuffled across the epoch, not stuck on one value). + +**2. Throughput A/B** — runs the production config +(``batch_size=256``, ``channels_per_sample=1``, ``stratify_by=[perturbation, marker]``, +``num_workers=2``) under two ``recheck_cached_data`` settings: + +- ``None`` — TensorStore driver default. +- ``"open"`` — validate at open only (our merge's default). + +Reports median/p95 per-iter latency, iter/s, samples/s for each leg. +Because this runs on the real VAST-resident parquet with 7k+ FOVs, the +FOV-open amortisation is representative of real training. + +Usage +----- + uv run python applications/dynaclr/scripts/profiling/benchmark_boc2d_real.py +""" + +from __future__ import annotations + +import statistics +import time +from dataclasses import dataclass + +import numpy as np + +from dynaclr.data.datamodule import MultiExperimentDataModule +from viscy_transforms import ( + BatchedChannelWiseZReductiond, + BatchedRandSpatialCropd, + NormalizeSampled, +) + +CELL_INDEX_PARQUET = "/hpc/projects/organelle_phenotyping/models/collections/DynaCLR-2D-MIP-BagOfChannels-v2.parquet" + +BATCH_SIZE = 256 +NUM_WORKERS = 2 +WARMUP_BATCHES = 10 +N_BATCHES = 60 +SEED = 42 + +Z_WINDOW = 1 +Z_EXTRACTION_WINDOW = 20 +Z_FOCUS_OFFSET = 0.3 +YX_PATCH_SIZE = (256, 256) +FINAL_YX_PATCH_SIZE = (160, 160) + +COMPOSITION_BATCHES = 20 + +RECHECK_LEGS: list[tuple[str, str | bool | None]] = [ + ("None (driver default)", None), + ("open (our default)", "open"), +] + + +@dataclass +class LegResult: + """Timing outcome for one recheck_cached_data leg on the real parquet.""" + + label: str + iter_latencies_s: list[float] + total_s: float + + @property + def median_ms(self) -> float: + """Return median per-iter latency in milliseconds.""" + return statistics.median(self.iter_latencies_s) * 1000.0 + + @property + def p95_ms(self) -> float: + """Return p95 per-iter latency in milliseconds.""" + return float(np.percentile(self.iter_latencies_s, 95)) * 1000.0 + + @property + def iter_per_s(self) -> float: + """Return sustained iterations per second.""" + return len(self.iter_latencies_s) / self.total_s + + @property + def samples_per_s(self) -> float: + """Return sustained samples per second.""" + return self.iter_per_s * BATCH_SIZE + + +def _build_production_dm( + recheck_cached_data: str | bool | None, + batch_group_by: str | list[str] | None = None, + stratify_by: list[str] | None = None, + num_workers: int = NUM_WORKERS, +) -> MultiExperimentDataModule: + """Build a DataModule matching the production 2D-MIP-BoC training recipe.""" + normalizations = [ + NormalizeSampled( + keys=["channel_0"], + level="timepoint_statistics", + subtrahend="mean", + divisor="std", + ), + ] + augmentations = [ + BatchedRandSpatialCropd(keys=["channel_0"], roi_size=(10, 192, 192)), + BatchedChannelWiseZReductiond(keys=["channel_0"], allow_missing_keys=True), + ] + dm = MultiExperimentDataModule( + cell_index_path=CELL_INDEX_PARQUET, + focus_channel="Phase3D", + reference_pixel_size_xy_um=0.1494, + z_window=Z_WINDOW, + z_extraction_window=Z_EXTRACTION_WINDOW, + z_focus_offset=Z_FOCUS_OFFSET, + yx_patch_size=YX_PATCH_SIZE, + final_yx_patch_size=FINAL_YX_PATCH_SIZE, + channels_per_sample=1, + positive_cell_source="lookup", + positive_match_columns=["lineage_id"], + positive_channel_source="same", + tau_range=(0.5, 2.0), + tau_decay_rate=2.0, + batch_group_by=batch_group_by, + stratify_by=stratify_by if stratify_by is not None else ["perturbation", "marker"], + split_ratio=0.8, + batch_size=BATCH_SIZE, + num_workers=num_workers, + seed=SEED, + normalizations=normalizations, + augmentations=augmentations, + ) + dm.tensorstore_config = dm.tensorstore_config.model_copy(update={"recheck_cached_data": recheck_cached_data}) + return dm + + +def _composition_check() -> None: + """Verify batch_group_by='marker' yields single-marker, shuffled batches.""" + print("=" * 72) + print("Composition check: batch_group_by='marker'") + print("=" * 72) + + dm = _build_production_dm( + recheck_cached_data="open", + batch_group_by="marker", + stratify_by=None, + num_workers=0, + ) + dm.setup("fit") + loader = dm.train_dataloader() + it = iter(loader) + + markers_by_batch: list[set[str]] = [] + for i in range(COMPOSITION_BATCHES): + batch = next(it) + metas = batch["anchor_meta"] + batch_markers = {m["marker"] for m in metas} + markers_by_batch.append(batch_markers) + print(f" batch {i:>2}: {len(batch_markers)} unique markers → {sorted(batch_markers)[:4]}") + + non_singleton = [i for i, ms in enumerate(markers_by_batch) if len(ms) != 1] + if non_singleton: + print(f"\n FAIL: {len(non_singleton)} of {COMPOSITION_BATCHES} batches had >1 marker") + print(f" offending batches: {non_singleton}") + raise AssertionError("batch_group_by='marker' did not produce single-marker batches") + + unique_markers_seen = set().union(*markers_by_batch) + print(f"\n PASS: all {COMPOSITION_BATCHES} batches are single-marker") + print(f" markers touched across the {COMPOSITION_BATCHES} batches: {len(unique_markers_seen)}") + print(f" → {sorted(unique_markers_seen)}") + + if len(unique_markers_seen) < 2: + print("\n WARNING: only 1 marker touched across all batches — epoch may be stuck on one group") + else: + print(" → grouping is shuffled across markers (good)") + + del it + del loader + + +def _run_throughput_leg(label: str, recheck_cached_data: str | bool | None) -> LegResult: + """Run one throughput leg on the production config.""" + print(f"\n-- Throughput leg: recheck_cached_data = {label} --") + dm = _build_production_dm( + recheck_cached_data=recheck_cached_data, + batch_group_by=None, + stratify_by=["perturbation", "marker"], + num_workers=NUM_WORKERS, + ) + dm.setup("fit") + loader = dm.train_dataloader() + it = iter(loader) + + for _ in range(WARMUP_BATCHES): + _ = next(it) + + latencies_s: list[float] = [] + t_total = time.perf_counter() + t_prev = time.perf_counter() + for _ in range(N_BATCHES): + _ = next(it) + t_now = time.perf_counter() + latencies_s.append(t_now - t_prev) + t_prev = t_now + total_s = time.perf_counter() - t_total + + del it + del loader + + result = LegResult(label=label, iter_latencies_s=latencies_s, total_s=total_s) + print( + f" median {result.median_ms:.1f} ms | p95 {result.p95_ms:.1f} ms | " + f"{result.iter_per_s:.2f} iter/s | {result.samples_per_s:.1f} samples/s" + ) + return result + + +def _print_markdown(results: list[LegResult]) -> None: + """Emit a markdown-formatted throughput table.""" + print() + print("## Throughput (real 2D-MIP-BoC v2 parquet)") + print() + print(f"- Parquet: `{CELL_INDEX_PARQUET.split('/')[-1]}`") + print(f"- Batch size: {BATCH_SIZE}, num_workers: {NUM_WORKERS}") + print(f"- Warmup: {WARMUP_BATCHES} batches; timed: {N_BATCHES} batches") + print(f"- Z_extraction={Z_EXTRACTION_WINDOW}, YX={YX_PATCH_SIZE}, final_YX={FINAL_YX_PATCH_SIZE}") + print("- channels_per_sample=1, stratify_by=[perturbation, marker]") + print() + print("| recheck_cached_data | median ms | p95 ms | iter/s | samples/s |") + print("|---|---:|---:|---:|---:|") + for r in results: + print(f"| {r.label} | {r.median_ms:.1f} | {r.p95_ms:.1f} | {r.iter_per_s:.2f} | {r.samples_per_s:.1f} |") + print() + + +def main() -> None: + """Run composition check, then the throughput A/B, and print a summary.""" + _composition_check() + + print() + print("=" * 72) + print("Throughput A/B: production config, real parquet") + print("=" * 72) + + results: list[LegResult] = [] + for label, value in RECHECK_LEGS: + results.append(_run_throughput_leg(label, value)) + + _print_markdown(results) + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/scripts/profiling/benchmark_dataloader_recheck.py b/applications/dynaclr/scripts/profiling/benchmark_dataloader_recheck.py new file mode 100644 index 000000000..bd8626fad --- /dev/null +++ b/applications/dynaclr/scripts/profiling/benchmark_dataloader_recheck.py @@ -0,0 +1,198 @@ +"""Full-pipeline A/B benchmark for TensorStoreConfig.recheck_cached_data. + +Drives :class:`dynaclr.data.datamodule.MultiExperimentDataModule` +end-to-end — ``__getitems__`` + ``collate_fn=lambda x:x`` + +PyTorch DataLoader with ``num_workers`` forked workers — to measure the +effect of ``recheck_cached_data`` on sustained training-loader +throughput, the only number that actually matters for GPU utilization. + +Three legs are compared against the same parquet, in the same process, +with the same FOVs and the same seed so sampling is deterministic: + +- ``"open"`` — validate at open only, trust cache thereafter (our + expected production setting). +- ``None`` — driver default, revalidate cached chunk metadata every + read (one stat/GETATTR per chunk per read on NFS). +- ``False`` — never revalidate (included for completeness). + +Per leg the script: + +1. Constructs a fresh ``MultiExperimentDataModule``, forcibly overriding + ``self.tensorstore_config.recheck_cached_data`` after ``__init__`` so + every Plate opens with the configured setting. +2. Runs ``setup("fit")`` once. +3. Warms the DataLoader with ``WARMUP_BATCHES`` batches (discarded). +4. Times ``N_BATCHES`` steady-state batches by wall-clocking the + iterator yield interval — this is what the training loop sees. +5. Reports median/p95 iteration time and steady-state iter/s. + +Because we use forked DataLoader workers, each config opens its own +Plates inside the worker after fork — matching real DDP training. + +Usage +----- + uv run python applications/dynaclr/scripts/profiling/benchmark_dataloader_recheck.py + +Requires: + +- iohub with ``recheck_cached_data`` on ``TensorStoreConfig`` + (czbiohub-sf/iohub#406 or later). +- A parquet whose ``store_path`` entries are readable on this node. +""" + +from __future__ import annotations + +import statistics +import time +from dataclasses import dataclass + +import numpy as np + +from dynaclr.data.datamodule import MultiExperimentDataModule + +CELL_INDEX_PARQUET = "/home/eduardo.hirata/repos/viscy/applications/dynaclr/configs/cell_index/benchmark_2exp.parquet" + +BATCH_SIZE = 32 +NUM_WORKERS = 4 +WARMUP_BATCHES = 10 +N_BATCHES = 100 +SEED = 42 + +Z_WINDOW = 8 +YX_PATCH_SIZE = (192, 192) +FINAL_YX_PATCH_SIZE = (160, 160) + +LEGS: list[tuple[str, str | bool | None]] = [ + ("open (recommended)", "open"), + ("None (driver default)", None), + ("False (never revalidate)", False), +] + + +@dataclass +class LegResult: + """Timing outcome for one recheck_cached_data leg.""" + + label: str + iter_latencies_s: list[float] + total_s: float + + @property + def median_ms(self) -> float: + """Return the median inter-batch iteration time in milliseconds.""" + return statistics.median(self.iter_latencies_s) * 1000.0 + + @property + def p95_ms(self) -> float: + """Return the p95 inter-batch iteration time in milliseconds.""" + return float(np.percentile(self.iter_latencies_s, 95)) * 1000.0 + + @property + def iter_per_s(self) -> float: + """Return steady-state iterations per second.""" + return len(self.iter_latencies_s) / self.total_s + + @property + def samples_per_s(self) -> float: + """Return steady-state samples per second.""" + return self.iter_per_s * BATCH_SIZE + + +def _build_datamodule(recheck_cached_data: str | bool | None) -> MultiExperimentDataModule: + """Construct a DataModule and force the recheck_cached_data leg onto its config.""" + dm = MultiExperimentDataModule( + cell_index_path=CELL_INDEX_PARQUET, + z_window=Z_WINDOW, + yx_patch_size=YX_PATCH_SIZE, + final_yx_patch_size=FINAL_YX_PATCH_SIZE, + channels_per_sample=None, + positive_cell_source="lookup", + positive_match_columns=["lineage_id"], + tau_range=(0.5, 2.0), + tau_decay_rate=2.0, + stratify_by=None, + split_ratio=0.8, + batch_size=BATCH_SIZE, + num_workers=NUM_WORKERS, + seed=SEED, + normalizations=[], + augmentations=[], + ) + # The datamodule sets recheck_cached_data="open" by default; override + # it here so every leg can dial the knob independently without editing + # the production code path. + dm.tensorstore_config = dm.tensorstore_config.model_copy(update={"recheck_cached_data": recheck_cached_data}) + return dm + + +def _run_leg(label: str, recheck_cached_data: str | bool | None) -> LegResult: + """Run one A/B leg and return a populated LegResult.""" + print(f"\n-- Leg: recheck_cached_data = {label} --") + dm = _build_datamodule(recheck_cached_data) + dm.setup("fit") + loader = dm.train_dataloader() + + it = iter(loader) + + # Warmup — discard. Forks workers, populates each worker's plate/ts + # caches, amortises Python import cost in the forked child. + for _ in range(WARMUP_BATCHES): + _ = next(it) + + # Steady-state timing. We measure the inter-batch yield interval, + # which is exactly what the training loop observes. + latencies_s: list[float] = [] + t_total = time.perf_counter() + t_prev = time.perf_counter() + for _ in range(N_BATCHES): + _ = next(it) + t_now = time.perf_counter() + latencies_s.append(t_now - t_prev) + t_prev = t_now + total_s = time.perf_counter() - t_total + + # Release workers before the next leg so forked processes do not + # pile up and compete for file descriptors. + del it + del loader + + result = LegResult(label=label, iter_latencies_s=latencies_s, total_s=total_s) + print( + f" median {result.median_ms:.1f} ms | p95 {result.p95_ms:.1f} ms | " + f"{result.iter_per_s:.2f} iter/s | {result.samples_per_s:.1f} samples/s" + ) + return result + + +def _print_markdown(results: list[LegResult]) -> None: + """Emit a markdown-formatted summary for the PR / Confluence.""" + print() + print("## Results (dataloader-level A/B)") + print() + print(f"- Parquet: `{CELL_INDEX_PARQUET.split('/')[-1]}`") + print(f"- Batch size: {BATCH_SIZE}, num_workers: {NUM_WORKERS}") + print(f"- Warmup: {WARMUP_BATCHES} batches; timed: {N_BATCHES} batches") + print(f"- Z={Z_WINDOW}, YX={YX_PATCH_SIZE}, final_YX={FINAL_YX_PATCH_SIZE}") + print() + print("| recheck_cached_data | median ms | p95 ms | iter/s | samples/s |") + print("|---|---:|---:|---:|---:|") + for r in results: + print(f"| {r.label} | {r.median_ms:.1f} | {r.p95_ms:.1f} | {r.iter_per_s:.2f} | {r.samples_per_s:.1f} |") + print() + + +def main() -> None: + """Run all three legs and print a combined markdown summary.""" + print("=" * 72) + print("Dataloader-level recheck_cached_data benchmark — MultiExperimentDataModule") + print("=" * 72) + + results: list[LegResult] = [] + for label, value in LEGS: + results.append(_run_leg(label, value)) + + _print_markdown(results) + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/scripts/profiling/benchmark_dataloader_workers_sweep.py b/applications/dynaclr/scripts/profiling/benchmark_dataloader_workers_sweep.py new file mode 100644 index 000000000..3a650e155 --- /dev/null +++ b/applications/dynaclr/scripts/profiling/benchmark_dataloader_workers_sweep.py @@ -0,0 +1,183 @@ +"""Sweep num_workers × recheck_cached_data for the DynaCLR dataloader. + +Purpose +------- + +The first pass A/B (``benchmark_dataloader_recheck.py``) showed a counter- +intuitive result on ``MultiExperimentDataModule.train_dataloader()`` with +``num_workers=4``: ``recheck_cached_data="open"`` was slower than the +driver default. The raw ``ts.stack`` benchmark showed the opposite. Most +likely the p95 tails were dominated by first-touch FOV opens while the +ThreadDataLoader prefetch buffer masked them differently per leg. + +This sweep pins down the cause by running every ``recheck_cached_data`` +value across several ``num_workers`` settings with generous warmup, so we +can tell: + +- Does the ordering flip between ``num_workers=0`` (no fork, no thread + buffer) and ``num_workers>0`` (forked workers)? +- Is the ``"open"`` penalty paid only on cold FOV opens? If yes, longer + warmup should close the gap. +- Does the ``p95`` converge once steady-state is reached? + +Usage +----- + uv run python applications/dynaclr/scripts/profiling/benchmark_dataloader_workers_sweep.py +""" + +from __future__ import annotations + +import statistics +import time +from dataclasses import dataclass + +import numpy as np + +from dynaclr.data.datamodule import MultiExperimentDataModule + +CELL_INDEX_PARQUET = "/home/eduardo.hirata/repos/viscy/applications/dynaclr/configs/cell_index/benchmark_2exp.parquet" + +BATCH_SIZE = 32 +WARMUP_BATCHES = 30 +N_BATCHES = 150 +SEED = 42 + +Z_WINDOW = 8 +YX_PATCH_SIZE = (192, 192) +FINAL_YX_PATCH_SIZE = (160, 160) + +WORKER_COUNTS: list[int] = [0, 1, 4] +RECHECK_VALUES: list[tuple[str, str | bool | None]] = [ + ("None", None), + ("open", "open"), + ("False", False), +] + + +@dataclass +class SweepResult: + """One cell of the ``num_workers`` × ``recheck_cached_data`` grid.""" + + num_workers: int + recheck_label: str + iter_latencies_s: list[float] + total_s: float + + @property + def median_ms(self) -> float: + """Return median per-iter latency in milliseconds.""" + return statistics.median(self.iter_latencies_s) * 1000.0 + + @property + def p95_ms(self) -> float: + """Return p95 per-iter latency in milliseconds.""" + return float(np.percentile(self.iter_latencies_s, 95)) * 1000.0 + + @property + def iter_per_s(self) -> float: + """Return sustained iterations per second across timed batches.""" + return len(self.iter_latencies_s) / self.total_s + + @property + def samples_per_s(self) -> float: + """Return sustained samples per second (iter/s × batch).""" + return self.iter_per_s * BATCH_SIZE + + +def _build(num_workers: int, recheck_cached_data: str | bool | None) -> MultiExperimentDataModule: + """Build one datamodule with forced num_workers and recheck_cached_data.""" + dm = MultiExperimentDataModule( + cell_index_path=CELL_INDEX_PARQUET, + z_window=Z_WINDOW, + yx_patch_size=YX_PATCH_SIZE, + final_yx_patch_size=FINAL_YX_PATCH_SIZE, + channels_per_sample=None, + positive_cell_source="lookup", + positive_match_columns=["lineage_id"], + tau_range=(0.5, 2.0), + tau_decay_rate=2.0, + stratify_by=None, + split_ratio=0.8, + batch_size=BATCH_SIZE, + num_workers=num_workers, + seed=SEED, + normalizations=[], + augmentations=[], + ) + dm.tensorstore_config = dm.tensorstore_config.model_copy(update={"recheck_cached_data": recheck_cached_data}) + return dm + + +def _run_cell(num_workers: int, label: str, recheck_cached_data: str | bool | None) -> SweepResult: + """Run one cell of the sweep.""" + print(f"\n-- num_workers={num_workers}, recheck_cached_data={label} --") + dm = _build(num_workers, recheck_cached_data) + dm.setup("fit") + loader = dm.train_dataloader() + it = iter(loader) + + for _ in range(WARMUP_BATCHES): + _ = next(it) + + latencies_s: list[float] = [] + t_total = time.perf_counter() + t_prev = time.perf_counter() + for _ in range(N_BATCHES): + _ = next(it) + t_now = time.perf_counter() + latencies_s.append(t_now - t_prev) + t_prev = t_now + total_s = time.perf_counter() - t_total + + del it + del loader + + result = SweepResult( + num_workers=num_workers, + recheck_label=label, + iter_latencies_s=latencies_s, + total_s=total_s, + ) + print( + f" median {result.median_ms:.1f} ms | p95 {result.p95_ms:.1f} ms | " + f"{result.iter_per_s:.2f} iter/s | {result.samples_per_s:.1f} samples/s" + ) + return result + + +def _print_markdown(results: list[SweepResult]) -> None: + """Emit a markdown-formatted sweep table for the PR / Confluence.""" + print() + print("## Sweep results") + print() + print(f"- Parquet: `{CELL_INDEX_PARQUET.split('/')[-1]}`") + print(f"- Batch size: {BATCH_SIZE}, warmup: {WARMUP_BATCHES}, timed: {N_BATCHES}") + print(f"- Z={Z_WINDOW}, YX={YX_PATCH_SIZE}, final_YX={FINAL_YX_PATCH_SIZE}") + print() + print("| num_workers | recheck | median ms | p95 ms | iter/s | samples/s |") + print("|---:|---|---:|---:|---:|---:|") + for r in results: + print( + f"| {r.num_workers} | {r.recheck_label} | " + f"{r.median_ms:.1f} | {r.p95_ms:.1f} | " + f"{r.iter_per_s:.2f} | {r.samples_per_s:.1f} |" + ) + print() + + +def main() -> None: + """Run the full sweep and print a combined markdown summary.""" + print("=" * 72) + print("num_workers × recheck_cached_data sweep — MultiExperimentDataModule") + print("=" * 72) + + results: list[SweepResult] = [] + for nw in WORKER_COUNTS: + for label, value in RECHECK_VALUES: + results.append(_run_cell(nw, label, value)) + + _print_markdown(results) + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/scripts/profiling/benchmark_recheck_cached_data.py b/applications/dynaclr/scripts/profiling/benchmark_recheck_cached_data.py new file mode 100644 index 000000000..98f503a11 --- /dev/null +++ b/applications/dynaclr/scripts/profiling/benchmark_recheck_cached_data.py @@ -0,0 +1,215 @@ +"""Measure the impact of ``TensorStoreConfig.recheck_cached_data`` on NFS reads. + +Single-process raw ``ts.stack(...).read().result()`` loop against a +2-experiment parquet for three TensorStoreConfig settings: + +- ``none`` — driver default, revalidate on every read (one stat/GETATTR + per chunk per read). +- ``open`` — validate only at open time, trust the cache thereafter. +- ``false`` — never revalidate. + +The loop issues ``N_BATCHES`` batches of stacked 3D crops sampled from +random FOVs, reports median/p95 read latency and sustained patches/s. +For the DataLoader-driven end-to-end view see +``benchmark_dataloader_workers_sweep.py``. + +Usage +----- + uv run python applications/dynaclr/scripts/profiling/benchmark_recheck_cached_data.py +""" + +from __future__ import annotations + +import statistics +import time +from dataclasses import dataclass +from typing import Any + +import numpy as np +import pandas as pd +import tensorstore as ts +from iohub import open_ome_zarr +from iohub.core.config import TensorStoreConfig + +CELL_INDEX_PARQUET = "/home/eduardo.hirata/repos/viscy/applications/dynaclr/configs/cell_index/benchmark_2exp.parquet" + +BATCH_SIZE = 32 +N_BATCHES = 50 +PATCH_Z = 8 +PATCH_YX = (192, 192) +SEED = 0 + +DATA_COPY_CONCURRENCY = 16 +FILE_IO_CONCURRENCY = 64 +CACHE_POOL_BYTES: int | None = None + +CONFIGS: list[tuple[str, dict[str, Any]]] = [ + ("none (driver default)", {}), + ("open", {"recheck_cached_data": "open"}), + ("false", {"recheck_cached_data": False}), +] + + +@dataclass +class Result: + """Timing results for one ``recheck_cached_data`` configuration.""" + + label: str + batch_latencies_ms: list[float] + total_bytes: int + total_s: float + + @property + def median_ms(self) -> float: + """Return the median per-batch read latency in milliseconds.""" + return statistics.median(self.batch_latencies_ms) + + @property + def p95_ms(self) -> float: + """Return the p95 per-batch read latency in milliseconds.""" + return float(np.percentile(self.batch_latencies_ms, 95)) + + @property + def patches_per_s(self) -> float: + """Return the sustained patch-read throughput.""" + return BATCH_SIZE * len(self.batch_latencies_ms) / self.total_s + + @property + def mib_per_s(self) -> float: + """Return the sustained read throughput in MiB/s.""" + return (self.total_bytes / (1024 * 1024)) / self.total_s + + +def _load_fov_index() -> pd.DataFrame: + """Return unique (store_path, well, fov, shape) rows from the benchmark parquet.""" + df = pd.read_parquet(CELL_INDEX_PARQUET) + unique = df[["store_path", "well", "fov", "C_shape", "Z_shape", "Y_shape", "X_shape"]].drop_duplicates( + subset=["store_path", "well", "fov"] + ) + return unique.reset_index(drop=True) + + +def _open_stores(fov_df: pd.DataFrame, ts_config: TensorStoreConfig) -> dict[str, Any]: + """Open each unique zarr store once with the given TensorStoreConfig.""" + store_paths = fov_df["store_path"].drop_duplicates().tolist() + plates: dict[str, Any] = {} + for sp in store_paths: + plates[sp] = open_ome_zarr( + sp, + mode="r", + implementation="tensorstore", + implementation_config=ts_config, + ) + return plates + + +def _sample_patches( + fov_df: pd.DataFrame, + plates: dict[str, Any], + batch_size: int, + rng: np.random.Generator, +) -> tuple[list[ts.TensorStore], int]: + """Pick ``batch_size`` random (fov, z, y, x) crops and return lazy slices + byte count. + + Returns a list of tensorstore lazy slices (one per crop) plus the + total number of bytes the resulting stacked read will pull. + """ + lazies: list[ts.TensorStore] = [] + total_bytes = 0 + rows = fov_df.sample(n=batch_size, replace=True, random_state=rng.integers(0, 2**31 - 1)) + for _, row in rows.iterrows(): + plate = plates[row["store_path"]] + position_path = f"{row['well']}/{row['fov']}" + arr = plate[position_path]["0"].native + z_start = int(rng.integers(0, max(1, row["Z_shape"] - PATCH_Z + 1))) + y_start = int(rng.integers(0, max(1, row["Y_shape"] - PATCH_YX[0] + 1))) + x_start = int(rng.integers(0, max(1, row["X_shape"] - PATCH_YX[1] + 1))) + lazy = arr[ + 0, # t=0 — keep indexing simple; timepoint is not what we're benchmarking + :, + z_start : z_start + PATCH_Z, + y_start : y_start + PATCH_YX[0], + x_start : x_start + PATCH_YX[1], + ] + lazies.append(lazy) + total_bytes += PATCH_Z * PATCH_YX[0] * PATCH_YX[1] * row["C_shape"] * 4 # assume float32 + return lazies, total_bytes + + +def _run_one_config(label: str, extra_cfg: dict[str, Any], fov_df: pd.DataFrame) -> Result: + """Run the read-loop benchmark for one recheck_cached_data setting.""" + ts_config = TensorStoreConfig( + data_copy_concurrency=DATA_COPY_CONCURRENCY, + file_io_concurrency=FILE_IO_CONCURRENCY, + cache_pool_bytes=CACHE_POOL_BYTES, + **extra_cfg, + ) + plates = _open_stores(fov_df, ts_config) + + def _translate_all(lazies: list[ts.TensorStore]) -> list[ts.TensorStore]: + """Translate each lazy slice to origin so ts.stack can combine them.""" + return [p.translate_to[0] for p in lazies] # noqa: PD013 + + rng_warm = np.random.default_rng(SEED) + warm_lazies, _ = _sample_patches(fov_df, plates, BATCH_SIZE, rng_warm) + _ = ts.stack(_translate_all(warm_lazies)).read().result() + + rng = np.random.default_rng(SEED + 1) + latencies_ms: list[float] = [] + total_bytes = 0 + t_total = time.perf_counter() + for _ in range(N_BATCHES): + lazies, batch_bytes = _sample_patches(fov_df, plates, BATCH_SIZE, rng) + t0 = time.perf_counter() + _ = ts.stack(_translate_all(lazies)).read().result() + latencies_ms.append((time.perf_counter() - t0) * 1000.0) + total_bytes += batch_bytes + total_s = time.perf_counter() - t_total + + for plate in plates.values(): + plate.close() + + return Result(label=label, batch_latencies_ms=latencies_ms, total_bytes=total_bytes, total_s=total_s) + + +def _print_markdown_table(results: list[Result]) -> None: + """Print a markdown-formatted results table suitable for Confluence/PR pasting.""" + print() + print("## Results") + print() + print(f"- Parquet: `{CELL_INDEX_PARQUET.split('/')[-1]}`") + print(f"- Batch size: {BATCH_SIZE}, N batches: {N_BATCHES}") + print(f"- Patch shape: (C, Z={PATCH_Z}, Y={PATCH_YX[0]}, X={PATCH_YX[1]})") + print(f"- data_copy_concurrency={DATA_COPY_CONCURRENCY}, file_io_concurrency={FILE_IO_CONCURRENCY}") + print() + print("| recheck_cached_data | median ms | p95 ms | patches/s | MiB/s | total s |") + print("|---|---:|---:|---:|---:|---:|") + for r in results: + print( + f"| {r.label} | {r.median_ms:.1f} | {r.p95_ms:.1f} | " + f"{r.patches_per_s:.1f} | {r.mib_per_s:.1f} | {r.total_s:.2f} |" + ) + print() + + +def main() -> None: + """Run the three configurations back-to-back and print a markdown summary.""" + print("=" * 72) + print("recheck_cached_data benchmark — DynaCLR contrastive read pattern on VAST") + print("=" * 72) + + fov_df = _load_fov_index() + print(f"Loaded {len(fov_df)} unique FOVs across {fov_df['store_path'].nunique()} stores") + + results: list[Result] = [] + for label, extra_cfg in CONFIGS: + print(f"\n-- Running: recheck_cached_data = {label} --") + r = _run_one_config(label, extra_cfg, fov_df) + print(f" median {r.median_ms:.1f} ms | p95 {r.p95_ms:.1f} ms | {r.patches_per_s:.1f} patches/s") + results.append(r) + + _print_markdown_table(results) + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/scripts/profiling/profile_dataloaders.py b/applications/dynaclr/scripts/profiling/profile_dataloaders.py new file mode 100644 index 000000000..956fb201d --- /dev/null +++ b/applications/dynaclr/scripts/profiling/profile_dataloaders.py @@ -0,0 +1,371 @@ +"""Profile BatchedConcatDataModule + TripletDatasets vs MultiExperimentDataModule. + +Benchmarks setup time, raw __getitems__ latency, and full dataloader +throughput for: +- Old: BatchedConcatDataModule wrapping 2 TripletDataModules (one per experiment) +- New: Single MultiExperimentDataModule with flat parquet index + +Uses two real datasets: +- 2025_07_24 G3BP1 (stress granules) +- 2025_04_15 H2B (chromatin) + +Usage +----- + uv run python applications/dynaclr/scripts/dataloader_inspection/profile_dataloaders.py +""" + +from __future__ import annotations + +import time + +import numpy as np +import pandas as pd +import torch + +# --------------------------------------------------------------------------- +# Dataset paths +# --------------------------------------------------------------------------- + +COLLECTION_YAML = "applications/dynaclr/configs/collections/benchmark_2exp.yml" +CELL_INDEX_PARQUET = "applications/dynaclr/configs/cell_index/benchmark_2exp.parquet" + +DATASETS = { + "G3BP1": { + "data_path": ( + "/hpc/projects/organelle_phenotyping/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV" + "/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr" + ), + "tracks_path": ( + "/hpc/projects/organelle_phenotyping/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr" + ), + "source_channel": ["raw GFP EX488 EM525-45"], + "include_wells": ["C/1", "C/2"], + }, + "H2B": { + "data_path": ( + "/hpc/projects/organelle_phenotyping/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV" + "/2025_04_15_A549_H2B_CAAX_ZIKV_DENV.zarr" + ), + "tracks_path": ( + "/hpc/projects/organelle_phenotyping/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/tracking.zarr" + ), + "source_channel": ["raw Cy5 EX639 EM698-70"], + "include_wells": ["B/1", "B/2"], + }, +} + +# Shared benchmark parameters +BATCH_SIZES = [8, 32, 64, 128] +N_BATCHES = 20 +WARMUP_BATCHES = 3 +CACHE_POOL_BYTES = 500_000_000 # 500 MB +Z_RANGE = (30, 46) # 16 z-slices, 3D benchmark + + +def _fmt(seconds: float) -> str: + if seconds < 1: + return f"{seconds * 1000:.1f} ms" + return f"{seconds:.2f} s" + + +# ====================================================================== +# Old: BatchedConcatDataModule wrapping 2 TripletDataModules +# ====================================================================== + + +def setup_old(): + """Set up legacy BatchedConcatDataModule with 2 TripletDataModules.""" + from viscy_data.combined import BatchedConcatDataModule + from viscy_data.triplet import TripletDataModule + + dms = [] + for name, cfg in DATASETS.items(): + dm = TripletDataModule( + data_path=cfg["data_path"], + tracks_path=cfg["tracks_path"], + source_channel=cfg["source_channel"], + z_range=Z_RANGE, + initial_yx_patch_size=(192, 192), + final_yx_patch_size=(160, 160), + split_ratio=0.8, + batch_size=BATCH_SIZES[-1], + num_workers=1, + time_interval=3, + return_negative=False, + cache_pool_bytes=CACHE_POOL_BYTES, + fit_include_wells=cfg["include_wells"], + ) + dms.append(dm) + print(f" Created TripletDataModule for {name}") + + concat_dm = BatchedConcatDataModule(data_modules=dms) + concat_dm.setup("fit") + return concat_dm + + +# ====================================================================== +# New: MultiExperimentDataModule +# ====================================================================== + + +def setup_new(): + """Set up MultiExperimentDataModule with pre-built parquet.""" + from dynaclr.data.datamodule import MultiExperimentDataModule + + dm = MultiExperimentDataModule( + cell_index_path=CELL_INDEX_PARQUET, + z_window=Z_RANGE[1] - Z_RANGE[0], # 16 + yx_patch_size=(192, 192), + final_yx_patch_size=(160, 160), + channels_per_sample=None, + positive_cell_source="lookup", + positive_match_columns=["lineage_id"], + tau_range=(0.5, 2.0), + tau_decay_rate=2.0, + stratify_by=["perturbation"], + split_ratio=0.8, + batch_size=BATCH_SIZES[-1], + num_workers=1, + seed=42, + cache_pool_bytes=CACHE_POOL_BYTES, + normalizations=[], + augmentations=[], + ) + dm.setup("fit") + return dm + + +# ====================================================================== +# Benchmark helpers +# ====================================================================== + + +def benchmark_getitems( + dataset: torch.utils.data.Dataset, + batch_size: int, + n_batches: int = N_BATCHES, + warmup: int = WARMUP_BATCHES, +) -> dict: + """Time raw __getitems__ calls. + + Parameters + ---------- + dataset : Dataset + Must implement __getitems__(indices). + batch_size : int + Number of indices per call. + n_batches : int + Total batches to time (excluding warmup). + warmup : int + Batches to discard for cache warmup. + + Returns + ------- + dict + Timing statistics. + """ + n_samples = len(dataset) + rng = np.random.default_rng(42) + total = warmup + n_batches + + times = [] + for i in range(total): + indices = rng.integers(0, n_samples, size=batch_size).tolist() + t0 = time.perf_counter() + _ = dataset.__getitems__(indices) + t1 = time.perf_counter() + if i >= warmup: + times.append(t1 - t0) + + times_arr = np.array(times) + return { + "batch_size": batch_size, + "mean_ms": times_arr.mean() * 1000, + "std_ms": times_arr.std() * 1000, + "median_ms": np.median(times_arr) * 1000, + "p95_ms": np.percentile(times_arr, 95) * 1000, + "throughput_samples_per_sec": batch_size / times_arr.mean(), + } + + +def benchmark_dataloader( + dataloader, + n_batches: int = N_BATCHES, + warmup: int = WARMUP_BATCHES, +) -> dict: + """Time full dataloader iteration. + + Parameters + ---------- + dataloader : DataLoader + Configured dataloader. + n_batches : int + Batches to time after warmup. + warmup : int + Batches to discard. + + Returns + ------- + dict + Timing statistics. + """ + timestamps = [] + total_samples = 0 + + for i, batch in enumerate(dataloader): + if i >= warmup + n_batches: + break + now = time.perf_counter() + if i >= warmup: + timestamps.append(now) + # Count samples in batch + if isinstance(batch, list): + # BatchedConcatDataModule returns list of micro-batches + for mb in batch: + if isinstance(mb, dict) and "anchor" in mb: + total_samples += mb["anchor"].shape[0] + elif isinstance(batch, dict) and "anchor" in batch: + total_samples += batch["anchor"].shape[0] + + if len(timestamps) < 2: + return {"note": "not enough batches"} + + inter_batch = np.diff(timestamps) + return { + "n_batches": len(inter_batch), + "total_samples": total_samples, + "mean_inter_batch_ms": inter_batch.mean() * 1000, + "std_inter_batch_ms": inter_batch.std() * 1000, + "median_inter_batch_ms": np.median(inter_batch) * 1000, + "throughput_samples_per_sec": total_samples / inter_batch.sum() if inter_batch.sum() > 0 else 0, + } + + +# ====================================================================== +# Main +# ====================================================================== + + +def main(): + """Profile and compare dataloader implementations.""" + results = [] + + print("=" * 70) + print("DATALOADER PROFILING") + print("BatchedConcatDataModule + TripletDatasets vs MultiExperimentDataModule") + print("=" * 70) + print("\nDatasets: G3BP1 (2025_07_24) + H2B (2025_04_15)") + print(f"Z range: {Z_RANGE} ({Z_RANGE[1] - Z_RANGE[0]} slices)") + print("Patch: 192x192 -> 160x160") + print(f"Cache: {CACHE_POOL_BYTES / 1e6:.0f} MB") + + # ------------------------------------------------------------------ + # Setup timing + # ------------------------------------------------------------------ + print("\n## Setup: Old (BatchedConcatDataModule + 2x TripletDataModule)") + t0 = time.perf_counter() + old_dm = setup_old() + old_setup_time = time.perf_counter() - t0 + n_old_train = len(old_dm.train_dataset) + n_old_val = len(old_dm.val_dataset) + print(f" Setup time: {_fmt(old_setup_time)}") + print(f" Train samples: {n_old_train} | Val samples: {n_old_val}") + + print("\n## Setup: New (MultiExperimentDataModule)") + t0 = time.perf_counter() + new_dm = setup_new() + new_setup_time = time.perf_counter() - t0 + n_new_train = len(new_dm.train_dataset) + n_new_val = len(new_dm.val_dataset) if new_dm.val_dataset else 0 + print(f" Setup time: {_fmt(new_setup_time)}") + print(f" Train samples: {n_new_train} | Val samples: {n_new_val}") + + # ------------------------------------------------------------------ + # Benchmark 1: Raw __getitems__ + # ------------------------------------------------------------------ + print("\n" + "=" * 70) + print("BENCHMARK 1: Raw __getitems__ (no dataloader, no transforms)") + print("=" * 70) + + for bs in BATCH_SIZES: + print(f"\n### batch_size={bs}") + + stats_old = benchmark_getitems(old_dm.train_dataset, bs) + stats_old["dataset"] = "Old (BatchedConcatDataset)" + results.append(stats_old) + print( + f" Old: {stats_old['mean_ms']:.1f} ± {stats_old['std_ms']:.1f} ms/batch " + f"| p95={stats_old['p95_ms']:.1f} ms " + f"| {stats_old['throughput_samples_per_sec']:.0f} samples/s" + ) + + stats_new = benchmark_getitems(new_dm.train_dataset, bs) + stats_new["dataset"] = "New (MultiExperimentTripletDataset)" + results.append(stats_new) + print( + f" New: {stats_new['mean_ms']:.1f} ± {stats_new['std_ms']:.1f} ms/batch " + f"| p95={stats_new['p95_ms']:.1f} ms " + f"| {stats_new['throughput_samples_per_sec']:.0f} samples/s" + ) + + speedup = stats_old["mean_ms"] / stats_new["mean_ms"] if stats_new["mean_ms"] > 0 else float("inf") + direction = "faster" if speedup > 1 else "slower" + print(f" New is {abs(speedup):.2f}x {direction}") + + # ------------------------------------------------------------------ + # Benchmark 2: Full dataloader + # ------------------------------------------------------------------ + print("\n" + "=" * 70) + print("BENCHMARK 2: Full ThreadDataLoader iteration") + print("=" * 70) + + for bs in [32, 64]: + print(f"\n### batch_size={bs}") + + # Old + old_dm.batch_size = bs + for sub_dm in old_dm.data_modules: + sub_dm.batch_size = bs + old_dl = old_dm.train_dataloader() + dl_old = benchmark_dataloader(old_dl) + print( + f" Old: {dl_old.get('mean_inter_batch_ms', 0):.1f} ± " + f"{dl_old.get('std_inter_batch_ms', 0):.1f} ms/batch " + f"| {dl_old.get('throughput_samples_per_sec', 0):.0f} samples/s" + ) + + # New + new_dm.batch_size = bs + new_dl = new_dm.train_dataloader() + dl_new = benchmark_dataloader(new_dl) + print( + f" New: {dl_new.get('mean_inter_batch_ms', 0):.1f} ± " + f"{dl_new.get('std_inter_batch_ms', 0):.1f} ms/batch " + f"| {dl_new.get('throughput_samples_per_sec', 0):.0f} samples/s" + ) + + # ------------------------------------------------------------------ + # Summary + # ------------------------------------------------------------------ + print("\n" + "=" * 70) + print("SUMMARY") + print("=" * 70) + + print("\n### __getitems__ throughput (samples/sec)") + summary = pd.DataFrame(results) + pivot = summary.pivot_table( + index="batch_size", + columns="dataset", + values="throughput_samples_per_sec", + ) + print(pivot.to_string(float_format=lambda x: f"{x:.0f}")) + + print("\n### Setup times") + print("| Pipeline | Setup Time |") + print("|----------|-----------|") + print(f"| Old (BatchedConcatDataModule) | {_fmt(old_setup_time)} |") + print(f"| New (MultiExperimentDataModule) | {_fmt(new_setup_time)} |") + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/scripts/profiling/profile_num_workers.py b/applications/dynaclr/scripts/profiling/profile_num_workers.py new file mode 100644 index 000000000..e57279021 --- /dev/null +++ b/applications/dynaclr/scripts/profiling/profile_num_workers.py @@ -0,0 +1,175 @@ +"""Sweep num_workers to find optimal dataloader parallelism. + +Holds all other parameters constant and measures end-to-end ThreadDataLoader +throughput (samples/sec and inter-batch latency) for num_workers in [1, 2, 4, 8]. + +Unlike profile_stages.py (which isolates individual pipeline stages) or +profile_dataloaders.py (which compares two dataloader implementations), this +script answers: does adding more CPU workers reduce GPU starvation? + +Usage +----- + uv run python applications/dynaclr/scripts/dataloader_inspection/profile_num_workers.py +""" + +from __future__ import annotations + +import time + +import numpy as np + +from dynaclr.data.datamodule import MultiExperimentDataModule + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + +CELL_INDEX_PARQUET = "applications/dynaclr/configs/cell_index/benchmark_2exp.parquet" + +BATCH_SIZE = 128 +N_BATCHES = 30 +WARMUP = 5 +CACHE_POOL_BYTES = 500_000_000 # 500 MB + +Z_WINDOW = 16 +Z_EXTRACTION_WINDOW = 45 +YX_PATCH = (192, 192) +FINAL_YX_PATCH = (160, 160) + +NUM_WORKERS_SWEEP = [1, 2, 4, 8] + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def setup_dm(num_workers: int) -> MultiExperimentDataModule: + """Build a MultiExperimentDataModule with the given num_workers.""" + dm = MultiExperimentDataModule( + cell_index_path=CELL_INDEX_PARQUET, + z_window=Z_WINDOW, + z_extraction_window=Z_EXTRACTION_WINDOW, + z_focus_offset=0.3, + yx_patch_size=YX_PATCH, + final_yx_patch_size=FINAL_YX_PATCH, + channels_per_sample=1, + positive_cell_source="lookup", + positive_match_columns=["lineage_id"], + tau_range=(0.5, 2.0), + tau_decay_rate=2.0, + stratify_by=["perturbation"], + split_ratio=0.8, + batch_size=BATCH_SIZE, + num_workers=num_workers, + seed=42, + cache_pool_bytes=CACHE_POOL_BYTES, + normalizations=[], + augmentations=[], + ) + dm.setup("fit") + return dm + + +def benchmark_dataloader(dataloader, n_batches: int = N_BATCHES, warmup: int = WARMUP) -> dict: + """Measure inter-batch latency and throughput over the dataloader. + + Parameters + ---------- + dataloader : ThreadDataLoader + Configured training dataloader. + n_batches : int + Number of batches to time after warmup. + warmup : int + Batches to discard for cache/thread warmup. + + Returns + ------- + dict + Inter-batch timing stats, throughput in samples/sec, and VAST bandwidth in MB/s. + """ + timestamps = [] + total_samples = 0 + read_mb_per_batch = None + + for i, batch in enumerate(dataloader): + if i >= warmup + n_batches: + break + now = time.perf_counter() + if i >= warmup: + timestamps.append(now) + if isinstance(batch, dict) and "anchor" in batch: + total_samples += batch["anchor"].shape[0] + if read_mb_per_batch is None: + # anchor + positive (fit mode). Lower bound — ignores chunk alignment overhead. + n_tensors = 2 if "positive" in batch else 1 + read_mb_per_batch = batch["anchor"].nelement() * batch["anchor"].element_size() * n_tensors / 1e6 + + if len(timestamps) < 2: + return {"note": "not enough batches"} + + inter_batch = np.diff(timestamps) + mean_s = inter_batch.mean() + bandwidth_mb_s = read_mb_per_batch / mean_s if read_mb_per_batch else 0.0 + return { + "mean_ms": mean_s * 1000, + "std_ms": inter_batch.std() * 1000, + "median_ms": float(np.median(inter_batch) * 1000), + "p95_ms": float(np.percentile(inter_batch, 95) * 1000), + "throughput_samples_per_sec": total_samples / inter_batch.sum(), + "read_mb_per_batch": read_mb_per_batch or 0.0, + "bandwidth_mb_s": bandwidth_mb_s, + } + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main(): + """Sweep num_workers and report throughput.""" + print("=" * 60) + print("num_workers SWEEP — ThreadDataLoader throughput") + print("=" * 60) + print(f"batch_size={BATCH_SIZE}, z={Z_EXTRACTION_WINDOW}→{Z_WINDOW}") + print(f"patch={YX_PATCH}→{FINAL_YX_PATCH}, channels_per_sample=1") + print(f"warmup={WARMUP} batches, measured over {N_BATCHES} batches") + print() + + # Setup is shared across runs — only the dataloader changes. + # Re-setup for each num_workers since ThreadDataLoader is created in train_dataloader(). + results = [] + for nw in NUM_WORKERS_SWEEP: + print(f"## num_workers={nw}") + dm = setup_dm(nw) + dl = dm.train_dataloader() + stats = benchmark_dataloader(dl) + stats["num_workers"] = nw + results.append(stats) + print( + f" {stats['mean_ms']:.1f} ± {stats['std_ms']:.1f} ms/batch" + f" | p95={stats['p95_ms']:.1f} ms" + f" | {stats['throughput_samples_per_sec']:.0f} samples/sec" + f" | {stats['bandwidth_mb_s']:.0f} MB/s" + ) + print() + + print("=" * 60) + print("SUMMARY") + print("=" * 60) + print() + read_mb = results[0]["read_mb_per_batch"] if results else 0.0 + print(f"Read volume per batch (lower bound): {read_mb:.0f} MB") + print() + print("| num_workers | mean ms/batch | p95 ms | samples/sec | MB/s (VAST) |") + print("|-------------|---------------|--------|-------------|-------------|") + for r in results: + print( + f"| {r['num_workers']:11d} | {r['mean_ms']:13.1f} | {r['p95_ms']:6.1f}" + f" | {r['throughput_samples_per_sec']:11.0f} | {r['bandwidth_mb_s']:11.0f} |" + ) + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/scripts/profiling/profile_predict_batch_size.py b/applications/dynaclr/scripts/profiling/profile_predict_batch_size.py new file mode 100644 index 000000000..a5b820164 --- /dev/null +++ b/applications/dynaclr/scripts/profiling/profile_predict_batch_size.py @@ -0,0 +1,219 @@ +"""Sweep batch_size for prediction to find GPU utilization sweet spot. + +Times the full predict pipeline (dataloader I/O + GPU forward) at increasing +batch sizes to find where GPU utilization saturates on the local A40. + +Uses the microglia-eval parquet and the 2D MIP checkpoint. + +Usage +----- + uv run python applications/dynaclr/scripts/dataloader_inspection/profile_predict_batch_size.py +""" + +from __future__ import annotations + +import time + +import numpy as np +import torch + +from dynaclr.data.datamodule import MultiExperimentDataModule +from viscy_data._utils import _transform_channel_wise +from viscy_models.contrastive import ContrastiveEncoder +from viscy_transforms import BatchedChannelWiseZReductiond, NormalizeSampled + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + +CELL_INDEX_PARQUET = "/hpc/projects/organelle_phenotyping/models/collections/microglia-eval.parquet" +CKPT_PATH = ( + "/hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels" + "/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11" + "/DynaCLR-2D-MIP-BagOfChannels/20260403-150013/checkpoints/last.ckpt" +) + +BATCH_SIZES = [256, 512, 1024, 2048, 4096] +N_BATCHES = 20 +WARMUP = 3 +NUM_WORKERS = 4 +DEVICE = "cuda" + + +# --------------------------------------------------------------------------- +# Setup +# --------------------------------------------------------------------------- + + +def setup_dm(batch_size: int) -> MultiExperimentDataModule: + """Build a predict-mode MultiExperimentDataModule for the given batch size.""" + dm = MultiExperimentDataModule( + cell_index_path=CELL_INDEX_PARQUET, + focus_channel="Phase3D", + reference_pixel_size_xy_um=0.1494, + z_window=1, + z_extraction_window=11, + z_focus_offset=0.5, + yx_patch_size=(192, 192), + final_yx_patch_size=(160, 160), + channels_per_sample=1, + positive_cell_source="lookup", + positive_match_columns=["lineage_id"], + tau_range=(0.5, 2.0), + tau_decay_rate=2.0, + split_ratio=1.0, + batch_size=batch_size, + num_workers=NUM_WORKERS, + pin_memory=True, + seed=42, + normalizations=[ + NormalizeSampled( + keys=["channel_0"], + level="timepoint_statistics", + subtrahend="mean", + divisor="std", + ), + BatchedChannelWiseZReductiond(keys=["channel_0"], allow_missing_keys=True), + ], + augmentations=[], + ) + dm.setup("predict") + return dm + + +def load_model() -> torch.nn.Module: + """Load ConvNeXt-Tiny encoder from the benchmark checkpoint.""" + encoder = ContrastiveEncoder( + backbone="convnext_tiny", + in_channels=1, + in_stack_depth=1, + stem_kernel_size=[1, 4, 4], + stem_stride=[1, 4, 4], + embedding_dim=768, + projection_dim=32, + drop_path_rate=0.0, + ) + ckpt = torch.load(CKPT_PATH, map_location="cpu", weights_only=True) + # checkpoint keys are prefixed with "model." since ContrastiveModule stores encoder as self.model + state = {k.removeprefix("model."): v for k, v in ckpt["state_dict"].items() if k.startswith("model.")} + encoder.load_state_dict(state) + encoder.eval() + encoder.to(DEVICE) + return encoder + + +# --------------------------------------------------------------------------- +# Benchmark +# --------------------------------------------------------------------------- + + +def benchmark(batch_size: int, model: torch.nn.Module) -> dict: + """Time the predict pipeline (I/O + forward) over N_BATCHES after warmup.""" + dm = setup_dm(batch_size) + dl = dm.predict_dataloader() + + forward_times = [] + samples_processed = 0 + t_start = None + + with torch.inference_mode(): + for i, batch in enumerate(dl): + if i >= WARMUP + N_BATCHES: + break + + # Mirror the predict path: apply _predict_transform then forward + norm_meta = batch.get("anchor_norm_meta") + if isinstance(norm_meta, list) and all(m is None for m in norm_meta): + norm_meta = None + anchor = _transform_channel_wise( + transform=dm._predict_transform, + channel_names=dm._channel_names, + patch=batch["anchor"].to(DEVICE), + norm_meta=norm_meta, + ) + + if i == WARMUP: + torch.cuda.synchronize() + t_start = time.perf_counter() + + torch.cuda.synchronize() + t0 = time.perf_counter() + _ = model(anchor) + torch.cuda.synchronize() + t1 = time.perf_counter() + + if i >= WARMUP: + forward_times.append(t1 - t0) + samples_processed += anchor.shape[0] + + torch.cuda.synchronize() + t_end = time.perf_counter() + + wall_s = t_end - t_start if t_start else 1.0 + fwd = np.array(forward_times) * 1000 + + return { + "batch_size": batch_size, + "forward_mean_ms": fwd.mean(), + "forward_std_ms": fwd.std(), + "e2e_samples_per_sec": samples_processed / wall_s, + "gpu_mem_mib": torch.cuda.max_memory_allocated() // (1024**2), + } + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main() -> None: + """Sweep batch sizes and print a throughput summary table.""" + if not torch.cuda.is_available(): + print("No GPU available.") + return + + gpu_name = torch.cuda.get_device_name(0) + total_mib = torch.cuda.get_device_properties(0).total_memory // (1024**2) + print("=" * 65) + print(f"Predict batch_size sweep — {gpu_name} ({total_mib} MiB)") + print("=" * 65) + print(f"num_workers={NUM_WORKERS}, warmup={WARMUP}, measured={N_BATCHES} batches") + print("model: ConvNeXt-Tiny 2D MIP, input 1×1×160×160") + print() + + print("Loading model...") + model = load_model() + torch.cuda.reset_peak_memory_stats() + + results = [] + for bs in BATCH_SIZES: + print(f"batch_size={bs} ...", end=" ", flush=True) + try: + torch.cuda.reset_peak_memory_stats() + r = benchmark(bs, model) + results.append(r) + print( + f"{r['forward_mean_ms']:.1f} ms fwd | " + f"{r['e2e_samples_per_sec']:.0f} samples/sec | " + f"{r['gpu_mem_mib']} MiB" + ) + except torch.cuda.OutOfMemoryError: + print("OOM") + break + + print() + print("=" * 65) + print("SUMMARY") + print("=" * 65) + print() + print("| batch_size | fwd ms | samples/sec | GPU MiB |") + print("|------------|--------|-------------|---------|") + for r in results: + print( + f"| {r['batch_size']:10d} | {r['forward_mean_ms']:6.1f} | " + f"{r['e2e_samples_per_sec']:11.0f} | {r['gpu_mem_mib']:7d} |" + ) + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/scripts/profiling/profile_stages.py b/applications/dynaclr/scripts/profiling/profile_stages.py new file mode 100644 index 000000000..6b7c4b415 --- /dev/null +++ b/applications/dynaclr/scripts/profiling/profile_stages.py @@ -0,0 +1,326 @@ +"""Profile per-stage breakdown: I/O vs normalization vs augmentation vs crop. + +Isolates each stage of the training batch pipeline to find the bottleneck: +1. I/O: __getitems__ (tensorstore read + positive sampling) +2. CPU→GPU: .to(device) transfer +3. Normalization: NormalizeSampled (fov/timepoint stats) +4. Augmentation: affine + flip + contrast + scale + smooth + noise +5. Final crop: BatchedRandSpatialCropd (z_extraction → z_window) + +Uses the new MultiExperimentDataModule with the benchmark_2exp collection. +Requires GPU. + +Usage +----- + uv run python applications/dynaclr/scripts/dataloader_inspection/profile_stages.py +""" + +from __future__ import annotations + +import time + +import numpy as np +import torch +from monai.transforms import Compose + +from dynaclr.data.datamodule import MultiExperimentDataModule +from viscy_transforms import ( + BatchedRandAdjustContrastd, + BatchedRandAffined, + BatchedRandFlipd, + BatchedRandGaussianNoised, + BatchedRandGaussianSmoothd, + BatchedRandScaleIntensityd, + BatchedRandSpatialCropd, + NormalizeSampled, +) + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + +COLLECTION_YAML = "applications/dynaclr/configs/collections/benchmark_2exp.yml" +CELL_INDEX_PARQUET = "applications/dynaclr/configs/cell_index/benchmark_2exp.parquet" + +BATCH_SIZE = 128 +N_BATCHES = 15 +WARMUP = 3 +CACHE_POOL_BYTES = 500_000_000 + +Z_WINDOW = 32 +Z_EXTRACTION_WINDOW = 45 +YX_PATCH = (192, 192) +FINAL_YX_PATCH = (160, 160) + +CHANNEL_KEY = "channel_0" +DEVICE = "cuda" + + +def _fmt(seconds: float) -> str: + if seconds < 1: + return f"{seconds * 1000:.1f} ms" + return f"{seconds:.2f} s" + + +def setup(): + """Set up MultiExperimentDataModule with production-like config.""" + dm = MultiExperimentDataModule( + cell_index_path=CELL_INDEX_PARQUET, + z_window=Z_WINDOW, + z_extraction_window=Z_EXTRACTION_WINDOW, + z_focus_offset=0.3, + yx_patch_size=YX_PATCH, + final_yx_patch_size=FINAL_YX_PATCH, + channels_per_sample=None, + positive_cell_source="lookup", + positive_match_columns=["lineage_id"], + tau_range=(0.5, 2.0), + tau_decay_rate=2.0, + stratify_by=["perturbation"], + split_ratio=0.8, + batch_size=BATCH_SIZE, + num_workers=1, + seed=42, + cache_pool_bytes=CACHE_POOL_BYTES, + normalizations=[], + augmentations=[], + ) + dm.setup("fit") + return dm + + +def build_transforms(): + """Build the individual transform stages matching DynaCLR-3D-BagOfChannels-v2.""" + normalization = NormalizeSampled( + keys=[CHANNEL_KEY], + level="fov_statistics", + subtrahend="mean", + divisor="std", + ) + + augmentations = [ + BatchedRandAffined( + keys=[CHANNEL_KEY], + prob=0.8, + scale_range=[[0.9, 1.1], [0.9, 1.1], [0.9, 1.1]], + rotate_range=[3.14, 0.0, 0.0], + shear_range=[0.05, 0.05, 0.0, 0.05, 0.0, 0.05], + ), + BatchedRandFlipd( + keys=[CHANNEL_KEY], + spatial_axes=[1, 2], + prob=0.5, + ), + BatchedRandAdjustContrastd( + keys=[CHANNEL_KEY], + prob=0.5, + gamma=(0.6, 1.6), + ), + BatchedRandScaleIntensityd( + keys=[CHANNEL_KEY], + prob=0.5, + factors=0.5, + ), + BatchedRandGaussianSmoothd( + keys=[CHANNEL_KEY], + prob=0.5, + sigma_x=[0.25, 0.50], + sigma_y=[0.25, 0.50], + sigma_z=[0.0, 0.2], + ), + BatchedRandGaussianNoised( + keys=[CHANNEL_KEY], + prob=0.5, + mean=0.0, + std=0.1, + ), + ] + + final_crop = BatchedRandSpatialCropd( + keys=[CHANNEL_KEY], + roi_size=(Z_WINDOW, FINAL_YX_PATCH[0], FINAL_YX_PATCH[1]), + ) + + return normalization, augmentations, final_crop + + +def time_stage(fn, n_batches=N_BATCHES, warmup=WARMUP): + """Time a callable over multiple iterations, return stats. + + Parameters + ---------- + fn : callable + Function to time. Called with no arguments. + n_batches : int + Iterations to time after warmup. + warmup : int + Iterations to discard. + + Returns + ------- + dict + mean_ms, std_ms, median_ms. + """ + times = [] + for i in range(warmup + n_batches): + if DEVICE == "cuda": + torch.cuda.synchronize() + t0 = time.perf_counter() + result = fn() + if DEVICE == "cuda": + torch.cuda.synchronize() + t1 = time.perf_counter() + if i >= warmup: + times.append(t1 - t0) + arr = np.array(times) + return { + "mean_ms": arr.mean() * 1000, + "std_ms": arr.std() * 1000, + "median_ms": np.median(arr) * 1000, + }, result + + +def main(): + """Profile individual dataloader pipeline stages.""" + print("=" * 70) + print("STAGE BREAKDOWN: I/O → Transfer → Normalize → Augment → Crop") + print("=" * 70) + print(f"batch_size={BATCH_SIZE}, z_extraction={Z_EXTRACTION_WINDOW}→z_window={Z_WINDOW}") + print(f"patch={YX_PATCH}→{FINAL_YX_PATCH}, device={DEVICE}") + print() + + # Setup + dm = setup() + dataset = dm.train_dataset + normalization, augmentations, final_crop = build_transforms() + rng = np.random.default_rng(42) + n_samples = len(dataset) + + def random_indices(): + return rng.integers(0, n_samples, size=BATCH_SIZE).tolist() + + # Pre-generate index lists so index generation doesn't pollute timing + all_indices = [random_indices() for _ in range(WARMUP + N_BATCHES + 5)] + idx_iter = iter(all_indices) + + # ── Stage 1: I/O (__getitems__) ── + print("## Stage 1: I/O (__getitems__)") + batches = [] + + def io_step(): + indices = next(idx_iter) + batch = dataset.__getitems__(indices) + batches.append(batch) + return batch + + io_stats, _ = time_stage(io_step) + + # Use the last batch for subsequent stages + sample_batch = batches[-1] + anchor = sample_batch["anchor"] + positive = sample_batch.get("positive") + + # Read volume: what was actually fetched from VAST (z_extraction_window, not z_window). + # anchor + positive (fit mode reads both). Lower bound — chunk alignment may add overhead. + n_tensors = 2 if positive is not None else 1 + read_bytes = anchor.nelement() * anchor.element_size() * n_tensors + read_mb = read_bytes / 1e6 + bandwidth_mb_s = read_mb / (io_stats["mean_ms"] / 1000) + io_stats["read_mb"] = read_mb + io_stats["bandwidth_mb_s"] = bandwidth_mb_s + + print(f" {io_stats['mean_ms']:.1f} ± {io_stats['std_ms']:.1f} ms") + pos_label = "+ positive" if positive is not None else "" + print(f" read volume: {read_mb:.0f} MB (anchor{pos_label}) | bandwidth: {bandwidth_mb_s:.0f} MB/s") + print(f" anchor shape: {anchor.shape}, dtype: {anchor.dtype}") + + # ── Stage 2: CPU→GPU transfer ── + print("\n## Stage 2: CPU → GPU transfer") + + def transfer_step(): + return anchor.to(DEVICE, non_blocking=True) + + transfer_stats, gpu_anchor = time_stage(transfer_step) + print(f" {transfer_stats['mean_ms']:.1f} ± {transfer_stats['std_ms']:.1f} ms") + print(f" tensor size: {anchor.nelement() * anchor.element_size() / 1e6:.1f} MB") + + # ── Stage 3: Normalization ── + print("\n## Stage 3: Normalization (subtract mean, divide std — manual)") + # NormalizeSampled via _transform_channel_wise requires channel-name + # alignment that depends on the full DataModule context. Time the raw + # arithmetic instead: this is what NormalizeSampled does per channel. + + def norm_step(): + x = gpu_anchor.clone() + mean = x.mean(dim=(-3, -2, -1), keepdim=True) + std = x.std(dim=(-3, -2, -1), keepdim=True) + return (x - mean) / (std + 1e-8) + + norm_stats, normed = time_stage(norm_step) + print(f" {norm_stats['mean_ms']:.1f} ± {norm_stats['std_ms']:.1f} ms") + + # ── Stage 4: Augmentations (individually) ── + print("\n## Stage 4: Augmentations (individual)") + aug_names = [ + "RandAffined", + "RandFlipd", + "RandAdjustContrastd", + "RandScaleIntensityd", + "RandGaussianSmoothd", + "RandGaussianNoised", + ] + aug_total = 0.0 + current_input = normed + + for aug_name, aug_transform in zip(aug_names, augmentations): + t = Compose([aug_transform]) + inp = current_input + + def aug_step(transform=t, data=inp): + d = {CHANNEL_KEY: data.clone()} + return transform(d)[CHANNEL_KEY] + + stats, current_input = time_stage(aug_step) + aug_total += stats["mean_ms"] + print(f" {aug_name:30s} {stats['mean_ms']:8.1f} ± {stats['std_ms']:.1f} ms") + + print(f" {'TOTAL':30s} {aug_total:8.1f} ms") + + # ── Stage 5: Final crop ── + print("\n## Stage 5: Final crop (BatchedRandSpatialCropd)") + crop_input = current_input + + def crop_step(): + d = {CHANNEL_KEY: crop_input.clone()} + return final_crop(d)[CHANNEL_KEY] + + crop_stats, _ = time_stage(crop_step) + print(f" {crop_stats['mean_ms']:.1f} ± {crop_stats['std_ms']:.1f} ms") + + # ── Summary ── + print("\n" + "=" * 70) + print("SUMMARY (mean ms per batch)") + print("=" * 70) + + stages = { + "I/O (__getitems__)": io_stats["mean_ms"], + "CPU→GPU transfer": transfer_stats["mean_ms"], + "Normalization": norm_stats["mean_ms"], + "Augmentations (total)": aug_total, + "Final crop": crop_stats["mean_ms"], + } + total = sum(stages.values()) + + print("\n| Stage | Time (ms) | % of total | Bandwidth |") + print("|-------|-----------|------------|-----------|") + for name, ms in stages.items(): + if name == "I/O (__getitems__)": + bw = f"{io_stats['bandwidth_mb_s']:.0f} MB/s ({io_stats['read_mb']:.0f} MB read)" + else: + bw = "—" + print(f"| {name} | {ms:.1f} | {ms / total * 100:.1f}% | {bw} |") + print(f"| **Total** | **{total:.1f}** | **100%** | |") + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/scripts/pseudotime/3-organelle-remodeling/compute_label_timing.py b/applications/dynaclr/scripts/pseudotime/3-organelle-remodeling/compute_label_timing.py new file mode 100644 index 000000000..11590b185 --- /dev/null +++ b/applications/dynaclr/scripts/pseudotime/3-organelle-remodeling/compute_label_timing.py @@ -0,0 +1,466 @@ +r"""Per-cell label-timing metrics from linear classifier predictions (Stage 3). + +Embedding-based timing (``compute_timing_metrics.py``) measures cosine +distance from each cell's pre-baseline. This script is the label-side +complement: it reduces each cell's *predicted label* trajectory to timing +scalars. Both scripts share the sensor-aligned ``t_rel`` axis, so their +outputs are directly comparable. + +Label taxonomy (~/memory/project_label_taxonomy.md): + +- ``{state}`` : human annotation (sparse). +- ``predicted_{state}`` : linear classifier output (dense). **Used here.** +- ``dtw_{state}`` : DTW-propagated template label (aligned-only). + +Per-cell metrics on the binarized predicted-label trajectory (1 = positive): + +- ``t_first_pos`` : first t_rel where the cell is predicted positive. +- ``t_run_start`` : first t_rel where the cell enters a run of + ``min_run`` consecutive positive predictions + (default 3). Robust to single-frame flicker. +- ``t_run_end`` : last t_rel where the cell is in a positive run. +- ``pos_duration`` : ``t_run_end − t_run_start`` (minutes). +- ``pos_fraction`` : fraction of aligned frames predicted positive. +- ``flips`` : number of 0→1 or 1→0 transitions over the full track. + +Outputs: + +- ``_per_cell.parquet`` : one row per cell. +- ``_summary.md`` : per-well + pooled median ± bootstrap CI. + +Example:: + + cd applications/dynaclr/scripts/pseudotime/3-organelle-remodeling + uv run python compute_label_timing.py compute \ + --datasets ../../../configs/pseudotime/datasets.yaml \ + --config ../../../configs/pseudotime/align_cells.yaml \ + --template infection_nondividing_sensor --flavor raw \ + --query-set sensor_all_07_24 \ + --organelle-channel organelle_sec61 \ + --state-column organelle_state --state-positive remodel \ + --top-n 30 + +Pair a SEC61 and G3BP1 run then:: + + uv run python compute_label_timing.py compare \ + --per-cell timing_labels/..._sec61_..._per_cell.parquet \ + timing_labels/..._g3bp1_..._per_cell.parquet \ + --out-stem timing_labels/compare_sec61_vs_g3bp1 +""" + +from __future__ import annotations + +import argparse +import logging +import sys +from pathlib import Path + +import anndata as ad +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import zarr +from scipy import stats + +matplotlib.use("Agg") + +SCRIPT_DIR = Path(__file__).resolve().parent +TEMPLATES_DIR = SCRIPT_DIR.parent / "1-build_template" / "templates" +ALIGNMENTS_DIR = SCRIPT_DIR.parent / "2-align_cells" / "alignments" +OUT_DIR = SCRIPT_DIR / "timing_labels" + +sys.path.insert(0, str(SCRIPT_DIR.parent)) +sys.path.insert(0, str(SCRIPT_DIR.parent / "1-build_template")) +from evaluate_template import _date_prefix_from_dataset_id, _find_zarr # noqa: E402 +from utils import load_stage_config # noqa: E402 + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s") +_logger = logging.getLogger(__name__) + + +def _top_n_cells(alignments: pd.DataFrame, top_n: int) -> pd.DataFrame: + """Select rows belonging to the top-N cells by length-normalized DTW cost.""" + costs = alignments.groupby(["dataset_id", "fov_name", "track_id"])["length_normalized_cost"].first() + top_keys = set(costs.sort_values().head(top_n).index) + mask = [ + (ds, fov, tid) in top_keys + for ds, fov, tid in zip(alignments["dataset_id"], alignments["fov_name"], alignments["track_id"]) + ] + return alignments[mask].reset_index(drop=True) + + +def _lookup_predicted_labels( + selected: pd.DataFrame, + dataset_cfgs: dict[str, dict], + organelle_pattern: str, + predicted_column: str, + positive_value: str, +) -> np.ndarray: + """Per-row binarized predicted-label value (1.0, 0.0, or NaN if missing).""" + labels = np.full(len(selected), np.nan, dtype=np.float64) + for dataset_id, ds_rows in selected.groupby("dataset_id"): + ds_cfg = dataset_cfgs[dataset_id] + prefix = _date_prefix_from_dataset_id(dataset_id) + zarr_path = _find_zarr(ds_cfg["pred_dir"], prefix + organelle_pattern) + adata = ad.read_zarr(zarr_path) + adata.obs_names_make_unique() + if predicted_column not in adata.obs.columns: + _logger.warning(f" [{dataset_id}] obs has no {predicted_column!r} column — skipping") + continue + lookup = { + (str(row["fov_name"]), int(row["track_id"]), int(row["t"])): str(row[predicted_column]) + for _, row in adata.obs.iterrows() + } + for idx_local, row in enumerate(ds_rows.itertuples(index=False)): + key = (str(row.fov_name), int(row.track_id), int(row.t)) + val = lookup.get(key) + if val is None or val == "nan": + continue + global_idx = ds_rows.index[idx_local] + labels[global_idx] = 1.0 if val == positive_value else 0.0 + return labels + + +def _longest_positive_run(is_pos: np.ndarray, min_run: int) -> tuple[int, int] | None: + """Return (start_idx, end_idx) of the earliest run of ≥``min_run`` consecutive True values.""" + in_run = False + run_start = -1 + for i, v in enumerate(is_pos): + if v and not in_run: + in_run = True + run_start = i + elif not v and in_run: + if i - run_start >= min_run: + return run_start, i - 1 + in_run = False + if in_run and len(is_pos) - run_start >= min_run: + return run_start, len(is_pos) - 1 + return None + + +def _compute_per_cell( + selected: pd.DataFrame, + labels: np.ndarray, + t_rel: np.ndarray, + min_run: int, +) -> pd.DataFrame: + """Return one row per (dataset_id, fov, track_id) with label-timing scalars.""" + df = selected.copy() + df["predicted_pos"] = labels + df["t_rel"] = t_rel + + rows = [] + for (ds, fov, tid), grp in df.groupby(["dataset_id", "fov_name", "track_id"], sort=False): + grp = grp.sort_values("t_rel") + y = grp["predicted_pos"].to_numpy(dtype=float) + t = grp["t_rel"].to_numpy(dtype=float) + aligned_mask = grp["alignment_region"].to_numpy() == "aligned" + mask = np.isfinite(y) & np.isfinite(t) + if mask.sum() < 3: + continue + y = y[mask] + t = t[mask] + aligned_mask = aligned_mask[mask] + + is_pos = y == 1.0 + flips = int(np.abs(np.diff(y)).sum()) + + if is_pos.any(): + t_first_pos = float(t[int(np.argmax(is_pos))]) + else: + t_first_pos = np.nan + + run = _longest_positive_run(is_pos, min_run=min_run) + if run is not None: + t_run_start = float(t[run[0]]) + t_run_end = float(t[run[1]]) + pos_duration = t_run_end - t_run_start + else: + t_run_start = np.nan + t_run_end = np.nan + pos_duration = np.nan + + if aligned_mask.any(): + pos_fraction = float(is_pos[aligned_mask].mean()) + else: + pos_fraction = float(is_pos.mean()) + + rows.append( + { + "dataset_id": ds, + "fov_name": fov, + "track_id": int(tid), + "cell_uid": f"{ds}/{fov}/{tid}", + "well": _extract_well(fov), + "length_normalized_cost": float(grp["length_normalized_cost"].iloc[0]), + "n_frames_labeled": int(mask.sum()), + "t_first_pos": t_first_pos, + "t_run_start": t_run_start, + "t_run_end": t_run_end, + "pos_duration": pos_duration, + "pos_fraction": pos_fraction, + "flips": flips, + } + ) + return pd.DataFrame(rows) + + +def _extract_well(fov_name: str) -> str: + """Return ``'A/2'`` from ``'A/2/000000'`` style FOV names.""" + parts = fov_name.split("/") + if len(parts) >= 2: + return "/".join(parts[:2]) + return fov_name + + +def _bootstrap_ci(values: np.ndarray, n_boot: int = 2000, alpha: float = 0.05) -> tuple[float, float, float]: + """Percentile bootstrap on the median.""" + values = values[np.isfinite(values)] + if len(values) == 0: + return float("nan"), float("nan"), float("nan") + if len(values) == 1: + v = float(values[0]) + return v, v, v + rng = np.random.default_rng(42) + boots = np.empty(n_boot) + for i in range(n_boot): + boots[i] = np.median(rng.choice(values, size=len(values), replace=True)) + return float(np.median(values)), float(np.quantile(boots, alpha / 2)), float(np.quantile(boots, 1 - alpha / 2)) + + +def _summary_markdown(per_cell: pd.DataFrame, state_column: str, organelle_channel: str) -> str: + """Per-well + pooled markdown summary.""" + lines = [f"# Label-timing metrics — predicted_{state_column} ({organelle_channel})", ""] + lines.append(f"**n cells**: {len(per_cell)}") + lines.append("") + lines.append("## Per-well medians") + lines.append("") + header = ( + "| well | n | t_first_pos (min) | t_run_start (min) | t_run_end (min) | " + "pos_duration (min) | pos_fraction | flips |" + ) + lines.append(header) + lines.append("|---|---|---|---|---|---|---|---|") + for well, grp in per_cell.groupby("well"): + lines.append( + f"| {well} | {len(grp)} | " + f"{grp['t_first_pos'].median():.0f} | {grp['t_run_start'].median():.0f} | " + f"{grp['t_run_end'].median():.0f} | {grp['pos_duration'].median():.0f} | " + f"{grp['pos_fraction'].median():.3f} | {grp['flips'].median():.0f} |" + ) + lines.append("") + lines.append("## Pooled median ± 95% bootstrap CI") + lines.append("") + lines.append("| metric | median | 95% CI |") + lines.append("|---|---|---|") + for metric in ["t_first_pos", "t_run_start", "t_run_end", "pos_duration", "pos_fraction", "flips"]: + med, lo, hi = _bootstrap_ci(per_cell[metric].to_numpy(dtype=float)) + lines.append(f"| {metric} | {med:.3f} | [{lo:.3f}, {hi:.3f}] |") + lines.append("") + return "\n".join(lines) + + +def _compare(per_cell_files: list[Path], out_stem: Path, group_by: str | None = None) -> None: + """Merge per-cell parquets, emit strips + stats grouped by a column. + + Parameters + ---------- + per_cell_files : list[Path] + Per-cell parquets written by ``compute``. + out_stem : Path + Output path stem (no extension). + group_by : str or None + Column to group cells by in the comparison plot/stats. If ``None`` + (default), auto-select: use ``organelle_channel`` when multiple + organelle values are present, otherwise fall back to ``query_set`` + so cross-virus pools (same organelle, different query sets) split + correctly. + """ + dfs = [pd.read_parquet(p) for p in per_cell_files] + merged = pd.concat(dfs, ignore_index=True) + + metrics = ["t_first_pos", "t_run_start", "t_run_end", "pos_duration", "pos_fraction", "flips"] + if group_by is None: + n_organelles = len(merged["organelle_channel"].unique()) + group_by = "organelle_channel" if n_organelles > 1 else "query_set" + organelles = sorted(merged[group_by].unique()) + + fig, axes = plt.subplots(1, len(metrics), figsize=(3.3 * len(metrics), 4.2), squeeze=False) + axes = axes[0] + colors = plt.get_cmap("tab10").colors + for ax, metric in zip(axes, metrics): + for i, org in enumerate(organelles): + vals = merged.loc[merged[group_by] == org, metric].to_numpy(dtype=float) + vals = vals[np.isfinite(vals)] + if len(vals) == 0: + continue + jitter = np.random.default_rng(0).uniform(-0.12, 0.12, size=len(vals)) + ax.scatter( + np.full_like(vals, i, dtype=float) + jitter, + vals, + s=22, + color=colors[i % len(colors)], + alpha=0.7, + edgecolor="none", + ) + med, lo, hi = _bootstrap_ci(vals) + ax.hlines(med, i - 0.25, i + 0.25, color="black", linewidth=2, zorder=5) + ax.vlines(i, lo, hi, color="black", linewidth=1.2, zorder=5) + ax.set_xticks(np.arange(len(organelles))) + ax.set_xticklabels(organelles, rotation=30, ha="right") + ax.set_ylabel(metric) + ax.set_title(metric) + + fig.tight_layout() + out_stem.parent.mkdir(parents=True, exist_ok=True) + png = out_stem.with_suffix(".png") + fig.savefig(png, dpi=160, bbox_inches="tight", facecolor="white") + plt.close(fig) + _logger.info(f"Wrote {png}") + + lines = [ + "# Label-timing comparison", + "", + f"**Grouped by**: `{group_by}`", + f"**Groups**: {', '.join(organelles)}", + "", + ] + for metric in metrics: + lines.append(f"## {metric}") + lines.append("") + lines.append(f"| {group_by} | n | median | 95% CI |") + lines.append("|---|---|---|---|") + per_org = {} + for org in organelles: + vals = merged.loc[merged[group_by] == org, metric].to_numpy(dtype=float) + vals = vals[np.isfinite(vals)] + per_org[org] = vals + med, lo, hi = _bootstrap_ci(vals) + lines.append(f"| {org} | {len(vals)} | {med:.3f} | [{lo:.3f}, {hi:.3f}] |") + lines.append("") + if len(organelles) >= 2: + lines.append("**Pairwise rank-sum tests**") + lines.append("") + lines.append("| a | b | median(a) − median(b) | U | p |") + lines.append("|---|---|---|---|---|") + for i in range(len(organelles)): + for j in range(i + 1, len(organelles)): + a, b = per_org[organelles[i]], per_org[organelles[j]] + if len(a) >= 2 and len(b) >= 2: + u, p = stats.mannwhitneyu(a, b, alternative="two-sided") + diff = float(np.median(a) - np.median(b)) + lines.append(f"| {organelles[i]} | {organelles[j]} | {diff:.3f} | {u:.1f} | {p:.3g} |") + lines.append("") + + md = out_stem.with_suffix(".md") + md.write_text("\n".join(lines)) + _logger.info(f"Wrote {md}") + + +def main() -> None: + """Compute per-cell label timing OR merge across organelles.""" + parser = argparse.ArgumentParser(description="Per-cell label-timing from LC predictions.") + sub = parser.add_subparsers(dest="cmd", required=True) + + p_c = sub.add_parser("compute") + p_c.add_argument("--datasets", required=True) + p_c.add_argument("--config", required=True) + p_c.add_argument("--template", required=True) + p_c.add_argument("--flavor", choices=["raw", "pca"], default="raw") + p_c.add_argument("--query-set", required=True) + p_c.add_argument("--organelle-channel", required=True) + p_c.add_argument( + "--state-column", required=True, help="Base state column; the script looks up 'predicted_{state_column}'." + ) + p_c.add_argument("--state-positive", required=True) + p_c.add_argument("--top-n", type=int, default=30) + p_c.add_argument( + "--min-run", type=int, default=3, help="Minimum consecutive positive frames for t_run_start (flicker filter)." + ) + + p_cmp = sub.add_parser("compare") + p_cmp.add_argument("--per-cell", nargs="+", required=True) + p_cmp.add_argument("--out-stem", required=True) + p_cmp.add_argument( + "--group-by", + default=None, + help=( + "Column to split cells by. Default auto-picks organelle_channel " + "if multiple organelles are present, else query_set (so cross-virus " + "pools with the same organelle split correctly)." + ), + ) + + args = parser.parse_args() + + if args.cmd == "compute": + config = load_stage_config(args.datasets, args.config) + dataset_cfgs = {d["dataset_id"]: d for d in config["datasets"]} + if args.organelle_channel not in config["embeddings"]: + raise ValueError(f"organelle-channel {args.organelle_channel!r} not in embeddings") + organelle_pattern = config["embeddings"][args.organelle_channel] + + alignment_path = ALIGNMENTS_DIR / f"{args.template}_{args.flavor}_on_{args.query_set}.parquet" + if not alignment_path.exists(): + raise FileNotFoundError(alignment_path) + alignments = pd.read_parquet(alignment_path) + + selected = _top_n_cells(alignments, args.top_n) + frame_interval_by_ds = {d["dataset_id"]: float(d["frame_interval_minutes"]) for d in config["datasets"]} + selected = selected.copy() + selected["frame_interval"] = selected["dataset_id"].map(frame_interval_by_ds) + + template_path = TEMPLATES_DIR / f"template_{args.template}.zarr" + tc_grp = zarr.open(str(template_path), mode="r")[args.flavor] + tc = np.asarray(tc_grp["time_calibration"]) if "time_calibration" in tc_grp else None + + def _extrapolate(row): + if row["alignment_region"] == "aligned": + return float(row["estimated_t_rel_minutes"]) + fi = row["frame_interval"] + if tc is None: + return float("nan") + if row["alignment_region"] == "pre": + return float(tc[0] + (row["t"] - row["match_q_start"]) * fi) + return float(tc[-1] + (row["t"] - row["match_q_end"]) * fi) + + selected["t_rel_minutes_extrap"] = selected.apply(_extrapolate, axis=1) + + predicted_col = f"predicted_{args.state_column}" + _logger.info(f"Looking up {predicted_col!r} (positive={args.state_positive!r}) from {organelle_pattern}") + labels = _lookup_predicted_labels(selected, dataset_cfgs, organelle_pattern, predicted_col, args.state_positive) + n_labeled = int(np.isfinite(labels).sum()) + _logger.info(f" {n_labeled}/{len(labels)} rows labeled") + if n_labeled == 0: + raise RuntimeError( + f"No rows had {predicted_col!r}. Has the linear classifier been run for this dataset/organelle?" + ) + + t_rel = selected["t_rel_minutes_extrap"].to_numpy(dtype=float) + per_cell = _compute_per_cell(selected, labels, t_rel, min_run=args.min_run) + per_cell["organelle_channel"] = args.organelle_channel + per_cell["state_column"] = args.state_column + per_cell["template"] = args.template + per_cell["flavor"] = args.flavor + per_cell["query_set"] = args.query_set + + OUT_DIR.mkdir(parents=True, exist_ok=True) + stem = OUT_DIR / ( + f"label_timing_{args.template}_{args.flavor}_{args.organelle_channel}_{args.state_column}_{args.query_set}" + ) + parquet = stem.with_name(stem.name + "_per_cell.parquet") + per_cell.to_parquet(parquet, index=False) + _logger.info(f"Wrote {parquet} ({len(per_cell)} cells)") + + md = _summary_markdown(per_cell, args.state_column, args.organelle_channel) + md_path = stem.with_name(stem.name + "_summary.md") + md_path.write_text(md) + _logger.info(f"Wrote {md_path}") + + elif args.cmd == "compare": + _compare([Path(p) for p in args.per_cell], Path(args.out_stem), group_by=args.group_by) + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/scripts/pseudotime/3-organelle-remodeling/compute_timing_metrics.py b/applications/dynaclr/scripts/pseudotime/3-organelle-remodeling/compute_timing_metrics.py new file mode 100644 index 000000000..c5d632888 --- /dev/null +++ b/applications/dynaclr/scripts/pseudotime/3-organelle-remodeling/compute_timing_metrics.py @@ -0,0 +1,464 @@ +r"""Per-cell timing metrics for organelle remodeling (Stage 3 analysis). + +Given a sensor alignment parquet and one organelle channel, computes per-cell +timing scalars on each cell's cosine-distance-from-pre-baseline curve, then +pools them into a per-organelle distribution. Cross-organelle comparisons are +population-level (disjoint FOVs share only the sensor-aligned t_rel axis). + +Metrics per cell (computed on the aligned region only): + +- ``t_onset_abs`` : first t_rel where (distance − pre_median) crosses + an absolute threshold (default 0.10 cosine-distance + units). SNR-robust: small Δpeak cells can't fake an + onset by having their noise floor crossed. +- ``t50`` : first t_rel where distance crosses pre_median + 0.5 × Δpeak, + restricted to the pre-endpoint window so DTW endpoint-pinning + doesn't saturate the metric. +- ``t_peak`` : t_rel of argmax distance within the *interior* of the + aligned region (last 2 frames excluded — they're where + DTW endpoint-pinning crowds cells onto ``tc[-1]``). +- ``rise_rate_per_hour`` : slope of distance vs t_rel over the aligned region, + in Δcos per hour (not per minute). +- ``delta_peak`` : max(aligned distance) − median(pre distance). + +Outputs: + +- ``_per_cell.parquet`` : one row per cell with all metrics + dataset_id, + fov_name, track_id, organelle_channel, length_normalized_cost. +- ``_summary.md`` : markdown summary — per-well medians, pooled + median ± 95% bootstrap CI, rank-sum vs a reference organelle (optional). +- ``_strips.png`` : per-metric strip/violin comparing organelles + (only meaningful when called twice with different organelles then merged). + +Example:: + + cd applications/dynaclr/scripts/pseudotime/3-organelle-remodeling + uv run python compute_timing_metrics.py \ + --datasets ../../../configs/pseudotime/datasets.yaml \ + --config ../../../configs/pseudotime/align_cells.yaml \ + --template infection_nondividing_sensor --flavor raw \ + --query-set sensor_all_07_24 \ + --organelle-channel organelle_sec61 --top-n 30 + +Run twice (once per organelle) then pass both per-cell parquets to +``--compare`` to emit cross-organelle plots and stats. +""" + +from __future__ import annotations + +import argparse +import logging +import sys +from pathlib import Path + +import anndata as ad +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import zarr +from scipy import stats + +matplotlib.use("Agg") + +SCRIPT_DIR = Path(__file__).resolve().parent +TEMPLATES_DIR = SCRIPT_DIR.parent / "1-build_template" / "templates" +ALIGNMENTS_DIR = SCRIPT_DIR.parent / "2-align_cells" / "alignments" +OUT_DIR = SCRIPT_DIR / "timing" + +sys.path.insert(0, str(SCRIPT_DIR.parent)) +sys.path.insert(0, str(SCRIPT_DIR.parent / "1-build_template")) +from evaluate_template import _date_prefix_from_dataset_id, _find_zarr # noqa: E402 +from utils import load_stage_config # noqa: E402 + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s") +_logger = logging.getLogger(__name__) + + +def _top_n_cells(alignments: pd.DataFrame, top_n: int) -> pd.DataFrame: + """Select rows for top-N cells ranked by length-normalized DTW cost.""" + costs = alignments.groupby(["dataset_id", "fov_name", "track_id"])["length_normalized_cost"].first() + top_keys = set(costs.sort_values().head(top_n).index) + mask = [ + (ds, fov, tid) in top_keys + for ds, fov, tid in zip(alignments["dataset_id"], alignments["fov_name"], alignments["track_id"]) + ] + return alignments[mask].reset_index(drop=True) + + +def _join_organelle_embeddings( + selected: pd.DataFrame, + dataset_cfgs: dict[str, dict], + organelle_pattern: str, +) -> pd.DataFrame: + """Attach organelle embedding vectors via ``(fov, track, t)`` lookup.""" + parts = [] + for dataset_id, ds_align in selected.groupby("dataset_id"): + ds_cfg = dataset_cfgs[dataset_id] + prefix = _date_prefix_from_dataset_id(dataset_id) + zarr_path = _find_zarr(ds_cfg["pred_dir"], prefix + organelle_pattern) + adata = ad.read_zarr(zarr_path) + adata.obs_names_make_unique() + X = adata.X + if hasattr(X, "toarray"): + X = X.toarray() + X = np.asarray(X, dtype=np.float64) + obs = adata.obs.reset_index(drop=True) + lookup = {(str(row["fov_name"]), int(row["track_id"]), int(row["t"])): i for i, row in obs.iterrows()} + aligned_rows = ds_align.reset_index(drop=True).copy() + embeddings = [] + for _, row in aligned_rows.iterrows(): + key = (str(row["fov_name"]), int(row["track_id"]), int(row["t"])) + idx = lookup.get(key) + embeddings.append(X[idx] if idx is not None else None) + aligned_rows["embedding"] = embeddings + aligned_rows = aligned_rows[aligned_rows["embedding"].notna()].reset_index(drop=True) + parts.append(aligned_rows) + return pd.concat(parts, ignore_index=True) + + +def _cosine_distance_from_baseline(joined: pd.DataFrame) -> np.ndarray: + """Per-frame cosine distance to that cell's mean pre-event embedding.""" + distances = np.full(len(joined), np.nan, dtype=np.float64) + for (_, _, _), group in joined.groupby(["dataset_id", "fov_name", "track_id"], sort=False): + idx = group.index.to_numpy() + emb = np.stack(group["embedding"].to_list()) + pre_mask = group["alignment_region"].to_numpy() == "pre" + if pre_mask.any(): + baseline = emb[pre_mask].mean(axis=0) + else: + aligned_mask = group["alignment_region"].to_numpy() == "aligned" + if not aligned_mask.any(): + continue + earliest = aligned_mask.nonzero()[0][: max(1, aligned_mask.sum() // 4)] + baseline = emb[earliest].mean(axis=0) + bn = np.linalg.norm(baseline) + en = np.linalg.norm(emb, axis=1) + denom = bn * en + cos_sim = np.where(denom > 0, (emb @ baseline) / np.where(denom > 0, denom, 1.0), 0.0) + distances[idx] = 1.0 - cos_sim + return distances + + +def _compute_per_cell_metrics( + joined: pd.DataFrame, + distances: np.ndarray, + t_rel: np.ndarray, +) -> pd.DataFrame: + """Return one row per (dataset_id, fov_name, track_id) with timing scalars.""" + joined = joined.copy() + joined["distance"] = distances + joined["t_rel"] = t_rel + + rows = [] + for (ds, fov, tid), grp in joined.groupby(["dataset_id", "fov_name", "track_id"], sort=False): + aligned = grp[grp["alignment_region"] == "aligned"].sort_values("t_rel") + pre = grp[grp["alignment_region"] == "pre"] + if len(aligned) < 3: + continue + a_t = aligned["t_rel"].to_numpy(dtype=float) + a_d = aligned["distance"].to_numpy(dtype=float) + mask = np.isfinite(a_t) & np.isfinite(a_d) + if mask.sum() < 3: + continue + a_t = a_t[mask] + a_d = a_d[mask] + + pre_median = float(np.nanmedian(pre["distance"])) if len(pre) else float(a_d.min()) + + # Drop the last 2 aligned frames for peak/t_peak/t50 — DTW endpoint + # constraints pin many cells' warp paths onto tc[-1], crowding frames + # at the last template position. The INTERIOR peak is what reflects + # true remodeling amplitude; the endpoint pile-up is a warp-path artifact. + interior_n = max(3, len(a_t) - 2) + i_t = a_t[:interior_n] + i_d = a_d[:interior_n] + peak = float(i_d.max()) + delta_peak = peak - pre_median + + # t50 on the interior (half-rise in absolute units, not normalized). + if delta_peak <= 1e-6: + t50 = np.nan + else: + t50 = _first_crossing(i_t, i_d, pre_median + 0.5 * delta_peak) + + # Absolute-threshold onset (SNR-robust across cells with different Δpeak). + t_onset_abs = _first_crossing(a_t, a_d, pre_median + 0.10) + + t_peak = float(i_t[int(np.argmax(i_d))]) + + # Rise-rate in Δcos per hour (multiply per-minute slope by 60). + if len(a_t) >= 2 and (a_t.max() - a_t.min()) > 1e-6: + slope, _intercept, _r, _p, _se = stats.linregress(a_t, a_d) + rise_rate_per_hour = float(slope) * 60.0 + else: + rise_rate_per_hour = np.nan + + rows.append( + { + "dataset_id": ds, + "fov_name": fov, + "track_id": int(tid), + "cell_uid": f"{ds}/{fov}/{tid}", + "well": _extract_well(fov), + "length_normalized_cost": float(grp["length_normalized_cost"].iloc[0]), + "n_aligned_frames": int(len(a_t)), + "pre_median_distance": pre_median, + "peak_distance": peak, + "delta_peak": delta_peak, + "t_onset_abs": t_onset_abs, + "t50": t50, + "t_peak": t_peak, + "rise_rate_per_hour": rise_rate_per_hour, + } + ) + return pd.DataFrame(rows) + + +def _first_crossing(t: np.ndarray, y: np.ndarray, threshold: float) -> float: + """First ``t`` value where the signal crosses ``threshold`` upward, linearly interpolated.""" + above = y >= threshold + if not above.any(): + return float("nan") + first_above = int(np.argmax(above)) + if first_above == 0: + return float(t[0]) + t_before, t_after = t[first_above - 1], t[first_above] + y_before, y_after = y[first_above - 1], y[first_above] + if y_after == y_before: + return float(t_after) + frac = (threshold - y_before) / (y_after - y_before) + return float(t_before + frac * (t_after - t_before)) + + +def _extract_well(fov_name: str) -> str: + """Return ``'A/2'`` from ``'A/2/000000'`` style FOV names, else full FOV.""" + parts = fov_name.split("/") + if len(parts) >= 2: + return "/".join(parts[:2]) + return fov_name + + +def _bootstrap_ci(values: np.ndarray, n_boot: int = 2000, alpha: float = 0.05) -> tuple[float, float, float]: + """Return (median, lo, hi) with a percentile bootstrap on the median.""" + values = values[np.isfinite(values)] + if len(values) == 0: + return float("nan"), float("nan"), float("nan") + if len(values) == 1: + v = float(values[0]) + return v, v, v + rng = np.random.default_rng(42) + boots = np.empty(n_boot) + for i in range(n_boot): + boots[i] = np.median(rng.choice(values, size=len(values), replace=True)) + med = float(np.median(values)) + lo = float(np.quantile(boots, alpha / 2)) + hi = float(np.quantile(boots, 1 - alpha / 2)) + return med, lo, hi + + +def _summary_markdown(per_cell: pd.DataFrame, organelle_channel: str) -> str: + """Render per-well + pooled median ± CI as markdown for copy to Confluence.""" + lines = [] + lines.append(f"# Timing metrics — {organelle_channel}") + lines.append("") + lines.append(f"**n cells**: {len(per_cell)}") + lines.append("") + + lines.append("## Per-well medians") + lines.append("") + lines.append("| well | n | t_onset_abs (min) | t50 (min) | t_peak (min) | delta_peak | rise_rate (Δcos/hr) |") + lines.append("|---|---|---|---|---|---|---|") + for well, grp in per_cell.groupby("well"): + lines.append( + f"| {well} | {len(grp)} | " + f"{grp['t_onset_abs'].median():.0f} | {grp['t50'].median():.0f} | " + f"{grp['t_peak'].median():.0f} | {grp['delta_peak'].median():.3f} | " + f"{grp['rise_rate_per_hour'].median():.3f} |" + ) + lines.append("") + + lines.append("## Pooled median ± 95% bootstrap CI") + lines.append("") + lines.append("| metric | median | 95% CI |") + lines.append("|---|---|---|") + for metric in ["t_onset_abs", "t50", "t_peak", "delta_peak", "rise_rate_per_hour"]: + med, lo, hi = _bootstrap_ci(per_cell[metric].to_numpy(dtype=float)) + lines.append(f"| {metric} | {med:.3f} | [{lo:.3f}, {hi:.3f}] |") + lines.append("") + return "\n".join(lines) + + +def _compare_organelles(per_cell_files: list[Path], out_stem: Path) -> None: + """Merge per-cell parquets from multiple organelles, emit comparison plots + stats.""" + dfs = [] + for p in per_cell_files: + df = pd.read_parquet(p) + dfs.append(df) + merged = pd.concat(dfs, ignore_index=True) + + metrics = ["t_onset_abs", "t50", "t_peak", "delta_peak", "rise_rate_per_hour"] + organelles = sorted(merged["organelle_channel"].unique()) + + fig, axes = plt.subplots(1, len(metrics), figsize=(3.3 * len(metrics), 4.2), squeeze=False) + axes = axes[0] + colors = plt.get_cmap("tab10").colors + for ax, metric in zip(axes, metrics): + positions = np.arange(len(organelles)) + for i, org in enumerate(organelles): + vals = merged.loc[merged["organelle_channel"] == org, metric].to_numpy(dtype=float) + vals = vals[np.isfinite(vals)] + if len(vals) == 0: + continue + jitter = np.random.default_rng(0).uniform(-0.12, 0.12, size=len(vals)) + ax.scatter( + np.full_like(vals, i, dtype=float) + jitter, + vals, + s=22, + color=colors[i % len(colors)], + alpha=0.7, + edgecolor="none", + ) + med, lo, hi = _bootstrap_ci(vals) + ax.hlines(med, i - 0.25, i + 0.25, color="black", linewidth=2, zorder=5) + ax.vlines(i, lo, hi, color="black", linewidth=1.2, zorder=5) + ax.set_xticks(positions) + ax.set_xticklabels(organelles, rotation=30, ha="right") + ax.set_ylabel(metric) + ax.axhline( + 0 if metric in {"t_onset_abs", "t50", "t_peak"} else ax.get_ylim()[0], + color="red", + linestyle=":", + alpha=0.3, + linewidth=0.8, + ) + ax.set_title(metric) + + fig.tight_layout() + out_stem.parent.mkdir(parents=True, exist_ok=True) + png_path = out_stem.with_suffix(".png") + fig.savefig(png_path, dpi=160, bbox_inches="tight", facecolor="white") + plt.close(fig) + _logger.info(f"Wrote {png_path}") + + lines = ["# Cross-organelle timing comparison", ""] + lines.append(f"**Organelles**: {', '.join(organelles)}") + lines.append("") + for metric in metrics: + lines.append(f"## {metric}") + lines.append("") + lines.append("| organelle | n | median | 95% CI |") + lines.append("|---|---|---|---|") + per_org = {} + for org in organelles: + vals = merged.loc[merged["organelle_channel"] == org, metric].to_numpy(dtype=float) + vals = vals[np.isfinite(vals)] + per_org[org] = vals + med, lo, hi = _bootstrap_ci(vals) + lines.append(f"| {org} | {len(vals)} | {med:.3f} | [{lo:.3f}, {hi:.3f}] |") + lines.append("") + if len(organelles) >= 2: + lines.append("**Pairwise rank-sum tests (Mann-Whitney U, two-sided)**") + lines.append("") + lines.append("| a | b | median(a) − median(b) | U | p |") + lines.append("|---|---|---|---|---|") + for i in range(len(organelles)): + for j in range(i + 1, len(organelles)): + a, b = per_org[organelles[i]], per_org[organelles[j]] + if len(a) >= 2 and len(b) >= 2: + u, p = stats.mannwhitneyu(a, b, alternative="two-sided") + diff = float(np.median(a) - np.median(b)) + lines.append(f"| {organelles[i]} | {organelles[j]} | {diff:.3f} | {u:.1f} | {p:.3g} |") + lines.append("") + + md_path = out_stem.with_suffix(".md") + md_path.write_text("\n".join(lines)) + _logger.info(f"Wrote {md_path}") + + +def main() -> None: + """Compute per-cell timing metrics OR merge existing per-cell parquets for comparison.""" + parser = argparse.ArgumentParser(description="Per-cell timing metrics for organelle remodeling.") + sub = parser.add_subparsers(dest="cmd", required=True) + + p_compute = sub.add_parser("compute", help="Compute per-cell metrics for one organelle.") + p_compute.add_argument("--datasets", required=True) + p_compute.add_argument("--config", required=True) + p_compute.add_argument("--template", required=True) + p_compute.add_argument("--flavor", choices=["raw", "pca"], default="raw") + p_compute.add_argument("--query-set", required=True) + p_compute.add_argument("--organelle-channel", required=True) + p_compute.add_argument("--top-n", type=int, default=30) + + p_compare = sub.add_parser("compare", help="Merge per-cell parquets across organelles.") + p_compare.add_argument( + "--per-cell", nargs="+", required=True, help="Paths to per-cell parquets from prior `compute` runs." + ) + p_compare.add_argument("--out-stem", required=True, help="Output path stem (no extension).") + + args = parser.parse_args() + + if args.cmd == "compute": + config = load_stage_config(args.datasets, args.config) + dataset_cfgs = {d["dataset_id"]: d for d in config["datasets"]} + if args.organelle_channel not in config["embeddings"]: + raise ValueError(f"organelle-channel {args.organelle_channel!r} not found in embeddings") + organelle_pattern = config["embeddings"][args.organelle_channel] + + alignment_path = ALIGNMENTS_DIR / f"{args.template}_{args.flavor}_on_{args.query_set}.parquet" + if not alignment_path.exists(): + raise FileNotFoundError(f"Sensor alignment parquet not found: {alignment_path}") + + _logger.info(f"Reading sensor alignment {alignment_path}") + alignments = pd.read_parquet(alignment_path) + + selected = _top_n_cells(alignments, args.top_n) + frame_interval_by_ds = {d["dataset_id"]: float(d["frame_interval_minutes"]) for d in config["datasets"]} + selected = selected.copy() + selected["frame_interval"] = selected["dataset_id"].map(frame_interval_by_ds) + + template_path = TEMPLATES_DIR / f"template_{args.template}.zarr" + tc_grp = zarr.open(str(template_path), mode="r")[args.flavor] + tc = np.asarray(tc_grp["time_calibration"]) if "time_calibration" in tc_grp else None + + def _extrapolate_minutes(row: pd.Series) -> float: + if row["alignment_region"] == "aligned": + return float(row["estimated_t_rel_minutes"]) + fi = row["frame_interval"] + if tc is None: + return float("nan") + if row["alignment_region"] == "pre": + return float(tc[0] + (row["t"] - row["match_q_start"]) * fi) + return float(tc[-1] + (row["t"] - row["match_q_end"]) * fi) + + selected["t_rel_minutes_extrap"] = selected.apply(_extrapolate_minutes, axis=1) + + joined = _join_organelle_embeddings(selected, dataset_cfgs, organelle_pattern) + _logger.info(f" {joined['cell_uid'].nunique()} cells after organelle join") + + distances = _cosine_distance_from_baseline(joined) + t_rel = joined["t_rel_minutes_extrap"].to_numpy(dtype=float) + + per_cell = _compute_per_cell_metrics(joined, distances, t_rel) + per_cell["organelle_channel"] = args.organelle_channel + per_cell["template"] = args.template + per_cell["flavor"] = args.flavor + per_cell["query_set"] = args.query_set + + OUT_DIR.mkdir(parents=True, exist_ok=True) + stem = OUT_DIR / f"timing_{args.template}_{args.flavor}_{args.organelle_channel}_{args.query_set}" + per_cell_path = stem.with_name(stem.name + "_per_cell.parquet") + per_cell.to_parquet(per_cell_path, index=False) + _logger.info(f"Wrote {per_cell_path} ({len(per_cell)} cells)") + + md = _summary_markdown(per_cell, args.organelle_channel) + md_path = stem.with_name(stem.name + "_summary.md") + md_path.write_text(md) + _logger.info(f"Wrote {md_path}") + + elif args.cmd == "compare": + _compare_organelles([Path(p) for p in args.per_cell], Path(args.out_stem)) + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/scripts/pseudotime/README.md b/applications/dynaclr/scripts/pseudotime/README.md deleted file mode 100644 index 4b86214aa..000000000 --- a/applications/dynaclr/scripts/pseudotime/README.md +++ /dev/null @@ -1,146 +0,0 @@ -# Pseudotime Remodeling Analysis - -Measure organelle remodeling timing relative to viral infection onset using lineage-aware alignment and multiple signal extraction methods. - -## Overview - -This directory is organized into `src/` (importable library modules) and `analysis/` (HPC scripts): - -``` -pseudotime/ -├── README.md -├── src/ -│ ├── __init__.py -│ ├── alignment.py -│ ├── signals.py -│ ├── metrics.py -│ └── plotting.py -└── analysis/ - ├── annotation_remodeling.py - ├── prediction_remodeling.py - └── embedding_distance.py -``` - -The pipeline follows: - -``` -alignment → signal extraction → aggregation → metrics → plotting -``` - -### Library Modules (`src/`) - -| Module | Description | -|--------|-------------| -| `src/alignment.py` | Lineage detection, FOV/track filtering, T_perturb assignment | -| `src/signals.py` | Signal extraction: annotation binary, classifier prediction, embedding distance | -| `src/metrics.py` | Population aggregation, onset/T50/peak detection, per-track timing, statistical tests | -| `src/plotting.py` | Response curves, per-track heatmaps, timing distributions, onset comparison | - -### Analysis Scripts (`analysis/`) - -Each script runs the full pipeline with a different signal source. They are Jupyter-compatible (`# %%` cell markers) and designed for HPC execution. - -| Script | Signal Source | Requires | -|--------|--------------|----------| -| `analysis/annotation_remodeling.py` | Human annotations (`organelle_state` column) | Tracking CSV + annotation CSV | -| `analysis/prediction_remodeling.py` | Classifier predictions (`predicted_organelle_state` in AnnData) | Tracking CSV + predicted AnnData zarr | -| `analysis/embedding_distance.py` | Cosine distance from baseline embeddings | Tracking CSV + embedding AnnData zarr | - -## Prerequisites - -Install DynaCLR with the eval extras and statsmodels: - -```bash -cd applications/dynaclr -uv pip install -e ".[eval]" statsmodels -``` - -## Running Tests - -Unit tests cover all four library modules using synthetic data (no HPC paths required): - -```bash -cd applications/dynaclr -uv run pytest tests/test_pseudotime.py -v -``` - -### Test Structure - -| Test Class | Tests | Module Covered | -|------------|-------|----------------| -| `TestAlignment` | 7 | `src/alignment.py` — lineage detection, FOV filtering, T_perturb assignment | -| `TestSignals` | 5 | `src/signals.py` — annotation/prediction/embedding-distance signal extraction | -| `TestMetrics` | 8 | `src/metrics.py` — population aggregation, onset/T50/peak, track timing, stats | -| `TestPlotting` | 4 | `src/plotting.py` — file output (pdf+png) and Figure return for all plot types | - -### Synthetic Data - -Tests use a self-contained tracking DataFrame with: -- **C/2/000**: 3 tracks with parent-child lineage, infected at t=5 -- **C/2/001**: 1 orphan track, infected at t=7 -- **B/1/000**: 2 control tracks (no infection) - -Plus a matching AnnData with 16-dim random embeddings and classifier predictions. - -## Pipeline Details - -### 1. Alignment - -Tracks are filtered by FOV pattern and minimum length, then aligned to infection onset (T_perturb). Lineage-aware logic ensures all tracks in a parent-child lineage share the same T_perturb. - -```python -from src.alignment import align_tracks - -aligned_df = align_tracks( - tracking_df, - frame_interval_minutes=30.0, - fov_pattern="C/2", - min_track_timepoints=3, -) -# Adds columns: t_perturb, t_relative_minutes -``` - -### 2. Signal Extraction - -Three modes producing a common `signal` column: - -```python -from src.signals import ( - extract_annotation_signal, - extract_prediction_signal, - extract_embedding_distance, -) - -# Binary from annotations -df = extract_annotation_signal(aligned_df, state_col="organelle_state") - -# Binary or continuous from classifier predictions -df = extract_prediction_signal(adata, aligned_df, task="organelle_state") - -# Cosine distance from baseline embeddings -df = extract_embedding_distance(adata, aligned_df, baseline_method="per_track") -``` - -### 3. Aggregation and Metrics - -```python -from src.metrics import aggregate_population, find_onset_time - -time_bins = np.arange(-600, 901, 30) -pop_df = aggregate_population(df, time_bins, signal_type="fraction") -onset, threshold, bl_mean, bl_std = find_onset_time(pop_df) -``` - -### 4. Plotting - -All plot functions save pdf+png and return the matplotlib Figure: - -```python -from src.plotting import plot_response_curves - -fig = plot_response_curves( - organelle_curves={"SEC61": pop_df}, - organelle_configs={"SEC61": {"label": "SEC61", "color": "#1f77b4"}}, - output_dir=Path("figures/"), -) -``` diff --git a/applications/dynaclr/scripts/pseudotime/annotation_remodeling.py b/applications/dynaclr/scripts/pseudotime/annotation_remodeling.py deleted file mode 100644 index 96b446045..000000000 --- a/applications/dynaclr/scripts/pseudotime/annotation_remodeling.py +++ /dev/null @@ -1,338 +0,0 @@ -# %% -""" -Annotation-based organelle remodeling analysis. - -Measures remodeling timing using human annotations (organelle_state column) -directly from annotation CSVs — no model predictions required. - -Pipeline: alignment → annotation signal → aggregation → metrics → plotting - -Usage: Run as a Jupyter-compatible script (# %% cell markers). -""" - -from pathlib import Path - -import numpy as np -import pandas as pd - -from dynaclr.evaluation.pseudotime.alignment import align_tracks -from dynaclr.evaluation.pseudotime.metrics import ( - aggregate_population, - compute_track_timing, - find_half_max_time, - find_onset_time, - find_peak_metrics, - run_statistical_tests, -) -from dynaclr.evaluation.pseudotime.plotting import ( - plot_cell_heatmap, - plot_onset_comparison, - plot_response_curves, - plot_timing_distributions, -) -from dynaclr.evaluation.pseudotime.signals import ( - extract_annotation_signal, -) - -# %% -# =========================================================================== -# Dataset configuration -# =========================================================================== - -ANNOTATIONS_ROOT = Path("/hpc/projects/organelle_phenotyping/datasets/annotations") - -ORGANELLE_CONFIG = { - "G3BP1_ZIKV": { - "experiments": [ - { - "csv_path": ANNOTATIONS_ROOT - / "2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV" - / "2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv", - "fov_pattern": "C/2", - "frame_interval_minutes": 10, - "label": "2025_07_22 ZIKV", - }, - { - "csv_path": ANNOTATIONS_ROOT - / "2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV" - / "2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv", - "fov_pattern": "C/2", - "frame_interval_minutes": 30, - "label": "2025_07_24 ZIKV", - }, - ], - "controls": [ - { - "csv_path": ANNOTATIONS_ROOT - / "2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV" - / "2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv", - "fov_pattern": "C/1", - "frame_interval_minutes": 30, - "label": "2025_07_24 control (C/1)", - }, - ], - "label": "G3BP1 ZIKV (Stress Granule)", - "color": "#1f77b4", - }, - "SEC61B_ZIKV": { - "experiments": [ - { - "csv_path": ANNOTATIONS_ROOT - / "2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV" - / "2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv", - "fov_pattern": "A/2", - "frame_interval_minutes": 30, - "label": "2025_07_24 ZIKV (SEC61B)", - }, - ], - "controls": [], - "label": "SEC61B ZIKV (ER)", - "color": "#ff7f0e", - }, - "G3BP1_DENV": { - "experiments": [ - { - "csv_path": ANNOTATIONS_ROOT - / "2025_01_24_A549_G3BP1_DENV" - / "2025_01_24_A549_G3BP1_DENV_combined_annotations.csv", - "fov_pattern": "C/2", - "frame_interval_minutes": 10, - "label": "2025_01_24 DENV", - }, - { - "csv_path": ANNOTATIONS_ROOT - / "2025_01_28_A549_G3BP1_ZIKV_DENV" - / "2025_01_28_A549_G3BP1_ZIKV_DENV_combined_annotations.csv", - "fov_pattern": "C/4", - "frame_interval_minutes": 30, - "label": "2025_01_28 DENV", - }, - ], - "controls": [ - { - "csv_path": ANNOTATIONS_ROOT - / "2025_01_28_A549_G3BP1_ZIKV_DENV" - / "2025_01_28_A549_G3BP1_ZIKV_DENV_combined_annotations.csv", - "fov_pattern": "B/4", - "frame_interval_minutes": 30, - "label": "2025_01_28 control (B/4)", - }, - ], - "label": "G3BP1 DENV (Stress Granule)", - "color": "#2ca02c", - }, -} - -# Analysis parameters -T_PERTURB_SOURCE = "annotation" -TIME_BINS_MINUTES = np.arange(-600, 901, 30) -MIN_CELLS_PER_BIN = 5 -MIN_TRACK_TIMEPOINTS = 3 -ONSET_THRESHOLD_SIGMA = 2 - -RESULTS_DIR = Path(__file__).parent / "results" / "annotation_remodeling" - -# %% -# =========================================================================== -# Step 1 + 2: Load data, alignment, and signal extraction -# =========================================================================== - -marker_results = {} - -for marker, config in ORGANELLE_CONFIG.items(): - print(f"\n{'=' * 60}") - print(f"Processing {marker}") - print(f"{'=' * 60}") - - all_experiment_dfs = [] - - for exp in config["experiments"]: - print(f"\n Experiment: {exp['label']}") - df = pd.read_csv(exp["csv_path"]) - print(f" Loaded {len(df):,} annotations, t range: {df['t'].min()}-{df['t'].max()}") - - # Ensure parent_track_id exists - if "parent_track_id" not in df.columns: - df["parent_track_id"] = -1 - - # Step 1: Alignment - aligned = align_tracks( - df, - frame_interval_minutes=exp["frame_interval_minutes"], - source=T_PERTURB_SOURCE, - fov_pattern=exp["fov_pattern"], - min_track_timepoints=MIN_TRACK_TIMEPOINTS, - ) - - # Step 2: Signal extraction (annotation-based) - aligned = extract_annotation_signal(aligned, state_col="organelle_state", positive_value="remodel") - aligned["experiment"] = exp["label"] - aligned["marker"] = marker - all_experiment_dfs.append(aligned) - - if not all_experiment_dfs: - print(f" No data for {marker}, skipping") - continue - - combined = pd.concat(all_experiment_dfs, ignore_index=True) - - # Step 3: Aggregate - fraction_df = aggregate_population(combined, TIME_BINS_MINUTES, signal_type="fraction") - - n_tracks = combined.groupby(["fov_name", "track_id", "experiment"]).ngroups - marker_results[marker] = { - "combined_df": combined, - "fraction_df": fraction_df, - "config": config, - "n_tracks": n_tracks, - "n_experiments": len(config["experiments"]), - "n_frames": len(combined), - } - - print( - f"\n **{marker} summary**: {n_tracks} tracks, " - f"{len(config['experiments'])} experiments, {len(combined):,} total frames" - ) - -# %% -# =========================================================================== -# Process controls -# =========================================================================== - -control_results = {} -for marker, config in ORGANELLE_CONFIG.items(): - if not config.get("controls"): - continue - ctrl_dfs = [] - for ctrl in config["controls"]: - df = pd.read_csv(ctrl["csv_path"]) - df = df[df["fov_name"].str.startswith(ctrl["fov_pattern"])].copy() - ctrl_dfs.append(df) - if ctrl_dfs: - control_combined = pd.concat(ctrl_dfs, ignore_index=True) - n_total = len(control_combined.dropna(subset=["organelle_state"])) - n_remodel = (control_combined["organelle_state"] == "remodel").sum() - fraction = n_remodel / n_total if n_total > 0 else 0 - control_results[marker] = { - "n_total": n_total, - "n_remodel": n_remodel, - "fraction": fraction, - } - print(f" {marker} control: {n_remodel}/{n_total} = {fraction:.4f}") - -# %% -# =========================================================================== -# Step 4: Timing metrics -# =========================================================================== - -timing_rows = [] -for marker, res in marker_results.items(): - frac_df = res["fraction_df"] - - t_onset, threshold, bl_mean, bl_std = find_onset_time( - frac_df, - sigma_threshold=ONSET_THRESHOLD_SIGMA, - min_cells_per_bin=MIN_CELLS_PER_BIN, - ) - t_50 = find_half_max_time(frac_df) - peak = find_peak_metrics(frac_df) - - timing_rows.append( - { - "marker": marker, - "T_onset_minutes": t_onset, - "T_50_minutes": t_50, - "T_peak_minutes": peak["T_peak_minutes"], - "peak_amplitude": peak["peak_amplitude"], - "T_return_minutes": peak["T_return_minutes"], - "pulse_duration_minutes": peak["pulse_duration_minutes"], - "auc": peak["auc"], - "baseline_mean": bl_mean, - "baseline_std": bl_std, - "n_tracks": res["n_tracks"], - "n_experiments": res["n_experiments"], - } - ) - -timing_df = pd.DataFrame(timing_rows) -print("\n## Remodeling Timing Metrics\n") -print(timing_df.to_string(index=False)) - -# Per-track timing -all_track_timing = [] -for marker, res in marker_results.items(): - track_timing = compute_track_timing(res["combined_df"], signal_type="fraction") - track_timing["marker"] = marker - all_track_timing.append(track_timing) - -track_timing_df = pd.concat(all_track_timing, ignore_index=True) - -# %% -# =========================================================================== -# Step 5: Plotting -# =========================================================================== - -marker_curves = {m: res["fraction_df"] for m, res in marker_results.items()} -marker_configs = {m: res["config"] for m, res in marker_results.items()} - -plot_response_curves( - marker_curves, - marker_configs, - RESULTS_DIR, - signal_type="fraction", - min_cells_per_bin=MIN_CELLS_PER_BIN, - title="Annotation-based organelle remodeling after sensor translocation", - filename_prefix="annotation_remodeling_comparison", -) - -for marker, res in marker_results.items(): - plot_cell_heatmap( - res["combined_df"], - TIME_BINS_MINUTES, - signal_type="fraction", - organelle_label=res["config"]["label"], - output_dir=RESULTS_DIR, - filename_prefix=f"{marker}_annotation_heatmap", - ) - -plot_timing_distributions( - track_timing_df, - marker_configs, - RESULTS_DIR, - filename_prefix="per_track_onset_duration", -) - -plot_onset_comparison( - timing_df, - RESULTS_DIR, - filename_prefix="onset_comparison", -) - -# %% -# =========================================================================== -# Step 6: Statistical tests -# =========================================================================== - -if len(marker_results) > 1: - stats_df = run_statistical_tests(marker_results, track_timing_df, control_results or None) - print("\n## Statistical Tests\n") - print(stats_df.to_string(index=False)) - stats_df.to_csv(RESULTS_DIR / "statistical_tests.csv", index=False) - -# %% -# =========================================================================== -# Step 7: Save CSVs -# =========================================================================== - -RESULTS_DIR.mkdir(parents=True, exist_ok=True) - -timing_df.to_csv(RESULTS_DIR / "timing_metrics.csv", index=False) -track_timing_df.to_csv(RESULTS_DIR / "per_track_timing.csv", index=False) - -for marker, res in marker_results.items(): - frac_path = RESULTS_DIR / f"{marker}_fraction_curve.csv" - res["fraction_df"].to_csv(frac_path, index=False) - -print(f"\nResults saved to {RESULTS_DIR}") - -# %% diff --git a/applications/dynaclr/scripts/pseudotime/embedding_distance.py b/applications/dynaclr/scripts/pseudotime/embedding_distance.py deleted file mode 100644 index e9311e3c0..000000000 --- a/applications/dynaclr/scripts/pseudotime/embedding_distance.py +++ /dev/null @@ -1,301 +0,0 @@ -# %% -""" -Embedding distance-based organelle remodeling analysis. - -Measures remodeling timing using cosine distance from pre-infection -baseline embeddings. Supports per-track and control-well baselines, -with optional PCA projection. - -Pipeline: alignment → embedding distance → aggregation → metrics → plotting - -Usage: Run as a Jupyter-compatible script (# %% cell markers). -""" - -import glob -from pathlib import Path - -import anndata as ad -import numpy as np -import pandas as pd - -from dynaclr.evaluation.pseudotime.alignment import align_tracks -from dynaclr.evaluation.pseudotime.metrics import ( - aggregate_population, - compute_track_timing, - find_half_max_time, - find_onset_time, - find_peak_metrics, - run_statistical_tests, -) -from dynaclr.evaluation.pseudotime.plotting import ( - plot_cell_heatmap, - plot_onset_comparison, - plot_response_curves, - plot_timing_distributions, -) -from dynaclr.evaluation.pseudotime.signals import ( - extract_embedding_distance, -) - -# %% -# =========================================================================== -# Dataset configuration -# =========================================================================== - -ANNOTATIONS_ROOT = Path("/hpc/projects/organelle_phenotyping/datasets/annotations") -EMBEDDINGS_ROOT = Path("/hpc/projects/intracellular_dashboard/organelle_dynamics") - -ORGANELLE_CONFIG = { - "G3BP1": { - "experiments": [ - { - "embeddings_path": EMBEDDINGS_ROOT - / "2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV" - / "4-phenotyping/predictions/DynaCLR-2D-BagOfChannels-timeaware/v3", - "embeddings_pattern": "*organelle*.zarr", - "annotations_path": ANNOTATIONS_ROOT - / "2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV" - / "2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv", - "fov_pattern": "C/2", - "control_fov_pattern": "C/1", - "frame_interval_minutes": 30, - "label": "2025_07_22 ZIKV", - }, - ], - "label": "G3BP1 (Stress Granule)", - "color": "#1f77b4", - }, - "SEC61B": { - "experiments": [ - { - "embeddings_path": EMBEDDINGS_ROOT - / "2024_11_07_A549_SEC61_DENV" - / "4-phenotyping/2-predictions/DynaCLR-2D-BagOfChannels-timeaware/v3", - "embeddings_pattern": "*organelle*.zarr", - "annotations_path": ANNOTATIONS_ROOT - / "2024_11_07_A549_SEC61B_DENV" - / "2024_11_07_A549_SEC61B_DENV_combined_annotations.csv", - "fov_pattern": "C/2", - "control_fov_pattern": "B/3", - "frame_interval_minutes": 10, - "label": "2024_11_07 DENV", - }, - ], - "label": "SEC61B (ER)", - "color": "#2ca02c", - }, -} - -# Analysis parameters -T_PERTURB_SOURCE = "annotation" -BASELINE_METHOD = "per_track" # "per_track" or "control_well" -BASELINE_WINDOW_MINUTES = (-240, -180) -DISTANCE_METRIC = "cosine" -PCA_N_COMPONENTS = 20 # Set to None to use full embedding space -MIN_BASELINE_FRAMES = 2 -TIME_BINS_MINUTES = np.arange(-600, 901, 30) -MIN_CELLS_PER_BIN = 10 -MIN_TRACK_TIMEPOINTS = 3 -ONSET_THRESHOLD_SIGMA = 2 - -RESULTS_DIR = Path(__file__).parent / "results" / "embedding_distance" - -# %% -# =========================================================================== -# Step 1 + 2: Load data, alignment, and signal extraction -# =========================================================================== - -marker_results = {} - -for marker, config in ORGANELLE_CONFIG.items(): - print(f"\n{'=' * 60}") - print(f"Processing {marker}") - print(f"{'=' * 60}") - - all_experiment_dfs = [] - - for exp in config["experiments"]: - print(f"\n Experiment: {exp['label']}") - - # Load embeddings - emb_files = glob.glob(str(exp["embeddings_path"] / exp["embeddings_pattern"])) - if not emb_files: - print(f" No embeddings found matching: {exp['embeddings_pattern']}") - continue - - adata = ad.read_zarr(emb_files[0]) - print(f" Loaded {adata.shape[0]:,} embeddings") - - # Load annotations for infection state alignment - ann_df = pd.read_csv(exp["annotations_path"]) - if "parent_track_id" not in ann_df.columns: - ann_df["parent_track_id"] = -1 - - # Step 1: Alignment - aligned = align_tracks( - ann_df, - frame_interval_minutes=exp["frame_interval_minutes"], - source=T_PERTURB_SOURCE, - fov_pattern=exp["fov_pattern"], - min_track_timepoints=MIN_TRACK_TIMEPOINTS, - ) - - # Step 2: Signal extraction (embedding distance) - aligned = extract_embedding_distance( - adata, - aligned, - baseline_method=BASELINE_METHOD, - baseline_window_minutes=BASELINE_WINDOW_MINUTES, - control_fov_pattern=exp.get("control_fov_pattern"), - distance_metric=DISTANCE_METRIC, - pca_n_components=PCA_N_COMPONENTS, - min_baseline_frames=MIN_BASELINE_FRAMES, - ) - aligned["experiment"] = exp["label"] - aligned["marker"] = marker - all_experiment_dfs.append(aligned) - - if not all_experiment_dfs: - print(f" No data for {marker}, skipping") - continue - - combined = pd.concat(all_experiment_dfs, ignore_index=True) - - # Step 3: Aggregate - population_df = aggregate_population(combined, TIME_BINS_MINUTES, signal_type="continuous") - - n_tracks = combined.groupby(["fov_name", "track_id", "experiment"]).ngroups - marker_results[marker] = { - "combined_df": combined, - "population_df": population_df, - "config": config, - "n_tracks": n_tracks, - "n_experiments": len(config["experiments"]), - "n_frames": len(combined), - } - - print( - f"\n **{marker} summary**: {n_tracks} tracks, " - f"{len(config['experiments'])} experiments, {len(combined):,} total frames" - ) - -# %% -# =========================================================================== -# Step 4: Timing metrics -# =========================================================================== - -timing_rows = [] -for marker, res in marker_results.items(): - pop_df = res["population_df"] - - t_onset, threshold, bl_mean, bl_std = find_onset_time( - pop_df, - sigma_threshold=ONSET_THRESHOLD_SIGMA, - min_cells_per_bin=MIN_CELLS_PER_BIN, - ) - t_50 = find_half_max_time(pop_df) - peak = find_peak_metrics(pop_df) - - timing_rows.append( - { - "marker": marker, - "T_onset_minutes": t_onset, - "T_50_minutes": t_50, - "T_peak_minutes": peak["T_peak_minutes"], - "peak_amplitude": peak["peak_amplitude"], - "T_return_minutes": peak["T_return_minutes"], - "pulse_duration_minutes": peak["pulse_duration_minutes"], - "auc": peak["auc"], - "baseline_mean": bl_mean, - "baseline_std": bl_std, - "baseline_method": BASELINE_METHOD, - "distance_metric": DISTANCE_METRIC, - "pca_components": PCA_N_COMPONENTS, - "n_tracks": res["n_tracks"], - "n_experiments": res["n_experiments"], - } - ) - -timing_df = pd.DataFrame(timing_rows) -print("\n## Embedding Distance Timing Metrics\n") -print(timing_df.to_string(index=False)) - -# Per-track timing -all_track_timing = [] -for marker, res in marker_results.items(): - track_timing = compute_track_timing(res["combined_df"], signal_type="continuous") - track_timing["marker"] = marker - all_track_timing.append(track_timing) - -track_timing_df = pd.concat(all_track_timing, ignore_index=True) - -# %% -# =========================================================================== -# Step 5: Plotting -# =========================================================================== - -marker_curves = {m: res["population_df"] for m, res in marker_results.items()} -marker_configs = {m: res["config"] for m, res in marker_results.items()} - -plot_response_curves( - marker_curves, - marker_configs, - RESULTS_DIR, - signal_type="continuous", - min_cells_per_bin=MIN_CELLS_PER_BIN, - title=f"Embedding distance remodeling ({BASELINE_METHOD}, {DISTANCE_METRIC})", - filename_prefix="embedding_distance_comparison", -) - -for marker, res in marker_results.items(): - plot_cell_heatmap( - res["combined_df"], - TIME_BINS_MINUTES, - signal_type="continuous", - organelle_label=res["config"]["label"], - output_dir=RESULTS_DIR, - filename_prefix=f"{marker}_distance_heatmap", - ) - -if len(track_timing_df) > 0: - plot_timing_distributions( - track_timing_df, - marker_configs, - RESULTS_DIR, - filename_prefix="per_track_onset_duration", - ) - - plot_onset_comparison( - timing_df, - RESULTS_DIR, - filename_prefix="onset_comparison", - ) - -# %% -# =========================================================================== -# Step 6: Statistical tests -# =========================================================================== - -if len(marker_results) > 1 and len(track_timing_df) > 0: - stats_df = run_statistical_tests(marker_results, track_timing_df) - print("\n## Statistical Tests\n") - print(stats_df.to_string(index=False)) - stats_df.to_csv(RESULTS_DIR / "statistical_tests.csv", index=False) - -# %% -# =========================================================================== -# Step 7: Save CSVs -# =========================================================================== - -RESULTS_DIR.mkdir(parents=True, exist_ok=True) - -timing_df.to_csv(RESULTS_DIR / "timing_metrics.csv", index=False) -track_timing_df.to_csv(RESULTS_DIR / "per_track_timing.csv", index=False) - -for marker, res in marker_results.items(): - curve_path = RESULTS_DIR / f"{marker}_distance_curve.csv" - res["population_df"].to_csv(curve_path, index=False) - -print(f"\nResults saved to {RESULTS_DIR}") - -# %% diff --git a/applications/dynaclr/scripts/pseudotime/infection_death_remodeling.py b/applications/dynaclr/scripts/pseudotime/infection_death_remodeling.py deleted file mode 100644 index 890b6c83d..000000000 --- a/applications/dynaclr/scripts/pseudotime/infection_death_remodeling.py +++ /dev/null @@ -1,386 +0,0 @@ -# %% -""" -Multi-channel correlation: infection, death, and organelle remodeling. - -Uses classifier predictions from different channels to ask: -- Do cells that get infected earlier also die faster? -- Is faster death correlated with faster organelle remodeling? - -Pipeline: -1. Load sensor zarr → T_perturb (infection onset), T_death (cell death onset) -2. Load organelle zarr → T_remodel (organelle remodeling onset) -3. Merge per-track event timings -4. Correlate and visualize - -Usage: Run as a Jupyter-compatible script (# %% cell markers). -""" - -from pathlib import Path - -import anndata as ad -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -from scipy import stats - -# %% -# =========================================================================== -# Configuration -# =========================================================================== - -DATASET_ROOT = Path( - "/hpc/projects/intracellular_dashboard/organelle_dynamics" - "/2025_01_24_A549_G3BP1_DENV/4-phenotyping/predictions" - "/DynaCLR-2D-BagOfChannels-timeaware/v3" -) - -SENSOR_ZARR = DATASET_ROOT / "timeaware_sensor_160patch_104ckpt.zarr" -ORGANELLE_ZARR = DATASET_ROOT / "timeaware_organelle_160patch_104ckpt.zarr" - -FOV_PATTERN = "C/2" # infected wells -FRAME_INTERVAL_MINUTES = 10 -MIN_TRACK_TIMEPOINTS = 3 - -RESULTS_DIR = Path(__file__).parent / "results" / "infection_death_remodeling" - -# %% -# =========================================================================== -# Step 1: Load data and filter to infected wells -# =========================================================================== - -sensor = ad.read_zarr(SENSOR_ZARR) -organelle = ad.read_zarr(ORGANELLE_ZARR) - -print(f"Sensor: {sensor.shape[0]:,} cells") -print(f"Organelle: {organelle.shape[0]:,} cells") - -# Filter to infected FOVs -sensor_obs = sensor.obs[sensor.obs["fov_name"].astype(str).str.startswith(FOV_PATTERN)].copy() -organelle_obs = organelle.obs[organelle.obs["fov_name"].astype(str).str.startswith(FOV_PATTERN)].copy() - -print(f"\nAfter FOV filter ({FOV_PATTERN}):") -print(f" Sensor: {len(sensor_obs):,} cells") -print(f" Organelle: {len(organelle_obs):,} cells") - -# %% -# =========================================================================== -# Step 2: Build per-cell merged dataframe -# =========================================================================== - -merge_keys = ["fov_name", "track_id", "t"] - -sensor_cols = merge_keys + [ - "predicted_infection_state", - "predicted_cell_death_state", -] -organelle_cols = merge_keys + [ - "predicted_organelle_state_g3bp1", -] - -merged = sensor_obs[sensor_cols].merge( - organelle_obs[organelle_cols], - on=merge_keys, - how="inner", -) - -merged["t_minutes"] = merged["t"] * FRAME_INTERVAL_MINUTES - -print(f"\nMerged: {len(merged):,} cells across {merged.groupby(['fov_name', 'track_id']).ngroups} tracks") -print(f" Infection: {merged['predicted_infection_state'].value_counts().to_dict()}") -print(f" Death: {merged['predicted_cell_death_state'].value_counts().to_dict()}") -print(f" Remodel: {merged['predicted_organelle_state_g3bp1'].value_counts().to_dict()}") - -# %% -# =========================================================================== -# Step 3: Compute per-track event timings -# =========================================================================== - - -def find_first_event(group: pd.DataFrame, col: str, value: str) -> float | None: - """Return t_minutes of the first frame matching value, or None.""" - hits = group.loc[group[col] == value, "t_minutes"] - if len(hits) > 0: - return hits.min() - return None - - -track_events = [] -for (fov, tid), group in merged.groupby(["fov_name", "track_id"]): - group = group.sort_values("t") - n_frames = len(group) - if n_frames < MIN_TRACK_TIMEPOINTS: - continue - - t_start = group["t_minutes"].min() - t_end = group["t_minutes"].max() - track_duration = t_end - t_start - - t_infection = find_first_event(group, "predicted_infection_state", "infected") - t_death = find_first_event(group, "predicted_cell_death_state", "dead") - t_remodel = find_first_event(group, "predicted_organelle_state_g3bp1", "remodel") - - # Was cell ever infected, dead, remodeled? - ever_infected = t_infection is not None - ever_dead = t_death is not None - ever_remodeled = t_remodel is not None - - # Time from infection to death / remodeling - infection_to_death = (t_death - t_infection) if (ever_infected and ever_dead) else None - infection_to_remodel = (t_remodel - t_infection) if (ever_infected and ever_remodeled) else None - remodel_to_death = (t_death - t_remodel) if (ever_remodeled and ever_dead) else None - - track_events.append( - { - "fov_name": fov, - "track_id": tid, - "n_frames": n_frames, - "track_duration_min": track_duration, - "t_infection_min": t_infection, - "t_death_min": t_death, - "t_remodel_min": t_remodel, - "ever_infected": ever_infected, - "ever_dead": ever_dead, - "ever_remodeled": ever_remodeled, - "infection_to_death_min": infection_to_death, - "infection_to_remodel_min": infection_to_remodel, - "remodel_to_death_min": remodel_to_death, - } - ) - -events_df = pd.DataFrame(track_events) - -print(f"\n## Track Event Summary ({len(events_df)} tracks)") -print(f" Ever infected: {events_df['ever_infected'].sum()}") -print(f" Ever dead: {events_df['ever_dead'].sum()}") -print(f" Ever remodeled: {events_df['ever_remodeled'].sum()}") -print(f" Infected & dead: {(events_df['ever_infected'] & events_df['ever_dead']).sum()}") -print(f" Infected & remodeled: {(events_df['ever_infected'] & events_df['ever_remodeled']).sum()}") -print(f" All three: {(events_df['ever_infected'] & events_df['ever_dead'] & events_df['ever_remodeled']).sum()}") - -# %% -# =========================================================================== -# Step 4: Descriptive statistics -# =========================================================================== - -infected_tracks = events_df[events_df["ever_infected"]].copy() - -print("\n## Timing distributions (infected tracks only)") -for col_label, col in [ - ("Infection → Death", "infection_to_death_min"), - ("Infection → Remodel", "infection_to_remodel_min"), - ("Remodel → Death", "remodel_to_death_min"), -]: - valid = infected_tracks[col].dropna() - if len(valid) > 0: - print(f"\n **{col_label}** (n={len(valid)})") - print(f" median: {valid.median():.0f} min, mean: {valid.mean():.0f} min, std: {valid.std():.0f} min") - print(f" range: [{valid.min():.0f}, {valid.max():.0f}] min") - -# Compare death rates: infected vs uninfected -infected_dead = events_df["ever_infected"] & events_df["ever_dead"] -uninfected_dead = ~events_df["ever_infected"] & events_df["ever_dead"] -n_infected = events_df["ever_infected"].sum() -n_uninfected = (~events_df["ever_infected"]).sum() - -print("\n## Death rates") -print(f" Infected tracks: {infected_dead.sum()}/{n_infected} = {infected_dead.sum() / max(n_infected, 1):.1%}") -print( - f" Uninfected tracks: {uninfected_dead.sum()}/{n_uninfected} = {uninfected_dead.sum() / max(n_uninfected, 1):.1%}" -) - -if n_infected > 0 and n_uninfected > 0: - table = np.array( - [ - [infected_dead.sum(), n_infected - infected_dead.sum()], - [uninfected_dead.sum(), n_uninfected - uninfected_dead.sum()], - ] - ) - chi2, p_val, _, _ = stats.chi2_contingency(table) - print(f" Chi-squared: {chi2:.2f}, p={p_val:.4g}") - -# %% -# =========================================================================== -# Step 5: Correlation — infection_to_death vs infection_to_remodel -# =========================================================================== - -both = infected_tracks.dropna(subset=["infection_to_death_min", "infection_to_remodel_min"]).copy() - -print(f"\n## Correlation: Infection→Death vs Infection→Remodel (n={len(both)})") - -if len(both) >= 5: - r_pearson, p_pearson = stats.pearsonr(both["infection_to_remodel_min"], both["infection_to_death_min"]) - r_spearman, p_spearman = stats.spearmanr(both["infection_to_remodel_min"], both["infection_to_death_min"]) - print(f" Pearson r={r_pearson:.3f}, p={p_pearson:.4g}") - print(f" Spearman rho={r_spearman:.3f}, p={p_spearman:.4g}") - - # Bin tracks into early/late remodelers (median split) - median_remodel = both["infection_to_remodel_min"].median() - both["remodel_speed"] = np.where( - both["infection_to_remodel_min"] <= median_remodel, "early_remodel", "late_remodel" - ) - - for label, subdf in both.groupby("remodel_speed"): - death_times = subdf["infection_to_death_min"] - print( - f"\n {label} (n={len(subdf)}): death at median {death_times.median():.0f} min," - f" mean {death_times.mean():.0f} min" - ) - - early = both.loc[both["remodel_speed"] == "early_remodel", "infection_to_death_min"] - late = both.loc[both["remodel_speed"] == "late_remodel", "infection_to_death_min"] - if len(early) >= 3 and len(late) >= 3: - u_stat, u_p = stats.mannwhitneyu(early, late, alternative="two-sided") - print(f"\n Mann-Whitney U test (early vs late remodelers death time): U={u_stat:.0f}, p={u_p:.4g}") - -# %% -# =========================================================================== -# Step 6: Plots -# =========================================================================== - -RESULTS_DIR.mkdir(parents=True, exist_ok=True) - -fig, axes = plt.subplots(2, 2, figsize=(14, 12)) - -# --- Panel A: Scatter of infection→remodel vs infection→death --- -ax = axes[0, 0] -if len(both) >= 5: - ax.scatter( - both["infection_to_remodel_min"], - both["infection_to_death_min"], - alpha=0.4, - s=15, - edgecolors="none", - ) - # Regression line - slope, intercept, _, _, _ = stats.linregress(both["infection_to_remodel_min"], both["infection_to_death_min"]) - x_fit = np.linspace(both["infection_to_remodel_min"].min(), both["infection_to_remodel_min"].max(), 100) - ax.plot(x_fit, slope * x_fit + intercept, "r--", label=f"r={r_pearson:.2f}, p={p_pearson:.2g}") - ax.legend() -ax.set_xlabel("Infection → Remodel (min)") -ax.set_ylabel("Infection → Death (min)") -ax.set_title("A. Remodeling vs Death timing") - -# --- Panel B: Distribution of infection→death for infected vs all tracks --- -ax = axes[0, 1] -infected_death_times = infected_tracks["infection_to_death_min"].dropna() -if len(infected_death_times) > 0: - ax.hist(infected_death_times, bins=30, alpha=0.7, color="#d62728", edgecolor="white") -ax.set_xlabel("Infection → Death (min)") -ax.set_ylabel("Number of tracks") -ax.set_title("B. Time from infection to death") - -# --- Panel C: Death rate comparison --- -ax = axes[1, 0] -categories = ["Infected", "Uninfected"] -dead_counts = [infected_dead.sum(), uninfected_dead.sum()] -alive_counts = [n_infected - infected_dead.sum(), n_uninfected - uninfected_dead.sum()] -x = np.arange(len(categories)) -width = 0.35 -ax.bar(x - width / 2, dead_counts, width, label="Dead", color="#d62728") -ax.bar(x + width / 2, alive_counts, width, label="Alive", color="#2ca02c") -ax.set_xticks(x) -ax.set_xticklabels(categories) -ax.set_ylabel("Number of tracks") -ax.set_title("C. Death rates by infection status") -ax.legend() - -# --- Panel D: Boxplot of death timing by remodel speed --- -ax = axes[1, 1] -if len(both) >= 5: - early_vals = both.loc[both["remodel_speed"] == "early_remodel", "infection_to_death_min"].to_numpy() - late_vals = both.loc[both["remodel_speed"] == "late_remodel", "infection_to_death_min"].to_numpy() - bp = ax.boxplot( - [early_vals, late_vals], - labels=["Early remodelers", "Late remodelers"], - patch_artist=True, - ) - bp["boxes"][0].set_facecolor("#1f77b4") - bp["boxes"][1].set_facecolor("#ff7f0e") - ax.set_ylabel("Infection → Death (min)") - ax.set_title("D. Death timing by remodel speed") - -plt.tight_layout() -fig.savefig(RESULTS_DIR / "infection_death_remodeling.png", dpi=150, bbox_inches="tight") -fig.savefig(RESULTS_DIR / "infection_death_remodeling.pdf", bbox_inches="tight") -plt.show() -print(f"Saved to {RESULTS_DIR}") - -# %% -# =========================================================================== -# Step 7: Timeline heatmap — per-track state over time -# =========================================================================== - -# Show a sample of infected tracks with all 3 states over time -infected_tids = infected_tracks.sort_values("t_infection_min").head(50) -sample_keys = set(zip(infected_tids["fov_name"], infected_tids["track_id"])) - -sample = merged[merged.apply(lambda r: (r["fov_name"], r["track_id"]) in sample_keys, axis=1)].copy() - -if len(sample) > 0: - # Align to infection time - sample = sample.merge( - infected_tids[["fov_name", "track_id", "t_infection_min"]], - on=["fov_name", "track_id"], - ) - sample["t_rel"] = sample["t_minutes"] - sample["t_infection_min"] - - # Encode states as numeric for heatmap - sample["infection_num"] = (sample["predicted_infection_state"] == "infected").astype(int) - sample["death_num"] = (sample["predicted_cell_death_state"] == "dead").astype(int) - sample["remodel_num"] = (sample["predicted_organelle_state_g3bp1"] == "remodel").astype(int) - - fig, axes = plt.subplots(1, 3, figsize=(18, 8), sharey=True) - time_bins = np.arange(sample["t_rel"].min(), sample["t_rel"].max() + FRAME_INTERVAL_MINUTES, FRAME_INTERVAL_MINUTES) - - track_labels = [] - for i, ((fov, tid), _) in enumerate(infected_tids.iterrows()): - track_labels.append(f"{fov}:{tid}") - - for ax, (title, col) in zip( - axes, - [ - ("Infection", "infection_num"), - ("Death", "death_num"), - ("Remodeling", "remodel_num"), - ], - ): - # Pivot: rows=tracks, cols=time bins - track_list = list(zip(infected_tids["fov_name"], infected_tids["track_id"])) - matrix = np.full((len(track_list), len(time_bins) - 1), np.nan) - - for i, (fov, tid) in enumerate(track_list): - track_data = sample[(sample["fov_name"] == fov) & (sample["track_id"] == tid)] - for _, row in track_data.iterrows(): - bin_idx = np.searchsorted(time_bins, row["t_rel"]) - 1 - if 0 <= bin_idx < matrix.shape[1]: - matrix[i, bin_idx] = row[col] - - im = ax.imshow(matrix, aspect="auto", cmap="RdYlBu_r", vmin=0, vmax=1, interpolation="nearest") - ax.set_xlabel("Time relative to infection (min)") - ax.set_title(title) - - # Set x tick labels - n_ticks = min(10, len(time_bins)) - tick_positions = np.linspace(0, len(time_bins) - 2, n_ticks, dtype=int) - ax.set_xticks(tick_positions) - ax.set_xticklabels([f"{time_bins[t]:.0f}" for t in tick_positions], rotation=45) - - axes[0].set_ylabel("Tracks (sorted by infection time)") - plt.colorbar(im, ax=axes[-1], label="State (0=no, 1=yes)") - plt.tight_layout() - fig.savefig(RESULTS_DIR / "track_timeline_heatmap.png", dpi=150, bbox_inches="tight") - fig.savefig(RESULTS_DIR / "track_timeline_heatmap.pdf", bbox_inches="tight") - plt.show() - -# %% -# =========================================================================== -# Step 8: Save results -# =========================================================================== - -events_df.to_csv(RESULTS_DIR / "track_events.csv", index=False) -if len(both) > 0: - both.to_csv(RESULTS_DIR / "infected_remodeled_dead_tracks.csv", index=False) - -print(f"\nAll results saved to {RESULTS_DIR}") - -# %% diff --git a/applications/dynaclr/scripts/pseudotime/infection_onset_distribution.py b/applications/dynaclr/scripts/pseudotime/infection_onset_distribution.py deleted file mode 100644 index 276f3e99c..000000000 --- a/applications/dynaclr/scripts/pseudotime/infection_onset_distribution.py +++ /dev/null @@ -1,1028 +0,0 @@ -# %% -""" -Infection onset timing distribution and phenotype binning. - -Measures the absolute time from experiment start to first infection -(T_perturbation) per track, then bins cells by early/mid/late infection -to compare downstream phenotype responses (death, remodeling). - -Supports both annotation-based and prediction-based infection timing. - -Usage: Run as a Jupyter-compatible script (# %% cell markers). -""" - -from pathlib import Path - -import anndata as ad -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -from scipy import stats - -# %% -# =========================================================================== -# Configuration -# =========================================================================== - -ANNOTATIONS_ROOT = Path("/hpc/projects/organelle_phenotyping/datasets/annotations") -EMBEDDINGS_ROOT = Path("/hpc/projects/intracellular_dashboard/organelle_dynamics") - -# All experiments start at 3 HPI (hours post-infection). -# t=0 in the data corresponds to 3 HPI, so absolute HPI = t_minutes/60 + T_OFFSET_HPI. -T_OFFSET_HPI = 3.0 - -EXPERIMENTS = { - "G3BP1 (Stress Granule)": { - "datasets": [ - { - "annotations_path": ANNOTATIONS_ROOT - / "2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV" - / "2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv", - "embeddings_path": EMBEDDINGS_ROOT - / "2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV" - / "4-phenotyping/predictions/DynaCLR-2D-BagOfChannels-timeaware/v3", - "fov_pattern": "C/2", - "frame_interval_minutes": 30, - "label": "2025_07_24 ZIKV", - }, - { - "annotations_path": ANNOTATIONS_ROOT - / "2025_01_24_A549_G3BP1_DENV" - / "2025_01_24_A549_G3BP1_DENV_combined_annotations.csv", - "embeddings_path": EMBEDDINGS_ROOT - / "2025_01_24_A549_G3BP1_DENV" - / "4-phenotyping/predictions/DynaCLR-2D-BagOfChannels-timeaware/v3", - "fov_pattern": "C/2", - "frame_interval_minutes": 10, - "label": "2025_01_24 DENV", - }, - ], - "remodel_task": "organelle_state_g3bp1", - "remodel_ann_col": "organelle_state", - "remodel_positive": "remodel", - }, - "SEC61B (ER)": { - "datasets": [ - { - "annotations_path": ANNOTATIONS_ROOT - / "2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV" - / "2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv", - "embeddings_path": EMBEDDINGS_ROOT - / "2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV" - / "4-phenotyping/predictions/DynaCLR-2D-BagOfChannels-timeaware/v3", - "fov_pattern": "A/2", - "frame_interval_minutes": 30, - "label": "2025_07_24 ZIKV", - }, - ], - "remodel_task": "organelle_state_sec61b", - "remodel_ann_col": "organelle_state", - "remodel_positive": "remodel", - }, -} - -MIN_TRACK_TIMEPOINTS = 10 - -# Smoothing: require N consecutive frames of a state before calling it a true event. -# Set to 1 to disable (raw first-frame detection). -MIN_CONSECUTIVE_FRAMES = 3 - -# Binning strategy: terciles by default, or custom edges -N_BINS = 3 - -RESULTS_DIR = Path(__file__).parent / "results" / "infection_onset_distribution" - -SAVE_FIGURES = False - -# %% -# =========================================================================== -# Step 1: Helper — extract per-track events from annotations -# =========================================================================== - - -def extract_annotation_events( - ann_df: pd.DataFrame, - fov_pattern: str, - frame_interval: float, - remodel_col: str = "organelle_state", - remodel_positive: str = "remodel", -) -> pd.DataFrame: - """Extract per-track first-event timings from annotation CSV.""" - filtered = ann_df[ann_df["fov_name"].astype(str).str.startswith(fov_pattern)].copy() - has_division = "cell_division_state" in filtered.columns - rows = [] - for (fov, tid), g in filtered.groupby(["fov_name", "track_id"]): - if len(g) < MIN_TRACK_TIMEPOINTS: - continue - t_start, t_end = g["t"].min(), g["t"].max() - inf = g[g["infection_state"] == "infected"] - dead = g[g["cell_death_state"] == "dead"] - remodel = g[g[remodel_col] == remodel_positive] - - t_infection = inf["t"].min() if len(inf) > 0 else None - t_death = dead["t"].min() if len(dead) > 0 else None - t_remodel = remodel["t"].min() if len(remodel) > 0 else None - - t_division = None - if has_division: - mitosis = g[g["cell_division_state"] == "mitosis"] - t_division = mitosis["t"].min() if len(mitosis) > 0 else None - - rows.append( - { - "fov_name": fov, - "track_id": tid, - "source": "annotation", - "t_track_start": t_start * frame_interval, - "t_track_end": t_end * frame_interval, - "track_duration_min": (t_end - t_start) * frame_interval, - "t_infection_min": (t_infection * frame_interval if t_infection is not None else None), - "t_death_min": (t_death * frame_interval if t_death is not None else None), - "t_remodel_min": (t_remodel * frame_interval if t_remodel is not None else None), - "t_division_min": (t_division * frame_interval if t_division is not None else None), - "ever_infected": t_infection is not None, - "ever_dead": t_death is not None, - "ever_remodeled": t_remodel is not None, - "ever_divided": t_division is not None, - } - ) - return pd.DataFrame(rows) - - -# %% -# =========================================================================== -# Step 2: Helper — extract per-track events from predictions -# =========================================================================== - - -def _first_consecutive_event( - sorted_t: np.ndarray, - is_positive: np.ndarray, - min_consecutive: int, -) -> float | None: - """Return the t value where min_consecutive consecutive positive frames first occur.""" - if min_consecutive <= 1: - positives = sorted_t[is_positive] - return float(positives[0]) if len(positives) > 0 else None - - run = 0 - for i, pos in enumerate(is_positive): - if pos: - run += 1 - if run >= min_consecutive: - return float(sorted_t[i - min_consecutive + 1]) - else: - run = 0 - return None - - -def extract_prediction_events( - embeddings_path: Path, - fov_pattern: str, - frame_interval: float, - remodel_task: str = "organelle_state_g3bp1", - remodel_positive: str = "remodel", -) -> pd.DataFrame: - """Extract per-track first-event timings from sensor + organelle + phase zarrs.""" - sensor = ad.read_zarr(embeddings_path / "timeaware_sensor_160patch_104ckpt.zarr") - organelle = ad.read_zarr(embeddings_path / "timeaware_organelle_160patch_104ckpt.zarr") - phase = ad.read_zarr(embeddings_path / "timeaware_phase_160patch_104ckpt.zarr") - - sensor_obs = sensor.obs[sensor.obs["fov_name"].astype(str).str.startswith(fov_pattern)].copy() - organelle_obs = organelle.obs[organelle.obs["fov_name"].astype(str).str.startswith(fov_pattern)].copy() - phase_obs = phase.obs[phase.obs["fov_name"].astype(str).str.startswith(fov_pattern)].copy() - - merge_keys = ["fov_name", "track_id", "t"] - pred_remodel_col = f"predicted_{remodel_task}" - - # Check if phase has division predictions - has_division = "predicted_cell_division_state" in phase_obs.columns - - merged = sensor_obs[merge_keys + ["predicted_infection_state", "predicted_cell_death_state"]].merge( - organelle_obs[merge_keys + [pred_remodel_col]], - on=merge_keys, - how="inner", - ) - if has_division: - merged = merged.merge( - phase_obs[merge_keys + ["predicted_cell_division_state"]], - on=merge_keys, - how="inner", - ) - - rows = [] - for (fov, tid), g in merged.groupby(["fov_name", "track_id"]): - if len(g) < MIN_TRACK_TIMEPOINTS: - continue - g = g.sort_values("t") - t_start, t_end = g["t"].min(), g["t"].max() - - sorted_t = g["t"].to_numpy() - t_infection = _first_consecutive_event( - sorted_t, - (g["predicted_infection_state"] == "infected").values, - MIN_CONSECUTIVE_FRAMES, - ) - t_death = _first_consecutive_event( - sorted_t, - (g["predicted_cell_death_state"] == "dead").values, - MIN_CONSECUTIVE_FRAMES, - ) - t_remodel = _first_consecutive_event( - sorted_t, - (g[pred_remodel_col] == remodel_positive).values, - MIN_CONSECUTIVE_FRAMES, - ) - t_division = None - if has_division: - t_division = _first_consecutive_event( - sorted_t, - (g["predicted_cell_division_state"] == "mitosis").values, - MIN_CONSECUTIVE_FRAMES, - ) - - rows.append( - { - "fov_name": fov, - "track_id": tid, - "source": "prediction", - "t_track_start": t_start * frame_interval, - "t_track_end": t_end * frame_interval, - "track_duration_min": (t_end - t_start) * frame_interval, - "t_infection_min": (t_infection * frame_interval if t_infection is not None else None), - "t_death_min": (t_death * frame_interval if t_death is not None else None), - "t_remodel_min": (t_remodel * frame_interval if t_remodel is not None else None), - "t_division_min": (t_division * frame_interval if t_division is not None else None), - "ever_infected": t_infection is not None, - "ever_dead": t_death is not None, - "ever_remodeled": t_remodel is not None, - "ever_divided": t_division is not None, - } - ) - return pd.DataFrame(rows) - - -# %% -# =========================================================================== -# Step 3: Process all experiments (multiple datasets per organelle) -# =========================================================================== - -all_results = {} - -for exp_name, cfg in EXPERIMENTS.items(): - print(f"\n{'=' * 60}") - print(f" {exp_name}") - print(f"{'=' * 60}") - - all_ann_events = [] - all_pred_events = [] - - for ds in cfg["datasets"]: - print(f"\n Dataset: {ds['label']}") - - ann_df = pd.read_csv(ds["annotations_path"]) - ann_ev = extract_annotation_events( - ann_df, - fov_pattern=ds["fov_pattern"], - frame_interval=ds["frame_interval_minutes"], - remodel_col=cfg["remodel_ann_col"], - remodel_positive=cfg["remodel_positive"], - ) - ann_ev["dataset"] = ds["label"] - all_ann_events.append(ann_ev) - print(f" Annotation: {len(ann_ev)} tracks, {ann_ev['ever_infected'].sum()} infected") - - pred_ev = extract_prediction_events( - embeddings_path=ds["embeddings_path"], - fov_pattern=ds["fov_pattern"], - frame_interval=ds["frame_interval_minutes"], - remodel_task=cfg["remodel_task"], - remodel_positive=cfg["remodel_positive"], - ) - pred_ev["dataset"] = ds["label"] - all_pred_events.append(pred_ev) - print(f" Prediction: {len(pred_ev)} tracks, {pred_ev['ever_infected'].sum()} infected") - - ann_events_df = pd.concat(all_ann_events, ignore_index=True) - pred_events_df = pd.concat(all_pred_events, ignore_index=True) - - # Convert to HPI (hours post-inoculation) - for df in [ann_events_df, pred_events_df]: - df["t_infection_hpi"] = df["t_infection_min"] / 60 + T_OFFSET_HPI - df["t_death_hpi"] = df["t_death_min"] / 60 + T_OFFSET_HPI - df["t_remodel_hpi"] = df["t_remodel_min"] / 60 + T_OFFSET_HPI - df["t_division_hpi"] = df["t_division_min"] / 60 + T_OFFSET_HPI - - print(f"\n Combined annotation: {len(ann_events_df)} tracks, {ann_events_df['ever_infected'].sum()} infected") - print(f" Combined prediction: {len(pred_events_df)} tracks, {pred_events_df['ever_infected'].sum()} infected") - - all_results[exp_name] = { - "cfg": cfg, - "ann_events_df": ann_events_df, - "pred_events_df": pred_events_df, - } - -# %% -# =========================================================================== -# Step 4: Bin infected tracks by infection onset time -# =========================================================================== - - -def bin_and_analyze(events_df: pd.DataFrame, source_label: str) -> pd.DataFrame: - """Bin infected tracks by T_infection terciles and summarize phenotypes.""" - infected = events_df[events_df["ever_infected"]].copy() - if len(infected) < N_BINS: - print(f" Too few infected tracks ({len(infected)}) for {N_BINS} bins") - return infected - - # Tercile binning — labels in HPI (hours post-inoculation) - _, bin_edges = pd.qcut(infected["t_infection_hpi"], q=N_BINS, retbins=True) - bin_labels = [f"{bin_edges[i]:.1f}–{bin_edges[i + 1]:.1f} HPI" for i in range(len(bin_edges) - 1)] - infected["infection_bin"] = pd.qcut( - infected["t_infection_hpi"], - q=N_BINS, - labels=bin_labels, - ) - - print(f"\n## {source_label}: Translocation onset bins") - print(f" Bin edges (HPI): {[f'{e:.1f}' for e in bin_edges]}") - print(f" Labels: {bin_labels}") - - has_division = "ever_divided" in infected.columns - - for bin_label in bin_labels: - subset = infected[infected["infection_bin"] == bin_label] - n = len(subset) - n_dead = subset["ever_dead"].sum() - n_remodel = subset["ever_remodeled"].sum() - - print( - f"\n **{bin_label}** (n={n}, T_inf range: " - f"{subset['t_infection_min'].min():.0f}-{subset['t_infection_min'].max():.0f} min)" - ) - print(f" Death rate: {n_dead}/{n} = {n_dead / max(n, 1):.1%}") - print(f" Remodel rate: {n_remodel}/{n} = {n_remodel / max(n, 1):.1%}") - - if has_division: - n_divided = subset["ever_divided"].sum() - print(f" Division rate: {n_divided}/{n} = {n_divided / max(n, 1):.1%}") - - # Time from infection to death/remodel for those that have it - both_dead = subset[subset["ever_dead"]].copy() - if len(both_dead) > 0: - dt = both_dead["t_death_min"] - both_dead["t_infection_min"] - print( - f" Translocation→Death: median={dt.median():.0f} min, mean={dt.mean():.0f} min (n={len(both_dead)})" - ) - - both_remodel = subset[subset["ever_remodeled"]].copy() - if len(both_remodel) > 0: - dt = both_remodel["t_remodel_min"] - both_remodel["t_infection_min"] - print( - f" Translocation→Remodel: median={dt.median():.0f} min," - f" mean={dt.mean():.0f} min (n={len(both_remodel)})" - ) - - if has_division: - both_divided = subset[subset["ever_divided"]].copy() - if len(both_divided) > 0: - dt = both_divided["t_division_min"] - both_divided["t_infection_min"] - print( - f" Translocation→Division: median={dt.median():.0f} min," - f" mean={dt.mean():.0f} min (n={len(both_divided)})" - ) - - # Kruskal-Wallis across bins for infection→death, infection→remodel, infection→division - event_tests = [ - ("Translocation→Death", "t_death_min"), - ("Translocation→Remodel", "t_remodel_min"), - ] - if has_division: - event_tests.append(("Translocation→Division", "t_division_min")) - for event_label, event_col in event_tests: - infected_with_event = infected.dropna(subset=[event_col]).copy() - infected_with_event["delta"] = infected_with_event[event_col] - infected_with_event["t_infection_min"] - groups = [g["delta"].to_numpy() for _, g in infected_with_event.groupby("infection_bin") if len(g) >= 2] - if len(groups) >= 2: - h_stat, h_p = stats.kruskal(*groups) - print(f"\n Kruskal-Wallis ({event_label} across bins): H={h_stat:.2f}, p={h_p:.4g}") - - return infected - - -for exp_name, res in all_results.items(): - ann_binned = bin_and_analyze(res["ann_events_df"], f"{exp_name} (Annotation)") - pred_binned = bin_and_analyze(res["pred_events_df"], f"{exp_name} (Prediction)") - res["ann_binned"] = ann_binned - res["pred_binned"] = pred_binned - -# %% -# =========================================================================== -# Step 5: Plots — per experiment: onset distribution + response histograms -# =========================================================================== - -if SAVE_FIGURES: - RESULTS_DIR.mkdir(parents=True, exist_ok=True) - -BIN_COLORS = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd"] - - -def _plot_kde_by_bin(ax, binned_df, event_col, delta_label): - """Plot KDE curves of response time per infection bin.""" - if "infection_bin" not in binned_df.columns: - return - categories = binned_df["infection_bin"].cat.categories - for i, bin_label in enumerate(categories): - subset = binned_df[binned_df["infection_bin"] == bin_label] - dt = (subset[event_col] - subset["t_infection_min"]).dropna() - if len(dt) >= 3: - from scipy.stats import gaussian_kde - - kde = gaussian_kde(dt, bw_method="scott") - x_grid = np.linspace(dt.min() - 30, dt.max() + 30, 200) - ax.plot(x_grid, kde(x_grid), color=BIN_COLORS[i % len(BIN_COLORS)], linewidth=2) - ax.fill_between( - x_grid, - kde(x_grid), - alpha=0.15, - color=BIN_COLORS[i % len(BIN_COLORS)], - label=f"{bin_label} (n={len(dt)})", - ) - elif len(dt) > 0: - ax.axvline( - dt.median(), - color=BIN_COLORS[i % len(BIN_COLORS)], - linestyle="--", - label=f"{bin_label} (n={len(dt)})", - ) - ax.legend(fontsize=8) - ax.set_xlabel(f"{delta_label} (min)") - ax.set_ylabel("Density") - - -for exp_name, res in all_results.items(): - ann_infected = res["ann_events_df"][res["ann_events_df"]["ever_infected"]] - pred_infected = res["pred_events_df"][res["pred_events_df"]["ever_infected"]] - ann_binned = res["ann_binned"] - pred_binned = res["pred_binned"] - - fig, axes = plt.subplots(2, 4, figsize=(24, 10)) - fig.suptitle(exp_name, fontsize=14, fontweight="bold") - - # --- Row 1: Annotation-based --- - ax = axes[0, 0] - if len(ann_infected) > 0: - ax.hist( - ann_infected["t_infection_hpi"], - bins=20, - alpha=0.7, - color="#1f77b4", - edgecolor="white", - ) - ax.set_xlabel("T_infection (HPI)") - ax.set_ylabel("Number of tracks") - ax.set_title("A. Annotation: infection onset") - - for ax, (delta_label, event_col, panel) in zip( - [axes[0, 1], axes[0, 2], axes[0, 3]], - [ - ("Translocation → Death", "t_death_min", "B"), - ("Translocation → Remodel", "t_remodel_min", "C"), - ("Translocation → Division", "t_division_min", "D"), - ], - ): - _plot_kde_by_bin(ax, ann_binned, event_col, delta_label) - ax.set_title(f"{panel}. Annotation: {delta_label}") - - # --- Row 2: Prediction-based --- - ax = axes[1, 0] - if len(pred_infected) > 0: - ax.hist( - pred_infected["t_infection_hpi"], - bins=30, - alpha=0.7, - color="#ff7f0e", - edgecolor="white", - ) - ax.set_xlabel("T_infection (HPI)") - ax.set_ylabel("Number of tracks") - ax.set_title("E. Prediction: infection onset") - - for ax, (delta_label, event_col, panel) in zip( - [axes[1, 1], axes[1, 2], axes[1, 3]], - [ - ("Translocation → Death", "t_death_min", "F"), - ("Translocation → Remodel", "t_remodel_min", "G"), - ("Translocation → Division", "t_division_min", "H"), - ], - ): - _plot_kde_by_bin(ax, pred_binned, event_col, delta_label) - ax.set_title(f"{panel}. Prediction: {delta_label}") - - plt.tight_layout() - if SAVE_FIGURES: - prefix = exp_name.replace(" ", "_").replace("(", "").replace(")", "") - fig.savefig(RESULTS_DIR / f"{prefix}_onset_binning.png", dpi=150, bbox_inches="tight") - fig.savefig(RESULTS_DIR / f"{prefix}_onset_binning.pdf", bbox_inches="tight") - plt.show() - -# %% -# =========================================================================== -# Step 7: Response time comparison — are elapsed times the same across bins? -# =========================================================================== - - -def plot_response_time_comparison( - binned_df: pd.DataFrame, - source_label: str, - output_dir: Path, -) -> None: - """Boxplot + swarm of response times per infection bin with pairwise tests.""" - if "infection_bin" not in binned_df.columns: - return - - # Compute deltas - binned_df = binned_df.copy() - binned_df["infection_to_death"] = binned_df["t_death_min"] - binned_df["t_infection_min"] - binned_df["infection_to_remodel"] = binned_df["t_remodel_min"] - binned_df["t_infection_min"] - has_division = "t_division_min" in binned_df.columns - if has_division: - binned_df["infection_to_division"] = binned_df["t_division_min"] - binned_df["t_infection_min"] - - n_panels = 4 if has_division else 3 - fig, axes = plt.subplots(1, n_panels, figsize=(6 * n_panels, 6)) - - bin_categories = list(binned_df["infection_bin"].cat.categories) - - # --- Response time boxplots --- - boxplot_items = [ - ("infection_to_death", "Translocation → Death (min)", "Death"), - ("infection_to_remodel", "Translocation → Remodel (min)", "Remodel"), - ] - if has_division: - boxplot_items.append(("infection_to_division", "Translocation → Division (min)", "Division")) - for ax, (delta_col, ylabel, title_suffix) in zip( - axes[: len(boxplot_items)], - boxplot_items, - ): - plot_data = [] - positions = [] - tick_labels = [] - bin_names = [] - for i, bin_label in enumerate(bin_categories): - vals = binned_df.loc[binned_df["infection_bin"] == bin_label, delta_col].dropna() - if len(vals) > 0: - plot_data.append(vals.values) - positions.append(i) - tick_labels.append(f"{bin_label}\n(n={len(vals)})") - bin_names.append(bin_label) - - if len(plot_data) == 0: - ax.text(0.5, 0.5, "No data", ha="center", va="center", transform=ax.transAxes) - ax.set_title(f"{source_label}: {title_suffix}") - continue - - bp = ax.boxplot(plot_data, positions=positions, patch_artist=True, widths=0.5) - colors = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd"] - for patch, color in zip(bp["boxes"], colors[: len(plot_data)]): - patch.set_facecolor(color) - patch.set_alpha(0.6) - - # Overlay individual points - for pos, vals in zip(positions, plot_data): - jitter = np.random.default_rng(42).uniform(-0.12, 0.12, len(vals)) - ax.scatter(pos + jitter, vals, alpha=0.4, s=12, color="black", zorder=3) - - ax.set_xticks(positions) - ax.set_xticklabels(tick_labels) - ax.set_ylabel(ylabel) - ax.set_title(f"{source_label}: {title_suffix} response time") - ax.set_xlabel("Translocation onset bin") - - # Pairwise Mann-Whitney U tests - test_results = [] - for i in range(len(plot_data)): - for j in range(i + 1, len(plot_data)): - if len(plot_data[i]) >= 3 and len(plot_data[j]) >= 3: - u_stat, u_p = stats.mannwhitneyu(plot_data[i], plot_data[j], alternative="two-sided") - test_results.append(f"{bin_names[i]} vs {bin_names[j]}: p={u_p:.4g}") - - if test_results: - test_text = "\n".join(test_results) - ax.text( - 0.98, - 0.98, - test_text, - transform=ax.transAxes, - ha="right", - va="top", - fontsize=8, - family="monospace", - bbox=dict(boxstyle="round,pad=0.3", facecolor="wheat", alpha=0.5), - ) - - # --- Phenotype rates per bin --- - ax = axes[-1] - rates = [] - for bin_label in bin_categories: - subset = binned_df[binned_df["infection_bin"] == bin_label] - n = len(subset) - row_dict = { - "bin": bin_label, - "death_rate": subset["ever_dead"].sum() / max(n, 1), - "remodel_rate": subset["ever_remodeled"].sum() / max(n, 1), - "n": n, - } - if has_division: - row_dict["division_rate"] = subset["ever_divided"].sum() / max(n, 1) - rates.append(row_dict) - rates_df = pd.DataFrame(rates) - - x = np.arange(len(bin_categories)) - n_bars = 3 if has_division else 2 - width = 0.8 / n_bars - ax.bar( - x - width, - rates_df["death_rate"], - width, - label="Death rate", - color="#d62728", - alpha=0.7, - ) - ax.bar( - x, - rates_df["remodel_rate"], - width, - label="Remodel rate", - color="#1f77b4", - alpha=0.7, - ) - if has_division: - ax.bar( - x + width, - rates_df["division_rate"], - width, - label="Division rate", - color="#2ca02c", - alpha=0.7, - ) - for i, row in rates_df.iterrows(): - max_rate = max(row["death_rate"], row["remodel_rate"]) - if has_division: - max_rate = max(max_rate, row["division_rate"]) - ax.text( - i, - max_rate + 0.02, - f"n={row['n']}", - ha="center", - fontsize=9, - ) - ax.set_xticks(x) - ax.set_xticklabels(bin_categories, rotation=15, ha="right") - ax.set_ylabel("Fraction of tracks") - ax.set_title(f"{source_label}: phenotype rates by bin") - ax.legend() - ax.set_ylim(0, 1.1) - - plt.tight_layout() - if SAVE_FIGURES: - prefix = source_label.lower().replace(" ", "_") - fig.savefig( - output_dir / f"{prefix}_response_time_comparison.png", - dpi=150, - bbox_inches="tight", - ) - fig.savefig(output_dir / f"{prefix}_response_time_comparison.pdf", bbox_inches="tight") - plt.show() - - # Print summary table - print(f"\n## {source_label}: Response time summary (median min)") - summary_rows = [] - for bin_label in bin_categories: - subset = binned_df[binned_df["infection_bin"] == bin_label] - death_dt = subset["infection_to_death"].dropna() - remodel_dt = subset["infection_to_remodel"].dropna() - row_dict = { - "bin": bin_label, - "n_tracks": len(subset), - "transloc→death median": (f"{death_dt.median():.0f}" if len(death_dt) > 0 else "—"), - "transloc→death n": len(death_dt), - "transloc→remodel median": (f"{remodel_dt.median():.0f}" if len(remodel_dt) > 0 else "—"), - "transloc→remodel n": len(remodel_dt), - } - if has_division: - division_dt = subset["infection_to_division"].dropna() - row_dict["transloc→division median"] = f"{division_dt.median():.0f}" if len(division_dt) > 0 else "—" - row_dict["transloc→division n"] = len(division_dt) - summary_rows.append(row_dict) - print(pd.DataFrame(summary_rows).to_string(index=False)) - - -for exp_name, res in all_results.items(): - plot_response_time_comparison(res["pred_binned"], f"{exp_name} (Prediction)", RESULTS_DIR) - plot_response_time_comparison(res["ann_binned"], f"{exp_name} (Annotation)", RESULTS_DIR) - -# %% -# =========================================================================== -# Step 7a: Continuous scatter — HPI vs response time (no binning) -# =========================================================================== - - -def plot_hpi_vs_response( - events_df: pd.DataFrame, - source_label: str, - output_dir: Path, -) -> None: - """Scatter plot of translocation onset (HPI) vs response time with regression.""" - infected = events_df[events_df["ever_infected"]].copy() - if len(infected) < 5: - print(f" {source_label}: too few infected tracks ({len(infected)}) for scatter") - return - - infected["infection_to_death"] = infected["t_death_min"] - infected["t_infection_min"] - infected["infection_to_remodel"] = infected["t_remodel_min"] - infected["t_infection_min"] - - response_items = [ - ("infection_to_death", "Transloc → Death (min)"), - ("infection_to_remodel", "Transloc → Remodel (min)"), - ] - has_division = "t_division_min" in infected.columns - if has_division: - infected["infection_to_division"] = infected["t_division_min"] - infected["t_infection_min"] - response_items.append(("infection_to_division", "Transloc → Division (min)")) - - n_panels = len(response_items) - fig, axes = plt.subplots(1, n_panels, figsize=(6 * n_panels, 5)) - if n_panels == 1: - axes = [axes] - fig.suptitle( - f"{source_label}: T_translocation vs response time", - fontsize=14, - fontweight="bold", - ) - - for ax, (delta_col, xlabel) in zip(axes, response_items): - valid = infected.dropna(subset=[delta_col]) - x = valid[delta_col].to_numpy() - y = valid["t_infection_hpi"].to_numpy() - - if len(x) < 3: - ax.text( - 0.5, - 0.5, - f"n={len(x)}", - ha="center", - va="center", - transform=ax.transAxes, - ) - ax.set_xlabel(xlabel) - ax.set_ylabel("T_translocation (HPI)") - continue - - # Color by division status if available - if has_division and "ever_divided" in valid.columns: - divided_mask = valid["ever_divided"].to_numpy() - ax.scatter( - x[~divided_mask], - y[~divided_mask], - alpha=0.5, - s=20, - color="#1f77b4", - label="No division", - zorder=2, - ) - ax.scatter( - x[divided_mask], - y[divided_mask], - alpha=0.7, - s=30, - color="#2ca02c", - marker="^", - label="Divided", - zorder=3, - ) - ax.legend(fontsize=8) - else: - ax.scatter(x, y, alpha=0.5, s=20, color="#1f77b4", zorder=2) - - ax.text( - 0.03, - 0.97, - f"n={len(x)}", - transform=ax.transAxes, - ha="left", - va="top", - fontsize=9, - family="monospace", - bbox=dict(boxstyle="round,pad=0.3", facecolor="wheat", alpha=0.5), - ) - - ax.set_xlabel(xlabel) - ax.set_ylabel("T_translocation (HPI)") - - plt.tight_layout() - if SAVE_FIGURES: - prefix = source_label.lower().replace(" ", "_") - fig.savefig( - output_dir / f"{prefix}_hpi_vs_response.png", - dpi=150, - bbox_inches="tight", - ) - fig.savefig( - output_dir / f"{prefix}_hpi_vs_response.pdf", - bbox_inches="tight", - ) - plt.show() - - -for exp_name, res in all_results.items(): - plot_hpi_vs_response(res["pred_events_df"], f"{exp_name} (Prediction)", RESULTS_DIR) - plot_hpi_vs_response(res["ann_events_df"], f"{exp_name} (Annotation)", RESULTS_DIR) - -# %% -# =========================================================================== -# Step 7b: Division confound analysis — do divided cells respond faster? -# =========================================================================== - - -def plot_division_confound( - binned_df: pd.DataFrame, - source_label: str, - output_dir: Path, -) -> None: - """Compare response times between divided and non-divided cells. - - Tests whether cells that underwent mitosis have shorter - translocation→death or translocation→remodel times, which would - indicate division is a confound for the observed phenotype timing. - """ - if "ever_divided" not in binned_df.columns: - return - if "infection_bin" not in binned_df.columns: - return - - binned_df = binned_df.copy() - binned_df["infection_to_death"] = binned_df["t_death_min"] - binned_df["t_infection_min"] - binned_df["infection_to_remodel"] = binned_df["t_remodel_min"] - binned_df["t_infection_min"] - binned_df["division_label"] = binned_df["ever_divided"].map({True: "Divided", False: "No division"}) - - bin_categories = list(binned_df["infection_bin"].cat.categories) - response_cols = [ - ("infection_to_death", "Transloc → Death (min)"), - ("infection_to_remodel", "Transloc → Remodel (min)"), - ] - - # --- Figure 1: Boxplots stratified by division within each bin --- - fig, axes = plt.subplots( - len(response_cols), - len(bin_categories), - figsize=(6 * len(bin_categories), 5 * len(response_cols)), - squeeze=False, - ) - fig.suptitle( - f"{source_label}: Response times — Divided vs Not divided", - fontsize=14, - fontweight="bold", - ) - - for row_idx, (delta_col, ylabel) in enumerate(response_cols): - for col_idx, bin_label in enumerate(bin_categories): - ax = axes[row_idx, col_idx] - subset = binned_df[binned_df["infection_bin"] == bin_label].dropna(subset=[delta_col]) - divided = subset[subset["ever_divided"]][delta_col] - not_divided = subset[~subset["ever_divided"]][delta_col] - - plot_data = [] - labels = [] - colors_box = [] - if len(not_divided) > 0: - plot_data.append(not_divided.values) - labels.append(f"No div\n(n={len(not_divided)})") - colors_box.append("#1f77b4") - if len(divided) > 0: - plot_data.append(divided.values) - labels.append(f"Divided\n(n={len(divided)})") - colors_box.append("#2ca02c") - - if len(plot_data) == 0: - ax.text( - 0.5, - 0.5, - "No data", - ha="center", - va="center", - transform=ax.transAxes, - ) - else: - bp = ax.boxplot( - plot_data, - patch_artist=True, - widths=0.5, - ) - for patch, c in zip(bp["boxes"], colors_box): - patch.set_facecolor(c) - patch.set_alpha(0.6) - for pos, vals in enumerate(plot_data, 1): - jitter = np.random.default_rng(42).uniform(-0.1, 0.1, len(vals)) - ax.scatter( - pos + jitter, - vals, - alpha=0.4, - s=12, - color="black", - zorder=3, - ) - ax.set_xticklabels(labels) - - # Mann-Whitney if both groups have enough data - if len(divided) >= 3 and len(not_divided) >= 3: - _, p = stats.mannwhitneyu(not_divided, divided, alternative="two-sided") - ax.set_title(f"{bin_label}\np={p:.4g}", fontsize=10) - else: - ax.set_title(bin_label, fontsize=10) - - if col_idx == 0: - ax.set_ylabel(ylabel) - - plt.tight_layout() - if SAVE_FIGURES: - prefix = source_label.lower().replace(" ", "_") - fig.savefig( - output_dir / f"{prefix}_division_confound.png", - dpi=150, - bbox_inches="tight", - ) - fig.savefig( - output_dir / f"{prefix}_division_confound.pdf", - bbox_inches="tight", - ) - plt.show() - - # --- Figure 2: Was division before or after translocation? --- - infected_divided = binned_df[binned_df["ever_divided"]].dropna(subset=["t_division_min"]) - if len(infected_divided) > 0: - infected_divided = infected_divided.copy() - infected_divided["division_relative_to_transloc"] = ( - infected_divided["t_division_min"] - infected_divided["t_infection_min"] - ) - n_before = (infected_divided["division_relative_to_transloc"] < 0).sum() - n_after = (infected_divided["division_relative_to_transloc"] >= 0).sum() - median_dt = infected_divided["division_relative_to_transloc"].median() - - print(f"\n## {source_label}: Division timing relative to translocation") - print(f" Divided before translocation: {n_before}/{len(infected_divided)}") - print(f" Divided after translocation: {n_after}/{len(infected_divided)}") - print(f" Median division–translocation gap: {median_dt:.0f} min") - - # Per-bin breakdown - for bin_label in bin_categories: - sub = infected_divided[infected_divided["infection_bin"] == bin_label] - if len(sub) > 0: - n_b = (sub["division_relative_to_transloc"] < 0).sum() - n_a = (sub["division_relative_to_transloc"] >= 0).sum() - print( - f" {bin_label}: {n_b} before, {n_a} after transloc " - f"(median gap: {sub['division_relative_to_transloc'].median():.0f} min)" - ) - - # --- Summary: overall Mann-Whitney (pooled across bins) --- - print(f"\n## {source_label}: Pooled divided vs not-divided response times") - for delta_col, label in response_cols: - valid = binned_df.dropna(subset=[delta_col]) - div_vals = valid[valid["ever_divided"]][delta_col] - nodiv_vals = valid[~valid["ever_divided"]][delta_col] - if len(div_vals) >= 3 and len(nodiv_vals) >= 3: - _, p = stats.mannwhitneyu(nodiv_vals, div_vals, alternative="two-sided") - print( - f" {label}: no-div median={nodiv_vals.median():.0f} min (n={len(nodiv_vals)}), " - f"div median={div_vals.median():.0f} min (n={len(div_vals)}), " - f"p={p:.4g}" - ) - else: - print(f" {label}: no-div n={len(nodiv_vals)}, div n={len(div_vals)} — too few for test") - - -for exp_name, res in all_results.items(): - plot_division_confound(res["pred_binned"], f"{exp_name} (Prediction)", RESULTS_DIR) - plot_division_confound(res["ann_binned"], f"{exp_name} (Annotation)", RESULTS_DIR) - -# %% -# =========================================================================== -# Step 8: Save CSVs -# =========================================================================== - -if SAVE_FIGURES: - RESULTS_DIR.mkdir(parents=True, exist_ok=True) - for exp_name, res in all_results.items(): - prefix = exp_name.replace(" ", "_").replace("(", "").replace(")", "") - res["ann_events_df"].to_csv(RESULTS_DIR / f"{prefix}_annotation_events.csv", index=False) - res["pred_events_df"].to_csv(RESULTS_DIR / f"{prefix}_prediction_events.csv", index=False) - - if "infection_bin" in res["ann_binned"].columns: - res["ann_binned"].to_csv(RESULTS_DIR / f"{prefix}_annotation_binned.csv", index=False) - if "infection_bin" in res["pred_binned"].columns: - res["pred_binned"].to_csv(RESULTS_DIR / f"{prefix}_prediction_binned.csv", index=False) - - print(f"\nAll results saved to {RESULTS_DIR}") - -# %% diff --git a/applications/dynaclr/scripts/pseudotime/prediction_remodeling.py b/applications/dynaclr/scripts/pseudotime/prediction_remodeling.py deleted file mode 100644 index 0f7a426e1..000000000 --- a/applications/dynaclr/scripts/pseudotime/prediction_remodeling.py +++ /dev/null @@ -1,355 +0,0 @@ -# %% -""" -Prediction-based organelle remodeling analysis. - -Measures remodeling timing using classifier predictions -(predicted_organelle_state in AnnData) instead of human annotations. - -Pipeline: alignment → prediction signal → aggregation → metrics → plotting - -Usage: Run as a Jupyter-compatible script (# %% cell markers). -""" - -import glob -from pathlib import Path - -import anndata as ad -import numpy as np -import pandas as pd - -from dynaclr.evaluation.pseudotime.alignment import align_tracks -from dynaclr.evaluation.pseudotime.metrics import ( - aggregate_population, - compute_track_timing, - find_half_max_time, - find_onset_time, - find_peak_metrics, - run_statistical_tests, -) -from dynaclr.evaluation.pseudotime.plotting import ( - plot_cell_heatmap, - plot_onset_comparison, - plot_response_curves, - plot_timing_distributions, -) -from dynaclr.evaluation.pseudotime.signals import ( - extract_prediction_signal, -) - -# %% -# =========================================================================== -# Dataset configuration -# =========================================================================== - -ANNOTATIONS_ROOT = Path("/hpc/projects/organelle_phenotyping/datasets/annotations") -EMBEDDINGS_ROOT = Path("/hpc/projects/intracellular_dashboard/organelle_dynamics") - -ORGANELLE_CONFIG = { - "G3BP1": { - "experiments": [ - { - "embeddings_path": EMBEDDINGS_ROOT - / "2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV" - / "4-phenotyping/predictions/DynaCLR-2D-BagOfChannels-timeaware/v3", - "embeddings_pattern": "*organelle*.zarr", - "annotations_path": ANNOTATIONS_ROOT - / "2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV" - / "2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv", - "fov_pattern": "C/2", # uninf c/1, inf c/2 - "frame_interval_minutes": 10, - "task": "organelle_state_g3bp1", - "label": "2025_07_22 ZIKV", - }, - { - "embeddings_path": EMBEDDINGS_ROOT - / "2025_01_24_A549_G3BP1_DENV" - / "4-phenotyping/predictions/DynaCLR-2D-BagOfChannels-timeaware/v3", - "embeddings_pattern": "*organelle*.zarr", - "annotations_path": ANNOTATIONS_ROOT - / "2025_01_24_A549_G3BP1_DENV" - / "2025_01_24_A549_G3BP1_DENV_combined_annotations.csv", - "fov_pattern": "C/2", # ZIKV uninf B/3, inf C/2 - "frame_interval_minutes": 10, - "task": "organelle_state_g3bp1", - "label": "2025_01_24 DENV", - }, - { - "embeddings_path": EMBEDDINGS_ROOT - / "2025_01_28_A549_G3BP1_ZIKV_DENV" - / "4-phenotyping/predictions/DynaCLR-2D-BagOfChannels-timeaware/v3", - "embeddings_pattern": "*organelle*.zarr", - "annotations_path": ANNOTATIONS_ROOT - / "2025_01_28_A549_G3BP1_ZIKV_DENV" - / "2025_01_28_A549_G3BP1_ZIKV_DENV_combined_annotations.csv", - "fov_pattern": "C/4", # DENV uninf B/4 and inf C/4 - "frame_interval_minutes": 30, - "task": "organelle_state_g3bp1", - "label": "2025_01_28 ZIKV", - }, - { - "embeddings_path": EMBEDDINGS_ROOT - / "2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV" - / "4-phenotyping/predictions/DynaCLR-2D-BagOfChannels-timeaware/v3", - "embeddings_pattern": "*organelle*.zarr", - "annotations_path": ANNOTATIONS_ROOT - / "2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV" - / "2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv", - "fov_pattern": "C/2", # ZIKV uinf C/1 and inf C/2 - "frame_interval_minutes": 30, - "task": "organelle_state_g3bp1", - "label": "2025_07_24 ZIKV", - }, - ], - "controls": [], - "label": "G3BP1 (Stress Granule)", - "color": "#1f77b4", - }, - "SEC61B": { - "experiments": [ - { - "embeddings_path": EMBEDDINGS_ROOT - / "2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV" - / "4-phenotyping/predictions/DynaCLR-2D-BagOfChannels-timeaware/v3", - "embeddings_pattern": "*organelle*.zarr", - "annotations_path": ANNOTATIONS_ROOT - / "2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV" - / "2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv", - "fov_pattern": "A/2", - "frame_interval_minutes": 30, - "task": "organelle_state_sec61b", - "label": "2025_07_24 ZIKV", - }, - ], - "controls": [], - "label": "SEC61B (ER)", - "color": "#ff7f0e", - }, -} - -# Analysis parameters -T_PERTURB_SOURCE = "annotation" # Default: use human annotations for T_perturb -USE_PROBABILITY = False # Set True to use continuous probability instead of binary -TIME_BINS_MINUTES = np.arange(-600, 901, 30) -MIN_CELLS_PER_BIN = 5 -MIN_TRACK_TIMEPOINTS = 3 -ONSET_THRESHOLD_SIGMA = 2 - -RESULTS_DIR = Path(__file__).parent / "results" / "prediction_remodeling" - -# %% -# =========================================================================== -# Step 1 + 2: Load data, alignment, and signal extraction -# =========================================================================== - -marker_results = {} - -for marker, config in ORGANELLE_CONFIG.items(): - print(f"\n{'=' * 60}") - print(f"Processing {marker}") - print(f"{'=' * 60}") - - all_experiment_dfs = [] - - for exp in config["experiments"]: - print(f"\n Experiment: {exp['label']}") - - # Load embeddings (AnnData with predictions) - emb_files = glob.glob(str(Path(exp["embeddings_path"]) / exp["embeddings_pattern"])) - if not emb_files: - print(f" No embeddings found matching: {exp['embeddings_pattern']}") - continue - - adata = ad.read_zarr(emb_files[0]) - print(f" Loaded {adata.shape[0]:,} embeddings") - - # Check predictions exist - task = exp.get("task", "organelle_state") - pred_col = f"predicted_{task}" - if pred_col not in adata.obs.columns: - print(f" WARNING: '{pred_col}' not in adata.obs — skipping") - continue - - # Load annotations for infection state alignment - ann_df = pd.read_csv(exp["annotations_path"]) - if "parent_track_id" not in ann_df.columns: - ann_df["parent_track_id"] = -1 - - # Step 1: Alignment (using annotations for T_perturb) - aligned = align_tracks( - ann_df, - frame_interval_minutes=exp["frame_interval_minutes"], - source=T_PERTURB_SOURCE, - fov_pattern=exp["fov_pattern"], - min_track_timepoints=MIN_TRACK_TIMEPOINTS, - ) - - # Step 2: Signal extraction (prediction-based) - aligned = extract_prediction_signal( - adata, - aligned, - task=task, - positive_value="remodel", - use_probability=USE_PROBABILITY, - ) - aligned["experiment"] = exp["label"] - aligned["marker"] = marker - all_experiment_dfs.append(aligned) - - if not all_experiment_dfs: - print(f" No data for {marker}, skipping") - continue - - combined = pd.concat(all_experiment_dfs, ignore_index=True) - - # Step 3: Aggregate - signal_type = "continuous" if USE_PROBABILITY else "fraction" - population_df = aggregate_population(combined, TIME_BINS_MINUTES, signal_type=signal_type) - - n_tracks = combined.groupby(["fov_name", "track_id", "experiment"]).ngroups - marker_results[marker] = { - "combined_df": combined, - "population_df": population_df, - "config": config, - "n_tracks": n_tracks, - "n_experiments": len(config["experiments"]), - "n_frames": len(combined), - } - - print( - f"\n **{marker} summary**: {n_tracks} tracks, " - f"{len(config['experiments'])} experiments, {len(combined):,} total frames" - ) - -# %% -# =========================================================================== -# Step 4: Timing metrics -# =========================================================================== - -timing_rows = [] -for marker, res in marker_results.items(): - pop_df = res["population_df"] - - t_onset, threshold, bl_mean, bl_std = find_onset_time( - pop_df, - sigma_threshold=ONSET_THRESHOLD_SIGMA, - min_cells_per_bin=MIN_CELLS_PER_BIN, - ) - t_50 = find_half_max_time(pop_df) - peak = find_peak_metrics(pop_df) - - timing_rows.append( - { - "marker": marker, - "T_onset_minutes": t_onset, - "T_50_minutes": t_50, - "T_peak_minutes": peak["T_peak_minutes"], - "peak_amplitude": peak["peak_amplitude"], - "T_return_minutes": peak["T_return_minutes"], - "pulse_duration_minutes": peak["pulse_duration_minutes"], - "auc": peak["auc"], - "baseline_mean": bl_mean, - "baseline_std": bl_std, - "n_tracks": res["n_tracks"], - "n_experiments": res["n_experiments"], - } - ) - -timing_df = pd.DataFrame(timing_rows) -print("\n## Prediction-based Timing Metrics\n") -print(timing_df.to_string(index=False)) - -# Per-track timing -signal_type = "continuous" if USE_PROBABILITY else "fraction" -all_track_timing = [] -for marker, res in marker_results.items(): - track_timing = compute_track_timing(res["combined_df"], signal_type=signal_type) - track_timing["marker"] = marker - all_track_timing.append(track_timing) - -if all_track_timing: - track_timing_df = pd.concat(all_track_timing, ignore_index=True) -else: - track_timing_df = pd.DataFrame( - columns=[ - "fov_name", - "track_id", - "onset_minutes", - "total_positive_minutes", - "span_minutes", - "n_positive_frames", - "n_total_frames", - "marker", - ] - ) - print("WARNING: No tracks with positive signal detected across any marker.") - -# %% -# =========================================================================== -# Step 5: Plotting -# =========================================================================== - -marker_curves = {m: res["population_df"] for m, res in marker_results.items()} -marker_configs = {m: res["config"] for m, res in marker_results.items()} - -plot_response_curves( - marker_curves, - marker_configs, - RESULTS_DIR, - signal_type=signal_type, - min_cells_per_bin=MIN_CELLS_PER_BIN, - title="Prediction-based organelle remodeling after infection", - filename_prefix="prediction_remodeling_comparison", -) - -for marker, res in marker_results.items(): - plot_cell_heatmap( - res["combined_df"], - TIME_BINS_MINUTES, - signal_type=signal_type, - organelle_label=res["config"]["label"], - output_dir=RESULTS_DIR, - filename_prefix=f"{marker}_prediction_heatmap", - ) - -if len(track_timing_df) > 0: - plot_timing_distributions( - track_timing_df, - marker_configs, - RESULTS_DIR, - filename_prefix="per_track_onset_duration", - ) - - plot_onset_comparison( - timing_df, - RESULTS_DIR, - filename_prefix="onset_comparison", - ) - -# %% -# =========================================================================== -# Step 6: Statistical tests -# =========================================================================== - -if len(marker_results) > 1 and len(track_timing_df) > 0: - stats_df = run_statistical_tests(marker_results, track_timing_df) - print("\n## Statistical Tests\n") - print(stats_df.to_string(index=False)) - stats_df.to_csv(RESULTS_DIR / "statistical_tests.csv", index=False) - -# %% -# =========================================================================== -# Step 7: Save CSVs -# =========================================================================== - -RESULTS_DIR.mkdir(parents=True, exist_ok=True) - -timing_df.to_csv(RESULTS_DIR / "timing_metrics.csv", index=False) -track_timing_df.to_csv(RESULTS_DIR / "per_track_timing.csv", index=False) - -for marker, res in marker_results.items(): - curve_path = RESULTS_DIR / f"{marker}_population_curve.csv" - res["population_df"].to_csv(curve_path, index=False) - -print(f"\nResults saved to {RESULTS_DIR}") - -# %% diff --git a/applications/dynaclr/src/dynaclr/cli.py b/applications/dynaclr/src/dynaclr/cli.py index ade202d7c..cf79d25c1 100644 --- a/applications/dynaclr/src/dynaclr/cli.py +++ b/applications/dynaclr/src/dynaclr/cli.py @@ -85,6 +85,14 @@ def dynaclr(): ) ) +dynaclr.add_command( + LazyCommand( + name="evaluate-tracking-accuracy", + import_path="dynaclr.evaluation.benchmarking.tracking_accuracy.evaluate_tracking.main", + short_help="Evaluate CTC tracking accuracy with DynaCLR ONNX embeddings", + ) +) + dynaclr.add_command( LazyCommand( name="append-obs", @@ -101,6 +109,14 @@ def dynaclr(): ) ) +dynaclr.add_command( + LazyCommand( + name="combined-dim-reduction", + import_path="dynaclr.evaluation.dimensionality_reduction.reduce_combined.main", + short_help="Joint PCA/PHATE across multiple AnnData stores", + ) +) + dynaclr.add_command( LazyCommand( name="cross-validate", @@ -109,6 +125,22 @@ def dynaclr(): ) ) +dynaclr.add_command( + LazyCommand( + name="run-linear-classifiers", + import_path="dynaclr.evaluation.linear_classifiers.orchestrated.main", + short_help="Run linear classifiers on orchestrator embeddings (batch, CSV metrics)", + ) +) + +dynaclr.add_command( + LazyCommand( + name="split-embeddings", + import_path="dynaclr.evaluation.split_embeddings.main", + short_help="Split combined embeddings zarr into one zarr per experiment", + ) +) + dynaclr.add_command( LazyCommand( name="info", @@ -125,6 +157,14 @@ def dynaclr(): ) ) +dynaclr.add_command( + LazyCommand( + name="preprocess-cell-index", + import_path="dynaclr.data.preprocess_cell_index.main", + short_help="Remove empty-frame rows from a cell index parquet", + ) +) + dynaclr.add_command( LazyCommand( name="convert-ops-parquet", @@ -157,6 +197,63 @@ def dynaclr(): ) ) +dynaclr.add_command( + LazyCommand( + name="compute-mmd", + import_path="dynaclr.evaluation.mmd.compute_mmd.main", + short_help="Compute MMD between perturbation groups in cell embeddings", + ) +) + +dynaclr.add_command( + LazyCommand( + name="plot-mmd-heatmap", + import_path="dynaclr.evaluation.mmd.compute_mmd.plot_mmd_heatmap_cmd", + short_help="Plot combined MMD heatmap (all markers) from per-experiment CSVs", + ) +) + +dynaclr.add_command( + LazyCommand( + name="prepare-eval-configs", + import_path="dynaclr.evaluation.evaluate.main", + short_help="Generate evaluation YAML configs and print JSON manifest (Nextflow entry point)", + ) +) + +dynaclr.add_command( + LazyCommand( + name="check-evals", + import_path="dynaclr.evaluation.check_evals.main", + short_help="Check eval completion status for all models in the registry", + ) +) + +dynaclr.add_command( + LazyCommand( + name="append-annotations", + import_path="dynaclr.evaluation.append_annotations.main", + short_help="Append annotation columns to per-experiment zarrs", + ) +) + +dynaclr.add_command( + LazyCommand( + name="append-predictions", + import_path="dynaclr.evaluation.append_predictions.main", + short_help="Apply saved classifiers and write predictions to per-experiment zarrs", + ) +) + + +dynaclr.add_command( + LazyCommand( + name="plot-embeddings", + import_path="dynaclr.evaluation.plot_embeddings.main", + short_help="Generate scatter plots from an AnnData embedding store", + ) +) + def main(): """Run the DynaCLR CLI.""" diff --git a/applications/dynaclr/src/dynaclr/data/datamodule.py b/applications/dynaclr/src/dynaclr/data/datamodule.py index cd702508b..70fac59f6 100644 --- a/applications/dynaclr/src/dynaclr/data/datamodule.py +++ b/applications/dynaclr/src/dynaclr/data/datamodule.py @@ -16,6 +16,7 @@ import numpy as np import pandas as pd +import torch from iohub.core.config import TensorStoreConfig from lightning.pytorch import LightningDataModule from monai.data.thread_buffer import ThreadDataLoader @@ -27,8 +28,8 @@ from dynaclr.data.index import MultiExperimentIndex from viscy_data._utils import BatchedCenterSpatialCropd, _transform_channel_wise from viscy_data.channel_dropout import ChannelDropout +from viscy_data.channel_utils import parse_channel_name from viscy_data.sampler import FlexibleBatchSampler -from viscy_transforms import BatchedRandSpatialCropd _logger = logging.getLogger(__name__) @@ -51,11 +52,10 @@ class MultiExperimentDataModule(LightningDataModule): Parameters ---------- - collection_path : str or None - Path to collection YAML for ExperimentRegistry.from_collection(). - Optional when ``cell_index_path`` is provided — the registry is - built directly from parquet + zarr metadata via - ExperimentRegistry.from_cell_index(). + cell_index_path : str + Path to preprocessed cell index parquet (from ``build-cell-index`` + + ``preprocess-cell-index``). Contains all metadata needed for + training: TCZYX shape, normalization stats, focus slice. z_window : int Number of Z slices the model consumes (final crop size). z_extraction_window : int or None @@ -85,7 +85,7 @@ class MultiExperimentDataModule(LightningDataModule): batch_size : int Batch size. Default: 128. num_workers : int - Thread workers for ThreadDataLoader. Default: 1. + Thread workers for ThreadDataLoader. Default: 4. batch_group_by : str or list[str] or None Column(s) to group batches by (e.g. ``"experiment"``). Default: None. stratify_by : str | list[str] | None @@ -122,17 +122,9 @@ class MultiExperimentDataModule(LightningDataModule): Only include these wells. Default: None. exclude_fovs : list[str] | None Exclude these FOVs. Default: None. - cell_index_path : str | None - Optional path to a pre-built cell index parquet for faster startup. - When provided, both train and val indices load from this parquet - (filtered by their respective registries). Default: None. focus_channel : str | None Channel name for ``focus_slice`` lookup when auto-resolving z_range. Default: None (uses first source_channel). - num_workers_index : int - Number of parallel processes for building the cell index. Default: 1 - (sequential). When > 1, one process is spawned per experiment. - Ignored when ``cell_index_path`` is provided. reference_pixel_size_xy_um : float or None Reference pixel size in XY (micrometers) for physical-scale normalization. None = no rescaling. Default: None. @@ -157,7 +149,7 @@ class MultiExperimentDataModule(LightningDataModule): def __init__( self, - collection_path: str | None, + cell_index_path: str, z_window: int, z_extraction_window: int | None = None, z_focus_offset: float = 0.5, @@ -168,7 +160,7 @@ def __init__( tau_range: tuple[float, float] = (0.5, 2.0), tau_decay_rate: float = 2.0, batch_size: int = 128, - num_workers: int = 1, + num_workers: int = 4, # Sampling hyperparameters (passed to FlexibleBatchSampler) batch_group_by: str | list[str] | None = None, stratify_by: str | list[str] | None = "perturbation", @@ -185,13 +177,11 @@ def __init__( normalizations: list[MapTransform] | None = None, augmentations: list[MapTransform] | None = None, # Other - cache_pool_bytes: int = 0, + cache_pool_bytes: int = 500_000_000, seed: int = 0, include_wells: list[str] | None = None, exclude_fovs: list[str] | None = None, - cell_index_path: str | None = None, focus_channel: str | None = None, - num_workers_index: int = 1, reference_pixel_size_xy_um: float | None = None, reference_pixel_size_z_um: float | None = None, positive_cell_source: str = "lookup", @@ -200,11 +190,14 @@ def __init__( label_columns: dict[str, str] | None = None, max_border_shift: int = -1, shuffle_val: bool = False, + pin_memory: bool = True, + prefetch_factor: int | None = None, + buffer_size: int = 4, ) -> None: super().__init__() # Core parameters - self.collection_path = collection_path + self.cell_index_path = cell_index_path self.z_window = z_window self.z_extraction_window = z_extraction_window self.z_focus_offset = z_focus_offset @@ -249,9 +242,7 @@ def __init__( self.seed = seed self.include_wells = include_wells self.exclude_fovs = exclude_fovs - self.cell_index_path = cell_index_path self.focus_channel = focus_channel - self.num_workers_index = num_workers_index self.reference_pixel_size_xy_um = reference_pixel_size_xy_um self.reference_pixel_size_z_um = reference_pixel_size_z_um self.positive_cell_source = positive_cell_source @@ -260,6 +251,9 @@ def __init__( self.label_columns = label_columns self.max_border_shift = max_border_shift self.shuffle_val = shuffle_val + self.pin_memory = pin_memory + self.prefetch_factor = prefetch_factor + self.buffer_size = buffer_size # Create ChannelDropout module self.channel_dropout = ChannelDropout( @@ -270,6 +264,7 @@ def __init__( # Datasets (populated in setup) self.train_dataset: MultiExperimentTripletDataset | None = None self.val_dataset: MultiExperimentTripletDataset | None = None + self.predict_dataset: MultiExperimentTripletDataset | None = None # ------------------------------------------------------------------ # Setup @@ -292,33 +287,20 @@ def setup(self, stage: str | None = None) -> None: Lightning stage: ``"fit"``, ``"predict"``, etc. """ if stage == "fit" or stage is None: - if self.collection_path is not None: - registry = ExperimentRegistry.from_collection( - self.collection_path, - z_window=self.z_window, - z_extraction_window=self.z_extraction_window, - z_focus_offset=self.z_focus_offset, - focus_channel=getattr(self, "focus_channel", None), - reference_pixel_size_xy_um=self.reference_pixel_size_xy_um, - reference_pixel_size_z_um=self.reference_pixel_size_z_um, - ) - elif self.cell_index_path is not None: - registry = ExperimentRegistry.from_cell_index( - self.cell_index_path, - z_window=self.z_window, - z_extraction_window=self.z_extraction_window, - z_focus_offset=self.z_focus_offset, - focus_channel=getattr(self, "focus_channel", None), - reference_pixel_size_xy_um=self.reference_pixel_size_xy_um, - reference_pixel_size_z_um=self.reference_pixel_size_z_um, - ) - else: - raise ValueError("Either collection_path or cell_index_path must be provided.") + registry, cell_index_df = ExperimentRegistry.from_cell_index( + self.cell_index_path, + z_window=self.z_window, + z_extraction_window=self.z_extraction_window, + z_focus_offset=self.z_focus_offset, + focus_channel=self.focus_channel, + reference_pixel_size_xy_um=self.reference_pixel_size_xy_um, + reference_pixel_size_z_um=self.reference_pixel_size_z_um, + ) if self.val_experiments: - self._setup_experiment_split(registry) + self._setup_experiment_split(registry, cell_index_df) else: - self._setup_fov_split(registry) + self._setup_fov_split(registry, cell_index_df) if self.channels_per_sample is None: self._channel_names = registry.source_channel_labels @@ -333,7 +315,6 @@ def setup(self, stage: str | None = None) -> None: self._augmentation_transform = Compose( self.normalizations + self.augmentations + [self._train_final_crop()] ) - self._no_augmentation_transform = Compose(self.normalizations + [self._val_final_crop()]) _logger.info( "MultiExperimentDataModule setup: %d train anchors, %d val anchors", @@ -341,7 +322,62 @@ def setup(self, stage: str | None = None) -> None: len(self.val_dataset) if self.val_dataset else 0, ) - def _setup_experiment_split(self, registry: ExperimentRegistry) -> None: + elif stage == "predict": + self._setup_predict() + _logger.info( + "MultiExperimentDataModule predict setup: %d anchors", + len(self.predict_dataset) if self.predict_dataset else 0, + ) + + def _setup_predict(self) -> None: + """Set up predict dataset over the full cell index (no train/val split).""" + registry, cell_index_df = ExperimentRegistry.from_cell_index( + self.cell_index_path, + z_window=self.z_window, + z_extraction_window=self.z_extraction_window, + z_focus_offset=self.z_focus_offset, + focus_channel=self.focus_channel, + reference_pixel_size_xy_um=self.reference_pixel_size_xy_um, + reference_pixel_size_z_um=self.reference_pixel_size_z_um, + ) + + if self.channels_per_sample is None: + self._channel_names = registry.source_channel_labels + elif isinstance(self.channels_per_sample, int): + self._channel_names = [f"channel_{i}" for i in range(self.channels_per_sample)] + else: + self._channel_names = list(self.channels_per_sample) + + predict_index = MultiExperimentIndex( + registry=registry, + yx_patch_size=self.yx_patch_size, + tau_range_hours=self.tau_range, + include_wells=self.include_wells, + exclude_fovs=self.exclude_fovs, + cell_index_df=cell_index_df, + positive_cell_source=self.positive_cell_source, + positive_match_columns=self.positive_match_columns, + fit=False, + ) + self.predict_dataset = MultiExperimentTripletDataset( + index=predict_index, + fit=False, + tau_range_hours=self.tau_range, + tau_decay_rate=self.tau_decay_rate, + channels_per_sample=self.channels_per_sample, + positive_cell_source=self.positive_cell_source, + positive_match_columns=self.positive_match_columns, + positive_channel_source=self.positive_channel_source, + label_columns=self.label_columns, + ) + + # Predict transform: normalizations + final center crop only (no augmentations). + # BatchedChannelWiseZReductiond is kept if present in self.augmentations + # since it is architecturally required to produce the 2D model input. + z_reduction = [t for t in self.augmentations if type(t).__name__ == "BatchedChannelWiseZReductiond"] + self._predict_transform = Compose(self.normalizations + z_reduction + [self._train_final_crop()]) + + def _setup_experiment_split(self, registry: ExperimentRegistry, cell_index_df: pd.DataFrame) -> None: """Split by whole experiments into train/val.""" train_names = [e.name for e in registry.experiments if e.name not in self.val_experiments] val_names = [e.name for e in registry.experiments if e.name in self.val_experiments] @@ -364,8 +400,7 @@ def _setup_experiment_split(self, registry: ExperimentRegistry) -> None: tau_range_hours=self.tau_range, include_wells=self.include_wells, exclude_fovs=self.exclude_fovs, - cell_index_path=self.cell_index_path, - num_workers=self.num_workers_index, + cell_index_df=cell_index_df, positive_cell_source=self.positive_cell_source, positive_match_columns=self.positive_match_columns, max_border_shift=self.max_border_shift, @@ -391,8 +426,7 @@ def _setup_experiment_split(self, registry: ExperimentRegistry) -> None: tau_range_hours=self.tau_range, include_wells=self.include_wells, exclude_fovs=self.exclude_fovs, - cell_index_path=self.cell_index_path, - num_workers=self.num_workers_index, + cell_index_df=cell_index_df, positive_cell_source=self.positive_cell_source, positive_match_columns=self.positive_match_columns, max_border_shift=self.max_border_shift, @@ -410,7 +444,7 @@ def _setup_experiment_split(self, registry: ExperimentRegistry) -> None: label_columns=self.label_columns, ) - def _setup_fov_split(self, registry: ExperimentRegistry) -> None: + def _setup_fov_split(self, registry: ExperimentRegistry, cell_index_df: pd.DataFrame) -> None: """Split FOVs within each experiment by split_ratio. Uses experiment-qualified keys ``(experiment, fov_name)`` so that @@ -423,45 +457,119 @@ def _setup_fov_split(self, registry: ExperimentRegistry) -> None: tau_range_hours=self.tau_range, include_wells=self.include_wells, exclude_fovs=self.exclude_fovs, - cell_index_path=self.cell_index_path, - num_workers=self.num_workers_index, + cell_index_df=cell_index_df, positive_cell_source=self.positive_cell_source, positive_match_columns=self.positive_match_columns, tensorstore_config=self.tensorstore_config, ) rng = np.random.default_rng(self.seed) - train_keys: set[tuple[str, str]] = set() - val_keys: set[tuple[str, str]] = set() + + # Build per-row boolean masks directly during the per-experiment + # groupby walk. The previous implementation built + # pd.MultiIndex.from_arrays over every row of tracks + valid_anchors + # (81M+ rows for OPS), which hashes a Python tuple per row and + # dominates setup-time memory. Per-group isin against a small + # Python-set of FOV names is O(group_size) with no object index. + train_fovs_per_exp: dict[str, set[str]] = {} + val_fovs_per_exp: dict[str, set[str]] = {} for exp_name, group in full_index.tracks.groupby("experiment"): fovs = sorted(group["fov_name"].unique()) n_train = max(1, int(len(fovs) * self.split_ratio)) rng.shuffle(fovs) - for f in fovs[:n_train]: - train_keys.add((exp_name, f)) - for f in fovs[n_train:]: - val_keys.add((exp_name, f)) + train_fovs_per_exp[exp_name] = set(fovs[:n_train]) + val_fovs_per_exp[exp_name] = set(fovs[n_train:]) + n_train_fovs = sum(len(s) for s in train_fovs_per_exp.values()) + n_val_fovs = sum(len(s) for s in val_fovs_per_exp.values()) _logger.info( "FOV split (ratio=%.2f): %d train FOVs, %d val FOVs", self.split_ratio, - len(train_keys), - len(val_keys), + n_train_fovs, + n_val_fovs, ) - full_qual = list(zip(full_index.tracks["experiment"], full_index.tracks["fov_name"])) - train_mask = pd.Series([k in train_keys for k in full_qual], index=full_index.tracks.index) + def _build_train_mask(df: pd.DataFrame) -> np.ndarray: + """Row-wise boolean mask: True if (experiment, fov_name) is train.""" + mask = np.zeros(len(df), dtype=bool) + # groupby("experiment") returns integer positions in ``df`` via + # group.index after reset_index; we rely on the caller passing + # reset-indexed frames (which is what MultiExperimentIndex produces). + for exp_name, group in df.groupby("experiment", sort=False): + train_fovs = train_fovs_per_exp.get(exp_name, set()) + if not train_fovs: + continue + sub_mask = group["fov_name"].isin(train_fovs).to_numpy() + mask[group.index.to_numpy()] = sub_mask + return mask + + def _split_by_mask(df: pd.DataFrame, mask: np.ndarray) -> tuple[pd.DataFrame, pd.DataFrame]: + """Partition ``df`` by a boolean mask using integer row indices. + + ``df[bool_mask]`` on an Arrow-backed DataFrame routes through + ``pyarrow.compute.take`` which allocates a fresh buffer per + string column and scales badly with row count × column count. + On a 16M-row × 15-string-col frame this can take 7-8 minutes + per call on a contended node. + + Using ``df.take(int_indices)`` on a frame whose Arrow string + columns have been cast to ``object`` upfront is ~20× faster + because pandas uses plain NumPy fancy indexing on the + materialized object arrays. + """ + train_rows = np.flatnonzero(mask) + val_rows = np.flatnonzero(~mask) + return ( + df.take(train_rows).reset_index(drop=True), + df.take(val_rows).reset_index(drop=True), + ) - train_tracks = full_index.tracks[train_mask].reset_index(drop=True) - val_tracks = full_index.tracks[~train_mask].reset_index(drop=True) + def _materialize_strings(df: pd.DataFrame) -> pd.DataFrame: + """In-place cast remaining ArrowStringArray columns to Categorical. + + ArrowStringArray routes every ``df[mask]`` through + ``pyarrow.compute.take`` which allocates a fresh per-column + buffer and scales catastrophically (7-8 min per call on 16M rows + with 15 string columns on a contended node). Casting to pandas + Categorical uses int codes + a single categories dict, so + slicing is pure NumPy fancy indexing on the codes. + + Low-cardinality columns (``experiment``, ``marker``, etc.) are + already Categorical from ``read_cell_index``/``_align_parquet_columns`` + — those are skipped. High-cardinality columns like ``cell_id`` + become effectively int32-indexed even at ~80M unique values, + since the dict overhead is one-time and the row-aligned codes + are cheap. NumPy-object casts were tried first but allocate + ~5-10 GB of Python string objects per frame, which on 4-rank DDP + OOMs the node. + """ + for col in df.columns: + s = df[col] + if isinstance(s.dtype, pd.CategoricalDtype): + continue + if pd.api.types.is_string_dtype(s) or str(s.dtype).startswith(("string", "Arrow")): + df[col] = s.astype("category") + return df + + _materialize_strings(full_index.tracks) + _materialize_strings(full_index.valid_anchors) + + train_mask = _build_train_mask(full_index.tracks) + train_tracks, val_tracks = _split_by_mask(full_index.tracks, train_mask) + + va = full_index.valid_anchors + train_va_mask = _build_train_mask(va) + train_va, val_va = _split_by_mask(va, train_va_mask) train_index = full_index.clone_with_subset( train_tracks, positive_cell_source=self.positive_cell_source, positive_match_columns=self.positive_match_columns, max_border_shift=self.max_border_shift, + precomputed_valid_anchors=train_va, ) + self.train_dataset = MultiExperimentTripletDataset( index=train_index, fit=True, @@ -474,11 +582,12 @@ def _setup_fov_split(self, registry: ExperimentRegistry) -> None: label_columns=self.label_columns, ) - if val_keys: + if not val_tracks.empty: val_index = full_index.clone_with_subset( val_tracks, positive_cell_source=self.positive_cell_source, positive_match_columns=self.positive_match_columns, + precomputed_valid_anchors=val_va, ) self.val_dataset = MultiExperimentTripletDataset( index=val_index, @@ -496,8 +605,29 @@ def _setup_fov_split(self, registry: ExperimentRegistry) -> None: # Dataloaders # ------------------------------------------------------------------ + def _ddp_topology(self) -> tuple[int, int]: + """Return ``(num_replicas, rank)`` for the current trainer. + + Lightning's auto-wrap hook only passes ``world_size``/``rank`` to + ``sampler``, not ``batch_sampler``. With ``use_distributed_sampler: + false`` and a batch sampler, the datamodule must read them from the + trainer itself and forward them; otherwise every rank iterates the + full sequence and yields identical batches. + + Returns ``(1, 0)`` when no trainer is attached (e.g. bare + dataloader construction in tests) or when the trainer stub lacks + DDP attributes (e.g. the ``_FakeTrainer`` in demo scripts). + """ + trainer = getattr(self, "trainer", None) + world_size = getattr(trainer, "world_size", None) + global_rank = getattr(trainer, "global_rank", None) + if world_size is None or global_rank is None: + return 1, 0 + return world_size, global_rank + def train_dataloader(self) -> ThreadDataLoader: """Return training data loader with FlexibleBatchSampler.""" + num_replicas, rank = self._ddp_topology() sampler = FlexibleBatchSampler( valid_anchors=self.train_dataset.index.valid_anchors, batch_size=self.batch_size, @@ -508,27 +638,76 @@ def train_dataloader(self) -> ThreadDataLoader: temporal_enrichment=self.temporal_enrichment, temporal_window_hours=self.temporal_window_hours, temporal_global_fraction=self.temporal_global_fraction, + num_replicas=num_replicas, + rank=rank, seed=self.seed, ) return ThreadDataLoader( self.train_dataset, use_thread_workers=True, + buffer_size=self.buffer_size, batch_sampler=sampler, num_workers=self.num_workers, + pin_memory=self.pin_memory, + prefetch_factor=self.prefetch_factor, collate_fn=lambda x: x, ) def val_dataloader(self) -> ThreadDataLoader | None: - """Return validation data loader.""" + """Return validation data loader. + + Uses the same ``FlexibleBatchSampler`` as training so ``loss/val`` + is measured on batches whose composition matches the training + regime — e.g. single-marker batches when ``batch_group_by="marker"``, + or perturbation-stratified batches when ``stratify_by`` is set. + + Without this, val was a plain sequential DataLoader that served + one experiment/marker at a time (all 4 example batches end up as + the same marker), and DDP sync of ``loss/val`` silently desynced + across ranks because each rank's shard had a different set of + markers. + + Temporal enrichment is disabled for val (we want a deterministic + representative sample, not oversampled biology-of-interest windows). + """ if self.val_dataset is None: return None + num_replicas, rank = self._ddp_topology() + sampler = FlexibleBatchSampler( + valid_anchors=self.val_dataset.index.valid_anchors, + batch_size=self.batch_size, + batch_group_by=self.batch_group_by, + leaky=self.leaky, + group_weights=self.group_weights, + stratify_by=self.stratify_by, + temporal_enrichment=False, + num_replicas=num_replicas, + rank=rank, + seed=self.seed, + ) return ThreadDataLoader( self.val_dataset, use_thread_workers=True, + buffer_size=self.buffer_size, + batch_sampler=sampler, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + prefetch_factor=self.prefetch_factor, + collate_fn=lambda x: x, + ) + + def predict_dataloader(self) -> ThreadDataLoader: + """Return predict data loader (no shuffling, no dropping).""" + return ThreadDataLoader( + self.predict_dataset, + use_thread_workers=True, + buffer_size=self.buffer_size, batch_size=self.batch_size, num_workers=self.num_workers, - shuffle=self.shuffle_val, + shuffle=False, drop_last=False, + pin_memory=self.pin_memory, + prefetch_factor=self.prefetch_factor, collate_fn=lambda x: x, ) @@ -536,18 +715,15 @@ def val_dataloader(self) -> ThreadDataLoader | None: # Transforms # ------------------------------------------------------------------ - def _train_final_crop(self) -> BatchedRandSpatialCropd: - """Random crop from extraction size to model input size (training).""" - return BatchedRandSpatialCropd( - keys=self._channel_names, - roi_size=(self.z_window, self.final_yx_patch_size[0], self.final_yx_patch_size[1]), - ) - - def _val_final_crop(self) -> BatchedCenterSpatialCropd: - """Center crop from extraction size to model input size (validation).""" + def _train_final_crop(self) -> BatchedCenterSpatialCropd: + """Center crop from extraction size to model input size (training).""" return BatchedCenterSpatialCropd( keys=self._channel_names, - roi_size=(self.z_window, self.final_yx_patch_size[0], self.final_yx_patch_size[1]), + roi_size=( + self.z_window, + self.final_yx_patch_size[0], + self.final_yx_patch_size[1], + ), ) def on_after_batch_transfer(self, batch, dataloader_idx: int): @@ -568,11 +744,38 @@ def on_after_batch_transfer(self, batch, dataloader_idx: int): if isinstance(batch, Tensor): return batch - # Determine transform: augmentation for training, no-aug for val - if self.trainer and self.trainer.validating: - transform = self._no_augmentation_transform - else: - transform = self._augmentation_transform + # During predict: normalizations + z_reduction only (no augmentations, no channel dropout). + if self.trainer.predicting: + norm_meta = batch.get("anchor_norm_meta") + if isinstance(norm_meta, list): + non_none = [m for m in norm_meta if m is not None] + if len(non_none) == 0: + norm_meta = None + elif len(non_none) != len(norm_meta): + raise ValueError("Mixed None/non-None norm_meta in predict batch.") + extra = None + if isinstance(self.channels_per_sample, int): + meta = batch.get("anchor_meta") + if meta is not None: + extra = { + "_is_labelfree": torch.tensor( + [parse_channel_name(m.get("marker", ""))["channel_type"] == "labelfree" for m in meta], + dtype=torch.bool, + device=batch["anchor"].device, + ) + } + batch["anchor"] = _transform_channel_wise( + transform=self._predict_transform, + channel_names=self._channel_names, + patch=batch["anchor"], + norm_meta=norm_meta, + extra=extra, + ) + batch.pop("anchor_norm_meta", None) + batch.pop("anchor_meta", None) + return batch + + transform = self._augmentation_transform for key in ["anchor", "positive", "negative"]: if key in batch: @@ -588,20 +791,31 @@ def on_after_batch_transfer(self, batch, dataloader_idx: int): "All FOVs must have normalization metadata or none of them." ) # else: all non-None, pass through as list + extra = None + if isinstance(self.channels_per_sample, int): + meta = batch.get(f"{key}_meta") + if meta is not None: + extra = { + "_is_labelfree": torch.tensor( + [parse_channel_name(m.get("marker", ""))["channel_type"] == "labelfree" for m in meta], + dtype=torch.bool, + device=batch[key].device, + ) + } transformed = _transform_channel_wise( transform=transform, channel_names=self._channel_names, patch=batch[key], norm_meta=norm_meta, + extra=extra, ) batch[key] = transformed if norm_meta_key in batch: del batch[norm_meta_key] - # Apply ChannelDropout to anchor and positive (training only) - if not (self.trainer and self.trainer.validating): - for key in ["anchor", "positive"]: - if key in batch: - batch[key] = self.channel_dropout(batch[key]) + # Apply ChannelDropout to anchor and positive + for key in ["anchor", "positive"]: + if key in batch: + batch[key] = self.channel_dropout(batch[key]) return batch diff --git a/applications/dynaclr/src/dynaclr/data/dataset.py b/applications/dynaclr/src/dynaclr/data/dataset.py index a2313fe1a..332067a40 100644 --- a/applications/dynaclr/src/dynaclr/data/dataset.py +++ b/applications/dynaclr/src/dynaclr/data/dataset.py @@ -31,21 +31,70 @@ except ImportError: ts = None +from iohub.ngff import open_ome_zarr + from dynaclr.data.index import MultiExperimentIndex from dynaclr.data.tau_sampling import sample_tau from viscy_data._typing import ULTRACK_INDEX_COLUMNS, NormMeta, SampleMeta from viscy_data._utils import _read_norm_meta + +def _pick_temporal_candidate( + timepoints: dict[int, list[int]], + anchor_t: int, + tau_min: int, + tau_max: int, + tau_decay_rate: float, + rng: np.random.Generator, + tr_marker_arr: np.ndarray | None, + anchor_marker: object | None, +) -> int | None: + """Pick one positive tracks-index for a temporal anchor. + + Mirrors the legacy ``_find_temporal_positive._pick`` logic but + operates on pre-computed NumPy arrays. Returns ``None`` if no + candidate is found in the ``[tau_min, tau_max]`` window. + """ + + def _filter_and_pick(cand_indices: list[int]) -> int | None: + if not cand_indices: + return None + if tr_marker_arr is not None: + # NumPy fancy-index filter: O(n) with n = number of candidates, + # single vectorized array op. + idx_arr = np.asarray(cand_indices, dtype=np.int64) + mask = tr_marker_arr[idx_arr] == anchor_marker + filtered = idx_arr[mask] + if len(filtered) > 0: + return int(filtered[rng.integers(len(filtered))]) + return int(cand_indices[rng.integers(len(cand_indices))]) + + sampled_tau = sample_tau(tau_min, tau_max, rng, tau_decay_rate) + result = _filter_and_pick(timepoints.get(anchor_t + sampled_tau, [])) + if result is not None: + return result + for tau in range(tau_min, tau_max + 1): + if tau == 0: + continue + result = _filter_and_pick(timepoints.get(anchor_t + tau, [])) + if result is not None: + return result + return None + + _META_COLUMNS = [ "experiment", "perturbation", "microscope", "fov_name", + "store_path", "global_track_id", "t", "hours_post_perturbation", "lineage_id", "marker", + "y_clamp", + "x_clamp", ] _logger = logging.getLogger(__name__) @@ -202,8 +251,16 @@ def __init__( self._rng = np.random.default_rng() self._tensorstores: dict[str, ts.TensorStore] = {} + self._store_cache: dict[str, object] = {} # store_path -> Plate + self._position_cache: dict[str, object] = {} # fov_name -> Position self._norm_meta_cache: dict[str, NormMeta | None] = {} - self._build_match_lookup() + if self.fit: + self._build_match_lookup() + self._build_anchor_cache() + + # ------------------------------------------------------------------ + # Initialization helpers + # ------------------------------------------------------------------ def _build_match_lookup(self) -> None: """Build lookup structures for O(1) positive candidate lookup. @@ -222,21 +279,107 @@ def _build_match_lookup(self) -> None: tracks = self.index.tracks if "lineage_id" in self.positive_match_columns: + # observed=True skips unobserved Categorical cross-products; + # without it groupby yields empty groups for every Categorical + # combination, exploding memory and time. Keys are coerced to + # str so the lookup works regardless of dtype (Categorical vs + # object vs ArrowString). + grouped = tracks.groupby(["experiment", "lineage_id", "t"], observed=True).indices self._lineage_timepoints: dict[tuple[str, str], dict[int, list[int]]] = defaultdict( lambda: defaultdict(list) ) - experiments = tracks["experiment"].to_numpy() - lineage_ids = tracks["lineage_id"].to_numpy() - t_values = tracks["t"].to_numpy() - for idx in range(len(tracks)): - self._lineage_timepoints[(experiments[idx], lineage_ids[idx])][t_values[idx]].append(idx) + for (exp, lid, t), row_indices in grouped.items(): + self._lineage_timepoints[(str(exp), str(lid))][int(t)] = row_indices.tolist() else: cols = self.positive_match_columns - self._match_lookup: dict[tuple, list[int]] = defaultdict(list) - col_arrays = [tracks[c].to_numpy() for c in cols] - for idx in range(len(tracks)): - key = tuple(arr[idx] for arr in col_arrays) - self._match_lookup[key].append(idx) + grouped = tracks.groupby(cols).indices + # Store candidate indices as ndarray for O(1) random choice without list copy. + self._match_lookup: dict[tuple, np.ndarray] = { + (k if isinstance(k, tuple) else (k,)): v for k, v in grouped.items() + } + + def _build_anchor_cache(self) -> None: + """Cache valid_anchors/tracks columns as NumPy arrays for fast per-sample access. + + Avoids pandas ``.iloc[idx][col]`` in the hot path, which constructs a + Series per call (~9 ms per anchor on 81M-row indices). NumPy indexing + is ~20 ns. Measured end-to-end speedup: ~3000× on positive-lookup. + + Both ``_va_arrays`` (for anchors) and ``_tr_arrays`` (for positives) + cache the full set of columns needed by ``_slice_patch`` and + ``_build_norm_meta``: ``store_path``, ``fov_name``, ``experiment``, + ``t``, ``y_clamp``, ``x_clamp``, plus ``norm_*`` columns for the + parquet-norm fast path. + + Cache is in-process RAM only — rebuilt on every dataset instantiation + from ``self.index.valid_anchors`` / ``self.index.tracks``. Parquet + remains the source of truth. + + Also precomputes per-experiment tau range (frames) to avoid a registry + lookup per anchor inside ``_sample_positives_temporal``. + """ + + # High-cardinality string columns (store_path, fov_name, experiment, + # marker, channel_name, lineage_id) have few unique values relative to + # row count, so cache them as category codes + categories lookup instead + # of object arrays. Object arrays of strings are ~40-80 bytes/entry; a + # categorical code is 4-8 bytes. On 81M rows this is the difference + # between an OOM and a healthy init. + # + # Access pattern: array[idx] still works if array is a pandas Categorical + # (returns the underlying string); downstream code doesn't care. + def _cache_columns(df: pd.DataFrame, columns: list[str]) -> dict: + out = {} + for col in columns: + if col not in df.columns: + continue + s = df[col] + if s.dtype == object or pd.api.types.is_string_dtype(s): + out[col] = s.astype("category").array # pd.Categorical + else: + out[col] = s.to_numpy() + return out + + # Whitelist columns actually read in the hot path. Caching every + # column of valid_anchors (81M+ rows × ~20 cols × 4 DDP ranks) blows + # the node memory budget; holding only the read set keeps per-rank + # RSS in the low tens of GiB. `positive_match_columns` (user-defined) + # and label column values must also be cached because they drive the + # SupCon key construction and per-sample label lookup respectively. + hot_cols: set[str] = { + "channel_name", + "experiment", + "lineage_id", + "t", + "marker", + "store_path", + "fov_name", + "y_clamp", + "x_clamp", + "norm_mean", + "norm_std", + "norm_median", + "norm_iqr", + } + if self.positive_match_columns: + hot_cols.update(self.positive_match_columns) + if getattr(self, "_label_encoders", None): + for col, _encoder in self._label_encoders.values(): + hot_cols.add(col) + + self._va_arrays: dict = _cache_columns(self.index.valid_anchors, sorted(hot_cols)) + self._tr_arrays: dict = _cache_columns(self.index.tracks, sorted(hot_cols)) + + # Precompute per-experiment tau range in frames to avoid a per-anchor + # registry call inside _sample_positive_indices_temporal. Skip + # experiments with interval_minutes == 0 (static/snapshot datasets like + # OPS) — they never go through the temporal path (positive_match_columns + # wouldn't include lineage_id), so missing entries are harmless and + # computing tau_range_frames for them would ZeroDivisionError. + self._tau_range_frames_cache: dict[str, tuple[int, int]] = {} + for name, exp in self.index.registry._name_map.items(): + if getattr(exp, "interval_minutes", 0): + self._tau_range_frames_cache[name] = self.index.registry.tau_range_frames(name, self.tau_range_hours) # ------------------------------------------------------------------ # Dataset protocol @@ -271,14 +414,16 @@ def __getitems__(self, indices: list[int]) -> dict: anchor_rows = self.index.valid_anchors.iloc[indices] # Pre-compute per-sample channel names based on channel_mode. + # Use the NumPy cache to avoid a pandas Series construction per row. if self._channel_mode == "from_index": - forced_channel_names = [[row["channel_name"]] for _, row in anchor_rows.iterrows()] + chan_arr = self._va_arrays["channel_name"] + forced_channel_names = [[chan_arr[i]] for i in indices] elif self._channel_mode == "fixed": forced_channel_names = [self._fixed_channel_names] * len(indices) else: forced_channel_names = None - anchor_patches, anchor_norms = self._slice_patches(anchor_rows, forced_channel_names) + anchor_patches, anchor_norms = self._slice_patches(self._va_arrays, indices, forced_channel_names) sample: dict = { "anchor": anchor_patches, "anchor_norm_meta": anchor_norms, @@ -286,27 +431,45 @@ def __getitems__(self, indices: list[int]) -> dict: } if self.fit: - positive_rows = self._sample_positives(anchor_rows) - if self._channel_mode == "from_index": - pos_forced_channel_names = [[row["channel_name"]] for _, row in positive_rows.iterrows()] + if self.positive_cell_source == "self": + # SimCLR: anchor and positive share the same patch pre-augmentation. + # Skip the second zarr read + meta extraction entirely — augmentation + # (applied independently downstream in on_after_batch_transfer) is + # what creates the two views. This roughly halves per-batch wall + # time for SimCLR baselines. + # clone the tensor so augmentation has an independent buffer to + # mutate without leaking into the anchor. + sample["positive"] = sample["anchor"].clone() + sample["positive_norm_meta"] = sample["anchor_norm_meta"] + sample["positive_meta"] = sample["anchor_meta"] else: - pos_forced_channel_names = forced_channel_names - positive_patches, positive_norms = self._slice_patches(positive_rows, pos_forced_channel_names) - sample["positive"] = positive_patches - sample["positive_norm_meta"] = positive_norms - sample["positive_meta"] = self._extract_meta(positive_rows) + pos_track_indices = self._sample_positive_indices(anchor_positions=indices) + if self._channel_mode == "from_index": + tr_chan_arr = self._tr_arrays["channel_name"] + pos_forced_channel_names = [[tr_chan_arr[i]] for i in pos_track_indices] + else: + pos_forced_channel_names = forced_channel_names + positive_patches, positive_norms = self._slice_patches( + self._tr_arrays, pos_track_indices, pos_forced_channel_names + ) + positive_rows = self.index.tracks.iloc[pos_track_indices].reset_index(drop=True) + sample["positive"] = positive_patches + sample["positive_norm_meta"] = positive_norms + sample["positive_meta"] = self._extract_meta(positive_rows) else: - indices_list = [] - for _, anchor_row in anchor_rows.iterrows(): - idx_dict: dict = {} - for col in ULTRACK_INDEX_COLUMNS: - if col in anchor_row.index: - idx_dict[col] = anchor_row[col] - elif col not in ["y", "x", "z"]: - # optional columns - pass - indices_list.append(idx_dict) - sample["index"] = indices_list + # Build per-sample index dicts via NumPy column arrays (no .iterrows). + all_cols = list(ULTRACK_INDEX_COLUMNS) + [ + "experiment", + "marker", + "perturbation", + "hours_post_perturbation", + "organelle", + "well", + "microscope", + ] + present_cols = [c for c in all_cols if c in anchor_rows.columns] + col_arrays = {c: anchor_rows[c].to_numpy() for c in present_cols} + sample["index"] = [{c: col_arrays[c][i] for c in present_cols} for i in range(len(anchor_rows))] return sample @@ -328,10 +491,18 @@ def _extract_meta(self, rows: pd.DataFrame) -> list[SampleMeta]: cols = [c for c in _META_COLUMNS if c in rows.columns] records = rows[cols].to_dict(orient="records") if self._label_encoders: - for i, (_, row) in enumerate(rows.iterrows()): + # Pre-extract label columns as NumPy arrays once (avoids per-row + # Series construction in .iterrows()). + label_arrays = { + batch_key: (encoder, rows[col].to_numpy() if col in rows.columns else None) + for batch_key, (col, encoder) in self._label_encoders.items() + } + for i in range(len(records)): labels = {} - for batch_key, (col, encoder) in self._label_encoders.items(): - val = row.get(col) + for batch_key, (encoder, arr) in label_arrays.items(): + if arr is None: + continue + val = arr[i] if val is not None and val in encoder: labels[batch_key] = encoder[val] records[i]["labels"] = labels @@ -341,181 +512,264 @@ def _extract_meta(self, rows: pd.DataFrame) -> list[SampleMeta]: # Positive sampling # ------------------------------------------------------------------ - def _sample_positives(self, anchor_rows: pd.DataFrame) -> pd.DataFrame: - """Sample one positive for each anchor. + def _sample_positive_indices( + self, + anchor_positions: list[int], + ) -> np.ndarray: + """Sample one positive tracks-index for each anchor. - When ``positive_cell_source="self"``, returns a copy of ``anchor_rows`` - (same crop; augmentation creates two views). Otherwise delegates to - :meth:`_find_positive`. + Returns positional indices into ``self.index.tracks`` / ``self._tr_arrays`` + — callers can slice patches directly from the cached NumPy arrays without + materializing a DataFrame. The DataFrame is still constructed downstream + for metadata extraction. Parameters ---------- - anchor_rows : pd.DataFrame - Rows from ``valid_anchors`` for the current batch. + anchor_positions : list[int] + Positional indices into ``valid_anchors`` (same as the sampler output). Returns ------- - pd.DataFrame - One row per anchor from ``self.index.tracks``. + np.ndarray + One tracks-positional-index per anchor, shape ``(len(anchor_positions),)``. """ - if self.positive_cell_source == "self": - return anchor_rows.copy().reset_index(drop=True) + # Temporal lineage mode — vectorized NumPy fast path + # (used by DynaCLR-2D-MIP, DynaCLR-3D-BagOfChannels). + if "lineage_id" in self.positive_match_columns: + return self._sample_positive_indices_temporal(anchor_positions) - pos_rows = [] - for _, row in anchor_rows.iterrows(): - pos = self._find_positive(row, self._rng) - if pos is None: + # Column-match mode (SupCon) — vectorized NumPy fast path. + cols = self.positive_match_columns + va_col_arrs = [self._va_arrays[c] for c in cols] + + pos_track_indices = np.empty(len(anchor_positions), dtype=np.int64) + match_lookup = self._match_lookup + rng = self._rng + for i, ai in enumerate(anchor_positions): + key = tuple(arr[ai] for arr in va_col_arrs) + cands = match_lookup.get(key) + if cands is None or len(cands) == 0: raise RuntimeError( - f"No positive found for anchor (experiment={row.get('experiment')}, " - f"match_key={tuple(row.get(c) for c in self.positive_match_columns)}, " - f"t={row.get('t')}). " + f"No positive found for anchor at position {ai} key={key}. " "This anchor should have been filtered out by valid_anchors." ) - pos_rows.append(pos) - return pd.DataFrame(pos_rows).reset_index(drop=True) + # Random pick from candidates. Note: the anchor's own tracks-index + # may be in `cands`; we don't filter it out explicitly because the + # anchor's valid_anchors-position and its tracks-index are in + # independent index spaces after reset_index(drop=True), and the + # original per-row implementation made the same loose comparison. + # For typical group sizes (>100), the self-as-positive probability + # is <1% — functionally equivalent to `positive_cell_source="self"`. + pos_track_indices[i] = cands[rng.integers(len(cands))] - def _find_positive( - self, - anchor_row: pd.Series, - rng: np.random.Generator, - ) -> pd.Series | None: - """Find a positive sample for a given anchor. + return pos_track_indices + + def _sample_positive_indices_temporal(self, anchor_positions: list[int]) -> np.ndarray: + """Vectorized temporal positive lookup (lineage + tau range). - Dispatches to temporal or generic column-match lookup based on - ``positive_match_columns``. + Uses pre-computed NumPy caches instead of per-row pandas ``.iloc``. + Uses ``self._tau_range_frames_cache`` to avoid a registry call per anchor. Parameters ---------- - anchor_row : pd.Series - A single row from ``valid_anchors``. - rng : numpy.random.Generator - Random number generator for tau sampling and tie-breaking. + anchor_positions : list[int] + Positional indices into ``valid_anchors`` for the batch. Returns ------- - pd.Series or None - A track row for the positive, or ``None`` if no positive found. + np.ndarray + Positional indices into ``self.index.tracks``, one per anchor. """ - if "lineage_id" in self.positive_match_columns: - return self._find_temporal_positive(anchor_row, rng) - return self._find_column_match_positive(anchor_row, rng) + rng = self._rng + exp_arr = self._va_arrays["experiment"] + lid_arr = self._va_arrays["lineage_id"] + t_arr = self._va_arrays["t"] + tau_cache = self._tau_range_frames_cache + + # In from_index mode (flat parquet), we filter candidates to same marker. + marker_filter = self._channel_mode == "from_index" + if marker_filter: + anchor_marker_arr = self._va_arrays["marker"] + tr_marker_arr = self._tr_arrays["marker"] + + pos_track_indices = np.empty(len(anchor_positions), dtype=np.int64) + lt_map = self._lineage_timepoints + + for i, ai in enumerate(anchor_positions): + # Coerce to str: _va_arrays columns come back as Categorical + # scalars after _materialize_strings, which hash differently + # from the str keys in _lineage_timepoints / _tau_range_frames_cache. + exp_name = str(exp_arr[ai]) + lineage_id = str(lid_arr[ai]) + anchor_t = int(t_arr[ai]) + + tau_min, tau_max = tau_cache[exp_name] + timepoints = lt_map.get((exp_name, lineage_id)) + if timepoints is None: + raise RuntimeError( + f"No positive found for anchor at position {ai} " + f"(experiment={exp_name}, lineage_id={lineage_id}, t={anchor_t}). " + "This anchor should have been filtered out by valid_anchors." + ) - def _find_temporal_positive( - self, - anchor_row: pd.Series, - rng: np.random.Generator, - ) -> pd.Series | None: - """Find a temporal positive: same lineage at ``t + tau``. + anchor_marker = anchor_marker_arr[ai] if marker_filter else None + chosen = _pick_temporal_candidate( + timepoints, + anchor_t, + tau_min, + tau_max, + self.tau_decay_rate, + rng, + tr_marker_arr if marker_filter else None, + anchor_marker, + ) + if chosen is None: + raise RuntimeError( + f"No positive found for anchor at position {ai} " + f"(experiment={exp_name}, lineage_id={lineage_id}, t={anchor_t}). " + "This anchor should have been filtered out by valid_anchors." + ) + pos_track_indices[i] = chosen + + return pos_track_indices + + # ------------------------------------------------------------------ + # Patch extraction (tensorstore I/O) + # ------------------------------------------------------------------ + + def _get_position(self, store_path: str, fov_name: str): + """Get or create a cached Position object for the given FOV. + + Cache is keyed by ``(store_path, fov_name)`` — critical for OPS + where the same FOV name (e.g. ``"A/3/0"``) appears across multiple + experiments. Parameters ---------- - anchor_row : pd.Series - A single row from ``valid_anchors``. - rng : numpy.random.Generator - Random number generator for tau sampling and tie-breaking. + store_path : str + Path to the OME-Zarr plate store. + fov_name : str + FOV name (e.g. ``"A/1/0"``). Returns ------- - pd.Series or None - A track row for the positive, or ``None`` if no positive found. + iohub.ngff.Position """ - exp_name = anchor_row["experiment"] - lineage_id = anchor_row["lineage_id"] - anchor_t = anchor_row["t"] - - tau_min, tau_max = self.index.registry.tau_range_frames(exp_name, self.tau_range_hours) - - lt_key = (exp_name, lineage_id) - lt_map = self._lineage_timepoints.get(lt_key) - if lt_map is None: - return None - - # In from_index mode (flat parquet), filter candidates to same marker. - # NOTE:The parquet SHOULD guarantee one channel_name per marker per experiment, - # so marker filtering is equivalent to channel_name filtering. - anchor_marker = anchor_row.get("marker") if self._channel_mode == "from_index" else None - - def _pick(candidate_indices: list[int]) -> pd.Series | None: - if not candidate_indices: - return None - if anchor_marker is not None: - filtered = [ - idx for idx in candidate_indices if self.index.tracks.iloc[idx].get("marker") == anchor_marker - ] - if filtered: - candidate_indices = filtered - chosen_idx = candidate_indices[rng.integers(len(candidate_indices))] - return self.index.tracks.iloc[chosen_idx] - - # Try sampled tau first, then scan full range as fallback - sampled_tau = sample_tau(tau_min, tau_max, rng, self.tau_decay_rate) - target_t = anchor_t + sampled_tau - result = _pick(lt_map.get(target_t, [])) - if result is not None: - return result - - for tau in range(tau_min, tau_max + 1): - if tau == 0: - continue - result = _pick(lt_map.get(anchor_t + tau, [])) - if result is not None: - return result + key = (store_path, fov_name) + if key not in self._position_cache: + if store_path not in self._store_cache: + self._store_cache[store_path] = open_ome_zarr( + store_path, + mode="r", + implementation="tensorstore", + implementation_config=self.index.tensorstore_config, + ) + plate = self._store_cache[store_path] + self._position_cache[key] = plate[fov_name] + return self._position_cache[key] - return None + def _get_tensorstore(self, store_path: str, fov_name: str) -> "ts.TensorStore": + """Get or create a cached tensorstore object for the given FOV. - def _find_column_match_positive( - self, - anchor_row: pd.Series, - rng: np.random.Generator, - ) -> pd.Series | None: - """Find a positive by matching column values, excluding the anchor itself. + Cache is keyed by ``(store_path, fov_name)`` — critical for OPS + where the same FOV name appears across multiple experiments. Parameters ---------- - anchor_row : pd.Series - A single row from ``valid_anchors``. - rng : numpy.random.Generator - Random number generator for tie-breaking. + store_path : str + Path to the OME-Zarr plate store. + fov_name : str + FOV name used together with ``store_path`` as cache key. Returns ------- - pd.Series or None - A track row for the positive, or ``None`` if no candidates found. + ts.TensorStore """ - cols = self.positive_match_columns - key = tuple(anchor_row[c] for c in cols) - all_candidates = self._match_lookup.get(key, []) - # Exclude the anchor row itself by integer index - candidates = [i for i in all_candidates if i != anchor_row.name] - if not candidates: - return None - chosen_idx = candidates[rng.integers(len(candidates))] - return self.index.tracks.iloc[chosen_idx] + key = (store_path, fov_name) + if key not in self._tensorstores: + position = self._get_position(store_path, fov_name) + self._tensorstores[key] = position["0"].native + return self._tensorstores[key] - # ------------------------------------------------------------------ - # Patch extraction (tensorstore I/O) - # ------------------------------------------------------------------ + def _build_norm_meta( + self, + arrays: dict[str, np.ndarray], + idx: int, + forced_channel_names: list[str] | None, + ) -> NormMeta | None: + """Build per-sample normalization metadata from parquet columns. - def _get_tensorstore(self, position, fov_name: str) -> "ts.TensorStore": - """Get or create a cached tensorstore object for the given FOV. + When the parquet has ``norm_mean`` / ``norm_std`` columns (written by + ``preprocess-cell-index``), reads stats directly from the cached + NumPy arrays — no zarr zattrs access and no pandas Series construction. + Falls back to zarr zattrs for old parquets. Parameters ---------- - position : iohub.ngff.Position - Position object from the OME-Zarr store. - fov_name : str - FOV name used as cache key. + arrays : dict[str, np.ndarray] + Pre-cached NumPy column arrays (``_va_arrays`` or ``_tr_arrays``). + idx : int + Positional row index into ``arrays``. + forced_channel_names : list[str] or None + Zarr channel names being read for this sample. Returns ------- - ts.TensorStore + NormMeta or None """ - if fov_name not in self._tensorstores: - self._tensorstores[fov_name] = position["0"].native - return self._tensorstores[fov_name] + # Parquet path: norm columns present and value is not NA + norm_mean_arr = arrays.get("norm_mean") + if norm_mean_arr is not None: + norm_mean = norm_mean_arr[idx] + if norm_mean is not None and not (isinstance(norm_mean, float) and np.isnan(norm_mean)): + tp_stats = { + "mean": torch.tensor(norm_mean, dtype=torch.float32), + "std": torch.tensor(arrays["norm_std"][idx], dtype=torch.float32), + "median": torch.tensor(arrays["norm_median"][idx], dtype=torch.float32), + "iqr": torch.tensor(arrays["norm_iqr"][idx], dtype=torch.float32), + } + if self._channel_mode == "from_index": + return {"channel_0": {"timepoint_statistics": tp_stats}} + else: + ch_arr = arrays.get("channel_name") + ch_name = ch_arr[idx] if ch_arr is not None else "channel_0" + return {ch_name: {"timepoint_statistics": tp_stats}} + + # Fallback: read from zarr zattrs (old parquets without norm columns) + store_path = arrays["store_path"][idx] + fov_name = arrays["fov_name"][idx] + t = arrays["t"][idx] + cache_key = (store_path, fov_name) + if cache_key not in self._norm_meta_cache: + position = self._get_position(store_path, fov_name) + self._norm_meta_cache[cache_key] = _read_norm_meta(position) + cached = self._norm_meta_cache[cache_key] + if cached is None: + return None + raw_norm_meta = {} + for ch, ch_meta in cached.items(): + resolved = {} + for level, level_stats in ch_meta.items(): + if level == "timepoint_statistics" and isinstance(level_stats, dict): + resolved[level] = level_stats.get(str(t)) + else: + resolved[level] = level_stats + raw_norm_meta[ch] = resolved + if forced_channel_names is not None and self._channel_mode == "from_index": + ch = forced_channel_names[0] + if ch in raw_norm_meta: + return {"channel_0": raw_norm_meta[ch]} + return None + if forced_channel_names is not None and self._channel_mode == "fixed": + raw_norm_meta = {name: raw_norm_meta[name] for name in forced_channel_names if name in raw_norm_meta} + return raw_norm_meta or None + return raw_norm_meta def _slice_patch( - self, track_row: pd.Series, forced_channel_names: list[str] | None = None + self, + arrays: dict[str, np.ndarray], + idx: int, + forced_channel_names: list[str] | None = None, ) -> tuple[ "ts.TensorStore", NormMeta | None, @@ -530,8 +784,10 @@ def _slice_patch( Parameters ---------- - track_row : pd.Series - A single row from ``tracks`` or ``valid_anchors``. + arrays : dict[str, np.ndarray] + Pre-cached NumPy column arrays (``_va_arrays`` or ``_tr_arrays``). + idx : int + Positional row index into ``arrays``. forced_channel_names : list[str] or None Zarr channel names to read. When provided, only these channels are sliced from the zarr. None reads all channels. @@ -543,15 +799,15 @@ def _slice_patch( scale factors ``(scale_z, scale_y, scale_x)``, and target size ``(z_window, patch_h, patch_w)``. """ - position = track_row["position"] - fov_name = track_row["fov_name"] - exp_name = track_row["experiment"] + store_path = arrays["store_path"][idx] + fov_name = arrays["fov_name"][idx] + exp_name = arrays["experiment"][idx] - image = self._get_tensorstore(position, fov_name) + image = self._get_tensorstore(store_path, fov_name) - t = track_row["t"] - y_center = int(track_row["y_clamp"]) - x_center = int(track_row["x_clamp"]) + t = int(arrays["t"][idx]) + y_center = int(arrays["y_clamp"][idx]) + x_center = int(arrays["x_clamp"][idx]) # Per-experiment scale factors for physical-space normalization scale_z, scale_y, scale_x = self.index.registry.scale_factors[exp_name] @@ -581,37 +837,8 @@ def _slice_patch( slice(x_center - x_half, x_center + x_half), ] - # Look up norm_meta by zarr channel name directly - # and pre-resolve timepoint_statistics for this sample's timepoint. - # Cache the tensor-converted norm_meta per FOV to avoid repeated - # zattrs reads. Build a shallow per-sample copy (dict structure only, - # tensors shared) since we only replace dict entries, not tensor values. - cache_key = (track_row["store_path"], fov_name) - if cache_key not in self._norm_meta_cache: - self._norm_meta_cache[cache_key] = _read_norm_meta(position) - cached = self._norm_meta_cache[cache_key] - if cached is not None: - raw_norm_meta = {ch: {level: stats for level, stats in ch_meta.items()} for ch, ch_meta in cached.items()} - # Pre-resolve timepoint_statistics for all channels - for ch_name, ch_meta in raw_norm_meta.items(): - if "timepoint_statistics" in ch_meta: - tp_stats = ch_meta["timepoint_statistics"].get(str(t)) - ch_meta["timepoint_statistics"] = tp_stats - else: - raw_norm_meta = None - if raw_norm_meta is not None: - # Filter to requested channels - if forced_channel_names is not None and self._channel_mode == "from_index": - ch = forced_channel_names[0] - if ch in raw_norm_meta: - raw_norm_meta = {"channel_0": raw_norm_meta[ch]} - else: - raw_norm_meta = None - elif forced_channel_names is not None and self._channel_mode == "fixed": - raw_norm_meta = {name: raw_norm_meta[name] for name in forced_channel_names if name in raw_norm_meta} - if not raw_norm_meta: - raw_norm_meta = None - # else: "all" mode — keep full raw_norm_meta + # Build norm_meta from parquet columns (preferred) or zarr zattrs (fallback). + raw_norm_meta = self._build_norm_meta(arrays, idx, forced_channel_names) # Use the configured extraction window as uniform target Z, # not the per-experiment capped range. This ensures all patches @@ -628,15 +855,18 @@ def _slice_patch( def _slice_patches( self, - track_rows: pd.DataFrame, + arrays: dict[str, np.ndarray], + indices: list[int] | np.ndarray, forced_channel_names: list[list[str]] | None = None, ) -> tuple[torch.Tensor, list[NormMeta | None]]: """Slice and stack patches for multiple track rows. Parameters ---------- - track_rows : pd.DataFrame - Multiple rows from ``tracks`` / ``valid_anchors``. + arrays : dict[str, np.ndarray] + Pre-cached NumPy column arrays (``_va_arrays`` or ``_tr_arrays``). + indices : list[int] or np.ndarray + Positional row indices into ``arrays``. forced_channel_names : list[list[str]] or None Per-sample zarr channel names to read. Each inner list contains the channel names for that sample. @@ -651,9 +881,9 @@ def _slice_patches( norms = [] scales = [] targets = [] - for i, (_, row) in enumerate(track_rows.iterrows()): + for i, idx in enumerate(indices): forced = forced_channel_names[i] if forced_channel_names is not None else None - patch, norm, scale, target = self._slice_patch(row, forced_channel_names=forced) + patch, norm, scale, target = self._slice_patch(arrays, int(idx), forced_channel_names=forced) patches.append(patch) norms.append(norm) scales.append(scale) @@ -674,4 +904,12 @@ def _slice_patches( rescaled = [] for i in range(len(patches)): rescaled.append(_rescale_patch(read_tensors[i], scales[i], targets[i])) + channel_counts = {t.shape[0] for t in rescaled} + if len(channel_counts) > 1: + raise RuntimeError( + f"Batch mixes samples with different channel counts: {sorted(channel_counts)}. " + "This happens with channels_per_sample=None across experiments that have " + "different channel counts. Set channels_per_sample=1 (bag-of-channels) " + "or channels_per_sample=[...] (fixed channel list)." + ) return torch.stack(rescaled), norms diff --git a/applications/dynaclr/src/dynaclr/data/experiment.py b/applications/dynaclr/src/dynaclr/data/experiment.py index 96187cafa..8134f7d54 100644 --- a/applications/dynaclr/src/dynaclr/data/experiment.py +++ b/applications/dynaclr/src/dynaclr/data/experiment.py @@ -12,6 +12,7 @@ from dataclasses import dataclass, field from pathlib import Path +import pandas as pd from iohub.ngff import open_ome_zarr from viscy_data.cell_index import read_cell_index @@ -96,27 +97,26 @@ def __post_init__(self) -> None: # noqa: D105 # Build name -> config map self._name_map = {e.name: e for e in experiments} - # Per-experiment validations + # Per-experiment validation + z-range resolution (single zarr open each) + z_extract = self.z_extraction_window or self.z_window + z_ranges: dict[str, tuple[int, int]] = {} + for exp in experiments: - # 4. Negative interval if exp.interval_minutes < 0: raise ValueError( f"Experiment '{exp.name}': interval_minutes must be non-negative, got {exp.interval_minutes}." ) - - # 5. Empty perturbation_wells if not exp.perturbation_wells: raise ValueError(f"Experiment '{exp.name}': perturbation_wells must not be empty.") - - # 6. data_path existence if not Path(exp.data_path).exists(): raise ValueError(f"Experiment '{exp.name}': data_path does not exist: {exp.data_path}") - # 7. Zarr channel validation — selected channels must exist in zarr with open_ome_zarr(exp.data_path, mode="r") as plate: first_position = next(iter(plate.positions()))[1] zarr_channels = list(first_position.channel_names) - # Store the full zarr channel list for index resolution + z_total = first_position["0"].shape[2] + focus_data = plate.zattrs.get("focus_slice", {}) + exp.channel_names = zarr_channels missing_channels = [ch.name for ch in exp.channels if ch.name not in zarr_channels] if missing_channels: @@ -125,16 +125,52 @@ def __post_init__(self) -> None: # noqa: D105 f"not found in zarr. Available: {zarr_channels}." ) - # Resolve per-experiment z_ranges - self.z_ranges = self._resolve_z_ranges() + # Z-range resolution + if z_extract is None: + z_ranges[exp.name] = (0, z_total) + else: + focus_ch = self.focus_channel or (exp.channels[0].name if exp.channels else None) + ch_focus = focus_data.get(focus_ch, {}) if focus_ch else {} + ds_stats = ch_focus.get("dataset_statistics", {}) + z_focus_mean = ds_stats.get("z_focus_mean") + + z_center = int(round(z_focus_mean)) if z_focus_mean is not None else z_total // 2 + effective_extract = min(z_extract, z_total) + z_below = int(effective_extract * self.z_focus_offset) + z_start = max(0, z_center - z_below) + z_end = min(z_total, z_start + effective_extract) + z_start = max(0, z_end - effective_extract) + + z_ranges[exp.name] = (z_start, z_end) + _logger.info( + "Experiment '%s': z_range=(%d, %d), z_total=%d, z_extraction_window=%d", + exp.name, + z_start, + z_end, + z_total, + effective_extract, + ) + + # Validate extraction windows >= z_window + if self.z_window is not None and z_ranges: + for name, (z_s, z_e) in z_ranges.items(): + if z_e - z_s < self.z_window: + raise ValueError( + f"Experiment '{name}': extraction range ({z_e - z_s}) " + f"< z_window ({self.z_window}). Increase z_extraction_window " + f"or reduce z_window." + ) + self.z_ranges = z_ranges # Validate pixel sizes and compute scale factors - if self.reference_pixel_size_xy_um is not None or self.reference_pixel_size_z_um is not None: - missing = [e.name for e in experiments if e.pixel_size_xy_um is None or e.pixel_size_z_um is None] + if self.reference_pixel_size_xy_um is not None: + missing = [e.name for e in experiments if e.pixel_size_xy_um is None] if missing: - raise ValueError( - f"reference_pixel_size set but experiments are missing pixel_size_xy_um/z_um: {missing}" - ) + raise ValueError(f"reference_pixel_size_xy_um set but experiments missing pixel_size_xy_um: {missing}") + if self.reference_pixel_size_z_um is not None: + missing = [e.name for e in experiments if e.pixel_size_z_um is None] + if missing: + raise ValueError(f"reference_pixel_size_z_um set but experiments missing pixel_size_z_um: {missing}") self.scale_factors = self._compute_scale_factors() @property @@ -158,72 +194,6 @@ def source_channel_labels(self) -> list[str]: # Internal helpers # ------------------------------------------------------------------ - def _resolve_z_ranges(self) -> dict[str, tuple[int, int]]: - """Resolve per-experiment Z extraction ranges. - - When ``z_extraction_window`` is set, extracts a larger Z range - centered on ``z_focus_mean`` (capped by the available Z depth). - The random crop from extraction size to ``z_window`` happens later - in ``on_after_batch_transfer``. - - Falls back to ``z_window`` when ``z_extraction_window`` is None. - """ - experiments = self.collection.experiments - z_ranges: dict[str, tuple[int, int]] = {} - z_extract = self.z_extraction_window or self.z_window - - for exp in experiments: - focus_ch = self.focus_channel or (exp.channels[0].name if exp.channels else None) - - with open_ome_zarr(exp.data_path, mode="r") as plate: - first_pos = next(iter(plate.positions()))[1] - z_total = first_pos["0"].shape[2] - - if z_extract is None: - z_ranges[exp.name] = (0, z_total) - continue - - focus_data = plate.zattrs.get("focus_slice", {}) - ch_focus = focus_data.get(focus_ch, {}) if focus_ch else {} - ds_stats = ch_focus.get("dataset_statistics", {}) - z_focus_mean = ds_stats.get("z_focus_mean") - - if z_focus_mean is None: - z_center = z_total // 2 - else: - z_center = int(round(z_focus_mean)) - - # Cap extraction window by available Z depth. - # z_focus_offset controls asymmetry: 0.5 = symmetric, - # 0.3 = 30% below focus, 70% above (cells on coverslip). - effective_extract = min(z_extract, z_total) - z_below = int(effective_extract * self.z_focus_offset) - z_start = max(0, z_center - z_below) - z_end = min(z_total, z_start + effective_extract) - z_start = max(0, z_end - effective_extract) - - z_ranges[exp.name] = (z_start, z_end) - _logger.info( - "Experiment '%s': z_range=(%d, %d), z_total=%d, z_extraction_window=%d", - exp.name, - z_start, - z_end, - z_total, - effective_extract, - ) - - # Validate: all extraction windows must be >= z_window - if self.z_window is not None and z_ranges: - for name, (z_s, z_e) in z_ranges.items(): - if z_e - z_s < self.z_window: - raise ValueError( - f"Experiment '{name}': extraction range ({z_e - z_s}) " - f"< z_window ({self.z_window}). Increase z_extraction_window " - f"or reduce z_window." - ) - - return z_ranges - def _compute_scale_factors(self) -> dict[str, tuple[float, float, float]]: """Compute per-experiment scale factors for physical-space normalization. @@ -237,18 +207,15 @@ def _compute_scale_factors(self) -> dict[str, tuple[float, float, float]]: """ scale_factors: dict[str, tuple[float, float, float]] = {} for exp in self.collection.experiments: - if ( - self.reference_pixel_size_xy_um is not None - and self.reference_pixel_size_z_um is not None - and exp.pixel_size_xy_um is not None - and exp.pixel_size_z_um is not None - ): + if self.reference_pixel_size_xy_um is not None and exp.pixel_size_xy_um is not None: scale_y = self.reference_pixel_size_xy_um / exp.pixel_size_xy_um scale_x = self.reference_pixel_size_xy_um / exp.pixel_size_xy_um - scale_z = self.reference_pixel_size_z_um / exp.pixel_size_z_um else: scale_y = 1.0 scale_x = 1.0 + if self.reference_pixel_size_z_um is not None and exp.pixel_size_z_um is not None: + scale_z = self.reference_pixel_size_z_um / exp.pixel_size_z_um + else: scale_z = 1.0 scale_factors[exp.name] = (scale_z, scale_y, scale_x) return scale_factors @@ -313,7 +280,7 @@ def from_cell_index( focus_channel: str | None = None, reference_pixel_size_xy_um: float | None = None, reference_pixel_size_z_um: float | None = None, - ) -> ExperimentRegistry: + ) -> tuple["ExperimentRegistry", "pd.DataFrame"]: """Build a registry from a flat cell index parquet and zarr metadata. Derives per-experiment channels from the parquet's ``marker`` and @@ -339,32 +306,24 @@ def from_cell_index( Returns ------- - ExperimentRegistry - Validated registry of experiments. + tuple[ExperimentRegistry, pd.DataFrame] + Validated registry of experiments and the raw cell index DataFrame. """ df = read_cell_index(cell_index_path) if df.empty: raise ValueError(f"Cell index is empty: {cell_index_path}") - # Step 1: Read channel names per (store_path, well) from zarr. - channel_names_cache: dict[tuple[str, str], list[str]] = {} - store_cache: dict[str, object] = {} + # Step 1: Read channel names per store from a single FOV. + # Channel names are uniform across all positions in a plate, + # so we open one FOV directly (store_path/well/fov) instead of + # iterating all positions. + channel_names_cache: dict[str, list[str]] = {} for store_path, group in df.groupby("store_path"): - plate = open_ome_zarr(str(store_path), mode="r") - store_cache[str(store_path)] = plate - for well in group["well"].unique(): - # Find one position in this well - well_str = str(well) - for pos_path, pos in plate.positions(): - if pos_path.startswith(well_str + "/"): - channel_names_cache[(str(store_path), well_str)] = list(pos.channel_names) - break - - # Close all opened stores - for plate in store_cache.values(): - if hasattr(plate, "close"): - plate.close() + first = group.iloc[0] + fov_path = f"{store_path}/{first['well']}/{first['fov']}" + with open_ome_zarr(fov_path, mode="r") as pos: + channel_names_cache[str(store_path)] = list(pos.channel_names) # Step 2: Derive per-experiment channels from flat (marker, channel_name) columns. exp_channels: dict[str, list[ChannelEntry]] = defaultdict(list) @@ -381,14 +340,7 @@ def from_cell_index( for exp_name, exp_group in df.groupby("experiment"): exp_name = str(exp_name) store_path = str(exp_group["store_path"].iloc[0]) - first_well = str(exp_group["well"].iloc[0]) - - channel_names = channel_names_cache.get((store_path, first_well)) - if channel_names is None: - raise ValueError( - f"Experiment '{exp_name}': could not read channel names from zarr " - f"(store_path={store_path}, well={first_well})." - ) + channel_names = channel_names_cache[store_path] # Derive perturbation_wells from parquet perturbation_wells: dict[str, list[str]] = defaultdict(list) @@ -453,7 +405,7 @@ def from_cell_index( experiments=experiments, ) - return cls( + registry = cls( collection=collection, z_window=z_window, z_extraction_window=z_extraction_window, @@ -462,6 +414,7 @@ def from_cell_index( reference_pixel_size_xy_um=reference_pixel_size_xy_um, reference_pixel_size_z_um=reference_pixel_size_z_um, ) + return registry, df def subset(self, experiment_names: list[str]) -> ExperimentRegistry: """Create a new registry with a subset of experiments. diff --git a/applications/dynaclr/src/dynaclr/data/index.py b/applications/dynaclr/src/dynaclr/data/index.py index 177d747ba..ddff9168a 100644 --- a/applications/dynaclr/src/dynaclr/data/index.py +++ b/applications/dynaclr/src/dynaclr/data/index.py @@ -15,7 +15,7 @@ import numpy as np import pandas as pd from iohub.core.config import TensorStoreConfig -from iohub.ngff import Plate, Position, open_ome_zarr +from iohub.ngff import Plate, open_ome_zarr from dynaclr.data.experiment import ExperimentRegistry from viscy_data.cell_index import read_cell_index @@ -185,10 +185,12 @@ def __init__( include_wells: list[str] | None = None, exclude_fovs: list[str] | None = None, cell_index_path: str | Path | None = None, + cell_index_df: pd.DataFrame | None = None, num_workers: int = 1, positive_cell_source: str = "lookup", positive_match_columns: list[str] | None = None, max_border_shift: int = -1, + fit: bool = True, tensorstore_config: TensorStoreConfig | None = None, ) -> None: self.registry = registry @@ -217,44 +219,53 @@ def __init__( else: all_exclude_fovs = None - if cell_index_path is not None: - _logger.info("Loading cell index from parquet: %s", cell_index_path) - tracks = read_cell_index(cell_index_path) - tracks = self._align_parquet_columns(tracks) + if cell_index_df is not None or cell_index_path is not None: + if cell_index_df is not None: + _logger.info( + "Using pre-loaded cell index DataFrame (%d rows)", + len(cell_index_df), + ) + tracks = self._align_parquet_columns(cell_index_df.copy()) + else: + _logger.info("Loading cell index from parquet: %s", cell_index_path) + tracks = read_cell_index(cell_index_path) + tracks = self._align_parquet_columns(tracks) if include_wells is not None: tracks = tracks[tracks["well_name"].isin(include_wells)].copy() if all_exclude_fovs is not None: tracks = tracks[~tracks["fov_name"].isin(all_exclude_fovs)].copy() tracks = self._filter_to_registry_experiments(tracks) - positions, tracks = self._resolve_positions_and_dims(tracks) - self.positions = positions + tracks = self._resolve_dims(tracks) # lineage_id already present from build step — skip _reconstruct_lineage - tracks = self._filter_empty_frames(tracks) + # Empty frames already filtered at parquet build time — skip _filter_empty_frames else: all_tracks = self._load_all_experiments( - include_wells=include_wells, exclude_fovs=all_exclude_fovs, num_workers=num_workers + include_wells=include_wells, + exclude_fovs=all_exclude_fovs, + num_workers=num_workers, ) tracks = pd.concat(all_tracks, ignore_index=True) if all_tracks else pd.DataFrame() tracks = self._reconstruct_lineage(tracks) - positions, tracks = self._resolve_positions_and_dims(tracks) - self.positions = positions - tracks = self._filter_empty_frames(tracks) + tracks = self._resolve_dims(tracks) tracks = self._clamp_borders(tracks) self.tracks = tracks.reset_index(drop=True) - self.valid_anchors = self._compute_valid_anchors( - tau_range_hours, - positive_cell_source=positive_cell_source, - positive_match_columns=positive_match_columns, - ) - if self.valid_anchors.empty and not self.tracks.empty: - raise ValueError( - f"No valid anchors found from {len(self.tracks)} tracks. " - f"positive_cell_source={positive_cell_source!r}, " - f"positive_match_columns={positive_match_columns!r}, " - f"tau_range_hours={tau_range_hours}. " - "Check that tracks have matching positives under these settings." + if fit: + self.valid_anchors = self._compute_valid_anchors( + tau_range_hours, + positive_cell_source=positive_cell_source, + positive_match_columns=positive_match_columns, ) + if self.valid_anchors.empty and not self.tracks.empty: + raise ValueError( + f"No valid anchors found from {len(self.tracks)} tracks. " + f"positive_cell_source={positive_cell_source!r}, " + f"positive_match_columns={positive_match_columns!r}, " + f"tau_range_hours={tau_range_hours}. " + "Check that tracks have matching positives under these settings." + ) + else: + self.valid_anchors = self.tracks # ------- internal methods ------- @@ -344,6 +355,15 @@ def _align_parquet_columns(tracks: pd.DataFrame) -> pd.DataFrame: ) if "microscope" not in tracks.columns: tracks["microscope"] = "" + # Cast low-cardinality string columns to Categorical to make + # downstream boolean-mask slicing (train/val split) a fast int-code + # gather instead of a pyarrow.compute.take over Arrow string buffers. + # Deferred from read_cell_index because ``fov_name`` is rewritten by + # the prefix logic above and Categorical columns don't support string + # concatenation. + for col in ("fov_name", "well_name"): + if col in tracks.columns and tracks[col].dtype == object: + tracks[col] = tracks[col].astype("category") return tracks def _filter_to_registry_experiments(self, tracks: pd.DataFrame) -> pd.DataFrame: @@ -351,22 +371,30 @@ def _filter_to_registry_experiments(self, tracks: pd.DataFrame) -> pd.DataFrame: registry_names = {exp.name for exp in self.registry.experiments} return tracks[tracks["experiment"].isin(registry_names)].copy() - def _resolve_positions_and_dims(self, tracks: pd.DataFrame) -> tuple[list[Position], pd.DataFrame]: - """Open zarr stores for unique (store_path, fov_name) pairs. + def _resolve_dims(self, tracks: pd.DataFrame) -> pd.DataFrame: + """Attach image dimensions to tracks for border clamping. - Attaches ``position``, ``_img_height``, ``_img_width`` columns to - *tracks* and returns the list of resolved Position objects. + When the parquet has ``Y_shape`` / ``X_shape`` columns (built with the + latest ``build_timelapse_cell_index``), reads dimensions directly — no + zarr opens needed. Falls back to opening stores when the columns are + missing (old parquets). """ - all_positions: list[Position] = [] - pos_lookup: dict[tuple[str, str], Position] = {} - dim_lookup: dict[tuple[str, str], tuple[int, int]] = {} - if tracks.empty: - tracks["position"] = pd.Series(dtype=object) tracks["_img_height"] = pd.Series(dtype=int) tracks["_img_width"] = pd.Series(dtype=int) - return all_positions, tracks + return tracks + if "Y_shape" in tracks.columns and "X_shape" in tracks.columns: + tracks["_img_height"] = tracks["Y_shape"] + tracks["_img_width"] = tracks["X_shape"] + return tracks + + _logger.warning( + "Parquet missing Y_shape/X_shape columns. Falling back to opening " + "zarr stores for image dimensions. Rebuild the parquet with " + "`build-cell-index` for faster startup." + ) + dim_lookup: dict[tuple[str, str], tuple[int, int]] = {} for (store_path, well_name, fov_name), _group in tracks.groupby(["store_path", "well_name", "fov_name"]): if store_path not in self._store_cache: self._store_cache[store_path] = open_ome_zarr( @@ -376,60 +404,17 @@ def _resolve_positions_and_dims(self, tracks: pd.DataFrame) -> tuple[list[Positi implementation_config=self.tensorstore_config, ) plate = self._store_cache[store_path] - # fov_name may be just the FOV id (e.g. "000000") or the full - # position path (e.g. "C/1/000000"). Prepend well_name when needed. if "/" in fov_name: position_path = fov_name else: position_path = f"{well_name}/{fov_name}" position = plate[position_path] - pos_lookup[(store_path, fov_name)] = position image = position["0"] dim_lookup[(store_path, fov_name)] = (image.height, image.width) - all_positions.append(position) - tracks["position"] = [pos_lookup[(sp, fn)] for sp, fn in zip(tracks["store_path"], tracks["fov_name"])] tracks["_img_height"] = [dim_lookup[(sp, fn)][0] for sp, fn in zip(tracks["store_path"], tracks["fov_name"])] tracks["_img_width"] = [dim_lookup[(sp, fn)][1] for sp, fn in zip(tracks["store_path"], tracks["fov_name"])] - - return all_positions, tracks - - @staticmethod - def _filter_empty_frames(tracks: pd.DataFrame) -> pd.DataFrame: - """Remove rows whose image frame is all zeros (missing acquisition). - - For each unique (store_path, fov_name, t) combination, reads a small - center crop of channel 0 to detect empty frames. Rows with an all-zero - frame are dropped. - """ - if tracks.empty or "t" not in tracks.columns: - return tracks - - valid_mask = pd.Series(True, index=tracks.index) - - for (store_path, fov_name), group in tracks.groupby(["store_path", "fov_name"]): - pos = group["position"].iloc[0] - image = pos["0"] - h, w = image.shape[-2], image.shape[-1] - cy, cx = h // 2, w // 2 - crop = 16 # 32x32 center crop is enough to detect empty frames - - for t in group["t"].unique(): - try: - patch = np.asarray(image[int(t), 0, :, cy - crop : cy + crop, cx - crop : cx + crop]) - if patch.max() == 0: - row_mask = ( - (tracks["store_path"] == store_path) & (tracks["fov_name"] == fov_name) & (tracks["t"] == t) - ) - valid_mask[row_mask] = False - except Exception: - pass # if we can't read, keep the row - - n_dropped = (~valid_mask).sum() - if n_dropped > 0: - _logger.info("Excluded %d observations from empty frames", n_dropped) - - return tracks[valid_mask].copy() + return tracks @staticmethod def _reconstruct_lineage(tracks: pd.DataFrame) -> pd.DataFrame: @@ -538,7 +523,11 @@ def _clamp_borders(self, tracks: pd.DataFrame) -> pd.DataFrame: n_dropped = n_before - len(tracks) if n_dropped > 0: - _logger.info("Excluded %d border cells (%.1f%%)", n_dropped, 100 * n_dropped / n_before) + _logger.info( + "Excluded %d border cells (%.1f%%)", + n_dropped, + 100 * n_dropped / n_before, + ) tracks = tracks.drop(columns=["_img_height", "_img_width"]) @@ -591,33 +580,43 @@ def _compute_valid_anchors( # Temporal mode: keep only anchors that have a positive at t+tau. # For each experiment, check whether (lineage_id, t+tau) exists - # for any tau in [min_f, max_f] (excluding 0). + # for any tau in [min_f, max_f] (excluding 0). In flat-parquet + # mode (one row per cell × channel), the dataset restricts + # candidates to the same marker at t+tau, so ``marker`` must be + # part of the match key here. Otherwise an anchor at (lid, marker=A, t) + # could pass validation because (lid, marker=B, t+1) exists, but + # fail at sample time because no (lid, marker=A, t+1) exists. + filter_by_marker = "marker" in self.tracks.columns + key_cols = ["lineage_id", "marker", "t"] if filter_by_marker else ["lineage_id", "t"] valid_mask = np.zeros(len(self.tracks), dtype=bool) for exp in self.registry.experiments: min_f, max_f = self.registry.tau_range_frames(exp.name, tau_range_hours) - exp_mask = self.tracks["experiment"].to_numpy() == exp.name - exp_indices = np.where(exp_mask)[0] - if len(exp_indices) == 0: + exp_mask = self.tracks["experiment"] == exp.name + exp_df = self.tracks.loc[exp_mask, key_cols] + if exp_df.empty: continue - lineage_ids = self.tracks["lineage_id"].to_numpy()[exp_indices] - t_values = self.tracks["t"].to_numpy()[exp_indices] - existing_pairs: set[tuple] = set(zip(lineage_ids, t_values)) + taus = [tau for tau in range(min_f, max_f + 1) if tau != 0] + + # Unique key tuples as a MultiIndex for O(1) isin checks. + existing = exp_df.drop_duplicates() + existing_mi = pd.MultiIndex.from_frame(existing) - # Collect all anchor (lineage_id, t) that have any valid positive - valid_anchors: set[tuple] = set() - for tau in range(min_f, max_f + 1): - if tau == 0: - continue - for lid, t in existing_pairs: - if (lid, t + tau) in existing_pairs: - valid_anchors.add((lid, t)) + # For each unique anchor key, check if the shifted key (same + # lineage_id/marker, t+tau) exists for any tau. + found_any = np.zeros(len(existing), dtype=bool) + t_vals = existing["t"].to_numpy() + non_t_arrays = [existing[c].to_numpy() for c in key_cols if c != "t"] + for tau in taus: + shifted_arrays = non_t_arrays + [t_vals + tau] + targets = pd.MultiIndex.from_arrays(shifted_arrays) + found_any |= targets.isin(existing_mi) - # Mark matching rows - for i, idx in enumerate(exp_indices): - if (lineage_ids[i], t_values[i]) in valid_anchors: - valid_mask[idx] = True + # Map valid unique pairs back to all rows in the experiment. + valid_pairs_mi = pd.MultiIndex.from_frame(existing[found_any]) + row_keys = pd.MultiIndex.from_frame(exp_df) + valid_mask[exp_mask.to_numpy()] = row_keys.isin(valid_pairs_mi) return self.tracks[valid_mask].reset_index(drop=True) @@ -651,11 +650,13 @@ def clone_with_subset( positive_cell_source: str = "lookup", positive_match_columns: list[str] | None = None, max_border_shift: int = -1, + precomputed_valid_anchors: pd.DataFrame | None = None, ) -> "MultiExperimentIndex": """Create a shallow copy with a different tracks DataFrame. Reuses the parent's registry, positions, and store cache so no - zarr stores are re-opened. Recomputes ``valid_anchors``. + zarr stores are re-opened. Recomputes ``valid_anchors`` unless + ``precomputed_valid_anchors`` is provided. Parameters ---------- @@ -667,20 +668,27 @@ def clone_with_subset( Forwarded to ``_compute_valid_anchors``. max_border_shift : int Forwarded to ``self.max_border_shift``. -1 inherits from parent. + precomputed_valid_anchors : pd.DataFrame | None + When provided, skip recomputing valid anchors. Pass the already- + filtered valid_anchors subset for this tracks_subset. Avoids + redundant O(N * tau_range) computation in FOV split mode. """ clone = object.__new__(MultiExperimentIndex) clone.registry = self.registry clone.yx_patch_size = self.yx_patch_size clone.tau_range_hours = self.tau_range_hours clone._store_cache = self._store_cache - clone.positions = self.positions + clone.tensorstore_config = self.tensorstore_config clone.max_border_shift = self.max_border_shift if max_border_shift < 0 else max_border_shift clone.tracks = tracks_subset.reset_index(drop=True) - clone.valid_anchors = clone._compute_valid_anchors( - tau_range_hours=self.tau_range_hours, - positive_cell_source=positive_cell_source, - positive_match_columns=positive_match_columns, - ) + if precomputed_valid_anchors is not None: + clone.valid_anchors = precomputed_valid_anchors.reset_index(drop=True) + else: + clone.valid_anchors = clone._compute_valid_anchors( + tau_range_hours=self.tau_range_hours, + positive_cell_source=positive_cell_source, + positive_match_columns=positive_match_columns, + ) if clone.valid_anchors.empty and not clone.tracks.empty: raise ValueError( f"No valid anchors found from {len(clone.tracks)} tracks in subset. " diff --git a/applications/dynaclr/src/dynaclr/data/preprocess_cell_index.py b/applications/dynaclr/src/dynaclr/data/preprocess_cell_index.py new file mode 100644 index 000000000..4ccab72da --- /dev/null +++ b/applications/dynaclr/src/dynaclr/data/preprocess_cell_index.py @@ -0,0 +1,30 @@ +"""CLI command for preprocessing a cell index parquet (add norm stats, focus slice, remove empties).""" + +import click + + +@click.command() +@click.argument("parquet_path") +@click.option( + "--output", + default=None, + help="Output path. Default: overwrite in place.", +) +@click.option( + "--focus-channel", + default=None, + help="Channel name for focus_slice lookup (e.g. Phase3D). Default: first channel per FOV.", +) +def main(parquet_path, output, focus_channel): + """Preprocess a cell index parquet: add normalization stats, focus slice, remove empty frames. + + Reads precomputed metadata from zarr zattrs and writes them as parquet + columns. Requires `viscy preprocess` to have been run on the zarr stores. + """ + from viscy_data.cell_index import preprocess_cell_index + + preprocess_cell_index( + parquet_path=parquet_path, + output_path=output, + focus_channel=focus_channel, + ) diff --git a/applications/dynaclr/src/dynaclr/evaluation/append_annotations.py b/applications/dynaclr/src/dynaclr/evaluation/append_annotations.py new file mode 100644 index 000000000..d7c4698f9 --- /dev/null +++ b/applications/dynaclr/src/dynaclr/evaluation/append_annotations.py @@ -0,0 +1,115 @@ +"""CLI for appending annotation columns to per-experiment AnnData zarr stores. + +Reads per-experiment annotation CSVs and writes task columns (e.g. infection_state, +organelle_state) directly into each zarr's obs. This persists ground truth labels +alongside the embeddings so downstream plots can color by annotation. + +Called as a step in the Nextflow evaluation pipeline after split-embeddings. +Annotation sources are shared with the linear_classifiers step config. + +Usage +----- +dynaclr append-annotations -c append_annotations.yaml +""" + +from __future__ import annotations + +from pathlib import Path + +import anndata as ad +import click + +from dynaclr.evaluation.evaluate_config import AnnotationSource, TaskSpec +from viscy_utils.cli_utils import load_config +from viscy_utils.evaluation.annotation import load_annotation_anndata +from viscy_utils.evaluation.zarr_utils import append_to_anndata_zarr + + +def append_annotations( + embeddings_path: Path, + annotations: list[AnnotationSource], + tasks: list[TaskSpec], +) -> None: + """Append annotation columns to per-experiment zarr obs. + + For each experiment in ``annotations``, loads the matching per-experiment + zarr, joins all task columns from the annotation CSV, and persists the + updated obs back to zarr. + + Parameters + ---------- + embeddings_path : Path + Directory containing per-experiment zarrs named ``{experiment}.zarr``. + annotations : list[AnnotationSource] + Per-experiment annotation CSV sources. Each entry maps an experiment + name to a CSV path with task columns. + tasks : list[TaskSpec] + Tasks to join (e.g. infection_state, organelle_state). Only tasks + present as columns in the annotation CSV are written. + """ + task_names = [t.task for t in tasks] + click.echo(f"Appending annotations for {len(annotations)} experiments, tasks: {task_names}") + + for ann_src in annotations: + experiment = ann_src.experiment + zarr_path = embeddings_path / f"{experiment}.zarr" + + if not zarr_path.exists(): + click.echo(f" [{experiment}] zarr not found, skipping: {zarr_path}", err=True) + continue + + ann_path = Path(ann_src.path) + if not ann_path.exists(): + raise FileNotFoundError(f"Annotation CSV not found: {ann_src.path}") + + click.echo(f"\n [{experiment}]") + adata = ad.read_zarr(zarr_path) + click.echo(f" Loaded {adata.n_obs} cells") + + n_joined = 0 + for task_name in task_names: + try: + adata = load_annotation_anndata(adata, str(ann_path), task_name) + n_valid = int(adata.obs[task_name].notna().sum()) + click.echo(f" {task_name}: {n_valid}/{adata.n_obs} labeled") + n_joined += 1 + except KeyError: + click.echo(f" {task_name}: not in {ann_path.name}, skipping") + + if n_joined == 0: + click.echo(f" No tasks found in {ann_path.name}, skipping zarr write") + continue + + append_to_anndata_zarr(zarr_path, obs=adata.obs) + click.echo(f" Saved obs to {zarr_path}") + + click.echo("\nDone.") + + +class _AppendAnnotationsConfig: + def __init__(self, raw: dict): + self.embeddings_path = Path(raw["embeddings_path"]) + self.annotations = [AnnotationSource(**a) for a in raw["annotations"]] + self.tasks = [TaskSpec(**t) for t in raw["tasks"]] + + +@click.command(context_settings={"help_option_names": ["-h", "--help"]}) +@click.option( + "-c", + "--config", + type=click.Path(exists=True, path_type=Path), + required=True, + help="Path to YAML configuration file", +) +def main(config: Path) -> None: + """Append annotation columns to per-experiment AnnData zarr stores.""" + click.echo("=" * 60) + click.echo("APPEND ANNOTATIONS") + click.echo("=" * 60) + raw = load_config(config) + cfg = _AppendAnnotationsConfig(raw) + append_annotations(cfg.embeddings_path, cfg.annotations, cfg.tasks) + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/src/dynaclr/evaluation/append_predictions.py b/applications/dynaclr/src/dynaclr/evaluation/append_predictions.py new file mode 100644 index 000000000..6f4553762 --- /dev/null +++ b/applications/dynaclr/src/dynaclr/evaluation/append_predictions.py @@ -0,0 +1,158 @@ +"""CLI for applying saved linear classifiers to per-experiment AnnData zarr stores. + +Reads the pipelines manifest written by ``dynaclr run-linear-classifiers``, +applies each saved classifier to ALL cells with the matching marker in each +per-experiment zarr, and writes predictions back to obs/obsm/uns. + +This enables plots colored by predicted labels (e.g. predicted_infection_state) +for every cell, including unannotated ones. + +Called as a step in the Nextflow evaluation pipeline after linear classifiers +have been trained (LINEAR_CLASSIFIERS step). + +Usage +----- +dynaclr append-predictions -c append_predictions.yaml +""" + +from __future__ import annotations + +import json +from pathlib import Path + +import anndata as ad +import click +import joblib +import numpy as np + +from viscy_utils.cli_utils import load_config +from viscy_utils.evaluation.zarr_utils import append_to_anndata_zarr + + +def append_predictions( + embeddings_path: Path, + pipelines_dir: Path, +) -> None: + """Apply saved classifiers to all cells and write predictions to zarrs. + + For each per-experiment zarr, loads all saved classifier pipelines and + applies each one to cells with the matching marker. Results are merged + per task (one ``predicted_{task}`` column per task regardless of how + many marker-specific classifiers contributed), then persisted to zarr. + + Parameters + ---------- + embeddings_path : Path + Directory containing per-experiment zarrs named ``{experiment}.zarr``. + pipelines_dir : Path + Directory containing ``manifest.json`` and ``*.joblib`` pipeline files + produced by ``dynaclr run-linear-classifiers``. + """ + manifest_path = pipelines_dir / "manifest.json" + if not manifest_path.exists(): + raise FileNotFoundError( + f"Pipeline manifest not found: {manifest_path}. Run dynaclr run-linear-classifiers first." + ) + + with open(manifest_path) as f: + manifest = json.load(f) + + if not manifest: + click.echo("No pipelines in manifest, nothing to do.") + return + + click.echo(f"Loaded {len(manifest)} pipeline(s) from {manifest_path}") + for entry in manifest: + click.echo(f" {entry['task']} / marker={entry['marker_filter']}") + + zarr_paths = sorted(embeddings_path.glob("*.zarr")) + if not zarr_paths: + raise FileNotFoundError(f"No .zarr files found in {embeddings_path}") + + click.echo(f"\nProcessing {len(zarr_paths)} per-experiment zarr(s)...") + + for zarr_path in zarr_paths: + click.echo(f"\n {zarr_path.stem}") + adata = ad.read_zarr(zarr_path) + click.echo(f" {adata.n_obs} cells, markers: {sorted(adata.obs['marker'].unique().tolist())}") + + # Group manifest entries by task + tasks_seen: set[str] = {entry["task"] for entry in manifest} + + new_obsm: dict[str, np.ndarray] = {} + + for task in sorted(tasks_seen): + task_entries = [e for e in manifest if e["task"] == task] + + first_pipeline = joblib.load(pipelines_dir / task_entries[0]["path"]) + n_classes = len(first_pipeline.classifier.classes_) + classes = first_pipeline.classifier.classes_.tolist() + + all_pred = np.full(adata.n_obs, np.nan, dtype=object) + all_proba = np.full((adata.n_obs, n_classes), np.nan) + + for entry in task_entries: + marker_filter = entry["marker_filter"] + pipeline_path = pipelines_dir / entry["path"] + + if not pipeline_path.exists(): + click.echo(f" Pipeline not found: {pipeline_path}, skipping", err=True) + continue + + marker_mask = (adata.obs["marker"] == marker_filter).to_numpy() + n_matching = int(marker_mask.sum()) + if n_matching == 0: + click.echo(f" {task}/{marker_filter}: no matching cells, skipping") + continue + + pipeline = joblib.load(pipeline_path) + adata_subset = adata[marker_mask] + + X_subset = adata_subset.X if isinstance(adata_subset.X, np.ndarray) else adata_subset.X.toarray() + preds = pipeline.predict(X_subset) + probas = pipeline.predict_proba(X_subset) + + all_pred[marker_mask] = preds + all_proba[marker_mask] = probas + click.echo(f" {task}/{marker_filter}: predicted {n_matching} cells") + + adata.obs[f"predicted_{task}"] = all_pred + adata.uns[f"predicted_{task}_classes"] = classes + new_obsm[f"predicted_{task}_proba"] = all_proba + + if not new_obsm: + click.echo(" No predictions written (no matching markers)") + continue + + append_to_anndata_zarr(zarr_path, obs=adata.obs, obsm=new_obsm, uns=adata.uns) + click.echo(f" Saved predictions to {zarr_path}") + + click.echo("\nDone.") + + +class _AppendPredictionsConfig: + def __init__(self, raw: dict): + self.embeddings_path = Path(raw["embeddings_path"]) + self.pipelines_dir = Path(raw["pipelines_dir"]) + + +@click.command(context_settings={"help_option_names": ["-h", "--help"]}) +@click.option( + "-c", + "--config", + type=click.Path(exists=True, path_type=Path), + required=True, + help="Path to YAML configuration file", +) +def main(config: Path) -> None: + """Apply saved linear classifiers to per-experiment zarrs and write predictions.""" + click.echo("=" * 60) + click.echo("APPEND PREDICTIONS") + click.echo("=" * 60) + raw = load_config(config) + cfg = _AppendPredictionsConfig(raw) + append_predictions(cfg.embeddings_path, cfg.pipelines_dir) + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/src/dynaclr/evaluation/benchmarking/smoothness/config.py b/applications/dynaclr/src/dynaclr/evaluation/benchmarking/smoothness/config.py index 77af8cf07..20e028f05 100644 --- a/applications/dynaclr/src/dynaclr/evaluation/benchmarking/smoothness/config.py +++ b/applications/dynaclr/src/dynaclr/evaluation/benchmarking/smoothness/config.py @@ -34,6 +34,10 @@ class SmoothnessEvalConfig(BaseModel): Whether to use memory-optimized computation. verbose : bool Print verbose progress messages. + group_by : str or None + obs column to group by before computing smoothness (e.g. "marker"). + Smoothness is computed per group; the reported aggregate stats are + mean ± std across groups. Set to null to compute on the whole embedding. """ models: list[ModelEntry] = Field(..., min_length=1) @@ -44,6 +48,7 @@ class SmoothnessEvalConfig(BaseModel): save_distributions: bool = False use_optimized: bool = True verbose: bool = False + group_by: Optional[str] = "marker" @model_validator(mode="after") def validate_paths(self): diff --git a/applications/dynaclr/src/dynaclr/evaluation/benchmarking/smoothness/evaluate_smoothness.py b/applications/dynaclr/src/dynaclr/evaluation/benchmarking/smoothness/evaluate_smoothness.py index 91a2e6db7..ae9e7c650 100644 --- a/applications/dynaclr/src/dynaclr/evaluation/benchmarking/smoothness/evaluate_smoothness.py +++ b/applications/dynaclr/src/dynaclr/evaluation/benchmarking/smoothness/evaluate_smoothness.py @@ -50,6 +50,7 @@ def main(config: Path): for i, model_entry in enumerate(config.models, 1): model_path = Path(model_entry.path) model_label = model_entry.label + experiment_name = model_path.stem click.echo(f"\nProcessing {i}/{len(config.models)}: {model_label}...") @@ -60,28 +61,87 @@ def main(config: Path): if config.verbose: click.echo(f" Loaded {features_ad.shape[0]:,} samples with {features_ad.shape[1]} features") - stats, distributions, _ = compute_embeddings_smoothness( - features_ad, - distance_metric=config.distance_metric, - verbose=config.verbose, - ) + group_col = config.group_by + if group_col and group_col in features_ad.obs.columns: + groups = features_ad.obs[group_col].unique().tolist() + click.echo(f" Computing smoothness per {group_col}: {groups}") + + per_group_rows = [] + group_stats_list = [] + group_distributions = {} + + for group_val in groups: + mask = features_ad.obs[group_col] == group_val + group_ad = features_ad[mask].copy() + + if config.verbose: + click.echo(f" {group_col}={group_val}: {group_ad.shape[0]:,} cells") + + g_stats, g_dists, _ = compute_embeddings_smoothness( + group_ad, + distance_metric=config.distance_metric, + verbose=config.verbose, + ) + per_group_rows.append({group_col: group_val, **g_stats}) + group_stats_list.append(g_stats) + group_distributions[group_val] = g_dists + + if config.save_plots: + _create_smoothness_plot( + g_dists, + g_stats, + f"{model_label}_{experiment_name}_{group_val}", + config.distance_metric, + output_dir, + ) + + per_group_df = pd.DataFrame(per_group_rows) + per_group_df.insert(0, "experiment", experiment_name) + per_group_df.to_csv( + output_dir / f"{model_label}_{experiment_name}_per_{group_col}_smoothness.csv", index=False + ) + click.echo(f" Per-{group_col} stats saved.") + + # Aggregate: mean ± std across groups + metric_cols = [c for c in per_group_df.columns if c != group_col] + agg_means = per_group_df[metric_cols].mean() + agg_stds = per_group_df[metric_cols].std() + stats = agg_means.to_dict() + stats_std = {f"{k}_std": v for k, v in agg_stds.to_dict().items()} + stats.update(stats_std) + + # Concatenate distributions across groups for the combined plot + distributions = { + "adjacent_frame_distribution": np.concatenate( + [d["adjacent_frame_distribution"] for d in group_distributions.values()] + ), + "random_frame_distribution": np.concatenate( + [d["random_frame_distribution"] for d in group_distributions.values()] + ), + } + else: + stats, distributions, _ = compute_embeddings_smoothness( + features_ad, + distance_metric=config.distance_metric, + verbose=config.verbose, + ) all_results[model_label] = stats all_distributions[model_label] = distributions save_results( stats, - output_dir / f"{model_label}_smoothness_stats.csv", + output_dir / f"{model_label}_{experiment_name}_smoothness_stats.csv", format="csv", ) if config.save_distributions: np.save( - output_dir / f"{model_label}_adjacent_distribution.npy", + output_dir / f"{model_label}_{experiment_name}_adjacent_distribution.npy", distributions["adjacent_frame_distribution"], ) np.save( - output_dir / f"{model_label}_random_distribution.npy", + output_dir / f"{model_label}_{experiment_name}_random_distribution.npy", distributions["random_frame_distribution"], ) @@ -91,7 +151,7 @@ def main(config: Path): _create_smoothness_plot( distributions, stats, - model_label, + f"{model_label}_{experiment_name}", config.distance_metric, output_dir, ) diff --git a/applications/dynaclr/src/dynaclr/evaluation/benchmarking/tracking_accuracy/__init__.py b/applications/dynaclr/src/dynaclr/evaluation/benchmarking/tracking_accuracy/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/applications/dynaclr/src/dynaclr/evaluation/benchmarking/tracking_accuracy/config.py b/applications/dynaclr/src/dynaclr/evaluation/benchmarking/tracking_accuracy/config.py new file mode 100644 index 000000000..aae0b9a9d --- /dev/null +++ b/applications/dynaclr/src/dynaclr/evaluation/benchmarking/tracking_accuracy/config.py @@ -0,0 +1,107 @@ +"""Configuration models for CTC tracking accuracy evaluation.""" + +from __future__ import annotations + +from pydantic import BaseModel, Field + + +class ONNXModelEntry(BaseModel): + """One model to benchmark. + + Parameters + ---------- + path : str or None + Path to the ONNX model file. None runs the baseline (IoU + spatial edges only, + no embedding model). + label : str + Display name for this model in results. + pixel_size_um : float or None + Pixel size (µm/px) the model was trained at. Used to rescale input crops + when the dataset pixel size differs. None disables rescaling. + """ + + path: str | None + label: str + pixel_size_um: float | None = None + + +class CTCDatasetEntry(BaseModel): + """One CTC dataset directory. + + Parameters + ---------- + path : str + Path to the dataset root (e.g. /hpc/reference/group.royer/CTC/training/BF-C2DL-HSC). + Must contain ``{seq}_ERR_SEG/``, ``{seq}/`` (raw images), and ``{seq}_GT/TRA/`` + subdirectories for each sequence. + sequences : list[str] + Sequence numbers to evaluate (e.g. ["01", "02"]). + pixel_size_um : float or None + Pixel size (µm/px) of the raw images. Used with ``ONNXModelEntry.pixel_size_um`` + to rescale crops before ONNX inference. If None, looked up from + ``TrackingAccuracyConfig.ctc_metadata_path`` by dataset name, then + falls back to reading TIFF XResolution metadata. + """ + + path: str + sequences: list[str] = Field(default=["01", "02"]) + pixel_size_um: float | None = None + + +class TrackingAccuracyConfig(BaseModel): + """Configuration for CTC tracking accuracy evaluation. + + Parameters + ---------- + models : list[ONNXModelEntry] + Models to benchmark. Include an entry with ``path: null`` for the IoU baseline. + datasets : list[CTCDatasetEntry] + CTC datasets to evaluate. + model_input_shape : tuple[int, int] + Height x width of the ONNX model input (must match what the model was exported with). + Default (160, 160) matches the DynaCLR-2D-MIP training resolution. + distance_threshold : float + Maximum spatial distance (pixels) for candidate edges in DistanceEdges. + n_neighbors : int + Maximum candidate edges per cell. + delta_t : int + Maximum frame gap for candidate edges. + division_weight : float + ILP solver weight for cell division events. + appearance_weight : float + ILP solver weight for cell appearance. + disappearance_weight : float + ILP solver weight for cell disappearance. + node_weight : float + ILP solver weight per node (negative = prefer more detections). + output_dir : str + Directory for results CSV. + ctc_metrics : list[str] or None + CTC metric names to include in output. None = all available metrics. + batch_size : int + Number of cell crops per ONNX inference call. + ctc_metadata_path : str or None + Path to a CTC metadata YAML mapping dataset names to + ``[interval_min, y_um, x_um]``. Used to look up pixel size when + ``CTCDatasetEntry.pixel_size_um`` is not set. Falls back to reading + TIFF XResolution tags if the dataset is not in the file. + show_napari : bool + Open a napari viewer after tracking each sequence. Only use when running + interactively on a partition with a display. Default: False. + """ + + models: list[ONNXModelEntry] = Field(..., min_length=1) + datasets: list[CTCDatasetEntry] = Field(..., min_length=1) + ctc_metadata_path: str | None = None + model_input_shape: tuple[int, int] = (160, 160) + distance_threshold: float = 325.0 + n_neighbors: int = 10 + delta_t: int = 5 + division_weight: float = 0.5 + appearance_weight: float = 0.0 + disappearance_weight: float = 0.0 + node_weight: float = -10.0 + output_dir: str + ctc_metrics: list[str] | None = None + batch_size: int = 128 + show_napari: bool = False diff --git a/applications/dynaclr/src/dynaclr/evaluation/benchmarking/tracking_accuracy/evaluate_tracking.py b/applications/dynaclr/src/dynaclr/evaluation/benchmarking/tracking_accuracy/evaluate_tracking.py new file mode 100644 index 000000000..c4005e068 --- /dev/null +++ b/applications/dynaclr/src/dynaclr/evaluation/benchmarking/tracking_accuracy/evaluate_tracking.py @@ -0,0 +1,484 @@ +"""CLI tool for CTC tracking accuracy benchmarking with DynaCLR embeddings. + +Evaluates how well DynaCLR embedding similarity, used as an additional edge cost, +improves cell tracking accuracy on CTC (Cell Tracking Challenge) benchmark datasets. + +For each (ONNX model, CTC dataset, sequence) combination: +1. Load segmentation masks and raw images. +2. Build a tracksdata graph (nodes from masks, candidate edges via DistanceEdges). +3. If a model is provided, run ONNX inference on cell crops and weight edges by + embedding cosine similarity * spatial distance weight. +4. If no model is provided, use IoU + spatial distance (baseline). +5. Solve the tracking with ILP and evaluate against CTC ground truth. + +Usage +----- +dynaclr evaluate-tracking-accuracy -c tracking_accuracy_config.yaml +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any + +import click +import numpy as np +import polars as pl +import tracksdata as td +from dask.array.image import imread +from numpy.typing import NDArray +from rich import print as rprint +from skimage.transform import resize + +from dynaclr.evaluation.benchmarking.tracking_accuracy.config import ( + CTCDatasetEntry, + ONNXModelEntry, + TrackingAccuracyConfig, +) +from dynaclr.evaluation.benchmarking.tracking_accuracy.utils import ( + normalize_crop, + pad_to_shape, + seg_dir, +) +from viscy_utils.cli_utils import load_config + +_logger = logging.getLogger(__name__) + + +def _load_ctc_metadata(path: Path) -> dict[str, float]: + """Load dataset name → x pixel size (µm) from Jordao's CTC metadata YAML. + + Format: ``dataset_name: [interval_min, y_um, x_um]`` + + Parameters + ---------- + path : Path + Path to the metadata YAML file. + + Returns + ------- + dict[str, float] + Mapping from dataset name to x pixel size in µm. + """ + import yaml + + with open(path) as f: + raw = yaml.safe_load(f) + # value is [interval_min, y_um, x_um] — take x (index 2) + return {name: values[2] for name, values in raw.items() if isinstance(values, list)} + + +def _crop_embedding( + frame: NDArray, + mask: list, + source_shape: tuple[int, int], + final_shape: tuple[int, int], + session: Any, + input_name: str, +) -> list[NDArray]: + """Crop cells from a frame and compute DynaCLR embeddings via ONNX. + + Parameters + ---------- + frame : NDArray + Raw image frame (2-D or 3-D with a single z-slice). + mask : list[td.nodes.Mask] + Cell masks for this frame. The parameter name must match the graph + attribute key (``"mask"`` in ``attr_keys``). + source_shape : tuple[int, int] + (height, width) to extract from the image in dataset pixels. + If different from ``final_shape``, the crop is resized to ``final_shape`` + to correct for pixel size differences between dataset and training data. + final_shape : tuple[int, int] + (height, width) of the model input (must match ONNX input size). + session : ort.InferenceSession + ONNX runtime inference session. + input_name : str + Name of the ONNX model's input tensor. + + Returns + ------- + list[NDArray] + L2-normalized embedding vector for each mask (same order). + """ + # Compute frame-level stats once — matches timepoint_statistics normalization used in training + frame_f32 = frame.astype(np.float32) + frame_mean = float(np.mean(frame_f32)) + frame_std = float(np.std(frame_f32)) + + label_img = np.zeros_like(frame, dtype=np.int16) + crops = [] + + for i, m in enumerate(mask, start=1): + if frame.ndim == 3: + extract_shape = (1, *source_shape) + else: + extract_shape = source_shape + + label_img[m.mask_indices()] = i + + crop = m.crop(frame, shape=extract_shape).astype(np.float32) + + if crop.ndim == 3: + if crop.shape[0] != 1: + raise ValueError(f"Expected 1 z-slice in 3D crop, got {crop.shape[0]}") + crop = crop[0] + + crop = pad_to_shape(crop, source_shape, mode="reflect") + + if source_shape != final_shape: + crop = resize(crop, final_shape, order=1, anti_aliasing=True, preserve_range=True).astype(np.float32) + + crop = normalize_crop(crop, frame_mean, frame_std) + + if crop.shape != final_shape: + raise ValueError(f"Crop shape {crop.shape} != final_shape {final_shape}") + + crops.append(crop) + + # shape: (batch, channel, z, h, w) + batch = np.stack(crops, axis=0)[:, np.newaxis, np.newaxis, ...] + output = session.run(None, {input_name: batch}) + + embeddings = output[0] # backbone features (e.g. 768-dim) + embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True) + return list(embeddings) + + +def _add_dynaclr_attrs( + model_path: Path, + graph: td.graph.InMemoryGraph, + images: NDArray, + model_input_shape: tuple[int, int], + batch_size: int, + pixel_size_scale: float, +) -> None: + """Add DynaCLR embedding node attributes and cosine similarity edge attributes. + + Parameters + ---------- + model_path : Path + Path to the exported ONNX model. + graph : td.graph.InMemoryGraph + Graph with nodes already added (must have ``mask`` attribute). + images : NDArray + Raw image stack, shape (T, H, W) or (T, Z, H, W). + model_input_shape : tuple[int, int] + (height, width) of the ONNX model input (e.g. (160, 160)). + batch_size : int + Number of crops per ONNX inference call. + pixel_size_scale : float + Ratio of dataset pixel size to model training pixel size + (dataset_um / model_um). Crops are extracted at + ``model_input_shape * pixel_size_scale`` and resized to ``model_input_shape``. + Use 1.0 when no rescaling is needed. + """ + import onnxruntime as ort + + session_options = ort.SessionOptions() + session_options.intra_op_num_threads = 1 + session_options.inter_op_num_threads = 1 + session = ort.InferenceSession( + str(model_path), + sess_options=session_options, + providers=["CUDAExecutionProvider", "CPUExecutionProvider"], + ) + input_name = session.get_inputs()[0].name + _logger.info( + "ONNX model: input='%s' shape=%s type=%s", + input_name, + session.get_inputs()[0].shape, + session.get_inputs()[0].type, + ) + + source_shape = ( + round(model_input_shape[0] * pixel_size_scale), + round(model_input_shape[1] * pixel_size_scale), + ) + _logger.info( + "Crop pipeline: extract %s px -> resize to %s px (scale=%.3f)", + source_shape, + model_input_shape, + pixel_size_scale, + ) + + from toolz import curry + + crop_fn = curry(_crop_embedding)( + source_shape=source_shape, + final_shape=model_input_shape, + session=session, + input_name=input_name, + ) + + graph.add_node_attr_key("dynaclr_embedding", dtype=pl.List(pl.Float32)) + + td.nodes.GenericFuncNodeAttrs( + func=crop_fn, + output_key="dynaclr_embedding", + attr_keys=["mask"], + batch_size=batch_size, + ).add_node_attrs(graph, frames=images) + + td.edges.GenericFuncEdgeAttrs( + func=np.dot, + output_key="dynaclr_similarity", + attr_keys="dynaclr_embedding", + ).add_edge_attrs(graph) + + +def _build_and_solve( + model_path: Path | None, + images: NDArray, + labels: NDArray, + config: TrackingAccuracyConfig, + pixel_size_scale: float = 1.0, +) -> tuple[td.graph.InMemoryGraph, td.graph.InMemoryGraph]: + """Build a tracksdata graph and solve tracking. + + Parameters + ---------- + model_path : Path or None + ONNX model path. None uses the IoU + spatial baseline. + images : NDArray + Raw image stack (T, H, W). + labels : NDArray + Segmentation label stack (T, H, W). + config : TrackingAccuracyConfig + Evaluation configuration. + pixel_size_scale : float + Ratio of dataset pixel size to model training pixel size + (dataset_um / model_um). Passed to ``_add_dynaclr_attrs``. Default 1.0. + + Returns + ------- + graph : td.graph.InMemoryGraph + Full candidate graph (all nodes + candidate edges). + solution_graph : td.graph.InMemoryGraph + ILP-solved tracking result. + """ + graph = td.graph.InMemoryGraph() + + td.nodes.RegionPropsNodes().add_nodes(graph, labels=labels) + _logger.info("Nodes: %d", graph.num_nodes()) + + dist_op = td.edges.DistanceEdges( + distance_threshold=config.distance_threshold, + n_neighbors=config.n_neighbors, + delta_t=config.delta_t, + ) + dist_op.add_edges(graph) + _logger.info("Candidate edges: %d", graph.num_edges()) + + td.edges.GenericFuncEdgeAttrs( + func=lambda x, y: abs(x - y), + output_key="delta_t", + attr_keys="t", + ).add_edge_attrs(graph) + + dist_weight = (-td.EdgeAttr(td.DEFAULT_ATTR_KEYS.EDGE_DIST) / config.distance_threshold).exp() + + if model_path is not None: + _add_dynaclr_attrs(model_path, graph, images, config.model_input_shape, config.batch_size, pixel_size_scale) + edge_weight = -td.EdgeAttr("dynaclr_similarity") * dist_weight + else: + td.edges.IoUEdgeAttr(output_key="iou").add_edge_attrs(graph) + edge_weight = -(td.EdgeAttr("iou") + 0.1) * dist_weight + + edge_weight = edge_weight / td.EdgeAttr("delta_t").clip(lower_bound=1) + + solver = td.solvers.ILPSolver( + edge_weight=edge_weight, + appearance_weight=config.appearance_weight, + disappearance_weight=config.disappearance_weight, + division_weight=config.division_weight, + node_weight=config.node_weight, + ) + solution_graph = solver.solve(graph) + + return graph, solution_graph + + +def _show_napari_viewer( + graph: td.graph.InMemoryGraph, + images: NDArray, + labels: NDArray, +) -> None: + """Open a napari viewer with the tracking result overlaid on the raw images. + + Parameters + ---------- + graph : td.graph.InMemoryGraph + Full candidate graph (used to derive napari tracks format). + images : NDArray + Raw image stack (T, H, W). + labels : NDArray + Segmentation label stack (T, H, W). + """ + import napari + + tracks_df, track_graph, label_stack = td.functional.to_napari_format( + graph, labels.shape, mask_key=td.DEFAULT_ATTR_KEYS.MASK + ) + viewer = napari.Viewer() + viewer.add_image(images) + viewer.add_labels(label_stack) + viewer.add_tracks(tracks_df, graph=track_graph) + napari.run() + + +def track_single_dataset( + dataset_entry: CTCDatasetEntry, + sequence: str, + model_entry: ONNXModelEntry, + config: TrackingAccuracyConfig, +) -> dict: + """Track one CTC sequence and evaluate metrics. + + Parameters + ---------- + dataset_dir : Path + CTC dataset root. + sequence : str + Sequence number (e.g. "01"). + model_entry : ONNXModelEntry + Model to use (path=None for baseline). + config : TrackingAccuracyConfig + Evaluation configuration. + + Returns + ------- + dict + CTC metrics dict plus ``model``, ``dataset``, ``sequence`` keys. + """ + dataset_dir = Path(dataset_entry.path) + _seg_dir = seg_dir(dataset_dir, sequence) + if not _seg_dir.exists(): + raise FileNotFoundError(f"Segmentation directory not found: {_seg_dir}") + + model_path = Path(model_entry.path) if model_entry.path is not None else None + + _logger.info("Loading labels from %s", _seg_dir) + labels = imread(str(_seg_dir / "*.tif")).compute() + images = imread(str(dataset_dir / sequence / "*.tif")).compute() + + gt_graph = td.graph.InMemoryGraph.from_ctc(dataset_dir / f"{sequence}_GT" / "TRA") + + _logger.info( + "Tracking: model=%s dataset=%s seq=%s", + model_entry.label, + dataset_dir.name, + sequence, + ) + dataset_pixel_size = dataset_entry.pixel_size_um + if dataset_pixel_size is None and config.ctc_metadata_path is not None: + ctc_meta = _load_ctc_metadata(Path(config.ctc_metadata_path)) + dataset_pixel_size = ctc_meta.get(dataset_dir.name) + if dataset_pixel_size is not None: + _logger.info("Pixel size from metadata: %.4f µm/px (%s)", dataset_pixel_size, dataset_dir.name) + else: + _logger.warning( + "Dataset %s not found in %s; no rescaling applied", dataset_dir.name, config.ctc_metadata_path + ) + + if model_entry.pixel_size_um is not None and dataset_pixel_size is not None: + pixel_size_scale = dataset_pixel_size / model_entry.pixel_size_um + else: + pixel_size_scale = 1.0 + + graph, solution_graph = _build_and_solve(model_path, images, labels, config, pixel_size_scale) + + if config.show_napari: + _show_napari_viewer(graph, images, labels) + + _logger.info("Evaluating CTC metrics ...") + metrics = td.metrics.evaluate_ctc_metrics( + solution_graph, + gt_graph, + input_reset=False, + reference_reset=False, + metrics=config.ctc_metrics, + ) + + metrics["model"] = model_entry.label + metrics["dataset"] = dataset_dir.name + metrics["sequence"] = sequence + return metrics + + +@click.command(context_settings={"help_option_names": ["-h", "--help"]}) +@click.option( + "-c", + "--config", + type=click.Path(exists=True, path_type=Path), + required=True, + help="Path to tracking accuracy YAML configuration file", +) +def main(config: Path) -> None: + """Evaluate CTC tracking accuracy with DynaCLR ONNX embeddings. + + Runs ILP-based tracking on CTC benchmark datasets, comparing a spatial+IoU + baseline against models that use DynaCLR embedding similarity as an additional + edge cost. Writes results.csv to the configured output directory. + """ + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(name)s %(levelname)s %(message)s", + ) + + raw = load_config(config) + cfg = TrackingAccuracyConfig(**raw) + + output_dir = Path(cfg.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + results: list[dict] = [] + + for model_entry in cfg.models: + for dataset_entry in cfg.datasets: + dataset_dir = Path(dataset_entry.path) + for sequence in dataset_entry.sequences: + _seg = seg_dir(dataset_dir, sequence) + if not _seg.exists(): + click.echo( + f"Skipping {dataset_dir.name}/{sequence}: no segmentation at {_seg}", + err=True, + ) + continue + + try: + row = track_single_dataset(dataset_entry, sequence, model_entry, cfg) + except Exception as exc: + click.echo( + f"Error {model_entry.label} / {dataset_dir.name} / {sequence}: {exc}", + err=True, + ) + _logger.exception("Tracking failed") + continue + + rprint(row) + results.append(row) + + # Write incrementally so partial results are never lost + df = pl.DataFrame(results) + df.write_csv(output_dir / "results.csv") + + if not results: + click.echo("No results produced.", err=True) + return + + df = pl.DataFrame(results) + df.write_csv(output_dir / "results.csv") + click.echo(f"\nResults written to {output_dir / 'results.csv'}") + + # Summary: mean across sequences, grouped by model x dataset + key_metrics = [c for c in ["LNK", "BIO(0)", "OP_CLB(0)", "CHOTA", "TRA", "DET"] if c in df.columns] + if key_metrics: + summary = df.group_by("model", "dataset").agg([pl.col(m).mean() for m in key_metrics]).sort("model", "dataset") + click.echo("\n## Tracking Accuracy Summary (mean over sequences)\n") + click.echo(summary.to_pandas().to_markdown(index=False, floatfmt=".3f")) + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/src/dynaclr/evaluation/benchmarking/tracking_accuracy/utils.py b/applications/dynaclr/src/dynaclr/evaluation/benchmarking/tracking_accuracy/utils.py new file mode 100644 index 000000000..8fc998465 --- /dev/null +++ b/applications/dynaclr/src/dynaclr/evaluation/benchmarking/tracking_accuracy/utils.py @@ -0,0 +1,66 @@ +"""Utilities for CTC tracking accuracy evaluation.""" + +from __future__ import annotations + +from pathlib import Path + +import numpy as np +from numpy.typing import NDArray + + +def seg_dir(dataset_dir: Path, sequence: str) -> Path: + """Return path to the error-segmentation directory for a CTC sequence. + + Parameters + ---------- + dataset_dir : Path + CTC dataset root (e.g. .../BF-C2DL-HSC). + sequence : str + Sequence number (e.g. "01"). + """ + return dataset_dir / f"{sequence}_ERR_SEG" + + +def pad_to_shape(image: NDArray, shape: tuple[int, int], mode: str) -> NDArray: + """Pad image symmetrically to at least the given spatial shape. + + Parameters + ---------- + image : NDArray + 2-D array to pad. + shape : tuple[int, int] + Target (height, width). No-op if image is already large enough. + mode : str + Padding mode passed to ``np.pad``. + """ + diff = np.asarray(shape) - np.asarray(image.shape) + if diff.sum() == 0: + return image + left = diff // 2 + right = diff - left + return np.pad(image, tuple(zip(left, right)), mode=mode) + + +def normalize_crop(crop: NDArray, frame_mean: float, frame_std: float) -> NDArray: + """Z-score normalize a cell crop using whole-frame statistics. + + Matches the training normalization (``NormalizeSampled`` with + ``level=timepoint_statistics``): mean/std are computed over the full + frame, not the cell foreground, so the model sees the same intensity + distribution it was trained on. + + Parameters + ---------- + crop : NDArray + Float32 2-D cell image. + frame_mean : float + Mean pixel intensity of the full frame at this timepoint. + frame_std : float + Std pixel intensity of the full frame at this timepoint. + + Returns + ------- + NDArray + Z-score normalized crop. + """ + return (crop - frame_mean) / max(frame_std, 1e-8) diff --git a/applications/dynaclr/src/dynaclr/evaluation/check_evals.py b/applications/dynaclr/src/dynaclr/evaluation/check_evals.py new file mode 100644 index 000000000..83b6e3ee2 --- /dev/null +++ b/applications/dynaclr/src/dynaclr/evaluation/check_evals.py @@ -0,0 +1,161 @@ +"""Check completion status of eval runs defined in an eval registry YAML. + +Derives status from filesystem sentinels rather than stored state, so it +is always ground-truth. + +Usage +----- +dynaclr check-evals -r eval_registry.yaml +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Literal + +import click +import yaml + +from dynaclr.evaluation.evaluate_config import EvaluationConfig +from viscy_utils.cli_utils import load_config + +_STEP_SENTINELS: dict[str, str] = { + "predict": "embeddings/embeddings.zarr", + "split": "configs/viewer.yaml", + "reduce_dimensionality": "configs/reduce.yaml", + "reduce_combined": "configs/reduce_combined.yaml", + "smoothness": "smoothness/combined_smoothness_stats.csv", + "plot": "plots", + "linear_classifiers": "linear_classifiers/metrics_summary.csv", +} + +Status = Literal["done", "partial", "pending"] + + +def _check_mmd_step(output_dir: Path, eval_cfg: EvaluationConfig) -> bool: + """Return True if all MMD blocks have at least one result CSV.""" + if not eval_cfg.mmd: + return True # no MMD configured — not a blocking step + for i, block in enumerate(eval_cfg.mmd): + block_name = block.name if block.name else f"mmd_{i}" + block_dir = output_dir / "mmd" / block_name + if not any(block_dir.glob("*.csv")): + return False + return True + + +def _check_plot_step(output_dir: Path) -> bool: + """Return True if the plots directory has at least one PDF.""" + plots_dir = output_dir / "plots" + return any(plots_dir.rglob("*.pdf")) + + +def _missing_steps(eval_cfg: EvaluationConfig) -> list[str]: + """Return steps from eval_cfg.steps that have not yet produced their sentinel output.""" + output_dir = Path(eval_cfg.output_dir) + missing = [] + for step in eval_cfg.steps: + if step == "mmd": + if not _check_mmd_step(output_dir, eval_cfg): + missing.append(step) + elif step == "plot": + if not _check_plot_step(output_dir): + missing.append(step) + elif step in _STEP_SENTINELS: + sentinel = output_dir / _STEP_SENTINELS[step] + if not sentinel.exists(): + missing.append(step) + # unknown steps: skip silently + return missing + + +def _model_status(eval_cfg: EvaluationConfig, force_rerun: bool) -> tuple[Status, list[str]]: + """Return (status, missing_steps) for one model entry.""" + if force_rerun: + return "pending", ["(force_rerun=true)"] + missing = _missing_steps(eval_cfg) + if not missing: + return "done", [] + if len(missing) < len(eval_cfg.steps): + return "partial", missing + return "pending", missing + + +def _load_registry(registry_path: Path) -> list[dict]: + with open(registry_path) as f: + data = yaml.safe_load(f) + return data["models"] + + +def check_evals(registry: Path, workspace_dir: Path | None) -> None: + """Print a markdown table showing completion status for each registered model.""" + models = _load_registry(registry) + + rows = [] + for entry in models: + name = entry["name"] + force_rerun = entry.get("force_rerun", False) + eval_config_path = Path(entry["eval_config"]) + + # Resolve relative paths against workspace_dir (if provided) or registry location + if not eval_config_path.is_absolute(): + base = workspace_dir if workspace_dir else registry.parent.parent.parent.parent + eval_config_path = base / eval_config_path + + try: + raw = load_config(eval_config_path) + eval_cfg = EvaluationConfig(**raw) + status, missing = _model_status(eval_cfg, force_rerun) + missing_str = ", ".join(missing) if missing else "—" + except FileNotFoundError as e: + status = "pending" + missing_str = f"config not found: {e}" + except Exception as e: # noqa: BLE001 + status = "pending" + missing_str = f"error: {e}" + + rows.append((name, status, missing_str)) + + # Print markdown table + col_name = max(len(r[0]) for r in rows) + col_status = max(len(r[1]) for r in rows) + col_missing = max(len(r[2]) for r in rows) + + col_name = max(col_name, len("Model")) + col_status = max(col_status, len("Status")) + col_missing = max(col_missing, len("Missing Steps")) + + header = f"| {'Model':<{col_name}} | {'Status':<{col_status}} | {'Missing Steps':<{col_missing}} |" + sep = f"| {'-' * col_name} | {'-' * col_status} | {'-' * col_missing} |" + click.echo(header) + click.echo(sep) + for name, status, missing_str in rows: + click.echo(f"| {name:<{col_name}} | {status:<{col_status}} | {missing_str:<{col_missing}} |") + + +@click.command(context_settings={"help_option_names": ["-h", "--help"]}) +@click.option( + "-r", + "--registry", + type=click.Path(exists=True, path_type=Path), + required=True, + help="Path to eval_registry.yaml", +) +@click.option( + "--workspace-dir", + type=click.Path(exists=True, path_type=Path), + default=None, + help="Workspace root for resolving relative eval_config paths. Defaults to four levels above the registry file.", +) +def main(registry: Path, workspace_dir: Path | None) -> None: + """Print a markdown table showing eval completion status for each registered model. + + Status is derived from filesystem sentinels — never stored manually. + Set force_rerun: true in the registry to mark a model for re-execution + regardless of existing outputs. + """ + check_evals(registry, workspace_dir) + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/src/dynaclr/evaluation/dimensionality_reduction/config.py b/applications/dynaclr/src/dynaclr/evaluation/dimensionality_reduction/config.py index a5448b261..1f3aa86c4 100644 --- a/applications/dynaclr/src/dynaclr/evaluation/dimensionality_reduction/config.py +++ b/applications/dynaclr/src/dynaclr/evaluation/dimensionality_reduction/config.py @@ -30,6 +30,9 @@ class PHATEConfig(BaseModel): knn_dist: str = "cosine" scale_embeddings: bool = False random_state: int = 42 + n_pca: int = 50 + subsample: Optional[int] = 50_000 + n_jobs: int = 1 class DimensionalityReductionConfig(BaseModel): @@ -65,3 +68,64 @@ def validate_config(self): if self.pca is None and self.umap is None and self.phate is None: raise ValueError("At least one reduction method must be specified (pca, umap, or phate)") return self + + +class CombinedDatasetConfig(BaseModel): + """Input dataset spec for combined reductions. + + Parameters + ---------- + anndata : str + Path to AnnData zarr store with features in ``.X``. + hcs_plate : str, optional + Path to the raw HCS plate zarr (not used for reductions, but useful for reuse). + """ + + anndata: str = Field(...) + hcs_plate: Optional[str] = None + + +class CombinedDimensionalityReductionConfig(BaseModel): + """Configuration for computing joint dimensionality reductions across multiple AnnData stores. + + Parameters + ---------- + input_paths : list[str], optional + Paths to AnnData zarr stores. Embeddings from all stores are concatenated before fitting + reductions, then per-store slices are written back with a ``_combined`` suffix. + datasets : dict[str, CombinedDatasetConfig], optional + Alternative to ``input_paths``. When provided, ``input_paths`` is derived from + ``datasets[*].anndata``. This matches the multi-dataset YAML used in organelle dynamics. + pca : PCAConfig, optional + PCA parameters. Results stored as ``X_pca_combined``. + umap : UMAPConfig, optional + UMAP parameters. Results stored as ``X_umap_combined``. + phate : PHATEConfig, optional + PHATE parameters. Results stored as ``X_phate_combined``. + overwrite_keys : bool + If True, overwrite existing ``.obsm`` keys. Otherwise raise on conflict. + """ + + input_paths: Optional[list[str]] = None + datasets: Optional[dict[str, CombinedDatasetConfig]] = None + pca: Optional[PCAConfig] = None + umap: Optional[UMAPConfig] = None + phate: Optional[PHATEConfig] = None + overwrite_keys: bool = False + + @model_validator(mode="after") + def validate_config(self): + if self.input_paths is None: + if not self.datasets: + raise ValueError("Either input_paths or datasets must be provided") + self.input_paths = [d.anndata for d in self.datasets.values()] + + if len(self.input_paths) < 1: + raise ValueError("At least one input path must be provided") + + for p in self.input_paths: + if not Path(p).exists(): + raise ValueError(f"Input path not found: {p}") + if self.pca is None and self.umap is None and self.phate is None: + raise ValueError("At least one reduction method must be specified (pca, umap, or phate)") + return self diff --git a/applications/dynaclr/src/dynaclr/evaluation/dimensionality_reduction/reduce_combined.py b/applications/dynaclr/src/dynaclr/evaluation/dimensionality_reduction/reduce_combined.py new file mode 100644 index 000000000..0561ecba0 --- /dev/null +++ b/applications/dynaclr/src/dynaclr/evaluation/dimensionality_reduction/reduce_combined.py @@ -0,0 +1,127 @@ +""" +Joint dimensionality reduction (PCA, UMAP, PHATE) across multiple AnnData zarr stores. + +Concatenates embeddings from all stores, fits joint reductions, +then writes per-store slices back as X_*_combined. + +Usage +----- +dynaclr reduce-combined -c multi-dataset-dim-reduction.yml +""" + +import anndata as ad +import click +import numpy as np + +from viscy_utils.cli_utils import format_markdown_table, load_config_section +from viscy_utils.evaluation.zarr_utils import append_to_anndata_zarr + +from .config import CombinedDimensionalityReductionConfig +from .reduce_dimensionality import _run_pca, _run_phate, _run_umap + + +@click.command(context_settings={"help_option_names": ["-h", "--help"]}) +@click.option( + "-c", + "--config", + type=click.Path(exists=True, path_type=str), + required=True, + help="Path to YAML configuration file", +) +def main(config: str): + """Compute joint PCA, UMAP, and/or PHATE across multiple AnnData zarr stores.""" + click.echo("Loading configuration...") + raw_config = load_config_section(config, None, default_section="reduce_combined") + cfg = CombinedDimensionalityReductionConfig(**raw_config) + + if hasattr(ad, "settings") and hasattr(ad.settings, "allow_write_nullable_strings"): + ad.settings.allow_write_nullable_strings = True + + resolved_paths = [str(p) for p in cfg.input_paths] + dataset_names = list(cfg.datasets.keys()) if cfg.datasets else None + + # Determine which keys will be written + methods_to_run: list[tuple[str, object]] = [] + if cfg.pca is not None: + methods_to_run.append(("pca", cfg.pca)) + if cfg.umap is not None: + methods_to_run.append(("umap", cfg.umap)) + if cfg.phate is not None: + methods_to_run.append(("phate", cfg.phate)) + + key_map = {"pca": "X_pca_combined", "umap": "X_umap_combined", "phate": "X_phate_combined"} + keys_to_write = [key_map[name] for name, _ in methods_to_run] + + # Check for existing keys before loading data + if not cfg.overwrite_keys: + for path in resolved_paths: + adata = ad.read_zarr(path) + for key in keys_to_write: + if key in adata.obsm: + raise click.ClickException( + f"Key '{key}' already exists in {path}. Use overwrite_keys: true to replace." + ) + + # Load embeddings from all stores + all_features = [] + all_lineage_ids = [] + sample_counts = [] + for path in resolved_paths: + click.echo(f"Reading {path}...") + adata = ad.read_zarr(path) + features = np.asarray(adata.X) + all_features.append(features) + sample_counts.append(features.shape[0]) + if "lineage_id" in adata.obs.columns: + all_lineage_ids.append(adata.obs["lineage_id"].to_numpy()) + click.echo(f" {features.shape[0]:,} samples x {features.shape[1]} features") + + combined = np.concatenate(all_features, axis=0) + combined_lineage_ids = np.concatenate(all_lineage_ids) if all_lineage_ids else None + click.echo(f"Combined: {combined.shape[0]:,} samples x {combined.shape[1]} features") + + # Compute reductions on joint data + results: dict[str, np.ndarray] = {} + + runner_map = {"pca": _run_pca, "umap": _run_umap, "phate": _run_phate} + for method_name, method_cfg in methods_to_run: + if method_name == "phate": + _, embedding = _run_phate(combined, method_cfg, lineage_ids=combined_lineage_ids) + else: + _, embedding = runner_map[method_name](combined, method_cfg) + out_key = key_map[method_name] + results[out_key] = embedding + click.echo(f" {method_name.upper()} done -> {out_key} ({embedding.shape[1]} components)") + + # Slice and write back to each store + offset = 0 + for i, path in enumerate(resolved_paths): + n = sample_counts[i] + store_obsm = {key: emb[offset : offset + n] for key, emb in results.items()} + store_uns = {} + for method_name, _ in methods_to_run: + store_uns[f"{method_name}_combined_datasets"] = resolved_paths + if dataset_names is not None: + store_uns[f"{method_name}_combined_dataset_names"] = dataset_names + offset += n + + click.echo(f"Writing to {path} ({n:,} rows)...") + append_to_anndata_zarr(path, obsm=store_obsm, uns=store_uns) + + # Summary + summary_data = [] + for key, embedding in sorted(results.items()): + summary_data.append( + { + "method": key, + "components": embedding.shape[1], + "total_samples": embedding.shape[0], + "stores": len(resolved_paths), + } + ) + click.echo("\n" + format_markdown_table(summary_data, title="Combined Dimensionality Reduction")) + click.echo(f"Results written to {len(resolved_paths)} store(s)") + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/src/dynaclr/evaluation/dimensionality_reduction/reduce_dimensionality.py b/applications/dynaclr/src/dynaclr/evaluation/dimensionality_reduction/reduce_dimensionality.py index ed2b47aa2..ccd82b464 100644 --- a/applications/dynaclr/src/dynaclr/evaluation/dimensionality_reduction/reduce_dimensionality.py +++ b/applications/dynaclr/src/dynaclr/evaluation/dimensionality_reduction/reduce_dimensionality.py @@ -17,7 +17,7 @@ import numpy as np from numpy.typing import NDArray -from viscy_utils.cli_utils import format_markdown_table, load_config +from viscy_utils.cli_utils import format_markdown_table, load_config_section from viscy_utils.evaluation.zarr_utils import append_to_anndata_zarr from .config import ( @@ -51,7 +51,7 @@ def _run_umap(features: NDArray, cfg: UMAPConfig) -> tuple[str, NDArray]: return "X_umap", umap_embedding -def _run_phate(features: NDArray, cfg: PHATEConfig) -> tuple[str, NDArray]: +def _run_phate(features: NDArray, cfg: PHATEConfig, lineage_ids: NDArray | None = None) -> tuple[str, NDArray]: from viscy_utils.evaluation.dimensionality_reduction import compute_phate _, phate_embedding = compute_phate( @@ -62,6 +62,10 @@ def _run_phate(features: NDArray, cfg: PHATEConfig) -> tuple[str, NDArray]: knn_dist=cfg.knn_dist, scale_embeddings=cfg.scale_embeddings, random_state=cfg.random_state, + n_pca=cfg.n_pca, + subsample=cfg.subsample, + lineage_ids=lineage_ids, + n_jobs=cfg.n_jobs, ) return "X_phate", phate_embedding @@ -77,7 +81,7 @@ def _run_phate(features: NDArray, cfg: PHATEConfig) -> tuple[str, NDArray]: def main(config: Path): """Compute PCA, UMAP, and/or PHATE on saved embeddings.""" click.echo("Loading configuration...") - raw_config = load_config(config) + raw_config = load_config_section(config, None, default_section="reduce_dimensionality") cfg = DimensionalityReductionConfig(**raw_config) click.echo(f"Reading embeddings from {cfg.input_path}...") @@ -103,10 +107,15 @@ def main(config: Path): click.echo(f"Computing {len(methods_to_run)} reduction(s): {', '.join(name for name, _, _ in methods_to_run)}") + lineage_ids = adata.obs["lineage_id"].to_numpy() if "lineage_id" in adata.obs.columns else None + results = {} for method_name, method_cfg, obsm_key in methods_to_run: try: - key, embedding = runner_map[method_name](features, method_cfg) + if method_name == "phate": + key, embedding = _run_phate(features, method_cfg, lineage_ids=lineage_ids) + else: + key, embedding = runner_map[method_name](features, method_cfg) results[key] = embedding click.echo(f" {method_name.upper()} done -> {key} ({embedding.shape[1]} components)") except Exception as e: diff --git a/applications/dynaclr/src/dynaclr/evaluation/evaluate.py b/applications/dynaclr/src/dynaclr/evaluation/evaluate.py new file mode 100644 index 000000000..4ec16a648 --- /dev/null +++ b/applications/dynaclr/src/dynaclr/evaluation/evaluate.py @@ -0,0 +1,525 @@ +"""Evaluation config generator for DynaCLR trained models. + +Generates per-step YAML configs from a single eval YAML and prints a JSON manifest +mapping step names to config paths. Called internally by the Nextflow PREPARE_CONFIGS step. + +Usage +----- +dynaclr prepare-eval-configs -c eval_config.yaml +""" + +from __future__ import annotations + +import json +import shutil +from pathlib import Path +from typing import Any + +import click +import yaml + +from dynaclr.evaluation.evaluate_config import EvaluationConfig +from viscy_utils.cli_utils import load_config + +_Z_REDUCTION_CLASS = "viscy_transforms.BatchedChannelWiseZReductiond" + +# Placeholders used in template YAMLs that operate per-experiment zarr. +# Nextflow processes substitute these at runtime when handling per-experiment channels. +_ZARR_PLACEHOLDER = "__ZARR_PATH__" +_PLOT_DIR_PLACEHOLDER = "__PLOT_DIR__" + + +def _load_training_config(path: str) -> dict: + with open(path) as f: + return yaml.safe_load(f) + + +def _extract_predict_data_config(training_cfg: dict, eval_cfg: EvaluationConfig) -> dict: + """Extract data init_args for the predict YAML from the training config. + + Strips augmentations (except BatchedChannelWiseZReductiond which is + architecturally required), overrides batch_size and split_ratio. + """ + data_init = dict(training_cfg["data"]["init_args"]) + + # Override cell_index_path if user supplied one + if eval_cfg.cell_index_path is not None: + data_init["cell_index_path"] = eval_cfg.cell_index_path + + # Move z-reduction transform from augmentations to end of normalizations + augmentations = data_init.pop("augmentations", []) or [] + z_reduction = [t for t in augmentations if _is_z_reduction(t)] + normalizations = list(data_init.get("normalizations") or []) + data_init["normalizations"] = normalizations + z_reduction + data_init["augmentations"] = [] + + # Predict-specific overrides + data_init["batch_size"] = eval_cfg.predict.batch_size + data_init["num_workers"] = eval_cfg.predict.num_workers + data_init["split_ratio"] = 1.0 + + # Remove training-only keys that are irrelevant for predict + for key in ["stratify_by", "batch_group_by", "temporal_enrichment", "leaky", "group_weights"]: + data_init.pop(key, None) + + return data_init + + +def _is_z_reduction(transform: Any) -> bool: + """Check if a transform config is BatchedChannelWiseZReductiond.""" + if isinstance(transform, dict): + return transform.get("class_path", "") == _Z_REDUCTION_CLASS + return False + + +def _extract_model_config(training_cfg: dict) -> dict: + """Extract model config, setting drop_path_rate=0 for inference. + + Only sets drop_path_rate if the encoder already declares it (e.g. ContrastiveEncoder). + Encoders like DINOv3Model do not accept this parameter and must not receive it. + """ + model = dict(training_cfg["model"]) + init_args = dict(model.get("init_args", {})) + encoder = dict(init_args.get("encoder", {})) + encoder_init = dict(encoder.get("init_args", {})) + if "drop_path_rate" in encoder_init: + encoder_init["drop_path_rate"] = 0.0 + encoder["init_args"] = encoder_init + init_args["encoder"] = encoder + model["init_args"] = init_args + return model + + +# --------------------------------------------------------------------------- +# YAML config generators +# --------------------------------------------------------------------------- + + +def _generate_predict_yaml(eval_cfg: EvaluationConfig, training_cfg: dict, output_dir: Path) -> Path: + """Generate the Lightning predict YAML config.""" + embeddings_path = str(output_dir / "embeddings" / "embeddings.zarr") + data_init = _extract_predict_data_config(training_cfg, eval_cfg) + model_cfg = _extract_model_config(training_cfg) + + embedding_writer: dict = { + "class_path": "viscy_utils.callbacks.embedding_writer.EmbeddingWriter", + "init_args": { + "output_path": embeddings_path, + "overwrite": True, + }, + } + + predict_cfg: dict = { + "seed_everything": 42, + "trainer": { + "accelerator": "gpu", + "devices": eval_cfg.predict.devices, + "num_nodes": 1, + "precision": eval_cfg.predict.precision, + "inference_mode": True, + "logger": False, + "callbacks": [embedding_writer], + }, + "model": model_cfg, + "data": { + "class_path": training_cfg["data"]["class_path"], + "init_args": data_init, + }, + "ckpt_path": eval_cfg.ckpt_path, + } + + out_path = output_dir / "configs" / "predict.yml" + with open(out_path, "w") as f: + yaml.dump(predict_cfg, f, default_flow_style=False, sort_keys=False, allow_unicode=True) + return out_path + + +def _generate_reduce_yaml(eval_cfg: EvaluationConfig, output_dir: Path) -> Path: + """Generate dim reduction template config YAML. + + Uses a placeholder for ``input_path`` because the actual per-experiment + zarr paths are only known after the split step runs. + """ + cfg_dict: dict = { + "input_path": _ZARR_PLACEHOLDER, + "overwrite_keys": eval_cfg.reduce_dimensionality.overwrite_keys, + } + if eval_cfg.reduce_dimensionality.pca: + cfg_dict["pca"] = eval_cfg.reduce_dimensionality.pca.model_dump() + if eval_cfg.reduce_dimensionality.umap: + cfg_dict["umap"] = eval_cfg.reduce_dimensionality.umap.model_dump() + if eval_cfg.reduce_dimensionality.phate: + cfg_dict["phate"] = eval_cfg.reduce_dimensionality.phate.model_dump() + + out_path = output_dir / "configs" / "reduce.yaml" + with open(out_path, "w") as f: + yaml.dump(cfg_dict, f, default_flow_style=False, sort_keys=False) + return out_path + + +def _generate_reduce_combined_yaml(eval_cfg: EvaluationConfig, output_dir: Path) -> Path: + """Generate joint dimensionality reduction config YAML. + + ``input_paths`` is populated at runtime by Nextflow (collecting per-experiment zarrs). + """ + rc = eval_cfg.reduce_combined + cfg_dict: dict = { + "input_paths": [_ZARR_PLACEHOLDER], + "overwrite_keys": rc.overwrite_keys, + } + if rc.pca: + cfg_dict["pca"] = rc.pca.model_dump() + if rc.umap: + cfg_dict["umap"] = rc.umap.model_dump() + if rc.phate: + cfg_dict["phate"] = rc.phate.model_dump() + + out_path = output_dir / "configs" / "reduce_combined.yaml" + with open(out_path, "w") as f: + yaml.dump(cfg_dict, f, default_flow_style=False, sort_keys=False) + return out_path + + +def _generate_smoothness_yaml(eval_cfg: EvaluationConfig, output_dir: Path) -> Path: + """Generate smoothness evaluation config YAML.""" + model_name = Path(eval_cfg.training_config).stem + + cfg_dict = { + "models": [{"path": _ZARR_PLACEHOLDER, "label": model_name}], + "evaluation": { + "distance_metric": eval_cfg.smoothness.distance_metric, + "output_dir": str(output_dir / "smoothness"), + "save_plots": eval_cfg.smoothness.save_plots, + "save_distributions": eval_cfg.smoothness.save_distributions, + "verbose": eval_cfg.smoothness.verbose, + }, + } + + out_path = output_dir / "configs" / "smoothness.yaml" + with open(out_path, "w") as f: + yaml.dump(cfg_dict, f, default_flow_style=False, sort_keys=False) + return out_path + + +def _generate_plot_yaml(eval_cfg: EvaluationConfig, output_dir: Path) -> Path: + """Generate per-experiment plot config YAML (template with placeholders).""" + cfg_dict = { + "input_path": _ZARR_PLACEHOLDER, + "output_dir": _PLOT_DIR_PLACEHOLDER, + "embedding_keys": eval_cfg.plot.embedding_keys, + "color_by": eval_cfg.plot.color_by, + "point_size": eval_cfg.plot.point_size, + "components": list(eval_cfg.plot.components), + "format": eval_cfg.plot.format, + } + + out_path = output_dir / "configs" / "plot.yaml" + with open(out_path, "w") as f: + yaml.dump(cfg_dict, f, default_flow_style=False, sort_keys=False) + return out_path + + +def _generate_plot_combined_yaml(eval_cfg: EvaluationConfig, output_dir: Path) -> Path: + """Generate combined plot config YAML. + + The input_paths list is patched at runtime by Nextflow. + """ + cfg_dict = { + "input_paths": [_ZARR_PLACEHOLDER], + "output_dir": str(output_dir / "plots" / "combined"), + "embedding_keys": eval_cfg.plot.combined_embedding_keys, + "color_by": eval_cfg.plot.combined_color_by, + "point_size": eval_cfg.plot.point_size, + "components": list(eval_cfg.plot.components), + "format": eval_cfg.plot.format, + } + + out_path = output_dir / "configs" / "plot_combined.yaml" + with open(out_path, "w") as f: + yaml.dump(cfg_dict, f, default_flow_style=False, sort_keys=False) + return out_path + + +def _generate_append_annotations_yaml(eval_cfg: EvaluationConfig, output_dir: Path) -> Path: + """Generate append-annotations config YAML.""" + lc = eval_cfg.linear_classifiers + cfg_dict = { + "embeddings_path": str(output_dir / "embeddings"), + "annotations": [{"experiment": a.experiment, "path": a.path} for a in lc.annotations], + "tasks": [{"task": t.task, "marker_filters": t.marker_filters} for t in lc.tasks], + } + out_path = output_dir / "configs" / "append_annotations.yaml" + with open(out_path, "w") as f: + yaml.dump(cfg_dict, f, default_flow_style=False, sort_keys=False, allow_unicode=True) + return out_path + + +def _generate_append_predictions_yaml(eval_cfg: EvaluationConfig, output_dir: Path) -> Path: + """Generate append-predictions config YAML.""" + cfg_dict = { + "embeddings_path": str(output_dir / "embeddings"), + "pipelines_dir": str(output_dir / "linear_classifiers" / "pipelines"), + } + out_path = output_dir / "configs" / "append_predictions.yaml" + with open(out_path, "w") as f: + yaml.dump(cfg_dict, f, default_flow_style=False, sort_keys=False) + return out_path + + +def _generate_linear_classifiers_yaml(eval_cfg: EvaluationConfig, output_dir: Path) -> Path: + """Generate linear classifiers config YAML for dynaclr run-linear-classifiers.""" + lc = eval_cfg.linear_classifiers + embeddings_dir = str(output_dir / "embeddings") + lc_output_dir = str(output_dir / "linear_classifiers") + + cfg_dict = { + "embeddings_path": embeddings_dir, + "output_dir": lc_output_dir, + "annotations": [{"experiment": a.experiment, "path": a.path} for a in lc.annotations], + "tasks": [{"task": t.task, "marker_filters": t.marker_filters} for t in lc.tasks], + "use_scaling": lc.use_scaling, + "use_pca": lc.use_pca, + "n_pca_components": lc.n_pca_components, + "max_iter": lc.max_iter, + "class_weight": lc.class_weight, + "solver": lc.solver, + "split_train_data": lc.split_train_data, + "random_seed": lc.random_seed, + } + + out_path = output_dir / "configs" / "linear_classifiers.yaml" + with open(out_path, "w") as f: + yaml.dump(cfg_dict, f, default_flow_style=False, sort_keys=False, allow_unicode=True) + return out_path + + +def _mmd_block_name(mmd: "MMDStepConfig", idx: int) -> str: # noqa: F821 + """Derive a filesystem-safe name for an MMD block.""" + if mmd.name: + return mmd.name + return f"mmd_{idx}" + + +def _generate_mmd_yaml(mmd: "MMDStepConfig", output_dir: Path, block_name: str) -> Path: # noqa: F821 + """Generate per-experiment MMD config YAML template (uses __ZARR_PATH__ placeholder).""" + cfg_dict = { + "input_path": _ZARR_PLACEHOLDER, + "output_dir": str(output_dir / "mmd" / block_name), + "comparisons": [{"cond_a": c.cond_a, "cond_b": c.cond_b, "label": c.label} for c in mmd.comparisons], + "group_by": mmd.group_by, + "obs_filter": mmd.obs_filter, + "embedding_key": mmd.embedding_key, + "mmd": mmd.mmd.model_dump(), + "map_settings": mmd.map_settings.model_dump(), + "temporal_bin_size": mmd.temporal_bin_size, + "save_plots": mmd.save_plots, + } + out_path = output_dir / "configs" / f"{block_name}.yaml" + with open(out_path, "w") as f: + yaml.dump(cfg_dict, f, default_flow_style=False, sort_keys=False) + return out_path + + +def _generate_mmd_combined_yaml(mmd: "MMDStepConfig", output_dir: Path, block_name: str) -> Path: # noqa: F821 + """Generate cross-experiment MMD config YAML template (input_paths patched at runtime).""" + combined_name = f"{block_name}_cross_exp" + combined_bin_size = ( + mmd.combined_temporal_bin_size if mmd.combined_temporal_bin_size is not None else mmd.temporal_bin_size + ) + cfg_dict = { + "input_paths": [_ZARR_PLACEHOLDER], + "output_dir": str(output_dir / "mmd" / combined_name), + "group_by": mmd.group_by, + "obs_filter": mmd.obs_filter, + "embedding_key": mmd.embedding_key, + "mmd": mmd.mmd.model_dump(), + "map_settings": mmd.map_settings.model_dump(), + "temporal_bin_size": combined_bin_size, + "save_plots": mmd.save_plots, + } + out_path = output_dir / "configs" / f"{combined_name}.yaml" + with open(out_path, "w") as f: + yaml.dump(cfg_dict, f, default_flow_style=False, sort_keys=False) + return out_path + + +def _resolve_cell_index_path(eval_cfg: EvaluationConfig, training_cfg: dict) -> str: + """Resolve the cell index parquet path from eval config or training config fallback.""" + if eval_cfg.cell_index_path is not None: + return eval_cfg.cell_index_path + return training_cfg["data"]["init_args"]["cell_index_path"] + + +# --------------------------------------------------------------------------- +# Main prepare_configs function +# --------------------------------------------------------------------------- + + +def prepare_configs(config: Path) -> None: + """Generate all per-step YAML configs and print a JSON manifest to stdout. + + The manifest maps step names to generated config paths and includes paths + needed by Nextflow to wire the pipeline (embeddings_dir, output_dir, + cell_index_path, mmd_blocks). + """ + raw = load_config(config) + eval_cfg = EvaluationConfig(**raw) + + training_cfg = _load_training_config(eval_cfg.training_config) + output_dir = Path(eval_cfg.output_dir) + + # Create output directories for active steps + subdirs = ["configs", "embeddings"] + step_subdirs = { + "smoothness": "smoothness", + "mmd": "mmd", + "plot": "plots", + "linear_classifiers": "linear_classifiers", + } + for step in eval_cfg.steps: + if step in step_subdirs: + subdirs.append(step_subdirs[step]) + for subdir in subdirs: + (output_dir / subdir).mkdir(parents=True, exist_ok=True) + + # Save a copy of the input eval config for reproducibility and re-runs + shutil.copy(config, output_dir / "configs" / "eval.yaml") + + manifest: dict = { + "output_dir": str(output_dir), + "embeddings_dir": str(output_dir / "embeddings"), + "cell_index_path": _resolve_cell_index_path(eval_cfg, training_cfg), + "mmd_blocks": [], + "mmd_combined_blocks": [], + } + + for step in eval_cfg.steps: + if step == "predict": + predict_yml = _generate_predict_yaml(eval_cfg, training_cfg, output_dir) + manifest["predict"] = str(predict_yml) + click.echo(f"[predict] {predict_yml}", err=True) + + elif step == "split": + click.echo( + f"[split] viewer.yaml will be written to {output_dir / 'configs' / 'viewer.yaml'} after split runs", + err=True, + ) + + elif step == "reduce_dimensionality": + reduce_yaml = _generate_reduce_yaml(eval_cfg, output_dir) + manifest["reduce"] = str(reduce_yaml) + click.echo(f"[reduce] {reduce_yaml}", err=True) + + elif step == "reduce_combined": + reduce_combined_yaml = _generate_reduce_combined_yaml(eval_cfg, output_dir) + manifest["reduce_combined"] = str(reduce_combined_yaml) + click.echo(f"[combined] {reduce_combined_yaml}", err=True) + + elif step == "smoothness": + smoothness_yaml = _generate_smoothness_yaml(eval_cfg, output_dir) + manifest["smoothness"] = str(smoothness_yaml) + click.echo(f"[smooth] {smoothness_yaml}", err=True) + + elif step == "plot": + plot_yaml = _generate_plot_yaml(eval_cfg, output_dir) + manifest["plot"] = str(plot_yaml) + click.echo(f"[plot] {plot_yaml}", err=True) + plot_combined_yaml = _generate_plot_combined_yaml(eval_cfg, output_dir) + manifest["plot_combined"] = str(plot_combined_yaml) + click.echo(f"[plot] {plot_combined_yaml}", err=True) + + elif step == "mmd": + if not eval_cfg.mmd: + click.echo("[mmd] skipped: no blocks configured", err=True) + continue + for i, mmd_block in enumerate(eval_cfg.mmd): + block_name = _mmd_block_name(mmd_block, i) + mmd_yaml = _generate_mmd_yaml(mmd_block, output_dir, block_name) + manifest[f"mmd_{block_name}"] = str(mmd_yaml) + manifest[f"mmd_{block_name}_dir"] = str(output_dir / "mmd" / block_name) + manifest["mmd_blocks"].append(block_name) + click.echo(f"[mmd] {mmd_yaml}", err=True) + if mmd_block.combined_mode: + mmd_combined_yaml = _generate_mmd_combined_yaml(mmd_block, output_dir, block_name) + combined_name = f"{block_name}_cross_exp" + manifest[f"mmd_{combined_name}"] = str(mmd_combined_yaml) + manifest["mmd_combined_blocks"].append(block_name) + click.echo(f"[mmd] {mmd_combined_yaml}", err=True) + + elif step == "linear_classifiers": + if eval_cfg.linear_classifiers is None: + click.echo("[linear_classifiers] skipped: no config provided", err=True) + continue + if not eval_cfg.linear_classifiers.annotations: + click.echo( + "[linear_classifiers] Warning: annotations is empty. " + "Add experiment + annotation CSV paths before running.", + err=True, + ) + if not eval_cfg.linear_classifiers.tasks: + click.echo( + "[linear_classifiers] Warning: tasks is empty. " + "Add task specs (task + optional marker_filters) before running.", + err=True, + ) + lc_yaml = _generate_linear_classifiers_yaml(eval_cfg, output_dir) + manifest["linear_classifiers"] = str(lc_yaml) + click.echo(f"[lc] {lc_yaml}", err=True) + + elif step == "append_annotations": + if eval_cfg.linear_classifiers is None: + click.echo( + "[append_annotations] skipped: no linear_classifiers config (annotations come from there)", err=True + ) + continue + if not eval_cfg.linear_classifiers.annotations: + click.echo("[append_annotations] Warning: annotations list is empty, nothing to append", err=True) + aa_yaml = _generate_append_annotations_yaml(eval_cfg, output_dir) + manifest["append_annotations"] = str(aa_yaml) + click.echo(f"[append_ann] {aa_yaml}", err=True) + + elif step == "append_predictions": + if eval_cfg.linear_classifiers is None: + click.echo("[append_predictions] skipped: no linear_classifiers config", err=True) + continue + if "linear_classifiers" not in eval_cfg.steps: + raise ValueError( + "'append_predictions' requires 'linear_classifiers' to also be in steps. " + "Pipelines are saved by run-linear-classifiers and must exist before applying predictions." + ) + ap_yaml = _generate_append_predictions_yaml(eval_cfg, output_dir) + manifest["append_predictions"] = str(ap_yaml) + click.echo(f"[append_pred] {ap_yaml}", err=True) + + else: + click.echo(f"Unknown step '{step}', skipping", err=True) + + # Print JSON manifest to stdout for Nextflow to consume + click.echo(json.dumps(manifest, indent=2)) + + +# --------------------------------------------------------------------------- +# CLI entry points +# --------------------------------------------------------------------------- + + +@click.command(context_settings={"help_option_names": ["-h", "--help"]}) +@click.option( + "-c", + "--config", + type=click.Path(exists=True, path_type=Path), + required=True, + help="Path to evaluation YAML configuration file", +) +def main(config: Path) -> None: + """Generate evaluation configs for a trained DynaCLR model. + + Writes per-step YAML configs to output_dir/configs/ and prints a JSON manifest + to stdout mapping step names to config paths. Used as the entry point for the + Nextflow evaluation pipeline. + """ + prepare_configs(config) + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/src/dynaclr/evaluation/evaluate_config.py b/applications/dynaclr/src/dynaclr/evaluation/evaluate_config.py new file mode 100644 index 000000000..624772750 --- /dev/null +++ b/applications/dynaclr/src/dynaclr/evaluation/evaluate_config.py @@ -0,0 +1,302 @@ +"""Pydantic configuration models for the DynaCLR evaluation orchestrator.""" + +from __future__ import annotations + +from typing import Literal, Optional + +from pydantic import BaseModel + +from dynaclr.evaluation.dimensionality_reduction.config import PCAConfig, PHATEConfig, UMAPConfig +from dynaclr.evaluation.mmd.config import ComparisonSpec, MAPSettings, MMDSettings + + +class PredictStepConfig(BaseModel): + """Configuration for the embedding extraction (predict) step. + + Parameters + ---------- + batch_size : int + Batch size for inference. Default: 128. + num_workers : int + DataLoader thread workers. Default: 2. + precision : str + Mixed-precision setting for Lightning Trainer. Default: "bf16-mixed". + devices : int + Number of GPUs. Default: 1. + """ + + batch_size: int = 128 + num_workers: int = 2 + precision: str = "32-true" + devices: int = 1 + + +class ReduceCombinedStepConfig(BaseModel): + """Configuration for the joint dimensionality reduction step across experiments. + + Parameters + ---------- + overwrite_keys : bool + Whether to overwrite existing obsm keys. Default: True. + pca : PCAConfig or None + PCA parameters for joint fit. Results stored as X_pca_combined. + umap : UMAPConfig or None + UMAP parameters for joint fit. Results stored as X_umap_combined. + phate : PHATEConfig or None + PHATE parameters for joint fit. Results stored as X_phate_combined. + """ + + overwrite_keys: bool = True + pca: Optional[PCAConfig] = PCAConfig(n_components=32, normalize_features=True) + umap: Optional[UMAPConfig] = None + phate: Optional[PHATEConfig] = PHATEConfig(n_components=2, knn=5, decay=40, scale_embeddings=False) + + +class ReduceStepConfig(BaseModel): + """Configuration for the dimensionality reduction step. + + Parameters + ---------- + overwrite_keys : bool + Whether to overwrite existing obsm keys. Default: True. + pca : PCAConfig or None + PCA parameters. None skips PCA. + umap : UMAPConfig or None + UMAP parameters. None skips UMAP. + phate : PHATEConfig or None + PHATE parameters. None skips PHATE. + """ + + overwrite_keys: bool = True + pca: Optional[PCAConfig] = PCAConfig(n_components=32, normalize_features=True) + umap: Optional[UMAPConfig] = None + phate: Optional[PHATEConfig] = None # PHATE runs jointly in reduce_combined, not per-experiment + + +class SmoothnessStepConfig(BaseModel): + """Configuration for the temporal smoothness evaluation step. + + Parameters + ---------- + distance_metric : str + Distance metric. "cosine" or "euclidean". Default: "cosine". + save_plots : bool + Save distribution plots. Default: True. + save_distributions : bool + Save raw distribution arrays. Default: False. + verbose : bool + Print verbose progress. Default: True. + """ + + distance_metric: Literal["cosine", "euclidean"] = "cosine" + save_plots: bool = True + save_distributions: bool = False + verbose: bool = True + + +class PlotStepConfig(BaseModel): + """Configuration for the embedding visualization step. + + Parameters + ---------- + embedding_keys : list[str] + Per-experiment obsm keys to plot (looped over each split zarr). + Default: ["X_pca"]. + combined_embedding_keys : list[str] + Cross-experiment obsm keys to plot once across all zarrs concatenated. + Default: ["X_pca_combined", "X_phate_combined"]. + color_by : list[str] + obs columns for per-experiment plots. Default: perturbation, hours, marker. + combined_color_by : list[str] + obs columns for combined (cross-experiment) plots. Adds "experiment" to color_by. + point_size : float + Scatter plot point size. Default: 1.0. + components : tuple[int, int] + Which components to use as X/Y axes (0-indexed). Default: (0, 1). + format : str + Output format. "pdf" or "png". Default: "pdf". + """ + + embedding_keys: list[str] = ["X_pca"] + combined_embedding_keys: list[str] = ["X_pca_combined", "X_phate_combined"] + color_by: list[str] = ["perturbation", "hours_post_perturbation", "marker"] + combined_color_by: list[str] = ["perturbation", "hours_post_perturbation", "experiment", "marker"] + point_size: float = 1.0 + components: tuple[int, int] = (0, 1) + format: str = "pdf" + + +class AnnotationSource(BaseModel): + """Annotation CSV for one experiment. + + Parameters + ---------- + experiment : str + Experiment name matching obs["experiment"] in the embeddings zarr. + path : str + Absolute path to the annotation CSV. Must have fov_name, id, and + at least one task column (e.g. infection_state, organelle_state). + """ + + experiment: str + path: str + + +class TaskSpec(BaseModel): + """One classification task to evaluate. + + Parameters + ---------- + task : str + Task column name in annotation CSVs (e.g. infection_state, organelle_state). + marker_filters : list[str] or None + If set, run one classifier per listed marker. None (default) runs one + classifier per marker discovered in the data (all unique obs["marker"] values). + """ + + task: str + marker_filters: Optional[list[str]] = None + + +class MMDStepConfig(BaseModel): + """Configuration for one MMD evaluation block. + + Comparisons are explicit ``(cond_a, cond_b, label)`` pairs — no auto-discovery. + Include a null comparison (e.g. uninfected1 vs uninfected2) to establish + a baseline false-positive rate. + + Parameters + ---------- + comparisons : list[ComparisonSpec] + Explicit pairwise comparisons to run. + group_by : str + obs column whose values are referenced by ``cond_a``/``cond_b``. + Default: "perturbation". + obs_filter : dict[str, str] or None + Subset adata to rows where obs[key] == value before running MMD. + Example: ``{perturbation: uninfected}`` to restrict batch-QC + comparisons to control cells only. None = use all cells. + embedding_key : str or None + obsm key to use. None = raw .X. Default: None. + mmd : MMDSettings + Kernel MMD algorithm settings (permutations, cell caps, seed, etc.). + map_settings : MAPSettings + copairs-based mAP settings. Default: disabled. + temporal_bin_size : float or None + Width of each temporal bin in hours. Edges derived from data max. + None = aggregate MMD. + combined_temporal_bin_size : float or None + Override temporal_bin_size for the combined (cross-experiment) run only. + If not set, falls back to temporal_bin_size. Use None to aggregate across + all time in the combined run while keeping per-experiment binning. + save_plots : bool + Generate kinetics and heatmap plots. Default: True. + combined_mode : bool + Also run cross-experiment MMD with per-experiment batch centering. + Default: False. + name : str or None + Short name used in output filenames (e.g. "perturbation", "batch_qc"). + Auto-derived from group_by if None. + """ + + comparisons: list[ComparisonSpec] + group_by: str = "perturbation" + obs_filter: Optional[dict[str, str]] = None + embedding_key: Optional[str] = None + mmd: MMDSettings = MMDSettings() + map_settings: MAPSettings = MAPSettings() + temporal_bin_size: Optional[float] = None + combined_temporal_bin_size: Optional[float] = None + save_plots: bool = True + combined_mode: bool = False + name: Optional[str] = None + + +class LinearClassifiersStepConfig(BaseModel): + """Configuration for the orchestrated linear classifiers step. + + Parameters + ---------- + annotations : list[AnnotationSource] + Per-experiment annotation CSVs. Each entry maps an experiment name + (matching obs["experiment"] in embeddings.zarr) to a CSV path. + tasks : list[TaskSpec] + Tasks to evaluate. Each task can optionally filter by marker. + use_scaling : bool + Apply StandardScaler. Default: True. + use_pca : bool + Apply PCA before classifier. Default: False. + n_pca_components : int or None + Number of PCA components (required if use_pca is True). + max_iter : int + Max iterations for solver. Default: 1000. + class_weight : str or None + Class weighting. "balanced" or None. Default: "balanced". + solver : str + Optimization algorithm. Default: "liblinear". + split_train_data : float + Fraction for training. Default: 0.8. + random_seed : int + Random seed for reproducibility. Default: 42. + """ + + annotations: list[AnnotationSource] + tasks: list[TaskSpec] + use_scaling: bool = True + use_pca: bool = False + n_pca_components: Optional[int] = None + max_iter: int = 1000 + class_weight: Optional[str] = "balanced" + solver: str = "liblinear" + split_train_data: float = 0.8 + random_seed: int = 42 + + +class EvaluationConfig(BaseModel): + """Top-level configuration for the DynaCLR evaluation orchestrator. + + Parameters + ---------- + training_config : str + Path to the training YAML config (Lightning CLI format). Model + architecture, normalizations, and data parameters are auto-extracted. + ckpt_path : str + Path to the model checkpoint (.ckpt). + cell_index_path : str or None + Override the cell index parquet path from the training config. + None = use the path from the training config. + output_dir : str + Root directory for all evaluation outputs. + steps : list[str] + Ordered list of steps to generate configs for. + Valid values: predict, split, reduce_dimensionality, reduce_combined, + plot, smoothness, mmd, linear_classifiers. + predict : PredictStepConfig + Predict step configuration. + reduce_dimensionality : ReduceStepConfig + Per-experiment dimensionality reduction step configuration. + reduce_combined : ReduceCombinedStepConfig + Joint dimensionality reduction across all experiments. + smoothness : SmoothnessStepConfig + Smoothness evaluation configuration. + plot : PlotStepConfig + Embedding visualization configuration. + linear_classifiers : LinearClassifiersStepConfig or None + Linear classifier configuration. None disables this step. + mmd : list[MMDStepConfig] + MMD evaluation blocks. Each block is an independent run with its own + group_by, comparisons, and optional obs_filter. Empty list disables MMD. + """ + + training_config: str + ckpt_path: str + cell_index_path: Optional[str] = None + output_dir: str + steps: list[str] = ["predict", "split", "reduce_dimensionality", "reduce_combined", "plot", "smoothness"] + predict: PredictStepConfig = PredictStepConfig() + reduce_dimensionality: ReduceStepConfig = ReduceStepConfig() + reduce_combined: ReduceCombinedStepConfig = ReduceCombinedStepConfig() + smoothness: SmoothnessStepConfig = SmoothnessStepConfig() + plot: PlotStepConfig = PlotStepConfig() + linear_classifiers: Optional[LinearClassifiersStepConfig] = None + mmd: list[MMDStepConfig] = [] diff --git a/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/apply_linear_classifier.py b/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/apply_linear_classifier.py index 3c98029b5..34e738ff1 100644 --- a/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/apply_linear_classifier.py +++ b/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/apply_linear_classifier.py @@ -10,7 +10,7 @@ from anndata import read_zarr from pydantic import ValidationError -from viscy_utils.cli_utils import format_markdown_table, load_config +from viscy_utils.cli_utils import format_markdown_table, load_config_section from viscy_utils.evaluation.linear_classifier import ( load_pipeline_from_wandb, predict_with_classifier, @@ -92,7 +92,7 @@ def main(config: Path): click.echo("=" * 60) try: - config_dict = load_config(config) + config_dict = load_config_section(config, None, default_section="apply_linear_classifier") inference_config = LinearClassifierInferenceConfig(**config_dict) except ValidationError as e: click.echo(f"\n Configuration validation failed:\n{e}", err=True) diff --git a/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/cross_validation.py b/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/cross_validation.py index 3d4a33d80..47bb3172e 100644 --- a/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/cross_validation.py +++ b/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/cross_validation.py @@ -37,7 +37,7 @@ get_available_tasks, resolve_task_channels, ) -from viscy_utils.cli_utils import format_markdown_table, load_config +from viscy_utils.cli_utils import format_markdown_table, load_config_section from viscy_utils.evaluation.annotation import load_annotation_anndata from viscy_utils.evaluation.linear_classifier import ( load_and_combine_datasets, @@ -137,17 +137,6 @@ def _get_class_counts(datasets_for_combo: list[dict], task: str) -> dict[str, in return dict(pd.Series(all_labels).value_counts()) -def _detect_n_features(datasets: list[dict], channel: str) -> int | None: - """Detect embedding dimensionality from the first available zarr.""" - for ds in datasets: - embeddings_dir = Path(ds["embeddings_dir"]) - channel_zarrs = find_channel_zarrs(embeddings_dir, [channel]) - if channel in channel_zarrs: - adata = ad.read_zarr(channel_zarrs[channel]) - return adata.shape[1] - return None - - # --------------------------------------------------------------------------- # Core rotating CV unit # --------------------------------------------------------------------------- @@ -234,7 +223,7 @@ def _train_and_evaluate( "random_state": seed, } - pipeline, metrics = train_linear_classifier( + pipeline, metrics, _ = train_linear_classifier( adata=combined_adata, task=task, use_scaling=use_scaling, @@ -828,7 +817,7 @@ def _get_recommended_subsets(summary_df: pd.DataFrame) -> pd.DataFrame: ) def main(config: Path, task: str | None, report: bool): """Run rotating test-set leave-one-dataset-out cross-validation.""" - config_dict = load_config(config) + config_dict = load_config_section(config, None, default_section="cross_validate") if report: config_dict["report"] = True diff --git a/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/evaluate_dataset.py b/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/evaluate_dataset.py deleted file mode 100644 index ad615758f..000000000 --- a/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/evaluate_dataset.py +++ /dev/null @@ -1,456 +0,0 @@ -"""Evaluation pipeline comparing embedding models on a held-out test dataset. - -Trains linear classifiers on cross-dataset embeddings, applies them to a -held-out test set, evaluates predictions, and optionally generates a PDF -comparison report. - -Usage:: - - python scripts/evaluate_dataset.py -c configs/evaluate_dataset_example.yaml - python scripts/evaluate_dataset.py -c config.yaml --report -""" - -from __future__ import annotations - -import argparse -from pathlib import Path -from typing import Any - -import anndata as ad -import joblib -import pandas as pd -from sklearn.metrics import classification_report - -from dynaclr.evaluation.linear_classifiers.utils import ( - find_channel_zarrs, - get_available_tasks, - resolve_task_channels, -) -from viscy_utils.cli_utils import format_markdown_table, load_config -from viscy_utils.evaluation.annotation import load_annotation_anndata -from viscy_utils.evaluation.linear_classifier import ( - load_and_combine_datasets, - predict_with_classifier, - save_pipeline_to_wandb, - train_linear_classifier, -) - -# --------------------------------------------------------------------------- -# Main evaluation function -# --------------------------------------------------------------------------- - - -def run_evaluation(config: dict) -> None: - """Run the full evaluation pipeline: train, infer, evaluate, report. - - Parameters - ---------- - config : dict - Evaluation config parsed from YAML. Expected keys: - - dataset_name: str - - test_annotations_csv: str path - - output_dir: str path - - models: dict of model specs - - task_channels: dict or None (auto-detect from test CSV) - - use_scaling, n_pca_components, max_iter, class_weight, solver, - split_train_data, random_seed - - wandb_logging: bool (default True) - """ - output_dir = Path(config["output_dir"]) - output_dir.mkdir(parents=True, exist_ok=True) - - test_csv = Path(config["test_annotations_csv"]) - tc = resolve_task_channels(config.get("task_channels"), [test_csv]) - if not tc: - raise ValueError("No valid tasks found in test annotations CSV.") - - model_labels = list(config["models"].keys()) - - print("## Evaluation Pipeline") - print(f" Test dataset: {config['dataset_name']}") - print(f" Task-channels: {tc}") - print(f" Models: {model_labels}") - - use_scaling = config.get("use_scaling", True) - n_pca = config.get("n_pca_components") - use_pca = n_pca is not None - split_train_data = config.get("split_train_data", 0.8) - random_seed = config.get("random_seed", 42) - wandb_logging = config.get("wandb_logging", True) - - classifier_params = { - "max_iter": config.get("max_iter", 1000), - "class_weight": config.get("class_weight", "balanced"), - "solver": config.get("solver", "liblinear"), - "random_state": random_seed, - } - - train_results: dict[str, dict[tuple[str, str], dict[str, Any]]] = {} - eval_results: dict[str, dict[tuple[str, str], dict[str, Any]]] = {} - - for model_label, model_spec in config["models"].items(): - print(f"\n### Model: {model_label} ({model_spec.get('name', model_label)})") - model_train: dict[tuple[str, str], dict[str, Any]] = {} - model_eval: dict[tuple[str, str], dict[str, Any]] = {} - model_output_dir = output_dir / model_label - model_output_dir.mkdir(parents=True, exist_ok=True) - - test_embeddings_dir = Path(model_spec["test_embeddings_dir"]) - - for task, channels in tc.items(): - test_channel_zarrs = find_channel_zarrs(test_embeddings_dir, channels) - - for channel in channels: - combo_key = (task, channel) - print(f"\n {task} / {channel}:") - - # --- Train --- - try: - datasets_for_combo = _build_train_datasets(model_spec["train_datasets"], task, channel) - if not datasets_for_combo: - print(" No training datasets available, skipping.") - continue - - print(f" Training on {len(datasets_for_combo)} dataset(s)") - combined_adata = load_and_combine_datasets(datasets_for_combo, task) - - pipeline, metrics = train_linear_classifier( - adata=combined_adata, - task=task, - use_scaling=use_scaling, - use_pca=use_pca, - n_pca_components=n_pca, - classifier_params=classifier_params, - split_train_data=split_train_data, - random_seed=random_seed, - ) - - pipeline_path = model_output_dir / f"{task}_{channel}_pipeline.joblib" - joblib.dump(pipeline, pipeline_path) - print(f" Pipeline saved: {pipeline_path.name}") - - artifact_name = f"{model_spec.get('name', model_label)}_{task}_{channel}_local" - if wandb_logging and "wandb_project" in model_spec: - wandb_config = { - "task": task, - "input_channel": channel, - "marker": config.get("marker"), - "embedding_model": f"{model_spec['name']}-{model_spec['version']}", - "test_dataset": config["dataset_name"], - "use_scaling": use_scaling, - "use_pca": use_pca, - "n_pca_components": n_pca, - "max_iter": classifier_params["max_iter"], - "class_weight": classifier_params["class_weight"], - "solver": classifier_params["solver"], - "split_train_data": split_train_data, - "random_seed": random_seed, - } - wandb_tags = [ - config["dataset_name"], - model_spec["name"], - model_spec["version"], - channel, - task, - "cross-dataset", - ] - artifact_name = save_pipeline_to_wandb( - pipeline=pipeline, - metrics=metrics, - config=wandb_config, - wandb_project=model_spec["wandb_project"], - tags=wandb_tags, - ) - - model_train[combo_key] = { - "pipeline": pipeline, - "metrics": metrics, - "artifact_name": artifact_name, - } - - val_acc = metrics.get("val_accuracy") - val_f1 = metrics.get("val_weighted_f1") - if val_acc is not None: - print(f" Val accuracy: {val_acc:.3f} Val F1: {val_f1:.3f}") - - except Exception as e: - print(f" TRAIN FAILED: {e}") - continue - - # --- Infer + Evaluate --- - if channel not in test_channel_zarrs: - print(f" No test zarr for {channel}, skipping inference.") - continue - - try: - print(" Loading test embeddings...") - test_adata = ad.read_zarr(test_channel_zarrs[channel]) - - artifact_metadata = { - "artifact_name": artifact_name, - "artifact_id": artifact_name, - "artifact_version": "local", - } - test_adata = predict_with_classifier( - test_adata, - pipeline, - task, - artifact_metadata=artifact_metadata, - ) - - pred_path = model_output_dir / f"{task}_{channel}_predictions.zarr" - test_adata.write_zarr(pred_path) - print(f" Saved predictions: {pred_path.name}") - - # Evaluate against ground truth - annotated = load_annotation_anndata(test_adata, str(test_csv), task) - mask = annotated.obs[task].notna() & (annotated.obs[task] != "unknown") - eval_subset = annotated[mask] - - if len(eval_subset) == 0: - print(" No annotated test cells after filtering.") - continue - - pred_col = f"predicted_{task}" - y_true = eval_subset.obs[task].values - y_pred = eval_subset.obs[pred_col].values - - report = classification_report(y_true, y_pred, digits=3, output_dict=True) - - test_metrics = { - "test_accuracy": report["accuracy"], - "test_weighted_precision": report["weighted avg"]["precision"], - "test_weighted_recall": report["weighted avg"]["recall"], - "test_weighted_f1": report["weighted avg"]["f1-score"], - "test_n_samples": len(eval_subset), - } - - for class_name in sorted(set(y_true) | set(y_pred)): - if class_name in report: - test_metrics[f"test_{class_name}_precision"] = report[class_name]["precision"] - test_metrics[f"test_{class_name}_recall"] = report[class_name]["recall"] - test_metrics[f"test_{class_name}_f1"] = report[class_name]["f1-score"] - - annotated_path = model_output_dir / f"{task}_{channel}_annotated.zarr" - annotated.write_zarr(annotated_path) - - model_eval[combo_key] = { - "metrics": test_metrics, - "annotated_adata": annotated, - } - - acc = test_metrics["test_accuracy"] - f1 = test_metrics["test_weighted_f1"] - n = test_metrics["test_n_samples"] - print(f" Test: acc={acc:.3f} F1={f1:.3f} (n={n})") - - except Exception as e: - print(f" EVAL FAILED: {e}") - continue - - train_results[model_label] = model_train - eval_results[model_label] = model_eval - - # Save per-model metrics CSV - _save_metrics_csv( - model_train, - model_eval, - model_output_dir / "metrics_summary.csv", - ) - - # Save combined comparison CSVs - _save_comparison_csv(train_results, output_dir / "train_metrics_comparison.csv") - _save_eval_comparison_csv(eval_results, output_dir / "test_metrics_comparison.csv") - - # Print markdown summary - _print_summary(train_results, eval_results, tc) - - return train_results, eval_results - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def _build_train_datasets(train_datasets: list[dict], task: str, channel: str) -> list[dict]: - """Filter and build training dataset dicts for a (task, channel) combo. - - Parameters - ---------- - train_datasets : list[dict] - Raw dataset entries from config, each with 'embeddings_dir' and 'annotations'. - task : str - Classification task to check for. - channel : str - Channel to look for in embeddings_dir. - - Returns - ------- - list[dict] - Filtered list with 'embeddings' and 'annotations' keys. - """ - result = [] - for ds in train_datasets: - embeddings_dir = Path(ds["embeddings_dir"]) - annotations_path = Path(ds["annotations"]) - - channel_zarrs = find_channel_zarrs(embeddings_dir, [channel]) - if channel not in channel_zarrs: - print(f" Skipping {embeddings_dir.parent.name} - no {channel} zarr") - continue - - available_tasks = get_available_tasks(annotations_path) - if task not in available_tasks: - print(f" Skipping {embeddings_dir.parent.name} - no {task} column") - continue - - training_dict = { - "embeddings": str(channel_zarrs[channel]), - "annotations": str(annotations_path), - } - if "include_wells" in ds: - training_dict["include_wells"] = ds["include_wells"] - result.append(training_dict) - return result - - -def _save_metrics_csv( - train_results: dict[tuple[str, str], dict[str, Any]], - eval_results: dict[tuple[str, str], dict[str, Any]], - output_path: Path, -) -> None: - """Save combined train + eval metrics for one model.""" - rows = [] - all_keys = set(train_results.keys()) | set(eval_results.keys()) - for combo_key in sorted(all_keys): - task, channel = combo_key - row = {"task": task, "channel": channel} - if combo_key in train_results: - row.update(train_results[combo_key]["metrics"]) - if combo_key in eval_results: - row.update(eval_results[combo_key]["metrics"]) - rows.append(row) - - if rows: - pd.DataFrame(rows).to_csv(output_path, index=False) - - -def _save_comparison_csv( - all_results: dict[str, dict[tuple[str, str], dict[str, Any]]], - output_path: Path, -) -> None: - """Save combined train metrics comparison across models.""" - rows = [] - for model_label, model_results in all_results.items(): - for (task, channel), result in model_results.items(): - row = {"model": model_label, "task": task, "channel": channel} - row.update(result["metrics"]) - rows.append(row) - if rows: - pd.DataFrame(rows).to_csv(output_path, index=False) - - -def _save_eval_comparison_csv( - all_results: dict[str, dict[tuple[str, str], dict[str, Any]]], - output_path: Path, -) -> None: - """Save combined test metrics comparison across models.""" - rows = [] - for model_label, model_results in all_results.items(): - for (task, channel), result in model_results.items(): - row = {"model": model_label, "task": task, "channel": channel} - row.update(result["metrics"]) - rows.append(row) - if rows: - pd.DataFrame(rows).to_csv(output_path, index=False) - - -def _print_summary( - train_results: dict[str, dict[tuple[str, str], dict[str, Any]]], - eval_results: dict[str, dict[tuple[str, str], dict[str, Any]]], - task_channels: dict[str, list[str]], -) -> None: - """Print markdown summary table of all results.""" - headers = ["Task", "Channel"] - model_labels = list(train_results.keys()) - for label in model_labels: - headers += [ - f"{label} Val Acc", - f"{label} Val F1", - f"{label} Test Acc", - f"{label} Test F1", - ] - - rows = [] - for task, channels in task_channels.items(): - for channel in channels: - row_dict = {"Task": task, "Channel": channel} - for label in model_labels: - tr = train_results.get(label, {}).get((task, channel)) - ev = eval_results.get(label, {}).get((task, channel)) - if tr: - row_dict[f"{label} Val Acc"] = f"{tr['metrics'].get('val_accuracy', float('nan')):.3f}" - row_dict[f"{label} Val F1"] = f"{tr['metrics'].get('val_weighted_f1', float('nan')):.3f}" - else: - row_dict[f"{label} Val Acc"] = "-" - row_dict[f"{label} Val F1"] = "-" - if ev: - row_dict[f"{label} Test Acc"] = f"{ev['metrics'].get('test_accuracy', float('nan')):.3f}" - row_dict[f"{label} Test F1"] = f"{ev['metrics'].get('test_weighted_f1', float('nan')):.3f}" - else: - row_dict[f"{label} Test Acc"] = "-" - row_dict[f"{label} Test F1"] = "-" - rows.append(row_dict) - - print(format_markdown_table(rows, title="Evaluation Summary", headers=headers)) - - -# --------------------------------------------------------------------------- -# Entry point -# --------------------------------------------------------------------------- - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Evaluate embedding models on a held-out test dataset") - parser.add_argument( - "-c", - "--config", - type=str, - required=True, - help="Path to YAML config file", - ) - parser.add_argument( - "--report", - action="store_true", - help="Generate PDF comparison report", - ) - args = parser.parse_args() - - config = load_config(args.config) - - print(f"Dataset: {config['dataset_name']}") - print(f"Output: {config['output_dir']}") - for label, spec in config["models"].items(): - n_train = len(spec["train_datasets"]) - print(f" {label}: {n_train} training dataset(s)") - - train_results, eval_results = run_evaluation(config) - - if args.report: - from dynaclr.evaluation.linear_classifiers.report import generate_comparison_report - - test_csv = Path(config["test_annotations_csv"]) - tc = resolve_task_channels(config.get("task_channels"), [test_csv]) - tasks = list(tc.keys()) - channels = sorted({ch for chs in tc.values() for ch in chs}) - - generate_comparison_report( - output_dir=Path(config["output_dir"]), - dataset_name=config["dataset_name"], - model_labels=list(config["models"].keys()), - tasks=tasks, - channels=channels, - train_results=train_results, - eval_results=eval_results, - ) diff --git a/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/orchestrated.py b/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/orchestrated.py new file mode 100644 index 000000000..555dae847 --- /dev/null +++ b/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/orchestrated.py @@ -0,0 +1,442 @@ +"""Orchestrated linear classifiers evaluation from a single embeddings zarr. + +Reads the combined embeddings.zarr produced by the predict step, filters by +experiment and marker, joins per-experiment annotation CSVs, and trains one +logistic regression classifier per (task, marker_filter) combination. + +Outputs a metrics_summary.csv and a summary PDF to the output directory. +No W&B logging. For standalone training with W&B use ``dynaclr train-linear-classifier``. + +Usage +----- +dynaclr run-linear-classifiers -c linear_classifiers.yaml +""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import click +import joblib +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from matplotlib.backends.backend_pdf import PdfPages +from sklearn.model_selection import train_test_split + +from viscy_utils.cli_utils import format_markdown_table, load_config +from viscy_utils.evaluation.annotation import load_annotation_anndata +from viscy_utils.evaluation.linear_classifier import train_linear_classifier + +matplotlib.use("Agg") + +if TYPE_CHECKING: + import anndata as ad + + from dynaclr.evaluation.evaluate_config import LinearClassifiersStepConfig + + +def run_linear_classifiers( + embeddings_path: Path, + config: LinearClassifiersStepConfig, + output_dir: Path, +) -> pd.DataFrame: + """Train linear classifiers for each (task, marker_filter) combination. + + Parameters + ---------- + embeddings_path : Path + Path to the combined embeddings zarr (AnnData format). Must have + experiment and marker columns in obs (added by the predict step). + config : LinearClassifiersStepConfig + Configuration with annotations list and task specs. + output_dir : Path + Directory to write metrics_summary.csv. + + Returns + ------- + pd.DataFrame + One row per (task, marker_filter) with accuracy, F1, AUROC, etc. + """ + import anndata as ad + + click.echo(f"Loading embeddings from {embeddings_path}") + if embeddings_path.is_dir() and not str(embeddings_path).endswith(".zarr"): + zarr_paths = sorted(embeddings_path.glob("*.zarr")) + if not zarr_paths: + raise FileNotFoundError(f"No .zarr files found in {embeddings_path}") + parts = [ad.read_zarr(p) for p in zarr_paths] + adata = ad.concat(parts, join="outer") + adata.obs_names_make_unique() + click.echo(f" Loaded {len(zarr_paths)} per-experiment zarrs") + else: + adata = ad.read_zarr(embeddings_path) + click.echo(f" {adata.n_obs} cells, {adata.n_vars} features") + + missing = [col for col in ["experiment", "marker"] if col not in adata.obs.columns] + if missing: + raise ValueError( + f"embeddings.zarr obs is missing columns: {missing}. " + "Re-run the predict step with the updated pipeline to include metadata." + ) + + all_metrics: list[dict] = [] + # val_outputs_by_task: task → list of per-marker dicts for plotting + val_outputs_by_task: dict[str, list[dict[str, Any]]] = {} + # Saved pipelines for append-predictions step + pipelines_dir = output_dir / "pipelines" + pipelines_dir.mkdir(parents=True, exist_ok=True) + pipeline_manifest: list[dict] = [] + + for task_spec in config.tasks: + task = task_spec.task + # Expand marker_filters: None → all unique markers; list → one run per specified marker + runs: list[str] = ( + task_spec.marker_filters + if task_spec.marker_filters is not None + else sorted(adata.obs["marker"].unique().tolist()) + ) + val_outputs_by_task[task] = [] + + for marker_filter in runs: + label = f"{task}" + (f" (marker={marker_filter})" if marker_filter else " (all markers)") + click.echo(f"\n{'=' * 60}") + click.echo(f"Task: {label}") + click.echo("=" * 60) + + # Filter by marker if specified + if marker_filter is not None: + adata_task = adata[adata.obs["marker"] == marker_filter] + click.echo(f" Filtered to {adata_task.n_obs} cells with marker={marker_filter}") + else: + adata_task = adata + + if adata_task.n_obs == 0: + click.echo(f" No cells found for marker_filter={marker_filter!r}, skipping.") + continue + + # Join annotation CSVs per experiment and collect annotated subsets + annotated_parts: list[ad.AnnData] = [] + for ann_src in config.annotations: + exp_mask = adata_task.obs["experiment"] == ann_src.experiment + n_exp = int(exp_mask.sum()) + if n_exp == 0: + click.echo(f" Experiment {ann_src.experiment!r}: no matching cells, skipping.") + continue + + adata_exp = adata_task[exp_mask].copy() + ann_path = Path(ann_src.path) + if not ann_path.exists(): + raise FileNotFoundError(f"Annotation CSV not found: {ann_src.path}") + + try: + adata_exp = load_annotation_anndata(adata_exp, str(ann_path), task) + except KeyError: + click.echo(f" Experiment {ann_src.experiment!r}: task {task!r} not in {ann_path.name}, skipping.") + continue + + valid_mask = adata_exp.obs[task].notna() & (adata_exp.obs[task] != "unknown") + n_valid = int(valid_mask.sum()) + if n_valid == 0: + click.echo(f" Experiment {ann_src.experiment!r}: no valid labels for {task!r}, skipping.") + continue + + annotated_parts.append(adata_exp[valid_mask]) + click.echo(f" Experiment {ann_src.experiment!r}: {n_valid}/{n_exp} labeled cells") + + if not annotated_parts: + click.echo(f" No annotated data found for task {task!r}, skipping.") + continue + + combined = annotated_parts[0] if len(annotated_parts) == 1 else ad.concat(annotated_parts, join="outer") + class_dist = combined.obs[task].value_counts().to_dict() + click.echo(f" Total: {combined.n_obs} cells, class distribution: {class_dist}") + + classifier_params = { + "max_iter": config.max_iter, + "class_weight": config.class_weight, + "solver": config.solver, + "random_state": config.random_seed, + } + + try: + pipeline, metrics, val_outputs = train_linear_classifier( + adata=combined, + task=task, + use_scaling=config.use_scaling, + use_pca=config.use_pca, + n_pca_components=config.n_pca_components, + classifier_params=classifier_params, + split_train_data=config.split_train_data, + random_seed=config.random_seed, + ) + except ValueError as exc: + click.echo(f" Skipping {label}: {exc}") + continue + + # Save pipeline for append-predictions step + pipeline_filename = f"{task}_{marker_filter}.joblib" + joblib.dump(pipeline, pipelines_dir / pipeline_filename) + pipeline_manifest.append({"task": task, "marker_filter": marker_filter, "path": pipeline_filename}) + click.echo(f" Pipeline saved: {pipeline_filename}") + + # Replay the same split to recover val obs (hours_post_perturbation) + y_full = combined.obs[task].to_numpy(dtype=object) + val_hours: np.ndarray | None = None + if config.split_train_data < 1.0 and "hours_post_perturbation" in combined.obs.columns: + try: + idx = np.arange(len(combined)) + _, idx_val = train_test_split( + idx, + train_size=config.split_train_data, + random_state=config.random_seed, + stratify=y_full, + shuffle=True, + ) + val_hours = combined.obs["hours_post_perturbation"].to_numpy()[idx_val] + except ValueError: + click.echo(" Could not replay stratified split for val_hours; F1-over-time plot skipped.") + + row = { + "task": task, + "marker_filter": marker_filter, + "n_samples": combined.n_obs, + **metrics, + } + all_metrics.append(row) + val_outputs_by_task[task].append( + { + "marker_filter": marker_filter, + "val_hours": val_hours, + **val_outputs, + } + ) + + if not all_metrics: + click.echo("\nNo classifiers trained — check annotations and marker filters.") + return pd.DataFrame() + + results_df = pd.DataFrame(all_metrics) + output_dir.mkdir(parents=True, exist_ok=True) + summary_path = output_dir / "metrics_summary.csv" + results_df.to_csv(summary_path, index=False) + click.echo(f"\nMetrics summary written to {summary_path}") + + manifest_path = pipelines_dir / "manifest.json" + with open(manifest_path, "w") as f: + json.dump(pipeline_manifest, f, indent=2) + click.echo(f"Pipeline manifest written to {manifest_path}") + + _print_summary(results_df) + for task, task_val_outputs in val_outputs_by_task.items(): + task_df = results_df[results_df["task"] == task] + _save_task_plots(task, task_df, task_val_outputs, output_dir) + return results_df + + +def _print_summary(results_df: pd.DataFrame) -> None: + """Print a markdown summary table of key metrics.""" + click.echo("\n## Linear Classifier Results\n") + + per_class_f1_cols = sorted(c for c in results_df.columns if c.startswith("val_") and c.endswith("_f1")) + summary_cols = [ + "task", + "marker_filter", + "n_samples", + "val_accuracy", + "val_weighted_f1", + "val_auroc", + ] + per_class_f1_cols + display = results_df[[c for c in summary_cols if c in results_df.columns]].copy() + + float_cols = [c for c in display.columns if c not in ("task", "marker_filter")] + for col in float_cols: + if pd.api.types.is_float_dtype(display[col]): + display[col] = display[col].map(lambda v: f"{v:.3f}" if pd.notna(v) else "N/A") + + rows = display.to_dict(orient="records") + click.echo(format_markdown_table(rows, headers=list(display.columns))) + + +def _save_task_plots( + task: str, + task_df: pd.DataFrame, + task_val_outputs: list[dict[str, Any]], + output_dir: Path, +) -> None: + """Save one PDF per task with bar chart, ROC curves, and F1-over-time plots. + + Parameters + ---------- + task : str + Task name (used in filename and titles). + task_df : pd.DataFrame + Rows from metrics_summary.csv for this task (one row per marker). + task_val_outputs : list[dict] + Per-marker val outputs. Each entry has keys ``marker_filter``, + ``y_val``, ``y_val_proba``, ``classes``, ``val_hours``. + output_dir : Path + Directory to write ``{task}_summary.pdf``. + """ + pdf_path = output_dir / f"{task}_summary.pdf" + + with PdfPages(pdf_path) as pdf: + _plot_metrics_bar(pdf, task, task_df) + for vo in task_val_outputs: + if vo["y_val"] is None or vo["y_val_proba"] is None: + continue + _plot_roc_curves(pdf, task, vo["marker_filter"], vo["y_val"], vo["y_val_proba"], vo["classes"]) + if vo["val_hours"] is not None: + _plot_f1_over_time( + pdf, task, vo["marker_filter"], vo["y_val"], vo["y_val_proba"], vo["classes"], vo["val_hours"] + ) + + click.echo(f"Plots written to {pdf_path}") + + +def _plot_metrics_bar(pdf: PdfPages, task: str, task_df: pd.DataFrame) -> None: + """Bar chart of AUROC, accuracy, and weighted F1 per marker for one task.""" + metric_cols = ["val_auroc", "val_accuracy", "val_weighted_f1"] + present = [c for c in metric_cols if c in task_df.columns] + if not present: + return + + labels = task_df["marker_filter"].fillna("all").tolist() + x = np.arange(len(labels)) + n_metrics = len(present) + width = 0.8 / n_metrics + + metric_display = {"val_auroc": "AUROC", "val_accuracy": "Accuracy", "val_weighted_f1": "Weighted F1"} + colors = ["#0072B2", "#E69F00", "#009E73"] + + fig, ax = plt.subplots(figsize=(max(6, len(labels) * 1.5), 5)) + for i, col in enumerate(present): + vals = task_df[col].fillna(0).values + ax.bar(x + i * width, vals, width, label=metric_display.get(col, col), color=colors[i], alpha=0.85) + + ax.set_xticks(x + width * (n_metrics - 1) / 2) + ax.set_xticklabels(labels, fontsize=9) + ax.set_ylim(0, 1.05) + ax.axhline(0.5, color="gray", linewidth=0.8, linestyle="--", label="Random (0.5)") + ax.set_ylabel("Score") + ax.set_title(f"{task} — classifier performance per marker") + ax.legend(fontsize=9) + fig.tight_layout() + pdf.savefig(fig, bbox_inches="tight") + plt.close(fig) + + +def _plot_roc_curves( + pdf: PdfPages, + task: str, + marker_filter: str | None, + y_val: np.ndarray, + y_val_proba: np.ndarray, + classes: list[str], +) -> None: + """One-vs-rest ROC curves for a single (task, marker) classifier.""" + from sklearn.metrics import roc_curve + from sklearn.preprocessing import label_binarize + + # Colorblind-friendly palette (Wong 2011) + palette = ["#0072B2", "#E69F00", "#009E73", "#CC79A7", "#D55E00", "#56B4E9", "#F0E442"] + + fig, ax = plt.subplots(figsize=(6, 5)) + ax.set_title(f"ROC — {task} ({marker_filter})", fontsize=11) + + if len(classes) == 2: + fpr, tpr, _ = roc_curve(y_val, y_val_proba[:, 1], pos_label=classes[1]) + auroc = float(np.trapezoid(tpr, fpr)) + ax.plot(fpr, tpr, color=palette[0], linewidth=2, label=f"{classes[1]} (AUROC={auroc:.3f})") + else: + y_bin = label_binarize(y_val, classes=classes) + for i, cls in enumerate(classes): + fpr, tpr, _ = roc_curve(y_bin[:, i], y_val_proba[:, i]) + auroc = float(np.trapezoid(tpr, fpr)) + ax.plot(fpr, tpr, color=palette[i % len(palette)], linewidth=1.5, label=f"{cls} (AUROC={auroc:.3f})") + + ax.plot([0, 1], [0, 1], "k--", linewidth=0.8) + ax.set_xlabel("False Positive Rate") + ax.set_ylabel("True Positive Rate") + ax.set_xlim([0, 1]) + ax.set_ylim([0, 1.05]) + ax.legend(fontsize=8, loc="lower right") + fig.tight_layout() + pdf.savefig(fig, bbox_inches="tight") + plt.close(fig) + + +def _plot_f1_over_time( + pdf: PdfPages, + task: str, + marker_filter: str | None, + y_val: np.ndarray, + y_val_proba: np.ndarray, + classes: list[str], + val_hours: np.ndarray, +) -> None: + """Per-class F1 at each unique timepoint for a single (task, marker) classifier.""" + from sklearn.metrics import f1_score + + palette = ["#0072B2", "#E69F00", "#009E73", "#CC79A7", "#D55E00", "#56B4E9", "#F0E442"] + + y_pred = np.array(classes)[np.argmax(y_val_proba, axis=1)] + timepoints = sorted(np.unique(val_hours[~np.isnan(val_hours)])) + + # (n_timepoints, n_classes) + f1_per_time = np.full((len(timepoints), len(classes)), np.nan) + for ti, t in enumerate(timepoints): + mask = val_hours == t + if mask.sum() < 2: + continue + f1s = f1_score(y_val[mask], y_pred[mask], labels=classes, average=None, zero_division=0) + f1_per_time[ti] = f1s + + fig, ax = plt.subplots(figsize=(8, 5)) + for ci, cls in enumerate(classes): + ax.plot(timepoints, f1_per_time[:, ci], marker="o", color=palette[ci % len(palette)], linewidth=2, label=cls) + + ax.set_xlabel("Hours post perturbation") + ax.set_ylabel("F1 score") + ax.set_ylim(0, 1.05) + ax.axhline(0.5, color="gray", linewidth=0.8, linestyle="--") + ax.set_title(f"F1 over time — {task} ({marker_filter})") + ax.legend(fontsize=9) + fig.tight_layout() + pdf.savefig(fig, bbox_inches="tight") + plt.close(fig) + + +class _RunLinearClassifiersConfig: + """Config container for the run-linear-classifiers CLI.""" + + def __init__(self, raw: dict): + from dynaclr.evaluation.evaluate_config import LinearClassifiersStepConfig + + self.embeddings_path = Path(raw["embeddings_path"]) + self.output_dir = Path(raw["output_dir"]) + self.lc_config = LinearClassifiersStepConfig( + **{k: v for k, v in raw.items() if k not in ("embeddings_path", "output_dir")} + ) + + +@click.command(context_settings={"help_option_names": ["-h", "--help"]}) +@click.option( + "-c", + "--config", + type=click.Path(exists=True, path_type=Path), + required=True, + help="Path to YAML configuration file", +) +def main(config: Path) -> None: + """Run linear classifiers on a combined embeddings zarr from the evaluation orchestrator.""" + raw = load_config(config) + cfg = _RunLinearClassifiersConfig(raw) + run_linear_classifiers(cfg.embeddings_path, cfg.lc_config, cfg.output_dir) + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/orchestrated_test.py b/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/orchestrated_test.py new file mode 100644 index 000000000..a28db7569 --- /dev/null +++ b/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/orchestrated_test.py @@ -0,0 +1,336 @@ +"""Tests for the orchestrated linear classifiers evaluation.""" + +from pathlib import Path + +import anndata as ad +import numpy as np +import pandas as pd +import pytest + +from dynaclr.evaluation.evaluate_config import AnnotationSource, LinearClassifiersStepConfig, TaskSpec +from dynaclr.evaluation.linear_classifiers.orchestrated import run_linear_classifiers + + +def _make_embeddings_zarr( + path: Path, + n_cells: int = 200, + n_features: int = 16, + experiment: str = "exp_A", + use_id_col: bool = True, + extra_markers: list[tuple[str, int]] | None = None, +) -> ad.AnnData: + """Write a synthetic embeddings zarr and return the AnnData. + + Parameters + ---------- + extra_markers : list of (marker_name, n_cells) tuples, optional + Additional markers appended after the default Phase3D/TOMM20 split. + """ + half = n_cells // 2 + markers = ["Phase3D"] * half + ["TOMM20"] * half + extra_cells: list[dict] = [] + if extra_markers: + for marker_name, m_count in extra_markers: + markers += [marker_name] * m_count + extra_cells += [{}] * m_count + + total = n_cells + len(extra_cells) + rng = np.random.default_rng(42) + X = rng.standard_normal((total, n_features)).astype(np.float32) + + obs: dict = { + "fov_name": [f"A/1/FOV{i % 5}" for i in range(total)], + "t": [i % 10 for i in range(total)], + "track_id": list(range(total)), + "experiment": [experiment] * total, + "marker": markers, + "perturbation": ["uninfected"] * (total // 2) + ["ZIKV"] * (total - total // 2), + "hours_post_perturbation": [float(i % 5) * 24.0 for i in range(total)], + } + if use_id_col: + obs["id"] = list(range(total)) + + df = pd.DataFrame(obs) + # Convert string columns to object dtype — pandas 3 defaults to ArrowStringArray + # which anndata's zarr writer does not support. + for col in df.select_dtypes("string").columns: + df[col] = df[col].astype(object) + df.index = pd.Index([str(i) for i in range(total)], dtype=object) + var = pd.DataFrame(index=pd.Index([str(i) for i in range(n_features)], dtype=object)) + adata = ad.AnnData(X=X, obs=df, var=var) + adata.write_zarr(path) + return adata + + +def _make_embeddings_dir(tmp_path: Path, n_cells: int = 200, n_features: int = 16) -> Path: + """Write two per-experiment zarrs to a directory; return the directory path.""" + emb_dir = tmp_path / "embeddings" + emb_dir.mkdir() + _make_embeddings_zarr(emb_dir / "exp_A.zarr", n_cells=n_cells, n_features=n_features, experiment="exp_A") + _make_embeddings_zarr(emb_dir / "exp_B.zarr", n_cells=n_cells, n_features=n_features, experiment="exp_B") + return emb_dir + + +def _make_annotations( + tmp_path: Path, experiment: str, fov_names: list, ts: list, track_ids: list, hours: list | None = None +) -> Path: + """Create a synthetic annotation CSV with infection_state and organelle_state labels. + + fov_name is stored as the first path component only (e.g. "A/1/FOV0" → "A"), + matching what load_annotation_anndata extracts from obs via .str.split("/").str[0]. + """ + labels = ["uninfected" if i % 3 != 0 else "infected" for i in range(len(fov_names))] + # Extract first path component to match the join key in load_annotation_anndata + fov_first = [str(f).split("/")[0] for f in fov_names] + data: dict = { + "fov_name": fov_first, + "t": ts, + "track_id": track_ids, + "infection_state": labels, + "organelle_state": ["normal" if i % 4 != 0 else "abnormal" for i in range(len(fov_names))], + } + if hours is not None: + data["hours_post_perturbation"] = hours + df = pd.DataFrame(data) + csv_path = tmp_path / f"{experiment}_annotations.csv" + df.to_csv(csv_path, index=False) + return csv_path + + +def _setup_dir_with_annotations(tmp_path: Path) -> tuple[Path, Path, Path]: + """Create embeddings directory + annotation CSVs for exp_A and exp_B.""" + emb_dir = _make_embeddings_dir(tmp_path) + ann_paths = {} + for exp in ["exp_A", "exp_B"]: + adata = ad.read_zarr(emb_dir / f"{exp}.zarr") + ann_paths[exp] = _make_annotations( + tmp_path, + exp, + adata.obs["fov_name"].tolist(), + adata.obs["t"].tolist(), + adata.obs["track_id"].tolist(), + hours=adata.obs["hours_post_perturbation"].tolist(), + ) + return emb_dir, ann_paths["exp_A"], ann_paths["exp_B"] + + +def test_run_linear_classifiers_directory_mode(tmp_path): + """Embeddings directory (post-split) is loaded and concatenated correctly.""" + emb_dir, ann_a, ann_b = _setup_dir_with_annotations(tmp_path) + + config = LinearClassifiersStepConfig( + annotations=[ + AnnotationSource(experiment="exp_A", path=str(ann_a)), + AnnotationSource(experiment="exp_B", path=str(ann_b)), + ], + tasks=[TaskSpec(task="infection_state")], + use_scaling=True, + split_train_data=0.8, + ) + + results = run_linear_classifiers(emb_dir, config, tmp_path / "out") + + # auto-expand to Phase3D and TOMM20 → 2 rows + assert len(results) == 2 + assert set(results["marker_filter"].tolist()) == {"Phase3D", "TOMM20"} + assert results.iloc[0]["task"] == "infection_state" + assert results.iloc[0]["n_samples"] == 200 # 100 per experiment × 2 + assert (tmp_path / "out" / "metrics_summary.csv").exists() + # one summary PDF per task + assert (tmp_path / "out" / "infection_state_summary.pdf").exists() + + +def test_run_linear_classifiers_single_zarr_mode(tmp_path): + """Single combined zarr (pre-split) is still accepted.""" + zarr_path = tmp_path / "embeddings.zarr" + adata = _make_embeddings_zarr(zarr_path, experiment="exp_A") + ann = _make_annotations( + tmp_path, + "exp_A", + adata.obs["fov_name"].tolist(), + adata.obs["t"].tolist(), + adata.obs["track_id"].tolist(), + hours=adata.obs["hours_post_perturbation"].tolist(), + ) + + config = LinearClassifiersStepConfig( + annotations=[AnnotationSource(experiment="exp_A", path=str(ann))], + tasks=[TaskSpec(task="infection_state")], + use_scaling=True, + split_train_data=0.8, + ) + + results = run_linear_classifiers(zarr_path, config, tmp_path / "out") + # auto-expand to Phase3D and TOMM20 → 2 rows + assert len(results) == 2 + assert set(results["marker_filter"].tolist()) == {"Phase3D", "TOMM20"} + + +def test_run_linear_classifiers_fallback_join_no_id(tmp_path): + """Annotation join falls back to (fov_name, t, track_id) when id column is absent.""" + zarr_path = tmp_path / "embeddings.zarr" + adata = _make_embeddings_zarr(zarr_path, experiment="exp_A", use_id_col=False) + + assert "id" not in adata.obs.columns + + ann = _make_annotations( + tmp_path, + "exp_A", + adata.obs["fov_name"].tolist(), + adata.obs["t"].tolist(), + adata.obs["track_id"].tolist(), + ) + + config = LinearClassifiersStepConfig( + annotations=[AnnotationSource(experiment="exp_A", path=str(ann))], + tasks=[TaskSpec(task="infection_state")], + use_scaling=True, + split_train_data=0.8, + ) + + results = run_linear_classifiers(zarr_path, config, tmp_path / "out") + # auto-expand to Phase3D and TOMM20 → 2 rows, 100 cells each + assert len(results) == 2 + assert set(results["marker_filter"].tolist()) == {"Phase3D", "TOMM20"} + assert (results["n_samples"] == 100).all() + + +def test_run_linear_classifiers_multiple_tasks(tmp_path): + """Multiple tasks produce one row each in results.""" + emb_dir, ann_a, ann_b = _setup_dir_with_annotations(tmp_path) + + config = LinearClassifiersStepConfig( + annotations=[ + AnnotationSource(experiment="exp_A", path=str(ann_a)), + AnnotationSource(experiment="exp_B", path=str(ann_b)), + ], + tasks=[ + TaskSpec(task="infection_state"), + TaskSpec(task="organelle_state"), + ], + use_scaling=True, + split_train_data=0.8, + ) + + results = run_linear_classifiers(emb_dir, config, tmp_path / "out") + + # auto-expand to Phase3D and TOMM20 → 2 tasks × 2 markers = 4 rows + assert len(results) == 4 + assert set(results["task"].tolist()) == {"infection_state", "organelle_state"} + + +def test_run_linear_classifiers_marker_filter(tmp_path): + """marker_filters restricts cells to those with matching marker.""" + emb_dir, ann_a, ann_b = _setup_dir_with_annotations(tmp_path) + + config = LinearClassifiersStepConfig( + annotations=[ + AnnotationSource(experiment="exp_A", path=str(ann_a)), + AnnotationSource(experiment="exp_B", path=str(ann_b)), + ], + tasks=[TaskSpec(task="infection_state", marker_filters=["Phase3D"])], + use_scaling=True, + split_train_data=0.8, + ) + + results = run_linear_classifiers(emb_dir, config, tmp_path / "out") + + assert not results.empty + # Phase3D is half of each experiment → 100 per exp × 2 = 200 + assert results.iloc[0]["n_samples"] == 200 + + +def test_run_linear_classifiers_missing_metadata_raises(tmp_path): + """Raises ValueError when embeddings zarr lacks experiment/marker columns.""" + X = np.random.standard_normal((50, 8)).astype(np.float32) + obs = pd.DataFrame({"fov_name": [f"A/1/FOV{i}" for i in range(50)]}) + obs["fov_name"] = obs["fov_name"].astype(object) + obs.index = pd.Index([str(i) for i in range(50)], dtype=object) + var = pd.DataFrame(index=pd.Index([str(i) for i in range(8)], dtype=object)) + zarr_path = tmp_path / "embeddings.zarr" + ad.AnnData(X=X, obs=obs, var=var).write_zarr(zarr_path) + + config = LinearClassifiersStepConfig( + annotations=[AnnotationSource(experiment="exp_A", path=str(tmp_path / "ann.csv"))], + tasks=[TaskSpec(task="infection_state")], + ) + + with pytest.raises(ValueError, match="missing columns"): + run_linear_classifiers(zarr_path, config, tmp_path / "out") + + +def test_run_linear_classifiers_unknown_marker_skipped(tmp_path): + """If marker_filters matches no rows, task is skipped and result is empty.""" + emb_dir, ann_a, _ = _setup_dir_with_annotations(tmp_path) + + config = LinearClassifiersStepConfig( + annotations=[AnnotationSource(experiment="exp_A", path=str(ann_a))], + tasks=[TaskSpec(task="infection_state", marker_filters=["NonExistentMarker"])], + ) + + results = run_linear_classifiers(emb_dir, config, tmp_path / "out") + assert results.empty + + +def test_run_linear_classifiers_sparse_marker_skipped(tmp_path): + """Sparse marker with too few samples for stratified split is skipped without crashing.""" + emb_dir = tmp_path / "embeddings" + emb_dir.mkdir() + + # exp_A: 200 cells (Phase3D/TOMM20) + 4 RARE cells (1 infected, 3 uninfected) + adata_a = _make_embeddings_zarr( + emb_dir / "exp_A.zarr", + n_cells=200, + experiment="exp_A", + extra_markers=[("RARE", 4)], + ) + ann_a = _make_annotations( + tmp_path, + "exp_A", + adata_a.obs["fov_name"].tolist(), + adata_a.obs["t"].tolist(), + adata_a.obs["track_id"].tolist(), + hours=adata_a.obs["hours_post_perturbation"].tolist(), + ) + # Override RARE annotation so only 1 sample is "infected" (too few for stratified split) + df = pd.read_csv(ann_a) + rare_idx = adata_a.obs.index[adata_a.obs["marker"] == "RARE"].tolist() + rare_rows = df[df["track_id"].isin([int(i) for i in rare_idx])] + df.loc[rare_rows.index, "infection_state"] = ["infected"] + ["uninfected"] * (len(rare_rows) - 1) + df.to_csv(ann_a, index=False) + + config = LinearClassifiersStepConfig( + annotations=[AnnotationSource(experiment="exp_A", path=str(ann_a))], + tasks=[TaskSpec(task="infection_state")], + use_scaling=True, + split_train_data=0.8, + ) + + # Must not crash; RARE is skipped due to insufficient samples + results = run_linear_classifiers(emb_dir, config, tmp_path / "out") + assert not results.empty + assert "RARE" not in results["marker_filter"].tolist() + assert set(results["marker_filter"].tolist()) == {"Phase3D", "TOMM20"} + + +def test_run_linear_classifiers_f1_over_time_plots_written(tmp_path): + """F1-over-time plots are written when hours_post_perturbation is present.""" + emb_dir, ann_a, ann_b = _setup_dir_with_annotations(tmp_path) + + config = LinearClassifiersStepConfig( + annotations=[ + AnnotationSource(experiment="exp_A", path=str(ann_a)), + AnnotationSource(experiment="exp_B", path=str(ann_b)), + ], + tasks=[TaskSpec(task="infection_state", marker_filters=["Phase3D"])], + use_scaling=True, + split_train_data=0.8, + ) + + out_dir = tmp_path / "out" + results = run_linear_classifiers(emb_dir, config, out_dir) + + assert not results.empty + pdf_path = out_dir / "infection_state_summary.pdf" + assert pdf_path.exists() + assert pdf_path.stat().st_size > 0 diff --git a/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/report.py b/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/report.py index a55b68e33..e63af9086 100644 --- a/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/report.py +++ b/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/report.py @@ -1,10 +1,7 @@ -"""PDF report generation for linear classifier evaluation and cross-validation. +"""PDF report generation for linear classifier cross-validation. -Provides two report generators: -- ``generate_comparison_report``: Evaluation report comparing models on a test set. -- ``generate_cv_report``: Cross-validation report with impact analysis. - -Both are optional and gated behind the ``--report`` flag in the respective scripts. +Provides ``generate_cv_report`` for cross-validation reports with impact analysis. +This is optional and gated behind the ``--report`` flag in the cross-validation script. """ from __future__ import annotations @@ -20,7 +17,6 @@ import pandas as pd from matplotlib.backends.backend_pdf import PdfPages from matplotlib.patches import Patch -from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix matplotlib.use("Agg") @@ -39,9 +35,6 @@ "baseline": _COLOR_BASELINE, } -_MODEL_COLORS = {"2D": "#1f77b4", "3D": "#ff7f0e"} -_EXTRA_COLORS = ["#2ca02c", "#9467bd", "#8c564b", "#e377c2"] - _TEMPORAL_PALETTE = [ "#0072B2", "#E69F00", @@ -54,281 +47,6 @@ ] -def _get_model_color(label: str, idx: int = 0) -> str: - return _MODEL_COLORS.get(label, _EXTRA_COLORS[idx % len(_EXTRA_COLORS)]) - - -# --------------------------------------------------------------------------- -# Evaluation report -# --------------------------------------------------------------------------- - - -def generate_comparison_report( - output_dir: Path, - dataset_name: str, - model_labels: list[str], - tasks: list[str], - channels: list[str], - train_results: dict[str, dict[tuple[str, str], dict[str, Any]]], - eval_results: dict[str, dict[tuple[str, str], dict[str, Any]]], -) -> Path: - """Generate a PDF comparing model performance on a held-out test set. - - Parameters - ---------- - output_dir : Path - Directory to save the report. - dataset_name : str - Name of the test dataset. - model_labels : list[str] - Model labels (e.g. ``["2D", "3D"]``). - tasks : list[str] - Classification tasks evaluated. - channels : list[str] - Input channels evaluated. - train_results : dict - ``model_label -> (task, channel) -> {"metrics": {...}, ...}``. - eval_results : dict - ``model_label -> (task, channel) -> {"metrics": {...}, "annotated_adata": ...}``. - - Returns - ------- - Path - Path to the generated PDF. - """ - report_path = output_dir / f"{dataset_name}_comparison_report.pdf" - output_dir.mkdir(parents=True, exist_ok=True) - - with PdfPages(report_path) as pdf: - _eval_page_title(pdf, dataset_name, model_labels, tasks, channels, train_results) - _eval_page_global_metrics(pdf, model_labels, tasks, channels, train_results, eval_results) - for task in tasks: - _eval_page_task_comparison(pdf, task, model_labels, channels, eval_results) - for channel in channels: - _eval_page_channel_comparison(pdf, channel, model_labels, tasks, train_results, eval_results) - - print(f"\nReport saved: {report_path}") - return report_path - - -def _eval_page_title(pdf, dataset_name, model_labels, tasks, channels, train_results): - fig, ax = plt.subplots(figsize=(11, 8.5)) - ax.axis("off") - - lines = [ - "Linear Classifier Comparison Report", - "", - f"Test Dataset: {dataset_name}", - "", - ] - for label in model_labels: - n_combos = len(train_results.get(label, {})) - lines.append(f"Model {label}: {n_combos} classifiers trained") - lines.append("") - lines.append(f"Channels: {', '.join(channels)}") - lines.append(f"Tasks: {', '.join(tasks)}") - - ax.text( - 0.5, - 0.5, - "\n".join(lines), - transform=ax.transAxes, - fontsize=12, - verticalalignment="center", - horizontalalignment="center", - fontfamily="monospace", - ) - fig.suptitle("Model Comparison", fontsize=16, fontweight="bold") - pdf.savefig(fig, bbox_inches="tight") - plt.close(fig) - - -def _eval_page_global_metrics(pdf, model_labels, tasks, channels, train_results, eval_results): - fig, ax = plt.subplots(figsize=(11, 8.5)) - ax.axis("off") - fig.suptitle("Global Metrics Summary", fontsize=14, fontweight="bold") - - col_labels = ["Task", "Channel"] - for label in model_labels: - col_labels.extend([f"{label}\nVal Acc", f"{label}\nVal F1", f"{label}\nTest Acc", f"{label}\nTest F1"]) - - table_data = [] - for task in tasks: - for channel in channels: - row = [task, channel] - for label in model_labels: - train_r = train_results.get(label, {}).get((task, channel)) - eval_r = eval_results.get(label, {}).get((task, channel)) - val_acc = f"{train_r['metrics']['val_accuracy']:.3f}" if train_r else "-" - val_f1 = f"{train_r['metrics']['val_weighted_f1']:.3f}" if train_r else "-" - test_acc = f"{eval_r['metrics']['test_accuracy']:.3f}" if eval_r else "-" - test_f1 = f"{eval_r['metrics']['test_weighted_f1']:.3f}" if eval_r else "-" - row.extend([val_acc, val_f1, test_acc, test_f1]) - table_data.append(row) - - if table_data: - table = ax.table(cellText=table_data, colLabels=col_labels, loc="center", cellLoc="center") - table.auto_set_font_size(False) - table.set_fontsize(8) - table.scale(1.0, 1.4) - - pdf.savefig(fig, bbox_inches="tight") - plt.close(fig) - - -def _eval_page_task_comparison(pdf, task, model_labels, channels, eval_results): - n_models = len(model_labels) - - all_classes: set[str] = set() - for label in model_labels: - for ch in channels: - r = eval_results.get(label, {}).get((task, ch)) - if r and "annotated_adata" in r: - adata = r["annotated_adata"] - if task in adata.obs.columns: - all_classes.update(adata.obs[task].dropna().unique()) - all_classes_sorted = sorted(all_classes) - - # F1 bar chart - fig, ax_bar = plt.subplots(figsize=(11, 5)) - fig.suptitle(f"Task: {task} - Per-Class F1", fontsize=14, fontweight="bold") - - if all_classes_sorted: - x = np.arange(len(all_classes_sorted)) - width = 0.8 / max(n_models, 1) - for i, label in enumerate(model_labels): - f1_values = [] - for cls in all_classes_sorted: - f1s = [] - for ch in channels: - r = eval_results.get(label, {}).get((task, ch)) - if r: - f1 = r["metrics"].get(f"test_{cls}_f1") - if f1 is not None: - f1s.append(f1) - f1_values.append(np.mean(f1s) if f1s else 0) - ax_bar.bar( - x + i * width, - f1_values, - width, - label=label, - color=_get_model_color(label, i), - ) - ax_bar.set_xticks(x + width * (n_models - 1) / 2) - ax_bar.set_xticklabels(all_classes_sorted) - ax_bar.set_ylabel("Test F1 (avg across channels)") - ax_bar.legend() - ax_bar.set_ylim(0, 1.05) - - fig.tight_layout() - pdf.savefig(fig, bbox_inches="tight") - plt.close(fig) - - # Confusion matrices - n_cols = len(channels) - n_rows = n_models - if n_cols == 0 or n_rows == 0: - return - - fig_cm, cm_axes = plt.subplots(n_rows, max(n_cols, 1), figsize=(4 * max(n_cols, 1), 3.5 * n_rows)) - fig_cm.suptitle(f"Confusion Matrices: {task}", fontsize=14, fontweight="bold") - - if n_rows == 1 and n_cols == 1: - cm_axes = [[cm_axes]] - elif n_rows == 1: - cm_axes = [cm_axes] - elif n_cols == 1: - cm_axes = [[row] for row in cm_axes] - - for i, label in enumerate(model_labels): - for j, ch in enumerate(channels): - ax = cm_axes[i][j] - r = eval_results.get(label, {}).get((task, ch)) - if r and "annotated_adata" in r: - adata = r["annotated_adata"] - pred_col = f"predicted_{task}" - mask = adata.obs[task].notna() & (adata.obs[task] != "unknown") - subset = adata[mask] - if len(subset) > 0 and pred_col in subset.obs.columns: - y_true = subset.obs[task].values - y_pred = subset.obs[pred_col].values - labels = sorted(set(y_true) | set(y_pred)) - cm = confusion_matrix(y_true, y_pred, labels=labels) - ConfusionMatrixDisplay(cm, display_labels=labels).plot(ax=ax, cmap="Blues", colorbar=False) - ax.set_title(f"{label} / {ch}", fontsize=10) - - fig_cm.tight_layout() - pdf.savefig(fig_cm, bbox_inches="tight") - plt.close(fig_cm) - - -def _eval_page_channel_comparison(pdf, channel, model_labels, tasks, train_results, eval_results): - fig, axes = plt.subplots(1, 2, figsize=(11, 5)) - fig.suptitle(f"Channel: {channel}", fontsize=14, fontweight="bold") - - n_models = len(model_labels) - x = np.arange(len(tasks)) - width = 0.8 / max(n_models, 1) - - ax = axes[0] - for i, label in enumerate(model_labels): - accs = [] - for task in tasks: - r = eval_results.get(label, {}).get((task, channel)) - accs.append(r["metrics"]["test_accuracy"] if r else 0) - ax.bar( - x + i * width, - accs, - width, - label=label, - color=_get_model_color(label, i), - ) - ax.set_xticks(x + width * (n_models - 1) / 2) - ax.set_xticklabels(tasks, rotation=30, ha="right", fontsize=8) - ax.set_ylabel("Test Accuracy") - ax.set_ylim(0, 1.05) - ax.legend() - ax.set_title("Test Accuracy") - - ax2 = axes[1] - for i, label in enumerate(model_labels): - val_accs, test_accs = [], [] - for task in tasks: - tr = train_results.get(label, {}).get((task, channel)) - ev = eval_results.get(label, {}).get((task, channel)) - val_accs.append(tr["metrics"]["val_accuracy"] if tr else 0) - test_accs.append(ev["metrics"]["test_accuracy"] if ev else 0) - - color = _get_model_color(label, i) - ax2.bar( - x + i * width - width / 4, - val_accs, - width / 2, - label=f"{label} Val", - color=color, - alpha=0.5, - ) - ax2.bar( - x + i * width + width / 4, - test_accs, - width / 2, - label=f"{label} Test", - color=color, - alpha=1.0, - ) - - ax2.set_xticks(x + width * (n_models - 1) / 2) - ax2.set_xticklabels(tasks, rotation=30, ha="right", fontsize=8) - ax2.set_ylabel("Accuracy") - ax2.set_ylim(0, 1.05) - ax2.legend(fontsize=7) - ax2.set_title("Val vs Test (Generalization)") - - fig.tight_layout() - pdf.savefig(fig, bbox_inches="tight") - plt.close(fig) - - # --------------------------------------------------------------------------- # Cross-validation report # --------------------------------------------------------------------------- diff --git a/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/train_linear_classifier.py b/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/train_linear_classifier.py index 00e62aa41..d79ff4e8a 100644 --- a/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/train_linear_classifier.py +++ b/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/train_linear_classifier.py @@ -9,7 +9,7 @@ import click from pydantic import ValidationError -from viscy_utils.cli_utils import format_markdown_table, load_config +from viscy_utils.cli_utils import format_markdown_table, load_config_section from viscy_utils.evaluation.linear_classifier import ( load_and_combine_datasets, save_pipeline_to_wandb, @@ -68,7 +68,7 @@ def main(config: Path): click.echo("=" * 60) try: - config_dict = load_config(config) + config_dict = load_config_section(config, None, default_section="train_linear_classifier") train_config = LinearClassifierTrainConfig(**config_dict) except ValidationError as e: click.echo(f"\n Configuration validation failed:\n{e}", err=True) @@ -103,7 +103,7 @@ def main(config: Path): "random_state": train_config.random_seed, } - pipeline, metrics = train_linear_classifier( + pipeline, metrics, _ = train_linear_classifier( adata=combined_adata, task=train_config.task, use_scaling=train_config.use_scaling, diff --git a/applications/dynaclr/src/dynaclr/evaluation/mmd/__init__.py b/applications/dynaclr/src/dynaclr/evaluation/mmd/__init__.py new file mode 100644 index 000000000..1419a3501 --- /dev/null +++ b/applications/dynaclr/src/dynaclr/evaluation/mmd/__init__.py @@ -0,0 +1 @@ +"""MMD-based evaluation of perturbation effects in cell embedding space.""" diff --git a/applications/dynaclr/src/dynaclr/evaluation/mmd/compute_mmd.py b/applications/dynaclr/src/dynaclr/evaluation/mmd/compute_mmd.py new file mode 100644 index 000000000..c08fdc40b --- /dev/null +++ b/applications/dynaclr/src/dynaclr/evaluation/mmd/compute_mmd.py @@ -0,0 +1,924 @@ +"""CLI and analysis logic for MMD-based perturbation effect evaluation.""" + +from __future__ import annotations + +from pathlib import Path + +import anndata as ad +import click +import numpy as np +import pandas as pd + +from dynaclr.evaluation.mmd.config import ( + ComparisonSpec, + MMDCombinedConfig, + MMDEvalConfig, + MMDPooledConfig, + MMDSettings, + _resolve_bin_edges, +) +from viscy_utils.compose import load_composed_config +from viscy_utils.evaluation.mmd import median_heuristic, mmd_permutation_test + + +def _extract_embeddings(adata: ad.AnnData, embedding_key: str | None) -> np.ndarray: + """Extract embedding matrix from AnnData. + + Parameters + ---------- + adata : AnnData + AnnData store with ``.X`` or ``.obsm``. + embedding_key : str or None + obsm key, or None to use ``.X``. + + Returns + ------- + np.ndarray + Embedding matrix, shape (n_cells, n_features). + """ + if embedding_key is None: + X = adata.X + else: + X = adata.obsm[embedding_key] + if hasattr(X, "toarray"): + return X.toarray() + return np.asarray(X) + + +def _subsample(X: np.ndarray, max_n: int | None, rng: np.random.Generator) -> np.ndarray: + if max_n is None or len(X) <= max_n: + return X + idx = rng.choice(len(X), max_n, replace=False) + return X[idx] + + +def _run_one_comparison( + emb_a: np.ndarray, + emb_b: np.ndarray, + settings: MMDSettings, + bandwidth: float | None = None, +) -> tuple[float, float, float, float, float, int, int]: + """Run MMD permutation test for one (cond_a, cond_b) pair. + + Parameters + ---------- + emb_a : np.ndarray + Embeddings for group A. + emb_b : np.ndarray + Embeddings for group B. + settings : MMDSettings + Algorithm settings. + bandwidth : float or None + Pre-computed bandwidth to use. If None, computed via median heuristic. + Pass a value to share bandwidth across comparisons within the same group. + + Returns + ------- + mmd2 : float + p_value : float + bandwidth : float + effect_size : float + mmd2 / bandwidth + activity_zscore : float + (mmd2 - null_mean) / null_std — normalizes observed MMD relative to + the permutation null, comparable across markers and datasets. + n_a_used : int + Actual number of cells used from group A after subsampling/balancing. + n_b_used : int + Actual number of cells used from group B after subsampling/balancing. + All metric floats are NaN if fewer than min_cells cells in either group. + """ + rng = np.random.default_rng(settings.seed) + emb_a = _subsample(emb_a, settings.max_cells, rng) + emb_b = _subsample(emb_b, settings.max_cells, rng) + if settings.balance_samples: + min_n = min(len(emb_a), len(emb_b)) + emb_a = _subsample(emb_a, min_n, rng) + emb_b = _subsample(emb_b, min_n, rng) + n_a_used = len(emb_a) + n_b_used = len(emb_b) + if n_a_used < settings.min_cells or n_b_used < settings.min_cells: + return float("nan"), float("nan"), float("nan"), float("nan"), float("nan"), n_a_used, n_b_used + if bandwidth is None: + bandwidth = median_heuristic(emb_a, emb_b) + mmd2, p_value, null_dist = mmd_permutation_test( + emb_a, emb_b, n_permutations=settings.n_permutations, bandwidth=bandwidth, seed=settings.seed + ) + effect_size = mmd2 / bandwidth if bandwidth > 0 else float("nan") + activity_zscore = float((mmd2 - null_dist.mean()) / (null_dist.std() + 1e-12)) + return mmd2, p_value, bandwidth, effect_size, activity_zscore, n_a_used, n_b_used + + +def _run_map_comparison( + meta: pd.DataFrame, + features: np.ndarray, + comp: ComparisonSpec, + group_by: str, + marker: str, + map_settings, +) -> tuple[float, float]: + """Run copairs mAP for one comparison. + + Returns + ------- + map_value : float + map_p_value : float + Both NaN on failure or if copairs is unavailable. + """ + try: + from viscy_utils.evaluation.embedding_map import compute_embedding_map + except ImportError: + return float("nan"), float("nan") + result = compute_embedding_map( + meta=meta, + features=features, + reference_condition=comp.cond_a, + target_condition=comp.cond_b, + condition_col=group_by, + group_col="marker", + distance=map_settings.distance, + null_size=map_settings.null_size, + seed=map_settings.seed, + ) + if result is None: + return float("nan"), float("nan") + return result["mean_average_precision"], result["p_value"] + + +def run_mmd_analysis(adata: ad.AnnData, config: MMDEvalConfig) -> pd.DataFrame: + """Run per-experiment MMD analysis for explicit comparison pairs across all markers. + + Each comparison is an explicit ``(cond_a, cond_b)`` pair with a label. + The analysis is always faceted by ``obs["marker"]`` and ``obs["experiment"]``. + Each experiment is processed independently to avoid cross-experiment pooling. + + Parameters + ---------- + adata : AnnData + AnnData (single- or multi-experiment) after split-embeddings step. + config : MMDEvalConfig + Analysis configuration. + + Returns + ------- + pd.DataFrame + Results with columns: experiment, marker, cond_a, cond_b, label, + hours_bin_start, hours_bin_end, n_a, n_b, mmd2, p_value, bandwidth, + effect_size, activity_zscore, embedding_key, and optionally map_value, + map_p_value. + """ + if config.obs_filter: + mask = pd.Series([True] * len(adata), index=adata.obs.index) + for col, val in config.obs_filter.items(): + if col not in adata.obs.columns: + raise KeyError(f"obs_filter column '{col}' not found. Available: {list(adata.obs.columns)}") + mask &= adata.obs[col] == val + adata = adata[mask].copy() + + obs = adata.obs + if config.group_by not in obs.columns: + raise KeyError(f"obs column '{config.group_by}' not found. Available: {list(obs.columns)}") + + emb_key_label = config.embedding_key if config.embedding_key is not None else "X" + all_emb = _extract_embeddings(adata, config.embedding_key) + experiments = obs["experiment"].unique() if "experiment" in obs.columns else ["unknown"] + + records: list[dict] = [] + for experiment in experiments: + exp_mask = ( + obs["experiment"] == experiment + if "experiment" in obs.columns + else pd.Series([True] * len(obs), index=obs.index) + ) + for marker in sorted(obs["marker"].unique()): + marker_mask = exp_mask & (obs["marker"] == marker) + + if config.temporal_bin_size is None and config.temporal_bins is None: + # Aggregate mode + shared_bw = _compute_shared_bandwidth( + all_emb, obs, marker_mask, config.comparisons, config.mmd, config.group_by + ) + for comp in config.comparisons: + mask_a = marker_mask & (obs[config.group_by] == comp.cond_a) + mask_b = marker_mask & (obs[config.group_by] == comp.cond_b) + emb_a = all_emb[mask_a.values] + emb_b = all_emb[mask_b.values] + bw = shared_bw if shared_bw is not None else None + mmd2, p_value, bw_out, es, az, na, nb = _run_one_comparison(emb_a, emb_b, config.mmd, bandwidth=bw) + map_val, map_pval = _maybe_map( + obs[marker_mask.values], + all_emb[marker_mask.values], + comp, + config.group_by, + marker, + config.map_settings, + ) + records.append( + _record( + experiment, + marker, + comp, + float("nan"), + float("nan"), + na, + nb, + mmd2, + p_value, + bw_out, + es, + az, + map_val, + map_pval, + emb_key_label, + ) + ) + else: + if "hours_post_perturbation" not in obs.columns: + raise KeyError("temporal binning requires obs column 'hours_post_perturbation'") + max_hours = obs["hours_post_perturbation"].max() + bin_pairs = _resolve_bin_edges(config.temporal_bin_size, config.temporal_bins, max_hours) + for b_start, b_end in bin_pairs: + shared_bw = _compute_shared_bandwidth_temporal( + all_emb, obs, marker_mask, config.comparisons, config.mmd, config.group_by, b_start, b_end + ) + for comp in config.comparisons: + mask_a = marker_mask & (obs[config.group_by] == comp.cond_a) + bin_mask_b = ( + marker_mask + & (obs[config.group_by] == comp.cond_b) + & (obs["hours_post_perturbation"] >= b_start) + & (obs["hours_post_perturbation"] < b_end) + ) + emb_a = all_emb[mask_a.values] + emb_b = all_emb[bin_mask_b.values] + bw = shared_bw if shared_bw is not None else None + mmd2, p_value, bw_out, es, az, na, nb = _run_one_comparison( + emb_a, emb_b, config.mmd, bandwidth=bw + ) + map_val, map_pval = _maybe_map( + obs[marker_mask.values], + all_emb[marker_mask.values], + comp, + config.group_by, + marker, + config.map_settings, + ) + records.append( + _record( + experiment, + marker, + comp, + b_start, + b_end, + na, + nb, + mmd2, + p_value, + bw_out, + es, + az, + map_val, + map_pval, + emb_key_label, + ) + ) + return pd.DataFrame(records) + + +def _compute_shared_bandwidth( + all_emb: np.ndarray, + obs: pd.DataFrame, + marker_mask: pd.Series, + comparisons: list[ComparisonSpec], + settings: MMDSettings, + group_by: str, +) -> float | None: + """Compute bandwidth from the share_bandwidth_from comparison, if configured.""" + if settings.share_bandwidth_from is None: + return None + for comp in comparisons: + if comp.label == settings.share_bandwidth_from: + mask_a = marker_mask & (obs[group_by] == comp.cond_a) + mask_b = marker_mask & (obs[group_by] == comp.cond_b) + emb_a = all_emb[mask_a.values] + emb_b = all_emb[mask_b.values] + if len(emb_a) >= settings.min_cells and len(emb_b) >= settings.min_cells: + return median_heuristic(emb_a, emb_b) + return None + return None + + +def _compute_shared_bandwidth_temporal( + all_emb: np.ndarray, + obs: pd.DataFrame, + marker_mask: pd.Series, + comparisons: list[ComparisonSpec], + settings: MMDSettings, + group_by: str, + b_start: float, + b_end: float, +) -> float | None: + """Compute shared bandwidth from the share_bandwidth_from comparison for a temporal bin.""" + if settings.share_bandwidth_from is None: + return None + for comp in comparisons: + if comp.label == settings.share_bandwidth_from: + mask_a = ( + marker_mask + & (obs[group_by] == comp.cond_a) + & (obs["hours_post_perturbation"] >= b_start) + & (obs["hours_post_perturbation"] < b_end) + ) + mask_b = ( + marker_mask + & (obs[group_by] == comp.cond_b) + & (obs["hours_post_perturbation"] >= b_start) + & (obs["hours_post_perturbation"] < b_end) + ) + emb_a = all_emb[mask_a.values] + emb_b = all_emb[mask_b.values] + if len(emb_a) >= settings.min_cells and len(emb_b) >= settings.min_cells: + return median_heuristic(emb_a, emb_b) + return None + return None + + +def _maybe_map( + obs_sub: pd.DataFrame, + emb_sub: np.ndarray, + comp: ComparisonSpec, + group_by: str, + marker: str, + map_settings, +) -> tuple[float, float]: + """Run mAP if enabled, otherwise return NaN pair.""" + if not map_settings.enabled: + return float("nan"), float("nan") + return _run_map_comparison(obs_sub, emb_sub, comp, group_by, marker, map_settings) + + +def _record( + experiment: str, + marker: str, + comp: ComparisonSpec, + hours_bin_start: float, + hours_bin_end: float, + n_a: int, + n_b: int, + mmd2: float, + p_value: float, + bandwidth: float, + effect_size: float, + activity_zscore: float, + map_value: float, + map_p_value: float, + embedding_key: str, +) -> dict: + return { + "experiment": experiment, + "marker": marker, + "cond_a": comp.cond_a, + "cond_b": comp.cond_b, + "label": comp.label, + "hours_bin_start": hours_bin_start, + "hours_bin_end": hours_bin_end, + "n_a": n_a, + "n_b": n_b, + "mmd2": mmd2, + "p_value": p_value, + "bandwidth": bandwidth, + "effect_size": effect_size, + "activity_zscore": activity_zscore, + "map_value": map_value, + "map_p_value": map_p_value, + "embedding_key": embedding_key, + } + + +def run_mmd_combined(config: MMDCombinedConfig) -> pd.DataFrame: + """Run pairwise cross-experiment MMD, faceted by marker and condition+time bin. + + For each marker, finds all experiments that share it, then for each pair + of those experiments runs MMD per (condition, time_bin) after centering + within that pair only. This measures batch effects between experiments + at matched biological states. + + Parameters + ---------- + config : MMDCombinedConfig + Combined analysis configuration. + + Returns + ------- + pd.DataFrame + Results with columns: marker, exp_a, exp_b, condition, hours_bin_start, + hours_bin_end, n_a, n_b, mmd2, p_value, bandwidth, effect_size, + activity_zscore, embedding_key. + """ + from itertools import combinations + + adatas = {ad.read_zarr(p).obs["experiment"].iloc[0]: ad.read_zarr(p) for p in config.input_paths} + + if config.obs_filter: + filtered = {} + for exp_name, adata in adatas.items(): + mask = pd.Series([True] * len(adata), index=adata.obs.index) + for col, val in config.obs_filter.items(): + if col not in adata.obs.columns: + raise KeyError( + f"obs_filter column '{col}' not found in {exp_name}. Available: {list(adata.obs.columns)}" + ) + mask &= adata.obs[col] == val + filtered[exp_name] = adata[mask].copy() + adatas = filtered + + marker_to_exps: dict[str, list[str]] = {} + for exp_name, adata in adatas.items(): + for marker in adata.obs["marker"].unique(): + marker_to_exps.setdefault(marker, []).append(exp_name) + + emb_key_label = config.embedding_key if config.embedding_key is not None else "X" + records: list[dict] = [] + + for marker, exp_names in sorted(marker_to_exps.items()): + if len(exp_names) < 2: + continue + for exp_a, exp_b in combinations(exp_names, 2): + adata_a = adatas[exp_a][adatas[exp_a].obs["marker"] == marker] + adata_b = adatas[exp_b][adatas[exp_b].obs["marker"] == marker] + emb_a_full = _extract_embeddings(adata_a, config.embedding_key).astype(np.float32) + emb_b_full = _extract_embeddings(adata_b, config.embedding_key).astype(np.float32) + obs_a = adata_a.obs + obs_b = adata_b.obs + + emb_a_full = emb_a_full - emb_a_full.mean(axis=0) + emb_b_full = emb_b_full - emb_b_full.mean(axis=0) + + conditions = sorted(set(obs_a[config.group_by].unique()) & set(obs_b[config.group_by].unique())) + for condition in conditions: + cond_mask_a = obs_a[config.group_by] == condition + cond_mask_b = obs_b[config.group_by] == condition + emb_ca = emb_a_full[cond_mask_a.values] + emb_cb = emb_b_full[cond_mask_b.values] + + if config.temporal_bin_size is None and config.temporal_bins is None: + mmd2, p_value, bw, es, az, na, nb = _run_one_comparison(emb_ca, emb_cb, config.mmd) + records.append( + _combined_record( + marker, + exp_a, + exp_b, + condition, + float("nan"), + float("nan"), + na, + nb, + mmd2, + p_value, + bw, + es, + az, + emb_key_label, + ) + ) + else: + if "hours_post_perturbation" not in obs_a.columns: + raise KeyError("temporal binning requires obs column 'hours_post_perturbation'") + max_hours = min(obs_a["hours_post_perturbation"].max(), obs_b["hours_post_perturbation"].max()) + bin_pairs = _resolve_bin_edges(config.temporal_bin_size, config.temporal_bins, max_hours) + for b_start, b_end in bin_pairs: + bin_mask_a = ( + cond_mask_a + & (obs_a["hours_post_perturbation"] >= b_start) + & (obs_a["hours_post_perturbation"] < b_end) + ) + bin_mask_b = ( + cond_mask_b + & (obs_b["hours_post_perturbation"] >= b_start) + & (obs_b["hours_post_perturbation"] < b_end) + ) + bin_emb_a = emb_a_full[bin_mask_a.values] + bin_emb_b = emb_b_full[bin_mask_b.values] + mmd2, p_value, bw, es, az, na, nb = _run_one_comparison(bin_emb_a, bin_emb_b, config.mmd) + records.append( + _combined_record( + marker, + exp_a, + exp_b, + condition, + b_start, + b_end, + na, + nb, + mmd2, + p_value, + bw, + es, + az, + emb_key_label, + ) + ) + + return pd.DataFrame(records) + + +def _combined_record( + marker: str, + exp_a: str, + exp_b: str, + condition: str, + hours_bin_start: float, + hours_bin_end: float, + n_a: int, + n_b: int, + mmd2: float, + p_value: float, + bandwidth: float, + effect_size: float, + activity_zscore: float, + embedding_key: str, +) -> dict: + return { + "marker": marker, + "exp_a": exp_a, + "exp_b": exp_b, + "condition": condition, + "hours_bin_start": hours_bin_start, + "hours_bin_end": hours_bin_end, + "n_a": n_a, + "n_b": n_b, + "mmd2": mmd2, + "p_value": p_value, + "bandwidth": bandwidth, + "effect_size": effect_size, + "activity_zscore": activity_zscore, + "embedding_key": embedding_key, + } + + +def run_mmd_pooled(config: MMDPooledConfig) -> pd.DataFrame: + """Run pooled multi-experiment MMD/mAP analysis. + + Concatenates cells from all input experiments into a single pool, then + computes MMD (and optionally mAP) per (marker, time_bin, comparison). + Unlike the combined mode (pairwise batch-effect detection), this pools all + experiments together for phenotypic profiling. + + Parameters + ---------- + config : MMDPooledConfig + Pooled analysis configuration. + + Returns + ------- + pd.DataFrame + Results with columns: marker, cond_a, cond_b, label, hours_bin_start, + hours_bin_end, n_a, n_b, mmd2, p_value, bandwidth, effect_size, + activity_zscore, map_value, map_p_value, embedding_key. + FDR-corrected q_value column is also included. + """ + from statsmodels.stats.multitest import multipletests + + adatas = [ad.read_zarr(p) for p in config.input_paths] + combined = ad.concat(adatas, join="outer", label="source_experiment") + combined.obs_names_make_unique() + + if config.obs_filter: + mask = pd.Series([True] * len(combined), index=combined.obs.index) + for col, val in config.obs_filter.items(): + if col not in combined.obs.columns: + raise KeyError(f"obs_filter column '{col}' not found. Available: {list(combined.obs.columns)}") + mask &= combined.obs[col] == val + combined = combined[mask].copy() + + if config.condition_aliases: + alias_map: dict[str, str] = {} + for canonical, variants in config.condition_aliases.items(): + for v in variants: + alias_map[v] = canonical + combined.obs[config.group_by] = combined.obs[config.group_by].map(lambda x: alias_map.get(x, x)) + + obs = combined.obs + if config.group_by not in obs.columns: + raise KeyError(f"obs column '{config.group_by}' not found. Available: {list(obs.columns)}") + + emb_key_label = config.embedding_key if config.embedding_key is not None else "X" + all_emb = _extract_embeddings(combined, config.embedding_key) + + records: list[dict] = [] + for marker in sorted(obs["marker"].unique()): + marker_mask = obs["marker"] == marker + + if config.temporal_bin_size is None and config.temporal_bins is None: + shared_bw = _compute_shared_bandwidth( + all_emb, obs, marker_mask, config.comparisons, config.mmd, config.group_by + ) + for comp in config.comparisons: + mask_a = marker_mask & (obs[config.group_by] == comp.cond_a) + mask_b = marker_mask & (obs[config.group_by] == comp.cond_b) + emb_a = all_emb[mask_a.values] + emb_b = all_emb[mask_b.values] + bw = shared_bw if shared_bw is not None else None + mmd2, p_value, bw_out, es, az, na, nb = _run_one_comparison(emb_a, emb_b, config.mmd, bandwidth=bw) + map_val, map_pval = _maybe_map( + obs[marker_mask.values], + all_emb[marker_mask.values], + comp, + config.group_by, + marker, + config.map_settings, + ) + records.append( + _pooled_record( + marker, + comp, + float("nan"), + float("nan"), + na, + nb, + mmd2, + p_value, + bw_out, + es, + az, + map_val, + map_pval, + emb_key_label, + ) + ) + else: + if "hours_post_perturbation" not in obs.columns: + raise KeyError("temporal binning requires obs column 'hours_post_perturbation'") + max_hours = obs["hours_post_perturbation"].max() + bin_pairs = _resolve_bin_edges(config.temporal_bin_size, config.temporal_bins, max_hours) + for b_start, b_end in bin_pairs: + shared_bw = _compute_shared_bandwidth_temporal( + all_emb, obs, marker_mask, config.comparisons, config.mmd, config.group_by, b_start, b_end + ) + for comp in config.comparisons: + mask_a = marker_mask & (obs[config.group_by] == comp.cond_a) + bin_mask_b = ( + marker_mask + & (obs[config.group_by] == comp.cond_b) + & (obs["hours_post_perturbation"] >= b_start) + & (obs["hours_post_perturbation"] < b_end) + ) + emb_a = all_emb[mask_a.values] + emb_b = all_emb[bin_mask_b.values] + bw = shared_bw if shared_bw is not None else None + mmd2, p_value, bw_out, es, az, na, nb = _run_one_comparison(emb_a, emb_b, config.mmd, bandwidth=bw) + map_val, map_pval = _maybe_map( + obs[marker_mask.values], + all_emb[marker_mask.values], + comp, + config.group_by, + marker, + config.map_settings, + ) + records.append( + _pooled_record( + marker, + comp, + b_start, + b_end, + na, + nb, + mmd2, + p_value, + bw_out, + es, + az, + map_val, + map_pval, + emb_key_label, + ) + ) + + df = pd.DataFrame(records) + if not df.empty: + valid_p = df["p_value"].dropna() + if len(valid_p) > 0: + _, q_values, _, _ = multipletests(df["p_value"].fillna(1.0), alpha=0.05, method="fdr_bh") + df["q_value"] = q_values + df.loc[df["p_value"].isna(), "q_value"] = float("nan") + else: + df["q_value"] = float("nan") + return df + + +def _pooled_record( + marker: str, + comp: ComparisonSpec, + hours_bin_start: float, + hours_bin_end: float, + n_a: int, + n_b: int, + mmd2: float, + p_value: float, + bandwidth: float, + effect_size: float, + activity_zscore: float, + map_value: float, + map_p_value: float, + embedding_key: str, +) -> dict: + return { + "marker": marker, + "cond_a": comp.cond_a, + "cond_b": comp.cond_b, + "label": comp.label, + "hours_bin_start": hours_bin_start, + "hours_bin_end": hours_bin_end, + "n_a": n_a, + "n_b": n_b, + "mmd2": mmd2, + "p_value": p_value, + "bandwidth": bandwidth, + "effect_size": effect_size, + "activity_zscore": activity_zscore, + "map_value": map_value, + "map_p_value": map_p_value, + "embedding_key": embedding_key, + } + + +@click.command(context_settings={"help_option_names": ["-h", "--help"]}) +@click.argument("mmd_dir", type=click.Path(exists=True, path_type=Path)) +@click.option( + "--output-dir", type=click.Path(path_type=Path), default=None, help="Output directory. Default: same as mmd_dir." +) +def plot_mmd_heatmap_cmd(mmd_dir: Path, output_dir: Path | None) -> None: + """Plot a combined MMD heatmap (all markers) from per-experiment CSVs in MMD_DIR.""" + from dynaclr.evaluation.mmd.plotting import plot_mmd_heatmap + + csvs = sorted(mmd_dir.glob("*_mmd_results.csv")) + if not csvs: + raise click.ClickException(f"No *_mmd_results.csv files found in {mmd_dir}") + + df = pd.concat([pd.read_csv(f) for f in csvs], ignore_index=True) + click.echo(f"Loaded {len(df)} rows from {len(csvs)} CSV(s)") + + out = output_dir or mmd_dir + out.mkdir(parents=True, exist_ok=True) + + for comp_label in df["label"].unique(): + sub = df[df["label"] == comp_label] + safe = comp_label.replace(" ", "_").replace("/", "-") + for fmt in ("pdf", "png"): + plot_mmd_heatmap(sub, out / f"all_markers_{safe}_heatmap.{fmt}") + click.echo(f"Saved heatmap for: {comp_label}") + + +@click.command(context_settings={"help_option_names": ["-h", "--help"]}) +@click.option( + "-c", + "--config", + type=click.Path(exists=True, path_type=Path), + required=True, + help="Path to MMD evaluation YAML config", +) +@click.option( + "--combined", + is_flag=True, + default=False, + help="Run cross-experiment combined mode (config must have input_paths list)", +) +@click.option( + "--pooled", + is_flag=True, + default=False, + help="Run pooled multi-experiment phenotypic analysis (config must have input_paths list)", +) +def main(config: Path, combined: bool, pooled: bool) -> None: + """Compute MMD between explicit condition pairs in cell embeddings. + + Comparisons are defined as explicit (cond_a, cond_b, label) pairs. + The analysis is always faceted by obs["marker"]. + """ + if combined and pooled: + raise click.UsageError("--combined and --pooled are mutually exclusive") + raw = load_composed_config(config) + output_dir = Path(raw["output_dir"]) + output_dir.mkdir(parents=True, exist_ok=True) + + if combined: + cfg = MMDCombinedConfig(**raw) + df = run_mmd_combined(cfg) + out_csv = output_dir / "combined_mmd_results.csv" + df.to_csv(out_csv, index=False) + click.echo(f"Saved: {out_csv}") + if cfg.save_plots: + _save_plots_combined(df, output_dir, cfg.temporal_bin_size) + _print_summary(df, mode="combined") + elif pooled: + cfg = MMDPooledConfig(**raw) + df = run_mmd_pooled(cfg) + out_csv = output_dir / "pooled_mmd_results.csv" + df.to_csv(out_csv, index=False) + click.echo(f"Saved: {out_csv}") + if cfg.save_plots and len(df): + _save_plots_pooled(df, output_dir) + _print_summary(df, mode="pooled") + else: + cfg = MMDEvalConfig(**raw) + adata = ad.read_zarr(cfg.input_path) + df = run_mmd_analysis(adata, cfg) + experiment = df["experiment"].iloc[0] if len(df) else "unknown" + out_csv = output_dir / f"{experiment}_mmd_results.csv" + df.to_csv(out_csv, index=False) + click.echo(f"Saved: {out_csv}") + if cfg.save_plots and len(df): + _save_plots(df, output_dir, experiment, cfg.temporal_bin_size or cfg.temporal_bins) + _print_summary(df, mode="per_experiment") + + +def _save_plots(df: pd.DataFrame, output_dir: Path, label: str, temporal_config) -> None: + from dynaclr.evaluation.mmd.plotting import plot_mmd_kinetics, plot_mmd_multi_panel_kinetics + + has_bins = temporal_config is not None and len(df) and not df["hours_bin_start"].isna().all() + if not has_bins: + return + for comp_label in df["label"].unique(): + sub = df[df["label"] == comp_label] + safe = comp_label.replace(" ", "_").replace("/", "-") + for fmt in ("pdf", "png"): + plot_mmd_kinetics(sub, output_dir / f"{label}_{safe}_kinetics.{fmt}") + for fmt in ("pdf", "png"): + plot_mmd_multi_panel_kinetics(df, output_dir / f"{label}_multi_panel_kinetics.{fmt}") + if "activity_zscore" in df.columns and not df["activity_zscore"].isna().all(): + from dynaclr.evaluation.mmd.plotting import plot_activity_heatmap, plot_paired_heatmaps + + for fmt in ("pdf", "png"): + plot_activity_heatmap(df, output_dir / f"{label}_activity_heatmap.{fmt}") + labels = [c for c in df["label"].unique() if c] + if len(labels) >= 2: + for fmt in ("pdf", "png"): + plot_paired_heatmaps(df, labels[:2], "activity_zscore", output_dir / f"{label}_paired_activity.{fmt}") + + +def _save_plots_combined(df: pd.DataFrame, output_dir: Path, temporal_bin_size: float | None) -> None: + from dynaclr.evaluation.mmd.plotting import plot_mmd_combined_heatmap, plot_mmd_kinetics + + has_bins = temporal_bin_size is not None and len(df) and not df["hours_bin_start"].isna().all() + for fmt in ("pdf", "png"): + if has_bins: + for marker in df["marker"].unique(): + sub = df[df["marker"] == marker] + safe = marker.replace(" ", "_").replace("/", "-") + plot_mmd_kinetics(sub, output_dir / f"combined_{safe}_kinetics.{fmt}") + plot_mmd_combined_heatmap(df, output_dir / f"combined_heatmap.{fmt}") + + +def _save_plots_pooled(df: pd.DataFrame, output_dir: Path) -> None: + from dynaclr.evaluation.mmd.plotting import ( + plot_activity_heatmap, + plot_mmd_heatmap, + plot_mmd_multi_panel_kinetics, + plot_paired_heatmaps, + ) + + has_bins = not df["hours_bin_start"].isna().all() + for fmt in ("pdf", "png"): + for comp_label in df["label"].unique(): + sub = df[df["label"] == comp_label] + safe = comp_label.replace(" ", "_").replace("/", "-") + plot_mmd_heatmap(sub, output_dir / f"pooled_{safe}_heatmap.{fmt}") + if has_bins: + plot_mmd_multi_panel_kinetics(df, output_dir / f"pooled_multi_panel_kinetics.{fmt}") + if "activity_zscore" in df.columns and not df["activity_zscore"].isna().all(): + plot_activity_heatmap(df, output_dir / f"pooled_activity_heatmap.{fmt}") + labels = [c for c in df["label"].unique() if c] + if len(labels) >= 2: + plot_paired_heatmaps(df, labels[:2], "activity_zscore", output_dir / f"pooled_paired_activity.{fmt}") + + +def _print_summary(df: pd.DataFrame, mode: str = "per_experiment") -> None: + if df.empty: + click.echo("No results.") + return + click.echo("\n## MMD Results Summary\n") + if mode == "combined": + summary = ( + df.dropna(subset=["mmd2"]) + .groupby(["marker", "condition"])[["mmd2", "p_value", "effect_size"]] + .agg({"mmd2": "mean", "p_value": "min", "effect_size": "mean"}) + .round(4) + .reset_index() + ) + elif mode == "pooled": + summary = ( + df.dropna(subset=["mmd2"]) + .groupby(["marker", "label"])[["mmd2", "p_value", "effect_size", "activity_zscore"]] + .agg({"mmd2": "mean", "p_value": "min", "effect_size": "mean", "activity_zscore": "mean"}) + .round(4) + .reset_index() + ) + else: + summary = ( + df.dropna(subset=["mmd2"]) + .groupby(["marker", "label"])[["mmd2", "p_value", "effect_size"]] + .agg({"mmd2": "mean", "p_value": "min", "effect_size": "mean"}) + .round(4) + .reset_index() + ) + click.echo(summary.to_string(index=False)) diff --git a/applications/dynaclr/src/dynaclr/evaluation/mmd/config.py b/applications/dynaclr/src/dynaclr/evaluation/mmd/config.py new file mode 100644 index 000000000..e80463bd4 --- /dev/null +++ b/applications/dynaclr/src/dynaclr/evaluation/mmd/config.py @@ -0,0 +1,224 @@ +"""Pydantic configuration for the MMD perturbation evaluation step.""" + +from __future__ import annotations + +from typing import Optional + +import numpy as np +from pydantic import BaseModel, model_validator + + +class ComparisonSpec(BaseModel): + """One pairwise comparison to run MMD on. + + Parameters + ---------- + cond_a : str + Value of ``obs[group_by]`` for group A (typically the reference/control). + cond_b : str + Value of ``obs[group_by]`` for group B (typically the treatment). + label : str + Human-readable label for this comparison (used in output filenames and plots). + """ + + cond_a: str + cond_b: str + label: str + + +class MMDSettings(BaseModel): + """Kernel MMD algorithm settings, shared across per-experiment and combined modes. + + Parameters + ---------- + n_permutations : int + Number of permutations for the significance test. Default: 1000. + max_cells : int or None + Subsample each group to at most this many cells before computing MMD. + Controls memory and compute cost. Default: 2000. + min_cells : int + Minimum cells required per group. Groups below this produce NaN. Default: 20. + seed : int + Random seed for subsampling and permutations. Default: 42. + balance_samples : bool + Subsample the larger group to match the smaller group's size before + computing MMD. Prevents sample-size imbalance from inflating test statistics. + Applied after the ``max_cells`` cap. Default: False. + share_bandwidth_from : str or None + Label of a comparison whose bandwidth should be reused for all other + comparisons within the same (marker, time_bin) group. Typically the + baseline comparison (e.g. ``"uninf1 vs uninf2"``). If None, each + comparison computes its own bandwidth independently. Default: None. + """ + + n_permutations: int = 1000 + max_cells: Optional[int] = 2000 + min_cells: int = 20 + seed: int = 42 + balance_samples: bool = False + share_bandwidth_from: Optional[str] = None + + +class MAPSettings(BaseModel): + """Settings for the copairs-based mean Average Precision metric. + + Parameters + ---------- + enabled : bool + Compute mAP alongside MMD. Requires the ``copairs`` package. Default: False. + distance : str + Distance metric passed to copairs (e.g. ``"cosine"``). Default: ``"cosine"``. + null_size : int + Number of null pairs for the mAP permutation test. Default: 10000. + seed : int + Random seed. Default: 0. + """ + + enabled: bool = False + distance: str = "cosine" + null_size: int = 10000 + seed: int = 0 + + +class _MMDBaseConfig(BaseModel): + """Shared fields for all MMD analysis modes. + + Parameters + ---------- + output_dir : str + Directory for CSV results and plots. + group_by : str + obs column used to select condition groups. Default: ``"perturbation"``. + obs_filter : dict[str, str] or None + Restrict analysis to rows where ``obs[key] == value``. Default: None. + embedding_key : str or None + obsm key to use. None = raw ``.X`` backbone embeddings. Default: None. + mmd : MMDSettings + Kernel MMD algorithm settings. + map_settings : MAPSettings + copairs-based mAP settings. Default: disabled. + temporal_bin_size : float or None + Width of each temporal bin in hours, starting from 0. + Bin edges: ``[0, size, 2*size, ..., max_hours]``. + Mutually exclusive with ``temporal_bins``. Default: None (aggregate). + temporal_bins : list[float] or None + Explicit bin edges in hours (e.g. ``[0, 6, 12, 24]``). Takes precedence + over ``temporal_bin_size``. Default: None (aggregate). + save_plots : bool + Generate plots after computing metrics. Default: True. + """ + + output_dir: str + group_by: str = "perturbation" + obs_filter: Optional[dict[str, str]] = None + embedding_key: Optional[str] = None + mmd: MMDSettings = MMDSettings() + map_settings: MAPSettings = MAPSettings() + temporal_bin_size: Optional[float] = None + temporal_bins: Optional[list[float]] = None + save_plots: bool = True + + @model_validator(mode="after") + def _validate_temporal(self) -> "_MMDBaseConfig": + if self.temporal_bin_size is not None and self.temporal_bins is not None: + raise ValueError("temporal_bin_size and temporal_bins are mutually exclusive") + return self + + +def _resolve_bin_edges( + temporal_bin_size: Optional[float], + temporal_bins: Optional[list[float]], + max_hours: float, +) -> Optional[list[tuple[float, float]]]: + """Return a list of (start, end) bin edge pairs, or None if no temporal binning. + + Parameters + ---------- + temporal_bin_size : float or None + Uniform bin width. Generates edges ``[0, size, 2*size, ..., max_hours]``. + temporal_bins : list[float] or None + Explicit bin edges (e.g. ``[0, 6, 12, 24]``). Takes precedence over + ``temporal_bin_size``. + max_hours : float + Maximum hours value in the data, used only when ``temporal_bin_size`` is set. + + Returns + ------- + list[tuple[float, float]] or None + Ordered list of ``(bin_start, bin_end)`` pairs, or ``None`` for aggregate mode. + """ + if temporal_bins is not None: + edges = temporal_bins + elif temporal_bin_size is not None: + edges = list(np.arange(0, max_hours + temporal_bin_size, temporal_bin_size)) + else: + return None + return list(zip(edges[:-1], edges[1:])) + + +class MMDEvalConfig(_MMDBaseConfig): + """Per-experiment MMD analysis with explicit pairwise comparisons. + + Parameters + ---------- + input_path : str + Path to a single per-experiment AnnData zarr store. + comparisons : list[ComparisonSpec] + Explicit list of pairwise comparisons to run (required). + """ + + input_path: str + comparisons: list[ComparisonSpec] + + @model_validator(mode="after") + def _validate(self) -> "MMDEvalConfig": + if not self.comparisons: + raise ValueError("comparisons must not be empty") + return self + + +class MMDCombinedConfig(_MMDBaseConfig): + """Pairwise cross-experiment MMD for batch-effect detection. + + Conditions are auto-discovered from the data intersection — no explicit + comparisons needed. For each marker shared between a pair of experiments, + runs MMD per (condition, time_bin) after per-experiment mean centering. + + Parameters + ---------- + input_paths : list[str] + Paths to per-experiment AnnData zarr stores. + """ + + input_paths: list[str] + + +class MMDPooledConfig(_MMDBaseConfig): + """Pooled multi-experiment phenotypic analysis. + + Concatenates cells from all input experiments before computing MMD/mAP, + faceted by marker and temporal bin. Unlike ``MMDCombinedConfig`` (pairwise + batch-effect detection), this pools all experiments for a single biological + comparison. + + Parameters + ---------- + input_paths : list[str] + Paths to per-experiment AnnData zarr stores to pool. + comparisons : list[ComparisonSpec] + Explicit list of pairwise comparisons to run (required). + condition_aliases : dict[str, list[str]] or None + Mapping from canonical condition name to variant strings found in the + data. E.g. ``{"uninfected": ["uninfected", "uninfected1", "uninfected2"]}``. + Applied to ``obs[group_by]`` before comparisons are evaluated. + """ + + input_paths: list[str] + comparisons: list[ComparisonSpec] + condition_aliases: Optional[dict[str, list[str]]] = None + + @model_validator(mode="after") + def _validate(self) -> "MMDPooledConfig": + if not self.comparisons: + raise ValueError("comparisons must not be empty") + return self diff --git a/applications/dynaclr/src/dynaclr/evaluation/mmd/plotting.py b/applications/dynaclr/src/dynaclr/evaluation/mmd/plotting.py new file mode 100644 index 000000000..9828f0711 --- /dev/null +++ b/applications/dynaclr/src/dynaclr/evaluation/mmd/plotting.py @@ -0,0 +1,438 @@ +"""Plots for MMD perturbation evaluation: kinetics curves and heatmaps.""" + +from __future__ import annotations + +import math +from pathlib import Path + +import matplotlib +import numpy as np +import pandas as pd + +matplotlib.use("Agg") +import matplotlib.colors as mcolors +import matplotlib.pyplot as plt +import seaborn as sns +from statsmodels.stats.multitest import multipletests + + +def _bh_significance(p_values: np.ndarray, alpha: float = 0.05) -> np.ndarray: + """Return boolean mask of BH-corrected significant p-values.""" + p_values = np.asarray(p_values, dtype=float) + valid = ~np.isnan(p_values) + sig = np.zeros(len(p_values), dtype=bool) + if valid.sum() == 0: + return sig + _, corrected, _, _ = multipletests(p_values[valid], alpha=alpha, method="fdr_bh") + sig[valid] = corrected + return sig + + +def plot_mmd_kinetics(df: pd.DataFrame, output_path: Path) -> None: + """Plot MMD kinetics curves (one line per marker over temporal bins). + + Parameters + ---------- + df : pd.DataFrame + MMD results for a single treatment group, with columns: + marker, hours_bin_start, hours_bin_end, mmd2, p_value. + output_path : Path + Output file path. Format inferred from suffix (.pdf or .png). + """ + df = df.copy().dropna(subset=["hours_bin_start", "hours_bin_end"]) + if df.empty: + return + df["bin_mid"] = (df["hours_bin_start"] + df["hours_bin_end"]) / 2 + + markers = sorted(df["marker"].unique()) + fig, ax = plt.subplots(figsize=(8, 4)) + palette = sns.color_palette("tab10", n_colors=len(markers)) + + for marker, color in zip(markers, palette): + sub = df[df["marker"] == marker].sort_values("bin_mid") + ax.plot(sub["bin_mid"], sub["mmd2"], marker="o", label=marker, color=color) + # Stars for BH-significant bins + sig = _bh_significance(sub["p_value"]) + for _, row, s in zip(range(len(sub)), sub.itertuples(), sig): + if s: + ax.text(row.bin_mid, row.mmd2, "*", ha="center", va="bottom", color=color, fontsize=12) + + ax.set_xlabel("Hours post perturbation (bin midpoint)") + ax.set_ylabel("MMD²") + ax.set_title(df["label"].iloc[0] if "label" in df.columns else "") + ax.legend(title="Marker", bbox_to_anchor=(1.01, 1), loc="upper left", fontsize=10, title_fontsize=11) + ax.axhline(0, color="gray", linewidth=0.8, linestyle="--") + sns.despine(ax=ax) + fig.tight_layout() + fig.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close(fig) + + +def plot_mmd_combined_heatmap(df: pd.DataFrame, output_path: Path) -> None: + """Plot combined cross-experiment MMD heatmap: markers × experiment pairs. + + One subplot per condition. Rows = markers, columns = exp_a vs exp_b pairs + (averaged over temporal bins if present). + + Parameters + ---------- + df : pd.DataFrame + Combined MMD results with columns: marker, exp_a, exp_b, condition, + hours_bin_start, hours_bin_end, mmd2, p_value. + output_path : Path + Output file path. + """ + df = df.copy() + df["exp_pair"] = ( + df["exp_a"].str.split("_").str[:3].str.join("_") + "\nvs\n" + df["exp_b"].str.split("_").str[:3].str.join("_") + ) + conditions = sorted(df["condition"].unique()) + n_conds = len(conditions) + + fig, axes = plt.subplots( + 1, n_conds, figsize=(max(5 * n_conds, 6), max(4, df["marker"].nunique() * 0.7)), squeeze=False + ) + + for ax, condition in zip(axes[0], conditions): + sub = df[df["condition"] == condition] + pivot_mmd = sub.pivot_table(index="marker", columns="exp_pair", values="mmd2", aggfunc="mean") + pivot_pval = sub.pivot_table(index="marker", columns="exp_pair", values="p_value", aggfunc="min") + + if pivot_mmd.empty or pivot_mmd.isna().all().all(): + ax.set_visible(False) + continue + + sns.heatmap(pivot_mmd, ax=ax, cmap="viridis", linewidths=0.5, cbar_kws={"label": "MMD²"}) + + sig = _bh_significance(pivot_pval.values.ravel()) + sig_matrix = sig.reshape(pivot_pval.shape) + for r in range(sig_matrix.shape[0]): + for c in range(sig_matrix.shape[1]): + if sig_matrix[r, c]: + ax.text( + c + 0.5, r + 0.5, "*", ha="center", va="center", color="white", fontsize=10, fontweight="bold" + ) + + ax.set_title(f"condition: {condition}") + ax.set_xlabel("Experiment pair") + ax.set_ylabel("Marker") + ax.tick_params(axis="x", labelsize=7) + + fig.suptitle("Cross-experiment MMD — all markers", y=1.01) + fig.tight_layout() + fig.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close(fig) + + +def plot_mmd_multi_panel_kinetics( + df: pd.DataFrame, + output_path: Path, + baseline_label: str | None = None, + ncols: int = 4, +) -> None: + """Plot per-marker MMD kinetics in a multi-panel grid with optional baseline band. + + One subplot per marker. Treatment comparisons are plotted as colored lines; + if ``baseline_label`` is given, that comparison is shown as a gray dashed + line with a shaded ±1 std band instead of a treatment line. + + Parameters + ---------- + df : pd.DataFrame + MMD results with columns: marker, label, hours_bin_start, hours_bin_end, + mmd2, p_value. + output_path : Path + Output file path (.pdf or .png). + baseline_label : str or None + Label of the baseline comparison to render as a band. Default: None. + ncols : int + Number of columns in the panel grid. Default: 4. + """ + df = df.copy().dropna(subset=["hours_bin_start", "hours_bin_end"]) + if df.empty: + return + df["bin_mid"] = (df["hours_bin_start"] + df["hours_bin_end"]) / 2 + + markers = sorted(df["marker"].unique()) + treatment_labels = [lbl for lbl in df["label"].unique() if lbl != baseline_label] + nrows = math.ceil(len(markers) / ncols) + palette = sns.color_palette("tab10", n_colors=max(len(treatment_labels), 1)) + + # Shared y-axis range + treat_vals = df[df["label"].isin(treatment_labels)]["mmd2"].dropna() + y_min = float(treat_vals.min()) if len(treat_vals) else 0.0 + y_max = float(treat_vals.max()) if len(treat_vals) else 1.0 + y_pad = (y_max - y_min) * 0.1 + 1e-6 + + fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 3.5, nrows * 2.8), squeeze=False) + + for ax_idx, marker in enumerate(markers): + ax = axes[ax_idx // ncols][ax_idx % ncols] + sub = df[df["marker"] == marker] + + # Baseline band + if baseline_label is not None: + base = sub[sub["label"] == baseline_label].sort_values("bin_mid") + if not base.empty: + ax.axhline(base["mmd2"].mean(), color="gray", linewidth=1.0, linestyle="--", zorder=1) + ax.fill_between( + base["bin_mid"], + base["mmd2"] - base["mmd2"].std(), + base["mmd2"] + base["mmd2"].std(), + color="gray", + alpha=0.2, + zorder=1, + ) + + # Treatment lines + for lbl, color in zip(treatment_labels, palette): + treat = sub[sub["label"] == lbl].sort_values("bin_mid") + if treat.empty: + continue + sig = _bh_significance(treat["p_value"]) + ax.plot(treat["bin_mid"], treat["mmd2"], color=color, linewidth=1.2, label=lbl, zorder=2) + sig_rows = treat[sig] + if not sig_rows.empty: + ax.scatter( + sig_rows["bin_mid"], + sig_rows["mmd2"], + color=color, + edgecolors="black", + linewidths=0.8, + s=40, + zorder=3, + ) + + ax.set_title(marker, fontsize=9) + ax.set_ylim(y_min - y_pad, y_max + y_pad) + ax.axhline(0, color="lightgray", linewidth=0.5, linestyle="--") + sns.despine(ax=ax) + + # Hide unused axes + for ax_idx in range(len(markers), nrows * ncols): + axes[ax_idx // ncols][ax_idx % ncols].set_visible(False) + + # Shared legend + handles, lbls = axes[0][0].get_legend_handles_labels() + if handles: + fig.legend( + handles, lbls, loc="lower center", ncol=len(treatment_labels), fontsize=9, bbox_to_anchor=(0.5, -0.02) + ) + + fig.supxlabel("Hours post perturbation (bin midpoint)", fontsize=10) + fig.supylabel("MMD²", fontsize=10) + fig.tight_layout() + fig.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close(fig) + + +def plot_activity_heatmap( + df: pd.DataFrame, + output_path: Path, + linthresh: float = 1.0, +) -> None: + """Plot activity z-score heatmap (markers × temporal bins). + + Uses symmetric log normalization so both small and large z-scores are + visible. Significance stars mark FDR-corrected significant cells. + + Parameters + ---------- + df : pd.DataFrame + MMD results with columns: marker, label, hours_bin_start, hours_bin_end, + activity_zscore, p_value. + output_path : Path + Output file path (.pdf or .png). + linthresh : float + Linear threshold for ``SymLogNorm``. Values within ``[-linthresh, + linthresh]`` are rendered linearly; outside is log-scaled. Default: 1.0. + """ + if "activity_zscore" not in df.columns or df["activity_zscore"].isna().all(): + return + df = df.copy().dropna(subset=["hours_bin_start", "hours_bin_end", "activity_zscore"]) + if df.empty: + return + df["bin_label"] = df.apply(lambda r: f"{r.hours_bin_start:.0f}–{r.hours_bin_end:.0f}h", axis=1) + + labels = [lbl for lbl in df["label"].unique() if lbl] + n_labels = len(labels) + fig, axes = plt.subplots( + 1, + n_labels, + figsize=(max(5, len(df["bin_label"].unique()) * 1.0 * n_labels), max(4, df["marker"].nunique() * 0.6)), + squeeze=False, + ) + + for ax, lbl in zip(axes[0], labels): + sub = df[df["label"] == lbl] + pivot_z = sub.pivot_table(index="marker", columns="bin_label", values="activity_zscore", aggfunc="mean") + pivot_pval = sub.pivot_table(index="marker", columns="bin_label", values="p_value", aggfunc="min") + bin_order = sub.drop_duplicates("bin_label").sort_values("hours_bin_start")["bin_label"].tolist() + pivot_z = pivot_z.reindex(columns=bin_order) + pivot_pval = pivot_pval.reindex(columns=bin_order) + + if pivot_z.empty or pivot_z.isna().all().all(): + ax.set_visible(False) + continue + + vmax = float(np.nanmax(np.abs(pivot_z.values))) + norm = mcolors.SymLogNorm(linthresh=linthresh, vmin=-vmax, vmax=vmax) + sns.heatmap(pivot_z, ax=ax, cmap="RdBu_r", norm=norm, linewidths=0.3, cbar_kws={"label": "Activity z-score"}) + + sig = _bh_significance(pivot_pval.values.ravel()) + sig_matrix = sig.reshape(pivot_pval.shape) + for r in range(sig_matrix.shape[0]): + for c in range(sig_matrix.shape[1]): + if sig_matrix[r, c]: + ax.text( + c + 0.5, r + 0.5, "*", ha="center", va="center", color="black", fontsize=10, fontweight="bold" + ) + + ax.set_title(lbl) + ax.set_xlabel("Temporal bin") + ax.set_ylabel("Marker") + + fig.tight_layout() + fig.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close(fig) + + +def plot_paired_heatmaps( + df: pd.DataFrame, + condition_labels: list[str], + value_col: str, + output_path: Path, + linthresh: float = 1.0, +) -> None: + """Plot side-by-side heatmaps for two conditions sharing a colorbar. + + Parameters + ---------- + df : pd.DataFrame + MMD results. Must have columns: marker, label, hours_bin_start, + hours_bin_end, ``value_col``, p_value. + condition_labels : list[str] + Exactly two comparison labels to plot side-by-side. + value_col : str + Column to use as heatmap values (e.g. ``"activity_zscore"``). + output_path : Path + Output file path. + linthresh : float + Linear threshold for ``SymLogNorm``. Default: 1.0. + """ + if value_col not in df.columns or len(condition_labels) < 2: + return + df = df.copy().dropna(subset=["hours_bin_start", "hours_bin_end", value_col]) + if df.empty: + return + df["bin_label"] = df.apply(lambda r: f"{r.hours_bin_start:.0f}–{r.hours_bin_end:.0f}h", axis=1) + bin_order = df.drop_duplicates("bin_label").sort_values("hours_bin_start")["bin_label"].tolist() + + all_vals = df[df["label"].isin(condition_labels)][value_col].dropna() + if all_vals.empty: + return + vmax = float(np.nanmax(np.abs(all_vals))) + norm = mcolors.SymLogNorm(linthresh=linthresh, vmin=-vmax, vmax=vmax) + + fig, axes = plt.subplots( + 1, 2, figsize=(max(10, len(bin_order) * 2), max(4, df["marker"].nunique() * 0.6)), squeeze=False + ) + + for ax, lbl in zip(axes[0], condition_labels[:2]): + sub = df[df["label"] == lbl] + pivot_val = sub.pivot_table(index="marker", columns="bin_label", values=value_col, aggfunc="mean") + pivot_pval = sub.pivot_table(index="marker", columns="bin_label", values="p_value", aggfunc="min") + pivot_val = pivot_val.reindex(columns=bin_order) + pivot_pval = pivot_pval.reindex(columns=bin_order) + + if pivot_val.empty or pivot_val.isna().all().all(): + ax.set_visible(False) + continue + + im = ax.imshow( + pivot_val.values, + aspect="auto", + norm=norm, + cmap="YlOrRd", + origin="upper", + ) + ax.set_xticks(range(len(pivot_val.columns))) + ax.set_xticklabels(pivot_val.columns, rotation=45, ha="right", fontsize=8) + ax.set_yticks(range(len(pivot_val.index))) + ax.set_yticklabels(pivot_val.index, fontsize=8) + ax.set_title(lbl) + + sig = _bh_significance(pivot_pval.values.ravel()) + sig_matrix = sig.reshape(pivot_pval.shape) + for r in range(sig_matrix.shape[0]): + for c in range(sig_matrix.shape[1]): + val = pivot_val.values[r, c] + if np.isfinite(val): + txt = f"{int(val)}" if abs(val) >= 1 else f"{val:.1f}" + if sig_matrix[r, c]: + txt += "*" + ax.text(c, r, txt, ha="center", va="center", fontsize=7, color="black") + + plt.colorbar(im, ax=axes[0], label=value_col) + fig.suptitle(f"{' vs '.join(condition_labels[:2])}", y=1.01) + fig.tight_layout() + fig.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close(fig) + + +def plot_mmd_heatmap(df: pd.DataFrame, output_path: Path) -> None: + """Plot MMD heatmap (markers x temporal bins or aggregate). + + Parameters + ---------- + df : pd.DataFrame + MMD results for a single treatment group. + output_path : Path + Output file path. + """ + df = df.copy() + has_bins = not df["hours_bin_start"].isna().all() + + if has_bins: + df["bin_label"] = df.apply(lambda r: f"{r.hours_bin_start:.0f}–{r.hours_bin_end:.0f}h", axis=1) + pivot_mmd = df.pivot_table(index="marker", columns="bin_label", values="mmd2", aggfunc="mean") + pivot_pval = df.pivot_table(index="marker", columns="bin_label", values="p_value", aggfunc="min") + # Order columns by bin start + bin_order = df.drop_duplicates("bin_label").sort_values("hours_bin_start")["bin_label"].tolist() + pivot_mmd = pivot_mmd.reindex(columns=bin_order) + pivot_pval = pivot_pval.reindex(columns=bin_order) + xlabel = "Temporal bin" + figsize = (max(6, len(bin_order) * 0.8), max(4, len(pivot_mmd) * 0.6)) + else: + pivot_mmd = df.set_index("marker")[["mmd2"]].rename(columns={"mmd2": "aggregate"}) + pivot_pval = df.set_index("marker")[["p_value"]].rename(columns={"p_value": "aggregate"}) + xlabel = "" + figsize = (3, max(4, len(pivot_mmd) * 0.6)) + + if pivot_mmd.empty or pivot_mmd.isna().all().all(): + return + + fig, ax = plt.subplots(figsize=figsize) + sns.heatmap( + pivot_mmd, + ax=ax, + cmap="viridis", + annot=False, + linewidths=0.5, + cbar_kws={"label": "MMD²"}, + ) + + # Add significance stars + sig = _bh_significance(pivot_pval.values.ravel()) + sig_matrix = sig.reshape(pivot_pval.shape) + for r in range(sig_matrix.shape[0]): + for c in range(sig_matrix.shape[1]): + if sig_matrix[r, c]: + ax.text(c + 0.5, r + 0.5, "*", ha="center", va="center", color="white", fontsize=10, fontweight="bold") + + ax.set_title(f"MMD heatmap — {df['label'].iloc[0] if 'label' in df.columns else ''}") + ax.set_xlabel(xlabel) + ax.set_ylabel("Marker") + fig.tight_layout() + fig.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close(fig) diff --git a/applications/dynaclr/src/dynaclr/evaluation/plot_embeddings.py b/applications/dynaclr/src/dynaclr/evaluation/plot_embeddings.py new file mode 100644 index 000000000..d18bc6438 --- /dev/null +++ b/applications/dynaclr/src/dynaclr/evaluation/plot_embeddings.py @@ -0,0 +1,278 @@ +"""CLI tool for generating scatter plots from AnnData embedding stores. + +For high-dimensional embeddings (PCA): generates a seaborn pairplot of the +first N components, one figure per color variable. +For low-dimensional embeddings (PHATE, UMAP): generates a simple scatter +colored by each metadata column. + +Usage +----- +dynaclr plot-embeddings -c plot_config.yaml +""" + +from pathlib import Path +from typing import Optional + +import anndata as ad +import click +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from pydantic import BaseModel, Field, model_validator + +from viscy_utils.cli_utils import load_config + + +class PlotEmbeddingsConfig(BaseModel): + """Configuration for plot-embeddings command. + + Parameters + ---------- + input_path : str, optional + Path to a single AnnData zarr store. Mutually exclusive with input_paths. + input_paths : list[str], optional + Paths to multiple AnnData zarr stores. All are concatenated before plotting. + Use for combined embeddings (X_pca_combined, X_phate_combined) to get one + figure across all experiments. Mutually exclusive with input_path. + output_dir : str + Directory to save plots. + embedding_keys : list[str] + obsm keys to plot (e.g. X_phate, X_pca). + color_by : list[str] + obs columns to use as hue in pairplots / color in scatter plots. + pairplot_components : int + Number of leading components to include in pairplots. Default: 10. + point_size : float + Scatter plot point size (passed as ``s`` to matplotlib and + ``plot_kws`` to seaborn). Default: 1.0. + format : str + Output format: "pdf", "png", or "both". Default: "pdf". + low_dim_threshold : int + Embeddings with <= this many components use the simple scatter path + instead of pairplot. Default: 4. + """ + + input_path: Optional[str] = None + input_paths: Optional[list[str]] = None + output_dir: str = Field(...) + embedding_keys: list[str] = ["X_pca_combined", "X_phate_combined"] + color_by: list[str] = ["perturbation", "hours_post_perturbation", "experiment", "marker"] + pairplot_components: int = 10 + point_size: float = 1.0 + format: str = "pdf" + low_dim_threshold: int = 4 + + @model_validator(mode="after") + def validate_input(self): + if self.input_path is None and self.input_paths is None: + raise ValueError("Either input_path or input_paths must be provided") + if self.input_path is not None and self.input_paths is not None: + raise ValueError("Provide either input_path or input_paths, not both") + return self + + +_PALETTE = [ + "#1b69a1", + "#d9534f", + "#5cb85c", + "#f0ad4e", + "#9b59b6", + "#1abc9c", + "#e74c3c", + "#3498db", + "#2ecc71", + "#e67e22", +] + + +def _save_fig(fig: plt.Figure, output_dir: Path, stem: str, fmt: str) -> None: + if fmt in ("pdf", "both"): + fig.savefig(output_dir / f"{stem}.pdf", dpi=150, bbox_inches="tight") + if fmt in ("png", "both"): + fig.savefig(output_dir / f"{stem}.png", dpi=150, bbox_inches="tight") + plt.close(fig) + click.echo(f" Saved {stem}.{fmt}") + + +def _pairplot( + emb: np.ndarray, + obs: pd.DataFrame, + color_col: str, + n_components: int, + point_size: float, + emb_key: str, +) -> plt.Figure: + """Build a seaborn pairplot of the first n_components.""" + import seaborn as sns + + n = min(n_components, emb.shape[1]) + cols = [f"{emb_key}_{i}" for i in range(n)] + df = pd.DataFrame(emb[:, :n], columns=cols) + + values = obs[color_col].to_numpy() + is_categorical = values.dtype.kind in ("U", "O", "S") or hasattr(values, "cat") + + if is_categorical: + cats = sorted(str(v) for v in np.unique(values)) + palette = {cat: _PALETTE[i % len(_PALETTE)] for i, cat in enumerate(cats)} + df[color_col] = [str(v) for v in values] + pg = sns.pairplot( + df, + hue=color_col, + palette=palette, + plot_kws={"s": point_size, "alpha": 0.4, "rasterized": True, "zorder": 0}, + diag_kind="hist", + corner=True, + ) + pg.legend.set(title=color_col) + for lh in pg.legend.legend_handles: + lh.set_alpha(1.0) + if hasattr(lh, "set_sizes"): + lh.set_sizes([40]) + else: + lh.set_markersize(8) + for ax_row in pg.axes: + for ax in ax_row: + if ax is not None: + ax.set_rasterization_zorder(1) + else: + # Continuous: no hue support in pairplot — use a custom scatter matrix + df[color_col] = values.astype(float) + pg = sns.pairplot( + df, + plot_kws={"s": point_size, "alpha": 0.4, "rasterized": True, "color": "#888888", "zorder": 0}, + diag_kind="hist", + corner=True, + ) + # Overlay color on lower-triangle axes + norm = plt.Normalize(df[color_col].min(), df[color_col].max()) + cmap = plt.cm.viridis + for i in range(1, n): + for j in range(i): + ax = pg.axes[i][j] + if ax is None: + continue + ax.collections[0].set_visible(False) + sc = ax.scatter( + df.iloc[:, j], + df.iloc[:, i], + c=df[color_col], + cmap=cmap, + norm=norm, + s=point_size, + alpha=0.4, + rasterized=True, + zorder=0, + ) + pg.figure.colorbar(sc, ax=pg.axes[-1][-1], label=color_col) + for ax_row in pg.axes: + for ax in ax_row: + if ax is not None: + ax.set_rasterization_zorder(1) + + pg.figure.suptitle(f"{emb_key} — {color_col}", y=1.01, fontsize=11, fontweight="bold") + return pg.figure + + +def _scatter_2d( + emb: np.ndarray, + obs: pd.DataFrame, + color_cols: list[str], + point_size: float, + emb_key: str, +) -> plt.Figure: + """Simple scatter for low-dimensional embeddings (PHATE, UMAP).""" + ncols = min(4, len(color_cols)) + nrows = (len(color_cols) + ncols - 1) // ncols + fig, axes = plt.subplots(nrows, ncols, figsize=(5 * ncols, 5 * nrows), squeeze=False) + rng = np.random.default_rng(42) + shuffle = rng.permutation(len(emb)) + x, y = emb[shuffle, 0], emb[shuffle, 1] + + for ax_idx, col in enumerate(color_cols): + ax = axes[ax_idx // ncols][ax_idx % ncols] + values = obs[col].to_numpy()[shuffle] + is_categorical = values.dtype.kind in ("U", "O", "S") or hasattr(values, "cat") + + if is_categorical: + cats = sorted(str(v) for v in np.unique(values)) + for i, cat in enumerate(cats): + mask = np.array([str(v) == cat for v in values]) + ax.scatter( + x[mask], y[mask], s=point_size, c=_PALETTE[i % len(_PALETTE)], label=cat, alpha=0.5, rasterized=True + ) + ax.legend( + markerscale=6, fontsize=10, loc="best", framealpha=1.0, edgecolor="black", ncol=max(1, len(cats) // 8) + ) + else: + sc = ax.scatter(x, y, s=point_size, c=values.astype(float), cmap="viridis", alpha=0.5, rasterized=True) + plt.colorbar(sc, ax=ax, shrink=0.8) + + ax.set_title(col.replace("_", " ").title(), fontsize=10) + ax.set_xlabel(f"{emb_key} 0") + ax.set_ylabel(f"{emb_key} 1") + + for ax_idx in range(len(color_cols), nrows * ncols): + axes[ax_idx // ncols][ax_idx % ncols].set_visible(False) + + fig.suptitle(f"Embeddings: {emb_key}", fontsize=13, fontweight="bold") + plt.tight_layout() + return fig + + +@click.command(context_settings={"help_option_names": ["-h", "--help"]}) +@click.option( + "-c", + "--config", + type=click.Path(exists=True, path_type=Path), + required=True, + help="Path to YAML configuration file", +) +def main(config: Path) -> None: + """Generate pairplots (PCA) and scatter plots (PHATE/UMAP) from an AnnData store.""" + matplotlib.use("Agg") + + raw = load_config(config) + cfg = PlotEmbeddingsConfig(**raw) + + if cfg.input_paths is not None: + click.echo(f"Concatenating {len(cfg.input_paths)} zarr stores...") + adata = ad.concat([ad.read_zarr(p) for p in cfg.input_paths], join="outer") + else: + adata = ad.read_zarr(cfg.input_path) + output_dir = Path(cfg.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + valid_color_cols = [c for c in cfg.color_by if c in adata.obs.columns] + missing = set(cfg.color_by) - set(valid_color_cols) + if missing: + click.echo(f"Warning: obs columns not found, skipping: {sorted(missing)}", err=True) + if not valid_color_cols: + click.echo("No valid color columns found, nothing to plot.", err=True) + return + + for emb_key in cfg.embedding_keys: + if emb_key not in adata.obsm: + click.echo(f"Warning: {emb_key} not in obsm, skipping", err=True) + continue + + emb = np.asarray(adata.obsm[emb_key]) + click.echo(f"Plotting {emb_key} ({emb.shape[1]} components)...") + + if emb.shape[1] <= cfg.low_dim_threshold: + # Simple scatter (PHATE, UMAP) + fig = _scatter_2d(emb, adata.obs, valid_color_cols, cfg.point_size, emb_key) + _save_fig(fig, output_dir, f"scatter_{emb_key}", cfg.format) + else: + # Pairplot per color variable (PCA) + for col in valid_color_cols: + try: + fig = _pairplot(emb, adata.obs, col, cfg.pairplot_components, cfg.point_size, emb_key) + _save_fig(fig, output_dir, f"pairplot_{emb_key}_{col}", cfg.format) + except Exception as e: + click.echo(f" Warning: pairplot {emb_key}/{col} failed: {e}", err=True) + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/src/dynaclr/evaluation/pseudotime/dtw_alignment.py b/applications/dynaclr/src/dynaclr/evaluation/pseudotime/dtw_alignment.py new file mode 100644 index 000000000..69499b9fe --- /dev/null +++ b/applications/dynaclr/src/dynaclr/evaluation/pseudotime/dtw_alignment.py @@ -0,0 +1,862 @@ +"""DTW-based pseudotime alignment for cellular dynamics. + +Aligns cell trajectories to a template infection response using Dynamic +Time Warping (DTW). The template is built from annotated transitioning +cells via DBA (DTW Barycenter Averaging), then all cells are warped +onto it to produce pseudotime values in [0, 1]. + +Preprocessing pipeline: per-experiment z-score -> PCA -> L2-normalize -> DTW. +""" + +from __future__ import annotations + +import logging +import uuid +from typing import NamedTuple + +import anndata as ad +import numpy as np +import pandas as pd +from dtaidistance import dtw, dtw_ndim +from sklearn.decomposition import PCA +from sklearn.preprocessing import normalize + +_logger = logging.getLogger(__name__) + +POSITIVE_CLASSES: dict[str, str] = { + "infection_state": "infected", + "organelle_state": "remodel", +} + + +class TemplateResult(NamedTuple): + """Result of building an infection response template.""" + + template: np.ndarray + template_id: str + pca: PCA | None + zscore_params: dict[str, tuple[np.ndarray, np.ndarray]] + template_cell_ids: list[tuple[str, str, int]] + n_input_tracks: int + explained_variance: float | None + template_labels: dict[str, np.ndarray] | None # {col: (T,) fraction} per label column + time_calibration: np.ndarray | None = None # (T,) mean t_relative_minutes per template position + + +class AlignmentResult(NamedTuple): + """DTW alignment result for a single cell track.""" + + cell_uid: str + dataset_id: str + fov_name: str + track_id: int + timepoints: np.ndarray + pseudotime: np.ndarray + dtw_cost: float + warping_path: np.ndarray + warping_speed: np.ndarray + propagated_labels: dict[str, np.ndarray] | None # {col: (T,) fraction} per label column + alignment_region: np.ndarray # per-frame: "pre", "aligned", or "post" + + +def _zscore_embeddings( + embeddings_dict: dict[str, np.ndarray], +) -> tuple[dict[str, np.ndarray], dict[str, tuple[np.ndarray, np.ndarray]]]: + """Per-experiment z-score normalization. + + Parameters + ---------- + embeddings_dict : dict[str, np.ndarray] + {dataset_id: (N, D) embedding array}. + + Returns + ------- + tuple[dict[str, np.ndarray], dict[str, tuple[np.ndarray, np.ndarray]]] + Z-scored embeddings and per-experiment (mean, std) params. + """ + zscored = {} + params = {} + for dataset_id, emb in embeddings_dict.items(): + mean = emb.mean(axis=0) + std = emb.std(axis=0) + std = np.where(std < 1e-10, 1.0, std) + zscored[dataset_id] = (emb - mean) / std + params[dataset_id] = (mean, std) + return zscored, params + + +def _preprocess_embeddings( + embeddings: np.ndarray, + pca: PCA | None = None, +) -> np.ndarray: + """PCA transform + L2 normalize. + + Parameters + ---------- + embeddings : np.ndarray + (N, D) array, already z-scored. + pca : PCA or None + Fitted PCA model. If None, skip dimensionality reduction. + + Returns + ------- + np.ndarray + (N, D') L2-normalized embeddings. + """ + if pca is not None: + embeddings = pca.transform(embeddings) + return normalize(embeddings, norm="l2", axis=1) + + +def _extract_track_trajectories( + adata: ad.AnnData, + df: pd.DataFrame, + min_track_timepoints: int = 3, + crop_window: int | None = None, + label_cols: list[str] | None = None, +) -> list[tuple[str, int, np.ndarray, np.ndarray, dict[str, np.ndarray] | None]]: + """Extract per-track embedding trajectories from AnnData. + + Parameters + ---------- + adata : ad.AnnData + Embeddings with obs containing fov_name, track_id, t. + df : pd.DataFrame + Filtered tracking DataFrame (used for valid track selection). + Must have t_perturb column if crop_window is set. + min_track_timepoints : int + Minimum timepoints per track (applied after cropping). + crop_window : int or None + If set, crop each track to [t_perturb - crop_window, t_perturb + crop_window]. + Requires t_perturb column in df. None = use full track. + label_cols : list[str] or None + Label columns to extract (e.g., ["infection_state", "organelle_state"]). + Each is binarized using POSITIVE_CLASSES mapping. + + Returns + ------- + list[tuple[str, int, np.ndarray, np.ndarray, dict[str, np.ndarray] | None]] + Each element: (fov_name, track_id, embeddings (T, D), timepoints (T,), + labels {col: (T,)} or None). + """ + valid_tracks = df.groupby(["fov_name", "track_id"]).filter(lambda x: len(x) >= min_track_timepoints) + valid_keys = set(zip(valid_tracks["fov_name"], valid_tracks["track_id"])) + + # Build t_perturb lookup if cropping + t_perturb_lookup: dict[tuple[str, int], int] = {} + if crop_window is not None: + if "t_perturb" not in df.columns: + raise ValueError("crop_window requires t_perturb column in df") + for (fov, tid), grp in df.groupby(["fov_name", "track_id"]): + t_perturb_lookup[(fov, tid)] = int(grp["t_perturb"].iloc[0]) + + # Build label lookups per column + label_lookups: dict[str, dict[tuple, int]] = {} + if label_cols: + for col in label_cols: + if col not in df.columns: + continue + positive_val = POSITIVE_CLASSES[col] + lookup: dict[tuple, int] = {} + for _, row in df.iterrows(): + val = row[col] + if pd.notna(val) and val != "": + lookup[(row["fov_name"], row["track_id"], int(row["t"]))] = 1 if val == positive_val else 0 + label_lookups[col] = lookup + + obs = adata.obs.copy() + obs["_iloc"] = np.arange(len(obs)) + trajectories = [] + for (fov_name, track_id), group in obs.groupby(["fov_name", "track_id"]): + if (fov_name, track_id) not in valid_keys: + continue + sorted_group = group.sort_values("t") + + # Crop around t_perturb if requested + if crop_window is not None and (fov_name, track_id) in t_perturb_lookup: + tp = t_perturb_lookup[(fov_name, track_id)] + t_vals = sorted_group["t"].values + mask = (t_vals >= tp - crop_window) & (t_vals <= tp + crop_window) + sorted_group = sorted_group.iloc[mask] + + if len(sorted_group) < min_track_timepoints: + continue + + iloc_indices = sorted_group["_iloc"].values + emb = adata.X[iloc_indices] + if hasattr(emb, "toarray"): + emb = emb.toarray() + timepoints = sorted_group["t"].values.astype(int) + + labels = None + if label_lookups: + labels = {} + for col, lookup in label_lookups.items(): + labels[col] = np.array( + [lookup.get((fov_name, track_id, int(t)), 0) for t in timepoints], dtype=np.float64 + ) + + trajectories.append((str(fov_name), int(track_id), np.asarray(emb, dtype=np.float64), timepoints, labels)) + + return trajectories + + +def _dba( + sequences: list[np.ndarray], + max_iter: int = 30, + tol: float = 1e-5, + init: str = "medoid", +) -> np.ndarray: + """DTW Barycenter Averaging (DBA). + + Parameters + ---------- + sequences : list[np.ndarray] + List of (T_i, D) sequences. + max_iter : int + Maximum iterations. + tol : float + Convergence tolerance on mean absolute change. + init : str + Initialization method. "medoid" selects the sequence with + lowest total DTW cost to all others. + + Returns + ------- + np.ndarray + (T_avg, D) template sequence. + """ + if len(sequences) == 0: + raise ValueError("No sequences provided for DBA.") + + if init == "medoid": + n = len(sequences) + # Subsample for medoid if too many sequences (O(n²) DTW calls) + max_medoid_candidates = 50 + if n > max_medoid_candidates: + rng = np.random.default_rng(42) + candidate_idx = rng.choice(n, max_medoid_candidates, replace=False) + _logger.info("DBA medoid init: subsampling %d/%d candidates", max_medoid_candidates, n) + else: + candidate_idx = np.arange(n) + costs = np.zeros(len(candidate_idx)) + for ci, i in enumerate(candidate_idx): + for j in range(n): + if i != j: + costs[ci] += dtw_ndim.distance(sequences[i], sequences[j]) + avg = sequences[int(candidate_idx[np.argmin(costs)])].copy() + else: + avg = sequences[0].copy() + + for iteration in range(max_iter): + n_frames = avg.shape[0] + n_dims = avg.shape[1] + accum = np.zeros((n_frames, n_dims)) + counts = np.zeros(n_frames) + + for seq in sequences: + _, paths = dtw_ndim.warping_paths(avg, seq) + path = dtw.best_path(paths) + for idx_avg, idx_seq in path: + accum[idx_avg] += seq[idx_seq] + counts[idx_avg] += 1 + + counts = np.maximum(counts, 1) + new_avg = accum / counts[:, np.newaxis] + change = np.mean(np.abs(new_avg - avg)) + + _logger.debug(f"DBA iteration {iteration + 1}: mean change = {change:.6f}") + avg = new_avg + + if change < tol: + _logger.info(f"DBA converged at iteration {iteration + 1} (change={change:.2e})") + break + + return avg + + +def build_infection_template( + adata_dict: dict[str, ad.AnnData], + aligned_df_dict: dict[str, pd.DataFrame], + pca_n_components: int | None = 20, + pca_variance_threshold: float | None = None, + dba_max_iter: int = 30, + dba_tol: float = 1e-5, + dba_init: str = "medoid", + control_adata_dict: dict[str, ad.AnnData] | None = None, + crop_window: int | dict[str, int] | None = None, +) -> TemplateResult: + """Build an infection response template from annotated datasets. + + Parameters + ---------- + adata_dict : dict[str, ad.AnnData] + {dataset_id: adata} with embeddings for infected cells. + aligned_df_dict : dict[str, pd.DataFrame] + {dataset_id: aligned_df} with t_perturb assigned. + pca_n_components : int or None + Number of PCA components. Ignored if pca_variance_threshold is set. + pca_variance_threshold : float or None + If set, auto-select components to explain this variance fraction. + dba_max_iter : int + Max DBA iterations. + dba_tol : float + DBA convergence tolerance. + dba_init : str + DBA initialization ("medoid"). + control_adata_dict : dict[str, ad.AnnData] | None + Control embeddings per dataset, included in PCA fitting. + crop_window : int or dict[str, int] or None + If set, crop each track to [t_perturb - crop_window, t_perturb + crop_window] + before DBA. Produces a shorter template centered on the infection transition. + Pass a dict to use per-dataset crop windows (e.g. when datasets have different + frame intervals and crop_window was derived from a fixed duration in minutes). + None = use full tracks (variable length). + + Returns + ------- + TemplateResult + Template array, PCA model, z-score params, and metadata. + """ + raw_embeddings = {} + for dataset_id, adata in adata_dict.items(): + emb = adata.X + if hasattr(emb, "toarray"): + emb = emb.toarray() + raw_embeddings[dataset_id] = np.asarray(emb, dtype=np.float64) + + if control_adata_dict is not None: + for dataset_id, adata in control_adata_dict.items(): + ctrl_key = f"{dataset_id}__control" + emb = adata.X + if hasattr(emb, "toarray"): + emb = emb.toarray() + raw_embeddings[ctrl_key] = np.asarray(emb, dtype=np.float64) + + zscored, zscore_params = _zscore_embeddings(raw_embeddings) + + all_zscored = np.concatenate(list(zscored.values()), axis=0) + use_pca = pca_n_components is not None or pca_variance_threshold is not None + pca = None + explained_variance = None + + if use_pca: + if pca_variance_threshold is not None: + pca = PCA(n_components=pca_variance_threshold, svd_solver="full") + else: + n_comp = min(pca_n_components, all_zscored.shape[1], all_zscored.shape[0]) + pca = PCA(n_components=n_comp) + pca.fit(all_zscored) + explained_variance = float(np.sum(pca.explained_variance_ratio_)) + _logger.info(f"PCA: {pca.n_components_} components explain {explained_variance:.1%} variance") + + clean_zscore_params = {k: v for k, v in zscore_params.items() if "__control" not in k} + + trajectories = [] + track_labels: list[dict[str, np.ndarray] | None] = [] + track_t_rels: list[np.ndarray] = [] + cell_ids: list[tuple[str, str, int]] = [] + + # Detect which label columns are available across all datasets + label_cols = [col for col in POSITIVE_CLASSES if any(col in df.columns for df in aligned_df_dict.values())] + label_cols_or_none = label_cols if label_cols else None + + for dataset_id, adata in adata_dict.items(): + df = aligned_df_dict[dataset_id] + ds_zscored_emb = zscored[dataset_id] + + zscored_adata = ad.AnnData(X=ds_zscored_emb, obs=adata.obs.copy()) + zscored_adata.obs.index = adata.obs.index + + # Build t_relative_minutes lookup for this dataset + t_rel_lookup: dict[tuple[str, int, int], float] = {} + if "t_relative_minutes" in df.columns: + for _, row in df.iterrows(): + t_rel_lookup[(str(row["fov_name"]), int(row["track_id"]), int(row["t"]))] = float( + row["t_relative_minutes"] + ) + + ds_crop_window = crop_window[dataset_id] if isinstance(crop_window, dict) else crop_window + tracks = _extract_track_trajectories( + zscored_adata, + df, + min_track_timepoints=1, + crop_window=ds_crop_window, + label_cols=label_cols_or_none, + ) + for fov_name, track_id, emb, timepoints, labels in tracks: + processed = _preprocess_embeddings(emb, pca=pca) + trajectories.append(processed) + track_labels.append(labels) + cell_ids.append((dataset_id, fov_name, track_id)) + t_rel = np.array([t_rel_lookup.get((fov_name, track_id, int(t)), np.nan) for t in timepoints]) + track_t_rels.append(t_rel) + + if len(trajectories) == 0: + raise ValueError("No valid trajectories found for template building.") + + _logger.info(f"Building template from {len(trajectories)} trajectories") + template = _dba(trajectories, max_iter=dba_max_iter, tol=dba_tol, init=dba_init) + template = normalize(template, norm="l2", axis=1) + + # Compute template labels and time calibration via DTW alignment back to template. + # One DTW path per track; labels and t_relative_minutes mapped through the same path. + n_template = template.shape[0] + template_labels = None + time_calibration = None + + has_labels = label_cols and all(lb is not None for lb in track_labels) + has_t_rel = any(np.any(np.isfinite(t)) for t in track_t_rels) + + if has_labels or has_t_rel: + label_sums = {col: np.zeros(n_template) for col in label_cols} if has_labels else {} + label_counts = {col: np.zeros(n_template) for col in label_cols} if has_labels else {} + time_sums = np.zeros(n_template) + time_counts = np.zeros(n_template) + + for seq, labels_dict, t_rel_arr in zip(trajectories, track_labels, track_t_rels): + _, paths = dtw_ndim.warping_paths(template, seq) + path = dtw.best_path(paths) + if has_labels and labels_dict is not None: + for col in label_cols: + if col not in labels_dict: + continue + col_labels = labels_dict[col] + for idx_template, idx_seq in path: + if idx_seq < len(col_labels): + label_sums[col][idx_template] += col_labels[idx_seq] + label_counts[col][idx_template] += 1 + for idx_template, idx_seq in path: + if idx_seq < len(t_rel_arr) and np.isfinite(t_rel_arr[idx_seq]): + time_sums[idx_template] += t_rel_arr[idx_seq] + time_counts[idx_template] += 1 + + if has_labels: + template_labels = {} + for col in label_cols: + counts = np.maximum(label_counts[col], 1) + template_labels[col] = label_sums[col] / counts + _logger.info( + "Template labels [%s]: %d positions, fraction range [%.2f, %.2f]", + col, + n_template, + template_labels[col].min(), + template_labels[col].max(), + ) + + if has_t_rel and time_counts.sum() > 0: + raw_cal = np.where(time_counts > 0, time_sums / np.maximum(time_counts, 1), np.nan) + # Interpolate any gaps linearly + positions = np.arange(n_template) + valid_mask = np.isfinite(raw_cal) + if valid_mask.sum() >= 2: + time_calibration = np.interp(positions, positions[valid_mask], raw_cal[valid_mask]) + elif valid_mask.sum() == 1: + time_calibration = np.full(n_template, raw_cal[valid_mask][0]) + _logger.info( + "Time calibration: %d positions, range [%.1f, %.1f] min", + n_template, + time_calibration.min(), + time_calibration.max(), + ) + + return TemplateResult( + template=template, + template_id=str(uuid.uuid4()), + pca=pca, + zscore_params=clean_zscore_params, + template_cell_ids=cell_ids, + n_input_tracks=len(trajectories), + explained_variance=explained_variance, + template_labels=template_labels, + time_calibration=time_calibration, + ) + + +def dtw_align_tracks( + adata: ad.AnnData, + df: pd.DataFrame, + template_result: TemplateResult, + dataset_id: str, + min_track_timepoints: int = 3, + psi: int | None = None, + subsequence: bool = False, +) -> list[AlignmentResult]: + """Align cell tracks to a template using DTW. + + Parameters + ---------- + adata : ad.AnnData + Embeddings with obs containing fov_name, track_id, t. + df : pd.DataFrame + Tracking DataFrame (optionally with t_perturb). + template_result : TemplateResult + Template from build_infection_template. + dataset_id : str + Identifier for this dataset. + min_track_timepoints : int + Minimum timepoints per track. + psi : int or None + Psi relaxation for DTW. If None, auto-computed: + - subsequence=True: psi = max(track_len - template_len, 0) + - subsequence=False: psi = template_len // 2 + subsequence : bool + If True, use subsequence DTW: sweep the (short) template across + the (long) cell track to find the best-matching segment. + Frames before the matched region get pseudotime=0, + frames after get pseudotime=1. + Use this when the template was built with crop_window. + + Returns + ------- + list[AlignmentResult] + One result per aligned track. + """ + emb = adata.X + if hasattr(emb, "toarray"): + emb = emb.toarray() + emb = np.asarray(emb, dtype=np.float64) + + if dataset_id in template_result.zscore_params: + mean, std = template_result.zscore_params[dataset_id] + else: + mean = emb.mean(axis=0) + std = emb.std(axis=0) + std = np.where(std < 1e-10, 1.0, std) + emb_zscored = (emb - mean) / std + + zscored_adata = ad.AnnData(X=emb_zscored, obs=adata.obs.copy()) + zscored_adata.obs.index = adata.obs.index + + tracks = _extract_track_trajectories(zscored_adata, df, min_track_timepoints) + template = template_result.template + t_template = template.shape[0] + + results = [] + for fov_name, track_id, track_emb, timepoints, _labels in tracks: + processed = _preprocess_embeddings(track_emb, pca=template_result.pca) + n_track = len(processed) + + # Compute psi (must be < min(template_len, track_len)) + max_psi = min(n_track - 1, t_template - 1) + if psi is not None: + track_psi = min(psi, max_psi) + elif subsequence: + # Allow template to float anywhere within the track + track_psi = max_psi + else: + track_psi = min(t_template // 2, max_psi) + + _, paths = dtw_ndim.warping_paths(template, processed, psi=track_psi) + path = dtw.best_path(paths) + path_arr = np.array(path) + + cost = paths[path_arr[-1, 0], path_arr[-1, 1]] + + pseudotime = np.zeros(n_track) + speed = np.zeros(n_track) + alignment_region = np.full(n_track, "aligned", dtype=object) + + # Map each query frame to its template position + # DTW path: (idx_template, idx_query) pairs + # A query frame may appear multiple times; keep the last (highest) template position + matched_template_pos = np.full(n_track, -1.0) + for idx_template, idx_query in path: + if idx_query < n_track: + matched_template_pos[idx_query] = idx_template + + if subsequence and t_template > 1: + # Find the matched region (query frames that got a template assignment) + matched_mask = matched_template_pos >= 0 + if matched_mask.any(): + first_matched = np.argmax(matched_mask) + last_matched = n_track - 1 - np.argmax(matched_mask[::-1]) + + # Within matched region: pseudotime from template position + for i in range(first_matched, last_matched + 1): + if matched_template_pos[i] >= 0: + pseudotime[i] = matched_template_pos[i] / (t_template - 1) + + # Forward-fill any gaps within the matched region + for i in range(first_matched + 1, last_matched + 1): + if matched_template_pos[i] < 0: + pseudotime[i] = pseudotime[i - 1] + + # Before matched region: pseudotime = 0 + pseudotime[:first_matched] = 0.0 + # After matched region: pseudotime = 1 + pseudotime[last_matched + 1 :] = 1.0 + alignment_region[:first_matched] = "pre" + alignment_region[last_matched + 1 :] = "post" + else: + pseudotime[:] = 0.0 + alignment_region[:] = "pre" + elif t_template > 1: + # Standard DTW: template position / (template_length - 1) + template_positions = np.zeros(n_track) + for idx_template, idx_query in path: + if idx_query < n_track: + template_positions[idx_query] = idx_template + pseudotime = template_positions / (t_template - 1) + + # Propagate template labels to cell frames via warping path + propagated_labels = None + if template_result.template_labels is not None: + propagated_labels = {} + for col, tl in template_result.template_labels.items(): + col_propagated = np.full(n_track, np.nan) + for idx_template, idx_query in path: + if idx_query < n_track and idx_template < len(tl): + col_propagated[idx_query] = tl[idx_template] + + if subsequence: + matched_mask_lbl = matched_template_pos >= 0 + if matched_mask_lbl.any(): + first_m = np.argmax(matched_mask_lbl) + last_m = n_track - 1 - np.argmax(matched_mask_lbl[::-1]) + for i in range(first_m + 1, last_m + 1): + if np.isnan(col_propagated[i]): + col_propagated[i] = col_propagated[i - 1] + col_propagated[:first_m] = 0.0 + col_propagated[last_m + 1 :] = 1.0 + + propagated_labels[col] = col_propagated + + # Compute warping speed (discrete derivative of pseudotime) + for i in range(n_track): + if i == 0: + speed[i] = pseudotime[1] - pseudotime[0] if n_track > 1 else 0.0 + elif i == n_track - 1: + speed[i] = pseudotime[i] - pseudotime[i - 1] + else: + speed[i] = (pseudotime[i + 1] - pseudotime[i - 1]) / 2 + + cell_uid = f"{dataset_id}/{fov_name}/{track_id}" + results.append( + AlignmentResult( + cell_uid=cell_uid, + dataset_id=dataset_id, + fov_name=fov_name, + track_id=track_id, + timepoints=timepoints, + pseudotime=pseudotime, + dtw_cost=float(cost), + warping_path=path_arr, + warping_speed=speed, + propagated_labels=propagated_labels, + alignment_region=alignment_region, + ) + ) + + _logger.info(f"Aligned {len(results)} tracks for dataset {dataset_id}") + return results + + +def classify_response_groups( + alignment_results: list[AlignmentResult] | pd.DataFrame, + cost_percentile_threshold: float = 75.0, + speed_clustering_method: str = "quantile", + speed_quantile: float = 0.5, +) -> pd.DataFrame: + """Classify aligned cells into response groups. + + Groups: + - non_responder: DTW cost above percentile threshold + - early_responder: responders with above-median mean warping speed + - late_responder: responders with below-median mean warping speed + + Parameters + ---------- + alignment_results : list[AlignmentResult] or pd.DataFrame + Alignment results. If DataFrame, must have columns: + cell_uid, dtw_cost, mean_warping_speed (or warping_speed). + cost_percentile_threshold : float + Percentile of DTW cost above which cells are non-responders. + speed_clustering_method : str + "quantile" or "kmeans" for splitting early/late. + speed_quantile : float + Quantile threshold for speed split (used when method="quantile"). + + Returns + ------- + pd.DataFrame + One row per cell with columns: cell_uid, dataset_id, + response_group, dtw_cost, mean_warping_speed. + """ + if isinstance(alignment_results, pd.DataFrame): + df = alignment_results.copy() + if "mean_warping_speed" not in df.columns and "warping_speed" in df.columns: + df["mean_warping_speed"] = df.groupby("cell_uid")["warping_speed"].transform("mean") + per_cell = df.groupby("cell_uid").first().reset_index() + records = [] + for _, row in per_cell.iterrows(): + records.append( + { + "cell_uid": row["cell_uid"], + "dataset_id": row.get("dataset_id", ""), + "dtw_cost": row["dtw_cost"], + "mean_warping_speed": row["mean_warping_speed"], + } + ) + else: + records = [] + for r in alignment_results: + records.append( + { + "cell_uid": r.cell_uid, + "dataset_id": r.dataset_id, + "dtw_cost": r.dtw_cost, + "mean_warping_speed": float(np.mean(np.abs(r.warping_speed))), + } + ) + + df = pd.DataFrame(records) + if len(df) == 0: + df["response_group"] = pd.Series(dtype=str) + return df + + cost_threshold = np.percentile(df["dtw_cost"], cost_percentile_threshold) + df["response_group"] = "non_responder" + + responder_mask = df["dtw_cost"] <= cost_threshold + responders = df[responder_mask] + + if len(responders) > 0: + if speed_clustering_method == "quantile": + speed_threshold = responders["mean_warping_speed"].quantile(speed_quantile) + df.loc[responder_mask & (df["mean_warping_speed"] >= speed_threshold), "response_group"] = "early_responder" + df.loc[responder_mask & (df["mean_warping_speed"] < speed_threshold), "response_group"] = "late_responder" + elif speed_clustering_method == "kmeans": + from sklearn.cluster import KMeans + + speeds = responders["mean_warping_speed"].values.reshape(-1, 1) + if len(speeds) >= 2: + km = KMeans(n_clusters=2, random_state=42, n_init=10) + labels = km.fit_predict(speeds) + cluster_means = [speeds[labels == c].mean() for c in range(2)] + fast_cluster = int(np.argmax(cluster_means)) + resp_indices = responders.index + for idx, label in zip(resp_indices, labels): + if label == fast_cluster: + df.loc[idx, "response_group"] = "early_responder" + else: + df.loc[idx, "response_group"] = "late_responder" + else: + df.loc[responder_mask, "response_group"] = "early_responder" + + _logger.info( + f"Classification: {(df['response_group'] == 'early_responder').sum()} early, " + f"{(df['response_group'] == 'late_responder').sum()} late, " + f"{(df['response_group'] == 'non_responder').sum()} non-responder" + ) + + return df[["cell_uid", "dataset_id", "response_group", "dtw_cost", "mean_warping_speed"]] + + +def alignment_results_to_dataframe( + results: list[AlignmentResult], + template_id: str, + time_calibration: np.ndarray | None = None, +) -> pd.DataFrame: + """Flatten alignment results into a DataFrame (one row per timepoint). + + Parameters + ---------- + results : list[AlignmentResult] + Output of dtw_align_tracks. + template_id : str + Template UUID to attach. + time_calibration : np.ndarray or None + (T_template,) array mapping template position to mean t_relative_minutes. + If provided, adds an ``estimated_t_rel_minutes`` column. + + Returns + ------- + pd.DataFrame + Columns: cell_uid, dataset_id, fov_name, track_id, t, + pseudotime, dtw_cost, warping_speed, template_id, + plus propagated_{label}_label for each label column, + plus estimated_t_rel_minutes if time_calibration is provided. + """ + rows = [] + for r in results: + for i, t in enumerate(r.timepoints): + row = { + "cell_uid": r.cell_uid, + "dataset_id": r.dataset_id, + "fov_name": r.fov_name, + "track_id": r.track_id, + "t": int(t), + "pseudotime": float(r.pseudotime[i]), + "dtw_cost": r.dtw_cost, + "warping_speed": float(r.warping_speed[i]), + "alignment_region": r.alignment_region[i], + "template_id": template_id, + } + if r.propagated_labels is not None: + for col, arr in r.propagated_labels.items(): + col_clean = col.replace("_state", "") + row[f"propagated_{col_clean}_label"] = float(arr[i]) + rows.append(row) + df = pd.DataFrame(rows) + if time_calibration is not None and len(df) > 0: + T = len(time_calibration) + df["estimated_t_rel_minutes"] = np.interp( + df["pseudotime"].values * (T - 1), + np.arange(T), + time_calibration, + ) + return df + + +def extract_dtw_pseudotime( + adata: ad.AnnData, + df: pd.DataFrame, + template_result: TemplateResult, + dataset_id: str, + min_track_timepoints: int = 3, + cost_percentile_threshold: float = 75.0, + speed_clustering_method: str = "quantile", + speed_quantile: float = 0.5, + psi: int | None = None, +) -> pd.DataFrame: + """Convenience wrapper: align + classify + flatten. + + Parameters + ---------- + adata : ad.AnnData + Embeddings AnnData. + df : pd.DataFrame + Tracking DataFrame. + template_result : TemplateResult + Built template. + dataset_id : str + Dataset identifier. + min_track_timepoints : int + Minimum timepoints per track. + cost_percentile_threshold : float + Non-responder cost threshold percentile. + speed_clustering_method : str + "quantile" or "kmeans". + speed_quantile : float + Speed split quantile. + + Returns + ------- + pd.DataFrame + Flat DataFrame with pseudotime renamed to "signal" for metrics + compatibility, plus dtw_cost, warping_speed, response_group columns. + """ + results = dtw_align_tracks(adata, df, template_result, dataset_id, min_track_timepoints, psi=psi) + flat = alignment_results_to_dataframe( + results, template_result.template_id, time_calibration=template_result.time_calibration + ) + classifications = classify_response_groups( + results, + cost_percentile_threshold=cost_percentile_threshold, + speed_clustering_method=speed_clustering_method, + speed_quantile=speed_quantile, + ) + merged = flat.merge(classifications[["cell_uid", "response_group"]], on="cell_uid", how="left") + merged = merged.rename(columns={"pseudotime": "signal"}) + return merged diff --git a/applications/dynaclr/src/dynaclr/evaluation/pseudotime/evaluation.py b/applications/dynaclr/src/dynaclr/evaluation/pseudotime/evaluation.py new file mode 100644 index 000000000..4d97dfe35 --- /dev/null +++ b/applications/dynaclr/src/dynaclr/evaluation/pseudotime/evaluation.py @@ -0,0 +1,295 @@ +"""Evaluation of DTW pseudotime against ground truth annotations. + +Compares DTW-derived pseudotime with annotated infection_state and +organelle_state to quantify alignment quality. Designed to run across +multiple embedding types for comparison. +""" + +from __future__ import annotations + +import logging + +import numpy as np +import pandas as pd +from scipy.stats import spearmanr +from sklearn.metrics import average_precision_score, roc_auc_score + +_logger = logging.getLogger(__name__) + + +def pseudotime_vs_annotation_auc( + df: pd.DataFrame, + pseudotime_col: str = "pseudotime", + annotation_col: str = "infection_state", + positive_value: str = "infected", +) -> float: + """ROC-AUC of pseudotime predicting a binary annotation. + + Parameters + ---------- + df : pd.DataFrame + Must have pseudotime_col and annotation_col columns. + pseudotime_col : str + Column with DTW pseudotime values. + annotation_col : str + Column with ground truth annotation. + positive_value : str + Value in annotation_col that is the positive class. + + Returns + ------- + float + ROC-AUC score, or NaN if not computable. + """ + valid = df.dropna(subset=[pseudotime_col, annotation_col]) + valid = valid[valid[annotation_col] != ""] + if len(valid) == 0: + return np.nan + + y_true = (valid[annotation_col] == positive_value).astype(int).values + y_score = valid[pseudotime_col].values + + if len(np.unique(y_true)) < 2: + return np.nan + + return float(roc_auc_score(y_true, y_score)) + + +def onset_concordance( + df: pd.DataFrame, + pseudotime_col: str = "pseudotime", + annotation_col: str = "infection_state", + positive_value: str = "infected", + min_track_timepoints: int = 3, +) -> tuple[float, int]: + """Spearman correlation between DTW-derived and annotation-derived onset times. + + For each track, onset is defined as the first timepoint where the signal + transitions to positive. Computes correlation across all tracks that have + a detectable onset in both DTW pseudotime and annotations. + + Parameters + ---------- + df : pd.DataFrame + Must have pseudotime_col, annotation_col, fov_name, track_id, t columns. + pseudotime_col : str + Column with DTW pseudotime values. + annotation_col : str + Column with ground truth annotation. + positive_value : str + Positive value in annotation_col. + min_track_timepoints : int + Minimum timepoints per track to include. + + Returns + ------- + tuple[float, int] + (Spearman rho, n_tracks) or (NaN, 0) if not computable. + """ + valid = df.dropna(subset=[pseudotime_col, annotation_col]) + valid = valid[valid[annotation_col] != ""] + + dtw_onsets = [] + ann_onsets = [] + + for (fov, tid), track in valid.groupby(["fov_name", "track_id"]): + if len(track) < min_track_timepoints: + continue + track = track.sort_values("t") + + # Annotation onset: first timepoint with positive value + ann_positive = track[track[annotation_col] == positive_value] + if len(ann_positive) == 0: + continue + ann_onset_t = ann_positive["t"].iloc[0] + + # DTW onset: first timepoint where pseudotime exceeds median of track + pt = track[pseudotime_col].values + threshold = np.median(pt) + above = track[track[pseudotime_col] > threshold] + if len(above) == 0: + continue + dtw_onset_t = above["t"].iloc[0] + + dtw_onsets.append(dtw_onset_t) + ann_onsets.append(ann_onset_t) + + if len(dtw_onsets) < 3: + return np.nan, len(dtw_onsets) + + rho, _ = spearmanr(dtw_onsets, ann_onsets) + return float(rho), len(dtw_onsets) + + +def per_timepoint_auc( + df: pd.DataFrame, + pseudotime_col: str = "pseudotime", + annotation_col: str = "infection_state", + positive_value: str = "infected", + time_col: str = "t", +) -> pd.DataFrame: + """ROC-AUC of pseudotime predicting annotation at each timepoint. + + Parameters + ---------- + df : pd.DataFrame + Must have pseudotime_col, annotation_col, time_col columns. + pseudotime_col : str + Column with DTW pseudotime values. + annotation_col : str + Column with ground truth annotation. + positive_value : str + Positive value in annotation_col. + time_col : str + Timepoint column. + + Returns + ------- + pd.DataFrame + Columns: t, auc, n_cells, n_positive. + """ + valid = df.dropna(subset=[pseudotime_col, annotation_col]) + valid = valid[valid[annotation_col] != ""] + + rows = [] + for t_val, group in valid.groupby(time_col): + y_true = (group[annotation_col] == positive_value).astype(int).values + y_score = group[pseudotime_col].values + n_pos = int(y_true.sum()) + + if len(np.unique(y_true)) < 2: + auc = np.nan + else: + auc = float(roc_auc_score(y_true, y_score)) + + rows.append({"t": t_val, "auc": auc, "n_cells": len(group), "n_positive": n_pos}) + + return pd.DataFrame(rows) + + +def _pseudotime_ap( + df: pd.DataFrame, + pseudotime_col: str = "pseudotime", + annotation_col: str = "infection_state", + positive_value: str = "infected", +) -> float: + """Average precision (AUPRC) of pseudotime predicting a binary annotation. + + Parameters + ---------- + df : pd.DataFrame + Must have pseudotime_col and annotation_col columns. + pseudotime_col : str + Column with DTW pseudotime values. + annotation_col : str + Column with ground truth annotation. + positive_value : str + Value in annotation_col that is the positive class. + + Returns + ------- + float + Average precision score, or NaN if not computable. + """ + valid = df.dropna(subset=[pseudotime_col, annotation_col]) + valid = valid[valid[annotation_col] != ""] + if len(valid) == 0: + return np.nan + + y_true = (valid[annotation_col] == positive_value).astype(int).values + y_score = valid[pseudotime_col].values + + if len(np.unique(y_true)) < 2: + return np.nan + + return float(average_precision_score(y_true, y_score)) + + +def evaluate_embedding( + alignments: pd.DataFrame, + annotations: pd.DataFrame, + embedding_name: str, + dataset_id: str, +) -> dict: + """Run full evaluation suite for one embedding × dataset. + + Parameters + ---------- + alignments : pd.DataFrame + Output of alignment_results_to_dataframe (has pseudotime, fov_name, + track_id, t columns). + annotations : pd.DataFrame + Annotation CSV with fov_name, track_id, t, infection_state, + organelle_state columns. + embedding_name : str + Name of the embedding (e.g., "sensor", "organelle", "phase"). + dataset_id : str + Dataset identifier. + + Returns + ------- + dict + Summary metrics for this embedding × dataset. + """ + # Merge alignments with annotations + merge_keys = ["fov_name", "track_id", "t"] + merged = alignments.merge( + annotations[merge_keys + ["infection_state", "organelle_state"]], on=merge_keys, how="left" + ) + + result = { + "embedding": embedding_name, + "dataset_id": dataset_id, + "n_cells": len(merged), + "n_tracks": merged.groupby(["fov_name", "track_id"]).ngroup().nunique(), + } + + # Infection state AUC + AP + result["infection_auc"] = pseudotime_vs_annotation_auc( + merged, pseudotime_col="pseudotime", annotation_col="infection_state", positive_value="infected" + ) + result["infection_ap"] = _pseudotime_ap( + merged, pseudotime_col="pseudotime", annotation_col="infection_state", positive_value="infected" + ) + + # Organelle state AUC + AP + result["organelle_auc"] = pseudotime_vs_annotation_auc( + merged, pseudotime_col="pseudotime", annotation_col="organelle_state", positive_value="remodel" + ) + result["organelle_ap"] = _pseudotime_ap( + merged, pseudotime_col="pseudotime", annotation_col="organelle_state", positive_value="remodel" + ) + + # Onset concordance (infection) + rho, n_tracks = onset_concordance( + merged, pseudotime_col="pseudotime", annotation_col="infection_state", positive_value="infected" + ) + result["infection_onset_spearman"] = rho + result["infection_onset_n_tracks"] = n_tracks + + # Onset concordance (organelle) + rho_org, n_tracks_org = onset_concordance( + merged, pseudotime_col="pseudotime", annotation_col="organelle_state", positive_value="remodel" + ) + result["organelle_onset_spearman"] = rho_org + result["organelle_onset_n_tracks"] = n_tracks_org + + # Mean DTW cost + if "dtw_cost" in alignments.columns: + per_track_cost = alignments.groupby(["fov_name", "track_id"])["dtw_cost"].first() + result["mean_dtw_cost"] = float(per_track_cost.mean()) + result["median_dtw_cost"] = float(per_track_cost.median()) + + _logger.info( + "%s/%s: infection_auc=%.3f ap=%.3f, organelle_auc=%.3f ap=%.3f, onset_rho=%.3f (%d tracks)", + embedding_name, + dataset_id, + result.get("infection_auc", np.nan), + result.get("infection_ap", np.nan), + result.get("organelle_auc", np.nan), + result.get("organelle_ap", np.nan), + result.get("infection_onset_spearman", np.nan), + result.get("infection_onset_n_tracks", 0), + ) + + return result diff --git a/applications/dynaclr/src/dynaclr/evaluation/split_embeddings.py b/applications/dynaclr/src/dynaclr/evaluation/split_embeddings.py new file mode 100644 index 000000000..c55aedbcd --- /dev/null +++ b/applications/dynaclr/src/dynaclr/evaluation/split_embeddings.py @@ -0,0 +1,101 @@ +"""Split a combined embeddings zarr into one zarr per experiment. + +Reads the combined embeddings.zarr produced by the predict step, groups rows +by obs["experiment"], and writes one AnnData zarr per experiment under +output_dir/{experiment}.zarr. The combined zarr is removed after splitting. + +Usage +----- +dynaclr split-embeddings -c split.yaml + +Or with inline arguments: + +dynaclr split-embeddings --input /path/to/embeddings.zarr --output-dir /path/to/embeddings/ +""" + +from __future__ import annotations + +from pathlib import Path + +import click + + +def split_embeddings(input_path: Path, output_dir: Path) -> list[Path]: + """Split combined embeddings zarr into one zarr per experiment. + + Parameters + ---------- + input_path : Path + Path to the combined embeddings zarr (AnnData format). + Must have obs["experiment"] column. + output_dir : Path + Directory to write per-experiment zarrs. + Each experiment is written to output_dir/{experiment}.zarr. + + Returns + ------- + list[Path] + Paths to the written per-experiment zarrs. + """ + import anndata as ad + + if hasattr(ad, "settings") and hasattr(ad.settings, "allow_write_nullable_strings"): + ad.settings.allow_write_nullable_strings = True + import pandas as pd + + pd.options.future.infer_string = False + + click.echo(f"Loading embeddings from {input_path}") + adata = ad.read_zarr(input_path) + click.echo(f" {adata.n_obs} cells, {adata.n_vars} features") + + if "experiment" not in adata.obs.columns: + raise ValueError( + "embeddings zarr obs is missing 'experiment' column. " + "Re-run the predict step with the updated pipeline to include metadata." + ) + + experiments = adata.obs["experiment"].unique().tolist() + click.echo(f" {len(experiments)} experiments: {experiments}") + + output_dir.mkdir(parents=True, exist_ok=True) + written: list[Path] = [] + + for exp in experiments: + mask = adata.obs["experiment"] == exp + adata_exp = adata[mask].copy() + out_path = output_dir / f"{exp}.zarr" + click.echo(f" Writing {exp}: {adata_exp.n_obs} cells → {out_path}") + adata_exp.write_zarr(out_path) + written.append(out_path) + + click.echo(f"\nRemoving combined zarr: {input_path}") + import shutil + + shutil.rmtree(input_path) + + click.echo(f"\nWrote {len(written)} per-experiment zarrs to {output_dir}") + return written + + +@click.command(context_settings={"help_option_names": ["-h", "--help"]}) +@click.option( + "--input", + "input_path", + type=click.Path(exists=True, path_type=Path), + required=True, + help="Path to combined embeddings zarr", +) +@click.option( + "--output-dir", + type=click.Path(path_type=Path), + required=True, + help="Directory to write per-experiment zarrs", +) +def main(input_path: Path, output_dir: Path) -> None: + """Split a combined embeddings zarr into one zarr per experiment.""" + split_embeddings(input_path, output_dir) + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/src/dynaclr/info.py b/applications/dynaclr/src/dynaclr/info.py index fb8523aeb..624ab629e 100644 --- a/applications/dynaclr/src/dynaclr/info.py +++ b/applications/dynaclr/src/dynaclr/info.py @@ -12,6 +12,7 @@ def main(path: Path): """Print summary of an AnnData zarr store.""" import anndata as ad + import scipy.sparse as sp with warnings.catch_warnings(): warnings.simplefilter("ignore") @@ -19,7 +20,14 @@ def main(path: Path): click.echo(f"Path: {path}") click.echo(f"Shape: {adata.n_obs:,} obs × {adata.n_vars:,} vars") - click.echo(f"X: dtype={adata.X.dtype}, range=[{np.nanmin(adata.X):.4f}, {np.nanmax(adata.X):.4f}]") + X = adata.X + if sp.issparse(X): + X_dense = X.toarray() + else: + X_dense = X + sparse = sp.issparse(adata.X) + xmin, xmax = np.nanmin(X_dense), np.nanmax(X_dense) + click.echo(f"X: dtype={X_dense.dtype}, sparse={sparse}, range=[{xmin:.4f}, {xmax:.4f}]") if len(adata.obs.columns): click.echo("\nobs columns:") @@ -27,7 +35,7 @@ def main(path: Path): s = adata.obs[col] nuniq = s.nunique() if nuniq <= 10: - vals = ", ".join(str(v) for v in sorted(s.unique()[:10])) + vals = ", ".join(str(v) for v in sorted(s.dropna().unique()[:10])) click.echo(f" {col}: {s.dtype}, {nuniq} unique — [{vals}]") else: click.echo(f" {col}: {s.dtype}, {nuniq} unique") diff --git a/applications/dynaclr/tests/conftest.py b/applications/dynaclr/tests/conftest.py index 7b37bf5a9..855efe84f 100644 --- a/applications/dynaclr/tests/conftest.py +++ b/applications/dynaclr/tests/conftest.py @@ -144,6 +144,14 @@ def create_experiment( dtype=np.float32, ) arr[:] = rng.standard_normal(arr.shape).astype(np.float32) + tp_stats = { + str(t): {"mean": 1.0, "std": 0.5, "median": 1.0, "iqr": 1.0, "max": 2.0, "min": 0.0} + for t in range(n_t) + } + pos.zattrs["normalization"] = { + ch: {"fov_statistics": {"mean": 1.0, "std": 0.5}, "timepoint_statistics": tp_stats} + for ch in channel_names + } fov_name = f"{row}/{col}/{fov_idx}" csv_path = tracks_root / fov_name / "tracks.csv" make_tracks_csv( diff --git a/applications/dynaclr/tests/test_datamodule.py b/applications/dynaclr/tests/test_datamodule.py index 0907954d4..25cac7e52 100644 --- a/applications/dynaclr/tests/test_datamodule.py +++ b/applications/dynaclr/tests/test_datamodule.py @@ -5,7 +5,8 @@ from __future__ import annotations import pytest -import torch + +from viscy_data.cell_index import build_timelapse_cell_index # --------------------------------------------------------------------------- # Constants @@ -23,7 +24,7 @@ @pytest.fixture() def four_experiments(tmp_path, _create_experiment, _write_collection_yaml): - """Four synthetic experiments with collection YAML.""" + """Four synthetic experiments with collection YAML and cell index parquet.""" entries = [] for i, name in enumerate(["exp_a", "exp_b", "exp_c", "exp_d"]): row_letter = chr(ord("A") + i) @@ -37,12 +38,14 @@ def four_experiments(tmp_path, _create_experiment, _write_collection_yaml): ) ) collection_path = _write_collection_yaml(tmp_path, entries) - return collection_path, entries + parquet_path = tmp_path / "cell_index.parquet" + build_timelapse_cell_index(collection_path, parquet_path) + return parquet_path, entries @pytest.fixture() def two_experiments(tmp_path, _create_experiment, _write_collection_yaml): - """Two synthetic experiments for simpler tests.""" + """Two synthetic experiments with cell index parquet.""" entries = [ _create_experiment( tmp_path, @@ -60,7 +63,9 @@ def two_experiments(tmp_path, _create_experiment, _write_collection_yaml): ), ] collection_path = _write_collection_yaml(tmp_path, entries) - return collection_path, entries + parquet_path = tmp_path / "cell_index.parquet" + build_timelapse_cell_index(collection_path, parquet_path) + return parquet_path, entries @pytest.fixture() @@ -85,7 +90,9 @@ def multi_fov_experiments(tmp_path, _create_experiment, _write_collection_yaml): ), ] collection_path = _write_collection_yaml(tmp_path, entries) - return collection_path, entries + parquet_path = tmp_path / "cell_index.parquet" + build_timelapse_cell_index(collection_path, parquet_path) + return parquet_path, entries # --------------------------------------------------------------------------- @@ -100,9 +107,9 @@ def test_init_exposes_all_hyperparameters(self, two_experiments): """Instantiate with all hyperparameters explicitly set and verify storage.""" from dynaclr.data.datamodule import MultiExperimentDataModule - collection_path, _ = two_experiments + parquet_path, _ = two_experiments dm = MultiExperimentDataModule( - collection_path=str(collection_path), + cell_index_path=str(parquet_path), z_window=1, yx_patch_size=_YX_PATCH, final_yx_patch_size=_FINAL_YX_PATCH, @@ -148,9 +155,9 @@ def test_train_val_split_by_experiment(self, four_experiments): """With 4 experiments and val_experiments=[exp_c, exp_d], verify correct split.""" from dynaclr.data.datamodule import MultiExperimentDataModule - collection_path, _ = four_experiments + parquet_path, _ = four_experiments dm = MultiExperimentDataModule( - collection_path=str(collection_path), + cell_index_path=str(parquet_path), z_window=1, yx_patch_size=_YX_PATCH, final_yx_patch_size=_FINAL_YX_PATCH, @@ -183,9 +190,9 @@ def test_train_dataloader_uses_flexible_batch_sampler(self, two_experiments): """train_dataloader() returns a ThreadDataLoader with FlexibleBatchSampler.""" from dynaclr.data.datamodule import MultiExperimentDataModule - collection_path, _ = two_experiments + parquet_path, _ = two_experiments dm = MultiExperimentDataModule( - collection_path=str(collection_path), + cell_index_path=str(parquet_path), z_window=1, yx_patch_size=_YX_PATCH, final_yx_patch_size=_FINAL_YX_PATCH, @@ -215,128 +222,31 @@ def test_train_dataloader_uses_flexible_batch_sampler(self, two_experiments): assert sampler.temporal_enrichment is False -class TestValDataloaderNoBatchSampler: - """Validation should be deterministic without FlexibleBatchSampler.""" - - def test_val_dataloader_no_batch_sampler(self, two_experiments): - """val_dataloader uses simple sequential loading.""" - from dynaclr.data.datamodule import MultiExperimentDataModule - - collection_path, _ = two_experiments - dm = MultiExperimentDataModule( - collection_path=str(collection_path), - z_window=1, - yx_patch_size=_YX_PATCH, - final_yx_patch_size=_FINAL_YX_PATCH, - val_experiments=["exp_b"], - tau_range=(0.5, 2.0), - batch_size=8, - ) - dm.setup("fit") - val_dl = dm.val_dataloader() - - from viscy_data.sampler import FlexibleBatchSampler - - # val_dataloader should NOT use FlexibleBatchSampler - assert not isinstance(val_dl.batch_sampler, FlexibleBatchSampler), ( - "Validation should NOT use FlexibleBatchSampler" - ) +class TestTrainDataloaderWiresDDPTopology: + """train_dataloader must forward Trainer world_size/rank to the sampler.""" + def test_reads_world_size_and_rank_from_trainer(self, two_experiments): + from types import SimpleNamespace -class TestOnAfterBatchTransferAppliesTransforms: - """Verify on_after_batch_transfer applies transforms and ChannelDropout.""" - - def test_on_after_batch_transfer_applies_channel_dropout_and_transforms(self, two_experiments): - """Create a mock batch and verify on_after_batch_transfer processes it.""" from dynaclr.data.datamodule import MultiExperimentDataModule - collection_path, _ = two_experiments + parquet_path, _ = two_experiments dm = MultiExperimentDataModule( - collection_path=str(collection_path), + cell_index_path=str(parquet_path), z_window=1, yx_patch_size=_YX_PATCH, final_yx_patch_size=_FINAL_YX_PATCH, val_experiments=["exp_b"], tau_range=(0.5, 2.0), batch_size=8, - channel_dropout_channels=[1], - channel_dropout_prob=0.0, # No dropout for this test - ) - dm.setup("fit") - - # Create a synthetic batch dict - B, C, Z, Y, X = 4, 2, 1, 32, 32 - batch = { - "anchor": torch.randn(B, C, Z, Y, X), - "positive": torch.randn(B, C, Z, Y, X), - "anchor_norm_meta": [None] * B, - "positive_norm_meta": [None] * B, - } - - result = dm.on_after_batch_transfer(batch, 0) - - # Output should have anchor and positive as Tensors - assert isinstance(result["anchor"], torch.Tensor) - assert isinstance(result["positive"], torch.Tensor) - - # norm_meta keys should be consumed (removed) - assert "anchor_norm_meta" not in result - assert "positive_norm_meta" not in result - - # Final crop should reduce spatial size to final_yx_patch_size - assert result["anchor"].shape[-2:] == ( - _FINAL_YX_PATCH[0], - _FINAL_YX_PATCH[1], - ), f"Expected spatial {_FINAL_YX_PATCH}, got {result['anchor'].shape[-2:]}" - - -class TestChannelDropoutIntegration: - """Verify ChannelDropout behavior in train vs eval mode.""" - - def test_channel_dropout_integration(self, two_experiments): - """With p=1.0 on channel 1, training zeros ch1; eval preserves it.""" - from dynaclr.data.datamodule import MultiExperimentDataModule - - collection_path, _ = two_experiments - dm = MultiExperimentDataModule( - collection_path=str(collection_path), - z_window=1, - yx_patch_size=_YX_PATCH, - final_yx_patch_size=_FINAL_YX_PATCH, - val_experiments=["exp_b"], - tau_range=(0.5, 2.0), - batch_size=8, - channel_dropout_channels=[1], - channel_dropout_prob=1.0, # Always drop channel 1 + batch_group_by="experiment", + stratify_by="perturbation", + temporal_enrichment=False, ) dm.setup("fit") - - B, C, Z, Y, X = 4, 2, 1, 32, 32 - batch_train = { - "anchor": torch.randn(B, C, Z, Y, X).abs() + 0.1, # all positive - "positive": torch.randn(B, C, Z, Y, X).abs() + 0.1, - "anchor_norm_meta": [None] * B, - "positive_norm_meta": [None] * B, - } - - # Training mode: channel 1 should be zeroed - dm.channel_dropout.train() - result_train = dm.on_after_batch_transfer(batch_train, 0) - assert torch.all(result_train["anchor"][:, 1] == 0.0), "Training: channel 1 should be all zeros with p=1.0" - assert torch.all(result_train["positive"][:, 1] == 0.0), ( - "Training: positive channel 1 should be all zeros with p=1.0" - ) - - # Eval mode: channel 1 should be preserved - dm.channel_dropout.eval() - batch_eval = { - "anchor": torch.randn(B, C, Z, Y, X).abs() + 0.1, - "positive": torch.randn(B, C, Z, Y, X).abs() + 0.1, - "anchor_norm_meta": [None] * B, - "positive_norm_meta": [None] * B, - } - result_eval = dm.on_after_batch_transfer(batch_eval, 0) - assert not torch.all(result_eval["anchor"][:, 1] == 0.0), "Eval: channel 1 should NOT be zeroed" + dm.__dict__["trainer"] = SimpleNamespace(world_size=4, global_rank=2) + sampler = dm.train_dataloader().batch_sampler + assert (sampler.num_replicas, sampler.rank) == (4, 2) class TestFovLevelSplit: @@ -346,9 +256,9 @@ def test_fov_split_no_overlap(self, multi_fov_experiments): """With split_ratio=0.6, FOVs are split within each experiment with no overlap.""" from dynaclr.data.datamodule import MultiExperimentDataModule - collection_path, _ = multi_fov_experiments + parquet_path, _ = multi_fov_experiments dm = MultiExperimentDataModule( - collection_path=str(collection_path), + cell_index_path=str(parquet_path), z_window=1, yx_patch_size=_YX_PATCH, final_yx_patch_size=_FINAL_YX_PATCH, @@ -381,9 +291,9 @@ def test_fov_split_ratio_1_no_val(self, multi_fov_experiments): """With split_ratio=1.0, all FOVs go to train and val_dataset is None.""" from dynaclr.data.datamodule import MultiExperimentDataModule - collection_path, _ = multi_fov_experiments + parquet_path, _ = multi_fov_experiments dm = MultiExperimentDataModule( - collection_path=str(collection_path), + cell_index_path=str(parquet_path), z_window=1, yx_patch_size=_YX_PATCH, final_yx_patch_size=_FINAL_YX_PATCH, @@ -401,9 +311,9 @@ def test_fov_split_default_val_experiments(self, multi_fov_experiments): """Default val_experiments=[] triggers FOV split.""" from dynaclr.data.datamodule import MultiExperimentDataModule - collection_path, _ = multi_fov_experiments + parquet_path, _ = multi_fov_experiments dm = MultiExperimentDataModule( - collection_path=str(collection_path), + cell_index_path=str(parquet_path), z_window=1, yx_patch_size=_YX_PATCH, final_yx_patch_size=_FINAL_YX_PATCH, @@ -428,9 +338,9 @@ def test_positive_cell_source_self_stores_on_dm(self, two_experiments): """positive_cell_source='self' is stored and passed to datasets.""" from dynaclr.data.datamodule import MultiExperimentDataModule - collection_path, _ = two_experiments + parquet_path, _ = two_experiments dm = MultiExperimentDataModule( - collection_path=str(collection_path), + cell_index_path=str(parquet_path), z_window=1, yx_patch_size=_YX_PATCH, final_yx_patch_size=_FINAL_YX_PATCH, @@ -447,9 +357,9 @@ def test_positive_match_columns_stored_on_dm(self, two_experiments): """positive_match_columns is stored on datamodule.""" from dynaclr.data.datamodule import MultiExperimentDataModule - collection_path, _ = two_experiments + parquet_path, _ = two_experiments dm = MultiExperimentDataModule( - collection_path=str(collection_path), + cell_index_path=str(parquet_path), z_window=1, yx_patch_size=_YX_PATCH, final_yx_patch_size=_FINAL_YX_PATCH, @@ -464,9 +374,9 @@ def test_positive_channel_source_any_stored(self, two_experiments): """positive_channel_source='any' is stored on datamodule and dataset.""" from dynaclr.data.datamodule import MultiExperimentDataModule - collection_path, _ = two_experiments + parquet_path, _ = two_experiments dm = MultiExperimentDataModule( - collection_path=str(collection_path), + cell_index_path=str(parquet_path), z_window=1, yx_patch_size=_YX_PATCH, final_yx_patch_size=_FINAL_YX_PATCH, @@ -483,9 +393,9 @@ def test_self_positive_all_tracks_are_valid_anchors(self, two_experiments): """With positive_cell_source='self', all tracks become valid anchors.""" from dynaclr.data.datamodule import MultiExperimentDataModule - collection_path, _ = two_experiments + parquet_path, _ = two_experiments dm = MultiExperimentDataModule( - collection_path=str(collection_path), + cell_index_path=str(parquet_path), z_window=1, yx_patch_size=_YX_PATCH, final_yx_patch_size=_FINAL_YX_PATCH, @@ -495,6 +405,6 @@ def test_self_positive_all_tracks_are_valid_anchors(self, two_experiments): positive_cell_source="self", ) dm.setup("fit") - n_tracks = len(dm.train_dataset.index.tracks) + n_unique_cells = dm.train_dataset.index.tracks["cell_id"].nunique() n_anchors = len(dm.train_dataset.index.valid_anchors) - assert n_anchors == n_tracks + assert n_anchors == n_unique_cells diff --git a/applications/dynaclr/tests/test_dataset.py b/applications/dynaclr/tests/test_dataset.py index c63e5f94e..ab058d369 100644 --- a/applications/dynaclr/tests/test_dataset.py +++ b/applications/dynaclr/tests/test_dataset.py @@ -213,74 +213,6 @@ def test_getitems_returns_norm_meta(self, single_experiment_index): assert len(batch["anchor_norm_meta"]) == 1 -class TestPositiveSampling: - """Test lineage-aware positive selection.""" - - def test_positive_same_lineage(self, single_experiment_index): - """Positive comes from same lineage_id at t+tau (tau>0).""" - from dynaclr.data.dataset import MultiExperimentTripletDataset - - ds = MultiExperimentTripletDataset( - index=single_experiment_index, - fit=True, - ) - # Get anchor info - anchor_row = ds.index.valid_anchors.iloc[0] - anchor_lineage = anchor_row["lineage_id"] - anchor_t = anchor_row["t"] - - # Call _find_positive directly to verify lineage matching - rng = np.random.default_rng(42) - pos_row = ds._find_positive(anchor_row, rng) - assert pos_row is not None, "Should find a positive" - assert pos_row["lineage_id"] == anchor_lineage, ( - f"Positive lineage {pos_row['lineage_id']} != anchor {anchor_lineage}" - ) - assert pos_row["t"] > anchor_t, f"Positive t={pos_row['t']} should be > anchor t={anchor_t}" - - def test_positive_through_division(self, lineage_index): - """When anchor is on parent track that divides, positive can be a daughter.""" - from dynaclr.data.dataset import MultiExperimentTripletDataset - - ds = MultiExperimentTripletDataset( - index=lineage_index, - fit=True, - ) - - # Tracks 0, 1, 2 share the same lineage_id due to parent_map={1:0, 2:0} - # All three tracks should share one lineage (rooted at track 0) - parent_lineage = lineage_index.tracks[lineage_index.tracks["global_track_id"].str.endswith("_0")][ - "lineage_id" - ].iloc[0] - daughter1_lineage = lineage_index.tracks[lineage_index.tracks["global_track_id"].str.endswith("_1")][ - "lineage_id" - ].iloc[0] - daughter2_lineage = lineage_index.tracks[lineage_index.tracks["global_track_id"].str.endswith("_2")][ - "lineage_id" - ].iloc[0] - assert parent_lineage == daughter1_lineage == daughter2_lineage, ( - f"Lineage mismatch: parent={parent_lineage}, d1={daughter1_lineage}, d2={daughter2_lineage}" - ) - - # Find an anchor on the parent track - parent_anchors = ds.index.valid_anchors[ds.index.valid_anchors["global_track_id"].str.endswith("_0")] - assert len(parent_anchors) > 0, "Parent track should have valid anchors" - - # Verify positive sampling can reach daughters (same lineage, different track) - rng = np.random.default_rng(42) - anchor_row = parent_anchors.iloc[0] - found_daughter = False - for _ in range(50): - pos_row = ds._find_positive(anchor_row, rng) - if pos_row is not None and pos_row["global_track_id"] != anchor_row["global_track_id"]: - found_daughter = True - assert pos_row["lineage_id"] == anchor_row["lineage_id"] - break - # Even if we don't find a daughter every time, the lineage is correct - # (parent and daughter share lineage so any positive is valid) - assert found_daughter or True, "Test informational -- daughters reachable" - - class TestChannelRemapping: """Test that per-experiment channel indices are used correctly.""" @@ -418,6 +350,70 @@ def test_int_gt1_raises(self, single_experiment_index): ) +class TestMixedChannelCountErrors: + """``channels_per_sample=None`` on a parquet whose experiments have different + channel counts must raise a clear error instead of a cryptic torch.stack + failure deep in a dataloader thread.""" + + def test_raises_when_experiments_have_different_channel_counts(self, tmp_path, _make_tracks_csv, hcs_dims): + from dynaclr.data.dataset import MultiExperimentTripletDataset + from dynaclr.data.experiment import ExperimentRegistry + from dynaclr.data.index import MultiExperimentIndex + from viscy_data.collection import ChannelEntry, Collection, ExperimentEntry + + # exp_a: 2 channels; exp_b: 1 channel. + zarr_a, tracks_a = _create_zarr_and_tracks( + tmp_path, + name="exp_a", + channel_names=["Phase", "GFP"], + wells=[("A", "1")], + hcs_dims=hcs_dims, + _make_tracks_csv=_make_tracks_csv, + ) + zarr_b, tracks_b = _create_zarr_and_tracks( + tmp_path, + name="exp_b", + channel_names=["Phase"], + wells=[("A", "1")], + hcs_dims=hcs_dims, + _make_tracks_csv=_make_tracks_csv, + ) + registry = ExperimentRegistry( + collection=Collection( + name="test", + experiments=[ + ExperimentEntry( + name="exp_a", + data_path=str(zarr_a), + tracks_path=str(tracks_a), + channels=[ChannelEntry(name="Phase", marker="Phase"), ChannelEntry(name="GFP", marker="GFP")], + channel_names=["Phase", "GFP"], + perturbation_wells={"c": ["A/1"]}, + interval_minutes=30.0, + ), + ExperimentEntry( + name="exp_b", + data_path=str(zarr_b), + tracks_path=str(tracks_b), + channels=[ChannelEntry(name="Phase", marker="Phase")], + channel_names=["Phase"], + perturbation_wells={"c": ["A/1"]}, + interval_minutes=30.0, + ), + ], + ), + z_window=1, + ) + index = MultiExperimentIndex(registry=registry, yx_patch_size=_YX_PATCH, tau_range_hours=(0.5, 2.0)) + ds = MultiExperimentTripletDataset(index=index, fit=True, channels_per_sample=None) + + va = index.valid_anchors + idx_a = int(va.index[va["experiment"] == "exp_a"][0]) + idx_b = int(va.index[va["experiment"] == "exp_b"][0]) + with pytest.raises(RuntimeError, match="different channel counts"): + ds.__getitems__([idx_a, idx_b]) + + class TestDatasetLength: """Test dataset length matches valid_anchors.""" @@ -543,57 +539,6 @@ def test_self_positive_pixel_values_identical(self, single_experiment_index): ) -class TestColumnMatchPositive: - """Tests for positive_cell_source='lookup' with non-lineage columns.""" - - @staticmethod - def _build_index_with_gene_name(tmp_path: Path, _make_tracks_csv, hcs_dims: dict) -> "MultiExperimentIndex": - """Build an index where tracks have gene_name/reporter columns for matching.""" - index = _build_index(tmp_path, _make_tracks_csv=_make_tracks_csv, hcs_dims=hcs_dims) - n = len(index.tracks) - index.tracks["gene_name"] = ["RPL35" if i % 2 == 0 else "TP53" for i in range(n)] - index.tracks["reporter"] = "Phase" - index.valid_anchors["gene_name"] = ["RPL35" if i % 2 == 0 else "TP53" for i in range(len(index.valid_anchors))] - index.valid_anchors["reporter"] = "Phase" - return index - - def test_column_match_positive_different_cell(self, tmp_path, _make_tracks_csv, hcs_dims): - """positive_match_columns=['gene_name','reporter'] finds different cell with same values.""" - from dynaclr.data.dataset import MultiExperimentTripletDataset - - index = self._build_index_with_gene_name(tmp_path, _make_tracks_csv, hcs_dims) - ds = MultiExperimentTripletDataset( - index=index, - fit=True, - positive_cell_source="lookup", - positive_match_columns=["gene_name", "reporter"], - ) - rng = np.random.default_rng(0) - anchor_row = ds.index.valid_anchors.iloc[0] - pos = ds._find_positive(anchor_row, rng) - assert pos is not None, "Should find a column-match positive" - assert pos["gene_name"] == anchor_row["gene_name"], "Positive must share gene_name" - assert pos["reporter"] == anchor_row["reporter"], "Positive must share reporter" - assert pos.name != anchor_row.name, "Positive must be a different cell" - - def test_column_match_no_self_as_positive(self, tmp_path, _make_tracks_csv, hcs_dims): - """Column-match lookup never returns the anchor itself.""" - from dynaclr.data.dataset import MultiExperimentTripletDataset - - index = self._build_index_with_gene_name(tmp_path, _make_tracks_csv, hcs_dims) - ds = MultiExperimentTripletDataset( - index=index, - fit=True, - positive_cell_source="lookup", - positive_match_columns=["gene_name", "reporter"], - ) - rng = np.random.default_rng(42) - for _, anchor_row in ds.index.valid_anchors.iterrows(): - pos = ds._find_positive(anchor_row, rng) - if pos is not None: - assert pos.name != anchor_row.name, "Positive must not be the anchor itself" - - class TestTimepointStatisticsResolution: """Verify that timepoint_statistics norm_meta resolves the correct timepoint.""" diff --git a/applications/dynaclr/tests/test_index.py b/applications/dynaclr/tests/test_index.py index 08a6fbd46..7f17de18d 100644 --- a/applications/dynaclr/tests/test_index.py +++ b/applications/dynaclr/tests/test_index.py @@ -7,7 +7,7 @@ import numpy as np import pandas as pd import pytest -from iohub.ngff import Position, open_ome_zarr +from iohub.ngff import open_ome_zarr from dynaclr.data.experiment import ExperimentRegistry from dynaclr.data.index import MultiExperimentIndex @@ -196,7 +196,6 @@ def test_required_columns_present(self, two_experiment_setup): "y", "x", "z", - "position", "fov_name", "well_name", "experiment", @@ -234,22 +233,6 @@ def test_exclude_fovs_filter(self, two_experiment_setup): # Removed 1 FOV from each experiment: 2 * (4 - 1) * 5 * 10 = 300 assert len(index.tracks) == 300 - def test_positions_stored(self, two_experiment_setup): - """Position objects are stored in self.positions.""" - registry, _, _ = two_experiment_setup - index = MultiExperimentIndex(registry=registry, yx_patch_size=_YX_PATCH) - # 2 experiments * 2 wells * 2 FOVs = 8 positions - assert len(index.positions) == 8 - - def test_position_column_is_position_object(self, two_experiment_setup): - """'position' column contains iohub Position objects.""" - registry, _, _ = two_experiment_setup - index = MultiExperimentIndex(registry=registry, yx_patch_size=_YX_PATCH) - from iohub.ngff import Position - - sample_pos = index.tracks.iloc[0]["position"] - assert isinstance(sample_pos, Position) - def test_parallel_load_matches_serial(self, two_experiment_setup): """Parallel loading (num_workers=2) produces same result as serial (num_workers=1).""" registry, _, _ = two_experiment_setup @@ -261,10 +244,9 @@ def test_parallel_load_matches_serial(self, two_experiment_setup): serial_tracks = index_serial.tracks.sort_values(sort_cols).reset_index(drop=True) parallel_tracks = index_parallel.tracks.sort_values(sort_cols).reset_index(drop=True) - # Drop position column (object identity differs across processes) pd.testing.assert_frame_equal( - serial_tracks.drop(columns=["position"]), - parallel_tracks.drop(columns=["position"]), + serial_tracks, + parallel_tracks, check_like=True, ) assert len(index_serial.valid_anchors) == len(index_parallel.valid_anchors) @@ -1013,8 +995,8 @@ def test_parquet_valid_anchors_count(self, two_experiment_setup, tmp_path): n_channels = 2 # _CHANNEL_NAMES_A / _CHANNEL_NAMES_B each have 2 channels assert len(parquet_index.valid_anchors) == len(legacy_index.valid_anchors) * n_channels - def test_parquet_positions_resolved(self, two_experiment_setup, tmp_path): - """position column contains iohub Position objects.""" + def test_parquet_dims_from_columns(self, two_experiment_setup, tmp_path): + """Parquet path reads Y_shape/X_shape from parquet columns (no zarr opens).""" registry, _, _ = two_experiment_setup parquet_path = _build_cell_index_parquet(tmp_path, registry) @@ -1023,8 +1005,9 @@ def test_parquet_positions_resolved(self, two_experiment_setup, tmp_path): yx_patch_size=_YX_PATCH, cell_index_path=parquet_path, ) - sample_pos = index.tracks.iloc[0]["position"] - assert isinstance(sample_pos, Position) + assert "Y_shape" in index.tracks.columns + assert "X_shape" in index.tracks.columns + assert "position" not in index.tracks.columns # no longer stored def test_parquet_border_clamping(self, tmp_path, _create_experiment): """y_clamp, x_clamp are computed correctly from parquet path.""" diff --git a/applications/dynaclr/tests/test_mmd.py b/applications/dynaclr/tests/test_mmd.py new file mode 100644 index 000000000..1b02196f2 --- /dev/null +++ b/applications/dynaclr/tests/test_mmd.py @@ -0,0 +1,482 @@ +"""Tests for MMD perturbation evaluation.""" + +from __future__ import annotations + +import anndata as ad +import numpy as np +import pandas as pd +import pytest + +from dynaclr.evaluation.mmd.compute_mmd import run_mmd_analysis, run_mmd_pooled +from dynaclr.evaluation.mmd.config import ComparisonSpec, MMDEvalConfig, MMDPooledConfig, MMDSettings +from viscy_utils.evaluation.mmd import compute_mmd_unbiased, median_heuristic, mmd_permutation_test + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_COMP = [ComparisonSpec(cond_a="uninfected", cond_b="ZIKV", label="uninf vs ZIKV")] +_SETTINGS_FAST = MMDSettings(n_permutations=50) + + +def _cfg(**kwargs) -> MMDEvalConfig: + return MMDEvalConfig(input_path="dummy", output_dir="/tmp", comparisons=_COMP, **kwargs) + + +def _make_adata( + n_cells: int = 200, + n_features: int = 32, + markers: list[str] | None = None, + treatment_shift: float = 3.0, + seed: int = 0, +) -> ad.AnnData: + """Synthetic AnnData with two markers and two perturbation groups. + + TOMM20 has a large shift between uninfected and ZIKV (detectable MMD). + Phase3D has no shift (null). + """ + rng = np.random.default_rng(seed) + if markers is None: + markers = ["Phase3D", "TOMM20"] + n_per_group = n_cells // (2 * len(markers)) + + rows = [] + emb_list = [] + for marker in markers: + for perturbation in ["uninfected", "ZIKV"]: + for t in range(n_per_group): + shift = treatment_shift if (perturbation == "ZIKV" and marker == "TOMM20") else 0.0 + emb = rng.normal(loc=shift, scale=1.0, size=n_features) + emb_list.append(emb) + rows.append( + { + "experiment": "test_exp", + "marker": marker, + "perturbation": perturbation, + "hours_post_perturbation": float(t % 6), + } + ) + X = np.stack(emb_list) + obs = pd.DataFrame(rows) + return ad.AnnData(X=X.astype(np.float32), obs=obs) + + +def _make_temporal_adata(n_features: int = 16, seed: int = 0) -> ad.AnnData: + """AnnData where ZIKV treatment effect increases with hours_post_perturbation.""" + rng = np.random.default_rng(seed) + rows = [] + emb_list = [] + hours_bins = [1.0, 3.0, 6.0, 12.0] + for marker in ["TOMM20"]: + for _ in range(50): + emb_list.append(rng.normal(0.0, 1.0, n_features)) + rows.append( + {"experiment": "e", "marker": marker, "perturbation": "uninfected", "hours_post_perturbation": 0.0} + ) + for hpi in hours_bins: + shift = hpi / 3.0 + for _ in range(30): + emb_list.append(rng.normal(shift, 1.0, n_features)) + rows.append( + {"experiment": "e", "marker": marker, "perturbation": "ZIKV", "hours_post_perturbation": hpi} + ) + X = np.stack(emb_list).astype(np.float32) + obs = pd.DataFrame(rows) + return ad.AnnData(X=X, obs=obs) + + +# --------------------------------------------------------------------------- +# Core MMD tests +# --------------------------------------------------------------------------- + + +def test_mmd_identical_distributions(): + rng = np.random.default_rng(1) + X = rng.normal(0, 1, (200, 16)) + Y = rng.normal(0, 1, (200, 16)) + mmd2, p_value, _ = mmd_permutation_test(X, Y, n_permutations=200, seed=42) + assert mmd2 < 0.1 + assert p_value > 0.05 + + +def test_mmd_different_distributions(): + rng = np.random.default_rng(2) + X = rng.normal(0.0, 1.0, (200, 16)) + Y = rng.normal(5.0, 1.0, (200, 16)) + mmd2, p_value, _ = mmd_permutation_test(X, Y, n_permutations=200, seed=42) + assert mmd2 > 0.1 + assert p_value < 0.05 + + +def test_mmd_permutation_null(): + rng = np.random.default_rng(3) + X = rng.normal(0, 1, (100, 8)) + Y = rng.normal(0, 1, (100, 8)) + _, _, null = mmd_permutation_test(X, Y, n_permutations=100, seed=0) + assert len(null) == 100 + assert np.all(np.isfinite(null)) + + +def test_median_heuristic_positive(): + rng = np.random.default_rng(4) + X = rng.normal(0, 1, (50, 8)) + Y = rng.normal(2, 1, (50, 8)) + assert median_heuristic(X, Y) > 0 + + +def test_compute_mmd_unbiased_symmetric(): + rng = np.random.default_rng(5) + X = rng.normal(0, 1, (100, 8)) + Y = rng.normal(1, 1, (100, 8)) + bw = median_heuristic(X, Y) + assert abs(compute_mmd_unbiased(X, Y, bw) - compute_mmd_unbiased(Y, X, bw)) < 1e-10 + + +# --------------------------------------------------------------------------- +# run_mmd_analysis tests +# --------------------------------------------------------------------------- + + +def test_run_mmd_analysis_columns(): + adata = _make_adata() + df = run_mmd_analysis(adata, _cfg(mmd=_SETTINGS_FAST)) + expected = { + "experiment", + "marker", + "cond_a", + "cond_b", + "label", + "hours_bin_start", + "hours_bin_end", + "n_a", + "n_b", + "mmd2", + "p_value", + "bandwidth", + "effect_size", + "activity_zscore", + "embedding_key", + } + assert expected.issubset(df.columns), f"Missing columns: {expected - set(df.columns)}" + + +def test_run_mmd_analysis_explicit_comparisons(): + adata = _make_adata() + df = run_mmd_analysis(adata, _cfg(mmd=_SETTINGS_FAST)) + assert set(df["cond_b"].unique()) == {"ZIKV"} + assert set(df["cond_a"].unique()) == {"uninfected"} + assert df["label"].iloc[0] == "uninf vs ZIKV" + + +def test_run_mmd_analysis_per_marker(): + adata = _make_adata() + df = run_mmd_analysis(adata, _cfg(mmd=_SETTINGS_FAST)) + assert set(df["marker"].unique()) == {"Phase3D", "TOMM20"} + assert len(df) == 2 # one row per (marker, comparison) in aggregate mode + + +def test_run_mmd_analysis_significant_for_shifted_marker(): + adata = _make_adata(n_cells=600, treatment_shift=4.0) + df = run_mmd_analysis(adata, _cfg(mmd=MMDSettings(n_permutations=200))) + tomm = df[df["marker"] == "TOMM20"]["mmd2"].iloc[0] + phase = df[df["marker"] == "Phase3D"]["mmd2"].iloc[0] + assert tomm > phase + assert df[df["marker"] == "TOMM20"]["p_value"].iloc[0] < 0.05 + + +def test_run_mmd_analysis_missing_cond_returns_nan(): + """When cond_a is absent from the data, result is NaN (not an error).""" + adata = _make_adata() + cfg = MMDEvalConfig( + input_path="dummy", + output_dir="/tmp", + comparisons=[ComparisonSpec(cond_a="MISSING", cond_b="ZIKV", label="missing vs ZIKV")], + mmd=_SETTINGS_FAST, + ) + df = run_mmd_analysis(adata, cfg) + assert df["mmd2"].isna().all() + + +def test_run_mmd_analysis_temporal_bins(): + adata = _make_temporal_adata() + cfg = _cfg(mmd=MMDSettings(n_permutations=100), temporal_bins=[0.0, 2.0, 5.0, 8.0, 15.0]) + df = run_mmd_analysis(adata, cfg) + valid = df.dropna(subset=["mmd2"]).sort_values("hours_bin_start") + assert len(valid) >= 2 + assert valid.iloc[-1]["mmd2"] > valid.iloc[0]["mmd2"] + + +def test_run_mmd_analysis_min_cells_skip(): + adata = _make_temporal_adata() + cfg = _cfg( + mmd=MMDSettings(n_permutations=50, min_cells=5), + temporal_bins=[0.0, 0.5, 1.0, 100.0], + ) + df = run_mmd_analysis(adata, cfg) + first_bin = df[(df["hours_bin_start"] == 0.0) & (df["hours_bin_end"] == 0.5)] + assert len(first_bin) > 0 + assert first_bin["mmd2"].isna().all() + + +def test_run_mmd_analysis_batch_centering(): + rng = np.random.default_rng(7) + n, n_feat = 100, 8 + rows, embs = [], [] + for exp, offset in [("exp_A", 0.0), ("exp_B", 10.0)]: + for pert in ["uninfected", "ZIKV"]: + shift = 3.0 if pert == "ZIKV" else 0.0 + for _ in range(n): + embs.append(rng.normal(offset + shift, 1.0, n_feat)) + rows.append( + {"experiment": exp, "marker": "TOMM20", "perturbation": pert, "hours_post_perturbation": 1.0} + ) + X = np.stack(embs).astype(np.float32) + obs = pd.DataFrame(rows) + adata = ad.AnnData(X=X, obs=obs) + + cfg_test = MMDEvalConfig( + input_path="dummy", + output_dir="/tmp", + comparisons=_COMP, + mmd=MMDSettings(n_permutations=100), + ) + df_no_center = run_mmd_analysis(adata, cfg_test) + + centered = X.copy() + for exp in obs["experiment"].unique(): + for marker in obs["marker"].unique(): + mask = ((obs["experiment"] == exp) & (obs["marker"] == marker)).to_numpy() + if mask.sum() > 0: + centered[mask] -= centered[mask].mean(axis=0) + adata_centered = ad.AnnData(X=centered, obs=obs) + df_centered = run_mmd_analysis(adata_centered, cfg_test) + + tomm_uncentered = df_no_center[df_no_center["marker"] == "TOMM20"]["mmd2"].iloc[0] + tomm_centered = df_centered[df_centered["marker"] == "TOMM20"]["mmd2"].iloc[0] + assert tomm_centered <= tomm_uncentered * 1.5, ( + f"Centering should reduce MMD. centered={tomm_centered:.4f}, uncentered={tomm_uncentered:.4f}" + ) + + +def test_run_mmd_analysis_obs_filter(): + """obs_filter restricts analysis to matching rows before computing MMD.""" + rng = np.random.default_rng(42) + n, n_feat = 60, 8 + rows, embs = [], [] + for microscope in ["dragonfly", "mantis"]: + for perturbation in ["uninfected", "ZIKV"]: + shift = 10.0 if perturbation == "ZIKV" else 0.0 + for _ in range(n): + embs.append(rng.normal(shift, 1.0, n_feat)) + rows.append( + { + "experiment": "e", + "marker": "TOMM20", + "perturbation": perturbation, + "microscope": microscope, + "hours_post_perturbation": 1.0, + } + ) + + adata = ad.AnnData(X=np.stack(embs).astype(np.float32), obs=pd.DataFrame(rows)) + + # Compare microscopes on uninfected only — should be near zero (same distribution) + comp = [ComparisonSpec(cond_a="dragonfly", cond_b="mantis", label="dragonfly vs mantis")] + cfg = MMDEvalConfig( + input_path="dummy", + output_dir="/tmp", + comparisons=comp, + group_by="microscope", + obs_filter={"perturbation": "uninfected"}, + mmd=MMDSettings(n_permutations=50), + ) + df = run_mmd_analysis(adata, cfg) + assert len(df) == 1 + # MMD on unfiltered data would be dominated by the ZIKV shift; filtered should be small + assert df["mmd2"].iloc[0] < 1.0, f"Expected near-zero MMD on uninfected-only, got {df['mmd2'].iloc[0]:.4f}" + + +# --------------------------------------------------------------------------- +# Activity z-score tests +# --------------------------------------------------------------------------- + + +def test_activity_zscore_shifted(): + """Strongly shifted distributions produce a large positive activity_zscore.""" + adata = _make_adata(n_cells=600, treatment_shift=5.0) + df = run_mmd_analysis(adata, _cfg(mmd=MMDSettings(n_permutations=200))) + tomm = df[df["marker"] == "TOMM20"]["activity_zscore"].iloc[0] + assert tomm > 1.0, f"Expected activity_zscore > 1 for shifted distribution, got {tomm:.3f}" + + +def test_activity_zscore_identical(): + """Identical distributions produce activity_zscore near zero.""" + adata = _make_adata(n_cells=400, treatment_shift=0.0) + df = run_mmd_analysis(adata, _cfg(mmd=MMDSettings(n_permutations=200))) + for _, row in df.iterrows(): + assert np.isfinite(row["activity_zscore"]) or np.isnan(row["activity_zscore"]) + + +# --------------------------------------------------------------------------- +# Sample balancing tests +# --------------------------------------------------------------------------- + + +def test_balance_samples(): + """With balance_samples=True, both groups have equal size (reflected in n_a, n_b).""" + rng = np.random.default_rng(10) + n_small, n_large = 30, 120 + rows, embs = [], [] + for pert, n in [("uninfected", n_large), ("ZIKV", n_small)]: + for _ in range(n): + embs.append(rng.normal(0.0, 1.0, 8)) + rows.append({"experiment": "e", "marker": "TOMM20", "perturbation": pert, "hours_post_perturbation": 1.0}) + adata = ad.AnnData(X=np.stack(embs).astype(np.float32), obs=pd.DataFrame(rows)) + cfg = _cfg(mmd=MMDSettings(n_permutations=50, balance_samples=True, max_cells=None)) + df = run_mmd_analysis(adata, cfg) + row = df[df["marker"] == "TOMM20"].iloc[0] + assert row["n_a"] == row["n_b"], f"Expected equal group sizes, got n_a={row['n_a']}, n_b={row['n_b']}" + + +# --------------------------------------------------------------------------- +# Bandwidth sharing tests +# --------------------------------------------------------------------------- + + +def test_share_bandwidth_from(): + """With share_bandwidth_from set, the bandwidth is the same across comparisons.""" + adata = _make_adata(n_cells=400, treatment_shift=2.0) + # Add a second condition + obs = adata.obs.copy() + extra_rows = obs[obs["perturbation"] == "ZIKV"].copy() + extra_rows["perturbation"] = "DENV" + extra_obs = pd.concat([obs, extra_rows], ignore_index=True) + extra_emb = np.concatenate([adata.X, adata.X[obs["perturbation"] == "ZIKV"]], axis=0) + adata2 = ad.AnnData(X=extra_emb.astype(np.float32), obs=extra_obs) + + comps = [ + ComparisonSpec(cond_a="uninfected", cond_b="ZIKV", label="baseline"), + ComparisonSpec(cond_a="uninfected", cond_b="DENV", label="treatment"), + ] + cfg = MMDEvalConfig( + input_path="dummy", + output_dir="/tmp", + comparisons=comps, + mmd=MMDSettings(n_permutations=50, share_bandwidth_from="baseline"), + ) + df = run_mmd_analysis(adata2, cfg) + for marker in df["marker"].unique(): + sub = df[df["marker"] == marker].dropna(subset=["bandwidth"]) + if len(sub) == 2: + assert abs(sub["bandwidth"].iloc[0] - sub["bandwidth"].iloc[1]) < 1e-6, ( + f"Expected shared bandwidth for {marker}, got {sub['bandwidth'].to_numpy()}" + ) + + +# --------------------------------------------------------------------------- +# Temporal bins (explicit edges) tests +# --------------------------------------------------------------------------- + + +def test_temporal_bins_explicit(): + """temporal_bins produces one row per bin per comparison.""" + adata = _make_temporal_adata() + cfg = _cfg(mmd=MMDSettings(n_permutations=50), temporal_bins=[0.0, 2.0, 5.0, 8.0, 15.0]) + df = run_mmd_analysis(adata, cfg) + valid = df.dropna(subset=["mmd2"]).sort_values("hours_bin_start") + assert len(valid) >= 2, "Expected at least 2 valid temporal bins" + assert valid.iloc[-1]["mmd2"] > valid.iloc[0]["mmd2"], "MMD should increase with shift" + + +def test_temporal_bins_min_cells_skip(): + """Bins with fewer than min_cells cells produce NaN rows.""" + adata = _make_temporal_adata() + cfg = _cfg( + mmd=MMDSettings(n_permutations=50, min_cells=5), + temporal_bins=[0.0, 0.5, 1.0, 100.0], + ) + df = run_mmd_analysis(adata, cfg) + first_bin = df[(df["hours_bin_start"] == 0.0) & (df["hours_bin_end"] == 0.5)] + assert len(first_bin) > 0 + assert first_bin["mmd2"].isna().all() + + +def test_temporal_bins_mutually_exclusive(): + """Setting both temporal_bin_size and temporal_bins raises ValidationError.""" + with pytest.raises(Exception): + MMDEvalConfig( + input_path="dummy", + output_dir="/tmp", + comparisons=_COMP, + temporal_bin_size=4.0, + temporal_bins=[0.0, 4.0, 8.0], + ) + + +# --------------------------------------------------------------------------- +# Pooled mode tests +# --------------------------------------------------------------------------- + + +def _save_adata_zarr(adata: ad.AnnData, path: str) -> None: + import os + import shutil + + if os.path.exists(path): + shutil.rmtree(path) + adata.write_zarr(path) + + +def test_run_mmd_pooled_columns(tmp_path): + """run_mmd_pooled returns expected columns including activity_zscore and q_value.""" + adata1 = _make_adata(n_cells=200, seed=0) + adata2 = _make_adata(n_cells=200, seed=1) + p1 = str(tmp_path / "exp1.zarr") + p2 = str(tmp_path / "exp2.zarr") + _save_adata_zarr(adata1, p1) + _save_adata_zarr(adata2, p2) + + cfg = MMDPooledConfig( + input_paths=[p1, p2], + output_dir=str(tmp_path / "out"), + comparisons=_COMP, + mmd=MMDSettings(n_permutations=50), + ) + df = run_mmd_pooled(cfg) + expected = { + "marker", + "cond_a", + "cond_b", + "label", + "mmd2", + "p_value", + "bandwidth", + "effect_size", + "activity_zscore", + "q_value", + } + assert expected.issubset(df.columns), f"Missing: {expected - set(df.columns)}" + + +def test_run_mmd_pooled_condition_aliases(tmp_path): + """condition_aliases remaps variant condition names to canonical names.""" + rng = np.random.default_rng(99) + rows, embs = [], [] + for pert in ["uninfected1", "uninfected2", "ZIKV"]: + shift = 3.0 if pert == "ZIKV" else 0.0 + for _ in range(60): + embs.append(rng.normal(shift, 1.0, 16)) + rows.append({"experiment": "e", "marker": "TOMM20", "perturbation": pert, "hours_post_perturbation": 1.0}) + adata = ad.AnnData(X=np.stack(embs).astype(np.float32), obs=pd.DataFrame(rows)) + p = str(tmp_path / "exp.zarr") + _save_adata_zarr(adata, p) + + cfg = MMDPooledConfig( + input_paths=[p], + output_dir=str(tmp_path / "out"), + comparisons=[ComparisonSpec(cond_a="uninfected", cond_b="ZIKV", label="uninf vs ZIKV")], + mmd=MMDSettings(n_permutations=50), + condition_aliases={"uninfected": ["uninfected1", "uninfected2"]}, + ) + df = run_mmd_pooled(cfg) + assert not df["mmd2"].isna().all(), "Expected valid MMD after condition alias remapping" diff --git a/applications/dynaclr/tests/test_multi_experiment_integration.py b/applications/dynaclr/tests/test_multi_experiment_integration.py index 10f30005e..26d22cac1 100644 --- a/applications/dynaclr/tests/test_multi_experiment_integration.py +++ b/applications/dynaclr/tests/test_multi_experiment_integration.py @@ -14,6 +14,7 @@ from lightning.pytorch.loggers import TensorBoardLogger from dynaclr.engine import ContrastiveModule +from viscy_data.cell_index import build_timelapse_cell_index from viscy_models.contrastive.loss import NTXentHCL # --------------------------------------------------------------------------- @@ -52,11 +53,13 @@ def test_multi_experiment_fast_dev_run(tmp_path, _create_experiment, _write_coll perturbation_wells={"control": ["B/1"]}, ) yaml_path = _write_collection_yaml(tmp_path, [exp_alpha, exp_beta]) + parquet_path = tmp_path / "cell_index.parquet" + build_timelapse_cell_index(yaml_path, parquet_path, num_workers=1) from dynaclr.data.datamodule import MultiExperimentDataModule datamodule = MultiExperimentDataModule( - collection_path=str(yaml_path), + cell_index_path=str(parquet_path), z_window=1, yx_patch_size=(32, 32), final_yx_patch_size=(24, 24), @@ -183,11 +186,13 @@ def test_multi_experiment_fast_dev_run_with_all_sampling_axes( start_hpi=0.0, ) yaml_path = _write_collection_yaml(tmp_path, [exp_alpha, exp_beta]) + parquet_path = tmp_path / "cell_index.parquet" + build_timelapse_cell_index(yaml_path, parquet_path, num_workers=1) from dynaclr.data.datamodule import MultiExperimentDataModule datamodule = MultiExperimentDataModule( - collection_path=str(yaml_path), + cell_index_path=str(parquet_path), z_window=1, yx_patch_size=(32, 32), final_yx_patch_size=(24, 24), diff --git a/applications/dynaclr/tests/test_pseudotime.py b/applications/dynaclr/tests/test_pseudotime.py index d091c0e4d..afda9c8ed 100644 --- a/applications/dynaclr/tests/test_pseudotime.py +++ b/applications/dynaclr/tests/test_pseudotime.py @@ -16,6 +16,11 @@ filter_tracks, identify_lineages, ) +from dynaclr.evaluation.pseudotime.dtw_alignment import ( + alignment_results_to_dataframe, + build_infection_template, + dtw_align_tracks, +) from dynaclr.evaluation.pseudotime.metrics import ( aggregate_population, compute_track_timing, @@ -385,3 +390,120 @@ def test_plot_onset_comparison_saves_files(self, tmp_path): assert isinstance(fig, plt.Figure) assert (tmp_path / "onset_comparison.pdf").exists() assert (tmp_path / "onset_comparison.png").exists() + + +# ── TestTimeCalibration ─────────────────────────────────────────────── + + +class TestTimeCalibration: + """Tests for pseudotime-to-minutes template calibration.""" + + @pytest.fixture + def simple_template_inputs(self): + """Two synthetic 5-timepoint tracks with known t_relative_minutes.""" + rng = np.random.default_rng(0) + D = 8 + n_tracks = 6 + tracks = [] + for i in range(n_tracks): + # Each track: 10 frames, t_relative_minutes from -150 to +150 + fov = "C/2/000" + track_id = i + emb = rng.normal(0, 1, (10, D)).astype(np.float32) + obs = pd.DataFrame( + { + "fov_name": fov, + "track_id": track_id, + "t": np.arange(10), + "infection_state": ["not_infected"] * 5 + ["infected"] * 5, + "organelle_state": ["noremodel"] * 10, + "parent_track_id": -1, + } + ) + tracks.append((fov, track_id, emb, obs)) + + # Build AnnData for one "dataset" + all_obs = pd.concat([t[3] for t in tracks], ignore_index=True) + all_emb = np.vstack([t[2] for t in tracks]) + adata = ad.AnnData(X=all_emb, obs=all_obs) + + # Build aligned_df: t_perturb = 5 for all, t_relative_minutes = (t - 5) * 30 + df = all_obs.copy() + df["t_perturb"] = 5 + df["t_relative_minutes"] = (df["t"] - 5) * 30.0 + + return {"test": adata}, {"test": df} + + def test_build_template_has_time_calibration(self, simple_template_inputs): + adata_dict, aligned_df_dict = simple_template_inputs + result = build_infection_template(adata_dict, aligned_df_dict, pca_n_components=None) + assert result.time_calibration is not None + T = result.template.shape[0] + assert result.time_calibration.shape == (T,) + # Calibration should span a reasonable real-time range + assert result.time_calibration.min() < 0 + assert result.time_calibration.max() > 0 + + def test_time_calibration_monotonically_increasing(self, simple_template_inputs): + adata_dict, aligned_df_dict = simple_template_inputs + result = build_infection_template(adata_dict, aligned_df_dict, pca_n_components=None) + cal = result.time_calibration + # After gap interpolation, calibration should be non-decreasing + diffs = np.diff(cal) + assert np.all(diffs >= -1e-6), f"Non-monotonic calibration: {diffs}" + + def test_estimated_t_rel_in_alignment_output(self, simple_template_inputs): + adata_dict, aligned_df_dict = simple_template_inputs + template = build_infection_template(adata_dict, aligned_df_dict, pca_n_components=None) + assert template.time_calibration is not None + + # Align one dataset against the template + adata = list(adata_dict.values())[0] + df = list(aligned_df_dict.values())[0] + results = dtw_align_tracks(adata, df, template, "test", min_track_timepoints=3) + flat = alignment_results_to_dataframe(results, template.template_id, time_calibration=template.time_calibration) + + assert "estimated_t_rel_minutes" in flat.columns + cal_min = template.time_calibration.min() + cal_max = template.time_calibration.max() + est = flat["estimated_t_rel_minutes"].dropna() + assert len(est) > 0 + assert est.min() >= cal_min - 1.0 + assert est.max() <= cal_max + 1.0 + + +# ── TestMetricsContinuous ───────────────────────────────────────────── + + +class TestMetricsContinuous: + """Tests for continuous-signal metrics (onset, peak).""" + + def test_find_onset_continuous_signal(self): + rows = [] + for t in range(-600, 901, 30): + val = 3.0 if t >= 120 else 0.0 + rows.append({"time_minutes": t, "mean": val, "n_cells": 20}) + pop_df = pd.DataFrame(rows) + onset, threshold, bl_mean, bl_std = find_onset_time( + pop_df, baseline_window=(-600, -60), sigma_threshold=2.0, signal_col="mean" + ) + assert onset is not None + assert onset == 120 + + def test_find_peak_metrics_continuous(self): + rows = [] + for t in range(-300, 601, 30): + if t < 0: + val = 0.0 + elif t <= 150: + val = t / 150.0 * 5.0 + elif t <= 300: + val = 5.0 - (t - 150) / 150.0 * 5.0 + else: + val = 0.0 + rows.append({"time_minutes": t, "mean": val, "n_cells": 20}) + pop_df = pd.DataFrame(rows) + metrics = find_peak_metrics(pop_df, signal_col="mean") + assert not np.isnan(metrics["T_peak_minutes"]) + assert metrics["peak_amplitude"] > 0 + assert metrics["auc"] > 0 diff --git a/applications/dynaclr/tests/test_reduce_dimensionality.py b/applications/dynaclr/tests/test_reduce_dimensionality.py index 3b291b8b7..fbfd4c56a 100644 --- a/applications/dynaclr/tests/test_reduce_dimensionality.py +++ b/applications/dynaclr/tests/test_reduce_dimensionality.py @@ -6,6 +6,8 @@ from pydantic import ValidationError from dynaclr.evaluation.dimensionality_reduction.config import ( + CombinedDatasetConfig, + CombinedDimensionalityReductionConfig, DimensionalityReductionConfig, PCAConfig, PHATEConfig, @@ -154,7 +156,9 @@ class TestCLIIntegration: def test_pca_end_to_end(self, synthetic_zarr, tmp_path): from click.testing import CliRunner - from dynaclr.evaluation.dimensionality_reduction.reduce_dimensionality import main + from dynaclr.evaluation.dimensionality_reduction.reduce_dimensionality import ( + main, + ) output_path = str(tmp_path / "output.zarr") config_content = f"input_path: {synthetic_zarr}\noutput_path: {output_path}\npca:\n n_components: 10\n" @@ -172,7 +176,9 @@ def test_pca_end_to_end(self, synthetic_zarr, tmp_path): def test_overwrite_keys_protection(self, synthetic_zarr, tmp_path): from click.testing import CliRunner - from dynaclr.evaluation.dimensionality_reduction.reduce_dimensionality import main + from dynaclr.evaluation.dimensionality_reduction.reduce_dimensionality import ( + main, + ) # Pre-populate X_pca adata = ad.read_zarr(synthetic_zarr) @@ -191,7 +197,9 @@ def test_overwrite_keys_protection(self, synthetic_zarr, tmp_path): def test_overwrite_keys_allowed(self, synthetic_zarr, tmp_path): from click.testing import CliRunner - from dynaclr.evaluation.dimensionality_reduction.reduce_dimensionality import main + from dynaclr.evaluation.dimensionality_reduction.reduce_dimensionality import ( + main, + ) # Pre-populate X_pca adata = ad.read_zarr(synthetic_zarr) @@ -212,7 +220,9 @@ def test_overwrite_keys_allowed(self, synthetic_zarr, tmp_path): def test_writes_back_to_input_when_no_output(self, synthetic_zarr, tmp_path): from click.testing import CliRunner - from dynaclr.evaluation.dimensionality_reduction.reduce_dimensionality import main + from dynaclr.evaluation.dimensionality_reduction.reduce_dimensionality import ( + main, + ) config_content = f"input_path: {synthetic_zarr}\npca:\n n_components: 5\n" config_path = tmp_path / "test_config.yaml" @@ -225,3 +235,197 @@ def test_writes_back_to_input_when_no_output(self, synthetic_zarr, tmp_path): adata = ad.read_zarr(synthetic_zarr) assert "X_pca" in adata.obsm assert adata.obsm["X_pca"].shape == (100, 5) + + +class TestAppendToAnndataZarrUns: + """Test that append_to_anndata_zarr preserves existing uns keys.""" + + def test_uns_per_key_preserves_existing(self, tmp_path): + from viscy_utils.evaluation.zarr_utils import append_to_anndata_zarr + + rng = np.random.default_rng(42) + adata = ad.AnnData(X=rng.standard_normal((10, 4)).astype(np.float32)) + adata.uns["existing_key"] = "should_survive" + adata.uns["existing_list"] = ["a", "b"] + zarr_path = tmp_path / "test.zarr" + ad.settings.allow_write_nullable_strings = True + adata.write_zarr(zarr_path) + + append_to_anndata_zarr(zarr_path, uns={"new_key": ["path1", "path2"]}) + + result = ad.read_zarr(zarr_path) + assert result.uns["existing_key"] == "should_survive" + assert list(result.uns["existing_list"]) == ["a", "b"] + assert list(result.uns["new_key"]) == ["path1", "path2"] + + def test_uns_overwrites_specific_key(self, tmp_path): + from viscy_utils.evaluation.zarr_utils import append_to_anndata_zarr + + rng = np.random.default_rng(42) + adata = ad.AnnData(X=rng.standard_normal((10, 4)).astype(np.float32)) + adata.uns["my_key"] = "old_value" + adata.uns["other_key"] = "untouched" + zarr_path = tmp_path / "test.zarr" + ad.settings.allow_write_nullable_strings = True + adata.write_zarr(zarr_path) + + append_to_anndata_zarr(zarr_path, uns={"my_key": "new_value"}) + + result = ad.read_zarr(zarr_path) + assert result.uns["my_key"] == "new_value" + assert result.uns["other_key"] == "untouched" + + +class TestCombinedDimensionalityReductionConfig: + def test_valid_config(self, synthetic_zarr): + cfg = CombinedDimensionalityReductionConfig( + input_paths=[synthetic_zarr], + pca=PCAConfig(n_components=5), + ) + assert len(cfg.input_paths) == 1 + + def test_valid_config_with_datasets_mapping(self, synthetic_zarr): + cfg = CombinedDimensionalityReductionConfig( + datasets={"ds1": CombinedDatasetConfig(anndata=synthetic_zarr)}, + pca=PCAConfig(n_components=5), + ) + assert cfg.input_paths == [synthetic_zarr] + + def test_missing_methods_raises(self, synthetic_zarr): + with pytest.raises(ValidationError, match="At least one reduction method"): + CombinedDimensionalityReductionConfig(input_paths=[synthetic_zarr]) + + def test_missing_path_raises(self): + with pytest.raises(ValidationError, match="Input path not found"): + CombinedDimensionalityReductionConfig( + input_paths=["/nonexistent/path.zarr"], + pca=PCAConfig(), + ) + + +class TestCombinedReduction: + @pytest.fixture + def two_synthetic_zarrs(self, tmp_path): + """Create two synthetic AnnData zarrs with uns metadata.""" + ad.settings.allow_write_nullable_strings = True + rng = np.random.default_rng(42) + paths = [] + for i in range(2): + n = 50 + i * 30 # 50 and 80 samples + X = rng.standard_normal((n, 32)).astype(np.float32) + adata = ad.AnnData(X=X) + adata.uns["classifier_version"] = f"v{i}" + adata.uns["predicted_classes"] = ["alive", "dead"] + zarr_path = tmp_path / f"store_{i}.zarr" + adata.write_zarr(zarr_path) + paths.append(str(zarr_path)) + return paths + + def test_combined_pca_only(self, two_synthetic_zarrs): + from click.testing import CliRunner + + from dynaclr.evaluation.dimensionality_reduction.reduce_combined import main + + config_content = ( + f"input_paths:\n - {two_synthetic_zarrs[0]}\n - {two_synthetic_zarrs[1]}\npca:\n n_components: 5\n" + ) + runner = CliRunner() + with runner.isolated_filesystem(): + config_path = "combined.yaml" + with open(config_path, "w") as f: + f.write(config_content) + result = runner.invoke(main, ["-c", config_path]) + assert result.exit_code == 0, result.output + + for i, path in enumerate(two_synthetic_zarrs): + adata = ad.read_zarr(path) + n = 50 + i * 30 + assert "X_pca_combined" in adata.obsm + assert adata.obsm["X_pca_combined"].shape[0] == n + assert "pca_combined_datasets" in adata.uns + assert list(adata.uns["pca_combined_datasets"]) == two_synthetic_zarrs + # uns preserved + assert adata.uns["classifier_version"] == f"v{i}" + assert list(adata.uns["predicted_classes"]) == ["alive", "dead"] + + @pytest.fixture(autouse=False) + def _skip_no_phate(self): + pytest.importorskip("phate") + + def test_combined_pca_and_phate(self, two_synthetic_zarrs, _skip_no_phate): + from click.testing import CliRunner + + from dynaclr.evaluation.dimensionality_reduction.reduce_combined import main + + config_content = ( + "input_paths:\n" + f" - {two_synthetic_zarrs[0]}\n" + f" - {two_synthetic_zarrs[1]}\n" + "pca:\n" + " n_components: 5\n" + "phate:\n" + " n_components: 2\n" + " knn: 5\n" + " decay: 40\n" + ) + runner = CliRunner() + with runner.isolated_filesystem(): + config_path = "combined.yaml" + with open(config_path, "w") as f: + f.write(config_content) + result = runner.invoke(main, ["-c", config_path]) + assert result.exit_code == 0, result.output + + for i, path in enumerate(two_synthetic_zarrs): + adata = ad.read_zarr(path) + n = 50 + i * 30 + assert adata.obsm["X_pca_combined"].shape[0] == n + assert adata.obsm["X_phate_combined"].shape == (n, 2) + assert "pca_combined_datasets" in adata.uns + assert "phate_combined_datasets" in adata.uns + + def test_overwrite_protection(self, two_synthetic_zarrs): + from click.testing import CliRunner + + from dynaclr.evaluation.dimensionality_reduction.reduce_combined import main + + config_content = ( + f"input_paths:\n - {two_synthetic_zarrs[0]}\n - {two_synthetic_zarrs[1]}\npca:\n n_components: 5\n" + ) + runner = CliRunner() + with runner.isolated_filesystem(): + config_path = "combined.yaml" + with open(config_path, "w") as f: + f.write(config_content) + # First run + result = runner.invoke(main, ["-c", config_path]) + assert result.exit_code == 0, result.output + # Second run without overwrite_keys should fail + result = runner.invoke(main, ["-c", config_path]) + assert result.exit_code != 0 + assert "already exists" in result.output + + def test_overwrite_allowed(self, two_synthetic_zarrs): + from click.testing import CliRunner + + from dynaclr.evaluation.dimensionality_reduction.reduce_combined import main + + config_content = ( + "input_paths:\n" + f" - {two_synthetic_zarrs[0]}\n" + f" - {two_synthetic_zarrs[1]}\n" + "overwrite_keys: true\n" + "pca:\n" + " n_components: 5\n" + ) + runner = CliRunner() + with runner.isolated_filesystem(): + config_path = "combined.yaml" + with open(config_path, "w") as f: + f.write(config_content) + # First run + result = runner.invoke(main, ["-c", config_path]) + assert result.exit_code == 0, result.output + # Second run should also succeed (overwrite_keys=true) + result = runner.invoke(main, ["-c", config_path]) + assert result.exit_code == 0, result.output diff --git a/applications/dynaclr/tests/test_valid_anchors_marker.py b/applications/dynaclr/tests/test_valid_anchors_marker.py new file mode 100644 index 000000000..87c323fd6 --- /dev/null +++ b/applications/dynaclr/tests/test_valid_anchors_marker.py @@ -0,0 +1,212 @@ +"""Regression tests for marker-aware valid_anchors in flat-parquet mode. + +In flat-parquet / bag-of-channels mode, one cell observation becomes one +row per channel. ``_pick_temporal_candidate`` restricts positive candidates +to rows with the same ``marker`` as the anchor, so ``_compute_valid_anchors`` +must also include ``marker`` in the validity key — otherwise an anchor can +pass validation because a different-marker row exists at ``t+tau``, then +crash at sample time with "No positive found". + +These tests hit ``_compute_valid_anchors`` directly via ``object.__new__`` +so they don't need real zarr stores. +""" + +from __future__ import annotations + +from types import SimpleNamespace + +import pandas as pd +import pytest + +from dynaclr.data.index import MultiExperimentIndex + + +def _make_registry(experiment_names, interval_minutes=30.0): + """Return a minimal object that quacks like ExperimentRegistry for tau math.""" + experiments = [SimpleNamespace(name=n, interval_minutes=interval_minutes) for n in experiment_names] + + def tau_range_frames(name, tau_range_hours): + exp = next(e for e in experiments if e.name == name) + min_h, max_h = tau_range_hours + frames_per_hour = 60.0 / exp.interval_minutes + return (int(round(min_h * frames_per_hour)), int(round(max_h * frames_per_hour))) + + return SimpleNamespace(experiments=experiments, tau_range_frames=tau_range_frames) + + +def _make_index(tracks: pd.DataFrame, registry) -> MultiExperimentIndex: + """Construct a bare MultiExperimentIndex without zarr I/O.""" + index = object.__new__(MultiExperimentIndex) + index.registry = registry + index.tracks = tracks.reset_index(drop=True) + return index + + +class TestMarkerAwareValidAnchors: + """`marker` must be part of the temporal validity key in flat-parquet mode.""" + + def test_anchor_with_cross_marker_positive_rejected(self): + """ + Anchor at (lid, marker=A, t=5) must be REJECTED when the only row + at t+tau is (lid, marker=B, t=6). Without marker-aware validity + this anchor would be accepted and then crash at sample time because + `_pick_temporal_candidate` filters candidates to same marker. + """ + tracks = pd.DataFrame( + { + "experiment": ["exp"] * 2, + "lineage_id": ["L"] * 2, + "marker": ["A", "B"], + "t": [5, 6], + } + ) + registry = _make_registry(["exp"], interval_minutes=30.0) + index = _make_index(tracks, registry) + + # tau_range 0.5h - 1.5h at 30min = (1, 3) frames. + valid = index._compute_valid_anchors( + tau_range_hours=(0.5, 1.5), + positive_cell_source="lookup", + positive_match_columns=["lineage_id"], + ) + # Neither row is a valid anchor: A has no same-marker positive in window, + # and B has no same-marker positive either. + assert len(valid) == 0, f"expected 0 valid anchors, got {len(valid)}:\n{valid}" + + def test_anchor_with_same_marker_positive_accepted(self): + """Anchor at (lid, marker=A, t=5) with (lid, marker=A, t=6) IS valid.""" + tracks = pd.DataFrame( + { + "experiment": ["exp"] * 3, + "lineage_id": ["L"] * 3, + "marker": ["A", "A", "B"], + "t": [5, 6, 6], + } + ) + registry = _make_registry(["exp"], interval_minutes=30.0) + index = _make_index(tracks, registry) + + valid = index._compute_valid_anchors( + tau_range_hours=(0.5, 1.5), + positive_cell_source="lookup", + positive_match_columns=["lineage_id"], + ) + # (A, t=5) is valid because (A, t=6) exists. + # (A, t=6) is NOT valid because there's no (A, t=7..8). + # (B, t=6) is NOT valid because there's no (B, t=7..8). + assert len(valid) == 1 + row = valid.iloc[0] + assert row["marker"] == "A" + assert row["t"] == 5 + + def test_both_markers_have_positives_both_accepted(self): + """When each marker has its own lineage continuity, both pass.""" + tracks = pd.DataFrame( + { + "experiment": ["exp"] * 4, + "lineage_id": ["L"] * 4, + "marker": ["A", "A", "B", "B"], + "t": [5, 6, 5, 6], + } + ) + registry = _make_registry(["exp"], interval_minutes=30.0) + index = _make_index(tracks, registry) + + valid = index._compute_valid_anchors( + tau_range_hours=(0.5, 1.5), + positive_cell_source="lookup", + positive_match_columns=["lineage_id"], + ) + # (A, t=5) valid (A, t=6 exists). (B, t=5) valid (B, t=6 exists). + # t=6 of each marker is NOT valid (no t=7 for either). + assert len(valid) == 2 + assert set(zip(valid["marker"], valid["t"])) == {("A", 5), ("B", 5)} + + def test_no_marker_column_falls_back_to_lineage_t(self): + """When `marker` column is absent, behavior matches legacy (lid, t) keys.""" + tracks = pd.DataFrame( + { + "experiment": ["exp"] * 3, + "lineage_id": ["L"] * 3, + "t": [5, 6, 7], + } + ) + registry = _make_registry(["exp"], interval_minutes=30.0) + index = _make_index(tracks, registry) + + valid = index._compute_valid_anchors( + tau_range_hours=(0.5, 1.5), + positive_cell_source="lookup", + positive_match_columns=["lineage_id"], + ) + # tau_range_frames = (1, 3). t=5 needs t=6,7,8 (6,7 exist) -> valid. + # t=6 needs t=7,8,9 (7 exists) -> valid. t=7 needs t=8,9,10 -> NOT valid. + assert len(valid) == 2 + assert set(valid["t"].to_numpy()) == {5, 6} + + +class TestLineageCollisionDetection: + """ + Regression for the ALFI-style bug where two FOVs share the same + ``lineage_id`` because lineage reconstruction collapsed across FOVs. + The marker-aware fix cannot save this — it's a data bug — so the + test documents the failure mode: `_compute_valid_anchors` will + accept anchors whose temporal neighbors are actually in a different + physical FOV. Cached so we notice if lineage reconstruction ever + starts disambiguating by FOV. + """ + + def test_cross_fov_lineage_collision_accepted_today(self): + """Two FOVs share `lineage_id='L'`; validity check treats as one lineage.""" + # FOV1 has t=5 only; FOV2 has t=6 only. They share lineage_id. + tracks = pd.DataFrame( + { + "experiment": ["exp"] * 2, + "lineage_id": ["L", "L"], + "fov_name": ["FOV1", "FOV2"], # different physical fields + "marker": ["A", "A"], + "t": [5, 6], + } + ) + registry = _make_registry(["exp"], interval_minutes=30.0) + index = _make_index(tracks, registry) + + valid = index._compute_valid_anchors( + tau_range_hours=(0.5, 1.5), + positive_cell_source="lookup", + positive_match_columns=["lineage_id"], + ) + # Today both rows pass — the fix doesn't consider fov_name in the + # validity key. If cell_index generation ever disambiguates lineage_id + # by fov, this test will flip and should be updated. + assert len(valid) == 1 # (A, t=5) valid because "L" at t=6 exists + # The surviving anchor is t=5 — at sample time it would try to + # pull a patch from FOV2 thinking it's the same biological lineage. + # That's still wrong biologically, but it won't raise "No positive found". + + +@pytest.mark.parametrize("interval_minutes", [15.0, 30.0, 60.0]) +def test_marker_key_respects_per_experiment_tau(interval_minutes): + """Marker-aware validity plays correctly with per-experiment interval_minutes.""" + tracks = pd.DataFrame( + { + "experiment": ["exp"] * 4, + "lineage_id": ["L"] * 4, + "marker": ["A", "A", "A", "A"], + "t": [0, 1, 5, 10], + } + ) + registry = _make_registry(["exp"], interval_minutes=interval_minutes) + index = _make_index(tracks, registry) + + valid = index._compute_valid_anchors( + tau_range_hours=(0.5, 1.5), + positive_cell_source="lookup", + positive_match_columns=["lineage_id"], + ) + min_f, max_f = registry.tau_range_frames("exp", (0.5, 1.5)) + # Every valid anchor t must have some other row at t+tau within [min_f, max_f]. + t_vals = set(tracks["t"].to_numpy()) + for t in valid["t"].to_numpy(): + ok = any((t + tau) in t_vals for tau in range(min_f, max_f + 1) if tau != 0) + assert ok, f"anchor t={t} validated but no t+tau neighbor exists at interval={interval_minutes}" diff --git a/packages/viscy-data/src/viscy_data/_typing.py b/packages/viscy-data/src/viscy_data/_typing.py index 17eb7ed1a..0e6baf6cc 100644 --- a/packages/viscy-data/src/viscy_data/_typing.py +++ b/packages/viscy-data/src/viscy_data/_typing.py @@ -24,6 +24,7 @@ "CELL_INDEX_CORE_COLUMNS", "CELL_INDEX_GROUPING_COLUMNS", "CELL_INDEX_IMAGING_COLUMNS", + "CELL_INDEX_NORMALIZATION_COLUMNS", "CELL_INDEX_OPS_COLUMNS", "CELL_INDEX_TIMELAPSE_COLUMNS", "CellIndex", @@ -245,7 +246,25 @@ class TripletSample(TypedDict): CELL_INDEX_OPS_COLUMNS = ["gene_name", "reporter", "sgRNA"] -CELL_INDEX_IMAGING_COLUMNS = ["pixel_size_xy_um", "pixel_size_z_um"] +CELL_INDEX_IMAGING_COLUMNS = [ + "pixel_size_xy_um", + "pixel_size_z_um", + "T_shape", + "C_shape", + "Z_shape", + "Y_shape", + "X_shape", + "z_focus_mean", +] + +CELL_INDEX_NORMALIZATION_COLUMNS = [ + "norm_mean", + "norm_std", + "norm_median", + "norm_iqr", + "norm_max", + "norm_min", +] # Extracted from viscy/data/triplet.py for shared access ULTRACK_INDEX_COLUMNS = [ diff --git a/packages/viscy-data/src/viscy_data/_utils.py b/packages/viscy-data/src/viscy_data/_utils.py index ea0e96ef0..e6a6523c7 100644 --- a/packages/viscy-data/src/viscy_data/_utils.py +++ b/packages/viscy-data/src/viscy_data/_utils.py @@ -217,4 +217,7 @@ def _transform_channel_wise( ) -> list[Tensor]: scattered_channels = _scatter_channels(channel_names, patch, norm_meta, extra) transformed_channels = transform(scattered_channels) - return _gather_channels(transformed_channels) + extra_keys = ("norm_meta",) + if extra is not None: + extra_keys = ("norm_meta",) + tuple(extra.keys()) + return _gather_channels(transformed_channels, extra_keys=extra_keys) diff --git a/packages/viscy-data/src/viscy_data/cell_index.py b/packages/viscy-data/src/viscy_data/cell_index.py index ca03e8167..6d0f7bf9b 100644 --- a/packages/viscy-data/src/viscy_data/cell_index.py +++ b/packages/viscy-data/src/viscy_data/cell_index.py @@ -15,6 +15,7 @@ from concurrent.futures import ProcessPoolExecutor, as_completed from pathlib import Path +import numpy as np import pandas as pd import pyarrow as pa import pyarrow.parquet as pq @@ -26,6 +27,7 @@ CELL_INDEX_CORE_COLUMNS, CELL_INDEX_GROUPING_COLUMNS, CELL_INDEX_IMAGING_COLUMNS, + CELL_INDEX_NORMALIZATION_COLUMNS, CELL_INDEX_OPS_COLUMNS, CELL_INDEX_TIMELAPSE_COLUMNS, ) @@ -37,6 +39,7 @@ "build_ops_cell_index", "build_timelapse_cell_index", "convert_ops_parquet", + "preprocess_cell_index", "read_cell_index", "validate_cell_index", "write_cell_index", @@ -74,6 +77,18 @@ ("organelle", pa.string()), ("pixel_size_xy_um", pa.float32()), ("pixel_size_z_um", pa.float32()), + ("T_shape", pa.int32()), + ("C_shape", pa.int32()), + ("Z_shape", pa.int32()), + ("Y_shape", pa.int32()), + ("X_shape", pa.int32()), + ("z_focus_mean", pa.float32()), + ("norm_mean", pa.float32()), + ("norm_std", pa.float32()), + ("norm_median", pa.float32()), + ("norm_iqr", pa.float32()), + ("norm_max", pa.float32()), + ("norm_min", pa.float32()), ] ) @@ -85,6 +100,7 @@ + CELL_INDEX_TIMELAPSE_COLUMNS + CELL_INDEX_OPS_COLUMNS + CELL_INDEX_IMAGING_COLUMNS + + CELL_INDEX_NORMALIZATION_COLUMNS ) # --------------------------------------------------------------------------- @@ -168,6 +184,13 @@ def write_cell_index( def read_cell_index(path: str | Path) -> pd.DataFrame: """Read a cell index parquet into a pandas DataFrame. + String columns are materialized as NumPy ``object`` arrays instead of + ``ArrowStringArray``. ArrowStringArray-backed columns route every + boolean mask slice through ``pyarrow.compute.take``, which allocates + a fresh buffer per string column and can spike peak RSS by 50+ GiB + on 80M-row indices during train/val FOV partitioning. NumPy object + columns make ``df[mask]`` a cheap gather. + Parameters ---------- path : str | Path @@ -179,7 +202,154 @@ def read_cell_index(path: str | Path) -> pd.DataFrame: Cell index with correct dtypes. """ table = pq.read_table(str(path), schema=CELL_INDEX_SCHEMA) - return table.to_pandas() + df = table.to_pandas(use_threads=True) + # ArrowStringArray columns with low cardinality (experiment, fov_name, + # marker, store_path, well, microscope, organelle, reporter) become + # Categorical to make ``df[mask]`` a fast int-code gather. Other string + # columns (cell_id, tracks_path, global_track_id, lineage_id, etc.) are + # high cardinality and are already read via the NumPy column cache in + # the dataset, so leave them as ArrowStringArray to avoid allocating + # millions of Python string objects here. + # NB: ``fov`` and ``well`` are NOT cast here because ``_align_parquet_columns`` + # downstream rewrites ``fov_name`` via string concatenation, which pandas + # does not support on Categorical. We cast ``fov_name`` later, after the + # prefix rewrite, in the runtime index layer. + _categorical_cols = ( + "experiment", + "marker", + "store_path", + "microscope", + "organelle", + "reporter", + "channel_name", + ) + for col in _categorical_cols: + if col in df.columns: + df[col] = df[col].astype("category") + return df + + +# --------------------------------------------------------------------------- +# Preprocessing (clean up an existing cell index parquet) +# --------------------------------------------------------------------------- + + +def preprocess_cell_index( + parquet_path: str | Path, + output_path: str | Path | None = None, + focus_channel: str | None = None, +) -> pd.DataFrame: + """Add normalization stats, focus slice, and remove invalid rows. + + Reads precomputed metadata from each FOV's ``zattrs`` (written by + ``viscy preprocess``) and writes them as parquet columns: + + - ``norm_mean``, ``norm_std``, ``norm_median``, ``norm_iqr``, + ``norm_max``, ``norm_min`` — per-timepoint, per-channel statistics + - ``z_focus_mean`` — per-FOV focus plane from ``focus_slice`` + + Drops rows where timepoint stats are missing or ``norm_max == 0.0`` + (empty frames). + + Parameters + ---------- + parquet_path : str | Path + Path to the cell index parquet to preprocess. + output_path : str | Path | None + Destination path. When ``None``, overwrites *parquet_path* in place. + focus_channel : str | None + Channel name for ``focus_slice`` lookup (e.g. ``"Phase3D"``). + When ``None``, uses the first channel_name in each FOV's group. + + Returns + ------- + pd.DataFrame + The preprocessed cell index with normalization and focus columns. + + Raises + ------ + ValueError + If a FOV has no normalization metadata (run ``viscy preprocess`` first). + """ + if output_path is None: + output_path = parquet_path + + df = read_cell_index(parquet_path) + n_before = len(df) + + fov_col = "fov" if "fov" in df.columns else "fov_name" + + # Build lookups from zarr zattrs (one open per unique FOV) + stat_lookup: dict[tuple[str, str, str, int], dict[str, float]] = {} + focus_lookup: dict[tuple[str, str], float] = {} + focus_per_t_lookup: dict[tuple[str, str], dict[int, int]] = {} + + for (store_path, fov), group in df.groupby(["store_path", fov_col]): + fov_path = f"{group['well'].iloc[0]}/{fov}" if "/" not in str(fov) else str(fov) + with open_ome_zarr(f"{store_path}/{fov_path}", mode="r") as pos: + norm_meta = pos.zattrs.get("normalization", None) + focus_meta = pos.zattrs.get("focus_slice", {}) + if norm_meta is None: + raise ValueError( + f"FOV '{fov_path}' in store '{store_path}' has no normalization metadata. " + "Run `viscy preprocess` on this dataset first." + ) + for ch_name, ch_stats in norm_meta.items(): + for t_str, tp_stats in ch_stats.get("timepoint_statistics", {}).items(): + stat_lookup[(str(store_path), str(fov), ch_name, int(t_str))] = tp_stats + + fc = focus_channel or group["channel_name"].iloc[0] + ch_focus = focus_meta.get(fc, {}) + fov_stats = ch_focus.get("fov_statistics", {}) + z_focus = fov_stats.get("z_focus_mean") + if z_focus is not None: + focus_lookup[(str(store_path), str(fov))] = float(z_focus) + per_timepoint = ch_focus.get("per_timepoint", {}) + if per_timepoint: + focus_per_t_lookup[(str(store_path), str(fov))] = { + int(t_str): int(z_idx) for t_str, z_idx in per_timepoint.items() + } + + # Vectorized lookup: build norm + focus column arrays + stat_keys = ["mean", "std", "median", "iqr", "max", "min"] + store_arr = df["store_path"].astype(str).to_numpy() + fov_arr = df[fov_col].astype(str).to_numpy() + ch_arr = df["channel_name"].astype(str).to_numpy() + t_arr = df["t"].astype(int).to_numpy() + + norm_arrays = {stat: np.full(len(df), float("nan"), dtype=np.float32) for stat in stat_keys} + focus_arr = np.full(len(df), float("nan"), dtype=np.float32) + z_arr = df["z"].to_numpy(dtype=np.int16).copy() + valid_mask = np.ones(len(df), dtype=bool) + + for i in range(len(df)): + tp_stats = stat_lookup.get((store_arr[i], fov_arr[i], ch_arr[i], t_arr[i])) + if tp_stats is None or tp_stats.get("max", 1.0) == 0.0: + valid_mask[i] = False + continue + for stat in stat_keys: + norm_arrays[stat][i] = float(tp_stats[stat]) + fov_key = (store_arr[i], fov_arr[i]) + z_focus = focus_lookup.get(fov_key) + if z_focus is not None: + focus_arr[i] = z_focus + z_t = focus_per_t_lookup.get(fov_key, {}).get(t_arr[i]) + if z_t is not None: + z_arr[i] = z_t + + for stat in stat_keys: + df[f"norm_{stat}"] = norm_arrays[stat] + df["z_focus_mean"] = focus_arr + df["z"] = z_arr + + df = df[valid_mask].reset_index(drop=True) + n_dropped = n_before - len(df) + + write_cell_index(df, output_path) + if n_dropped > 0: + _logger.info("Dropped %d invalid rows (%.1f%%).", n_dropped, 100 * n_dropped / n_before) + print(f"Wrote {len(df):,} rows to {output_path} (dropped {n_dropped:,}, added norm + focus columns)") + return df # --------------------------------------------------------------------------- @@ -194,11 +364,17 @@ def _reconstruct_lineage(tracks: pd.DataFrame) -> pd.DataFrame: ancestor. Tracks without a ``parent_track_id`` (or whose parent is not present in the data) are their own root. + The lineage walk is scoped per ``(experiment, well, fov)`` when the + ``well`` column is available. Scoping on ``(experiment, fov)`` alone + collapses cells across wells that share an FOV number (e.g. B/2/002001 + and C/2/002001), producing cross-well lineage_id aliasing that later + crashes the temporal positive lookup with "No positive found". + Parameters ---------- tracks : pd.DataFrame Must contain ``global_track_id``, ``experiment``, ``fov``, ``track_id``. - Optionally ``parent_track_id``. + Optionally ``parent_track_id`` and ``well``. Returns ------- @@ -216,8 +392,9 @@ def _reconstruct_lineage(tracks: pd.DataFrame) -> pd.DataFrame: lineage_series = tracks["lineage_id"].copy() - groups = list(tracks.groupby(["experiment", "fov"])) - for (exp, fov), group in tqdm(groups, desc="Reconstructing lineages", unit="fov"): + group_keys = ["experiment", "well", "fov"] if "well" in tracks.columns else ["experiment", "fov"] + groups = list(tracks.groupby(group_keys)) + for _key, group in tqdm(groups, desc="Reconstructing lineages", unit="fov"): tid_to_gtid: dict[int, str] = dict(zip(group["track_id"], group["global_track_id"])) parent_map: dict[str, str] = {} @@ -274,8 +451,8 @@ def _build_experiment_tracks( if exclude_fovs is not None: all_exclude.update(exclude_fovs) - # Channel-marker pairs from per-experiment channels list - channel_marker_pairs = [(ch.name, ch.marker) for ch in exp.channels] + # Channel entries from per-experiment channels list + channel_entries = [(ch.name, ch.marker, set(ch.wells)) for ch in exp.channels] exp_tracks: list[pd.DataFrame] = [] @@ -305,6 +482,10 @@ def _build_experiment_tracks( raise ValueError(f"Expected exactly one tracking CSV in {tracks_dir}, found: {csv_files}") tracks_df = pd.read_csv(csv_files[0]) + # TCZYX shape from zarr metadata (same for all positions in a well) + img_arr = position["0"] + t_shape, c_shape, z_shape, y_shape, x_shape = img_arr.shape + # Base columns (shared across channel rows) tracks_df["cell_id"] = ( exp.name + "_" + fov_path + "_" + tracks_df["track_id"].astype(str) + "_" + tracks_df["t"].astype(str) @@ -322,12 +503,19 @@ def _build_experiment_tracks( tracks_df["organelle"] = exp.organelle tracks_df["pixel_size_xy_um"] = exp.pixel_size_xy_um tracks_df["pixel_size_z_um"] = exp.pixel_size_z_um + tracks_df["T_shape"] = t_shape + tracks_df["C_shape"] = c_shape + tracks_df["Z_shape"] = z_shape + tracks_df["Y_shape"] = y_shape + tracks_df["X_shape"] = x_shape if "z" not in tracks_df.columns: tracks_df["z"] = 0 - # Explode: one row per channel - for zarr_ch, marker in channel_marker_pairs: + # Explode: one row per channel (skip channels restricted to other wells) + for zarr_ch, marker, valid_wells in channel_entries: + if valid_wells and well_name not in valid_wells: + continue ch_df = tracks_df.copy() ch_df["channel_name"] = zarr_ch ch_df["marker"] = marker diff --git a/packages/viscy-data/src/viscy_data/channel_utils.py b/packages/viscy-data/src/viscy_data/channel_utils.py index 9f7dc3753..63fcc9c16 100644 --- a/packages/viscy-data/src/viscy_data/channel_utils.py +++ b/packages/viscy-data/src/viscy_data/channel_utils.py @@ -50,7 +50,7 @@ def parse_channel_name(name: str) -> dict: # Label-free patterns (use word boundaries for short keywords) labelfree_substrings = ("phase", "brightfield", "retardance") - labelfree_word_patterns = (r"\bbf[\b_]", r"\bdic\b", r"\bpol\b") + labelfree_word_patterns = (r"\bbf(\b|_)", r"\bdic\b", r"\bpol\b", r"\bphc\b") if any(kw in name_lower for kw in labelfree_substrings) or any( re.search(p, name_lower) for p in labelfree_word_patterns ): diff --git a/packages/viscy-data/src/viscy_data/collection.py b/packages/viscy-data/src/viscy_data/collection.py index dd4be9dcb..a28656e7c 100644 --- a/packages/viscy-data/src/viscy_data/collection.py +++ b/packages/viscy-data/src/viscy_data/collection.py @@ -58,10 +58,14 @@ class ChannelEntry(BaseModel): Zarr channel name (e.g. ``"Phase3D"``, ``"raw GFP EX488 EM525-45"``). marker : str Protein marker or channel identity (e.g. ``"Phase3D"``, ``"TOMM20"``). + wells : list[str] + Wells where this channel is biologically valid (e.g. ``["B/3", "C/2"]``). + Empty list means the channel is valid in all wells of the experiment. """ name: str marker: str + wells: list[str] = [] class ExperimentEntry(BaseModel): @@ -144,6 +148,10 @@ class Collection(BaseModel): Collection name. description : str Human-readable description. + datasets_root : str or None + Optional path prefix substituted for ``${datasets_root}`` in + ``data_path`` and ``tracks_path`` at load time. Paths not + starting with this root are left unchanged. provenance : Provenance How the collection was created. experiments : list[ExperimentEntry] @@ -154,6 +162,7 @@ class Collection(BaseModel): name: str description: str = "" + datasets_root: str | None = None provenance: Provenance = Provenance() experiments: list[ExperimentEntry] fov_records: list[FOVRecord] = [] @@ -171,9 +180,9 @@ def _validate_collection(self) -> Collection: seen.add(e.name) for exp in self.experiments: - if exp.interval_minutes <= 0: + if exp.interval_minutes < 0: raise ValueError( - f"Experiment '{exp.name}': interval_minutes must be positive, got {exp.interval_minutes}." + f"Experiment '{exp.name}': interval_minutes must be non-negative, got {exp.interval_minutes}." ) wells = exp.perturbation_wells if not wells: @@ -182,6 +191,39 @@ def _validate_collection(self) -> Collection: return self +_DATASETS_ROOT_VAR = "${datasets_root}" + + +def _resolve_datasets_root(data: dict) -> None: + """Replace ``${datasets_root}`` in experiment paths with the root value. + + Mutates *data* in place. + """ + root = data.get("datasets_root") + if not root: + return + root = root.rstrip("/") + for exp in data.get("experiments", []): + for key in ("data_path", "tracks_path"): + val = exp.get(key, "") + if _DATASETS_ROOT_VAR in val: + exp[key] = val.replace(_DATASETS_ROOT_VAR, root) + + +def _unresolve_datasets_root(data: dict, datasets_root: str) -> None: + """Replace the resolved root prefix with ``${datasets_root}`` for portable YAML. + + Mutates *data* in place. Only paths that start with *datasets_root* are + modified; paths pointing elsewhere are left as absolute strings. + """ + root = datasets_root.rstrip("/") + for exp in data.get("experiments", []): + for key in ("data_path", "tracks_path"): + val = exp.get(key, "") + if val.startswith(root + "/"): + exp[key] = _DATASETS_ROOT_VAR + val[len(root) :] + + def load_collection(path: str | Path) -> Collection: """Load a collection from a YAML file. @@ -197,6 +239,7 @@ def load_collection(path: str | Path) -> Collection: """ with open(Path(path)) as f: data = yaml.safe_load(f) + _resolve_datasets_root(data) return Collection(**data) @@ -211,6 +254,8 @@ def save_collection(collection: Collection, path: str | Path) -> None: Output YAML path. """ data = collection.model_dump(mode="json") + if collection.datasets_root: + _unresolve_datasets_root(data, collection.datasets_root) with open(Path(path), "w") as f: yaml.safe_dump(data, f, default_flow_style=False, sort_keys=False) @@ -257,6 +302,7 @@ def build_collection( name: str, description: str = "", channel_markers: dict[str, list[tuple[str, str]]] | None = None, + datasets_root: str | None = None, ) -> Collection: """Build a collection by grouping FOVRecords into experiments. @@ -277,6 +323,9 @@ def build_collection( Per-experiment ``{exp_name: [(zarr_channel_name, marker), ...]}`` mapping. If None, derives from the first record's ``channel_names`` using channel names as markers. + datasets_root : str or None + Passed through to :class:`Collection`. When set, ``save_collection`` + will write ``${datasets_root}`` prefixes instead of absolute paths. Returns ------- @@ -305,6 +354,15 @@ def build_collection( elif first.channel_names: channels = [ChannelEntry(name=n, marker=n) for n in first.channel_names] + # Auto-populate wells per channel from per-record channel_markers. + # A channel gets a wells restriction if only a subset of wells have + # a non-None marker for it in Airtable. + all_wells = sorted({rec.well_id for rec in recs}) + for ch in channels: + wells_with_marker = sorted({rec.well_id for rec in recs if ch.name in rec.channel_markers}) + if wells_with_marker and wells_with_marker != all_wells: + ch.wells = wells_with_marker + experiments.append( ExperimentEntry( name=exp_name, @@ -316,6 +374,7 @@ def build_collection( start_hpi=first.hours_post_perturbation or 0.0, marker=first.marker or "", organelle=first.organelle or "", + microscope=first.microscope or "", pixel_size_xy_um=getattr(first, "pixel_size_xy_um", None), pixel_size_z_um=getattr(first, "pixel_size_z_um", None), moi=first.moi or 0.0, @@ -325,6 +384,7 @@ def build_collection( return Collection( name=name, description=description, + datasets_root=datasets_root, experiments=experiments, fov_records=records, ) diff --git a/packages/viscy-data/src/viscy_data/sampler.py b/packages/viscy-data/src/viscy_data/sampler.py index 75017b85d..68534b059 100644 --- a/packages/viscy-data/src/viscy_data/sampler.py +++ b/packages/viscy-data/src/viscy_data/sampler.py @@ -153,14 +153,35 @@ def __init__( # Precomputation # ------------------------------------------------------------------ + @staticmethod + def _indices_by_key(keys: pd.Series) -> dict[str, np.ndarray]: + """Return ``{key_str: row_index_array}`` for every unique value in *keys*. + + Fast path for Categorical keys uses NumPy ``cat.codes`` directly — + avoids materializing a pandas groupby iterator, which on large + (~16M row) Arrow-backed DataFrames routes every group slice + through ``pyarrow.compute.take`` and can take tens of minutes. + + For non-Categorical keys, falls back to the pandas groupby. + """ + # Categorical fast path — O(N) single vectorized pass per group. + if isinstance(keys.dtype, pd.CategoricalDtype): + codes = keys.cat.codes.to_numpy() + categories = list(keys.cat.categories) + out: dict[str, np.ndarray] = {} + for c, name in enumerate(categories): + rows = np.flatnonzero(codes == c) + if len(rows) > 0: + out[str(name)] = rows + return out + # Generic fallback. + return {str(name): group.to_numpy() for name, group in keys.groupby(keys).groups.items()} + def _precompute_groups(self) -> None: """Build index lookup tables from valid_anchors columns.""" - # Per-group indices if self.batch_group_by is not None: group_keys = self._compute_strat_keys(self.valid_anchors, self.batch_group_by) - self._group_indices: dict[str, np.ndarray] = { - str(name): group.index.to_numpy() for name, group in self.valid_anchors.groupby(group_keys) - } + self._group_indices: dict[str, np.ndarray] = self._indices_by_key(group_keys) self._group_names: list[str] = list(self._group_indices.keys()) else: self._group_indices = {} @@ -174,16 +195,19 @@ def _precompute_groups(self) -> None: if self.stratify_by is not None: strat_keys = self._compute_strat_keys(self.valid_anchors, self.stratify_by) - # Global stratification indices - for key in strat_keys.unique(): - self._strat_indices[key] = self.valid_anchors.index[strat_keys == key].to_numpy() + # Global stratification indices — NumPy fast path for Categorical. + self._strat_indices = self._indices_by_key(strat_keys) self._strat_names = list(self._strat_indices.keys()) - # Per-group stratification indices + # Per-group × per-stratum indices. Using np.intersect1d between + # pre-built group and strat index arrays stays NumPy-native + # instead of reinvoking pandas groupby on the full 16M-row frame. if self.batch_group_by is not None: - group_keys = self._compute_strat_keys(self.valid_anchors, self.batch_group_by) - for (grp, strat_key), group in self.valid_anchors.groupby([group_keys, strat_keys]): - self._group_strat_indices[(str(grp), str(strat_key))] = group.index.to_numpy() + for grp, g_idx in self._group_indices.items(): + for strat_key, s_idx in self._strat_indices.items(): + common = np.intersect1d(g_idx, s_idx, assume_unique=True) + if len(common) > 0: + self._group_strat_indices[(grp, strat_key)] = common # All indices self._all_indices = np.arange(len(self.valid_anchors)) @@ -212,23 +236,18 @@ def _precompute_groups(self) -> None: @staticmethod def _compute_strat_keys(df: pd.DataFrame, columns: list[str]) -> pd.Series: - """Compute a single string key per row for grouping. + """Compute a single key per row for grouping. - Parameters - ---------- - df : pd.DataFrame - DataFrame to compute keys for. - columns : list[str] - Column names to combine into group keys. + For a single column, returns the raw Series — pandas ``groupby`` + handles Categorical / string / numeric dtypes directly, and + ``df.col.astype(str)`` over an 80M-row Categorical allocates a + Python-object array that can spike 5-8 GiB transient RAM per call. - Returns - ------- - pd.Series - String keys, one per row. Single-column uses values directly; - multi-column joins with ``"|"``. + For multi-column keys, falls back to the ``"|"``-joined string form + which is unavoidable with pandas groupby today. """ if len(columns) == 1: - return df[columns[0]].astype(str) + return df[columns[0]] return df[columns].astype(str).agg("|".join, axis=1) # ------------------------------------------------------------------ @@ -249,13 +268,47 @@ def __len__(self) -> int: return math.ceil(total_batches / self.num_replicas) def __iter__(self) -> Iterator[list[int]]: - """Yield batch-sized lists of integer indices.""" - rng = np.random.default_rng(self.seed + self.epoch) + """Yield batch-sized lists of integer indices. + + Builds batches lazily so the first batch is ready in milliseconds + instead of blocking on a full-epoch materialization. Every rank + still calls ``_build_one_batch`` on every index so the RNG draws + stay identical to the list-based implementation — only the + *yield* is rank-filtered, not the sampling. DDP correctness is + therefore bit-identical to the previous implementation; the only + change is that the main thread sees batch 0 after one + ``_build_one_batch`` call instead of ``total_batches`` calls. + + ``limit_train_batches`` interacts with this: Lightning stops + pulling from the generator after its cap, so we never pay for + the unused suffix of the epoch. + + The epoch counter auto-advances at the start of each iteration + so that the next ``__iter__`` call reseeds the RNG with a fresh + ``seed + epoch`` and yields a different batch sequence. Advancing + at the start (not the end) is robust against early generator + termination from ``limit_train_batches``: Lightning stops pulling + after its cap and garbage-collects the generator, which would + skip any end-of-iter bookkeeping. + + PyTorch Lightning does not call ``set_epoch`` on custom + ``batch_sampler`` instances (``use_distributed_sampler: false`` + with a batch sampler means Lightning's auto-wrap skips us), so + we self-advance. ``set_epoch`` still works if a caller wants + deterministic resume from a specific epoch — call it before the + iteration and the advance will take the resumed epoch as its + starting point. + """ + seed_offset = self.epoch + self.epoch += 1 + rng = np.random.default_rng(self.seed + seed_offset) total_batches = len(self.valid_anchors) // self.batch_size - all_batches = [self._build_one_batch(rng) for _ in range(total_batches)] - # DDP: each rank takes its interleaved slice - my_batches = all_batches[self.rank :: self.num_replicas] - yield from my_batches + rank = self.rank + replicas = self.num_replicas + for i in range(total_batches): + batch = self._build_one_batch(rng) + if i % replicas == rank: + yield batch # ------------------------------------------------------------------ # Batch construction diff --git a/packages/viscy-data/src/viscy_data/schemas.py b/packages/viscy-data/src/viscy_data/schemas.py index a7f96eb5d..575230657 100644 --- a/packages/viscy-data/src/viscy_data/schemas.py +++ b/packages/viscy-data/src/viscy_data/schemas.py @@ -54,6 +54,14 @@ class FOVRecord(BaseModel): Treatment concentration in nanomolar. fluorescence_modality : str or None Fluorescence imaging modality. + microscope : str or None + Microscope identifier (e.g. ``"mantis"``, ``"dragonfly"``). + labelfree_modality : str or None + Label-free imaging modality (e.g. ``"widefield"``, ``"oblique"``). + treatment : str or None + Treatment name (e.g. ``"DMSO"``, ``"Bafilomycin"``). + hours_post_treatment : float or None + Hours post treatment at imaging start. t_shape : int or None Number of timepoints. c_shape : int or None @@ -68,6 +76,10 @@ class FOVRecord(BaseModel): Physical pixel size in the XY plane (micrometers). pixel_size_z_um : float or None Physical pixel size in Z (micrometers). + channel_markers : dict[str, str] + Maps zarr channel name to marker for this well. + Only channels with a non-None marker in Airtable are included. + Empty dict means no per-well channel marker information is available. """ dataset: str @@ -88,6 +100,10 @@ class FOVRecord(BaseModel): seeding_density: int | None = None treatment_concentration_nm: float | None = None fluorescence_modality: str | None = None + microscope: str | None = None + labelfree_modality: str | None = None + treatment: str | None = None + hours_post_treatment: float | None = None t_shape: int | None = None c_shape: int | None = None z_shape: int | None = None @@ -95,3 +111,4 @@ class FOVRecord(BaseModel): x_shape: int | None = None pixel_size_xy_um: float | None = None pixel_size_z_um: float | None = None + channel_markers: dict[str, str] = {} diff --git a/packages/viscy-data/tests/test_cell_index.py b/packages/viscy-data/tests/test_cell_index.py index c6fd6aa62..0166e65bf 100644 --- a/packages/viscy-data/tests/test_cell_index.py +++ b/packages/viscy-data/tests/test_cell_index.py @@ -15,6 +15,7 @@ CELL_INDEX_CORE_COLUMNS, CELL_INDEX_GROUPING_COLUMNS, CELL_INDEX_IMAGING_COLUMNS, + CELL_INDEX_NORMALIZATION_COLUMNS, CELL_INDEX_OPS_COLUMNS, CELL_INDEX_TIMELAPSE_COLUMNS, ) @@ -22,6 +23,7 @@ CELL_INDEX_SCHEMA, _parse_bbox_min_size, _parse_bbox_to_centroid, + _reconstruct_lineage, build_timelapse_cell_index, convert_ops_parquet, read_cell_index, @@ -130,6 +132,7 @@ def test_strict_passes_with_all_columns(self): + CELL_INDEX_TIMELAPSE_COLUMNS + CELL_INDEX_OPS_COLUMNS + CELL_INDEX_IMAGING_COLUMNS + + CELL_INDEX_NORMALIZATION_COLUMNS ): df[col] = None warnings = validate_cell_index(df, strict=True) @@ -143,6 +146,7 @@ def test_all_null_column_warns(self): + CELL_INDEX_TIMELAPSE_COLUMNS + CELL_INDEX_OPS_COLUMNS + CELL_INDEX_IMAGING_COLUMNS + + CELL_INDEX_NORMALIZATION_COLUMNS ): df[col] = None warnings = validate_cell_index(df, strict=True) @@ -246,7 +250,7 @@ def test_lineage_reconstruction(self, tmp_path): dataset = open_ome_zarr(dataset_path, layout="hcs", mode="w", channel_names=["nuclei_labels"]) pos = dataset.create_position("A", "1", "0") rng = np.random.default_rng(42) - pos.create_image("0", rng.random((2, 1, 1, 64, 64)).astype(np.float32)) + pos.create_image("0", rng.random((4, 1, 1, 64, 64)).astype(np.float32)) # Track 0 → root, Track 1 → child of 0, Track 2 → grandchild of 1 tracks_df = pd.DataFrame( @@ -297,6 +301,104 @@ def test_cell_id_unique(self, tracks_hcs_dataset, tmp_path): assert not df["cell_id"].duplicated().any() +class TestReconstructLineage: + """Unit tests for ``_reconstruct_lineage`` — scoped directly, no zarr I/O.""" + + def test_cross_well_same_fov_does_not_collapse(self): + """ + Two wells (B/2 and C/2) that share the same FOV number ("002001") and + contain tracks with the same numeric ``track_id`` / ``parent_track_id`` + must NOT have their lineages fused. Prior to the fix, the groupby was + scoped by (experiment, fov) and the two wells were walked as if they + were one, aliasing their lineage_ids. + """ + rows = [] + # Well B/2, fov 002001: track_id 88 whose parent is 35; root is 35. + rows.append( + { + "experiment": "exp", + "well": "B/2", + "fov": "002001", + "track_id": 35, + "parent_track_id": -1, + "global_track_id": "exp_B/2/002001_35", + } + ) + rows.append( + { + "experiment": "exp", + "well": "B/2", + "fov": "002001", + "track_id": 88, + "parent_track_id": 35, + "global_track_id": "exp_B/2/002001_88", + } + ) + # Well C/2, fov 002001: independent track_id 86 whose parent is 34. + # Without the fix, the (exp, fov="002001") group sees BOTH wells' + # tracks, and the parent_track_id=34 lookup in the B/2-derived map + # fails, so track 86 becomes its own root — but track 35 from B/2 + # appears inside the same group, potentially misrouting. + rows.append( + { + "experiment": "exp", + "well": "C/2", + "fov": "002001", + "track_id": 34, + "parent_track_id": -1, + "global_track_id": "exp_C/2/002001_34", + } + ) + rows.append( + { + "experiment": "exp", + "well": "C/2", + "fov": "002001", + "track_id": 86, + "parent_track_id": 34, + "global_track_id": "exp_C/2/002001_86", + } + ) + tracks = pd.DataFrame(rows) + + result = _reconstruct_lineage(tracks.copy()) + + # B/2 rows must resolve to B/2 root; C/2 rows must resolve to C/2 root. + b2_rows = result[result["well"] == "B/2"] + c2_rows = result[result["well"] == "C/2"] + assert set(b2_rows["lineage_id"].unique()) == {"exp_B/2/002001_35"} + assert set(c2_rows["lineage_id"].unique()) == {"exp_C/2/002001_34"} + + def test_no_parent_track_id_column(self): + """If `parent_track_id` is missing, lineage_id falls back to global_track_id.""" + tracks = pd.DataFrame( + { + "experiment": ["exp"] * 2, + "well": ["A/1"] * 2, + "fov": ["0"] * 2, + "track_id": [0, 1], + "global_track_id": ["exp_A/1/0_0", "exp_A/1/0_1"], + } + ) + result = _reconstruct_lineage(tracks.copy()) + assert (result["lineage_id"] == result["global_track_id"]).all() + + def test_single_well_chain_resolves_to_root(self): + """Basic sanity: a parent → daughter chain resolves daughters to root.""" + tracks = pd.DataFrame( + { + "experiment": ["exp"] * 3, + "well": ["A/1"] * 3, + "fov": ["0"] * 3, + "track_id": [0, 1, 2], + "parent_track_id": [-1, 0, 1], + "global_track_id": ["exp_A/1/0_0", "exp_A/1/0_1", "exp_A/1/0_2"], + } + ) + result = _reconstruct_lineage(tracks.copy()) + assert (result["lineage_id"] == "exp_A/1/0_0").all() + + # --------------------------------------------------------------------------- # OPS builder helpers (tests 11–14) # --------------------------------------------------------------------------- @@ -336,19 +438,29 @@ class TestCrossParadigm: def test_timelapse_has_null_ops_columns(self): """15. Time-lapse parquet has OPS columns as null.""" df = _make_timelapse_df() - for col in CELL_INDEX_OPS_COLUMNS + CELL_INDEX_BIOLOGY_COLUMNS + CELL_INDEX_IMAGING_COLUMNS: + for col in ( + CELL_INDEX_OPS_COLUMNS + + CELL_INDEX_BIOLOGY_COLUMNS + + CELL_INDEX_IMAGING_COLUMNS + + CELL_INDEX_NORMALIZATION_COLUMNS + ): df[col] = None warnings = validate_cell_index(df, strict=True) - ops_warnings = [w for w in warnings if any(c in w for c in CELL_INDEX_OPS_COLUMNS)] + ops_warnings = [w for w in warnings if any(f"'{c}'" in w for c in CELL_INDEX_OPS_COLUMNS)] assert len(ops_warnings) == len(CELL_INDEX_OPS_COLUMNS) def test_ops_has_null_timelapse_columns(self): """16. OPS parquet has time-lapse columns as null.""" df = _make_ops_df() - for col in CELL_INDEX_TIMELAPSE_COLUMNS + CELL_INDEX_BIOLOGY_COLUMNS + CELL_INDEX_IMAGING_COLUMNS: + for col in ( + CELL_INDEX_TIMELAPSE_COLUMNS + + CELL_INDEX_BIOLOGY_COLUMNS + + CELL_INDEX_IMAGING_COLUMNS + + CELL_INDEX_NORMALIZATION_COLUMNS + ): df[col] = None warnings = validate_cell_index(df, strict=True) - tl_warnings = [w for w in warnings if any(c in w for c in CELL_INDEX_TIMELAPSE_COLUMNS)] + tl_warnings = [w for w in warnings if any(f"'{c}'" in w for c in CELL_INDEX_TIMELAPSE_COLUMNS)] assert len(tl_warnings) == len(CELL_INDEX_TIMELAPSE_COLUMNS) def test_concat_schema_compatible(self, tmp_path): diff --git a/packages/viscy-data/tests/test_collection.py b/packages/viscy-data/tests/test_collection.py index 9ca19297e..4cd824ef0 100644 --- a/packages/viscy-data/tests/test_collection.py +++ b/packages/viscy-data/tests/test_collection.py @@ -1,6 +1,7 @@ """Tests for viscy_data.collection: Collection, load/save, build_collection.""" import pytest +import yaml from viscy_data.collection import ( ChannelEntry, @@ -55,16 +56,15 @@ def test_duplicate_experiment_names(self): with pytest.raises(ValueError, match="Duplicate experiment name"): _make_collection(experiments=[exp, exp]) - def test_interval_minutes_not_positive(self): - """Raise ValueError when interval_minutes <= 0.""" + def test_zero_interval_minutes_allowed(self): + """Zero interval_minutes is valid (non-timelapse data).""" exp = _make_experiment(name="exp1", interval_minutes=0.0) - with pytest.raises(ValueError, match="interval_minutes must be positive"): - _make_collection(experiments=[exp]) + _make_collection(experiments=[exp]) def test_negative_interval_minutes(self): """Raise ValueError when interval_minutes is negative.""" exp = _make_experiment(name="exp1", interval_minutes=-5.0) - with pytest.raises(ValueError, match="interval_minutes must be positive"): + with pytest.raises(ValueError, match="interval_minutes must be non-negative"): _make_collection(experiments=[exp]) def test_perturbation_wells_empty(self): @@ -241,3 +241,211 @@ def test_single_marker_dataset_not_split(self): grouped = _group_records(records) assert len(grouped) == 1 assert "plate1" in grouped + + +class TestChannelWells: + """Test per-well channel validity restriction via ChannelEntry.wells.""" + + def _make_viral_sensor_records(self): + """FOVRecords for a mixed plate where viral sensor is only in B/3 and C/2.""" + common = dict( + dataset="2025_01_24", + data_path="/data/2025_01_24.zarr", + tracks_path="/tracks/2025_01_24", + channel_names=["Phase3D", "raw mCherry EX561 EM600-37"], + time_interval_min=15.0, + ) + # B/1, B/2: no viral sensor (channel_markers has no entry for mCherry) + no_sensor = [ + FOVRecord(**common, well_id="B/1", cell_state="uninfected", channel_markers={"Phase3D": "Phase3D"}), + FOVRecord(**common, well_id="B/2", cell_state="uninfected", channel_markers={"Phase3D": "Phase3D"}), + ] + # B/3, C/2: viral sensor present + sensor = [ + FOVRecord( + **common, + well_id="B/3", + cell_state="uninfected", + channel_markers={"Phase3D": "Phase3D", "raw mCherry EX561 EM600-37": "pAL40"}, + ), + FOVRecord( + **common, + well_id="C/2", + cell_state="infected", + channel_markers={"Phase3D": "Phase3D", "raw mCherry EX561 EM600-37": "pAL40"}, + ), + ] + return no_sensor + sensor + + def test_wells_auto_populated_for_partial_channel(self): + """build_collection restricts a channel to wells where it has a marker.""" + records = self._make_viral_sensor_records() + coll = build_collection(records, name="test") + exp = coll.experiments[0] + + phase = next(ch for ch in exp.channels if ch.name == "Phase3D") + mcherry = next(ch for ch in exp.channels if ch.name == "raw mCherry EX561 EM600-37") + + assert phase.wells == [], "Phase3D is valid in all wells — wells must be empty" + assert sorted(mcherry.wells) == ["B/3", "C/2"], "mCherry only valid in B/3, C/2" + + def test_wells_empty_when_all_wells_have_marker(self): + """When all wells share a marker, wells stays empty (no restriction needed).""" + records = [ + FOVRecord( + dataset="exp", + well_id="A/1", + data_path="/d.zarr", + tracks_path="/t", + channel_names=["Phase3D"], + cell_state="uninfected", + channel_markers={"Phase3D": "Phase3D"}, + ), + FOVRecord( + dataset="exp", + well_id="A/2", + data_path="/d.zarr", + tracks_path="/t", + channel_names=["Phase3D"], + cell_state="infected", + channel_markers={"Phase3D": "Phase3D"}, + ), + ] + coll = build_collection(records, name="test") + phase = coll.experiments[0].channels[0] + assert phase.wells == [] + + def test_wells_round_trips_yaml(self, tmp_path): + """wells field survives save_collection → load_collection.""" + records = self._make_viral_sensor_records() + coll = build_collection(records, name="test") + path = tmp_path / "col.yml" + save_collection(coll, path) + loaded = load_collection(path) + mcherry = next(ch for ch in loaded.experiments[0].channels if ch.name == "raw mCherry EX561 EM600-37") + assert sorted(mcherry.wells) == ["B/3", "C/2"] + + def test_channel_entry_wells_default_empty(self): + """ChannelEntry.wells defaults to empty list.""" + ch = ChannelEntry(name="Phase3D", marker="Phase3D") + assert ch.wells == [] + + +def _write_yaml(path, data): + with open(path, "w") as f: + yaml.safe_dump(data, f, default_flow_style=False, sort_keys=False) + + +def _minimal_experiment(name, data_path, tracks_path): + return { + "name": name, + "data_path": data_path, + "tracks_path": tracks_path, + "channels": [{"name": "Phase3D", "marker": "Phase3D"}], + "perturbation_wells": {"mock": ["A/1"]}, + } + + +class TestDatasetsRoot: + """Test ${datasets_root} substitution in load/save round-trip.""" + + def test_resolve_datasets_root(self, tmp_path): + """Paths with ${datasets_root} are fully resolved after load.""" + data = { + "name": "test", + "datasets_root": "/hpc/projects/organelle_phenotyping", + "experiments": [ + _minimal_experiment( + "exp1", + "${datasets_root}/datasets/exp1/exp1.zarr", + "${datasets_root}/datasets/exp1/tracking.zarr", + ) + ], + } + _write_yaml(tmp_path / "col.yml", data) + coll = load_collection(tmp_path / "col.yml") + assert coll.experiments[0].data_path == "/hpc/projects/organelle_phenotyping/datasets/exp1/exp1.zarr" + assert coll.experiments[0].tracks_path == "/hpc/projects/organelle_phenotyping/datasets/exp1/tracking.zarr" + assert coll.datasets_root == "/hpc/projects/organelle_phenotyping" + + def test_round_trip_preserves_templates(self, tmp_path): + """save_collection writes ${datasets_root} back; reload resolves again.""" + data = { + "name": "test", + "datasets_root": "/hpc/projects/organelle_phenotyping", + "experiments": [ + _minimal_experiment( + "exp1", + "${datasets_root}/datasets/exp1/exp1.zarr", + "${datasets_root}/datasets/exp1/tracking.zarr", + ) + ], + } + yaml_path = tmp_path / "col.yml" + _write_yaml(yaml_path, data) + coll = load_collection(yaml_path) + out_path = tmp_path / "col_out.yml" + save_collection(coll, out_path) + + with open(out_path) as f: + on_disk = yaml.safe_load(f) + + assert "${datasets_root}" in on_disk["experiments"][0]["data_path"] + assert "${datasets_root}" in on_disk["experiments"][0]["tracks_path"] + + reloaded = load_collection(out_path) + assert reloaded.experiments[0].data_path == "/hpc/projects/organelle_phenotyping/datasets/exp1/exp1.zarr" + + def test_mixed_paths_non_root_stays_absolute(self, tmp_path): + """Paths not under datasets_root survive save unchanged.""" + data = { + "name": "test", + "datasets_root": "/hpc/projects/organelle_phenotyping", + "experiments": [ + _minimal_experiment( + "exp_vast", + "${datasets_root}/datasets/exp1/exp1.zarr", + "${datasets_root}/datasets/exp1/tracking.zarr", + ), + _minimal_experiment( + "exp_nfs", + "${datasets_root}/datasets/exp2/exp2.zarr", + "/hpc/projects/intracellular_dashboard/viral-sensor/tracking.zarr", + ), + ], + } + yaml_path = tmp_path / "col.yml" + _write_yaml(yaml_path, data) + coll = load_collection(yaml_path) + assert coll.experiments[1].tracks_path == "/hpc/projects/intracellular_dashboard/viral-sensor/tracking.zarr" + + out_path = tmp_path / "col_out.yml" + save_collection(coll, out_path) + with open(out_path) as f: + on_disk = yaml.safe_load(f) + nfs_path = "/hpc/projects/intracellular_dashboard/viral-sensor/tracking.zarr" + assert on_disk["experiments"][1]["tracks_path"] == nfs_path + + def test_no_datasets_root_passthrough(self, tmp_path): + """Collections without datasets_root load and save unchanged.""" + data = { + "name": "test", + "experiments": [ + _minimal_experiment( + "exp1", + "/absolute/data/exp1.zarr", + "/absolute/tracks/exp1", + ) + ], + } + yaml_path = tmp_path / "col.yml" + _write_yaml(yaml_path, data) + coll = load_collection(yaml_path) + assert coll.datasets_root is None + assert coll.experiments[0].data_path == "/absolute/data/exp1.zarr" + + out_path = tmp_path / "col_out.yml" + save_collection(coll, out_path) + with open(out_path) as f: + on_disk = yaml.safe_load(f) + assert on_disk["experiments"][0]["data_path"] == "/absolute/data/exp1.zarr" diff --git a/packages/viscy-data/tests/test_sampler.py b/packages/viscy-data/tests/test_sampler.py index 20571033d..202cb2937 100644 --- a/packages/viscy-data/tests/test_sampler.py +++ b/packages/viscy-data/tests/test_sampler.py @@ -181,6 +181,107 @@ def test_batch_group_by_none_allows_mixing(self, two_experiment_anchors: pd.Data assert any_mixed, "With batch_group_by=None, at least one batch should mix experiments" +# --------------------------------------------------------------------------- +# Marker-aware batching (bag-of-channels regime) +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def multi_marker_anchors() -> pd.DataFrame: + """DataFrame with 1 experiment, 4 markers, 2 conditions, 320 rows total. + + Represents the bag-of-channels regime where each row is one (cell, + timepoint, channel) observation and ``marker`` identifies which + channel/protein the patch came from. + """ + rng = np.random.default_rng(7) + rows = [] + for marker in ["Phase3D", "TOMM20", "SEC61B", "Brightfield"]: + for cond in ["infected", "uninfected"]: + for i in range(40): + rows.append( + { + "experiment": "exp_boc", + "condition": cond, + "marker": marker, + "hours_post_perturbation": rng.uniform(0, 24), + "global_track_id": f"{marker}_{cond}_{i}", + "t": rng.integers(0, 20), + } + ) + df = pd.DataFrame(rows) + return df.reset_index(drop=True) + + +class TestMarkerAware: + """batch_group_by="marker" produces single-marker batches shuffled across markers. + + This is the bag-of-channels training regime — the config asks for one + marker per batch so contrastive pairs stay within the same channel, + while different batches traverse the full marker pool across an + epoch. + """ + + def test_every_batch_is_single_marker(self, multi_marker_anchors: pd.DataFrame): + """Every batch must contain rows from exactly one marker.""" + sampler = FlexibleBatchSampler( + valid_anchors=multi_marker_anchors, + batch_size=16, + batch_group_by="marker", + stratify_by=None, + leaky=0.0, + seed=42, + ) + batches = list(sampler) + assert batches, "Sampler should yield batches" + for batch in batches: + markers = multi_marker_anchors.iloc[batch]["marker"].unique() + assert len(markers) == 1, f"batch_group_by='marker' batch has {len(markers)} markers: {markers}" + + def test_all_markers_appear_across_epoch(self, multi_marker_anchors: pd.DataFrame): + """Across one epoch every marker surfaces in at least one batch.""" + sampler = FlexibleBatchSampler( + valid_anchors=multi_marker_anchors, + batch_size=16, + batch_group_by="marker", + stratify_by=None, + leaky=0.0, + seed=42, + ) + seen: set[str] = set() + for batch in sampler: + seen.update(multi_marker_anchors.iloc[batch]["marker"].unique()) + expected = {"Phase3D", "TOMM20", "SEC61B", "Brightfield"} + assert seen == expected, f"Not all markers surfaced in one epoch: {seen} vs {expected}" + + def test_batches_shuffled_across_markers(self, multi_marker_anchors: pd.DataFrame): + """Consecutive batches should not all be the same marker — the sampler + must interleave marker groups rather than drain them sequentially. + + We require at least half of the marker-to-marker batch transitions + to be a change (pathological samplers that yield all Phase3D + batches first, then all TOMM20, etc., would get a change-ratio + close to ``1/num_batches`` which this threshold catches). + """ + sampler = FlexibleBatchSampler( + valid_anchors=multi_marker_anchors, + batch_size=16, + batch_group_by="marker", + stratify_by=None, + leaky=0.0, + seed=42, + ) + per_batch_marker: list[str] = [] + for batch in sampler: + per_batch_marker.append(multi_marker_anchors.iloc[batch]["marker"].iloc[0]) + transitions = [a != b for a, b in zip(per_batch_marker[:-1], per_batch_marker[1:], strict=False)] + change_ratio = sum(transitions) / len(transitions) + assert change_ratio >= 0.5, ( + f"Only {change_ratio:.1%} of consecutive batches changed marker; " + "sampler appears to drain groups sequentially instead of shuffling" + ) + + # --------------------------------------------------------------------------- # Stratified sampling (SAMP-02) # --------------------------------------------------------------------------- @@ -391,6 +492,22 @@ def test_set_epoch_same_epoch_same_result(self, two_experiment_anchors: pd.DataF batches_b = list(sampler) assert batches_a == batches_b + def test_iter_auto_advances_epoch(self, two_experiment_anchors: pd.DataFrame): + """Consecutive iterations must yield different sequences without set_epoch. + + PL does not call ``set_epoch`` on ``batch_sampler`` instances, so the + sampler must self-advance. Regression guard for the frozen-dataset bug. + """ + sampler = FlexibleBatchSampler( + valid_anchors=two_experiment_anchors, + batch_size=8, + batch_group_by="experiment", + stratify_by=None, + leaky=0.0, + seed=42, + ) + assert list(sampler) != list(sampler) + # --------------------------------------------------------------------------- # __len__ and __iter__ protocol diff --git a/packages/viscy-transforms/src/viscy_transforms/__init__.py b/packages/viscy-transforms/src/viscy_transforms/__init__.py index 560110d4c..af8a24a1e 100644 --- a/packages/viscy-transforms/src/viscy_transforms/__init__.py +++ b/packages/viscy-transforms/src/viscy_transforms/__init__.py @@ -70,12 +70,18 @@ from viscy_transforms._sharpen import BatchedRandSharpend from viscy_transforms._stack_channels import BatchedStackChannelsd, StackChannelsd from viscy_transforms._tiled_crop import TiledSpatialCropSamplesd +from viscy_transforms._z_reduction import ( + BatchedChannelWiseZReduction, + BatchedChannelWiseZReductiond, +) from viscy_transforms._zoom import BatchedZoom, BatchedZoomd from viscy_transforms._zstack_shift import BatchedRandZStackShiftd __version__ = version("viscy-transforms") __all__ = [ + "BatchedChannelWiseZReduction", + "BatchedChannelWiseZReductiond", "BatchedCenterSpatialCrop", "BatchedCenterSpatialCropd", "BatchedDivisibleCropd", diff --git a/packages/viscy-transforms/src/viscy_transforms/_z_reduction.py b/packages/viscy-transforms/src/viscy_transforms/_z_reduction.py new file mode 100644 index 000000000..398b08d05 --- /dev/null +++ b/packages/viscy-transforms/src/viscy_transforms/_z_reduction.py @@ -0,0 +1,117 @@ +"""Channel-wise Z-reduction transforms for 2D training from 3D z-stacks.""" + +from __future__ import annotations + +from collections.abc import Hashable + +import torch +from monai.transforms import MapTransform +from torch import Tensor + +__all__ = ["BatchedChannelWiseZReduction", "BatchedChannelWiseZReductiond"] + + +class BatchedChannelWiseZReduction: + """Reduce the Z dimension of a ``(B, C, Z, Y, X)`` tensor. + + Label-free samples get the center z-slice; fluorescence samples get a + max-intensity projection (MIP). A per-sample boolean mask selects the + strategy when the batch mixes both types. + + Parameters + ---------- + default_strategy : str + Strategy when no mask is provided: ``"mip"`` or ``"center"``. + """ + + def __init__(self, default_strategy: str = "mip") -> None: + if default_strategy not in ("mip", "center"): + raise ValueError(f"default_strategy must be 'mip' or 'center', got '{default_strategy}'") + self.default_strategy = default_strategy + + def __call__(self, img: Tensor, is_labelfree: Tensor | None = None) -> Tensor: + """Apply z-reduction. + + Parameters + ---------- + img : Tensor + Shape ``(B, C, Z, Y, X)``. + is_labelfree : Tensor or None + Boolean tensor of shape ``(B,)``. ``True`` → center-slice, + ``False`` → MIP. When ``None``, ``default_strategy`` is used + uniformly. + + Returns + ------- + Tensor + Shape ``(B, C, 1, Y, X)``. + """ + z = img.shape[2] + if z == 1: + return img + + if is_labelfree is None: + if self.default_strategy == "center": + return img[:, :, z // 2 : z // 2 + 1] + return img.amax(dim=2, keepdim=True) + + center = img[:, :, z // 2 : z // 2 + 1] + mip = img.amax(dim=2, keepdim=True) + mask = is_labelfree.view(-1, 1, 1, 1, 1) + return torch.where(mask, center, mip) + + +class BatchedChannelWiseZReductiond(MapTransform): + """Dict transform that applies channel-wise Z-reduction. + + In **bag-of-channels mode** each sample may represent a different channel. + The transform reads a ``_is_labelfree`` boolean tensor from the data dict + (injected by the datamodule) to decide per-sample strategy. + + In **all-channels mode** the dict keys identify channel type. Pass + ``labelfree_keys`` to specify which keys should use center-slice; all + others get MIP. + + Parameters + ---------- + keys : KeysCollection + Keys of the image tensors to transform. + labelfree_keys : list[str] or None + Channel keys that should use center-slice (all-channels mode). + When set, ``_is_labelfree`` in the data dict is ignored. + default_strategy : str + Fallback strategy when neither ``labelfree_keys`` nor + ``_is_labelfree`` can determine the channel type. + allow_missing_keys : bool + If ``True``, skip keys not present in the data dict. + """ + + def __init__( + self, + keys, + labelfree_keys: list[str] | None = None, + default_strategy: str = "mip", + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) + self.labelfree_keys = set(labelfree_keys) if labelfree_keys is not None else None + self.reducer = BatchedChannelWiseZReduction(default_strategy=default_strategy) + + def __call__(self, data: dict[Hashable, Tensor]) -> dict[Hashable, Tensor]: + is_labelfree = data.pop("_is_labelfree", None) + + for key in self.key_iterator(data): + if self.labelfree_keys is not None: + # All-channels mode: strategy determined by key name. + img = data[key] + z = img.shape[2] + if z == 1: + continue + if key in self.labelfree_keys: + data[key] = img[:, :, z // 2 : z // 2 + 1] + else: + data[key] = img.amax(dim=2, keepdim=True) + else: + data[key] = self.reducer(data[key], is_labelfree=is_labelfree) + + return data diff --git a/packages/viscy-transforms/tests/test_z_reduction.py b/packages/viscy-transforms/tests/test_z_reduction.py new file mode 100644 index 000000000..91fc9d2ef --- /dev/null +++ b/packages/viscy-transforms/tests/test_z_reduction.py @@ -0,0 +1,128 @@ +import torch + +from viscy_transforms import BatchedChannelWiseZReduction, BatchedChannelWiseZReductiond + + +def _make_img(B=4, C=1, Z=11, Y=8, X=8): + """Create a test image with distinct z-slices for easy verification.""" + img = torch.randn(B, C, Z, Y, X) + return img + + +class TestBatchedChannelWiseZReduction: + def test_mip_only(self): + img = _make_img() + reducer = BatchedChannelWiseZReduction(default_strategy="mip") + out = reducer(img) + assert out.shape == (4, 1, 1, 8, 8) + expected = img.amax(dim=2, keepdim=True) + torch.testing.assert_close(out, expected) + + def test_center_only(self): + img = _make_img() + reducer = BatchedChannelWiseZReduction(default_strategy="center") + out = reducer(img) + assert out.shape == (4, 1, 1, 8, 8) + expected = img[:, :, 5:6] + torch.testing.assert_close(out, expected) + + def test_mixed_mask(self): + img = _make_img() + mask = torch.tensor([True, False, True, False]) + reducer = BatchedChannelWiseZReduction() + out = reducer(img, is_labelfree=mask) + assert out.shape == (4, 1, 1, 8, 8) + center = img[:, :, 5:6] + mip = img.amax(dim=2, keepdim=True) + torch.testing.assert_close(out[0], center[0]) + torch.testing.assert_close(out[1], mip[1]) + torch.testing.assert_close(out[2], center[2]) + torch.testing.assert_close(out[3], mip[3]) + + def test_noop_z1(self): + img = _make_img(Z=1) + reducer = BatchedChannelWiseZReduction() + out = reducer(img) + assert out.shape == img.shape + torch.testing.assert_close(out, img) + + def test_invalid_strategy(self): + try: + BatchedChannelWiseZReduction(default_strategy="invalid") + assert False, "Should have raised ValueError" + except ValueError: + pass + + +class TestBatchedChannelWiseZReductiond: + def test_bag_of_channels_with_mask(self): + data = { + "channel_0": _make_img(), + "_is_labelfree": torch.tensor([True, False, False, True]), + } + transform = BatchedChannelWiseZReductiond(keys=["channel_0"]) + out = transform(data) + assert out["channel_0"].shape == (4, 1, 1, 8, 8) + assert "_is_labelfree" not in out + + def test_all_channels_with_labelfree_keys(self): + phase_img = _make_img() + fluor_img = _make_img() + expected_center = phase_img[:, :, 5:6].clone() + expected_mip = fluor_img.amax(dim=2, keepdim=True) + data = {"Phase3D": phase_img, "TOMM20": fluor_img} + transform = BatchedChannelWiseZReductiond( + keys=["Phase3D", "TOMM20"], + labelfree_keys=["Phase3D"], + ) + out = transform(data) + assert out["Phase3D"].shape == (4, 1, 1, 8, 8) + assert out["TOMM20"].shape == (4, 1, 1, 8, 8) + torch.testing.assert_close(out["Phase3D"], expected_center) + torch.testing.assert_close(out["TOMM20"], expected_mip) + + def test_pops_is_labelfree(self): + data = { + "channel_0": _make_img(), + "_is_labelfree": torch.tensor([False, False, False, False]), + } + transform = BatchedChannelWiseZReductiond(keys=["channel_0"]) + out = transform(data) + assert "_is_labelfree" not in out + + def test_missing_keys(self): + data = {"channel_0": _make_img()} + transform = BatchedChannelWiseZReductiond( + keys=["channel_0", "channel_1"], + allow_missing_keys=True, + ) + out = transform(data) + assert out["channel_0"].shape == (4, 1, 1, 8, 8) + assert "channel_1" not in out + + def test_noop_z1_dict(self): + data = {"channel_0": _make_img(Z=1)} + transform = BatchedChannelWiseZReductiond(keys=["channel_0"]) + out = transform(data) + assert out["channel_0"].shape == (4, 1, 1, 8, 8) + + def test_no_mask_uses_default(self): + img = _make_img() + expected = img[:, :, 5:6].clone() + data = {"channel_0": img} + transform = BatchedChannelWiseZReductiond(keys=["channel_0"], default_strategy="center") + out = transform(data) + torch.testing.assert_close(out["channel_0"], expected) + + def test_labelfree_keys_noop_z1(self): + data = { + "Phase3D": _make_img(Z=1), + "TOMM20": _make_img(Z=1), + } + transform = BatchedChannelWiseZReductiond( + keys=["Phase3D", "TOMM20"], + labelfree_keys=["Phase3D"], + ) + out = transform(data) + torch.testing.assert_close(out["Phase3D"], data["Phase3D"]) + torch.testing.assert_close(out["TOMM20"], data["TOMM20"]) diff --git a/packages/viscy-utils/pyproject.toml b/packages/viscy-utils/pyproject.toml index dde0f28ba..cfde97577 100644 --- a/packages/viscy-utils/pyproject.toml +++ b/packages/viscy-utils/pyproject.toml @@ -35,6 +35,8 @@ dependencies = [ "lightning>=2.3", "matplotlib>=3.10", "numpy>=2.4.1", + "onnx", + "onnxscript", "pyyaml", "scikit-image", "scipy", @@ -48,6 +50,7 @@ dependencies = [ optional-dependencies.all = [ "viscy-utils[anndata,eval]" ] optional-dependencies.anndata = [ "anndata", "natsort" ] optional-dependencies.eval = [ + "copairs", "phate", "scikit-learn", "umap-learn", diff --git a/packages/viscy-utils/src/viscy_utils/callbacks/__init__.py b/packages/viscy-utils/src/viscy_utils/callbacks/__init__.py index 9a49d6db2..6e41540e4 100644 --- a/packages/viscy-utils/src/viscy_utils/callbacks/__init__.py +++ b/packages/viscy-utils/src/viscy_utils/callbacks/__init__.py @@ -2,12 +2,10 @@ from viscy_utils.callbacks.embedding_writer import EmbeddingWriter from viscy_utils.callbacks.online_eval import OnlineEvalCallback from viscy_utils.callbacks.prediction_writer import HCSPredictionWriter -from viscy_utils.callbacks.save_config_wandb import SaveConfigToWandb __all__ = [ "EmbeddingSnapshotCallback", "EmbeddingWriter", "OnlineEvalCallback", "HCSPredictionWriter", - "SaveConfigToWandb", ] diff --git a/packages/viscy-utils/src/viscy_utils/callbacks/embedding_writer.py b/packages/viscy-utils/src/viscy_utils/callbacks/embedding_writer.py index 7784a25f4..373507b8f 100644 --- a/packages/viscy-utils/src/viscy_utils/callbacks/embedding_writer.py +++ b/packages/viscy-utils/src/viscy_utils/callbacks/embedding_writer.py @@ -156,8 +156,15 @@ def write_embedding_dataset( ultrack_indices = index_df.copy() ultrack_indices["fov_name"] = ultrack_indices["fov_name"].str.strip("/") - for col in ultrack_indices.select_dtypes(include="string").columns: - ultrack_indices[col] = ultrack_indices[col].astype(object) + # TODO: remove once anndata 0.13 supports pandas 3 Arrow-backed strings natively. + # anndata 0.12.9+ requires pandas <3, so we stay on 0.12.6 + pandas 3 and + # must manually downcast ArrowStringArray columns to object dtype before writing. + for col in ultrack_indices.columns: + s = ultrack_indices[col] + if isinstance(s.dtype, pd.StringDtype): + ultrack_indices[col] = s.astype(object) + elif hasattr(s, "cat") and isinstance(s.cat.categories.dtype, pd.StringDtype): + ultrack_indices[col] = s.cat.rename_categories(s.cat.categories.astype(object)) if embedding_key == "projections": if projections is None: diff --git a/packages/viscy-utils/src/viscy_utils/callbacks/online_eval.py b/packages/viscy-utils/src/viscy_utils/callbacks/online_eval.py index bf4045454..d16a068c4 100644 --- a/packages/viscy-utils/src/viscy_utils/callbacks/online_eval.py +++ b/packages/viscy-utils/src/viscy_utils/callbacks/online_eval.py @@ -21,7 +21,7 @@ from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.callbacks import Callback from scipy.stats import spearmanr -from sklearn.model_selection import cross_val_score +from sklearn.model_selection import cross_val_score, train_test_split from sklearn.neighbors import KNeighborsClassifier from viscy_data._typing import TripletSample @@ -48,8 +48,24 @@ def effective_rank(features: np.ndarray) -> float: float Effective rank (scalar >= 1). """ + # Guard against NaN/Inf in features — np.linalg.svd raises + # "SVD did not converge" on non-finite input, which crashes the whole + # run from inside a validation callback. Drop affected rows and return + # NaN when no finite rows remain. + finite_mask = np.isfinite(features).all(axis=1) + if not finite_mask.all(): + _logger.warning( + "effective_rank: %d/%d rows contain NaN/Inf; skipping those", + (~finite_mask).sum(), + len(features), + ) + features = features[finite_mask] + if features.shape[0] < 2: + return float("nan") _, s, _ = np.linalg.svd(features, full_matrices=False) s = s[s > 1e-10] + if s.size == 0: + return float("nan") p = s / s.sum() entropy = -(p * np.log(p)).sum() return float(np.exp(entropy)) @@ -114,7 +130,8 @@ class OnlineEvalCallback(Callback): Accumulates validation embeddings every ``every_n_epochs`` epochs and computes three metrics: - - ``metrics/knn_acc/{label_key}/val`` — k-NN accuracy (5-fold CV) + - ``metrics/knn_acc/{label_key}/val`` — k-NN accuracy (5-fold CV or + stratified holdout, configurable via ``knn_eval_mode``) - ``metrics/effective_rank/val`` — effective rank of covariance - ``metrics/temporal_smoothness/val`` — Spearman rho (distance vs dt) @@ -133,6 +150,15 @@ class OnlineEvalCallback(Callback): Metadata key for track identity (temporal smoothness). timepoint_key : str Metadata key for timepoint (temporal smoothness). + knn_eval_mode : {"cv", "holdout"} + How to score the k-NN probe. ``"cv"`` runs 5-fold stratified CV + (default; good for few-class probes like 40 markers). ``"holdout"`` + runs a single stratified 80/20 train/test split — ~5x cheaper and + tolerates classes with only 2 samples, which is the right choice + for many-class probes (e.g. 1001-gene perturbation). + holdout_test_size : float + Fraction of samples held out for scoring when + ``knn_eval_mode="holdout"``. Ignored in CV mode. """ def __init__( @@ -142,6 +168,8 @@ def __init__( k: int = 20, track_id_key: str = "global_track_id", timepoint_key: str = "t", + knn_eval_mode: Literal["cv", "holdout"] = "cv", + holdout_test_size: float = 0.2, ): super().__init__() self.every_n_epochs = every_n_epochs @@ -149,6 +177,8 @@ def __init__( self.k = k self.track_id_key = track_id_key self.timepoint_key = timepoint_key + self.knn_eval_mode = knn_eval_mode + self.holdout_test_size = holdout_test_size self._collecting = False self._features: list[torch.Tensor] = [] self._meta: list[dict] = [] @@ -212,10 +242,36 @@ def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) if labels is not None and len(np.unique(labels)) >= 2: k = min(self.k, n_samples - 1) knn = KNeighborsClassifier(n_neighbors=k, metric="cosine") - cv_folds = min(5, min(np.bincount(labels))) - if cv_folds >= 2: + min_class_count = int(min(np.bincount(labels))) + mode = self.knn_eval_mode + # Auto-degrade CV -> holdout when the smallest class has < 2 + # samples (CV would skip silently). Holdout mode still requires + # >= 2 per class for stratified splitting. + if mode == "cv" and min_class_count < 2: + mode = "holdout" + if mode == "cv": + cv_folds = min(5, min_class_count) scores = cross_val_score(knn, features_np, labels, cv=cv_folds) knn_acc = float(scores.mean()) + eval_desc = f"cv={cv_folds}" + elif mode == "holdout" and min_class_count >= 2: + x_train, x_test, y_train, y_test = train_test_split( + features_np, + labels, + test_size=self.holdout_test_size, + stratify=labels, + random_state=0, + ) + knn.fit(x_train, y_train) + knn_acc = float(knn.score(x_test, y_test)) + eval_desc = f"holdout={self.holdout_test_size:.2f}" + else: + knn_acc = None + _logger.debug( + f"[OnlineEval epoch {epoch}] Skipping k-NN: " + f"smallest class has {min_class_count} samples (need >=2)." + ) + if knn_acc is not None: pl_module.log( f"metrics/knn_acc/{self.label_key}/val", knn_acc, @@ -223,14 +279,7 @@ def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) logger=True, rank_zero_only=True, ) - _logger.info( - f"[OnlineEval epoch {epoch}] knn_acc({self.label_key}, k={k})={knn_acc:.3f} (cv={cv_folds})" - ) - else: - _logger.debug( - f"[OnlineEval epoch {epoch}] Skipping k-NN: " - f"smallest class has {min(np.bincount(labels))} samples (need >=2)." - ) + _logger.info(f"[OnlineEval epoch {epoch}] knn_acc({self.label_key}, k={k})={knn_acc:.3f} ({eval_desc})") # --- Temporal smoothness (requires track_id + timepoint) --- track_ids = self._extract_array(self.track_id_key, source="meta") diff --git a/packages/viscy-utils/src/viscy_utils/callbacks/save_config_wandb.py b/packages/viscy-utils/src/viscy_utils/callbacks/save_config_wandb.py deleted file mode 100644 index cec542678..000000000 --- a/packages/viscy-utils/src/viscy_utils/callbacks/save_config_wandb.py +++ /dev/null @@ -1,39 +0,0 @@ -"""Save resolved Lightning config to W&B files.""" - -from __future__ import annotations - -import logging -from pathlib import Path - -from lightning.pytorch import Callback, Trainer -from lightning.pytorch.loggers import WandbLogger - -logger = logging.getLogger(__name__) - - -class SaveConfigToWandb(Callback): - """Upload the resolved config.yaml to W&B so it appears in the Files tab. - - Lightning's SaveConfigCallback writes config.yaml to ``trainer.log_dir``, - but WandbLogger does not sync arbitrary files from that directory. - This callback copies it into the W&B run's files directory on fit start. - """ - - def setup(self, trainer: Trainer, pl_module, stage: str) -> None: - """Copy config.yaml to W&B run files on fit start.""" - if stage != "fit": - return - wandb_logger = None - for lg in trainer.loggers: - if isinstance(lg, WandbLogger): - wandb_logger = lg - break - if wandb_logger is None: - return - config_path = Path(trainer.log_dir) / "config.yaml" - if not config_path.exists(): - logger.debug("No config.yaml found at %s, skipping W&B upload.", config_path) - return - run = wandb_logger.experiment - run.save(str(config_path), base_path=str(config_path.parent), policy="now") - logger.info("Uploaded %s to W&B run %s.", config_path, run.id) diff --git a/packages/viscy-utils/src/viscy_utils/cli.py b/packages/viscy-utils/src/viscy_utils/cli.py index 1babc02aa..f753b9734 100644 --- a/packages/viscy-utils/src/viscy_utils/cli.py +++ b/packages/viscy-utils/src/viscy_utils/cli.py @@ -13,9 +13,8 @@ import yaml from jsonargparse import Namespace, lazy_instance from lightning.pytorch import LightningDataModule, LightningModule -from lightning.pytorch.callbacks import TQDMProgressBar from lightning.pytorch.cli import LightningCLI -from lightning.pytorch.loggers import TensorBoardLogger +from lightning.pytorch.loggers import WandbLogger from viscy_utils.compose import load_composed_config from viscy_utils.trainer import VisCyTrainer @@ -84,18 +83,12 @@ def subcommands() -> dict[str, set[str]]: return subcommands def add_arguments_to_parser(self, parser) -> None: - """Set default logger and progress bar.""" - defaults = { - "trainer.logger": lazy_instance( - TensorBoardLogger, - save_dir="", - version=datetime.now().strftime(r"%Y%m%d-%H%M%S"), - log_graph=True, - ), - } - if not sys.stdout.isatty(): - defaults["trainer.callbacks"] = [lazy_instance(TQDMProgressBar, refresh_rate=10, leave=True)] - parser.set_defaults(defaults) + """Set default logger.""" + parser.set_defaults( + { + "trainer.logger": lazy_instance(WandbLogger), + } + ) def _parse_ckpt_path(self) -> None: try: diff --git a/packages/viscy-utils/src/viscy_utils/cli_utils.py b/packages/viscy-utils/src/viscy_utils/cli_utils.py index 78903f48b..b72837bc4 100644 --- a/packages/viscy-utils/src/viscy_utils/cli_utils.py +++ b/packages/viscy-utils/src/viscy_utils/cli_utils.py @@ -2,12 +2,8 @@ from pathlib import Path -import yaml - -def format_markdown_table( - data: dict | list[dict], title: str = None, headers: list[str] = None -) -> str: +def format_markdown_table(data: dict | list[dict], title: str = None, headers: list[str] = None) -> str: """Format data as a markdown table. Parameters @@ -90,9 +86,52 @@ def load_config(config_path: str | Path) -> dict: yaml.YAMLError If the YAML file is malformed. """ + from viscy_utils.compose import load_composed_config + config_path = Path(config_path) if not config_path.exists(): raise FileNotFoundError(f"Config file not found: {config_path}") + return load_composed_config(config_path) + + +def load_config_section(config_path: str | Path, section: str | None, default_section: str | None = None) -> dict: + """Load a YAML config file, optionally selecting a subsection. + + This enables reusing a single YAML file for multiple CLI steps by storing + per-command configuration under a top-level key (``section``), while keeping + shared keys (e.g., ``datasets``) at the root. + + Parameters + ---------- + config_path : str | Path + Path to YAML configuration file. + section : str | None + If provided, selects ``config[section]`` and merges in any shared root + keys that are not already present in the section. + default_section : str | None + If ``section`` is None and ``default_section`` exists in the YAML, that section is used. - with open(config_path, "r") as f: - return yaml.safe_load(f) + Returns + ------- + dict + Configuration dictionary (either full or merged subsection). + """ + cfg = load_config(config_path) + if section is None: + if default_section is None or default_section not in cfg: + return cfg + section = default_section + + if section not in cfg: + raise KeyError(f"Config section not found: {section}") + + section_cfg = cfg[section] or {} + if not isinstance(section_cfg, dict): + raise TypeError(f"Config section must be a mapping: {section}") + + merged = dict(section_cfg) + for k, v in cfg.items(): + if k == section: + continue + merged.setdefault(k, v) + return merged diff --git a/packages/viscy-utils/src/viscy_utils/evaluation/annotation.py b/packages/viscy-utils/src/viscy_utils/evaluation/annotation.py index 91a5af9c3..c3a0e4566 100644 --- a/packages/viscy-utils/src/viscy_utils/evaluation/annotation.py +++ b/packages/viscy-utils/src/viscy_utils/evaluation/annotation.py @@ -129,9 +129,24 @@ def load_annotation_anndata(adata: ad.AnnData, path: str, name: str, categories: annotation = pd.read_csv(path) annotation["fov_name"] = annotation["fov_name"].str.strip("/") - annotation = annotation.set_index(["fov_name", "id"]) - - mi = pd.MultiIndex.from_arrays([adata.obs["fov_name"], adata.obs["id"]], names=["fov_name", "id"]) + # Normalize obs fov_name: strip leading/trailing slashes so both sides match. + obs_fov = adata.obs["fov_name"].astype(object).str.strip("/") + + if "id" in adata.obs.columns and "id" in annotation.columns: + annotation = annotation.set_index(["fov_name", "id"]) + mi = pd.MultiIndex.from_arrays([obs_fov, adata.obs["id"]], names=["fov_name", "id"]) + elif all(c in adata.obs.columns for c in ("fov_name", "t", "track_id")) and all( + c in annotation.columns for c in ("fov_name", "t", "track_id") + ): + annotation = annotation.set_index(["fov_name", "t", "track_id"]) + mi = pd.MultiIndex.from_arrays( + [obs_fov, adata.obs["t"], adata.obs["track_id"]], + names=["fov_name", "t", "track_id"], + ) + else: + raise KeyError( + "Cannot join annotations: embeddings have neither (fov_name, id) nor (fov_name, t, track_id) columns." + ) # Use reindex to handle missing annotations gracefully # This will return NaN for observations that don't have annotations diff --git a/packages/viscy-utils/src/viscy_utils/evaluation/dimensionality_reduction.py b/packages/viscy-utils/src/viscy_utils/evaluation/dimensionality_reduction.py index bbcf690a8..5a3450208 100644 --- a/packages/viscy-utils/src/viscy_utils/evaluation/dimensionality_reduction.py +++ b/packages/viscy-utils/src/viscy_utils/evaluation/dimensionality_reduction.py @@ -18,6 +18,10 @@ def compute_phate( knn_dist: str = "cosine", update_dataset: bool = False, random_state: int = 42, + n_pca: int = 50, + subsample: int | None = None, + lineage_ids: NDArray | None = None, + n_jobs: int = 1, **phate_kwargs, ) -> tuple[object, NDArray]: """Compute PHATE embeddings. @@ -66,17 +70,38 @@ def compute_phate( else: embeddings_scaled = embeddings + import numpy as np + phate_model = phate.PHATE( n_components=n_components, knn=knn, decay=decay, knn_dist=knn_dist, random_state=random_state, - n_jobs=-1, + n_jobs=n_jobs, + n_pca=n_pca, **phate_kwargs, ) - phate_embedding = phate_model.fit_transform(embeddings_scaled) + n_samples = embeddings_scaled.shape[0] + if subsample is not None and subsample < n_samples: + rng = np.random.default_rng(random_state) + if lineage_ids is not None: + unique_lineages = np.unique(lineage_ids) + n_lineages = min(subsample, len(unique_lineages)) + chosen_lineages = rng.choice(unique_lineages, size=n_lineages, replace=False) + idx = np.where(np.isin(lineage_ids, chosen_lineages))[0] + _logger.info( + f"PHATE: fitting on {len(idx):,} cells ({n_lineages:,} lineages) " + f"/ {n_samples:,} total, projecting the rest" + ) + else: + idx = rng.choice(n_samples, size=subsample, replace=False) + _logger.info(f"PHATE: fitting on {subsample:,} / {n_samples:,} cells, projecting the rest") + phate_model.fit(embeddings_scaled[idx]) + phate_embedding = phate_model.transform(embeddings_scaled) + else: + phate_embedding = phate_model.fit_transform(embeddings_scaled) if update_dataset and isinstance(embedding_dataset, Dataset): for i in range(min(2, phate_embedding.shape[1])): diff --git a/packages/viscy-utils/src/viscy_utils/evaluation/embedding_map.py b/packages/viscy-utils/src/viscy_utils/evaluation/embedding_map.py new file mode 100644 index 000000000..7952738c1 --- /dev/null +++ b/packages/viscy-utils/src/viscy_utils/evaluation/embedding_map.py @@ -0,0 +1,120 @@ +"""Embedding-level mean Average Precision (mAP) via copairs.""" + +from __future__ import annotations + +import numpy as np +import pandas as pd + + +def compute_embedding_map( + meta: pd.DataFrame, + features: np.ndarray, + reference_condition: str, + target_condition: str, + condition_col: str = "condition", + group_col: str = "marker", + distance: str = "cosine", + null_size: int = 10000, + seed: int = 0, +) -> dict | None: + """Compute mean Average Precision for embedding-space phenotypic profiling. + + Uses ``copairs`` to compute per-cell Average Precision (AP) between a + reference and target condition, then aggregates to mAP per group. Positive + pairs share the same group and condition; negative pairs share only the group + but differ in condition. + + Parameters + ---------- + meta : pd.DataFrame + Cell metadata, one row per cell. Must contain ``condition_col`` and + ``group_col`` columns. + features : np.ndarray + Embedding matrix, shape (n_cells, n_features). Rows correspond to + ``meta`` rows. + reference_condition : str + Value of ``condition_col`` for the reference/control group (``cond_a``). + target_condition : str + Value of ``condition_col`` for the treatment group (``cond_b``). + condition_col : str + Column in ``meta`` that holds condition labels. Default: ``"condition"``. + group_col : str + Column in ``meta`` that holds group labels (e.g. marker/organelle). + Default: ``"marker"``. + distance : str + Distance metric for copairs (e.g. ``"cosine"``). Default: ``"cosine"``. + null_size : int + Number of null pairs for the mAP significance test. Default: 10000. + seed : int + Random seed. Default: 0. + + Returns + ------- + dict or None + ``{"mean_average_precision": float, "p_value": float, + "n_reference": int, "n_target": int}`` or ``None`` if either condition + has no cells. + """ + try: + import copairs.map + import copairs.matching + except ImportError as e: + raise ImportError("copairs is required for mAP computation. Install it with: pip install copairs") from e + + mask_ref = meta[condition_col] == reference_condition + mask_tgt = meta[condition_col] == target_condition + mask = mask_ref | mask_tgt + + if mask_ref.sum() == 0 or mask_tgt.sum() == 0: + return None + + sub_meta = meta[mask].reset_index(drop=True) + sub_feats = features[mask.values] + + reference_col = "reference_index" + sub_meta = sub_meta.copy() + sub_meta[reference_col] = copairs.matching.assign_reference_index( + sub_meta, reference_condition, condition_col, group_col + ) + + pos_sameby = [group_col, condition_col, reference_col] + neg_sameby = [group_col] + neg_diffby = [condition_col, reference_col] + + ap_df = copairs.map.average_precision( + sub_meta, + sub_feats, + pos_sameby=pos_sameby, + neg_sameby=neg_sameby, + neg_diffby=neg_diffby, + batch_size=20000, + distance=distance, + ) + + target_ap = ap_df[sub_meta[condition_col] == target_condition] + if len(target_ap) == 0: + return None + + map_result = copairs.map.mean_average_precision( + target_ap, + sameby=[group_col], + null_size=null_size, + threshold=0.05, + seed=seed, + ) + + if hasattr(map_result, "mean_average_precision"): + mmap = float(map_result.mean_average_precision.iloc[0]) + pval = float(map_result.p_value.iloc[0]) if "p_value" in map_result.columns else float("nan") + elif isinstance(map_result, dict): + mmap = float(map_result.get("mean_average_precision", float("nan"))) + pval = float(map_result.get("p_value", float("nan"))) + else: + return None + + return { + "mean_average_precision": mmap, + "p_value": pval, + "n_reference": int(mask_ref.sum()), + "n_target": int(mask_tgt.sum()), + } diff --git a/packages/viscy-utils/src/viscy_utils/evaluation/linear_classifier.py b/packages/viscy-utils/src/viscy_utils/evaluation/linear_classifier.py index 9bdc0bd35..d0518a80b 100644 --- a/packages/viscy-utils/src/viscy_utils/evaluation/linear_classifier.py +++ b/packages/viscy-utils/src/viscy_utils/evaluation/linear_classifier.py @@ -203,7 +203,7 @@ def train_linear_classifier( classifier_params: Optional[dict[str, Any]] = None, split_train_data: float = 0.8, random_seed: int = 42, -) -> tuple[LinearClassifierPipeline, dict[str, float]]: +) -> tuple[LinearClassifierPipeline, dict[str, float], dict[str, Any]]: """Train a linear classifier on embeddings with preprocessing and evaluation. Parameters @@ -231,6 +231,9 @@ def train_linear_classifier( Trained classifier pipeline with preprocessing. dict Dictionary of evaluation metrics (train and validation if split). + dict + Raw validation outputs for plotting: ``y_val``, ``y_val_proba``, + ``classes``. Values are ``None`` when no validation split was made. """ print("\n" + "=" * 60) print("TRAINING CLASSIFIER") @@ -316,6 +319,7 @@ def train_linear_classifier( train_metrics[f"train_{class_name}_f1"] = train_report[class_name]["f1-score"] val_metrics = {} + y_val_proba: Optional[np.ndarray] = None if X_val is not None and y_val is not None: y_val_pred = classifier.predict(X_val) val_report = classification_report(y_val, y_val_pred, digits=3, output_dict=True) @@ -336,6 +340,15 @@ def train_linear_classifier( else: val_metrics["val_auroc"] = roc_auc_score(y_val, y_val_proba, multi_class="ovr", average="macro") print(f" Val AUROC: {val_metrics['val_auroc']:.3f}") + + if len(classifier.classes_) > 2: + for i, class_name in enumerate(classifier.classes_): + try: + val_metrics[f"val_{class_name}_auroc"] = roc_auc_score( + (y_val == class_name).astype(int), y_val_proba[:, i] + ) + except ValueError: + pass except ValueError as e: _logger.warning(f"Could not compute val AUROC (likely only one class present): {e}") @@ -365,7 +378,13 @@ def train_linear_classifier( task=task, ) - return pipeline, all_metrics + val_outputs: dict[str, Any] = { + "y_val": y_val, + "y_val_proba": y_val_proba, + "classes": classifier.classes_.tolist(), + } + + return pipeline, all_metrics, val_outputs def predict_with_classifier( diff --git a/packages/viscy-utils/src/viscy_utils/evaluation/mmd.py b/packages/viscy-utils/src/viscy_utils/evaluation/mmd.py new file mode 100644 index 000000000..d911c0d5a --- /dev/null +++ b/packages/viscy-utils/src/viscy_utils/evaluation/mmd.py @@ -0,0 +1,217 @@ +"""Maximum Mean Discrepancy (MMD) with Gaussian RBF kernel and permutation test.""" + +import numpy as np +from numpy.typing import NDArray +from scipy.spatial.distance import cdist + + +def median_heuristic(X: NDArray, Y: NDArray, subsample: int = 1000) -> float: + """Compute Gaussian RBF bandwidth via the median heuristic. + + Subsamples jointly from X and Y, computes all pairwise squared Euclidean + distances, and returns the median. This is the standard bandwidth selection + for MMD tests (Gretton et al., 2012). + + Parameters + ---------- + X : NDArray + Samples from distribution P, shape (n, d). + Y : NDArray + Samples from distribution Q, shape (m, d). + subsample : int + Max samples to draw from the joint (X, Y) pool for median computation. + + Returns + ------- + float + Bandwidth sigma^2 for the Gaussian RBF kernel. + """ + rng = np.random.default_rng(0) + pool = np.concatenate([X, Y], axis=0).astype(np.float32) + if len(pool) > subsample: + idx = rng.choice(len(pool), subsample, replace=False) + pool = pool[idx] + sq_dists = cdist(pool, pool, metric="sqeuclidean") + upper = sq_dists[np.triu_indices_from(sq_dists, k=1)] + return float(np.median(upper)) + 1e-12 + + +def gaussian_rbf_kernel(X: NDArray, Y: NDArray, bandwidth: float) -> NDArray: + """Compute Gaussian RBF kernel matrix K(X, Y) in float32. + + K(x, y) = exp(-||x - y||^2 / (2 * bandwidth)) + + Parameters + ---------- + X : NDArray + Shape (n, d). + Y : NDArray + Shape (m, d). + bandwidth : float + Kernel bandwidth (sigma^2). Must be > 0. + + Returns + ------- + NDArray + Kernel matrix, shape (n, m), float32. + """ + sq_dists = cdist(X.astype(np.float32), Y.astype(np.float32), metric="sqeuclidean") + return np.exp(-sq_dists / (2.0 * bandwidth), dtype=np.float32) + + +def compute_mmd_unbiased(X: NDArray, Y: NDArray, bandwidth: float | None = None) -> float: + """Compute the unbiased quadratic-time MMD^2 estimator. + + MMD^2_u = (1/(n(n-1))) sum_{i!=j} k(x_i, x_j) + + (1/(m(m-1))) sum_{i!=j} k(y_i, y_j) + - (2/(nm)) sum_{i,j} k(x_i, y_j) + + Parameters + ---------- + X : NDArray + Samples from distribution P, shape (n, d). + Y : NDArray + Samples from distribution Q, shape (m, d). + bandwidth : float or None + Gaussian RBF bandwidth. None = median heuristic. + + Returns + ------- + float + Unbiased MMD^2 estimate. + """ + if bandwidth is None: + bandwidth = median_heuristic(X, Y) + n = len(X) + m = len(Y) + K_XX = gaussian_rbf_kernel(X, X, bandwidth) + K_YY = gaussian_rbf_kernel(Y, Y, bandwidth) + K_XY = gaussian_rbf_kernel(X, Y, bandwidth) + np.fill_diagonal(K_XX, 0.0) + np.fill_diagonal(K_YY, 0.0) + mmd2 = K_XX.sum() / (n * (n - 1)) + K_YY.sum() / (m * (m - 1)) - 2.0 * K_XY.mean() + return float(mmd2) + + +def _mmd2_from_kernel(K_pool: NDArray, n: int, perm: NDArray) -> float: + """Compute unbiased MMD^2 from a pre-computed pooled kernel matrix. + + Parameters + ---------- + K_pool : NDArray + Full pooled kernel matrix, shape (n+m, n+m). + n : int + Number of samples in X (first group). + perm : NDArray + Permutation index array of length n+m. + + Returns + ------- + float + Unbiased MMD^2 for this permutation. + """ + m = len(perm) - n + ix = perm[:n] + iy = perm[n:] + K_XX = K_pool[np.ix_(ix, ix)] + K_YY = K_pool[np.ix_(iy, iy)] + K_XY = K_pool[np.ix_(ix, iy)] + # Unbiased: zero diagonal contribution + kxx = (K_XX.sum() - K_XX.trace()) / (n * (n - 1)) + kyy = (K_YY.sum() - K_YY.trace()) / (m * (m - 1)) + kxy = K_XY.mean() + return float(kxx + kyy - 2.0 * kxy) + + +def mmd_permutation_test( + X: NDArray, + Y: NDArray, + n_permutations: int = 1000, + bandwidth: float | None = None, + seed: int = 42, +) -> tuple[float, float, NDArray]: + """MMD^2 with vectorized permutation test for significance. + + Precomputes the pooled kernel matrix K_pool once, then all permutations + are evaluated via vectorized row/column sums — no repeated cdist calls + and no Python loop over individual permutations. + + Strategy: for each permutation p, MMD^2 = sum_X/n(n-1) + sum_Y/m(m-1) - 2*mean_XY + where sum_X = sum of K_pool[ix,ix] off-diagonal = (K_pool[ix,:] * one_hot_X).sum(). + We represent each permutation as a binary label vector z in {0,1}^(n+m), + then use K_pool @ z and K_pool @ (1-z) to get row sums in O(n_perm * N) ops. + + Parameters + ---------- + X : NDArray + Samples from distribution P, shape (n, d). + Y : NDArray + Samples from distribution Q, shape (m, d). + n_permutations : int + Number of permutations for the null distribution. + bandwidth : float or None + Gaussian RBF bandwidth. None = median heuristic (computed once). + seed : int + Random seed for reproducibility. + + Returns + ------- + mmd2 : float + Observed MMD^2 (unbiased). + p_value : float + Permutation test p-value. + null_distribution : NDArray + Null MMD^2 values from permutations, shape (n_permutations,). + """ + if bandwidth is None: + bandwidth = median_heuristic(X, Y) + n = len(X) + m = len(Y) + N = n + m + pool = np.concatenate([X, Y], axis=0).astype(np.float32) + # Compute full pooled kernel matrix once: (N, N) float32 + K = gaussian_rbf_kernel(pool, pool, bandwidth) + np.fill_diagonal(K, 0.0) + + def _mmd2_from_labels(z: NDArray) -> NDArray: + """Vectorized MMD^2 for a batch of label vectors. + + Parameters + ---------- + z : NDArray + Shape (n_perm, N), float32, 1 = assigned to X group. + + Returns + ------- + NDArray + MMD^2 values, shape (n_perm,). + """ + nz = z.sum(axis=1) # actual n per permutation (n_perm,) + mz = N - nz # actual m per permutation + # Row sums of K restricted to X-group and Y-group + # K @ z.T -> (N, n_perm), then z @ (K @ z.T) -> (n_perm, n_perm) diagonal = sum_XX + KzT = K @ z.T # (N, n_perm) + sum_XX = (z * KzT.T).sum(axis=1) # (n_perm,) — within-X kernel sums (diagonal zeroed) + sum_YY = ((1 - z) * (K @ (1 - z).T).T).sum(axis=1) # (n_perm,) — within-Y + sum_XY = (z * (K @ (1 - z).T).T).sum(axis=1) # (n_perm,) — cross + kxx = sum_XX / (nz * (nz - 1)) + kyy = sum_YY / (mz * (mz - 1)) + kxy = sum_XY / (nz * mz) + return kxx + kyy - 2.0 * kxy + + # Observed: original split (first n are X) + z_obs = np.zeros((1, N), dtype=np.float32) + z_obs[0, :n] = 1.0 + observed = float(_mmd2_from_labels(z_obs)[0]) + + # Null: random permutations as binary label vectors + rng = np.random.default_rng(seed) + # Generate all permutation indices at once + perms = np.stack([rng.permutation(N) for _ in range(n_permutations)]) # (n_perm, N) + z_null = np.zeros((n_permutations, N), dtype=np.float32) + row_idx = np.arange(n_permutations)[:, None] + z_null[row_idx, perms[:, :n]] = 1.0 + + null = _mmd2_from_labels(z_null) + p_value = float((np.sum(null >= observed) + 1) / (n_permutations + 1)) + return observed, p_value, null diff --git a/packages/viscy-utils/src/viscy_utils/evaluation/zarr_utils.py b/packages/viscy-utils/src/viscy_utils/evaluation/zarr_utils.py index a6e0aefe2..e4566b029 100644 --- a/packages/viscy-utils/src/viscy_utils/evaluation/zarr_utils.py +++ b/packages/viscy-utils/src/viscy_utils/evaluation/zarr_utils.py @@ -7,6 +7,7 @@ import pandas as pd import zarr from anndata.io import write_elem +from pandas.arrays import ArrowStringArray def append_to_anndata_zarr( @@ -31,12 +32,25 @@ def append_to_anndata_zarr( obs : pd.DataFrame, optional Observation metadata. Replaces the entire ``obs`` group. uns : dict, optional - Unstructured annotation. Replaces the entire ``uns`` group. + Mapping of uns keys to values. Each key is written to ``uns/{key}``, + replacing any existing entry while preserving other uns keys. """ store = zarr.open(str(zarr_path), mode="a", use_consolidated=False) ad.settings.allow_write_nullable_strings = True if obs is not None: + # TODO: remove once anndata 0.13 supports pandas 3 Arrow-backed strings natively. + # anndata 0.12.9+ requires pandas <3, so we stay on 0.12.6 + pandas 3 and + # must manually downcast ArrowStringArray columns to object dtype before writing. + obs = obs.copy() + for col in obs.columns: + arr = obs[col].array + if isinstance(arr, ArrowStringArray): + obs[col] = obs[col].astype(object) + elif isinstance(arr, pd.Categorical) and isinstance(arr.categories._values, ArrowStringArray): + obs[col] = obs[col].cat.rename_categories(arr.categories.astype(object)) + if isinstance(obs.index._values, ArrowStringArray): + obs.index = obs.index.astype(object) if "obs" in store: del store["obs"] write_elem(store, "obs", obs) @@ -49,9 +63,13 @@ def append_to_anndata_zarr( write_elem(store, obsm_path, value) if uns is not None: - if "uns" in store: - del store["uns"] - write_elem(store, "uns", uns) + if "uns" not in store: + store.create_group("uns") + for key, value in uns.items(): + uns_path = f"uns/{key}" + if uns_path in store: + del store[uns_path] + write_elem(store, uns_path, value) zarr.consolidate_metadata(str(zarr_path)) diff --git a/packages/viscy-utils/tests/test_linear_classifier.py b/packages/viscy-utils/tests/test_linear_classifier.py index aad22f43d..efcd356b8 100644 --- a/packages/viscy-utils/tests/test_linear_classifier.py +++ b/packages/viscy-utils/tests/test_linear_classifier.py @@ -42,11 +42,13 @@ def synthetic_adata_with_unknowns(): class TestLinearClassifierPipeline: @pytest.fixture def trained_pipeline(self, annotated_adata): - pipeline, _ = train_linear_classifier(annotated_adata, task="cell_death_state", use_scaling=True, use_pca=False) + pipeline, _, _ = train_linear_classifier( + annotated_adata, task="cell_death_state", use_scaling=True, use_pca=False + ) return pipeline def test_transform_with_scaler_and_pca(self, annotated_adata): - pipeline, _ = train_linear_classifier( + pipeline, _, _ = train_linear_classifier( annotated_adata, task="cell_death_state", use_scaling=True, @@ -58,7 +60,7 @@ def test_transform_with_scaler_and_pca(self, annotated_adata): assert X_transformed.shape == (X.shape[0], 5) def test_transform_scaler_only(self, annotated_adata): - pipeline, _ = train_linear_classifier( + pipeline, _, _ = train_linear_classifier( annotated_adata, task="cell_death_state", use_scaling=True, @@ -70,7 +72,7 @@ def test_transform_scaler_only(self, annotated_adata): assert pipeline.pca is None def test_transform_no_preprocessing(self, annotated_adata): - pipeline, _ = train_linear_classifier( + pipeline, _, _ = train_linear_classifier( annotated_adata, task="cell_death_state", use_scaling=False, @@ -94,18 +96,18 @@ def test_predict_proba_shape(self, trained_pipeline, annotated_adata): class TestTrainLinearClassifier: def test_train_basic(self, annotated_adata): - pipeline, metrics = train_linear_classifier(annotated_adata, task="cell_death_state") + pipeline, metrics, _ = train_linear_classifier(annotated_adata, task="cell_death_state") assert isinstance(pipeline, LinearClassifierPipeline) assert isinstance(metrics, dict) assert "train_accuracy" in metrics assert "train_weighted_f1" in metrics def test_train_with_scaling(self, annotated_adata): - pipeline, _ = train_linear_classifier(annotated_adata, task="cell_death_state", use_scaling=True) + pipeline, _, _ = train_linear_classifier(annotated_adata, task="cell_death_state", use_scaling=True) assert pipeline.scaler is not None def test_train_with_pca(self, annotated_adata): - pipeline, _ = train_linear_classifier( + pipeline, _, _ = train_linear_classifier( annotated_adata, task="cell_death_state", use_pca=True, @@ -115,26 +117,26 @@ def test_train_with_pca(self, annotated_adata): assert pipeline.pca.n_components == 5 def test_train_no_split(self, annotated_adata): - pipeline, metrics = train_linear_classifier(annotated_adata, task="cell_death_state", split_train_data=1.0) + pipeline, metrics, _ = train_linear_classifier(annotated_adata, task="cell_death_state", split_train_data=1.0) assert "train_accuracy" in metrics assert "val_accuracy" not in metrics def test_train_metrics_keys(self, annotated_adata): - _, metrics = train_linear_classifier(annotated_adata, task="cell_death_state", split_train_data=0.8) + _, metrics, _ = train_linear_classifier(annotated_adata, task="cell_death_state", split_train_data=0.8) assert "train_accuracy" in metrics assert "train_weighted_f1" in metrics for class_name in ["alive", "dead", "apoptotic"]: assert f"train_{class_name}_f1" in metrics def test_train_reproducibility(self, annotated_adata): - _, metrics_a = train_linear_classifier(annotated_adata, task="cell_death_state", random_seed=123) - _, metrics_b = train_linear_classifier(annotated_adata, task="cell_death_state", random_seed=123) + _, metrics_a, _ = train_linear_classifier(annotated_adata, task="cell_death_state", random_seed=123) + _, metrics_b, _ = train_linear_classifier(annotated_adata, task="cell_death_state", random_seed=123) assert metrics_a == metrics_b def test_train_sparse_matrix(self, annotated_adata): sparse_adata = annotated_adata.copy() sparse_adata.X = scipy.sparse.csr_matrix(sparse_adata.X) - pipeline, metrics = train_linear_classifier(sparse_adata, task="cell_death_state") + pipeline, metrics, _ = train_linear_classifier(sparse_adata, task="cell_death_state") assert isinstance(pipeline, LinearClassifierPipeline) assert "train_accuracy" in metrics @@ -142,7 +144,7 @@ def test_train_sparse_matrix(self, annotated_adata): class TestPredictWithClassifier: @pytest.fixture def pipeline_and_adata(self, annotated_adata): - pipeline, _ = train_linear_classifier(annotated_adata, task="cell_death_state") + pipeline, _, _ = train_linear_classifier(annotated_adata, task="cell_death_state") return pipeline, annotated_adata def test_predict_adds_obs_columns(self, pipeline_and_adata): diff --git a/uv.lock b/uv.lock index 10657c08b..6211a6573 100644 --- a/uv.lock +++ b/uv.lock @@ -2,14 +2,22 @@ version = 1 revision = 3 requires-python = ">=3.12, <3.14" resolution-markers = [ - "python_full_version >= '3.13' and sys_platform == 'win32'", - "python_full_version < '3.13' and sys_platform == 'win32'", - "python_full_version >= '3.13' and sys_platform == 'emscripten'", - "python_full_version < '3.13' and sys_platform == 'emscripten'", - "python_full_version >= '3.13' and sys_platform == 'linux'", - "python_full_version < '3.13' and sys_platform == 'linux'", - "python_full_version >= '3.13' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version < '3.13' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version >= '3.13' and platform_machine != 's390x' and sys_platform == 'win32'", + "python_full_version >= '3.13' and platform_machine == 's390x' and sys_platform == 'win32'", + "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'win32'", + "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'win32'", + "python_full_version >= '3.13' and platform_machine != 's390x' and sys_platform == 'emscripten'", + "python_full_version >= '3.13' and platform_machine == 's390x' and sys_platform == 'emscripten'", + "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'emscripten'", + "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'emscripten'", + "python_full_version >= '3.13' and platform_machine != 's390x' and sys_platform == 'linux'", + "python_full_version >= '3.13' and platform_machine == 's390x' and sys_platform == 'linux'", + "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'linux'", + "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'linux'", + "python_full_version >= '3.13' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version >= '3.13' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32'", ] [manifest] @@ -684,6 +692,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/80/99/2adc7d8ffead633234817ef8e9a87115c8a11927a94478f6bb3d3f4d4f7d/contourpy-1.3.3-cp313-cp313t-win_arm64.whl", hash = "sha256:3c30273eb2a55024ff31ba7d052dde990d7d8e5450f4bbb6e913558b3d6c2301", size = 199713, upload-time = "2025-07-26T12:02:14.4Z" }, ] +[[package]] +name = "copairs" +version = "0.5.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "duckdb" }, + { name = "pandas" }, + { name = "statsmodels" }, + { name = "tqdm" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/65/25/7e2b2327ce9b3a7312be41070f264a09761fccb146cf60206d27c50e24b6/copairs-0.5.4.tar.gz", hash = "sha256:4d821784fa42d388db66e6a90c4ca1849c79957059260655faa884ffe6559648", size = 41895, upload-time = "2026-01-27T12:21:07.836Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b7/2a/86a6255d7e892419833ba5951f7574d02c9c83648cd939bb5a921e386858/copairs-0.5.4-py3-none-any.whl", hash = "sha256:e24e41ffdcfabf8d76b4288423f8951ea9c69884d5c4e88f8d9d33ff1ee32bbf", size = 34092, upload-time = "2026-01-27T12:21:06.368Z" }, +] + [[package]] name = "coverage" version = "7.13.4" @@ -759,7 +782,7 @@ name = "cuda-bindings" version = "12.9.4" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "cuda-pathfinder", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" }, + { name = "cuda-pathfinder", marker = "sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/a9/c1/dabe88f52c3e3760d861401bb994df08f672ec893b8f7592dc91626adcf3/cuda_bindings-12.9.4-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fda147a344e8eaeca0c6ff113d2851ffca8f7dfc0a6c932374ee5c47caa649c8", size = 12151019, upload-time = "2025-10-21T14:51:43.167Z" }, @@ -990,6 +1013,47 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0c/d5/c5db1ea3394c6e1732fb3286b3bd878b59507a8f77d32a2cebda7d7b7cd4/donfig-0.8.1.post1-py3-none-any.whl", hash = "sha256:2a3175ce74a06109ff9307d90a230f81215cbac9a751f4d1c6194644b8204f9d", size = 21592, upload-time = "2024-05-23T14:13:55.283Z" }, ] +[[package]] +name = "dtaidistance" +version = "2.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/cd/01/aa26cc97b64d397ff03b9576b0a04cc79d0e3bae512eb087cfab7d98f4ec/dtaidistance-2.4.0.tar.gz", hash = "sha256:bd4066800254fbd5b620e6462bb759c9d85b79ac2080b354cedc901f446b6c82", size = 1316462, upload-time = "2026-02-12T22:23:56.35Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ec/63/c1546dc5a4a98f77ca044206e8d8b7604349d36d0b76d5c03ab393a55e60/dtaidistance-2.4.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:64d54f910b53cd7a56b215e06d2b24b22090af836102d48558d3e9569ded2b66", size = 2124723, upload-time = "2026-02-12T22:23:39.482Z" }, + { url = "https://files.pythonhosted.org/packages/ad/9a/4c0cb726c3c93436c993f55fc59d5fd2142c1a0fe6fe9ec06cc7bf25ab15/dtaidistance-2.4.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3afb229f4524f8bbf835a5dc3e07abcee9b6b9c6af4f14436cad19639102243c", size = 1549051, upload-time = "2026-02-13T08:14:46.866Z" }, + { url = "https://files.pythonhosted.org/packages/f5/8e/ccdd057e4ff71cf0b6fe34220cbd214d469f831b45acbbb4366fdfef6330/dtaidistance-2.4.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:349d6765e10ddbb5e22e937cf1bc42394f5f8d36bc127f8af24a0cd0259f4804", size = 4361729, upload-time = "2026-02-12T22:23:41.184Z" }, + { url = "https://files.pythonhosted.org/packages/00/cf/ef215e8864c21eb14872f98987d9736ebbbe5049d429039e2a93adcacad4/dtaidistance-2.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:6ab9431a5b66aafd37ab4dfcfe563b66694ed192019c1632d2de7a431a883bcd", size = 1443363, upload-time = "2026-02-12T22:23:43.706Z" }, + { url = "https://files.pythonhosted.org/packages/87/89/c64eea692eae3b269719ee5173bf5008b5c165280248e3fad1948c765a2b/dtaidistance-2.4.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:4cf41f3edcc4c1b94ebbc1de029ee9b58da28f33f7bf3af89212cc05e35ec8f1", size = 2117805, upload-time = "2026-02-12T22:23:45.632Z" }, + { url = "https://files.pythonhosted.org/packages/ed/7f/06ce3d5ce51a959be0534584ad2556e6c8be966ef1218a866c6c3d62e3c5/dtaidistance-2.4.0-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:94b841d6575e3ad715b4e213f0f04de25e23c2da3ac21ee9c6775b38f5bdfecf", size = 1738478, upload-time = "2026-02-13T08:14:51.715Z" }, + { url = "https://files.pythonhosted.org/packages/db/8e/6c8a5c7710f9f5e3805281974ce8fea4ad0334c00a1e0f977977c045a594/dtaidistance-2.4.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b0f2a65628aea82175e7f8c5e96faf5372c933ed40e2e39a84957d8fe305158d", size = 4341606, upload-time = "2026-02-12T22:23:46.977Z" }, + { url = "https://files.pythonhosted.org/packages/9f/b6/7f77c6773380742660d09f379d43814b448296fc24c3fb1de15a3d813311/dtaidistance-2.4.0-cp313-cp313-win_amd64.whl", hash = "sha256:b8c9ef4c7270d1a192e8f1b481c2e10e63c33c6e7edfc507acac7f3fdc19949f", size = 1441578, upload-time = "2026-02-12T22:23:48.897Z" }, +] + +[[package]] +name = "duckdb" +version = "1.5.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0c/66/744b4931b799a42f8cb9bc7a6f169e7b8e51195b62b246db407fd90bf15f/duckdb-1.5.2.tar.gz", hash = "sha256:638da0d5102b6cb6f7d47f83d0600708ac1d3cb46c5e9aaabc845f9ba4d69246", size = 18017166, upload-time = "2026-04-13T11:30:09.065Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/41/de/ebe66bbe78125fc610f4fd415447a65349d94245950f3b3dfb31d028af02/duckdb-1.5.2-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:e6495b00cad16888384119842797c49316a96ae1cb132bb03856d980d95afee1", size = 30064950, upload-time = "2026-04-13T11:29:11.468Z" }, + { url = "https://files.pythonhosted.org/packages/2d/8a/3e25b5d03bcf1fb99d189912f8ce92b1db4f9c8778e1b1f55745973a855a/duckdb-1.5.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d72b8856b1839d35648f38301b058f6232f4d36b463fe4dc8f4d3fdff2df1a2e", size = 15969113, upload-time = "2026-04-13T11:29:14.139Z" }, + { url = "https://files.pythonhosted.org/packages/19/bb/58001f0815002b1a93431bf907f77854085c7d049b83d521814a07b9db0b/duckdb-1.5.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2a1de4f4d454b8c97aec546c82003fc834d3422ce4bc6a19902f3462ef293bed", size = 14224774, upload-time = "2026-04-13T11:29:16.758Z" }, + { url = "https://files.pythonhosted.org/packages/d3/2f/a7f0de9509d1cef35608aeb382919041cdd70f58c173865c3da6a0d87979/duckdb-1.5.2-cp312-cp312-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ce0b8141a10d37ecef729c45bc41d334854013f4389f1488bd6035c5579aaac1", size = 19313510, upload-time = "2026-04-13T11:29:19.574Z" }, + { url = "https://files.pythonhosted.org/packages/26/78/eb1e064ea8b9df3b87b167bfd7a407b2f615a4291e06cba756727adfa06c/duckdb-1.5.2-cp312-cp312-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c99ef73a277c8921bc0a1f16dee38d924484251d9cfd20951748c20fcd5ed855", size = 21429692, upload-time = "2026-04-13T11:29:22.575Z" }, + { url = "https://files.pythonhosted.org/packages/5b/12/05b0c47d14839925c5e35b79081d918ca82e3f236bb724a6f58409dd5291/duckdb-1.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:8d599758b4e48bf12e18c9b960cf491d219f0c4972d19a45489c05cc5ab36f83", size = 13107594, upload-time = "2026-04-13T11:29:25.43Z" }, + { url = "https://files.pythonhosted.org/packages/0b/2c/80558a82b236e044330e84a154b96aacddb343316b479f3d49be03ea11cb/duckdb-1.5.2-cp312-cp312-win_arm64.whl", hash = "sha256:fc85a5dbcbe6eccac1113c72370d1d3aacfdd49198d63950bdf7d8638a307f00", size = 13927537, upload-time = "2026-04-13T11:29:27.842Z" }, + { url = "https://files.pythonhosted.org/packages/98/f2/e3d742808f138d374be4bb516fade3d1f33749b813650810ab7885cdc363/duckdb-1.5.2-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:4420b3f47027a7849d0e1815532007f377fa95ee5810b47ea717d35525c12f79", size = 30064879, upload-time = "2026-04-13T11:29:30.763Z" }, + { url = "https://files.pythonhosted.org/packages/72/0d/f3dc1cf97e1267ca15e4307d456f96ce583961f0703fd75e62b2ad8d64fa/duckdb-1.5.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:bb42e6ed543902e14eae647850da24103a89f0bc2587dec5601b1c1f213bd2ed", size = 15969327, upload-time = "2026-04-13T11:29:33.481Z" }, + { url = "https://files.pythonhosted.org/packages/b1/e0/d5418def53ae4e05a63075705ff44ed5af5a1a5932627eb2b600c5df1c93/duckdb-1.5.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:98c0535cd6d901f61a5ea3c2e26a1fd28482953d794deb183daf568e3aa5dda6", size = 14225107, upload-time = "2026-04-13T11:29:35.882Z" }, + { url = "https://files.pythonhosted.org/packages/16/a7/15aaa59dbecc35e9711980fcdbf525b32a52470b32d18ef678193a146213/duckdb-1.5.2-cp313-cp313-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:486c862bf7f163c0110b6d85b3e5c031d224a671cca468f12ebb1d3a348f6b39", size = 19313433, upload-time = "2026-04-13T11:29:38.367Z" }, + { url = "https://files.pythonhosted.org/packages/bd/21/d903cc63a5140c822b7b62b373a87dc557e60c29b321dfb435061c5e67cf/duckdb-1.5.2-cp313-cp313-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:70631c847ca918ee710ec874241b00cf9d2e5be90762cbb2a0389f17823c08f7", size = 21429837, upload-time = "2026-04-13T11:29:41.135Z" }, + { url = "https://files.pythonhosted.org/packages/e3/0a/b770d1f60c70597302130d6247f418549b7094251a02348fbaf1c7e147ae/duckdb-1.5.2-cp313-cp313-win_amd64.whl", hash = "sha256:52a21823f3fbb52f0f0e5425e20b07391ad882464b955879499b5ff0b45a376b", size = 13107699, upload-time = "2026-04-13T11:29:43.905Z" }, + { url = "https://files.pythonhosted.org/packages/d9/cf/e200fe431d700962d1a908d2ce89f53ccee1cc8db260174ae663ba09686b/duckdb-1.5.2-cp313-cp313-win_arm64.whl", hash = "sha256:411ad438bd4140f189a10e7f515781335962c5d18bd07837dc6d202e3985253d", size = 13927646, upload-time = "2026-04-13T11:29:46.598Z" }, +] + [[package]] name = "dynacell" source = { editable = "applications/dynacell" } @@ -1054,6 +1118,7 @@ dependencies = [ [package.optional-dependencies] eval = [ { name = "anndata" }, + { name = "dtaidistance" }, { name = "natsort" }, { name = "phate" }, { name = "scikit-learn" }, @@ -1062,6 +1127,13 @@ eval = [ { name = "umap-learn" }, { name = "wandb" }, ] +tracking = [ + { name = "gurobipy" }, + { name = "onnxruntime-gpu" }, + { name = "py-ctcmetrics" }, + { name = "tabulate" }, + { name = "tracksdata" }, +] [package.dev-dependencies] dev = [ @@ -1087,15 +1159,21 @@ test = [ requires-dist = [ { name = "anndata", marker = "extra == 'eval'" }, { name = "click" }, - { name = "iohub", specifier = ">=0.3a2" }, + { name = "dtaidistance", marker = "extra == 'eval'" }, + { name = "gurobipy", marker = "extra == 'tracking'", specifier = ">=12.0.1,<13.0.0" }, + { name = "iohub", specifier = ">=0.3.3" }, { name = "natsort", marker = "extra == 'eval'" }, + { name = "onnxruntime-gpu", marker = "extra == 'tracking'" }, { name = "phate", marker = "extra == 'eval'" }, + { name = "py-ctcmetrics", marker = "extra == 'tracking'" }, { name = "pytorch-metric-learning" }, { name = "pyyaml" }, { name = "scikit-learn", marker = "extra == 'eval'" }, { name = "seaborn", marker = "extra == 'eval'" }, { name = "statsmodels", marker = "extra == 'eval'" }, + { name = "tabulate", marker = "extra == 'tracking'" }, { name = "torchvision" }, + { name = "tracksdata", marker = "extra == 'tracking'" }, { name = "umap-learn", marker = "extra == 'eval'" }, { name = "viscy-data", extras = ["triplet"], editable = "packages/viscy-data" }, { name = "viscy-models", editable = "packages/viscy-models" }, @@ -1103,7 +1181,7 @@ requires-dist = [ { name = "viscy-utils", extras = ["eval"], editable = "packages/viscy-utils" }, { name = "wandb", marker = "extra == 'eval'" }, ] -provides-extras = ["eval"] +provides-extras = ["eval", "tracking"] [package.metadata.requires-dev] dev = [ @@ -1256,7 +1334,7 @@ requires-dist = [ { name = "dask", extras = ["array"] }, { name = "eet-features", editable = "../../../../../home/eduardo.hirata/repos/eet_features" }, { name = "eet-inference", editable = "../../../../../home/eduardo.hirata/repos/eet_inference" }, - { name = "iohub", specifier = ">=0.3a2" }, + { name = "iohub", specifier = ">=0.3.3" }, { name = "napari", extras = ["pyqt5"] }, { name = "napari-geff", editable = "../../../../../home/eduardo.hirata/repos/napari-geff" }, { name = "pyyaml" }, @@ -1379,6 +1457,14 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b1/24/f4ed44e103ee7ec9880c43bb06a9d60eab5f06d80022f83005c67304655d/fill_voids-2.1.1-cp313-cp313-win_amd64.whl", hash = "sha256:976f6a3c5a68f3f3483da779d8c71f11e8e3eec4c104d0d594ba5cd11a36a7fa", size = 181694, upload-time = "2025-09-03T05:28:19.728Z" }, ] +[[package]] +name = "flatbuffers" +version = "25.12.19" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e8/2d/d2a548598be01649e2d46231d151a6c56d10b964d94043a335ae56ea2d92/flatbuffers-25.12.19-py2.py3-none-any.whl", hash = "sha256:7634f50c427838bb021c2d66a3d1168e9d199b0607e6329399f04846d42e20b4", size = 26661, upload-time = "2025-12-19T23:16:13.622Z" }, +] + [[package]] name = "flexcache" version = "0.3" @@ -1956,7 +2042,7 @@ wheels = [ [[package]] name = "iohub" -version = "0.3.2" +version = "0.3.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "blosc2" }, @@ -1974,9 +2060,9 @@ dependencies = [ { name = "zarr" }, { name = "zarrs" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/4c/d8/601a5a2d648370cd90825e3c51bc26155b223443ed936ec8e6d62135c871/iohub-0.3.2.tar.gz", hash = "sha256:54eb5a146efbc94375e2f40f51a98234a33a2e3820de7745fdcf06f33d86fef0", size = 289875, upload-time = "2026-04-10T04:16:50.515Z" } +sdist = { url = "https://files.pythonhosted.org/packages/d0/81/4400daf22b508a237bbe05a58320886b1549bde3cb41eaaf46c4c777f355/iohub-0.3.3.tar.gz", hash = "sha256:8190d3155a5dee0e0b98416970b648008b9d4e86e42a84e32078b04284a6b66e", size = 290272, upload-time = "2026-04-24T00:13:38.941Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/80/cd/7c3fd9ecc51b598468d7702b69ab738d5cf9dbe6ffd22f96f26b2aca4ceb/iohub-0.3.2-py3-none-any.whl", hash = "sha256:6808e979ea229e569a627636a073c41e85c0f0224936ddaa952a30b157b79810", size = 84464, upload-time = "2026-04-10T04:16:49.111Z" }, + { url = "https://files.pythonhosted.org/packages/a0/4f/cf3443512b38501677649f79fb7524b35cb1f26b0238d116ed0407e162bb/iohub-0.3.3-py3-none-any.whl", hash = "sha256:b0eb7781ae076bbd3db7143cb8482612fd414016c589a1beec98bb6de9da1173", size = 84950, upload-time = "2026-04-24T00:13:37.499Z" }, ] [[package]] @@ -3267,7 +3353,7 @@ name = "nvidia-cudnn-cu12" version = "9.10.2.21" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" }, + { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" }, @@ -3278,7 +3364,7 @@ name = "nvidia-cufft-cu12" version = "11.3.3.83" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" }, @@ -3305,9 +3391,9 @@ name = "nvidia-cusolver-cu12" version = "11.7.3.90" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" }, - { name = "nvidia-cusparse-cu12", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" }, - { name = "nvidia-nvjitlink-cu12", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" }, + { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" }, @@ -3318,7 +3404,7 @@ name = "nvidia-cusparse-cu12" version = "12.5.8.93" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" }, @@ -3383,6 +3469,82 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4f/21/59baa90924b815b70f88045f0b206b7eab0b68b461c0192692486b516ab7/ome_zarr-0.12.2-py3-none-any.whl", hash = "sha256:655fe1b11ca01148603f9931a5b0af31207dfc03a3a35f9b0ab8639790282bbd", size = 41410, upload-time = "2025-08-22T08:57:12.44Z" }, ] +[[package]] +name = "onnx" +version = "1.21.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "ml-dtypes" }, + { name = "numpy" }, + { name = "protobuf" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c5/93/942d2a0f6a70538eea042ce0445c8aefd46559ad153469986f29a743c01c/onnx-1.21.0.tar.gz", hash = "sha256:4d8b67d0aaec5864c87633188b91cc520877477ec0254eda122bef8be43cd764", size = 12074608, upload-time = "2026-03-27T21:33:36.118Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7d/ae/cb644ec84c25e63575d9d8790fdcc5d1a11d67d3f62f872edb35fa38d158/onnx-1.21.0-cp312-abi3-macosx_12_0_universal2.whl", hash = "sha256:fc2635400fe39ff37ebc4e75342cc54450eadadf39c540ff132c319bf4960095", size = 17965930, upload-time = "2026-03-27T21:32:48.089Z" }, + { url = "https://files.pythonhosted.org/packages/6f/b6/eeb5903586645ef8a49b4b7892580438741acc3df91d7a5bd0f3a59ea9cb/onnx-1.21.0-cp312-abi3-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9003d5206c01fa2ff4b46311566865d8e493e1a6998d4009ec6de39843f1b59b", size = 17531344, upload-time = "2026-03-27T21:32:50.837Z" }, + { url = "https://files.pythonhosted.org/packages/a7/00/4823f06357892d1e60d6f34e7299d2ba4ed2108c487cc394f7ce85a3ff14/onnx-1.21.0-cp312-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a9261bd580fb8548c9c37b3c6750387eb8f21ea43c63880d37b2c622e1684285", size = 17613697, upload-time = "2026-03-27T21:32:54.222Z" }, + { url = "https://files.pythonhosted.org/packages/23/1d/391f3c567ae068c8ac4f1d1316bae97c9eb45e702f05975fe0e17ad441f0/onnx-1.21.0-cp312-abi3-win32.whl", hash = "sha256:9ea4e824964082811938a9250451d89c4ec474fe42dd36c038bfa5df31993d1e", size = 16287200, upload-time = "2026-03-27T21:32:57.277Z" }, + { url = "https://files.pythonhosted.org/packages/9c/a6/5eefbe5b40ea96de95a766bd2e0e751f35bdea2d4b951991ec9afaa69531/onnx-1.21.0-cp312-abi3-win_amd64.whl", hash = "sha256:458d91948ad9a7729a347550553b49ab6939f9af2cddf334e2116e45467dc61f", size = 16441045, upload-time = "2026-03-27T21:33:00.081Z" }, + { url = "https://files.pythonhosted.org/packages/63/c4/0ed8dc037a39113d2a4d66e0005e07751c299c46b993f1ad5c2c35664c20/onnx-1.21.0-cp312-abi3-win_arm64.whl", hash = "sha256:ca14bc4842fccc3187eb538f07eabeb25a779b39388b006db4356c07403a7bbb", size = 16403134, upload-time = "2026-03-27T21:33:03.987Z" }, + { url = "https://files.pythonhosted.org/packages/f8/89/0e1a9beb536401e2f45ac88735e123f2735e12fc7b56ff6c11727e097526/onnx-1.21.0-cp313-cp313t-macosx_12_0_universal2.whl", hash = "sha256:257d1d1deb6a652913698f1e3f33ef1ca0aa69174892fe38946d4572d89dd94f", size = 17975430, upload-time = "2026-03-27T21:33:07.005Z" }, + { url = "https://files.pythonhosted.org/packages/ec/46/e6dc71a7b3b317265591b20a5f71d0ff5c0d26c24e52283139dc90c66038/onnx-1.21.0-cp313-cp313t-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7cd7cb8f6459311bdb557cbf6c0ccc6d8ace11c304d1bba0a30b4a4688e245f8", size = 17537435, upload-time = "2026-03-27T21:33:09.765Z" }, + { url = "https://files.pythonhosted.org/packages/49/2e/27affcac63eaf2ef183a44fd1a1354b11da64a6c72fe6f3fdcf5571bcee5/onnx-1.21.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7b58a4cfec8d9311b73dc083e4c1fa362069267881144c05139b3eba5dc3a840", size = 17617687, upload-time = "2026-03-27T21:33:12.619Z" }, + { url = "https://files.pythonhosted.org/packages/1c/5c/ac8ed15e941593a3672ce424280b764979026317811f2e8508432bfc3429/onnx-1.21.0-cp313-cp313t-win_amd64.whl", hash = "sha256:1a9baf882562c4cebf79589bebb7cd71a20e30b51158cac3e3bbaf27da6163bd", size = 16449402, upload-time = "2026-03-27T21:33:15.555Z" }, + { url = "https://files.pythonhosted.org/packages/0e/aa/d2231e0dcaad838217afc64c306c8152a080134d2034e247cc973d577674/onnx-1.21.0-cp313-cp313t-win_arm64.whl", hash = "sha256:bba12181566acf49b35875838eba49536a327b2944664b17125577d230c637ad", size = 16408273, upload-time = "2026-03-27T21:33:18.599Z" }, +] + +[[package]] +name = "onnx-ir" +version = "0.2.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "ml-dtypes" }, + { name = "numpy" }, + { name = "onnx" }, + { name = "sympy" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/35/e6/672fefb2f108d077f58181a7babf4c0f8d1182a30353ffc9c79c63afc5ee/onnx_ir-0.2.1.tar.gz", hash = "sha256:8b8b10a93f43e65962104de6070c43c5dacb0e3cdfefc7c8059dd83c9db64f35", size = 144279, upload-time = "2026-04-20T20:21:47.735Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8c/aa/f7a53321c60b9ad9ee184b6018292ed6b5389947592a2c8c09c736bb7f9e/onnx_ir-0.2.1-py3-none-any.whl", hash = "sha256:c7285da889312f91882de2092e298a9eeeefbfc1d1951c49d983992967eb09a7", size = 166792, upload-time = "2026-04-20T20:21:46.357Z" }, +] + +[[package]] +name = "onnxruntime-gpu" +version = "1.25.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "flatbuffers" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "protobuf" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/1d/6d/2c13d3eff74caa9e59820a044a75becd34e9cbeeaf7617ad7679cdb1fdb7/onnxruntime_gpu-1.25.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2f0c36c63c8b0eb4091f2567067f480f66f0aedc189eb009545c98ce7e919056", size = 270342429, upload-time = "2026-04-22T17:28:10.526Z" }, + { url = "https://files.pythonhosted.org/packages/8c/2e/9fc303ae59d4caeb85ec3cea6881b7de8ca1d2a07140fade39913cd7ff10/onnxruntime_gpu-1.25.0-cp312-cp312-win_amd64.whl", hash = "sha256:61178cc4d84f59861714554531e01cccbd33ddf13cc0e87a3adea13b24d297ce", size = 220847708, upload-time = "2026-04-22T17:20:47.993Z" }, + { url = "https://files.pythonhosted.org/packages/f5/15/e63fe7b1abad6884bed07e9bb333e9f0ea48fbb8cbc1ea4a67ee6019d5d0/onnxruntime_gpu-1.25.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e462eb13ee9955117baec4f518916c1e7cb1a96001114105632bc6d454c6aee6", size = 270342324, upload-time = "2026-04-22T17:28:21.142Z" }, + { url = "https://files.pythonhosted.org/packages/21/10/b3533243d062b589d4b1f3ae26584af332c5cde618e7f6f5ff6fabbfd5f2/onnxruntime_gpu-1.25.0-cp313-cp313-win_amd64.whl", hash = "sha256:9a3682158e5e911385252eb95d6332b6f525972746c582e10f8a78213b39e624", size = 220848188, upload-time = "2026-04-22T17:20:56.946Z" }, + { url = "https://files.pythonhosted.org/packages/35/6c/d7706dd1d0eaafdba44d5c89f8d952de41e425a1b0cbd3ecfa60f918c249/onnxruntime_gpu-1.25.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8514b92c5929c953850090d823d018770cba2a971efab5f8f69a3c4280cdc632", size = 270364210, upload-time = "2026-04-22T17:28:33.568Z" }, +] + +[[package]] +name = "onnxscript" +version = "0.7.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "ml-dtypes" }, + { name = "numpy" }, + { name = "onnx" }, + { name = "onnx-ir" }, + { name = "packaging" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9b/99/fd948eba63ba65b52265a4cd09a14f96bb9f5b730fcef58876c4358bf406/onnxscript-0.7.0.tar.gz", hash = "sha256:c95ed7b339b02cface56ee27689565c46612e1fc542c562298dddfdad5268dc5", size = 612032, upload-time = "2026-04-20T17:09:19.775Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b9/ce/2ed92575cc3be4ea1db5f38f16f20765f9b20b69b14d6c1d9972658a8ee9/onnxscript-0.7.0-py3-none-any.whl", hash = "sha256:5b356907d4501e9919f8599c91d8da967406a37b1fac2b40caa55a49acf242ea", size = 714842, upload-time = "2026-04-20T17:09:22.089Z" }, +] + [[package]] name = "opencv-python-headless" version = "4.13.0.92" @@ -3856,6 +4018,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8e/37/efad0257dc6e593a18957422533ff0f87ede7c9c6ea010a2177d738fb82f/pure_eval-0.2.3-py3-none-any.whl", hash = "sha256:1db8e35b67b3d218d818ae653e27f06c3aa420901fa7b081ca98cbedc874e0d0", size = 11842, upload-time = "2024-07-21T12:58:20.04Z" }, ] +[[package]] +name = "py-ctcmetrics" +version = "1.3.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "imagecodecs" }, + { name = "numpy" }, + { name = "pandas" }, + { name = "scikit-learn" }, + { name = "scipy" }, + { name = "tifffile" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5d/25/bc4ff397b3ac93606ee105ab6832cca5f2a06b2dee9e1240f6215f541d4f/py_ctcmetrics-1.3.3.tar.gz", hash = "sha256:e055b7713bc704a42673b1313c7fd5ae55b80d49455132ff27b6b7db609209b0", size = 35153, upload-time = "2026-03-12T08:53:53.572Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/01/cc/c3c0d99df9540ca8ac4ee9c9177c5f88bf9693f5808ab5a5330d7d2fda65/py_ctcmetrics-1.3.3-py3-none-any.whl", hash = "sha256:7f35906030aadf8a4b5be9cf44260969b82b2d6bb3959b93f24928ff557b5f6c", size = 43419, upload-time = "2026-03-12T08:53:52.367Z" }, +] + [[package]] name = "pyairtable" version = "3.3.0" @@ -4143,8 +4322,10 @@ name = "pyqt5-qt5" version = "5.15.2" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.13' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version < '3.13' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version >= '3.13' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version >= '3.13' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32'", ] wheels = [ { url = "https://files.pythonhosted.org/packages/62/09/99a222b0360616250fb2e6003a54e43a2a06b0774f0f8d5daafb86a2c375/PyQt5_Qt5-5.15.2-py3-none-macosx_10_13_intel.whl", hash = "sha256:76980cd3d7ae87e3c7a33bfebfaee84448fd650bad6840471d6cae199b56e154", size = 40546019, upload-time = "2021-03-10T13:52:47.763Z" }, @@ -4155,12 +4336,18 @@ name = "pyqt5-qt5" version = "5.15.18" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.13' and sys_platform == 'win32'", - "python_full_version < '3.13' and sys_platform == 'win32'", - "python_full_version >= '3.13' and sys_platform == 'emscripten'", - "python_full_version < '3.13' and sys_platform == 'emscripten'", - "python_full_version >= '3.13' and sys_platform == 'linux'", - "python_full_version < '3.13' and sys_platform == 'linux'", + "python_full_version >= '3.13' and platform_machine != 's390x' and sys_platform == 'win32'", + "python_full_version >= '3.13' and platform_machine == 's390x' and sys_platform == 'win32'", + "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'win32'", + "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'win32'", + "python_full_version >= '3.13' and platform_machine != 's390x' and sys_platform == 'emscripten'", + "python_full_version >= '3.13' and platform_machine == 's390x' and sys_platform == 'emscripten'", + "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'emscripten'", + "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'emscripten'", + "python_full_version >= '3.13' and platform_machine != 's390x' and sys_platform == 'linux'", + "python_full_version >= '3.13' and platform_machine == 's390x' and sys_platform == 'linux'", + "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'linux'", + "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'linux'", ] wheels = [ { url = "https://files.pythonhosted.org/packages/9a/46/ffe177f99f897a59dc237a20059020427bd2d3853d713992b8081933ddfe/pyqt5_qt5-5.15.18-py3-none-manylinux2014_x86_64.whl", hash = "sha256:bf2457e6371969736b4f660a0c153258fa03dbc6a181348218e6f05421682af7", size = 60864590, upload-time = "2025-11-09T12:57:26.724Z" }, @@ -5231,6 +5418,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a2/09/77d55d46fd61b4a135c444fc97158ef34a095e5681d0a6c10b75bf356191/sympy-1.14.0-py3-none-any.whl", hash = "sha256:e091cc3e99d2141a0ba2847328f5479b05d94a6635cb96148ccb3f34671bd8f5", size = 6299353, upload-time = "2025-04-27T18:04:59.103Z" }, ] +[[package]] +name = "tabulate" +version = "0.10.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/46/58/8c37dea7bbf769b20d58e7ace7e5edfe65b849442b00ffcdd56be88697c6/tabulate-0.10.0.tar.gz", hash = "sha256:e2cfde8f79420f6deeffdeda9aaec3b6bc5abce947655d17ac662b126e48a60d", size = 91754, upload-time = "2026-03-04T18:55:34.402Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/99/55/db07de81b5c630da5cbf5c7df646580ca26dfaefa593667fc6f2fe016d2e/tabulate-0.10.0-py3-none-any.whl", hash = "sha256:f0b0622e567335c8fabaaa659f1b33bcb6ddfe2e496071b743aa113f8774f2d3", size = 39814, upload-time = "2026-03-04T18:55:31.284Z" }, +] + [[package]] name = "tasklogger" version = "1.2.0" @@ -5980,7 +6176,7 @@ test = [ [package.metadata] requires-dist = [ { name = "imageio" }, - { name = "iohub", specifier = ">=0.3.2" }, + { name = "iohub", specifier = ">=0.3.3" }, { name = "lightning", specifier = ">=2.3" }, { name = "monai", specifier = ">=1.5.2" }, { name = "numpy", specifier = ">=2.4.1" }, @@ -6140,6 +6336,8 @@ dependencies = [ { name = "lightning" }, { name = "matplotlib" }, { name = "numpy" }, + { name = "onnx" }, + { name = "onnxscript" }, { name = "pyyaml" }, { name = "scikit-image" }, { name = "scipy" }, @@ -6153,6 +6351,7 @@ dependencies = [ [package.optional-dependencies] all = [ { name = "anndata" }, + { name = "copairs" }, { name = "natsort" }, { name = "phate" }, { name = "scikit-learn" }, @@ -6164,6 +6363,7 @@ anndata = [ { name = "natsort" }, ] eval = [ + { name = "copairs" }, { name = "phate" }, { name = "scikit-learn" }, { name = "umap-learn" }, @@ -6192,13 +6392,17 @@ test = [ requires-dist = [ { name = "anndata", marker = "extra == 'all'" }, { name = "anndata", marker = "extra == 'anndata'" }, - { name = "iohub", specifier = ">=0.3a2" }, + { name = "copairs", marker = "extra == 'all'" }, + { name = "copairs", marker = "extra == 'eval'" }, + { name = "iohub", specifier = ">=0.3.3" }, { name = "jsonargparse", extras = ["signatures"], specifier = ">=4.26" }, { name = "lightning", specifier = ">=2.3" }, { name = "matplotlib", specifier = ">=3.10" }, { name = "natsort", marker = "extra == 'all'" }, { name = "natsort", marker = "extra == 'anndata'" }, { name = "numpy", specifier = ">=2.4.1" }, + { name = "onnx" }, + { name = "onnxscript" }, { name = "phate", marker = "extra == 'all'" }, { name = "phate", marker = "extra == 'eval'" }, { name = "pyyaml" },