diff --git a/spikeinterface_gui/controller.py b/spikeinterface_gui/controller.py index 6f3f60c..bdaecfc 100644 --- a/spikeinterface_gui/controller.py +++ b/spikeinterface_gui/controller.py @@ -152,7 +152,7 @@ def __init__( print('\tLoading spike_amplitudes') sa_ext = analyzer.get_extension('spike_amplitudes') if sa_ext is not None: - self.spike_amplitudes = sa_ext.get_data() + self.spike_amplitudes = sa_ext.get_data(copy=False) else: self.spike_amplitudes = None @@ -165,7 +165,7 @@ def __init__( print('\tLoading amplitude_scalings') sa_ext = analyzer.get_extension('amplitude_scalings') if sa_ext is not None: - self.amplitude_scalings = sa_ext.get_data() + self.amplitude_scalings = sa_ext.get_data(copy=False) else: self.amplitude_scalings = None @@ -178,7 +178,7 @@ def __init__( print('\tLoading spike_locations') sl_ext = analyzer.get_extension('spike_locations') if sl_ext is not None: - self.spike_depths = sl_ext.get_data()["y"] + self.spike_depths = sl_ext.get_data(copy=False)["y"] else: self.spike_depths = None @@ -295,25 +295,23 @@ def __init__( unit_ids = self.analyzer.unit_ids num_seg = self.analyzer.get_num_segments() self.num_spikes = self.analyzer.sorting.count_num_spikes_per_unit(outputs="dict") - # print("self.num_spikes", self.num_spikes) - spike_vector = self.analyzer.sorting.to_spike_vector(concatenated=True, extremum_channel_inds=self._extremum_channel) - # spike_vector = self.analyzer.sorting.to_spike_vector(concatenated=True) + self.spikes = self.analyzer.sorting.to_spike_vector() self.random_spikes_indices = self.analyzer.get_extension("random_spikes").get_data() - self.spikes = np.zeros(spike_vector.size, dtype=spike_dtype) - self.spikes['sample_index'] = spike_vector['sample_index'] - self.spikes['unit_index'] = spike_vector['unit_index'] - self.spikes['segment_index'] = spike_vector['segment_index'] - self.spikes['channel_index'] = spike_vector['channel_index'] - self.spikes['rand_selected'][:] = False - self.spikes['rand_selected'][self.random_spikes_indices] = True + ext_channel_inds = np.array([self._extremum_channel[unit_id] for unit_id in self.unit_ids]) + self.spike_channel_index = ext_channel_inds[self.spikes["unit_index"]] + self.spike_rand_selected = np.zeros(len(self.spikes), dtype=bool) + self.spike_rand_selected[self.random_spikes_indices] = True - # self.num_spikes = self.analyzer.sorting.count_num_spikes_per_unit(outputs="dict") - seg_limits = np.searchsorted(self.spikes["segment_index"], np.arange(num_seg + 1)) + if self.analyzer.sorting._cached_spike_vector_segment_slices is not None: + seg_limits = self.analyzer.sorting._cached_spike_vector_segment_slices + else: + seg_limits = np.searchsorted(self.spikes["segment_index"], np.arange(num_seg + 1)) self.segment_slices = {segment_index: slice(seg_limits[segment_index], seg_limits[segment_index + 1]) for segment_index in range(num_seg)} - + + # TODO: minimize memory here spike_vector2 = self.analyzer.sorting.to_spike_vector(concatenated=False) self.final_spike_samples = [segment_spike_vector[-1][0] for segment_spike_vector in spike_vector2] # this is dict of list because per segment spike_indices[segment_index][unit_id] diff --git a/spikeinterface_gui/spikelistview.py b/spikeinterface_gui/spikelistview.py index c8cd73c..faca435 100644 --- a/spikeinterface_gui/spikelistview.py +++ b/spikeinterface_gui/spikelistview.py @@ -53,6 +53,8 @@ def data(self, index, role): abs_ind = self.visible_ind[row] spike = self.controller.spikes[abs_ind] + channel_index = self.controller.spike_channel_index[abs_ind] + rand_selected = self.controller.spike_rand_selected[abs_ind] unit_id = self.controller.unit_ids[spike['unit_index']] if role ==QT.Qt.DisplayRole : @@ -65,9 +67,9 @@ def data(self, index, role): elif col == 3: return '{}'.format(spike['sample_index']) elif col == 4: - return '{}'.format(spike['channel_index']) + return '{}'.format(channel_index) elif col == 5: - return '{}'.format(spike['rand_selected']) + return '{}'.format(rand_selected) else: return None elif role == QT.Qt.DecorationRole : @@ -309,6 +311,8 @@ def _panel_refresh_table(self): visible_inds = self.controller.get_indices_spike_visible() unit_ids = self.controller.unit_ids spikes = self.controller.spikes[visible_inds] + channel_inds = self.controller.spike_channel_index + rand_selected = self.controller.spike_rand_selected spike_unit_ids = [] for i, spike in enumerate(spikes): @@ -322,8 +326,8 @@ def _panel_refresh_table(self): 'unit_id': spike_unit_ids, 'segment_index': spikes['segment_index'], 'sample_index': spikes['sample_index'], - 'channel_index': spikes['channel_index'], - 'rand_selected': spikes['rand_selected'] + 'channel_index': channel_inds, + 'rand_selected': rand_selected } # Update table data without replacing entire dataframe diff --git a/spikeinterface_gui/tests/test_mainwindow_panel.py b/spikeinterface_gui/tests/test_mainwindow_panel.py index 3971078..66f7dab 100644 --- a/spikeinterface_gui/tests/test_mainwindow_panel.py +++ b/spikeinterface_gui/tests/test_mainwindow_panel.py @@ -33,8 +33,7 @@ def teardown_module(): def test_mainwindow(start_app=False, verbose=True, curation=False, only_some_extensions=False, events=False, port=0): - - analyzer = load_sorting_analyzer(test_folder / "sorting_analyzer") + analyzer = load_sorting_analyzer(test_folder / "sorting_analyzer", load_extensions=False) # analyzer = load_analyzer(test_folder / "sorting_analyzer.zarr") print(analyzer) diff --git a/spikeinterface_gui/tests/test_mainwindow_qt.py b/spikeinterface_gui/tests/test_mainwindow_qt.py index 3349eba..2d42554 100644 --- a/spikeinterface_gui/tests/test_mainwindow_qt.py +++ b/spikeinterface_gui/tests/test_mainwindow_qt.py @@ -35,15 +35,12 @@ def teardown_module(): clean_all(test_folder) -def test_mainwindow(start_app=False, verbose=True, curation=False, only_some_extensions=False, events=False): +def test_mainwindow(start_app=False, verbose=True, curation=False, only_some_extensions=False, events=False, lazy_load=False): - analyzer = load_sorting_analyzer(test_folder / "sorting_analyzer") + analyzer = load_sorting_analyzer(test_folder / "sorting_analyzer", load_extensions=False, lazy=lazy_load) # analyzer = load_analyzer(test_folder / "sorting_analyzer.zarr") - tm = analyzer.get_extension("template_metrics").get_data().iloc[0, :] - # print(tm) - # return print(analyzer) @@ -109,7 +106,8 @@ def test_mainwindow(start_app=False, verbose=True, curation=False, only_some_ext displayed_unit_properties=None, extra_unit_properties=extra_unit_properties, layout_preset='default', - events=events_dict + events=events_dict, + lazy_load=lazy_load # user_settings={"mainsettings": {"color_mode": "color_by_visibility", "max_visible_units": 5}} ) @@ -144,6 +142,8 @@ def test_launcher(verbose=True): parser = ArgumentParser() parser.add_argument('--dataset', default="small", help='Path to the dataset folder') parser.add_argument('--events', action="store_true", help='Simulate and add events') +parser.add_argument('--lazy', action="store_true", help='Lazy load') + if __name__ == '__main__': args = parser.parse_args() @@ -155,7 +155,7 @@ def test_launcher(verbose=True): if not test_folder.is_dir(): setup_module() - win = test_mainwindow(start_app=True, verbose=True, curation=True, events=args.events) + win = test_mainwindow(start_app=True, verbose=True, curation=True, events=args.events, lazy_load=args.lazy) # win = test_mainwindow(start_app=True, verbose=True, curation=False) # test_launcher(verbose=True) diff --git a/spikeinterface_gui/traceview.py b/spikeinterface_gui/traceview.py index e79f29d..032b21a 100644 --- a/spikeinterface_gui/traceview.py +++ b/spikeinterface_gui/traceview.py @@ -38,6 +38,7 @@ def get_data_in_chunk(self, t1, t2, segment_index): spikes_seg = self.controller.spikes[sl] i1, i2 = np.searchsorted(spikes_seg["sample_index"], [ind1, ind2]) spikes_chunk = spikes_seg[i1:i2].copy() + spikes_channel_chunk = self.controller.spike_channel_index[sl] spikes_chunk["sample_index"] -= ind1 # for trace map view, this returns the channels ordered by depth @@ -73,7 +74,7 @@ def get_data_in_chunk(self, t1, t2, segment_index): # Get spikes for this unit unit_spikes = spikes_chunk[inds] - channel_inds = unit_spikes["channel_index"] + channel_inds = spikes_channel_chunk[inds] sample_inds = unit_spikes["sample_index"] chan_mask = np.isin(channel_inds, visible_channel_inds)