Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 5 additions & 10 deletions spikeinterface_gui/backend_qt.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,16 +171,11 @@ def __init__(self, controller, parent=None, layout_dict=None, user_settings=None
self.make_views(user_settings)
self.create_main_layout()

# refresh all views without notiying
self.controller.signal_handler.deactivate()
for view in self.views.values():
# refresh do not work because view are not yet visible at init
view._refresh()
self.controller.signal_handler.activate()
# TODO sam : all views are always refreshed at the moment so this is useless.
# uncommen this when ViewBase.is_view_visible() work correctly
# for view_name, dock in self.docks.items():
# dock.visibilityChanged.connect(self.views[view_name].refresh)
for view_name, dock in self.docks.items():
view = self.views[view_name]
dock.visibilityChanged.connect(
lambda visible, v=view: visible and QT.QTimer.singleShot(0, v.refresh)
) # Deferred so visibleRegion is populated before refresh() re-checks is_view_visible;

def make_views(self, user_settings):
self.views = {}
Expand Down
8 changes: 3 additions & 5 deletions spikeinterface_gui/basescatterview.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,12 @@ def get_unit_data(self, unit_id, segment_index=0):

def get_selected_spikes_data(self, segment_index=0, visible_inds=None):
sl = self.controller.segment_slices[segment_index]
spikes_in_seg = self.controller.spikes[sl]
selected_indices = self.controller.get_indices_spike_selected()
if visible_inds is not None:
selected_indices = np.intersect1d(selected_indices, visible_inds)
mask = np.isin(sl.start + np.arange(len(spikes_in_seg)), selected_indices)
selected_spikes = spikes_in_seg[mask]
spike_times = self.controller.sample_index_to_time(selected_spikes['sample_index'])
spike_data = self.spike_data[sl][mask]
in_seg = selected_indices[(selected_indices >= sl.start) & (selected_indices < sl.stop)]
spike_times = self.controller.sample_index_to_time(self.controller.spikes['sample_index'][in_seg])
spike_data = self.spike_data[in_seg]
return (spike_times, spike_data)


Expand Down
27 changes: 19 additions & 8 deletions spikeinterface_gui/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,9 @@ def __init__(

self.random_spikes_indices = self.analyzer.get_extension("random_spikes").get_data()

self.spikes = np.zeros(spike_vector.size, dtype=spike_dtype)
# align=True is required for np.searchsorted (and therefore trace
# views) to be fast.
self.spikes = np.zeros(spike_vector.size, dtype=np.dtype(spike_dtype, align=True))
self.spikes['sample_index'] = spike_vector['sample_index']
self.spikes['unit_index'] = spike_vector['unit_index']
self.spikes['segment_index'] = spike_vector['segment_index']
Expand Down Expand Up @@ -492,7 +494,9 @@ def get_times_chunk(self, segment_index, t1, t2):
ind1, ind2 = self.get_chunk_indices(t1, t2, segment_index)
if self.main_settings["use_times"]:
recording = self.analyzer.recording
times_chunk = recording.get_times(segment_index=segment_index)[ind1:ind2]
# Passing frame bounds slices lazily if the time vector supports it (e.g. zarr).
# Can save 10s of GB of RAM on long recordings.
times_chunk = recording.get_times(segment_index=segment_index, start_frame=ind1, end_frame=ind2)
else:
times_chunk = np.arange(ind2 - ind1, dtype='float64') / self.sampling_frequency + max(t1, 0)
return times_chunk
Expand Down Expand Up @@ -1058,15 +1062,22 @@ def make_manual_split_if_possible(self, unit_id):
visible_unit_ids = self.get_visible_unit_ids()
if unit_id not in visible_unit_ids:
return False
indices = self.get_indices_spike_selected()
if len(indices) == 0:
indices = np.asarray(self.get_indices_spike_selected())
if indices.size == 0:
return False
spike_inds = self.get_spike_indices(unit_id, segment_index=None)
if not np.all(np.isin(indices, spike_inds)):
return False

# convert selected indices to indices within the spike train of the unit
indices = [np.where(spike_inds == ind)[0][0] for ind in indices]
# convert selected indices to indices within the spike train of the unit,
# and validate that they all belong to the unit.
# np.searchsorted does both (because spike_inds is sorted ascending)
positions = np.searchsorted(spike_inds, indices)
# positions == spike_inds.size means the index sorts past the end (absent);
# otherwise the index belongs to the unit iff spike_inds[position] matches.
if np.any(positions >= spike_inds.size) or not np.array_equal(
spike_inds[np.minimum(positions, spike_inds.size - 1)], indices
):
return False
indices = positions.tolist()

new_split = {
"unit_id": unit_id,
Expand Down
24 changes: 14 additions & 10 deletions spikeinterface_gui/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,19 +348,23 @@ def run_mainwindow_cli():
try:
if args.verbose:
print('Loading recording...')
recording_base_path = args.recording_base_path
recording = load(args.recording, base_folder=recording_base_path)
recording = load(args.recording, base_folder=args.recording_base_folder)
if args.verbose:
print('Recording loaded')
except Exception as e:
print('Error when loading recording. Please check the path or the file format')
if recording is not None:
if analyzer.get_num_channels() != recording.get_num_channels():
print('Recording and analyzer have different number of channels. Slicing recording')
channel_mask = np.isin(recording.channel_ids, analyzer.channel_ids)
if np.sum(channel_mask) != analyzer.get_num_channels():
raise ValueError('The recording does not have the same channel ids as the analyzer')
recording = recording.select_channels(recording.channel_ids[channel_mask])
raise RuntimeError(
f"Could not load recording from '{args.recording}' "
f"(base folder: {args.recording_base_folder}). "
"Check that the path exists and is readable by spikeinterface.load."
) from e
# --recording loaded successfully here (a failure raises above), so the
# analyzer/recording channel counts can be reconciled directly.
if analyzer.get_num_channels() != recording.get_num_channels():
print('Recording and analyzer have different number of channels. Slicing recording')
channel_mask = np.isin(recording.channel_ids, analyzer.channel_ids)
if np.sum(channel_mask) != analyzer.get_num_channels():
raise ValueError('The recording does not have the same channel ids as the analyzer')
recording = recording.select_channels(recording.channel_ids[channel_mask])

if args.curation_file is not None:
with open(args.curation_file, "r") as f:
Expand Down
41 changes: 27 additions & 14 deletions spikeinterface_gui/tests/test_controller.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,45 @@
import spikeinterface_gui as sigui
from pathlib import Path

from spikeinterface_gui.tests.testingtools import clean_all, make_analyzer_folder
import numpy as np

import spikeinterface.full as si

from pathlib import Path
from spikeinterface_gui.controller import Controller
from spikeinterface_gui.tests.testingtools import clean_all, make_analyzer_folder

test_folder = Path('my_dataset')


def setup_module():
make_analyzer_folder(test_folder)


def teardown_module():
clean_all(test_folder)


def test_controller():
def _load_controller(curation=False):
sorting_analyzer = si.load_sorting_analyzer(test_folder / "sorting_analyzer")
print()
controller = sigui.SpikeinterfaceController(sorting_analyzer)
print(controller)

# print(controller.segment_slices)
print(controller.get_isi_histograms())

return Controller(sorting_analyzer, curation=curation)


def test_controller():
controller = _load_controller()

# unit_ids mirror the analyzer
assert list(controller.unit_ids) == list(controller.analyzer.unit_ids)

# isi histograms were computed and are shaped consistently with the units
isi_histograms, isi_bins = controller.get_isi_histograms()
assert isi_histograms.shape[0] == controller.unit_ids.size
assert isi_bins.shape[0] == isi_histograms.shape[1] + 1




if __name__ == '__main__':

# setup_module()
test_controller()
setup_module()
try:
test_controller()
finally:
teardown_module()
59 changes: 40 additions & 19 deletions spikeinterface_gui/tests/test_curation_tools.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,43 @@
from spikeinterface_gui.curation_tools import adding_group
from spikeinterface_gui.curation_tools import add_merge


def test_adding_group():
original_groups = [[1, 2, 3], [4, 5, 6], [7, 8]]
new_group_0 = [12, 10]
new_group_1 = [1, 10]
new_group_2 = [1, 10, 4]
new_group_3 = [1, 10, 8]
new_group_4 = [1, 4, 8]
r0 = adding_group(original_groups, new_group_0)
r1 = adding_group(original_groups, new_group_1)
r2 = adding_group(original_groups, new_group_2)
r3 = adding_group(original_groups, new_group_3)
r4 = adding_group(original_groups, new_group_4)
assert r0 == [[10, 12], [1, 2, 3], [4, 5, 6], [8, 7]]
assert r1 == [[3, 1, 10, 2], [4, 5, 6], [8, 7]]
assert r2 == [[1, 2, 3, 4, 5, 6, 10], [8, 7]]
assert r3 == [[1, 2, 3, 7, 8, 10], [4, 5, 6]]
assert r4 == [[1, 2, 3, 4, 5, 6, 7, 8]]
print(f'{r0} \n {r1} \n {r2} \n {r3} \n {r4}')
def _as_sets(merges):
# add_merge returns [{"unit_ids": [...]}, ...] with the group contents
# de-duplicated via set(), so compare order-insensitively.
return {frozenset(merge["unit_ids"]) for merge in merges}


def test_add_merge():
# previous merges use the curation_data["merges"] format: list of
# {"unit_ids": [...]} groups. Adding a new group transitively merges every
# existing group that shares a unit with it.
previous_merges = [
{"unit_ids": [1, 2, 3]},
{"unit_ids": [4, 5, 6]},
{"unit_ids": [7, 8]},
]

# disjoint new group -> kept as its own group, others untouched
assert _as_sets(add_merge(previous_merges, [12, 10])) == {
frozenset({10, 12}), frozenset({1, 2, 3}), frozenset({4, 5, 6}), frozenset({7, 8}),
}
# shares unit 1 -> folds into the {1,2,3} group
assert _as_sets(add_merge(previous_merges, [1, 10])) == {
frozenset({1, 2, 3, 10}), frozenset({4, 5, 6}), frozenset({7, 8}),
}
# bridges the {1,2,3} and {4,5,6} groups
assert _as_sets(add_merge(previous_merges, [1, 10, 4])) == {
frozenset({1, 2, 3, 4, 5, 6, 10}), frozenset({7, 8}),
}
# bridges the {1,2,3} and {7,8} groups
assert _as_sets(add_merge(previous_merges, [1, 10, 8])) == {
frozenset({1, 2, 3, 7, 8, 10}), frozenset({4, 5, 6}),
}
# bridges all three groups
assert _as_sets(add_merge(previous_merges, [1, 4, 8])) == {
frozenset({1, 2, 3, 4, 5, 6, 7, 8}),
}


if __name__ == '__main__':
test_add_merge()
24 changes: 19 additions & 5 deletions spikeinterface_gui/tests/testingtools.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import gc
import shutil
import time
from pathlib import Path

import numpy as np
Expand All @@ -8,11 +10,23 @@


def clean_all(test_folder):
folders = [test_folder]
for folder in folders:
if Path(folder).exists():
folder = Path(test_folder)
if not folder.exists():
return
# Release lingering memmap handles (e.g. the binary_folder waveforms extension)
# before deleting. On NFS, unlinking a still-open file creates a ".nfs*" file, which
# makes shutil.rmtree raise an error when it tries to rmdir the parent too quickly.
for attempt in range(5):
gc.collect() # force release of memmap handles
try:
shutil.rmtree(folder)

return
except OSError:
if attempt < 4:
# retry after a short delay, to give the OS time to release the file handles
time.sleep(0.5)
# don't let a failure here cause an otherwise-passing test to fail
shutil.rmtree(folder, ignore_errors=True)
def make_analyzer_folder(test_folder, case="small", unit_dtype="str"):
clean_all(test_folder)

Expand Down Expand Up @@ -120,7 +134,7 @@ def make_analyzer_folder(test_folder, case="small", unit_dtype="str"):
sorting_analyzer.compute("templates", **job_kwargs)
sorting_analyzer.compute("noise_levels", **job_kwargs)
sorting_analyzer.compute("unit_locations")
ext = sorting_analyzer.compute("isi_histograms", window_ms=50., bin_ms=1., method="numba")
sorting_analyzer.compute("isi_histograms", window_ms=50., bin_ms=1., method="numba")
sorting_analyzer.compute("correlograms", window_ms=50., bin_ms=1.)
sorting_analyzer.compute("template_similarity", method="l1")
sorting_analyzer.compute("principal_components", n_components=3, mode='by_channel_global', whiten=True, **job_kwargs)
Expand Down
5 changes: 3 additions & 2 deletions spikeinterface_gui/view_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,9 @@ def on_settings_changed(self, *params):

def is_view_visible(self):
if self.backend == "qt":
# a widget is visible even is it is hidden under another tab!! TODO fix this
return self.qt_widget.isVisible()
# isVisible() (confusingly) stays True for a view tabbed behind another dock.
# But an obscured widget paints nothing, so its visibleRegion is empty.
return self.qt_widget.isVisible() and not self.qt_widget.visibleRegion().isEmpty()
elif self.backend == "panel":
return self._panel_view_is_visible

Expand Down
Loading