diff --git a/spikeinterface_gui/backend_qt.py b/spikeinterface_gui/backend_qt.py index 9adda83..6e6a7d7 100644 --- a/spikeinterface_gui/backend_qt.py +++ b/spikeinterface_gui/backend_qt.py @@ -171,16 +171,11 @@ def __init__(self, controller, parent=None, layout_dict=None, user_settings=None self.make_views(user_settings) self.create_main_layout() - # refresh all views without notiying - self.controller.signal_handler.deactivate() - for view in self.views.values(): - # refresh do not work because view are not yet visible at init - view._refresh() - self.controller.signal_handler.activate() - # TODO sam : all views are always refreshed at the moment so this is useless. - # uncommen this when ViewBase.is_view_visible() work correctly - # for view_name, dock in self.docks.items(): - # dock.visibilityChanged.connect(self.views[view_name].refresh) + for view_name, dock in self.docks.items(): + view = self.views[view_name] + dock.visibilityChanged.connect( + lambda visible, v=view: visible and QT.QTimer.singleShot(0, v.refresh) + ) # Deferred so visibleRegion is populated before refresh() re-checks is_view_visible; def make_views(self, user_settings): self.views = {} diff --git a/spikeinterface_gui/basescatterview.py b/spikeinterface_gui/basescatterview.py index c9754f9..5ba831b 100644 --- a/spikeinterface_gui/basescatterview.py +++ b/spikeinterface_gui/basescatterview.py @@ -79,14 +79,12 @@ def get_unit_data(self, unit_id, segment_index=0): def get_selected_spikes_data(self, segment_index=0, visible_inds=None): sl = self.controller.segment_slices[segment_index] - spikes_in_seg = self.controller.spikes[sl] selected_indices = self.controller.get_indices_spike_selected() if visible_inds is not None: selected_indices = np.intersect1d(selected_indices, visible_inds) - mask = np.isin(sl.start + np.arange(len(spikes_in_seg)), selected_indices) - selected_spikes = spikes_in_seg[mask] - spike_times = self.controller.sample_index_to_time(selected_spikes['sample_index']) - spike_data = self.spike_data[sl][mask] + in_seg = selected_indices[(selected_indices >= sl.start) & (selected_indices < sl.stop)] + spike_times = self.controller.sample_index_to_time(self.controller.spikes['sample_index'][in_seg]) + spike_data = self.spike_data[in_seg] return (spike_times, spike_data) diff --git a/spikeinterface_gui/controller.py b/spikeinterface_gui/controller.py index 6f3f60c..82196df 100644 --- a/spikeinterface_gui/controller.py +++ b/spikeinterface_gui/controller.py @@ -302,7 +302,9 @@ def __init__( self.random_spikes_indices = self.analyzer.get_extension("random_spikes").get_data() - self.spikes = np.zeros(spike_vector.size, dtype=spike_dtype) + # align=True is required for np.searchsorted (and therefore trace + # views) to be fast. + self.spikes = np.zeros(spike_vector.size, dtype=np.dtype(spike_dtype, align=True)) self.spikes['sample_index'] = spike_vector['sample_index'] self.spikes['unit_index'] = spike_vector['unit_index'] self.spikes['segment_index'] = spike_vector['segment_index'] @@ -492,7 +494,9 @@ def get_times_chunk(self, segment_index, t1, t2): ind1, ind2 = self.get_chunk_indices(t1, t2, segment_index) if self.main_settings["use_times"]: recording = self.analyzer.recording - times_chunk = recording.get_times(segment_index=segment_index)[ind1:ind2] + # Passing frame bounds slices lazily if the time vector supports it (e.g. zarr). + # Can save 10s of GB of RAM on long recordings. + times_chunk = recording.get_times(segment_index=segment_index, start_frame=ind1, end_frame=ind2) else: times_chunk = np.arange(ind2 - ind1, dtype='float64') / self.sampling_frequency + max(t1, 0) return times_chunk @@ -1058,15 +1062,22 @@ def make_manual_split_if_possible(self, unit_id): visible_unit_ids = self.get_visible_unit_ids() if unit_id not in visible_unit_ids: return False - indices = self.get_indices_spike_selected() - if len(indices) == 0: + indices = np.asarray(self.get_indices_spike_selected()) + if indices.size == 0: return False spike_inds = self.get_spike_indices(unit_id, segment_index=None) - if not np.all(np.isin(indices, spike_inds)): - return False - # convert selected indices to indices within the spike train of the unit - indices = [np.where(spike_inds == ind)[0][0] for ind in indices] + # convert selected indices to indices within the spike train of the unit, + # and validate that they all belong to the unit. + # np.searchsorted does both (because spike_inds is sorted ascending) + positions = np.searchsorted(spike_inds, indices) + # positions == spike_inds.size means the index sorts past the end (absent); + # otherwise the index belongs to the unit iff spike_inds[position] matches. + if np.any(positions >= spike_inds.size) or not np.array_equal( + spike_inds[np.minimum(positions, spike_inds.size - 1)], indices + ): + return False + indices = positions.tolist() new_split = { "unit_id": unit_id, diff --git a/spikeinterface_gui/main.py b/spikeinterface_gui/main.py index 050b614..9972a14 100644 --- a/spikeinterface_gui/main.py +++ b/spikeinterface_gui/main.py @@ -348,19 +348,23 @@ def run_mainwindow_cli(): try: if args.verbose: print('Loading recording...') - recording_base_path = args.recording_base_path - recording = load(args.recording, base_folder=recording_base_path) + recording = load(args.recording, base_folder=args.recording_base_folder) if args.verbose: print('Recording loaded') except Exception as e: - print('Error when loading recording. Please check the path or the file format') - if recording is not None: - if analyzer.get_num_channels() != recording.get_num_channels(): - print('Recording and analyzer have different number of channels. Slicing recording') - channel_mask = np.isin(recording.channel_ids, analyzer.channel_ids) - if np.sum(channel_mask) != analyzer.get_num_channels(): - raise ValueError('The recording does not have the same channel ids as the analyzer') - recording = recording.select_channels(recording.channel_ids[channel_mask]) + raise RuntimeError( + f"Could not load recording from '{args.recording}' " + f"(base folder: {args.recording_base_folder}). " + "Check that the path exists and is readable by spikeinterface.load." + ) from e + # --recording loaded successfully here (a failure raises above), so the + # analyzer/recording channel counts can be reconciled directly. + if analyzer.get_num_channels() != recording.get_num_channels(): + print('Recording and analyzer have different number of channels. Slicing recording') + channel_mask = np.isin(recording.channel_ids, analyzer.channel_ids) + if np.sum(channel_mask) != analyzer.get_num_channels(): + raise ValueError('The recording does not have the same channel ids as the analyzer') + recording = recording.select_channels(recording.channel_ids[channel_mask]) if args.curation_file is not None: with open(args.curation_file, "r") as f: diff --git a/spikeinterface_gui/tests/test_panel_embedded.py b/spikeinterface_gui/tests/debug_panel_embedded.py similarity index 100% rename from spikeinterface_gui/tests/test_panel_embedded.py rename to spikeinterface_gui/tests/debug_panel_embedded.py diff --git a/spikeinterface_gui/tests/test_controller.py b/spikeinterface_gui/tests/test_controller.py index 3bac96f..0282c01 100644 --- a/spikeinterface_gui/tests/test_controller.py +++ b/spikeinterface_gui/tests/test_controller.py @@ -1,10 +1,11 @@ -import spikeinterface_gui as sigui +from pathlib import Path -from spikeinterface_gui.tests.testingtools import clean_all, make_analyzer_folder +import numpy as np import spikeinterface.full as si -from pathlib import Path +from spikeinterface_gui.controller import Controller +from spikeinterface_gui.tests.testingtools import clean_all, make_analyzer_folder test_folder = Path('my_dataset') @@ -12,21 +13,33 @@ def setup_module(): make_analyzer_folder(test_folder) + def teardown_module(): clean_all(test_folder) -def test_controller(): +def _load_controller(curation=False): sorting_analyzer = si.load_sorting_analyzer(test_folder / "sorting_analyzer") - print() - controller = sigui.SpikeinterfaceController(sorting_analyzer) - print(controller) - - # print(controller.segment_slices) - print(controller.get_isi_histograms()) - + return Controller(sorting_analyzer, curation=curation) + + +def test_controller(): + controller = _load_controller() + + # unit_ids mirror the analyzer + assert list(controller.unit_ids) == list(controller.analyzer.unit_ids) + + # isi histograms were computed and are shaped consistently with the units + isi_histograms, isi_bins = controller.get_isi_histograms() + assert isi_histograms.shape[0] == controller.unit_ids.size + assert isi_bins.shape[0] == isi_histograms.shape[1] + 1 + + + if __name__ == '__main__': - - # setup_module() - test_controller() + setup_module() + try: + test_controller() + finally: + teardown_module() diff --git a/spikeinterface_gui/tests/test_curation_tools.py b/spikeinterface_gui/tests/test_curation_tools.py index 54ba496..4473c3b 100644 --- a/spikeinterface_gui/tests/test_curation_tools.py +++ b/spikeinterface_gui/tests/test_curation_tools.py @@ -1,22 +1,43 @@ -from spikeinterface_gui.curation_tools import adding_group +from spikeinterface_gui.curation_tools import add_merge -def test_adding_group(): - original_groups = [[1, 2, 3], [4, 5, 6], [7, 8]] - new_group_0 = [12, 10] - new_group_1 = [1, 10] - new_group_2 = [1, 10, 4] - new_group_3 = [1, 10, 8] - new_group_4 = [1, 4, 8] - r0 = adding_group(original_groups, new_group_0) - r1 = adding_group(original_groups, new_group_1) - r2 = adding_group(original_groups, new_group_2) - r3 = adding_group(original_groups, new_group_3) - r4 = adding_group(original_groups, new_group_4) - assert r0 == [[10, 12], [1, 2, 3], [4, 5, 6], [8, 7]] - assert r1 == [[3, 1, 10, 2], [4, 5, 6], [8, 7]] - assert r2 == [[1, 2, 3, 4, 5, 6, 10], [8, 7]] - assert r3 == [[1, 2, 3, 7, 8, 10], [4, 5, 6]] - assert r4 == [[1, 2, 3, 4, 5, 6, 7, 8]] - print(f'{r0} \n {r1} \n {r2} \n {r3} \n {r4}') +def _as_sets(merges): + # add_merge returns [{"unit_ids": [...]}, ...] with the group contents + # de-duplicated via set(), so compare order-insensitively. + return {frozenset(merge["unit_ids"]) for merge in merges} + +def test_add_merge(): + # previous merges use the curation_data["merges"] format: list of + # {"unit_ids": [...]} groups. Adding a new group transitively merges every + # existing group that shares a unit with it. + previous_merges = [ + {"unit_ids": [1, 2, 3]}, + {"unit_ids": [4, 5, 6]}, + {"unit_ids": [7, 8]}, + ] + + # disjoint new group -> kept as its own group, others untouched + assert _as_sets(add_merge(previous_merges, [12, 10])) == { + frozenset({10, 12}), frozenset({1, 2, 3}), frozenset({4, 5, 6}), frozenset({7, 8}), + } + # shares unit 1 -> folds into the {1,2,3} group + assert _as_sets(add_merge(previous_merges, [1, 10])) == { + frozenset({1, 2, 3, 10}), frozenset({4, 5, 6}), frozenset({7, 8}), + } + # bridges the {1,2,3} and {4,5,6} groups + assert _as_sets(add_merge(previous_merges, [1, 10, 4])) == { + frozenset({1, 2, 3, 4, 5, 6, 10}), frozenset({7, 8}), + } + # bridges the {1,2,3} and {7,8} groups + assert _as_sets(add_merge(previous_merges, [1, 10, 8])) == { + frozenset({1, 2, 3, 7, 8, 10}), frozenset({4, 5, 6}), + } + # bridges all three groups + assert _as_sets(add_merge(previous_merges, [1, 4, 8])) == { + frozenset({1, 2, 3, 4, 5, 6, 7, 8}), + } + + +if __name__ == '__main__': + test_add_merge() diff --git a/spikeinterface_gui/tests/testingtools.py b/spikeinterface_gui/tests/testingtools.py index dcd87af..39e7d41 100644 --- a/spikeinterface_gui/tests/testingtools.py +++ b/spikeinterface_gui/tests/testingtools.py @@ -1,4 +1,6 @@ +import gc import shutil +import time from pathlib import Path import numpy as np @@ -8,11 +10,23 @@ def clean_all(test_folder): - folders = [test_folder] - for folder in folders: - if Path(folder).exists(): + folder = Path(test_folder) + if not folder.exists(): + return + # Release lingering memmap handles (e.g. the binary_folder waveforms extension) + # before deleting. On NFS, unlinking a still-open file creates a ".nfs*" file, which + # makes shutil.rmtree raise an error when it tries to rmdir the parent too quickly. + for attempt in range(5): + gc.collect() # force release of memmap handles + try: shutil.rmtree(folder) - + return + except OSError: + if attempt < 4: + # retry after a short delay, to give the OS time to release the file handles + time.sleep(0.5) + # don't let a failure here cause an otherwise-passing test to fail + shutil.rmtree(folder, ignore_errors=True) def make_analyzer_folder(test_folder, case="small", unit_dtype="str"): clean_all(test_folder) @@ -120,7 +134,7 @@ def make_analyzer_folder(test_folder, case="small", unit_dtype="str"): sorting_analyzer.compute("templates", **job_kwargs) sorting_analyzer.compute("noise_levels", **job_kwargs) sorting_analyzer.compute("unit_locations") - ext = sorting_analyzer.compute("isi_histograms", window_ms=50., bin_ms=1., method="numba") + sorting_analyzer.compute("isi_histograms", window_ms=50., bin_ms=1., method="numba") sorting_analyzer.compute("correlograms", window_ms=50., bin_ms=1.) sorting_analyzer.compute("template_similarity", method="l1") sorting_analyzer.compute("principal_components", n_components=3, mode='by_channel_global', whiten=True, **job_kwargs) diff --git a/spikeinterface_gui/view_base.py b/spikeinterface_gui/view_base.py index 25c36b0..175059e 100644 --- a/spikeinterface_gui/view_base.py +++ b/spikeinterface_gui/view_base.py @@ -94,8 +94,9 @@ def on_settings_changed(self, *params): def is_view_visible(self): if self.backend == "qt": - # a widget is visible even is it is hidden under another tab!! TODO fix this - return self.qt_widget.isVisible() + # isVisible() (confusingly) stays True for a view tabbed behind another dock. + # But an obscured widget paints nothing, so its visibleRegion is empty. + return self.qt_widget.isVisible() and not self.qt_widget.visibleRegion().isEmpty() elif self.backend == "panel": return self._panel_view_is_visible