diff --git a/applications/DynaCLR/evaluation/linear_classifiers/apply_linear_classifier.py b/applications/DynaCLR/evaluation/linear_classifiers/apply_linear_classifier.py index e86bea050..c62927984 100644 --- a/applications/DynaCLR/evaluation/linear_classifiers/apply_linear_classifier.py +++ b/applications/DynaCLR/evaluation/linear_classifiers/apply_linear_classifier.py @@ -61,6 +61,19 @@ def format_predictions_markdown(adata, task: str) -> str: lines.append(f"**Classes:** {', '.join(adata.uns[classes_key])}") lines.append("") + artifact_key = f"classifier_{task}_artifact" + if artifact_key in adata.uns.keys(): + lines.append("### Classifier Provenance") + lines.append("") + lines.append(f"- **Artifact:** {adata.uns[artifact_key]}") + id_key = f"classifier_{task}_id" + if id_key in adata.uns.keys(): + lines.append(f"- **Artifact ID:** {adata.uns[id_key]}") + version_key = f"classifier_{task}_version" + if version_key in adata.uns.keys(): + lines.append(f"- **Artifact Version:** {adata.uns[version_key]}") + lines.append("") + return "\n".join(lines) @@ -88,14 +101,20 @@ def main(config: Path): click.echo(f"\n❌ Failed to load configuration: {e}", err=True) raise click.Abort() + write_path = ( + Path(inference_config.output_path) + if inference_config.output_path is not None + else Path(inference_config.embeddings_path) + ) + click.echo(f"\n✓ Configuration loaded: {config}") click.echo(f" Model: {inference_config.model_name}") click.echo(f" Version: {inference_config.version}") click.echo(f" Embeddings: {inference_config.embeddings_path}") - click.echo(f" Output: {inference_config.output_path}") + click.echo(f" Output: {write_path}") try: - pipeline, loaded_config = load_pipeline_from_wandb( + pipeline, loaded_config, artifact_metadata = load_pipeline_from_wandb( wandb_project=inference_config.wandb_project, model_name=inference_config.model_name, version=inference_config.version, @@ -103,21 +122,31 @@ def main(config: Path): ) task = loaded_config["task"] + marker = loaded_config.get("marker") + task_key = f"{task}_{marker}" if marker else task click.echo(f"\nLoading embeddings from: {inference_config.embeddings_path}") adata = read_zarr(inference_config.embeddings_path) click.echo(f"✓ Loaded embeddings: {adata.shape}") - adata = predict_with_classifier(adata, pipeline, task) + if inference_config.include_wells: + click.echo(f" Well filter: {inference_config.include_wells}") + + adata = predict_with_classifier( + adata, + pipeline, + task_key, + artifact_metadata=artifact_metadata, + include_wells=inference_config.include_wells, + ) - output_path = Path(inference_config.output_path) - output_path.parent.mkdir(parents=True, exist_ok=True) + write_path.parent.mkdir(parents=True, exist_ok=True) - click.echo(f"\nSaving predictions to: {output_path}") - adata.write_zarr(output_path) + click.echo(f"\nSaving predictions to: {write_path}") + adata.write_zarr(write_path) click.echo("✓ Saved predictions") - click.echo("\n" + format_predictions_markdown(adata, task)) + click.echo("\n" + format_predictions_markdown(adata, task_key)) click.echo("\n✓ Inference complete!") diff --git a/applications/DynaCLR/evaluation/linear_classifiers/configs/example_linear_classifier_inference.yaml b/applications/DynaCLR/evaluation/linear_classifiers/configs/example_linear_classifier_inference.yaml index a6e882365..7f7fe16a6 100644 --- a/applications/DynaCLR/evaluation/linear_classifiers/configs/example_linear_classifier_inference.yaml +++ b/applications/DynaCLR/evaluation/linear_classifiers/configs/example_linear_classifier_inference.yaml @@ -21,8 +21,16 @@ wandb_entity: null # Path to embeddings zarr file for inference embeddings_path: /path/to/embeddings.zarr -# Path to save output zarr file with predictions -output_path: /path/to/output_with_predictions.zarr +# Path to save output zarr file with predictions. +# When omitted (or null), predictions are written back to embeddings_path. +# output_path: /path/to/output_with_predictions.zarr -# Whether to overwrite output if it already exists +# Well prefixes to restrict predictions to (optional). +# When omitted, all cells are predicted. Cells in other wells get NaN. +# Useful for organelle-specific classifiers where different wells have different markers. +# include_wells: +# - A/1 +# - A/2 + +# Whether to overwrite output if it already exists (only used when output_path is set) overwrite: false diff --git a/applications/DynaCLR/evaluation/linear_classifiers/dataset_discovery.py b/applications/DynaCLR/evaluation/linear_classifiers/dataset_discovery.py deleted file mode 100644 index 5add0ef52..000000000 --- a/applications/DynaCLR/evaluation/linear_classifiers/dataset_discovery.py +++ /dev/null @@ -1,226 +0,0 @@ -"""Shared discovery functions for finding predictions, annotations, and gaps.""" - -# %% -from glob import glob -from pathlib import Path - -import pandas as pd -from natsort import natsorted - -from viscy.representation.evaluation.linear_classifier_config import ( - VALID_CHANNELS, - VALID_TASKS, -) - -CHANNELS = list(VALID_CHANNELS.__args__) -TASKS = list(VALID_TASKS.__args__) - - -def discover_predictions( - embeddings_dir: Path, - model_name: str, - version: str, -) -> dict[str, Path]: - """Find datasets that have a predictions folder for the given model/version. - - Searches for paths matching: - {embeddings_dir}/{dataset}/*phenotyping*/*prediction*/{model_glob}/{version}/ - - Parameters - ---------- - embeddings_dir : Path - Base directory containing dataset folders. - model_name : str - Model directory name (supports glob patterns). - version : str - Version subdirectory (e.g. "v3"). - - Returns - ------- - dict[str, Path] - Mapping of dataset_name -> resolved predictions version directory. - """ - pattern = str( - embeddings_dir / "*" / "*phenotyping*" / "*prediction*" / model_name / version - ) - matches = natsorted(glob(pattern)) - - results = {} - for match in matches: - match_path = Path(match) - dataset_name = match_path.relative_to(embeddings_dir).parts[0] - results[dataset_name] = match_path - - return results - - -def find_channel_zarrs( - predictions_dir: Path, - channels: list[str] | None = None, -) -> dict[str, Path]: - """Find embedding zarr files for each channel in a predictions directory. - - Parameters - ---------- - predictions_dir : Path - Path to the version directory containing zarr files. - channels : list[str] or None - Channel names to search for. Defaults to CHANNELS. - - Returns - ------- - dict[str, Path] - Mapping of channel_name -> zarr path (only channels with a match). - """ - if channels is None: - channels = CHANNELS - channel_zarrs = {} - for channel in channels: - matches = natsorted(glob(str(predictions_dir / f"*{channel}*.zarr"))) - if matches: - channel_zarrs[channel] = Path(matches[0]) - return channel_zarrs - - -def find_annotation_csv(annotations_dir: Path, dataset_name: str) -> Path | None: - """Find the annotation CSV for a dataset. - - Parameters - ---------- - annotations_dir : Path - Base annotations directory. - dataset_name : str - Dataset folder name. - - Returns - ------- - Path or None - Path to CSV if found, None otherwise. - """ - dataset_dir = annotations_dir / dataset_name - if not dataset_dir.is_dir(): - return None - csvs = natsorted(glob(str(dataset_dir / "*.csv"))) - return Path(csvs[0]) if csvs else None - - -def get_available_tasks(csv_path: Path) -> list[str]: - """Read CSV header and return which valid task columns are present. - - Parameters - ---------- - csv_path : Path - Path to annotation CSV. - - Returns - ------- - list[str] - Task names found in the CSV columns. - """ - columns = pd.read_csv(csv_path, nrows=0).columns.tolist() - return [t for t in TASKS if t in columns] - - -def build_registry( - embeddings_dir: Path, - annotations_dir: Path, - model_name: str, - version: str, -) -> tuple[list[dict], list[dict], list[str], list[str]]: - """Build a registry of datasets with predictions and annotations. - - Parameters - ---------- - embeddings_dir : Path - Base directory containing dataset folders with embeddings. - annotations_dir : Path - Base directory containing dataset annotation folders. - model_name : str - Model directory name (supports glob patterns). - version : str - Version subdirectory (e.g. "v3"). - - Returns - ------- - registry : list[dict] - Datasets with both predictions and annotations. - skipped : list[dict] - Datasets with predictions but missing annotations or tasks. - annotations_only : list[str] - Annotation datasets with no matching predictions. - predictions_only : list[str] - Prediction datasets with no matching annotations. - """ - predictions = discover_predictions(embeddings_dir, model_name, version) - - registry: list[dict] = [] - skipped: list[dict] = [] - - for dataset_name, pred_dir in predictions.items(): - channel_zarrs = find_channel_zarrs(pred_dir) - csv_path = find_annotation_csv(annotations_dir, dataset_name) - - if not csv_path: - skipped.append({"dataset": dataset_name, "reason": "No annotation CSV"}) - continue - if not channel_zarrs: - skipped.append({"dataset": dataset_name, "reason": "No channel zarrs"}) - continue - - available_tasks = get_available_tasks(csv_path) - if not available_tasks: - skipped.append( - {"dataset": dataset_name, "reason": "No valid task columns in CSV"} - ) - continue - - registry.append( - { - "dataset": dataset_name, - "predictions_dir": pred_dir, - "channel_zarrs": channel_zarrs, - "annotations_csv": csv_path, - "available_tasks": available_tasks, - } - ) - - annotation_datasets = set(d.name for d in annotations_dir.iterdir() if d.is_dir()) - prediction_datasets = set(predictions.keys()) - - annotations_only = natsorted(annotation_datasets - prediction_datasets) - predictions_only = natsorted(prediction_datasets - annotation_datasets) - - return registry, skipped, annotations_only, predictions_only - - -def print_registry_summary( - registry: list[dict], - skipped: list[dict], - annotations_only: list[str], - predictions_only: list[str], -): - """Print a markdown summary of the dataset registry and gaps.""" - print("## Dataset Registry\n") - print("| Dataset | Annotations | Channels | Tasks |") - print("|---------|-------------|----------|-------|") - for entry in registry: - channels_str = ", ".join(sorted(entry["channel_zarrs"].keys())) - tasks_str = ", ".join(entry["available_tasks"]) - print( - f"| {entry['dataset']} | {entry['annotations_csv'].name} " - f"| {channels_str} | {tasks_str} |" - ) - - if annotations_only or predictions_only or skipped: - print("\n## Gaps\n") - print("| Dataset | Status |") - print("|---------|--------|") - for d in annotations_only: - print(f"| {d} | Annotations only (missing predictions) |") - for d in predictions_only: - print(f"| {d} | Predictions only (missing annotations) |") - for s in skipped: - print(f"| {s['dataset']} | {s['reason']} |") - - -# %% diff --git a/applications/DynaCLR/evaluation/linear_classifiers/generate_batch_predictions.py b/applications/DynaCLR/evaluation/linear_classifiers/generate_batch_predictions.py new file mode 100644 index 000000000..49d7bb0c5 --- /dev/null +++ b/applications/DynaCLR/evaluation/linear_classifiers/generate_batch_predictions.py @@ -0,0 +1,308 @@ +# %% +"""Batch DynaCLR prediction config & SLURM script generator. + +Generates prediction YAML configs and SLURM submission scripts for +multiple datasets, channels, and checkpoints. Automatically resolves +z_range from focus_slice metadata (computing it on the fly if missing) +and detects source channel names from the zarr. + +Usage: run cells interactively or execute as a script. +""" + +import subprocess +from pathlib import Path + +from iohub import open_ome_zarr + +from utils import ( + FOCUS_PARAMS, + MODEL_2D_BAG_TIMEAWARE, # noqa: F401 — alternate model choice + MODEL_3D_BAG_TIMEAWARE, + build_registry, + extract_epoch, + find_phenotyping_predictions_dir, + generate_slurm_script, + generate_yaml, + get_z_range, + print_registry_summary, + resolve_channel_name, + resolve_dataset_paths, +) + +# %% +# =========================================================================== +# USER CONFIGURATION — edit this cell +# =========================================================================== + +BASE_DIR = Path("/hpc/projects/intracellular_dashboard/organelle_dynamics") + +# Choose model template +MODEL = MODEL_3D_BAG_TIMEAWARE +# MODEL = MODEL_2D_BAG_TIMEAWARE + +VERSION = "v1" + +CHANNELS = ["phase", "organelle", "sensor"] + +CHECKPOINTS = [ + "/hpc/projects/organelle_phenotyping/models/bag_of_channels/h2b_caax_tomm_sec61_g3bp1_sensor_phase/tb_logs/dynaclr3d_bag_channels_v1/version_2/checkpoints/epoch=40-step=44746.ckpt", +] + +# Datasets to process. Set to [] to auto-discover from annotations_only. +DATASETS = [ + "2025_01_24_A549_G3BP1_DENV", + "2024_11_07_A549_SEC61_DENV", + "2025_01_28_A549_G3BP1_ZIKV_DENV", + "2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV", +] + +# Per-dataset channel keyword overrides. +# E.g., {"2025_04_10_...": {"organelle": "Cy5"}} +CHANNEL_OVERRIDES: dict[str, dict[str, str]] = {} + +# Annotations directory (used for auto-discovery when DATASETS is empty). +ANNOTATIONS_DIR = Path("/hpc/projects/organelle_phenotyping/datasets/annotations") + +# Set to True for a dry run (preview only, no files written). +DRY_RUN = False + +# Set to True to overwrite existing config files. False to skip them. +OVERWRITE_FILES = True + +# Set to True to submit all generated predict_all.sh scripts via sbatch. +SUBMIT_JOBS = True + +# %% +# =========================================================================== +# Discovery & validation +# =========================================================================== + +# Auto-discover datasets from annotations when DATASETS is empty +if not DATASETS: + registry, skipped, annotations_only, predictions_only = build_registry( + BASE_DIR, ANNOTATIONS_DIR, MODEL["name"], VERSION + ) + print_registry_summary(registry, skipped, annotations_only, predictions_only) + DATASETS = annotations_only + print(f"\nAuto-discovered {len(DATASETS)} datasets missing predictions.\n") + +print("## Batch Prediction Config Generator\n") +print(f"- **Model**: `{MODEL['name']}`") +print(f"- **Version**: `{VERSION}`") +print(f"- **Channels**: {CHANNELS}") +print(f"- **Checkpoints**: {len(CHECKPOINTS)}") +print(f"- **Datasets**: {len(DATASETS)}") +print(f"- **Dry run**: {DRY_RUN}\n") + +validated: list[dict] = [] +errors: list[dict] = [] + +for ds in DATASETS: + try: + paths = resolve_dataset_paths(ds, BASE_DIR, MODEL) + print(f"Resolving {ds}...") + + # Read channel names from data zarr + plate = open_ome_zarr(str(paths["data_path"]), mode="r") + zarr_channels = list(plate.channel_names) + plate.close() + + # Resolve channel names + ds_overrides = CHANNEL_OVERRIDES.get(ds) + available = {} + for ch_type in CHANNELS: + ch_name = resolve_channel_name(zarr_channels, ch_type, ds_overrides) + if ch_name: + available[ch_type] = ch_name + else: + print(f" WARNING: channel '{ch_type}' not found in {ds}") + + # Resolve z_range (may compute focus on the fly) + phase_ch = available.get("phase") + z_range = get_z_range( + paths["data_path"], MODEL, FOCUS_PARAMS, phase_channel=phase_ch + ) + print(f" z_range: {z_range}") + + validated.append( + { + "dataset": ds, + "paths": paths, + "z_range": z_range, + "channels": available, + } + ) + + except Exception as e: + errors.append({"dataset": ds, "error": str(e)}) + print(f" ERROR: {e}") + +# %% +# =========================================================================== +# Summary before generation +# =========================================================================== + +print("\n### Validated Datasets\n") +print("| Dataset | z_range | Channels | data_path |") +print("|---------|---------|----------|-----------|") +for v in validated: + ch_str = ", ".join(sorted(v["channels"].keys())) + print( + f"| {v['dataset']} | {v['z_range']} | {ch_str} | `{v['paths']['data_path'].name}` |" + ) + +if errors: + print("\n### Errors\n") + print("| Dataset | Error |") + print("|---------|-------|") + for e in errors: + print(f"| {e['dataset']} | {e['error']} |") + +print( + f"\n**Will generate**: {len(validated)} datasets " + f"x {len(CHECKPOINTS)} checkpoints " + f"= {len(validated) * len(CHECKPOINTS)} config sets" +) + +# %% +# =========================================================================== +# Generate configs and scripts +# =========================================================================== + +generated: list[dict] = [] + +for entry in validated: + ds = entry["dataset"] + paths = entry["paths"] + z_range = entry["z_range"] + channels = entry["channels"] + + output_dir = find_phenotyping_predictions_dir(BASE_DIR / ds, MODEL["name"], VERSION) + + # TODO: support multiple checkpoints (namespace files by epoch or subdirs) + for ckpt in CHECKPOINTS: + epoch = extract_epoch(ckpt) + suffix = "" + files_written = [] + + for ch_type, ch_name in channels.items(): + yml_content = generate_yaml( + ds, + paths["data_path"], + paths["tracks_path"], + MODEL, + ch_type, + ch_name, + z_range, + ckpt, + output_dir, + VERSION, + ) + sh_content = generate_slurm_script(ch_type, output_dir, suffix=suffix) + + yml_path = output_dir / f"predict_{ch_type}{suffix}.yml" + sh_path = output_dir / f"predict_{ch_type}{suffix}.sh" + + if not OVERWRITE_FILES and yml_path.exists(): + print(f" Skipping {yml_path.name} (exists)") + continue + + if not DRY_RUN: + output_dir.mkdir(parents=True, exist_ok=True) + (output_dir / "slurm_out").mkdir(exist_ok=True) + yml_path.write_text(yml_content) + sh_path.write_text(sh_content) + sh_path.chmod(0o755) + + files_written.append( + { + "channel": ch_type, + "yml": yml_path, + "sh": sh_path, + "yml_content": yml_content, + "sh_content": sh_content, + } + ) + + # predict_all.sh + if files_written: + run_all_lines = ["#!/bin/bash", ""] + for f in files_written: + run_all_lines.append(f"sbatch {f['sh']}") + run_all_content = "\n".join(run_all_lines) + "\n" + + run_all_name = f"predict_all{suffix}.sh" + run_all_path = output_dir / run_all_name + if not DRY_RUN: + run_all_path.write_text(run_all_content) + run_all_path.chmod(0o755) + + generated.append( + { + "dataset": ds, + "checkpoint": ckpt, + "epoch": epoch, + "output_dir": output_dir, + "files": files_written, + } + ) + +# %% +# =========================================================================== +# Generation summary +# =========================================================================== + +action = "Generated" if not DRY_RUN else "Would generate (DRY RUN)" +print(f"\n## {action}\n") +print("| Dataset | Epoch | Channels | Output Dir |") +print("|---------|-------|----------|------------|") +for g in generated: + ch_str = ", ".join(f["channel"] for f in g["files"]) + print(f"| {g['dataset']} | {g['epoch']} | {ch_str} | `{g['output_dir']}` |") + +print("\n### Files\n") +for g in generated: + print(f"**{g['dataset']}** (epoch {g['epoch']}):") + for f in g["files"]: + print(f" - `{f['yml']}`") + print(f" - `{f['sh']}`") + print(f" - `{g['output_dir'] / 'predict_all.sh'}`") + +if DRY_RUN and generated: + print("\n### Preview (first config)\n") + print("```yaml") + print(generated[0]["files"][0]["yml_content"]) + print("```") + print("\nSet `DRY_RUN = False` to write files.") + +# %% +# =========================================================================== +# Submit SLURM jobs +# =========================================================================== + +if SUBMIT_JOBS and not DRY_RUN and generated: + print("\n## Submitting SLURM jobs\n") + print("| Dataset | Script | Job ID |") + print("|---------|--------|--------|") + for g in generated: + predict_all = g["output_dir"] / "predict_all.sh" + if not predict_all.exists(): + print(f"| {g['dataset']} | `{predict_all}` | MISSING |") + continue + result = subprocess.run( + ["bash", str(predict_all)], + capture_output=True, + text=True, + ) + output = result.stdout.strip() + if result.returncode != 0: + print( + f"| {g['dataset']} | `{predict_all.name}` | ERROR: {result.stderr.strip()} |" + ) + else: + for line in output.splitlines(): + print(f"| {g['dataset']} | `{predict_all.name}` | {line} |") +elif SUBMIT_JOBS and DRY_RUN: + print("\n**SUBMIT_JOBS is True but DRY_RUN is also True — skipping submission.**") + +# %% diff --git a/applications/DynaCLR/evaluation/linear_classifiers/generate_prediction_scripts.py b/applications/DynaCLR/evaluation/linear_classifiers/generate_prediction_scripts.py deleted file mode 100644 index 2b35e3c16..000000000 --- a/applications/DynaCLR/evaluation/linear_classifiers/generate_prediction_scripts.py +++ /dev/null @@ -1,191 +0,0 @@ -# %% -"""Generate prediction .sh/.yml scripts for datasets missing embeddings. - -Uses an existing dataset's prediction configs as a template, swaps in the -target dataset name, and enforces a single checkpoint across all datasets. -""" - -import re -from glob import glob -from pathlib import Path - -from dataset_discovery import ( - CHANNELS, - build_registry, - print_registry_summary, -) -from natsort import natsorted - -# %% -# --- Configuration --- -embeddings_dir = Path("/hpc/projects/intracellular_dashboard/organelle_dynamics") -annotations_dir = Path("/hpc/projects/organelle_phenotyping/datasets/annotations") -model = "DynaCLR-2D-Bag*Channels-timeaware" -version = "v3" -ckpt_path = ( - "/hpc/projects/organelle_phenotyping/models/" - "SEC61_TOMM20_G3BP1_Sensor/time_interval/dynaclr_gfp_rfp_Ph/" - "organelle_sensor_phase_maxproj_ver3_150epochs/saved_checkpoints/" - "epoch=104-step=53760.ckpt" -) - -# %% -# --- Discover datasets and gaps --- -registry, skipped, annotations_only, predictions_only = build_registry( - embeddings_dir, annotations_dir, model, version -) -print_registry_summary(registry, skipped, annotations_only, predictions_only) - -# %% -# --- Pick reference dataset --- -if not registry: - raise RuntimeError( - "No reference dataset found with both predictions and annotations." - ) - -reference_dataset = registry[0]["dataset"] -reference_pred_dir = registry[0]["predictions_dir"] -reference_model_dir = reference_pred_dir.parent.name - -print("\n## Prediction Script Generation\n") -print(f"- Reference dataset: `{reference_dataset}`") -print(f"- Reference dir: `{reference_pred_dir}`") -print(f"- Checkpoint: `{ckpt_path}`\n") - -# %% -# --- Generate scripts for each dataset missing predictions --- -prediction_scripts_generated: list[dict] = [] -generation_skipped: list[dict] = [] - -for target_dataset in annotations_only: - target_base = embeddings_dir / target_dataset - if not target_base.is_dir(): - generation_skipped.append( - {"dataset": target_dataset, "reason": "No directory in embeddings_dir"} - ) - continue - - phenotyping_matches = natsorted(glob(str(target_base / "*phenotyping*"))) - if not phenotyping_matches: - generation_skipped.append( - {"dataset": target_dataset, "reason": "No *phenotyping* directory"} - ) - continue - phenotyping_dir = Path(phenotyping_matches[0]) - - # Find existing predictions parent or default to "predictions" - pred_parent_matches = natsorted(glob(str(phenotyping_dir / "*prediction*"))) - pred_parent = ( - Path(pred_parent_matches[0]) - if pred_parent_matches - else phenotyping_dir / "predictions" - ) - target_pred_dir = pred_parent / reference_model_dir / version - - # Verify data_path and tracks_path exist - data_path_matches = natsorted( - glob(str(phenotyping_dir / "train-test" / f"{target_dataset}*.zarr")) - ) - tracks_path_matches = natsorted( - glob( - str( - target_base - / "1-preprocess" - / "label-free" - / "3-track" - / f"{target_dataset}*cropped.zarr" - ) - ) - ) - - if not data_path_matches: - generation_skipped.append( - {"dataset": target_dataset, "reason": "No train-test zarr found"} - ) - continue - if not tracks_path_matches: - generation_skipped.append( - {"dataset": target_dataset, "reason": "No tracking zarr found"} - ) - continue - - generated_files = [] - for channel in CHANNELS: - ref_yml = reference_pred_dir / f"predict_{channel}.yml" - ref_sh = reference_pred_dir / f"predict_{channel}.sh" - - if not ref_yml.exists() or not ref_sh.exists(): - continue - - # Swap dataset name in all paths - new_yml = ref_yml.read_text().replace(reference_dataset, target_dataset) - new_sh = ref_sh.read_text().replace(reference_dataset, target_dataset) - - # Enforce the configured checkpoint - new_yml = re.sub(r"(?m)^ckpt_path:.*$", f"ckpt_path: {ckpt_path}", new_yml) - - generated_files.append( - { - "channel": channel, - "yml_path": target_pred_dir / f"predict_{channel}.yml", - "yml_content": new_yml, - "sh_path": target_pred_dir / f"predict_{channel}.sh", - "sh_content": new_sh, - } - ) - - if generated_files: - prediction_scripts_generated.append( - { - "dataset": target_dataset, - "pred_dir": target_pred_dir, - "files": generated_files, - } - ) - -# %% -# --- Print summary --- -if prediction_scripts_generated: - print("### Will Generate\n") - print("| Dataset | Prediction Dir | Channels |") - print("|---------|---------------|----------|") - for entry in prediction_scripts_generated: - channels_str = ", ".join(f["channel"] for f in entry["files"]) - print(f"| {entry['dataset']} | `{entry['pred_dir']}` | {channels_str} |") -else: - print("No datasets need prediction scripts generated.") - -if generation_skipped: - print("\n### Cannot Generate\n") - print("| Dataset | Reason |") - print("|---------|--------|") - for s in generation_skipped: - print(f"| {s['dataset']} | {s['reason']} |") - -# %% -# --- Write prediction scripts and run_all.sh --- -for entry in prediction_scripts_generated: - pred_dir = entry["pred_dir"] - pred_dir.mkdir(parents=True, exist_ok=True) - (pred_dir / "slurm_out").mkdir(exist_ok=True) - - sh_names = [] - for f in entry["files"]: - f["yml_path"].write_text(f["yml_content"]) - f["sh_path"].write_text(f["sh_content"]) - f["sh_path"].chmod(0o755) - sh_names.append(f["sh_path"].name) - - # Generate run_all.sh - run_all_path = pred_dir / "run_all.sh" - run_all_lines = ["#!/bin/bash", ""] - for sh_name in sh_names: - run_all_lines.append(f"sbatch {sh_name}") - run_all_content = "\n".join(run_all_lines) + "\n" - run_all_path.write_text(run_all_content) - run_all_path.chmod(0o755) - - print(f"Wrote {entry['dataset']} -> {pred_dir}") - for sh_name in sh_names: - print(f" {sh_name}") - print(" run_all.sh") diff --git a/applications/DynaCLR/evaluation/linear_classifiers/generate_train_config.py b/applications/DynaCLR/evaluation/linear_classifiers/generate_train_config.py index ef6e03ec9..01a070ec3 100644 --- a/applications/DynaCLR/evaluation/linear_classifiers/generate_train_config.py +++ b/applications/DynaCLR/evaluation/linear_classifiers/generate_train_config.py @@ -9,7 +9,8 @@ from pathlib import Path import yaml -from dataset_discovery import ( + +from utils import ( CHANNELS, TASKS, build_registry, diff --git a/applications/DynaCLR/evaluation/linear_classifiers/train_linear_classifier.py b/applications/DynaCLR/evaluation/linear_classifiers/train_linear_classifier.py index c554f5be9..d89f12a53 100644 --- a/applications/DynaCLR/evaluation/linear_classifiers/train_linear_classifier.py +++ b/applications/DynaCLR/evaluation/linear_classifiers/train_linear_classifier.py @@ -84,6 +84,8 @@ def main(config: Path): click.echo(f"\n✓ Configuration loaded: {config}") click.echo(f" Task: {train_config.task}") click.echo(f" Input channel: {train_config.input_channel}") + if train_config.marker: + click.echo(f" Marker: {train_config.marker}") click.echo(f" Model: {train_config.embedding_model}") click.echo(f" Datasets: {len(train_config.train_datasets)}") diff --git a/applications/DynaCLR/evaluation/linear_classifiers/utils.py b/applications/DynaCLR/evaluation/linear_classifiers/utils.py new file mode 100644 index 000000000..37b603008 --- /dev/null +++ b/applications/DynaCLR/evaluation/linear_classifiers/utils.py @@ -0,0 +1,707 @@ +"""Shared utilities for the linear_classifiers workflow. + +Constants, path resolution, config generation, dataset discovery, +and focus/z-range helpers used by both ``generate_batch_predictions.py`` +and ``generate_train_config.py``. +""" + +# %% +import re +from glob import glob +from pathlib import Path + +import pandas as pd +from natsort import natsorted + +from viscy.representation.evaluation.linear_classifier_config import ( + VALID_CHANNELS, + VALID_TASKS, +) + +CHANNELS = list(VALID_CHANNELS.__args__) +TASKS = list(VALID_TASKS.__args__) + +# --------------------------------------------------------------------------- +# Model templates +# --------------------------------------------------------------------------- + +MODEL_3D_BAG_TIMEAWARE = { + "name": "DynaCLR-3D-BagOfChannels-timeaware", + "in_stack_depth": 30, + "stem_kernel_size": [5, 4, 4], + "stem_stride": [5, 4, 4], + "patch_size": 192, + "data_path_type": "2-assemble", + "z_range": "auto", + # Fraction of z slices below the focus plane (0.33 = 1/3 below, 2/3 above). + "focus_below_fraction": 1 / 3, + "logger_base": "/hpc/projects/organelle_phenotyping/models/tb_logs", +} + +MODEL_2D_BAG_TIMEAWARE = { + "name": "DynaCLR-2D-BagOfChannels-timeaware", + "in_stack_depth": 1, + "stem_kernel_size": [1, 4, 4], + "stem_stride": [1, 4, 4], + "patch_size": 160, + "data_path_type": "train-test", + "z_range": [0, 1], + "logger_base": "/hpc/projects/organelle_phenotyping/models/embedding_logs", +} + +# --------------------------------------------------------------------------- +# Channel defaults +# --------------------------------------------------------------------------- + +CHANNEL_DEFAULTS: dict[str, dict] = { + "organelle": { + "keyword": "GFP", + "yaml_alias": "fluor", + "normalization_class": "viscy.transforms.ScaleIntensityRangePercentilesd", + "normalization_args": { + "lower": 50, + "upper": 99, + "b_min": 0.0, + "b_max": 1.0, + }, + "batch_size": {"2d": 32, "3d": 64}, + "num_workers": {"2d": 8, "3d": 16}, + }, + "phase": { + "keyword": "Phase", + "yaml_alias": "Ph", + "normalization_class": "viscy.transforms.NormalizeSampled", + "normalization_args": { + "level": "fov_statistics", + "subtrahend": "mean", + "divisor": "std", + }, + "batch_size": {"2d": 64, "3d": 64}, + "num_workers": {"2d": 16, "3d": 16}, + }, + "sensor": { + "keyword": "mCherry", + "yaml_alias": "fluor", + "normalization_class": "viscy.transforms.ScaleIntensityRangePercentilesd", + "normalization_args": { + "lower": 50, + "upper": 99, + "b_min": 0.0, + "b_max": 1.0, + }, + "batch_size": {"2d": 32, "3d": 64}, + "num_workers": {"2d": 8, "3d": 16}, + }, +} + +# --------------------------------------------------------------------------- +# Focus parameters (microscope-specific defaults) +# --------------------------------------------------------------------------- + +FOCUS_PARAMS = { + "NA_det": 1.35, + "lambda_ill": 0.450, + "pixel_size": 0.1494, + "device": "cuda", +} + + +# --------------------------------------------------------------------------- +# Checkpoint utilities +# --------------------------------------------------------------------------- + + +def extract_epoch(ckpt_path: str) -> str: + """Extract epoch number from a checkpoint filename. + + ``epoch=32-step=33066.ckpt`` -> ``"32"`` + """ + m = re.search(r"epoch=(\d+)", Path(ckpt_path).stem) + if m: + return m.group(1) + return Path(ckpt_path).stem + + +# --------------------------------------------------------------------------- +# Channel utilities +# --------------------------------------------------------------------------- + + +def resolve_channel_name( + channel_names: list[str], + channel_type: str, + channel_overrides: dict[str, str] | None = None, +) -> str | None: + """Find the full channel name by keyword substring match. + + When multiple channels match the keyword, the ``raw`` variant is + preferred (e.g. ``"raw GFP EX488 EM525-45"`` over ``"GFP EX488 EM525-45"``). + + Parameters + ---------- + channel_names : list[str] + Channel names from the zarr dataset. + channel_type : str + One of "organelle", "phase", "sensor". + channel_overrides : dict[str, str] or None + Optional mapping of channel_type -> keyword override. + + Returns + ------- + str or None + Matched channel name, or None if not found. + """ + keyword = channel_overrides.get(channel_type) if channel_overrides else None + if keyword is None: + keyword = CHANNEL_DEFAULTS[channel_type]["keyword"] + matches = [name for name in channel_names if keyword in name] + if not matches: + return None + # Prefer the "raw" variant when both raw and processed exist + raw = [m for m in matches if m.lower().startswith("raw")] + return raw[0] if raw else matches[0] + + +# --------------------------------------------------------------------------- +# Path resolution +# --------------------------------------------------------------------------- + + +def resolve_dataset_paths( + dataset_name: str, + base_dir: Path, + model_config: dict, +) -> dict: + """Resolve data_path and tracks_path for a dataset. + + Parameters + ---------- + dataset_name : str + Dataset folder name. + base_dir : Path + Base directory containing all datasets. + model_config : dict + Model template (used to determine data_path_type). + + Returns + ------- + dict + Keys: data_path, tracks_path (both as Path objects). + + Raises + ------ + FileNotFoundError + If required paths cannot be found. + """ + dataset_dir = base_dir / dataset_name + + # Data path + if model_config["data_path_type"] == "train-test": + matches = natsorted( + glob( + str( + dataset_dir + / "*phenotyping*" + / "*train-test*" + / f"{dataset_name}*.zarr" + ) + ) + ) + if not matches: + raise FileNotFoundError(f"No train-test zarr found for {dataset_name}") + data_path = Path(matches[0]) + else: + matches = natsorted( + glob(str(dataset_dir / "2-assemble" / f"{dataset_name}*.zarr")) + ) + if not matches: + raise FileNotFoundError(f"No 2-assemble zarr found for {dataset_name}") + data_path = Path(matches[0]) + + # Tracks path + tracks_matches = natsorted( + glob( + str( + dataset_dir + / "1-preprocess" + / "label-free" + / "3-track" + / f"{dataset_name}*cropped.zarr" + ) + ) + ) + if not tracks_matches: + raise FileNotFoundError(f"No tracking zarr found for {dataset_name}") + tracks_path = Path(tracks_matches[0]) + + return {"data_path": data_path, "tracks_path": tracks_path} + + +def find_phenotyping_predictions_dir( + dataset_dir: Path, + model_name: str, + version: str, +) -> Path: + """Locate or create the predictions output directory for a dataset.""" + pheno_matches = natsorted(glob(str(dataset_dir / "*phenotyping*"))) + if not pheno_matches: + pheno_dir = dataset_dir / "4-phenotyping" + else: + pheno_dir = Path(pheno_matches[0]) + + pred_matches = natsorted(glob(str(pheno_dir / "*prediction*"))) + pred_parent = Path(pred_matches[0]) if pred_matches else pheno_dir / "predictions" + + return pred_parent / model_name / version + + +# --------------------------------------------------------------------------- +# Focus / z-range +# --------------------------------------------------------------------------- + + +def get_z_range( + data_path: str | Path, + model_config: dict, + focus_params: dict | None = None, + phase_channel: str | None = None, +) -> list[int]: + """Determine z_range for prediction. + + For models with ``z_range="auto"``, reads focus_slice metadata from the + zarr. If metadata is missing, computes it on the fly. + + Parameters + ---------- + data_path : str or Path + Path to the OME-Zarr dataset. + model_config : dict + Model template dictionary. + focus_params : dict or None + Parameters for on-the-fly focus computation. + phase_channel : str or None + Name of the phase channel in the zarr. Used to look up focus_slice + metadata. If None, auto-detected by keyword match. + + Returns + ------- + list[int] + [z_start, z_end] range for prediction. + """ + from iohub import open_ome_zarr + + if model_config["z_range"] != "auto": + return list(model_config["z_range"]) + + plate = open_ome_zarr(str(data_path), mode="r") + + # Resolve phase channel name if not provided + if phase_channel is None: + phase_channel = resolve_channel_name(list(plate.channel_names), "phase") + if phase_channel is None: + plate.close() + raise ValueError( + f"Cannot determine z_range: no phase channel found in {data_path}" + ) + + focus_data = plate.zattrs.get("focus_slice", {}) + phase_stats = focus_data.get(phase_channel, {}).get("dataset_statistics", {}) + z_focus_mean = phase_stats.get("z_focus_mean") + + # Get total z depth from first position + for _, pos in plate.positions(): + z_total = pos["0"].shape[2] + break + plate.close() + + if z_focus_mean is None: + print(f" Focus metadata missing for {Path(data_path).name}, computing...") + z_focus_mean = _compute_focus( + str(data_path), focus_params or FOCUS_PARAMS, phase_channel + ) + + depth = model_config["in_stack_depth"] + below_frac = model_config.get("focus_below_fraction", 0.5) + slices_below = int(round(depth * below_frac)) + z_center = int(round(z_focus_mean)) + z_start = max(0, z_center - slices_below) + z_end = min(z_total, z_start + depth) + # Re-adjust start if we hit the ceiling + z_start = max(0, z_end - depth) + + return [z_start, z_end] + + +def _compute_focus(zarr_path: str, focus_params: dict, phase_channel: str) -> float: + """Compute focus_slice metadata and write it to the zarr. + + Returns the dataset-level z_focus_mean. + """ + from iohub import open_ome_zarr + + from viscy.preprocessing.focus import FocusSliceMetric + from viscy.preprocessing.qc_metrics import generate_qc_metadata + + metric = FocusSliceMetric( + NA_det=focus_params["NA_det"], + lambda_ill=focus_params["lambda_ill"], + pixel_size=focus_params["pixel_size"], + channel_names=[phase_channel], + device=focus_params.get("device", "cpu"), + ) + generate_qc_metadata(zarr_path, [metric]) + + plate = open_ome_zarr(zarr_path, mode="r") + z_focus_mean = plate.zattrs["focus_slice"][phase_channel]["dataset_statistics"][ + "z_focus_mean" + ] + plate.close() + return z_focus_mean + + +# --------------------------------------------------------------------------- +# Config generation +# --------------------------------------------------------------------------- + + +def model_dim_key(model_config: dict) -> str: + """Return '2d' or '3d' based on model template.""" + return "2d" if model_config["in_stack_depth"] == 1 else "3d" + + +def generate_yaml( + dataset_name: str, + data_path: Path, + tracks_path: Path, + model_config: dict, + channel_type: str, + channel_name: str, + z_range: list[int], + ckpt_path: str, + output_dir: Path, + version: str, +) -> str: + """Generate a prediction YAML config string. + + Uses YAML anchors to match the existing config style. + """ + dim = model_dim_key(model_config) + ch_cfg = CHANNEL_DEFAULTS[channel_type] + patch = model_config["patch_size"] + depth = model_config["in_stack_depth"] + epoch = extract_epoch(ckpt_path) + yaml_alias = ch_cfg["yaml_alias"] + + output_zarr = output_dir / f"timeaware_{channel_type}_{patch}patch_{epoch}ckpt.zarr" + + # Build normalization block + norm_class = ch_cfg["normalization_class"] + norm_args = dict(ch_cfg["normalization_args"]) + + # Format normalization init_args as YAML lines + norm_lines = [f" keys: [*{yaml_alias}]"] + for k, v in norm_args.items(): + norm_lines.append(f" {k}: {v}") + norm_block = "\n".join(norm_lines) + + logger_base = model_config["logger_base"] + model_name = model_config["name"] + logger_save_dir = f"{logger_base}/{dataset_name}" + logger_name = f"{model_name}/{version}/{channel_type}" + + yaml_str = f"""\ +seed_everything: 42 +trainer: + accelerator: gpu + strategy: auto + devices: auto + num_nodes: 1 + precision: 32-true + callbacks: + - class_path: viscy.representation.embedding_writer.EmbeddingWriter + init_args: + output_path: "{output_zarr}" + logger: + save_dir: "{logger_save_dir}" + name: "{logger_name}" + inference_mode: true +model: + class_path: viscy.representation.engine.ContrastiveModule + init_args: + encoder: + class_path: viscy.representation.contrastive.ContrastiveEncoder + init_args: + backbone: convnext_tiny + in_channels: 1 + in_stack_depth: {depth} + stem_kernel_size: {model_config["stem_kernel_size"]} + stem_stride: {model_config["stem_stride"]} + embedding_dim: 768 + projection_dim: 32 + drop_path_rate: 0.0 + example_input_array_shape: [1, 1, {depth}, {patch}, {patch}] +data: + class_path: viscy.data.triplet.TripletDataModule + init_args: + data_path: {data_path} + tracks_path: {tracks_path} + source_channel: + - &{yaml_alias} {channel_name} + z_range: {z_range} + batch_size: {ch_cfg["batch_size"][dim]} + num_workers: {ch_cfg["num_workers"][dim]} + initial_yx_patch_size: [{patch}, {patch}] + final_yx_patch_size: [{patch}, {patch}] + normalizations: + - class_path: {norm_class} + init_args: +{norm_block} +return_predictions: false +ckpt_path: {ckpt_path} +""" + return yaml_str + + +def generate_slurm_script( + channel_type: str, + output_dir: Path, + suffix: str = "", +) -> str: + """Generate a SLURM submission shell script.""" + config_file = output_dir / f"predict_{channel_type}{suffix}.yml" + slurm_out = output_dir / "slurm_out" / "pred_%j.out" + + return f"""\ +#!/bin/bash + +#SBATCH --job-name=dynaclr_pred +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:1 +#SBATCH --partition=gpu +#SBATCH --cpus-per-task=32 +#SBATCH --mem-per-cpu=8G +#SBATCH --time=0-02:00:00 +#SBATCH --output={slurm_out} + +module load anaconda/latest +conda activate viscy + +cat {config_file} +srun viscy predict -c {config_file} +""" + + +# --------------------------------------------------------------------------- +# Dataset discovery +# --------------------------------------------------------------------------- + + +def discover_predictions( + embeddings_dir: Path, + model_name: str, + version: str, +) -> dict[str, Path]: + """Find datasets that have a predictions folder for the given model/version. + + Searches for paths matching: + {embeddings_dir}/{dataset}/*phenotyping*/*prediction*/{model_glob}/{version}/ + + Parameters + ---------- + embeddings_dir : Path + Base directory containing dataset folders. + model_name : str + Model directory name (supports glob patterns). + version : str + Version subdirectory (e.g. "v3"). + + Returns + ------- + dict[str, Path] + Mapping of dataset_name -> resolved predictions version directory. + """ + pattern = str( + embeddings_dir / "*" / "*phenotyping*" / "*prediction*" / model_name / version + ) + matches = natsorted(glob(pattern)) + + results = {} + for match in matches: + match_path = Path(match) + dataset_name = match_path.relative_to(embeddings_dir).parts[0] + results[dataset_name] = match_path + + return results + + +def find_channel_zarrs( + predictions_dir: Path, + channels: list[str] | None = None, +) -> dict[str, Path]: + """Find embedding zarr files for each channel in a predictions directory. + + Parameters + ---------- + predictions_dir : Path + Path to the version directory containing zarr files. + channels : list[str] or None + Channel names to search for. Defaults to CHANNELS. + + Returns + ------- + dict[str, Path] + Mapping of channel_name -> zarr path (only channels with a match). + """ + if channels is None: + channels = CHANNELS + channel_zarrs = {} + for channel in channels: + matches = natsorted(glob(str(predictions_dir / f"*{channel}*.zarr"))) + if matches: + channel_zarrs[channel] = Path(matches[0]) + return channel_zarrs + + +def find_annotation_csv(annotations_dir: Path, dataset_name: str) -> Path | None: + """Find the annotation CSV for a dataset. + + Parameters + ---------- + annotations_dir : Path + Base annotations directory. + dataset_name : str + Dataset folder name. + + Returns + ------- + Path or None + Path to CSV if found, None otherwise. + """ + dataset_dir = annotations_dir / dataset_name + if not dataset_dir.is_dir(): + return None + csvs = natsorted(glob(str(dataset_dir / "*.csv"))) + return Path(csvs[0]) if csvs else None + + +def get_available_tasks(csv_path: Path) -> list[str]: + """Read CSV header and return which valid task columns are present. + + Parameters + ---------- + csv_path : Path + Path to annotation CSV. + + Returns + ------- + list[str] + Task names found in the CSV columns. + """ + columns = pd.read_csv(csv_path, nrows=0).columns.tolist() + return [t for t in TASKS if t in columns] + + +def build_registry( + embeddings_dir: Path, + annotations_dir: Path, + model_name: str, + version: str, +) -> tuple[list[dict], list[dict], list[str], list[str]]: + """Build a registry of datasets with predictions and annotations. + + Parameters + ---------- + embeddings_dir : Path + Base directory containing dataset folders with embeddings. + annotations_dir : Path + Base directory containing dataset annotation folders. + model_name : str + Model directory name (supports glob patterns). + version : str + Version subdirectory (e.g. "v3"). + + Returns + ------- + registry : list[dict] + Datasets with both predictions and annotations. + skipped : list[dict] + Datasets with predictions but missing annotations or tasks. + annotations_only : list[str] + Annotation datasets with no matching predictions. + predictions_only : list[str] + Prediction datasets with no matching annotations. + """ + predictions = discover_predictions(embeddings_dir, model_name, version) + + registry: list[dict] = [] + skipped: list[dict] = [] + + for dataset_name, pred_dir in predictions.items(): + channel_zarrs = find_channel_zarrs(pred_dir) + csv_path = find_annotation_csv(annotations_dir, dataset_name) + + if not csv_path: + skipped.append({"dataset": dataset_name, "reason": "No annotation CSV"}) + continue + if not channel_zarrs: + skipped.append({"dataset": dataset_name, "reason": "No channel zarrs"}) + continue + + available_tasks = get_available_tasks(csv_path) + if not available_tasks: + skipped.append( + {"dataset": dataset_name, "reason": "No valid task columns in CSV"} + ) + continue + + registry.append( + { + "dataset": dataset_name, + "predictions_dir": pred_dir, + "channel_zarrs": channel_zarrs, + "annotations_csv": csv_path, + "available_tasks": available_tasks, + } + ) + + annotation_datasets = set(d.name for d in annotations_dir.iterdir() if d.is_dir()) + prediction_datasets = set(predictions.keys()) + + annotations_only = natsorted(annotation_datasets - prediction_datasets) + predictions_only = natsorted(prediction_datasets - annotation_datasets) + + return registry, skipped, annotations_only, predictions_only + + +def print_registry_summary( + registry: list[dict], + skipped: list[dict], + annotations_only: list[str], + predictions_only: list[str], +): + """Print a markdown summary of the dataset registry and gaps.""" + print("## Dataset Registry\n") + print("| Dataset | Annotations | Channels | Tasks |") + print("|---------|-------------|----------|-------|") + for entry in registry: + channels_str = ", ".join(sorted(entry["channel_zarrs"].keys())) + tasks_str = ", ".join(entry["available_tasks"]) + print( + f"| {entry['dataset']} | {entry['annotations_csv'].name} " + f"| {channels_str} | {tasks_str} |" + ) + + if annotations_only or predictions_only or skipped: + print("\n## Gaps\n") + print("| Dataset | Status |") + print("|---------|--------|") + for d in annotations_only: + print(f"| {d} | Annotations only (missing predictions) |") + for d in predictions_only: + print(f"| {d} | Predictions only (missing annotations) |") + for s in skipped: + print(f"| {s['dataset']} | {s['reason']} |") + + +# %% diff --git a/applications/qc/README.md b/applications/qc/README.md new file mode 100644 index 000000000..c12371de1 --- /dev/null +++ b/applications/qc/README.md @@ -0,0 +1,137 @@ +# QC Metrics Pipeline + +Composable quality control metrics for HCS OME-Zarr datasets. Results are written to `.zattrs` at both plate and position levels. + +## Usage + +```bash +viscy qc -c applications/qc/qc_config.yml +``` + +## Available Metrics + +### FocusSliceMetric + +Detects the in-focus z-slice per timepoint using midband spatial frequency power (Z-vectorized FFT via waveorder). + +**Parameters:** + +| Parameter | Description | +|---|---| +| `NA_det` | Detection numerical aperture | +| `lambda_ill` | Illumination wavelength (same units as `pixel_size`) | +| `pixel_size` | Object-space pixel size (camera pixel size / magnification) | +| `channel_names` | List of channel names, or `-1` for all channels in the dataset | +| `midband_fractions` | Inner/outer fractions of cutoff frequency (default `[0.125, 0.25]`) | +| `device` | Torch device (`cpu` or `cuda`) | + +## Configuration + +```yaml +data_path: /path/to/dataset.zarr +num_workers: 4 +metrics: + - class_path: viscy.preprocessing.focus.FocusSliceMetric + init_args: + NA_det: 1.35 + lambda_ill: 0.450 + pixel_size: 0.1494 + channel_names: + - Phase3D + - GFP + device: cuda +``` + +Use `channel_names: -1` to run on all channels: + +```yaml +metrics: + - class_path: viscy.preprocessing.focus.FocusSliceMetric + init_args: + NA_det: 1.35 + lambda_ill: 0.450 + pixel_size: 0.1494 + channel_names: -1 + device: cuda +``` + +Multiple metrics can be composed in the `metrics` list. + +## Output Structure + +### Plate-level `.zattrs` + +```json +{ + "focus_slice": { + "Phase3D": { + "dataset_statistics": { + "z_focus_mean": 5.3, + "z_focus_std": 1.2, + "z_focus_min": 3, + "z_focus_max": 8 + } + } + } +} +``` + +### Position-level `.zattrs` + +```json +{ + "focus_slice": { + "Phase3D": { + "dataset_statistics": {"z_focus_mean": 5.3, "z_focus_std": 1.2, "z_focus_min": 3, "z_focus_max": 8}, + "fov_statistics": {"z_focus_mean": 5.5, "z_focus_std": 0.7}, + "per_timepoint": {"0": 5, "1": 6, "2": 5} + } + } +} +``` + +## Inspecting Results + +```python +from iohub import open_ome_zarr + +ds = open_ome_zarr("/path/to/dataset.zarr", mode="r") +print(ds.zattrs["focus_slice"]) + +for name, pos in ds.positions(): + print(name, pos.zattrs["focus_slice"]) + break +``` + +## Adding Custom Metrics + +Subclass `QCMetric` and implement `channels()` and `__call__()`: + +```python +from viscy.preprocessing.qc_metrics import QCMetric + +class MyMetric(QCMetric): + field_name = "my_metric" + + def __init__(self, channel_names, ...): + self.channel_names = channel_names + + def channels(self): + return self.channel_names # list[str] or -1 for all + + def __call__(self, position, channel_name, channel_index, num_workers=4): + # compute metric per FOV + return { + "fov_statistics": {"key": value}, + "per_timepoint": {"0": value, "1": value}, + } +``` + +Then add it to the config: + +```yaml +metrics: + - class_path: my_module.MyMetric + init_args: + channel_names: -1 +``` diff --git a/applications/qc/qc_config.yml b/applications/qc/qc_config.yml new file mode 100644 index 000000000..af355c052 --- /dev/null +++ b/applications/qc/qc_config.yml @@ -0,0 +1,12 @@ +data_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2-assemble/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr +num_workers: 8 # num workers to fetch data +metrics: + - class_path: viscy.preprocessing.focus.FocusSliceMetric + init_args: + NA_det: 1.35 + lambda_ill: 0.450 + pixel_size: 0.1494 + channel_names: + - Phase3D + # - "raw GFP EX488 EM525-45" + device: cuda diff --git a/pyproject.toml b/pyproject.toml index d9985413a..419e9e7aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ "timm>=0.9.5", "torch>=2.4.1", "wandb", + "waveorder @ git+https://github.com/mehta-lab/waveorder.git@main", # temporary pin for xarray#10851 "xarray<=2025.9", ] diff --git a/tests/preprocessing/test_focus.py b/tests/preprocessing/test_focus.py new file mode 100644 index 000000000..0dece0cb7 --- /dev/null +++ b/tests/preprocessing/test_focus.py @@ -0,0 +1,112 @@ +import pytest +from iohub import open_ome_zarr + +from viscy.preprocessing.focus import FocusSliceMetric +from viscy.preprocessing.qc_metrics import generate_qc_metadata + + +@pytest.fixture +def focus_metric(): + return FocusSliceMetric( + NA_det=0.55, + lambda_ill=0.532, + pixel_size=0.325, + channel_names=["Phase"], + ) + + +@pytest.fixture +def focus_metric_all_channels(): + return FocusSliceMetric( + NA_det=0.55, + lambda_ill=0.532, + pixel_size=0.325, + channel_names=-1, + ) + + +def test_focus_slice_metric_call(temporal_hcs_dataset, focus_metric): + with open_ome_zarr(temporal_hcs_dataset, mode="r") as plate: + channel_index = plate.channel_names.index("Phase") + _, pos = next(iter(plate.positions())) + result = focus_metric(pos, "Phase", channel_index, num_workers=1) + + assert "fov_statistics" in result + assert "per_timepoint" in result + assert "z_focus_mean" in result["fov_statistics"] + assert "z_focus_std" in result["fov_statistics"] + for t in range(5): + assert str(t) in result["per_timepoint"] + idx = result["per_timepoint"][str(t)] + assert isinstance(idx, int) + assert 0 <= idx < 10 + + +def test_generate_qc_metadata_focus(temporal_hcs_dataset, focus_metric): + generate_qc_metadata( + zarr_dir=temporal_hcs_dataset, + metrics=[focus_metric], + num_workers=1, + ) + + with open_ome_zarr(temporal_hcs_dataset, mode="r") as plate: + assert "focus_slice" in plate.zattrs + assert "Phase" in plate.zattrs["focus_slice"] + ds_stats = plate.zattrs["focus_slice"]["Phase"]["dataset_statistics"] + assert "z_focus_mean" in ds_stats + assert "z_focus_std" in ds_stats + assert "z_focus_min" in ds_stats + assert "z_focus_max" in ds_stats + + for _, pos in plate.positions(): + assert "focus_slice" in pos.zattrs + pos_meta = pos.zattrs["focus_slice"]["Phase"] + assert "dataset_statistics" in pos_meta + assert "fov_statistics" in pos_meta + assert "per_timepoint" in pos_meta + + +def test_generate_qc_metadata_skips_unconfigured_channel( + temporal_hcs_dataset, focus_metric +): + generate_qc_metadata( + zarr_dir=temporal_hcs_dataset, + metrics=[focus_metric], + num_workers=1, + ) + + with open_ome_zarr(temporal_hcs_dataset, mode="r") as plate: + assert "Retardance" not in plate.zattrs.get("focus_slice", {}) + for _, pos in plate.positions(): + assert "Retardance" not in pos.zattrs.get("focus_slice", {}) + + +def test_generate_qc_metadata_per_timepoint_count(temporal_hcs_dataset, focus_metric): + generate_qc_metadata( + zarr_dir=temporal_hcs_dataset, + metrics=[focus_metric], + num_workers=1, + ) + + with open_ome_zarr(temporal_hcs_dataset, mode="r") as plate: + for _, pos in plate.positions(): + per_tp = pos.zattrs["focus_slice"]["Phase"]["per_timepoint"] + assert len(per_tp) == 5 + for t in range(5): + assert str(t) in per_tp + + +def test_generate_qc_metadata_all_channels( + temporal_hcs_dataset, focus_metric_all_channels +): + generate_qc_metadata( + zarr_dir=temporal_hcs_dataset, + metrics=[focus_metric_all_channels], + num_workers=1, + ) + + with open_ome_zarr(temporal_hcs_dataset, mode="r") as plate: + for ch in plate.channel_names: + assert ch in plate.zattrs["focus_slice"] + for _, pos in plate.positions(): + assert ch in pos.zattrs["focus_slice"] diff --git a/tests/representation/evaluation/test_linear_classifier.py b/tests/representation/evaluation/test_linear_classifier.py index e9a6975ab..6cf816b34 100644 --- a/tests/representation/evaluation/test_linear_classifier.py +++ b/tests/representation/evaluation/test_linear_classifier.py @@ -195,6 +195,57 @@ def test_predict_adds_uns_classes(self, pipeline_and_adata): pipeline.classifier.classes_ ) + def test_predict_stores_provenance(self, pipeline_and_adata): + pipeline, adata = pipeline_and_adata + metadata = { + "artifact_name": "linear-classifier-cell_death_state-phase:v2", + "artifact_id": "abc123", + "artifact_version": "v2", + } + result = predict_with_classifier( + adata.copy(), pipeline, "cell_death_state", artifact_metadata=metadata + ) + assert ( + result.uns["classifier_cell_death_state_artifact"] + == "linear-classifier-cell_death_state-phase:v2" + ) + assert result.uns["classifier_cell_death_state_id"] == "abc123" + assert result.uns["classifier_cell_death_state_version"] == "v2" + + def test_predict_no_provenance_by_default(self, pipeline_and_adata): + pipeline, adata = pipeline_and_adata + result = predict_with_classifier(adata.copy(), pipeline, "cell_death_state") + assert "classifier_cell_death_state_artifact" not in result.uns + assert "classifier_cell_death_state_id" not in result.uns + assert "classifier_cell_death_state_version" not in result.uns + + def test_predict_with_include_wells(self, pipeline_and_adata): + pipeline, adata = pipeline_and_adata + data = adata.copy() + result = predict_with_classifier( + data, pipeline, "cell_death_state", include_wells=["A/1"] + ) + well_mask = result.obs["fov_name"].str.startswith("A/1/") + predicted = result.obs["predicted_cell_death_state"] + assert predicted[well_mask].notna().all() + assert predicted[~well_mask].isna().all() + + proba = result.obsm["predicted_cell_death_state_proba"] + assert np.isfinite(proba[well_mask]).all() + assert np.isnan(proba[~well_mask]).all() + + def test_predict_marker_namespaced_task(self, pipeline_and_adata): + pipeline, adata = pipeline_and_adata + result = predict_with_classifier( + adata.copy(), + pipeline, + "organelle_state_g3bp1", + include_wells=["A/1"], + ) + assert "predicted_organelle_state_g3bp1" in result.obs.columns + assert "predicted_organelle_state_g3bp1_proba" in result.obsm + assert "predicted_organelle_state_g3bp1_classes" in result.uns + class TestLoadAndCombineDatasets: """Tests for the load_and_combine_datasets function.""" @@ -434,3 +485,34 @@ def test_output_exists_with_overwrite(self, tmp_path): overwrite=True, ) assert config.overwrite is True + + def test_output_path_none_defaults_to_inplace(self, tmp_path): + emb = tmp_path / "emb.zarr" + emb.mkdir() + config = LinearClassifierInferenceConfig( + wandb_project="test_project", + model_name="test_model", + embeddings_path=str(emb), + ) + assert config.output_path is None + + def test_include_wells(self, tmp_path): + emb = tmp_path / "emb.zarr" + emb.mkdir() + config = LinearClassifierInferenceConfig( + wandb_project="test_project", + model_name="test_model", + embeddings_path=str(emb), + include_wells=["A/1", "B/2"], + ) + assert config.include_wells == ["A/1", "B/2"] + + def test_include_wells_none_by_default(self, tmp_path): + emb = tmp_path / "emb.zarr" + emb.mkdir() + config = LinearClassifierInferenceConfig( + wandb_project="test_project", + model_name="test_model", + embeddings_path=str(emb), + ) + assert config.include_wells is None diff --git a/viscy/cli.py b/viscy/cli.py index 096106d38..510256fa3 100644 --- a/viscy/cli.py +++ b/viscy/cli.py @@ -23,6 +23,7 @@ def subcommands() -> dict[str, set[str]]: subcommands["export"] = subcommand_base_args subcommands["precompute"] = subcommand_base_args subcommands["convert_to_anndata"] = subcommand_base_args + subcommands["qc"] = subcommand_base_args return subcommands def add_arguments_to_parser(self, parser) -> None: @@ -63,12 +64,14 @@ def main() -> None: "preprocess", "precompute", "convert_to_anndata", + "qc", }.isdisjoint(sys.argv) require_data = { "preprocess", "precompute", "export", "convert_to_anndata", + "qc", }.isdisjoint(sys.argv) _ = VisCyCLI( model_class=LightningModule, diff --git a/viscy/preprocessing/focus.py b/viscy/preprocessing/focus.py new file mode 100644 index 000000000..0066ba8ba --- /dev/null +++ b/viscy/preprocessing/focus.py @@ -0,0 +1,84 @@ +from typing import Literal + +import numpy as np +import tensorstore +import torch +from waveorder.focus import focus_from_transverse_band + +from viscy.preprocessing.qc_metrics import QCMetric + + +class FocusSliceMetric(QCMetric): + """In-focus z-slice detection using midband spatial frequency power. + + Parameters + ---------- + NA_det : float + Detection numerical aperture. + lambda_ill : float + Illumination wavelength (same units as pixel_size). + pixel_size : float + Object-space pixel size (camera pixel size / magnification). + channel_names : list[str] or -1 + Channel names to compute focus for. Use -1 for all channels. + midband_fractions : tuple[float, float] + Inner and outer fractions of cutoff frequency. + device : str + Torch device for FFT computation. + """ + + field_name = "focus_slice" + + def __init__( + self, + NA_det: float, + lambda_ill: float, + pixel_size: float, + channel_names: list[str] | Literal[-1] = -1, + midband_fractions: tuple[float, float] = (0.125, 0.25), + device: str = "cpu", + ): + self.NA_det = NA_det + self.lambda_ill = lambda_ill + self.pixel_size = pixel_size + self.channel_names = channel_names + self.midband_fractions = midband_fractions + self.device = device + + def channels(self) -> list[str] | Literal[-1]: + return self.channel_names + + def __call__(self, position, channel_name, channel_index, num_workers=4): + tzyx = ( + position["0"] + .tensorstore( + context=tensorstore.Context( + {"data_copy_concurrency": {"limit": num_workers}} + ) + )[:, channel_index] + .read() + .result() + ) + + T = tzyx.shape[0] + focus_indices = np.empty(T, dtype=int) + + for t in range(T): + zyx = torch.from_numpy(np.asarray(tzyx[t])).to(self.device) + focus_indices[t] = focus_from_transverse_band( + zyx, + NA_det=self.NA_det, + lambda_ill=self.lambda_ill, + pixel_size=self.pixel_size, + midband_fractions=self.midband_fractions, + ) + + per_timepoint = {str(t): int(idx) for t, idx in enumerate(focus_indices)} + fov_stats = { + "z_focus_mean": float(np.mean(focus_indices)), + "z_focus_std": float(np.std(focus_indices)), + } + return { + "fov_statistics": fov_stats, + "per_timepoint": per_timepoint, + } diff --git a/viscy/preprocessing/qc_metrics.py b/viscy/preprocessing/qc_metrics.py new file mode 100644 index 000000000..fddf81232 --- /dev/null +++ b/viscy/preprocessing/qc_metrics.py @@ -0,0 +1,119 @@ +from abc import ABC, abstractmethod +from typing import Literal + +import iohub.ngff as ngff +import numpy as np +from tqdm import tqdm + +from viscy.utils.meta_utils import write_meta_field + + +class QCMetric(ABC): + """Base class for composable QC metrics. + + Each metric: + - Owns its channel list and per-channel config + - Reads data and computes results per FOV + - Returns structured dicts for zattrs storage + """ + + field_name: str + + @abstractmethod + def channels(self) -> list[str] | Literal[-1]: + """Channel names this metric operates on. + + Return -1 to operate on all channels in the dataset. + """ + ... + + @abstractmethod + def __call__( + self, + position: ngff.Position, + channel_name: str, + channel_index: int, + num_workers: int = 4, + ) -> dict: + """Compute metric for one FOV and one channel. + + Returns + ------- + dict + { + "fov_statistics": {"key": value, ...}, + "per_timepoint": {"0": value, "1": value, ...}, + } + """ + ... + + +def generate_qc_metadata( + zarr_dir: str, + metrics: list[QCMetric], + num_workers: int = 4, +) -> None: + """Run composable QC metrics across an HCS dataset. + + Each metric specifies its own channels (or -1 for all). + The orchestrator iterates positions, dispatches to each metric + for its channels, aggregates dataset-level statistics, and + writes to .zattrs. + + Parameters + ---------- + zarr_dir : str + Path to the HCS OME-Zarr dataset. + metrics : list[QCMetric] + List of QC metric instances to compute. + num_workers : int + Number of workers for data loading. + """ + plate = ngff.open_ome_zarr(zarr_dir, mode="r+") + position_map = list(plate.positions()) + + for metric in metrics: + channel_list = metric.channels() + if channel_list == -1: + channel_list = list(plate.channel_names) + + for channel_name in channel_list: + channel_index = plate.channel_names.index(channel_name) + print(f"Computing {metric.field_name} for channel '{channel_name}'") + + all_focus_values = [] + position_results = [] + + for _, pos in tqdm(position_map, desc="Positions"): + result = metric(pos, channel_name, channel_index, num_workers) + position_results.append((pos, result)) + tp_values = list(result["per_timepoint"].values()) + all_focus_values.extend(tp_values) + + arr = np.array(all_focus_values, dtype=float) + dataset_stats = { + "z_focus_mean": float(np.mean(arr)), + "z_focus_std": float(np.std(arr)), + "z_focus_min": int(np.min(arr)), + "z_focus_max": int(np.max(arr)), + } + + write_meta_field( + position=plate, + metadata={"dataset_statistics": dataset_stats}, + field_name=metric.field_name, + subfield_name=channel_name, + ) + + for pos, result in position_results: + write_meta_field( + position=pos, + metadata={ + "dataset_statistics": dataset_stats, + **result, + }, + field_name=metric.field_name, + subfield_name=channel_name, + ) + + plate.close() diff --git a/viscy/representation/evaluation/linear_classifier.py b/viscy/representation/evaluation/linear_classifier.py index 84a2730d3..7b9027800 100644 --- a/viscy/representation/evaluation/linear_classifier.py +++ b/viscy/representation/evaluation/linear_classifier.py @@ -361,6 +361,8 @@ def predict_with_classifier( adata: ad.AnnData, pipeline: LinearClassifierPipeline, task: str, + artifact_metadata: Optional[dict] = None, + include_wells: Optional[list[str]] = None, ) -> ad.AnnData: """Apply trained classifier to make predictions on new data. @@ -371,7 +373,16 @@ def predict_with_classifier( pipeline : LinearClassifierPipeline Trained classifier pipeline with preprocessing. task : str - Name of the classification task. + Name of the classification task (used as column suffix). + artifact_metadata : Optional[dict] + W&B artifact metadata from ``load_pipeline_from_wandb``. When provided, + provenance keys are stored in ``adata.uns`` under + ``classifier_{task}_artifact``, ``classifier_{task}_id``, and + ``classifier_{task}_version``. + include_wells : Optional[list[str]] + Well prefixes to restrict prediction to (e.g. ``["A/1", "B/2"]``). + Cells in other wells will have ``NaN`` for prediction columns. + When ``None``, all cells are predicted. Returns ------- @@ -381,19 +392,44 @@ def predict_with_classifier( and class labels in .uns[f"predicted_{task}_classes"]. """ print("\nApplying preprocessing and making predictions...") - X = adata.X if isinstance(adata.X, np.ndarray) else adata.X.toarray() - predictions = pipeline.predict(X) - prediction_proba = pipeline.predict_proba(X) + if include_wells is not None: + well_mask = adata.obs["fov_name"].str.startswith( + tuple(w + "/" for w in include_wells) + ) + n_matched = well_mask.sum() + print(f" Well filter: {include_wells} -> {n_matched}/{len(adata)} cells") + else: + well_mask = np.ones(len(adata), dtype=bool) + + X_full = adata.X if isinstance(adata.X, np.ndarray) else adata.X.toarray() + X_subset = X_full[well_mask] + + predictions_subset = pipeline.predict(X_subset) + proba_subset = pipeline.predict_proba(X_subset) + n_classes = proba_subset.shape[1] - adata.obs[f"predicted_{task}"] = predictions - adata.obsm[f"predicted_{task}_proba"] = prediction_proba + all_predictions = np.full(len(adata), np.nan, dtype=object) + all_predictions[well_mask] = predictions_subset + + all_proba = np.full((len(adata), n_classes), np.nan) + all_proba[well_mask] = proba_subset + + adata.obs[f"predicted_{task}"] = all_predictions + adata.obsm[f"predicted_{task}_proba"] = all_proba adata.uns[f"predicted_{task}_classes"] = pipeline.classifier.classes_.tolist() + if artifact_metadata is not None: + adata.uns[f"classifier_{task}_artifact"] = artifact_metadata["artifact_name"] + adata.uns[f"classifier_{task}_id"] = artifact_metadata["artifact_id"] + adata.uns[f"classifier_{task}_version"] = artifact_metadata["artifact_version"] + + predicted_values = adata.obs[f"predicted_{task}"].dropna() print("✓ Predictions complete") + print(f" Predicted {len(predicted_values)}/{len(adata)} cells") print(" Predicted class distribution:") - print(adata.obs[f"predicted_{task}"].value_counts()) - print(f" Probability matrix shape: {prediction_proba.shape}") + print(predicted_values.value_counts()) + print(f" Probability matrix shape: {all_proba.shape}") print(f" Classes: {pipeline.classifier.classes_.tolist()}") return adata @@ -435,17 +471,21 @@ def save_pipeline_to_wandb( task = config["task"] input_channel = config["input_channel"] + marker = config.get("marker") use_pca = config.get("preprocessing", {}).get("use_pca", False) n_pca = config.get("preprocessing", {}).get("n_pca_components") model_name = f"linear-classifier-{task}-{input_channel}" + if marker: + model_name += f"-{marker}" if use_pca: model_name += f"-pca{n_pca}" run = wandb.init( project=wandb_project, entity=wandb_entity, - job_type=f"linear-classifier-{task}-{input_channel}", + job_type=f"linear-classifier-{task}-{input_channel}" + + (f"-{marker}" if marker else ""), name=model_name, group=model_name, config=config, @@ -503,7 +543,7 @@ def load_pipeline_from_wandb( model_name: str, version: str = "latest", wandb_entity: Optional[str] = None, -) -> tuple[LinearClassifierPipeline, dict]: +) -> tuple[LinearClassifierPipeline, dict, dict]: """Load trained pipeline and config from Weights & Biases. Parameters @@ -523,6 +563,9 @@ def load_pipeline_from_wandb( Loaded classifier pipeline. dict Configuration used for training. + dict + Artifact metadata with keys ``artifact_name``, ``artifact_id``, + and ``artifact_version``. """ print("\n" + "=" * 60) print("LOADING MODEL FROM WANDB") @@ -535,6 +578,11 @@ def load_pipeline_from_wandb( ) artifact = run.use_artifact(f"{model_name}:{version}") + artifact_metadata = { + "artifact_name": f"{model_name}:{artifact.version}", + "artifact_id": artifact.id, + "artifact_version": artifact.version, + } artifact_dir = Path(artifact.download()) config_path = artifact_dir / f"{model_name}_config.json" @@ -573,4 +621,4 @@ def load_pipeline_from_wandb( run.finish() - return pipeline, config + return pipeline, config, artifact_metadata diff --git a/viscy/representation/evaluation/linear_classifier_config.py b/viscy/representation/evaluation/linear_classifier_config.py index c1c4812d3..170c8b8e6 100644 --- a/viscy/representation/evaluation/linear_classifier_config.py +++ b/viscy/representation/evaluation/linear_classifier_config.py @@ -57,6 +57,10 @@ class LinearClassifierTrainConfig(BaseModel): # Task metadata task: VALID_TASKS = Field(...) input_channel: VALID_CHANNELS = Field(...) + marker: Optional[str] = Field( + default=None, + description="Marker name for marker-specific tasks (e.g. g3bp1, sec61b, tomm20).", + ) embedding_model: str = Field(..., min_length=1) # Training datasets @@ -138,8 +142,13 @@ class LinearClassifierInferenceConfig(BaseModel): W&B entity (username or team). embeddings_path : str Path to embeddings zarr file for inference. - output_path : str - Path to save output zarr file with predictions. + output_path : Optional[str] + Path to save output zarr file with predictions. When ``None`` + (the default), predictions are written back to ``embeddings_path``. + include_wells : Optional[list[str]] + Well prefixes to restrict prediction to (e.g. ``["A/1", "B/2"]``). + Cells in other wells will have ``NaN`` for prediction columns. + When ``None`` (the default), all cells are predicted. overwrite : bool Whether to overwrite output if it exists. """ @@ -149,12 +158,11 @@ class LinearClassifierInferenceConfig(BaseModel): version: str = Field(default="latest", min_length=1) wandb_entity: Optional[str] = Field(default=None) embeddings_path: str = Field(..., min_length=1) - output_path: str = Field(..., min_length=1) + output_path: Optional[str] = Field(default=None) + include_wells: Optional[list[str]] = Field(default=None) overwrite: bool = Field(default=False) - @field_validator( - "wandb_project", "model_name", "version", "embeddings_path", "output_path" - ) + @field_validator("wandb_project", "model_name", "version", "embeddings_path") @classmethod def validate_non_empty(cls, v: str) -> str: """Ensure string fields are non-empty.""" @@ -166,14 +174,15 @@ def validate_non_empty(cls, v: str) -> str: def validate_paths(self): """Validate input exists and output doesn't exist unless overwrite=True.""" embeddings_path = Path(self.embeddings_path) - output_path = Path(self.output_path) if not embeddings_path.exists(): raise ValueError(f"Embeddings file not found: {self.embeddings_path}") - if output_path.exists() and not self.overwrite: - raise ValueError( - f"Output file already exists: {self.output_path}. " - f"Set overwrite=true to overwrite." - ) + if self.output_path is not None: + output_path = Path(self.output_path) + if output_path.exists() and not self.overwrite: + raise ValueError( + f"Output file already exists: {self.output_path}. " + f"Set overwrite=true to overwrite." + ) return self diff --git a/viscy/trainer.py b/viscy/trainer.py index 76cc2ba3b..9b0d90c31 100644 --- a/viscy/trainer.py +++ b/viscy/trainer.py @@ -9,6 +9,7 @@ from torch.onnx import OperatorExportTypes from viscy.preprocessing.precompute import precompute_array +from viscy.preprocessing.qc_metrics import QCMetric, generate_qc_metadata from viscy.representation.evaluation.annotation import convert from viscy.utils.meta_utils import generate_normalization_metadata @@ -130,6 +131,34 @@ def precompute( exclude_fovs=exclude_fovs, ) + def qc( + self, + data_path: Path, + metrics: list[QCMetric], + num_workers: int = 1, + model: LightningModule | None = None, + ): + """Run composable QC metrics on an HCS OME-Zarr dataset. + + Parameters + ---------- + data_path : Path + Path to the HCS OME-Zarr dataset. + metrics : list[QCMetric] + QC metric instances (uses jsonargparse class_path/init_args). + num_workers : int + Number of workers for data loading. + model : LightningModule or None + Ignored placeholder. + """ + if model is not None: + _logger.warning("Ignoring model configuration during QC.") + generate_qc_metadata( + zarr_dir=data_path, + metrics=metrics, + num_workers=num_workers, + ) + def convert_to_anndata( self, embeddings_path: Path,