diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index cb68f3d455..db5d99b5c3 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -180,6 +180,7 @@ def get_unit_spike_train( segment_index=segment_index, start_time=start_time, end_time=end_time, + use_cache=use_cache, ) segment_index = self._check_segment_index(segment_index) @@ -212,6 +213,7 @@ def get_unit_spike_train_in_seconds( segment_index: int | None = None, start_time: float | None = None, end_time: float | None = None, + use_cache: bool = True, ) -> np.ndarray: """ Get spike train for a unit in seconds. @@ -236,6 +238,8 @@ def get_unit_spike_train_in_seconds( The start time in seconds for spike train extraction end_time : float or None, default: None The end time in seconds for spike train extraction + use_cache : bool, default: True + If True, precompute (or use) the reordered spike vector cache for fast access. Returns ------- @@ -246,7 +250,7 @@ def get_unit_spike_train_in_seconds( segment = self.segments[segment_index] # If sorting has a registered recording, get the frames and get the times from the recording - # Note that this take into account the segment start time of the recording + # Note that this takes into account the segment start time of the recording if self.has_recording(): # Get all the spike times and then slice them @@ -258,7 +262,7 @@ def get_unit_spike_train_in_seconds( start_frame=start_frame, end_frame=end_frame, return_times=False, - use_cache=True, + use_cache=use_cache, ) spike_times = self.sample_index_to_time(spike_frames, segment_index=segment_index) @@ -288,13 +292,169 @@ def get_unit_spike_train_in_seconds( start_frame=start_frame, end_frame=end_frame, return_times=False, - use_cache=True, + use_cache=use_cache, ) t_start = segment._t_start if segment._t_start is not None else 0 spike_times = spike_frames / self.get_sampling_frequency() return t_start + spike_times + def get_unit_spike_trains( + self, + unit_ids: np.ndarray | list, + segment_index: int | None = None, + start_frame: int | None = None, + end_frame: int | None = None, + return_times: bool = False, + use_cache: bool = True, + ) -> dict[int | str, np.ndarray]: + """Return spike trains for multiple units. + + Parameters + ---------- + unit_ids : np.ndarray | list + Unit ids to retrieve spike trains for + segment_index : int or None, default: None + The segment index to retrieve spike train from. + For multi-segment objects, it is required + start_frame : int or None, default: None + The start frame for spike train extraction + end_frame : int or None, default: None + The end frame for spike train extraction + return_times : bool, default: False + If True, returns spike times in seconds instead of frames + use_cache : bool, default: True + If True, precompute (or use) the reordered spike vector cache for fast access. + + Returns + ------- + dict[int | str, np.ndarray] + A dictionary where keys are unit ids and values are spike trains (arrays of spike times or frames) + """ + if return_times: + start_time = ( + self.sample_index_to_time(start_frame, segment_index=segment_index) if start_frame is not None else None + ) + end_time = ( + self.sample_index_to_time(end_frame, segment_index=segment_index) if end_frame is not None else None + ) + + return self.get_unit_spike_trains_in_seconds( + unit_ids=unit_ids, + segment_index=segment_index, + start_time=start_time, + end_time=end_time, + use_cache=use_cache, + ) + + segment_index = self._check_segment_index(segment_index) + segment = self.segments[segment_index] + if use_cache: + # TODO: speed things up + ordered_spike_vector, slices = self.to_reordered_spike_vector( + lexsort=("sample_index", "segment_index", "unit_index"), + return_order=False, + return_slices=True, + ) + unit_indices = self.ids_to_indices(unit_ids) + spike_trains = {} + for unit_index, unit_id in zip(unit_indices, unit_ids): + sl0, sl1 = slices[unit_index, segment_index, :] + spikes = ordered_spike_vector[sl0:sl1] + spike_frames = spikes["sample_index"] + if start_frame is not None: + start = np.searchsorted(spike_frames, start_frame) + spike_frames = spike_frames[start:] + if end_frame is not None: + end = np.searchsorted(spike_frames, end_frame) + spike_frames = spike_frames[:end] + spike_trains[unit_id] = spike_frames + else: + spike_trains = segment.get_unit_spike_trains( + unit_ids=unit_ids, start_frame=start_frame, end_frame=end_frame + ) + return spike_trains + + def get_unit_spike_trains_in_seconds( + self, + unit_ids: np.ndarray | list, + segment_index: int | None = None, + start_time: float | None = None, + end_time: float | None = None, + use_cache: bool = True, + ) -> dict[int | str, np.ndarray]: + """Return spike trains for multiple units in seconds. + + Parameters + ---------- + unit_ids : np.ndarray | list + Unit ids to retrieve spike trains for + segment_index : int or None, default: None + The segment index to retrieve spike train from. + For multi-segment objects, it is required + start_time : float or None, default: None + The start time in seconds for spike train extraction + end_time : float or None, default: None + The end time in seconds for spike train extraction + use_cache : bool, default: True + If True, precompute (or use) the reordered spike vector cache for fast access. + + Returns + ------- + dict[int | str, np.ndarray] + A dictionary where keys are unit ids and values are spike trains (arrays of spike times in seconds) + """ + segment_index = self._check_segment_index(segment_index) + segment = self.segments[segment_index] + + # If sorting has a registered recording, get the frames and get the times from the recording + # Note that this takes into account the segment start time of the recording + spike_times = {} + if self.has_recording(): + # Get all the spike times and then slice them + start_frame = None + end_frame = None + spike_train_frames = self.get_unit_spike_trains( + unit_ids=unit_ids, + segment_index=segment_index, + start_frame=start_frame, + end_frame=end_frame, + return_times=False, + use_cache=use_cache, + ) + + for unit_id in unit_ids: + spike_frames = self.sample_index_to_time(spike_train_frames[unit_id], segment_index=segment_index) + + # Filter to return only the spikes within the specified time range + if start_time is not None: + spike_frames = spike_frames[spike_frames >= start_time] + if end_time is not None: + spike_frames = spike_frames[spike_frames <= end_time] + + spike_times[unit_id] = spike_frames + + return spike_times + + # If no recording attached and all back to frame-based conversion + # Get spike train in frames and convert to times using traditional method + start_frame = self.time_to_sample_index(start_time, segment_index=segment_index) if start_time else None + end_frame = self.time_to_sample_index(end_time, segment_index=segment_index) if end_time else None + + spike_frames = self.get_unit_spike_trains( + unit_ids=unit_ids, + segment_index=segment_index, + start_frame=start_frame, + end_frame=end_frame, + return_times=False, + use_cache=use_cache, + ) + for unit_id in unit_ids: + spike_frames_unit = spike_frames[unit_id] + t_start = segment._t_start if segment._t_start is not None else 0 + spike_times[unit_id] = spike_frames_unit / self.get_sampling_frequency() + t_start + return spike_times + def register_recording(self, recording, check_spike_frames: bool = True): """ Register a recording to the sorting. If the sorting and recording both contain @@ -951,48 +1111,53 @@ def to_reordered_spike_vector( key = str(lexsort) if key not in self._cached_lexsorted_spike_vector.keys(): - spikes = self.to_spike_vector() - order = np.lexsort((spikes[lexsort[0]], spikes[lexsort[1]], spikes[lexsort[2]])) - ordered_spikes = spikes[order] - self._cached_lexsorted_spike_vector[key] = {} - self._cached_lexsorted_spike_vector[key]["ordered_spikes"] = ordered_spikes - self._cached_lexsorted_spike_vector[key]["order"] = order + from .sorting_tools import reorder_spike_vector_by_buckets + spikes = self.to_spike_vector() num_units = len(self.unit_ids) num_segments = self.get_num_segments() - # precompute the slices with nested search sorted + # Both supported `lexsort` keys are equivalent to grouping spikes into + # `num_units * num_segments` "buckets" while **preserving** ascending + # `sample_index`` within each bucket. + # (Within each bucket, samples are **already** sorted!) + # + # We can do this with a (stable) counting sort in O(N), if we know the + # bucket index for each spike. That's easy: the bucket index is just a + # straightforward linear combination of the `unit_index` and `segment_index` + # fields, with the order depending on the lexsort` key. if lexsort == ("sample_index", "segment_index", "unit_index"): - # this case make spiketrain per unit compact in memory - - slices = np.zeros((num_units, num_segments, 2), dtype=np.int64) - unit_slices = np.searchsorted(ordered_spikes["unit_index"], np.arange(num_units + 1), side="left") - for unit_index, unit_id in enumerate(self.unit_ids): - u0 = unit_slices[unit_index] - u1 = unit_slices[unit_index + 1] - seg_slices = np.searchsorted( - ordered_spikes[u0:u1]["segment_index"], np.arange(num_segments + 1), side="left" - ) - for segment_index in range(num_segments): - s0 = seg_slices[segment_index] - s1 = seg_slices[segment_index + 1] - slices[unit_index, segment_index, :] = [u0 + s0, u0 + s1] - - elif ("sample_index", "unit_index", "segment_index"): - slices = np.zeros((num_segments, num_units, 2), dtype=np.int64) - seg_slices = np.searchsorted(ordered_spikes["segment_index"], np.arange(num_segments + 1), side="left") - for segment_index in range(self.get_num_segments()): - s0 = seg_slices[segment_index] - s1 = seg_slices[segment_index + 1] - unit_slices = np.searchsorted( - ordered_spikes[s0:s1]["unit_index"], np.arange(num_units + 1), side="left" - ) - for unit_index, unit_id in enumerate(self.unit_ids): - u0 = unit_slices[unit_index] - u1 = unit_slices[unit_index + 1] - slices[segment_index, unit_index, :] = [s0 + u0, s0 + u1] - - self._cached_lexsorted_spike_vector[key]["slices"] = slices + # primary key unit_index, then segment_index, then sample_index + bucket_index = spikes["unit_index"].astype(np.int64, copy=False) * num_segments + spikes[ + "segment_index" + ].astype(np.int64, copy=False) + num_buckets = num_units * num_segments + ordered_spikes, order, counts = reorder_spike_vector_by_buckets(spikes, bucket_index, num_buckets) + # counts is laid out as [unit_index * num_segments + segment_index] + counts_2d = counts.reshape(num_units, num_segments) + + else: # ("sample_index", "unit_index", "segment_index") + # primary key segment_index, then unit_index, then sample_index + bucket_index = spikes["segment_index"].astype(np.int64, copy=False) * num_units + spikes[ + "unit_index" + ].astype(np.int64, copy=False) + num_buckets = num_segments * num_units + ordered_spikes, order, counts = reorder_spike_vector_by_buckets(spikes, bucket_index, num_buckets) + # counts is laid out as [segment_index * num_units + unit_index] + counts_2d = counts.reshape(num_segments, num_units) + + # Build slices from cumulative counts. Stops are exclusive cumulative sums + # (aka "prefix sums" in the language of counting sort) shifted by one; + # starts are the same prefix without the last element. + ends = np.cumsum(counts_2d.ravel()).reshape(counts_2d.shape) + starts = ends - counts_2d + slices = np.stack([starts, ends], axis=-1).astype(np.int64, copy=False) + + self._cached_lexsorted_spike_vector[key] = { + "ordered_spikes": ordered_spikes, + "order": order, + "slices": slices, + } ordered_spikes = self._cached_lexsorted_spike_vector[key]["ordered_spikes"] out = (ordered_spikes,) @@ -1083,7 +1248,7 @@ def __init__(self, t_start=None): def get_unit_spike_train( self, - unit_id, + unit_id: int | str, start_frame: int | None = None, end_frame: int | None = None, ) -> np.ndarray: @@ -1091,18 +1256,51 @@ def get_unit_spike_train( Parameters ---------- - unit_id + unit_id : int | str + The unit id for which to get the spike train. start_frame : int, default: None + The start frame for the spike train. If None, it is set to the beginning of the segment. end_frame : int, default: None + The end frame for the spike train. If None, it is set to the end of the segment. + Returns ------- np.ndarray - + The spike train for the given unit id and time interval. """ # must be implemented in subclass raise NotImplementedError + def get_unit_spike_trains( + self, + unit_ids: np.ndarray | list, + start_frame: int | None = None, + end_frame: int | None = None, + ) -> dict[int | str, np.ndarray]: + """Get the spike trains for several units. + Can be implemented in subclass for performance but the default implementation is to call + get_unit_spike_train for each unit_id. + + Parameters + ---------- + unit_ids : numpy.array or list + The unit ids for which to get the spike trains. + start_frame : int, default: None + The start frame for the spike trains. If None, it is set to the beginning of the segment. + end_frame : int, default: None + The end frame for the spike trains. If None, it is set to the end of the segment. + + Returns + ------- + dict[int | str, np.ndarray] + A dictionary where keys are unit_ids and values are the corresponding spike trains. + """ + spike_trains = {} + for unit_id in unit_ids: + spike_trains[unit_id] = self.get_unit_spike_train(unit_id, start_frame=start_frame, end_frame=end_frame) + return spike_trains + class SpikeVectorSortingSegment(BaseSortingSegment): """ diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index cbad29c806..5b904def19 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -5,7 +5,7 @@ import numpy as np -from spikeinterface.core.base import BaseExtractor, unit_period_dtype +from spikeinterface.core.base import BaseExtractor, minimum_spike_dtype, unit_period_dtype from spikeinterface.core.basesorting import BaseSorting from spikeinterface.core.numpyextractors import NumpySorting @@ -236,14 +236,14 @@ def random_spikes_selection( elif method == "percentage": if percentage is None or not (0 < percentage <= 1): - raise ValueError(f"percentage must be in the interval (0, 1]") + raise ValueError("percentage must be in the interval (0, 1]") rng_size = min(max_spikes_per_unit, int(all_unit_indices.size * percentage)) selected_unit_indices = rng.choice(all_unit_indices, size=rng_size, replace=False, shuffle=False) elif method == "maximum_rate": if maximum_rate is None: - raise ValueError(f"maximum_rate must be defined") + raise ValueError("maximum_rate must be defined") t_duration = np.sum(get_segment_durations(sorting)) rng_size = min(int(t_duration * maximum_rate), max_spikes_per_unit, all_unit_indices.size) @@ -949,6 +949,9 @@ def remap_unit_indices_in_vector(vector, all_old_unit_ids, all_new_unit_ids, kee * select unit and recompute quickly the "unit_index" in the spike vector * merging/spliting periods or spikes and update the "unit_index" in the vector + Do not use this if you are operating on `minimum_spike_dtype` inputs! In such cases, + it is much more efficient to use a dense LUT + `filter_and_remap_spike_vector()` + (see `UnitSelectionSorting._compute_and_cache_spike_vector()`). Parameters ---------- @@ -996,3 +999,436 @@ def remap_unit_indices_in_vector(vector, all_old_unit_ids, all_new_unit_ids, kee new_vector["unit_index"] = mapping[new_vector["unit_index"]] return new_vector, keep_mask_vector + + +def is_spike_vector_sorted( + spike_vector: np.ndarray, *, chunk_size: int | None = 10_000_000, assume_single_segment: bool = False +) -> bool: + """Return True iff the spike vector is sorted by (segment_index, sample_index, unit_index). + + This is an O(n) sequential scan used to avoid an O(n log n) lexsort when the + vector already happens to be in canonical order. + + The strategy is: compare pairs of adjacent spikes in chunks to avoid allocating + (possibly big) temporary arrays for diffs. + + Each adjacent pair has to be fully "lexsorted": + + * segment_index is nondecreasing; + * within the same segment, sample_index is nondecreasing; + * within the same segment and same sample, unit_index is nondecreasing. + + Parameters + ---------- + spike_vector : np.ndarray + Spike vector with fields "sample_index", "unit_index", and + "segment_index". + chunk_size : int | None, default 10_000_000 + Number of adjacent pairs to check per chunk. None checks the full vector + in one chunk. + assume_single_segment : bool, default False + If True, skip segment_index checks and require only sample_index/unit_index + ordering. + """ + n = len(spike_vector) + if n <= 1: + return True + + if chunk_size is None: + chunk_size = n - 1 + elif chunk_size < 1: + raise ValueError("chunk_size must be >= 1 or None") + + sample_index = spike_vector["sample_index"] + unit_index = spike_vector["unit_index"] + + if assume_single_segment: + for start in range(0, n - 1, chunk_size): + stop = min(start + chunk_size, n - 1) + + # Compare each sample_index value to the following one. The shifted + # slices have equal length and represent adjacent spike pairs. + sample0 = sample_index[start:stop] + sample1 = sample_index[start + 1 : stop + 1] + if np.any(sample1 < sample0): + return False + + # Unit order only matters for cotemporal (same sample) spikes + same_sample = sample1 == sample0 + if np.any((unit_index[start + 1 : stop + 1] < unit_index[start:stop]) & same_sample): + return False + + return True + + segment_index = spike_vector["segment_index"] + + for start in range(0, n - 1, chunk_size): + stop = min(start + chunk_size, n - 1) + + # First enforce segment ordering. Later checks are masked to adjacent + # pairs in the same segment because sample/unit ordering is segment-local. + segment0 = segment_index[start:stop] + segment1 = segment_index[start + 1 : stop + 1] + if np.any(segment1 < segment0): + return False + + same_segment = segment1 == segment0 + + sample0 = sample_index[start:stop] + sample1 = sample_index[start + 1 : stop + 1] + if np.any((sample1 < sample0) & same_segment): + return False + + # Unit order is only part of canonical order for cotemporal spikes. + same_sample = same_segment & (sample1 == sample0) + if np.any((unit_index[start + 1 : stop + 1] < unit_index[start:stop]) & same_sample): + return False + + return True + + +def build_spike_vector_from_sorted_arrays( + sample_indices: np.ndarray, + unit_indices: np.ndarray, + segment_index: int = 0, +) -> np.ndarray: + """Build a `minimum_spike_dtype` spike vector when sample_indices is already sorted. + + Some sorting extractors (notably Phy/Kilosort) hold their spike samples in + a flat array that is already monotonic non-decreasing in `sample_index`. + Building the spike vector then only requires sorting `unit_index` *within* + runs of equal `sample_index` — a single O(N) pass, instead of an O(N log N) + global `np.lexsort`. + + Parameters + ---------- + sample_indices : np.ndarray + 1-D integer array of spike sample positions. Expected to be monotonic + non-decreasing; if a violation is detected the function falls back to + a global lexsort so the result is still correct. + unit_indices : np.ndarray + 1-D integer array, parallel to `sample_indices`, giving each spike's + unit index (position in the parent sorting's `unit_ids`). + segment_index : int, default 0 + Value to broadcast into the output `segment_index` field. + + Returns + ------- + spikes : np.ndarray + Structured array of length `sample_indices.size` with dtype + `minimum_spike_dtype`. The ordering is identical to what you would + get by building the structured array from the inputs and then + applying ``np.lexsort((unit_indices, sample_indices))`` — i.e. + primary key `sample_index` ascending, secondary key `unit_index` + ascending within ties. + """ + n = sample_indices.size + if unit_indices.size != n: + raise ValueError(f"sample_indices and unit_indices must have the same length; got {n} and {unit_indices.size}.") + + if n == 0: + return np.empty(0, dtype=minimum_spike_dtype) + + # Since the numba kernel is compiled for int64, this ensures we don't re-JIT if, + # for examples, the caller passes unit ids as int32. More importantly, this allows + # the kernel to index with a constant stride no matter what (e.g. if the caller + # passes a non-contiguous view like `arr[::2]`), and costs nothing if no-op. + sample_arr = np.ascontiguousarray(sample_indices, dtype=np.int64) + unit_arr = np.ascontiguousarray(unit_indices, dtype=np.int64) + + if HAVE_NUMBA: + # Allocate the output as a flat (N, 3) int64 buffer and let one numba + # kernel pass do everything: monotonicity check, unit-index + # tie resolution, and writing all three fields. + flat = np.empty((n, 3), dtype=np.int64) + is_monotonic = _build_spike_vector_kernel(sample_arr, unit_arr, int(segment_index), flat) + if is_monotonic: + # NB: This is zero-copy, because the (N, 3) int64 layout matches + # `minimum_spike_dtype` exactly. + return flat.view(minimum_spike_dtype).reshape(n) + + # Fallback: caller's sample_indices invariant did not hold (or numba + # is unavailable). Do a global lexsort. ='( + spikes = np.empty(n, dtype=minimum_spike_dtype) + spikes["segment_index"] = segment_index + order = np.lexsort((unit_arr, sample_arr)) + spikes["sample_index"] = sample_arr[order] + spikes["unit_index"] = unit_arr[order] + return spikes + + +def filter_and_remap_spike_vector( + spike_vector: np.ndarray, + unit_mapping: np.ndarray, +) -> np.ndarray: + """Filter a `minimum_spike_dtype` spike vector by unit and remap unit_index in one pass. + + For each spike `i` in `spike_vector`: + * look up ``new_idx = unit_mapping[spike_vector[i]["unit_index"]]``, + * if ``new_idx >= 0``, copy the spike to the output with that new unit_index; + * otherwise drop it. + + Spikes are written in their original order in `spike_vector`, so if the input + is sorted by ``(segment_index, sample_index, parent_unit_index)`` and + `unit_mapping`, restricted to its kept entries, is monotonic increasing, + then the output is also sorted by ``(segment_index, sample_index, new_unit_index)``. + + If `unit_mapping` is not order-preserving, the caller is responsible for re-sorting + cotemporal spike groups (e.g. by calling `is_spike_vector_sorted` + `np.lexsort`). + + Parameters + ---------- + spike_vector : np.ndarray + Structured array with dtype `minimum_spike_dtype`. + unit_mapping : np.ndarray + 1-D int64 array of length ``parent_num_units``. ``unit_mapping[old]`` + gives the new unit_index, or any negative value to drop the spike. + + Returns + ------- + out : np.ndarray + Structured array of `minimum_spike_dtype`, length `n_kept`. + """ + n_parent = spike_vector.size + if n_parent == 0: + return np.empty(0, dtype=minimum_spike_dtype) + + # See implementation note in `build_spike_vector_from_sorted_arrays()` + mapping_arr = np.ascontiguousarray(unit_mapping, dtype=np.int64) + + if HAVE_NUMBA: + # Same trick as `build_spike_vector_from_sorted_arrays()`: + # These flat-buffier views are zero-copy because the (N, 3) int64 layouts match + # `minimum_spike_dtype` exactly. + parent_flat = spike_vector.view(np.int64).reshape(n_parent, 3) + out_flat = np.empty((n_parent, 3), dtype=np.int64) + n_kept = _filter_and_remap_kernel(parent_flat, mapping_arr, out_flat) + return out_flat[:n_kept].view(minimum_spike_dtype).reshape(n_kept) + + # NumPy fallback: bool-mask + remap. + old_unit_idx = spike_vector["unit_index"] + new_unit_idx_full = mapping_arr[old_unit_idx] + keep = new_unit_idx_full >= 0 + out = spike_vector[keep].copy() + out["unit_index"] = new_unit_idx_full[keep] + return out + + +def reorder_spike_vector_by_buckets( + spike_vector: np.ndarray, + bucket_index: np.ndarray, + num_buckets: int, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Stable counting-sort of a `minimum_spike_dtype` spike vector by precomputed + bucket index. + + If numba is available, runs in O(N) using a kernel based on the algorithm described + in Cormen, Leiserson, Rivest, and Stein (CLRS) Chapter 8.2. Uses the same + flat-buffer view trick as `build_spike_vector_from_sorted_arrays()` and + `filter_and_remap_spike_vector()` to avoid intermediate copies. + + Stability means rows are emitted in input order within each bucket, + so any pre-existing sort within a bucket + (e.g. ascending `sample_index` within a (segment, unit) group) is preserved. + + Parameters + ---------- + spike_vector : np.ndarray + Structured array with dtype `minimum_spike_dtype`. + bucket_index : np.ndarray + 1-D integer array of length ``spike_vector.size`` giving the + destination bucket for each spike. Values must be in + ``[0, num_buckets)``. + num_buckets : int + Total number of buckets. + + Returns + ------- + ordered_spikes : np.ndarray + Structured array of `minimum_spike_dtype`, same length as input, + with rows grouped by `bucket_index`. + order : np.ndarray + 1-D int64 array such that ``spike_vector[order] == ordered_spikes``. + counts : np.ndarray + 1-D int64 array of length `num_buckets`, the number of spikes in + each bucket. + """ + n = spike_vector.size + if n == 0: + return ( + np.empty(0, dtype=minimum_spike_dtype), + np.empty(0, dtype=np.int64), + np.zeros(int(num_buckets), dtype=np.int64), + ) + + # See implementation note in `build_spike_vector_from_sorted_arrays()`. + bucket_arr = np.ascontiguousarray(bucket_index, dtype=np.int64) + if bucket_arr.size != n: + raise ValueError(f"bucket_index and spike_vector must have the same length; got {bucket_arr.size} and {n}.") + + if HAVE_NUMBA: + # Same trick as `build_spike_vector_from_sorted_arrays()`: + # These flat-buffier views are zero-copy because the (N, 3) int64 layouts match + # `minimum_spike_dtype` exactly. + in_flat = spike_vector.view(np.int64).reshape(n, 3) + out_flat = np.empty((n, 3), dtype=np.int64) + order = np.empty(n, dtype=np.int64) + counts = np.empty(int(num_buckets), dtype=np.int64) + _reorder_spike_vector_kernel(in_flat, bucket_arr, int(num_buckets), out_flat, order, counts) + ordered_spikes = out_flat.view(minimum_spike_dtype).reshape(n) + return ordered_spikes, order, counts + + # NumPy fallback: stable argsort by bucket, then fancy-index the structured array. + order = np.argsort(bucket_arr, kind="stable") + ordered_spikes = spike_vector[order] + counts = np.bincount(bucket_arr, minlength=int(num_buckets)).astype(np.int64, copy=False) + return ordered_spikes, order, counts + + +if HAVE_NUMBA: + import numba + + @numba.jit(nopython=True, nogil=True, cache=True) + def _build_spike_vector_kernel(sample_indices, unit_indices, segment_index, flat_out): + """Single-pass build of a `minimum_spike_dtype` spike vector. + + Walks `sample_indices` once and, for each spike, writes all three + fields (sample_index, unit_index, segment_index) into `flat_out` + — a contiguous (N, 3) int64 buffer whose memory layout matches + `minimum_spike_dtype` exactly. + + While walking, the kernel also: + * verifies `sample_indices` is monotonic non-decreasing — returns + False at the first violation (caller falls back to a global + lexsort), + * insertion-sorts `unit_indices` within each run of equal + `sample_indices` before emitting that run. Runs of length 1 + (the common case for Kilosort/Phy output) skip the sort + entirely. + """ + n = sample_indices.shape[0] + i = 0 + # Walk one tie-run at a time: [i, j) is the next run of equal sample_indices. + while i < n: + # Monotonicity guard — bail out and lexsort instead of it fails + if i > 0 and sample_indices[i] < sample_indices[i - 1]: + return False + + # Find the end of the current tie-run. + j = i + 1 + while j < n and sample_indices[j] == sample_indices[i]: + j += 1 + + sample = sample_indices[i] + if j - i == 1: + # Fast path: no co-temporal spikes. + # Column order (0, 1, 2) matches minimum_spike_dtype field order + # (sample_index, unit_index, segment_index). Essential! + flat_out[i, 0] = sample + flat_out[i, 1] = unit_indices[i] + flat_out[i, 2] = segment_index + else: + # Tied run: sort unit_indices within the run. + # In practice, runs are short (single-digit numbers of spikes) and rare + # (single-digit percentage of total spikes), so the per-run allocation + # + insertion sort are cheap. + run_len = j - i + + # Stage the tied unit_indices into a small working buffer. + # This _could_ be expensive if the runs were long, but I tested + # with/without, and this _always_ wins. + buf = np.empty(run_len, dtype=np.int64) + for k in range(run_len): + buf[k] = unit_indices[i + k] + + # Insertion-sort the buffer in place. Beats anything fancier on + # tiny arrays (maybe because there is zero setup cost). + for k in range(1, run_len): + key = buf[k] + m = k - 1 + while m >= 0 and buf[m] > key: + buf[m + 1] = buf[m] + m -= 1 + buf[m + 1] = key + + # Emit the run with sorted unit_indices. + for k in range(run_len): + flat_out[i + k, 0] = sample + flat_out[i + k, 1] = buf[k] + flat_out[i + k, 2] = segment_index + + # Advance past the run we just emitted. + i = j + + return True + + @numba.jit(nopython=True, nogil=True, cache=True) + def _filter_and_remap_kernel(parent_flat, unit_mapping, out_flat): + """Single-pass filter + remap for a (N, 3) int64 spike-vector view. + + For each row i of `parent_flat` (columns 0/1/2 = sample, unit, segment): + look up ``new_unit = unit_mapping[parent_flat[i, 1]]``; if it is + non-negative, copy (sample, new_unit, segment) to ``out_flat[write_pos]`` + and advance `write_pos`. + + Returns the number of spikes written. The caller is expected to slice + `out_flat[:n_kept]` and view as `minimum_spike_dtype`. + + Spikes are emitted in the order they appear in `parent_flat`, so + ordering on (segment_index, sample_index) is preserved automatically; + unit_index ordering within tied sample_index groups follows whatever + `unit_mapping` does to the parent's unit_index values. + """ + n = parent_flat.shape[0] + write_pos = 0 + for i in range(n): + new_unit = unit_mapping[parent_flat[i, 1]] + if new_unit >= 0: + # Column order (0, 1, 2) is coupled to minimum_spike_dtype field order + # (sample_index, unit_index, segment_index) — keep them in sync. + out_flat[write_pos, 0] = parent_flat[i, 0] + out_flat[write_pos, 1] = new_unit + out_flat[write_pos, 2] = parent_flat[i, 2] + write_pos += 1 + return write_pos + + @numba.jit(nopython=True, nogil=True, cache=True) + def _reorder_spike_vector_kernel(in_flat, bucket_index, num_buckets, out_flat, order, counts): + """Stable counting-sort of a (N, 3) int64 spike-vector flat-buffer view by bucket. + + Adapted from Cormen, Leiserson, Rivest, and Stein (CLRS) Chapter 8.2. + + Two O(N) passes: + 1. histogram `bucket_index` into `counts`, + 2. cumulative-sum to per-bucket write positions, then scatter each + row of `in_flat` to its destination in `out_flat` and record the + source index in `order` so that ``in[order] == out``. + + Stability: within each bucket, rows keep their input order, so any + ordering already present in `in_flat` (e.g. ascending sample_index + within a (segment, unit) group) carries over to `out_flat`. + """ + n = in_flat.shape[0] + + for b in range(num_buckets): + counts[b] = 0 + for i in range(n): + counts[bucket_index[i]] += 1 + + # Exclusive prefix sum into a local write_pos buffer. `counts` keeps + # the per-bucket sizes (the caller uses them to derive slices). + # (If slices weren't needed, maybe this could be done in-place in `counts`?) + write_pos = np.empty(num_buckets, dtype=np.int64) + running = 0 + for b in range(num_buckets): + write_pos[b] = running + running += counts[b] + + for i in range(n): + b = bucket_index[i] + pos = write_pos[b] + out_flat[pos, 0] = in_flat[i, 0] + out_flat[pos, 1] = in_flat[i, 1] + out_flat[pos, 2] = in_flat[i, 2] + order[pos] = i + write_pos[b] = pos + 1 diff --git a/src/spikeinterface/core/tests/test_basesorting.py b/src/spikeinterface/core/tests/test_basesorting.py index 6c06b212b8..5617b881d7 100644 --- a/src/spikeinterface/core/tests/test_basesorting.py +++ b/src/spikeinterface/core/tests/test_basesorting.py @@ -155,6 +155,131 @@ def test_BaseSorting(create_cache_folder): assert sorting.get_annotation(annotation_name) == sorting_zarr_loaded.get_annotation(annotation_name) +def _reference_reordered_spike_vector(spikes, lexsort, num_units, num_segments): + """Pre-optimization reference: np.lexsort + nested searchsorted. + + Mirrors the implementation that lived in `to_reordered_spike_vector` + before the counting-sort rewrite. Used to assert byte-for-byte parity + of the new implementation. + """ + order = np.lexsort((spikes[lexsort[0]], spikes[lexsort[1]], spikes[lexsort[2]])) + ordered_spikes = spikes[order] + + if lexsort == ("sample_index", "segment_index", "unit_index"): + slices = np.zeros((num_units, num_segments, 2), dtype=np.int64) + unit_slices = np.searchsorted(ordered_spikes["unit_index"], np.arange(num_units + 1), side="left") + for unit_index in range(num_units): + u0 = unit_slices[unit_index] + u1 = unit_slices[unit_index + 1] + seg_slices = np.searchsorted( + ordered_spikes[u0:u1]["segment_index"], np.arange(num_segments + 1), side="left" + ) + for segment_index in range(num_segments): + s0 = seg_slices[segment_index] + s1 = seg_slices[segment_index + 1] + slices[unit_index, segment_index, :] = [u0 + s0, u0 + s1] + elif lexsort == ("sample_index", "unit_index", "segment_index"): + slices = np.zeros((num_segments, num_units, 2), dtype=np.int64) + seg_slices = np.searchsorted(ordered_spikes["segment_index"], np.arange(num_segments + 1), side="left") + for segment_index in range(num_segments): + s0 = seg_slices[segment_index] + s1 = seg_slices[segment_index + 1] + unit_slices = np.searchsorted(ordered_spikes[s0:s1]["unit_index"], np.arange(num_units + 1), side="left") + for unit_index in range(num_units): + u0 = unit_slices[unit_index] + u1 = unit_slices[unit_index + 1] + slices[segment_index, unit_index, :] = [s0 + u0, s0 + u1] + else: + raise ValueError(lexsort) + + return ordered_spikes, order, slices + + +def test_to_reordered_spike_vector_parity(): + """The counting-sort rewrite must match the prior np.lexsort implementation.""" + rng = np.random.default_rng(42) + num_units = 6 + num_segments = 3 + sampling_frequency = 30_000.0 + + # Build per-segment, per-unit spike trains with deliberate cotemporal spikes + # (multiple units firing at the same sample_index) so the unit-index tiebreaker + # is exercised. + spike_dicts = [] + for seg in range(num_segments): + seg_dict = {} + for u in range(num_units): + n = int(rng.integers(50, 200)) + times = np.sort(rng.integers(0, 10_000, size=n)) + # Inject a handful of cotemporal spikes that collide with the unit-0 train. + if u > 0 and n > 5: + times[:5] = np.array([100, 200, 300, 400, 500]) + seg * 10 + times = np.sort(times) + seg_dict[str(u)] = times.astype("int64") + spike_dicts.append(seg_dict) + + sorting = NumpySorting.from_unit_dict(spike_dicts, sampling_frequency) + spikes = sorting.to_spike_vector() + + for lexsort in [ + ("sample_index", "segment_index", "unit_index"), + ("sample_index", "unit_index", "segment_index"), + ]: + # Clear the cache between iterations so each call exercises the fresh build. + sorting._cached_lexsorted_spike_vector = {} + + ordered_spikes, order, slices = sorting.to_reordered_spike_vector( + lexsort=lexsort, return_order=True, return_slices=True + ) + + ref_ordered, ref_order, ref_slices = _reference_reordered_spike_vector(spikes, lexsort, num_units, num_segments) + + # ordered_spikes must agree with the reference exactly (cotemporal spikes + # are now ordered by unit_index — stable counting sort by bucket preserves + # the canonical unit-index ordering within each tied sample_index). + assert np.array_equal(ordered_spikes, ref_ordered), f"ordered mismatch for {lexsort}" + # The invariant `spikes[order] == ordered_spikes` must hold; the exact + # `order` permutation can differ from np.lexsort's because stable counting + # sort and np.lexsort may pick different tie-break orderings of source rows + # that map to the same destination (different source rows can carry the + # same (sample, unit, segment) triple). + assert np.array_equal(spikes[order], ordered_spikes) + assert np.array_equal(slices, ref_slices), f"slices mismatch for {lexsort}" + + # Each (unit, segment) — or (segment, unit) — slice must yield exactly the + # spikes for that group, with monotonic sample_index. + if lexsort == ("sample_index", "segment_index", "unit_index"): + for u in range(num_units): + for s in range(num_segments): + s0, s1 = slices[u, s] + block = ordered_spikes[s0:s1] + assert np.all(block["unit_index"] == u) + assert np.all(block["segment_index"] == s) + assert np.all(np.diff(block["sample_index"]) >= 0) + else: + for s in range(num_segments): + for u in range(num_units): + s0, s1 = slices[s, u] + block = ordered_spikes[s0:s1] + assert np.all(block["unit_index"] == u) + assert np.all(block["segment_index"] == s) + assert np.all(np.diff(block["sample_index"]) >= 0) + + +def test_to_reordered_spike_vector_empty(): + """An empty sorting must round-trip through the counting-sort path.""" + sorting = NumpySorting.from_unit_dict({"0": np.array([], dtype="int64")}, 30_000.0) + ordered_spikes, order, slices = sorting.to_reordered_spike_vector( + lexsort=("sample_index", "segment_index", "unit_index"), + return_order=True, + return_slices=True, + ) + assert ordered_spikes.size == 0 + assert order.size == 0 + assert slices.shape == (1, 1, 2) + assert np.array_equal(slices, np.zeros((1, 1, 2), dtype=np.int64)) + + def test_npy_sorting(): sfreq = 10 spike_times_0 = { @@ -310,6 +435,31 @@ def test_select_periods(): np.testing.assert_array_equal(sliced_sorting.to_spike_vector(), sliced_sorting_array.to_spike_vector()) +@pytest.mark.parametrize("use_cache", [False, True]) +def test_get_unit_spike_trains(use_cache): + sampling_frequency = 10_000.0 + duration = 1.0 + num_units = 10 + sorting = generate_sorting(durations=[duration], sampling_frequency=sampling_frequency, num_units=num_units) + + all_spike_trains = sorting.get_unit_spike_trains(unit_ids=sorting.unit_ids, use_cache=use_cache) + assert isinstance(all_spike_trains, dict) + assert set(all_spike_trains.keys()) == set(sorting.unit_ids) + for unit_id in sorting.unit_ids: + spiketrain = sorting.get_unit_spike_train(segment_index=0, unit_id=unit_id, use_cache=use_cache) + assert np.array_equal(all_spike_trains[unit_id], spiketrain) + + # test with times + spike_trains_times = sorting.get_unit_spike_trains_in_seconds(unit_ids=sorting.unit_ids, use_cache=use_cache) + assert isinstance(spike_trains_times, dict) + assert set(spike_trains_times.keys()) == set(sorting.unit_ids) + for unit_id in sorting.unit_ids: + spiketrain_times = sorting.get_unit_spike_train_in_seconds( + segment_index=0, unit_id=unit_id, use_cache=use_cache + ) + assert np.allclose(spike_trains_times[unit_id], spiketrain_times) + + if __name__ == "__main__": import tempfile diff --git a/src/spikeinterface/core/tests/test_sorting_tools.py b/src/spikeinterface/core/tests/test_sorting_tools.py index 4194f459b3..407cd18334 100644 --- a/src/spikeinterface/core/tests/test_sorting_tools.py +++ b/src/spikeinterface/core/tests/test_sorting_tools.py @@ -13,6 +13,9 @@ _get_ids_after_merging, generate_unit_ids_for_merge_group, remap_unit_indices_in_vector, + is_spike_vector_sorted, + build_spike_vector_from_sorted_arrays, + filter_and_remap_spike_vector, ) from spikeinterface.core.base import minimum_spike_dtype @@ -45,6 +48,258 @@ def test_spike_vector_to_indices(): ) +def test_is_spike_vector_sorted(): + empty_spikes = np.zeros(0, dtype=minimum_spike_dtype) + assert is_spike_vector_sorted(empty_spikes) + + one_spike = np.zeros(1, dtype=minimum_spike_dtype) + assert is_spike_vector_sorted(one_spike) + + spikes = np.zeros(5, dtype=minimum_spike_dtype) + spikes["segment_index"] = [0, 0, 1, 1, 1] + spikes["sample_index"] = [100, 200, 0, 100, 100] + spikes["unit_index"] = [0, 1, 0, 0, 1] + assert is_spike_vector_sorted(spikes) + assert is_spike_vector_sorted(spikes, chunk_size=None) + assert is_spike_vector_sorted(spikes, chunk_size=1) + + segment_unsorted = spikes.copy() + segment_unsorted["segment_index"] = [0, 1, 0, 1, 1] + segment_unsorted["sample_index"] = [0, 100, 200, 300, 400] + segment_unsorted["unit_index"] = [0, 0, 0, 0, 0] + assert not is_spike_vector_sorted(segment_unsorted) + + sample_unsorted = spikes.copy() + sample_unsorted["segment_index"] = 0 + sample_unsorted["sample_index"] = [0, 100, 50, 200, 300] + sample_unsorted["unit_index"] = [0, 0, 0, 0, 0] + assert not is_spike_vector_sorted(sample_unsorted) + + tie_unsorted = spikes.copy() + tie_unsorted["segment_index"] = 0 + tie_unsorted["sample_index"] = [0, 100, 100, 200, 300] + tie_unsorted["unit_index"] = [0, 1, 0, 0, 0] + assert not is_spike_vector_sorted(tie_unsorted) + + with pytest.raises(ValueError, match="chunk_size"): + is_spike_vector_sorted(spikes, chunk_size=0) + + +def test_is_spike_vector_sorted_chunk_boundaries(): + spikes = np.zeros(6, dtype=minimum_spike_dtype) + spikes["segment_index"] = [0, 0, 1, 0, 1, 1] + spikes["sample_index"] = [0, 100, 200, 300, 400, 500] + spikes["unit_index"] = 0 + assert not is_spike_vector_sorted(spikes, chunk_size=3) + + spikes["segment_index"] = 0 + spikes["sample_index"] = [0, 100, 300, 200, 400, 500] + assert not is_spike_vector_sorted(spikes, chunk_size=3) + + spikes["sample_index"] = [0, 100, 200, 200, 400, 500] + spikes["unit_index"] = [0, 0, 1, 0, 0, 0] + assert not is_spike_vector_sorted(spikes, chunk_size=3) + + +def test_is_spike_vector_sorted_assume_single_segment(): + spikes = np.zeros(5, dtype=minimum_spike_dtype) + spikes["segment_index"] = [0, 1, 0, 1, 1] + spikes["sample_index"] = [0, 100, 200, 300, 400] + spikes["unit_index"] = [0, 0, 0, 0, 0] + assert not is_spike_vector_sorted(spikes) + assert is_spike_vector_sorted(spikes, assume_single_segment=True) + + sample_unsorted = spikes.copy() + sample_unsorted["sample_index"] = [0, 100, 50, 200, 300] + assert not is_spike_vector_sorted(sample_unsorted, assume_single_segment=True) + + tie_unsorted = spikes.copy() + tie_unsorted["sample_index"] = [0, 100, 100, 200, 300] + tie_unsorted["unit_index"] = [0, 1, 0, 0, 0] + assert not is_spike_vector_sorted(tie_unsorted, assume_single_segment=True) + + +def _reference_spike_vector(sample_indices, unit_indices, segment_index=0): + """Reference implementation: global lexsort, used as ground truth in tests.""" + n = sample_indices.size + spikes = np.empty(n, dtype=minimum_spike_dtype) + spikes["sample_index"] = sample_indices + spikes["unit_index"] = unit_indices + spikes["segment_index"] = segment_index + order = np.lexsort((spikes["unit_index"], spikes["sample_index"])) + return spikes[order] + + +@pytest.fixture(params=[True, False], ids=["numba", "numpy"]) +def force_numba(request, monkeypatch): + """Run each test once with numba enabled (if installed) and once with the fallback.""" + if request.param and importlib.util.find_spec("numba") is None: + pytest.skip("numba not installed") + monkeypatch.setattr("spikeinterface.core.sorting_tools.HAVE_NUMBA", request.param) + return request.param + + +def test_build_spike_vector_no_ties(force_numba): + sample_indices = np.array([10, 20, 30, 40, 50], dtype=np.int64) + unit_indices = np.array([3, 1, 4, 1, 5], dtype=np.int64) + out = build_spike_vector_from_sorted_arrays(sample_indices, unit_indices) + assert out.dtype == np.dtype(minimum_spike_dtype) + assert np.array_equal(out["sample_index"], sample_indices) + assert np.array_equal(out["unit_index"], unit_indices) + assert np.all(out["segment_index"] == 0) + + +def test_build_spike_vector_with_ties(force_numba): + # Three runs of ties (lengths 3, 2, 1, 4) with shuffled unit_indices + sample_indices = np.array( + [10, 10, 10, 20, 20, 30, 40, 40, 40, 40, 50], + dtype=np.int64, + ) + unit_indices = np.array([7, 2, 5, 9, 1, 3, 4, 0, 8, 2, 6], dtype=np.int64) + out = build_spike_vector_from_sorted_arrays(sample_indices, unit_indices) + ref = _reference_spike_vector(sample_indices, unit_indices) + assert np.array_equal(out, ref) + + +def test_build_spike_vector_all_same_sample_index(force_numba): + n = 64 + sample_indices = np.full(n, 42, dtype=np.int64) + rng = np.random.default_rng(0) + unit_indices = rng.permutation(n).astype(np.int64) + out = build_spike_vector_from_sorted_arrays(sample_indices, unit_indices) + ref = _reference_spike_vector(sample_indices, unit_indices) + assert np.array_equal(out, ref) + + +def test_build_spike_vector_ties_at_edges(force_numba): + # Ties at the very start, the very end, and an isolated single in between. + sample_indices = np.array([5, 5, 5, 9, 12, 12, 12], dtype=np.int64) + unit_indices = np.array([2, 0, 1, 7, 3, 1, 2], dtype=np.int64) + out = build_spike_vector_from_sorted_arrays(sample_indices, unit_indices) + ref = _reference_spike_vector(sample_indices, unit_indices) + assert np.array_equal(out, ref) + + +def test_build_spike_vector_empty(force_numba): + out = build_spike_vector_from_sorted_arrays( + np.array([], dtype=np.int64), + np.array([], dtype=np.int64), + ) + assert out.size == 0 + assert out.dtype == np.dtype(minimum_spike_dtype) + + +def test_build_spike_vector_segment_index(force_numba): + sample_indices = np.array([0, 1, 2], dtype=np.int64) + unit_indices = np.array([0, 0, 0], dtype=np.int64) + out = build_spike_vector_from_sorted_arrays(sample_indices, unit_indices, segment_index=3) + assert np.all(out["segment_index"] == 3) + + +def test_build_spike_vector_length_mismatch(): + with pytest.raises(ValueError): + build_spike_vector_from_sorted_arrays( + np.array([1, 2, 3], dtype=np.int64), + np.array([1, 2], dtype=np.int64), + ) + + +def test_build_spike_vector_randomized_against_lexsort(force_numba): + rng = np.random.default_rng(1234) + n = 10_000 + # Build ~30% ties by drawing sample positions from a small space. + sample_indices = np.sort(rng.integers(0, n // 3, size=n).astype(np.int64)) + unit_indices = rng.integers(0, 200, size=n).astype(np.int64) + out = build_spike_vector_from_sorted_arrays(sample_indices, unit_indices) + ref = _reference_spike_vector(sample_indices, unit_indices) + assert np.array_equal(out, ref) + + +def test_build_spike_vector_unsorted_falls_back(force_numba): + # Caller violates the "sample_indices is sorted" invariant; helper must + # still return a globally lexsorted vector via the fallback. + sample_indices = np.array([200, 100, 300, 100], dtype=np.int64) + unit_indices = np.array([0, 1, 2, 0], dtype=np.int64) + out = build_spike_vector_from_sorted_arrays(sample_indices, unit_indices) + ref = _reference_spike_vector(sample_indices, unit_indices) + assert np.array_equal(out, ref) + + +def _make_spike_vector(samples, units, segments=None): + """Build a minimum_spike_dtype array from parallel arrays. Test helper.""" + n = len(samples) + sv = np.empty(n, dtype=minimum_spike_dtype) + sv["sample_index"] = samples + sv["unit_index"] = units + sv["segment_index"] = segments if segments is not None else 0 + return sv + + +def test_filter_and_remap_keep_all(force_numba): + # Identity mapping: every parent unit_index maps to itself. + sv = _make_spike_vector([10, 20, 30, 40], [0, 1, 2, 0]) + mapping = np.arange(3, dtype=np.int64) + out = filter_and_remap_spike_vector(sv, mapping) + assert np.array_equal(out, sv) + + +def test_filter_and_remap_drop_some(force_numba): + # Drop unit 1 entirely; keep 0 and 2 with new indices [0, 1]. + sv = _make_spike_vector([10, 20, 30, 40, 50], [0, 1, 2, 0, 1]) + mapping = np.array([0, -1, 1], dtype=np.int64) + out = filter_and_remap_spike_vector(sv, mapping) + expected = _make_spike_vector([10, 30, 40], [0, 1, 0]) + assert np.array_equal(out, expected) + + +def test_filter_and_remap_renamed_only(force_numba): + # Selection is full but unit indices are permuted: 0->2, 1->0, 2->1. + sv = _make_spike_vector([10, 20, 30], [0, 1, 2]) + mapping = np.array([2, 0, 1], dtype=np.int64) + out = filter_and_remap_spike_vector(sv, mapping) + expected = _make_spike_vector([10, 20, 30], [2, 0, 1]) + assert np.array_equal(out, expected) + + +def test_filter_and_remap_empty_selection(force_numba): + sv = _make_spike_vector([10, 20, 30], [0, 1, 2]) + mapping = np.full(3, -1, dtype=np.int64) + out = filter_and_remap_spike_vector(sv, mapping) + assert out.size == 0 + assert out.dtype == np.dtype(minimum_spike_dtype) + + +def test_filter_and_remap_empty_input(force_numba): + sv = np.empty(0, dtype=minimum_spike_dtype) + mapping = np.array([0, 1, 2], dtype=np.int64) + out = filter_and_remap_spike_vector(sv, mapping) + assert out.size == 0 + assert out.dtype == np.dtype(minimum_spike_dtype) + + +def test_filter_and_remap_preserves_tie_order(force_numba): + # Two cotemporal spikes at sample 100 (units 1 and 2). After dropping unit 0, + # the two cotemporals must appear in their original relative order — the kernel + # never reorders within ties. + sv = _make_spike_vector( + [50, 100, 100, 200], + [0, 1, 2, 1], + ) + mapping = np.array([-1, 0, 1], dtype=np.int64) + out = filter_and_remap_spike_vector(sv, mapping) + expected = _make_spike_vector([100, 100, 200], [0, 1, 0]) + assert np.array_equal(out, expected) + + +def test_filter_and_remap_segment_index_preserved(force_numba): + sv = _make_spike_vector([10, 20, 30, 40], [0, 1, 0, 1], segments=[0, 0, 1, 1]) + mapping = np.array([0, 1], dtype=np.int64) + out = filter_and_remap_spike_vector(sv, mapping) + assert np.array_equal(out["segment_index"], [0, 0, 1, 1]) + assert np.array_equal(out["sample_index"], [10, 20, 30, 40]) + assert np.array_equal(out["unit_index"], [0, 1, 0, 1]) + + def test_random_spikes_selection(): recording, sorting = generate_ground_truth_recording( durations=[20.0, 10.0], @@ -61,7 +316,6 @@ def test_random_spikes_selection(): random_spikes_indices = random_spikes_selection( sorting, num_samples, method="uniform", max_spikes_per_unit=max_spikes_per_unit, margin_size=None, seed=2205 ) - random_spikes_indices1 = random_spikes_indices spikes = sorting.to_spike_vector() some_spikes = spikes[random_spikes_indices] for unit_index, unit_id in enumerate(sorting.unit_ids): diff --git a/src/spikeinterface/core/tests/test_unitsselectionsorting.py b/src/spikeinterface/core/tests/test_unitsselectionsorting.py index 3aa7bc7577..d2b24d90ab 100644 --- a/src/spikeinterface/core/tests/test_unitsselectionsorting.py +++ b/src/spikeinterface/core/tests/test_unitsselectionsorting.py @@ -1,8 +1,9 @@ import pytest import numpy as np -from pathlib import Path from spikeinterface.core import UnitsSelectionSorting +from spikeinterface.core.numpyextractors import NumpySorting +from spikeinterface.core.sorting_tools import is_spike_vector_sorted from spikeinterface.core.generate import generate_sorting @@ -40,16 +41,140 @@ def test_failure_with_non_unique_unit_ids(): seed = 10 sorting = generate_sorting(num_units=3, durations=[0.100], sampling_frequency=30000.0, seed=seed) with pytest.raises(AssertionError): - sorting2 = UnitsSelectionSorting(sorting, unit_ids=["0", "2"], renamed_unit_ids=["a", "a"]) + UnitsSelectionSorting(sorting, unit_ids=["0", "2"], renamed_unit_ids=["a", "a"]) def test_compute_and_cache_spike_vector(): + """USS override of _compute_and_cache_spike_vector must produce the same + spike vector as the base class (per-unit) implementation.""" + from spikeinterface.core.basesorting import BaseSorting + sorting = generate_sorting(num_units=3, durations=[0.100, 0.100], sampling_frequency=30000.0) sub_sorting = UnitsSelectionSorting(sorting, unit_ids=["2", "0"], renamed_unit_ids=["b", "a"]) - cached_spike_vector = sub_sorting.to_spike_vector(use_cache=True) - computed_spike_vector = sub_sorting.to_spike_vector(use_cache=False) - assert np.all(cached_spike_vector == computed_spike_vector) + + # USS override path + sub_sorting._compute_and_cache_spike_vector() + uss_vector = sub_sorting._cached_spike_vector.copy() + + # Base class (per-unit) path + sub_sorting._cached_spike_vector = None + sub_sorting._cached_spike_vector_segment_slices = None + BaseSorting._compute_and_cache_spike_vector(sub_sorting) + base_vector = sub_sorting._cached_spike_vector + + assert np.array_equal(uss_vector, base_vector) + + +@pytest.mark.parametrize("use_cache", [False, True]) +def test_uss_get_unit_spike_trains_with_renamed_ids(use_cache): + """get_unit_spike_trains on a USS with renamed ids must return dicts with child ids + (as opposed to parent ids) as keys.""" + sorting = generate_sorting(num_units=5, durations=[0.100], sampling_frequency=30000.0, seed=42) + + # Select a subset and rename + sub = UnitsSelectionSorting(sorting, unit_ids=["1", "3", "4"], renamed_unit_ids=["a", "b", "c"]) + renamed_ids = list(sub.unit_ids) + + batch = sub.get_unit_spike_trains(unit_ids=renamed_ids, segment_index=0, use_cache=use_cache) + + assert isinstance(batch, dict) + assert set(batch.keys()) == set(renamed_ids) + + for uid in renamed_ids: + single = sub.get_unit_spike_train(unit_id=uid, segment_index=0, use_cache=use_cache) + assert np.array_equal(batch[uid], single), f"Mismatch for unit {uid}" + + +def test_spike_vector_sorted_after_reorder_with_cotemporal_spikes(): + """USS spike vector must be correctly sorted even when selection reverses unit order + and co-temporal spikes exist (same sample_index, different units).""" + from spikeinterface.core.basesorting import BaseSorting + + # Build a sorting with guaranteed co-temporal spikes: + # units 0, 1, 2 all fire at sample 100 and 200 + samples = np.array([100, 100, 100, 200, 200, 200, 300, 400], dtype=np.int64) + labels = np.array([0, 1, 2, 0, 1, 2, 0, 1], dtype=np.int64) + sorting = NumpySorting.from_samples_and_labels( + samples_list=[samples], labels_list=[labels], sampling_frequency=30000.0 + ) + + # Reverse the unit order — _is_order_preserving_selection must return False + sub = UnitsSelectionSorting(sorting, unit_ids=[2, 0], renamed_unit_ids=["b", "a"]) + + spike_vector = sub.to_spike_vector() + + sub._cached_spike_vector = None + sub._cached_spike_vector_segment_slices = None + BaseSorting._compute_and_cache_spike_vector(sub) + base_vector = sub._cached_spike_vector + + assert np.array_equal(spike_vector, base_vector) + assert np.all(spike_vector["segment_index"] == 0) + assert is_spike_vector_sorted(spike_vector) + + +def test_compute_and_cache_spike_vector_identity_selection_shares_parent_cache(): + """A USS that selects all of its parent's units in parent order should reuse the + parent's cached spike vector by reference, not rebuild it.""" + from spikeinterface.core.basesorting import BaseSorting + + sorting = generate_sorting(num_units=4, durations=[0.100, 0.100], sampling_frequency=30000.0) + + # First USS: identity selection over `sorting`. Force its cache. + uss1 = UnitsSelectionSorting(sorting, unit_ids=list(sorting.unit_ids)) + uss1._compute_and_cache_spike_vector() + assert uss1._cached_spike_vector is not None + + # Second USS: identity selection over uss1, with renamed ids to exercise the + # rename-only path. The cached spike vector must be the same Python object. + renamed = [f"r{uid}" for uid in uss1.unit_ids] + uss2 = UnitsSelectionSorting(uss1, unit_ids=list(uss1.unit_ids), renamed_unit_ids=renamed) + uss2._compute_and_cache_spike_vector() + assert uss2._cached_spike_vector is uss1._cached_spike_vector + if uss1._cached_spike_vector_segment_slices is not None: + assert uss2._cached_spike_vector_segment_slices is uss1._cached_spike_vector_segment_slices + + # Belt-and-suspenders: the shared vector must still match the slow base-class path. + uss2._cached_spike_vector = None + uss2._cached_spike_vector_segment_slices = None + BaseSorting._compute_and_cache_spike_vector(uss2) + base_vector = uss2._cached_spike_vector + assert np.array_equal(uss1._cached_spike_vector, base_vector) + + +def test_to_reordered_spike_vector_identity_selection_shares_parent_cache(): + """A USS that selects all of its parent's units in parent order should reuse the + parent's lexsorted spike vector cache by reference, not re-run the counting sort.""" + sorting = generate_sorting(num_units=5, durations=[0.200, 0.200], sampling_frequency=30000.0) + + # Identity selection, with renamed ids to also exercise the rename-only path. + renamed = [f"r{uid}" for uid in sorting.unit_ids] + uss = UnitsSelectionSorting(sorting, unit_ids=list(sorting.unit_ids), renamed_unit_ids=renamed) + + for lexsort in [ + ("sample_index", "segment_index", "unit_index"), + ("sample_index", "unit_index", "segment_index"), + ]: + # Force the parent to build the lexsorted cache. + parent_ordered, _, parent_slices = sorting.to_reordered_spike_vector( + lexsort=lexsort, return_order=True, return_slices=True + ) + key = str(lexsort) + assert key in sorting._cached_lexsorted_spike_vector + + # Reset USS cache and force a build through the override. + uss._cached_lexsorted_spike_vector = {} + uss_ordered, _, uss_slices = uss.to_reordered_spike_vector( + lexsort=lexsort, return_order=True, return_slices=True + ) + + # The cache entry must be the *same* dict object as the parent's. + assert ( + uss._cached_lexsorted_spike_vector[key] is sorting._cached_lexsorted_spike_vector[key] + ), f"identity USS did not share parent lexsorted cache for {lexsort}" + assert uss_ordered is parent_ordered + assert uss_slices is parent_slices if __name__ == "__main__": diff --git a/src/spikeinterface/core/unitsselectionsorting.py b/src/spikeinterface/core/unitsselectionsorting.py index 59356db976..823bc477c7 100644 --- a/src/spikeinterface/core/unitsselectionsorting.py +++ b/src/spikeinterface/core/unitsselectionsorting.py @@ -1,6 +1,7 @@ import numpy as np from .basesorting import BaseSorting, BaseSortingSegment +from .sorting_tools import filter_and_remap_spike_vector, is_spike_vector_sorted class UnitsSelectionSorting(BaseSorting): @@ -46,24 +47,99 @@ def __init__(self, parent_sorting, unit_ids=None, renamed_unit_ids=None): self._kwargs = dict(parent_sorting=parent_sorting, unit_ids=unit_ids, renamed_unit_ids=renamed_unit_ids) def _compute_and_cache_spike_vector(self) -> None: - from spikeinterface.core.sorting_tools import remap_unit_indices_in_vector - if self._parent_sorting._cached_spike_vector is None: self._parent_sorting._compute_and_cache_spike_vector() if self._parent_sorting._cached_spike_vector is None: return - spike_vector, _ = remap_unit_indices_in_vector( - vector=self._parent_sorting._cached_spike_vector, - all_old_unit_ids=self._parent_sorting.unit_ids, - all_new_unit_ids=self._unit_ids, + parent_unit_ids = self._parent_sorting.unit_ids + + # If the user requested an "identity selection" (all parent units, in + # parent order, possibly renamed), the cached parent spike vector is + # identical to the one we want — share the reference and skip the rest. + # See `_is_identity_selection` for the definition. + if self._is_identity_selection(): + self._cached_spike_vector = self._parent_sorting._cached_spike_vector + parent_slices = self._parent_sorting._cached_spike_vector_segment_slices + if parent_slices is not None: + self._cached_spike_vector_segment_slices = parent_slices + return + + # Build a dense LUT from parent unit_index -> new unit_index (-1 = drop). + parent_id_to_pos = {uid: i for i, uid in enumerate(parent_unit_ids)} + unit_mapping = np.full(parent_unit_ids.size, -1, dtype=np.int64) + for new_idx, uid in enumerate(self._unit_ids): + unit_mapping[parent_id_to_pos[uid]] = new_idx + + spike_vector = filter_and_remap_spike_vector( + spike_vector=self._parent_sorting._cached_spike_vector, + unit_mapping=unit_mapping, ) - # lexsort by segment_index, sample_index, unit_index - sort_indices = np.lexsort( - (spike_vector["unit_index"], spike_vector["sample_index"], spike_vector["segment_index"]) + + # The parent's spike vector is sorted by (segment_index, sample_index, unit_index). + # Filtering preserves that order and the remap only changes unit_index values. + # The result stays sorted iff the selected unit_ids appear in the same relative + # order as in the parent (an O(k) check). If not, the vector may still happen to + # be sorted -- verify with an O(n) scan before falling back to O(n log n) lexsort. + assume_single_segment = self.get_num_segments() == 1 + if not self._is_order_preserving_selection() and not is_spike_vector_sorted( + spike_vector, assume_single_segment=assume_single_segment + ): + if assume_single_segment: + sort_indices = np.lexsort((spike_vector["unit_index"], spike_vector["sample_index"])) + else: + sort_indices = np.lexsort( + (spike_vector["unit_index"], spike_vector["sample_index"], spike_vector["segment_index"]) + ) + spike_vector = spike_vector[sort_indices] + + self._cached_spike_vector = spike_vector + + def _is_identity_selection(self) -> bool: + """Return True if self._unit_ids are exactly the parent's unit_ids, in parent order. + + Renaming via ``renamed_unit_ids`` does not affect this — the spike vector + carries unit *indices*, not ids. When True, every cached form of the + parent's spike vector (canonical, lexsorted, etc.) can be shared with + ``self`` by reference. + """ + parent_unit_ids = self._parent_sorting.unit_ids + return self._unit_ids.size == parent_unit_ids.size and np.array_equal(self._unit_ids, parent_unit_ids) + + def to_reordered_spike_vector( + self, lexsort=("sample_index", "segment_index", "unit_index"), return_order=True, return_slices=True + ): + # On an identity selection, the parent's lexsorted cache is exactly + # what we'd compute — just reference it so we don't re-run the counting sort! + if self._is_identity_selection(): + key = str(tuple(lexsort)) + if key not in self._cached_lexsorted_spike_vector: + # Force the parent to populate its own cache (a no-op if already + # cached) before we share the entry. + self._parent_sorting.to_reordered_spike_vector(lexsort=lexsort, return_order=True, return_slices=True) + parent_entry = self._parent_sorting._cached_lexsorted_spike_vector.get(key) + if parent_entry is not None: + self._cached_lexsorted_spike_vector[key] = parent_entry + return super().to_reordered_spike_vector( + lexsort=lexsort, return_order=return_order, return_slices=return_slices ) - self._cached_spike_vector = spike_vector[sort_indices] + + def _is_order_preserving_selection(self) -> bool: + """Return True if self._unit_ids appear in the same relative order as in the parent. + + O(k) where k is the number of selected units. When True, the remapped spike vector + is guaranteed to remain sorted by (segment, sample, unit) without re-sorting. + """ + parent_unit_ids = self._parent_sorting.unit_ids + parent_id_to_pos = {uid: i for i, uid in enumerate(parent_unit_ids)} + prev_pos = -1 + for uid in self._unit_ids: + pos = parent_id_to_pos.get(uid) + if pos is None or pos <= prev_pos: + return False + prev_pos = pos + return True class UnitsSelectionSortingSegment(BaseSortingSegment): @@ -81,3 +157,13 @@ def get_unit_spike_train( unit_id_parent = self._ids_conversion[unit_id] times = self._parent_segment.get_unit_spike_train(unit_id_parent, start_frame, end_frame) return times + + def get_unit_spike_trains( + self, + unit_ids, + start_frame: int | None = None, + end_frame: int | None = None, + ) -> dict: + unit_ids_parent = [self._ids_conversion[unit_id] for unit_id in unit_ids] + parent_trains = self._parent_segment.get_unit_spike_trains(unit_ids_parent, start_frame, end_frame) + return {child_id: parent_trains[parent_id] for child_id, parent_id in zip(unit_ids, unit_ids_parent)} diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index 0e5dd2694d..1a5e610a44 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -1,3 +1,4 @@ +import importlib.util from pathlib import Path import warnings @@ -14,10 +15,13 @@ SortingAnalyzer, ) from spikeinterface.core.core_tools import define_function_from_class +from spikeinterface.core.sorting_tools import build_spike_vector_from_sorted_arrays -from spikeinterface.postprocessing import ComputeSpikeAmplitudes, ComputeSpikeLocations +from spikeinterface.postprocessing import ComputeSpikeLocations from probeinterface import read_prb, Probe +HAVE_NUMBA = importlib.util.find_spec("numba") is not None + class BasePhyKilosortSortingExtractor(BaseSorting): """Base SortingExtractor for Phy and Kilosort output folder. @@ -152,9 +156,14 @@ def __init__( # update spike clusters and times values bad_clusters = [clust for clust in clust_id if clust not in cluster_info["cluster_id"].values] - spike_clusters_clean_idxs = ~np.isin(spike_clusters, bad_clusters) - spike_clusters_clean = spike_clusters[spike_clusters_clean_idxs] - spike_times_clean = spike_times[spike_clusters_clean_idxs] + if len(bad_clusters) > 0: + spike_clusters_clean_idxs = ~np.isin(spike_clusters, bad_clusters) + spike_clusters_clean = spike_clusters[spike_clusters_clean_idxs] + spike_times_clean = spike_times[spike_clusters_clean_idxs] + else: + # No bad clusters — skip the O(N) isin mask and two N-sized copies. + spike_clusters_clean = spike_clusters + spike_times_clean = spike_times if "si_unit_id" in cluster_info.columns: unit_ids = cluster_info["si_unit_id"].values @@ -224,6 +233,46 @@ def __init__( self.add_sorting_segment(PhySortingSegment(spike_times_clean, spike_clusters_clean)) + def _compute_and_cache_spike_vector(self) -> None: + """Build the spike vector directly from the flat single-segment arrays. + + Since Phy/Kilosort segment already holds the full spike_times and + spike_clusters arrays in memory, we can construct the spike vector + in one shot. + """ + assert self.get_num_segments() == 1 + + unit_ids = np.asarray(self.unit_ids) + seg = self.segments[0] + all_spikes = seg._all_spikes + all_clusters = seg._all_clusters + n = all_spikes.size + + # Map cluster ids -> unit indices via a direct lookup table. + # cluster_ids are non-negative integers (Phy/Kilosort convention) and + # the max id is small (one per neural unit), so a "dense" table of size + # max_id + 1 is cheap (kilobytes), even though it reserves space for unit ids + # that don't exist, and lets the mapping run in a single O(N) gather. + # This is ~10x faster than `sorter[searchsorted(sorted_unit_ids, all_clusters)]` + # on large N. + max_id = int(max(unit_ids.max() if unit_ids.size else -1, all_clusters.max() if n else -1)) + cluster_to_unit = np.empty(max_id + 1, dtype=np.int64) + cluster_to_unit[unit_ids] = np.arange(unit_ids.size, dtype=np.int64) + unit_indices = cluster_to_unit[all_clusters] + + # Kilosort/Phy always emit spikes ascending in sample_index but DO NOT + # order cluster_ids within a sample_index. The helper sorts unit_index + # within tied sample_index runs in O(N), avoiding a global lexsort. + spikes = build_spike_vector_from_sorted_arrays( + sample_indices=all_spikes, + unit_indices=unit_indices, + segment_index=0, + ) + segment_slices = np.array([[0, n]], dtype="int64") + + self._cached_spike_vector = spikes + self._cached_spike_vector_segment_slices = segment_slices + class PhySortingSegment(BaseSortingSegment): def __init__(self, all_spikes, all_clusters): @@ -240,6 +289,107 @@ def get_unit_spike_train(self, unit_id, start_frame, end_frame): spike_times = self._all_spikes[start:end][self._all_clusters[start:end] == unit_id] return np.atleast_1d(spike_times.copy().squeeze()) + def get_unit_spike_trains( + self, + unit_ids, + start_frame: int | None = None, + end_frame: int | None = None, + ) -> dict: + """Extract spike trains for several units in one pass. + + If you need to get ~20 or more spike trains, this is usually **much** faster + than calling get_unit_spike_train() for each unit. + + Numba-accelerated, if numba is available. Otherwise, falls back to NumPy. + """ + start = 0 if start_frame is None else np.searchsorted(self._all_spikes, start_frame, side="left") + end = ( + len(self._all_spikes) if end_frame is None else np.searchsorted(self._all_spikes, end_frame, side="left") + ) # Exclude end frame + + spikes = self._all_spikes[start:end] + clusters = self._all_clusters[start:end] + + unit_ids_arr = np.asarray(unit_ids) + num_units = unit_ids_arr.size + if num_units == 0: + return {} + + # Map cluster ids -> unit indices via a direct lookup table. + # See `_compute_and_cache_spike_vector()`. + max_id = int(max(unit_ids_arr.max(), clusters.max() if clusters.size else -1)) + cluster_to_dest = np.full(max_id + 1, -1, dtype=np.int64) + cluster_to_dest[unit_ids_arr] = np.arange(num_units, dtype=np.int64) + dest = cluster_to_dest[clusters] + + if HAVE_NUMBA: + offsets, flat_out = _counting_sort_spikes_by_unit(spikes, dest, num_units) + else: + # NumPy fallback: stable argsort by destination index, then split on offsets. + # Stable sort preserves the input order of spikes within each unit group, + # and since _all_spikes is sorted by sample_index, so is each group. + if dest.size and dest.min() >= 0: + # Trick: Every cluster in `clusters` is in `unit_ids`, so no + # boolean-mask filtering is needed. Skips two N-sized copies. + order = np.argsort(dest, kind="stable") + flat_out = spikes[order] + counts = np.bincount(dest, minlength=num_units) + else: + valid = dest >= 0 + valid_spikes = spikes[valid] + valid_dest = dest[valid] + order = np.argsort(valid_dest, kind="stable") + flat_out = valid_spikes[order] + counts = np.bincount(valid_dest, minlength=num_units) + offsets = np.empty(num_units + 1, dtype=np.int64) + offsets[0] = 0 + np.cumsum(counts, out=offsets[1:]) + + return {unit_ids[i]: flat_out[offsets[i] : offsets[i + 1]] for i in range(num_units)} + + +if HAVE_NUMBA: + import numba + + @numba.jit(nopython=True, nogil=True, cache=True) + def _counting_sort_spikes_by_unit(all_spikes, dest_unit_indices, num_units): + """Counting-sort `all_spikes` into per-unit groups. + + Parameters + ---------- + all_spikes : int64 array + Spike sample indices. + dest_unit_indices : int64 array (same length as all_spikes) + Destination unit index for each spike, or -1 to skip. + num_units : int + Number of destination units. + + Returns + ------- + offsets : int64 array of shape (num_units + 1,) + Offsets into `flat_out`; group k is `flat_out[offsets[k]:offsets[k+1]]`. + flat_out : int64 array + Concatenated spike times, grouped by destination unit index. + """ + n = all_spikes.shape[0] + counts = np.zeros(num_units + 1, dtype=np.int64) + for i in range(n): + u = dest_unit_indices[i] + if u >= 0: + counts[u + 1] += 1 + for k in range(1, num_units + 1): + counts[k] += counts[k - 1] + + flat_out = np.empty(counts[num_units], dtype=all_spikes.dtype) + write_pos = counts[:-1].copy() + for i in range(n): + u = dest_unit_indices[i] + if u >= 0: + flat_out[write_pos[u]] = all_spikes[i] + write_pos[u] += 1 + + return counts, flat_out + class PhySortingExtractor(BasePhyKilosortSortingExtractor): """Load Phy format data as a sorting extractor. diff --git a/src/spikeinterface/extractors/tests/test_phykilosortextractors.py b/src/spikeinterface/extractors/tests/test_phykilosortextractors.py new file mode 100644 index 0000000000..3be84a341c --- /dev/null +++ b/src/spikeinterface/extractors/tests/test_phykilosortextractors.py @@ -0,0 +1,129 @@ +import pytest +import numpy as np + +from spikeinterface.extractors.phykilosortextractors import PhySortingSegment +from spikeinterface.core.sorting_tools import is_spike_vector_sorted +import spikeinterface.extractors.phykilosortextractors as phymod + +# Sorted spike times with known cluster assignments. +# 3 units (ids 10, 20, 30), some co-temporal spikes. +ALL_SPIKES = np.array([100, 100, 200, 300, 300, 300, 400, 500], dtype=np.int64) +ALL_CLUSTERS = np.array([10, 20, 30, 10, 20, 30, 10, 20], dtype=np.int64) +UNIT_IDS = [10, 20, 30] + + +@pytest.mark.parametrize("force_numpy_fallback", [False, True]) +def test_phy_sorting_segment_get_unit_spike_trains(monkeypatch, force_numpy_fallback): + """get_unit_spike_trains must match per-unit calls, for both Numba and NumPy paths.""" + if force_numpy_fallback: + monkeypatch.setattr(phymod, "HAVE_NUMBA", False) + + seg = PhySortingSegment(ALL_SPIKES, ALL_CLUSTERS) + + # Full range, all units + batch = seg.get_unit_spike_trains(UNIT_IDS, start_frame=None, end_frame=None) + assert set(batch.keys()) == set(UNIT_IDS) + for uid in UNIT_IDS: + single = seg.get_unit_spike_train(uid, start_frame=None, end_frame=None) + assert np.array_equal(batch[uid], single), f"Mismatch for unit {uid}" + + assert np.array_equal(batch[10], [100, 300, 400]) + assert np.array_equal(batch[20], [100, 300, 500]) + assert np.array_equal(batch[30], [200, 300]) + + # With start_frame / end_frame slicing + batch_sliced = seg.get_unit_spike_trains(UNIT_IDS, start_frame=200, end_frame=400) + assert np.array_equal(batch_sliced[10], [300]) + assert np.array_equal(batch_sliced[20], [300]) + assert np.array_equal(batch_sliced[30], [200, 300]) + + # Subset of unit_ids + batch_subset = seg.get_unit_spike_trains([20], start_frame=None, end_frame=None) + assert list(batch_subset.keys()) == [20] + assert np.array_equal(batch_subset[20], [100, 300, 500]) + + # Empty unit_ids + assert seg.get_unit_spike_trains([], start_frame=None, end_frame=None) == {} + + +def _make_phy_folder(tmp_path, spike_times=None, spike_clusters=None, cluster_ids=None): + """Create a minimal Phy output folder for testing.""" + if spike_times is None: + spike_times = np.array([100, 100, 200, 300, 300, 300, 400, 500], dtype=np.int64) + if spike_clusters is None: + spike_clusters = np.array([10, 20, 30, 10, 20, 30, 10, 20], dtype=np.int64) + + np.save(tmp_path / "spike_times.npy", spike_times) + np.save(tmp_path / "spike_clusters.npy", spike_clusters) + (tmp_path / "params.py").write_text("sample_rate = 30000.0\n") + if cluster_ids is not None: + cluster_lines = "\n".join(str(cluster_id) for cluster_id in cluster_ids) + (tmp_path / "cluster_info.tsv").write_text(f"cluster_id\n{cluster_lines}\n") + return tmp_path + + +@pytest.mark.parametrize( + ("spike_times", "spike_clusters", "cluster_ids"), + [ + pytest.param( + np.array([100, 200, 300, 400], dtype=np.int64), + np.array([20, 10, 30, 20], dtype=np.int64), + None, + id="canonical-no-cotemporal-ties", + ), + pytest.param( + np.array([100, 100, 200, 300, 300], dtype=np.int64), + np.array([10, 20, 30, 10, 30], dtype=np.int64), + None, + id="canonical-cotemporal-ties", + ), + pytest.param( + np.array([100, 100, 200, 300], dtype=np.int64), + np.array([20, 10, 30, 10], dtype=np.int64), + None, + id="cotemporal-ties-require-lexsort", + ), + pytest.param( + np.array([200, 100, 300, 100], dtype=np.int64), + np.array([10, 20, 30, 10], dtype=np.int64), + None, + id="sample-times-require-lexsort", + ), + pytest.param( + np.array([], dtype=np.int64), + np.array([], dtype=np.int64), + [10, 20], + id="empty-spike-vector", + ), + ], +) +def test_phy_compute_and_cache_spike_vector(tmp_path, spike_times, spike_clusters, cluster_ids): + """Phy override of _compute_and_cache_spike_vector must produce the same + spike vector as the base class (per-unit) implementation.""" + from spikeinterface.core.basesorting import BaseSorting + from spikeinterface.extractors.phykilosortextractors import BasePhyKilosortSortingExtractor + + phy_folder = _make_phy_folder( + tmp_path, + spike_times=spike_times, + spike_clusters=spike_clusters, + cluster_ids=cluster_ids, + ) + sorting = BasePhyKilosortSortingExtractor(phy_folder) + + # Phy override path + sorting._compute_and_cache_spike_vector() + phy_vector = sorting._cached_spike_vector.copy() + phy_segment_slices = sorting._cached_spike_vector_segment_slices.copy() + + # Base class (per-unit) path + sorting._cached_spike_vector = None + sorting._cached_spike_vector_segment_slices = None + BaseSorting._compute_and_cache_spike_vector(sorting) + base_vector = sorting._cached_spike_vector + + assert np.array_equal(phy_vector, base_vector) + assert np.array_equal(phy_segment_slices, np.array([[0, len(phy_vector)]], dtype="int64")) + assert len(phy_vector) == len(spike_times) + assert np.all(phy_vector["segment_index"] == 0) + assert is_spike_vector_sorted(phy_vector)