From 5aaa30a06fdae0720723ba42439e4534add1343e Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 24 Jun 2026 20:02:46 -0600 Subject: [PATCH 1/6] refactor recording Replace the parallel _pynwb/_backend methods in the NWB recording path with a single _NwbGeneralReader (reading_method discriminator). The reader owns open/locate/close and the electrode-column reads; the extractor keeps only the SpikeInterface mapping (uV scaling, brain_area rename, core-vs-extra column policy). Sorting and time-series extractors are unchanged. --- .../extractors/nwbextractors.py | 654 +++++++++--------- 1 file changed, 326 insertions(+), 328 deletions(-) diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index b223b97398..42223cd6a1 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -345,37 +345,6 @@ def _find_neurodata_type_from_backend(group, path="", result=None, neurodata_typ return result -def _fetch_time_info_pynwb(electrical_series, samples_for_rate_estimation, load_time_vector=False): - """ - Extracts the sampling frequency and the time vector from an ElectricalSeries object. - """ - sampling_frequency = None - if hasattr(electrical_series, "rate"): - sampling_frequency = electrical_series.rate - - if hasattr(electrical_series, "starting_time"): - t_start = electrical_series.starting_time - else: - t_start = None - - timestamps = None - if hasattr(electrical_series, "timestamps"): - if electrical_series.timestamps is not None: - timestamps = electrical_series.timestamps - t_start = electrical_series.timestamps[0] - - # TimeSeries need to have either timestamps or rate - if sampling_frequency is None: - sampling_frequency = 1.0 / np.median(np.diff(timestamps[:samples_for_rate_estimation])) - - if load_time_vector and timestamps is not None: - times_kwargs = dict(time_vector=electrical_series.timestamps) - else: - times_kwargs = dict(sampling_frequency=sampling_frequency, t_start=t_start) - - return sampling_frequency, times_kwargs - - def _retrieve_electrodes_indices_from_electrical_series_backend(open_file, electrical_series, backend="hdf5"): """ Retrieves the indices of the electrodes from the electrical series. @@ -400,6 +369,266 @@ def _retrieve_electrodes_indices_from_electrical_series_backend(open_file, elect return electrodes_indices +class _NwbGeneralReader: + """Read an NWB recording (its ElectricalSeries and electrodes table) for one storage format. + + SpikeInterface can read an NWB file three ways, and they used to be spread across parallel + ``*_pynwb`` / ``*_backend`` methods with ``if use_pynwb`` / ``if backend == "zarr"`` branches + throughout the extractor. This single reader stores the choice in ``self.reading_method`` (one of + ``"use_pynwb"``, ``"use_hdf5"``, ``"use_zarr"``), and every method switches on that one variable, + so the recording extractor talks to one object and never branches on format itself. + + ``load`` opens the file, locates the series, and sets ``self.nwbfile`` (pynwb) or ``self.file`` (raw) and populates ``self.series``, + ``self.electrodes_table`` and ``self.electrodes_indices``. + """ + + def __init__( + self, + *, + reading_method, + electrical_series_path=None, + ): + assert reading_method in ("use_pynwb", "use_hdf5", "use_zarr"), f"Unknown reading_method {reading_method}" + self.reading_method = reading_method + self.electrical_series_path = electrical_series_path + + self.file = None + self.nwbfile = None + self.series = None + self.electrodes_table = None + self.electrodes_indices = None + + @staticmethod + def _storage_backend(file_path=None, file=None, stream_mode=None): + # Whether the bytes are hdf5 or zarr; needed to pick pynwb's IO and to scan neurodata types. + if stream_mode is None and file is None: + return _get_backend_from_local_file(file_path) + return "zarr" if stream_mode == "zarr" else "hdf5" + + @staticmethod + def available_electrical_series(file_path, stream_mode=None, storage_options=None): + """Paths of every ElectricalSeries in the file, read directly (without pynwb).""" + backend = _NwbGeneralReader._storage_backend(file_path, stream_mode=stream_mode) + file_handle = read_file_from_backend( + file_path=file_path, stream_mode=stream_mode, storage_options=storage_options + ) + return _find_neurodata_type_from_backend(file_handle, neurodata_type="ElectricalSeries", backend=backend) + + def load(self, *, file_path=None, file=None, stream_mode=None, cache=False, stream_cache_path=None): + if self.reading_method == "use_pynwb": + # pynwb opens hdf5 or zarr transparently; detect which only to pick its IO class. + self.nwbfile = read_nwbfile( + backend=self._storage_backend(file_path, file, stream_mode), + file_path=file_path, + file=file, + stream_mode=stream_mode, + cache=cache, + stream_cache_path=stream_cache_path, + ) + self.series = _retrieve_electrical_series_pynwb(self.nwbfile, self.electrical_series_path) + self.electrodes_indices = self.series.electrodes.data[:] + self.electrodes_table = self.nwbfile.electrodes + else: + backend = "zarr" if self.reading_method == "use_zarr" else "hdf5" + self.file = read_file_from_backend( + file_path=file_path, + file=file, + stream_mode=stream_mode, + cache=cache, + stream_cache_path=stream_cache_path, + ) + self.series = self._locate_electrical_series(backend) + self.electrodes_indices = _retrieve_electrodes_indices_from_electrical_series_backend( + self.file, self.series, backend + ) + self.electrodes_table = self.file["/general/extracellular_ephys/electrodes"] + + def _locate_electrical_series(self, backend): + # Resolve electrical_series_path (auto-discovering it when the file has exactly one series) + # and return the series handle, raising a helpful error that lists the options on failure. + if self.electrical_series_path is None: + available = _find_neurodata_type_from_backend(self.file, neurodata_type="ElectricalSeries", backend=backend) + if len(available) != 1: + raise ValueError( + "Multiple ElectricalSeries found in the file. " + "Please specify the 'electrical_series_path' argument:" + f"Available options are: {available}." + ) + self.electrical_series_path = available[0] + try: + return self.file[self.electrical_series_path] + except KeyError: + available = _find_neurodata_type_from_backend(self.file, neurodata_type="ElectricalSeries", backend=backend) + raise ValueError( + f"{self.electrical_series_path} not found in the NWB file!" f"Available options are: {available}." + ) + + def close(self): + # Release the open handle on garbage collection (raw hdf5 / zarr store, or the pynwb read IO). + if self.file is not None: + if hasattr(self.file, "store"): # zarr + self.file.store.close() + else: # hdf5: close every object still open on the file id + import h5py + + for object_id in h5py.h5f.get_obj_ids(self.file.id, types=h5py.h5f.OBJ_ALL): + try: + object_id.close() + except Exception: + warnings.warn(f"Error closing object {h5py.h5i.get_name(object_id).decode('utf-8')}") + elif self.nwbfile is not None: + io = self.nwbfile.get_read_io() + if io is not None: + io.close() + + # --- electrodes table ---------------------------------------------------------------------- + @property + def column_names(self): + if self.reading_method == "use_pynwb": + return list(self.electrodes_table.colnames) + elif self.reading_method == "use_hdf5": + return list(self.electrodes_table.attrs["colnames"]) + elif self.reading_method == "use_zarr": + return list(self.electrodes_table.attrs["colnames"]) + + def read_electrode_property(self, name): + # The electrodes region may reference rows in any order, but h5py only fancy-indexes in + # strictly increasing order, so materialize the column to numpy first (GH-4619). A pynwb + # VectorData and a raw h5py/zarr dataset both support ``[name][:]``. HDF5 stores strings as + # bytes, so decode string columns here once, on behalf of every caller. + values = np.asarray(self.electrodes_table[name][:])[self.electrodes_indices] + if values.dtype.kind in ("S", "O"): + values = np.array([v.decode("utf-8") if isinstance(v, bytes) else v for v in values]) + return values + + def _read_ids(self): + if self.reading_method == "use_pynwb": + ids = self.electrodes_table.id + elif self.reading_method == "use_hdf5": + ids = self.electrodes_table["id"] + elif self.reading_method == "use_zarr": + ids = self.electrodes_table["id"] + return np.asarray(ids[:])[self.electrodes_indices] + + def channel_ids(self): + if "channel_name" in self.column_names: + return list(self.read_electrode_property("channel_name")) + return list(self._read_ids()) + + # --- electrical series --------------------------------------------------------------------- + def data(self): + if self.reading_method == "use_pynwb": + return self.series.data + elif self.reading_method == "use_hdf5": + return self.series["data"] + elif self.reading_method == "use_zarr": + return self.series["data"] + + def dtype(self): + return self.data().dtype + + def time_info(self, samples_for_rate_estimation): + # Return the NWB time facts; the extractor decides how to turn them into times_kwargs. + if self.reading_method == "use_pynwb": + series = self.series + sampling_frequency = series.rate if hasattr(series, "rate") else None + t_start = series.starting_time if hasattr(series, "starting_time") else None + timestamps = series.timestamps if getattr(series, "timestamps", None) is not None else None + if timestamps is not None: + t_start = timestamps[0] + if sampling_frequency is None: + sampling_frequency = 1.0 / np.median(np.diff(timestamps[:samples_for_rate_estimation])) + return sampling_frequency, t_start, timestamps + elif self.reading_method == "use_hdf5": + series = self.series + if "starting_time" in series.keys(): + t_start = series["starting_time"][()] + sampling_frequency = series["starting_time"].attrs["rate"] + timestamps = None + elif "timestamps" in series.keys(): + timestamps = series["timestamps"] + t_start = timestamps[0] + sampling_frequency = 1.0 / np.median(np.diff(timestamps[:samples_for_rate_estimation])) + else: + raise ValueError("TimeSeries must have either starting_time or timestamps") + return sampling_frequency, t_start, timestamps + elif self.reading_method == "use_zarr": + series = self.series + if "starting_time" in series.keys(): + t_start = series["starting_time"][()] + sampling_frequency = series["starting_time"].attrs["rate"] + timestamps = None + elif "timestamps" in series.keys(): + timestamps = series["timestamps"] + t_start = timestamps[0] + sampling_frequency = 1.0 / np.median(np.diff(timestamps[:samples_for_rate_estimation])) + else: + raise ValueError("TimeSeries must have either starting_time or timestamps") + return sampling_frequency, t_start, timestamps + + def _conversion(self): + if self.reading_method == "use_pynwb": + return self.series.conversion + elif self.reading_method == "use_hdf5": + return self.series["data"].attrs["conversion"] + elif self.reading_method == "use_zarr": + return self.series["data"].attrs["conversion"] + + def _channel_conversion(self): + if self.reading_method == "use_pynwb": + channel_conversion = self.series.channel_conversion + return channel_conversion[:] if channel_conversion is not None else None + elif self.reading_method == "use_hdf5": + if self.series.get("channel_conversion", None) is not None: + return self.series["channel_conversion"][:] + return None + elif self.reading_method == "use_zarr": + if self.series.get("channel_conversion", None) is not None: + return self.series["channel_conversion"][:] + return None + + def _series_offset(self): + if self.reading_method == "use_pynwb": + return self.series.offset if hasattr(self.series, "offset") else 0 + elif self.reading_method == "use_hdf5": + data_attributes = self.series["data"].attrs + return data_attributes["offset"] if "offset" in data_attributes else 0 + elif self.reading_method == "use_zarr": + data_attributes = self.series["data"].attrs + return data_attributes["offset"] if "offset" in data_attributes else 0 + + # --- recording ingredients (still volts / NWB names; the extractor applies uV + SI naming) ---- + def gain_to_volts(self): + gain = self._conversion() + channel_conversion = self._channel_conversion() + if channel_conversion is not None: + gain = gain * channel_conversion + return gain + + def offset_to_volts(self): + # NWB stores the offset on the series, or (fallback) per channel in the electrodes table. + offset = self._series_offset() + if offset == 0 and "offset" in self.column_names: + offset = self.read_electrode_property("offset") + return offset + + def locations(self): + if not ("rel_x" in self.column_names and "rel_y" in self.column_names): + return None + ndim = 3 if "rel_z" in self.column_names else 2 + locations = np.zeros((len(self.electrodes_indices), ndim), dtype=float) + locations[:, 0] = self.read_electrode_property("rel_x") + locations[:, 1] = self.read_electrode_property("rel_y") + if "rel_z" in self.column_names: + locations[:, 2] = self.read_electrode_property("rel_z") + return locations + + def groups(self): + if "group_name" not in self.column_names: + return None + return self.read_electrode_property("group_name") + + class _BaseNWBExtractor: "A class for common methods for NWB extractors." @@ -437,7 +666,7 @@ def __del__(self): io.close() -class NwbRecordingExtractor(BaseRecording, _BaseNWBExtractor): +class NwbRecordingExtractor(BaseRecording): """Load an NWBFile as a RecordingExtractor. Parameters @@ -536,34 +765,36 @@ def __init__( self.storage_options = storage_options self.electrical_series_path = electrical_series_path - if self.stream_mode is None and file is None: - self.backend = _get_backend_from_local_file(file_path) - else: - if self.stream_mode == "zarr": - self.backend = "zarr" - else: - self.backend = "hdf5" - # extract info + if use_pynwb and not HAVE_PYNWB: + raise ImportError(self.installation_mesg) + if use_pynwb: - if not HAVE_PYNWB: - raise ImportError(self.installation_mesg) + reading_method = "use_pynwb" # the reader detects hdf5 vs zarr itself for pynwb + else: + reading_method = f"use_{_NwbGeneralReader._storage_backend(file_path, file, self.stream_mode)}" + + self._reader = _NwbGeneralReader( + reading_method=reading_method, + electrical_series_path=self.electrical_series_path, + ) + self._reader.load( + file_path=self.file_path, + file=file, + stream_mode=self.stream_mode, + cache=cache, + stream_cache_path=self.stream_cache_path, + ) + self.electrical_series_path = self._reader.electrical_series_path - ( - channel_ids, - sampling_frequency, - dtype, - segment_data, - times_kwargs, - ) = self._fetch_recording_segment_info_pynwb(file, cache, load_time_vector, samples_for_rate_estimation) + channel_ids = self._reader.channel_ids() + sampling_frequency, t_start, timestamps = self._reader.time_info(samples_for_rate_estimation) + if load_time_vector and timestamps is not None: + times_kwargs = dict(time_vector=timestamps) else: - ( - channel_ids, - sampling_frequency, - dtype, - segment_data, - times_kwargs, - ) = self._fetch_recording_segment_info_backend(file, cache, load_time_vector, samples_for_rate_estimation) + times_kwargs = dict(sampling_frequency=sampling_frequency, t_start=t_start) + segment_data = self._reader.data() + dtype = self._reader.dtype() BaseRecording.__init__(self, channel_ids=channel_ids, sampling_frequency=sampling_frequency, dtype=dtype) recording_segment = NwbRecordingSegment( @@ -572,36 +803,37 @@ def __init__( ) self.add_recording_segment(recording_segment) - # fetch and add main recording properties - if use_pynwb: - gains, offsets, locations, groups = self._fetch_main_properties_pynwb() - else: - gains, offsets, locations, groups = self._fetch_main_properties_backend() - - self.set_channel_gains(gains) - self.set_channel_offsets(offsets) + # fetch and add main recording properties (gains/offsets are cast to uV here) + gains_to_uV = self._reader.gain_to_volts() * 1e6 + offsets_to_uV = self._reader.offset_to_volts() * 1e6 + self.set_channel_gains(gains_to_uV) + self.set_channel_offsets(offsets_to_uV) + locations = self._reader.locations() if locations is not None: self.set_channel_locations(locations) + groups = self._reader.groups() if groups is not None: self.set_channel_groups(groups) - # fetch and add additional recording properties + # Every other electrodes-table column becomes a generic channel property. The columns below + # are skipped because they were already mapped to core recording fields above (channel ids, + # locations, groups, offsets); "location" is exposed but renamed to SpikeInterface's brain_area. if load_channel_properties: - if use_pynwb: - electrodes_table = self._nwbfile.electrodes - electrodes_indices = self.electrical_series.electrodes.data[:] - columns = electrodes_table.colnames - else: - electrodes_table = self._file["/general/extracellular_ephys/electrodes"] - electrodes_indices = _retrieve_electrodes_indices_from_electrical_series_backend( - self._file, self.electrical_series, self.backend - ) - columns = electrodes_table.attrs["colnames"] - properties = self._fetch_other_properties(electrodes_table, electrodes_indices, columns) - - for property_name, property_values in properties.items(): - values = [x.decode("utf-8") if isinstance(x, bytes) else x for x in property_values] - self.set_property(property_name, values) + columns_mapped_to_core_fields = [ + "id", + "rel_x", + "rel_y", + "rel_z", + "group", + "group_name", + "channel_name", + "offset", + ] + for column in self._reader.column_names: + if column in columns_mapped_to_core_fields: + continue + property_name = "brain_area" if column == "location" else column + self.set_property(property_name, self._reader.read_electrode_property(column)) if stream_mode is None and file_path is not None: file_path = str(Path(file_path).resolve()) @@ -637,241 +869,24 @@ def __init__( } # Set extra requirements for the extractor, so they can be installed when using docker - if use_pynwb: + if self._reader.reading_method == "use_pynwb": self.extra_requirements.append("pynwb") - else: - if self.backend == "hdf5": - self.extra_requirements.append("h5py") - if self.backend == "zarr": - self.extra_requirements.append("zarr") + elif self._reader.reading_method == "use_hdf5": + self.extra_requirements.append("h5py") + elif self._reader.reading_method == "use_zarr": + self.extra_requirements.append("zarr") if self.stream_mode == "fsspec": self.extra_requirements.append("fsspec") if self.stream_mode == "remfile": self.extra_requirements.append("remfile") - def _fetch_recording_segment_info_pynwb(self, file, cache, load_time_vector, samples_for_rate_estimation): - self._nwbfile = read_nwbfile( - backend=self.backend, - file_path=self.file_path, - file=file, - stream_mode=self.stream_mode, - cache=cache, - stream_cache_path=self.stream_cache_path, - ) - electrical_series = _retrieve_electrical_series_pynwb(self._nwbfile, self.electrical_series_path) - # The indices in the electrode table corresponding to this electrical series - electrodes_indices = electrical_series.electrodes.data[:] - # The table for all the electrodes in the nwbfile - electrodes_table = self._nwbfile.electrodes - - sampling_frequency, times_kwargs = _fetch_time_info_pynwb( - electrical_series=electrical_series, - samples_for_rate_estimation=samples_for_rate_estimation, - load_time_vector=load_time_vector, - ) - - # Fill channel properties dictionary from electrodes table - if "channel_name" in electrodes_table.colnames: - channel_ids = [ - electrical_series.electrodes["channel_name"][electrodes_index] - for electrodes_index in electrodes_indices - ] - else: - channel_ids = [electrical_series.electrodes.table.id[x] for x in electrodes_indices] - electrical_series_data = electrical_series.data - dtype = electrical_series_data.dtype - - # need this later - self.electrical_series = electrical_series - - return channel_ids, sampling_frequency, dtype, electrical_series_data, times_kwargs - - def _fetch_recording_segment_info_backend(self, file, cache, load_time_vector, samples_for_rate_estimation): - open_file = read_file_from_backend( - file_path=self.file_path, - file=file, - stream_mode=self.stream_mode, - cache=cache, - stream_cache_path=self.stream_cache_path, - ) - - # If the electrical_series_path is not given, `_find_neurodata_type_from_backend` will be called - # And returns a list with the electrical_series_paths available in the file. - # If there is only one electrical series, the electrical_series_path is set to the name of the series, - # otherwise an error is raised. - if self.electrical_series_path is None: - available_electrical_series = _find_neurodata_type_from_backend( - open_file, neurodata_type="ElectricalSeries", backend=self.backend - ) - # if electrical_series_path is None: - if len(available_electrical_series) == 1: - self.electrical_series_path = available_electrical_series[0] - else: - raise ValueError( - "Multiple ElectricalSeries found in the file. " - "Please specify the 'electrical_series_path' argument:" - f"Available options are: {available_electrical_series}." - ) - - # Open the electrical series. In case of failure, raise an error with the available options. - try: - electrical_series = open_file[self.electrical_series_path] - except KeyError: - available_electrical_series = _find_neurodata_type_from_backend( - open_file, neurodata_type="ElectricalSeries", backend=self.backend - ) - raise ValueError( - f"{self.electrical_series_path} not found in the NWB file!" - f"Available options are: {available_electrical_series}." - ) - electrodes_indices = _retrieve_electrodes_indices_from_electrical_series_backend( - open_file, electrical_series, self.backend - ) - # The table for all the electrodes in the nwbfile - electrodes_table = open_file["/general/extracellular_ephys/electrodes"] - electrode_table_columns = electrodes_table.attrs["colnames"] - - # Get sampling frequency - if "starting_time" in electrical_series.keys(): - t_start = electrical_series["starting_time"][()] - sampling_frequency = electrical_series["starting_time"].attrs["rate"] - timestamps = None - elif "timestamps" in electrical_series.keys(): - timestamps = electrical_series["timestamps"][:] - t_start = timestamps[0] - sampling_frequency = 1.0 / np.median(np.diff(timestamps[:samples_for_rate_estimation])) - else: - raise ValueError("TimeSeries must have either starting_time or timestamps") - - if load_time_vector and timestamps is not None: - times_kwargs = dict(time_vector=electrical_series["timestamps"]) - else: - times_kwargs = dict(sampling_frequency=sampling_frequency, t_start=t_start) - - # If channel names are present, use them as channel_ids instead of the electrode ids - if "channel_name" in electrode_table_columns: - channel_names = electrodes_table["channel_name"] - channel_ids = channel_names[electrodes_indices] - # Decode if bytes with utf-8 - channel_ids = [x.decode("utf-8") if isinstance(x, bytes) else x for x in channel_ids] - - else: - channel_ids = [electrodes_table["id"][x] for x in electrodes_indices] - - dtype = electrical_series["data"].dtype - electrical_series_data = electrical_series["data"] - - # need this for later - self.electrical_series = electrical_series - self._file = open_file - - return channel_ids, sampling_frequency, dtype, electrical_series_data, times_kwargs - - def _fetch_locations_and_groups(self, electrodes_table, electrodes_indices): - # Channel locations - locations = None - if "rel_x" in electrodes_table: - if "rel_y" in electrodes_table: - ndim = 3 if "rel_z" in electrodes_table else 2 - locations = np.zeros((self.get_num_channels(), ndim), dtype=float) - locations[:, 0] = electrodes_table["rel_x"][electrodes_indices] - locations[:, 1] = electrodes_table["rel_y"][electrodes_indices] - if "rel_z" in electrodes_table: - locations[:, 2] = electrodes_table["rel_z"][electrodes_indices] - - # Channel groups - groups = None - if "group_name" in electrodes_table: - groups = electrodes_table["group_name"][electrodes_indices][:] - if groups is not None: - groups = np.array([x.decode("utf-8") if isinstance(x, bytes) else x for x in groups]) - return locations, groups - - def _fetch_other_properties(self, electrodes_table, electrodes_indices, columns): - ######### - # Extract and re-name properties from nwbfile TODO: Should be a function - ######## - properties = dict() - properties_to_skip = [ - "id", - "rel_x", - "rel_y", - "rel_z", - "group", - "group_name", - "channel_name", - "offset", - ] - rename_properties = dict(location="brain_area") - - for column in columns: - if column in properties_to_skip: - continue - else: - column_name = rename_properties.get(column, column) - properties[column_name] = electrodes_table[column][electrodes_indices] - - return properties - - def _fetch_main_properties_pynwb(self): - """ - Fetches the main properties from the NWBFile and stores them in the RecordingExtractor, including: - - - gains - - offsets - - locations - - groups - """ - electrodes_indices = self.electrical_series.electrodes.data[:] - electrodes_table = self._nwbfile.electrodes - - # Channels gains - for RecordingExtractor, these are values to cast traces to uV - gains = self.electrical_series.conversion * 1e6 - if self.electrical_series.channel_conversion is not None: - gains = self.electrical_series.conversion * self.electrical_series.channel_conversion[:] * 1e6 - - # Channel offsets - offset = self.electrical_series.offset if hasattr(self.electrical_series, "offset") else 0 - if offset == 0 and "offset" in electrodes_table: - offset = electrodes_table["offset"].data[electrodes_indices] - offsets = offset * 1e6 - - locations, groups = self._fetch_locations_and_groups(electrodes_table, electrodes_indices) - - return gains, offsets, locations, groups - - def _fetch_main_properties_backend(self): - """ - Fetches the main properties from the NWBFile and stores them in the RecordingExtractor, including: - - - gains - - offsets - - locations - - groups - """ - electrodes_indices = _retrieve_electrodes_indices_from_electrical_series_backend( - self._file, self.electrical_series, self.backend - ) - electrodes_table = self._file["/general/extracellular_ephys/electrodes"] - - # Channels gains - for RecordingExtractor, these are values to cast traces to uV - data_attributes = self.electrical_series["data"].attrs - electrical_series_conversion = data_attributes["conversion"] - gains = electrical_series_conversion * 1e6 - channel_conversion = self.electrical_series.get("channel_conversion", None) - if channel_conversion: - gains *= self.electrical_series["channel_conversion"][:] - - # Channel offsets - offset = data_attributes["offset"] if "offset" in data_attributes else 0 - if offset == 0 and "offset" in electrodes_table: - offset = electrodes_table["offset"][electrodes_indices] - offsets = offset * 1e6 - - # Channel locations and groups - locations, groups = self._fetch_locations_and_groups(electrodes_table, electrodes_indices) - - return gains, offsets, locations, groups + def __del__(self): + # Avoid impossible import errors during interpreter shutdown to reduce logging noise. + if getattr(sys, "meta_path", None) is None: + return + reader = getattr(self, "_reader", None) # may be unset if __init__ raised early + if reader is not None: + reader.close() @staticmethod def fetch_available_electrical_series_paths( @@ -908,26 +923,9 @@ def fetch_available_electrical_series_paths( - "processing/my_custom_module/MyContainer/ElectricalSeries2" """ - if stream_mode is None: - backend = _get_backend_from_local_file(file_path) - else: - if stream_mode == "zarr": - backend = "zarr" - else: - backend = "hdf5" - - file_handle = read_file_from_backend( - file_path=file_path, - stream_mode=stream_mode, - storage_options=storage_options, - ) - - electrical_series_paths = _find_neurodata_type_from_backend( - file_handle, - neurodata_type="ElectricalSeries", - backend=backend, + return _NwbGeneralReader.available_electrical_series( + file_path, stream_mode=stream_mode, storage_options=storage_options ) - return electrical_series_paths class NwbRecordingSegment(BaseRecordingSegment): From 55450c3c4b4edd26d19785ac2e79db754b8d8aa5 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 24 Jun 2026 20:23:44 -0600 Subject: [PATCH 2/6] sorting done Migrate NwbSortingExtractor onto the single _NWBReader: add the units side (load_units, unit_ids, spike_times, spike_times_index, units_column_names, available_units_tables, rate_and_t_start_from_electrical_series) and thread storage_options through _open_handle. The sorting extractor drops its parallel _fetch_sorting_segment_info_pynwb/_backend pair, self.backend, and the _BaseNWBExtractor mixin; the ragged-property logic in _fetch_properties is preserved, just sourcing units_table from the reader. Time-series extractor unchanged. --- .../extractors/nwbextractors.py | 452 ++++++++---------- 1 file changed, 212 insertions(+), 240 deletions(-) diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index 42223cd6a1..b87191aadb 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -369,34 +369,28 @@ def _retrieve_electrodes_indices_from_electrical_series_backend(open_file, elect return electrodes_indices -class _NwbGeneralReader: +class _NWBReader: """Read an NWB recording (its ElectricalSeries and electrodes table) for one storage format. - SpikeInterface can read an NWB file three ways, and they used to be spread across parallel - ``*_pynwb`` / ``*_backend`` methods with ``if use_pynwb`` / ``if backend == "zarr"`` branches - throughout the extractor. This single reader stores the choice in ``self.reading_method`` (one of - ``"use_pynwb"``, ``"use_hdf5"``, ``"use_zarr"``), and every method switches on that one variable, - so the recording extractor talks to one object and never branches on format itself. + SpikeInterface can read an NWB file three ways; the choice is stored in ``self.reading_method`` + (``"use_pynwb"`` / ``"use_hdf5"`` / ``"use_zarr"``), and every method switches on that one + variable, so the extractor talks to one object and never branches on format itself. - ``load`` opens the file, locates the series, and sets ``self.nwbfile`` (pynwb) or ``self.file`` (raw) and populates ``self.series``, - ``self.electrodes_table`` and ``self.electrodes_indices``. + ``load`` opens the file handle, locates the ElectricalSeries, and binds ``self.series``, + ``self.electrodes_table`` and ``self.electrodes_indices``; ``close`` releases the handle. """ - def __init__( - self, - *, - reading_method, - electrical_series_path=None, - ): + def __init__(self, *, reading_method, electrical_series_path=None, unit_table_path=None): assert reading_method in ("use_pynwb", "use_hdf5", "use_zarr"), f"Unknown reading_method {reading_method}" self.reading_method = reading_method self.electrical_series_path = electrical_series_path - + self.unit_table_path = unit_table_path self.file = None self.nwbfile = None self.series = None self.electrodes_table = None self.electrodes_indices = None + self.units_table = None @staticmethod def _storage_backend(file_path=None, file=None, stream_mode=None): @@ -405,18 +399,9 @@ def _storage_backend(file_path=None, file=None, stream_mode=None): return _get_backend_from_local_file(file_path) return "zarr" if stream_mode == "zarr" else "hdf5" - @staticmethod - def available_electrical_series(file_path, stream_mode=None, storage_options=None): - """Paths of every ElectricalSeries in the file, read directly (without pynwb).""" - backend = _NwbGeneralReader._storage_backend(file_path, stream_mode=stream_mode) - file_handle = read_file_from_backend( - file_path=file_path, stream_mode=stream_mode, storage_options=storage_options - ) - return _find_neurodata_type_from_backend(file_handle, neurodata_type="ElectricalSeries", backend=backend) - - def load(self, *, file_path=None, file=None, stream_mode=None, cache=False, stream_cache_path=None): + def _open_handle(self, *, file_path, file, stream_mode, cache, stream_cache_path, storage_options=None): + # Open the file handle: a pynwb NWBFile, or a raw h5py.File / zarr group. if self.reading_method == "use_pynwb": - # pynwb opens hdf5 or zarr transparently; detect which only to pick its IO class. self.nwbfile = read_nwbfile( backend=self._storage_backend(file_path, file, stream_mode), file_path=file_path, @@ -424,25 +409,168 @@ def load(self, *, file_path=None, file=None, stream_mode=None, cache=False, stre stream_mode=stream_mode, cache=cache, stream_cache_path=stream_cache_path, + storage_options=storage_options, ) - self.series = _retrieve_electrical_series_pynwb(self.nwbfile, self.electrical_series_path) - self.electrodes_indices = self.series.electrodes.data[:] - self.electrodes_table = self.nwbfile.electrodes else: - backend = "zarr" if self.reading_method == "use_zarr" else "hdf5" self.file = read_file_from_backend( file_path=file_path, file=file, stream_mode=stream_mode, cache=cache, stream_cache_path=stream_cache_path, + storage_options=storage_options, ) + + def _read_column(self, table, name, indices=None): + # Materialize a table column to numpy (h5py only fancy-indexes increasing, GH-4619), optionally + # reorder to `indices`, and decode HDF5 byte-strings once on behalf of every caller. + values = np.asarray(table[name][:]) + if indices is not None: + values = values[indices] + if values.dtype.kind in ("S", "O"): + values = np.array([v.decode("utf-8") if isinstance(v, bytes) else v for v in values]) + return values + + def close(self): + # Release the open handle (raw hdf5 / zarr store, or the pynwb read IO). + if self.file is not None: + if hasattr(self.file, "store"): # zarr + self.file.store.close() + else: # hdf5: close every object still open on the file id + import h5py + + for object_id in h5py.h5f.get_obj_ids(self.file.id, types=h5py.h5f.OBJ_ALL): + try: + object_id.close() + except Exception: + warnings.warn(f"Error closing object {h5py.h5i.get_name(object_id).decode('utf-8')}") + elif self.nwbfile is not None: + io = self.nwbfile.get_read_io() + if io is not None: + io.close() + + def __del__(self): + # Release the file handle on garbage collection. + if getattr(sys, "meta_path", None) is None: # avoid import errors during interpreter shutdown + return + self.close() + + @staticmethod + def available_electrical_series(file_path, stream_mode=None, storage_options=None): + """Paths of every ElectricalSeries in the file, read directly (without pynwb).""" + backend = _NWBReader._storage_backend(file_path, stream_mode=stream_mode) + file_handle = read_file_from_backend( + file_path=file_path, stream_mode=stream_mode, storage_options=storage_options + ) + return _find_neurodata_type_from_backend(file_handle, neurodata_type="ElectricalSeries", backend=backend) + + def load_recording( + self, *, file_path=None, file=None, stream_mode=None, cache=False, stream_cache_path=None, storage_options=None + ): + # Open the handle and bind the ElectricalSeries + electrodes table. + self._open_handle( + file_path=file_path, + file=file, + stream_mode=stream_mode, + cache=cache, + stream_cache_path=stream_cache_path, + storage_options=storage_options, + ) + if self.reading_method == "use_pynwb": + self.series = _retrieve_electrical_series_pynwb(self.nwbfile, self.electrical_series_path) + self.electrodes_indices = self.series.electrodes.data[:] + self.electrodes_table = self.nwbfile.electrodes + else: + backend = "zarr" if self.reading_method == "use_zarr" else "hdf5" self.series = self._locate_electrical_series(backend) self.electrodes_indices = _retrieve_electrodes_indices_from_electrical_series_backend( self.file, self.series, backend ) self.electrodes_table = self.file["/general/extracellular_ephys/electrodes"] + def load_units( + self, *, file_path=None, file=None, stream_mode=None, cache=False, stream_cache_path=None, storage_options=None + ): + # Open the handle and bind the Units table. + self._open_handle( + file_path=file_path, + file=file, + stream_mode=stream_mode, + cache=cache, + stream_cache_path=stream_cache_path, + storage_options=storage_options, + ) + if self.reading_method == "use_pynwb": + if self.unit_table_path == "units": + self.units_table = self.nwbfile.units + else: + self.units_table = _retrieve_unit_table_pynwb(self.nwbfile, unit_table_path=self.unit_table_path) + else: + backend = "zarr" if self.reading_method == "use_zarr" else "hdf5" + self.units_table = self._locate_units_table(backend) + + def _locate_units_table(self, backend): + # Resolve unit_table_path (auto-discovering it when the file has exactly one Units table) + # and return the table handle, raising a helpful error that lists the options on failure. + if self.unit_table_path is None: + available = _find_neurodata_type_from_backend(self.file, neurodata_type="Units", backend=backend) + if len(available) != 1: + raise ValueError( + "Multiple Units tables found in the file. " + "Please specify the 'unit_table_path' argument:" + f"Available options are: {available}." + ) + self.unit_table_path = available[0] + try: + return self.file[self.unit_table_path] + except KeyError: + available = _find_neurodata_type_from_backend(self.file, neurodata_type="Units", backend=backend) + raise ValueError( + f"{self.unit_table_path} not found in the NWB file!" f"Available options are: {available}." + ) + + @staticmethod + def available_units_tables(file_path, stream_mode=None, storage_options=None): + """Paths of every Units table in the file, read directly (without pynwb).""" + backend = _NWBReader._storage_backend(file_path, stream_mode=stream_mode) + file_handle = read_file_from_backend( + file_path=file_path, stream_mode=stream_mode, storage_options=storage_options + ) + return _find_neurodata_type_from_backend(file_handle, neurodata_type="Units", backend=backend) + + @property + def units_column_names(self): + if self.reading_method == "use_pynwb": + return [column.name for column in self.units_table.columns] + return list(self.units_table.keys()) + + def unit_ids(self): + if "unit_name" in self.units_column_names: + return list(self._read_column(self.units_table, "unit_name")) + if self.reading_method == "use_pynwb": + return list(np.asarray(self.units_table.id[:])) + return list(np.asarray(self.units_table["id"][:])) + + def spike_times(self): + if self.reading_method == "use_pynwb": + return {column.name: column for column in self.units_table.columns}["spike_times"].data + return self.units_table["spike_times"] + + def spike_times_index(self): + if self.reading_method == "use_pynwb": + return {column.name: column for column in self.units_table.columns}["spike_times_index"].data + return self.units_table["spike_times_index"] + + def rate_and_t_start_from_electrical_series(self, samples_for_rate_estimation): + # Locate the ElectricalSeries the sorting came from and read its rate / t_start. + if self.reading_method == "use_pynwb": + self.series = _retrieve_electrical_series_pynwb(self.nwbfile, self.electrical_series_path) + else: + backend = "zarr" if self.reading_method == "use_zarr" else "hdf5" + self.series = self._locate_electrical_series(backend) + sampling_frequency, t_start, _ = self.time_info(samples_for_rate_estimation) + return sampling_frequency, t_start + def _locate_electrical_series(self, backend): # Resolve electrical_series_path (auto-discovering it when the file has exactly one series) # and return the series handle, raising a helpful error that lists the options on failure. @@ -463,24 +591,6 @@ def _locate_electrical_series(self, backend): f"{self.electrical_series_path} not found in the NWB file!" f"Available options are: {available}." ) - def close(self): - # Release the open handle on garbage collection (raw hdf5 / zarr store, or the pynwb read IO). - if self.file is not None: - if hasattr(self.file, "store"): # zarr - self.file.store.close() - else: # hdf5: close every object still open on the file id - import h5py - - for object_id in h5py.h5f.get_obj_ids(self.file.id, types=h5py.h5f.OBJ_ALL): - try: - object_id.close() - except Exception: - warnings.warn(f"Error closing object {h5py.h5i.get_name(object_id).decode('utf-8')}") - elif self.nwbfile is not None: - io = self.nwbfile.get_read_io() - if io is not None: - io.close() - # --- electrodes table ---------------------------------------------------------------------- @property def column_names(self): @@ -492,14 +602,8 @@ def column_names(self): return list(self.electrodes_table.attrs["colnames"]) def read_electrode_property(self, name): - # The electrodes region may reference rows in any order, but h5py only fancy-indexes in - # strictly increasing order, so materialize the column to numpy first (GH-4619). A pynwb - # VectorData and a raw h5py/zarr dataset both support ``[name][:]``. HDF5 stores strings as - # bytes, so decode string columns here once, on behalf of every caller. - values = np.asarray(self.electrodes_table[name][:])[self.electrodes_indices] - if values.dtype.kind in ("S", "O"): - values = np.array([v.decode("utf-8") if isinstance(v, bytes) else v for v in values]) - return values + # An electrode property is one electrodes-table column, reordered to this series' channels. + return self._read_column(self.electrodes_table, name, self.electrodes_indices) def _read_ids(self): if self.reading_method == "use_pynwb": @@ -772,18 +876,19 @@ def __init__( if use_pynwb: reading_method = "use_pynwb" # the reader detects hdf5 vs zarr itself for pynwb else: - reading_method = f"use_{_NwbGeneralReader._storage_backend(file_path, file, self.stream_mode)}" + reading_method = f"use_{_NWBReader._storage_backend(file_path, file, self.stream_mode)}" - self._reader = _NwbGeneralReader( + self._reader = _NWBReader( reading_method=reading_method, electrical_series_path=self.electrical_series_path, ) - self._reader.load( + self._reader.load_recording( file_path=self.file_path, file=file, stream_mode=self.stream_mode, cache=cache, stream_cache_path=self.stream_cache_path, + storage_options=self.storage_options, ) self.electrical_series_path = self._reader.electrical_series_path @@ -880,14 +985,6 @@ def __init__( if self.stream_mode == "remfile": self.extra_requirements.append("remfile") - def __del__(self): - # Avoid impossible import errors during interpreter shutdown to reduce logging noise. - if getattr(sys, "meta_path", None) is None: - return - reader = getattr(self, "_reader", None) # may be unset if __init__ raised early - if reader is not None: - reader.close() - @staticmethod def fetch_available_electrical_series_paths( file_path: str | Path, @@ -923,7 +1020,7 @@ def fetch_available_electrical_series_paths( - "processing/my_custom_module/MyContainer/ElectricalSeries2" """ - return _NwbGeneralReader.available_electrical_series( + return _NWBReader.available_electrical_series( file_path, stream_mode=stream_mode, storage_options=storage_options ) @@ -963,7 +1060,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): return traces -class NwbSortingExtractor(BaseSorting, _BaseNWBExtractor): +class NwbSortingExtractor(BaseSorting): """Load an NWBFile as a SortingExtractor. Parameters @@ -1041,27 +1138,49 @@ def __init__( self.t_start = t_start self.provided_or_electrical_series_sampling_frequency = sampling_frequency self.storage_options = storage_options - self.units_table = None - if self.stream_mode is None: - self.backend = _get_backend_from_local_file(file_path) - else: - if self.stream_mode == "zarr": - self.backend = "zarr" - else: - self.backend = "hdf5" + if use_pynwb and not HAVE_PYNWB: + raise ImportError(self.installation_mesg) if use_pynwb: - if not HAVE_PYNWB: - raise ImportError(self.installation_mesg) - - unit_ids, spike_times_data, spike_times_index_data = self._fetch_sorting_segment_info_pynwb( - unit_table_path=unit_table_path, samples_for_rate_estimation=samples_for_rate_estimation, cache=cache - ) + reading_method = "use_pynwb" else: - unit_ids, spike_times_data, spike_times_index_data = self._fetch_sorting_segment_info_backend( - unit_table_path=unit_table_path, samples_for_rate_estimation=samples_for_rate_estimation, cache=cache + reading_method = f"use_{_NWBReader._storage_backend(file_path, stream_mode=stream_mode)}" + + self._reader = _NWBReader( + reading_method=reading_method, + electrical_series_path=electrical_series_path, + unit_table_path=unit_table_path, + ) + self._reader.load_units( + file_path=file_path, + stream_mode=stream_mode, + cache=cache, + stream_cache_path=stream_cache_path, + storage_options=storage_options, + ) + self.units_table = self._reader.units_table + + # A sorting needs a sampling_frequency and t_start; when not provided, take them from the + # ElectricalSeries the sorting was computed from. + if self.provided_or_electrical_series_sampling_frequency is None or self.t_start is None: + series_sampling_frequency, series_t_start = self._reader.rate_and_t_start_from_electrical_series( + samples_for_rate_estimation ) + if self.provided_or_electrical_series_sampling_frequency is None: + self.provided_or_electrical_series_sampling_frequency = series_sampling_frequency + if self.t_start is None: + self.t_start = series_t_start + assert ( + self.provided_or_electrical_series_sampling_frequency is not None + ), "Couldn't load sampling frequency. Please provide it with the 'sampling_frequency' argument" + assert ( + self.t_start is not None + ), "Couldn't load a starting time for the sorting. Please provide it with the 't_start' argument" + + unit_ids = self._reader.unit_ids() + spike_times_data = self._reader.spike_times() + spike_times_index_data = self._reader.spike_times_index() BaseSorting.__init__( self, sampling_frequency=self.provided_or_electrical_series_sampling_frequency, unit_ids=unit_ids @@ -1077,16 +1196,17 @@ def __init__( # fetch and add sorting properties if load_unit_properties: - if use_pynwb: - columns = [c.name for c in self.units_table.columns] - self.extra_requirements.append("pynwb") - else: - columns = list(self.units_table.keys()) - self.extra_requirements.append("h5py") - properties = self._fetch_properties(columns) + properties = self._fetch_properties(self._reader.units_column_names) for property_name, property_values in properties.items(): values = [x.decode("utf-8") if isinstance(x, bytes) else x for x in property_values] self.set_property(property_name, values) + + if reading_method == "use_pynwb": + self.extra_requirements.append("pynwb") + elif reading_method == "use_hdf5": + self.extra_requirements.append("h5py") + elif reading_method == "use_zarr": + self.extra_requirements.append("zarr") if stream_mode is not None: self.extra_requirements.append(stream_mode) @@ -1115,154 +1235,6 @@ def __init__( "t_start": self.t_start, } - def _fetch_sorting_segment_info_pynwb( - self, unit_table_path: str = None, samples_for_rate_estimation: int = 1000, cache: bool = False - ): - self._nwbfile = read_nwbfile( - backend=self.backend, - file_path=self.file_path, - stream_mode=self.stream_mode, - cache=cache, - stream_cache_path=self.stream_cache_path, - storage_options=self.storage_options, - ) - - timestamps = None - if self.provided_or_electrical_series_sampling_frequency is None: - # defines the electrical series from where the sorting came from - # important to know the sampling_frequency - self.electrical_series = _retrieve_electrical_series_pynwb(self._nwbfile, self.electrical_series_path) - # get rate - if self.electrical_series.rate is not None: - self.provided_or_electrical_series_sampling_frequency = self.electrical_series.rate - self.t_start = self.electrical_series.starting_time - else: - if hasattr(self.electrical_series, "timestamps"): - if self.electrical_series.timestamps is not None: - timestamps = self.electrical_series.timestamps - self.provided_or_electrical_series_sampling_frequency = 1 / np.median( - np.diff(timestamps[:samples_for_rate_estimation]) - ) - self.t_start = timestamps[0] - assert ( - self.provided_or_electrical_series_sampling_frequency is not None - ), "Couldn't load sampling frequency. Please provide it with the 'sampling_frequency' argument" - assert ( - self.t_start is not None - ), "Couldn't load a starting time for the sorting. Please provide it with the 't_start' argument" - if unit_table_path == "units": - units_table = self._nwbfile.units - else: - units_table = _retrieve_unit_table_pynwb(self._nwbfile, unit_table_path=unit_table_path) - - name_to_column_data = {c.name: c for c in units_table.columns} - spike_times_data = name_to_column_data.pop("spike_times").data - spike_times_index_data = name_to_column_data.pop("spike_times_index").data - - units_ids = name_to_column_data.pop("unit_name", None) - if units_ids is None: - units_ids = units_table["id"].data - - # need this for later - self.units_table = units_table - - return units_ids, spike_times_data, spike_times_index_data - - def _fetch_sorting_segment_info_backend( - self, unit_table_path: str = None, samples_for_rate_estimation: int = 1000, cache: bool = False - ): - open_file = read_file_from_backend( - file_path=self.file_path, - stream_mode=self.stream_mode, - cache=cache, - stream_cache_path=self.stream_cache_path, - storage_options=self.storage_options, - ) - - timestamps = None - - if self.provided_or_electrical_series_sampling_frequency is None or self.t_start is None: - # defines the electrical series from where the sorting came from - # important to know the sampling_frequency - available_electrical_series = _find_neurodata_type_from_backend( - open_file, neurodata_type="ElectricalSeries", backend=self.backend - ) - if self.electrical_series_path is None: - if len(available_electrical_series) == 1: - self.electrical_series_path = available_electrical_series[0] - else: - raise ValueError( - "Multiple ElectricalSeries found in the file. " - "Please specify the 'electrical_series_path' argument:" - f"Available options are: {available_electrical_series}." - ) - else: - if self.electrical_series_path not in available_electrical_series: - raise ValueError( - f"'{self.electrical_series_path}' not found in the file. " - f"Available options are: {available_electrical_series}" - ) - electrical_series = open_file[self.electrical_series_path] - - # Get sampling frequency - if "starting_time" in electrical_series.keys(): - self.t_start = electrical_series["starting_time"][()] - self.provided_or_electrical_series_sampling_frequency = electrical_series["starting_time"].attrs["rate"] - elif "timestamps" in electrical_series.keys(): - timestamps = electrical_series["timestamps"][:] - self.t_start = timestamps[0] - self.provided_or_electrical_series_sampling_frequency = 1.0 / np.median( - np.diff(timestamps[:samples_for_rate_estimation]) - ) - - assert ( - self.provided_or_electrical_series_sampling_frequency is not None - ), "Couldn't load sampling frequency. Please provide it with the 'sampling_frequency' argument" - assert ( - self.t_start is not None - ), "Couldn't load a starting time for the sorting. Please provide it with the 't_start' argument" - - if unit_table_path is None: - available_unit_table_paths = _find_neurodata_type_from_backend( - open_file, neurodata_type="Units", backend=self.backend - ) - if len(available_unit_table_paths) == 1: - unit_table_path = available_unit_table_paths[0] - else: - raise ValueError( - "Multiple Units tables found in the file. " - "Please specify the 'unit_table_path' argument:" - f"Available options are: {available_unit_table_paths}." - ) - # Try to open the unit table. If it fails, raise an error with the available options. - try: - units_table = open_file[unit_table_path] - except KeyError: - available_unit_table_paths = _find_neurodata_type_from_backend( - open_file, neurodata_type="Units", backend=self.backend - ) - raise ValueError( - f"{unit_table_path} not found in the NWB file!" f"Available options are: {available_unit_table_paths}." - ) - self.units_table_location = unit_table_path - units_table = open_file[self.units_table_location] - - spike_times_data = units_table["spike_times"] - spike_times_index_data = units_table["spike_times_index"] - - if "unit_name" in units_table: - unit_ids = units_table["unit_name"] - else: - unit_ids = units_table["id"] - - decode_to_string = lambda x: x.decode("utf-8") if isinstance(x, bytes) else x - unit_ids = [decode_to_string(id) for id in unit_ids] - - # need this for later - self.units_table = units_table - - return unit_ids, spike_times_data, spike_times_index_data - def _fetch_properties(self, columns): units_table = self.units_table From d84a278f5151d0ee0f5f8dc3eb68a6f91e7bdea2 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 24 Jun 2026 20:31:09 -0600 Subject: [PATCH 3/6] timeseries done Migrate NwbTimeSeriesExtractor onto the single _NWBReader (load_timeseries / available_timeseries) and retire the _BaseNWBExtractor mixin: no extractor inherits it anymore, so file-handle cleanup lives entirely in _NWBReader.close()/__del__ shared by all three extractors via composition. Also drop a stray no-op '1' line. All three NWB extractors now speak one idiom. --- .../extractors/nwbextractors.py | 287 ++++++------------ 1 file changed, 91 insertions(+), 196 deletions(-) diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index b87191aadb..202d880244 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -156,9 +156,6 @@ def read_nwbfile( return nwbfile -1 - - def _retrieve_electrical_series_pynwb( nwbfile: "NWBFile", electrical_series_path: Optional[str] = None ) -> "ElectricalSeries": @@ -380,11 +377,12 @@ class _NWBReader: ``self.electrodes_table`` and ``self.electrodes_indices``; ``close`` releases the handle. """ - def __init__(self, *, reading_method, electrical_series_path=None, unit_table_path=None): + def __init__(self, *, reading_method, electrical_series_path=None, unit_table_path=None, timeseries_path=None): assert reading_method in ("use_pynwb", "use_hdf5", "use_zarr"), f"Unknown reading_method {reading_method}" self.reading_method = reading_method self.electrical_series_path = electrical_series_path self.unit_table_path = unit_table_path + self.timeseries_path = timeseries_path self.file = None self.nwbfile = None self.series = None @@ -571,6 +569,64 @@ def rate_and_t_start_from_electrical_series(self, samples_for_rate_estimation): sampling_frequency, t_start, _ = self.time_info(samples_for_rate_estimation) return sampling_frequency, t_start + # --- generic TimeSeries (time-series recording) -------------------------------------------- + def load_timeseries( + self, *, file_path=None, file=None, stream_mode=None, cache=False, stream_cache_path=None, storage_options=None + ): + # Open the handle and bind a generic TimeSeries as self.series (no electrodes table). + self._open_handle( + file_path=file_path, + file=file, + stream_mode=stream_mode, + cache=cache, + stream_cache_path=stream_cache_path, + storage_options=storage_options, + ) + if self.reading_method == "use_pynwb": + self.series = self._retrieve_timeseries_pynwb() + else: + backend = "zarr" if self.reading_method == "use_zarr" else "hdf5" + self.series = self._locate_timeseries(backend) + + def _retrieve_timeseries_pynwb(self): + from pynwb.base import TimeSeries + + time_series_dict = {} + for item in self.nwbfile.all_children(): + if isinstance(item, TimeSeries): + time_series_dict[item.data.name.replace("/data", "")[1:]] = item + if self.timeseries_path is not None: + if self.timeseries_path not in time_series_dict: + raise ValueError(f"TimeSeries {self.timeseries_path} not found in file") + elif len(time_series_dict) == 1: + self.timeseries_path = list(time_series_dict.keys())[0] + else: + raise ValueError( + f"Multiple TimeSeries found! Specify 'timeseries_path'. Options: {list(time_series_dict.keys())}" + ) + return time_series_dict[self.timeseries_path] + + def _locate_timeseries(self, backend): + if self.timeseries_path is None: + available = _find_timeseries_from_backend(self.file, backend=backend) + if len(available) != 1: + raise ValueError(f"Multiple TimeSeries found! Specify 'timeseries_path'. Options: {available}") + self.timeseries_path = available[0] + try: + return self.file[self.timeseries_path] + except KeyError: + available = _find_timeseries_from_backend(self.file, backend=backend) + raise ValueError(f"{self.timeseries_path} not found! Available options: {available}") + + @staticmethod + def available_timeseries(file_path, stream_mode=None, storage_options=None): + """Paths of every TimeSeries in the file, read directly (without pynwb).""" + backend = _NWBReader._storage_backend(file_path, stream_mode=stream_mode) + file_handle = read_file_from_backend( + file_path=file_path, stream_mode=stream_mode, storage_options=storage_options + ) + return _find_timeseries_from_backend(file_handle, backend=backend) + def _locate_electrical_series(self, backend): # Resolve electrical_series_path (auto-discovering it when the file has exactly one series) # and return the series handle, raising a helpful error that lists the options on failure. @@ -733,43 +789,6 @@ def groups(self): return self.read_electrode_property("group_name") -class _BaseNWBExtractor: - "A class for common methods for NWB extractors." - - def _close_hdf5_file(self): - has_hdf5_backend = hasattr(self, "_file") - if has_hdf5_backend: - import h5py - - main_file_id = self._file.id - open_object_ids_main = h5py.h5f.get_obj_ids(main_file_id, types=h5py.h5f.OBJ_ALL) - for object_id in open_object_ids_main: - object_name = h5py.h5i.get_name(object_id).decode("utf-8") - try: - object_id.close() - except: - import warnings - - warnings.warn(f"Error closing object {object_name}") - - def __del__(self): - # Avoid impossible import errors during deletion to reduce logging noise - if getattr(sys, "meta_path", None) is None: - return - - # backend mode - if hasattr(self, "_file"): - if hasattr(self._file, "store"): - self._file.store.close() - else: - self._close_hdf5_file() - # pynwb mode - elif hasattr(self, "_nwbfile"): - io = self._nwbfile.get_read_io() - if io is not None: - io.close() - - class NwbRecordingExtractor(BaseRecording): """Load an NWBFile as a RecordingExtractor. @@ -1387,7 +1406,7 @@ def _find_timeseries_from_backend(group, path="", result=None, backend="hdf5"): return result -class NwbTimeSeriesExtractor(BaseRecording, _BaseNWBExtractor): +class NwbTimeSeriesExtractor(BaseRecording): """Load a TimeSeries from an NWBFile as a RecordingExtractor. Parameters @@ -1454,22 +1473,34 @@ def __init__( self.storage_options = storage_options self.timeseries_path = timeseries_path - if self.stream_mode is None and file is None: - self.backend = _get_backend_from_local_file(file_path) - else: - self.backend = "zarr" if self.stream_mode == "zarr" else "hdf5" + if use_pynwb and not HAVE_PYNWB: + raise ImportError(self.installation_mesg) if use_pynwb: - if not HAVE_PYNWB: - raise ImportError(self.installation_mesg) + reading_method = "use_pynwb" + else: + reading_method = f"use_{_NWBReader._storage_backend(file_path, file, stream_mode)}" - channel_ids, sampling_frequency, dtype, segment_data, times_kwargs = self._fetch_recording_segment_info( - file, cache, load_time_vector, samples_for_rate_estimation - ) + self._reader = _NWBReader(reading_method=reading_method, timeseries_path=timeseries_path) + self._reader.load_timeseries( + file_path=file_path, + file=file, + stream_mode=stream_mode, + cache=cache, + stream_cache_path=stream_cache_path, + storage_options=storage_options, + ) + self.timeseries_path = self._reader.timeseries_path + + sampling_frequency, t_start, timestamps = self._reader.time_info(samples_for_rate_estimation) + if load_time_vector and timestamps is not None: + times_kwargs = dict(time_vector=timestamps) else: - channel_ids, sampling_frequency, dtype, segment_data, times_kwargs = ( - self._fetch_recording_segment_info_backend(file, cache, load_time_vector, samples_for_rate_estimation) - ) + times_kwargs = dict(sampling_frequency=sampling_frequency, t_start=t_start) + segment_data = self._reader.data() + num_channels = 1 if segment_data.ndim == 1 else segment_data.shape[1] + channel_ids = np.arange(num_channels) + dtype = self._reader.dtype() BaseRecording.__init__(self, channel_ids=channel_ids, sampling_frequency=sampling_frequency, dtype=dtype) recording_segment = NwbTimeSeriesSegment( @@ -1498,139 +1529,18 @@ def __init__( "file": file, } - if use_pynwb: + if reading_method == "use_pynwb": self.extra_requirements.append("pynwb") - else: - if self.backend == "hdf5": - self.extra_requirements.append("h5py") - if self.backend == "zarr": - self.extra_requirements.append("zarr") + elif reading_method == "use_hdf5": + self.extra_requirements.append("h5py") + elif reading_method == "use_zarr": + self.extra_requirements.append("zarr") if self.stream_mode == "fsspec": self.extra_requirements.append("fsspec") elif self.stream_mode == "remfile": self.extra_requirements.append("remfile") - def _fetch_recording_segment_info(self, file, cache, load_time_vector, samples_for_rate_estimation): - self._nwbfile = read_nwbfile( - backend=self.backend, - file_path=self.file_path, - file=file, - stream_mode=self.stream_mode, - cache=cache, - stream_cache_path=self.stream_cache_path, - storage_options=self.storage_options, - ) - - from pynwb.base import TimeSeries - - time_series_dict: dict[str, TimeSeries] = {} - - for item in self._nwbfile.all_children(): - if isinstance(item, TimeSeries): - time_series_dict[item.data.name.replace("/data", "")[1:]] = item - - if self.timeseries_path is not None: - if self.timeseries_path not in time_series_dict: - raise ValueError(f"TimeSeries {self.timeseries_path} not found in file") - - else: - if len(time_series_dict) == 1: - self.timeseries_path = list(time_series_dict.keys())[0] - else: - raise ValueError( - f"Multiple TimeSeries found! Specify 'timeseries_path'. Options: {list(time_series_dict.keys())}" - ) - - timeseries = time_series_dict[self.timeseries_path] - - # Get sampling frequency and timing info - if hasattr(timeseries, "rate") and timeseries.rate is not None: - sampling_frequency = timeseries.rate - t_start = timeseries.starting_time if hasattr(timeseries, "starting_time") else 0 - timestamps = None - elif hasattr(timeseries, "timestamps"): - timestamps = timeseries.timestamps - sampling_frequency = 1.0 / np.median(np.diff(timestamps[:samples_for_rate_estimation])) - t_start = timestamps[0] - else: - raise ValueError("TimeSeries must have either starting_time or timestamps") - - if load_time_vector and timestamps is not None: - times_kwargs = dict(time_vector=timestamps) - else: - times_kwargs = dict(sampling_frequency=sampling_frequency, t_start=t_start) - - # Create channel IDs based on data shape - data = timeseries.data - if data.ndim == 1: - num_channels = 1 - else: - num_channels = data.shape[1] - channel_ids = np.arange(num_channels) - dtype = data.dtype - - return channel_ids, sampling_frequency, dtype, data, times_kwargs - - def _fetch_recording_segment_info_backend(self, file, cache, load_time_vector, samples_for_rate_estimation): - open_file = read_file_from_backend( - file_path=self.file_path, - file=file, - stream_mode=self.stream_mode, - cache=cache, - stream_cache_path=self.stream_cache_path, - storage_options=self.storage_options, - ) - - # If timeseries_path not provided, find all TimeSeries objects - if self.timeseries_path is None: - available_timeseries = _find_timeseries_from_backend(open_file, backend=self.backend) - if len(available_timeseries) == 1: - self.timeseries_path = available_timeseries[0] - else: - raise ValueError( - f"Multiple TimeSeries found! Specify 'timeseries_path'. Options: {available_timeseries}" - ) - - # Get TimeSeries object - try: - timeseries = open_file[self.timeseries_path] - except KeyError: - available_timeseries = _find_timeseries_from_backend(open_file, backend=self.backend) - raise ValueError(f"{self.timeseries_path} not found! Available options: {available_timeseries}") - - # Get timing information - if "starting_time" in timeseries: - t_start = timeseries["starting_time"][()] - sampling_frequency = timeseries["starting_time"].attrs["rate"] - timestamps = None - elif "timestamps" in timeseries: - timestamps = timeseries["timestamps"][:] - sampling_frequency = 1.0 / np.median(np.diff(timestamps[:samples_for_rate_estimation])) - t_start = timestamps[0] - else: - raise ValueError("TimeSeries must have either starting_time or timestamps") - - if load_time_vector and timestamps is not None: - times_kwargs = dict(time_vector=timestamps) - else: - times_kwargs = dict(sampling_frequency=sampling_frequency, t_start=t_start) - - # Create channel IDs based on data shape - data = timeseries["data"] - if data.ndim == 1: - num_channels = 1 - else: - num_channels = data.shape[1] - channel_ids = np.arange(num_channels) - dtype = data.dtype - - # Store for later use - self.timeseries = timeseries - self._file = open_file - - return channel_ids, sampling_frequency, dtype, data, times_kwargs - @staticmethod def fetch_available_timeseries_paths( file_path: str | Path, @@ -1654,22 +1564,7 @@ def fetch_available_timeseries_paths( list[str] List of paths to TimeSeries objects. """ - if stream_mode is None: - backend = _get_backend_from_local_file(file_path) - else: - backend = "zarr" if stream_mode == "zarr" else "hdf5" - - file_handle = read_file_from_backend( - file_path=file_path, - stream_mode=stream_mode, - storage_options=storage_options, - ) - - timeseries_paths = _find_timeseries_from_backend( - file_handle, - backend=backend, - ) - return timeseries_paths + return _NWBReader.available_timeseries(file_path, stream_mode=stream_mode, storage_options=storage_options) class NwbTimeSeriesSegment(BaseRecordingSegment): From b7ba5f9093f0dd9c92a2fd1c9c0f3435e861c9d5 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 24 Jun 2026 20:36:19 -0600 Subject: [PATCH 4/6] add non-increasing electrode tests --- .../extractors/tests/test_nwbextractors.py | 57 +++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/src/spikeinterface/extractors/tests/test_nwbextractors.py b/src/spikeinterface/extractors/tests/test_nwbextractors.py index c2422600e4..af899e513b 100644 --- a/src/spikeinterface/extractors/tests/test_nwbextractors.py +++ b/src/spikeinterface/extractors/tests/test_nwbextractors.py @@ -349,6 +349,63 @@ def test_failure_with_wrong_electrical_series_path(generate_nwbfile, use_pynwb): ) +@pytest.mark.parametrize("use_pynwb", [True, False]) +def test_nwb_extractor_electrodes_region_out_of_order(tmp_path, use_pynwb): + """An ElectricalSeries may reference its electrodes in any order (e.g. channels reordered by + depth during processing). h5py rejects fancy indexing with non-increasing indices, so reading + such a file used to raise "Indexing elements must be in increasing order" (GH-4619).""" + from pynwb import NWBHDF5IO + from pynwb.ecephys import ElectricalSeries + from pynwb.testing.mock.file import mock_NWBFile + from pynwb.testing.mock.device import mock_Device + from pynwb.testing.mock.ecephys import mock_ElectrodeGroup + + nwbfile = mock_NWBFile() + device = mock_Device(name="probe") + nwbfile.add_device(device) + nwbfile.add_electrode_column(name="channel_name", description="channel name") + nwbfile.add_electrode_column(name="rel_x", description="rel_x") + nwbfile.add_electrode_column(name="rel_y", description="rel_y") + nwbfile.add_electrode_column(name="property", description="A property") + nwbfile.add_electrode_column(name="offset", description="offset") + electrode_group = mock_ElectrodeGroup(device=device) + nwbfile.add_electrode_group(electrode_group) + + num_electrodes = 5 + for index in range(num_electrodes): + nwbfile.add_electrode( + id=index, + group=electrode_group, + location="brain", + channel_name=f"ch{index}", + rel_x=float(index), + rel_y=float(index), + property=f"prop{index}", + offset=float(index), + ) + + # The region is deliberately not in increasing order. + region = [4, 2, 0, 3, 1] + electrode_region = nwbfile.create_electrode_table_region(region=region, description="electrodes") + data = np.random.default_rng(0).random(size=(100, num_electrodes)) + electrical_series = ElectricalSeries(name="ElectricalSeries", data=data, electrodes=electrode_region, rate=30_000.0) + nwbfile.add_acquisition(electrical_series) + + nwbfile_path = tmp_path / "out_of_order.nwb" + with NWBHDF5IO(str(nwbfile_path), mode="w") as io: + io.write(nwbfile) + + recording = NwbRecordingExtractor( + nwbfile_path, electrical_series_path="acquisition/ElectricalSeries", use_pynwb=use_pynwb + ) + + # Everything pulled from the electrodes table must follow the region order, not the table order. + assert np.array_equal(recording.channel_ids, np.array([f"ch{i}" for i in region])) + assert np.array_equal(recording.get_channel_locations(), np.array([[float(i), float(i)] for i in region])) + assert np.array_equal(recording.get_property("property"), np.array([f"prop{i}" for i in region])) + assert np.array_equal(recording.get_channel_offsets(), np.array([float(i) for i in region]) * 1e6) + + @pytest.mark.parametrize("use_pynwb", [True, False]) def test_sorting_extraction_of_ragged_arrays(tmp_path, use_pynwb): from pynwb import NWBHDF5IO From c5c2f5a03e5b12a989ba9c97e221c110e1328c61 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 25 Jun 2026 15:44:59 -0600 Subject: [PATCH 5/6] fix issue --- .../extractors/nwbextractors.py | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index 202d880244..3c5fa0aa27 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -1299,6 +1299,31 @@ def _fetch_properties(self, columns): return properties + @staticmethod + def fetch_available_units_tables( + file_path: str | Path, + stream_mode: Optional[Literal["fsspec", "remfile", "zarr"]] = None, + storage_options: dict | None = None, + ) -> list[str]: + """ + Retrieves the paths to all Units tables within an NWB (Neurodata Without Borders) file. + + Parameters + ---------- + file_path : str or Path + The path to the NWB (Neurodata Without Borders) file. + stream_mode : "fsspec" | "remfile" | "zarr" | None, optional + Determines the streaming mode for reading the file. + storage_options : dict | None, default: None + Additional kwargs (e.g. AWS credentials) passed to zarr.open. Only used with "zarr" stream_mode. + + Returns + ------- + list of str + Paths to all Units tables found in the file. + """ + return _NWBReader.available_units_tables(file_path, stream_mode=stream_mode, storage_options=storage_options) + class NwbSortingSegment(BaseSortingSegment): def __init__(self, spike_times_data, spike_times_index_data, sampling_frequency: float, t_start: float): From c7768d66830d584672bc4f00cd1819be5029cfcc Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 25 Jun 2026 15:52:45 -0600 Subject: [PATCH 6/6] remove tests --- .../extractors/tests/test_nwbextractors.py | 57 ------------------- 1 file changed, 57 deletions(-) diff --git a/src/spikeinterface/extractors/tests/test_nwbextractors.py b/src/spikeinterface/extractors/tests/test_nwbextractors.py index ab4bf0880a..eaf1fb1f24 100644 --- a/src/spikeinterface/extractors/tests/test_nwbextractors.py +++ b/src/spikeinterface/extractors/tests/test_nwbextractors.py @@ -406,63 +406,6 @@ def test_failure_with_wrong_electrical_series_path(generate_nwbfile, use_pynwb): ) -@pytest.mark.parametrize("use_pynwb", [True, False]) -def test_nwb_extractor_electrodes_region_out_of_order(tmp_path, use_pynwb): - """An ElectricalSeries may reference its electrodes in any order (e.g. channels reordered by - depth during processing). h5py rejects fancy indexing with non-increasing indices, so reading - such a file used to raise "Indexing elements must be in increasing order" (GH-4619).""" - from pynwb import NWBHDF5IO - from pynwb.ecephys import ElectricalSeries - from pynwb.testing.mock.file import mock_NWBFile - from pynwb.testing.mock.device import mock_Device - from pynwb.testing.mock.ecephys import mock_ElectrodeGroup - - nwbfile = mock_NWBFile() - device = mock_Device(name="probe") - nwbfile.add_device(device) - nwbfile.add_electrode_column(name="channel_name", description="channel name") - nwbfile.add_electrode_column(name="rel_x", description="rel_x") - nwbfile.add_electrode_column(name="rel_y", description="rel_y") - nwbfile.add_electrode_column(name="property", description="A property") - nwbfile.add_electrode_column(name="offset", description="offset") - electrode_group = mock_ElectrodeGroup(device=device) - nwbfile.add_electrode_group(electrode_group) - - num_electrodes = 5 - for index in range(num_electrodes): - nwbfile.add_electrode( - id=index, - group=electrode_group, - location="brain", - channel_name=f"ch{index}", - rel_x=float(index), - rel_y=float(index), - property=f"prop{index}", - offset=float(index), - ) - - # The region is deliberately not in increasing order. - region = [4, 2, 0, 3, 1] - electrode_region = nwbfile.create_electrode_table_region(region=region, description="electrodes") - data = np.random.default_rng(0).random(size=(100, num_electrodes)) - electrical_series = ElectricalSeries(name="ElectricalSeries", data=data, electrodes=electrode_region, rate=30_000.0) - nwbfile.add_acquisition(electrical_series) - - nwbfile_path = tmp_path / "out_of_order.nwb" - with NWBHDF5IO(str(nwbfile_path), mode="w") as io: - io.write(nwbfile) - - recording = NwbRecordingExtractor( - nwbfile_path, electrical_series_path="acquisition/ElectricalSeries", use_pynwb=use_pynwb - ) - - # Everything pulled from the electrodes table must follow the region order, not the table order. - assert np.array_equal(recording.channel_ids, np.array([f"ch{i}" for i in region])) - assert np.array_equal(recording.get_channel_locations(), np.array([[float(i), float(i)] for i in region])) - assert np.array_equal(recording.get_property("property"), np.array([f"prop{i}" for i in region])) - assert np.array_equal(recording.get_channel_offsets(), np.array([float(i) for i in region]) * 1e6) - - @pytest.mark.parametrize("use_pynwb", [True, False]) def test_sorting_extraction_of_ragged_arrays(tmp_path, use_pynwb): from pynwb import NWBHDF5IO