diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index b0f75930d3..bf2eff9c70 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -654,6 +654,17 @@ class BaseRecordingSegment(TimeSeriesSegment): Abstract class representing a multichannel timeseries, or block of raw ephys traces """ + # Segments that know their channel count at construction (e.g. BinaryRecordingSegment, + # which needs it before being attached to a parent to compute the on-disk layout) set + # self.num_channels. Segments that don't leave this default and inherit the count from the + # parent recording, which is always attached by the time get_traces runs. + num_channels = None + + def get_num_channels(self) -> int: + if self.num_channels is not None: + return self.num_channels + return self.parent_extractor.get_num_channels() + def get_traces( self, start_frame: int | None = None, diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 4fa68ebec0..b82568d643 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -2061,9 +2061,9 @@ def get_traces( channel_indices: list | None = None, ) -> np.ndarray: if channel_indices is None: - n_channels = self.templates.shape[2] + n_channels = self.get_num_channels() elif isinstance(channel_indices, slice): - stop = channel_indices.stop if channel_indices.stop is not None else self.templates.shape[2] + stop = channel_indices.stop if channel_indices.stop is not None else self.get_num_channels() start = channel_indices.start if channel_indices.start is not None else 0 step = channel_indices.step if channel_indices.step is not None else 1 n_channels = math.ceil((stop - start) / step) diff --git a/src/spikeinterface/generation/drift_tools.py b/src/spikeinterface/generation/drift_tools.py index 1800138dae..f7747091a3 100644 --- a/src/spikeinterface/generation/drift_tools.py +++ b/src/spikeinterface/generation/drift_tools.py @@ -542,9 +542,9 @@ def get_traces( end_frame = self.num_samples if end_frame is None else end_frame if channel_indices is None: - n_channels = self.drifting_templates.num_channels + n_channels = self.get_num_channels() elif isinstance(channel_indices, slice): - stop = channel_indices.stop if channel_indices.stop is not None else self.drifting_templates.num_channels + stop = channel_indices.stop if channel_indices.stop is not None else self.get_num_channels() start = channel_indices.start if channel_indices.start is not None else 0 step = channel_indices.step if channel_indices.step is not None else 1 n_channels = math.ceil((stop - start) / step) diff --git a/src/spikeinterface/preprocessing/zero_channel_pad.py b/src/spikeinterface/preprocessing/zero_channel_pad.py index 45d4809cd8..12f09ebcc6 100644 --- a/src/spikeinterface/preprocessing/zero_channel_pad.py +++ b/src/spikeinterface/preprocessing/zero_channel_pad.py @@ -35,7 +35,6 @@ def __init__(self, recording: BaseRecording, padding_start: int = 0, padding_end for segment in recording.segments: recording_segment = TracePaddedRecordingSegment( segment, - recording.get_num_channels(), self.dtype, self.padding_start, self.padding_end, @@ -55,7 +54,6 @@ class TracePaddedRecordingSegment(BasePreprocessorSegment): def __init__( self, recording_segment: BaseRecordingSegment, - num_channels, dtype, padding_left, padding_end, @@ -64,7 +62,6 @@ def __init__( self.padding_start = padding_left self.padding_end = padding_end self.fill_value = fill_value - self.num_channels = num_channels self.num_samples_in_original_segment = recording_segment.get_num_samples() self.dtype = dtype @@ -76,7 +73,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): if isinstance(channel_indices, (np.ndarray, list)): num_channels = len(channel_indices) elif channel_indices == slice(None): - num_channels = self.num_channels + num_channels = self.get_num_channels() else: raise ValueError(f"Unsupported channel_indices type: {type(channel_indices)} raise an issue on github ") @@ -165,7 +162,7 @@ def __init__(self, recording: BaseRecording, num_channels: int, channel_mapping: self.parent_recording = recording self.num_channels = num_channels for segment in recording.segments: - recording_segment = ZeroChannelPaddedRecordingSegment(segment, self.num_channels, self.channel_mapping) + recording_segment = ZeroChannelPaddedRecordingSegment(segment, self.channel_mapping) self.add_recording_segment(recording_segment) # only copy relevant metadata and properties @@ -182,14 +179,13 @@ def __init__(self, recording: BaseRecording, num_channels: int, channel_mapping: class ZeroChannelPaddedRecordingSegment(BasePreprocessorSegment): - def __init__(self, recording_segment: BaseRecordingSegment, num_channels: int, channel_mapping: list): + def __init__(self, recording_segment: BaseRecordingSegment, channel_mapping: list): BasePreprocessorSegment.__init__(self, recording_segment) self.parent_recording_segment = recording_segment - self.num_channels = num_channels self.channel_mapping = channel_mapping def get_traces(self, start_frame, end_frame, channel_indices): - traces = np.zeros((end_frame - start_frame, self.num_channels)) + traces = np.zeros((end_frame - start_frame, self.get_num_channels())) traces[:, self.channel_mapping] = self.parent_recording_segment.get_traces( start_frame=start_frame, end_frame=end_frame, channel_indices=self.channel_mapping )