From 72a2bb039c185f538f3ff5902ff40e3463dfaeab Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 22 Jun 2026 15:13:20 +0200 Subject: [PATCH 01/11] feat: enhance probegroup API --- examples/ex_03_generate_probe_group.py | 60 +++++ examples/ex_05_device_channel_indices.py | 42 ++++ src/probeinterface/__init__.py | 6 +- src/probeinterface/io.py | 4 +- src/probeinterface/probegroup.py | 146 ++++++++++-- src/probeinterface/wiring.py | 19 ++ tests/test_probegroup.py | 273 +++++++++++++++++++++++ 7 files changed, 525 insertions(+), 25 deletions(-) diff --git a/examples/ex_03_generate_probe_group.py b/examples/ex_03_generate_probe_group.py index 8a640d3a..f8764278 100644 --- a/examples/ex_03_generate_probe_group.py +++ b/examples/ex_03_generate_probe_group.py @@ -46,4 +46,64 @@ plot_probegroup(probegroup, same_axes=False, with_contact_id=True) +############################################################################## +# Identifying probes with a ``probe_id`` +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Each probe in a `ProbeGroup` can be given a human-readable ``probe_id`` when +# it is added. This is handy to keep track of which probe targets which brain +# area or hemisphere. If no ``probe_id`` is given, a default one +# (``"probe_1"``, ``"probe_2"``, ...) is generated automatically. + +probe0 = generate_dummy_probe(elec_shapes='square') +probe1 = generate_dummy_probe(elec_shapes='circle') +probe1.move([250, -90]) + +probegroup = ProbeGroup() +probegroup.add_probe(probe0, probe_id="left_hemisphere") +probegroup.add_probe(probe1, probe_id="right_hemisphere") + +print(probegroup) +print("probe_ids:", probegroup.probe_ids) + +############################################################################## +# `ProbeGroup.select_contacts()` returns a new `ProbeGroup` with a sub-selection +# of contacts. The selection can be done by ``contact_ids``, by ``probe_ids``, +# or by both at the same time. +# +# Selecting by ``probe_ids`` alone keeps every contact of the matching probes, +# which is a convenient way to grab a whole hemisphere: + +left_hemisphere = probegroup.select_contacts(probe_ids=["left_hemisphere"]) +print("contacts in the left hemisphere:", left_hemisphere.get_contact_count()) + +############################################################################## +# We can also select by ``contact_ids``. Note that if ``contact_ids`` are not +# unique across probes, the selection will be ambiguous and an error will be +# raised. In this case, providing ``probe_ids`` disambiguates the selection: + +# check if any contact_id is not unique across probes +contact_ids = probegroup.get_global_contact_ids() +if len(contact_ids) != len(set(contact_ids)): + print("contact_ids are not unique across probes, you should provide probe_ids to disambiguate") + +############################################################################## +# Because the contact ids are not unique across probes, combining ``contact_ids`` +# with ``probe_ids`` lets us pull specific contacts from a single hemisphere: + +left_contacts = probegroup.select_contacts(contact_ids=["0", "1", "2"], probe_ids=["left_hemisphere"]) +print("contacts selected from the left hemisphere:", left_contacts.get_contact_count()) + +left_and_right_contacts = probegroup.select_contacts( + contact_ids=["0", "1", "2"], + probe_ids=["left_hemisphere", "right_hemisphere"] +) +print("contacts selected from the left and right hemispheres:", left_and_right_contacts.get_contact_count()) + +# Without providing probe_ids, the selection is ambiguous and an error is raised: +try: + ambiguous_selection = probegroup.select_contacts(contact_ids=["0", "1", "2"]) +except ValueError as e: + print("Error raised for ambiguous selection:", e) + plt.show() diff --git a/examples/ex_05_device_channel_indices.py b/examples/ex_05_device_channel_indices.py index 5731c910..e9d0c726 100644 --- a/examples/ex_05_device_channel_indices.py +++ b/examples/ex_05_device_channel_indices.py @@ -87,4 +87,46 @@ fig, ax = plt.subplots() plot_probegroup(probegroup, with_contact_id=True, same_axes=True, ax=ax) +############################################################################## +# Reordering contacts with a global contact order +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# By default the contact order of a `ProbeGroup` is the "natural" one: the +# contacts of each probe are stacked one probe after the other. But sometimes +# the contacts of the different probes are *interleaved* in the recording file +# (e.g. the acquisition system alternates between probes sample by sample). +# +# `ProbeGroup.set_global_contact_order()` lets us store this external ordering. +# The order is an array of indices into the natural (stacked) order, and it is +# applied whenever the group is exported with `to_numpy()` / `to_dataframe()`. + +probegroup = ProbeGroup() +probegroup.add_probe(probe0.copy()) +probegroup.add_probe(probe1.copy()) + +n = probegroup.get_contact_count() +print("default global contact order:", probegroup._global_contact_order) + +# interleave probe0 and probe1 contacts as they appear in the recording file +global_contact_order = np.zeros(n, dtype="int64") +global_contact_order[0::2] = np.arange(0, n // 2) # probe0 contacts +global_contact_order[1::2] = np.arange(n // 2, n) # probe1 contacts +probegroup.set_global_contact_order(global_contact_order) + +############################################################################## +# Now `to_numpy()` returns the contacts in the interleaved order: the +# ``probe_index`` column alternates between the two probes. + +contact_vector = probegroup.to_numpy() +print("probe_index in global order:", contact_vector["probe_index"][:8]) + +############################################################################## +# The global order interacts with `set_global_device_channel_indices()`: the +# ``device_channel_indices`` you pass are interpreted in the (reordered) order +# returned by `to_numpy()`, so they map directly onto the acquisition channels. + +probegroup.set_global_device_channel_indices(np.arange(n)) +print("device_channel_indices (global order):", + probegroup.to_numpy(complete=True)["device_channel_indices"][:8]) + plt.show() diff --git a/src/probeinterface/__init__.py b/src/probeinterface/__init__.py index 45e102bf..ff52e8a3 100644 --- a/src/probeinterface/__init__.py +++ b/src/probeinterface/__init__.py @@ -52,4 +52,8 @@ cache_full_library, clear_cache, ) -from .wiring import get_available_pathways +from .wiring import ( + get_available_pathways, + get_pathway, + wire_probe +) diff --git a/src/probeinterface/io.py b/src/probeinterface/io.py index 90849e81..263cd2a7 100644 --- a/src/probeinterface/io.py +++ b/src/probeinterface/io.py @@ -326,7 +326,7 @@ def write_BIDS_probe(folder: str | Path, probe_or_probegroup: Probe | ProbeGroup probegroup = probe_or_probegroup else: raise TypeError( - f"probe_or_probegroup has to be" "of type Probe or ProbeGroup " f"not type: {type(probe_or_probegroup)}" + f"probe_or_probegroup has to be" "of type Probe or ProbeGroup not type: {type(probe_or_probegroup)}" ) folder = Path(folder) @@ -352,7 +352,7 @@ def write_BIDS_probe(folder: str | Path, probe_or_probegroup: Probe | ProbeGroup ) if "type" not in probe.annotations: raise ValueError( - "Export to BIDS probe format requires " "the probe type to be specified as an " "annotation (type)" + "Export to BIDS probe format requires " "the probe type to be specified as an annotation (type)" ) # extract all used annotation keys diff --git a/src/probeinterface/probegroup.py b/src/probeinterface/probegroup.py index 26f23bff..6fdcd7a4 100644 --- a/src/probeinterface/probegroup.py +++ b/src/probeinterface/probegroup.py @@ -9,21 +9,30 @@ class ProbeGroup: Internally, this is represented as a list of Probe object. - The ProbeGroup is the object saved in the json based probeinterface format, even if there only one probe. + The ProbeGroup is the object saved in the json based probeinterface format, even if there is only one probe. Tiny detail: when using `PropbeGroup.to_numpy()` / `PropbeGroup.to_dataframe()` by default the contact order - is the "natural" one (stacked order of each probe). But optionally, this order can be more complex, for instance - some contact of each probe are interleaved, in this case a optional reordering can be applied. - - - + is the "natural" one (stacked order of each probe). An external contact order can be applied using the + ``ProbeGroup.set_global_contact_order()`` method, and the contact order is then stored in the + ``ProbeGroup._global_contact_order`` attribute. In this case, the contact order of the ProbeGroup is not "natural" + anymore, but the one defined by the user. This is useful for instance when some contact of each probe are + interleaved in the recording file. """ def __init__(self): self.probes = [] + self.probe_ids = [] self._global_contact_order = None - def add_probe(self, probe: Probe) -> None: + def __repr__(self): + repr_str = f"ProbeGroup: {len(self.probes)} probes - {self.get_contact_count()} contacts" + if self._global_contact_order is not None: + repr_str += " (with custom global contact order)" + for probe, probe_id in zip(self.probes, self.probe_ids): + repr_str += f"\n\t{probe_id}: {probe}" + return repr_str + + def add_probe(self, probe: Probe, probe_id: str = None) -> None: """ Add an additional probe to the ProbeGroup @@ -31,14 +40,36 @@ def add_probe(self, probe: Probe) -> None: ---------- probe: Probe The probe to add to the ProbeGroup + probe_id: str, optional + The ID to assign to the probe. If None, a unique ID will be generated. """ if len(self.probes) > 0: self._check_compatible(probe) self.probes.append(probe) + if probe_id is not None: + self.probe_ids.append(probe_id) + else: + self.probe_ids.append(f"probe_{len(self.probes)}") probe._probe_group = self + def set_probe_ids(self, probe_ids: list) -> None: + """ + Set the probe IDs for the ProbeGroup. + + Parameters + ---------- + probe_ids: list + A list of IDs to assign to the probes. + The length of the list must match the number of probes in the ProbeGroup. + """ + if len(probe_ids) != len(self.probes): + raise ValueError( + f"Length of probe_ids ({len(probe_ids)}) does not match number of probes ({len(self.probes)})" + ) + self.probe_ids = probe_ids + def _check_compatible(self, probe: Probe) -> None: if probe._probe_group is not None: raise ValueError( @@ -68,12 +99,7 @@ def copy(self) -> "ProbeGroup": copy: ProbeGroup A copy of the ProbeGroup """ - copy = ProbeGroup() - for probe in self.probes: - copy.add_probe(probe.copy()) - global_device_channel_indices = self.get_global_device_channel_indices()["device_channel_indices"] - copy.set_global_device_channel_indices(global_device_channel_indices) - return copy + return ProbeGroup.from_dict(self.to_dict(array_as_list=False)) def get_contact_count(self) -> int: """ @@ -207,13 +233,13 @@ def to_dict(self, array_as_list: bool = False) -> dict: """ d = {} d["probes"] = [] - for probe_ind, probe in enumerate(self.probes): + for probe in self.probes: probe_dict = probe.to_dict(array_as_list=array_as_list) d["probes"].append(probe_dict) if self._global_contact_order is not None: global_contact_order = self._global_contact_order if array_as_list: - global_contact_order = global_contact_order.to_list() + global_contact_order = global_contact_order.tolist() d["global_contact_order"] = global_contact_order return d @@ -242,6 +268,7 @@ def from_dict(d: dict) -> "ProbeGroup": return probegroup + # TODO: this should only return the device_channel_indices, not the probe_index!!! def get_global_device_channel_indices(self) -> np.ndarray: """ Gets the global device channels indices and returns as @@ -267,15 +294,15 @@ def get_global_device_channel_indices(self) -> np.ndarray: def set_global_device_channel_indices(self, device_channel_indices: np.ndarray | list) -> None: """ - Set global indices for all probes. + Set global device channel indices for all probes. - Important note : if the order of contacts is not "natural" then the device_channel_indices - is applied is the real/reordered contacts vector. In short, the device_channel_indices is zipped to + Important note: if the probegroup has ``_global_contact_order``, then the device_channel_indices + are reordered before being set. In short, the ``device_channel_indices`` is zipped to ProbeGroup.to_numpy() (always ordered). Parameters ---------- - channels: np.ndarray | list + device_channel_indices: np.ndarray | list The device channal indices to be set """ device_channel_indices = np.asarray(device_channel_indices) @@ -308,7 +335,7 @@ def get_global_contact_ids(self) -> np.ndarray: Returns ------- contact_ids: np.ndarray - An array of the contaact ids across all probes + An array of the contact ids across all probes """ contact_ids = self.to_numpy(complete=True)["contact_ids"] return contact_ids @@ -368,12 +395,87 @@ def get_slice(self, selection: np.ndarray[bool | int]) -> "ProbeGroup": contact_arr = self.to_numpy(complete=True) contact_arr = contact_arr[selection] + original_probe_indices = np.unique(contact_arr["probe_index"]) sliced_probe_group = ProbeGroup.from_numpy(contact_arr) - - # TODO annoatation probe per probe!! + new_probe_indices = np.unique(sliced_probe_group.to_numpy(complete=True)["probe_index"]) + + # Map annotations of the original probegroup to the sliced one + new_probe_ids = [self.probe_ids[i] for i in original_probe_indices] + sliced_probe_group.set_probe_ids(new_probe_ids) + for original_probe_index, new_probe_index in zip(original_probe_indices, new_probe_indices): + orig_probe = self.probes[original_probe_index] + new_probe = sliced_probe_group.probes[new_probe_index] + + for k in orig_probe.annotations: + if k not in new_probe.annotations: + new_probe.annotate(**{k: orig_probe.annotations[k]}) return sliced_probe_group + def select_contacts(self, contact_ids: np.ndarray | list | None = None, probe_ids: np.ndarray | list | None = None) -> "ProbeGroup": + """ + Get a copy of the ProbeGroup with a sub selection of contacts based on contact ids and probe ids. + + Parameters + ---------- + contact_ids : np.array or list or None, default: None + The contact ids to select. If None, all contacts are selected, but probe_ids must be provided. + probe_ids : np.array or list or None, default: None + The probe ids to select. If contact_ids are not unique across probes, + then probe_ids should be provided to disambiguate. + If contact_ids are unique across probes, then probe_ids can be None. + + Returns + ------- + sliced_probe_group: ProbeGroup + The sliced probe group + """ + if contact_ids is None and probe_ids is None: + raise ValueError( + "Either contact_ids or probe_ids must be provided for selection." + ) + if contact_ids is None: + contact_mask = np.ones(self.get_contact_count(), dtype=bool) + else: + contact_ids = np.asarray(contact_ids) + all_contact_ids = self.get_global_contact_ids() + contact_mask = np.isin(all_contact_ids, contact_ids) + if probe_ids is None: + # without probe_ids the selection must be unambiguous: every requested + # contact id must match a single contact across the whole ProbeGroup + matched_ids = all_contact_ids[contact_mask] + unique_ids, counts = np.unique(matched_ids, return_counts=True) + ambiguous_ids = unique_ids[counts > 1] + if ambiguous_ids.size > 0: + raise ValueError( + f"contact_ids {ambiguous_ids.tolist()} are not unique across probes, " + "you should provide probe_ids to disambiguate" + ) + if probe_ids is None: + probe_mask = np.ones(self.get_contact_count(), dtype=bool) + else: + all_probe_ids = np.asarray(self.probe_ids)[self.to_numpy(complete=True)["probe_index"]] + probe_ids = np.asarray(probe_ids) + probe_mask = np.isin(all_probe_ids, probe_ids) + selection_mask = contact_mask & probe_mask + return self.get_slice(selection_mask) + + def set_global_contact_order(self, global_contact_order: np.ndarray | list) -> None: + """ + Set the global contact order for the ProbeGroup. This is useful when some contact of each probe are interleaved in the recording file. + + Parameters + ---------- + global_contact_order: np.ndarray | list + The global contact order to be set. It should be an array of indices that defines the new order of contacts across all probes. + """ + global_contact_order = np.asarray(global_contact_order) + if global_contact_order.size != self.get_contact_count(): + raise ValueError( + f"Wrong global contact order size {global_contact_order.size} for the number of channels {self.get_contact_count()}" + ) + self._global_contact_order = global_contact_order + def check_global_device_wiring_and_ids(self) -> None: # check unique device_channel_indices for !=-1 chans = self.get_global_device_channel_indices() diff --git a/src/probeinterface/wiring.py b/src/probeinterface/wiring.py index 8378ad7b..6f8cb900 100644 --- a/src/probeinterface/wiring.py +++ b/src/probeinterface/wiring.py @@ -82,6 +82,25 @@ def get_available_pathways() -> list: return list(pathways.keys()) +def get_pathway(pathway: str) -> np.ndarray: + """Return the channel indices for a given pathway + + Parameters + ---------- + pathway : str + The pathway to use + + Returns + ------- + chan_indices : np.ndarray + The channel indices for the given pathway + """ + assert pathway in pathways, ( + f"{pathway} is not a currently supported pathway " f"run `get_available_pathways to see options" + ) + return np.array(pathways[pathway], dtype="int64") + + def wire_probe(probe: "Probe", pathway: str, channel_offset: int = 0): """Inplace wiring for a Probe using a pathway diff --git a/tests/test_probegroup.py b/tests/test_probegroup.py index 089c642a..a7bd13b2 100644 --- a/tests/test_probegroup.py +++ b/tests/test_probegroup.py @@ -264,6 +264,91 @@ def test_get_slice_all_contacts(probegroup): ) +# ── get_slice : probe annotations and probe_ids propagation ───────────────── + + +def _annotated_probegroup(): + """ProbeGroup with 3 probes, each carrying distinct annotations and probe_id.""" + pg = ProbeGroup() + for i in range(3): + probe = generate_dummy_probe() + probe.move([i * 200, 0]) + probe.annotate(brain_area=f"area_{i}", shank=f"s{i}") + pg.add_probe(probe, probe_id=f"probe_{i}") + return pg + + +def test_get_slice_propagates_annotations(): + """Annotations of each original probe are propagated to the sliced probe.""" + pg = _annotated_probegroup() + n_each = pg.probes[0].get_contact_count() + + # take a few contacts from each of the 3 probes + sel = np.array([0, 1, n_each, n_each + 1, 2 * n_each, 2 * n_each + 1]) + sub = pg.get_slice(sel) + + assert len(sub.probes) == 3 + for i, probe in enumerate(sub.probes): + assert probe.annotations["brain_area"] == f"area_{i}" + assert probe.annotations["shank"] == f"s{i}" + + +def test_get_slice_maps_annotations_to_correct_probe_when_skipping(): + """ + When the selection skips a middle probe, annotations must still map to the + correct sliced probe (not shift by position). + """ + pg = _annotated_probegroup() + n_each = pg.probes[0].get_contact_count() + + # contacts only from probe 0 and probe 2 (probe 1 is skipped entirely) + sel = np.zeros(pg.get_contact_count(), dtype=bool) + sel[0:3] = True + sel[2 * n_each : 2 * n_each + 4] = True + sub = pg.get_slice(sel) + + assert len(sub.probes) == 2 + # first sliced probe corresponds to original probe 0, second to original probe 2 + assert sub.probes[0].annotations["brain_area"] == "area_0" + assert sub.probes[1].annotations["brain_area"] == "area_2" + assert sub.probes[0].get_contact_count() == 3 + assert sub.probes[1].get_contact_count() == 4 + + +def test_get_slice_sets_probe_ids(): + """probe_ids are carried over to the sliced ProbeGroup.""" + pg = _annotated_probegroup() + n_each = pg.probes[0].get_contact_count() + + sel = np.array([0, 1, n_each, 2 * n_each]) + sub = pg.get_slice(sel) + assert sub.probe_ids == ["probe_0", "probe_1", "probe_2"] + + +def test_get_slice_sets_probe_ids_when_skipping(): + """probe_ids reflect only the probes present in the selection, in order.""" + pg = _annotated_probegroup() + n_each = pg.probes[0].get_contact_count() + + # contacts only from probe 0 and probe 2 + sel = np.array([0, 2 * n_each]) + sub = pg.get_slice(sel) + assert len(sub.probes) == 2 + assert sub.probe_ids == ["probe_0", "probe_2"] + + +def test_get_slice_single_probe_keeps_probe_id_and_annotations(): + """Slicing contacts from a single probe keeps that probe's id and annotations.""" + pg = _annotated_probegroup() + n_each = pg.probes[0].get_contact_count() + + sel = np.arange(n_each, n_each + 3) # only probe 1 + sub = pg.get_slice(sel) + assert len(sub.probes) == 1 + assert sub.probe_ids == ["probe_1"] + assert sub.probes[0].annotations["brain_area"] == "area_1" + + # ── global_contact_order : to_numpy/from_numpy, to_dict/from_dict, get_slice @@ -299,6 +384,194 @@ def test_reordred_probegroup(probegroup): assert probegroup6._global_contact_order is None +# ── set_global_contact_order() tests ──────────────────────────────────────── + + +def _reorder_indices(): + """An interleaved order over the 96 contacts of the default probegroup.""" + return np.concatenate([np.arange(0, 96, 2), np.arange(95, 0, -2)]) + + +def test_set_global_contact_order_reorders_to_numpy(probegroup): + """set_global_contact_order reorders the contact vector returned by to_numpy.""" + order = _reorder_indices() + natural = probegroup.to_numpy(complete=True).copy() + + probegroup.set_global_contact_order(order) + + assert probegroup._global_contact_order is not None + reordered = probegroup.to_numpy(complete=True) + np.testing.assert_array_equal(reordered, natural[order]) + + +def test_set_global_contact_order_reorders_positions(probegroup): + """get_global_contact_positions reflects the custom order.""" + order = _reorder_indices() + natural_positions = probegroup.get_global_contact_positions().copy() + + probegroup.set_global_contact_order(order) + + reordered_positions = probegroup.get_global_contact_positions() + np.testing.assert_array_equal(reordered_positions, natural_positions[order]) + + +def test_set_global_contact_order_wrong_size(probegroup): + """A global contact order that does not match the contact count raises ValueError.""" + with pytest.raises(ValueError, match="Wrong global contact order size"): + probegroup.set_global_contact_order(np.arange(5)) + + +def test_set_global_contact_order_accepts_list(probegroup): + """A plain list is accepted and stored as an array.""" + order = list(_reorder_indices()) + probegroup.set_global_contact_order(order) + assert isinstance(probegroup._global_contact_order, np.ndarray) + + +def test_set_global_contact_order_device_channel_indices_consistency(probegroup): + """ + device_channel_indices are zipped to to_numpy() (which is reordered), + so setting them after a custom order must roundtrip through to_numpy. + """ + order = _reorder_indices() + probegroup.set_global_contact_order(order) + + n = probegroup.get_contact_count() + device_channel_indices = np.arange(n) + probegroup.set_global_device_channel_indices(device_channel_indices) + + got = probegroup.to_numpy(complete=True)["device_channel_indices"] + np.testing.assert_array_equal(got, device_channel_indices) + + +def test_set_global_contact_order_roundtrip_dict(probegroup): + """The custom order survives a to_dict/from_dict roundtrip.""" + order = _reorder_indices() + probegroup.set_global_contact_order(order) + + other = ProbeGroup.from_dict(probegroup.to_dict()) + assert other._global_contact_order is not None + np.testing.assert_array_equal( + other.to_numpy(complete=True), + probegroup.to_numpy(complete=True), + ) + + +# ── select_contacts() tests ───────────────────────────────────────────────── + + +def _probegroup_with_contact_ids(unique=True): + """ProbeGroup with 3 probes whose contact_ids are unique (or duplicated) across probes.""" + pg = ProbeGroup() + for i in range(3): + probe = generate_dummy_probe() + probe.move([i * 100, i * 80]) + n = probe.get_contact_count() + if unique: + probe.set_contact_ids([f"p{i}c{j}" for j in range(n)]) + else: + probe.set_contact_ids([f"c{j}" for j in range(n)]) + pg.add_probe(probe) + return pg + + +def test_select_contacts_unique_ids(): + """Selecting by globally unique contact ids returns exactly those contacts.""" + pg = _probegroup_with_contact_ids(unique=True) + selected_ids = ["p0c0", "p0c1", "p2c5"] + sub = pg.select_contacts(selected_ids) + + assert sub.get_contact_count() == 3 + # contacts come from two distinct probes + assert len(sub.probes) == 2 + assert set(sub.get_global_contact_ids()) == set(selected_ids) + + +def test_select_contacts_single_probe(): + """Selecting contacts from a single probe keeps a single probe.""" + pg = _probegroup_with_contact_ids(unique=True) + sub = pg.select_contacts(["p1c0", "p1c1", "p1c2"]) + assert sub.get_contact_count() == 3 + assert len(sub.probes) == 1 + + +def test_select_contacts_ambiguous_ids_without_probe_ids_raises(): + """ + Without probe_ids, a contact id that exists on more than one probe is + ambiguous and raises a ValueError naming the offending id(s). + """ + pg = _probegroup_with_contact_ids(unique=False) + with pytest.raises(ValueError, match="c0"): + pg.select_contacts(["c0"]) + + +def test_select_contacts_with_probe_ids(): + """probe_ids disambiguate duplicated contact ids to a single probe.""" + pg = _probegroup_with_contact_ids(unique=False) + sub = pg.select_contacts(["c0", "c1"], probe_ids=["probe_1"]) + assert sub.get_contact_count() == 2 + assert len(sub.probes) == 1 + np.testing.assert_array_equal(sorted(sub.get_global_contact_ids()), ["c0", "c1"]) + + +def test_select_contacts_probe_ids_subset_of_probes(): + """probe_ids can restrict the selection to a subset of probes.""" + pg = _probegroup_with_contact_ids(unique=False) + sub = pg.select_contacts(["c0"], probe_ids=["probe_1", "probe_3"]) + assert sub.get_contact_count() == 2 + assert len(sub.probes) == 2 + + +def test_select_contacts_by_probe_ids_only(): + """Selecting by probe_ids alone keeps every contact of the matching probes.""" + pg = _probegroup_with_contact_ids(unique=False) + n_per_probe = pg.probes[0].get_contact_count() + + sub = pg.select_contacts(probe_ids=["probe_1"]) + assert sub.get_contact_count() == n_per_probe + assert len(sub.probes) == 1 + + sub_two = pg.select_contacts(probe_ids=["probe_1", "probe_3"]) + assert sub_two.get_contact_count() == 2 * n_per_probe + assert len(sub_two.probes) == 2 + + +def test_select_contacts_requires_some_selection(): + """Calling with neither contact_ids nor probe_ids raises ValueError.""" + pg = _probegroup_with_contact_ids(unique=False) + with pytest.raises(ValueError, match="Either contact_ids or probe_ids"): + pg.select_contacts() + + +def test_select_contacts_too_many_ids_without_probe_ids_raises(): + """ + Requesting more contact ids than the number of unique ids without probe_ids + raises a ValueError. + """ + pg = _probegroup_with_contact_ids(unique=False) + n_unique = len(np.unique(pg.get_global_contact_ids())) + too_many = [f"c{j}" for j in range(n_unique + 1)] + with pytest.raises(ValueError, match="not unique across probes"): + pg.select_contacts(too_many) + + +def test_select_contacts_preserves_positions(): + """Selected contacts keep their global positions.""" + pg = _probegroup_with_contact_ids(unique=True) + selected_ids = ["p0c0", "p0c1", "p2c5"] + + all_ids = pg.get_global_contact_ids() + all_positions = pg.get_global_contact_positions() + expected = np.vstack([all_positions[all_ids == cid] for cid in selected_ids]) + + sub = pg.select_contacts(selected_ids) + sub_ids = sub.get_global_contact_ids() + sub_positions = sub.get_global_contact_positions() + got = np.vstack([sub_positions[sub_ids == cid] for cid in selected_ids]) + + np.testing.assert_array_equal(got, expected) + + if __name__ == "__main__": probegroup = _make_probegroup() From 6b78f9de0a8eb6ed58a262f2afa831fdcababc6e Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 22 Jun 2026 16:20:56 +0200 Subject: [PATCH 02/11] fix: update BIDS writer/reader to use built-in probe_ids --- src/probeinterface/io.py | 27 +++------- src/probeinterface/probegroup.py | 93 +++++++++++++++++--------------- tests/test_probegroup.py | 7 ++- 3 files changed, 62 insertions(+), 65 deletions(-) diff --git a/src/probeinterface/io.py b/src/probeinterface/io.py index 263cd2a7..c1efb239 100644 --- a/src/probeinterface/io.py +++ b/src/probeinterface/io.py @@ -203,10 +203,9 @@ def read_BIDS_probe(folder: str | Path, prefix: str | None = None) -> ProbeGroup # create probe object and register with probegroup probe = Probe.from_dataframe(df=df_probe) - probe.annotate(probe_id=probe_id) probes[str(probe_id)] = probe - probegroup.add_probe(probe) + probegroup.add_probe(probe, probe_id=str(probe_id)) ignore_annotations = [ "probe_ids", @@ -326,7 +325,7 @@ def write_BIDS_probe(folder: str | Path, probe_or_probegroup: Probe | ProbeGroup probegroup = probe_or_probegroup else: raise TypeError( - f"probe_or_probegroup has to be" "of type Probe or ProbeGroup not type: {type(probe_or_probegroup)}" + f"probe_or_probegroup has to beof type Probe or ProbeGroup not type: {type(probe_or_probegroup)}" ) folder = Path(folder) @@ -337,22 +336,12 @@ def write_BIDS_probe(folder: str | Path, probe_or_probegroup: Probe | ProbeGroup probes = probegroup.probes # Step 1: GENERATION OF PROBE.TSV - # ensure required keys (probe_id, probe_type) are present - - if any("probe_id" not in p.annotations for p in probes): - probegroup.auto_generate_probe_ids() + # ensure required keys (probe_type) are present for probe in probes: - if "probe_id" not in probe.annotations: - raise ValueError( - "Export to BIDS probe format requires " - "the probe id to be specified as an annotation " - "(probe_id). You can do this via " - "`probegroup.auto_generate_ids." - ) if "type" not in probe.annotations: raise ValueError( - "Export to BIDS probe format requires " "the probe type to be specified as an annotation (type)" + "Export to BIDS probe format requires the probe type to be specified as an annotation (type)" ) # extract all used annotation keys @@ -361,11 +350,12 @@ def write_BIDS_probe(folder: str | Path, probe_or_probegroup: Probe | ProbeGroup annotation_keys = np.unique(keys_concatenated) # generate a tsv table capturing probe information - index = range(len([p.annotations["probe_id"] for p in probes])) + index = range(len(probes)) df = pd.DataFrame(index=index) for annotation_key in annotation_keys: df[annotation_key] = [p.annotations[annotation_key] for p in probes] df["n_shanks"] = [len(np.unique(p.shank_ids)) for p in probes] + df["probe_id"] = probegroup.probe_ids # Note: in principle it would also be possible to add the probe width and # depth here based on the probe contour information. However this would @@ -378,8 +368,7 @@ def write_BIDS_probe(folder: str | Path, probe_or_probegroup: Probe | ProbeGroup # Step 2: GENERATION OF PROBE.JSON probes_dict = {} - for probe in probes: - probe_id = probe.annotations["probe_id"] + for probe_id, probe in zip(probegroup.probe_ids, probes): probes_dict[probe_id] = { "contour": probe.probe_planar_contour.tolist(), "units": probe.si_units, @@ -403,7 +392,7 @@ def write_BIDS_probe(folder: str | Path, probe_or_probegroup: Probe | ProbeGroup index = range(sum([p.get_contact_count() for p in probes])) df.rename(columns=tsv_label_map_to_BIDS, inplace=True) - df["probe_id"] = [p.annotations["probe_id"] for p in probes for _ in p.contact_ids] + df["probe_id"] = [probe_id for probe_id, probe in zip(probegroup.probe_ids, probes) for _ in probe.contact_ids] df["coordinate_system"] = ["relative cartesian"] * len(index) channel_indices = [] diff --git a/src/probeinterface/probegroup.py b/src/probeinterface/probegroup.py index 6fdcd7a4..11105402 100644 --- a/src/probeinterface/probegroup.py +++ b/src/probeinterface/probegroup.py @@ -12,23 +12,23 @@ class ProbeGroup: The ProbeGroup is the object saved in the json based probeinterface format, even if there is only one probe. Tiny detail: when using `PropbeGroup.to_numpy()` / `PropbeGroup.to_dataframe()` by default the contact order - is the "natural" one (stacked order of each probe). An external contact order can be applied using the - ``ProbeGroup.set_global_contact_order()`` method, and the contact order is then stored in the - ``ProbeGroup._global_contact_order`` attribute. In this case, the contact order of the ProbeGroup is not "natural" - anymore, but the one defined by the user. This is useful for instance when some contact of each probe are + is the "natural" one (stacked order of each probe). An external contact order can be applied using the + ``ProbeGroup.set_global_contact_order()`` method, and the contact order is then stored in the + ``ProbeGroup._global_contact_order`` attribute. In this case, the contact order of the ProbeGroup is not "natural" + anymore, but the one defined by the user. This is useful for instance when some contact of each probe are interleaved in the recording file. """ def __init__(self): self.probes = [] - self.probe_ids = [] + self._probe_ids = [] self._global_contact_order = None def __repr__(self): repr_str = f"ProbeGroup: {len(self.probes)} probes - {self.get_contact_count()} contacts" if self._global_contact_order is not None: repr_str += " (with custom global contact order)" - for probe, probe_id in zip(self.probes, self.probe_ids): + for probe, probe_id in zip(self.probes, self._probe_ids): repr_str += f"\n\t{probe_id}: {probe}" return repr_str @@ -49,26 +49,31 @@ def add_probe(self, probe: Probe, probe_id: str = None) -> None: self.probes.append(probe) if probe_id is not None: - self.probe_ids.append(probe_id) + self._probe_ids.append(probe_id) else: - self.probe_ids.append(f"probe_{len(self.probes)}") + self._probe_ids.append(f"probe_{len(self.probes)}") probe._probe_group = self - def set_probe_ids(self, probe_ids: list) -> None: + @property + def probe_ids(self) -> list: + return self._probe_ids + + @probe_ids.setter + def probe_ids(self, probe_ids: list) -> None: """ Set the probe IDs for the ProbeGroup. Parameters ---------- probe_ids: list - A list of IDs to assign to the probes. + A list of IDs to assign to the probes. The length of the list must match the number of probes in the ProbeGroup. """ if len(probe_ids) != len(self.probes): raise ValueError( f"Length of probe_ids ({len(probe_ids)}) does not match number of probes ({len(self.probes)})" ) - self.probe_ids = probe_ids + self._probe_ids = probe_ids def _check_compatible(self, probe: Probe) -> None: if probe._probe_group is not None: @@ -236,6 +241,7 @@ def to_dict(self, array_as_list: bool = False) -> dict: for probe in self.probes: probe_dict = probe.to_dict(array_as_list=array_as_list) d["probes"].append(probe_dict) + d["probe_ids"] = self.probe_ids if self._global_contact_order is not None: global_contact_order = self._global_contact_order if array_as_list: @@ -261,6 +267,9 @@ def from_dict(d: dict) -> "ProbeGroup": for probe_dict in d["probes"]: probe = Probe.from_dict(probe_dict) probegroup.add_probe(probe) + probe_ids = d.get("probe_ids", None) + if probe_ids is not None: + probegroup.probe_ids = probe_ids global_contact_order = d.get("global_contact_order", None) if global_contact_order is not None: @@ -401,18 +410,20 @@ def get_slice(self, selection: np.ndarray[bool | int]) -> "ProbeGroup": # Map annotations of the original probegroup to the sliced one new_probe_ids = [self.probe_ids[i] for i in original_probe_indices] - sliced_probe_group.set_probe_ids(new_probe_ids) + sliced_probe_group.probe_ids = new_probe_ids for original_probe_index, new_probe_index in zip(original_probe_indices, new_probe_indices): orig_probe = self.probes[original_probe_index] new_probe = sliced_probe_group.probes[new_probe_index] - + for k in orig_probe.annotations: if k not in new_probe.annotations: new_probe.annotate(**{k: orig_probe.annotations[k]}) return sliced_probe_group - def select_contacts(self, contact_ids: np.ndarray | list | None = None, probe_ids: np.ndarray | list | None = None) -> "ProbeGroup": + def select_contacts( + self, contact_ids: np.ndarray | list | None = None, probe_ids: np.ndarray | list | None = None + ) -> "ProbeGroup": """ Get a copy of the ProbeGroup with a sub selection of contacts based on contact ids and probe ids. @@ -421,8 +432,8 @@ def select_contacts(self, contact_ids: np.ndarray | list | None = None, probe_id contact_ids : np.array or list or None, default: None The contact ids to select. If None, all contacts are selected, but probe_ids must be provided. probe_ids : np.array or list or None, default: None - The probe ids to select. If contact_ids are not unique across probes, - then probe_ids should be provided to disambiguate. + The probe ids to select. If contact_ids are not unique across probes, + then probe_ids should be provided to disambiguate. If contact_ids are unique across probes, then probe_ids can be None. Returns @@ -431,9 +442,7 @@ def select_contacts(self, contact_ids: np.ndarray | list | None = None, probe_id The sliced probe group """ if contact_ids is None and probe_ids is None: - raise ValueError( - "Either contact_ids or probe_ids must be provided for selection." - ) + raise ValueError("Either contact_ids or probe_ids must be provided for selection.") if contact_ids is None: contact_mask = np.ones(self.get_contact_count(), dtype=bool) else: @@ -485,29 +494,29 @@ def check_global_device_wiring_and_ids(self) -> None: if valid_chans.size != np.unique(valid_chans).size: raise ValueError("channel device indices are not unique across probes") - def auto_generate_probe_ids(self, *args, **kwargs) -> None: - """ - Annotate all probes with unique probe_id values. - - Parameters - ---------- - *args: will be forwarded to `probeinterface.utils.generate_unique_ids` - **kwargs: will be forwarded to - `probeinterface.utils.generate_unique_ids` - """ - - if any("probe_id" in p.annotations for p in self.probes): - raise ValueError("Probe already has a `probe_id` annotation.") - - if not args: - args = 1e7, 1e8 - # 3rd argument has to be the number of probes - args = args[:2] + (len(self.probes),) - - # creating unique probe ids in case probes do not have any yet - probe_ids = generate_unique_ids(*args, **kwargs).astype(str) - for pid, probe in enumerate(self.probes): - probe.annotate(probe_id=probe_ids[pid]) + # def auto_generate_probe_ids(self, *args, **kwargs) -> None: + # """ + # Annotate all probes with unique probe_id values. + + # Parameters + # ---------- + # *args: will be forwarded to `probeinterface.utils.generate_unique_ids` + # **kwargs: will be forwarded to + # `probeinterface.utils.generate_unique_ids` + # """ + + # if any("probe_id" in p.annotations for p in self.probes): + # raise ValueError("Probe already has a `probe_id` annotation.") + + # if not args: + # args = 1e7, 1e8 + # # 3rd argument has to be the number of probes + # args = args[:2] + (len(self.probes),) + + # # creating unique probe ids in case probes do not have any yet + # probe_ids = generate_unique_ids(*args, **kwargs).astype(str) + # for pid, probe in enumerate(self.probes): + # probe.annotate(probe_id=probe_ids[pid]) def auto_generate_contact_ids(self, *args, **kwargs) -> None: """ diff --git a/tests/test_probegroup.py b/tests/test_probegroup.py index a7bd13b2..20125e22 100644 --- a/tests/test_probegroup.py +++ b/tests/test_probegroup.py @@ -15,7 +15,7 @@ def _make_probegroup(): probe.move([i * 100, i * 80]) n = probe.get_contact_count() probe.set_device_channel_indices(np.arange(n) + nchan) - probegroup.add_probe(probe) + probegroup.add_probe(probe, probe_id=f"probe_00{i}") nchan += n return probegroup @@ -39,17 +39,16 @@ def test_probegroup(probegroup): d = probegroup.to_dict() other = ProbeGroup.from_dict(d) + assert probegroup.probe_ids == other.probe_ids # checking automatic generation of ids with new dummy probes probegroup.probes = [] for i in range(3): - probegroup.add_probe(generate_dummy_probe()) + probegroup.add_probe(generate_dummy_probe(), probe_id=f"probe_00{i}") probegroup.auto_generate_contact_ids() - probegroup.auto_generate_probe_ids() for p in probegroup.probes: assert p.contact_ids is not None - assert "probe_id" in p.annotations def test_probegroup_3d(): From 47598ec5a6449a0d72fcb6334412aee4bec41a2f Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 22 Jun 2026 16:21:37 +0200 Subject: [PATCH 03/11] refac: remove auto_generate_probe_ids --- src/probeinterface/probegroup.py | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/src/probeinterface/probegroup.py b/src/probeinterface/probegroup.py index 11105402..553f3a83 100644 --- a/src/probeinterface/probegroup.py +++ b/src/probeinterface/probegroup.py @@ -494,30 +494,6 @@ def check_global_device_wiring_and_ids(self) -> None: if valid_chans.size != np.unique(valid_chans).size: raise ValueError("channel device indices are not unique across probes") - # def auto_generate_probe_ids(self, *args, **kwargs) -> None: - # """ - # Annotate all probes with unique probe_id values. - - # Parameters - # ---------- - # *args: will be forwarded to `probeinterface.utils.generate_unique_ids` - # **kwargs: will be forwarded to - # `probeinterface.utils.generate_unique_ids` - # """ - - # if any("probe_id" in p.annotations for p in self.probes): - # raise ValueError("Probe already has a `probe_id` annotation.") - - # if not args: - # args = 1e7, 1e8 - # # 3rd argument has to be the number of probes - # args = args[:2] + (len(self.probes),) - - # # creating unique probe ids in case probes do not have any yet - # probe_ids = generate_unique_ids(*args, **kwargs).astype(str) - # for pid, probe in enumerate(self.probes): - # probe.annotate(probe_id=probe_ids[pid]) - def auto_generate_contact_ids(self, *args, **kwargs) -> None: """ Annotate all contacts with unique contact_id values. From de675a5d97fa2543cbd3383094a41a17bda5fb61 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 23 Jun 2026 10:17:14 +0200 Subject: [PATCH 04/11] fix: select_contacts should maintain requested contact order --- src/probeinterface/probegroup.py | 43 +++++++++++++++++++++++--------- tests/test_probegroup.py | 32 ++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 12 deletions(-) diff --git a/src/probeinterface/probegroup.py b/src/probeinterface/probegroup.py index 553f3a83..e342b3a7 100644 --- a/src/probeinterface/probegroup.py +++ b/src/probeinterface/probegroup.py @@ -443,16 +443,30 @@ def select_contacts( """ if contact_ids is None and probe_ids is None: raise ValueError("Either contact_ids or probe_ids must be provided for selection.") + + # both arrays are in the global contact order + arr = self.to_numpy(complete=True) + all_contact_ids = arr["contact_ids"] + all_probe_ids = np.asarray(self.probe_ids)[arr["probe_index"]] + if contact_ids is None: - contact_mask = np.ones(self.get_contact_count(), dtype=bool) + # select whole probes, following the requested probe_ids order + probe_ids = np.asarray(probe_ids) + blocks = [np.flatnonzero(all_probe_ids == probe_id) for probe_id in probe_ids] else: contact_ids = np.asarray(contact_ids) - all_contact_ids = self.get_global_contact_ids() - contact_mask = np.isin(all_contact_ids, contact_ids) + unique_requested, counts = np.unique(contact_ids, return_counts=True) + duplicated = unique_requested[counts > 1] + if duplicated.size > 0: + raise ValueError( + f"contact_ids must be unique, but {duplicated.tolist()} appear more than once. " + "Contact ids are unique across probes; if the same contact id is on multiple " + "probes, use probe_ids to disambiguate." + ) if probe_ids is None: # without probe_ids the selection must be unambiguous: every requested # contact id must match a single contact across the whole ProbeGroup - matched_ids = all_contact_ids[contact_mask] + matched_ids = all_contact_ids[np.isin(all_contact_ids, contact_ids)] unique_ids, counts = np.unique(matched_ids, return_counts=True) ambiguous_ids = unique_ids[counts > 1] if ambiguous_ids.size > 0: @@ -460,14 +474,19 @@ def select_contacts( f"contact_ids {ambiguous_ids.tolist()} are not unique across probes, " "you should provide probe_ids to disambiguate" ) - if probe_ids is None: - probe_mask = np.ones(self.get_contact_count(), dtype=bool) - else: - all_probe_ids = np.asarray(self.probe_ids)[self.to_numpy(complete=True)["probe_index"]] - probe_ids = np.asarray(probe_ids) - probe_mask = np.isin(all_probe_ids, probe_ids) - selection_mask = contact_mask & probe_mask - return self.get_slice(selection_mask) + # follow the requested contact_ids order + blocks = [np.flatnonzero(all_contact_ids == contact_id) for contact_id in contact_ids] + else: + # contact_ids drives the order, probe_ids breaks ties between duplicated ids + probe_ids = np.asarray(probe_ids) + blocks = [ + np.flatnonzero((all_contact_ids == contact_id) & (all_probe_ids == probe_id)) + for contact_id in contact_ids + for probe_id in probe_ids + ] + + selection = np.concatenate(blocks) if len(blocks) else np.array([], dtype=int) + return self.get_slice(selection) def set_global_contact_order(self, global_contact_order: np.ndarray | list) -> None: """ diff --git a/tests/test_probegroup.py b/tests/test_probegroup.py index 20125e22..d74dfa10 100644 --- a/tests/test_probegroup.py +++ b/tests/test_probegroup.py @@ -554,6 +554,38 @@ def test_select_contacts_too_many_ids_without_probe_ids_raises(): pg.select_contacts(too_many) +def test_select_contacts_follows_requested_order(): + """The selection follows the order of the provided contact_ids, even across probes.""" + pg = _probegroup_with_contact_ids(unique=True) + # interleave contacts from different probes in a non-natural order + selected_ids = ["p2c5", "p0c1", "p1c0", "p0c0"] + sub = pg.select_contacts(selected_ids) + + np.testing.assert_array_equal(sub.get_global_contact_ids(), selected_ids) + + # positions must follow the same order as the requested ids + all_ids = pg.get_global_contact_ids() + all_positions = pg.get_global_contact_positions() + expected = np.vstack([all_positions[all_ids == cid] for cid in selected_ids]) + np.testing.assert_array_equal(sub.get_global_contact_positions(), expected) + + +def test_select_contacts_by_probe_ids_follows_requested_order(): + """Selecting by probe_ids alone follows the requested probe order.""" + pg = _probegroup_with_contact_ids(unique=False) + sub = pg.select_contacts(probe_ids=["probe_3", "probe_1"]) + # probe_3's contacts come first since it is requested first + probe_index_per_contact = sub.to_numpy(complete=True)["probe_index"] + assert probe_index_per_contact[0] == sub.probe_ids.index("probe_3") + + +def test_select_contacts_duplicated_ids_raises(): + """Passing the same contact id more than once raises a ValueError.""" + pg = _probegroup_with_contact_ids(unique=True) + with pytest.raises(ValueError, match="must be unique"): + pg.select_contacts(["p0c0", "p0c1", "p0c0"]) + + def test_select_contacts_preserves_positions(): """Selected contacts keep their global positions.""" pg = _probegroup_with_contact_ids(unique=True) From c3e6150eb8c9b7b7cd8679ffcd3b91e882a6d3a0 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 23 Jun 2026 12:35:24 +0200 Subject: [PATCH 05/11] fix: to/from numpy has probe_id, select_contacts behavior, add select_probes --- examples/ex_03_generate_probe_group.py | 33 ++-- examples/ex_05_device_channel_indices.py | 42 ----- src/probeinterface/probe.py | 4 +- src/probeinterface/probegroup.py | 181 ++++++++++++---------- tests/test_probegroup.py | 187 ++++++++++------------- 5 files changed, 201 insertions(+), 246 deletions(-) diff --git a/examples/ex_03_generate_probe_group.py b/examples/ex_03_generate_probe_group.py index f8764278..5986f13f 100644 --- a/examples/ex_03_generate_probe_group.py +++ b/examples/ex_03_generate_probe_group.py @@ -67,20 +67,17 @@ print("probe_ids:", probegroup.probe_ids) ############################################################################## -# `ProbeGroup.select_contacts()` returns a new `ProbeGroup` with a sub-selection -# of contacts. The selection can be done by ``contact_ids``, by ``probe_ids``, -# or by both at the same time. -# -# Selecting by ``probe_ids`` alone keeps every contact of the matching probes, -# which is a convenient way to grab a whole hemisphere: +# `ProbeGroup.select_probes()` returns a new `ProbeGroup` with a sub-selection +# of probes given by probe_ids. -left_hemisphere = probegroup.select_contacts(probe_ids=["left_hemisphere"]) -print("contacts in the left hemisphere:", left_hemisphere.get_contact_count()) +left_hemisphere_probe = probegroup.select_probes(probe_ids=["left_hemisphere"]) +print(left_hemisphere_probe) ############################################################################## -# We can also select by ``contact_ids``. Note that if ``contact_ids`` are not -# unique across probes, the selection will be ambiguous and an error will be -# raised. In this case, providing ``probe_ids`` disambiguates the selection: +# We can also select by specific contacts from a probegroup with the +# ``select_contacts`` function. Note that if ``contact_ids`` are not +# unique across probes, you need to disambiguate the selection by specifying the +# probe_ids as well. Otherwise, a ValueError is raised. # check if any contact_id is not unique across probes contact_ids = probegroup.get_global_contact_ids() @@ -91,14 +88,18 @@ # Because the contact ids are not unique across probes, combining ``contact_ids`` # with ``probe_ids`` lets us pull specific contacts from a single hemisphere: -left_contacts = probegroup.select_contacts(contact_ids=["0", "1", "2"], probe_ids=["left_hemisphere"]) -print("contacts selected from the left hemisphere:", left_contacts.get_contact_count()) +left_probegroup = probegroup.select_contacts( + contact_ids=["0", "1", "2"], + probe_ids=["left_hemisphere", "left_hemisphere", "left_hemisphere"] +) +print(left_probegroup) -left_and_right_contacts = probegroup.select_contacts( +# Now select contacts from both hemispheres by providing the corresponding probe_ids for each contact_id: +left_and_right_probegroup = probegroup.select_contacts( contact_ids=["0", "1", "2"], - probe_ids=["left_hemisphere", "right_hemisphere"] + probe_ids=["left_hemisphere", "right_hemisphere", "left_hemisphere"] ) -print("contacts selected from the left and right hemispheres:", left_and_right_contacts.get_contact_count()) +print(left_and_right_probegroup) # Without providing probe_ids, the selection is ambiguous and an error is raised: try: diff --git a/examples/ex_05_device_channel_indices.py b/examples/ex_05_device_channel_indices.py index e9d0c726..5731c910 100644 --- a/examples/ex_05_device_channel_indices.py +++ b/examples/ex_05_device_channel_indices.py @@ -87,46 +87,4 @@ fig, ax = plt.subplots() plot_probegroup(probegroup, with_contact_id=True, same_axes=True, ax=ax) -############################################################################## -# Reordering contacts with a global contact order -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# -# By default the contact order of a `ProbeGroup` is the "natural" one: the -# contacts of each probe are stacked one probe after the other. But sometimes -# the contacts of the different probes are *interleaved* in the recording file -# (e.g. the acquisition system alternates between probes sample by sample). -# -# `ProbeGroup.set_global_contact_order()` lets us store this external ordering. -# The order is an array of indices into the natural (stacked) order, and it is -# applied whenever the group is exported with `to_numpy()` / `to_dataframe()`. - -probegroup = ProbeGroup() -probegroup.add_probe(probe0.copy()) -probegroup.add_probe(probe1.copy()) - -n = probegroup.get_contact_count() -print("default global contact order:", probegroup._global_contact_order) - -# interleave probe0 and probe1 contacts as they appear in the recording file -global_contact_order = np.zeros(n, dtype="int64") -global_contact_order[0::2] = np.arange(0, n // 2) # probe0 contacts -global_contact_order[1::2] = np.arange(n // 2, n) # probe1 contacts -probegroup.set_global_contact_order(global_contact_order) - -############################################################################## -# Now `to_numpy()` returns the contacts in the interleaved order: the -# ``probe_index`` column alternates between the two probes. - -contact_vector = probegroup.to_numpy() -print("probe_index in global order:", contact_vector["probe_index"][:8]) - -############################################################################## -# The global order interacts with `set_global_device_channel_indices()`: the -# ``device_channel_indices`` you pass are interpreted in the (reordered) order -# returned by `to_numpy()`, so they map directly onto the acquisition channels. - -probegroup.set_global_device_channel_indices(np.arange(n)) -print("device_channel_indices (global order):", - probegroup.to_numpy(complete=True)["device_channel_indices"][:8]) - plt.show() diff --git a/src/probeinterface/probe.py b/src/probeinterface/probe.py index 26284289..4174f199 100644 --- a/src/probeinterface/probe.py +++ b/src/probeinterface/probe.py @@ -1140,8 +1140,10 @@ def from_numpy(arr: np.ndarray) -> "Probe": "plane_axis_y_1", "plane_axis_z_0", "plane_axis_z_1", - "probe_index", "si_units", + # these two are for ProbeGroup to avoid duplication of fields + "probe_index", + "probe_id", ] contact_annotation_fields = [f for f in fields if f not in main_fields] diff --git a/src/probeinterface/probegroup.py b/src/probeinterface/probegroup.py index e342b3a7..755d2d43 100644 --- a/src/probeinterface/probegroup.py +++ b/src/probeinterface/probegroup.py @@ -20,15 +20,15 @@ class ProbeGroup: """ def __init__(self): - self.probes = [] + self._probes = [] self._probe_ids = [] self._global_contact_order = None def __repr__(self): - repr_str = f"ProbeGroup: {len(self.probes)} probes - {self.get_contact_count()} contacts" + repr_str = f"ProbeGroup: {len(self._probes)} probes - {self.get_contact_count()} contacts" if self._global_contact_order is not None: repr_str += " (with custom global contact order)" - for probe, probe_id in zip(self.probes, self._probe_ids): + for probe, probe_id in zip(self._probes, self._probe_ids): repr_str += f"\n\t{probe_id}: {probe}" return repr_str @@ -44,16 +44,22 @@ def add_probe(self, probe: Probe, probe_id: str = None) -> None: The ID to assign to the probe. If None, a unique ID will be generated. """ - if len(self.probes) > 0: + if len(self._probes) > 0: self._check_compatible(probe) - self.probes.append(probe) - if probe_id is not None: - self._probe_ids.append(probe_id) - else: - self._probe_ids.append(f"probe_{len(self.probes)}") + if probe_id is None: + probe_id = f"{len(self._probes)}" + if probe_id in self._probe_ids: + raise ValueError(f"Probe ID '{probe_id}' is already used in this ProbeGroup.") + self._probe_ids.append(probe_id) + + self._probes.append(probe) probe._probe_group = self + @property + def probes(self) -> list: + return self._probes + @property def probe_ids(self) -> list: return self._probe_ids @@ -69,9 +75,9 @@ def probe_ids(self, probe_ids: list) -> None: A list of IDs to assign to the probes. The length of the list must match the number of probes in the ProbeGroup. """ - if len(probe_ids) != len(self.probes): + if len(probe_ids) != len(self._probes): raise ValueError( - f"Length of probe_ids ({len(probe_ids)}) does not match number of probes ({len(self.probes)})" + f"Length of probe_ids ({len(probe_ids)}) does not match number of probes ({len(self._probes)})" ) self._probe_ids = probe_ids @@ -87,9 +93,11 @@ def _check_compatible(self, probe: Probe) -> None: ) # check global channel maps - self.probes.append(probe) + self._probes.append(probe) + self._probe_ids.append(f"{len(self._probes)-1}") self.check_global_device_wiring_and_ids() - self.probes = self.probes[:-1] + self._probes = self.probes[:-1] + self._probe_ids = self.probe_ids[:-1] @property def ndim(self) -> int: @@ -133,7 +141,7 @@ def to_numpy(self, complete: bool = False) -> np.ndarray: probe_arr = [] # loop over probes to get all fields - dtype = [("probe_index", "int64")] + dtype = [("probe_index", "int64"), ("probe_id", "U100")] fields = [] for probe_index, probe in enumerate(self.probes): arr = probe.to_numpy(complete=complete) @@ -148,6 +156,7 @@ def to_numpy(self, complete: bool = False) -> np.ndarray: arr = probe_arr[probe_index] arr_ext = np.zeros(probe.get_contact_count(), dtype=dtype) arr_ext["probe_index"] = probe_index + arr_ext["probe_id"] = self._probe_ids[probe_index] for k in fields: if k in arr.dtype.fields: arr_ext[k] = arr[k] @@ -185,12 +194,13 @@ def from_numpy(arr: np.ndarray) -> "ProbeGroup": if is_interleaved: global_contact_order = [] - probes_indices = np.unique(arr["probe_index"]) + probes_indices = np.sort(np.unique(arr["probe_index"])) probegroup = ProbeGroup() for probe_index in probes_indices: mask = arr["probe_index"] == probe_index + probe_id = arr["probe_id"][mask][0] probe = Probe.from_numpy(arr[mask]) - probegroup.add_probe(probe) + probegroup.add_probe(probe, probe_id=probe_id) if is_interleaved: global_contact_order.append(np.flatnonzero(mask)) @@ -264,12 +274,12 @@ def from_dict(d: dict) -> "ProbeGroup": The instantiated ProbeGroup object """ probegroup = ProbeGroup() - for probe_dict in d["probes"]: - probe = Probe.from_dict(probe_dict) - probegroup.add_probe(probe) probe_ids = d.get("probe_ids", None) - if probe_ids is not None: - probegroup.probe_ids = probe_ids + if probe_ids is None: + probe_ids = [str(i) for i in range(len(d["probes"]))] + for probe_id, probe_dict in zip(probe_ids, d["probes"]): + probe = Probe.from_dict(probe_dict) + probegroup.add_probe(probe, probe_id=probe_id) global_contact_order = d.get("global_contact_order", None) if global_contact_order is not None: @@ -404,16 +414,12 @@ def get_slice(self, selection: np.ndarray[bool | int]) -> "ProbeGroup": contact_arr = self.to_numpy(complete=True) contact_arr = contact_arr[selection] - original_probe_indices = np.unique(contact_arr["probe_index"]) sliced_probe_group = ProbeGroup.from_numpy(contact_arr) - new_probe_indices = np.unique(sliced_probe_group.to_numpy(complete=True)["probe_index"]) # Map annotations of the original probegroup to the sliced one - new_probe_ids = [self.probe_ids[i] for i in original_probe_indices] - sliced_probe_group.probe_ids = new_probe_ids - for original_probe_index, new_probe_index in zip(original_probe_indices, new_probe_indices): + for probe_id, new_probe in zip(sliced_probe_group.probe_ids, sliced_probe_group.probes): + original_probe_index = self.probe_ids.index(probe_id) orig_probe = self.probes[original_probe_index] - new_probe = sliced_probe_group.probes[new_probe_index] for k in orig_probe.annotations: if k not in new_probe.annotations: @@ -421,89 +427,100 @@ def get_slice(self, selection: np.ndarray[bool | int]) -> "ProbeGroup": return sliced_probe_group + def select_probes(self, probe_ids: str | np.ndarray | list) -> "ProbeGroup": + """ + Get a copy of the ProbeGroup with a sub selection of probes based on probe ids. + + Parameters + ---------- + probe_ids : str | np.array or list + The probe id or ids to select. + + Returns + ------- + sliced_probe_group: ProbeGroup + The sliced probe group + """ + if probe_ids is None: + raise ValueError("probe_ids must be provided for selection.") + + if isinstance(probe_ids, str): + probe_ids = [probe_ids] + + # selection is over the global contact vector, following the requested probe order + all_probe_ids = self.to_numpy(complete=True)["probe_id"] + probe_ids = np.asarray(probe_ids) + blocks = [np.flatnonzero(all_probe_ids == probe_id) for probe_id in probe_ids] + + selection = np.concatenate(blocks) if len(blocks) else np.array([], dtype=int) + return self.get_slice(selection) + def select_contacts( - self, contact_ids: np.ndarray | list | None = None, probe_ids: np.ndarray | list | None = None + self, contact_ids: np.ndarray | list, probe_ids: np.ndarray | list | None = None ) -> "ProbeGroup": """ Get a copy of the ProbeGroup with a sub selection of contacts based on contact ids and probe ids. Parameters ---------- - contact_ids : np.array or list or None, default: None - The contact ids to select. If None, all contacts are selected, but probe_ids must be provided. + contact_ids : np.array or list + The contact ids to select. probe_ids : np.array or list or None, default: None - The probe ids to select. If contact_ids are not unique across probes, - then probe_ids should be provided to disambiguate. - If contact_ids are unique across probes, then probe_ids can be None. + If multiple probes and contact ids not unique across probes, an array with the same length + as contact ids to specify which probe each contact id belongs to. Returns ------- sliced_probe_group: ProbeGroup The sliced probe group """ - if contact_ids is None and probe_ids is None: - raise ValueError("Either contact_ids or probe_ids must be provided for selection.") - # both arrays are in the global contact order arr = self.to_numpy(complete=True) all_contact_ids = arr["contact_ids"] - all_probe_ids = np.asarray(self.probe_ids)[arr["probe_index"]] + all_probe_ids = arr["probe_id"] - if contact_ids is None: - # select whole probes, following the requested probe_ids order - probe_ids = np.asarray(probe_ids) - blocks = [np.flatnonzero(all_probe_ids == probe_id) for probe_id in probe_ids] - else: - contact_ids = np.asarray(contact_ids) + contact_ids = np.asarray(contact_ids) + + if probe_ids is None: + # without probe_ids the request must be unambiguous: each requested contact + # id must appear once in the request and match a single contact in the group unique_requested, counts = np.unique(contact_ids, return_counts=True) duplicated = unique_requested[counts > 1] if duplicated.size > 0: raise ValueError( f"contact_ids must be unique, but {duplicated.tolist()} appear more than once. " - "Contact ids are unique across probes; if the same contact id is on multiple " - "probes, use probe_ids to disambiguate." + "If the same contact id is on multiple probes, use probe_ids to disambiguate." + ) + matched_ids = all_contact_ids[np.isin(all_contact_ids, contact_ids)] + unique_ids, match_counts = np.unique(matched_ids, return_counts=True) + ambiguous_ids = unique_ids[match_counts > 1] + if ambiguous_ids.size > 0: + raise ValueError( + f"contact_ids {ambiguous_ids.tolist()} are not unique across probes, " + "you should provide probe_ids to disambiguate" + ) + # follow the requested contact_ids order + blocks = [np.flatnonzero(all_contact_ids == contact_id) for contact_id in contact_ids] + else: + # probe_ids is paired with contact_ids: probe_ids[i] tells which probe + # contact_ids[i] belongs to, disambiguating ids shared across probes + probe_ids = np.asarray(probe_ids) + if probe_ids.size != contact_ids.size: + raise ValueError( + f"probe_ids must have the same length as contact_ids " f"({probe_ids.size} != {contact_ids.size})" ) - if probe_ids is None: - # without probe_ids the selection must be unambiguous: every requested - # contact id must match a single contact across the whole ProbeGroup - matched_ids = all_contact_ids[np.isin(all_contact_ids, contact_ids)] - unique_ids, counts = np.unique(matched_ids, return_counts=True) - ambiguous_ids = unique_ids[counts > 1] - if ambiguous_ids.size > 0: - raise ValueError( - f"contact_ids {ambiguous_ids.tolist()} are not unique across probes, " - "you should provide probe_ids to disambiguate" - ) - # follow the requested contact_ids order - blocks = [np.flatnonzero(all_contact_ids == contact_id) for contact_id in contact_ids] - else: - # contact_ids drives the order, probe_ids breaks ties between duplicated ids - probe_ids = np.asarray(probe_ids) - blocks = [ - np.flatnonzero((all_contact_ids == contact_id) & (all_probe_ids == probe_id)) - for contact_id in contact_ids - for probe_id in probe_ids - ] + pairs = list(zip(contact_ids.tolist(), probe_ids.tolist())) + if len(set(pairs)) != len(pairs): + raise ValueError("(contact_id, probe_id) pairs must be unique") + # contact_ids drives the order, probe_ids breaks ties between duplicated ids + blocks = [ + np.flatnonzero((all_contact_ids == contact_id) & (all_probe_ids == probe_id)) + for contact_id, probe_id in zip(contact_ids, probe_ids) + ] selection = np.concatenate(blocks) if len(blocks) else np.array([], dtype=int) return self.get_slice(selection) - def set_global_contact_order(self, global_contact_order: np.ndarray | list) -> None: - """ - Set the global contact order for the ProbeGroup. This is useful when some contact of each probe are interleaved in the recording file. - - Parameters - ---------- - global_contact_order: np.ndarray | list - The global contact order to be set. It should be an array of indices that defines the new order of contacts across all probes. - """ - global_contact_order = np.asarray(global_contact_order) - if global_contact_order.size != self.get_contact_count(): - raise ValueError( - f"Wrong global contact order size {global_contact_order.size} for the number of channels {self.get_contact_count()}" - ) - self._global_contact_order = global_contact_order - def check_global_device_wiring_and_ids(self) -> None: # check unique device_channel_indices for !=-1 chans = self.get_global_device_channel_indices() diff --git a/tests/test_probegroup.py b/tests/test_probegroup.py index d74dfa10..23d1fd7f 100644 --- a/tests/test_probegroup.py +++ b/tests/test_probegroup.py @@ -42,7 +42,8 @@ def test_probegroup(probegroup): assert probegroup.probe_ids == other.probe_ids # checking automatic generation of ids with new dummy probes - probegroup.probes = [] + probegroup._probes = [] + probegroup._probe_ids = [] for i in range(3): probegroup.add_probe(generate_dummy_probe(), probe_id=f"probe_00{i}") probegroup.auto_generate_contact_ids() @@ -351,7 +352,7 @@ def test_get_slice_single_probe_keeps_probe_id_and_annotations(): # ── global_contact_order : to_numpy/from_numpy, to_dict/from_dict, get_slice -def test_reordred_probegroup(probegroup): +def test_reordered_probegroup(probegroup): order = np.concatenate([np.arange(0, 96, 2), np.arange(95, 0, -2)]) contact_vector = probegroup.to_numpy(complete=True) @@ -374,7 +375,7 @@ def test_reordred_probegroup(probegroup): probegroup5 = ProbeGroup.from_dict(probegroup4.to_dict()) assert probegroup5._global_contact_order is not None - contact_vector5 = probegroup3.to_numpy(complete=True) + contact_vector5 = probegroup5.to_numpy(complete=True) assert np.array_equal(contact_vector4, contact_vector5) # let go back to original order @@ -383,79 +384,6 @@ def test_reordred_probegroup(probegroup): assert probegroup6._global_contact_order is None -# ── set_global_contact_order() tests ──────────────────────────────────────── - - -def _reorder_indices(): - """An interleaved order over the 96 contacts of the default probegroup.""" - return np.concatenate([np.arange(0, 96, 2), np.arange(95, 0, -2)]) - - -def test_set_global_contact_order_reorders_to_numpy(probegroup): - """set_global_contact_order reorders the contact vector returned by to_numpy.""" - order = _reorder_indices() - natural = probegroup.to_numpy(complete=True).copy() - - probegroup.set_global_contact_order(order) - - assert probegroup._global_contact_order is not None - reordered = probegroup.to_numpy(complete=True) - np.testing.assert_array_equal(reordered, natural[order]) - - -def test_set_global_contact_order_reorders_positions(probegroup): - """get_global_contact_positions reflects the custom order.""" - order = _reorder_indices() - natural_positions = probegroup.get_global_contact_positions().copy() - - probegroup.set_global_contact_order(order) - - reordered_positions = probegroup.get_global_contact_positions() - np.testing.assert_array_equal(reordered_positions, natural_positions[order]) - - -def test_set_global_contact_order_wrong_size(probegroup): - """A global contact order that does not match the contact count raises ValueError.""" - with pytest.raises(ValueError, match="Wrong global contact order size"): - probegroup.set_global_contact_order(np.arange(5)) - - -def test_set_global_contact_order_accepts_list(probegroup): - """A plain list is accepted and stored as an array.""" - order = list(_reorder_indices()) - probegroup.set_global_contact_order(order) - assert isinstance(probegroup._global_contact_order, np.ndarray) - - -def test_set_global_contact_order_device_channel_indices_consistency(probegroup): - """ - device_channel_indices are zipped to to_numpy() (which is reordered), - so setting them after a custom order must roundtrip through to_numpy. - """ - order = _reorder_indices() - probegroup.set_global_contact_order(order) - - n = probegroup.get_contact_count() - device_channel_indices = np.arange(n) - probegroup.set_global_device_channel_indices(device_channel_indices) - - got = probegroup.to_numpy(complete=True)["device_channel_indices"] - np.testing.assert_array_equal(got, device_channel_indices) - - -def test_set_global_contact_order_roundtrip_dict(probegroup): - """The custom order survives a to_dict/from_dict roundtrip.""" - order = _reorder_indices() - probegroup.set_global_contact_order(order) - - other = ProbeGroup.from_dict(probegroup.to_dict()) - assert other._global_contact_order is not None - np.testing.assert_array_equal( - other.to_numpy(complete=True), - probegroup.to_numpy(complete=True), - ) - - # ── select_contacts() tests ───────────────────────────────────────────────── @@ -505,41 +433,27 @@ def test_select_contacts_ambiguous_ids_without_probe_ids_raises(): def test_select_contacts_with_probe_ids(): - """probe_ids disambiguate duplicated contact ids to a single probe.""" + """probe_ids (paired with contact_ids) disambiguate duplicated contact ids.""" pg = _probegroup_with_contact_ids(unique=False) - sub = pg.select_contacts(["c0", "c1"], probe_ids=["probe_1"]) + sub = pg.select_contacts(["c0", "c1"], probe_ids=["1", "1"]) assert sub.get_contact_count() == 2 assert len(sub.probes) == 1 np.testing.assert_array_equal(sorted(sub.get_global_contact_ids()), ["c0", "c1"]) -def test_select_contacts_probe_ids_subset_of_probes(): - """probe_ids can restrict the selection to a subset of probes.""" +def test_select_contacts_same_id_across_probes_with_probe_ids(): + """The same contact id can be selected from several probes using probe_ids.""" pg = _probegroup_with_contact_ids(unique=False) - sub = pg.select_contacts(["c0"], probe_ids=["probe_1", "probe_3"]) + sub = pg.select_contacts(["c0", "c0"], probe_ids=["0", "2"]) assert sub.get_contact_count() == 2 assert len(sub.probes) == 2 -def test_select_contacts_by_probe_ids_only(): - """Selecting by probe_ids alone keeps every contact of the matching probes.""" +def test_select_contacts_probe_ids_length_mismatch_raises(): + """probe_ids must have the same length as contact_ids.""" pg = _probegroup_with_contact_ids(unique=False) - n_per_probe = pg.probes[0].get_contact_count() - - sub = pg.select_contacts(probe_ids=["probe_1"]) - assert sub.get_contact_count() == n_per_probe - assert len(sub.probes) == 1 - - sub_two = pg.select_contacts(probe_ids=["probe_1", "probe_3"]) - assert sub_two.get_contact_count() == 2 * n_per_probe - assert len(sub_two.probes) == 2 - - -def test_select_contacts_requires_some_selection(): - """Calling with neither contact_ids nor probe_ids raises ValueError.""" - pg = _probegroup_with_contact_ids(unique=False) - with pytest.raises(ValueError, match="Either contact_ids or probe_ids"): - pg.select_contacts() + with pytest.raises(ValueError, match="same length as contact_ids"): + pg.select_contacts(["c0", "c1"], probe_ids=["0"]) def test_select_contacts_too_many_ids_without_probe_ids_raises(): @@ -570,13 +484,76 @@ def test_select_contacts_follows_requested_order(): np.testing.assert_array_equal(sub.get_global_contact_positions(), expected) -def test_select_contacts_by_probe_ids_follows_requested_order(): - """Selecting by probe_ids alone follows the requested probe order.""" +def test_select_probes_keeps_every_contact_of_matching_probes(): + """select_probes keeps every contact of the matching probes.""" pg = _probegroup_with_contact_ids(unique=False) - sub = pg.select_contacts(probe_ids=["probe_3", "probe_1"]) - # probe_3's contacts come first since it is requested first + n_per_probe = pg.probes[0].get_contact_count() + + sub_str = pg.select_probes("1") + assert sub_str.get_contact_count() == n_per_probe + assert len(sub_str.probes) == 1 + + sub_one = pg.select_probes(["1"]) + assert sub_one.get_contact_count() == n_per_probe + assert len(sub_one.probes) == 1 + + sub_two = pg.select_probes(["1", "2"]) + assert sub_two.get_contact_count() == 2 * n_per_probe + assert len(sub_two.probes) == 2 + + +def test_select_probes_follows_requested_order(): + """select_probes follows the requested probe order.""" + pg = _probegroup_with_contact_ids(unique=False) + sub = pg.select_probes(["2", "0"]) + # probe "2"'s contacts come first since it is requested first probe_index_per_contact = sub.to_numpy(complete=True)["probe_index"] - assert probe_index_per_contact[0] == sub.probe_ids.index("probe_3") + assert probe_index_per_contact[0] == sub.probe_ids.index("2") + + +def test_select_probes_single_probe(): + """Selecting a single probe keeps a single probe with its contact ids.""" + pg = _probegroup_with_contact_ids(unique=True) + sub = pg.select_probes(["1"]) + assert len(sub.probes) == 1 + assert sub.probe_ids == ["1"] + assert all(cid.startswith("p1") for cid in sub.get_global_contact_ids()) + + +def test_select_probes_preserves_probe_ids(): + """The selected ProbeGroup keeps the requested probe ids.""" + pg = _probegroup_with_contact_ids(unique=False) + sub = pg.select_probes(["2", "0"]) + assert set(sub.probe_ids) == {"0", "2"} + + +def test_select_probes_preserves_positions(): + """Contacts of the selected probes keep their global positions.""" + pg = _probegroup_with_contact_ids(unique=True) + + all_ids = pg.get_global_contact_ids() + all_positions = pg.get_global_contact_positions() + + sub = pg.select_probes(["0", "2"]) + sub_ids = sub.get_global_contact_ids() + sub_positions = sub.get_global_contact_positions() + for cid, pos in zip(sub_ids, sub_positions): + np.testing.assert_array_equal(pos, all_positions[all_ids == cid][0]) + + +def test_select_probes_none_raises(): + """Calling select_probes without probe_ids raises a ValueError.""" + pg = _probegroup_with_contact_ids(unique=False) + with pytest.raises(ValueError, match="probe_ids must be provided"): + pg.select_probes(None) + + +def test_select_probes_all_probes(): + """Selecting all probes returns the whole ProbeGroup.""" + pg = _probegroup_with_contact_ids(unique=True) + sub = pg.select_probes(["0", "1", "2"]) + assert sub.get_contact_count() == pg.get_contact_count() + assert len(sub.probes) == len(pg.probes) def test_select_contacts_duplicated_ids_raises(): @@ -608,4 +585,4 @@ def test_select_contacts_preserves_positions(): # test_probegroup(probegroup) # test_probegroup_3d() - test_reordred_probegroup(probegroup) + test_reordered_probegroup(probegroup) From 9aaf5c619d1483f5a78f2abe9ccd8e39508e5a70 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 23 Jun 2026 12:37:54 +0200 Subject: [PATCH 06/11] Update src/probeinterface/io.py --- src/probeinterface/io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/probeinterface/io.py b/src/probeinterface/io.py index c1efb239..a6cccc5c 100644 --- a/src/probeinterface/io.py +++ b/src/probeinterface/io.py @@ -325,7 +325,7 @@ def write_BIDS_probe(folder: str | Path, probe_or_probegroup: Probe | ProbeGroup probegroup = probe_or_probegroup else: raise TypeError( - f"probe_or_probegroup has to beof type Probe or ProbeGroup not type: {type(probe_or_probegroup)}" + f"probe_or_probegroup has to be of type Probe or ProbeGroup not type: {type(probe_or_probegroup)}" ) folder = Path(folder) From 1969a73980ffaab67e9de1046b4a9699c2140b41 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 23 Jun 2026 12:54:22 +0200 Subject: [PATCH 07/11] tests: extend tests and fix docstring --- src/probeinterface/probegroup.py | 14 ++++----- tests/test_probegroup.py | 51 ++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 7 deletions(-) diff --git a/src/probeinterface/probegroup.py b/src/probeinterface/probegroup.py index 755d2d43..c8708cf0 100644 --- a/src/probeinterface/probegroup.py +++ b/src/probeinterface/probegroup.py @@ -7,16 +7,16 @@ class ProbeGroup: """ Class to handle a group of Probe objects and the global wiring to a device. - Internally, this is represented as a list of Probe object. + Internally, this is represented as a list of Probe objects. The ProbeGroup is the object saved in the json based probeinterface format, even if there is only one probe. - Tiny detail: when using `PropbeGroup.to_numpy()` / `PropbeGroup.to_dataframe()` by default the contact order - is the "natural" one (stacked order of each probe). An external contact order can be applied using the - ``ProbeGroup.set_global_contact_order()`` method, and the contact order is then stored in the - ``ProbeGroup._global_contact_order`` attribute. In this case, the contact order of the ProbeGroup is not "natural" - anymore, but the one defined by the user. This is useful for instance when some contact of each probe are - interleaved in the recording file. + Tiny detail about contact order: ``ProbeGroup.to_numpy()`` / ``ProbeGroup.to_dataframe()`` return contacts in the + "natural" order (the contacts of each probe stacked one probe after another) unless contacts have become + interleaved across probes. Interleaving can arise from ``get_slice`` or ``select_contacts`` (e.g. selecting + contacts from different probes in an alternating order). When it does, the resulting ``ProbeGroup`` keeps a custom + contact order in the ``_global_contact_order`` attribute so the requested order is preserved. This order is only + ever set internally; there is no public method to set it. """ def __init__(self): diff --git a/tests/test_probegroup.py b/tests/test_probegroup.py index 23d1fd7f..4ee3b457 100644 --- a/tests/test_probegroup.py +++ b/tests/test_probegroup.py @@ -384,6 +384,57 @@ def test_reordered_probegroup(probegroup): assert probegroup6._global_contact_order is None +def _interleaved_order(): + """An order interleaving contacts across probes (non-natural).""" + return np.concatenate([np.arange(0, 96, 2), np.arange(95, 0, -2)]) + + +def test_global_contact_order_natural_is_none(probegroup): + """A non-interleaved (natural) contact vector does not set a custom order.""" + pg = ProbeGroup.from_numpy(probegroup.to_numpy(complete=True)) + assert pg._global_contact_order is None + + +def test_global_contact_order_positions_reflect_order(probegroup): + """get_global_contact_positions follows the custom global contact order.""" + order = _interleaved_order() + natural_positions = probegroup.get_global_contact_positions().copy() + + pg = ProbeGroup.from_numpy(probegroup.to_numpy(complete=True)[order]) + assert pg._global_contact_order is not None + np.testing.assert_array_equal(pg.get_global_contact_positions(), natural_positions[order]) + + +def test_global_contact_order_ids_reflect_order(probegroup): + """get_global_contact_ids follows the custom global contact order.""" + order = _interleaved_order() + natural_ids = probegroup.get_global_contact_ids().copy() + + pg = ProbeGroup.from_numpy(probegroup.to_numpy(complete=True)[order]) + np.testing.assert_array_equal(pg.get_global_contact_ids(), natural_ids[order]) + + +def test_global_contact_order_device_channel_indices_roundtrip(probegroup): + """ + With a custom global contact order, device_channel_indices are zipped to the + (reordered) to_numpy() vector. Setting them must roundtrip through both + to_numpy() and get_global_device_channel_indices(). + """ + order = _interleaved_order() + pg = ProbeGroup.from_numpy(probegroup.to_numpy(complete=True)[order]) + assert pg._global_contact_order is not None + + n = pg.get_contact_count() + device_channel_indices = np.arange(n) + pg.set_global_device_channel_indices(device_channel_indices) + + got = pg.to_numpy(complete=True)["device_channel_indices"] + np.testing.assert_array_equal(got, device_channel_indices) + + got_getter = pg.get_global_device_channel_indices()["device_channel_indices"] + np.testing.assert_array_equal(got_getter, device_channel_indices) + + # ── select_contacts() tests ───────────────────────────────────────────────── From df62a7c948f44b48d1d6318ea7eb445e4c6c4e2f Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 25 Jun 2026 14:38:39 +0200 Subject: [PATCH 08/11] fix: suggestions from Sam's review --- src/probeinterface/probegroup.py | 87 +++++++++++++++++++------------- tests/test_probegroup.py | 37 ++++++++++---- 2 files changed, 79 insertions(+), 45 deletions(-) diff --git a/src/probeinterface/probegroup.py b/src/probeinterface/probegroup.py index c8708cf0..19b32759 100644 --- a/src/probeinterface/probegroup.py +++ b/src/probeinterface/probegroup.py @@ -1,3 +1,5 @@ +import warnings + import numpy as np from .utils import generate_unique_ids from .probe import Probe @@ -28,8 +30,6 @@ def __repr__(self): repr_str = f"ProbeGroup: {len(self._probes)} probes - {self.get_contact_count()} contacts" if self._global_contact_order is not None: repr_str += " (with custom global contact order)" - for probe, probe_id in zip(self._probes, self._probe_ids): - repr_str += f"\n\t{probe_id}: {probe}" return repr_str def add_probe(self, probe: Probe, probe_id: str = None) -> None: @@ -41,14 +41,28 @@ def add_probe(self, probe: Probe, probe_id: str = None) -> None: probe: Probe The probe to add to the ProbeGroup probe_id: str, optional - The ID to assign to the probe. If None, a unique ID will be generated. + The ID to assign to the probe. If None, a unique ID will be generated, + unless a probe_id is already present in the probe's annotations, + in which case that will be used. """ if len(self._probes) > 0: self._check_compatible(probe) + probe_id_annotation = probe.annotations.get("probe_id", None) + if probe_id is None: - probe_id = f"{len(self._probes)}" + if probe_id_annotation is not None: + probe_id = probe_id_annotation + else: + probe_id = f"{len(self._probes)}" + else: + if probe_id_annotation is not None and probe_id != probe_id_annotation: + warnings.warn( + f"Provided probe_id '{probe_id}' does not match probe's annotation 'probe_id' " + f"({probe_id_annotation}). Using provided probe_id." + ) + if probe_id in self._probe_ids: raise ValueError(f"Probe ID '{probe_id}' is already used in this ProbeGroup.") self._probe_ids.append(probe_id) @@ -56,6 +70,10 @@ def add_probe(self, probe: Probe, probe_id: str = None) -> None: self._probes.append(probe) probe._probe_group = self + @property + def probe_dict(self) -> dict: + return {probe_id: probe for probe_id, probe in zip(self._probe_ids, self._probes)} + @property def probes(self) -> list: return self._probes @@ -447,13 +465,14 @@ def select_probes(self, probe_ids: str | np.ndarray | list) -> "ProbeGroup": if isinstance(probe_ids, str): probe_ids = [probe_ids] - # selection is over the global contact vector, following the requested probe order - all_probe_ids = self.to_numpy(complete=True)["probe_id"] probe_ids = np.asarray(probe_ids) - blocks = [np.flatnonzero(all_probe_ids == probe_id) for probe_id in probe_ids] + if any(probe_id not in self.probe_ids for probe_id in probe_ids): + raise ValueError(f"Some probe_ids {probe_ids} are not present in the ProbeGroup.") - selection = np.concatenate(blocks) if len(blocks) else np.array([], dtype=int) - return self.get_slice(selection) + # selection keeps the order of the to_numpy vector + all_probe_ids = self.to_numpy(complete=True)["probe_id"] + keep_inds = np.flatnonzero(np.isin(all_probe_ids, probe_ids)) + return self.get_slice(keep_inds) def select_contacts( self, contact_ids: np.ndarray | list, probe_ids: np.ndarray | list | None = None @@ -491,35 +510,35 @@ def select_contacts( f"contact_ids must be unique, but {duplicated.tolist()} appear more than once. " "If the same contact id is on multiple probes, use probe_ids to disambiguate." ) - matched_ids = all_contact_ids[np.isin(all_contact_ids, contact_ids)] - unique_ids, match_counts = np.unique(matched_ids, return_counts=True) - ambiguous_ids = unique_ids[match_counts > 1] - if ambiguous_ids.size > 0: + probe_ids = [None] * len(contact_ids) + else: + if len(probe_ids) != len(contact_ids): raise ValueError( - f"contact_ids {ambiguous_ids.tolist()} are not unique across probes, " - "you should provide probe_ids to disambiguate" + f"probe_ids must be the same length as contact_ids, but got {len(probe_ids)} probe_ids and {len(contact_ids)} contact_ids." ) - # follow the requested contact_ids order - blocks = [np.flatnonzero(all_contact_ids == contact_id) for contact_id in contact_ids] - else: - # probe_ids is paired with contact_ids: probe_ids[i] tells which probe - # contact_ids[i] belongs to, disambiguating ids shared across probes - probe_ids = np.asarray(probe_ids) - if probe_ids.size != contact_ids.size: + + indices = [] + for contact_id, probe_id in zip(contact_ids, probe_ids): + if probe_id is None: + probe_condition = True + else: + probe_condition = all_probe_ids == probe_id + + # find the contact id within the specified probe + matches = np.flatnonzero((all_contact_ids == contact_id) & probe_condition) + if matches.size == 0: + raise ValueError(f"contact_id {contact_id} not found in probe {probe_id}") + elif matches.size > 1: + raise ValueError( + f"contact_id {contact_id} is not unique within probe {probe_id}, " + "this should not happen unless the probe has duplicate contact ids" + ) + if matches[0] in indices: raise ValueError( - f"probe_ids must have the same length as contact_ids " f"({probe_ids.size} != {contact_ids.size})" + f"contact_id {contact_id} in probe {probe_id} has a duplicate selection, please check your input" ) - pairs = list(zip(contact_ids.tolist(), probe_ids.tolist())) - if len(set(pairs)) != len(pairs): - raise ValueError("(contact_id, probe_id) pairs must be unique") - # contact_ids drives the order, probe_ids breaks ties between duplicated ids - blocks = [ - np.flatnonzero((all_contact_ids == contact_id) & (all_probe_ids == probe_id)) - for contact_id, probe_id in zip(contact_ids, probe_ids) - ] - - selection = np.concatenate(blocks) if len(blocks) else np.array([], dtype=int) - return self.get_slice(selection) + indices.append(matches[0]) + return self.get_slice(indices) def check_global_device_wiring_and_ids(self) -> None: # check unique device_channel_indices for !=-1 diff --git a/tests/test_probegroup.py b/tests/test_probegroup.py index 4ee3b457..15e96bd0 100644 --- a/tests/test_probegroup.py +++ b/tests/test_probegroup.py @@ -96,7 +96,7 @@ def test_set_contact_ids_rejects_within_probe_duplicates(): probe = Probe(ndim=2, si_units="um") probe.set_contacts(positions=positions, shapes="circle", shape_params={"radius": 5}) - with pytest.raises(ValueError, match="unique within a Probe"): + with pytest.raises(ValueError): probe.set_contact_ids(["a", "a"]) @@ -108,7 +108,7 @@ def test_set_contact_ids_rejects_wrong_size(): probe = Probe(ndim=2, si_units="um") probe.set_contacts(positions=positions, shapes="circle", shape_params={"radius": 5}) - with pytest.raises(ValueError, match="do not have the same size"): + with pytest.raises(ValueError): probe.set_contact_ids(["a", "b", "c"]) @@ -503,7 +503,7 @@ def test_select_contacts_same_id_across_probes_with_probe_ids(): def test_select_contacts_probe_ids_length_mismatch_raises(): """probe_ids must have the same length as contact_ids.""" pg = _probegroup_with_contact_ids(unique=False) - with pytest.raises(ValueError, match="same length as contact_ids"): + with pytest.raises(ValueError): pg.select_contacts(["c0", "c1"], probe_ids=["0"]) @@ -515,7 +515,7 @@ def test_select_contacts_too_many_ids_without_probe_ids_raises(): pg = _probegroup_with_contact_ids(unique=False) n_unique = len(np.unique(pg.get_global_contact_ids())) too_many = [f"c{j}" for j in range(n_unique + 1)] - with pytest.raises(ValueError, match="not unique across probes"): + with pytest.raises(ValueError): pg.select_contacts(too_many) @@ -553,13 +553,13 @@ def test_select_probes_keeps_every_contact_of_matching_probes(): assert len(sub_two.probes) == 2 -def test_select_probes_follows_requested_order(): - """select_probes follows the requested probe order.""" +def test_select_probes_keeps_array_order(): + """select_probes preserves the contact order.""" pg = _probegroup_with_contact_ids(unique=False) sub = pg.select_probes(["2", "0"]) - # probe "2"'s contacts come first since it is requested first - probe_index_per_contact = sub.to_numpy(complete=True)["probe_index"] - assert probe_index_per_contact[0] == sub.probe_ids.index("2") + # even if we requested probes in a different order, the contacts are still ordered by their original global order + probe_index_per_contact = sub.to_numpy(complete=True)["probe_id"] + assert probe_index_per_contact[0] == "0" def test_select_probes_single_probe(): @@ -595,7 +595,7 @@ def test_select_probes_preserves_positions(): def test_select_probes_none_raises(): """Calling select_probes without probe_ids raises a ValueError.""" pg = _probegroup_with_contact_ids(unique=False) - with pytest.raises(ValueError, match="probe_ids must be provided"): + with pytest.raises(ValueError): pg.select_probes(None) @@ -610,10 +610,25 @@ def test_select_probes_all_probes(): def test_select_contacts_duplicated_ids_raises(): """Passing the same contact id more than once raises a ValueError.""" pg = _probegroup_with_contact_ids(unique=True) - with pytest.raises(ValueError, match="must be unique"): + with pytest.raises(ValueError): pg.select_contacts(["p0c0", "p0c1", "p0c0"]) +def test_select_contacts_preserves_order_in_array(): + """Selected contacts keep the order specified in the input array.""" + pg = _probegroup_with_contact_ids(unique=True) + contact_ids_list = [ + ["p0c1", "p0c0", "p2c5"], + ["p2c5", "p0c0", "p0c1"], + ["p0c1", "p2c5", "p0c0",] + ] + for selected_ids in contact_ids_list: + sub = pg.select_contacts(selected_ids) + contact_vector = sub.to_numpy(complete=True) + sub_ids = contact_vector["contact_ids"] + assert list(sub_ids) == selected_ids + + def test_select_contacts_preserves_positions(): """Selected contacts keep their global positions.""" pg = _probegroup_with_contact_ids(unique=True) From 4a8b212dd4a8cf13f1de9a42b913d076b380e4c4 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 25 Jun 2026 14:40:40 +0200 Subject: [PATCH 09/11] other fixes --- src/probeinterface/probegroup.py | 17 ----------------- tests/test_probegroup.py | 10 ---------- 2 files changed, 27 deletions(-) diff --git a/src/probeinterface/probegroup.py b/src/probeinterface/probegroup.py index 19b32759..2da77641 100644 --- a/src/probeinterface/probegroup.py +++ b/src/probeinterface/probegroup.py @@ -82,23 +82,6 @@ def probes(self) -> list: def probe_ids(self) -> list: return self._probe_ids - @probe_ids.setter - def probe_ids(self, probe_ids: list) -> None: - """ - Set the probe IDs for the ProbeGroup. - - Parameters - ---------- - probe_ids: list - A list of IDs to assign to the probes. - The length of the list must match the number of probes in the ProbeGroup. - """ - if len(probe_ids) != len(self._probes): - raise ValueError( - f"Length of probe_ids ({len(probe_ids)}) does not match number of probes ({len(self._probes)})" - ) - self._probe_ids = probe_ids - def _check_compatible(self, probe: Probe) -> None: if probe._probe_group is not None: raise ValueError( diff --git a/tests/test_probegroup.py b/tests/test_probegroup.py index 15e96bd0..435332d9 100644 --- a/tests/test_probegroup.py +++ b/tests/test_probegroup.py @@ -41,16 +41,6 @@ def test_probegroup(probegroup): other = ProbeGroup.from_dict(d) assert probegroup.probe_ids == other.probe_ids - # checking automatic generation of ids with new dummy probes - probegroup._probes = [] - probegroup._probe_ids = [] - for i in range(3): - probegroup.add_probe(generate_dummy_probe(), probe_id=f"probe_00{i}") - probegroup.auto_generate_contact_ids() - - for p in probegroup.probes: - assert p.contact_ids is not None - def test_probegroup_3d(): probegroup = ProbeGroup() From bbb83729d6f9b706ad95c24dc0c9ac3b19017ed4 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 26 Jun 2026 11:07:54 +0200 Subject: [PATCH 10/11] fix: ramon's suggestions --- src/probeinterface/probe.py | 4 ++-- src/probeinterface/probegroup.py | 19 +++++++++---------- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/src/probeinterface/probe.py b/src/probeinterface/probe.py index 4174f199..0fb695b1 100644 --- a/src/probeinterface/probe.py +++ b/src/probeinterface/probe.py @@ -534,7 +534,7 @@ def set_device_channel_indices(self, channel_indices: np.ndarray | list): ) self.device_channel_indices = channel_indices if self._probe_group is not None: - self._probe_group.check_global_device_wiring_and_ids() + self._probe_group._check_global_device_wiring_and_ids() def wiring_to_device(self, pathway: str, channel_offset: int = 0): """ @@ -584,7 +584,7 @@ def set_contact_ids(self, contact_ids: np.ndarray | list): self._contact_ids = contact_ids if self._probe_group is not None: - self._probe_group.check_global_device_wiring_and_ids() + self._probe_group._check_global_device_wiring_and_ids() def set_shank_ids(self, shank_ids: np.ndarray | list): """ diff --git a/src/probeinterface/probegroup.py b/src/probeinterface/probegroup.py index 2da77641..6214599f 100644 --- a/src/probeinterface/probegroup.py +++ b/src/probeinterface/probegroup.py @@ -55,7 +55,8 @@ def add_probe(self, probe: Probe, probe_id: str = None) -> None: if probe_id_annotation is not None: probe_id = probe_id_annotation else: - probe_id = f"{len(self._probes)}" + existing_int_ids = [int(pid) for pid in self._probe_ids if pid.isdigit()] + probe_id = str(max(existing_int_ids, default=-1) + 1) else: if probe_id_annotation is not None and probe_id != probe_id_annotation: warnings.warn( @@ -96,7 +97,7 @@ def _check_compatible(self, probe: Probe) -> None: # check global channel maps self._probes.append(probe) self._probe_ids.append(f"{len(self._probes)-1}") - self.check_global_device_wiring_and_ids() + self._check_global_device_wiring_and_ids() self._probes = self.probes[:-1] self._probe_ids = self.probe_ids[:-1] @@ -502,13 +503,10 @@ def select_contacts( indices = [] for contact_id, probe_id in zip(contact_ids, probe_ids): - if probe_id is None: - probe_condition = True - else: - probe_condition = all_probe_ids == probe_id - # find the contact id within the specified probe - matches = np.flatnonzero((all_contact_ids == contact_id) & probe_condition) + in_probe_mask = np.ones(all_contact_ids.size, dtype=bool) if probe_id is None else all_probe_ids == probe_id + matches = np.flatnonzero((all_contact_ids == contact_id) & in_probe_mask) + if matches.size == 0: raise ValueError(f"contact_id {contact_id} not found in probe {probe_id}") elif matches.size > 1: @@ -518,12 +516,13 @@ def select_contacts( ) if matches[0] in indices: raise ValueError( - f"contact_id {contact_id} in probe {probe_id} has a duplicate selection, please check your input" + f"contact_id {contact_id} matches multiple probes; " + "pass probe_ids to disambiguate which probe each contact_id belongs to." ) indices.append(matches[0]) return self.get_slice(indices) - def check_global_device_wiring_and_ids(self) -> None: + def _check_global_device_wiring_and_ids(self) -> None: # check unique device_channel_indices for !=-1 chans = self.get_global_device_channel_indices() keep = chans["device_channel_indices"] >= 0 From f83ae1db9f710aa903aa934be3cd2fb44eb3f1e5 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 26 Jun 2026 11:14:03 +0200 Subject: [PATCH 11/11] test: add test for probe id naming --- tests/test_probegroup.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/tests/test_probegroup.py b/tests/test_probegroup.py index 435332d9..deeb9e9e 100644 --- a/tests/test_probegroup.py +++ b/tests/test_probegroup.py @@ -636,6 +636,41 @@ def test_select_contacts_preserves_positions(): np.testing.assert_array_equal(got, expected) +# ── add_probe : default probe_id generation ───────────────────────────────── + + +def test_add_probe_default_id_does_not_recycle_after_gap(): + """ + The default probe_id must not collide with an existing id after a selection + leaves a gap in the numeric ids. Using ``len(self._probes)`` would point back + at an id that is still in use; ``max(numeric ids) + 1`` is gap-proof. + """ + pg = ProbeGroup() + for _ in range(3): + pg.add_probe(generate_dummy_probe()) + assert pg.probe_ids == ["0", "1", "2"] + + # drop the middle probe -> ids become ["0", "2"], len is 2 (would collide with "2") + sub = pg.select_probes(["0", "2"]) + assert sub.probe_ids == ["0", "2"] + + sub.add_probe(generate_dummy_probe()) + assert sub.probe_ids == ["0", "2", "3"] + + +def test_add_probe_default_id_with_non_numeric_ids(): + """ + With only non-numeric ids present, the generated id starts from "0" and can + never collide with a non-numeric name. + """ + pg = ProbeGroup() + pg.add_probe(generate_dummy_probe(), probe_id="left") + pg.add_probe(generate_dummy_probe(), probe_id="right") + + pg.add_probe(generate_dummy_probe()) + assert pg.probe_ids == ["left", "right", "0"] + + if __name__ == "__main__": probegroup = _make_probegroup()