Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 31 additions & 61 deletions src/napari_easytrack/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def _convert_to_labels(frame):

# ============= SEGMENTATION PROCESSING =============

def clean_segmentation(segmentation, verbose=True, min_size=5):
def clean_segmentation(segmentation, verbose=True, min_size=5, inplace=False):
"""
Clean segmentation by:
1. Keeping only the largest connected component per label
Expand All @@ -342,10 +342,12 @@ def clean_segmentation(segmentation, verbose=True, min_size=5):
segmentation: 3D array (T, Y, X) or 4D array (T, Z, Y, X) with integer labels
verbose: If True, print progress information
min_size: Minimum number of pixels/voxels for a label (default: 5)
inplace: If True, modify segmentation in place instead of creating a copy.
Use this to reduce memory usage on large arrays (default: False)
Returns:
Cleaned segmentation with same shape and dtype
"""
cleaned = segmentation.copy()
cleaned = segmentation if inplace else segmentation.copy()
is_3d = (segmentation.ndim == 4)
n_timepoints = segmentation.shape[0]

Expand All @@ -359,25 +361,8 @@ def clean_segmentation(segmentation, verbose=True, min_size=5):
total_removed = 0

for t in range(n_timepoints):
# DEBUG: Only for frame 13
if t == 13:
structure = ndimage.generate_binary_structure(cleaned[t].ndim, connectivity=1)
label_31_mask = (cleaned[t] == 31)
if np.any(label_31_mask):
labeled_31, num_31 = ndimage.label(label_31_mask, structure=structure)
sizes_31 = ndimage.sum(label_31_mask, labeled_31, range(1, num_31 + 1)) if num_31 > 0 else []
print(f"\nBEFORE Frame 13: Label 31 has {num_31} components, sizes: {sizes_31}")

stats = _clean_frame(cleaned[t], min_size, is_3d)

if t == 13:
structure = ndimage.generate_binary_structure(cleaned[t].ndim, connectivity=1)
label_31_mask = (cleaned[t] == 31)
if np.any(label_31_mask):
labeled_31, num_31 = ndimage.label(label_31_mask, structure=structure)
sizes_31 = ndimage.sum(label_31_mask, labeled_31, range(1, num_31 + 1)) if num_31 > 0 else []
print(f"AFTER Frame 13: Label 31 has {num_31} components, sizes: {sizes_31}\n")

if verbose and (stats['reassigned'] > 0 or stats['removed'] > 0):
timepoint_name = "Timepoint" if is_3d else "Frame"
unit = "voxels" if is_3d else "pixels"
Expand All @@ -398,61 +383,46 @@ def clean_segmentation(segmentation, verbose=True, min_size=5):

def _clean_frame(frame, min_size, is_3d):
"""
Single-pass cleaning: Find and reassign disconnected fragments
Single-pass cleaning: Find and reassign disconnected fragments.
Processes each label independently to minimise peak memory usage.
"""
structure = ndimage.generate_binary_structure(frame.ndim, connectivity=1)
total_stats = {'reassigned': 0, 'removed': 0}

# Get all unique labels
unique_labels = np.unique(frame)
unique_labels = unique_labels[unique_labels > 0]

component_info = {}
next_id = 0
label_components = {}


for label in unique_labels:
label_mask = (frame == label)
labeled, num_comps = ndimage.label(label_mask, structure=structure)

del label_mask # free memory immediately

if num_comps <= 1:
continue # Skip if only 1 component

label_components[label] = []

for comp_idx in range(1, num_comps + 1):
comp_mask = (labeled == comp_idx)
size = int(np.sum(comp_mask))

component_info[next_id] = {
'original_label': int(label),
'size': size,
'mask': comp_mask
}
label_components[label].append(next_id)
next_id += 1

components_to_reassign = []

for label, comp_ids in label_components.items():
sizes = [component_info[cid]['size'] for cid in comp_ids]
largest_idx = np.argmax(sizes)
largest_size = sizes[largest_idx]

continue # nothing to do for a single connected component

# Compute component sizes without creating per-component boolean arrays
sizes = np.bincount(labeled.ravel())[1:] # skip background (index 0)
largest_idx = int(np.argmax(sizes))
largest_size = int(sizes[largest_idx])

if largest_size < min_size:
components_to_reassign.extend(comp_ids)
# Reassign every component of this label
for comp_idx in range(1, num_comps + 1):
comp_mask = labeled == comp_idx
neighbor = _find_neighbor_label(comp_mask, frame, int(label), structure)
frame[comp_mask] = neighbor
total_stats['reassigned'] += int(sizes[comp_idx - 1])
total_stats['removed'] += 1
else:
for i, comp_id in enumerate(comp_ids):
if i != largest_idx:
components_to_reassign.append(comp_id)

for comp_id in components_to_reassign:
info = component_info[comp_id]
neighbor = _find_neighbor_label(info['mask'], frame, info['original_label'], structure)
frame[info['mask']] = neighbor
total_stats['reassigned'] += info['size']

# Reassign only the non-largest components
for comp_idx in range(1, num_comps + 1):
if comp_idx - 1 != largest_idx:
comp_mask = labeled == comp_idx
neighbor = _find_neighbor_label(comp_mask, frame, int(label), structure)
frame[comp_mask] = neighbor
total_stats['reassigned'] += int(sizes[comp_idx - 1])

return total_stats


Expand Down
37 changes: 37 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,43 @@ def test_preserves_dtype(self):
# Should return integer dtype
assert np.issubdtype(cleaned.dtype, np.integer)

def test_inplace_false_does_not_modify_input(self):
"""Test that inplace=False (default) does not modify the original array."""
seg = np.zeros((2, 15, 15), dtype=np.int32)
seg[0, 2:8, 2:8] = 1
seg[0, 12:14, 12:14] = 1 # small disconnected fragment
original = seg.copy()

cleaned = clean_segmentation(seg, verbose=False, min_size=5, inplace=False)

# Input array must be unchanged
np.testing.assert_array_equal(seg, original)
# Returned array is a different object
assert cleaned is not seg

def test_inplace_true_modifies_input(self):
"""Test that inplace=True modifies the original array and returns it."""
seg = np.zeros((2, 15, 15), dtype=np.int32)
seg[0, 2:8, 2:8] = 1
seg[0, 12:14, 12:14] = 1 # small disconnected fragment

result = clean_segmentation(seg, verbose=False, min_size=5, inplace=True)

# The returned object is the same array that was passed in
assert result is seg

def test_inplace_true_produces_same_result_as_copy(self):
"""Test that inplace=True produces identical output to inplace=False."""
seg = np.zeros((2, 15, 15), dtype=np.int32)
seg[0, 2:8, 2:8] = 1
seg[0, 12:14, 12:14] = 1 # small disconnected fragment
seg[1, 3:9, 3:9] = 2

cleaned_copy = clean_segmentation(seg.copy(), verbose=False, min_size=5, inplace=False)
cleaned_inplace = clean_segmentation(seg.copy(), verbose=False, min_size=5, inplace=True)

np.testing.assert_array_equal(cleaned_copy, cleaned_inplace)


class TestGetCleaningStats:
"""Tests for the get_cleaning_stats function."""
Expand Down