diff --git a/scallops/cli/extract_crops.py b/scallops/cli/extract_crops.py index 05e56b6..b8a07f8 100644 --- a/scallops/cli/extract_crops.py +++ b/scallops/cli/extract_crops.py @@ -8,9 +8,13 @@ import pyarrow.parquet as pq from array_api_compat import get_namespace from skimage.util import img_as_ubyte -from zarr import Group -from scallops.cli.features import _read_merged_or_objects, get_labels +from scallops.cli.features import ( + _find_labels, + _image_key_without_time_and_selected_time, + _read_merged_or_objects, +) +from scallops.cli.find_objects import get_path from scallops.cli.util import ( _get_cli_logger, cli_metadata, @@ -27,23 +31,25 @@ logger = _get_cli_logger() -def _norm_block(image: np.ndarray, percentiles) -> np.ndarray: - xp = get_namespace(image) - percentiles = xp.percentile(image, percentiles, axis=(1, 2), keepdims=True) - image = (image - percentiles[0]) / (percentiles[1] - percentiles[0]) - image = xp.clip(image, 0, 1) - return image +def _norm_block(intensity_image: np.ndarray, percentiles) -> np.ndarray: + xp = get_namespace(intensity_image) + percentiles = xp.percentile( + intensity_image, percentiles, axis=(1, 2), keepdims=True + ) + intensity_image = (intensity_image - percentiles[0]) / ( + percentiles[1] - percentiles[0] + ) + intensity_image = xp.clip(intensity_image, 0, 1) + return intensity_image def single_crop( group: str, # NOT USED file_list: list[str], metadata: dict, - labels_group: Group, + label_paths: list[str], output_dir: str, - output_sep: str, - merge_dir: str, - merge_dir_sep: str, + merge_dirs: list[str], crop_size: tuple[int, int], output_format: Literal["tiff", "npy"], label_name: str, @@ -57,111 +63,136 @@ def single_crop( force: bool, ): image_key = metadata["id"] - - output_dir = f"{output_dir}{output_sep}{label_name}{output_sep}{image_key}" - - output_parquet_path = f"{output_dir}.parquet" - if not force and is_parquet_file(output_parquet_path): - logger.info(f"Skipping features for {image_key} {label_name}") - return - - output_fs, _ = fsspec.core.url_to_fs(output_dir) - output_fs.makedirs(output_dir, exist_ok=True) image = _images2fov(file_list, metadata, dask=True).squeeze().data - logger.info(f"{image_key} image shape {image.shape}") - zarr_labels = get_labels( - labels_group=labels_group, - name=image_key, - suffix=label_name, # e.g. nuclei + + image_key_without_t, selected_timepoint = _image_key_without_time_and_selected_time( + metadata ) - if zarr_labels is None: - raise ValueError(f"Unable to read {label_name} labels for {image_key}.") - merged_df = _read_merged_or_objects( - merge_dir=merge_dir, - merge_dir_sep=merge_dir_sep, - label_name=label_name, + g, timepoints = _find_labels( + label_paths=label_paths, image_key=image_key, - label_filter=label_filter, - ) - if merged_df is None: - raise ValueError(f"Unable to read merged data for {image_key}.") - n_labels_before_filtering = len(merged_df) - if label_filter is not None: - merged_df = merged_df.query(label_filter) - label_prefix = _label_name_to_prefix[label_name] - area_column = f"{label_prefix}_AreaShape_Area" - merged_df = merged_df.query(f"{area_column}>=2") - n_labels_filtered = n_labels_before_filtering - len(merged_df) - logger.info( - f"Removed {n_labels_filtered:,} out of {n_labels_before_filtering:,} labels for {image_key}." + label_name=label_name, + image_key_without_t=image_key_without_t, + selected_timepoint=selected_timepoint, ) - if len(merged_df) == 0: - raise ValueError(f"No labels found for {image_key}.") - # e.g. CHAMMI-75 + if g is None: + logger.info(f"No labels found for {image_key}") + return + labels_array = da.from_array(g[list(g.keys())[0]]) + output_fs, _ = fsspec.core.url_to_fs(output_dir) + output_sep = output_fs.sep + for timepoint in timepoints: + features_path = get_path( + output_dir, output_sep, label_name, image_key, timepoint, ".parquet" + ) + output_dir = get_path( + output_dir, output_sep, label_name, image_key, timepoint, "" + ) - if percentile_normalize is not None: - if local_percentile_normalize: - chunksize = list(image.chunksize) - for i in range(len(chunksize) - 2): - chunksize[i] = -1 - if chunks is not None: - chunksize[-2] = chunks - chunksize[-1] = chunks - else: - logger.info( - f"{image_key} chunk size: {chunksize[-2]:,} by {chunksize[-1]:,}" - ) - image = image.rechunk(tuple(chunksize)) - depth = None - if local_normalize_overlap is not None and local_normalize_overlap > 0: - depth = { - image.ndim - 2: local_normalize_overlap, - image.ndim - 1: local_normalize_overlap, - } - image = da.map_overlap( - _norm_block, - image, - percentiles=percentile_normalize, - depth=depth, - dtype=float, + if not force and is_parquet_file(features_path): + logger.info( + f"Skipping crops for {image_key} {label_name}{' at t=' + timepoint if timepoint is not None else ''}." ) + continue + if timepoint is not None and labels_array.ndim == 3: + timepoint_index = timepoints.index(timepoint) + label_image = labels_array[timepoint_index] else: - percentiles = da.percentile( - image, percentile_normalize, axis=(1, 2), keepdims=True - ) - image = (image - percentiles[0]) / (percentiles[1] - percentiles[0]) - image = da.clip(image, 0, 1) - image = da.map_blocks(img_as_ubyte, image) - label_col = "label" if "label" in merged_df.columns else None - merged_df = to_label_crops( - label_image=da.from_zarr(zarr_labels), - intensity_image=image, - df=merged_df, - label_col=label_col, - output_dir=output_dir, - crop_size=crop_size, - output_format=output_format, - centroid_cols=[ - f"{label_prefix}_AreaShape_Center_Y", - f"{label_prefix}_AreaShape_Center_X", - ], - gaussian_sigma=gaussian_sigma, - ) + label_image = labels_array + intensity_image = ( + image.sel(t=timepoint) + if timepoint is not None and image.sizes.get("t", 0) > 1 + else image + ) + merged_df = _read_merged_or_objects( + paths=merge_dirs, + label_name=label_name, + timepoint=timepoint, + image_key=image_key, + image_key_without_t=image_key_without_t, + label_filter=label_filter, + ) + if merged_df is None: + raise ValueError(f"Unable to read metadata for {image_key}.") + n_labels_before_filtering = len(merged_df) + if label_filter is not None: + merged_df = merged_df.query(label_filter) + label_prefix = _label_name_to_prefix[label_name] + area_column = f"{label_prefix}_AreaShape_Area" + merged_df = merged_df.query(f"{area_column}>=2") + n_labels_filtered = n_labels_before_filtering - len(merged_df) + logger.info( + f"Removed {n_labels_filtered:,} out of {n_labels_before_filtering:,} labels for {image_key}." + ) + if len(merged_df) == 0: + raise ValueError(f"No labels found for {image_key}.") + # e.g. CHAMMI-75 + + if percentile_normalize is not None: + if local_percentile_normalize: + chunksize = list(intensity_image.chunksize) + for i in range(len(chunksize) - 2): + chunksize[i] = -1 + if chunks is not None: + chunksize[-2] = chunks + chunksize[-1] = chunks + else: + logger.info( + f"{image_key} chunk size: {chunksize[-2]:,} by {chunksize[-1]:,}" + ) + intensity_image = intensity_image.rechunk(tuple(chunksize)) + depth = None + if local_normalize_overlap is not None and local_normalize_overlap > 0: + depth = { + intensity_image.ndim - 2: local_normalize_overlap, + intensity_image.ndim - 1: local_normalize_overlap, + } + intensity_image = da.map_overlap( + _norm_block, + intensity_image, + percentiles=percentile_normalize, + depth=depth, + dtype=float, + ) + else: + percentiles = da.percentile( + intensity_image, percentile_normalize, axis=(1, 2), keepdims=True + ) + intensity_image = (intensity_image - percentiles[0]) / ( + percentiles[1] - percentiles[0] + ) + intensity_image = da.clip(intensity_image, 0, 1) + intensity_image = da.map_blocks(img_as_ubyte, intensity_image) + label_col = "label" if "label" in merged_df.columns else None + merged_df = to_label_crops( + label_image=label_image, + intensity_image=intensity_image, + df=merged_df, + label_col=label_col, + output_dir=output_dir, + crop_size=crop_size, + output_format=output_format, + centroid_cols=[ + f"{label_prefix}_AreaShape_Center_Y", + f"{label_prefix}_AreaShape_Center_X", + ], + gaussian_sigma=gaussian_sigma, + ) - output_metadata = cli_metadata() if not no_version else dict() + output_metadata = cli_metadata() if not no_version else dict() - table = pa.Table.from_pandas(merged_df, preserve_index=True) - table = table.replace_schema_metadata( - { - "scallops".encode(): json.dumps(output_metadata).encode(), - **table.schema.metadata, - } - ) + table = pa.Table.from_pandas(merged_df, preserve_index=True) + table = table.replace_schema_metadata( + { + "scallops".encode(): json.dumps(output_metadata).encode(), + **table.schema.metadata, + } + ) - fs, output_parquet_path = fsspec.url_to_fs(output_parquet_path) - pq.write_table( - table, - output_parquet_path, - filesystem=fs, - ) + fs, features_path = fsspec.url_to_fs(features_path) + pq.write_table( + table, + features_path, + filesystem=fs, + ) diff --git a/scallops/cli/extract_crops_main.py b/scallops/cli/extract_crops_main.py index db8afa0..dc2c823 100644 --- a/scallops/cli/extract_crops_main.py +++ b/scallops/cli/extract_crops_main.py @@ -1,7 +1,5 @@ import argparse -import fsspec -import zarr from dask.bag import from_sequence from scallops.cli.arg_parser import _sort_groups @@ -39,12 +37,13 @@ def run_pipeline_extract_crops(arguments: argparse.Namespace): image_patterns = arguments.image_pattern output_dir = arguments.output - merge_dir = arguments.merge + merge_dirs = arguments.merge subset = arguments.subset force = arguments.force groupby = arguments.groupby crop_size = arguments.crop_size crop_size = (crop_size, crop_size) + label_paths = arguments.labels label_filter = arguments.label_filter percentile_min = arguments.percentile_min @@ -67,22 +66,9 @@ def run_pipeline_extract_crops(arguments: argparse.Namespace): label_name = arguments.label_name # cell, cytosol, nuclei chunks = arguments.chunks if dask_server_url is None and arguments.dask_cluster is None: - dask_cluster_parameters = _dask_workers_threads() + dask_cluster_parameters = _dask_workers_threads(threads_per_worker=4) - merge_dir_sep = None - if merge_dir is not None: - merge_dir_sep = fsspec.core.url_to_fs(merge_dir)[0].sep - merge_dir = merge_dir.rstrip(merge_dir_sep) - - output_fs, _ = fsspec.core.url_to_fs(output_dir) - output_dir = output_dir.rstrip(output_fs.sep) - - labels_path = arguments.labels no_version = arguments.no_version - assert labels_path is not None, "No labels provided" - label_root = zarr.open(labels_path, mode="r") - labels_group = label_root["labels"] - image_seq = from_sequence( _set_up_experiment( images_paths, @@ -101,10 +87,8 @@ def run_pipeline_extract_crops(arguments: argparse.Namespace): image_seq.starmap( single_crop, output_dir=output_dir, - output_sep=output_fs.sep, - merge_dir=merge_dir, - merge_dir_sep=merge_dir_sep, - labels_group=labels_group, + merge_dirs=merge_dirs, + label_paths=label_paths, label_filter=label_filter, label_name=label_name, percentile_normalize=percentile_normalize, diff --git a/scallops/cli/features.py b/scallops/cli/features.py index 77917cc..af12a60 100644 --- a/scallops/cli/features.py +++ b/scallops/cli/features.py @@ -11,12 +11,11 @@ import warnings from collections.abc import Sequence from itertools import zip_longest -from typing import get_type_hints +from typing import Any, get_type_hints import dask.array import dask.array as da import fsspec -import numpy as np import pandas as pd import pyarrow as pa import pyarrow.parquet as pq @@ -26,6 +25,7 @@ from natsort import natsorted from zarr import Group +from scallops.cli.find_objects import get_path from scallops.cli.util import ( _create_dask_client, _create_default_dask_config, @@ -52,100 +52,84 @@ pluralize, read_anndata_zarr, ) -from scallops.zarr_io import _read_ome_zarr_array logger = _get_cli_logger() def _read_merged_or_objects( - merge_dir: str, - merge_dir_sep: str, + paths: list[str], + timepoint: str | None, label_name: str, image_key: str, + image_key_without_t: str | None, label_filter: str | None, ): - merge_dir = merge_dir.rstrip(merge_dir_sep) - merge_paths = [ - f"{merge_dir}{merge_dir_sep}{label_name}{merge_dir_sep}{image_key}.parquet", - f"{merge_dir}{merge_dir_sep}{image_key}.zarr", - f"{merge_dir}{merge_dir_sep}{image_key}.parquet", - f"{merge_dir}{merge_dir_sep}{label_name}{merge_dir_sep}{image_key}-objects.parquet", - ] - merge_path = None - for path in merge_paths: - if fsspec.core.url_to_fs(path)[0].exists(path): - merge_path = path - break - if merge_path is None: - return None - - area_column = f"{_label_name_to_prefix[label_name]}_AreaShape_Area" - if merge_path.lower().endswith(".zarr"): - data = read_anndata_zarr(merge_path, dask=True) - merged_df = data.obs - columns = {area_column} - assert area_column in data.var.index - if label_filter is not None: - query_columns = _get_names_from_pd_query(label_filter) - columns.update( - c - for c in query_columns - if c not in merged_df.columns and c in data.var.index - ) - columns = list(columns) - values = data[:, columns].X.compute() - for i in range(len(columns)): - merged_df[columns[i]] = values[:, i] - - else: - merged_df = pd.read_parquet(merge_path) - - return merged_df - - -def get_labels(labels_group: Group, name: str, suffix: str) -> zarr.Array | None: - """Retrieve labels from a zarr group. - - :param labels_group: The zarr group containing labels. - :param name: The identifier associated with the labels. - :param suffix: The suffix used to identify the specific set of labels (e.g., 'nuclei'). - :return: The retrieved labels as a DataArray or None if the labels are not found. - """ - try: - g = labels_group[f"{name}-{suffix}"] - return g[list(g.keys())[0]] - except KeyError as e: - logger.warning(f'"{name}-{suffix}" not found in {labels_group}.') - raise e - + found_paths = [] # tuple of (path, time) + + for path in paths: + path_sep = fsspec.core.url_to_fs(path)[0].sep + path = path.rstrip(path_sep) + + test_paths = [ + (f"{path}{path_sep}{label_name}{path_sep}{image_key}.parquet", None), + (f"{path}{path_sep}{image_key}.zarr", None), + (f"{path}{path_sep}{image_key}.parquet", None), + ( + get_path( + path, + path_sep, + label_name, + image_key_without_t + if image_key_without_t is not None + else image_key, + timepoint, + "-objects.parquet", + ), + timepoint if image_key_without_t is not None else None, + ), + ] -def _read_image(file_list: list[str], metadata: dict) -> xr.DataArray: - """Read image files and preprocess them into a standardized format. + for test_path, t in test_paths: + if fsspec.core.url_to_fs(path)[0].exists(test_path): + logger.info(f"Reading {test_path}") + found_paths.append((test_path, t)) - This function reads image files specified in the file_list and processes them into an - xarray.DataArray with dimensions adjusted as needed. It handles missing dimensions and stacks - time and channel dimensions for further processing. + if len(found_paths) == 0: + return None - :param file_list: List of file paths to the image files. - :param metadata: Dictionary containing metadata associated with the images. - :return: DataArray containing the processed image data. - """ - image = _images2fov(file_list, metadata, dask=True) - dims = tuple([d for d in ["t", "c", "z"] if d in image.dims]) + area_column = f"{_label_name_to_prefix[label_name]}_AreaShape_Area" + merged_dfs = [] + # add suffix for time specific paths + for path, t in found_paths: + if path.lower().endswith(".zarr"): + data = read_anndata_zarr(path, dask=True) + merged_df = data.obs + columns = {area_column} + assert area_column in data.var.index + if label_filter is not None: + query_columns = _get_names_from_pd_query(label_filter) + columns.update( + c + for c in query_columns + if c not in merged_df.columns and c in data.var.index + ) + columns = list(columns) + values = data[:, columns].X.compute() + for i in range(len(columns)): + merged_df[columns[i]] = values[:, i] - if len(dims) > 0: - image = image.stack(t_c_z=dims, create_index=False).transpose( - *("y", "x", "t_c_z") - ) - with warnings.catch_warnings(): - # ignore UserWarning: rename 't_c_z' to 'c' does not create an index anymore. - # Try using swap_dims instead or use set_index after rename to create an indexed coordinate. - warnings.filterwarnings("ignore", "rename .*", UserWarning) - image = image.rename({"t_c_z": "c"}) - else: - # add trailing c dimension - image = image.expand_dims("c", -1) - return image + else: + merged_df = pd.read_parquet(path) + if "label" in merged_df.columns: + merged_df = merged_df.set_index("label") + if t is not None: + merged_df.columns = merged_df.columns + f"_{t}" + merged_dfs.append(merged_df) + return ( + merged_dfs[0] + if len(merged_dfs) == 1 + else pd.concat(merged_dfs, axis=1, join="inner") + ) def _get_feature_channel_indices(tokens): @@ -172,14 +156,81 @@ def _get_feature_channel_indices(tokens): ] +def _find_labels( + label_paths: list[str], + image_key: str, + label_name: str, + image_key_without_t: str | None, + selected_timepoint: Any, +): + timepoints = None + g = None + for label_path in label_paths: + label_root = zarr.open(label_path, mode="r") + labels_group = label_root.get("labels") + if labels_group is not None: + g = labels_group.get(f"{image_key}-{label_name}") + if g is not None: + timepoints = ( + g.attrs["multiscales"][0]["metadata"]["t"] + if "t" in g.attrs["multiscales"][0]["metadata"] + else [None] + ) + return g, timepoints + if g is None and image_key_without_t is not None: + g = labels_group.get(f"{image_key_without_t}-{label_name}") + zarr_metadata = g.attrs["multiscales"][0]["metadata"] + + if "t" in zarr_metadata: + timepoints = zarr_metadata["t"] + if selected_timepoint in timepoints: + index = timepoints.index(selected_timepoint) + timepoints = [timepoints[index]] + return g, timepoints + else: + timepoints = [None] + return g, timepoints + return g, timepoints + + +def _stack_and_rename(image: xr.DataArray) -> xr.DataArray: + image_dims = tuple([d for d in ["t", "c", "z"] if d in image.dims]) + with warnings.catch_warnings(): + # ignore UserWarning: rename 't_c_z' to 'c' does not create an index anymore. + # Try using swap_dims instead or use set_index after rename to create an indexed coordinate. + warnings.filterwarnings("ignore", "rename .*", UserWarning) + return ( + image.stack(t_c_z=image_dims, create_index=False) + .transpose(*("y", "x", "t_c_z")) + .rename({"t_c_z": "c"}) + if len(image_dims) > 0 + else image.expand_dims("c", -1) + ) + + +def _image_key_without_time_and_selected_time(metadata): + image_key_without_t = None + selected_timepoint = None + if "t" in metadata["group_metadata"]["group"]: + image_key_without_t = [] + for key in metadata["group_metadata"]["group"]: + if key != "t": + image_key_without_t.append( + str(metadata["group_metadata"]["group"][key]) + ) + else: + selected_timepoint = metadata["group_metadata"]["group"][key] + image_key_without_t = "-".join(image_key_without_t).replace("/", "-") + + return image_key_without_t, selected_timepoint + + def single_feature( stacked_image_tuple: tuple[tuple[str, ...], list[str | Group], dict] | None, image_tuple: tuple[tuple[str, ...], list[str | Group], dict], - labels_group: Group, + label_paths: list[str], output_dir: str, - output_sep: str, - objects_dir: str, - objects_dir_sep: str, + merge_paths: list[str] | None, label_name_to_features: dict[str, set[str]], label_name_to_min_max_area: dict[str, tuple[float | None, float | None]], features_plot: set[str], @@ -207,13 +258,10 @@ def single_feature( - A tuple of metadata strings. - A list of file paths or Zarr groups containing the primary image data. - A dictionary with additional metadata. - :param labels_group: Zarr group containing labels used to identify regions of interest in the image + :param label_paths: Zarr paths containing labels used to identify regions of interest in the image for feature computation. :param output_dir: Directory path where the computed feature files will be saved. - :param output_sep: Separator string used to construct the output file names. This helps in organizing - the output files systematically. - :param objects_dir: Directory path containing find objects output. - :param objects_dir_sep: File separator for `objects_dir` + :param merge_paths: Directory path containing output to merge :param label_name_to_features: Dictionary mapping label names (keys) to sets of feature names (values). Label names correspond to components in the labeled image (e.g. nuclei), and feature names specify the features to compute. @@ -239,222 +287,232 @@ def single_feature( ) output_fs, _ = fsspec.core.url_to_fs(output_dir) + output_sep = output_fs.sep + output_dir = output_dir.rstrip(output_fs.sep) + image = _images2fov(file_list, metadata, dask=True) - zarr_inputs = True - - for f in file_list: - if not isinstance(f, (zarr.Group, zarr.Array)): - zarr_inputs = False - break - - if zarr_inputs and stacked_image_tuple is not None: - for f in stacked_file_list: - if not isinstance(f, (zarr.Group, zarr.Array)): - zarr_inputs = False - break - if not zarr_inputs: - image = _read_image(file_list, metadata) - else: - image = [] - for f in file_list: - array, _, _, _ = _read_ome_zarr_array(f) - image.append(array) n_channels1 = None + stacked_image = None if stacked_image_tuple is not None: - if not zarr_inputs: - stacked_image = _read_image(stacked_file_list, stacked_metadata) - n_channels1 = image.sizes["c"] - # clear coords to avoid issues with xr.concat - for c in list(image.coords.keys()): - del image.coords[c] - for c in list(stacked_image.coords.keys()): - del stacked_image.coords[c] - image = xr.concat((image, stacked_image), dim="c") - else: - n_channels1 = 0 - for img in image: - n_channels1 += np.prod(img.shape[:-2]) - n_channels1 = int(n_channels1) - for f in stacked_file_list: - array, _, _, _ = _read_ome_zarr_array(f) - image.append(array) - + stacked_image = _images2fov(stacked_file_list, stacked_metadata, dask=True) + n_channels1 = image.sizes["c"] + image_key_without_t, selected_timepoint = _image_key_without_time_and_selected_time( + metadata + ) for label_name in label_name_to_features: + label_prefix = _label_name_to_prefix[label_name] features = label_name_to_features[label_name] - output_parquet_path = ( - f"{output_dir}{output_sep}{label_name}{output_sep}{image_key}.parquet" + g, timepoints = _find_labels( + label_paths=label_paths, + image_key=image_key, + label_name=label_name, + image_key_without_t=image_key_without_t, + selected_timepoint=selected_timepoint, ) - - if not force and is_parquet_file(output_parquet_path): - logger.info(f"Skipping features for {image_key} {label_name}") + if g is None: + logger.info(f"No labels found for {image_key}") continue - zarr_labels = get_labels( - labels_group=labels_group, - name=image_key, - suffix=label_name, # e.g. nuclei - ) - - if zarr_labels is None: - logger.info(f"Unable to read {label_name} labels for {image_key}.") - continue - label_prefix = _label_name_to_prefix[label_name] - merged_df = None - if objects_dir is not None: - merged_df = _read_merged_or_objects( - merge_dir=objects_dir, - merge_dir_sep=objects_dir_sep, - label_name=label_name, - image_key=image_key, - label_filter=label_filter, + labels_array = da.from_array(g[list(g.keys())[0]]) + for timepoint in timepoints: + image_ = ( + image.sel(t=timepoint) + if timepoint is not None and image.sizes.get("t", 0) > 1 + else image ) - - if merged_df is None: - logger.info(f"Find {label_name} objects for {image_key}.") - merged_df = find_objects(zarr_labels) - objects_path = f"{output_dir}{output_sep}{label_name}{output_sep}{image_key}-objects.parquet" - merged_df.index.name = "label" - merged_df.columns = f"{label_prefix}_" + merged_df.columns - _to_parquet( - merged_df, - objects_path, - write_index=True, - compute=True, - custom_metadata=dict(scallops=json.dumps(cli_metadata())) - if not no_version - else None, + image_ = _stack_and_rename(image_) + if stacked_image is not None: + stacked_image_ = _stack_and_rename(stacked_image) + + intensity_image = ( + xr.concat((image_, stacked_image_), dim="c", join="outer") + if stacked_image is not None + else image_ ) - merged_df = pd.read_parquet(objects_path) - - features = normalize_features(features) - # strip nuclei_, etc. from features_plot - features_plot_label = [] - for feature in features_plot: - tokens = feature.lower().split("_") - if tokens[0] == label_prefix: - features_plot_label.append("_".join(tokens[1:])) - - features_plot_label = normalize_features(features_plot_label) - if stacked_image_tuple is not None: - stacked_features = set() - stacked_features_plot = set() - for feature in features: # add offset for image1 - tokens = feature.lower().split("_") - if ( - tokens[0] in _features_multichannel.keys() - or tokens[0] in _features_single_channel.keys() - ): - for token_index in range( - 1, - 3 if tokens[0] in _features_multichannel.keys() else 2, - ): - c = tokens[token_index] - if c[0] == "s": - if c in channel_names: - channel_names[str(n_channels1 + int(c[1:]))] = ( - channel_names.pop(c) - ) - tokens[token_index] = str(n_channels1 + int(c[1:])) - - new_feature = "_".join(tokens) - if feature in features_plot_label: - stacked_features_plot.add(new_feature) - stacked_features.add(new_feature) - - features = stacked_features - features_plot_label = stacked_features_plot - - features = list(set(natsorted(features))) - logger.info( - f"{image_key} {label_name} {len(features):,} {pluralize('feature', len(features))}: " - f"{', '.join(features)}" - ) - if label_filter is not None: - merged_df = merged_df.query(label_filter) - min_max_area = label_name_to_min_max_area.get(label_name) - area_column = f"{label_prefix}_AreaShape_Area" - n_labels = len(merged_df) - prefix = "" - if min_max_area[0] is not None or min_max_area[1] is not None: - area_query = [] - if min_max_area[0] is not None: - area_query.append(f"{area_column}>={min_max_area[0]}") - if min_max_area[1] is not None: - area_query.append(f"{area_column}<={min_max_area[1]}") - merged_df = merged_df.query("&".join(area_query)) - n_labels_filtered = n_labels - len(merged_df) - prefix = f"Removed {n_labels_filtered:,} out of " - logger.info( - f"{prefix}{n_labels:,} {pluralize('label', n_labels)}. " - f"Area: {merged_df[area_column].min():,.0f} to {merged_df[area_column].max():,.0f}." - ) + features_path = get_path( + output_dir, output_sep, label_name, image_key, timepoint, ".parquet" + ) + if not force and is_parquet_file(features_path): + logger.info( + f"Skipping features for {image_key} {label_name}{' at t=' + timepoint if timepoint is not None else ''}." + ) + continue + + merged_df = None + if merge_paths is not None and len(merge_paths) > 0: + merged_df = _read_merged_or_objects( + paths=merge_paths, + timepoint=timepoint, + label_name=label_name, + image_key=image_key, + image_key_without_t=image_key_without_t, + label_filter=label_filter, + ) + if merged_df is None: + raise ValueError(f"Metadata not found for {image_key}") + if timepoint is not None and labels_array.ndim == 3: + timepoint_index = timepoints.index(timepoint) + label_image = labels_array[timepoint_index] + else: + label_image = labels_array - df = label_features( - objects_df=merged_df, - label_image=zarr_labels if zarr_inputs else da.from_zarr(zarr_labels), - intensity_image=image if zarr_inputs else image.data, - features=features, - normalize=normalize, - bounding_box_columns=[ - f"{label_prefix}_AreaShape_BoundingBoxMinimum_Y", - f"{label_prefix}_AreaShape_BoundingBoxMinimum_X", - f"{label_prefix}_AreaShape_BoundingBoxMaximum_Y", - f"{label_prefix}_AreaShape_BoundingBoxMaximum_X", - ], - channel_names=channel_names, - ) - # df will be None if only area and coordinates requested - - if df is not None: - fs = fsspec.url_to_fs(output_parquet_path)[0] - if fs.exists(output_parquet_path): - fs.rm(output_parquet_path, recursive=True) - df.index.name = "label" - df.columns = f"{label_prefix}_" + df.columns - - if isinstance(df, pd.DataFrame): - table = pa.Table.from_pandas(df, preserve_index=True) - if not no_version: - table = table.replace_schema_metadata( - { - "scallops".encode(): json.dumps(cli_metadata()).encode(), - **table.schema.metadata, - } - ) - fs, output_file = fsspec.url_to_fs(output_parquet_path) - pq.write_table( - table, - output_parquet_path, - filesystem=fs, + if merged_df is None: + logger.info( + f"Find {label_name} objects for {image_key}{' at t=' + timepoint if timepoint is not None else ''}." + ) + merged_df = find_objects(label_image) + objects_path = get_path( + output_dir, + output_sep, + label_name, + image_key, + timepoint, + "-objects.parquet", ) - else: + merged_df.index.name = "label" + merged_df.columns = f"{label_prefix}_" + merged_df.columns _to_parquet( - df, - output_parquet_path, + merged_df, + objects_path, write_index=True, compute=True, custom_metadata=dict(scallops=json.dumps(cli_metadata())) if not no_version else None, ) + merged_df = pd.read_parquet(objects_path) - if len(features_plot_label) > 0: - features_plot_label = [ - label_prefix + "_" + feature for feature in features_plot_label - ] - df_features = pd.read_parquet( - output_parquet_path, columns=features_plot_label + features = normalize_features(features) + # strip nuclei_, etc. from features_plot + features_plot_label = [] + for feature in features_plot: + tokens = feature.lower().split("_") + if tokens[0] == label_prefix: + features_plot_label.append("_".join(tokens[1:])) + + features_plot_label = normalize_features(features_plot_label) + if stacked_image_tuple is not None: + stacked_features = set() + stacked_features_plot = set() + for feature in features: # add offset for image1 + tokens = feature.lower().split("_") + + if ( + tokens[0] in _features_multichannel.keys() + or tokens[0] in _features_single_channel.keys() + ): + for token_index in range( + 1, + 3 if tokens[0] in _features_multichannel.keys() else 2, + ): + c = tokens[token_index] + if c[0] == "s": + if c in channel_names: + channel_names[str(n_channels1 + int(c[1:]))] = ( + channel_names.pop(c) + ) + tokens[token_index] = str(n_channels1 + int(c[1:])) + + new_feature = "_".join(tokens) + if feature in features_plot_label: + stacked_features_plot.add(new_feature) + stacked_features.add(new_feature) + + features = stacked_features + features_plot_label = stacked_features_plot + + features = list(set(natsorted(features))) + + if label_filter is not None: + merged_df = merged_df.query(label_filter) + min_max_area = label_name_to_min_max_area.get(label_name) + area_column = f"{label_prefix}_AreaShape_Area" + n_labels = len(merged_df) + prefix = "" + if min_max_area[0] is not None or min_max_area[1] is not None: + area_query = [] + if min_max_area[0] is not None: + area_query.append(f"{area_column}>={min_max_area[0]}") + if min_max_area[1] is not None: + area_query.append(f"{area_column}<={min_max_area[1]}") + merged_df = merged_df.query("&".join(area_query)) + n_labels_filtered = n_labels - len(merged_df) + prefix = f"Removed {n_labels_filtered:,} out of " + logger.info( + f"{prefix}{n_labels:,} {pluralize('label', n_labels)}. " + f"Area: {merged_df[area_column].min():,.0f} to {merged_df[area_column].max():,.0f}." ) - centroid_columns = [ - label_prefix + "_centroid-1", - label_name + "_centroid-0", - ] - df = merged_df[centroid_columns].join(df_features) - pdf_path = ( - f"{output_dir}{output_sep}{label_name}{output_sep}{image_key}.pdf" + + df = label_features( + objects_df=merged_df, + label_image=label_image, + intensity_image=intensity_image, + features=features, + normalize=normalize, + bounding_box_columns=[ + f"{label_prefix}_AreaShape_BoundingBoxMinimum_Y", + f"{label_prefix}_AreaShape_BoundingBoxMinimum_X", + f"{label_prefix}_AreaShape_BoundingBoxMaximum_Y", + f"{label_prefix}_AreaShape_BoundingBoxMaximum_X", + ], + channel_names=channel_names, ) - _plot_features(df, features_plot_label, pdf_path, centroid_columns) + # df will be None if only area and coordinates requested + + if df is not None: + fs = fsspec.url_to_fs(features_path)[0] + if fs.exists(features_path): + fs.rm(features_path, recursive=True) + df.index.name = "label" + df.columns = f"{label_prefix}_" + df.columns + + if isinstance(df, pd.DataFrame): + table = pa.Table.from_pandas(df, preserve_index=True) + if not no_version: + table = table.replace_schema_metadata( + { + "scallops".encode(): json.dumps( + cli_metadata() + ).encode(), + **table.schema.metadata, + } + ) + fs, output_file = fsspec.url_to_fs(features_path) + pq.write_table( + table, + features_path, + filesystem=fs, + ) + + else: + _to_parquet( + df, + features_path, + write_index=True, + compute=True, + custom_metadata=dict(scallops=json.dumps(cli_metadata())) + if not no_version + else None, + ) + + if len(features_plot_label) > 0: + features_plot_label = [ + label_prefix + "_" + feature for feature in features_plot_label + ] + df_features = pd.read_parquet( + features_path, columns=features_plot_label + ) + centroid_columns = [ + label_prefix + "_centroid-1", + label_name + "_centroid-0", + ] + df = merged_df[centroid_columns].join(df_features) + pdf_path = get_path( + output_dir, output_sep, label_name, image_key, timepoint, ".pdf" + ) + + _plot_features(df, features_plot_label, pdf_path, centroid_columns) return [] @@ -474,13 +532,14 @@ def run_pipeline_compute_features(arguments: argparse.Namespace) -> None: image_patterns = arguments.image_pattern output_dir = arguments.output - objects_dir = arguments.objects + merge_paths = arguments.merge subset = arguments.subset force = arguments.force groupby = arguments.groupby channel_names = arguments.channel_rename stack_images = arguments.stack_images label_filter = arguments.label_filter + label_paths = arguments.labels normalize = not arguments.no_normalize stack_image_pattern = arguments.stack_image_pattern cell_features = arguments.features_cell @@ -507,10 +566,6 @@ def run_pipeline_compute_features(arguments: argparse.Namespace) -> None: threads_per_worker=4 if "sizeshape" in unique_features else 1 ) - objects_dir_sep = None - if objects_dir is not None: - objects_dir_sep = fsspec.core.url_to_fs(objects_dir)[0].sep - objects_dir = objects_dir.rstrip(objects_dir_sep) label_name_to_min_max_area = dict( nuclei=[arguments.nuclei_min_area, arguments.nuclei_max_area], cytosol=[arguments.cytosol_min_area, arguments.cytosol_max_area], @@ -522,11 +577,9 @@ def run_pipeline_compute_features(arguments: argparse.Namespace) -> None: output_dir = output_dir.rstrip(output_fs.sep) for label in label_name_to_features: output_fs.makedirs(output_dir + output_fs.sep + label, exist_ok=True) - labels_path = arguments.labels + no_version = arguments.no_version - assert labels_path is not None, "No labels provided" - label_root = zarr.open(labels_path, mode="r") - labels_group = label_root["labels"] + if channel_names is not None: # keys are strings in json try: @@ -563,10 +616,8 @@ def run_pipeline_compute_features(arguments: argparse.Namespace) -> None: img_tuple[0], img_tuple[1], output_dir=output_dir, - output_sep=output_fs.sep, - objects_dir=objects_dir, - objects_dir_sep=objects_dir_sep, - labels_group=labels_group, + merge_paths=merge_paths, + label_paths=label_paths, label_filter=label_filter, label_name_to_min_max_area=label_name_to_min_max_area, label_name_to_features=label_name_to_features, diff --git a/scallops/cli/features_main.py b/scallops/cli/features_main.py index e7c5d19..9786ed0 100644 --- a/scallops/cli/features_main.py +++ b/scallops/cli/features_main.py @@ -73,7 +73,8 @@ def new_format_help(x): "--labels", dest="labels", required=True, - help="Path to zarr directory containing labels", + nargs="+", + help="Path(s) to zarr directory containing labels", ) generic_features_help = ( @@ -107,11 +108,6 @@ def new_format_help(x): help=generic_features_help, ) - parser.add_argument( - "--objects", - required=False, - help="Path to directory containing output from `find-objects` or `merge`", - ) parser.add_argument( "--stack-images", help="Path to additional images to stack with `images`. Add `s` prefix to refer" @@ -120,8 +116,8 @@ def new_format_help(x): ) required.add_argument( "--merge", - required=False, - help="Path to directory containing output from `merge`", + nargs="*", + help="Path(s) to directory containing output from `find-objects`, `merge`, or `features`", ) parser.add_argument( "--label-filter", diff --git a/scallops/cli/find_objects.py b/scallops/cli/find_objects.py index 88bf71e..2501201 100644 --- a/scallops/cli/find_objects.py +++ b/scallops/cli/find_objects.py @@ -8,6 +8,7 @@ import argparse import json +import dask.array as da import fsspec import zarr from zarr import Group @@ -27,30 +28,62 @@ logger = _get_cli_logger() +def get_path( + output_dir: str, + output_sep: str, + label_name: str, + image_key: str, + timepoint: str | None = None, + suffix="", +): + return ( + (f"{output_dir}{output_sep}{label_name}{output_sep}{image_key}{suffix}") + if timepoint is None + else f"{output_dir}{output_sep}{label_name}{output_sep}t={timepoint}{output_sep}{image_key}{suffix}" + ) + + def _execute( label_tuple: tuple[tuple[str, ...], list[str | Group], dict], + timepoint: str | None, output_dir: str, output_sep: str, - force: bool = False, - no_version: bool = False, + force: bool, + no_version: bool, ): group, file_list, metadata = label_tuple assert len(file_list) == 1 label_name = group[len(group) - 1] image_key = "-".join(group[:-1]) # exclude suffix from key - path = ( - f"{output_dir}{output_sep}{label_name}{output_sep}{image_key}-objects.parquet" + path = get_path( + output_dir, + output_sep, + label_name, + image_key, + timepoint, + suffix="-objects.parquet", ) fs = fsspec.url_to_fs(path)[0] if fs.exists(path): if force: fs.rm(path, recursive=True) else: - logger.info(f"Skipping finding objects for {metadata['id']}.") + logger.info( + f"Skipping find objects for {metadata['id']}{' at t=' + timepoint if timepoint is not None else ''}." + ) return - logger.info(f"Finding objects for {metadata['id']}.") - array = file_list[0][list(file_list[0].keys())[0]] + logger.info( + f"Find objects for {metadata['id']}{' at t=' + timepoint if timepoint is not None else ''}." + ) + g = file_list[0] + array = da.from_zarr(g[list(g.keys())[0]]) + + if timepoint is not None: + timepoint_index = g.attrs["multiscales"][0]["metadata"]["t"].index(timepoint) + array = array[timepoint_index] + df = find_objects(array) + df.index.name = "label" prefix = _label_name_to_prefix.get(label_name) if prefix is not None: @@ -64,7 +97,9 @@ def _execute( if not no_version else None, ) - logger.info(f"Saved objects for {metadata['id']} to {path}.") + logger.info( + f"Saved objects for {metadata['id']}{' at t=' + timepoint if timepoint is not None else ''} to {path}." + ) def run_pipeline_find_objects(arguments: argparse.Namespace) -> None: @@ -91,33 +126,50 @@ def run_pipeline_find_objects(arguments: argparse.Namespace) -> None: output_fs, _ = fsspec.core.url_to_fs(output_dir) output_dir = output_dir.rstrip(output_fs.sep) - paths = [] + _, _, keys = _create_file_regex(label_pattern) + keys = list(keys) + label_tuples = [] + timepoints = [] + for path in labels_paths: label_root = zarr.open(path, mode="r") labels_group = label_root.get("labels") - if labels_group is None: - raise ValueError(f"Labels group not found for {path}") - paths.append(labels_group) - _, _, keys = _create_file_regex(label_pattern) - gen = _set_up_experiment( - image_path=paths, - files_pattern=label_pattern, - group_by=list(keys), - subset=subset, - ) + if labels_group is not None: + gen = _set_up_experiment( + image_path=labels_group, + files_pattern=label_pattern, + group_by=keys, + subset=subset, + ) + for label_tuple in gen: + label_key, file_list, metadata = label_tuple + + if ( + label_suffix is None + or label_key[len(label_key) - 1] in label_suffix + ): + assert len(file_list) == 1 + g = file_list[0] + zarr_metadata = g.attrs["multiscales"][0]["metadata"] + + if "t" not in zarr_metadata: + label_tuples.append(label_tuple) + timepoints.append(None) + else: + for timepoint_ in zarr_metadata["t"]: + label_tuples.append(label_tuple) + timepoints.append(timepoint_) with ( _create_default_dask_config(), _create_dask_client(dask_server_url, **dask_cluster_parameters), ): - [ + for i in range(len(label_tuples)): _execute( - label_tuple=g, + label_tuple=label_tuples[i], + timepoint=timepoints[i], output_dir=output_dir, output_sep=output_fs.sep, force=force, no_version=no_version, ) - for g in gen - if label_suffix is None or g[0][len(g[0]) - 1] in label_suffix - ] diff --git a/scallops/cli/pooled_if_sbs.py b/scallops/cli/pooled_if_sbs.py index 69696e5..4d1be20 100644 --- a/scallops/cli/pooled_if_sbs.py +++ b/scallops/cli/pooled_if_sbs.py @@ -573,7 +573,7 @@ def merge_sbs_phenotype_pipeline( prefixes = [] for i in range(len(phenotype_paths)): - df = dd.read_parquet(phenotype_paths[i]) + df = dd.read_parquet(path) _metadata_cols = df.columns[ df.columns.str.contains(_metadata_columns_whitelist_str) ].tolist() @@ -585,7 +585,7 @@ def merge_sbs_phenotype_pipeline( if phenotype_suffix is not None: df.columns = df.columns + phenotype_suffix[i] - prefixes.append(phenotype_paths[i].split("/")[-3]) + prefixes.append(path.split("/")[-3]) if output_format == "zarr": # read index and metadata if len(_metadata_cols) > 0: @@ -596,7 +596,7 @@ def merge_sbs_phenotype_pipeline( ) for key in rename_features: feature_names_i[feature_names_i.index(key)] = rename_features[key] - df = dd.read_parquet(phenotype_paths[i], columns=_metadata_cols) + df = dd.read_parquet(path, columns=_metadata_cols) feature_names += feature_names_i unique_columns.update(feature_names_i) @@ -676,34 +676,38 @@ def merge_sbs_phenotype_pipeline( ) -def _find_phenotype_paths( - phenotype_paths, phenotype_filesystems, phenotype_suffix, image_key -): - _phenotype_paths = [] - _phenotype_suffix = [] +def _find_phenotype_paths(paths: list[str], image_key: str, image_metadata: dict): + found_paths = [] + for path in paths: + path_ = path + path = path.format(**image_metadata["file_metadata"][0]) + fs, path = fsspec.url_to_fs(path) + if path_ == path and "*" not in path: # directory + # match */A1-*.parquet and */A1.parquet + sep = fs.sep + matches = fs.glob(f"{path}{sep}*{sep}{image_key}-*.parquet") + fs.glob( + f"{path}{sep}*{sep}{image_key}.parquet" + ) - for i in range(len(phenotype_paths)): - # match */A1-*.parquet and */A1.parquet - sep = phenotype_filesystems[i].sep - matches = phenotype_filesystems[i].glob( - f"{phenotype_paths[i]}{sep}*{sep}{image_key}-*.parquet" - ) + phenotype_filesystems[i].glob( - f"{phenotype_paths[i]}{sep}*{sep}{image_key}.parquet" - ) + if len(matches) == 0: + # match A1-*.parquet and A1.parquet + matches = fs.glob(f"{path}{sep}{image_key}-*.parquet") + fs.glob( + f"{path}{sep}{image_key}.parquet" + ) - if len(matches) == 0: - # match A1-*.parquet and A1.parquet - matches = phenotype_filesystems[i].glob( - f"{phenotype_paths[i]}{sep}{image_key}-*.parquet" - ) + phenotype_filesystems[i].glob( - f"{phenotype_paths[i]}{sep}{image_key}.parquet" - ) + for x in matches: + path.append(fs.unstrip_protocol(x)) + # if phenotype_suffix is not None: + # _phenotype_suffix.append(phenotype_suffix[i]) - for x in matches: - _phenotype_paths.append(phenotype_filesystems[i].unstrip_protocol(x)) - if phenotype_suffix is not None: - _phenotype_suffix.append(phenotype_suffix[i]) - return _phenotype_paths, _phenotype_suffix + else: + if "*" in path: + matches = fs.glob(path) + for match in matches: + found_paths.append(fs.unstrip_protocol(match)) + elif fs.exists(path): + found_paths.append(fs.unstrip_protocol(path)) + return found_paths def merge_main(arguments: argparse.Namespace): @@ -750,8 +754,9 @@ def merge_main(arguments: argparse.Namespace): if len(set(phenotype_paths)) != len(phenotype_paths): raise ValueError("Duplicate phenotype paths") for i in range(len(phenotype_paths)): - phenotype_fs, _ = fsspec.core.url_to_fs(phenotype_paths[i]) - phenotype_paths[i] = phenotype_paths[i].rstrip(phenotype_fs.sep) + path = phenotype_paths[i] + phenotype_fs, _ = fsspec.core.url_to_fs(path) + path = path.rstrip(phenotype_fs.sep) phenotype_filesystems.append(phenotype_fs) paths = [] @@ -762,6 +767,7 @@ def merge_main(arguments: argparse.Namespace): sbs = sbs.rstrip(sbs_fs.sep) sbs_matches = sbs_fs.glob(sbs + sbs_fs.sep + "*.parquet") sbs_matches = [sbs_fs.unstrip_protocol(m) for m in sbs_matches] + print("sbs_matches", sbs_matches) for sbs_path in sbs_matches: name = os.path.splitext(os.path.basename(sbs_path))[0] if not name.startswith("."): # ignore hidden files @@ -978,9 +984,11 @@ def reads_pipeline( points_path = f"{points_path}/{image_key}-peaks.parquet" peaks = dd.read_parquet(points_path) maxed = read_ome_zarr_array(spots_root["images"][image_key + "-max"], dask=True) - labels = read_ome_zarr_array( - labels_root[image_key + "-" + label_name], dask=True - ).data.compute() + labels = ( + read_ome_zarr_array(labels_root[image_key + "-" + label_name], dask=True) + .data.squeeze() + .compute() + ) if expand_labels_distance is not None and expand_labels_distance > 0: labels = expand_labels(labels, distance=expand_labels_distance) iss_cycles = maxed.coords["t"].values diff --git a/scallops/cli/register.py b/scallops/cli/register.py index 70be43d..8f695db 100644 --- a/scallops/cli/register.py +++ b/scallops/cli/register.py @@ -5,7 +5,7 @@ import os from collections.abc import Sequence from itertools import zip_longest -from typing import Literal +from typing import Any, Literal import fsspec import itk @@ -13,6 +13,8 @@ import xarray as xr import zarr from dask.bag import from_sequence +from natsort import index_natsorted, natsorted +from xarray import DataArray from zarr import Group from scallops.cli.util import ( @@ -46,6 +48,7 @@ _get_fs, _get_store_path, _write_zarr_image, + _write_zarr_labels, is_ome_zarr_array, open_ome_zarr, read_ome_zarr_array, @@ -54,6 +57,63 @@ logger = _get_cli_logger() +def _output_exists( + register_self, label_output_root, moving_label_keys, image_output_root, image_key +): + labels_exist = True + if label_output_root is not None: + if not register_self: + for key in moving_label_keys: + key = os.path.basename(key) + if not is_ome_zarr_array(label_output_root.get(f"labels/{key}")): + labels_exist = False + break + + image_exists = True + if image_output_root is not None: + if not is_ome_zarr_array(image_output_root.get(f"images/{image_key}")): + image_exists = False + elif label_output_root is None: + image_exists = False + return labels_exist and image_exists + + +def _get_timepoint_index_and_value( + timepoint: str | int, image: Sequence[DataArray] | xr.DataArray +) -> tuple[int, Any]: + timepoint_value = None + timepoint_index = None + if isinstance(timepoint, str): + if isinstance(image, Sequence): + for i in range(len(image)): + if ( + "t" in image[i].coords + and image[i].coords["t"].values[0] == timepoint + ): + timepoint_index = i + timepoint_value = image[i].coords["t"].values[0] + break + elif "t" in image.coords: + times = list(image.coords["t"].values) + if timepoint in times: + timepoint_index = times.index(timepoint) + timepoint_value = times[timepoint_index] + + if timepoint_index is None: + raise ValueError(f"Reference timepoint not found: {timepoint}.") + elif isinstance(timepoint, int): + timepoint_index = timepoint + if isinstance(image, Sequence): + if "t" in image[timepoint_index].coords: + timepoint_value = image[timepoint_index].coords["t"].values[0] + elif "t" in image.coords: + times = list(image.coords["t"].values) + timepoint_value = times[timepoint_index] + else: + raise ValueError() + return timepoint_index, timepoint_value + + def single_registration( fixed_tuple: tuple[tuple[str, ...], list[str | Group], dict] | None, moving_tuple: tuple[tuple[str, ...], list[str | Group], dict], @@ -67,7 +127,8 @@ def single_registration( moving_labels: list[str], moving_image_spacing: tuple[float, float] | None, fixed_image_spacing: tuple[float, float] | None, - reference_timepoint: int | str, + moving_timepoint: int | str, + fixed_timepoint: int | str, unroll_channels: bool = False, force: bool = False, z_index: int | str = "max", @@ -105,8 +166,8 @@ def single_registration( :param moving_image_spacing: Spacing of the moving image. :param fixed_image_spacing: Spacing of the fixed image. :param force: Whether to overwrite existing output - :param reference_timepoint: Index or value of timepoint to register to when registering - across timepoints + :param moving_timepoint: Index or value of moving timepoint to use + :param fixed_timepoint: Index or value of fixed timepoint to use :param unroll_channels: Whether to unroll channels across timepoints. :param landmarks_initialize: Use landmarks to initialize registration. :param landmark_slice_size: The slice size in physical coordinates @@ -130,51 +191,41 @@ def single_registration( transform_dest = f"{transform_output_dir}{transform_fs.sep}{image_key}" moving_label_keys = [] + # when registering self, transforms written to output_dir/time={t_value}/ + moving_image = _images2fov( + moving_file_list, + moving_metadata, + dask=True, + concat_dims=("c",), + ) + moving_timepoint, moving_timepoint_value = _get_timepoint_index_and_value( + moving_timepoint, moving_image + ) + if moving_labels is not None: - matching_label_prefix = image_key - if register_self: - matching_label_prefix = f"{matching_label_prefix}-{reference_timepoint}" for moving_label in moving_labels: moving_label_keys.extend( get_matching_names( - image_key=matching_label_prefix, image_dir=moving_label, labels=True + image_key=image_key, image_dir=moving_label, labels=True ) ) - moving_label_keys = sorted(moving_label_keys) + moving_label_keys = natsorted(moving_label_keys) if len(moving_label_keys) == 0: - logger.warning(f"No labels found for {image_key}") - - if not force: - labels_exist = True - if label_output_root is not None: - if not register_self: - for key in moving_label_keys: - key = os.path.basename(key) - if not is_ome_zarr_array(label_output_root.get(f"labels/{key}")): - labels_exist = False - break - # TODO check for transformed labels when register_self - - image_exists = True - if image_output_root is not None: - if not is_ome_zarr_array(image_output_root.get(f"images/{image_key}")): - image_exists = False - elif label_output_root is None: - image_exists = False - if labels_exist and image_exists: - logger.info(f"Skipping registration for {image_key}") + logger.warning(f"No labels found for {image_key}.") return image_key - if register_self: - logger.info(f"Running registration for {image_key} t={reference_timepoint}") - logger.info( - f"{len(moving_file_list):,} {pluralize('input', len(moving_file_list))}:" - f" {', '.join([s.name.replace('/images/', '') if isinstance(s, zarr.Group) else str(s) for s in moving_file_list])}" - ) - else: - logger.info(f"Running registration for {image_key}") + if not force and _output_exists( + register_self, + label_output_root, + moving_label_keys, + image_output_root, + image_key, + ): + logger.info(f"Skipping registration for {image_key}") + return image_key if not register_self: + logger.info(f"Running registration for {image_key}") _, fixed_file_list, fixed_metadata = fixed_tuple assert fixed_metadata["id"] == moving_metadata["id"], ( @@ -187,18 +238,16 @@ def single_registration( concat_dims=("c",), dask=True, ) - if isinstance(fixed_image, Sequence): - fixed_image = fixed_image[0] - fixed_image = _z_projection(fixed_image, z_index).isel( - t=0, c=fixed_channel, missing_dims="ignore" + fixed_timepoint, fixed_timepoint_value = _get_timepoint_index_and_value( + fixed_timepoint, fixed_image ) - moving_image = _images2fov( - moving_file_list, - moving_metadata, - dask=True, - concat_dims=("c",), - ) - + if isinstance(fixed_image, Sequence): + fixed_image = fixed_image[fixed_timepoint] + else: + fixed_image = fixed_image.isel( + t=fixed_timepoint, c=fixed_channel, missing_dims="ignore" + ) + fixed_image = _z_projection(fixed_image, z_index) parameter_object = _load_itk_parameters(itk_parameters) parameter_object_across_channels = ( _load_itk_parameters(itk_channel_parameters) @@ -211,8 +260,15 @@ def single_registration( transform_fs.makedirs(transform_dest, exist_ok=True) if not register_self: - if isinstance(moving_image, Sequence): - moving_image = moving_image[0] + moving_image = ( + moving_image[moving_timepoint] + if isinstance(moving_image, Sequence) + else moving_image.isel(t=moving_timepoint, missing_dims="ignore") + ) + moving_image_align = _z_projection( + moving_image.isel(c=moving_channel, missing_dims="ignore"), z_index + ) + if ( moving_image_spacing is None and get_image_spacing(moving_image.attrs) is None @@ -226,9 +282,6 @@ def single_registration( f"Physical size not found for fixed image for {image_key}." ) - moving_image_align = _z_projection(moving_image, z_index).isel( - t=0, c=moving_channel, missing_dims="ignore" - ) if "c" in moving_image_align.dims and moving_image_align.sizes["c"] > 1: moving_image_align = moving_image_align.median(dim="c", keep_attrs=True) @@ -352,18 +405,12 @@ def single_registration( output_root=label_output_root, ) - else: # align to t=reference_timepoint - if isinstance(reference_timepoint, str): - reference_timepoint_found = False - for i in range(len(moving_image)): - if moving_image[i].coords["t"].values[0] == reference_timepoint: - reference_timepoint = i - reference_timepoint_found = True - break - if not reference_timepoint_found: - raise ValueError( - f"Reference timepoint not found: {reference_timepoint}." - ) + else: # align to t=moving_timepoint + logger.info(f"Running registration for {image_key} t={moving_timepoint}") + logger.info( + f"{len(moving_file_list):,} {pluralize('input', len(moving_file_list))}:" + f" {', '.join([s.name.replace('/images/', '') if isinstance(s, zarr.Group) else str(s) for s in moving_file_list])}" + ) set_automatic_transform_initialization(parameter_object, False) if output_aligned_channels_only and not isinstance(moving_image, xr.DataArray): @@ -373,10 +420,10 @@ def single_registration( moving_channel = 0 moving_image = new_moving_image if not no_version: - moving_image[reference_timepoint].attrs.update(cli_metadata()) + moving_image[moving_timepoint].attrs.update(cli_metadata()) _itk_align_reference_time_zarr( unroll_channels=unroll_channels, - reference_timepoint=reference_timepoint, + reference_timepoint=moving_timepoint, moving_image=moving_image, moving_channel=moving_channel, parameter_object=parameter_object, @@ -396,68 +443,97 @@ def single_registration( parameter_object_across_channels=parameter_object_across_channels, ) moving_image_attrs = moving_image[0].attrs.copy() + chunksize = moving_image[0].data.chunksize[-2:] del moving_image - + if moving_image_spacing is None: + moving_image_spacing = get_image_spacing(moving_image_attrs) if len(moving_label_keys) > 0: _transform_labels_t( - image_key=image_key, transform_fs=transform_fs, transform_dest=transform_dest, - moving_label_keys=moving_label_keys, + label_output_root=label_output_root, moving_image_attrs=moving_image_attrs, + moving_label_keys=moving_label_keys, moving_image_spacing=moving_image_spacing, - label_output_root=label_output_root, + moving_timepoint_value=moving_timepoint_value, + chunksize=chunksize, ) return image_key def _transform_labels_t( - image_key: str, transform_fs, transform_dest: str, - moving_label_keys: Sequence[str], - moving_image_attrs, - moving_image_spacing, label_output_root, + moving_image_attrs: dict, + moving_label_keys: Sequence[str], + moving_image_spacing: tuple[int, int], + moving_timepoint_value: str, + chunksize: tuple[int, int], ): # transform_dest structure is image_key/t=1 - # assume labels are named image_key-t-suffix - print("moving_label_keys", moving_label_keys) - for transform_file in transform_fs.ls(transform_dest, detail=True, refresh=True): - if transform_file["type"] == "directory": - transform_name = transform_file["name"] - basename = os.path.basename(transform_name) - if basename.startswith("t="): - time = basename[2:] - moving_label_keys_t = [] - output_label_prefix = f"{image_key}-{time}" - output_names = [] - # e.g. transform plateA-A1-IF-cell to plateA-A1-FISH-cell - for moving_label_key in moving_label_keys: - moving_label_key_basename = os.path.basename(moving_label_key) - output_label_suffix = "-" + moving_label_key_basename.split("-")[-1] - output_name = f"{output_label_prefix}{output_label_suffix}" - moving_label_keys_t.append(moving_label_key) - output_names.append(output_name) - - if len(moving_label_keys_t) > 0: - transform_parameter_object = _load_itk_parameters_from_dir( - transform_fs.unstrip_protocol(transform_name) + if len(moving_label_keys) > 0: + times = [] + transform_file_paths = [] + for transform_file in transform_fs.ls( + transform_dest, detail=True, refresh=True + ): + if transform_file["type"] == "directory": + transform_file_path = transform_file["name"] + basename = os.path.basename(transform_file_path) + if basename.startswith("t="): + times.append(basename[2:]) + transform_file_paths.append( + transform_fs.unstrip_protocol(transform_file_path) ) - if transform_parameter_object.GetNumberOfParameterMaps() > 0: - _transform_labels( + index = index_natsorted(times) + times = [times[val] for val in index] + transform_file_paths = [transform_file_paths[val] for val in index] + storage_options = {"chunks": chunksize} + for moving_label_key in moving_label_keys: + moving_label_key_basename = os.path.basename(moving_label_key) + transformed_labels = [] + times_ = [] + for i in range(len(times)): + transform_parameter_object = _load_itk_parameters_from_dir( + transform_file_paths[i] + ) + if transform_parameter_object.GetNumberOfParameterMaps() > 0: + src = read_ome_zarr_array(moving_label_key).squeeze() + times_.append(times[i]) + if src.sizes.get("t", 0) > 0 and "t" in src.coords: + src = src.sel(t=moving_timepoint_value) + transformed_labels.append( + itk_transform_labels( + image=src, transform_parameter_object=transform_parameter_object, - attrs=moving_image_attrs, - matching_keys=moving_label_keys_t, - output_names=output_names, - moving_image_spacing=moving_image_spacing, - output_root=label_output_root, + image_spacing=moving_image_spacing, ) + ) + + transformed_labels = xr.DataArray( + np.stack(transformed_labels), + dims=["t", "y", "x"], + coords={"t": times_}, + attrs=moving_image_attrs, + ) + + _write_zarr_labels( + name=moving_label_key_basename, + root=label_output_root, + metadata=None, + group_metadata=None, + labels=transformed_labels, + storage_options=storage_options, + ) def get_matching_names( - image_key: str, image_dir: str | Group, labels: bool = True + image_key: str, + image_dir: str | Group, + labels: bool = True, + label_suffixes: Sequence[str] = {"cell", "nuclei", "cytosol"}, ) -> list[str]: """Get matching keys for the given image key and directory. @@ -472,24 +548,25 @@ def get_matching_names( # look for f'labels/image_key-{suffix} or f'images/image_key zarr_dir = "labels" if labels else "images" if isinstance(image_dir, Group): - protocol = _get_fs_protocol(_get_fs(image_dir)) + fs = _get_fs(image_dir) image_dir = f"{_get_store_path(image_dir)}{image_dir.name}" - if protocol != "file": - image_dir = f"{protocol}://{image_dir}" + image_dir = fs.unstrip_protocol(image_dir) image_fs, _ = fsspec.core.url_to_fs(image_dir) image_dir = image_dir.rstrip(image_fs.sep) - glob_pattern = f"{image_dir}{image_fs.sep}{zarr_dir}{image_fs.sep}{image_key}" if labels: - glob_pattern += "-*" - paths = image_fs.glob(glob_pattern) - protocol = _get_fs_protocol(image_fs) - if protocol != "file": - paths = [f"{protocol}://{x}" for x in paths] + glob_pattern = f"{glob_pattern}-*" # for suffix results = [] - for path in paths: + + for path in image_fs.glob(glob_pattern): + path = image_fs.unstrip_protocol(path) name = os.path.basename(path) + + if labels: + tokens = name.split("-") + if tokens[-1] not in label_suffixes and tokens[:-1] == image_key: + continue if not name.startswith(".") and is_ome_zarr_array(zarr.open(path, mode="r")): results.append(path) return results @@ -757,10 +834,16 @@ def run_itk_registration(arguments: argparse.Namespace) -> None: moving_image_pattern = arguments.moving_image_pattern fixed_image_pattern = arguments.fixed_image_pattern group_by = arguments.groupby - reference_timepoint = arguments.time - if reference_timepoint is not None: + fixed_timepoint = arguments.fixed_time if arguments.fixed_time is not None else 0 + moving_timepoint = arguments.moving_time if arguments.moving_time is not None else 0 + if isinstance(fixed_timepoint, str) and fixed_timepoint.isdigit(): + try: + fixed_timepoint = int(fixed_timepoint) + except ValueError: + pass + if isinstance(moving_timepoint, str) and moving_timepoint.isdigit(): try: - reference_timepoint = int(reference_timepoint) + moving_timepoint = int(moving_timepoint) except ValueError: pass unroll_channels = arguments.unroll_channels @@ -870,7 +953,8 @@ def run_itk_registration(arguments: argparse.Namespace) -> None: image_output_root=image_output_root, moving_image_spacing=moving_image_spacing, fixed_image_spacing=fixed_image_spacing, - reference_timepoint=reference_timepoint, + fixed_timepoint=fixed_timepoint, + moving_timepoint=moving_timepoint, landmarks_initialize=landmarks_initialize, landmark_slice_size=landmark_slice_size, landmark_min_count=landmark_min_count, diff --git a/scallops/cli/register_main.py b/scallops/cli/register_main.py index 598d481..8bde07e 100644 --- a/scallops/cli/register_main.py +++ b/scallops/cli/register_main.py @@ -152,10 +152,12 @@ def _create_elastix_parser(subparsers: ArgumentParser, default_help: bool) -> No ) parser.add_argument( - "--time", - "-t", - default="0", - help="Time index (0-based) or value for alignment across timepoints", + "--moving-time", + help="Time index (0-based) or value for moving image", + ) + parser.add_argument( + "--fixed-time", + help="Time index (0-based) or value for fixed image", ) parser.add_argument( "--unroll-channels", diff --git a/scallops/cli/segment.py b/scallops/cli/segment.py index b4afe86..1447867 100644 --- a/scallops/cli/segment.py +++ b/scallops/cli/segment.py @@ -12,15 +12,19 @@ import argparse import importlib +from collections.abc import Sequence from typing import Callable, Literal, Optional import dask.array as da import fsspec import numpy as np +import xarray as xr import zarr +from array_api_compat import get_namespace from dask.bag import from_sequence from zarr import Group +from scallops.cli.register import _get_timepoint_index_and_value from scallops.cli.util import ( _create_dask_client, _create_default_dask_config, @@ -51,6 +55,7 @@ def segment_nuclei( dapi_channel: int, method: Callable, root: Group, + timepoint: int | str, z_index: int | str, min_area: float | None = None, max_area: float | None = None, @@ -71,6 +76,7 @@ def segment_nuclei( :param method: Segmentation method. :param root: Zarr hierarchy root. :param z_index: Either 'max' or z-index + :param timepoint: Time to use. :param min_area: Minimum area threshold for filtering labels. :param max_area: Maximum area threshold for filtering labels. :param chunks: Tuple specifying chunking size for Dask arrays. @@ -88,7 +94,19 @@ def segment_nuclei( logger.info(f"Skipping nuclei segmentation for {image_key}") return root logger.info(f"Running nuclei segmentation for {image_key}") - image = _images2fov(file_list, metadata, dask=True).squeeze() + image = _images2fov( + file_list, + metadata, + concat_dims=("c",), + dask=True, + ) + timepoint, timepoint_value = _get_timepoint_index_and_value(timepoint, image) + + image = ( + image[timepoint] + if isinstance(image, Sequence) + else image.isel(t=timepoint, missing_dims="ignore") + ) image = _z_projection(image, z_index) nuclei_seg_args = {} @@ -125,10 +143,17 @@ def segment_nuclei( group_metadata = { "image-label": {"source": {"image": f"../../images/{image_key}"}} } - additional_metadata = label_metadata.get(key) if label_metadata else None + additional_metadata = label_metadata.get(key) if label_metadata else dict() storage_options = None if isinstance(label_data, np.ndarray): storage_options = {"chunks": image.data.chunksize[-2:]} + if timepoint_value is not None: + label_data = xr.DataArray( + get_namespace(label_data).expand_dims(label_data, 0), + dims=["t", "y", "x"], + coords={"t": [timepoint_value]}, + ) + additional_metadata.update(label_data.attrs) _write_zarr_labels( name=f"{image_key}-{key}", root=root, @@ -150,6 +175,7 @@ def segment_cells( method: Callable, root: Group, z_index: int | str, + timepoint: int | str, min_area: float | None = None, max_area: float | None = None, chunks: None | tuple[int, int] = None, @@ -160,7 +186,6 @@ def segment_cells( cell_segmentation_rolling_ball: bool = False, cell_segmentation_sigma: Optional[float] = None, closing_radius: Optional[int] = None, - cell_segmentation_t: Optional[list[int]] = None, force: bool = False, shrink_primary: bool = False, no_version: bool = False, @@ -178,6 +203,7 @@ def segment_cells( :param method: Segmentation method. :param root: Zarr hierarchy root. :param z_index: Either 'max' or z-index + :param timepoint: Time to use. :param min_area: Minimum area threshold for filtering labels. :param max_area: Maximum area threshold for filtering labels. :param chunks: Tuple specifying chunking size for Dask arrays. @@ -188,7 +214,6 @@ def segment_cells( :param cell_segmentation_rolling_ball: Use rolling ball mask for cell segmentation. :param cell_segmentation_sigma: Standard deviation for smoothing in cell segmentation. :param closing_radius: Radius for closing operation in cell segmentation. - :param cell_segmentation_t: List of timepoints to consider for cell segmentation. :param force: Whether to overwrite existing output :param shrink_primary: Whether to shrink primary labels. :param no_version: Whether to skip version/CLI information in output. @@ -197,7 +222,20 @@ def segment_cells( if not force and is_ome_zarr_array(root.get(f"labels/{image_key}-cell")): logger.info(f"Skipping cell segmentation for {image_key}") return root - image = _images2fov(file_list, metadata, dask=True).squeeze() + image = _images2fov( + file_list, + metadata, + concat_dims=("c",), + dask=True, + ) + + timepoint, timepoint_value = _get_timepoint_index_and_value(timepoint, image) + + image = ( + image[timepoint] + if isinstance(image, Sequence) + else image.isel(t=timepoint, missing_dims="ignore") + ) image = _z_projection(image, z_index) if cyto_channel is None: cyto_channel = np.delete(np.arange(image.sizes["c"]), dapi_channel) @@ -210,7 +248,14 @@ def segment_cells( if method.__name__ in ["segment_cells_watershed", "segment_cells_propagation"]: nuclei = read_ome_zarr_array( nuclei_image_root["labels"][image_key + "-nuclei"] - ).values + ).squeeze() + if ( + timepoint_value is not None + and nuclei.sizes.get("t", 0) > 0 + and "t" in nuclei.coords + ): + nuclei = nuclei.sel(t=timepoint_value) + nuclei = nuclei.values assert nuclei.shape == ( image.sizes["y"], image.sizes["x"], @@ -223,7 +268,6 @@ def segment_cells( cell_seg_args["rolling_ball"] = cell_segmentation_rolling_ball cell_seg_args["sigma"] = cell_segmentation_sigma cell_seg_args["closing_radius"] = closing_radius - cell_seg_args["t"] = cell_segmentation_t if method.__name__ == "segment_cells_watershed": cell_seg_args["watershed_method"] = watershed_method if chunks is not None: @@ -267,10 +311,17 @@ def segment_cells( group_metadata = { "image-label": {"source": {"image": f"../../images/{image_key}"}} } - additional_metadata = label_metadata.get(key) if label_metadata else None + additional_metadata = label_metadata.get(key) if label_metadata else dict() storage_options = None if isinstance(label_data, np.ndarray): storage_options = {"chunks": image.data.chunksize[-2:]} + if timepoint_value is not None: + label_data = xr.DataArray( + get_namespace(label_data).expand_dims(label_data, 0), + dims=["t", "y", "x"], + coords={"t": [timepoint_value]}, + ) + additional_metadata.update(label_data.attrs) _write_zarr_labels( name=f"{image_key}-{key}", root=root, @@ -332,7 +383,12 @@ def run_pipeline(arguments: argparse.Namespace, nuclei: bool): output_root = open_ome_zarr(output, mode="a") kwargs = dict() - + timepoint = arguments.time if arguments.time is not None else 0 + if isinstance(timepoint, str) and timepoint.isdigit(): + try: + timepoint = int(timepoint) + except ValueError: + pass if not nuclei: kwargs["nuclei_min_area"] = arguments.nuclei_min_area kwargs["nuclei_max_area"] = arguments.nuclei_max_area @@ -353,7 +409,6 @@ def run_pipeline(arguments: argparse.Namespace, nuclei: bool): "Please provide sigma for `local` threshold" ) - kwargs["cell_segmentation_t"] = arguments.cell_segmentation_t if method in ["watershed", "watershed-intensity", "propagation"]: nuclei_label = arguments.nuclei_label if nuclei_label is None: @@ -381,10 +436,12 @@ def run_pipeline(arguments: argparse.Namespace, nuclei: bool): kwargs["clip"] = arguments.stardist_clip kwargs["pmin"] = arguments.stardist_pmin kwargs["pmax"] = arguments.stardist_pmax + method = getattr( importlib.import_module("scallops.segmentation." + method), f"{'segment_nuclei_' if nuclei else 'segment_cells_'}{method}", ) + image_seq = from_sequence( _set_up_experiment(data_path, image_pattern, group_by, subset=subset) ) @@ -399,6 +456,7 @@ def run_pipeline(arguments: argparse.Namespace, nuclei: bool): root=output_root, min_area=min_area, max_area=max_area, + timepoint=timepoint, chunks=chunks, chunk_overlap=chunk_overlap, z_index=z_index, diff --git a/scallops/cli/segment_main.py b/scallops/cli/segment_main.py index fe06e6e..d6b513f 100644 --- a/scallops/cli/segment_main.py +++ b/scallops/cli/segment_main.py @@ -71,6 +71,10 @@ def _add_common_args(parser: ArgumentParser) -> None: default=0, help="Channel index (0-based) where DAPI is found", ) + parser.add_argument( + "--time", + help="Time index (0-based) or value.", + ) parser.add_argument( "--min-area", @@ -252,13 +256,6 @@ def _add_cell_parser(subparsers: ArgumentParser, default_help: bool = True) -> N type=int, ) - parser.add_argument( - "--time", - dest="cell_segmentation_t", - help="Time indices (0-based) to include when computing cell segmentation mask. Defaults to all time points.", - type=int, - action="append", - ) parser.add_argument( "--shrink-nuclei", help="Shrink nuclei prior to subtraction of nuclei from cells to identify the " diff --git a/scallops/cli/util.py b/scallops/cli/util.py index 0e05501..2266e89 100644 --- a/scallops/cli/util.py +++ b/scallops/cli/util.py @@ -24,6 +24,7 @@ import dask.array as da import fsspec import numpy as np +import pandas as pd import xarray as xr import zarr from distributed import Client @@ -450,3 +451,253 @@ def load_json(path_or_str: str) -> dict: with fs.open(path_or_str, "rt") as fp: return json.load(fp) return json.loads(path_or_str) + + +def _write_img_size(file_list: list[str]): + from scallops.io import _images2fov, _localize_path + + local_file_list = [] + cleanup_file_list = [] + for path in file_list: + local_path = _localize_path(path) + if local_path is not None: + cleanup_file_list.append(local_path) + local_file_list.append(local_path) + else: + local_file_list.append(path) + sizes = _images2fov(local_file_list, dask=True).sizes + for path in cleanup_file_list: + os.remove(path) + with open("img_size.txt", "wt") as f: + for dim in ["t", "c", "z", "y", "x"]: + s = sizes[dim] if dim in sizes else 0 + f.write(f"{s}") + f.write("\n") + + +def _n_files_in_group(metadata: dict) -> int: + n_tiles = len(metadata["file_metadata"]) + metadata_fields = [v for v in ("c", "z") if v in metadata["file_metadata"][0]] + if len(metadata_fields) > 0: + from scallops.cli.util import _group_src_attrs + + keys, channel_sources, filepaths = _group_src_attrs( + metadata=metadata, metadata_fields=tuple(metadata_fields) + ) + n_tiles = len(filepaths) + return n_tiles + + +def _list_images_wdl( + image_pattern1: str, + urls1: list[str], + reference_time1: str | None, + n_cycles1: str | None, + image_pattern2: str, + urls2: list[str], + reference_time2: str | None, + n_cycles2: str | None, + groupby: list[str], + subset: list[str] | None, + batch_size: str | None, + save_group_size: bool = False, +): + """Used by WDL workflow to output info about images""" + + urls1 = [url.strip() for url in urls1 if url.strip() != ""] + reference_time1 = None if reference_time1 == "" else reference_time1 + n_cycles1 = int(n_cycles1) if n_cycles1 is not None and n_cycles1 != "" else None + + urls2 = [url.strip() for url in urls2 if url.strip() != ""] + reference_time2 = None if reference_time2 == "" else reference_time2 + n_cycles2 = int(n_cycles2) if n_cycles2 is not None and n_cycles2 != "" else None + batch_size = int(batch_size) if batch_size is not None and batch_size != "" else 1 + + if subset is not None and ( + len(subset) == 0 or (len(subset) == 1 and subset[0] == "") + ): + subset = None + if image_pattern1 != "": + groupby1 = [g for g in groupby if "{" + g + "}" in image_pattern1] + if image_pattern2 != "": + groupby2 = [g for g in groupby if "{" + g + "}" in image_pattern2] + if len(urls1) > 0 and len(urls2) > 0: + groupby = groupby1 + assert groupby1 == groupby2 + elif len(urls1) > 0: + groupby = groupby1 + elif len(urls2) > 0: + groupby = groupby2 + else: + raise ValueError() + + result1 = _list_images( + urls=urls1, + image_pattern=image_pattern1, + groupby=groupby, + subset=subset, + save_group_size=save_group_size, + reference_time=reference_time1, + n_cycles=n_cycles1, + ) + result2 = _list_images( + urls=urls2, + image_pattern=image_pattern2, + groupby=groupby, + subset=subset, + save_group_size=save_group_size, + reference_time=reference_time2, + n_cycles=n_cycles2, + ) + results = [result1, result2] + + if len(urls1) > 0 and len(urls2) > 0: + df1 = result1["subset_df"] + df2 = result2["subset_df"] + subset_ids = df1.index.intersection(df2.index) + result1["subset_df"] = df1.loc[subset_ids] + result2["subset_df"] = df2.loc[subset_ids] + elif len(urls1) > 0: + subset_ids = result1["subset_df"].index.values + elif len(urls2) > 0: + subset_ids = result2["subset_df"].index.values + + with open("subsets.txt", "wt") as f: # ["plate1-A1", "plate1-A2", ...] + for i in range(0, len(subset_ids), batch_size): + selected = subset_ids[i : i + batch_size] + f.write(" ".join(selected)) + f.write("\n") + + with open("groupby_array.txt", "wt") as f: # ['plate', 'well'] + for g in groupby: + f.write(g) + f.write("\n") + with open("groupby_pattern.txt", "wt") as f: # "{plate}-{well}" + first = True + for g in groupby: + if not first: + f.write("-") + first = False + f.write("{") + f.write(g) + f.write("}") + for index in range(len(results)): + result = results[index] + url_val = index + 1 + subset_df = result["subset_df"] + + with open( + f"groupby_array_with_time_{url_val}.txt", "wt" + ) as f: # ['plate', 'well', 't'] + for g in result["groupby_with_time"]: + f.write(g) + f.write("\n") + + with open(f"group_size_{url_val}.txt", "wt") as f: + f.write(f"{result['group_size']}") + f.write("\n") + + with open(f"reference_time_{url_val}.txt", "wt") as f: # IF + f.write(result["reference_time"]) + f.write("\n") + with open(f"times_{url_val}.txt", "wt") as f: # ["FISH", "IF"] + if result["times"] is not None: + for val in result["times"]: + f.write(str(val)) + f.write("\n") + + with open( + f"subsets_with_reference_time_{url_val}.txt", "wt" + ) as f: # ["plate1-A1-IF", "plate1-A2-IF", ...] + subset_ids_with_reference_times = ( + subset_df["subset_ids_with_reference_times"].values + if subset_df is not None + else [] + ) + for i in range(0, len(subset_ids_with_reference_times), batch_size): + selected = subset_ids_with_reference_times[i : i + batch_size] + f.write(" ".join(selected)) + f.write("\n") + + with open( + f"image_pattern_with_reference_time_{url_val}.txt", "wt" + ) as f: # "{plate}-{well}-IF" + first = True + for g in groupby: + if not first: + f.write("-") + first = False + f.write("{") + f.write(g) + f.write("}") + f.write(f"-{result['reference_time']}") + + +def _list_images( + urls: Sequence[str], + image_pattern: str, + groupby: list[str], + subset: list[str] | None, + save_group_size: bool, + reference_time: str | None, + n_cycles: int | None, +): + reference_time_suffix = "" + group_size = 0 + if len(urls) == 0: + return dict( + group_size=group_size, + subset_df=None, + times=None, + reference_time_suffix=reference_time_suffix, + ) + from scallops.io import _set_up_experiment + + exp_gen = _set_up_experiment( + image_path=urls, files_pattern=image_pattern, group_by=groupby, subset=subset + ) + + groupby_t = "t" in groupby + times = None + subset_ids = [] + subset_ids_with_reference_times = [] + first = True + + for g, file_list, metadata in exp_gen: + times = None + if first: + first = False + if save_group_size: + group_size = _n_files_in_group(metadata) + if not groupby_t and "t" in metadata["file_metadata"][0]: + times = [md["t"] for md in metadata["file_metadata"]] + if n_cycles is not None: + assert len(times) == n_cycles + t_suffix = "" + if times is not None and len(times) > 0: + t_suffix = ( + f"-{times[0]}" if reference_time is None else f"-{reference_time}" + ) + + subset_ids.append('"' + metadata["id"] + '"') + subset_ids_with_reference_times.append('"' + metadata["id"] + t_suffix + '"') + subset_df = pd.DataFrame( + index=subset_ids, + data=dict(subset_ids_with_reference_times=subset_ids_with_reference_times), + ) + + groupby_with_time = list(groupby) + if not groupby_t and times is not None: + groupby_with_time.append("t") + + if reference_time is None: + reference_time = times[0] if times is not None and len(times) > 0 else "0" + + return dict( + group_size=group_size, + subset_df=subset_df, + groupby=groupby, + groupby_with_time=groupby_with_time, + times=times, + reference_time=reference_time, + ) diff --git a/scallops/features/find_objects.py b/scallops/features/find_objects.py index 9c4f175..b315be1 100644 --- a/scallops/features/find_objects.py +++ b/scallops/features/find_objects.py @@ -254,7 +254,7 @@ def _agg_objects(grouped): return objects_df -def find_objects(label_image: da.Array) -> dd.DataFrame: +def find_objects(label_image: da.Array | zarr.Array) -> dd.DataFrame: """Find objects in a labeled array. :param label_image: Image labels noted by integers. diff --git a/scallops/features/generate.py b/scallops/features/generate.py index 59b9234..df20ba7 100644 --- a/scallops/features/generate.py +++ b/scallops/features/generate.py @@ -16,6 +16,7 @@ import numpy as np import pandas as pd import skimage.util +import xarray as xr import zarr from dask import delayed @@ -123,8 +124,8 @@ def _create_dd_metadata( def label_features( objects_df: pd.DataFrame, - label_image: da.Array | zarr.Array, - intensity_image: da.Array | zarr.Array | Sequence[zarr.Array] | None, + label_image: da.Array | xr.DataArray | zarr.Array, + intensity_image: da.Array | zarr.Array | xr.DataArray | Sequence[zarr.Array] | None, features: Iterable[str], channel_names: dict[int | str, str] | None = None, normalize: bool = True, @@ -133,7 +134,7 @@ def label_features( ) -> dd.DataFrame | pd.DataFrame: """Extract features from labeled regions in the image. - :param objects_df: Data frame containing labeled regions from `find_objects`. + :param objects_df: Data frame containing labeled regions from `find_objects` with frame index set to label id. :param label_image: Labeled regions. :param intensity_image: Intensity image with dimensions (y, x, c) or zarr array(s) with dimensions with leading dimensions unrolled to channel dimension @@ -146,6 +147,10 @@ def label_features( :return: DataFrame with extracted features. """ is_numpy = False + if isinstance(label_image, xr.DataArray): + label_image = label_image.data + if isinstance(intensity_image, xr.DataArray): + intensity_image = intensity_image.data if isinstance(label_image, np.ndarray): is_numpy = True label_image = da.from_array(label_image) diff --git a/scallops/io.py b/scallops/io.py index f58c16c..91f4343 100644 --- a/scallops/io.py +++ b/scallops/io.py @@ -1277,8 +1277,10 @@ def _images2fov( if image is None: raise ValueError(f"{file_list[i]} could not be read.") if file_metadata is not None: - if "t" in file_metadata[i] and isinstance( - file_metadata[i]["t"], str + if ( + "t" in file_metadata[i] + and isinstance(file_metadata[i]["t"], str) + and file_metadata[i]["t"].isdigit() ): # convert to int try: # note we replace "_" with "-" so we don't convert "1_2" to 12 e.g. diff --git a/scallops/tests/miniwdl_local/local_runner.py b/scallops/tests/miniwdl_local/local_runner.py index c9b1392..6207931 100644 --- a/scallops/tests/miniwdl_local/local_runner.py +++ b/scallops/tests/miniwdl_local/local_runner.py @@ -15,8 +15,13 @@ def global_init(cls, cfg, logger): """ Perform any necessary process-wide initialization of the container backend """ + cpu_count = os.environ.get("SCALLOPS_MINIWDL_CPU") + if cpu_count is None: + cpu_count = psutil.cpu_count() + else: + cpu_count = int(cpu_count) cls._resource_limits = { - "cpu": psutil.cpu_count(), + "cpu": cpu_count, "mem_bytes": psutil.virtual_memory().total, } diff --git a/scallops/tests/test_register_cli.py b/scallops/tests/test_register_cli.py index d2cf5b8..14a6118 100644 --- a/scallops/tests/test_register_cli.py +++ b/scallops/tests/test_register_cli.py @@ -118,7 +118,7 @@ def test_register_itk_cli_known_shift(tmp_path): "scallops", "registration", "elastix", - "--time", + "--moving-time", "1", "--moving", str(data_path), @@ -147,10 +147,11 @@ def test_register_itk_cli_t_reference(tmp_path, array_A1_102_nuclei): tmp_path, "registration-input.zarr" ) exp = Experiment() - reference_t = 2 + + reference_t_index = 2 test_t = 10 array_A1_102_nuclei = array_A1_102_nuclei.squeeze() - exp.labels[f"A1-102-{reference_t}-mask"] = array_A1_102_nuclei + exp.labels["A1-102-nuclei"] = array_A1_102_nuclei exp.save(registration_input_moving_labels_path) cmd = [ @@ -179,8 +180,8 @@ def test_register_itk_cli_t_reference(tmp_path, array_A1_102_nuclei): registration_input_moving_labels_path, "--label-output", elastix_output_dir, - "--time", - str(reference_t), + "--moving-time", + str(reference_t_index), ] subprocess.check_call(cmd) result_exp = read_experiment(elastix_output_dir) @@ -195,16 +196,24 @@ def test_register_itk_cli_t_reference(tmp_path, array_A1_102_nuclei): .images["A1-102"] .squeeze() ) - assert len(result_exp.labels.keys()) == 8 + times = list(original_image.t.values) + del times[reference_t_index] + times = [str(t) for t in times] + transformed_times = list(result_exp.labels["A1-102-nuclei"].coords["t"].values) + transformed_times = [str(t) for t in transformed_times] + assert times == transformed_times, ( + f"{', '.join(times)} != {', '.join(transformed_times)}" + ) + assert len(result_exp.labels.keys()) == 1 np.testing.assert_array_equal(transformed_image.t.values, original_image.t.values) np.testing.assert_array_equal(transformed_image.c.values, original_image.c.values) np.testing.assert_array_equal( - transformed_image.isel(t=reference_t), - original_image.isel(t=reference_t), + transformed_image.isel(t=reference_t_index), + original_image.isel(t=reference_t_index), err_msg="Reference t not equal via CLI", ) for t in range(original_image.sizes["t"]): - if t != reference_t: + if t != reference_t_index: with np.testing.assert_raises(AssertionError): np.testing.assert_array_equal( transformed_image.isel(t=t), original_image.isel(t=t) @@ -232,8 +241,9 @@ def test_register_itk_cli_t_reference(tmp_path, array_A1_102_nuclei): image_spacing=(1, 1), ) assert warped_labels.min() == 0 + np.testing.assert_array_equal( - result_exp.labels[f"A1-102-{test_t}-mask"].values, + result_exp.labels["A1-102-nuclei"].sel(t=str(test_t)).values, warped_labels, err_msg=f"t {test_t} labels not equal using itk_transform_labels and CLI", ) @@ -245,7 +255,7 @@ def test_register_itk_cli_t_reference(tmp_path, array_A1_102_nuclei): moving_channel=[0], parameter_object=parameter_object, moving_image_spacing=(1, 1), - reference_timepoint=reference_t, + reference_timepoint=reference_t_index, ) xr.testing.assert_equal(result_np, transformed_image) @@ -308,19 +318,25 @@ def test_register_transform_labels_moving_only(tmp_path): output_zarr = tmp_path / "out.zarr" output_transforms = tmp_path / "transforms" - img = read_image( - "scallops/tests/data/experimentC/10X_c0-DAPI-p65ab/10X_c0-DAPI-p65ab_A1_Tile-102.phenotype.tif" - ) - img.attrs["physical_pixel_sizes"] = (1, 1) - rng = np.random.default_rng(0) + img = rng.integers(low=0, high=10, size=(2, 2, 100, 100)) + img = xr.DataArray( + img, + dims=["t", "c", "y", "x"], + coords={"t": ["IF", "FISH"]}, + attrs={"physical_pixel_sizes": (1, 1)}, + ) - segmentation = rng.integers(low=0, high=10, size=(img.sizes["y"], img.sizes["x"])) + segmentation = rng.integers(low=0, high=10, size=(100, 100)) Experiment( - images={"plateA-A1-IF": img, "plateA-A1-FISH": img}, + images={"plateA-A1": img}, labels={ - "plateA-A1-IF-cell": segmentation, + "plateA-A1-cell": xr.DataArray( + np.expand_dims(segmentation, 0), + dims=["t", "y", "x"], + coords={"t": ["IF"]}, + ), }, ).save(image_zarr) cmd = [ @@ -330,7 +346,7 @@ def test_register_transform_labels_moving_only(tmp_path): "--moving", str(image_zarr), "--moving-image-pattern", - "{plate}-{well}-{t}", + "{plate}-{well}", "--moving-label", str(image_zarr), "--subset", @@ -344,7 +360,7 @@ def test_register_transform_labels_moving_only(tmp_path): "--label-output", str(output_zarr), "--output-aligned-channels-only", - "--time", + "--moving-time", "IF", "--transform-output", str(output_transforms), @@ -352,7 +368,8 @@ def test_register_transform_labels_moving_only(tmp_path): create_itk_param_file(tmp_path), ] subprocess.check_call(cmd) - transformed_labels = read_image(output_zarr / "labels" / "plateA-A1-FISH-cell") + transformed_labels = read_image(output_zarr / "labels" / "plateA-A1-cell") + assert list(transformed_labels.coords["t"].values) == ["FISH"] assert transformed_labels.max() > 0 transformed_image = read_image(output_zarr / "images" / "plateA-A1") assert transformed_image.shape[0] == 2 @@ -384,6 +401,9 @@ def _warp(img): dims=["c", "y", "x"], ) moving_image = moving_image.isel(c=[0, 1, 2]) + moving_image.coords["c"] = ["a", "b", "c"] + moving_image.attrs["physical_pixel_sizes"] = (2, 2) + moving_image.attrs["physical_pixel_units"] = ("micrometer", "micrometer") moving_labels = array_A1_102_nuclei.squeeze().values moving_labels = warp(moving_labels, st, order=0, preserve_range=True) @@ -396,9 +416,6 @@ def _warp(img): ), dims=["y", "x"], ) - moving_image.coords["c"] = ["a", "b", "c"] - moving_image.attrs["physical_pixel_sizes"] = (2, 2) - moving_image.attrs["physical_pixel_units"] = ("micrometer", "micrometer") fixed_image = xr.DataArray( resize( diff --git a/scallops/tests/test_wdl.py b/scallops/tests/test_wdl.py index 0e1b2ee..96786dd 100644 --- a/scallops/tests/test_wdl.py +++ b/scallops/tests/test_wdl.py @@ -10,6 +10,7 @@ import xarray as xr from scallops import Experiment +from scallops.cli.util import _list_images_wdl from scallops.io import read_image, save_ome_tiff from scallops.tests.test_stitch import _write_image_with_position @@ -27,6 +28,42 @@ def add_physical_size(input_path, output_path): save_ome_tiff(img.values, uri=output_path, ome_xml=img.attrs["processed"].to_xml()) +@pytest.mark.parametrize("reference_time", ["IF", None]) +@pytest.mark.cli_e2e +def test_list_images_wdl(reference_time, tmp_path, monkeypatch): + (tmp_path / "plate1-A1-IF").touch() + (tmp_path / "plate1-A1-FISH").touch() + monkeypatch.chdir(tmp_path) + _list_images_wdl( + image_pattern="{plate}-{well}-{t}", + reference_time=reference_time, + urls=[str(tmp_path)], + groupby=["plate", "well"], + subset=None, + batch_size_str=None, + save_group_size=False, + expected_cycles_str=None, + ) + groups = pd.read_csv(tmp_path / "groups.txt", header=None)[0].values + np.testing.assert_array_equal(groups, ["plate1-A1"]) + groupby = pd.read_csv(tmp_path / "groupby.txt", header=None)[0].values + np.testing.assert_array_equal(groupby, ["plate", "well"]) + groupby_pattern = pd.read_csv(tmp_path / "groupby_pattern.txt", header=None)[ + 0 + ].values + np.testing.assert_array_equal(groupby_pattern, ["{plate}-{well}"]) + + times = pd.read_csv(tmp_path / "t.txt", header=None)[0].values + np.testing.assert_array_equal(times, ["FISH", "IF"]) + groups_with_times = pd.read_csv(tmp_path / "groups_with_t.txt", header=None)[ + 0 + ].values + if reference_time == "IF": + np.testing.assert_array_equal(groups_with_times, ["plate1-A1-IF"]) + else: + np.testing.assert_array_equal(groups_with_times, ["plate1-A1-FISH"]) + + @pytest.mark.cli_e2e def test_stitch_wdl_z_stack(tmp_path): input_path = tmp_path / "input" @@ -120,13 +157,15 @@ def test_stitch_wdl(tmp_path): @pytest.mark.cli_e2e def test_ops_wdl(tmp_path): sbs_dir = tmp_path / "sbs" - pheno_dir = tmp_path / "pheno" output = tmp_path / "out" + pheno_dir = tmp_path / "pheno.zarr" sbs_dir.mkdir() - pheno_dir.mkdir() output.mkdir() for p in glob.glob("scallops/tests/data/experimentC/input/*/*Tile-102*"): - add_physical_size(p, str(sbs_dir / os.path.basename(p))) + cycles = os.path.basename(p).split("_")[1] + cycles = cycles.split("-")[0] + dest = f"plateA-A1-{cycles[1:]}.tif" + add_physical_size(p, str(sbs_dir / dest)) pheno_img = read_image( "scallops/tests/data/experimentC/10X_c0-DAPI-p65ab/10X_c0-DAPI-p65ab_A1_Tile-102.phenotype.tif" @@ -135,45 +174,43 @@ def test_ops_wdl(tmp_path): phenotype_mask = np.ones( (pheno_img.sizes["y"], pheno_img.sizes["x"]), dtype=np.uint8 ) - phenotype_mask[10, 10] = 1 + phenotype_mask[10, 10] = 0 phenotype_tile = np.ones( (pheno_img.sizes["y"], pheno_img.sizes["x"]), dtype=np.uint16 ) phenotype_tile[10, 10] = 2 - exp = Experiment( - images={"A1-102-1": pheno_img, "A1-102-2": pheno_img}, + Experiment( + images={"plateA-A1-IF": pheno_img, "plateA-A1-FISH": pheno_img}, labels={ - "A1-102-1-mask": phenotype_mask, - "A1-102-1-tile": phenotype_tile, - "A1-102-2-mask": phenotype_mask, - "A1-102-2-tile": phenotype_tile, + "plateA-A1-IF-mask": phenotype_mask, + "plateA-A1-IF-tile": phenotype_tile, + "plateA-A1-FISH-mask": phenotype_mask, + "plateA-A1-FISH-tile": phenotype_tile, }, - ) - exp.save(str(pheno_dir)) + ).save(pheno_dir) input_json = { "model_dir": "", "iss_url": str(sbs_dir.absolute()), - "iss_image_pattern": "{mag}X_c{t}-{experiment}-{t}_{well}_Tile-{tile}.{datatype}.tif", + "iss_image_pattern": "{plate}-{well}-{t}.tif", "output_directory": str(output.absolute()), "iss_registration_extra_arguments": "--no-landmarks", "pheno_to_iss_registration_extra_arguments": "--no-landmarks", "pheno_registration_extra_arguments": "--no-landmarks", "phenotype_cyto_channel": [1], - "phenotype_dapi_channel": 0, + "reference_phenotype_time": "IF", "phenotype_url": str(pheno_dir.absolute()), - "phenotype_nuclei_features": ["intensity_0", "intensity_1"], + "phenotype_nuclei_features": { + "IF": ["intensity_0", "intensity_1"], + "FISH": ["intensity_0", "intensity_1"], + }, # 2 batches - "phenotype_cell_features": ["intensity_0"], + "phenotype_cell_features": {"IF": ["intensity_0"]}, # "phenotype_cytosol_features": ["mean_0 area"], # no cytosol features - "phenotype_image_pattern": "{well}-{tile}-{t}", - "groupby": ["well", "tile"], "reads_threshold_peaks": "0", "reads_threshold_peaks_crosstalk": "20", "barcodes": os.path.abspath("scallops/tests/data/experimentC/barcodes.csv"), - "mark_stitch_boundary_cells": False, "reads_labels": "cell", - "merge_extra_arguments": "--format parquet", "docker": "", } diff --git a/scallops/utils.py b/scallops/utils.py index f015589..6663f8d 100644 --- a/scallops/utils.py +++ b/scallops/utils.py @@ -637,130 +637,3 @@ def _dask_from_array_no_copy( meta = x return da.Array(dsk, name, chunks, meta=meta, dtype=getattr(x, "dtype", None)) - - -def _write_img_size(file_list: list[str]): - from scallops.io import _images2fov, _localize_path - - local_file_list = [] - cleanup_file_list = [] - for path in file_list: - local_path = _localize_path(path) - if local_path is not None: - cleanup_file_list.append(local_path) - local_file_list.append(local_path) - else: - local_file_list.append(path) - sizes = _images2fov(local_file_list, dask=True).sizes - for path in cleanup_file_list: - os.remove(path) - with open("img_size.txt", "wt") as f: - for dim in ["t", "c", "z", "y", "x"]: - s = sizes[dim] if dim in sizes else 0 - f.write(f"{s}") - f.write("\n") - - -def _write_group_size(metadata: dict): - n_tiles = len(metadata["file_metadata"]) - metadata_fields = [v for v in ("c", "z") if v in metadata["file_metadata"][0]] - if len(metadata_fields) > 0: - from scallops.cli.util import _group_src_attrs - - keys, channel_sources, filepaths = _group_src_attrs( - metadata=metadata, metadata_fields=tuple(metadata_fields) - ) - n_tiles = len(filepaths) - with open("group_size.txt", "wt") as f: - f.write(f"{n_tiles}") - f.write("\n") - - -def _list_images_wdl( - image_pattern: str, - urls: list[str], - groupby: list[str], - subset: list[str], - batch_size_str: str, - save_group_size: bool = False, - expected_cycles_str: int | None = None, -): - """Used by WDL workflow to output info about images""" - from scallops.io import _set_up_experiment - - batch_size = 0 - expected_cycles = None - if expected_cycles_str != "": - expected_cycles = int(expected_cycles_str) - if batch_size_str != "": - batch_size = int(batch_size_str) - - if len(subset) == 0 or (len(subset) == 1 and subset[0] == ""): - subset = None - if image_pattern != "": - groupby = [g for g in groupby if "{" + g + "}" in image_pattern] - exp_gen = _set_up_experiment( - image_path=urls, files_pattern=image_pattern, group_by=groupby, subset=subset - ) - # "groups.txt" is passed to --subset in cli - # "groupby.txt" filtered groupby - groupby_t = "t" in groupby - t = [] - - if not save_group_size: - with open("group_size.txt", "wt") as f: - f.write("0\n") - if batch_size > 0: - with open("groups.txt", "wt") as f: - ids = [] - first = True - for g, file_list, metadata in exp_gen: - if first: - first = False - if save_group_size: - _write_group_size(metadata) - if not groupby_t and "t" in metadata["file_metadata"][0]: - t = [md["t"] for md in metadata["file_metadata"]] - if expected_cycles is not None: - assert len(t) == expected_cycles - - ids.append('"' + metadata["id"] + '"') - if len(ids) == batch_size: - f.write(" ".join(ids)) - f.write("\n") - ids = [] - if len(ids) > 0: - f.write(" ".join(ids)) - f.write("\n") - else: - with open("groups.txt", "wt") as f: - first = True - for g, file_list, metadata in exp_gen: - f.write(metadata["id"]) - f.write("\n") - if first: - first = False - if save_group_size: - _write_group_size(metadata) - if not groupby_t and "t" in metadata["file_metadata"][0]: - t = [md["t"] for md in metadata["file_metadata"]] - - with open("groupby.txt", "wt") as f: - for g in groupby: - f.write(g) - f.write("\n") - - with open("t.txt", "wt") as f: - for val in t: - f.write(str(val)) - f.write("\n") - - with open("groupby_pattern.txt", "wt") as f: - first = True - for g in groupby: - if not first: - f.write("-") - first = False - f.write("{") - f.write(g) - f.write("}") diff --git a/scallops/zarr_io.py b/scallops/zarr_io.py index 9f04010..56acd35 100644 --- a/scallops/zarr_io.py +++ b/scallops/zarr_io.py @@ -285,6 +285,66 @@ def _attrs_axes_coordinates( return image_attrs, axes, coordinate_transformations +def _write_zarr_labels( + name: str, + root: zarr.Group | str | Path, + labels: np.ndarray | xr.DataArray | da.Array, + metadata: dict[str, Any] | None = None, + group_metadata: dict[str, Any] | None = None, + compute: bool = True, + storage_options: JSONDict | None = None, +) -> list[Delayed]: + """Write label in zarr format. + + :param name: Zarr group name to store label + :param root: Root zarr group. + :param labels: Labels to write. + :param metadata: Optional label metadata. + :param group_metadata: Optional group level metadata. + :param compute: If true compute immediately otherwise a list + of :class:`dask.delayed.Delayed` is returned. + :param storage_options: Optional storage options. + :return: Empty list if the compute flag is True, otherwise it returns a list + of :class:`dask.delayed.Delayed` representing the value to be computed by dask. + """ + + # stored in labels/key + if isinstance(root, (str, Path)): + root = open_ome_zarr(root, mode="a") + + labels_grp = root.require_group("labels", overwrite=False) + dest_grp = labels_grp.create_group(name.replace("/", "-"), overwrite=True) + + label_attrs = None + coords = None + dims = ["y", "x"] + if isinstance(labels, xr.DataArray): + data = labels.data + label_attrs = labels.attrs.copy() + coords = labels.coords + dims = labels.dims + else: + data = labels + + # need 'image-label' attr to be recognized as label + group_metadata = group_metadata.copy() if group_metadata is not None else dict() + if "image-label" not in group_metadata: + group_metadata["image-label"] = {} + metadata = metadata.copy() if metadata is not None else {} + + return write_zarr( + grp=dest_grp, + data=data, + image_attrs=label_attrs, + coords=coords, + dims=dims, + metadata=metadata, + zarr_format="ome_zarr", + compute=compute, + storage_options=storage_options, + ) + + def _write_zarr_image( name: str | None, root: zarr.Group | str | Path, @@ -346,6 +406,7 @@ def write_zarr( metadata: dict[str, Any] | None = None, zarr_format: Literal["ome_zarr", "zarr"] = "ome_zarr", compute: bool = True, + storage_options: JSONDict | None = None, ) -> list[Delayed]: """Write data to a Zarr group with optional metadata and scaling. @@ -366,6 +427,7 @@ def write_zarr( :param compute: If True, compute immediately. Otherwise, return a list of dask.delayed. Delayed objects representing the value to be computed by dask. Default is True. + :param storage_options: Optional storage options. :return: Empty list if the compute flag is True, otherwise a list of dask.delayed.Delayed objects. @@ -408,19 +470,26 @@ def write_zarr( dask_delayed = [] fmt = _current_format() if zarr_format == "zarr": # No axis validation + chunks_opt = None + if storage_options is not None: + chunks_opt = storage_options.pop("chunks", None) if isinstance(data, da.Array): d = da.to_zarr( arr=data, url=grp.store, component=str(Path(grp.path, "0")), compute=compute, + storage_options=storage_options, **_da_to_zarr_kwargs(fmt), ) if not compute: dask_delayed.append(d) elif not isinstance(data, zarr.Array): + kwds = _da_to_zarr_kwargs(fmt) + if storage_options is not None: + kwds.update(storage_options) grp.create_dataset( - "0", data=data, overwrite=True, **_da_to_zarr_kwargs(fmt) + "0", data=data, overwrite=True, chunks=chunks_opt, **kwds ) datasets = [{"path": "0"}] @@ -470,6 +539,7 @@ def _write_metadata_delayed(grp, d): if coordinate_transformations is not None else None ), + storage_options=storage_options, ) @@ -509,68 +579,6 @@ def rechunk(image: xr.DataArray | da.Array) -> xr.DataArray | da.Array: return image -def _write_zarr_labels( - name: str, - root: zarr.Group | str | Path, - labels: np.ndarray | xr.DataArray | da.Array, - metadata: dict[str, Any] = None, - group_metadata: dict[str, Any] = None, - compute: bool = True, - storage_options: JSONDict | None = None, -) -> list[Delayed]: - """Write label in zarr format. - - :param name: Zarr group name to store label - :param root: Root zarr group. - :param labels: Labels to write. - :param metadata: Optional label metadata. - :param group_metadata: Optional group level metadata. - :param compute: If true compute immediately otherwise a list - of :class:`dask.delayed.Delayed` is returned. - :param storage_options: Optional storage options. - :return: Empty list if the compute flag is True, otherwise it returns a list - of :class:`dask.delayed.Delayed` representing the value to be computed by dask. - """ - - # stored in labels/key - name = name.replace("/", "-") - if isinstance(root, (str, Path)): - root = open_ome_zarr(root, mode="a") - labels_grp = root.require_group("labels") - grp = labels_grp.create_group(name, overwrite=True) - if not isinstance(labels, xr.DataArray): - if labels.ndim == 2: - label_axes = ["y", "x"] - elif labels.ndim == 5: - label_axes = ["t", "c", "z", "y", "x"] - else: - raise ValueError("Axes can't be inferred for 3D or 4D labels") - else: - label_axes = labels.dims - labels = labels.data - - # need 'image-label' attr to be recognized as label - group_metadata = group_metadata.copy() if group_metadata is not None else dict() - if "image-label" not in group_metadata: - group_metadata["image-label"] = {} - grp.attrs.update(group_metadata) - metadata = metadata.copy() if metadata is not None else {} - if isinstance(labels, da.Array) or ( - isinstance(labels, xr.DataArray) and isinstance(labels.data, da.Array) - ): - labels = rechunk(labels) - return write_image( - labels, - grp, - scaler=None, - # scale_factors=[], - axes=label_axes, - metadata=metadata, - compute=compute, - storage_options=storage_options, - ) - - def _read_zarr_attrs(attrs) -> tuple[dict, dict, list[str]]: """Read attributes from Zarr. diff --git a/wdl/ops_tasks.wdl b/wdl/ops_tasks.wdl index f17dc52..2a35a24 100644 --- a/wdl/ops_tasks.wdl +++ b/wdl/ops_tasks.wdl @@ -1,4 +1,4 @@ -version 1.0 +version 1.1 task segment_nuclei { input { @@ -7,6 +7,7 @@ task segment_nuclei { String? image_pattern Array[String] groupby Int? dapi_channel + String? time String output_directory String subset Boolean? force @@ -33,6 +34,7 @@ task segment_nuclei { --groupby ~{sep=" " groupby} \ ~{if defined(image_pattern) then '--image-pattern "' + image_pattern + '"' else ''} \ ~{'--dapi-channel ' + dapi_channel} \ + ~{'--time ' + time} \ --output "~{output_directory}" \ --subset ~{subset} \ ~{if defined(extra_arguments) then extra_arguments else ''} \ @@ -64,6 +66,7 @@ task segment_cell { Array[String] groupby Int? dapi_channel Array[Int] cyto_channel + String? time Int? chunks String? nuclei_label String? threshold @@ -94,6 +97,7 @@ task segment_cell { --groupby ~{sep=" " groupby} \ ~{if defined(image_pattern) then '--image-pattern "' + image_pattern + '"' else ''} \ ~{'--dapi-channel ' + dapi_channel} \ + ~{'--time ' + time} \ --cyto-channel ~{sep=" " cyto_channel} \ ~{"--nuclei-label " + nuclei_label} \ ~{"--method " + method} \ @@ -134,7 +138,8 @@ task register_elastix { Boolean? output_aligned_channels_only Boolean? unroll_channels String? fixed - String? reference_time + String? moving_time + String? fixed_time Int? fixed_channel Boolean? register_across_channels String transform_output_directory @@ -174,7 +179,8 @@ task register_elastix { --subset ~{subset} \ ~{if defined(label_output_directory) then '--label-output "' + label_output_directory + '"' else ''} \ ~{true="--unroll-channels" false="" unroll_channels} \ - ~{if defined(reference_time) then '--time "' + reference_time + '"' else ''} \ + ~{if defined(moving_time) then '--moving-time "' + moving_time + '"' else ''} \ + ~{if defined(fixed_time) then '--fixed-time "' + fixed_time + '"' else ''} \ ~{true="--force" false="" force} \ ~{true="--align-across-channels" false="" register_across_channels} \ ~{true="--output-aligned-channels-only" false="" output_aligned_channels_only} \ @@ -266,16 +272,17 @@ task register_pheno_to_iss_qc { } -task register_qc { +task register_iss_iss_qc { input { String images String? image_pattern String label_type String labels - Int channel + Int dapi_channel + Int n_timepoints String subset String output_directory - String channel_prefix + Array[String] groupby Boolean? force @@ -288,6 +295,7 @@ task register_qc { String memory Int max_retries } + Int n_channels = n_timepoints*5 command <<< set -ex @@ -296,58 +304,17 @@ task register_qc { ulimit -n 100000 fi - python < 0: - cmd.append("--subset") - cmd += subset - cmd += ["--output", output_directory] - cmd += ["--images", images] - cmd += ["--channel-rename", f"{json.dumps(channel_rename)}"] - - if force == "true": - cmd.append("--force") - print(" ".join(cmd)) - check_call(cmd) - - CODE + scallops features \ + --features-~{label_type} "correlationpearsonbox_~{dapi_channel}_5:~{n_channels}:5" \ + --labels "~{labels}" \ + --groupby ~{sep=" " groupby} \ + --subset ~{subset} \ + --output "~{output_directory}" \ + --images "~{images}" \ + ~{'--image-pattern ' + image_pattern} \ + ~{true="--force" false="" force} + + >>> @@ -373,7 +340,7 @@ task intersects_boundary { String images String? image_pattern String label_type - String labels + Array[String] labels String subset String? objects String output_directory @@ -399,12 +366,12 @@ task intersects_boundary { scallops features \ --features-~{label_type} "intersects-boundary_0" \ - --labels "~{labels}" \ + --labels ~{sep=" " labels} \ --groupby ~{sep=" " groupby} \ - --subset ~{subset} \ + --subset "~{subset}" \ --output "~{output_directory}" \ --images "~{images}" \ - --objects "~{objects}" \ + --merge "~{objects}" \ --no-normalize \ ~{'--image-pattern ' + image_pattern} \ ~{true="--force" false="" force} @@ -430,7 +397,7 @@ task intersects_boundary { task find_objects { input { - String? labels + Array[String] labels String subset Boolean? force String? label_pattern @@ -454,7 +421,7 @@ task find_objects { fi scallops find-objects \ - --labels "~{labels}" \ + --labels ~{sep=" " labels} \ --subset ~{subset} \ ~{"--label-pattern " + label_pattern} \ --label-suffix ~{suffix} \ @@ -493,8 +460,8 @@ task features { Int? cytosol_max_area String? features_extra_arguments String? model_dir - String? labels - String? objects + Array[String] labels + String? merge String images String subset Boolean? force @@ -531,8 +498,8 @@ task features { ~{if defined(cell_max_area) && select_first([cell_max_area])>0 then '--cell-max-area ' + cell_max_area else ''} \ ~{if defined(cytosol_max_area) && select_first([cytosol_max_area])>0 then '--cytosol-max-area ' + cytosol_max_area else ''} \ ~{if defined(features_extra_arguments) then features_extra_arguments else ''} \ - --labels "~{labels}" \ - ~{"--objects " + objects} \ + --merge ~{merge} \ + --labels ~{sep=" " labels} \ ~{"--label-filter " + '"' + label_filter + '"'} \ --subset ~{subset} \ ~{"--image-pattern " + image_pattern} \ diff --git a/wdl/ops_workflow.wdl b/wdl/ops_workflow.wdl index 1955d4a..b893206 100644 --- a/wdl/ops_workflow.wdl +++ b/wdl/ops_workflow.wdl @@ -1,4 +1,4 @@ -version 1.0 +version 1.1 import "utils.wdl" as utils import "ops_tasks.wdl" as tasks @@ -15,14 +15,14 @@ workflow ops_workflow { String output_directory - # t to align phenotyping rounds to e.g. "IF" + # t to align phenotyping rounds to e.g. "IF". If not specified then first round in natural sorted order is used String? reference_phenotype_time # features String? features_label_filter = "~barcode_count_0.isna()" # valid barcodes - Array[String]? phenotype_cell_features - Array[String]? phenotype_nuclei_features - Array[String]? phenotype_cytosol_features + Map[String, Array[String]]? phenotype_cell_features + Map[String, Array[String]]? phenotype_nuclei_features + Map[String, Array[String]]? phenotype_cytosol_features String? features_extra_arguments # Single string with extra arguments to scallops features cli Int? features_cell_min_area @@ -32,16 +32,15 @@ workflow ops_workflow { Int? features_nuclei_max_area Int? features_cytosol_max_area - Array[Int] phenotype_cyto_channel # indices after registration for cell segmentation - Int phenotype_dapi_channel # index after registration for segmentation and pheno to iss registration - Int? phenotype_dapi_channel_before_registration # for pheno to pheno registration + Array[Int] phenotype_cyto_channel # indices within referent time for cell segmentation + Int? phenotype_dapi_channel # index within t for nuclei segmentation and pheno to iss registration Int? iss_dapi_channel # ISS to ISS and pheno to ISS registration String? iss_registration_extra_arguments # Extra arguments in scallops registration elastix cli for ISS String? pheno_to_iss_registration_extra_arguments String? pheno_registration_extra_arguments - Boolean? register_across_channels + # spot detect Int? iss_expected_cycles @@ -68,7 +67,7 @@ workflow ops_workflow { String model_dir = "" # nuclei segment - String? nuclei_segmentation + String? nuclei_segmentation_method String? nuclei_segmentation_extra_arguments # cell segment @@ -78,6 +77,7 @@ workflow ops_workflow { String? cell_segmentation_extra_arguments Boolean mark_stitch_boundary_cells = true + String intersects_stitch_boundary_label = "cell" # nuclei # merge String? merge_extra_arguments @@ -152,10 +152,9 @@ workflow ops_workflow { String merge_memory = "256 GiB" String merge_disks = "local-disk 20 HDD" - Int cell_intersects_boundary_cpu = 16 - String cell_intersects_boundary_memory = "32 GiB" - String cell_intersects_boundary_disks = "local-disk 200 HDD" - + Int intersects_boundary_cpu = 16 + String intersects_boundary_memory = "32 GiB" + String intersects_boundary_disks = "local-disk 200 HDD" String docker @@ -169,9 +168,7 @@ workflow ops_workflow { String register_iss_transforms_suffix = "iss-transforms-t0" String register_pheno_to_iss_suffix = "pheno-to-iss-registered.zarr" String register_pheno_to_iss_transforms_suffix = "pheno-to-iss-transforms" - String nuclei_objects_suffix = "objects-nuclei" - String cell_objects_suffix = "objects-cell" - String cytosol_objects_suffix = "objects-cytosol" + String objects_suffix = "objects" String nuclei_features_suffix = "features-nuclei" String cell_features_suffix = "features-cell" String cytosol_features_suffix = "features-cytosol" @@ -183,11 +180,9 @@ workflow ops_workflow { String reads_suffix = "reads" String merge_meta_suffix = "merge-sbs-metadata" String merge_features_suffix = "merge-features" - String cell_intersects_boundary_suffix = "intersects-boundary" - String cell_intersects_boundary_non_reference_t_suffix = "intersects-boundary-t" + String intersects_boundary_suffix = "intersects-boundary" } - String output_stripped = sub(output_directory, "/+$", "") + "/" String segment_directory = output_stripped + segment_suffix String register_iss_t0_directory = output_stripped + register_iss_suffix @@ -197,9 +192,7 @@ workflow ops_workflow { String nuclei_features_directory = output_stripped + nuclei_features_suffix String cell_features_directory = output_stripped + cell_features_suffix String cytosol_features_directory = output_stripped + cytosol_features_suffix - String nuclei_objects_directory = output_stripped + nuclei_objects_suffix - String cell_objects_directory = output_stripped + cell_objects_suffix - String cytosol_objects_directory = output_stripped + cytosol_objects_suffix + String objects_directory = output_stripped + objects_suffix String register_pheno_to_pheno_directory = output_stripped + register_pheno_to_pheno_suffix String register_pheno_to_pheno_transform_directory = output_stripped + register_pheno_to_pheno_transform_suffix String spot_detect_directory = output_stripped + spot_detect_suffix @@ -207,71 +200,63 @@ workflow ops_workflow { String merge_meta_directory = output_stripped + merge_meta_suffix String merge_features_directory = output_stripped + merge_features_suffix String register_pheno_to_iss_qc_directory = output_stripped + register_pheno_to_iss_qc_suffix - String cell_intersects_boundary_directory = output_stripped + cell_intersects_boundary_suffix - String cell_intersects_boundary_directory_non_reference_t = output_stripped + cell_intersects_boundary_non_reference_t_suffix + String intersects_boundary_directory = output_stripped + intersects_boundary_suffix Boolean iss_url_supplied = defined(iss_url) Boolean pheno_url_supplied = defined(phenotype_url) call utils.list_images { input: - urls = [select_first([phenotype_url, iss_url])], - image_pattern = if pheno_url_supplied then phenotype_image_pattern else iss_image_pattern, + urls1 = select_all([phenotype_url]), + image_pattern1 = phenotype_image_pattern, + reference_time1=reference_phenotype_time, + + urls2 = select_all([iss_url]), + image_pattern2 = iss_image_pattern, batch_size=batch_size, + groupby=groupby, - subset=subset, + subset = subset, docker=docker, zones = zones, preemptible = preemptible, aws_queue_arn = aws_queue_arn, max_retries = max_retries } - String image_pattern_after_registration = list_images.groupby_pattern - Array[String] groups = list_images.groups - Array[String] times = list_images.t - scatter (group in groups) { + Array[String] subsets = list_images.subsets + String groupby_pattern = list_images.groupby_pattern # e.g. {plate}-{well} + Array[String] groupby_array = list_images.groupby_array # e.g. ["plate", "well"] + + + # Array[String] subsets_with_reference_times_pheno = list_images.subsets_with_reference_times_1 + # Array[String] subsets_with_reference_times_iss = list_images.subsets_with_reference_times_2 + + Array[String] times_pheno = list_images.times_1 + Array[String] times_iss = list_images.times_2 + + String reference_time_pheno = list_images.reference_time_1 + String reference_time_iss = list_images.reference_time_2 + Array[String] phenotype_group_by_with_time = list_images.groupby_array_with_time_2 + #String image_pattern_with_reference_time_pheno = list_images.image_pattern_with_reference_time_1 # e.g. {plate}-{well}-IF + # String image_pattern_with_reference_time_iss = list_images.image_pattern_with_reference_time_2 # e.g. {plate}-{well}-1 + scatter (subset_index in range(length(subsets))) { + String subset_ = subsets[subset_index] # e.g. plate1-A1 + # String subset_with_reference_times_pheno = subsets_with_reference_times_pheno[subset_index] + # String subset_with_reference_times_iss = subsets_with_reference_times_iss[subset_index] if(pheno_url_supplied) { - if(length(times)>1) { - call tasks.register_elastix as register_pheno_to_pheno { - input: - moving=select_all([phenotype_url]), - moving_label=phenotype_url, # transform stitch masks - moving_channel=phenotype_dapi_channel_before_registration, # DAPI index in each round - moving_image_pattern=phenotype_image_pattern, - reference_time=reference_phenotype_time, - extra_arguments=pheno_registration_extra_arguments, - unroll_channels=true, - register_across_channels=register_across_channels, - groupby=groupby, - moving_output_directory=register_pheno_to_pheno_directory, - label_output_directory=register_pheno_to_pheno_directory, - transform_output_directory=register_pheno_to_pheno_transform_directory, - subset = group, - force = force_register_pheno_to_pheno, - docker=docker, - zones = zones, - preemptible = preemptible, - aws_queue_arn = aws_queue_arn, - disks = register_pheno_to_pheno_disks, - memory = register_pheno_to_pheno_memory, - cpu = register_pheno_to_pheno_cpu, - max_retries = max_retries - } - } - String register_pheno_to_pheno_output_url = select_first([register_pheno_to_pheno.moving_output_url, phenotype_url]) - String register_pheno_to_pheno_image_pattern = if(length(times)>1) then image_pattern_after_registration else phenotype_image_pattern - if(run_nuclei_segmentation) { call tasks.segment_nuclei { input: - images = register_pheno_to_pheno_output_url, - image_pattern = register_pheno_to_pheno_image_pattern, - method = nuclei_segmentation, + images = select_first([phenotype_url]), + image_pattern = phenotype_image_pattern, + time=reference_time_pheno, + subset = subset_, + method = nuclei_segmentation_method, groupby=groupby, dapi_channel = phenotype_dapi_channel, output_directory=segment_directory, model_dir=model_dir, - subset = group, + extra_arguments=nuclei_segmentation_extra_arguments, force = force_segment_nuclei, docker=docker, @@ -283,14 +268,17 @@ workflow ops_workflow { cpu = segment_nuclei_cpu, max_retries = max_retries } + } if(run_cell_segmentation) { call tasks.segment_cell { input: - images = register_pheno_to_pheno_output_url, - image_pattern = register_pheno_to_pheno_image_pattern, + images = select_first([phenotype_url]), + image_pattern = phenotype_image_pattern, + time=reference_time_pheno, method = cell_segmentation_method, - groupby=groupby, + groupby = groupby, + subset = subset_, dapi_channel = phenotype_dapi_channel, cyto_channel=phenotype_cyto_channel, nuclei_label=select_first([segment_nuclei.output_url]), @@ -298,7 +286,7 @@ workflow ops_workflow { threshold_correction_factor = segment_cell_threshold_correction_factor, output_directory=segment_directory, model_dir=model_dir, - subset = group, + extra_arguments=cell_segmentation_extra_arguments, force = force_segment_cell, docker=docker, @@ -310,101 +298,129 @@ workflow ops_workflow { cpu = segment_cell_cpu, max_retries = max_retries } - call tasks.find_objects as find_objects_cell { - input: - labels= segment_cell.output_url, - label_pattern=image_pattern_after_registration, - suffix="cell", - output_directory=cell_objects_directory, - subset = group, - force = force_find_objects, - docker=docker, - zones = zones, - preemptible = preemptible, - aws_queue_arn = aws_queue_arn, - disks = find_objects_disks, - memory = find_objects_memory, - cpu = find_objects_cpu, - max_retries = max_retries - } - call tasks.find_objects as find_objects_cytosol { - input: - labels=segment_cell.output_url, - label_pattern=image_pattern_after_registration, - suffix="cytosol", - output_directory=cytosol_objects_directory, - subset = group, - force = force_find_objects, - docker=docker, - zones = zones, - preemptible = preemptible, - aws_queue_arn = aws_queue_arn, - disks = find_objects_disks, - memory = find_objects_memory, - cpu = find_objects_cpu, - max_retries = max_retries + if(length(times_pheno)>1) { + + call tasks.register_elastix as register_pheno_to_pheno { + input: + moving=select_all([phenotype_url]), + moving_label=segment_cell.output_url, + moving_channel=phenotype_dapi_channel, + moving_image_pattern=phenotype_image_pattern, + moving_time=reference_time_pheno, + + extra_arguments=pheno_registration_extra_arguments, + output_aligned_channels_only=true, + groupby=groupby, + subset = subset_, + moving_output_directory=register_pheno_to_pheno_directory, + label_output_directory=register_pheno_to_pheno_directory, + transform_output_directory=register_pheno_to_pheno_transform_directory, + + force = force_register_pheno_to_pheno, + docker=docker, + zones = zones, + preemptible = preemptible, + aws_queue_arn = aws_queue_arn, + disks = register_pheno_to_pheno_disks, + memory = register_pheno_to_pheno_memory, + cpu = register_pheno_to_pheno_cpu, + max_retries = max_retries + } } + if(run_nuclei_segmentation) { + call tasks.find_objects as find_objects_nuclei { + input: + labels=select_all([segment_nuclei.output_url, register_pheno_to_pheno.label_output_url]), + label_pattern=groupby_pattern, + suffix="nuclei", + output_directory=objects_directory, + subset = subset_, + force = force_find_objects, + docker=docker, + zones = zones, + preemptible = preemptible, + aws_queue_arn = aws_queue_arn, + disks = find_objects_disks, + memory = find_objects_memory, + cpu = find_objects_cpu, + max_retries = max_retries + } + } + if(run_cell_segmentation) { + call tasks.find_objects as find_objects_cell { + input: + labels=select_all([segment_cell.output_url, register_pheno_to_pheno.label_output_url]), + label_pattern=groupby_pattern, + suffix="cell", + output_directory=objects_directory, + subset = subset_, + force = force_find_objects, + docker=docker, + zones = zones, + preemptible = preemptible, + aws_queue_arn = aws_queue_arn, + disks = find_objects_disks, + memory = find_objects_memory, + cpu = find_objects_cpu, + max_retries = max_retries + } + } + if (run_nuclei_segmentation && run_cell_segmentation) { + call tasks.find_objects as find_objects_cytosol { + input: + labels=select_all([segment_cell.output_url, register_pheno_to_pheno.label_output_url]), + label_pattern=groupby_pattern, + suffix="cytosol", + output_directory=objects_directory, + subset = subset_, + force = force_find_objects, + docker=docker, + zones = zones, + preemptible = preemptible, + aws_queue_arn = aws_queue_arn, + disks = find_objects_disks, + memory = find_objects_memory, + cpu = find_objects_cpu, + max_retries = max_retries + } + } # determine whether cells intersect stitch boundary - # using stitch mask as image + # use stitch mask as image and segment output for reference phenotype or transformed phenotype for others + if(mark_stitch_boundary_cells) { - String t0 = if (length(times)>0) then times[0] else "" - String reference_phenotype_time_ = select_first([reference_phenotype_time, t0]) - String output_prefix = if (reference_phenotype_time_!="") then "-" else "" - String phenotype_url_stripped = if (pheno_url_supplied) then sub(select_first([phenotype_url]), "/+$", "") else "" - call tasks.intersects_boundary as cell_intersects_boundary { - # reference time mask is not transformed - # use mask from stitch output + String phenotype_url_stripped = sub(select_first([phenotype_url]), "/+$", "") + call tasks.intersects_boundary as intersects_boundary { + input: - labels=segment_cell.output_url, + labels=select_all([segment_cell.output_url, register_pheno_to_pheno.label_output_url]), images=phenotype_url_stripped + '/labels/', - image_pattern=image_pattern_after_registration + output_prefix + reference_phenotype_time_ + '-mask', - output_directory=cell_intersects_boundary_directory, - label_type='cell', - objects=find_objects_cell.output_url, - groupby=groupby, - subset = group, - force = force_segment_cell, + image_pattern=phenotype_image_pattern + '-mask', + output_directory=intersects_boundary_directory, + label_type=intersects_stitch_boundary_label, + objects=if(intersects_stitch_boundary_label=='cell') then find_objects_cell.output_url else find_objects_nuclei.output_url, + groupby=phenotype_group_by_with_time, + subset = subset_ + '-*', + force = if(intersects_stitch_boundary_label=='cell') then force_segment_cell else force_segment_nuclei, docker=docker, zones = zones, preemptible = preemptible, aws_queue_arn = aws_queue_arn, - disks = cell_intersects_boundary_disks, - memory = cell_intersects_boundary_memory, - cpu = cell_intersects_boundary_cpu, + disks = intersects_boundary_disks, + memory = intersects_boundary_memory, + cpu = intersects_boundary_cpu, max_retries = max_retries } - if (length(times)>1) { - call tasks.intersects_boundary as cell_intersects_boundary_t { - # non-reference time masks are transformed - # use masks from registration output - input: - labels= segment_cell.output_url, - images=register_pheno_to_pheno.moving_output_url + '/labels/', - image_pattern=phenotype_image_pattern + '-mask', - output_directory=cell_intersects_boundary_directory_non_reference_t, - label_type='cell', - objects=find_objects_cell.output_url, - subset = group, - groupby=groupby, - force = force_segment_cell, - docker=docker, - zones = zones, - preemptible = preemptible, - aws_queue_arn = aws_queue_arn, - disks = cell_intersects_boundary_disks, - memory = cell_intersects_boundary_memory, - cpu = cell_intersects_boundary_cpu, - max_retries = max_retries - } - } + } } + } if(iss_url_supplied) { + call tasks.register_elastix as register_iss_t0 { input: moving=[select_first([iss_url])], @@ -413,9 +429,8 @@ workflow ops_workflow { groupby=groupby, moving_output_directory=register_iss_t0_directory, transform_output_directory=register_iss_t0_transforms_directory, - register_across_channels=register_across_channels, extra_arguments=iss_registration_extra_arguments, - subset = group, + subset = subset_, force = force_register_iss, docker=docker, zones = zones, @@ -429,21 +444,23 @@ workflow ops_workflow { } if(iss_url_supplied && pheno_url_supplied) { + # transfer phenotype segmentation and DAPI channel to ISS call tasks.register_elastix as register_pheno_to_iss { input: fixed=select_first([iss_url]), fixed_channel=iss_dapi_channel, moving_label=segment_cell.output_url, - moving=select_all([register_pheno_to_pheno_output_url]), - moving_image_pattern=register_pheno_to_pheno_image_pattern, + moving=select_all([phenotype_url]), + moving_image_pattern=phenotype_image_pattern, fixed_image_pattern=iss_image_pattern, moving_channel=phenotype_dapi_channel, + moving_time=reference_time_pheno, + fixed_time=reference_time_iss, output_aligned_channels_only=true, - register_across_channels=register_across_channels, moving_output_directory=register_pheno_to_iss_directory, label_output_directory=register_pheno_to_iss_directory, transform_output_directory=register_pheno_to_iss_transforms_directory, - subset = group, + subset = subset_, groupby=groupby, extra_arguments=pheno_to_iss_registration_extra_arguments, force = force_register_pheno_to_iss, @@ -457,36 +474,19 @@ workflow ops_workflow { max_retries = max_retries } if(run_nuclei_segmentation) { - call tasks.find_objects as find_objects_nuclei { - input: - labels=segment_nuclei.output_url, - label_pattern=image_pattern_after_registration, - suffix="nuclei", - output_directory=nuclei_objects_directory, - subset = group, - force = force_find_objects, - docker=docker, - zones = zones, - preemptible = preemptible, - aws_queue_arn = aws_queue_arn, - disks = find_objects_disks, - memory = find_objects_memory, - cpu = find_objects_cpu, - max_retries = max_retries - } - + # ISS t0 to phenotype reference time call tasks.register_pheno_to_iss_qc as register_pheno_to_iss_qc { input: - images=select_first([register_iss_t0.moving_output_url]), - image_pattern=image_pattern_after_registration, - stacked_images=register_pheno_to_iss.moving_output_url, - stacked_image_pattern=image_pattern_after_registration, + images=register_pheno_to_iss.moving_output_url, + image_pattern=groupby_pattern, + stacked_images=select_first([register_pheno_to_iss.moving_output_url]), + stacked_image_pattern=groupby_pattern, image_channel=iss_dapi_channel, stacked_image_channel=0, label_type='nuclei', output_directory=register_pheno_to_iss_qc_directory, labels=register_pheno_to_iss.label_output_url, - subset = group, + subset = subset_, groupby=groupby, force = force_register_pheno_to_iss_qc, docker=docker, @@ -498,16 +498,18 @@ workflow ops_workflow { cpu = register_pheno_to_iss_qc_cpu, max_retries = max_retries } - call tasks.register_qc as register_iss_to_iss_qc { + # ISS t0 to other times + call tasks.register_iss_iss_qc as register_iss_to_iss_qc { input: images=select_first([register_iss_t0.moving_output_url]), - image_pattern=image_pattern_after_registration, - channel=select_first([iss_dapi_channel, 0]), + image_pattern=groupby_pattern, + dapi_channel=select_first([iss_dapi_channel, 0]), + n_timepoints=length(times_iss), label_type='nuclei', - channel_prefix="ISS", + output_directory=register_iss_to_iss_qc_directory, labels=register_pheno_to_iss.label_output_url, - subset = group, + subset = subset_, groupby=groupby, force = force_register_iss_to_iss_qc, docker=docker, @@ -526,14 +528,14 @@ workflow ops_workflow { call tasks.spot_detect { input: images=select_first([register_iss_t0.moving_output_url]), - image_pattern=image_pattern_after_registration, + image_pattern=groupby_pattern, iss_channels=iss_channels, sigma_log=spot_detection_sigma_log, max_filter_width=spot_detection_max_filter_width, peak_neighborhood_size=spot_detection_peak_neighborhood_size, expected_cycles=iss_expected_cycles, output_directory=spot_detect_directory, - subset = group, + subset = subset_, groupby=groupby, extra_arguments=spot_detection_extra_arguments, force = force_spot_detect, @@ -562,7 +564,7 @@ workflow ops_workflow { label_name=reads_labels, mismatches=reads_mismatches, threshold_peaks_crosstalk=reads_threshold_peaks_crosstalk, - subset = group, + subset = subset_, extra_arguments=reads_extra_arguments, force = force_reads, docker=docker, @@ -580,20 +582,14 @@ workflow ops_workflow { call tasks.merge as merge_sbs_metadata { input: iss_reads=select_first([reads.output_url]) + '/labels', -# phenotypes_nuclei=features_nuclei.output_url, -# phenotypes_cell=features_cell.output_url, -# phenotypes_cytosol=features_cytosol.output_url, - objects_nuclei=find_objects_nuclei.output_url, - objects_cell=find_objects_cell.output_url, - objects_cytosol=find_objects_cytosol.output_url, - cell_intersects_boundary=cell_intersects_boundary.output_url, - cell_intersects_boundary_t=cell_intersects_boundary_t.output_url, + objects_nuclei=if(run_cell_segmentation) then find_objects_nuclei.output_url else find_objects_cell.output_url, + cell_intersects_boundary=intersects_boundary.output_url, register_pheno_to_iss_qc=register_pheno_to_iss_qc.output_url, register_iss_to_iss_qc=register_iss_to_iss_qc.output_url, barcodes=select_first([barcodes]), barcode_column=barcode_column, output_directory=merge_meta_directory, - subset = group, + subset = subset_, extra_arguments=merge_extra_arguments, force = force_merge, docker=docker, @@ -609,132 +605,148 @@ workflow ops_workflow { } if (defined(phenotype_nuclei_features)) { - Array[String] phenotype_nuclei_features_ = select_first([phenotype_nuclei_features]) - # cromwell hack + Map[String,Array[String]] phenotype_nuclei_features_ = select_first([phenotype_nuclei_features]) Int features_nuclei_min_area_ = select_first([features_nuclei_min_area, -1]) Int features_nuclei_max_area_ = select_first([features_nuclei_max_area, -1]) - scatter (index in range(length(phenotype_nuclei_features_))) { + Array[String] phenotype_nuclei_times = keys(phenotype_nuclei_features_) - call tasks.features as features_nuclei { - input: - images = select_first([register_pheno_to_pheno_output_url]), - image_pattern=register_pheno_to_pheno_image_pattern, - objects=merge_sbs_metadata.output_url, - label_filter=features_label_filter, - nuclei_features = phenotype_nuclei_features_[index], - nuclei_min_area = features_nuclei_min_area_, - nuclei_max_area = features_nuclei_max_area_, - features_extra_arguments=features_extra_arguments, - labels= segment_cell.output_url, - model_dir=model_dir, - groupby=groupby, - output_directory=nuclei_features_directory + '-' + index, - subset = group, - force = force_features, - docker=docker, - zones = zones, - preemptible = preemptible, - aws_queue_arn = aws_queue_arn, - disks = features_disks, - memory = features_memory, - cpu = features_cpu, - max_retries = max_retries + scatter (phenotype_time in phenotype_nuclei_times) { + Array[String] nuclei_features = phenotype_nuclei_features_[phenotype_time] + scatter (feature_index in range(length(nuclei_features))) { + call tasks.features as features_nuclei { + input: + images = select_first([phenotype_url]), + image_pattern=phenotype_image_pattern, + merge=merge_sbs_metadata.output_url, + labels=select_all([segment_nuclei.output_url, register_pheno_to_pheno.label_output_url]), + label_filter=features_label_filter, + groupby=phenotype_group_by_with_time, + nuclei_features = nuclei_features[feature_index], + nuclei_min_area = features_nuclei_min_area_, + nuclei_max_area = features_nuclei_max_area_, + features_extra_arguments=features_extra_arguments, + + model_dir=model_dir, + + output_directory=nuclei_features_directory + '-' + phenotype_time + '-batch' + feature_index, + subset = subset_ + '-' + phenotype_time, + force = force_features, + docker=docker, + zones = zones, + preemptible = preemptible, + aws_queue_arn = aws_queue_arn, + disks = features_disks, + memory = features_memory, + cpu = features_cpu, + max_retries = max_retries + } } } } if (defined(phenotype_cell_features)) { - Array[String] phenotype_cell_features_ = select_first([phenotype_cell_features]) - # cromwell hack + Map[String,Array[String]] phenotype_cell_features_ = select_first([phenotype_cell_features]) Int features_cell_min_area_ = select_first([features_cell_min_area, -1]) Int features_cell_max_area_ = select_first([features_cell_max_area, -1]) - scatter (index in range(length(phenotype_cell_features_))) { - call tasks.features as features_cell { - input: - images = select_first([register_pheno_to_pheno_output_url]), - image_pattern=register_pheno_to_pheno_image_pattern, - objects=merge_sbs_metadata.output_url, - label_filter=features_label_filter, - cell_features = phenotype_cell_features_[index], - cell_min_area = features_cell_min_area_, - cell_max_area = features_cell_max_area_, - features_extra_arguments=features_extra_arguments, - labels= segment_cell.output_url, - model_dir=model_dir, - groupby=groupby, - output_directory=cell_features_directory + '-' + index, - subset = group, - force = force_features, - docker=docker, - zones = zones, - preemptible = preemptible, - aws_queue_arn = aws_queue_arn, - disks = features_disks, - memory = features_memory, - cpu = features_cpu, - max_retries = max_retries + Array[String] phenotype_cell_times = keys(phenotype_cell_features_) + + scatter (phenotype_time in phenotype_cell_times) { + Array[String] cell_features = phenotype_cell_features_[phenotype_time] + scatter (feature_index in range(length(cell_features))) { + call tasks.features as features_cell { + input: + images = select_first([phenotype_url]), + image_pattern=phenotype_image_pattern, + merge=merge_sbs_metadata.output_url, + labels=select_all([segment_cell.output_url, register_pheno_to_pheno.label_output_url]), + label_filter=features_label_filter, + groupby=phenotype_group_by_with_time, + cell_features = cell_features[feature_index], + cell_min_area = features_cell_min_area_, + cell_max_area = features_cell_max_area_, + features_extra_arguments=features_extra_arguments, + + model_dir=model_dir, + + output_directory=cell_features_directory + '-' + phenotype_time + '-' + feature_index, + subset = subset_ + '-' + phenotype_time, + force = force_features, + docker=docker, + zones = zones, + preemptible = preemptible, + aws_queue_arn = aws_queue_arn, + disks = features_disks, + memory = features_memory, + cpu = features_cpu, + max_retries = max_retries + } } } } if (defined(phenotype_cytosol_features)) { - Array[String] phenotype_cytosol_features_ = select_first([phenotype_cytosol_features]) - # cromwell hack + Map[String,Array[String]] phenotype_cytosol_features_ = select_first([phenotype_cytosol_features]) Int features_cytosol_min_area_ = select_first([features_cytosol_min_area, -1]) Int features_cytosol_max_area_ = select_first([features_cytosol_max_area, -1]) - scatter (index in range(length(phenotype_cytosol_features_))) { - call tasks.features as features_cytosol { - input: - images = select_first([register_pheno_to_pheno_output_url]), - image_pattern=register_pheno_to_pheno_image_pattern, - objects=merge_sbs_metadata.output_url, - label_filter=features_label_filter, - cytosol_features = phenotype_cytosol_features_[index], - cytosol_min_area = features_cytosol_min_area_, - cytosol_max_area = features_cytosol_max_area_, - labels = segment_cell.output_url, - features_extra_arguments=features_extra_arguments, - model_dir=model_dir, - groupby=groupby, - output_directory=cytosol_features_directory + '-' + index, - subset = group, - force = force_features, - docker=docker, - zones = zones, - preemptible = preemptible, - aws_queue_arn = aws_queue_arn, - disks = features_disks, - memory = features_memory, - cpu = features_cpu, - max_retries = max_retries + Array[String] phenotype_cytosol_times = keys(phenotype_cytosol_features_) + + scatter (phenotype_time in phenotype_cytosol_times) { + Array[String] cytosol_features = phenotype_cytosol_features_[phenotype_time] + scatter (feature_index in range(length(cytosol_features))) { + call tasks.features as features_cytosol { + input: + images = select_first([phenotype_url]), + image_pattern=phenotype_image_pattern, + merge=merge_sbs_metadata.output_url, + labels=select_all([segment_cell.output_url, register_pheno_to_pheno.label_output_url]), + label_filter=features_label_filter, + groupby=phenotype_group_by_with_time, + output_directory=cytosol_features_directory + '-' + phenotype_time + '-' + feature_index, + cytosol_features = cytosol_features[feature_index], + cytosol_min_area = features_cytosol_min_area_, + cytosol_max_area = features_cytosol_max_area_, + features_extra_arguments=features_extra_arguments, + + model_dir=model_dir, + + subset = subset_ + '-' + phenotype_time, + force = force_features, + docker=docker, + zones = zones, + preemptible = preemptible, + aws_queue_arn = aws_queue_arn, + disks = features_disks, + memory = features_memory, + cpu = features_cpu, + max_retries = max_retries + } } } } - if (defined(barcodes)) { + if (defined(barcodes)) { - call tasks.merge as merge_features { - input: - phenotypes_nuclei=features_nuclei.output_url, - phenotypes_cell=features_cell.output_url, - phenotypes_cytosol=features_cytosol.output_url, - iss_reads=merge_sbs_metadata.output_url, - output_directory=merge_features_directory, - subset = group, - extra_arguments=merge_extra_arguments, - force = force_merge, - docker=docker, - zones = zones, - preemptible = preemptible, - aws_queue_arn = aws_queue_arn, - disks = merge_disks, - memory = merge_memory, - cpu = merge_cpu, - max_retries = max_retries - } + call tasks.merge as merge_features { + input: + phenotypes_nuclei=features_nuclei.output_url, + phenotypes_cell=features_cell.output_url, + phenotypes_cytosol=features_cytosol.output_url, + iss_reads=merge_sbs_metadata.output_url, + output_directory=merge_features_directory, + subset = subset_, + extra_arguments=merge_extra_arguments, + force = force_merge, + docker=docker, + zones = zones, + preemptible = preemptible, + aws_queue_arn = aws_queue_arn, + disks = merge_disks, + memory = merge_memory, + cpu = merge_cpu, + max_retries = max_retries } - + } } output { @@ -754,7 +766,5 @@ workflow ops_workflow { Array[Array[String]?] features_cytosol_output_url = features_cytosol.output_url Array[String?] merge_sbs_metadata_output_url = merge_sbs_metadata.output_url Array[String?] merge_features_output_url = merge_features.output_url - Array[String] list_images_groups = list_images.groups - } } diff --git a/wdl/utils.wdl b/wdl/utils.wdl index 35e8f01..3a8a0ee 100644 --- a/wdl/utils.wdl +++ b/wdl/utils.wdl @@ -2,12 +2,21 @@ version 1.0 task list_images { input { - Boolean? save_group_size - Array[String] urls - String? image_pattern + + Array[String]? urls1 + String? image_pattern1 + String? reference_time1 + Int? n_cycles1 + + Array[String]? urls2 + String? image_pattern2 + String? reference_time2 + Int? n_cycles2 + Array[String] groupby Array[String]? subset - Int? expected_cycles + + Boolean? save_group_size Int? batch_size String docker String zones @@ -23,25 +32,51 @@ task list_images { command <<< set -e python <>> output { - Array[String] groups = read_lines('groups.txt') - Array[String] t = read_lines('t.txt') - Array[String] filtered_groupby = read_lines('groupby.txt') - String groupby_pattern = read_lines('groupby_pattern.txt')[0] - Int group_size = read_int('group_size.txt') + Array[String] subsets = read_lines('subsets.txt') + String groupby_pattern = read_lines('groupby_pattern.txt')[0] # e.g. {plate}-{well} + Array[String] groupby_array = read_lines('groupby_array.txt') # e.g. ["plate", "well"] + + Array[String] groupby_array_with_time_1 = read_lines('groupby_array_with_time_1.txt') # e.g. ["plate", "well", "t"] + Array[String] groupby_array_with_time_2 = read_lines('groupby_array_with_time_2.txt') # e.g. ["plate", "well", "t"] + + Int group_size_1 = read_int('group_size_1.txt') + Int group_size_2 = read_int('group_size_2.txt') + Array[String] subsets_with_reference_times_1 = read_lines('subsets_with_reference_time_1.txt') + Array[String] subsets_with_reference_times_2 = read_lines('subsets_with_reference_time_2.txt') + + Array[String] times_1 = read_lines('times_1.txt') + Array[String] times_2 = read_lines('times_2.txt') + + String reference_time_1 = read_lines('reference_time_1.txt')[0] + String reference_time_2 = read_lines('reference_time_2.txt')[0] + + String image_pattern_with_reference_time_1 = read_lines('image_pattern_with_reference_time_1.txt')[0] # e.g. {plate}-{well}-IF + String image_pattern_with_reference_time_2 = read_lines('image_pattern_with_reference_time_2.txt')[0] # e.g. {plate}-{well}-IF }