diff --git a/src/napari_easytrack/utils.py b/src/napari_easytrack/utils.py index eb276f4..db8e6be 100644 --- a/src/napari_easytrack/utils.py +++ b/src/napari_easytrack/utils.py @@ -333,7 +333,7 @@ def _convert_to_labels(frame): # ============= SEGMENTATION PROCESSING ============= -def clean_segmentation(segmentation, verbose=True, min_size=5): +def clean_segmentation(segmentation, verbose=True, min_size=5, inplace=False): """ Clean segmentation by: 1. Keeping only the largest connected component per label @@ -342,10 +342,12 @@ def clean_segmentation(segmentation, verbose=True, min_size=5): segmentation: 3D array (T, Y, X) or 4D array (T, Z, Y, X) with integer labels verbose: If True, print progress information min_size: Minimum number of pixels/voxels for a label (default: 5) + inplace: If True, modify segmentation in place instead of creating a copy. + Use this to reduce memory usage on large arrays (default: False) Returns: Cleaned segmentation with same shape and dtype """ - cleaned = segmentation.copy() + cleaned = segmentation if inplace else segmentation.copy() is_3d = (segmentation.ndim == 4) n_timepoints = segmentation.shape[0] @@ -359,25 +361,8 @@ def clean_segmentation(segmentation, verbose=True, min_size=5): total_removed = 0 for t in range(n_timepoints): - # DEBUG: Only for frame 13 - if t == 13: - structure = ndimage.generate_binary_structure(cleaned[t].ndim, connectivity=1) - label_31_mask = (cleaned[t] == 31) - if np.any(label_31_mask): - labeled_31, num_31 = ndimage.label(label_31_mask, structure=structure) - sizes_31 = ndimage.sum(label_31_mask, labeled_31, range(1, num_31 + 1)) if num_31 > 0 else [] - print(f"\nBEFORE Frame 13: Label 31 has {num_31} components, sizes: {sizes_31}") - stats = _clean_frame(cleaned[t], min_size, is_3d) - if t == 13: - structure = ndimage.generate_binary_structure(cleaned[t].ndim, connectivity=1) - label_31_mask = (cleaned[t] == 31) - if np.any(label_31_mask): - labeled_31, num_31 = ndimage.label(label_31_mask, structure=structure) - sizes_31 = ndimage.sum(label_31_mask, labeled_31, range(1, num_31 + 1)) if num_31 > 0 else [] - print(f"AFTER Frame 13: Label 31 has {num_31} components, sizes: {sizes_31}\n") - if verbose and (stats['reassigned'] > 0 or stats['removed'] > 0): timepoint_name = "Timepoint" if is_3d else "Frame" unit = "voxels" if is_3d else "pixels" @@ -398,61 +383,46 @@ def clean_segmentation(segmentation, verbose=True, min_size=5): def _clean_frame(frame, min_size, is_3d): """ - Single-pass cleaning: Find and reassign disconnected fragments + Single-pass cleaning: Find and reassign disconnected fragments. + Processes each label independently to minimise peak memory usage. """ structure = ndimage.generate_binary_structure(frame.ndim, connectivity=1) total_stats = {'reassigned': 0, 'removed': 0} - + # Get all unique labels unique_labels = np.unique(frame) unique_labels = unique_labels[unique_labels > 0] - - component_info = {} - next_id = 0 - label_components = {} - + for label in unique_labels: label_mask = (frame == label) labeled, num_comps = ndimage.label(label_mask, structure=structure) - + del label_mask # free memory immediately + if num_comps <= 1: - continue # Skip if only 1 component - - label_components[label] = [] - - for comp_idx in range(1, num_comps + 1): - comp_mask = (labeled == comp_idx) - size = int(np.sum(comp_mask)) - - component_info[next_id] = { - 'original_label': int(label), - 'size': size, - 'mask': comp_mask - } - label_components[label].append(next_id) - next_id += 1 - - components_to_reassign = [] - - for label, comp_ids in label_components.items(): - sizes = [component_info[cid]['size'] for cid in comp_ids] - largest_idx = np.argmax(sizes) - largest_size = sizes[largest_idx] - + continue # nothing to do for a single connected component + + # Compute component sizes without creating per-component boolean arrays + sizes = np.bincount(labeled.ravel())[1:] # skip background (index 0) + largest_idx = int(np.argmax(sizes)) + largest_size = int(sizes[largest_idx]) + if largest_size < min_size: - components_to_reassign.extend(comp_ids) + # Reassign every component of this label + for comp_idx in range(1, num_comps + 1): + comp_mask = labeled == comp_idx + neighbor = _find_neighbor_label(comp_mask, frame, int(label), structure) + frame[comp_mask] = neighbor + total_stats['reassigned'] += int(sizes[comp_idx - 1]) total_stats['removed'] += 1 else: - for i, comp_id in enumerate(comp_ids): - if i != largest_idx: - components_to_reassign.append(comp_id) - - for comp_id in components_to_reassign: - info = component_info[comp_id] - neighbor = _find_neighbor_label(info['mask'], frame, info['original_label'], structure) - frame[info['mask']] = neighbor - total_stats['reassigned'] += info['size'] - + # Reassign only the non-largest components + for comp_idx in range(1, num_comps + 1): + if comp_idx - 1 != largest_idx: + comp_mask = labeled == comp_idx + neighbor = _find_neighbor_label(comp_mask, frame, int(label), structure) + frame[comp_mask] = neighbor + total_stats['reassigned'] += int(sizes[comp_idx - 1]) + return total_stats diff --git a/tests/test_utils.py b/tests/test_utils.py index a80ed12..284f7c0 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -157,6 +157,43 @@ def test_preserves_dtype(self): # Should return integer dtype assert np.issubdtype(cleaned.dtype, np.integer) + def test_inplace_false_does_not_modify_input(self): + """Test that inplace=False (default) does not modify the original array.""" + seg = np.zeros((2, 15, 15), dtype=np.int32) + seg[0, 2:8, 2:8] = 1 + seg[0, 12:14, 12:14] = 1 # small disconnected fragment + original = seg.copy() + + cleaned = clean_segmentation(seg, verbose=False, min_size=5, inplace=False) + + # Input array must be unchanged + np.testing.assert_array_equal(seg, original) + # Returned array is a different object + assert cleaned is not seg + + def test_inplace_true_modifies_input(self): + """Test that inplace=True modifies the original array and returns it.""" + seg = np.zeros((2, 15, 15), dtype=np.int32) + seg[0, 2:8, 2:8] = 1 + seg[0, 12:14, 12:14] = 1 # small disconnected fragment + + result = clean_segmentation(seg, verbose=False, min_size=5, inplace=True) + + # The returned object is the same array that was passed in + assert result is seg + + def test_inplace_true_produces_same_result_as_copy(self): + """Test that inplace=True produces identical output to inplace=False.""" + seg = np.zeros((2, 15, 15), dtype=np.int32) + seg[0, 2:8, 2:8] = 1 + seg[0, 12:14, 12:14] = 1 # small disconnected fragment + seg[1, 3:9, 3:9] = 2 + + cleaned_copy = clean_segmentation(seg.copy(), verbose=False, min_size=5, inplace=False) + cleaned_inplace = clean_segmentation(seg.copy(), verbose=False, min_size=5, inplace=True) + + np.testing.assert_array_equal(cleaned_copy, cleaned_inplace) + class TestGetCleaningStats: """Tests for the get_cleaning_stats function."""