Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,17 @@ class BaseRecordingSegment(TimeSeriesSegment):
Abstract class representing a multichannel timeseries, or block of raw ephys traces
"""

# Segments that know their channel count at construction (e.g. BinaryRecordingSegment,
# which needs it before being attached to a parent to compute the on-disk layout) set
# self.num_channels. Segments that don't leave this default and inherit the count from the
# parent recording, which is always attached by the time get_traces runs.
num_channels = None

def get_num_channels(self) -> int:
if self.num_channels is not None:
return self.num_channels
return self.parent_extractor.get_num_channels()

def get_traces(
self,
start_frame: int | None = None,
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2061,9 +2061,9 @@ def get_traces(
channel_indices: list | None = None,
) -> np.ndarray:
if channel_indices is None:
n_channels = self.templates.shape[2]
n_channels = self.get_num_channels()
elif isinstance(channel_indices, slice):
stop = channel_indices.stop if channel_indices.stop is not None else self.templates.shape[2]
stop = channel_indices.stop if channel_indices.stop is not None else self.get_num_channels()
start = channel_indices.start if channel_indices.start is not None else 0
step = channel_indices.step if channel_indices.step is not None else 1
n_channels = math.ceil((stop - start) / step)
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/generation/drift_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,9 +542,9 @@ def get_traces(
end_frame = self.num_samples if end_frame is None else end_frame

if channel_indices is None:
n_channels = self.drifting_templates.num_channels
n_channels = self.get_num_channels()
elif isinstance(channel_indices, slice):
stop = channel_indices.stop if channel_indices.stop is not None else self.drifting_templates.num_channels
stop = channel_indices.stop if channel_indices.stop is not None else self.get_num_channels()
start = channel_indices.start if channel_indices.start is not None else 0
step = channel_indices.step if channel_indices.step is not None else 1
n_channels = math.ceil((stop - start) / step)
Expand Down
12 changes: 4 additions & 8 deletions src/spikeinterface/preprocessing/zero_channel_pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def __init__(self, recording: BaseRecording, padding_start: int = 0, padding_end
for segment in recording.segments:
recording_segment = TracePaddedRecordingSegment(
segment,
recording.get_num_channels(),
self.dtype,
self.padding_start,
self.padding_end,
Expand All @@ -55,7 +54,6 @@ class TracePaddedRecordingSegment(BasePreprocessorSegment):
def __init__(
self,
recording_segment: BaseRecordingSegment,
num_channels,
dtype,
padding_left,
padding_end,
Expand All @@ -64,7 +62,6 @@ def __init__(
self.padding_start = padding_left
self.padding_end = padding_end
self.fill_value = fill_value
self.num_channels = num_channels
self.num_samples_in_original_segment = recording_segment.get_num_samples()
self.dtype = dtype

Expand All @@ -76,7 +73,7 @@ def get_traces(self, start_frame, end_frame, channel_indices):
if isinstance(channel_indices, (np.ndarray, list)):
num_channels = len(channel_indices)
elif channel_indices == slice(None):
num_channels = self.num_channels
num_channels = self.get_num_channels()
else:
raise ValueError(f"Unsupported channel_indices type: {type(channel_indices)} raise an issue on github ")

Expand Down Expand Up @@ -165,7 +162,7 @@ def __init__(self, recording: BaseRecording, num_channels: int, channel_mapping:
self.parent_recording = recording
self.num_channels = num_channels
for segment in recording.segments:
recording_segment = ZeroChannelPaddedRecordingSegment(segment, self.num_channels, self.channel_mapping)
recording_segment = ZeroChannelPaddedRecordingSegment(segment, self.channel_mapping)
self.add_recording_segment(recording_segment)

# only copy relevant metadata and properties
Expand All @@ -182,14 +179,13 @@ def __init__(self, recording: BaseRecording, num_channels: int, channel_mapping:


class ZeroChannelPaddedRecordingSegment(BasePreprocessorSegment):
def __init__(self, recording_segment: BaseRecordingSegment, num_channels: int, channel_mapping: list):
def __init__(self, recording_segment: BaseRecordingSegment, channel_mapping: list):
BasePreprocessorSegment.__init__(self, recording_segment)
self.parent_recording_segment = recording_segment
self.num_channels = num_channels
self.channel_mapping = channel_mapping

def get_traces(self, start_frame, end_frame, channel_indices):
traces = np.zeros((end_frame - start_frame, self.num_channels))
traces = np.zeros((end_frame - start_frame, self.get_num_channels()))
traces[:, self.channel_mapping] = self.parent_recording_segment.get_traces(
start_frame=start_frame, end_frame=end_frame, channel_indices=self.channel_mapping
)
Expand Down
Loading