From 545ab0d1e894b03b272c1ad812229f80e056564c Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Wed, 28 Jan 2026 11:47:19 -0500 Subject: [PATCH 01/10] feat: add channel axis handling in CellMapDatasetWriter and ImageWriter --- src/cellmap_data/dataset_writer.py | 4 ++++ src/cellmap_data/image_writer.py | 5 ++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/cellmap_data/dataset_writer.py b/src/cellmap_data/dataset_writer.py index 24f8c75..0d5efca 100644 --- a/src/cellmap_data/dataset_writer.py +++ b/src/cellmap_data/dataset_writer.py @@ -431,6 +431,10 @@ def get_image_writer( shape = array_info["shape"] if not isinstance(shape, (Mapping, Sequence)): raise TypeError(f"Shape must be a Mapping or Sequence, not {type(shape)}") + if "n_channels" in array_info: + shape = [array_info["n_channels"]] + list(shape) + if "c" not in self.axis_order: + self.axis_order = "c" + self.axis_order scale_level = array_info.get("scale_level", 0) if not isinstance(scale_level, int): raise TypeError(f"Scale level must be an int, not {type(scale_level)}") diff --git a/src/cellmap_data/image_writer.py b/src/cellmap_data/image_writer.py index a9e8913..71ee0b7 100644 --- a/src/cellmap_data/image_writer.py +++ b/src/cellmap_data/image_writer.py @@ -38,12 +38,15 @@ def __init__( self.base_path = str(path) self.path = (UPath(path) / f"s{scale_level}").path self.label_class = self.target_class = target_class + if len(write_voxel_shape) == len(axis_order) + 1 and "c" not in axis_order: + # Add channel axis if missing + axis_order = "c" + axis_order if isinstance(scale, Sequence): if len(axis_order) > len(scale): scale = [scale[0]] * (len(axis_order) - len(scale)) + list(scale) scale = {c: s for c, s in zip(axis_order, scale)} if isinstance(write_voxel_shape, Sequence): - if len(axis_order) > len(write_voxel_shape): + if len(axis_order) > len(write_voxel_shape): # TODO: This might be a bug write_voxel_shape = [1] * ( len(axis_order) - len(write_voxel_shape) ) + list(write_voxel_shape) From 42f2cec58bcff830798a5c3a9fc6cd511b86557a Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Mon, 16 Feb 2026 22:33:54 -0500 Subject: [PATCH 02/10] fix: clarify dataset length description and initialize additional attributes in CellMapDataset --- src/cellmap_data/dataset.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/cellmap_data/dataset.py b/src/cellmap_data/dataset.py index 4dfdd7b..d434c03 100644 --- a/src/cellmap_data/dataset.py +++ b/src/cellmap_data/dataset.py @@ -509,7 +509,7 @@ def device(self) -> torch.device: return self._device def __len__(self) -> int: - """Returns the number of patches in the dataset.""" + """Returns the number of unique patches in the dataset.""" if not self.has_data and not self.force_has_data: return 0 # Return at least 1 if the dataset has data, so that samplers can be initialized @@ -976,4 +976,6 @@ def empty() -> "CellMapDataset": # Directly instantiate to bypass __new__ logic instance = super(CellMapDataset, CellMapDataset).__new__(CellMapDataset) instance.__init__("", "", [], {}, {}, force_has_data=False) + instance.has_data = False + instance._sampling_box_shape = {c: 0 for c in instance.axis_order} return instance From a7a9fd7250623257446c7fe760d22394eea9c1b9 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 18 Feb 2026 15:30:22 +0000 Subject: [PATCH 03/10] Initial plan From 109c52a3ba661d61c2827dae11cb7195d44dd38b Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 18 Feb 2026 15:43:01 +0000 Subject: [PATCH 04/10] Add test coverage for multi-channel writing and fix bounding box handling Co-authored-by: rhoadesScholar <37990507+rhoadesScholar@users.noreply.github.com> --- src/cellmap_data/image_writer.py | 8 +++- tests/test_cellmap_dataset.py | 35 ++++++++++++++++ tests/test_dataset_writer.py | 69 ++++++++++++++++++++++++++++++++ tests/test_empty_image_writer.py | 55 +++++++++++++++++++++++++ 4 files changed, 166 insertions(+), 1 deletion(-) diff --git a/src/cellmap_data/image_writer.py b/src/cellmap_data/image_writer.py index 71ee0b7..ced8a45 100644 --- a/src/cellmap_data/image_writer.py +++ b/src/cellmap_data/image_writer.py @@ -38,12 +38,14 @@ def __init__( self.base_path = str(path) self.path = (UPath(path) / f"s{scale_level}").path self.label_class = self.target_class = target_class + channel_axis_added = False if len(write_voxel_shape) == len(axis_order) + 1 and "c" not in axis_order: # Add channel axis if missing axis_order = "c" + axis_order + channel_axis_added = True if isinstance(scale, Sequence): if len(axis_order) > len(scale): - scale = [scale[0]] * (len(axis_order) - len(scale)) + list(scale) + scale = [1.0] * (len(axis_order) - len(scale)) + list(scale) scale = {c: s for c, s in zip(axis_order, scale)} if isinstance(write_voxel_shape, Sequence): if len(axis_order) > len(write_voxel_shape): # TODO: This might be a bug @@ -52,6 +54,10 @@ def __init__( ) + list(write_voxel_shape) write_voxel_shape = {c: t for c, t in zip(axis_order, write_voxel_shape)} self.scale = scale + # Add bounding_box for channel axis if it was added or if 'c' is in axis_order but not in bounding_box + if "c" in axis_order and "c" not in bounding_box: + n_channels = write_voxel_shape["c"] + bounding_box = {"c": [0, n_channels], **bounding_box} self.bounding_box = bounding_box self.write_voxel_shape = write_voxel_shape self.write_world_shape = { diff --git a/tests/test_cellmap_dataset.py b/tests/test_cellmap_dataset.py index 371e645..b01cf98 100644 --- a/tests/test_cellmap_dataset.py +++ b/tests/test_cellmap_dataset.py @@ -445,3 +445,38 @@ def test_max_workers_parameter(self, minimal_dataset_config): # Dataset should be created successfully assert dataset is not None + + def test_empty_dataset_creation(self): + """Test CellMapDataset.empty() static method.""" + from cellmap_data import CellMapDataset + + # Create an empty dataset + empty_dataset = CellMapDataset.empty() + + # Verify basic properties + assert empty_dataset is not None + assert isinstance(empty_dataset, CellMapDataset) + assert empty_dataset.has_data is False + assert len(empty_dataset) == 0 + + # Verify the newly initialized attributes + assert hasattr(empty_dataset, "_sampling_box_shape") + assert isinstance(empty_dataset._sampling_box_shape, dict) + assert all(v == 0 for v in empty_dataset._sampling_box_shape.values()) + + # Verify axis_order is set (should have default value) + assert hasattr(empty_dataset, "axis_order") + assert len(empty_dataset._sampling_box_shape) == len(empty_dataset.axis_order) + + def test_empty_dataset_sampling_box_shape(self): + """Test that empty dataset has correct _sampling_box_shape initialization.""" + from cellmap_data import CellMapDataset + + empty_dataset = CellMapDataset.empty() + + # Verify sampling_box_shape keys match axis_order + assert set(empty_dataset._sampling_box_shape.keys()) == set(empty_dataset.axis_order) + + # Verify all dimensions are 0 + for axis in empty_dataset.axis_order: + assert empty_dataset._sampling_box_shape[axis] == 0 diff --git a/tests/test_dataset_writer.py b/tests/test_dataset_writer.py index a2186d4..2bd8dcb 100644 --- a/tests/test_dataset_writer.py +++ b/tests/test_dataset_writer.py @@ -277,6 +277,75 @@ def test_context_parameter(self, writer_config): assert writer.context is context + def test_n_channels_in_target_arrays(self, writer_config): + """Test n_channels parameter in target arrays configuration.""" + config = writer_config["input_config"] + + # Test with n_channels to create multi-channel output + target_arrays = { + "predictions": { + "shape": (32, 32, 32), + "scale": (8.0, 8.0, 8.0), + "n_channels": 3, + } + } + + target_bounds = { + "predictions": { + "x": [0, 256], + "y": [0, 256], + "z": [0, 256], + } + } + + writer = CellMapDatasetWriter( + raw_path=config["raw_path"], + target_path=writer_config["output_path"], + classes=["class_0"], + input_arrays={"raw": {"shape": (32, 32, 32), "scale": (8.0, 8.0, 8.0)}}, + target_arrays=target_arrays, + target_bounds=target_bounds, + ) + + # Verify that the channel axis was added + assert "c" in writer.axis_order + assert writer.axis_order.startswith("c") + + def test_n_channels_with_existing_channel_axis(self, writer_config): + """Test n_channels parameter when channel axis already exists.""" + config = writer_config["input_config"] + + target_arrays = { + "predictions": { + "shape": (32, 32, 32), + "scale": (8.0, 8.0, 8.0), + "n_channels": 4, + } + } + + target_bounds = { + "predictions": { + "x": [0, 256], + "y": [0, 256], + "z": [0, 256], + } + } + + # Test with explicit channel axis in axis_order + writer = CellMapDatasetWriter( + raw_path=config["raw_path"], + target_path=writer_config["output_path"], + classes=["class_0"], + input_arrays={"raw": {"shape": (32, 32, 32), "scale": (8.0, 8.0, 8.0)}}, + target_arrays=target_arrays, + axis_order="czyx", + target_bounds=target_bounds, + ) + + # Verify channel axis is present and not duplicated + assert writer.axis_order == "czyx" + assert writer.axis_order.count("c") == 1 + class TestWriterOperations: """Test writer operations and functionality.""" diff --git a/tests/test_empty_image_writer.py b/tests/test_empty_image_writer.py index dcc9aef..93b867b 100644 --- a/tests/test_empty_image_writer.py +++ b/tests/test_empty_image_writer.py @@ -336,3 +336,58 @@ def test_multiple_writers_different_classes(self, tmp_upath): assert len(writers) == 3 assert all(w.target_class in classes for w in writers) + + def test_image_writer_channel_axis_detection(self, tmp_upath): + """Test automatic channel axis detection when write_voxel_shape has extra dimension.""" + path = tmp_upath / "output_channels.zarr" + + # Test with 4D shape but 3D axis_order (should add channel axis) + writer = ImageWriter( + path=str(path), + target_class="multichannel", + scale=(4.0, 4.0, 4.0), + write_voxel_shape=(3, 32, 32, 32), # 4D shape with channels + axis_order="zyx", # 3D axis order + bounding_box={"z": [0, 256], "y": [0, 256], "x": [0, 256]}, + ) + + # Verify channel axis was added + assert "c" in writer.axes + assert writer.axes.startswith("c") + assert len(writer.axes) == 4 + + def test_image_writer_with_explicit_channel_axis(self, tmp_upath): + """Test ImageWriter with explicit channel axis in axis_order.""" + path = tmp_upath / "output_explicit_channels.zarr" + + # Test with explicit channel axis + writer = ImageWriter( + path=str(path), + target_class="multichannel", + scale=(4.0, 4.0, 4.0), + write_voxel_shape=(5, 32, 32, 32), # 4D shape with 5 channels + axis_order="czyx", # Explicit channel axis + bounding_box={"z": [0, 256], "y": [0, 256], "x": [0, 256]}, + ) + + # Verify channel axis is present + assert writer.axes == "czyx" + assert writer.write_voxel_shape["c"] == 5 + + def test_image_writer_no_channel_detection_when_not_needed(self, tmp_upath): + """Test that channel axis is not added when dimensions match.""" + path = tmp_upath / "output_no_channels.zarr" + + # Test with matching dimensions (no channel detection needed) + writer = ImageWriter( + path=str(path), + target_class="test", + scale=(4.0, 4.0, 4.0), + write_voxel_shape=(32, 32, 32), # 3D shape matching 3D axis order + axis_order="zyx", + bounding_box={"z": [0, 256], "y": [0, 256], "x": [0, 256]}, + ) + + # Verify no channel axis was added + assert "c" not in writer.axes + assert writer.axes == "zyx" From 36dc0447f29436f28caa0d85dd6c7423ae0a1203 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Wed, 18 Feb 2026 10:54:14 -0500 Subject: [PATCH 05/10] test: clean up whitespace in test cases for CellMapDataset and ImageWriter --- tests/test_cellmap_dataset.py | 18 ++++++++++-------- tests/test_empty_image_writer.py | 12 ++++++------ 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/tests/test_cellmap_dataset.py b/tests/test_cellmap_dataset.py index b01cf98..2750513 100644 --- a/tests/test_cellmap_dataset.py +++ b/tests/test_cellmap_dataset.py @@ -449,21 +449,21 @@ def test_max_workers_parameter(self, minimal_dataset_config): def test_empty_dataset_creation(self): """Test CellMapDataset.empty() static method.""" from cellmap_data import CellMapDataset - + # Create an empty dataset empty_dataset = CellMapDataset.empty() - + # Verify basic properties assert empty_dataset is not None assert isinstance(empty_dataset, CellMapDataset) assert empty_dataset.has_data is False assert len(empty_dataset) == 0 - + # Verify the newly initialized attributes assert hasattr(empty_dataset, "_sampling_box_shape") assert isinstance(empty_dataset._sampling_box_shape, dict) assert all(v == 0 for v in empty_dataset._sampling_box_shape.values()) - + # Verify axis_order is set (should have default value) assert hasattr(empty_dataset, "axis_order") assert len(empty_dataset._sampling_box_shape) == len(empty_dataset.axis_order) @@ -471,12 +471,14 @@ def test_empty_dataset_creation(self): def test_empty_dataset_sampling_box_shape(self): """Test that empty dataset has correct _sampling_box_shape initialization.""" from cellmap_data import CellMapDataset - + empty_dataset = CellMapDataset.empty() - + # Verify sampling_box_shape keys match axis_order - assert set(empty_dataset._sampling_box_shape.keys()) == set(empty_dataset.axis_order) - + assert set(empty_dataset._sampling_box_shape.keys()) == set( + empty_dataset.axis_order + ) + # Verify all dimensions are 0 for axis in empty_dataset.axis_order: assert empty_dataset._sampling_box_shape[axis] == 0 diff --git a/tests/test_empty_image_writer.py b/tests/test_empty_image_writer.py index 93b867b..2d733bf 100644 --- a/tests/test_empty_image_writer.py +++ b/tests/test_empty_image_writer.py @@ -340,7 +340,7 @@ def test_multiple_writers_different_classes(self, tmp_upath): def test_image_writer_channel_axis_detection(self, tmp_upath): """Test automatic channel axis detection when write_voxel_shape has extra dimension.""" path = tmp_upath / "output_channels.zarr" - + # Test with 4D shape but 3D axis_order (should add channel axis) writer = ImageWriter( path=str(path), @@ -350,7 +350,7 @@ def test_image_writer_channel_axis_detection(self, tmp_upath): axis_order="zyx", # 3D axis order bounding_box={"z": [0, 256], "y": [0, 256], "x": [0, 256]}, ) - + # Verify channel axis was added assert "c" in writer.axes assert writer.axes.startswith("c") @@ -359,7 +359,7 @@ def test_image_writer_channel_axis_detection(self, tmp_upath): def test_image_writer_with_explicit_channel_axis(self, tmp_upath): """Test ImageWriter with explicit channel axis in axis_order.""" path = tmp_upath / "output_explicit_channels.zarr" - + # Test with explicit channel axis writer = ImageWriter( path=str(path), @@ -369,7 +369,7 @@ def test_image_writer_with_explicit_channel_axis(self, tmp_upath): axis_order="czyx", # Explicit channel axis bounding_box={"z": [0, 256], "y": [0, 256], "x": [0, 256]}, ) - + # Verify channel axis is present assert writer.axes == "czyx" assert writer.write_voxel_shape["c"] == 5 @@ -377,7 +377,7 @@ def test_image_writer_with_explicit_channel_axis(self, tmp_upath): def test_image_writer_no_channel_detection_when_not_needed(self, tmp_upath): """Test that channel axis is not added when dimensions match.""" path = tmp_upath / "output_no_channels.zarr" - + # Test with matching dimensions (no channel detection needed) writer = ImageWriter( path=str(path), @@ -387,7 +387,7 @@ def test_image_writer_no_channel_detection_when_not_needed(self, tmp_upath): axis_order="zyx", bounding_box={"z": [0, 256], "y": [0, 256], "x": [0, 256]}, ) - + # Verify no channel axis was added assert "c" not in writer.axes assert writer.axes == "zyx" From d490a24fc1d393cd4e46f1fae90fde70f899f145 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Wed, 18 Feb 2026 15:09:29 -0500 Subject: [PATCH 06/10] feat: refactor neuroglancer integration and improve scale pyramid handling --- src/cellmap_data/utils/view.py | 237 +++++++++++++++++---------------- 1 file changed, 122 insertions(+), 115 deletions(-) diff --git a/src/cellmap_data/utils/view.py b/src/cellmap_data/utils/view.py index 3e667db..264ef80 100644 --- a/src/cellmap_data/utils/view.py +++ b/src/cellmap_data/utils/view.py @@ -8,15 +8,8 @@ import webbrowser from multiprocessing.pool import ThreadPool -import neuroglancer import numpy as np -import s3fs import zarr -from IPython.core.getipython import get_ipython -from IPython.display import IFrame, display -from tensorstore import d as ts_d -from tensorstore import open as ts_open -from upath import UPath logger = logging.getLogger(__name__) @@ -30,6 +23,8 @@ def get_multiscale_voxel_sizes(path: str): if "s3://" in path: + import s3fs + # Use s3fs to read the zarr metadata fs = s3fs.S3FileSystem(anon=True) store = s3fs.S3Map( @@ -114,6 +109,10 @@ def open_neuroglancer(metadata): Returns the Neuroglancer.Viewer object. """ + import neuroglancer + from IPython.core.getipython import get_ipython + from IPython.display import IFrame, display + # 1) bind to localhost on a random port neuroglancer.set_server_bind_address("localhost", 0) viewer = neuroglancer.Viewer() @@ -178,7 +177,7 @@ def get_layer( data_path: str, layer_type: str = "image", multiscale: bool = True, -) -> neuroglancer.Layer: +): """ Get a Neuroglancer layer from a zarr data path for a LocalVolume. @@ -196,6 +195,9 @@ def get_layer( neuroglancer.Layer The Neuroglancer layer. """ + import neuroglancer + from upath import UPath + # Construct an xarray with Tensorstore backend # Get metadata if multiscale: @@ -217,6 +219,115 @@ def get_layer( voxel_offset=metadata[scale]["voxel_offset"], ) ) + class ScalePyramid(neuroglancer.LocalVolume): + """A neuroglancer layer that provides volume data on different scales. + Mimics a LocalVolume. + From https://github.com/funkelab/funlib.show.neuroglancer/blob/master/funlib/show/neuroglancer/scale_pyramid.py + + Args: + ---- + volume_layers (``list`` of ``LocalVolume``): + + One ``LocalVolume`` per provided resolution. + """ + + def __init__(self, volume_layers): + volume_layers = volume_layers + + super(neuroglancer.LocalVolume, self).__init__() + + logger.debug("Creating scale pyramid...") + + self.min_voxel_size = min( + [tuple(layer.dimensions.scales) for layer in volume_layers] + ) + self.max_voxel_size = max( + [tuple(layer.dimensions.scales) for layer in volume_layers] + ) + + self.dims = len(volume_layers[0].dimensions.scales) + self.volume_layers = { + tuple( + int(x) + for x in map( + operator.truediv, + layer.dimensions.scales, + self.min_voxel_size, + ) + ): layer + for layer in volume_layers + } + + logger.debug("min_voxel_size: %s", self.min_voxel_size) + logger.debug("scale keys: %s", self.volume_layers.keys()) + logger.debug(self.info()) + + @property + def volume_type(self): + return self.volume_layers[(1,) * self.dims].volume_type + + @property + def token(self): + return self.volume_layers[(1,) * self.dims].token + + def info(self): + reference_layer = self.volume_layers[(1,) * self.dims] + reference_info = reference_layer.info() + + info = { + "dataType": reference_info["dataType"], + "encoding": reference_info["encoding"], + "generation": reference_info["generation"], + "coordinateSpace": reference_info["coordinateSpace"], + "shape": reference_info["shape"], + "volumeType": reference_info["volumeType"], + "voxelOffset": reference_info["voxelOffset"], + "chunkLayout": reference_info["chunkLayout"], + "downsamplingLayout": reference_info["downsamplingLayout"], + "maxDownsampling": int( + np.prod( + np.array(self.max_voxel_size) + // np.array(self.min_voxel_size) + ) + ), + "maxDownsampledSize": reference_info["maxDownsampledSize"], + "maxDownsamplingScales": reference_info["maxDownsamplingScales"], + } + + return info + + def get_encoded_subvolume(self, data_format, start, end, scale_key=None): + if scale_key is None: + scale_key = ",".join(("1",) * self.dims) + + scale = tuple(int(s) for s in scale_key.split(",")) + closest_scale = None + min_diff = np.inf + for volume_scales in self.volume_layers.keys(): + scale_diff = np.array(scale) // np.array(volume_scales) + if any(scale_diff < 1): + continue + scale_diff = scale_diff.max() + if scale_diff < min_diff: + min_diff = scale_diff + closest_scale = volume_scales + + assert closest_scale is not None + relative_scale = np.array(scale) // np.array(closest_scale) + + return self.volume_layers[closest_scale].get_encoded_subvolume( + data_format, + start, + end, + scale_key=",".join(map(str, relative_scale)), + ) + + def get_object_mesh(self, object_id): + return self.volume_layers[(1,) * self.dims].get_object_mesh(object_id) + + def invalidate(self): + return self.volume_layers[(1,) * self.dims].invalidate() + volume = ScalePyramid(layers) else: @@ -319,114 +430,10 @@ def parse_multiscale_metadata(data_path: str): return scales, parsed -class ScalePyramid(neuroglancer.LocalVolume): - """A neuroglancer layer that provides volume data on different scales. - Mimics a LocalVolume. - From https://github.com/funkelab/funlib.show.neuroglancer/blob/master/funlib/show/neuroglancer/scale_pyramid.py - - Args: - ---- - volume_layers (``list`` of ``LocalVolume``): - - One ``LocalVolume`` per provided resolution. - """ - - def __init__(self, volume_layers): - volume_layers = volume_layers - - super(neuroglancer.LocalVolume, self).__init__() - - logger.debug("Creating scale pyramid...") - - self.min_voxel_size = min( - [tuple(layer.dimensions.scales) for layer in volume_layers] - ) - self.max_voxel_size = max( - [tuple(layer.dimensions.scales) for layer in volume_layers] - ) - - self.dims = len(volume_layers[0].dimensions.scales) - self.volume_layers = { - tuple( - int(x) - for x in map( - operator.truediv, layer.dimensions.scales, self.min_voxel_size - ) - ): layer - for layer in volume_layers - } - - logger.debug("min_voxel_size: %s", self.min_voxel_size) - logger.debug("scale keys: %s", self.volume_layers.keys()) - logger.debug(self.info()) - - @property - def volume_type(self): - return self.volume_layers[(1,) * self.dims].volume_type - - @property - def token(self): - return self.volume_layers[(1,) * self.dims].token - - def info(self): - reference_layer = self.volume_layers[(1,) * self.dims] - # return reference_layer.info() - - reference_info = reference_layer.info() - - info = { - "dataType": reference_info["dataType"], - "encoding": reference_info["encoding"], - "generation": reference_info["generation"], - "coordinateSpace": reference_info["coordinateSpace"], - "shape": reference_info["shape"], - "volumeType": reference_info["volumeType"], - "voxelOffset": reference_info["voxelOffset"], - "chunkLayout": reference_info["chunkLayout"], - "downsamplingLayout": reference_info["downsamplingLayout"], - "maxDownsampling": int( - np.prod(np.array(self.max_voxel_size) // np.array(self.min_voxel_size)) - ), - "maxDownsampledSize": reference_info["maxDownsampledSize"], - "maxDownsamplingScales": reference_info["maxDownsamplingScales"], - } - - return info - - def get_encoded_subvolume(self, data_format, start, end, scale_key=None): - if scale_key is None: - scale_key = ",".join(("1",) * self.dims) - - scale = tuple(int(s) for s in scale_key.split(",")) - closest_scale = None - min_diff = np.inf - for volume_scales in self.volume_layers.keys(): - scale_diff = np.array(scale) // np.array(volume_scales) - if any(scale_diff < 1): - continue - scale_diff = scale_diff.max() - if scale_diff < min_diff: - min_diff = scale_diff - closest_scale = volume_scales - - assert closest_scale is not None - relative_scale = np.array(scale) // np.array(closest_scale) - - return self.volume_layers[closest_scale].get_encoded_subvolume( - data_format, - start, - end, - scale_key=",".join(map(str, relative_scale)), - ) - - def get_object_mesh(self, object_id): - return self.volume_layers[(1,) * self.dims].get_object_mesh(object_id) - - def invalidate(self): - return self.volume_layers[(1,) * self.dims].invalidate() - - def open_ds_tensorstore(dataset_path: str, mode="r", concurrency_limit=None): + from tensorstore import d as ts_d + from tensorstore import open as ts_open + # open with zarr or n5 depending on extension filetype = ( "zarr" if dataset_path.rfind(".zarr") > dataset_path.rfind(".n5") else "n5" From 529feaacded6d9fd76cae33d28cf16259375067c Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Wed, 18 Feb 2026 15:10:21 -0500 Subject: [PATCH 07/10] black format --- src/cellmap_data/utils/view.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/cellmap_data/utils/view.py b/src/cellmap_data/utils/view.py index 264ef80..e377650 100644 --- a/src/cellmap_data/utils/view.py +++ b/src/cellmap_data/utils/view.py @@ -219,6 +219,7 @@ def get_layer( voxel_offset=metadata[scale]["voxel_offset"], ) ) + class ScalePyramid(neuroglancer.LocalVolume): """A neuroglancer layer that provides volume data on different scales. Mimics a LocalVolume. From 5ef3a14c3046c78d10abe98bb5f4bf3cdf9a1c44 Mon Sep 17 00:00:00 2001 From: Jeff Rhoades <37990507+rhoadesScholar@users.noreply.github.com> Date: Wed, 18 Feb 2026 15:11:28 -0500 Subject: [PATCH 08/10] Update src/cellmap_data/image_writer.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/cellmap_data/image_writer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cellmap_data/image_writer.py b/src/cellmap_data/image_writer.py index ced8a45..1d99658 100644 --- a/src/cellmap_data/image_writer.py +++ b/src/cellmap_data/image_writer.py @@ -54,7 +54,7 @@ def __init__( ) + list(write_voxel_shape) write_voxel_shape = {c: t for c, t in zip(axis_order, write_voxel_shape)} self.scale = scale - # Add bounding_box for channel axis if it was added or if 'c' is in axis_order but not in bounding_box + # Add bounding_box for channel axis if 'c' is in axis_order but not in bounding_box if "c" in axis_order and "c" not in bounding_box: n_channels = write_voxel_shape["c"] bounding_box = {"c": [0, n_channels], **bounding_box} From 571e606c76eadb717dd15adf7c44f29dbe171187 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Wed, 18 Feb 2026 15:57:42 -0500 Subject: [PATCH 09/10] fix: improve axis handling and bounding box initialization in ImageWriter --- src/cellmap_data/image_writer.py | 38 ++++++++++++++------------------ 1 file changed, 17 insertions(+), 21 deletions(-) diff --git a/src/cellmap_data/image_writer.py b/src/cellmap_data/image_writer.py index 1d99658..095d0ad 100644 --- a/src/cellmap_data/image_writer.py +++ b/src/cellmap_data/image_writer.py @@ -38,41 +38,35 @@ def __init__( self.base_path = str(path) self.path = (UPath(path) / f"s{scale_level}").path self.label_class = self.target_class = target_class - channel_axis_added = False - if len(write_voxel_shape) == len(axis_order) + 1 and "c" not in axis_order: - # Add channel axis if missing - axis_order = "c" + axis_order - channel_axis_added = True if isinstance(scale, Sequence): - if len(axis_order) > len(scale): - scale = [1.0] * (len(axis_order) - len(scale)) + list(scale) - scale = {c: s for c, s in zip(axis_order, scale)} + scale = {c: s for c, s in zip(axis_order[::-1], scale[::-1])} + self.scale = scale if isinstance(write_voxel_shape, Sequence): if len(axis_order) > len(write_voxel_shape): # TODO: This might be a bug write_voxel_shape = [1] * ( len(axis_order) - len(write_voxel_shape) ) + list(write_voxel_shape) + elif ( + len(axis_order) + 1 == len(write_voxel_shape) and "c" not in axis_order + ): + axis_order = "c" + axis_order write_voxel_shape = {c: t for c, t in zip(axis_order, write_voxel_shape)} - self.scale = scale - # Add bounding_box for channel axis if 'c' is in axis_order but not in bounding_box - if "c" in axis_order and "c" not in bounding_box: - n_channels = write_voxel_shape["c"] - bounding_box = {"c": [0, n_channels], **bounding_box} + self.axes = axis_order + # Assume axes correspond to last dimensions of voxel shape + self.spatial_axes = axis_order[-len(scale) :] self.bounding_box = bounding_box self.write_voxel_shape = write_voxel_shape self.write_world_shape = { - c: write_voxel_shape[c] * scale[c] for c in axis_order + c: write_voxel_shape[c] * scale[c] for c in self.spatial_axes } - self.axes = axis_order[: len(write_voxel_shape)] self.scale_level = scale_level self.context = context self.overwrite = overwrite self.dtype = dtype self.fill_value = fill_value - dims = [c for c in axis_order] self.metadata = { "offset": list(self.offset.values()), - "axes": dims, + "axes": [c for c in axis_order], "voxel_size": list(self.scale.values()), "shape": list(self.shape.values()), "units": "nanometer", @@ -176,7 +170,8 @@ def world_shape(self) -> Mapping[str, float]: return self._world_shape except AttributeError: self._world_shape = { - c: self.bounding_box[c][1] - self.bounding_box[c][0] for c in self.axes + c: self.bounding_box[c][1] - self.bounding_box[c][0] + for c in self.spatial_axes } return self._world_shape @@ -186,7 +181,8 @@ def shape(self) -> Mapping[str, int]: return self._shape except AttributeError: self._shape = { - c: int(np.ceil(self.world_shape[c] / self.scale[c])) for c in self.axes + c: int(np.ceil(self.world_shape[c] / self.scale[c])) + for c in self.spatial_axes } return self._shape @@ -205,7 +201,7 @@ def offset(self) -> Mapping[str, float]: try: return self._offset except AttributeError: - self._offset = {c: self.bounding_box[c][0] for c in self.axes} + self._offset = {c: self.bounding_box[c][0] for c in self.spatial_axes} return self._offset @property @@ -234,7 +230,7 @@ def align_coords( self, coords: Mapping[str, tuple[Sequence, np.ndarray]] ) -> Mapping[str, tuple[Sequence, np.ndarray]]: aligned_coords = {} - for c in self.axes: + for c in self.spatial_axes: aligned_coords[c] = np.array( self.array.coords[c][ np.abs(np.array(self.array.coords[c])[:, None] - coords[c]).argmin( From 25aa8a459aae87362db5e38efc47f0fd040edd36 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Wed, 18 Feb 2026 16:28:09 -0500 Subject: [PATCH 10/10] feat: add tests for metadata utilities and enhance existing utility tests --- .gitignore | 1 + tests/test_metadata.py | 291 +++++++++++++++++++++ tests/test_multidataset_datasplit.py | 368 +++++++++++++++++++++++++++ tests/test_utils.py | 188 +++++++++++++- 4 files changed, 847 insertions(+), 1 deletion(-) create mode 100644 tests/test_metadata.py diff --git a/.gitignore b/.gitignore index b4274f1..eaa04de 100644 --- a/.gitignore +++ b/.gitignore @@ -116,3 +116,4 @@ clean/ .pytest_cache/ __pycache__/ mypy_cache/ +.claude/ \ No newline at end of file diff --git a/tests/test_metadata.py b/tests/test_metadata.py new file mode 100644 index 0000000..ddb1f74 --- /dev/null +++ b/tests/test_metadata.py @@ -0,0 +1,291 @@ +""" +Tests for utils/metadata.py. + +Tests OME-NGFF metadata generation, writing, and scale-level lookup. +""" + +import json +import os + +import numpy as np +import pytest +import zarr + +from cellmap_data.utils.metadata import ( + add_multiscale_metadata_levels, + create_multiscale_metadata, + find_level, + generate_base_multiscales_metadata, + write_metadata, +) + + +class TestGenerateBaseMultiscalesMetadata: + """Tests for generate_base_multiscales_metadata.""" + + def test_basic_structure(self): + z_attrs = generate_base_multiscales_metadata( + ds_name="my_dataset", + scale_level=0, + voxel_size=[4.0, 4.0, 4.0], + translation=[0.0, 0.0, 0.0], + units="nanometer", + axes=["z", "y", "x"], + ) + assert "multiscales" in z_attrs + assert len(z_attrs["multiscales"]) == 1 + ms = z_attrs["multiscales"][0] + assert ms["version"] == "0.4" + + def test_axes_populated(self): + z_attrs = generate_base_multiscales_metadata( + ds_name="test", + scale_level=0, + voxel_size=[8.0, 8.0, 8.0], + translation=[0.0, 0.0, 0.0], + units="nanometer", + axes=["z", "y", "x"], + ) + axes = z_attrs["multiscales"][0]["axes"] + assert len(axes) == 3 + axis_names = [a["name"] for a in axes] + assert axis_names == ["z", "y", "x"] + for a in axes: + assert a["type"] == "space" + assert a["unit"] == "nanometer" + + def test_dataset_path_uses_scale_level(self): + z_attrs = generate_base_multiscales_metadata( + ds_name="test", + scale_level=2, + voxel_size=[16.0, 16.0, 16.0], + translation=[0.0, 0.0, 0.0], + units="nanometer", + axes=["z", "y", "x"], + ) + datasets = z_attrs["multiscales"][0]["datasets"] + assert datasets[0]["path"] == "s2" + + def test_voxel_size_stored(self): + voxel_size = [4.0, 8.0, 16.0] + z_attrs = generate_base_multiscales_metadata( + ds_name="test", + scale_level=0, + voxel_size=voxel_size, + translation=[0.0, 0.0, 0.0], + units="nanometer", + axes=["z", "y", "x"], + ) + datasets = z_attrs["multiscales"][0]["datasets"] + transforms = datasets[0]["coordinateTransformations"] + scale_transform = next(t for t in transforms if t.get("type") == "scale") + assert scale_transform["scale"] == voxel_size + + def test_zarr_suffix_stripped_from_name(self): + z_attrs = generate_base_multiscales_metadata( + ds_name="some_path/dataset.zarr/subgroup", + scale_level=0, + voxel_size=[4.0, 4.0, 4.0], + translation=[0.0, 0.0, 0.0], + units="nanometer", + axes=["z", "y", "x"], + ) + name = z_attrs["multiscales"][0]["name"] + assert ".zarr" not in name + + def test_name_stored(self): + z_attrs = generate_base_multiscales_metadata( + ds_name="my_group", + scale_level=0, + voxel_size=[4.0, 4.0, 4.0], + translation=[0.0, 0.0, 0.0], + units="nanometer", + axes=["z", "y", "x"], + ) + assert z_attrs["multiscales"][0]["name"] == "my_group" + + def test_2d_axes(self): + z_attrs = generate_base_multiscales_metadata( + ds_name="2d_test", + scale_level=0, + voxel_size=[4.0, 4.0], + translation=[0.0, 0.0], + units="nanometer", + axes=["y", "x"], + ) + axes = z_attrs["multiscales"][0]["axes"] + assert len(axes) == 2 + + +class TestAddMultiscaleMetadataLevels: + """Tests for add_multiscale_metadata_levels.""" + + @pytest.fixture + def base_metadata(self): + return generate_base_multiscales_metadata( + ds_name="test", + scale_level=0, + voxel_size=[4.0, 4.0, 4.0], + translation=[0.0, 0.0, 0.0], + units="nanometer", + axes=["z", "y", "x"], + ) + + def test_adds_correct_number_of_levels(self, base_metadata): + result = add_multiscale_metadata_levels(base_metadata, 0, 3) + datasets = result["multiscales"][0]["datasets"] + # Started with 1 level (s0), added 3 more (s1, s2, s3) + assert len(datasets) == 4 + + def test_added_paths_sequential(self, base_metadata): + result = add_multiscale_metadata_levels(base_metadata, 0, 2) + datasets = result["multiscales"][0]["datasets"] + paths = [d["path"] for d in datasets] + assert "s1" in paths + assert "s2" in paths + + def test_scale_formula(self, base_metadata): + # With base_scale_level=1, the added level uses pow(2, 1)=2, so scale doubles + result = add_multiscale_metadata_levels(base_metadata, 1, 1) + datasets = result["multiscales"][0]["datasets"] + s0_scale = datasets[0]["coordinateTransformations"][0]["scale"] + s1_scale = datasets[1]["coordinateTransformations"][0]["scale"] + # Formula: sn = dim * pow(2, level) where level=1 + for i in range(len(s0_scale)): + assert s1_scale[i] == pytest.approx(s0_scale[i] * 2, rel=1e-5) + + def test_zero_levels_adds_nothing(self, base_metadata): + original_count = len(base_metadata["multiscales"][0]["datasets"]) + result = add_multiscale_metadata_levels(base_metadata, 0, 0) + assert len(result["multiscales"][0]["datasets"]) == original_count + + +class TestCreateMultiscaleMetadata: + """Tests for create_multiscale_metadata.""" + + def test_returns_metadata_without_outpath(self): + result = create_multiscale_metadata( + ds_name="test", + voxel_size=[4.0, 4.0, 4.0], + translation=[0.0, 0.0, 0.0], + units="nanometer", + axes=["z", "y", "x"], + ) + assert result is not None + assert "multiscales" in result + + def test_with_extra_levels(self): + result = create_multiscale_metadata( + ds_name="test", + voxel_size=[4.0, 4.0, 4.0], + translation=[0.0, 0.0, 0.0], + units="nanometer", + axes=["z", "y", "x"], + levels_to_add=2, + ) + datasets = result["multiscales"][0]["datasets"] + assert len(datasets) == 3 + + def test_writes_to_file(self, tmp_path): + out_path = str(tmp_path / "zattrs.json") + result = create_multiscale_metadata( + ds_name="test", + voxel_size=[4.0, 4.0, 4.0], + translation=[0.0, 0.0, 0.0], + units="nanometer", + axes=["z", "y", "x"], + out_path=out_path, + ) + # When out_path given, should return None and write file + assert result is None + assert os.path.exists(out_path) + with open(out_path) as f: + data = json.load(f) + assert "multiscales" in data + + +class TestWriteMetadata: + """Tests for write_metadata.""" + + def test_writes_valid_json(self, tmp_path): + z_attrs = {"multiscales": [{"version": "0.4", "name": "test"}]} + out_path = str(tmp_path / "metadata.json") + write_metadata(z_attrs, out_path) + assert os.path.exists(out_path) + with open(out_path) as f: + loaded = json.load(f) + assert loaded == z_attrs + + def test_overwrites_existing_file(self, tmp_path): + out_path = str(tmp_path / "metadata.json") + write_metadata({"version": "old"}, out_path) + write_metadata({"version": "new"}, out_path) + with open(out_path) as f: + loaded = json.load(f) + assert loaded["version"] == "new" + + def test_indented_output(self, tmp_path): + z_attrs = {"multiscales": [{"version": "0.4"}]} + out_path = str(tmp_path / "indented.json") + write_metadata(z_attrs, out_path) + with open(out_path) as f: + content = f.read() + # Should be pretty-printed (indented) + assert "\n" in content + + +class TestFindLevel: + """Tests for find_level.""" + + @pytest.fixture + def multiscale_zarr(self, tmp_path): + """Create a Zarr group with multiple scale levels.""" + store = zarr.DirectoryStore(str(tmp_path / "test.zarr")) + root = zarr.group(store=store, overwrite=True) + + # Create two scale levels + root.create_dataset("s0", data=np.zeros((64, 64, 64), dtype=np.float32)) + root.create_dataset("s1", data=np.zeros((32, 32, 32), dtype=np.float32)) + + root.attrs["multiscales"] = [ + { + "version": "0.4", + "axes": [ + {"name": "z", "type": "space", "unit": "nanometer"}, + {"name": "y", "type": "space", "unit": "nanometer"}, + {"name": "x", "type": "space", "unit": "nanometer"}, + ], + "datasets": [ + { + "path": "s0", + "coordinateTransformations": [ + {"type": "scale", "scale": [4.0, 4.0, 4.0]}, + {"type": "translation", "translation": [0.0, 0.0, 0.0]}, + ], + }, + { + "path": "s1", + "coordinateTransformations": [ + {"type": "scale", "scale": [8.0, 8.0, 8.0]}, + {"type": "translation", "translation": [2.0, 2.0, 2.0]}, + ], + }, + ], + } + ] + return str(tmp_path / "test.zarr") + + def test_find_fine_level(self, multiscale_zarr): + # Target scale smaller than s0 -> should return s0 + level = find_level(multiscale_zarr, {"z": 2.0, "y": 2.0, "x": 2.0}) + assert level == "s0" + + def test_find_coarse_level(self, multiscale_zarr): + # Target scale between s0 and s1 -> should return s0 (last level not exceeding target) + level = find_level(multiscale_zarr, {"z": 6.0, "y": 6.0, "x": 6.0}) + assert level == "s0" + + def test_find_last_level(self, multiscale_zarr): + # Target scale larger than all levels -> should return last level + level = find_level(multiscale_zarr, {"z": 100.0, "y": 100.0, "x": 100.0}) + assert level == "s1" diff --git a/tests/test_multidataset_datasplit.py b/tests/test_multidataset_datasplit.py index b60eeb1..649a455 100644 --- a/tests/test_multidataset_datasplit.py +++ b/tests/test_multidataset_datasplit.py @@ -4,7 +4,12 @@ Tests combining multiple datasets and train/validation splits. """ +import csv +import os + import pytest +import torch +import torchvision.transforms.v2 as T from cellmap_data import CellMapDataset, CellMapDataSplit, CellMapMultiDataset @@ -451,3 +456,366 @@ def test_different_resolution_datasets(self, tmp_path): ) assert len(multi_dataset.datasets) == 2 + + +class TestCellMapMultiDatasetProperties: + """Tests for CellMapMultiDataset properties and methods not yet covered.""" + + @pytest.fixture + def multi_dataset(self, tmp_path): + """Build a CellMapMultiDataset from two real datasets.""" + datasets = [] + for i in range(2): + config = create_test_dataset( + tmp_path / f"ds_{i}", + raw_shape=(32, 32, 32), + num_classes=2, + raw_scale=(4.0, 4.0, 4.0), + seed=i, + ) + ds = CellMapDataset( + raw_path=config["raw_path"], + target_path=config["gt_path"], + classes=config["classes"], + input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + ) + datasets.append(ds) + + return CellMapMultiDataset( + classes=["class_0", "class_1"], + input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + datasets=datasets, + ) + + def test_has_data_true(self, multi_dataset): + assert multi_dataset.has_data is True + + def test_class_counts_structure(self, multi_dataset): + counts = multi_dataset.class_counts + assert "totals" in counts + assert "class_0" in counts["totals"] + assert "class_1" in counts["totals"] + + def test_class_weights_keys(self, multi_dataset): + weights = multi_dataset.class_weights + assert "class_0" in weights + assert "class_1" in weights + for w in weights.values(): + assert w >= 0 + + def test_dataset_weights_keys(self, multi_dataset): + dw = multi_dataset.dataset_weights + # Should have one entry per dataset + assert len(dw) == len(multi_dataset.datasets) + for w in dw.values(): + assert w >= 0 + + def test_sample_weights_length(self, multi_dataset): + sw = multi_dataset.sample_weights + assert len(sw) == len(multi_dataset) + + def test_validation_indices_nonempty(self, multi_dataset): + indices = multi_dataset.validation_indices + assert isinstance(indices, list) + assert len(indices) > 0 + assert all(0 <= i < len(multi_dataset) for i in indices) + + def test_verify_true(self, multi_dataset): + assert multi_dataset.verify() is True + + def test_get_weighted_sampler(self, multi_dataset): + sampler = multi_dataset.get_weighted_sampler(batch_size=4) + assert sampler is not None + + def test_get_random_subset_indices(self, multi_dataset): + indices = multi_dataset.get_random_subset_indices(4, weighted=False) + assert len(indices) == 4 + + def test_get_random_subset_indices_weighted(self, multi_dataset): + indices = multi_dataset.get_random_subset_indices(4, weighted=True) + assert len(indices) == 4 + + def test_get_subset_random_sampler(self, multi_dataset): + sampler = multi_dataset.get_subset_random_sampler(4) + assert sampler is not None + + def test_get_indices(self, multi_dataset): + indices = multi_dataset.get_indices({"x": 8, "y": 8, "z": 8}) + assert isinstance(indices, list) + assert len(indices) > 0 + + def test_set_raw_value_transforms(self, multi_dataset): + new_transforms = T.Compose([T.ToDtype(torch.float, scale=True)]) + multi_dataset.set_raw_value_transforms(new_transforms) + + def test_set_target_value_transforms(self, multi_dataset): + new_transforms = T.Compose([T.ToDtype(torch.float)]) + multi_dataset.set_target_value_transforms(new_transforms) + + def test_set_spatial_transforms(self, multi_dataset): + transforms = {"mirror": {"axes": {"x": 0.5}}} + multi_dataset.set_spatial_transforms(transforms) + + def test_repr(self, multi_dataset): + r = repr(multi_dataset) + assert "CellMapMultiDataset" in r + + def test_empty_class_method(self): + empty = CellMapMultiDataset.empty() + assert empty is not None + assert empty.has_data is False + assert empty.classes == [] + assert empty.validation_indices == [] + + def test_verify_empty_returns_false(self): + empty = CellMapMultiDataset.empty() + assert empty.verify() is False + + def test_no_classes_dataset_weights(self, tmp_path): + """Dataset weights with no classes should give equal weights.""" + config = create_test_dataset(tmp_path / "ds", raw_shape=(32, 32, 32)) + ds = CellMapDataset( + raw_path=config["raw_path"], + target_path=config["gt_path"], + classes=config["classes"], + input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + ) + multi = CellMapMultiDataset( + classes=[], + input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + target_arrays={}, + datasets=[ds], + ) + dw = multi.dataset_weights + assert list(dw.values())[0] == 1.0 + + +class TestCellMapDataSplitExtended: + """Extended tests for CellMapDataSplit.""" + + @pytest.fixture + def train_val_configs(self, tmp_path): + train = [] + for i in range(2): + train.append( + create_test_dataset( + tmp_path / f"train_{i}", + raw_shape=(32, 32, 32), + num_classes=2, + seed=i, + ) + ) + val = [ + create_test_dataset( + tmp_path / "val_0", + raw_shape=(32, 32, 32), + num_classes=2, + seed=99, + ) + ] + return train, val + + @pytest.fixture + def datasplit(self, train_val_configs): + train, val = train_val_configs + dataset_dict = { + "train": [{"raw": c["raw_path"], "gt": c["gt_path"]} for c in train], + "validate": [{"raw": c["raw_path"], "gt": c["gt_path"]} for c in val], + } + return CellMapDataSplit( + dataset_dict=dataset_dict, + classes=["class_0", "class_1"], + input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + force_has_data=True, + ) + + def test_from_csv(self, tmp_path, train_val_configs): + """Test CellMapDataSplit.from_csv loads the dataset_dict correctly.""" + train, val = train_val_configs + + csv_path = str(tmp_path / "splits.csv") + rows = [] + for c in train: + raw_dir, raw_file = os.path.split(c["raw_path"]) + gt_dir, gt_file = os.path.split(c["gt_path"]) + rows.append(["train", raw_dir, raw_file, gt_dir, gt_file]) + for c in val: + raw_dir, raw_file = os.path.split(c["raw_path"]) + gt_dir, gt_file = os.path.split(c["gt_path"]) + rows.append(["validate", raw_dir, raw_file, gt_dir, gt_file]) + + with open(csv_path, "w", newline="") as f: + writer = csv.writer(f) + writer.writerows(rows) + + # Use from_csv via the constructor + split = CellMapDataSplit( + csv_path=csv_path, + classes=["class_0", "class_1"], + input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + force_has_data=True, + ) + assert len(split.train_datasets) == 2 + assert len(split.validation_datasets) == 1 + + def test_from_csv_no_gt(self, tmp_path, train_val_configs): + """Test CSV rows without gt columns.""" + train, _ = train_val_configs + + csv_path = str(tmp_path / "splits_no_gt.csv") + rows = [] + for c in train: + raw_dir, raw_file = os.path.split(c["raw_path"]) + rows.append(["train", raw_dir, raw_file]) + + with open(csv_path, "w", newline="") as f: + writer = csv.writer(f) + writer.writerows(rows) + + # Direct call to from_csv + split = CellMapDataSplit( + classes=["class_0", "class_1"], + input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + force_has_data=True, + datasets={"train": []}, + ) + # Read CSV manually using the method + result = split.from_csv(csv_path) + assert "train" in result + assert len(result["train"]) == 2 + for entry in result["train"]: + assert entry["gt"] == "" + + def test_train_datasets_combined_property(self, datasplit): + combined = datasplit.train_datasets_combined + assert combined is not None + assert len(combined) > 0 + + def test_validation_datasets_combined_property(self, datasplit): + combined = datasplit.validation_datasets_combined + assert combined is not None + + def test_class_counts_property(self, datasplit): + counts = datasplit.class_counts + assert "train" in counts + assert "validate" in counts + + def test_repr(self, datasplit): + r = repr(datasplit) + assert "CellMapDataSplit" in r + + def test_no_source_raises(self): + """Providing no data source should raise ValueError.""" + with pytest.raises(ValueError, match="One of"): + CellMapDataSplit( + classes=["class_0"], + input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + ) + + def test_set_raw_value_transforms(self, datasplit): + new_transform = T.Compose([T.ToDtype(torch.float, scale=True)]) + datasplit.set_raw_value_transforms( + train_transforms=new_transform, val_transforms=new_transform + ) + + def test_set_target_value_transforms(self, datasplit): + new_transform = T.Compose([T.ToDtype(torch.float)]) + datasplit.set_target_value_transforms(new_transform) + + def test_set_spatial_transforms(self, datasplit): + train_transforms = {"mirror": {"axes": {"x": 0.5}}} + datasplit.set_spatial_transforms(train_transforms=train_transforms) + + def test_set_raw_value_transforms_after_combined(self, datasplit): + """Test set_raw_value_transforms after train_datasets_combined is cached.""" + _ = datasplit.train_datasets_combined + new_transform = T.Compose([T.ToDtype(torch.float, scale=True)]) + datasplit.set_raw_value_transforms(train_transforms=new_transform) + + def test_set_target_value_transforms_after_combined(self, datasplit): + """Test set_target_value_transforms after combined datasets are cached.""" + _ = datasplit.train_datasets_combined + _ = datasplit.validation_datasets_combined + new_transform = T.Compose([T.ToDtype(torch.float)]) + datasplit.set_target_value_transforms(new_transform) + + def test_set_spatial_transforms_after_combined(self, datasplit): + """Test set_spatial_transforms after train_datasets_combined is cached.""" + _ = datasplit.train_datasets_combined + _ = datasplit.validation_datasets_combined + transforms = {"mirror": {"axes": {"x": 0.5}}} + datasplit.set_spatial_transforms( + train_transforms=transforms, val_transforms=transforms + ) + + def test_to_device(self, datasplit): + datasplit.to("cpu") + assert datasplit.device == "cpu" + + def test_to_device_after_combined(self, datasplit): + _ = datasplit.train_datasets_combined + _ = datasplit.validation_datasets_combined + datasplit.to("cpu") + + def test_pad_string_train(self, train_val_configs): + train, val = train_val_configs + dataset_dict = { + "train": [{"raw": c["raw_path"], "gt": c["gt_path"]} for c in train], + "validate": [{"raw": c["raw_path"], "gt": c["gt_path"]} for c in val], + } + split = CellMapDataSplit( + dataset_dict=dataset_dict, + classes=["class_0", "class_1"], + input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + pad="train", + force_has_data=True, + ) + assert split.pad_training is True + assert split.pad_validation is False + + def test_pad_string_validate(self, train_val_configs): + train, val = train_val_configs + dataset_dict = { + "train": [{"raw": c["raw_path"], "gt": c["gt_path"]} for c in train], + "validate": [{"raw": c["raw_path"], "gt": c["gt_path"]} for c in val], + } + split = CellMapDataSplit( + dataset_dict=dataset_dict, + classes=["class_0", "class_1"], + input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + pad="validate", + force_has_data=True, + ) + assert split.pad_training is False + assert split.pad_validation is True + + def test_initialization_with_datasets_no_validate(self, tmp_path): + """Test providing datasets dict without validate key.""" + config = create_test_dataset(tmp_path / "ds", raw_shape=(32, 32, 32)) + ds = CellMapDataset( + raw_path=config["raw_path"], + target_path=config["gt_path"], + classes=config["classes"], + input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + ) + split = CellMapDataSplit( + classes=config["classes"], + input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, + datasets={"train": [ds]}, + force_has_data=True, + ) + assert split.validation_datasets == [] + + def test_validation_blocks_property(self, datasplit): + blocks = datasplit.validation_blocks + assert blocks is not None diff --git a/tests/test_utils.py b/tests/test_utils.py index f63ba15..9052276 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -7,7 +7,16 @@ import numpy as np import torch -from cellmap_data.utils.misc import get_sliced_shape, torch_max_value +from cellmap_data.utils.misc import ( + array_has_singleton_dim, + expand_scale, + get_sliced_shape, + longest_common_substring, + permute_singleton_dimension, + split_target_path, + torch_max_value, +) +from cellmap_data.utils.sampling import min_redundant_inds class TestUtilsMisc: @@ -266,3 +275,180 @@ def test_dtype_max_values(self): # Float types return 1 (normalized) assert torch_max_value(torch.float32) == 1 assert torch_max_value(torch.float64) == 1 + + +class TestLongestCommonSubstring: + """Tests for longest_common_substring utility.""" + + def test_identical_strings(self): + result = longest_common_substring("abcdef", "abcdef") + assert result == "abcdef" + + def test_partial_overlap(self): + result = longest_common_substring("abcXYZ", "XYZdef") + assert result == "XYZ" + + def test_no_overlap(self): + result = longest_common_substring("abc", "xyz") + assert result == "" + + def test_substring_at_start(self): + result = longest_common_substring("hello world", "hello there") + assert result == "hello " + + def test_single_char_overlap(self): + result = longest_common_substring("abc", "cde") + assert result == "c" + + def test_empty_string(self): + result = longest_common_substring("", "abc") + assert result == "" + + def test_path_like_strings(self): + a = "/data/train/dataset_0/raw" + b = "/data/train/dataset_1/raw" + result = longest_common_substring(a, b) + assert len(result) > 0 + assert result in a and result in b + + +class TestExpandScale: + """Tests for expand_scale utility.""" + + def test_2d_scale_expanded(self): + scale = [4.0, 8.0] + result = expand_scale(scale) + assert len(result) == 3 + assert result[0] == 4.0 # first element duplicated at front + + def test_3d_scale_unchanged(self): + scale = [4.0, 8.0, 16.0] + result = expand_scale(scale) + assert result == [4.0, 8.0, 16.0] + + def test_isotropic_2d(self): + scale = [4.0, 4.0] + result = expand_scale(scale) + assert len(result) == 3 + assert result == [4.0, 4.0, 4.0] + + def test_single_element(self): + scale = [8.0] + result = expand_scale(scale) + assert len(result) == 1 # no change for 1D + + +class TestArrayHasSingletonDim: + """Tests for array_has_singleton_dim utility.""" + + def test_with_singleton(self): + arr_info = {"shape": (1, 64, 64)} + assert array_has_singleton_dim(arr_info) is True + + def test_without_singleton(self): + arr_info = {"shape": (8, 64, 64)} + assert array_has_singleton_dim(arr_info) is False + + def test_none_input(self): + assert array_has_singleton_dim(None) is False + + def test_empty_dict(self): + assert array_has_singleton_dim({}) is False + + def test_nested_dict_any(self): + arr_info = { + "raw": {"shape": (1, 64, 64)}, + "labels": {"shape": (8, 64, 64)}, + } + # summary=True (default) returns True if any has singleton + assert array_has_singleton_dim(arr_info, summary=True) is True + + def test_nested_dict_none_singleton(self): + arr_info = { + "raw": {"shape": (4, 64, 64)}, + "labels": {"shape": (8, 64, 64)}, + } + assert array_has_singleton_dim(arr_info, summary=True) is False + + def test_nested_dict_per_key(self): + arr_info = { + "raw": {"shape": (1, 64, 64)}, + "labels": {"shape": (8, 64, 64)}, + } + result = array_has_singleton_dim(arr_info, summary=False) + assert isinstance(result, dict) + assert result["raw"] is True + assert result["labels"] is False + + +class TestPermutesSingletonDimension: + """Tests for permute_singleton_dimension utility.""" + + def test_single_array_dict(self): + arr_dict = {"shape": (64, 64), "scale": (4.0, 4.0)} + permute_singleton_dimension(arr_dict, axis=0) + assert len(arr_dict["shape"]) == 3 + assert arr_dict["shape"][0] == 1 + assert len(arr_dict["scale"]) == 3 + + def test_nested_array_dict(self): + arr_dict = { + "raw": {"shape": (64, 64), "scale": (4.0, 4.0)}, + "labels": {"shape": (64, 64), "scale": (4.0, 4.0)}, + } + permute_singleton_dimension(arr_dict, axis=1) + assert len(arr_dict["raw"]["shape"]) == 3 + assert len(arr_dict["labels"]["shape"]) == 3 + + def test_axis_placement(self): + arr_dict = {"shape": (64, 64), "scale": (4.0, 8.0)} + permute_singleton_dimension(arr_dict, axis=2) + assert arr_dict["shape"][2] == 1 + + def test_existing_singleton_moved(self): + # shape already has a singleton, but at wrong position + arr_dict = {"shape": (1, 64, 64), "scale": (4.0, 4.0, 4.0)} + permute_singleton_dimension(arr_dict, axis=2) + assert arr_dict["shape"][2] == 1 + + +class TestMinRedundantInds: + """Tests for min_redundant_inds from utils.sampling.""" + + def test_basic_sampling_under_size(self): + result = min_redundant_inds(10, 5) + assert len(result) == 5 + assert result.max() < 10 + + def test_exact_size(self): + result = min_redundant_inds(10, 10) + assert len(result) == 10 + # Should be a permutation + assert set(result.tolist()) == set(range(10)) + + def test_oversample(self): + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + result = min_redundant_inds(5, 12) + assert len(result) == 12 + assert result.max() < 5 + + def test_with_rng(self): + rng = torch.Generator() + rng.manual_seed(42) + result1 = min_redundant_inds(10, 5, rng=rng) + rng.manual_seed(42) + result2 = min_redundant_inds(10, 5, rng=rng) + assert torch.equal(result1, result2) + + def test_invalid_size_raises(self): + import pytest + + with pytest.raises(ValueError): + min_redundant_inds(0, 5) + + def test_returns_tensor(self): + result = min_redundant_inds(10, 5) + assert isinstance(result, torch.Tensor)