diff --git a/.gitignore b/.gitignore index cc4d53c..c74aaa0 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,5 @@ tmp*/ *.n5 *.h5 .idea/ +build/ test/test-folder/ diff --git a/mobie/image_data.py b/mobie/image_data.py index fe7dcaf..ad4d8c0 100644 --- a/mobie/image_data.py +++ b/mobie/image_data.py @@ -168,7 +168,9 @@ def add_image(input_path, input_key, move_only=False, int_to_uint=False, channel=None, - skip_add_to_dataset=False): + skip_add_to_dataset=False, + selected_input_channel=None, + roi_begin=None, roi_end=None): """ Add an image source to a MoBIE dataset. Will create the dataset if it does not exist. @@ -205,6 +207,9 @@ def add_image(input_path, input_key, This should be used when calling `add_image` in parallel in order to avoid writing to dataset.json in parallel, which can cause issues. In this case the source needs to be added later , which can be done by calling this function again. (default: False) + selected_input_channel [list[int]] - A single channel (idx) to be added. If channel is not axis 0: [idx, dim] + roi_begin [list[int]] - Start of ROI to be extracted + roi_end [list[int]] - End of ROI to be extracted """ # TODO add 'setup_id' to the json schema for bdv formats to also support it there if channel is not None and file_format != "ome.zarr": @@ -247,7 +252,10 @@ def add_image(input_path, input_key, source_name=image_name, file_format=file_format, int_to_uint=int_to_uint, - channel=channel) + channel=channel, + selected_input_channel=selected_input_channel, + roi_begin=roi_begin, roi_end=roi_end + ) if transformation is not None: utils.update_transformation_parameter(image_metadata_path, transformation, file_format) diff --git a/mobie/import_data/image.py b/mobie/import_data/image.py index 4357606..5240ede 100644 --- a/mobie/import_data/image.py +++ b/mobie/import_data/image.py @@ -1,4 +1,5 @@ import multiprocessing as mp +from elf.io import open_file from .utils import downscale, ensure_volume @@ -7,7 +8,9 @@ def import_image_data(in_path, in_key, out_path, tmp_folder=None, target="local", max_jobs=mp.cpu_count(), block_shape=None, unit="micrometer", source_name=None, file_format="ome.zarr", - int_to_uint=False, channel=None): + int_to_uint=False, channel=None, + selected_input_channel=None, + roi_begin=None, roi_end=None): """ Import image data to mobie format. Arguments: @@ -28,13 +31,51 @@ def import_image_data(in_path, in_key, out_path, int_to_uint [bool] - whether to convert signed to unsigned integer (default: False) channel [int] - the channel to load from the data. Currently only supported for the ome.zarr format (default: None) + selected_input_channel [list[int]] - A single channel (idx) to be added. If channel is not axis 0: [idx, dim] + roi_begin [list[int]] - Start of ROI to be extracted + roi_end [list[int]] - End of ROI to be extracted """ + # we allow 2d data for ome.zarr file format if file_format != "ome.zarr": in_path, in_key = ensure_volume(in_path, in_key, tmp_folder, chunks) + if not all((selected_input_channel is None, roi_begin is None, roi_end is None)): + raise NotImplementedError("Selection of sub-arrays only possible with OME-Zarr output.") + + fit_to_roi = False + + if selected_input_channel: + if type(selected_input_channel) is int: + selected_input_channel = [0, selected_input_channel] + elif len(selected_input_channel) < 2: + # if only one element, we assume relevant image stack dimension is 0 (like channel for multi-channel tifs). + selected_input_channel = [0, selected_input_channel[0]] + elif len(selected_input_channel) > 2: + raise ValueError("Only single channel selection possible.") + + with open_file(in_path, mode="r") as f: + shape = f[in_key].shape + + roi_begin = [0] * len(shape) + roi_end = list(shape) + + if selected_input_channel[0] > len(shape) - 1: + raise ValueError("Wrong channel dimension.") + + if selected_input_channel[1] > shape[selected_input_channel[0]] - 1: + raise ValueError("Channel index exceeds axis length.") + + roi_begin[selected_input_channel[0]] = selected_input_channel[1] + roi_end[selected_input_channel[0]] = selected_input_channel[1] + 1 + + if any((roi_begin is not None, roi_end is not None)): + fit_to_roi = True + downscale(in_path, in_key, out_path, resolution, scale_factors, chunks, tmp_folder, target, max_jobs, block_shape, library="skimage", unit=unit, source_name=source_name, - metadata_format=file_format, int_to_uint=int_to_uint, + metadata_format=file_format, + roi_begin=roi_begin, roi_end=roi_end, + int_to_uint=int_to_uint, fit_to_roi=fit_to_roi, channel=channel) diff --git a/mobie/import_data/utils.py b/mobie/import_data/utils.py index 5d34709..3faeab5 100644 --- a/mobie/import_data/utils.py +++ b/mobie/import_data/utils.py @@ -1,5 +1,6 @@ import json import os +import numpy as np import luigi import nifty.distributed as ndist @@ -47,12 +48,13 @@ def compute_node_labels(seg_path, seg_key, return data -def check_input_data(in_path, in_key, resolution, require3d, channel): +def check_input_data(in_path, in_key, resolution, require3d, channel, roi_begin=None, roi_end=None): # TODO to support data with channel, we need to support downscaling with channels if channel is not None: raise NotImplementedError with open_file(in_path, "r") as f: ndim = f[in_key].ndim + if require3d and ndim != 3: raise ValueError(f"Expect 3d data, got ndim={ndim}") if len(resolution) != ndim: @@ -65,7 +67,7 @@ def downscale(in_path, in_key, out_path, library="vigra", library_kwargs=None, metadata_format="ome.zarr", out_key="", unit="micrometer", source_name=None, - roi_begin=None, roi_end=None, + roi_begin=None, roi_end=None, fit_to_roi=False, int_to_uint=False, channel=None): task = DownscalingWorkflow @@ -73,9 +75,9 @@ def downscale(in_path, in_key, out_path, config_dir = os.path.join(tmp_folder, "configs") # ome.zarr can also be written in 2d, all other formats require 3d require3d = metadata_format != "ome.zarr" - check_input_data(in_path, in_key, resolution, require3d, channel) + check_input_data(in_path, in_key, resolution, require3d, channel, roi_begin=roi_begin, roi_end=roi_end) write_global_config(config_dir, block_shape=block_shape, require3d=require3d, - roi_begin=roi_begin, roi_end=roi_end) + roi_begin=roi_begin, roi_end=roi_end, fit_to_roi=fit_to_roi) configs = DownscalingWorkflow.get_config() conf = configs["copy_volume"] diff --git a/mobie/utils.py b/mobie/utils.py index 8fe4598..2084352 100644 --- a/mobie/utils.py +++ b/mobie/utils.py @@ -220,6 +220,7 @@ def write_global_config(config_folder, block_shape=None, roi_begin=None, roi_end=None, + fit_to_roi=False, qos=None, require3d=True): os.makedirs(config_folder, exist_ok=True) @@ -248,6 +249,9 @@ def write_global_config(config_folder, raise ValueError(f"Invalid roi_end given: {roi_end}") global_config["roi_end"] = roi_end + if fit_to_roi: + global_config["fit_to_roi"] = True + if qos is not None: global_config["qos"] = qos diff --git a/test/test_image_data.py b/test/test_image_data.py index 3d7fe8d..5f35582 100644 --- a/test/test_image_data.py +++ b/test/test_image_data.py @@ -234,6 +234,7 @@ def test_cli(self): self.check_data(dataset_folder, im_name) # 2D + @unittest.skipIf(platform == "win32", "CLI does not work on windows") def test_cli_2D(self): @@ -267,7 +268,6 @@ def test_cli_2D(self): exp_data = imageio.imread(in_path) - dataset_folder = os.path.join(self.root, dataset_name) self.check_data(dataset_folder, im_name, exp_data=exp_data) @@ -351,6 +351,91 @@ def test_skip_metadata(self): self.check_data(os.path.join(self.root, self.dataset_name), im_name) + def test_input_channel(self): + path1 = os.path.join(self.test_folder, '3ch.h5') + key = 'data' + self.make_hdf5_data(path1, key, shape=(3, 128, 128)) + + with open_file(path1, mode="r") as f: + im = f[key][:] + + # test wrong channel input + im_name = 'channel_error' + for in_channel in ([1, 2, 3], [4, 0], [0, 4]): + with self.assertRaises(ValueError): + mobie.add_image(path1, key, self.root, self.dataset_name, im_name, + resolution=(1, 1, 1), scale_factors=[[2, 2, 2]], + chunks=(1, 64, 64), tmp_folder=self.tmp_folder, + file_format='ome.zarr', + target="local", max_jobs=self.max_jobs, selected_input_channel=in_channel) + + # check integer channel + for chidx, in_channel in enumerate([1,[1]]): + im_name = '3ch_test_int_' + str(chidx) + mobie.add_image(path1, key, self.root, self.dataset_name, im_name, + resolution=(1, 1, 1), scale_factors=[[2, 2, 2]], + chunks=(1, 64, 64), tmp_folder=self.tmp_folder, + file_format='ome.zarr', + target="local", max_jobs=self.max_jobs, selected_input_channel=in_channel) + test_data = im[1, :, :] + + self.check_data(os.path.join(self.root, self.dataset_name), im_name, exp_data=test_data) + + # check channel as list + im_name = '3ch_test_list' + mobie.add_image(path1, key, self.root, self.dataset_name, im_name, + resolution=(1, 1, 1), scale_factors=[[2, 2, 2]], + chunks=(1, 1, 64), tmp_folder=self.tmp_folder, + file_format='ome.zarr', + target="local", max_jobs=self.max_jobs, selected_input_channel=[1, 14]) + test_data = im[:, 14, :] + + self.check_data(os.path.join(self.root, self.dataset_name), im_name, exp_data=test_data) + + def test_input_roi(self): + path1 = os.path.join(self.test_folder, '3ch.h5') + key = 'data' + inshape=(123, 124, 125) + + self.make_hdf5_data(path1, key, shape=inshape) + + roi_vals = np.floor(np.random.random((2,3))*(np.array(inshape)-1)).astype(int) + + roi_begin = np.min(roi_vals, axis=0) + roi_end = np.max(roi_vals, axis=0) + + for idx in range(3): + if roi_begin[idx] == roi_end[idx]: + roi_end[idx] += 1 + + with open_file(path1, mode="r") as f: + im = f[key][:] + + # check integer channel + im_name = 'roi_test' + mobie.add_image(path1, key, self.root, self.dataset_name, im_name, + resolution=(1, 1, 1), scale_factors=[[2, 2, 2]], + chunks=(1, 64, 64), tmp_folder=self.tmp_folder, + file_format='ome.zarr', + roi_begin=roi_begin, roi_end=roi_end, + target="local", max_jobs=self.max_jobs, + ) + test_data = im[roi_begin[0]:roi_end[0], roi_begin[1]:roi_end[1], roi_begin[2]:roi_end[2]] + + + self.check_data(os.path.join(self.root, self.dataset_name), im_name, exp_data=test_data) + + # check channel as list + im_name = '3ch_test_list' + mobie.add_image(path1, key, self.root, self.dataset_name, im_name, + resolution=(1, 1, 1), scale_factors=[[2, 2, 2]], + chunks=(1, 1, 64), tmp_folder=self.tmp_folder, + file_format='ome.zarr', + target="local", max_jobs=self.max_jobs, selected_input_channel=[1, 14]) + test_data = im[:, 14, :] + + self.check_data(os.path.join(self.root, self.dataset_name), im_name, exp_data=test_data) + # # data validation