From 784adc2e3e3de477025fc369a4175594f1400878 Mon Sep 17 00:00:00 2001 From: Joshua Gould Date: Tue, 9 Jun 2026 13:23:34 -0400 Subject: [PATCH 01/21] Transform labels instead of images --- scallops/cli/util.py | 154 ++++++++++- scallops/io.py | 6 +- scallops/tests/test_wdl.py | 86 ++++-- scallops/utils.py | 127 --------- wdl/ops_tasks.wdl | 6 +- wdl/ops_workflow.wdl | 524 ++++++++++++++++++------------------- wdl/utils.wdl | 20 +- 7 files changed, 494 insertions(+), 429 deletions(-) diff --git a/scallops/cli/util.py b/scallops/cli/util.py index 31d11d0..595f0fa 100644 --- a/scallops/cli/util.py +++ b/scallops/cli/util.py @@ -93,7 +93,7 @@ def _dask_workers_threads( } -def _create_default_dask_config(config: dict | None = None) -> set: +def _create_default_dask_config(config: dict | None = None) -> dask.config.set: if config is None: return dask.config.set(DEFAULT_DASK_CONFIG) @@ -450,3 +450,155 @@ def load_json(path_or_str: str) -> dict: with fs.open(path_or_str, "rt") as fp: return json.load(fp) return json.loads(path_or_str) + + +def _write_img_size(file_list: list[str]): + from scallops.io import _images2fov, _localize_path + + local_file_list = [] + cleanup_file_list = [] + for path in file_list: + local_path = _localize_path(path) + if local_path is not None: + cleanup_file_list.append(local_path) + local_file_list.append(local_path) + else: + local_file_list.append(path) + sizes = _images2fov(local_file_list, dask=True).sizes + for path in cleanup_file_list: + os.remove(path) + with open("img_size.txt", "wt") as f: + for dim in ["t", "c", "z", "y", "x"]: + s = sizes[dim] if dim in sizes else 0 + f.write(f"{s}") + f.write("\n") + + +def _write_group_size(metadata: dict): + n_tiles = len(metadata["file_metadata"]) + metadata_fields = [v for v in ("c", "z") if v in metadata["file_metadata"][0]] + if len(metadata_fields) > 0: + from scallops.cli.util import _group_src_attrs + + keys, channel_sources, filepaths = _group_src_attrs( + metadata=metadata, metadata_fields=tuple(metadata_fields) + ) + n_tiles = len(filepaths) + with open("group_size.txt", "wt") as f: + f.write(f"{n_tiles}") + f.write("\n") + + +def _list_images_wdl( + image_pattern: str, + urls: list[str], + groupby: list[str], + reference_time: str | None, + subset: list[str] | None, + batch_size_str: str | None, + save_group_size: bool = False, + expected_cycles_str: int | None = None, +): + """Used by WDL workflow to output info about images""" + from scallops.io import _set_up_experiment + + batch_size = 1 + expected_cycles = None + if expected_cycles_str is not None and expected_cycles_str != "": + expected_cycles = int(expected_cycles_str) + if batch_size_str is not None and batch_size_str != "": + batch_size = int(batch_size_str) + if reference_time == "": + reference_time = None + + if subset is not None and ( + len(subset) == 0 or (len(subset) == 1 and subset[0] == "") + ): + subset = None + if image_pattern != "": + groupby = [g for g in groupby if "{" + g + "}" in image_pattern] + exp_gen = _set_up_experiment( + image_path=urls, files_pattern=image_pattern, group_by=groupby, subset=subset + ) + # "groups.txt": each line passed to --subset in cli + # "groupby.txt": filtered groupby with values not in image_pattern removed + groupby_t = "t" in groupby + times = [] + + if not save_group_size: + with open("group_size.txt", "wt") as f: + f.write("0\n") + + with ( + open("subsets.txt", "wt") as groups_out, + open("subsets_with_t.txt", "wt") as groups_with_t_out, + ): + subset_ids = [] + subset_ids_with_reference_times = [] + first = True + + for g, file_list, metadata in exp_gen: + times = None + if first: + first = False + if save_group_size: + _write_group_size(metadata) + if not groupby_t and "t" in metadata["file_metadata"][0]: + times = [md["t"] for md in metadata["file_metadata"]] + if expected_cycles is not None: + assert len(times) == expected_cycles + t_suffix = "" + if times is not None and len(times) > 0: + t_suffix = ( + f"-{times[0]}" if reference_time is None else f"-{reference_time}" + ) + + subset_ids.append('"' + metadata["id"] + '"') + subset_ids_with_reference_times.append( + '"' + metadata["id"] + t_suffix + '"' + ) + if len(subset_ids) == batch_size: + groups_out.write(" ".join(subset_ids)) + groups_out.write("\n") + + groups_with_t_out.write(" ".join(subset_ids_with_reference_times)) + groups_with_t_out.write("\n") + + subset_ids = [] + subset_ids_with_reference_times = [] + if len(subset_ids) > 0: + groups_out.write(" ".join(subset_ids)) + groups_out.write("\n") + + groups_with_t_out.write(" ".join(subset_ids_with_reference_times)) + groups_with_t_out.write("\n") + + with open("groupby.txt", "wt") as f: + for g in groupby: + f.write(g) + f.write("\n") + groupby_with_t = list(groupby) + + if not groupby_t and times is not None: + groupby_with_t.append("t") + + with open("groupby_with_t.txt", "wt") as f: + for g in groupby_with_t: + f.write(g) + f.write("\n") + + with open("t.txt", "wt") as f: + if times is not None: + for val in times: + f.write(str(val)) + f.write("\n") + + with open("groupby_pattern.txt", "wt") as f: + first = True + for g in groupby: + if not first: + f.write("-") + first = False + f.write("{") + f.write(g) + f.write("}") diff --git a/scallops/io.py b/scallops/io.py index f58c16c..91f4343 100644 --- a/scallops/io.py +++ b/scallops/io.py @@ -1277,8 +1277,10 @@ def _images2fov( if image is None: raise ValueError(f"{file_list[i]} could not be read.") if file_metadata is not None: - if "t" in file_metadata[i] and isinstance( - file_metadata[i]["t"], str + if ( + "t" in file_metadata[i] + and isinstance(file_metadata[i]["t"], str) + and file_metadata[i]["t"].isdigit() ): # convert to int try: # note we replace "_" with "-" so we don't convert "1_2" to 12 e.g. diff --git a/scallops/tests/test_wdl.py b/scallops/tests/test_wdl.py index 0e1b2ee..1fa68bf 100644 --- a/scallops/tests/test_wdl.py +++ b/scallops/tests/test_wdl.py @@ -1,4 +1,3 @@ -import glob import json import os.path from subprocess import check_call @@ -10,6 +9,7 @@ import xarray as xr from scallops import Experiment +from scallops.cli.util import _list_images_wdl from scallops.io import read_image, save_ome_tiff from scallops.tests.test_stitch import _write_image_with_position @@ -27,6 +27,42 @@ def add_physical_size(input_path, output_path): save_ome_tiff(img.values, uri=output_path, ome_xml=img.attrs["processed"].to_xml()) +@pytest.mark.parametrize("reference_time", ["IF", None]) +@pytest.mark.cli_e2e +def test_list_images_wdl(reference_time, tmp_path, monkeypatch): + (tmp_path / "plate1-A1-IF").touch() + (tmp_path / "plate1-A1-FISH").touch() + monkeypatch.chdir(tmp_path) + _list_images_wdl( + image_pattern="{plate}-{well}-{t}", + reference_time=reference_time, + urls=[str(tmp_path)], + groupby=["plate", "well"], + subset=None, + batch_size_str=None, + save_group_size=False, + expected_cycles_str=None, + ) + groups = pd.read_csv(tmp_path / "groups.txt", header=None)[0].values + np.testing.assert_array_equal(groups, ["plate1-A1"]) + groupby = pd.read_csv(tmp_path / "groupby.txt", header=None)[0].values + np.testing.assert_array_equal(groupby, ["plate", "well"]) + groupby_pattern = pd.read_csv(tmp_path / "groupby_pattern.txt", header=None)[ + 0 + ].values + np.testing.assert_array_equal(groupby_pattern, ["{plate}-{well}"]) + + times = pd.read_csv(tmp_path / "t.txt", header=None)[0].values + np.testing.assert_array_equal(times, ["FISH", "IF"]) + groups_with_times = pd.read_csv(tmp_path / "groups_with_t.txt", header=None)[ + 0 + ].values + if reference_time == "IF": + np.testing.assert_array_equal(groups_with_times, ["plate1-A1-IF"]) + else: + np.testing.assert_array_equal(groups_with_times, ["plate1-A1-FISH"]) + + @pytest.mark.cli_e2e def test_stitch_wdl_z_stack(tmp_path): input_path = tmp_path / "input" @@ -119,15 +155,15 @@ def test_stitch_wdl(tmp_path): @pytest.mark.cli_e2e def test_ops_wdl(tmp_path): - sbs_dir = tmp_path / "sbs" - pheno_dir = tmp_path / "pheno" + sbs_dir = tmp_path / "sbs.zarr" + pheno_dir = tmp_path / "pheno.zarr" output = tmp_path / "out" - sbs_dir.mkdir() - pheno_dir.mkdir() output.mkdir() - for p in glob.glob("scallops/tests/data/experimentC/input/*/*Tile-102*"): - add_physical_size(p, str(sbs_dir / os.path.basename(p))) + iss_img = read_image( + "scallops/tests/data/experimentC/input/10X_c1-SBS-1/10X_c1-SBS-1_A1_Tile-102.sbs.tif" + ) + iss_img.attrs["physical_pixel_sizes"] = (1, 1) pheno_img = read_image( "scallops/tests/data/experimentC/10X_c0-DAPI-p65ab/10X_c0-DAPI-p65ab_A1_Tile-102.phenotype.tif" ) @@ -135,45 +171,47 @@ def test_ops_wdl(tmp_path): phenotype_mask = np.ones( (pheno_img.sizes["y"], pheno_img.sizes["x"]), dtype=np.uint8 ) - phenotype_mask[10, 10] = 1 + phenotype_mask[10, 10] = 0 phenotype_tile = np.ones( (pheno_img.sizes["y"], pheno_img.sizes["x"]), dtype=np.uint16 ) phenotype_tile[10, 10] = 2 - exp = Experiment( - images={"A1-102-1": pheno_img, "A1-102-2": pheno_img}, + Experiment( + images={"plateA-A1-IF": pheno_img, "plateA-A1-FISH": pheno_img}, labels={ - "A1-102-1-mask": phenotype_mask, - "A1-102-1-tile": phenotype_tile, - "A1-102-2-mask": phenotype_mask, - "A1-102-2-tile": phenotype_tile, + "plateA-A1-IF-mask": phenotype_mask, + "plateA-A1-IF-tile": phenotype_tile, + "plateA-A1-FISH-mask": phenotype_mask, + "plateA-A1-FISH-tile": phenotype_tile, }, - ) - exp.save(str(pheno_dir)) + ).save(pheno_dir) + + Experiment( + images={"plateA-A1-1": iss_img, "plateA-A1-2": iss_img}, + ).save(sbs_dir) input_json = { "model_dir": "", "iss_url": str(sbs_dir.absolute()), - "iss_image_pattern": "{mag}X_c{t}-{experiment}-{t}_{well}_Tile-{tile}.{datatype}.tif", "output_directory": str(output.absolute()), + "nuclei_segmentation_method": "cellpose", "iss_registration_extra_arguments": "--no-landmarks", "pheno_to_iss_registration_extra_arguments": "--no-landmarks", "pheno_registration_extra_arguments": "--no-landmarks", "phenotype_cyto_channel": [1], - "phenotype_dapi_channel": 0, + "reference_phenotype_time": "IF", "phenotype_url": str(pheno_dir.absolute()), - "phenotype_nuclei_features": ["intensity_0", "intensity_1"], + "phenotype_nuclei_features": { + "IF": ["intensity_0", "intensity_1"], + "FISH": ["intensity_0", "intensity_1"], + }, # 2 batches - "phenotype_cell_features": ["intensity_0"], + "phenotype_cell_features": {"IF": ["intensity_0"]}, # "phenotype_cytosol_features": ["mean_0 area"], # no cytosol features - "phenotype_image_pattern": "{well}-{tile}-{t}", - "groupby": ["well", "tile"], "reads_threshold_peaks": "0", "reads_threshold_peaks_crosstalk": "20", "barcodes": os.path.abspath("scallops/tests/data/experimentC/barcodes.csv"), - "mark_stitch_boundary_cells": False, "reads_labels": "cell", - "merge_extra_arguments": "--format parquet", "docker": "", } diff --git a/scallops/utils.py b/scallops/utils.py index f015589..6663f8d 100644 --- a/scallops/utils.py +++ b/scallops/utils.py @@ -637,130 +637,3 @@ def _dask_from_array_no_copy( meta = x return da.Array(dsk, name, chunks, meta=meta, dtype=getattr(x, "dtype", None)) - - -def _write_img_size(file_list: list[str]): - from scallops.io import _images2fov, _localize_path - - local_file_list = [] - cleanup_file_list = [] - for path in file_list: - local_path = _localize_path(path) - if local_path is not None: - cleanup_file_list.append(local_path) - local_file_list.append(local_path) - else: - local_file_list.append(path) - sizes = _images2fov(local_file_list, dask=True).sizes - for path in cleanup_file_list: - os.remove(path) - with open("img_size.txt", "wt") as f: - for dim in ["t", "c", "z", "y", "x"]: - s = sizes[dim] if dim in sizes else 0 - f.write(f"{s}") - f.write("\n") - - -def _write_group_size(metadata: dict): - n_tiles = len(metadata["file_metadata"]) - metadata_fields = [v for v in ("c", "z") if v in metadata["file_metadata"][0]] - if len(metadata_fields) > 0: - from scallops.cli.util import _group_src_attrs - - keys, channel_sources, filepaths = _group_src_attrs( - metadata=metadata, metadata_fields=tuple(metadata_fields) - ) - n_tiles = len(filepaths) - with open("group_size.txt", "wt") as f: - f.write(f"{n_tiles}") - f.write("\n") - - -def _list_images_wdl( - image_pattern: str, - urls: list[str], - groupby: list[str], - subset: list[str], - batch_size_str: str, - save_group_size: bool = False, - expected_cycles_str: int | None = None, -): - """Used by WDL workflow to output info about images""" - from scallops.io import _set_up_experiment - - batch_size = 0 - expected_cycles = None - if expected_cycles_str != "": - expected_cycles = int(expected_cycles_str) - if batch_size_str != "": - batch_size = int(batch_size_str) - - if len(subset) == 0 or (len(subset) == 1 and subset[0] == ""): - subset = None - if image_pattern != "": - groupby = [g for g in groupby if "{" + g + "}" in image_pattern] - exp_gen = _set_up_experiment( - image_path=urls, files_pattern=image_pattern, group_by=groupby, subset=subset - ) - # "groups.txt" is passed to --subset in cli - # "groupby.txt" filtered groupby - groupby_t = "t" in groupby - t = [] - - if not save_group_size: - with open("group_size.txt", "wt") as f: - f.write("0\n") - if batch_size > 0: - with open("groups.txt", "wt") as f: - ids = [] - first = True - for g, file_list, metadata in exp_gen: - if first: - first = False - if save_group_size: - _write_group_size(metadata) - if not groupby_t and "t" in metadata["file_metadata"][0]: - t = [md["t"] for md in metadata["file_metadata"]] - if expected_cycles is not None: - assert len(t) == expected_cycles - - ids.append('"' + metadata["id"] + '"') - if len(ids) == batch_size: - f.write(" ".join(ids)) - f.write("\n") - ids = [] - if len(ids) > 0: - f.write(" ".join(ids)) - f.write("\n") - else: - with open("groups.txt", "wt") as f: - first = True - for g, file_list, metadata in exp_gen: - f.write(metadata["id"]) - f.write("\n") - if first: - first = False - if save_group_size: - _write_group_size(metadata) - if not groupby_t and "t" in metadata["file_metadata"][0]: - t = [md["t"] for md in metadata["file_metadata"]] - - with open("groupby.txt", "wt") as f: - for g in groupby: - f.write(g) - f.write("\n") - - with open("t.txt", "wt") as f: - for val in t: - f.write(str(val)) - f.write("\n") - - with open("groupby_pattern.txt", "wt") as f: - first = True - for g in groupby: - if not first: - f.write("-") - first = False - f.write("{") - f.write(g) - f.write("}") diff --git a/wdl/ops_tasks.wdl b/wdl/ops_tasks.wdl index f17dc52..396bdb5 100644 --- a/wdl/ops_tasks.wdl +++ b/wdl/ops_tasks.wdl @@ -1,4 +1,4 @@ -version 1.0 +version 1.1 task segment_nuclei { input { @@ -430,7 +430,7 @@ task intersects_boundary { task find_objects { input { - String? labels + Array[String] labels String subset Boolean? force String? label_pattern @@ -454,7 +454,7 @@ task find_objects { fi scallops find-objects \ - --labels "~{labels}" \ + --labels ~{sep=" " labels} \ --subset ~{subset} \ ~{"--label-pattern " + label_pattern} \ --label-suffix ~{suffix} \ diff --git a/wdl/ops_workflow.wdl b/wdl/ops_workflow.wdl index 1955d4a..9c14b86 100644 --- a/wdl/ops_workflow.wdl +++ b/wdl/ops_workflow.wdl @@ -1,4 +1,4 @@ -version 1.0 +version 1.1 import "utils.wdl" as utils import "ops_tasks.wdl" as tasks @@ -15,14 +15,14 @@ workflow ops_workflow { String output_directory - # t to align phenotyping rounds to e.g. "IF" + # t to align phenotyping rounds to e.g. "IF". If not specified then first round in natural sorted order is used String? reference_phenotype_time # features String? features_label_filter = "~barcode_count_0.isna()" # valid barcodes - Array[String]? phenotype_cell_features - Array[String]? phenotype_nuclei_features - Array[String]? phenotype_cytosol_features + Map[String, Array[String]]? phenotype_cell_features + Map[String, Array[String]]? phenotype_nuclei_features + Map[String, Array[String]]? phenotype_cytosol_features String? features_extra_arguments # Single string with extra arguments to scallops features cli Int? features_cell_min_area @@ -32,16 +32,15 @@ workflow ops_workflow { Int? features_nuclei_max_area Int? features_cytosol_max_area - Array[Int] phenotype_cyto_channel # indices after registration for cell segmentation - Int phenotype_dapi_channel # index after registration for segmentation and pheno to iss registration - Int? phenotype_dapi_channel_before_registration # for pheno to pheno registration + Array[Int] phenotype_cyto_channel # indices within referent time for cell segmentation + Int? phenotype_dapi_channel # index within t for nuclei segmentation and pheno to iss registration Int? iss_dapi_channel # ISS to ISS and pheno to ISS registration String? iss_registration_extra_arguments # Extra arguments in scallops registration elastix cli for ISS String? pheno_to_iss_registration_extra_arguments String? pheno_registration_extra_arguments - Boolean? register_across_channels + # spot detect Int? iss_expected_cycles @@ -68,7 +67,7 @@ workflow ops_workflow { String model_dir = "" # nuclei segment - String? nuclei_segmentation + String? nuclei_segmentation_method String? nuclei_segmentation_extra_arguments # cell segment @@ -156,7 +155,6 @@ workflow ops_workflow { String cell_intersects_boundary_memory = "32 GiB" String cell_intersects_boundary_disks = "local-disk 200 HDD" - String docker Int preemptible = 0 @@ -187,7 +185,6 @@ workflow ops_workflow { String cell_intersects_boundary_non_reference_t_suffix = "intersects-boundary-t" } - String output_stripped = sub(output_directory, "/+$", "") + "/" String segment_directory = output_stripped + segment_suffix String register_iss_t0_directory = output_stripped + register_iss_suffix @@ -208,7 +205,6 @@ workflow ops_workflow { String merge_features_directory = output_stripped + merge_features_suffix String register_pheno_to_iss_qc_directory = output_stripped + register_pheno_to_iss_qc_suffix String cell_intersects_boundary_directory = output_stripped + cell_intersects_boundary_suffix - String cell_intersects_boundary_directory_non_reference_t = output_stripped + cell_intersects_boundary_non_reference_t_suffix Boolean iss_url_supplied = defined(iss_url) Boolean pheno_url_supplied = defined(phenotype_url) @@ -218,60 +214,35 @@ workflow ops_workflow { urls = [select_first([phenotype_url, iss_url])], image_pattern = if pheno_url_supplied then phenotype_image_pattern else iss_image_pattern, batch_size=batch_size, + reference_time=reference_phenotype_time, groupby=groupby, - subset=subset, + subset = subset, docker=docker, zones = zones, preemptible = preemptible, aws_queue_arn = aws_queue_arn, max_retries = max_retries } - String image_pattern_after_registration = list_images.groupby_pattern - Array[String] groups = list_images.groups + String groupby_pattern = list_images.groupby_pattern # plate-well + Array[String] subsets = list_images.subsets + Array[String] subset_with_reference_times = list_images.subset_with_reference_times Array[String] times = list_images.t - scatter (group in groups) { + Array[String] groupby_with_time = list_images.filtered_groupby_with_t + scatter (subset_index in range(length(subsets))) { + String subset_ = subsets[subset_index] + String subset_with_reference_time = subset_with_reference_times[subset_index] if(pheno_url_supplied) { - if(length(times)>1) { - call tasks.register_elastix as register_pheno_to_pheno { - input: - moving=select_all([phenotype_url]), - moving_label=phenotype_url, # transform stitch masks - moving_channel=phenotype_dapi_channel_before_registration, # DAPI index in each round - moving_image_pattern=phenotype_image_pattern, - reference_time=reference_phenotype_time, - extra_arguments=pheno_registration_extra_arguments, - unroll_channels=true, - register_across_channels=register_across_channels, - groupby=groupby, - moving_output_directory=register_pheno_to_pheno_directory, - label_output_directory=register_pheno_to_pheno_directory, - transform_output_directory=register_pheno_to_pheno_transform_directory, - subset = group, - force = force_register_pheno_to_pheno, - docker=docker, - zones = zones, - preemptible = preemptible, - aws_queue_arn = aws_queue_arn, - disks = register_pheno_to_pheno_disks, - memory = register_pheno_to_pheno_memory, - cpu = register_pheno_to_pheno_cpu, - max_retries = max_retries - } - } - String register_pheno_to_pheno_output_url = select_first([register_pheno_to_pheno.moving_output_url, phenotype_url]) - String register_pheno_to_pheno_image_pattern = if(length(times)>1) then image_pattern_after_registration else phenotype_image_pattern - if(run_nuclei_segmentation) { call tasks.segment_nuclei { input: - images = register_pheno_to_pheno_output_url, - image_pattern = register_pheno_to_pheno_image_pattern, - method = nuclei_segmentation, - groupby=groupby, + images = select_first([phenotype_url]), + image_pattern = phenotype_image_pattern, + method = nuclei_segmentation_method, + groupby=groupby_with_time, dapi_channel = phenotype_dapi_channel, output_directory=segment_directory, model_dir=model_dir, - subset = group, + subset = subset_with_reference_time, extra_arguments=nuclei_segmentation_extra_arguments, force = force_segment_nuclei, docker=docker, @@ -283,14 +254,16 @@ workflow ops_workflow { cpu = segment_nuclei_cpu, max_retries = max_retries } + } if(run_cell_segmentation) { call tasks.segment_cell { input: - images = register_pheno_to_pheno_output_url, - image_pattern = register_pheno_to_pheno_image_pattern, + images = select_first([phenotype_url]), + image_pattern = phenotype_image_pattern, method = cell_segmentation_method, - groupby=groupby, + groupby = groupby_with_time, + subset = subset_with_reference_time, dapi_channel = phenotype_dapi_channel, cyto_channel=phenotype_cyto_channel, nuclei_label=select_first([segment_nuclei.output_url]), @@ -298,7 +271,7 @@ workflow ops_workflow { threshold_correction_factor = segment_cell_threshold_correction_factor, output_directory=segment_directory, model_dir=model_dir, - subset = group, + extra_arguments=cell_segmentation_extra_arguments, force = force_segment_cell, docker=docker, @@ -310,62 +283,113 @@ workflow ops_workflow { cpu = segment_cell_cpu, max_retries = max_retries } - call tasks.find_objects as find_objects_cell { - input: - labels= segment_cell.output_url, - label_pattern=image_pattern_after_registration, - suffix="cell", - output_directory=cell_objects_directory, - subset = group, - force = force_find_objects, - docker=docker, - zones = zones, - preemptible = preemptible, - aws_queue_arn = aws_queue_arn, - disks = find_objects_disks, - memory = find_objects_memory, - cpu = find_objects_cpu, - max_retries = max_retries + + if(length(times)>1) { + call tasks.register_elastix as register_pheno_to_pheno { + input: + moving=select_all([phenotype_url]), + moving_label=segment_cell.output_url, + moving_channel=phenotype_dapi_channel, + moving_image_pattern=phenotype_image_pattern, + reference_time=reference_phenotype_time, + extra_arguments=pheno_registration_extra_arguments, + output_aligned_channels_only=true, + groupby=groupby, + subset = subset_, + moving_output_directory=register_pheno_to_pheno_directory, + label_output_directory=register_pheno_to_pheno_directory, + transform_output_directory=register_pheno_to_pheno_transform_directory, + + force = force_register_pheno_to_pheno, + docker=docker, + zones = zones, + preemptible = preemptible, + aws_queue_arn = aws_queue_arn, + disks = register_pheno_to_pheno_disks, + memory = register_pheno_to_pheno_memory, + cpu = register_pheno_to_pheno_cpu, + max_retries = max_retries + } } - call tasks.find_objects as find_objects_cytosol { - input: - labels=segment_cell.output_url, - label_pattern=image_pattern_after_registration, - suffix="cytosol", - output_directory=cytosol_objects_directory, - subset = group, - force = force_find_objects, - docker=docker, - zones = zones, - preemptible = preemptible, - aws_queue_arn = aws_queue_arn, - disks = find_objects_disks, - memory = find_objects_memory, - cpu = find_objects_cpu, - max_retries = max_retries + if(run_nuclei_segmentation) { + call tasks.find_objects as find_objects_nuclei { + input: + labels=select_all([segment_nuclei.output_url, register_pheno_to_pheno.label_output_url]), + label_pattern=phenotype_image_pattern, + suffix="nuclei", + output_directory=nuclei_objects_directory, + subset = subset_, + force = force_find_objects, + docker=docker, + zones = zones, + preemptible = preemptible, + aws_queue_arn = aws_queue_arn, + disks = find_objects_disks, + memory = find_objects_memory, + cpu = find_objects_cpu, + max_retries = max_retries + } } + if(run_cell_segmentation) { + call tasks.find_objects as find_objects_cell { + input: + labels=select_all([segment_cell.output_url, register_pheno_to_pheno.label_output_url]), + label_pattern=phenotype_image_pattern, + suffix="cell", + output_directory=cell_objects_directory, + subset = subset_, + force = force_find_objects, + docker=docker, + zones = zones, + preemptible = preemptible, + aws_queue_arn = aws_queue_arn, + disks = find_objects_disks, + memory = find_objects_memory, + cpu = find_objects_cpu, + max_retries = max_retries + } + } + if (run_nuclei_segmentation && run_cell_segmentation) { + call tasks.find_objects as find_objects_cytosol { + input: + labels=select_all([segment_cell.output_url, register_pheno_to_pheno.label_output_url]), + label_pattern=phenotype_image_pattern, + suffix="cytosol", + output_directory=cytosol_objects_directory, + subset = subset_, + force = force_find_objects, + docker=docker, + zones = zones, + preemptible = preemptible, + aws_queue_arn = aws_queue_arn, + disks = find_objects_disks, + memory = find_objects_memory, + cpu = find_objects_cpu, + max_retries = max_retries + } + } + # String register_pheno_to_pheno_output_url = select_first([register_pheno_to_pheno.moving_output_url, phenotype_url]) + # String phenotype_image_pattern = if(length(times)>1) then image_pattern_after_registration else phenotype_image_pattern + # determine whether cells intersect stitch boundary - # using stitch mask as image + # use stitch mask as image and segment output for reference phenotype or transformed phenotype for others + if(mark_stitch_boundary_cells) { - String t0 = if (length(times)>0) then times[0] else "" - String reference_phenotype_time_ = select_first([reference_phenotype_time, t0]) - String output_prefix = if (reference_phenotype_time_!="") then "-" else "" String phenotype_url_stripped = if (pheno_url_supplied) then sub(select_first([phenotype_url]), "/+$", "") else "" call tasks.intersects_boundary as cell_intersects_boundary { - # reference time mask is not transformed - # use mask from stitch output + input: - labels=segment_cell.output_url, + labels=select_all([segment_cell.output_url, register_pheno_to_pheno.label_output_url]), images=phenotype_url_stripped + '/labels/', - image_pattern=image_pattern_after_registration + output_prefix + reference_phenotype_time_ + '-mask', + image_pattern=phenotype_image_pattern, output_directory=cell_intersects_boundary_directory, label_type='cell', objects=find_objects_cell.output_url, groupby=groupby, - subset = group, + subset = subset_, force = force_segment_cell, docker=docker, zones = zones, @@ -376,35 +400,14 @@ workflow ops_workflow { cpu = cell_intersects_boundary_cpu, max_retries = max_retries } - if (length(times)>1) { - call tasks.intersects_boundary as cell_intersects_boundary_t { - # non-reference time masks are transformed - # use masks from registration output - input: - labels= segment_cell.output_url, - images=register_pheno_to_pheno.moving_output_url + '/labels/', - image_pattern=phenotype_image_pattern + '-mask', - output_directory=cell_intersects_boundary_directory_non_reference_t, - label_type='cell', - objects=find_objects_cell.output_url, - subset = group, - groupby=groupby, - force = force_segment_cell, - docker=docker, - zones = zones, - preemptible = preemptible, - aws_queue_arn = aws_queue_arn, - disks = cell_intersects_boundary_disks, - memory = cell_intersects_boundary_memory, - cpu = cell_intersects_boundary_cpu, - max_retries = max_retries - } - } + } } + } if(iss_url_supplied) { + call tasks.register_elastix as register_iss_t0 { input: moving=[select_first([iss_url])], @@ -413,9 +416,8 @@ workflow ops_workflow { groupby=groupby, moving_output_directory=register_iss_t0_directory, transform_output_directory=register_iss_t0_transforms_directory, - register_across_channels=register_across_channels, extra_arguments=iss_registration_extra_arguments, - subset = group, + subset = subset_, force = force_register_iss, docker=docker, zones = zones, @@ -429,21 +431,21 @@ workflow ops_workflow { } if(iss_url_supplied && pheno_url_supplied) { + # transfer phenotype segmentation and DAPI channel to ISS call tasks.register_elastix as register_pheno_to_iss { input: fixed=select_first([iss_url]), fixed_channel=iss_dapi_channel, moving_label=segment_cell.output_url, - moving=select_all([register_pheno_to_pheno_output_url]), - moving_image_pattern=register_pheno_to_pheno_image_pattern, + moving=select_all([phenotype_url]), + moving_image_pattern=phenotype_image_pattern, fixed_image_pattern=iss_image_pattern, moving_channel=phenotype_dapi_channel, output_aligned_channels_only=true, - register_across_channels=register_across_channels, moving_output_directory=register_pheno_to_iss_directory, label_output_directory=register_pheno_to_iss_directory, transform_output_directory=register_pheno_to_iss_transforms_directory, - subset = group, + subset = subset_with_reference_time, groupby=groupby, extra_arguments=pheno_to_iss_registration_extra_arguments, force = force_register_pheno_to_iss, @@ -457,36 +459,19 @@ workflow ops_workflow { max_retries = max_retries } if(run_nuclei_segmentation) { - call tasks.find_objects as find_objects_nuclei { - input: - labels=segment_nuclei.output_url, - label_pattern=image_pattern_after_registration, - suffix="nuclei", - output_directory=nuclei_objects_directory, - subset = group, - force = force_find_objects, - docker=docker, - zones = zones, - preemptible = preemptible, - aws_queue_arn = aws_queue_arn, - disks = find_objects_disks, - memory = find_objects_memory, - cpu = find_objects_cpu, - max_retries = max_retries - } - + # ISS t0 to phenotype reference time call tasks.register_pheno_to_iss_qc as register_pheno_to_iss_qc { input: - images=select_first([register_iss_t0.moving_output_url]), - image_pattern=image_pattern_after_registration, + images=select_first([iss_url]), + image_pattern=iss_image_pattern, stacked_images=register_pheno_to_iss.moving_output_url, - stacked_image_pattern=image_pattern_after_registration, + stacked_image_pattern=phenotype_image_pattern, image_channel=iss_dapi_channel, stacked_image_channel=0, label_type='nuclei', output_directory=register_pheno_to_iss_qc_directory, labels=register_pheno_to_iss.label_output_url, - subset = group, + subset = subset_, groupby=groupby, force = force_register_pheno_to_iss_qc, docker=docker, @@ -498,16 +483,17 @@ workflow ops_workflow { cpu = register_pheno_to_iss_qc_cpu, max_retries = max_retries } - call tasks.register_qc as register_iss_to_iss_qc { + # ISS t0 to other times + call tasks.register_qc as register_iss_to_iss_qc { input: images=select_first([register_iss_t0.moving_output_url]), - image_pattern=image_pattern_after_registration, + image_pattern=groupby_pattern, channel=select_first([iss_dapi_channel, 0]), label_type='nuclei', channel_prefix="ISS", output_directory=register_iss_to_iss_qc_directory, labels=register_pheno_to_iss.label_output_url, - subset = group, + subset = subset_, groupby=groupby, force = force_register_iss_to_iss_qc, docker=docker, @@ -526,14 +512,14 @@ workflow ops_workflow { call tasks.spot_detect { input: images=select_first([register_iss_t0.moving_output_url]), - image_pattern=image_pattern_after_registration, + image_pattern=groupby_pattern, iss_channels=iss_channels, sigma_log=spot_detection_sigma_log, max_filter_width=spot_detection_max_filter_width, peak_neighborhood_size=spot_detection_peak_neighborhood_size, expected_cycles=iss_expected_cycles, output_directory=spot_detect_directory, - subset = group, + subset = subset_, groupby=groupby, extra_arguments=spot_detection_extra_arguments, force = force_spot_detect, @@ -562,7 +548,7 @@ workflow ops_workflow { label_name=reads_labels, mismatches=reads_mismatches, threshold_peaks_crosstalk=reads_threshold_peaks_crosstalk, - subset = group, + subset = subset_, extra_arguments=reads_extra_arguments, force = force_reads, docker=docker, @@ -580,20 +566,16 @@ workflow ops_workflow { call tasks.merge as merge_sbs_metadata { input: iss_reads=select_first([reads.output_url]) + '/labels', -# phenotypes_nuclei=features_nuclei.output_url, -# phenotypes_cell=features_cell.output_url, -# phenotypes_cytosol=features_cytosol.output_url, - objects_nuclei=find_objects_nuclei.output_url, + objects_nuclei=find_objects_nuclei.output_url, # all rounds objects_cell=find_objects_cell.output_url, objects_cytosol=find_objects_cytosol.output_url, cell_intersects_boundary=cell_intersects_boundary.output_url, - cell_intersects_boundary_t=cell_intersects_boundary_t.output_url, register_pheno_to_iss_qc=register_pheno_to_iss_qc.output_url, register_iss_to_iss_qc=register_iss_to_iss_qc.output_url, barcodes=select_first([barcodes]), barcode_column=barcode_column, output_directory=merge_meta_directory, - subset = group, + subset = subset_, extra_arguments=merge_extra_arguments, force = force_merge, docker=docker, @@ -609,132 +591,146 @@ workflow ops_workflow { } if (defined(phenotype_nuclei_features)) { - Array[String] phenotype_nuclei_features_ = select_first([phenotype_nuclei_features]) - # cromwell hack + Map[String,Array[String]] phenotype_nuclei_features_ = select_first([phenotype_nuclei_features]) Int features_nuclei_min_area_ = select_first([features_nuclei_min_area, -1]) Int features_nuclei_max_area_ = select_first([features_nuclei_max_area, -1]) - scatter (index in range(length(phenotype_nuclei_features_))) { + Array[String] phenotype_nuclei_times = keys(phenotype_nuclei_features_) - call tasks.features as features_nuclei { - input: - images = select_first([register_pheno_to_pheno_output_url]), - image_pattern=register_pheno_to_pheno_image_pattern, - objects=merge_sbs_metadata.output_url, - label_filter=features_label_filter, - nuclei_features = phenotype_nuclei_features_[index], - nuclei_min_area = features_nuclei_min_area_, - nuclei_max_area = features_nuclei_max_area_, - features_extra_arguments=features_extra_arguments, - labels= segment_cell.output_url, - model_dir=model_dir, - groupby=groupby, - output_directory=nuclei_features_directory + '-' + index, - subset = group, - force = force_features, - docker=docker, - zones = zones, - preemptible = preemptible, - aws_queue_arn = aws_queue_arn, - disks = features_disks, - memory = features_memory, - cpu = features_cpu, - max_retries = max_retries + scatter (phenotype_time in phenotype_nuclei_times) { + Array[String] nuclei_features = phenotype_nuclei_features_[phenotype_time] + scatter (feature_index in range(length(nuclei_features))) { + call tasks.features as features_nuclei { + input: + images = select_first([phenotype_url]), + image_pattern=phenotype_image_pattern, + objects=merge_sbs_metadata.output_url, + labels=select_all([segment_nuclei.output_url, register_pheno_to_pheno.label_output_url]), + label_filter=features_label_filter, + nuclei_features = nuclei_features[feature_index], + nuclei_min_area = features_nuclei_min_area_, + nuclei_max_area = features_nuclei_max_area_, + features_extra_arguments=features_extra_arguments, + + model_dir=model_dir, + groupby=groupby, + output_directory=nuclei_features_directory + '-' + phenotype_time + '-' + feature_index, + subset = subset_, + force = force_features, + docker=docker, + zones = zones, + preemptible = preemptible, + aws_queue_arn = aws_queue_arn, + disks = features_disks, + memory = features_memory, + cpu = features_cpu, + max_retries = max_retries + } } } } if (defined(phenotype_cell_features)) { - Array[String] phenotype_cell_features_ = select_first([phenotype_cell_features]) - # cromwell hack + Map[String,Array[String]] phenotype_cell_features_ = select_first([phenotype_cell_features]) Int features_cell_min_area_ = select_first([features_cell_min_area, -1]) Int features_cell_max_area_ = select_first([features_cell_max_area, -1]) - scatter (index in range(length(phenotype_cell_features_))) { - call tasks.features as features_cell { - input: - images = select_first([register_pheno_to_pheno_output_url]), - image_pattern=register_pheno_to_pheno_image_pattern, - objects=merge_sbs_metadata.output_url, - label_filter=features_label_filter, - cell_features = phenotype_cell_features_[index], - cell_min_area = features_cell_min_area_, - cell_max_area = features_cell_max_area_, - features_extra_arguments=features_extra_arguments, - labels= segment_cell.output_url, - model_dir=model_dir, - groupby=groupby, - output_directory=cell_features_directory + '-' + index, - subset = group, - force = force_features, - docker=docker, - zones = zones, - preemptible = preemptible, - aws_queue_arn = aws_queue_arn, - disks = features_disks, - memory = features_memory, - cpu = features_cpu, - max_retries = max_retries + Array[String] phenotype_cell_times = keys(phenotype_cell_features_) + + scatter (phenotype_time in phenotype_cell_times) { + Array[String] cell_features = phenotype_cell_features_[phenotype_time] + scatter (feature_index in range(length(cell_features))) { + call tasks.features as features_cell { + input: + images = select_first([phenotype_url]), + image_pattern=phenotype_image_pattern, + objects=merge_sbs_metadata.output_url, + labels=select_all([segment_cell.output_url, register_pheno_to_pheno.label_output_url]), + label_filter=features_label_filter, + cell_features = cell_features[feature_index], + cell_min_area = features_cell_min_area_, + cell_max_area = features_cell_max_area_, + features_extra_arguments=features_extra_arguments, + + model_dir=model_dir, + groupby=groupby, + output_directory=cell_features_directory + '-' + phenotype_time + '-' + feature_index, + subset = subset_, + force = force_features, + docker=docker, + zones = zones, + preemptible = preemptible, + aws_queue_arn = aws_queue_arn, + disks = features_disks, + memory = features_memory, + cpu = features_cpu, + max_retries = max_retries + } } } } if (defined(phenotype_cytosol_features)) { - Array[String] phenotype_cytosol_features_ = select_first([phenotype_cytosol_features]) - # cromwell hack + Map[String,Array[String]] phenotype_cytosol_features_ = select_first([phenotype_cytosol_features]) Int features_cytosol_min_area_ = select_first([features_cytosol_min_area, -1]) Int features_cytosol_max_area_ = select_first([features_cytosol_max_area, -1]) - scatter (index in range(length(phenotype_cytosol_features_))) { - call tasks.features as features_cytosol { - input: - images = select_first([register_pheno_to_pheno_output_url]), - image_pattern=register_pheno_to_pheno_image_pattern, - objects=merge_sbs_metadata.output_url, - label_filter=features_label_filter, - cytosol_features = phenotype_cytosol_features_[index], - cytosol_min_area = features_cytosol_min_area_, - cytosol_max_area = features_cytosol_max_area_, - labels = segment_cell.output_url, - features_extra_arguments=features_extra_arguments, - model_dir=model_dir, - groupby=groupby, - output_directory=cytosol_features_directory + '-' + index, - subset = group, - force = force_features, - docker=docker, - zones = zones, - preemptible = preemptible, - aws_queue_arn = aws_queue_arn, - disks = features_disks, - memory = features_memory, - cpu = features_cpu, - max_retries = max_retries + Array[String] phenotype_cytosol_times = keys(phenotype_cytosol_features_) + + scatter (phenotype_time in phenotype_cytosol_times) { + Array[String] cytosol_features = phenotype_cytosol_features_[phenotype_time] + scatter (feature_index in range(length(cytosol_features))) { + call tasks.features as features_cytosol { + input: + images = select_first([phenotype_url]), + image_pattern=phenotype_image_pattern, + objects=merge_sbs_metadata.output_url, + labels=select_all([segment_cell.output_url, register_pheno_to_pheno.label_output_url]), + label_filter=features_label_filter, + output_directory=cytosol_features_directory + '-' + phenotype_time + '-' + feature_index, + cytosol_features = cytosol_features[feature_index], + cytosol_min_area = features_cytosol_min_area_, + cytosol_max_area = features_cytosol_max_area_, + features_extra_arguments=features_extra_arguments, + + model_dir=model_dir, + groupby=groupby, + + subset = subset_, + force = force_features, + docker=docker, + zones = zones, + preemptible = preemptible, + aws_queue_arn = aws_queue_arn, + disks = features_disks, + memory = features_memory, + cpu = features_cpu, + max_retries = max_retries + } } } } - if (defined(barcodes)) { + if (defined(barcodes)) { - call tasks.merge as merge_features { - input: - phenotypes_nuclei=features_nuclei.output_url, - phenotypes_cell=features_cell.output_url, - phenotypes_cytosol=features_cytosol.output_url, - iss_reads=merge_sbs_metadata.output_url, - output_directory=merge_features_directory, - subset = group, - extra_arguments=merge_extra_arguments, - force = force_merge, - docker=docker, - zones = zones, - preemptible = preemptible, - aws_queue_arn = aws_queue_arn, - disks = merge_disks, - memory = merge_memory, - cpu = merge_cpu, - max_retries = max_retries - } + call tasks.merge as merge_features { + input: + phenotypes_nuclei=features_nuclei.output_url, + phenotypes_cell=features_cell.output_url, + phenotypes_cytosol=features_cytosol.output_url, + iss_reads=merge_sbs_metadata.output_url, + output_directory=merge_features_directory, + subset = subset_, + extra_arguments=merge_extra_arguments, + force = force_merge, + docker=docker, + zones = zones, + preemptible = preemptible, + aws_queue_arn = aws_queue_arn, + disks = merge_disks, + memory = merge_memory, + cpu = merge_cpu, + max_retries = max_retries } - + } } output { @@ -754,7 +750,5 @@ workflow ops_workflow { Array[Array[String]?] features_cytosol_output_url = features_cytosol.output_url Array[String?] merge_sbs_metadata_output_url = merge_sbs_metadata.output_url Array[String?] merge_features_output_url = merge_features.output_url - Array[String] list_images_groups = list_images.groups - } } diff --git a/wdl/utils.wdl b/wdl/utils.wdl index 35e8f01..77060f4 100644 --- a/wdl/utils.wdl +++ b/wdl/utils.wdl @@ -1,10 +1,11 @@ -version 1.0 +version 1.1 task list_images { input { Boolean? save_group_size Array[String] urls String? image_pattern + String? reference_time Array[String] groupby Array[String]? subset Int? expected_cycles @@ -23,26 +24,31 @@ task list_images { command <<< set -e python <>> output { - Array[String] groups = read_lines('groups.txt') + Array[String] subsets = read_lines('subsets.txt') + Array[String] subset_with_reference_times = read_lines('subsets_with_t.txt') Array[String] t = read_lines('t.txt') - Array[String] filtered_groupby = read_lines('groupby.txt') - String groupby_pattern = read_lines('groupby_pattern.txt')[0] - Int group_size = read_int('group_size.txt') + String groupby_pattern = read_lines('groupby_pattern.txt')[0] # e.g. {plate}-{well} + + Array[String] filtered_groupby_with_t = read_lines('groupby_with_t.txt') # e.g. [plate, well, t] + Array[String] filtered_groupby = read_lines('groupby.txt') # e.g. [plate, well] + + Int group_size = read_int('group_size.txt') } meta { From 0fb20ab7fd266e8cd09071a3ee2543b46d258b26 Mon Sep 17 00:00:00 2001 From: Joshua Gould Date: Tue, 9 Jun 2026 13:24:26 -0400 Subject: [PATCH 02/21] Transform labels instead of images --- scallops/tests/test_wdl.py | 1 - 1 file changed, 1 deletion(-) diff --git a/scallops/tests/test_wdl.py b/scallops/tests/test_wdl.py index 1fa68bf..943e549 100644 --- a/scallops/tests/test_wdl.py +++ b/scallops/tests/test_wdl.py @@ -194,7 +194,6 @@ def test_ops_wdl(tmp_path): "model_dir": "", "iss_url": str(sbs_dir.absolute()), "output_directory": str(output.absolute()), - "nuclei_segmentation_method": "cellpose", "iss_registration_extra_arguments": "--no-landmarks", "pheno_to_iss_registration_extra_arguments": "--no-landmarks", "pheno_registration_extra_arguments": "--no-landmarks", From 92a4b3592eb810d3d146f428b92bbaffee87843c Mon Sep 17 00:00:00 2001 From: Joshua Gould Date: Tue, 9 Jun 2026 15:00:23 -0400 Subject: [PATCH 03/21] Transform labels instead of images --- scallops/cli/util.py | 17 ++++++++++++++++- wdl/ops_workflow.wdl | 19 ++++++++----------- wdl/utils.wdl | 3 +-- 3 files changed, 25 insertions(+), 14 deletions(-) diff --git a/scallops/cli/util.py b/scallops/cli/util.py index 595f0fa..28553b9 100644 --- a/scallops/cli/util.py +++ b/scallops/cli/util.py @@ -523,7 +523,7 @@ def _list_images_wdl( # "groups.txt": each line passed to --subset in cli # "groupby.txt": filtered groupby with values not in image_pattern removed groupby_t = "t" in groupby - times = [] + times = None if not save_group_size: with open("group_size.txt", "wt") as f: @@ -602,3 +602,18 @@ def _list_images_wdl( f.write("{") f.write(g) f.write("}") + groupby_with_reference_time = list(groupby) + if reference_time is not None: + groupby_with_reference_time.append(reference_time) + elif times is not None and len(times) > 0: + groupby_with_reference_time.append(times[0]) + + with open("groupby_pattern_with_reference_t.txt", "wt") as f: + first = True + for g in groupby_with_reference_time: + if not first: + f.write("-") + first = False + f.write("{") + f.write(g) + f.write("}") diff --git a/wdl/ops_workflow.wdl b/wdl/ops_workflow.wdl index 9c14b86..eed2343 100644 --- a/wdl/ops_workflow.wdl +++ b/wdl/ops_workflow.wdl @@ -223,11 +223,12 @@ workflow ops_workflow { aws_queue_arn = aws_queue_arn, max_retries = max_retries } - String groupby_pattern = list_images.groupby_pattern # plate-well - Array[String] subsets = list_images.subsets - Array[String] subset_with_reference_times = list_images.subset_with_reference_times - Array[String] times = list_images.t - Array[String] groupby_with_time = list_images.filtered_groupby_with_t + String groupby_pattern = list_images.groupby_pattern # "{plate}-{well}" + Array[String] subsets = list_images.subsets # e.g. ["plate1-A1", "plate1-A2", ...] + Array[String] subset_with_reference_times = list_images.subset_with_reference_times # e.g. ["plate1-A1-IF", "plate1-A2-IF", ...] + Array[String] times = list_images.t # e.g. ["FISH", "IF"] + Array[String] groupby_with_time = list_images.filtered_groupby_with_t # e.g. ['plate', 'well', 't'] + String groupby_pattern_with_reference_t = list_images.groupby_pattern_with_reference_t # e.g. "{plate}-{well}-IF" scatter (subset_index in range(length(subsets))) { String subset_ = subsets[subset_index] String subset_with_reference_time = subset_with_reference_times[subset_index] @@ -369,10 +370,6 @@ workflow ops_workflow { max_retries = max_retries } } - # String register_pheno_to_pheno_output_url = select_first([register_pheno_to_pheno.moving_output_url, phenotype_url]) - # String phenotype_image_pattern = if(length(times)>1) then image_pattern_after_registration else phenotype_image_pattern - - # determine whether cells intersect stitch boundary # use stitch mask as image and segment output for reference phenotype or transformed phenotype for others @@ -438,14 +435,14 @@ workflow ops_workflow { fixed_channel=iss_dapi_channel, moving_label=segment_cell.output_url, moving=select_all([phenotype_url]), - moving_image_pattern=phenotype_image_pattern, + moving_image_pattern=groupby_pattern_with_reference_t, fixed_image_pattern=iss_image_pattern, moving_channel=phenotype_dapi_channel, output_aligned_channels_only=true, moving_output_directory=register_pheno_to_iss_directory, label_output_directory=register_pheno_to_iss_directory, transform_output_directory=register_pheno_to_iss_transforms_directory, - subset = subset_with_reference_time, + subset = subset_, groupby=groupby, extra_arguments=pheno_to_iss_registration_extra_arguments, force = force_register_pheno_to_iss, diff --git a/wdl/utils.wdl b/wdl/utils.wdl index 77060f4..393b806 100644 --- a/wdl/utils.wdl +++ b/wdl/utils.wdl @@ -43,8 +43,7 @@ task list_images { Array[String] subset_with_reference_times = read_lines('subsets_with_t.txt') Array[String] t = read_lines('t.txt') String groupby_pattern = read_lines('groupby_pattern.txt')[0] # e.g. {plate}-{well} - - + String groupby_pattern_with_reference_t = read_lines('groupby_pattern_with_reference_t.txt')[0] # e.g. {plate}-{well}-IF Array[String] filtered_groupby_with_t = read_lines('groupby_with_t.txt') # e.g. [plate, well, t] Array[String] filtered_groupby = read_lines('groupby.txt') # e.g. [plate, well] From fd1fa6442ebd14a0babad19b842232a5584546a5 Mon Sep 17 00:00:00 2001 From: Joshua Gould Date: Tue, 9 Jun 2026 15:34:36 -0400 Subject: [PATCH 04/21] Transform labels instead of images --- scallops/cli/util.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/scallops/cli/util.py b/scallops/cli/util.py index 28553b9..74f514e 100644 --- a/scallops/cli/util.py +++ b/scallops/cli/util.py @@ -602,18 +602,19 @@ def _list_images_wdl( f.write("{") f.write(g) f.write("}") - groupby_with_reference_time = list(groupby) + reference_time_suffix = "" if reference_time is not None: - groupby_with_reference_time.append(reference_time) + reference_time_suffix = f"-{reference_time}" elif times is not None and len(times) > 0: - groupby_with_reference_time.append(times[0]) + reference_time_suffix = f"-{times[0]}" with open("groupby_pattern_with_reference_t.txt", "wt") as f: first = True - for g in groupby_with_reference_time: + for g in groupby: if not first: f.write("-") first = False f.write("{") f.write(g) f.write("}") + f.write(reference_time_suffix) From e6f66ab5b0c38ba879d0e160f421989a60fbc45c Mon Sep 17 00:00:00 2001 From: Joshua Gould Date: Tue, 9 Jun 2026 15:43:34 -0400 Subject: [PATCH 05/21] Transform labels instead of images --- scallops/cli/util.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/scallops/cli/util.py b/scallops/cli/util.py index 74f514e..b7360e1 100644 --- a/scallops/cli/util.py +++ b/scallops/cli/util.py @@ -520,8 +520,7 @@ def _list_images_wdl( exp_gen = _set_up_experiment( image_path=urls, files_pattern=image_pattern, group_by=groupby, subset=subset ) - # "groups.txt": each line passed to --subset in cli - # "groupby.txt": filtered groupby with values not in image_pattern removed + groupby_t = "t" in groupby times = None @@ -617,4 +616,4 @@ def _list_images_wdl( f.write("{") f.write(g) f.write("}") - f.write(reference_time_suffix) + f.write(reference_time_suffix) From ce6634f14cb7f9bf66c4ec11696389dfcee58a78 Mon Sep 17 00:00:00 2001 From: Joshua Gould Date: Thu, 11 Jun 2026 12:37:53 -0400 Subject: [PATCH 06/21] registration transform across times --- scallops/cli/register.py | 151 +++++++++++++++++++++++---------------- 1 file changed, 88 insertions(+), 63 deletions(-) diff --git a/scallops/cli/register.py b/scallops/cli/register.py index 70be43d..21245a3 100644 --- a/scallops/cli/register.py +++ b/scallops/cli/register.py @@ -5,7 +5,7 @@ import os from collections.abc import Sequence from itertools import zip_longest -from typing import Literal +from typing import Any, Literal import fsspec import itk @@ -13,6 +13,8 @@ import xarray as xr import zarr from dask.bag import from_sequence +from natsort import natsorted +from xarray import DataArray from zarr import Group from scallops.cli.util import ( @@ -54,6 +56,48 @@ logger = _get_cli_logger() +def _output_exists( + register_self, label_output_root, moving_label_keys, image_output_root, image_key +): + labels_exist = True + if label_output_root is not None: + if not register_self: + for key in moving_label_keys: + key = os.path.basename(key) + if not is_ome_zarr_array(label_output_root.get(f"labels/{key}")): + labels_exist = False + break + # TODO check for transformed labels when register_self + + image_exists = True + if image_output_root is not None: + if not is_ome_zarr_array(image_output_root.get(f"images/{image_key}")): + image_exists = False + elif label_output_root is None: + image_exists = False + return labels_exist and image_exists + + +def _get_reference_timepoint( + reference_timepoint: str, moving_image: Sequence[DataArray] +) -> tuple[int, Any]: + reference_timepoint_value = reference_timepoint + if isinstance(reference_timepoint, str): + reference_timepoint_found = False + for i in range(len(moving_image)): + if moving_image[i].coords["t"].values[0] == reference_timepoint: + reference_timepoint = i + reference_timepoint_found = True + break + if not reference_timepoint_found: + raise ValueError(f"Reference timepoint not found: {reference_timepoint}.") + else: + reference_timepoint_value = ( + moving_image[reference_timepoint].coords["t"].values[0] + ) + return reference_timepoint, reference_timepoint_value + + def single_registration( fixed_tuple: tuple[tuple[str, ...], list[str | Group], dict] | None, moving_tuple: tuple[tuple[str, ...], list[str | Group], dict], @@ -130,51 +174,44 @@ def single_registration( transform_dest = f"{transform_output_dir}{transform_fs.sep}{image_key}" moving_label_keys = [] + # when registering self, transforms written to output_dir/time={t}/ + moving_image = _images2fov( + moving_file_list, + moving_metadata, + dask=True, + concat_dims=("c",), + ) + if register_self: + reference_timepoint, reference_timepoint_value = _get_reference_timepoint( + reference_timepoint=reference_timepoint, moving_image=moving_image + ) if moving_labels is not None: - matching_label_prefix = image_key - if register_self: - matching_label_prefix = f"{matching_label_prefix}-{reference_timepoint}" + matching_label_prefix = ( + f"{image_key}-{reference_timepoint_value}" if register_self else image_key + ) + for moving_label in moving_labels: moving_label_keys.extend( get_matching_names( image_key=matching_label_prefix, image_dir=moving_label, labels=True ) ) - moving_label_keys = sorted(moving_label_keys) + moving_label_keys = natsorted(moving_label_keys) if len(moving_label_keys) == 0: logger.warning(f"No labels found for {image_key}") - if not force: - labels_exist = True - if label_output_root is not None: - if not register_self: - for key in moving_label_keys: - key = os.path.basename(key) - if not is_ome_zarr_array(label_output_root.get(f"labels/{key}")): - labels_exist = False - break - # TODO check for transformed labels when register_self - - image_exists = True - if image_output_root is not None: - if not is_ome_zarr_array(image_output_root.get(f"images/{image_key}")): - image_exists = False - elif label_output_root is None: - image_exists = False - if labels_exist and image_exists: - logger.info(f"Skipping registration for {image_key}") - return image_key - - if register_self: - logger.info(f"Running registration for {image_key} t={reference_timepoint}") - logger.info( - f"{len(moving_file_list):,} {pluralize('input', len(moving_file_list))}:" - f" {', '.join([s.name.replace('/images/', '') if isinstance(s, zarr.Group) else str(s) for s in moving_file_list])}" - ) - else: - logger.info(f"Running registration for {image_key}") + if not force and _output_exists( + register_self, + label_output_root, + moving_label_keys, + image_output_root, + image_key, + ): + logger.info(f"Skipping registration for {image_key}") + return image_key if not register_self: + logger.info(f"Running registration for {image_key}") _, fixed_file_list, fixed_metadata = fixed_tuple assert fixed_metadata["id"] == moving_metadata["id"], ( @@ -192,12 +229,6 @@ def single_registration( fixed_image = _z_projection(fixed_image, z_index).isel( t=0, c=fixed_channel, missing_dims="ignore" ) - moving_image = _images2fov( - moving_file_list, - moving_metadata, - dask=True, - concat_dims=("c",), - ) parameter_object = _load_itk_parameters(itk_parameters) parameter_object_across_channels = ( @@ -353,17 +384,11 @@ def single_registration( ) else: # align to t=reference_timepoint - if isinstance(reference_timepoint, str): - reference_timepoint_found = False - for i in range(len(moving_image)): - if moving_image[i].coords["t"].values[0] == reference_timepoint: - reference_timepoint = i - reference_timepoint_found = True - break - if not reference_timepoint_found: - raise ValueError( - f"Reference timepoint not found: {reference_timepoint}." - ) + logger.info(f"Running registration for {image_key} t={reference_timepoint}") + logger.info( + f"{len(moving_file_list):,} {pluralize('input', len(moving_file_list))}:" + f" {', '.join([s.name.replace('/images/', '') if isinstance(s, zarr.Group) else str(s) for s in moving_file_list])}" + ) set_automatic_transform_initialization(parameter_object, False) if output_aligned_channels_only and not isinstance(moving_image, xr.DataArray): @@ -423,7 +448,7 @@ def _transform_labels_t( ): # transform_dest structure is image_key/t=1 # assume labels are named image_key-t-suffix - print("moving_label_keys", moving_label_keys) + for transform_file in transform_fs.ls(transform_dest, detail=True, refresh=True): if transform_file["type"] == "directory": transform_name = transform_file["name"] @@ -457,7 +482,10 @@ def _transform_labels_t( def get_matching_names( - image_key: str, image_dir: str | Group, labels: bool = True + image_key: str, + image_dir: str | Group, + labels: bool = True, + label_suffixes: Sequence[str] = {"cell", "nuclei", "cytosol"}, ) -> list[str]: """Get matching keys for the given image key and directory. @@ -472,24 +500,21 @@ def get_matching_names( # look for f'labels/image_key-{suffix} or f'images/image_key zarr_dir = "labels" if labels else "images" if isinstance(image_dir, Group): - protocol = _get_fs_protocol(_get_fs(image_dir)) + fs = _get_fs(image_dir) image_dir = f"{_get_store_path(image_dir)}{image_dir.name}" - if protocol != "file": - image_dir = f"{protocol}://{image_dir}" + image_dir = fs.unstrip_protocol(image_dir) image_fs, _ = fsspec.core.url_to_fs(image_dir) image_dir = image_dir.rstrip(image_fs.sep) glob_pattern = f"{image_dir}{image_fs.sep}{zarr_dir}{image_fs.sep}{image_key}" - if labels: - glob_pattern += "-*" - paths = image_fs.glob(glob_pattern) - protocol = _get_fs_protocol(image_fs) - if protocol != "file": - paths = [f"{protocol}://{x}" for x in paths] + results = [] - for path in paths: + for path in image_fs.glob(glob_pattern): + path = image_fs.unstrip_protocol(path) name = os.path.basename(path) + if labels and name.split("-")[-1] not in label_suffixes: + continue if not name.startswith(".") and is_ome_zarr_array(zarr.open(path, mode="r")): results.append(path) return results From cc2367555324d48d0e1f41776417ead83020524b Mon Sep 17 00:00:00 2001 From: Joshua Gould Date: Fri, 12 Jun 2026 16:24:23 -0400 Subject: [PATCH 07/21] Registration transform across times --- scallops/cli/register.py | 240 ++++++++++------ scallops/cli/register_main.py | 10 +- scallops/cli/segment.py | 78 ++++- scallops/cli/segment_main.py | 11 +- scallops/cli/util.py | 284 ++++++++++++------- scallops/tests/miniwdl_local/local_runner.py | 7 +- scallops/tests/test_register_cli.py | 61 ++-- scallops/zarr_io.py | 121 ++++---- wdl/ops_tasks.wdl | 10 +- wdl/ops_workflow.wdl | 65 +++-- wdl/utils.wdl | 65 +++-- 11 files changed, 609 insertions(+), 343 deletions(-) diff --git a/scallops/cli/register.py b/scallops/cli/register.py index 21245a3..e40bedd 100644 --- a/scallops/cli/register.py +++ b/scallops/cli/register.py @@ -13,7 +13,7 @@ import xarray as xr import zarr from dask.bag import from_sequence -from natsort import natsorted +from natsort import index_natsorted, natsorted from xarray import DataArray from zarr import Group @@ -48,6 +48,7 @@ _get_fs, _get_store_path, _write_zarr_image, + _write_zarr_labels, is_ome_zarr_array, open_ome_zarr, read_ome_zarr_array, @@ -78,24 +79,40 @@ def _output_exists( return labels_exist and image_exists -def _get_reference_timepoint( - reference_timepoint: str, moving_image: Sequence[DataArray] +def _get_timepoint_index_and_value( + timepoint: str | int, image: Sequence[DataArray] | xr.DataArray ) -> tuple[int, Any]: - reference_timepoint_value = reference_timepoint - if isinstance(reference_timepoint, str): - reference_timepoint_found = False - for i in range(len(moving_image)): - if moving_image[i].coords["t"].values[0] == reference_timepoint: - reference_timepoint = i - reference_timepoint_found = True - break - if not reference_timepoint_found: - raise ValueError(f"Reference timepoint not found: {reference_timepoint}.") + timepoint_value = None + timepoint_index = None + if isinstance(timepoint, str): + if isinstance(image, Sequence): + for i in range(len(image)): + if ( + "t" in image[i].coords + and image[i].coords["t"].values[0] == timepoint + ): + timepoint_index = i + timepoint_value = image[i].coords["t"].values[0] + break + elif "t" in image.coords: + times = list(image.coords["t"].values) + if timepoint in times: + timepoint_index = times.index(timepoint) + timepoint_value = times[timepoint_index] + + if timepoint_index is None: + raise ValueError(f"Reference timepoint not found: {timepoint}.") + elif isinstance(timepoint, int): + timepoint_index = timepoint + if isinstance(image, Sequence): + if "t" in image[timepoint_index].coords: + timepoint_value = image[timepoint_index].coords["t"].values[0] + elif "t" in image.coords: + times = list(image.coords["t"].values) + timepoint_value = times[timepoint_index] else: - reference_timepoint_value = ( - moving_image[reference_timepoint].coords["t"].values[0] - ) - return reference_timepoint, reference_timepoint_value + raise ValueError() + return timepoint_index, timepoint_value def single_registration( @@ -111,7 +128,8 @@ def single_registration( moving_labels: list[str], moving_image_spacing: tuple[float, float] | None, fixed_image_spacing: tuple[float, float] | None, - reference_timepoint: int | str, + moving_timepoint: int | str, + fixed_timepoint: int | str, unroll_channels: bool = False, force: bool = False, z_index: int | str = "max", @@ -149,8 +167,8 @@ def single_registration( :param moving_image_spacing: Spacing of the moving image. :param fixed_image_spacing: Spacing of the fixed image. :param force: Whether to overwrite existing output - :param reference_timepoint: Index or value of timepoint to register to when registering - across timepoints + :param moving_timepoint: Index or value of moving timepoint to use + :param fixed_timepoint: Index or value of fixed timepoint to use :param unroll_channels: Whether to unroll channels across timepoints. :param landmarks_initialize: Use landmarks to initialize registration. :param landmark_slice_size: The slice size in physical coordinates @@ -174,31 +192,28 @@ def single_registration( transform_dest = f"{transform_output_dir}{transform_fs.sep}{image_key}" moving_label_keys = [] - # when registering self, transforms written to output_dir/time={t}/ + # when registering self, transforms written to output_dir/time={t_value}/ moving_image = _images2fov( moving_file_list, moving_metadata, dask=True, concat_dims=("c",), ) - if register_self: - reference_timepoint, reference_timepoint_value = _get_reference_timepoint( - reference_timepoint=reference_timepoint, moving_image=moving_image - ) - if moving_labels is not None: - matching_label_prefix = ( - f"{image_key}-{reference_timepoint_value}" if register_self else image_key - ) + moving_timepoint, moving_timepoint_value = _get_timepoint_index_and_value( + moving_timepoint, moving_image + ) + if moving_labels is not None: for moving_label in moving_labels: moving_label_keys.extend( get_matching_names( - image_key=matching_label_prefix, image_dir=moving_label, labels=True + image_key=image_key, image_dir=moving_label, labels=True ) ) moving_label_keys = natsorted(moving_label_keys) if len(moving_label_keys) == 0: - logger.warning(f"No labels found for {image_key}") + logger.warning(f"No labels found for {image_key}.") + return image_key if not force and _output_exists( register_self, @@ -224,12 +239,16 @@ def single_registration( concat_dims=("c",), dask=True, ) - if isinstance(fixed_image, Sequence): - fixed_image = fixed_image[0] - fixed_image = _z_projection(fixed_image, z_index).isel( - t=0, c=fixed_channel, missing_dims="ignore" + fixed_timepoint, fixed_timepoint_value = _get_timepoint_index_and_value( + fixed_timepoint, fixed_image ) - + if isinstance(fixed_image, Sequence): + fixed_image = fixed_image[fixed_timepoint] + else: + fixed_image = fixed_image.isel( + t=fixed_timepoint, c=fixed_channel, missing_dims="ignore" + ) + fixed_image = _z_projection(fixed_image, z_index) parameter_object = _load_itk_parameters(itk_parameters) parameter_object_across_channels = ( _load_itk_parameters(itk_channel_parameters) @@ -242,8 +261,15 @@ def single_registration( transform_fs.makedirs(transform_dest, exist_ok=True) if not register_self: - if isinstance(moving_image, Sequence): - moving_image = moving_image[0] + moving_image = ( + moving_image[moving_timepoint].isel(c=moving_channel, missing_dims="ignore") + if isinstance(moving_image, Sequence) + else moving_image.isel( + t=moving_timepoint, c=moving_channel, missing_dims="ignore" + ) + ) + moving_image_align = _z_projection(moving_image, z_index) + if ( moving_image_spacing is None and get_image_spacing(moving_image.attrs) is None @@ -257,9 +283,6 @@ def single_registration( f"Physical size not found for fixed image for {image_key}." ) - moving_image_align = _z_projection(moving_image, z_index).isel( - t=0, c=moving_channel, missing_dims="ignore" - ) if "c" in moving_image_align.dims and moving_image_align.sizes["c"] > 1: moving_image_align = moving_image_align.median(dim="c", keep_attrs=True) @@ -383,8 +406,8 @@ def single_registration( output_root=label_output_root, ) - else: # align to t=reference_timepoint - logger.info(f"Running registration for {image_key} t={reference_timepoint}") + else: # align to t=moving_timepoint + logger.info(f"Running registration for {image_key} t={moving_timepoint}") logger.info( f"{len(moving_file_list):,} {pluralize('input', len(moving_file_list))}:" f" {', '.join([s.name.replace('/images/', '') if isinstance(s, zarr.Group) else str(s) for s in moving_file_list])}" @@ -398,10 +421,10 @@ def single_registration( moving_channel = 0 moving_image = new_moving_image if not no_version: - moving_image[reference_timepoint].attrs.update(cli_metadata()) + moving_image[moving_timepoint].attrs.update(cli_metadata()) _itk_align_reference_time_zarr( unroll_channels=unroll_channels, - reference_timepoint=reference_timepoint, + reference_timepoint=moving_timepoint, moving_image=moving_image, moving_channel=moving_channel, parameter_object=parameter_object, @@ -421,64 +444,90 @@ def single_registration( parameter_object_across_channels=parameter_object_across_channels, ) moving_image_attrs = moving_image[0].attrs.copy() + chunksize = moving_image[0].data.chunksize[-2:] del moving_image - + if moving_image_spacing is None: + moving_image_spacing = get_image_spacing(moving_image_attrs) if len(moving_label_keys) > 0: _transform_labels_t( - image_key=image_key, transform_fs=transform_fs, transform_dest=transform_dest, - moving_label_keys=moving_label_keys, + label_output_root=label_output_root, moving_image_attrs=moving_image_attrs, + moving_label_keys=moving_label_keys, moving_image_spacing=moving_image_spacing, - label_output_root=label_output_root, + moving_timepoint_value=moving_timepoint_value, + chunksize=chunksize, ) return image_key def _transform_labels_t( - image_key: str, transform_fs, transform_dest: str, - moving_label_keys: Sequence[str], - moving_image_attrs, - moving_image_spacing, label_output_root, + moving_image_attrs: dict, + moving_label_keys: Sequence[str], + moving_image_spacing: tuple[int, int], + moving_timepoint_value: str, + chunksize: tuple[int, int], ): # transform_dest structure is image_key/t=1 - # assume labels are named image_key-t-suffix - - for transform_file in transform_fs.ls(transform_dest, detail=True, refresh=True): - if transform_file["type"] == "directory": - transform_name = transform_file["name"] - basename = os.path.basename(transform_name) - if basename.startswith("t="): - time = basename[2:] - moving_label_keys_t = [] - output_label_prefix = f"{image_key}-{time}" - output_names = [] - # e.g. transform plateA-A1-IF-cell to plateA-A1-FISH-cell - for moving_label_key in moving_label_keys: - moving_label_key_basename = os.path.basename(moving_label_key) - output_label_suffix = "-" + moving_label_key_basename.split("-")[-1] - output_name = f"{output_label_prefix}{output_label_suffix}" - moving_label_keys_t.append(moving_label_key) - output_names.append(output_name) - - if len(moving_label_keys_t) > 0: - transform_parameter_object = _load_itk_parameters_from_dir( - transform_fs.unstrip_protocol(transform_name) + if len(moving_label_keys) > 0: + times = [] + transform_file_paths = [] + for transform_file in transform_fs.ls( + transform_dest, detail=True, refresh=True + ): + if transform_file["type"] == "directory": + transform_file_path = transform_file["name"] + basename = os.path.basename(transform_file_path) + if basename.startswith("t="): + times.append(basename[2:]) + transform_file_paths.append( + transform_fs.unstrip_protocol(transform_file_path) ) - if transform_parameter_object.GetNumberOfParameterMaps() > 0: - _transform_labels( + index = index_natsorted(times) + times = [times[val] for val in index] + transform_file_paths = [transform_file_paths[val] for val in index] + storage_options = {"chunks": chunksize} + for moving_label_key in moving_label_keys: + moving_label_key_basename = os.path.basename(moving_label_key) + transformed_labels = [] + times_ = [] + for i in range(len(times)): + transform_parameter_object = _load_itk_parameters_from_dir( + transform_file_paths[i] + ) + if transform_parameter_object.GetNumberOfParameterMaps() > 0: + src = read_ome_zarr_array(moving_label_key).squeeze() + times_.append(times[i]) + if src.sizes.get("t", 0) > 0 and "t" in src.coords: + src = src.sel(t=moving_timepoint_value) + transformed_labels.append( + itk_transform_labels( + image=src, transform_parameter_object=transform_parameter_object, - attrs=moving_image_attrs, - matching_keys=moving_label_keys_t, - output_names=output_names, - moving_image_spacing=moving_image_spacing, - output_root=label_output_root, + image_spacing=moving_image_spacing, ) + ) + + transformed_labels = xr.DataArray( + np.stack(transformed_labels), + dims=["t", "y", "x"], + coords={"t": times_}, + attrs=moving_image_attrs, + ) + + _write_zarr_labels( + name=moving_label_key_basename, + root=label_output_root, + metadata=None, + group_metadata=None, + labels=transformed_labels, + storage_options=storage_options, + ) def get_matching_names( @@ -506,15 +555,19 @@ def get_matching_names( image_fs, _ = fsspec.core.url_to_fs(image_dir) image_dir = image_dir.rstrip(image_fs.sep) - glob_pattern = f"{image_dir}{image_fs.sep}{zarr_dir}{image_fs.sep}{image_key}" - + if labels: + glob_pattern = f"{glob_pattern}-*" # for suffix results = [] + for path in image_fs.glob(glob_pattern): path = image_fs.unstrip_protocol(path) name = os.path.basename(path) - if labels and name.split("-")[-1] not in label_suffixes: - continue + + if labels: + tokens = name.split("-") + if tokens[-1] not in label_suffixes and tokens[:-1] == image_key: + continue if not name.startswith(".") and is_ome_zarr_array(zarr.open(path, mode="r")): results.append(path) return results @@ -782,10 +835,16 @@ def run_itk_registration(arguments: argparse.Namespace) -> None: moving_image_pattern = arguments.moving_image_pattern fixed_image_pattern = arguments.fixed_image_pattern group_by = arguments.groupby - reference_timepoint = arguments.time - if reference_timepoint is not None: + fixed_timepoint = arguments.fixed_time if arguments.fixed_time is not None else 0 + moving_timepoint = arguments.moving_time if arguments.moving_time is not None else 0 + if isinstance(fixed_timepoint, str) and fixed_timepoint.isdigit(): + try: + fixed_timepoint = int(fixed_timepoint) + except ValueError: + pass + if isinstance(moving_timepoint, str) and moving_timepoint.isdigit(): try: - reference_timepoint = int(reference_timepoint) + moving_timepoint = int(moving_timepoint) except ValueError: pass unroll_channels = arguments.unroll_channels @@ -895,7 +954,8 @@ def run_itk_registration(arguments: argparse.Namespace) -> None: image_output_root=image_output_root, moving_image_spacing=moving_image_spacing, fixed_image_spacing=fixed_image_spacing, - reference_timepoint=reference_timepoint, + fixed_timepoint=fixed_timepoint, + moving_timepoint=moving_timepoint, landmarks_initialize=landmarks_initialize, landmark_slice_size=landmark_slice_size, landmark_min_count=landmark_min_count, diff --git a/scallops/cli/register_main.py b/scallops/cli/register_main.py index 598d481..8bde07e 100644 --- a/scallops/cli/register_main.py +++ b/scallops/cli/register_main.py @@ -152,10 +152,12 @@ def _create_elastix_parser(subparsers: ArgumentParser, default_help: bool) -> No ) parser.add_argument( - "--time", - "-t", - default="0", - help="Time index (0-based) or value for alignment across timepoints", + "--moving-time", + help="Time index (0-based) or value for moving image", + ) + parser.add_argument( + "--fixed-time", + help="Time index (0-based) or value for fixed image", ) parser.add_argument( "--unroll-channels", diff --git a/scallops/cli/segment.py b/scallops/cli/segment.py index b4afe86..1447867 100644 --- a/scallops/cli/segment.py +++ b/scallops/cli/segment.py @@ -12,15 +12,19 @@ import argparse import importlib +from collections.abc import Sequence from typing import Callable, Literal, Optional import dask.array as da import fsspec import numpy as np +import xarray as xr import zarr +from array_api_compat import get_namespace from dask.bag import from_sequence from zarr import Group +from scallops.cli.register import _get_timepoint_index_and_value from scallops.cli.util import ( _create_dask_client, _create_default_dask_config, @@ -51,6 +55,7 @@ def segment_nuclei( dapi_channel: int, method: Callable, root: Group, + timepoint: int | str, z_index: int | str, min_area: float | None = None, max_area: float | None = None, @@ -71,6 +76,7 @@ def segment_nuclei( :param method: Segmentation method. :param root: Zarr hierarchy root. :param z_index: Either 'max' or z-index + :param timepoint: Time to use. :param min_area: Minimum area threshold for filtering labels. :param max_area: Maximum area threshold for filtering labels. :param chunks: Tuple specifying chunking size for Dask arrays. @@ -88,7 +94,19 @@ def segment_nuclei( logger.info(f"Skipping nuclei segmentation for {image_key}") return root logger.info(f"Running nuclei segmentation for {image_key}") - image = _images2fov(file_list, metadata, dask=True).squeeze() + image = _images2fov( + file_list, + metadata, + concat_dims=("c",), + dask=True, + ) + timepoint, timepoint_value = _get_timepoint_index_and_value(timepoint, image) + + image = ( + image[timepoint] + if isinstance(image, Sequence) + else image.isel(t=timepoint, missing_dims="ignore") + ) image = _z_projection(image, z_index) nuclei_seg_args = {} @@ -125,10 +143,17 @@ def segment_nuclei( group_metadata = { "image-label": {"source": {"image": f"../../images/{image_key}"}} } - additional_metadata = label_metadata.get(key) if label_metadata else None + additional_metadata = label_metadata.get(key) if label_metadata else dict() storage_options = None if isinstance(label_data, np.ndarray): storage_options = {"chunks": image.data.chunksize[-2:]} + if timepoint_value is not None: + label_data = xr.DataArray( + get_namespace(label_data).expand_dims(label_data, 0), + dims=["t", "y", "x"], + coords={"t": [timepoint_value]}, + ) + additional_metadata.update(label_data.attrs) _write_zarr_labels( name=f"{image_key}-{key}", root=root, @@ -150,6 +175,7 @@ def segment_cells( method: Callable, root: Group, z_index: int | str, + timepoint: int | str, min_area: float | None = None, max_area: float | None = None, chunks: None | tuple[int, int] = None, @@ -160,7 +186,6 @@ def segment_cells( cell_segmentation_rolling_ball: bool = False, cell_segmentation_sigma: Optional[float] = None, closing_radius: Optional[int] = None, - cell_segmentation_t: Optional[list[int]] = None, force: bool = False, shrink_primary: bool = False, no_version: bool = False, @@ -178,6 +203,7 @@ def segment_cells( :param method: Segmentation method. :param root: Zarr hierarchy root. :param z_index: Either 'max' or z-index + :param timepoint: Time to use. :param min_area: Minimum area threshold for filtering labels. :param max_area: Maximum area threshold for filtering labels. :param chunks: Tuple specifying chunking size for Dask arrays. @@ -188,7 +214,6 @@ def segment_cells( :param cell_segmentation_rolling_ball: Use rolling ball mask for cell segmentation. :param cell_segmentation_sigma: Standard deviation for smoothing in cell segmentation. :param closing_radius: Radius for closing operation in cell segmentation. - :param cell_segmentation_t: List of timepoints to consider for cell segmentation. :param force: Whether to overwrite existing output :param shrink_primary: Whether to shrink primary labels. :param no_version: Whether to skip version/CLI information in output. @@ -197,7 +222,20 @@ def segment_cells( if not force and is_ome_zarr_array(root.get(f"labels/{image_key}-cell")): logger.info(f"Skipping cell segmentation for {image_key}") return root - image = _images2fov(file_list, metadata, dask=True).squeeze() + image = _images2fov( + file_list, + metadata, + concat_dims=("c",), + dask=True, + ) + + timepoint, timepoint_value = _get_timepoint_index_and_value(timepoint, image) + + image = ( + image[timepoint] + if isinstance(image, Sequence) + else image.isel(t=timepoint, missing_dims="ignore") + ) image = _z_projection(image, z_index) if cyto_channel is None: cyto_channel = np.delete(np.arange(image.sizes["c"]), dapi_channel) @@ -210,7 +248,14 @@ def segment_cells( if method.__name__ in ["segment_cells_watershed", "segment_cells_propagation"]: nuclei = read_ome_zarr_array( nuclei_image_root["labels"][image_key + "-nuclei"] - ).values + ).squeeze() + if ( + timepoint_value is not None + and nuclei.sizes.get("t", 0) > 0 + and "t" in nuclei.coords + ): + nuclei = nuclei.sel(t=timepoint_value) + nuclei = nuclei.values assert nuclei.shape == ( image.sizes["y"], image.sizes["x"], @@ -223,7 +268,6 @@ def segment_cells( cell_seg_args["rolling_ball"] = cell_segmentation_rolling_ball cell_seg_args["sigma"] = cell_segmentation_sigma cell_seg_args["closing_radius"] = closing_radius - cell_seg_args["t"] = cell_segmentation_t if method.__name__ == "segment_cells_watershed": cell_seg_args["watershed_method"] = watershed_method if chunks is not None: @@ -267,10 +311,17 @@ def segment_cells( group_metadata = { "image-label": {"source": {"image": f"../../images/{image_key}"}} } - additional_metadata = label_metadata.get(key) if label_metadata else None + additional_metadata = label_metadata.get(key) if label_metadata else dict() storage_options = None if isinstance(label_data, np.ndarray): storage_options = {"chunks": image.data.chunksize[-2:]} + if timepoint_value is not None: + label_data = xr.DataArray( + get_namespace(label_data).expand_dims(label_data, 0), + dims=["t", "y", "x"], + coords={"t": [timepoint_value]}, + ) + additional_metadata.update(label_data.attrs) _write_zarr_labels( name=f"{image_key}-{key}", root=root, @@ -332,7 +383,12 @@ def run_pipeline(arguments: argparse.Namespace, nuclei: bool): output_root = open_ome_zarr(output, mode="a") kwargs = dict() - + timepoint = arguments.time if arguments.time is not None else 0 + if isinstance(timepoint, str) and timepoint.isdigit(): + try: + timepoint = int(timepoint) + except ValueError: + pass if not nuclei: kwargs["nuclei_min_area"] = arguments.nuclei_min_area kwargs["nuclei_max_area"] = arguments.nuclei_max_area @@ -353,7 +409,6 @@ def run_pipeline(arguments: argparse.Namespace, nuclei: bool): "Please provide sigma for `local` threshold" ) - kwargs["cell_segmentation_t"] = arguments.cell_segmentation_t if method in ["watershed", "watershed-intensity", "propagation"]: nuclei_label = arguments.nuclei_label if nuclei_label is None: @@ -381,10 +436,12 @@ def run_pipeline(arguments: argparse.Namespace, nuclei: bool): kwargs["clip"] = arguments.stardist_clip kwargs["pmin"] = arguments.stardist_pmin kwargs["pmax"] = arguments.stardist_pmax + method = getattr( importlib.import_module("scallops.segmentation." + method), f"{'segment_nuclei_' if nuclei else 'segment_cells_'}{method}", ) + image_seq = from_sequence( _set_up_experiment(data_path, image_pattern, group_by, subset=subset) ) @@ -399,6 +456,7 @@ def run_pipeline(arguments: argparse.Namespace, nuclei: bool): root=output_root, min_area=min_area, max_area=max_area, + timepoint=timepoint, chunks=chunks, chunk_overlap=chunk_overlap, z_index=z_index, diff --git a/scallops/cli/segment_main.py b/scallops/cli/segment_main.py index fe06e6e..d6b513f 100644 --- a/scallops/cli/segment_main.py +++ b/scallops/cli/segment_main.py @@ -71,6 +71,10 @@ def _add_common_args(parser: ArgumentParser) -> None: default=0, help="Channel index (0-based) where DAPI is found", ) + parser.add_argument( + "--time", + help="Time index (0-based) or value.", + ) parser.add_argument( "--min-area", @@ -252,13 +256,6 @@ def _add_cell_parser(subparsers: ArgumentParser, default_help: bool = True) -> N type=int, ) - parser.add_argument( - "--time", - dest="cell_segmentation_t", - help="Time indices (0-based) to include when computing cell segmentation mask. Defaults to all time points.", - type=int, - action="append", - ) parser.add_argument( "--shrink-nuclei", help="Shrink nuclei prior to subtraction of nuclei from cells to identify the " diff --git a/scallops/cli/util.py b/scallops/cli/util.py index b7360e1..928df57 100644 --- a/scallops/cli/util.py +++ b/scallops/cli/util.py @@ -24,6 +24,7 @@ import dask.array as da import fsspec import numpy as np +import pandas as pd import xarray as xr import zarr from distributed import Client @@ -474,7 +475,7 @@ def _write_img_size(file_list: list[str]): f.write("\n") -def _write_group_size(metadata: dict): +def _n_files_in_group(metadata: dict) -> int: n_tiles = len(metadata["file_metadata"]) metadata_fields = [v for v in ("c", "z") if v in metadata["file_metadata"][0]] if len(metadata_fields) > 0: @@ -484,115 +485,94 @@ def _write_group_size(metadata: dict): metadata=metadata, metadata_fields=tuple(metadata_fields) ) n_tiles = len(filepaths) - with open("group_size.txt", "wt") as f: - f.write(f"{n_tiles}") - f.write("\n") + return n_tiles def _list_images_wdl( - image_pattern: str, - urls: list[str], + image_pattern1: str, + urls1: list[str], + reference_time1: str | None, + n_cycles1: str | None, + image_pattern2: str, + urls2: list[str], + reference_time2: str | None, + n_cycles2: str | None, groupby: list[str], - reference_time: str | None, subset: list[str] | None, - batch_size_str: str | None, + batch_size: str | None, save_group_size: bool = False, - expected_cycles_str: int | None = None, ): """Used by WDL workflow to output info about images""" - from scallops.io import _set_up_experiment - batch_size = 1 - expected_cycles = None - if expected_cycles_str is not None and expected_cycles_str != "": - expected_cycles = int(expected_cycles_str) - if batch_size_str is not None and batch_size_str != "": - batch_size = int(batch_size_str) - if reference_time == "": - reference_time = None + urls1 = [url.strip() for url in urls1 if url.strip() != ""] + reference_time1 = None if reference_time1 == "" else reference_time1 + n_cycles1 = int(n_cycles1) if n_cycles1 is not None and n_cycles1 != "" else None + + urls2 = [url.strip() for url in urls2 if url.strip() != ""] + reference_time2 = None if reference_time2 == "" else reference_time2 + n_cycles2 = int(n_cycles2) if n_cycles2 is not None and n_cycles2 != "" else None + batch_size = int(batch_size) if batch_size is not None and batch_size != "" else 1 if subset is not None and ( len(subset) == 0 or (len(subset) == 1 and subset[0] == "") ): subset = None - if image_pattern != "": - groupby = [g for g in groupby if "{" + g + "}" in image_pattern] - exp_gen = _set_up_experiment( - image_path=urls, files_pattern=image_pattern, group_by=groupby, subset=subset + if image_pattern1 != "": + groupby1 = [g for g in groupby if "{" + g + "}" in image_pattern1] + if image_pattern2 != "": + groupby2 = [g for g in groupby if "{" + g + "}" in image_pattern2] + if len(urls1) > 0 and len(urls2) > 0: + groupby = groupby1 + assert groupby1 == groupby2 + elif len(urls1) > 0: + groupby = groupby1 + elif len(urls2) > 0: + groupby = groupby2 + else: + raise ValueError() + + result1 = _list_images( + urls=urls1, + image_pattern=image_pattern1, + groupby=groupby, + subset=subset, + save_group_size=save_group_size, + reference_time=reference_time1, + n_cycles=n_cycles1, ) - - groupby_t = "t" in groupby - times = None - - if not save_group_size: - with open("group_size.txt", "wt") as f: - f.write("0\n") - - with ( - open("subsets.txt", "wt") as groups_out, - open("subsets_with_t.txt", "wt") as groups_with_t_out, - ): - subset_ids = [] - subset_ids_with_reference_times = [] - first = True - - for g, file_list, metadata in exp_gen: - times = None - if first: - first = False - if save_group_size: - _write_group_size(metadata) - if not groupby_t and "t" in metadata["file_metadata"][0]: - times = [md["t"] for md in metadata["file_metadata"]] - if expected_cycles is not None: - assert len(times) == expected_cycles - t_suffix = "" - if times is not None and len(times) > 0: - t_suffix = ( - f"-{times[0]}" if reference_time is None else f"-{reference_time}" - ) - - subset_ids.append('"' + metadata["id"] + '"') - subset_ids_with_reference_times.append( - '"' + metadata["id"] + t_suffix + '"' - ) - if len(subset_ids) == batch_size: - groups_out.write(" ".join(subset_ids)) - groups_out.write("\n") - - groups_with_t_out.write(" ".join(subset_ids_with_reference_times)) - groups_with_t_out.write("\n") - - subset_ids = [] - subset_ids_with_reference_times = [] - if len(subset_ids) > 0: - groups_out.write(" ".join(subset_ids)) - groups_out.write("\n") - - groups_with_t_out.write(" ".join(subset_ids_with_reference_times)) - groups_with_t_out.write("\n") - - with open("groupby.txt", "wt") as f: - for g in groupby: - f.write(g) + result2 = _list_images( + urls=urls2, + image_pattern=image_pattern2, + groupby=groupby, + subset=subset, + save_group_size=save_group_size, + reference_time=reference_time2, + n_cycles=n_cycles2, + ) + results = [result1, result2] + + if len(urls1) > 0 and len(urls2) > 0: + df1 = result1["subset_df"] + df2 = result2["subset_df"] + subset_ids = df1.index.intersection(df2.index) + result1["subset_df"] = df1.loc[subset_ids] + result2["subset_df"] = df2.loc[subset_ids] + elif len(urls1) > 0: + subset_ids = result1["subset_df"].index.values + elif len(urls2) > 0: + subset_ids = result2["subset_df"].index.values + + with open("subsets.txt", "wt") as f: # ["plate1-A1", "plate1-A2", ...] + for i in range(0, len(subset_ids), batch_size): + selected = subset_ids[i : i + batch_size] + f.write(" ".join(selected)) f.write("\n") - groupby_with_t = list(groupby) - - if not groupby_t and times is not None: - groupby_with_t.append("t") - with open("groupby_with_t.txt", "wt") as f: - for g in groupby_with_t: + with open("groupby_array.txt", "wt") as f: # ['plate', 'well'] + for g in groupby: f.write(g) f.write("\n") - - with open("t.txt", "wt") as f: - if times is not None: - for val in times: - f.write(str(val)) - f.write("\n") - - with open("groupby_pattern.txt", "wt") as f: + with open("groupby_pattern.txt", "wt") as f: # "{plate}-{well}" first = True for g in groupby: if not first: @@ -601,19 +581,115 @@ def _list_images_wdl( f.write("{") f.write(g) f.write("}") + for index in range(len(results)): + result = results[index] + url_val = index + 1 + subset_df = result["subset_df"] + + with open(f"group_size_{url_val}.txt", "wt") as f: + f.write(f"{result['group_size']}") + f.write("\n") + + with open(f"reference_time_{url_val}.txt", "wt") as f: # IF + f.write(result["reference_time"]) + f.write("\n") + with open(f"times_{url_val}.txt", "wt") as f: # ["FISH", "IF"] + if result["times"] is not None: + for val in result["times"]: + f.write(str(val)) + f.write("\n") + + with open( + f"subsets_with_reference_time_{url_val}.txt", "wt" + ) as f: # ["plate1-A1-IF", "plate1-A2-IF", ...] + subset_ids_with_reference_times = ( + subset_df["subset_ids_with_reference_times"].values + if subset_df is not None + else [] + ) + for i in range(0, len(subset_ids_with_reference_times), batch_size): + selected = subset_ids_with_reference_times[i : i + batch_size] + f.write(" ".join(selected)) + f.write("\n") + + with open( + f"image_pattern_with_reference_time_{url_val}.txt", "wt" + ) as f: # "{plate}-{well}-IF" + first = True + for g in groupby: + if not first: + f.write("-") + first = False + f.write("{") + f.write(g) + f.write("}") + f.write(f"-{result['reference_time']}") + + +def _list_images( + urls: Sequence[str], + image_pattern: str, + groupby: list[str], + subset: list[str] | None, + save_group_size: bool, + reference_time: str | None, + n_cycles: int | None, +): reference_time_suffix = "" - if reference_time is not None: - reference_time_suffix = f"-{reference_time}" - elif times is not None and len(times) > 0: - reference_time_suffix = f"-{times[0]}" + group_size = 0 + if len(urls) == 0: + return dict( + group_size=group_size, + subset_df=None, + times=None, + reference_time_suffix=reference_time_suffix, + ) + from scallops.io import _set_up_experiment - with open("groupby_pattern_with_reference_t.txt", "wt") as f: - first = True - for g in groupby: - if not first: - f.write("-") + exp_gen = _set_up_experiment( + image_path=urls, files_pattern=image_pattern, group_by=groupby, subset=subset + ) + + groupby_t = "t" in groupby + times = None + subset_ids = [] + subset_ids_with_reference_times = [] + first = True + + for g, file_list, metadata in exp_gen: + times = None + if first: first = False - f.write("{") - f.write(g) - f.write("}") - f.write(reference_time_suffix) + if save_group_size: + group_size = _n_files_in_group(metadata) + if not groupby_t and "t" in metadata["file_metadata"][0]: + times = [md["t"] for md in metadata["file_metadata"]] + if n_cycles is not None: + assert len(times) == n_cycles + t_suffix = "" + if times is not None and len(times) > 0: + t_suffix = ( + f"-{times[0]}" if reference_time is None else f"-{reference_time}" + ) + + subset_ids.append('"' + metadata["id"] + '"') + subset_ids_with_reference_times.append('"' + metadata["id"] + t_suffix + '"') + subset_df = pd.DataFrame( + index=subset_ids, + data=dict(subset_ids_with_reference_times=subset_ids_with_reference_times), + ) + + groupby_with_t = list(groupby) + if not groupby_t and times is not None: + groupby_with_t.append("t") + + if reference_time is None: + reference_time = times[0] if times is not None and len(times) > 0 else "0" + + return dict( + group_size=group_size, + subset_df=subset_df, + groupby=groupby, + times=times, + reference_time=reference_time, + ) diff --git a/scallops/tests/miniwdl_local/local_runner.py b/scallops/tests/miniwdl_local/local_runner.py index c9b1392..6207931 100644 --- a/scallops/tests/miniwdl_local/local_runner.py +++ b/scallops/tests/miniwdl_local/local_runner.py @@ -15,8 +15,13 @@ def global_init(cls, cfg, logger): """ Perform any necessary process-wide initialization of the container backend """ + cpu_count = os.environ.get("SCALLOPS_MINIWDL_CPU") + if cpu_count is None: + cpu_count = psutil.cpu_count() + else: + cpu_count = int(cpu_count) cls._resource_limits = { - "cpu": psutil.cpu_count(), + "cpu": cpu_count, "mem_bytes": psutil.virtual_memory().total, } diff --git a/scallops/tests/test_register_cli.py b/scallops/tests/test_register_cli.py index d2cf5b8..9be1113 100644 --- a/scallops/tests/test_register_cli.py +++ b/scallops/tests/test_register_cli.py @@ -118,7 +118,7 @@ def test_register_itk_cli_known_shift(tmp_path): "scallops", "registration", "elastix", - "--time", + "--moving-time", "1", "--moving", str(data_path), @@ -147,10 +147,11 @@ def test_register_itk_cli_t_reference(tmp_path, array_A1_102_nuclei): tmp_path, "registration-input.zarr" ) exp = Experiment() - reference_t = 2 + + reference_t_index = 2 test_t = 10 array_A1_102_nuclei = array_A1_102_nuclei.squeeze() - exp.labels[f"A1-102-{reference_t}-mask"] = array_A1_102_nuclei + exp.labels["A1-102-nuclei"] = array_A1_102_nuclei exp.save(registration_input_moving_labels_path) cmd = [ @@ -179,8 +180,8 @@ def test_register_itk_cli_t_reference(tmp_path, array_A1_102_nuclei): registration_input_moving_labels_path, "--label-output", elastix_output_dir, - "--time", - str(reference_t), + "--moving-time", + str(reference_t_index), ] subprocess.check_call(cmd) result_exp = read_experiment(elastix_output_dir) @@ -195,16 +196,24 @@ def test_register_itk_cli_t_reference(tmp_path, array_A1_102_nuclei): .images["A1-102"] .squeeze() ) - assert len(result_exp.labels.keys()) == 8 + times = list(original_image.t.values) + del times[reference_t_index] + times = [str(t) for t in times] + transformed_times = list(result_exp.labels["A1-102-nuclei"].coords["t"].values) + transformed_times = [str(t) for t in transformed_times] + assert times == transformed_times, ( + f"{', '.join(times)} != {', '.join(transformed_times)}" + ) + assert len(result_exp.labels.keys()) == 1 np.testing.assert_array_equal(transformed_image.t.values, original_image.t.values) np.testing.assert_array_equal(transformed_image.c.values, original_image.c.values) np.testing.assert_array_equal( - transformed_image.isel(t=reference_t), - original_image.isel(t=reference_t), + transformed_image.isel(t=reference_t_index), + original_image.isel(t=reference_t_index), err_msg="Reference t not equal via CLI", ) for t in range(original_image.sizes["t"]): - if t != reference_t: + if t != reference_t_index: with np.testing.assert_raises(AssertionError): np.testing.assert_array_equal( transformed_image.isel(t=t), original_image.isel(t=t) @@ -232,8 +241,9 @@ def test_register_itk_cli_t_reference(tmp_path, array_A1_102_nuclei): image_spacing=(1, 1), ) assert warped_labels.min() == 0 + np.testing.assert_array_equal( - result_exp.labels[f"A1-102-{test_t}-mask"].values, + result_exp.labels["A1-102-nuclei"].sel(t=str(test_t)).values, warped_labels, err_msg=f"t {test_t} labels not equal using itk_transform_labels and CLI", ) @@ -245,7 +255,7 @@ def test_register_itk_cli_t_reference(tmp_path, array_A1_102_nuclei): moving_channel=[0], parameter_object=parameter_object, moving_image_spacing=(1, 1), - reference_timepoint=reference_t, + reference_timepoint=reference_t_index, ) xr.testing.assert_equal(result_np, transformed_image) @@ -308,19 +318,25 @@ def test_register_transform_labels_moving_only(tmp_path): output_zarr = tmp_path / "out.zarr" output_transforms = tmp_path / "transforms" - img = read_image( - "scallops/tests/data/experimentC/10X_c0-DAPI-p65ab/10X_c0-DAPI-p65ab_A1_Tile-102.phenotype.tif" - ) - img.attrs["physical_pixel_sizes"] = (1, 1) - rng = np.random.default_rng(0) + img = rng.integers(low=0, high=10, size=(2, 2, 100, 100)) + img = xr.DataArray( + img, + dims=["t", "c", "y", "x"], + coords={"t": ["IF", "FISH"]}, + attrs={"physical_pixel_sizes": (1, 1)}, + ) - segmentation = rng.integers(low=0, high=10, size=(img.sizes["y"], img.sizes["x"])) + segmentation = rng.integers(low=0, high=10, size=(100, 100)) Experiment( - images={"plateA-A1-IF": img, "plateA-A1-FISH": img}, + images={"plateA-A1": img}, labels={ - "plateA-A1-IF-cell": segmentation, + "plateA-A1-cell": xr.DataArray( + np.expand_dims(segmentation, 0), + dims=["t", "y", "x"], + coords={"t": ["IF"]}, + ), }, ).save(image_zarr) cmd = [ @@ -330,7 +346,7 @@ def test_register_transform_labels_moving_only(tmp_path): "--moving", str(image_zarr), "--moving-image-pattern", - "{plate}-{well}-{t}", + "{plate}-{well}", "--moving-label", str(image_zarr), "--subset", @@ -344,7 +360,7 @@ def test_register_transform_labels_moving_only(tmp_path): "--label-output", str(output_zarr), "--output-aligned-channels-only", - "--time", + "--moving-time", "IF", "--transform-output", str(output_transforms), @@ -352,7 +368,8 @@ def test_register_transform_labels_moving_only(tmp_path): create_itk_param_file(tmp_path), ] subprocess.check_call(cmd) - transformed_labels = read_image(output_zarr / "labels" / "plateA-A1-FISH-cell") + transformed_labels = read_image(output_zarr / "labels" / "plateA-A1-cell") + assert list(transformed_labels.coords["t"].values) == ["FISH"] assert transformed_labels.max() > 0 transformed_image = read_image(output_zarr / "images" / "plateA-A1") assert transformed_image.shape[0] == 2 diff --git a/scallops/zarr_io.py b/scallops/zarr_io.py index 9f04010..77facc6 100644 --- a/scallops/zarr_io.py +++ b/scallops/zarr_io.py @@ -285,6 +285,65 @@ def _attrs_axes_coordinates( return image_attrs, axes, coordinate_transformations +def _write_zarr_labels( + name: str, + root: zarr.Group | str | Path, + labels: np.ndarray | xr.DataArray | da.Array, + metadata: dict[str, Any] | None = None, + group_metadata: dict[str, Any] | None = None, + compute: bool = True, + storage_options: JSONDict | None = None, +) -> list[Delayed]: + """Write label in zarr format. + + :param name: Zarr group name to store label + :param root: Root zarr group. + :param labels: Labels to write. + :param metadata: Optional label metadata. + :param group_metadata: Optional group level metadata. + :param compute: If true compute immediately otherwise a list + of :class:`dask.delayed.Delayed` is returned. + :param storage_options: Optional storage options. + :return: Empty list if the compute flag is True, otherwise it returns a list + of :class:`dask.delayed.Delayed` representing the value to be computed by dask. + """ + + # stored in labels/key + if isinstance(root, (str, Path)): + root = open_ome_zarr(root, mode="a") + + labels_grp = root.require_group("labels", overwrite=False) + dest_grp = labels_grp.create_group(name.replace("/", "-"), overwrite=True) + + label_attrs = None + coords = None + dims = None + if isinstance(labels, xr.DataArray): + data = labels.data + label_attrs = labels.attrs.copy() + coords = labels.coords + dims = labels.dims + else: + data = labels + + # need 'image-label' attr to be recognized as label + group_metadata = group_metadata.copy() if group_metadata is not None else dict() + if "image-label" not in group_metadata: + group_metadata["image-label"] = {} + metadata = metadata.copy() if metadata is not None else {} + + return write_zarr( + grp=dest_grp, + data=data, + image_attrs=label_attrs, + coords=coords, + dims=dims, + metadata=metadata, + zarr_format="ome_zarr", + compute=compute, + ) + + def _write_zarr_image( name: str | None, root: zarr.Group | str | Path, @@ -509,68 +568,6 @@ def rechunk(image: xr.DataArray | da.Array) -> xr.DataArray | da.Array: return image -def _write_zarr_labels( - name: str, - root: zarr.Group | str | Path, - labels: np.ndarray | xr.DataArray | da.Array, - metadata: dict[str, Any] = None, - group_metadata: dict[str, Any] = None, - compute: bool = True, - storage_options: JSONDict | None = None, -) -> list[Delayed]: - """Write label in zarr format. - - :param name: Zarr group name to store label - :param root: Root zarr group. - :param labels: Labels to write. - :param metadata: Optional label metadata. - :param group_metadata: Optional group level metadata. - :param compute: If true compute immediately otherwise a list - of :class:`dask.delayed.Delayed` is returned. - :param storage_options: Optional storage options. - :return: Empty list if the compute flag is True, otherwise it returns a list - of :class:`dask.delayed.Delayed` representing the value to be computed by dask. - """ - - # stored in labels/key - name = name.replace("/", "-") - if isinstance(root, (str, Path)): - root = open_ome_zarr(root, mode="a") - labels_grp = root.require_group("labels") - grp = labels_grp.create_group(name, overwrite=True) - if not isinstance(labels, xr.DataArray): - if labels.ndim == 2: - label_axes = ["y", "x"] - elif labels.ndim == 5: - label_axes = ["t", "c", "z", "y", "x"] - else: - raise ValueError("Axes can't be inferred for 3D or 4D labels") - else: - label_axes = labels.dims - labels = labels.data - - # need 'image-label' attr to be recognized as label - group_metadata = group_metadata.copy() if group_metadata is not None else dict() - if "image-label" not in group_metadata: - group_metadata["image-label"] = {} - grp.attrs.update(group_metadata) - metadata = metadata.copy() if metadata is not None else {} - if isinstance(labels, da.Array) or ( - isinstance(labels, xr.DataArray) and isinstance(labels.data, da.Array) - ): - labels = rechunk(labels) - return write_image( - labels, - grp, - scaler=None, - # scale_factors=[], - axes=label_axes, - metadata=metadata, - compute=compute, - storage_options=storage_options, - ) - - def _read_zarr_attrs(attrs) -> tuple[dict, dict, list[str]]: """Read attributes from Zarr. diff --git a/wdl/ops_tasks.wdl b/wdl/ops_tasks.wdl index 396bdb5..d1adc5a 100644 --- a/wdl/ops_tasks.wdl +++ b/wdl/ops_tasks.wdl @@ -7,6 +7,7 @@ task segment_nuclei { String? image_pattern Array[String] groupby Int? dapi_channel + String? time String output_directory String subset Boolean? force @@ -33,6 +34,7 @@ task segment_nuclei { --groupby ~{sep=" " groupby} \ ~{if defined(image_pattern) then '--image-pattern "' + image_pattern + '"' else ''} \ ~{'--dapi-channel ' + dapi_channel} \ + ~{'--time ' + time} \ --output "~{output_directory}" \ --subset ~{subset} \ ~{if defined(extra_arguments) then extra_arguments else ''} \ @@ -64,6 +66,7 @@ task segment_cell { Array[String] groupby Int? dapi_channel Array[Int] cyto_channel + String? time Int? chunks String? nuclei_label String? threshold @@ -94,6 +97,7 @@ task segment_cell { --groupby ~{sep=" " groupby} \ ~{if defined(image_pattern) then '--image-pattern "' + image_pattern + '"' else ''} \ ~{'--dapi-channel ' + dapi_channel} \ + ~{'--time ' + time} \ --cyto-channel ~{sep=" " cyto_channel} \ ~{"--nuclei-label " + nuclei_label} \ ~{"--method " + method} \ @@ -134,7 +138,8 @@ task register_elastix { Boolean? output_aligned_channels_only Boolean? unroll_channels String? fixed - String? reference_time + String? moving_time + String? fixed_time Int? fixed_channel Boolean? register_across_channels String transform_output_directory @@ -174,7 +179,8 @@ task register_elastix { --subset ~{subset} \ ~{if defined(label_output_directory) then '--label-output "' + label_output_directory + '"' else ''} \ ~{true="--unroll-channels" false="" unroll_channels} \ - ~{if defined(reference_time) then '--time "' + reference_time + '"' else ''} \ + ~{if defined(moving_time) then '--moving-time "' + moving_time + '"' else ''} \ + ~{if defined(fixed_time) then '--fixed-time "' + fixed_time + '"' else ''} \ ~{true="--force" false="" force} \ ~{true="--align-across-channels" false="" register_across_channels} \ ~{true="--output-aligned-channels-only" false="" output_aligned_channels_only} \ diff --git a/wdl/ops_workflow.wdl b/wdl/ops_workflow.wdl index eed2343..ca801af 100644 --- a/wdl/ops_workflow.wdl +++ b/wdl/ops_workflow.wdl @@ -211,10 +211,14 @@ workflow ops_workflow { call utils.list_images { input: - urls = [select_first([phenotype_url, iss_url])], - image_pattern = if pheno_url_supplied then phenotype_image_pattern else iss_image_pattern, + urls1 = select_all([phenotype_url]), + image_pattern1 = phenotype_image_pattern, + reference_time1=reference_phenotype_time, + + urls2 = select_all([iss_url]), + image_pattern2 = iss_image_pattern, batch_size=batch_size, - reference_time=reference_phenotype_time, + groupby=groupby, subset = subset, docker=docker, @@ -223,27 +227,40 @@ workflow ops_workflow { aws_queue_arn = aws_queue_arn, max_retries = max_retries } - String groupby_pattern = list_images.groupby_pattern # "{plate}-{well}" - Array[String] subsets = list_images.subsets # e.g. ["plate1-A1", "plate1-A2", ...] - Array[String] subset_with_reference_times = list_images.subset_with_reference_times # e.g. ["plate1-A1-IF", "plate1-A2-IF", ...] - Array[String] times = list_images.t # e.g. ["FISH", "IF"] - Array[String] groupby_with_time = list_images.filtered_groupby_with_t # e.g. ['plate', 'well', 't'] - String groupby_pattern_with_reference_t = list_images.groupby_pattern_with_reference_t # e.g. "{plate}-{well}-IF" + Array[String] subsets = list_images.subsets + String groupby_pattern = list_images.groupby_pattern # e.g. {plate}-{well} + Array[String] groupby_array = list_images.groupby_array # e.g. ["plate", "well"] + + + # Array[String] subsets_with_reference_times_pheno = list_images.subsets_with_reference_times_1 + # Array[String] subsets_with_reference_times_iss = list_images.subsets_with_reference_times_2 + + Array[String] times_pheno = list_images.times_1 + Array[String] times_iss = list_images.times_2 + + String reference_time_pheno = list_images.reference_time_1 + String reference_time_iss = list_images.reference_time_2 + + #String image_pattern_with_reference_time_pheno = list_images.image_pattern_with_reference_time_1 # e.g. {plate}-{well}-IF + # String image_pattern_with_reference_time_iss = list_images.image_pattern_with_reference_time_2 # e.g. {plate}-{well}-1 scatter (subset_index in range(length(subsets))) { String subset_ = subsets[subset_index] - String subset_with_reference_time = subset_with_reference_times[subset_index] + # String subset_with_reference_times_pheno = subsets_with_reference_times_pheno[subset_index] + # String subset_with_reference_times_iss = subsets_with_reference_times_iss[subset_index] if(pheno_url_supplied) { if(run_nuclei_segmentation) { call tasks.segment_nuclei { input: images = select_first([phenotype_url]), image_pattern = phenotype_image_pattern, + time=reference_time_pheno, + subset = subset_, method = nuclei_segmentation_method, - groupby=groupby_with_time, + groupby=groupby, dapi_channel = phenotype_dapi_channel, output_directory=segment_directory, model_dir=model_dir, - subset = subset_with_reference_time, + extra_arguments=nuclei_segmentation_extra_arguments, force = force_segment_nuclei, docker=docker, @@ -262,9 +279,10 @@ workflow ops_workflow { input: images = select_first([phenotype_url]), image_pattern = phenotype_image_pattern, + time=reference_time_pheno, method = cell_segmentation_method, - groupby = groupby_with_time, - subset = subset_with_reference_time, + groupby = groupby, + subset = subset_, dapi_channel = phenotype_dapi_channel, cyto_channel=phenotype_cyto_channel, nuclei_label=select_first([segment_nuclei.output_url]), @@ -285,14 +303,15 @@ workflow ops_workflow { max_retries = max_retries } - if(length(times)>1) { + if(length(times_pheno)>1) { call tasks.register_elastix as register_pheno_to_pheno { input: moving=select_all([phenotype_url]), moving_label=segment_cell.output_url, moving_channel=phenotype_dapi_channel, moving_image_pattern=phenotype_image_pattern, - reference_time=reference_phenotype_time, + moving_time=reference_time_pheno, + extra_arguments=pheno_registration_extra_arguments, output_aligned_channels_only=true, groupby=groupby, @@ -375,7 +394,7 @@ workflow ops_workflow { # use stitch mask as image and segment output for reference phenotype or transformed phenotype for others if(mark_stitch_boundary_cells) { - String phenotype_url_stripped = if (pheno_url_supplied) then sub(select_first([phenotype_url]), "/+$", "") else "" + String phenotype_url_stripped = sub(select_first([phenotype_url]), "/+$", "") call tasks.intersects_boundary as cell_intersects_boundary { input: @@ -435,9 +454,11 @@ workflow ops_workflow { fixed_channel=iss_dapi_channel, moving_label=segment_cell.output_url, moving=select_all([phenotype_url]), - moving_image_pattern=groupby_pattern_with_reference_t, + moving_image_pattern=phenotype_image_pattern, fixed_image_pattern=iss_image_pattern, moving_channel=phenotype_dapi_channel, + moving_time=reference_time_pheno, + fixed_time=reference_time_iss, output_aligned_channels_only=true, moving_output_directory=register_pheno_to_iss_directory, label_output_directory=register_pheno_to_iss_directory, @@ -459,10 +480,10 @@ workflow ops_workflow { # ISS t0 to phenotype reference time call tasks.register_pheno_to_iss_qc as register_pheno_to_iss_qc { input: - images=select_first([iss_url]), - image_pattern=iss_image_pattern, - stacked_images=register_pheno_to_iss.moving_output_url, - stacked_image_pattern=phenotype_image_pattern, + images=register_pheno_to_iss.moving_output_url, + image_pattern=groupby_pattern, + stacked_images=select_first([iss_url]), + stacked_image_pattern=groupby_pattern, image_channel=iss_dapi_channel, stacked_image_channel=0, label_type='nuclei', diff --git a/wdl/utils.wdl b/wdl/utils.wdl index 393b806..0daf6aa 100644 --- a/wdl/utils.wdl +++ b/wdl/utils.wdl @@ -1,14 +1,22 @@ -version 1.1 +version 1.0 task list_images { input { - Boolean? save_group_size - Array[String] urls - String? image_pattern - String? reference_time + + Array[String]? urls1 + String? image_pattern1 + String? reference_time1 + Int? n_cycles1 + + Array[String]? urls2 + String? image_pattern2 + String? reference_time2 + Int? n_cycles2 + Array[String] groupby Array[String]? subset - Int? expected_cycles + + Boolean? save_group_size Int? batch_size String docker String zones @@ -26,28 +34,47 @@ task list_images { python <>> output { Array[String] subsets = read_lines('subsets.txt') - Array[String] subset_with_reference_times = read_lines('subsets_with_t.txt') - Array[String] t = read_lines('t.txt') String groupby_pattern = read_lines('groupby_pattern.txt')[0] # e.g. {plate}-{well} - String groupby_pattern_with_reference_t = read_lines('groupby_pattern_with_reference_t.txt')[0] # e.g. {plate}-{well}-IF - Array[String] filtered_groupby_with_t = read_lines('groupby_with_t.txt') # e.g. [plate, well, t] - Array[String] filtered_groupby = read_lines('groupby.txt') # e.g. [plate, well] + Array[String] groupby_array = read_lines('groupby_array.txt') # e.g. ["plate", "well"] + + Int group_size_1 = read_int('group_size_1.txt') + Int group_size_2 = read_int('group_size_2.txt') + Array[String] subsets_with_reference_times_1 = read_lines('subsets_with_reference_time_1.txt') + Array[String] subsets_with_reference_times_2 = read_lines('subsets_with_reference_time_2.txt') + + Array[String] times_1 = read_lines('times_1.txt') + Array[String] times_2 = read_lines('times_2.txt') + + String reference_time_1 = read_lines('reference_time_1.txt')[0] + String reference_time_2 = read_lines('reference_time_2.txt')[0] + + String image_pattern_with_reference_time_1 = read_lines('image_pattern_with_reference_time_1.txt')[0] # e.g. {plate}-{well}-IF + String image_pattern_with_reference_time_2 = read_lines('image_pattern_with_reference_time_2.txt')[0] # e.g. {plate}-{well}-IF - Int group_size = read_int('group_size.txt') } meta { From e52424364b22177cc21620ea40e51866dcdb7191 Mon Sep 17 00:00:00 2001 From: Joshua Gould Date: Fri, 12 Jun 2026 16:38:02 -0400 Subject: [PATCH 08/21] Registration transform across times --- scallops/cli/register.py | 11 +++++------ scallops/tests/test_register_cli.py | 6 +++--- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/scallops/cli/register.py b/scallops/cli/register.py index e40bedd..8f695db 100644 --- a/scallops/cli/register.py +++ b/scallops/cli/register.py @@ -68,7 +68,6 @@ def _output_exists( if not is_ome_zarr_array(label_output_root.get(f"labels/{key}")): labels_exist = False break - # TODO check for transformed labels when register_self image_exists = True if image_output_root is not None: @@ -262,13 +261,13 @@ def single_registration( if not register_self: moving_image = ( - moving_image[moving_timepoint].isel(c=moving_channel, missing_dims="ignore") + moving_image[moving_timepoint] if isinstance(moving_image, Sequence) - else moving_image.isel( - t=moving_timepoint, c=moving_channel, missing_dims="ignore" - ) + else moving_image.isel(t=moving_timepoint, missing_dims="ignore") + ) + moving_image_align = _z_projection( + moving_image.isel(c=moving_channel, missing_dims="ignore"), z_index ) - moving_image_align = _z_projection(moving_image, z_index) if ( moving_image_spacing is None diff --git a/scallops/tests/test_register_cli.py b/scallops/tests/test_register_cli.py index 9be1113..14a6118 100644 --- a/scallops/tests/test_register_cli.py +++ b/scallops/tests/test_register_cli.py @@ -401,6 +401,9 @@ def _warp(img): dims=["c", "y", "x"], ) moving_image = moving_image.isel(c=[0, 1, 2]) + moving_image.coords["c"] = ["a", "b", "c"] + moving_image.attrs["physical_pixel_sizes"] = (2, 2) + moving_image.attrs["physical_pixel_units"] = ("micrometer", "micrometer") moving_labels = array_A1_102_nuclei.squeeze().values moving_labels = warp(moving_labels, st, order=0, preserve_range=True) @@ -413,9 +416,6 @@ def _warp(img): ), dims=["y", "x"], ) - moving_image.coords["c"] = ["a", "b", "c"] - moving_image.attrs["physical_pixel_sizes"] = (2, 2) - moving_image.attrs["physical_pixel_units"] = ("micrometer", "micrometer") fixed_image = xr.DataArray( resize( From 6fc0e5546c42308807297ac84d17158baaad88e4 Mon Sep 17 00:00:00 2001 From: Joshua Gould Date: Fri, 12 Jun 2026 16:52:24 -0400 Subject: [PATCH 09/21] storage options --- scallops/zarr_io.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/scallops/zarr_io.py b/scallops/zarr_io.py index 77facc6..56acd35 100644 --- a/scallops/zarr_io.py +++ b/scallops/zarr_io.py @@ -317,7 +317,7 @@ def _write_zarr_labels( label_attrs = None coords = None - dims = None + dims = ["y", "x"] if isinstance(labels, xr.DataArray): data = labels.data label_attrs = labels.attrs.copy() @@ -341,6 +341,7 @@ def _write_zarr_labels( metadata=metadata, zarr_format="ome_zarr", compute=compute, + storage_options=storage_options, ) @@ -405,6 +406,7 @@ def write_zarr( metadata: dict[str, Any] | None = None, zarr_format: Literal["ome_zarr", "zarr"] = "ome_zarr", compute: bool = True, + storage_options: JSONDict | None = None, ) -> list[Delayed]: """Write data to a Zarr group with optional metadata and scaling. @@ -425,6 +427,7 @@ def write_zarr( :param compute: If True, compute immediately. Otherwise, return a list of dask.delayed. Delayed objects representing the value to be computed by dask. Default is True. + :param storage_options: Optional storage options. :return: Empty list if the compute flag is True, otherwise a list of dask.delayed.Delayed objects. @@ -467,19 +470,26 @@ def write_zarr( dask_delayed = [] fmt = _current_format() if zarr_format == "zarr": # No axis validation + chunks_opt = None + if storage_options is not None: + chunks_opt = storage_options.pop("chunks", None) if isinstance(data, da.Array): d = da.to_zarr( arr=data, url=grp.store, component=str(Path(grp.path, "0")), compute=compute, + storage_options=storage_options, **_da_to_zarr_kwargs(fmt), ) if not compute: dask_delayed.append(d) elif not isinstance(data, zarr.Array): + kwds = _da_to_zarr_kwargs(fmt) + if storage_options is not None: + kwds.update(storage_options) grp.create_dataset( - "0", data=data, overwrite=True, **_da_to_zarr_kwargs(fmt) + "0", data=data, overwrite=True, chunks=chunks_opt, **kwds ) datasets = [{"path": "0"}] @@ -529,6 +539,7 @@ def _write_metadata_delayed(grp, d): if coordinate_transformations is not None else None ), + storage_options=storage_options, ) From baa48d33dade54f512dde06e5d575138c5d6b46b Mon Sep 17 00:00:00 2001 From: Joshua Gould Date: Fri, 12 Jun 2026 16:56:08 -0400 Subject: [PATCH 10/21] pattern --- wdl/ops_workflow.wdl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/wdl/ops_workflow.wdl b/wdl/ops_workflow.wdl index ca801af..fe2121a 100644 --- a/wdl/ops_workflow.wdl +++ b/wdl/ops_workflow.wdl @@ -336,7 +336,7 @@ workflow ops_workflow { call tasks.find_objects as find_objects_nuclei { input: labels=select_all([segment_nuclei.output_url, register_pheno_to_pheno.label_output_url]), - label_pattern=phenotype_image_pattern, + label_pattern=groupby_pattern, suffix="nuclei", output_directory=nuclei_objects_directory, subset = subset_, @@ -355,7 +355,7 @@ workflow ops_workflow { call tasks.find_objects as find_objects_cell { input: labels=select_all([segment_cell.output_url, register_pheno_to_pheno.label_output_url]), - label_pattern=phenotype_image_pattern, + label_pattern=groupby_pattern, suffix="cell", output_directory=cell_objects_directory, subset = subset_, @@ -374,7 +374,7 @@ workflow ops_workflow { call tasks.find_objects as find_objects_cytosol { input: labels=select_all([segment_cell.output_url, register_pheno_to_pheno.label_output_url]), - label_pattern=phenotype_image_pattern, + label_pattern=groupby_pattern, suffix="cytosol", output_directory=cytosol_objects_directory, subset = subset_, @@ -400,7 +400,7 @@ workflow ops_workflow { input: labels=select_all([segment_cell.output_url, register_pheno_to_pheno.label_output_url]), images=phenotype_url_stripped + '/labels/', - image_pattern=phenotype_image_pattern, + image_pattern=groupby_pattern, output_directory=cell_intersects_boundary_directory, label_type='cell', objects=find_objects_cell.output_url, From a6cdfb7fa5615d7ccd221371b3080dc092be8b40 Mon Sep 17 00:00:00 2001 From: Joshua Gould Date: Mon, 15 Jun 2026 12:36:50 -0400 Subject: [PATCH 11/21] handle timepoint in segmentation --- scallops/cli/find_objects.py | 80 ++++++++++++++++++++++--------- scallops/features/find_objects.py | 2 +- 2 files changed, 59 insertions(+), 23 deletions(-) diff --git a/scallops/cli/find_objects.py b/scallops/cli/find_objects.py index 88bf71e..c7fffec 100644 --- a/scallops/cli/find_objects.py +++ b/scallops/cli/find_objects.py @@ -8,6 +8,7 @@ import argparse import json +import dask.array as da import fsspec import zarr from zarr import Group @@ -29,10 +30,11 @@ def _execute( label_tuple: tuple[tuple[str, ...], list[str | Group], dict], + timepoint: str | None, output_dir: str, output_sep: str, - force: bool = False, - no_version: bool = False, + force: bool, + no_version: bool, ): group, file_list, metadata = label_tuple assert len(file_list) == 1 @@ -41,16 +43,31 @@ def _execute( path = ( f"{output_dir}{output_sep}{label_name}{output_sep}{image_key}-objects.parquet" ) + if timepoint is not None: + path = ( + f"{path}/t={timepoint}/{label_name}{output_sep}{image_key}-objects.parquet" + ) fs = fsspec.url_to_fs(path)[0] if fs.exists(path): if force: fs.rm(path, recursive=True) else: - logger.info(f"Skipping finding objects for {metadata['id']}.") + logger.info( + f"Skipping find objects for {metadata['id']}{' at t=' + timepoint if timepoint is not None else ''}." + ) return - logger.info(f"Finding objects for {metadata['id']}.") - array = file_list[0][list(file_list[0].keys())[0]] + logger.info( + f"Find objects for {metadata['id']}{' at t=' + timepoint if timepoint is not None else ''}." + ) + g = file_list[0] + array = da.from_zarr(g[list(g.keys())[0]]) + + if timepoint is not None: + timepoint_index = g.attrs["multiscales"][0]["metadata"]["t"].index(timepoint) + array = array[timepoint_index] + df = find_objects(array) + df.index.name = "label" prefix = _label_name_to_prefix.get(label_name) if prefix is not None: @@ -64,7 +81,9 @@ def _execute( if not no_version else None, ) - logger.info(f"Saved objects for {metadata['id']} to {path}.") + logger.info( + f"Saved objects for {metadata['id']}{' at t=' + timepoint if timepoint is not None else ''} to {path}." + ) def run_pipeline_find_objects(arguments: argparse.Namespace) -> None: @@ -91,33 +110,50 @@ def run_pipeline_find_objects(arguments: argparse.Namespace) -> None: output_fs, _ = fsspec.core.url_to_fs(output_dir) output_dir = output_dir.rstrip(output_fs.sep) - paths = [] + _, _, keys = _create_file_regex(label_pattern) + keys = list(keys) + label_tuples = [] + timepoints = [] + for path in labels_paths: label_root = zarr.open(path, mode="r") labels_group = label_root.get("labels") - if labels_group is None: - raise ValueError(f"Labels group not found for {path}") - paths.append(labels_group) - _, _, keys = _create_file_regex(label_pattern) - gen = _set_up_experiment( - image_path=paths, - files_pattern=label_pattern, - group_by=list(keys), - subset=subset, - ) + if labels_group is not None: + gen = _set_up_experiment( + image_path=labels_group, + files_pattern=label_pattern, + group_by=keys, + subset=subset, + ) + for label_tuple in gen: + label_key, file_list, metadata = label_tuple + + if ( + label_suffix is None + or label_key[len(label_key) - 1] in label_suffix + ): + assert len(file_list) == 1 + g = file_list[0] + zarr_metadata = g.attrs["multiscales"][0]["metadata"] + + if "t" not in zarr_metadata: + label_tuples.append(label_tuple) + timepoints.append(None) + else: + for timepoint_ in zarr_metadata["t"]: + label_tuples.append(label_tuple) + timepoints.append(timepoint_) with ( _create_default_dask_config(), _create_dask_client(dask_server_url, **dask_cluster_parameters), ): - [ + for i in range(len(label_tuples)): _execute( - label_tuple=g, + label_tuple=label_tuples[i], + timepoint=timepoints[i], output_dir=output_dir, output_sep=output_fs.sep, force=force, no_version=no_version, ) - for g in gen - if label_suffix is None or g[0][len(g[0]) - 1] in label_suffix - ] diff --git a/scallops/features/find_objects.py b/scallops/features/find_objects.py index 9c4f175..b315be1 100644 --- a/scallops/features/find_objects.py +++ b/scallops/features/find_objects.py @@ -254,7 +254,7 @@ def _agg_objects(grouped): return objects_df -def find_objects(label_image: da.Array) -> dd.DataFrame: +def find_objects(label_image: da.Array | zarr.Array) -> dd.DataFrame: """Find objects in a labeled array. :param label_image: Image labels noted by integers. From c9db0a2ca52d97fd98ecc913faee374c7fd2f34f Mon Sep 17 00:00:00 2001 From: Joshua Gould Date: Mon, 15 Jun 2026 12:44:46 -0400 Subject: [PATCH 12/21] handle timepoint in segmentation --- scallops/cli/pooled_if_sbs.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/scallops/cli/pooled_if_sbs.py b/scallops/cli/pooled_if_sbs.py index 6f45daf..ceb30bc 100644 --- a/scallops/cli/pooled_if_sbs.py +++ b/scallops/cli/pooled_if_sbs.py @@ -973,9 +973,11 @@ def reads_pipeline( points_path = f"{points_path}/{image_key}-peaks.parquet" peaks = dd.read_parquet(points_path) maxed = read_ome_zarr_array(spots_root["images"][image_key + "-max"], dask=True) - labels = read_ome_zarr_array( - labels_root[image_key + "-" + label_name], dask=True - ).data.compute() + labels = ( + read_ome_zarr_array(labels_root[image_key + "-" + label_name], dask=True) + .data.squeeze() + .compute() + ) if expand_labels_distance is not None and expand_labels_distance > 0: labels = expand_labels(labels, distance=expand_labels_distance) iss_cycles = maxed.coords["t"].values From 3815857f498e13215a5e8c799d704324a3723279 Mon Sep 17 00:00:00 2001 From: Joshua Gould Date: Mon, 15 Jun 2026 13:08:30 -0400 Subject: [PATCH 13/21] handle timepoint in segmentation --- scallops/cli/find_objects.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/scallops/cli/find_objects.py b/scallops/cli/find_objects.py index c7fffec..d292da9 100644 --- a/scallops/cli/find_objects.py +++ b/scallops/cli/find_objects.py @@ -41,12 +41,10 @@ def _execute( label_name = group[len(group) - 1] image_key = "-".join(group[:-1]) # exclude suffix from key path = ( - f"{output_dir}{output_sep}{label_name}{output_sep}{image_key}-objects.parquet" + (f"{output_dir}{output_sep}{label_name}{output_sep}{image_key}-objects.parquet") + if timepoint is None + else f"{output_dir}{output_sep}{label_name}{output_sep}t={timepoint}{output_sep}{image_key}-objects.parquet" ) - if timepoint is not None: - path = ( - f"{path}/t={timepoint}/{label_name}{output_sep}{image_key}-objects.parquet" - ) fs = fsspec.url_to_fs(path)[0] if fs.exists(path): if force: From 3cea7ecea96bceaa98080b76d809aab53bee1190 Mon Sep 17 00:00:00 2001 From: Joshua Gould Date: Tue, 16 Jun 2026 07:54:18 -0400 Subject: [PATCH 14/21] handle timepoint in segmentation --- scallops/cli/extract_crops.py | 13 +- scallops/cli/features.py | 526 ++++++++++++++++++---------------- scallops/cli/features_main.py | 3 +- scallops/cli/find_objects.py | 21 +- scallops/cli/util.py | 12 +- scallops/features/generate.py | 9 +- wdl/ops_tasks.wdl | 4 +- wdl/ops_workflow.wdl | 11 +- wdl/utils.wdl | 3 + 9 files changed, 322 insertions(+), 280 deletions(-) diff --git a/scallops/cli/extract_crops.py b/scallops/cli/extract_crops.py index 5c364e7..c5f7d03 100644 --- a/scallops/cli/extract_crops.py +++ b/scallops/cli/extract_crops.py @@ -10,7 +10,7 @@ from skimage.util import img_as_ubyte from zarr import Group -from scallops.cli.features import _read_merged_or_objects, get_labels +from scallops.cli.features import _read_merged_or_objects from scallops.cli.util import ( _get_cli_logger, cli_metadata, @@ -69,14 +69,11 @@ def single_crop( output_fs.makedirs(output_dir, exist_ok=True) image = _images2fov(file_list, metadata, dask=True).squeeze().data logger.info(f"{image_key} image shape {image.shape}") - zarr_labels = get_labels( - labels_group=labels_group, - name=image_key, - suffix=label_name, # e.g. nuclei - ) - - if zarr_labels is None: + g = labels_group.get(f"{image_key}-{label_name}") + if g is None: raise ValueError(f"Unable to read {label_name} labels for {image_key}.") + zarr_labels = g[list(g.keys())[0]] + merged_df = _read_merged_or_objects( merge_dir=merge_dir, merge_dir_sep=merge_dir_sep, diff --git a/scallops/cli/features.py b/scallops/cli/features.py index 77917cc..1e4ecce 100644 --- a/scallops/cli/features.py +++ b/scallops/cli/features.py @@ -8,15 +8,13 @@ import argparse import json -import warnings from collections.abc import Sequence from itertools import zip_longest -from typing import get_type_hints +from typing import Any, get_type_hints import dask.array import dask.array as da import fsspec -import numpy as np import pandas as pd import pyarrow as pa import pyarrow.parquet as pq @@ -26,6 +24,7 @@ from natsort import natsorted from zarr import Group +from scallops.cli.find_objects import get_path from scallops.cli.util import ( _create_dask_client, _create_default_dask_config, @@ -52,27 +51,28 @@ pluralize, read_anndata_zarr, ) -from scallops.zarr_io import _read_ome_zarr_array logger = _get_cli_logger() def _read_merged_or_objects( - merge_dir: str, - merge_dir_sep: str, + path: str, + path_sep: str, + timepoint: str | None, label_name: str, image_key: str, label_filter: str | None, ): - merge_dir = merge_dir.rstrip(merge_dir_sep) - merge_paths = [ - f"{merge_dir}{merge_dir_sep}{label_name}{merge_dir_sep}{image_key}.parquet", - f"{merge_dir}{merge_dir_sep}{image_key}.zarr", - f"{merge_dir}{merge_dir_sep}{image_key}.parquet", - f"{merge_dir}{merge_dir_sep}{label_name}{merge_dir_sep}{image_key}-objects.parquet", + path = path.rstrip(path_sep) + paths = [ + f"{path}{path_sep}{label_name}{path_sep}{image_key}.parquet", + f"{path}{path_sep}{image_key}.zarr", + f"{path}{path_sep}{image_key}.parquet", + get_path(path, path_sep, label_name, image_key, timepoint), ] + merge_path = None - for path in merge_paths: + for path in paths: if fsspec.core.url_to_fs(path)[0].exists(path): merge_path = path break @@ -103,51 +103,6 @@ def _read_merged_or_objects( return merged_df -def get_labels(labels_group: Group, name: str, suffix: str) -> zarr.Array | None: - """Retrieve labels from a zarr group. - - :param labels_group: The zarr group containing labels. - :param name: The identifier associated with the labels. - :param suffix: The suffix used to identify the specific set of labels (e.g., 'nuclei'). - :return: The retrieved labels as a DataArray or None if the labels are not found. - """ - try: - g = labels_group[f"{name}-{suffix}"] - return g[list(g.keys())[0]] - except KeyError as e: - logger.warning(f'"{name}-{suffix}" not found in {labels_group}.') - raise e - - -def _read_image(file_list: list[str], metadata: dict) -> xr.DataArray: - """Read image files and preprocess them into a standardized format. - - This function reads image files specified in the file_list and processes them into an - xarray.DataArray with dimensions adjusted as needed. It handles missing dimensions and stacks - time and channel dimensions for further processing. - - :param file_list: List of file paths to the image files. - :param metadata: Dictionary containing metadata associated with the images. - :return: DataArray containing the processed image data. - """ - image = _images2fov(file_list, metadata, dask=True) - dims = tuple([d for d in ["t", "c", "z"] if d in image.dims]) - - if len(dims) > 0: - image = image.stack(t_c_z=dims, create_index=False).transpose( - *("y", "x", "t_c_z") - ) - with warnings.catch_warnings(): - # ignore UserWarning: rename 't_c_z' to 'c' does not create an index anymore. - # Try using swap_dims instead or use set_index after rename to create an indexed coordinate. - warnings.filterwarnings("ignore", "rename .*", UserWarning) - image = image.rename({"t_c_z": "c"}) - else: - # add trailing c dimension - image = image.expand_dims("c", -1) - return image - - def _get_feature_channel_indices(tokens): """Get indices of channel parameters in a feature name. @@ -172,10 +127,47 @@ def _get_feature_channel_indices(tokens): ] +def _find_labels( + label_paths: list[str], + image_key: str, + label_name: str, + image_key_no_t: str | None, + selected_timepoint: Any, +): + timepoints = None + g = None + for label_path in label_paths: + label_root = zarr.open(label_path, mode="r") + labels_group = label_root.get("labels") + if labels_group is not None: + g = labels_group.get(f"{image_key}-{label_name}") + if g is not None: + timepoints = ( + g.attrs["multiscales"][0]["metadata"]["t"] + if "t" in g.attrs["multiscales"][0]["metadata"] + else [None] + ) + return g, timepoints + if g is None and image_key_no_t is not None: + g = labels_group.get(f"{image_key_no_t}-{label_name}") + zarr_metadata = g.attrs["multiscales"][0]["metadata"] + + if "t" in zarr_metadata: + timepoints = zarr_metadata["t"] + if selected_timepoint in timepoints: + index = timepoints.index(selected_timepoint) + timepoints = [timepoints[index]] + return g, timepoints + else: + timepoints = [None] + return g, timepoints + return g, timepoints + + def single_feature( stacked_image_tuple: tuple[tuple[str, ...], list[str | Group], dict] | None, image_tuple: tuple[tuple[str, ...], list[str | Group], dict], - labels_group: Group, + label_paths: list[str], output_dir: str, output_sep: str, objects_dir: str, @@ -240,221 +232,246 @@ def single_feature( output_fs, _ = fsspec.core.url_to_fs(output_dir) - zarr_inputs = True - - for f in file_list: - if not isinstance(f, (zarr.Group, zarr.Array)): - zarr_inputs = False - break - - if zarr_inputs and stacked_image_tuple is not None: - for f in stacked_file_list: - if not isinstance(f, (zarr.Group, zarr.Array)): - zarr_inputs = False - break - if not zarr_inputs: - image = _read_image(file_list, metadata) - else: - image = [] - for f in file_list: - array, _, _, _ = _read_ome_zarr_array(f) - image.append(array) + image = _images2fov(file_list, metadata, dask=True) + image_dims = tuple([d for d in ["t", "c", "z"] if d in image.dims]) n_channels1 = None + stacked_image = None if stacked_image_tuple is not None: - if not zarr_inputs: - stacked_image = _read_image(stacked_file_list, stacked_metadata) - n_channels1 = image.sizes["c"] - # clear coords to avoid issues with xr.concat - for c in list(image.coords.keys()): - del image.coords[c] - for c in list(stacked_image.coords.keys()): - del stacked_image.coords[c] - image = xr.concat((image, stacked_image), dim="c") - else: - n_channels1 = 0 - for img in image: - n_channels1 += np.prod(img.shape[:-2]) - n_channels1 = int(n_channels1) - for f in stacked_file_list: - array, _, _, _ = _read_ome_zarr_array(f) - image.append(array) + stacked_image = _images2fov(stacked_file_list, stacked_metadata, dask=True) + n_channels1 = image.sizes["c"] + stacked_image_dims = tuple( + [d for d in ["t", "c", "z"] if d in stacked_image.dims] + ) + image_key_no_t = None + selected_timepoint = None + if "t" in metadata["group_metadata"]["group"]: + image_key_no_t = [] + for key in metadata["group_metadata"]["group"]: + if key != "t": + image_key_no_t.append(str(metadata["group_metadata"]["group"][key])) + else: + selected_timepoint = metadata["group_metadata"]["group"][key] + image_key_no_t = "-".join(image_key_no_t).replace("/", "-") for label_name in label_name_to_features: + label_prefix = _label_name_to_prefix[label_name] features = label_name_to_features[label_name] - output_parquet_path = ( - f"{output_dir}{output_sep}{label_name}{output_sep}{image_key}.parquet" - ) - - if not force and is_parquet_file(output_parquet_path): - logger.info(f"Skipping features for {image_key} {label_name}") - continue - zarr_labels = get_labels( - labels_group=labels_group, - name=image_key, - suffix=label_name, # e.g. nuclei + g, timepoints = _find_labels( + label_paths=label_paths, + image_key=image_key, + label_name=label_name, + image_key_no_t=image_key_no_t, + selected_timepoint=selected_timepoint, ) - - if zarr_labels is None: - logger.info(f"Unable to read {label_name} labels for {image_key}.") + if g is None: + logger.info(f"No labels found for {image_key}") continue - label_prefix = _label_name_to_prefix[label_name] - merged_df = None - if objects_dir is not None: - merged_df = _read_merged_or_objects( - merge_dir=objects_dir, - merge_dir_sep=objects_dir_sep, - label_name=label_name, - image_key=image_key, - label_filter=label_filter, + labels_array = da.from_array(g[list(g.keys())[0]]) + for timepoint in timepoints: + image_ = ( + image.sel(t=timepoint) + if timepoint is not None and image.sizes.get("t", 0) > 1 + else image + ) + image_ = ( + image_.stack(t_c_z=image_dims, create_index=False) + .transpose(*("y", "x", "t_c_z")) + .rename({"t_c_z": "c"}) + if len(image_dims) > 0 + else image_.expand_dims("c", -1) ) + if stacked_image is not None: + stacked_image_ = ( + stacked_image.stack(t_c_z=stacked_image_dims, create_index=False) + .transpose(*("y", "x", "t_c_z")) + .rename({"t_c_z": "c"}) + if len(stacked_image_dims) > 0 + else stacked_image.expand_dims("c", -1) + ) - if merged_df is None: - logger.info(f"Find {label_name} objects for {image_key}.") - merged_df = find_objects(zarr_labels) - objects_path = f"{output_dir}{output_sep}{label_name}{output_sep}{image_key}-objects.parquet" - merged_df.index.name = "label" - merged_df.columns = f"{label_prefix}_" + merged_df.columns - _to_parquet( - merged_df, - objects_path, - write_index=True, - compute=True, - custom_metadata=dict(scallops=json.dumps(cli_metadata())) - if not no_version - else None, + intensity_image = ( + xr.concat((image_, stacked_image_), dim="c", join="outer") + if stacked_image is not None + else image_ ) - merged_df = pd.read_parquet(objects_path) - - features = normalize_features(features) - # strip nuclei_, etc. from features_plot - features_plot_label = [] - for feature in features_plot: - tokens = feature.lower().split("_") - if tokens[0] == label_prefix: - features_plot_label.append("_".join(tokens[1:])) - - features_plot_label = normalize_features(features_plot_label) - if stacked_image_tuple is not None: - stacked_features = set() - stacked_features_plot = set() - for feature in features: # add offset for image1 - tokens = feature.lower().split("_") - if ( - tokens[0] in _features_multichannel.keys() - or tokens[0] in _features_single_channel.keys() - ): - for token_index in range( - 1, - 3 if tokens[0] in _features_multichannel.keys() else 2, - ): - c = tokens[token_index] - if c[0] == "s": - if c in channel_names: - channel_names[str(n_channels1 + int(c[1:]))] = ( - channel_names.pop(c) - ) - tokens[token_index] = str(n_channels1 + int(c[1:])) - - new_feature = "_".join(tokens) - if feature in features_plot_label: - stacked_features_plot.add(new_feature) - stacked_features.add(new_feature) - - features = stacked_features - features_plot_label = stacked_features_plot - - features = list(set(natsorted(features))) - logger.info( - f"{image_key} {label_name} {len(features):,} {pluralize('feature', len(features))}: " - f"{', '.join(features)}" - ) - if label_filter is not None: - merged_df = merged_df.query(label_filter) - min_max_area = label_name_to_min_max_area.get(label_name) - area_column = f"{label_prefix}_AreaShape_Area" - n_labels = len(merged_df) - prefix = "" - if min_max_area[0] is not None or min_max_area[1] is not None: - area_query = [] - if min_max_area[0] is not None: - area_query.append(f"{area_column}>={min_max_area[0]}") - if min_max_area[1] is not None: - area_query.append(f"{area_column}<={min_max_area[1]}") - merged_df = merged_df.query("&".join(area_query)) - n_labels_filtered = n_labels - len(merged_df) - prefix = f"Removed {n_labels_filtered:,} out of " - logger.info( - f"{prefix}{n_labels:,} {pluralize('label', n_labels)}. " - f"Area: {merged_df[area_column].min():,.0f} to {merged_df[area_column].max():,.0f}." - ) + features_path = get_path( + output_dir, output_sep, label_name, image_key, timepoint, ".parquet" + ) + if not force and is_parquet_file(features_path): + logger.info( + f"Skipping features for {image_key} {label_name}{' at t=' + timepoint if timepoint is not None else ''}." + ) + continue + + merged_df = None + if objects_dir is not None: + merged_df = _read_merged_or_objects( + path=objects_dir, + timepoint=timepoint, + path_sep=objects_dir_sep, + label_name=label_name, + image_key=image_key, + label_filter=label_filter, + ) + label_image = labels_array[ + timepoints.index(timepoint) + if timepoint is not None and labels_array.ndim == 3 + else labels_array + ] - df = label_features( - objects_df=merged_df, - label_image=zarr_labels if zarr_inputs else da.from_zarr(zarr_labels), - intensity_image=image if zarr_inputs else image.data, - features=features, - normalize=normalize, - bounding_box_columns=[ - f"{label_prefix}_AreaShape_BoundingBoxMinimum_Y", - f"{label_prefix}_AreaShape_BoundingBoxMinimum_X", - f"{label_prefix}_AreaShape_BoundingBoxMaximum_Y", - f"{label_prefix}_AreaShape_BoundingBoxMaximum_X", - ], - channel_names=channel_names, - ) - # df will be None if only area and coordinates requested - - if df is not None: - fs = fsspec.url_to_fs(output_parquet_path)[0] - if fs.exists(output_parquet_path): - fs.rm(output_parquet_path, recursive=True) - df.index.name = "label" - df.columns = f"{label_prefix}_" + df.columns - - if isinstance(df, pd.DataFrame): - table = pa.Table.from_pandas(df, preserve_index=True) - if not no_version: - table = table.replace_schema_metadata( - { - "scallops".encode(): json.dumps(cli_metadata()).encode(), - **table.schema.metadata, - } - ) - fs, output_file = fsspec.url_to_fs(output_parquet_path) - pq.write_table( - table, - output_parquet_path, - filesystem=fs, + if merged_df is None: + logger.info( + f"Find {label_name} objects for {image_key}{' at t=' + timepoint if timepoint is not None else ''}." + ) + merged_df = find_objects(label_image) + objects_path = get_path( + output_dir, output_sep, label_name, image_key, timepoint ) - else: + merged_df.index.name = "label" + merged_df.columns = f"{label_prefix}_" + merged_df.columns _to_parquet( - df, - output_parquet_path, + merged_df, + objects_path, write_index=True, compute=True, custom_metadata=dict(scallops=json.dumps(cli_metadata())) if not no_version else None, ) + merged_df = pd.read_parquet(objects_path) - if len(features_plot_label) > 0: - features_plot_label = [ - label_prefix + "_" + feature for feature in features_plot_label - ] - df_features = pd.read_parquet( - output_parquet_path, columns=features_plot_label + features = normalize_features(features) + # strip nuclei_, etc. from features_plot + features_plot_label = [] + for feature in features_plot: + tokens = feature.lower().split("_") + if tokens[0] == label_prefix: + features_plot_label.append("_".join(tokens[1:])) + + features_plot_label = normalize_features(features_plot_label) + if stacked_image_tuple is not None: + stacked_features = set() + stacked_features_plot = set() + for feature in features: # add offset for image1 + tokens = feature.lower().split("_") + + if ( + tokens[0] in _features_multichannel.keys() + or tokens[0] in _features_single_channel.keys() + ): + for token_index in range( + 1, + 3 if tokens[0] in _features_multichannel.keys() else 2, + ): + c = tokens[token_index] + if c[0] == "s": + if c in channel_names: + channel_names[str(n_channels1 + int(c[1:]))] = ( + channel_names.pop(c) + ) + tokens[token_index] = str(n_channels1 + int(c[1:])) + + new_feature = "_".join(tokens) + if feature in features_plot_label: + stacked_features_plot.add(new_feature) + stacked_features.add(new_feature) + + features = stacked_features + features_plot_label = stacked_features_plot + + features = list(set(natsorted(features))) + + if label_filter is not None: + merged_df = merged_df.query(label_filter) + min_max_area = label_name_to_min_max_area.get(label_name) + area_column = f"{label_prefix}_AreaShape_Area" + n_labels = len(merged_df) + prefix = "" + if min_max_area[0] is not None or min_max_area[1] is not None: + area_query = [] + if min_max_area[0] is not None: + area_query.append(f"{area_column}>={min_max_area[0]}") + if min_max_area[1] is not None: + area_query.append(f"{area_column}<={min_max_area[1]}") + merged_df = merged_df.query("&".join(area_query)) + n_labels_filtered = n_labels - len(merged_df) + prefix = f"Removed {n_labels_filtered:,} out of " + logger.info( + f"{prefix}{n_labels:,} {pluralize('label', n_labels)}. " + f"Area: {merged_df[area_column].min():,.0f} to {merged_df[area_column].max():,.0f}." ) - centroid_columns = [ - label_prefix + "_centroid-1", - label_name + "_centroid-0", - ] - df = merged_df[centroid_columns].join(df_features) - pdf_path = ( - f"{output_dir}{output_sep}{label_name}{output_sep}{image_key}.pdf" + + df = label_features( + objects_df=merged_df, + label_image=label_image, + intensity_image=intensity_image, + features=features, + normalize=normalize, + bounding_box_columns=[ + f"{label_prefix}_AreaShape_BoundingBoxMinimum_Y", + f"{label_prefix}_AreaShape_BoundingBoxMinimum_X", + f"{label_prefix}_AreaShape_BoundingBoxMaximum_Y", + f"{label_prefix}_AreaShape_BoundingBoxMaximum_X", + ], + channel_names=channel_names, ) - _plot_features(df, features_plot_label, pdf_path, centroid_columns) + # df will be None if only area and coordinates requested + + if df is not None: + fs = fsspec.url_to_fs(features_path)[0] + if fs.exists(features_path): + fs.rm(features_path, recursive=True) + df.index.name = "label" + df.columns = f"{label_prefix}_" + df.columns + + if isinstance(df, pd.DataFrame): + table = pa.Table.from_pandas(df, preserve_index=True) + if not no_version: + table = table.replace_schema_metadata( + { + "scallops".encode(): json.dumps( + cli_metadata() + ).encode(), + **table.schema.metadata, + } + ) + fs, output_file = fsspec.url_to_fs(features_path) + pq.write_table( + table, + features_path, + filesystem=fs, + ) + + else: + _to_parquet( + df, + features_path, + write_index=True, + compute=True, + custom_metadata=dict(scallops=json.dumps(cli_metadata())) + if not no_version + else None, + ) + + if len(features_plot_label) > 0: + features_plot_label = [ + label_prefix + "_" + feature for feature in features_plot_label + ] + df_features = pd.read_parquet( + features_path, columns=features_plot_label + ) + centroid_columns = [ + label_prefix + "_centroid-1", + label_name + "_centroid-0", + ] + df = merged_df[centroid_columns].join(df_features) + pdf_path = get_path( + output_dir, output_sep, label_name, image_key, timepoint, ".pdf" + ) + + _plot_features(df, features_plot_label, pdf_path, centroid_columns) return [] @@ -481,6 +498,7 @@ def run_pipeline_compute_features(arguments: argparse.Namespace) -> None: channel_names = arguments.channel_rename stack_images = arguments.stack_images label_filter = arguments.label_filter + label_paths = arguments.labels normalize = not arguments.no_normalize stack_image_pattern = arguments.stack_image_pattern cell_features = arguments.features_cell @@ -522,11 +540,9 @@ def run_pipeline_compute_features(arguments: argparse.Namespace) -> None: output_dir = output_dir.rstrip(output_fs.sep) for label in label_name_to_features: output_fs.makedirs(output_dir + output_fs.sep + label, exist_ok=True) - labels_path = arguments.labels + no_version = arguments.no_version - assert labels_path is not None, "No labels provided" - label_root = zarr.open(labels_path, mode="r") - labels_group = label_root["labels"] + if channel_names is not None: # keys are strings in json try: @@ -566,7 +582,7 @@ def run_pipeline_compute_features(arguments: argparse.Namespace) -> None: output_sep=output_fs.sep, objects_dir=objects_dir, objects_dir_sep=objects_dir_sep, - labels_group=labels_group, + label_paths=label_paths, label_filter=label_filter, label_name_to_min_max_area=label_name_to_min_max_area, label_name_to_features=label_name_to_features, diff --git a/scallops/cli/features_main.py b/scallops/cli/features_main.py index e7c5d19..0117d13 100644 --- a/scallops/cli/features_main.py +++ b/scallops/cli/features_main.py @@ -73,7 +73,8 @@ def new_format_help(x): "--labels", dest="labels", required=True, - help="Path to zarr directory containing labels", + nargs="+", + help="Path(s) to zarr directory containing labels", ) generic_features_help = ( diff --git a/scallops/cli/find_objects.py b/scallops/cli/find_objects.py index d292da9..c746517 100644 --- a/scallops/cli/find_objects.py +++ b/scallops/cli/find_objects.py @@ -28,6 +28,21 @@ logger = _get_cli_logger() +def get_path( + output_dir: str, + output_sep: str, + label_name: str, + image_key: str, + timepoint: str | None, + suffix="-objects.parquet", +): + return ( + (f"{output_dir}{output_sep}{label_name}{output_sep}{image_key}{suffix}") + if timepoint is None + else f"{output_dir}{output_sep}{label_name}{output_sep}t={timepoint}{output_sep}{image_key}{suffix}" + ) + + def _execute( label_tuple: tuple[tuple[str, ...], list[str | Group], dict], timepoint: str | None, @@ -40,11 +55,7 @@ def _execute( assert len(file_list) == 1 label_name = group[len(group) - 1] image_key = "-".join(group[:-1]) # exclude suffix from key - path = ( - (f"{output_dir}{output_sep}{label_name}{output_sep}{image_key}-objects.parquet") - if timepoint is None - else f"{output_dir}{output_sep}{label_name}{output_sep}t={timepoint}{output_sep}{image_key}-objects.parquet" - ) + path = get_path(output_dir, output_sep, label_name, image_key, timepoint) fs = fsspec.url_to_fs(path)[0] if fs.exists(path): if force: diff --git a/scallops/cli/util.py b/scallops/cli/util.py index 928df57..2266e89 100644 --- a/scallops/cli/util.py +++ b/scallops/cli/util.py @@ -586,6 +586,13 @@ def _list_images_wdl( url_val = index + 1 subset_df = result["subset_df"] + with open( + f"groupby_array_with_time_{url_val}.txt", "wt" + ) as f: # ['plate', 'well', 't'] + for g in result["groupby_with_time"]: + f.write(g) + f.write("\n") + with open(f"group_size_{url_val}.txt", "wt") as f: f.write(f"{result['group_size']}") f.write("\n") @@ -679,9 +686,9 @@ def _list_images( data=dict(subset_ids_with_reference_times=subset_ids_with_reference_times), ) - groupby_with_t = list(groupby) + groupby_with_time = list(groupby) if not groupby_t and times is not None: - groupby_with_t.append("t") + groupby_with_time.append("t") if reference_time is None: reference_time = times[0] if times is not None and len(times) > 0 else "0" @@ -690,6 +697,7 @@ def _list_images( group_size=group_size, subset_df=subset_df, groupby=groupby, + groupby_with_time=groupby_with_time, times=times, reference_time=reference_time, ) diff --git a/scallops/features/generate.py b/scallops/features/generate.py index 59b9234..036bd62 100644 --- a/scallops/features/generate.py +++ b/scallops/features/generate.py @@ -16,6 +16,7 @@ import numpy as np import pandas as pd import skimage.util +import xarray as xr import zarr from dask import delayed @@ -123,8 +124,8 @@ def _create_dd_metadata( def label_features( objects_df: pd.DataFrame, - label_image: da.Array | zarr.Array, - intensity_image: da.Array | zarr.Array | Sequence[zarr.Array] | None, + label_image: da.Array | xr.DataArray | zarr.Array, + intensity_image: da.Array | zarr.Array | xr.DataArray | Sequence[zarr.Array] | None, features: Iterable[str], channel_names: dict[int | str, str] | None = None, normalize: bool = True, @@ -146,6 +147,10 @@ def label_features( :return: DataFrame with extracted features. """ is_numpy = False + if isinstance(label_image, xr.DataArray): + label_image = label_image.data + if isinstance(intensity_image, xr.DataArray): + intensity_image = intensity_image.data if isinstance(label_image, np.ndarray): is_numpy = True label_image = da.from_array(label_image) diff --git a/wdl/ops_tasks.wdl b/wdl/ops_tasks.wdl index d1adc5a..c01f669 100644 --- a/wdl/ops_tasks.wdl +++ b/wdl/ops_tasks.wdl @@ -379,7 +379,7 @@ task intersects_boundary { String images String? image_pattern String label_type - String labels + Array[String] labels String subset String? objects String output_directory @@ -405,7 +405,7 @@ task intersects_boundary { scallops features \ --features-~{label_type} "intersects-boundary_0" \ - --labels "~{labels}" \ + --labels ~{sep=" " labels} \ --groupby ~{sep=" " groupby} \ --subset ~{subset} \ --output "~{output_directory}" \ diff --git a/wdl/ops_workflow.wdl b/wdl/ops_workflow.wdl index fe2121a..5320ce6 100644 --- a/wdl/ops_workflow.wdl +++ b/wdl/ops_workflow.wdl @@ -240,11 +240,11 @@ workflow ops_workflow { String reference_time_pheno = list_images.reference_time_1 String reference_time_iss = list_images.reference_time_2 - + Array[String] phenotype_group_by_with_time = list_images.groupby_array_with_time_2 #String image_pattern_with_reference_time_pheno = list_images.image_pattern_with_reference_time_1 # e.g. {plate}-{well}-IF # String image_pattern_with_reference_time_iss = list_images.image_pattern_with_reference_time_2 # e.g. {plate}-{well}-1 scatter (subset_index in range(length(subsets))) { - String subset_ = subsets[subset_index] + String subset_ = subsets[subset_index] # e.g. plate1-A1 # String subset_with_reference_times_pheno = subsets_with_reference_times_pheno[subset_index] # String subset_with_reference_times_iss = subsets_with_reference_times_iss[subset_index] if(pheno_url_supplied) { @@ -304,6 +304,7 @@ workflow ops_workflow { } if(length(times_pheno)>1) { + call tasks.register_elastix as register_pheno_to_pheno { input: moving=select_all([phenotype_url]), @@ -400,12 +401,12 @@ workflow ops_workflow { input: labels=select_all([segment_cell.output_url, register_pheno_to_pheno.label_output_url]), images=phenotype_url_stripped + '/labels/', - image_pattern=groupby_pattern, + image_pattern=phenotype_image_pattern + '-mask', output_directory=cell_intersects_boundary_directory, label_type='cell', objects=find_objects_cell.output_url, - groupby=groupby, - subset = subset_, + groupby=phenotype_group_by_with_time, + subset = subset_ + '-*', force = force_segment_cell, docker=docker, zones = zones, diff --git a/wdl/utils.wdl b/wdl/utils.wdl index 0daf6aa..3a8a0ee 100644 --- a/wdl/utils.wdl +++ b/wdl/utils.wdl @@ -61,6 +61,9 @@ task list_images { String groupby_pattern = read_lines('groupby_pattern.txt')[0] # e.g. {plate}-{well} Array[String] groupby_array = read_lines('groupby_array.txt') # e.g. ["plate", "well"] + Array[String] groupby_array_with_time_1 = read_lines('groupby_array_with_time_1.txt') # e.g. ["plate", "well", "t"] + Array[String] groupby_array_with_time_2 = read_lines('groupby_array_with_time_2.txt') # e.g. ["plate", "well", "t"] + Int group_size_1 = read_int('group_size_1.txt') Int group_size_2 = read_int('group_size_2.txt') Array[String] subsets_with_reference_times_1 = read_lines('subsets_with_reference_time_1.txt') From 13a32ca7c9406342a4ada8c4912c7d61ecb23a5c Mon Sep 17 00:00:00 2001 From: Joshua Gould Date: Tue, 16 Jun 2026 09:04:44 -0400 Subject: [PATCH 15/21] Segmentation time --- wdl/ops_tasks.wdl | 2 +- wdl/ops_workflow.wdl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/wdl/ops_tasks.wdl b/wdl/ops_tasks.wdl index c01f669..0446745 100644 --- a/wdl/ops_tasks.wdl +++ b/wdl/ops_tasks.wdl @@ -407,7 +407,7 @@ task intersects_boundary { --features-~{label_type} "intersects-boundary_0" \ --labels ~{sep=" " labels} \ --groupby ~{sep=" " groupby} \ - --subset ~{subset} \ + --subset "~{subset}" \ --output "~{output_directory}" \ --images "~{images}" \ --objects "~{objects}" \ diff --git a/wdl/ops_workflow.wdl b/wdl/ops_workflow.wdl index 5320ce6..34ef69f 100644 --- a/wdl/ops_workflow.wdl +++ b/wdl/ops_workflow.wdl @@ -483,7 +483,7 @@ workflow ops_workflow { input: images=register_pheno_to_iss.moving_output_url, image_pattern=groupby_pattern, - stacked_images=select_first([iss_url]), + stacked_images=select_first([register_pheno_to_iss.moving_output_url]), stacked_image_pattern=groupby_pattern, image_channel=iss_dapi_channel, stacked_image_channel=0, From 33ce260e486a6105a6ee57413802f7ed29956ec9 Mon Sep 17 00:00:00 2001 From: Joshua Gould Date: Tue, 16 Jun 2026 10:39:23 -0400 Subject: [PATCH 16/21] Segmentation time --- scallops/cli/features.py | 48 ++++++++++++++------------ scallops/tests/test_wdl.py | 20 +++++------ wdl/ops_tasks.wdl | 71 +++++++++----------------------------- wdl/ops_workflow.wdl | 7 ++-- 4 files changed, 55 insertions(+), 91 deletions(-) diff --git a/scallops/cli/features.py b/scallops/cli/features.py index 1e4ecce..0e2c990 100644 --- a/scallops/cli/features.py +++ b/scallops/cli/features.py @@ -8,6 +8,7 @@ import argparse import json +import warnings from collections.abc import Sequence from itertools import zip_longest from typing import Any, get_type_hints @@ -164,6 +165,21 @@ def _find_labels( return g, timepoints +def _stack_and_rename(image: xr.DataArray) -> xr.DataArray: + image_dims = tuple([d for d in ["t", "c", "z"] if d in image.dims]) + with warnings.catch_warnings(): + # ignore UserWarning: rename 't_c_z' to 'c' does not create an index anymore. + # Try using swap_dims instead or use set_index after rename to create an indexed coordinate. + warnings.filterwarnings("ignore", "rename .*", UserWarning) + return ( + image.stack(t_c_z=image_dims, create_index=False) + .transpose(*("y", "x", "t_c_z")) + .rename({"t_c_z": "c"}) + if len(image_dims) > 0 + else image.expand_dims("c", -1) + ) + + def single_feature( stacked_image_tuple: tuple[tuple[str, ...], list[str | Group], dict] | None, image_tuple: tuple[tuple[str, ...], list[str | Group], dict], @@ -233,15 +249,13 @@ def single_feature( output_fs, _ = fsspec.core.url_to_fs(output_dir) image = _images2fov(file_list, metadata, dask=True) - image_dims = tuple([d for d in ["t", "c", "z"] if d in image.dims]) + n_channels1 = None stacked_image = None if stacked_image_tuple is not None: stacked_image = _images2fov(stacked_file_list, stacked_metadata, dask=True) n_channels1 = image.sizes["c"] - stacked_image_dims = tuple( - [d for d in ["t", "c", "z"] if d in stacked_image.dims] - ) + image_key_no_t = None selected_timepoint = None if "t" in metadata["group_metadata"]["group"]: @@ -273,21 +287,9 @@ def single_feature( if timepoint is not None and image.sizes.get("t", 0) > 1 else image ) - image_ = ( - image_.stack(t_c_z=image_dims, create_index=False) - .transpose(*("y", "x", "t_c_z")) - .rename({"t_c_z": "c"}) - if len(image_dims) > 0 - else image_.expand_dims("c", -1) - ) + image_ = _stack_and_rename(image_) if stacked_image is not None: - stacked_image_ = ( - stacked_image.stack(t_c_z=stacked_image_dims, create_index=False) - .transpose(*("y", "x", "t_c_z")) - .rename({"t_c_z": "c"}) - if len(stacked_image_dims) > 0 - else stacked_image.expand_dims("c", -1) - ) + stacked_image_ = _stack_and_rename(stacked_image) intensity_image = ( xr.concat((image_, stacked_image_), dim="c", join="outer") @@ -314,11 +316,11 @@ def single_feature( image_key=image_key, label_filter=label_filter, ) - label_image = labels_array[ - timepoints.index(timepoint) - if timepoint is not None and labels_array.ndim == 3 - else labels_array - ] + if timepoint is not None and labels_array.ndim == 3: + timepoint_index = timepoints.index(timepoint) + label_image = labels_array[timepoint_index] + else: + label_image = labels_array if merged_df is None: logger.info( diff --git a/scallops/tests/test_wdl.py b/scallops/tests/test_wdl.py index 943e549..96786dd 100644 --- a/scallops/tests/test_wdl.py +++ b/scallops/tests/test_wdl.py @@ -1,3 +1,4 @@ +import glob import json import os.path from subprocess import check_call @@ -155,15 +156,17 @@ def test_stitch_wdl(tmp_path): @pytest.mark.cli_e2e def test_ops_wdl(tmp_path): - sbs_dir = tmp_path / "sbs.zarr" - pheno_dir = tmp_path / "pheno.zarr" + sbs_dir = tmp_path / "sbs" output = tmp_path / "out" + pheno_dir = tmp_path / "pheno.zarr" + sbs_dir.mkdir() output.mkdir() + for p in glob.glob("scallops/tests/data/experimentC/input/*/*Tile-102*"): + cycles = os.path.basename(p).split("_")[1] + cycles = cycles.split("-")[0] + dest = f"plateA-A1-{cycles[1:]}.tif" + add_physical_size(p, str(sbs_dir / dest)) - iss_img = read_image( - "scallops/tests/data/experimentC/input/10X_c1-SBS-1/10X_c1-SBS-1_A1_Tile-102.sbs.tif" - ) - iss_img.attrs["physical_pixel_sizes"] = (1, 1) pheno_img = read_image( "scallops/tests/data/experimentC/10X_c0-DAPI-p65ab/10X_c0-DAPI-p65ab_A1_Tile-102.phenotype.tif" ) @@ -186,13 +189,10 @@ def test_ops_wdl(tmp_path): }, ).save(pheno_dir) - Experiment( - images={"plateA-A1-1": iss_img, "plateA-A1-2": iss_img}, - ).save(sbs_dir) - input_json = { "model_dir": "", "iss_url": str(sbs_dir.absolute()), + "iss_image_pattern": "{plate}-{well}-{t}.tif", "output_directory": str(output.absolute()), "iss_registration_extra_arguments": "--no-landmarks", "pheno_to_iss_registration_extra_arguments": "--no-landmarks", diff --git a/wdl/ops_tasks.wdl b/wdl/ops_tasks.wdl index 0446745..3ec391b 100644 --- a/wdl/ops_tasks.wdl +++ b/wdl/ops_tasks.wdl @@ -272,16 +272,17 @@ task register_pheno_to_iss_qc { } -task register_qc { +task register_iss_iss_qc { input { String images String? image_pattern String label_type String labels - Int channel + Int dapi_channel + Int n_timepoints String subset String output_directory - String channel_prefix + Array[String] groupby Boolean? force @@ -294,6 +295,7 @@ task register_qc { String memory Int max_retries } + Int n_channels = n_timepoints*5 command <<< set -ex @@ -302,58 +304,17 @@ task register_qc { ulimit -n 100000 fi - python < 0: - cmd.append("--subset") - cmd += subset - cmd += ["--output", output_directory] - cmd += ["--images", images] - cmd += ["--channel-rename", f"{json.dumps(channel_rename)}"] - - if force == "true": - cmd.append("--force") - print(" ".join(cmd)) - check_call(cmd) - - CODE + scallops features \ + --features-~{label_type} "correlationpearsonbox_~{dapi_channel}_5:~{n_channels}:5" \ + --labels "~{labels}" \ + --groupby ~{sep=" " groupby} \ + --subset ~{subset} \ + --output "~{output_directory}" \ + --images "~{images}" \ + ~{'--image-pattern ' + image_pattern} \ + ~{true="--force" false="" force} + + >>> diff --git a/wdl/ops_workflow.wdl b/wdl/ops_workflow.wdl index 34ef69f..100f474 100644 --- a/wdl/ops_workflow.wdl +++ b/wdl/ops_workflow.wdl @@ -503,13 +503,14 @@ workflow ops_workflow { max_retries = max_retries } # ISS t0 to other times - call tasks.register_qc as register_iss_to_iss_qc { + call tasks.register_iss_iss_qc as register_iss_to_iss_qc { input: images=select_first([register_iss_t0.moving_output_url]), image_pattern=groupby_pattern, - channel=select_first([iss_dapi_channel, 0]), + dapi_channel=select_first([iss_dapi_channel, 0]), + n_timepoints=length(times_iss), label_type='nuclei', - channel_prefix="ISS", + output_directory=register_iss_to_iss_qc_directory, labels=register_pheno_to_iss.label_output_url, subset = subset_, From 5f88d479781f3f205ae3b7c90fc8787ffea02214 Mon Sep 17 00:00:00 2001 From: Joshua Gould Date: Tue, 16 Jun 2026 10:52:24 -0400 Subject: [PATCH 17/21] Segmentation time --- wdl/ops_tasks.wdl | 4 ++-- wdl/ops_workflow.wdl | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/wdl/ops_tasks.wdl b/wdl/ops_tasks.wdl index 3ec391b..4f50e53 100644 --- a/wdl/ops_tasks.wdl +++ b/wdl/ops_tasks.wdl @@ -460,7 +460,7 @@ task features { Int? cytosol_max_area String? features_extra_arguments String? model_dir - String? labels + Array[String] labels String? objects String images String subset @@ -498,7 +498,7 @@ task features { ~{if defined(cell_max_area) && select_first([cell_max_area])>0 then '--cell-max-area ' + cell_max_area else ''} \ ~{if defined(cytosol_max_area) && select_first([cytosol_max_area])>0 then '--cytosol-max-area ' + cytosol_max_area else ''} \ ~{if defined(features_extra_arguments) then features_extra_arguments else ''} \ - --labels "~{labels}" \ + --labels ~{sep=" " labels} \ ~{"--objects " + objects} \ ~{"--label-filter " + '"' + label_filter + '"'} \ --subset ~{subset} \ diff --git a/wdl/ops_workflow.wdl b/wdl/ops_workflow.wdl index 100f474..2974fda 100644 --- a/wdl/ops_workflow.wdl +++ b/wdl/ops_workflow.wdl @@ -697,7 +697,7 @@ workflow ops_workflow { Array[String] phenotype_cytosol_times = keys(phenotype_cytosol_features_) scatter (phenotype_time in phenotype_cytosol_times) { - Array[String] cytosol_features = phenotype_cytosol_features_[phenotype_time] + Array[String] cytosol_features = phenotype_cytosol_features_[phenotype_time] scatter (feature_index in range(length(cytosol_features))) { call tasks.features as features_cytosol { input: From 970a5bd8c3983ed21f4fca229a811509374f42db Mon Sep 17 00:00:00 2001 From: Joshua Gould Date: Tue, 16 Jun 2026 11:29:01 -0400 Subject: [PATCH 18/21] Segmentation time --- wdl/ops_workflow.wdl | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/wdl/ops_workflow.wdl b/wdl/ops_workflow.wdl index 2974fda..6133f9f 100644 --- a/wdl/ops_workflow.wdl +++ b/wdl/ops_workflow.wdl @@ -626,15 +626,16 @@ workflow ops_workflow { objects=merge_sbs_metadata.output_url, labels=select_all([segment_nuclei.output_url, register_pheno_to_pheno.label_output_url]), label_filter=features_label_filter, + groupby=phenotype_group_by_with_time, nuclei_features = nuclei_features[feature_index], nuclei_min_area = features_nuclei_min_area_, nuclei_max_area = features_nuclei_max_area_, features_extra_arguments=features_extra_arguments, model_dir=model_dir, - groupby=groupby, + output_directory=nuclei_features_directory + '-' + phenotype_time + '-' + feature_index, - subset = subset_, + subset = subset_ + '-' + phenotype_time, force = force_features, docker=docker, zones = zones, @@ -666,15 +667,16 @@ workflow ops_workflow { objects=merge_sbs_metadata.output_url, labels=select_all([segment_cell.output_url, register_pheno_to_pheno.label_output_url]), label_filter=features_label_filter, + groupby=phenotype_group_by_with_time, cell_features = cell_features[feature_index], cell_min_area = features_cell_min_area_, cell_max_area = features_cell_max_area_, features_extra_arguments=features_extra_arguments, model_dir=model_dir, - groupby=groupby, + output_directory=cell_features_directory + '-' + phenotype_time + '-' + feature_index, - subset = subset_, + subset = subset_ + '-' + phenotype_time, force = force_features, docker=docker, zones = zones, @@ -706,6 +708,7 @@ workflow ops_workflow { objects=merge_sbs_metadata.output_url, labels=select_all([segment_cell.output_url, register_pheno_to_pheno.label_output_url]), label_filter=features_label_filter, + groupby=phenotype_group_by_with_time, output_directory=cytosol_features_directory + '-' + phenotype_time + '-' + feature_index, cytosol_features = cytosol_features[feature_index], cytosol_min_area = features_cytosol_min_area_, @@ -713,9 +716,8 @@ workflow ops_workflow { features_extra_arguments=features_extra_arguments, model_dir=model_dir, - groupby=groupby, - subset = subset_, + subset = subset_ + '-' + phenotype_time, force = force_features, docker=docker, zones = zones, From eda1e3874b3f579c460cd7535f56e1905ebcb904 Mon Sep 17 00:00:00 2001 From: Joshua Gould Date: Wed, 17 Jun 2026 12:42:24 -0400 Subject: [PATCH 19/21] Segmentation time --- scallops/cli/extract_crops.py | 245 ++++++++++++++++------------- scallops/cli/extract_crops_main.py | 24 +-- scallops/cli/features.py | 141 +++++++++-------- scallops/cli/features_main.py | 4 +- scallops/features/generate.py | 2 +- 5 files changed, 219 insertions(+), 197 deletions(-) diff --git a/scallops/cli/extract_crops.py b/scallops/cli/extract_crops.py index f8f089a..79a6796 100644 --- a/scallops/cli/extract_crops.py +++ b/scallops/cli/extract_crops.py @@ -8,9 +8,13 @@ import pyarrow.parquet as pq from array_api_compat import get_namespace from skimage.util import img_as_ubyte -from zarr import Group -from scallops.cli.features import _read_merged_or_objects +from scallops.cli.features import ( + _find_labels, + _image_key_without_time_and_selected_time, + _read_merged_or_objects, +) +from scallops.cli.find_objects import get_path from scallops.cli.util import ( _get_cli_logger, cli_metadata, @@ -27,23 +31,25 @@ logger = _get_cli_logger() -def _norm_block(image: np.ndarray, percentiles) -> np.ndarray: - xp = get_namespace(image) - percentiles = xp.percentile(image, percentiles, axis=(1, 2), keepdims=True) - image = (image - percentiles[0]) / (percentiles[1] - percentiles[0]) - image = xp.clip(image, 0, 1) - return image +def _norm_block(intensity_image: np.ndarray, percentiles) -> np.ndarray: + xp = get_namespace(intensity_image) + percentiles = xp.percentile( + intensity_image, percentiles, axis=(1, 2), keepdims=True + ) + intensity_image = (intensity_image - percentiles[0]) / ( + percentiles[1] - percentiles[0] + ) + intensity_image = xp.clip(intensity_image, 0, 1) + return intensity_image def single_crop( group: str, # NOT USED file_list: list[str], metadata: dict, - labels_group: Group, + label_paths: list[str], output_dir: str, - output_sep: str, - merge_dir: str, - merge_dir_sep: str, + merge_dirs: list[str], crop_size: tuple[int, int], output_format: Literal["tiff", "npy"], label_name: str, @@ -57,108 +63,135 @@ def single_crop( force: bool, ): image_key = metadata["id"] + image = _images2fov(file_list, metadata, dask=True).squeeze().data - output_dir = f"{output_dir}{output_sep}{label_name}{output_sep}{image_key}" + image_key_no_t, selected_timepoint = _image_key_without_time_and_selected_time( + metadata + ) - output_parquet_path = f"{output_dir}.parquet" - if not force and is_parquet_file(output_parquet_path): - logger.info(f"Skipping features for {image_key} {label_name}") + g, timepoints = _find_labels( + label_paths=label_paths, + image_key=image_key, + label_name=label_name, + image_key_no_t=image_key_no_t, + selected_timepoint=selected_timepoint, + ) + if g is None: + logger.info(f"No labels found for {image_key}") return - + labels_array = da.from_array(g[list(g.keys())[0]]) output_fs, _ = fsspec.core.url_to_fs(output_dir) - output_fs.makedirs(output_dir, exist_ok=True) - image = _images2fov(file_list, metadata, dask=True).squeeze().data - logger.info(f"{image_key} image shape {image.shape}") - g = labels_group.get(f"{image_key}-{label_name}") - if g is None: - raise ValueError(f"Unable to read {label_name} labels for {image_key}.") - zarr_labels = g[list(g.keys())[0]] + output_sep = output_fs.sep + for timepoint in timepoints: + features_path = get_path( + output_dir, output_sep, label_name, image_key, timepoint, ".parquet" + ) + output_dir = get_path( + output_dir, output_sep, label_name, image_key, timepoint, "" + ) - merged_df = _read_merged_or_objects( - merge_dir=merge_dir, - merge_dir_sep=merge_dir_sep, - label_name=label_name, - image_key=image_key, - label_filter=label_filter, - ) - if merged_df is None: - raise ValueError(f"Unable to read merged data for {image_key}.") - n_labels_before_filtering = len(merged_df) - if label_filter is not None: - merged_df = merged_df.query(label_filter) - label_prefix = _label_name_to_prefix[label_name] - area_column = f"{label_prefix}_AreaShape_Area" - merged_df = merged_df.query(f"{area_column}>=2") - n_labels_filtered = n_labels_before_filtering - len(merged_df) - logger.info( - f"Removed {n_labels_filtered:,} out of {n_labels_before_filtering:,} labels for {image_key}." - ) - if len(merged_df) == 0: - raise ValueError(f"No labels found for {image_key}.") - # e.g. CHAMMI-75 + if not force and is_parquet_file(features_path): + logger.info( + f"Skipping crops for {image_key} {label_name}{' at t=' + timepoint if timepoint is not None else ''}." + ) + continue + if timepoint is not None and labels_array.ndim == 3: + timepoint_index = timepoints.index(timepoint) + label_image = labels_array[timepoint_index] + else: + label_image = labels_array + intensity_image = ( + image.sel(t=timepoint) + if timepoint is not None and image.sizes.get("t", 0) > 1 + else image + ) + merged_df = _read_merged_or_objects( + paths=merge_dirs, + label_name=label_name, + timepoint=timepoint, + image_key=image_key, + label_filter=label_filter, + ) + if merged_df is None: + raise ValueError(f"Unable to read merged data for {image_key}.") + n_labels_before_filtering = len(merged_df) + if label_filter is not None: + merged_df = merged_df.query(label_filter) + label_prefix = _label_name_to_prefix[label_name] + area_column = f"{label_prefix}_AreaShape_Area" + merged_df = merged_df.query(f"{area_column}>=2") + n_labels_filtered = n_labels_before_filtering - len(merged_df) + logger.info( + f"Removed {n_labels_filtered:,} out of {n_labels_before_filtering:,} labels for {image_key}." + ) + if len(merged_df) == 0: + raise ValueError(f"No labels found for {image_key}.") + # e.g. CHAMMI-75 - if percentile_normalize is not None: - if local_percentile_normalize: - chunksize = list(image.chunksize) - for i in range(len(chunksize) - 2): - chunksize[i] = -1 - if chunks is not None: - chunksize[-2] = chunks - chunksize[-1] = chunks + if percentile_normalize is not None: + if local_percentile_normalize: + chunksize = list(intensity_image.chunksize) + for i in range(len(chunksize) - 2): + chunksize[i] = -1 + if chunks is not None: + chunksize[-2] = chunks + chunksize[-1] = chunks + else: + logger.info( + f"{image_key} chunk size: {chunksize[-2]:,} by {chunksize[-1]:,}" + ) + intensity_image = intensity_image.rechunk(tuple(chunksize)) + depth = None + if local_normalize_overlap is not None and local_normalize_overlap > 0: + depth = { + intensity_image.ndim - 2: local_normalize_overlap, + intensity_image.ndim - 1: local_normalize_overlap, + } + intensity_image = da.map_overlap( + _norm_block, + intensity_image, + percentiles=percentile_normalize, + depth=depth, + dtype=float, + ) else: - logger.info( - f"{image_key} chunk size: {chunksize[-2]:,} by {chunksize[-1]:,}" + percentiles = da.percentile( + intensity_image, percentile_normalize, axis=(1, 2), keepdims=True ) - image = image.rechunk(tuple(chunksize)) - depth = None - if local_normalize_overlap is not None and local_normalize_overlap > 0: - depth = { - image.ndim - 2: local_normalize_overlap, - image.ndim - 1: local_normalize_overlap, - } - image = da.map_overlap( - _norm_block, - image, - percentiles=percentile_normalize, - depth=depth, - dtype=float, - ) - else: - percentiles = da.percentile( - image, percentile_normalize, axis=(1, 2), keepdims=True - ) - image = (image - percentiles[0]) / (percentiles[1] - percentiles[0]) - image = da.clip(image, 0, 1) - image = da.map_blocks(img_as_ubyte, image) - label_col = "label" if "label" in merged_df.columns else None - merged_df = to_label_crops( - label_image=da.from_zarr(zarr_labels), - intensity_image=image, - df=merged_df, - label_col=label_col, - output_dir=output_dir, - crop_size=crop_size, - output_format=output_format, - centroid_cols=[ - f"{label_prefix}_AreaShape_Center_Y", - f"{label_prefix}_AreaShape_Center_X", - ], - gaussian_sigma=gaussian_sigma, - ) + intensity_image = (intensity_image - percentiles[0]) / ( + percentiles[1] - percentiles[0] + ) + intensity_image = da.clip(intensity_image, 0, 1) + intensity_image = da.map_blocks(img_as_ubyte, intensity_image) + label_col = "label" if "label" in merged_df.columns else None + merged_df = to_label_crops( + label_image=label_image, + intensity_image=intensity_image, + df=merged_df, + label_col=label_col, + output_dir=output_dir, + crop_size=crop_size, + output_format=output_format, + centroid_cols=[ + f"{label_prefix}_AreaShape_Center_Y", + f"{label_prefix}_AreaShape_Center_X", + ], + gaussian_sigma=gaussian_sigma, + ) - output_metadata = cli_metadata() if not no_version else dict() + output_metadata = cli_metadata() if not no_version else dict() - table = pa.Table.from_pandas(merged_df, preserve_index=True) - table = table.replace_schema_metadata( - { - "scallops".encode(): json.dumps(output_metadata).encode(), - **table.schema.metadata, - } - ) + table = pa.Table.from_pandas(merged_df, preserve_index=True) + table = table.replace_schema_metadata( + { + "scallops".encode(): json.dumps(output_metadata).encode(), + **table.schema.metadata, + } + ) - fs, output_parquet_path = fsspec.url_to_fs(output_parquet_path) - pq.write_table( - table, - output_parquet_path, - filesystem=fs, - ) + fs, features_path = fsspec.url_to_fs(features_path) + pq.write_table( + table, + features_path, + filesystem=fs, + ) diff --git a/scallops/cli/extract_crops_main.py b/scallops/cli/extract_crops_main.py index db8afa0..9481be8 100644 --- a/scallops/cli/extract_crops_main.py +++ b/scallops/cli/extract_crops_main.py @@ -1,7 +1,5 @@ import argparse -import fsspec -import zarr from dask.bag import from_sequence from scallops.cli.arg_parser import _sort_groups @@ -39,12 +37,13 @@ def run_pipeline_extract_crops(arguments: argparse.Namespace): image_patterns = arguments.image_pattern output_dir = arguments.output - merge_dir = arguments.merge + merge_dirs = arguments.merge subset = arguments.subset force = arguments.force groupby = arguments.groupby crop_size = arguments.crop_size crop_size = (crop_size, crop_size) + label_paths = arguments.labels label_filter = arguments.label_filter percentile_min = arguments.percentile_min @@ -69,20 +68,7 @@ def run_pipeline_extract_crops(arguments: argparse.Namespace): if dask_server_url is None and arguments.dask_cluster is None: dask_cluster_parameters = _dask_workers_threads() - merge_dir_sep = None - if merge_dir is not None: - merge_dir_sep = fsspec.core.url_to_fs(merge_dir)[0].sep - merge_dir = merge_dir.rstrip(merge_dir_sep) - - output_fs, _ = fsspec.core.url_to_fs(output_dir) - output_dir = output_dir.rstrip(output_fs.sep) - - labels_path = arguments.labels no_version = arguments.no_version - assert labels_path is not None, "No labels provided" - label_root = zarr.open(labels_path, mode="r") - labels_group = label_root["labels"] - image_seq = from_sequence( _set_up_experiment( images_paths, @@ -101,10 +87,8 @@ def run_pipeline_extract_crops(arguments: argparse.Namespace): image_seq.starmap( single_crop, output_dir=output_dir, - output_sep=output_fs.sep, - merge_dir=merge_dir, - merge_dir_sep=merge_dir_sep, - labels_group=labels_group, + merge_dirs=merge_dirs, + label_paths=label_paths, label_filter=label_filter, label_name=label_name, percentile_normalize=percentile_normalize, diff --git a/scallops/cli/features.py b/scallops/cli/features.py index 0e2c990..a093a32 100644 --- a/scallops/cli/features.py +++ b/scallops/cli/features.py @@ -57,51 +57,61 @@ def _read_merged_or_objects( - path: str, - path_sep: str, + paths: list[str], timepoint: str | None, label_name: str, image_key: str, label_filter: str | None, ): - path = path.rstrip(path_sep) - paths = [ - f"{path}{path_sep}{label_name}{path_sep}{image_key}.parquet", - f"{path}{path_sep}{image_key}.zarr", - f"{path}{path_sep}{image_key}.parquet", - get_path(path, path_sep, label_name, image_key, timepoint), - ] - - merge_path = None + found_paths = [] for path in paths: - if fsspec.core.url_to_fs(path)[0].exists(path): - merge_path = path - break - if merge_path is None: - return None + path_sep = fsspec.core.url_to_fs(path)[0].sep + path = path.rstrip(path_sep) + + test_paths = [ + f"{path}{path_sep}{label_name}{path_sep}{image_key}.parquet", + f"{path}{path_sep}{image_key}.zarr", + f"{path}{path_sep}{image_key}.parquet", + get_path(path, path_sep, label_name, image_key, timepoint), + ] - area_column = f"{_label_name_to_prefix[label_name]}_AreaShape_Area" - if merge_path.lower().endswith(".zarr"): - data = read_anndata_zarr(merge_path, dask=True) - merged_df = data.obs - columns = {area_column} - assert area_column in data.var.index - if label_filter is not None: - query_columns = _get_names_from_pd_query(label_filter) - columns.update( - c - for c in query_columns - if c not in merged_df.columns and c in data.var.index - ) - columns = list(columns) - values = data[:, columns].X.compute() - for i in range(len(columns)): - merged_df[columns[i]] = values[:, i] + for test_path in test_paths: + if fsspec.core.url_to_fs(path)[0].exists(test_path): + found_paths.append(test_path) - else: - merged_df = pd.read_parquet(merge_path) + if len(found_paths) == 0: + return None - return merged_df + area_column = f"{_label_name_to_prefix[label_name]}_AreaShape_Area" + merged_dfs = [] + for path in found_paths: + if path.lower().endswith(".zarr"): + data = read_anndata_zarr(path, dask=True) + merged_df = data.obs + columns = {area_column} + assert area_column in data.var.index + if label_filter is not None: + query_columns = _get_names_from_pd_query(label_filter) + columns.update( + c + for c in query_columns + if c not in merged_df.columns and c in data.var.index + ) + columns = list(columns) + values = data[:, columns].X.compute() + for i in range(len(columns)): + merged_df[columns[i]] = values[:, i] + + else: + merged_df = pd.read_parquet(path) + if "label" in merged_df.columns: + merged_df = merged_df.set_index("label") + merged_dfs.append(merged_df) + return ( + merged_dfs[0] + if len(merged_dfs) == 1 + else pd.concat(merged_dfs, axis=1, join="inner") + ) def _get_feature_channel_indices(tokens): @@ -180,14 +190,27 @@ def _stack_and_rename(image: xr.DataArray) -> xr.DataArray: ) +def _image_key_without_time_and_selected_time(metadata): + image_key_no_t = None + selected_timepoint = None + if "t" in metadata["group_metadata"]["group"]: + image_key_no_t = [] + for key in metadata["group_metadata"]["group"]: + if key != "t": + image_key_no_t.append(str(metadata["group_metadata"]["group"][key])) + else: + selected_timepoint = metadata["group_metadata"]["group"][key] + image_key_no_t = "-".join(image_key_no_t).replace("/", "-") + + return image_key_no_t, selected_timepoint + + def single_feature( stacked_image_tuple: tuple[tuple[str, ...], list[str | Group], dict] | None, image_tuple: tuple[tuple[str, ...], list[str | Group], dict], label_paths: list[str], output_dir: str, - output_sep: str, - objects_dir: str, - objects_dir_sep: str, + merge_paths: list[str], label_name_to_features: dict[str, set[str]], label_name_to_min_max_area: dict[str, tuple[float | None, float | None]], features_plot: set[str], @@ -215,13 +238,10 @@ def single_feature( - A tuple of metadata strings. - A list of file paths or Zarr groups containing the primary image data. - A dictionary with additional metadata. - :param labels_group: Zarr group containing labels used to identify regions of interest in the image + :param label_paths: Zarr paths containing labels used to identify regions of interest in the image for feature computation. :param output_dir: Directory path where the computed feature files will be saved. - :param output_sep: Separator string used to construct the output file names. This helps in organizing - the output files systematically. - :param objects_dir: Directory path containing find objects output. - :param objects_dir_sep: File separator for `objects_dir` + :param merge_paths: Directory path containing output to merge :param label_name_to_features: Dictionary mapping label names (keys) to sets of feature names (values). Label names correspond to components in the labeled image (e.g. nuclei), and feature names specify the features to compute. @@ -247,7 +267,8 @@ def single_feature( ) output_fs, _ = fsspec.core.url_to_fs(output_dir) - + output_sep = output_fs.sep + output_dir = output_dir.rstrip(output_fs.sep) image = _images2fov(file_list, metadata, dask=True) n_channels1 = None @@ -255,18 +276,9 @@ def single_feature( if stacked_image_tuple is not None: stacked_image = _images2fov(stacked_file_list, stacked_metadata, dask=True) n_channels1 = image.sizes["c"] - - image_key_no_t = None - selected_timepoint = None - if "t" in metadata["group_metadata"]["group"]: - image_key_no_t = [] - for key in metadata["group_metadata"]["group"]: - if key != "t": - image_key_no_t.append(str(metadata["group_metadata"]["group"][key])) - else: - selected_timepoint = metadata["group_metadata"]["group"][key] - image_key_no_t = "-".join(image_key_no_t).replace("/", "-") - + image_key_no_t, selected_timepoint = _image_key_without_time_and_selected_time( + metadata + ) for label_name in label_name_to_features: label_prefix = _label_name_to_prefix[label_name] features = label_name_to_features[label_name] @@ -307,11 +319,10 @@ def single_feature( continue merged_df = None - if objects_dir is not None: + if len(merge_paths) > 0: merged_df = _read_merged_or_objects( - path=objects_dir, + paths=merge_paths, timepoint=timepoint, - path_sep=objects_dir_sep, label_name=label_name, image_key=image_key, label_filter=label_filter, @@ -493,7 +504,7 @@ def run_pipeline_compute_features(arguments: argparse.Namespace) -> None: image_patterns = arguments.image_pattern output_dir = arguments.output - objects_dir = arguments.objects + merge_paths = arguments.merge subset = arguments.subset force = arguments.force groupby = arguments.groupby @@ -527,10 +538,6 @@ def run_pipeline_compute_features(arguments: argparse.Namespace) -> None: threads_per_worker=4 if "sizeshape" in unique_features else 1 ) - objects_dir_sep = None - if objects_dir is not None: - objects_dir_sep = fsspec.core.url_to_fs(objects_dir)[0].sep - objects_dir = objects_dir.rstrip(objects_dir_sep) label_name_to_min_max_area = dict( nuclei=[arguments.nuclei_min_area, arguments.nuclei_max_area], cytosol=[arguments.cytosol_min_area, arguments.cytosol_max_area], @@ -581,9 +588,7 @@ def run_pipeline_compute_features(arguments: argparse.Namespace) -> None: img_tuple[0], img_tuple[1], output_dir=output_dir, - output_sep=output_fs.sep, - objects_dir=objects_dir, - objects_dir_sep=objects_dir_sep, + merge_paths=merge_paths, label_paths=label_paths, label_filter=label_filter, label_name_to_min_max_area=label_name_to_min_max_area, diff --git a/scallops/cli/features_main.py b/scallops/cli/features_main.py index 0117d13..0062e86 100644 --- a/scallops/cli/features_main.py +++ b/scallops/cli/features_main.py @@ -121,8 +121,8 @@ def new_format_help(x): ) required.add_argument( "--merge", - required=False, - help="Path to directory containing output from `merge`", + nargs="*", + help="Path(s) to directory containing output from `merge`", ) parser.add_argument( "--label-filter", diff --git a/scallops/features/generate.py b/scallops/features/generate.py index 036bd62..df20ba7 100644 --- a/scallops/features/generate.py +++ b/scallops/features/generate.py @@ -134,7 +134,7 @@ def label_features( ) -> dd.DataFrame | pd.DataFrame: """Extract features from labeled regions in the image. - :param objects_df: Data frame containing labeled regions from `find_objects`. + :param objects_df: Data frame containing labeled regions from `find_objects` with frame index set to label id. :param label_image: Labeled regions. :param intensity_image: Intensity image with dimensions (y, x, c) or zarr array(s) with dimensions with leading dimensions unrolled to channel dimension From 6a7cd74e3c0cfd3ef4ad563d44f59e2f14031958 Mon Sep 17 00:00:00 2001 From: Joshua Gould Date: Wed, 17 Jun 2026 13:02:15 -0400 Subject: [PATCH 20/21] Segmentation time --- scallops/cli/features.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scallops/cli/features.py b/scallops/cli/features.py index a093a32..d4fd211 100644 --- a/scallops/cli/features.py +++ b/scallops/cli/features.py @@ -210,7 +210,7 @@ def single_feature( image_tuple: tuple[tuple[str, ...], list[str | Group], dict], label_paths: list[str], output_dir: str, - merge_paths: list[str], + merge_paths: list[str] | None, label_name_to_features: dict[str, set[str]], label_name_to_min_max_area: dict[str, tuple[float | None, float | None]], features_plot: set[str], @@ -319,7 +319,7 @@ def single_feature( continue merged_df = None - if len(merge_paths) > 0: + if merge_paths is not None and len(merge_paths) > 0: merged_df = _read_merged_or_objects( paths=merge_paths, timepoint=timepoint, From c376bccf0f8dc1d0b168e23a0bf434f73cc7cc1d Mon Sep 17 00:00:00 2001 From: Joshua Gould Date: Thu, 18 Jun 2026 14:11:14 -0400 Subject: [PATCH 21/21] Segmentation time --- scallops/cli/extract_crops.py | 7 ++-- scallops/cli/extract_crops_main.py | 2 +- scallops/cli/features.py | 66 +++++++++++++++++++++--------- scallops/cli/features_main.py | 7 +--- scallops/cli/find_objects.py | 13 ++++-- scallops/cli/pooled_if_sbs.py | 66 ++++++++++++++++-------------- wdl/ops_tasks.wdl | 6 +-- wdl/ops_workflow.wdl | 56 +++++++++++-------------- 8 files changed, 127 insertions(+), 96 deletions(-) diff --git a/scallops/cli/extract_crops.py b/scallops/cli/extract_crops.py index 79a6796..b8a07f8 100644 --- a/scallops/cli/extract_crops.py +++ b/scallops/cli/extract_crops.py @@ -65,7 +65,7 @@ def single_crop( image_key = metadata["id"] image = _images2fov(file_list, metadata, dask=True).squeeze().data - image_key_no_t, selected_timepoint = _image_key_without_time_and_selected_time( + image_key_without_t, selected_timepoint = _image_key_without_time_and_selected_time( metadata ) @@ -73,7 +73,7 @@ def single_crop( label_paths=label_paths, image_key=image_key, label_name=label_name, - image_key_no_t=image_key_no_t, + image_key_without_t=image_key_without_t, selected_timepoint=selected_timepoint, ) if g is None: @@ -110,10 +110,11 @@ def single_crop( label_name=label_name, timepoint=timepoint, image_key=image_key, + image_key_without_t=image_key_without_t, label_filter=label_filter, ) if merged_df is None: - raise ValueError(f"Unable to read merged data for {image_key}.") + raise ValueError(f"Unable to read metadata for {image_key}.") n_labels_before_filtering = len(merged_df) if label_filter is not None: merged_df = merged_df.query(label_filter) diff --git a/scallops/cli/extract_crops_main.py b/scallops/cli/extract_crops_main.py index 9481be8..dc2c823 100644 --- a/scallops/cli/extract_crops_main.py +++ b/scallops/cli/extract_crops_main.py @@ -66,7 +66,7 @@ def run_pipeline_extract_crops(arguments: argparse.Namespace): label_name = arguments.label_name # cell, cytosol, nuclei chunks = arguments.chunks if dask_server_url is None and arguments.dask_cluster is None: - dask_cluster_parameters = _dask_workers_threads() + dask_cluster_parameters = _dask_workers_threads(threads_per_worker=4) no_version = arguments.no_version image_seq = from_sequence( diff --git a/scallops/cli/features.py b/scallops/cli/features.py index d4fd211..af12a60 100644 --- a/scallops/cli/features.py +++ b/scallops/cli/features.py @@ -61,30 +61,46 @@ def _read_merged_or_objects( timepoint: str | None, label_name: str, image_key: str, + image_key_without_t: str | None, label_filter: str | None, ): - found_paths = [] + found_paths = [] # tuple of (path, time) + for path in paths: path_sep = fsspec.core.url_to_fs(path)[0].sep path = path.rstrip(path_sep) test_paths = [ - f"{path}{path_sep}{label_name}{path_sep}{image_key}.parquet", - f"{path}{path_sep}{image_key}.zarr", - f"{path}{path_sep}{image_key}.parquet", - get_path(path, path_sep, label_name, image_key, timepoint), + (f"{path}{path_sep}{label_name}{path_sep}{image_key}.parquet", None), + (f"{path}{path_sep}{image_key}.zarr", None), + (f"{path}{path_sep}{image_key}.parquet", None), + ( + get_path( + path, + path_sep, + label_name, + image_key_without_t + if image_key_without_t is not None + else image_key, + timepoint, + "-objects.parquet", + ), + timepoint if image_key_without_t is not None else None, + ), ] - for test_path in test_paths: + for test_path, t in test_paths: if fsspec.core.url_to_fs(path)[0].exists(test_path): - found_paths.append(test_path) + logger.info(f"Reading {test_path}") + found_paths.append((test_path, t)) if len(found_paths) == 0: return None area_column = f"{_label_name_to_prefix[label_name]}_AreaShape_Area" merged_dfs = [] - for path in found_paths: + # add suffix for time specific paths + for path, t in found_paths: if path.lower().endswith(".zarr"): data = read_anndata_zarr(path, dask=True) merged_df = data.obs @@ -106,6 +122,8 @@ def _read_merged_or_objects( merged_df = pd.read_parquet(path) if "label" in merged_df.columns: merged_df = merged_df.set_index("label") + if t is not None: + merged_df.columns = merged_df.columns + f"_{t}" merged_dfs.append(merged_df) return ( merged_dfs[0] @@ -142,7 +160,7 @@ def _find_labels( label_paths: list[str], image_key: str, label_name: str, - image_key_no_t: str | None, + image_key_without_t: str | None, selected_timepoint: Any, ): timepoints = None @@ -159,8 +177,8 @@ def _find_labels( else [None] ) return g, timepoints - if g is None and image_key_no_t is not None: - g = labels_group.get(f"{image_key_no_t}-{label_name}") + if g is None and image_key_without_t is not None: + g = labels_group.get(f"{image_key_without_t}-{label_name}") zarr_metadata = g.attrs["multiscales"][0]["metadata"] if "t" in zarr_metadata: @@ -191,18 +209,20 @@ def _stack_and_rename(image: xr.DataArray) -> xr.DataArray: def _image_key_without_time_and_selected_time(metadata): - image_key_no_t = None + image_key_without_t = None selected_timepoint = None if "t" in metadata["group_metadata"]["group"]: - image_key_no_t = [] + image_key_without_t = [] for key in metadata["group_metadata"]["group"]: if key != "t": - image_key_no_t.append(str(metadata["group_metadata"]["group"][key])) + image_key_without_t.append( + str(metadata["group_metadata"]["group"][key]) + ) else: selected_timepoint = metadata["group_metadata"]["group"][key] - image_key_no_t = "-".join(image_key_no_t).replace("/", "-") + image_key_without_t = "-".join(image_key_without_t).replace("/", "-") - return image_key_no_t, selected_timepoint + return image_key_without_t, selected_timepoint def single_feature( @@ -276,7 +296,7 @@ def single_feature( if stacked_image_tuple is not None: stacked_image = _images2fov(stacked_file_list, stacked_metadata, dask=True) n_channels1 = image.sizes["c"] - image_key_no_t, selected_timepoint = _image_key_without_time_and_selected_time( + image_key_without_t, selected_timepoint = _image_key_without_time_and_selected_time( metadata ) for label_name in label_name_to_features: @@ -286,7 +306,7 @@ def single_feature( label_paths=label_paths, image_key=image_key, label_name=label_name, - image_key_no_t=image_key_no_t, + image_key_without_t=image_key_without_t, selected_timepoint=selected_timepoint, ) if g is None: @@ -325,8 +345,11 @@ def single_feature( timepoint=timepoint, label_name=label_name, image_key=image_key, + image_key_without_t=image_key_without_t, label_filter=label_filter, ) + if merged_df is None: + raise ValueError(f"Metadata not found for {image_key}") if timepoint is not None and labels_array.ndim == 3: timepoint_index = timepoints.index(timepoint) label_image = labels_array[timepoint_index] @@ -339,7 +362,12 @@ def single_feature( ) merged_df = find_objects(label_image) objects_path = get_path( - output_dir, output_sep, label_name, image_key, timepoint + output_dir, + output_sep, + label_name, + image_key, + timepoint, + "-objects.parquet", ) merged_df.index.name = "label" diff --git a/scallops/cli/features_main.py b/scallops/cli/features_main.py index 0062e86..9786ed0 100644 --- a/scallops/cli/features_main.py +++ b/scallops/cli/features_main.py @@ -108,11 +108,6 @@ def new_format_help(x): help=generic_features_help, ) - parser.add_argument( - "--objects", - required=False, - help="Path to directory containing output from `find-objects` or `merge`", - ) parser.add_argument( "--stack-images", help="Path to additional images to stack with `images`. Add `s` prefix to refer" @@ -122,7 +117,7 @@ def new_format_help(x): required.add_argument( "--merge", nargs="*", - help="Path(s) to directory containing output from `merge`", + help="Path(s) to directory containing output from `find-objects`, `merge`, or `features`", ) parser.add_argument( "--label-filter", diff --git a/scallops/cli/find_objects.py b/scallops/cli/find_objects.py index c746517..2501201 100644 --- a/scallops/cli/find_objects.py +++ b/scallops/cli/find_objects.py @@ -33,8 +33,8 @@ def get_path( output_sep: str, label_name: str, image_key: str, - timepoint: str | None, - suffix="-objects.parquet", + timepoint: str | None = None, + suffix="", ): return ( (f"{output_dir}{output_sep}{label_name}{output_sep}{image_key}{suffix}") @@ -55,7 +55,14 @@ def _execute( assert len(file_list) == 1 label_name = group[len(group) - 1] image_key = "-".join(group[:-1]) # exclude suffix from key - path = get_path(output_dir, output_sep, label_name, image_key, timepoint) + path = get_path( + output_dir, + output_sep, + label_name, + image_key, + timepoint, + suffix="-objects.parquet", + ) fs = fsspec.url_to_fs(path)[0] if fs.exists(path): if force: diff --git a/scallops/cli/pooled_if_sbs.py b/scallops/cli/pooled_if_sbs.py index ceb30bc..b207b29 100644 --- a/scallops/cli/pooled_if_sbs.py +++ b/scallops/cli/pooled_if_sbs.py @@ -572,7 +572,7 @@ def merge_sbs_phenotype_pipeline( prefixes = [] for i in range(len(phenotype_paths)): - df = dd.read_parquet(phenotype_paths[i]) + df = dd.read_parquet(path) _metadata_cols = df.columns[ df.columns.str.contains(_metadata_columns_whitelist_str) ].tolist() @@ -584,7 +584,7 @@ def merge_sbs_phenotype_pipeline( if phenotype_suffix is not None: df.columns = df.columns + phenotype_suffix[i] - prefixes.append(phenotype_paths[i].split("/")[-3]) + prefixes.append(path.split("/")[-3]) if output_format == "zarr": # read index and metadata if len(_metadata_cols) > 0: @@ -595,7 +595,7 @@ def merge_sbs_phenotype_pipeline( ) for key in rename_features: feature_names_i[feature_names_i.index(key)] = rename_features[key] - df = dd.read_parquet(phenotype_paths[i], columns=_metadata_cols) + df = dd.read_parquet(path, columns=_metadata_cols) feature_names += feature_names_i unique_columns.update(feature_names_i) @@ -671,34 +671,38 @@ def merge_sbs_phenotype_pipeline( ) -def _find_phenotype_paths( - phenotype_paths, phenotype_filesystems, phenotype_suffix, image_key -): - _phenotype_paths = [] - _phenotype_suffix = [] +def _find_phenotype_paths(paths: list[str], image_key: str, image_metadata: dict): + found_paths = [] + for path in paths: + path_ = path + path = path.format(**image_metadata["file_metadata"][0]) + fs, path = fsspec.url_to_fs(path) + if path_ == path and "*" not in path: # directory + # match */A1-*.parquet and */A1.parquet + sep = fs.sep + matches = fs.glob(f"{path}{sep}*{sep}{image_key}-*.parquet") + fs.glob( + f"{path}{sep}*{sep}{image_key}.parquet" + ) - for i in range(len(phenotype_paths)): - # match */A1-*.parquet and */A1.parquet - sep = phenotype_filesystems[i].sep - matches = phenotype_filesystems[i].glob( - f"{phenotype_paths[i]}{sep}*{sep}{image_key}-*.parquet" - ) + phenotype_filesystems[i].glob( - f"{phenotype_paths[i]}{sep}*{sep}{image_key}.parquet" - ) + if len(matches) == 0: + # match A1-*.parquet and A1.parquet + matches = fs.glob(f"{path}{sep}{image_key}-*.parquet") + fs.glob( + f"{path}{sep}{image_key}.parquet" + ) - if len(matches) == 0: - # match A1-*.parquet and A1.parquet - matches = phenotype_filesystems[i].glob( - f"{phenotype_paths[i]}{sep}{image_key}-*.parquet" - ) + phenotype_filesystems[i].glob( - f"{phenotype_paths[i]}{sep}{image_key}.parquet" - ) + for x in matches: + path.append(fs.unstrip_protocol(x)) + # if phenotype_suffix is not None: + # _phenotype_suffix.append(phenotype_suffix[i]) - for x in matches: - _phenotype_paths.append(phenotype_filesystems[i].unstrip_protocol(x)) - if phenotype_suffix is not None: - _phenotype_suffix.append(phenotype_suffix[i]) - return _phenotype_paths, _phenotype_suffix + else: + if "*" in path: + matches = fs.glob(path) + for match in matches: + found_paths.append(fs.unstrip_protocol(match)) + elif fs.exists(path): + found_paths.append(fs.unstrip_protocol(path)) + return found_paths def merge_main(arguments: argparse.Namespace): @@ -745,8 +749,9 @@ def merge_main(arguments: argparse.Namespace): if len(set(phenotype_paths)) != len(phenotype_paths): raise ValueError("Duplicate phenotype paths") for i in range(len(phenotype_paths)): - phenotype_fs, _ = fsspec.core.url_to_fs(phenotype_paths[i]) - phenotype_paths[i] = phenotype_paths[i].rstrip(phenotype_fs.sep) + path = phenotype_paths[i] + phenotype_fs, _ = fsspec.core.url_to_fs(path) + path = path.rstrip(phenotype_fs.sep) phenotype_filesystems.append(phenotype_fs) paths = [] @@ -757,6 +762,7 @@ def merge_main(arguments: argparse.Namespace): sbs = sbs.rstrip(sbs_fs.sep) sbs_matches = sbs_fs.glob(sbs + sbs_fs.sep + "*.parquet") sbs_matches = [sbs_fs.unstrip_protocol(m) for m in sbs_matches] + print("sbs_matches", sbs_matches) for sbs_path in sbs_matches: name = os.path.splitext(os.path.basename(sbs_path))[0] if not name.startswith("."): # ignore hidden files diff --git a/wdl/ops_tasks.wdl b/wdl/ops_tasks.wdl index 4f50e53..2a35a24 100644 --- a/wdl/ops_tasks.wdl +++ b/wdl/ops_tasks.wdl @@ -371,7 +371,7 @@ task intersects_boundary { --subset "~{subset}" \ --output "~{output_directory}" \ --images "~{images}" \ - --objects "~{objects}" \ + --merge "~{objects}" \ --no-normalize \ ~{'--image-pattern ' + image_pattern} \ ~{true="--force" false="" force} @@ -461,7 +461,7 @@ task features { String? features_extra_arguments String? model_dir Array[String] labels - String? objects + String? merge String images String subset Boolean? force @@ -498,8 +498,8 @@ task features { ~{if defined(cell_max_area) && select_first([cell_max_area])>0 then '--cell-max-area ' + cell_max_area else ''} \ ~{if defined(cytosol_max_area) && select_first([cytosol_max_area])>0 then '--cytosol-max-area ' + cytosol_max_area else ''} \ ~{if defined(features_extra_arguments) then features_extra_arguments else ''} \ + --merge ~{merge} \ --labels ~{sep=" " labels} \ - ~{"--objects " + objects} \ ~{"--label-filter " + '"' + label_filter + '"'} \ --subset ~{subset} \ ~{"--image-pattern " + image_pattern} \ diff --git a/wdl/ops_workflow.wdl b/wdl/ops_workflow.wdl index 6133f9f..b893206 100644 --- a/wdl/ops_workflow.wdl +++ b/wdl/ops_workflow.wdl @@ -77,6 +77,7 @@ workflow ops_workflow { String? cell_segmentation_extra_arguments Boolean mark_stitch_boundary_cells = true + String intersects_stitch_boundary_label = "cell" # nuclei # merge String? merge_extra_arguments @@ -151,9 +152,9 @@ workflow ops_workflow { String merge_memory = "256 GiB" String merge_disks = "local-disk 20 HDD" - Int cell_intersects_boundary_cpu = 16 - String cell_intersects_boundary_memory = "32 GiB" - String cell_intersects_boundary_disks = "local-disk 200 HDD" + Int intersects_boundary_cpu = 16 + String intersects_boundary_memory = "32 GiB" + String intersects_boundary_disks = "local-disk 200 HDD" String docker @@ -167,9 +168,7 @@ workflow ops_workflow { String register_iss_transforms_suffix = "iss-transforms-t0" String register_pheno_to_iss_suffix = "pheno-to-iss-registered.zarr" String register_pheno_to_iss_transforms_suffix = "pheno-to-iss-transforms" - String nuclei_objects_suffix = "objects-nuclei" - String cell_objects_suffix = "objects-cell" - String cytosol_objects_suffix = "objects-cytosol" + String objects_suffix = "objects" String nuclei_features_suffix = "features-nuclei" String cell_features_suffix = "features-cell" String cytosol_features_suffix = "features-cytosol" @@ -181,8 +180,7 @@ workflow ops_workflow { String reads_suffix = "reads" String merge_meta_suffix = "merge-sbs-metadata" String merge_features_suffix = "merge-features" - String cell_intersects_boundary_suffix = "intersects-boundary" - String cell_intersects_boundary_non_reference_t_suffix = "intersects-boundary-t" + String intersects_boundary_suffix = "intersects-boundary" } String output_stripped = sub(output_directory, "/+$", "") + "/" @@ -194,9 +192,7 @@ workflow ops_workflow { String nuclei_features_directory = output_stripped + nuclei_features_suffix String cell_features_directory = output_stripped + cell_features_suffix String cytosol_features_directory = output_stripped + cytosol_features_suffix - String nuclei_objects_directory = output_stripped + nuclei_objects_suffix - String cell_objects_directory = output_stripped + cell_objects_suffix - String cytosol_objects_directory = output_stripped + cytosol_objects_suffix + String objects_directory = output_stripped + objects_suffix String register_pheno_to_pheno_directory = output_stripped + register_pheno_to_pheno_suffix String register_pheno_to_pheno_transform_directory = output_stripped + register_pheno_to_pheno_transform_suffix String spot_detect_directory = output_stripped + spot_detect_suffix @@ -204,7 +200,7 @@ workflow ops_workflow { String merge_meta_directory = output_stripped + merge_meta_suffix String merge_features_directory = output_stripped + merge_features_suffix String register_pheno_to_iss_qc_directory = output_stripped + register_pheno_to_iss_qc_suffix - String cell_intersects_boundary_directory = output_stripped + cell_intersects_boundary_suffix + String intersects_boundary_directory = output_stripped + intersects_boundary_suffix Boolean iss_url_supplied = defined(iss_url) Boolean pheno_url_supplied = defined(phenotype_url) @@ -339,7 +335,7 @@ workflow ops_workflow { labels=select_all([segment_nuclei.output_url, register_pheno_to_pheno.label_output_url]), label_pattern=groupby_pattern, suffix="nuclei", - output_directory=nuclei_objects_directory, + output_directory=objects_directory, subset = subset_, force = force_find_objects, docker=docker, @@ -358,7 +354,7 @@ workflow ops_workflow { labels=select_all([segment_cell.output_url, register_pheno_to_pheno.label_output_url]), label_pattern=groupby_pattern, suffix="cell", - output_directory=cell_objects_directory, + output_directory=objects_directory, subset = subset_, force = force_find_objects, docker=docker, @@ -377,7 +373,7 @@ workflow ops_workflow { labels=select_all([segment_cell.output_url, register_pheno_to_pheno.label_output_url]), label_pattern=groupby_pattern, suffix="cytosol", - output_directory=cytosol_objects_directory, + output_directory=objects_directory, subset = subset_, force = force_find_objects, docker=docker, @@ -396,25 +392,25 @@ workflow ops_workflow { if(mark_stitch_boundary_cells) { String phenotype_url_stripped = sub(select_first([phenotype_url]), "/+$", "") - call tasks.intersects_boundary as cell_intersects_boundary { + call tasks.intersects_boundary as intersects_boundary { input: labels=select_all([segment_cell.output_url, register_pheno_to_pheno.label_output_url]), images=phenotype_url_stripped + '/labels/', image_pattern=phenotype_image_pattern + '-mask', - output_directory=cell_intersects_boundary_directory, - label_type='cell', - objects=find_objects_cell.output_url, + output_directory=intersects_boundary_directory, + label_type=intersects_stitch_boundary_label, + objects=if(intersects_stitch_boundary_label=='cell') then find_objects_cell.output_url else find_objects_nuclei.output_url, groupby=phenotype_group_by_with_time, subset = subset_ + '-*', - force = force_segment_cell, + force = if(intersects_stitch_boundary_label=='cell') then force_segment_cell else force_segment_nuclei, docker=docker, zones = zones, preemptible = preemptible, aws_queue_arn = aws_queue_arn, - disks = cell_intersects_boundary_disks, - memory = cell_intersects_boundary_memory, - cpu = cell_intersects_boundary_cpu, + disks = intersects_boundary_disks, + memory = intersects_boundary_memory, + cpu = intersects_boundary_cpu, max_retries = max_retries } @@ -586,10 +582,8 @@ workflow ops_workflow { call tasks.merge as merge_sbs_metadata { input: iss_reads=select_first([reads.output_url]) + '/labels', - objects_nuclei=find_objects_nuclei.output_url, # all rounds - objects_cell=find_objects_cell.output_url, - objects_cytosol=find_objects_cytosol.output_url, - cell_intersects_boundary=cell_intersects_boundary.output_url, + objects_nuclei=if(run_cell_segmentation) then find_objects_nuclei.output_url else find_objects_cell.output_url, + cell_intersects_boundary=intersects_boundary.output_url, register_pheno_to_iss_qc=register_pheno_to_iss_qc.output_url, register_iss_to_iss_qc=register_iss_to_iss_qc.output_url, barcodes=select_first([barcodes]), @@ -623,7 +617,7 @@ workflow ops_workflow { input: images = select_first([phenotype_url]), image_pattern=phenotype_image_pattern, - objects=merge_sbs_metadata.output_url, + merge=merge_sbs_metadata.output_url, labels=select_all([segment_nuclei.output_url, register_pheno_to_pheno.label_output_url]), label_filter=features_label_filter, groupby=phenotype_group_by_with_time, @@ -634,7 +628,7 @@ workflow ops_workflow { model_dir=model_dir, - output_directory=nuclei_features_directory + '-' + phenotype_time + '-' + feature_index, + output_directory=nuclei_features_directory + '-' + phenotype_time + '-batch' + feature_index, subset = subset_ + '-' + phenotype_time, force = force_features, docker=docker, @@ -664,7 +658,7 @@ workflow ops_workflow { input: images = select_first([phenotype_url]), image_pattern=phenotype_image_pattern, - objects=merge_sbs_metadata.output_url, + merge=merge_sbs_metadata.output_url, labels=select_all([segment_cell.output_url, register_pheno_to_pheno.label_output_url]), label_filter=features_label_filter, groupby=phenotype_group_by_with_time, @@ -705,7 +699,7 @@ workflow ops_workflow { input: images = select_first([phenotype_url]), image_pattern=phenotype_image_pattern, - objects=merge_sbs_metadata.output_url, + merge=merge_sbs_metadata.output_url, labels=select_all([segment_cell.output_url, register_pheno_to_pheno.label_output_url]), label_filter=features_label_filter, groupby=phenotype_group_by_with_time,