Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,4 @@ clean/
.pytest_cache/
__pycache__/
mypy_cache/
.claude/
4 changes: 3 additions & 1 deletion src/cellmap_data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ def device(self) -> torch.device:
return self._device

def __len__(self) -> int:
"""Returns the number of patches in the dataset."""
"""Returns the number of unique patches in the dataset."""
if not self.has_data and not self.force_has_data:
return 0
# Return at least 1 if the dataset has data, so that samplers can be initialized
Expand Down Expand Up @@ -976,4 +976,6 @@ def empty() -> "CellMapDataset":
# Directly instantiate to bypass __new__ logic
instance = super(CellMapDataset, CellMapDataset).__new__(CellMapDataset)
instance.__init__("", "", [], {}, {}, force_has_data=False)
instance.has_data = False
instance._sampling_box_shape = {c: 0 for c in instance.axis_order}
return instance
4 changes: 4 additions & 0 deletions src/cellmap_data/dataset_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,10 @@ def get_image_writer(
shape = array_info["shape"]
if not isinstance(shape, (Mapping, Sequence)):
raise TypeError(f"Shape must be a Mapping or Sequence, not {type(shape)}")
if "n_channels" in array_info:
shape = [array_info["n_channels"]] + list(shape)
if "c" not in self.axis_order:
self.axis_order = "c" + self.axis_order
scale_level = array_info.get("scale_level", 0)
if not isinstance(scale_level, int):
raise TypeError(f"Scale level must be an int, not {type(scale_level)}")
Expand Down
31 changes: 18 additions & 13 deletions src/cellmap_data/image_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,31 +39,34 @@ def __init__(
self.path = (UPath(path) / f"s{scale_level}").path
self.label_class = self.target_class = target_class
if isinstance(scale, Sequence):
if len(axis_order) > len(scale):
scale = [scale[0]] * (len(axis_order) - len(scale)) + list(scale)
scale = {c: s for c, s in zip(axis_order, scale)}
scale = {c: s for c, s in zip(axis_order[::-1], scale[::-1])}
self.scale = scale
if isinstance(write_voxel_shape, Sequence):
if len(axis_order) > len(write_voxel_shape):
if len(axis_order) > len(write_voxel_shape): # TODO: This might be a bug
write_voxel_shape = [1] * (
len(axis_order) - len(write_voxel_shape)
) + list(write_voxel_shape)
elif (
len(axis_order) + 1 == len(write_voxel_shape) and "c" not in axis_order
):
axis_order = "c" + axis_order
write_voxel_shape = {c: t for c, t in zip(axis_order, write_voxel_shape)}
self.scale = scale
self.axes = axis_order
# Assume axes correspond to last dimensions of voxel shape
self.spatial_axes = axis_order[-len(scale) :]
self.bounding_box = bounding_box
self.write_voxel_shape = write_voxel_shape
self.write_world_shape = {
c: write_voxel_shape[c] * scale[c] for c in axis_order
c: write_voxel_shape[c] * scale[c] for c in self.spatial_axes
}
self.axes = axis_order[: len(write_voxel_shape)]
self.scale_level = scale_level
self.context = context
self.overwrite = overwrite
self.dtype = dtype
self.fill_value = fill_value
dims = [c for c in axis_order]
self.metadata = {
"offset": list(self.offset.values()),
"axes": dims,
"axes": [c for c in axis_order],
"voxel_size": list(self.scale.values()),
"shape": list(self.shape.values()),
"units": "nanometer",
Expand Down Expand Up @@ -167,7 +170,8 @@ def world_shape(self) -> Mapping[str, float]:
return self._world_shape
except AttributeError:
self._world_shape = {
c: self.bounding_box[c][1] - self.bounding_box[c][0] for c in self.axes
c: self.bounding_box[c][1] - self.bounding_box[c][0]
for c in self.spatial_axes
}
return self._world_shape

Expand All @@ -177,7 +181,8 @@ def shape(self) -> Mapping[str, int]:
return self._shape
except AttributeError:
self._shape = {
c: int(np.ceil(self.world_shape[c] / self.scale[c])) for c in self.axes
c: int(np.ceil(self.world_shape[c] / self.scale[c]))
for c in self.spatial_axes
}
return self._shape

Expand All @@ -196,7 +201,7 @@ def offset(self) -> Mapping[str, float]:
try:
return self._offset
except AttributeError:
self._offset = {c: self.bounding_box[c][0] for c in self.axes}
self._offset = {c: self.bounding_box[c][0] for c in self.spatial_axes}
return self._offset

@property
Expand Down Expand Up @@ -225,7 +230,7 @@ def align_coords(
self, coords: Mapping[str, tuple[Sequence, np.ndarray]]
) -> Mapping[str, tuple[Sequence, np.ndarray]]:
aligned_coords = {}
for c in self.axes:
for c in self.spatial_axes:
aligned_coords[c] = np.array(
self.array.coords[c][
np.abs(np.array(self.array.coords[c])[:, None] - coords[c]).argmin(
Expand Down
Loading