diff --git a/src/probeinterface/probegroup.py b/src/probeinterface/probegroup.py index 0ece283..d42906a 100644 --- a/src/probeinterface/probegroup.py +++ b/src/probeinterface/probegroup.py @@ -14,7 +14,7 @@ class ProbeGroup: def __init__(self): self.probes = [] - def add_probe(self, probe: Probe): + def add_probe(self, probe: Probe) -> None: """ Add an additional probe to the ProbeGroup @@ -30,7 +30,7 @@ def add_probe(self, probe: Probe): self.probes.append(probe) probe._probe_group = self - def _check_compatible(self, probe: Probe): + def _check_compatible(self, probe: Probe) -> None: if probe._probe_group is not None: raise ValueError( "This probe is already attached to another ProbeGroup. Use probe.copy() to attach it to another ProbeGroup" @@ -47,9 +47,25 @@ def _check_compatible(self, probe: Probe): self.probes = self.probes[:-1] @property - def ndim(self): + def ndim(self) -> int: return self.probes[0].ndim + def copy(self) -> "ProbeGroup": + """ + Create a copy of the ProbeGroup + + Returns + ------- + copy: ProbeGroup + A copy of the ProbeGroup + """ + copy = ProbeGroup() + for probe in self.probes: + copy.add_probe(probe.copy()) + global_device_channel_indices = self.get_global_device_channel_indices()["device_channel_indices"] + copy.set_global_device_channel_indices(global_device_channel_indices) + return copy + def get_contact_count(self) -> int: """ Total number of channels. @@ -147,7 +163,7 @@ def to_dataframe(self, complete: bool = False) -> "pandas.DataFrame": df.index = np.arange(df.shape[0], dtype="int64") return df - def to_dict(self, array_as_list: bool = False): + def to_dict(self, array_as_list: bool = False) -> dict: """Create a dictionary of all necessary attributes. Parameters @@ -168,7 +184,7 @@ def to_dict(self, array_as_list: bool = False): return d @staticmethod - def from_dict(d: dict): + def from_dict(d: dict) -> "ProbeGroup": """Instantiate a ProbeGroup from a dictionary Parameters @@ -210,7 +226,7 @@ def get_global_device_channel_indices(self) -> np.ndarray: channels["device_channel_indices"] = arr["device_channel_indices"] return channels - def set_global_device_channel_indices(self, channels: np.ndarray | list): + def set_global_device_channel_indices(self, channels: np.ndarray | list) -> None: """ Set global indices for all probes @@ -249,7 +265,86 @@ def get_global_contact_ids(self) -> np.ndarray: contact_ids = self.to_numpy(complete=True)["contact_ids"] return contact_ids - def check_global_device_wiring_and_ids(self): + def get_global_contact_positions(self) -> np.ndarray: + """ + Gets all contact positions concatenated across probes + + Returns + ------- + contact_positions: np.ndarray + An array of the contact positions across all probes + """ + contact_positions = np.vstack([probe.contact_positions for probe in self.probes]) + return contact_positions + + def get_slice(self, selection: np.ndarray[bool | int]) -> "ProbeGroup": + """ + Get a copy of the ProbeGroup with a sub selection of contacts. + + Selection can be boolean or by index + + Parameters + ---------- + selection : np.array of bool or int (for index) + Either an np.array of bool or for desired selection of contacts + or the indices of the desired contacts + + Returns + ------- + sliced_probe_group: ProbeGroup + The sliced probe group + + """ + + n = self.get_contact_count() + + selection = np.asarray(selection) + if selection.dtype == "bool": + assert selection.shape == ( + n, + ), f"if array of bool given it must be the same size as the number of contacts {selection.shape} != {n}" + (selection_indices,) = np.nonzero(selection) + elif selection.dtype.kind == "i": + assert np.unique(selection).size == selection.size + if len(selection) > 0: + assert ( + 0 <= np.min(selection) < n + ), f"An index within your selection is out of bounds {np.min(selection)}" + assert ( + 0 <= np.max(selection) < n + ), f"An index within your selection is out of bounds {np.max(selection)}" + selection_indices = selection + else: + selection_indices = [] + else: + raise TypeError(f"selection must be bool array or int array, not of type: {type(selection)}") + + if len(selection_indices) == 0: + return ProbeGroup() + + # Map selection to indices of individual probes + ind = 0 + sliced_probes = [] + for probe in self.probes: + n = probe.get_contact_count() + probe_limits = (ind, ind + n) + ind += n + + probe_selection_indices = selection_indices[ + (selection_indices >= probe_limits[0]) & (selection_indices < probe_limits[1]) + ] + if len(probe_selection_indices) == 0: + continue + sliced_probe = probe.get_slice(probe_selection_indices - probe_limits[0]) + sliced_probes.append(sliced_probe) + + sliced_probe_group = ProbeGroup() + for probe in sliced_probes: + sliced_probe_group.add_probe(probe) + + return sliced_probe_group + + def check_global_device_wiring_and_ids(self) -> None: # check unique device_channel_indices for !=-1 chans = self.get_global_device_channel_indices() keep = chans["device_channel_indices"] >= 0 @@ -258,7 +353,7 @@ def check_global_device_wiring_and_ids(self): if valid_chans.size != np.unique(valid_chans).size: raise ValueError("channel device indices are not unique across probes") - def auto_generate_probe_ids(self, *args, **kwargs): + def auto_generate_probe_ids(self, *args, **kwargs) -> None: """ Annotate all probes with unique probe_id values. @@ -282,7 +377,7 @@ def auto_generate_probe_ids(self, *args, **kwargs): for pid, probe in enumerate(self.probes): probe.annotate(probe_id=probe_ids[pid]) - def auto_generate_contact_ids(self, *args, **kwargs): + def auto_generate_contact_ids(self, *args, **kwargs) -> None: """ Annotate all contacts with unique contact_id values. diff --git a/tests/test_probegroup.py b/tests/test_probegroup.py index 56bf97d..c942190 100644 --- a/tests/test_probegroup.py +++ b/tests/test_probegroup.py @@ -6,29 +6,27 @@ import numpy as np -def test_probegroup(): +@pytest.fixture +def probegroup(): + """Fixture: a ProbeGroup with 3 probes, each with device channel indices set.""" probegroup = ProbeGroup() - nchan = 0 for i in range(3): probe = generate_dummy_probe() probe.move([i * 100, i * 80]) n = probe.get_contact_count() - probe.set_device_channel_indices(np.arange(n)[::-1] + nchan) - shank_ids = np.ones(n) - shank_ids[: n // 2] *= i * 2 - shank_ids[n // 2 :] *= i * 2 + 1 - probe.set_shank_ids(shank_ids) + probe.set_device_channel_indices(np.arange(n) + nchan) probegroup.add_probe(probe) - nchan += n + return probegroup + +def test_probegroup(probegroup): indices = probegroup.get_global_device_channel_indices() ids = probegroup.get_global_contact_ids() df = probegroup.to_dataframe() - # ~ print(df['global_contact_ids']) arr = probegroup.to_numpy(complete=False) other = ProbeGroup.from_numpy(arr) @@ -38,12 +36,6 @@ def test_probegroup(): d = probegroup.to_dict() other = ProbeGroup.from_dict(d) - # ~ from probeinterface.plotting import plot_probe_group, plot_probe - # ~ import matplotlib.pyplot as plt - # ~ plot_probe_group(probegroup) - # ~ plot_probe_group(other) - # ~ plt.show() - # checking automatic generation of ids with new dummy probes probegroup.probes = [] for i in range(3): @@ -116,6 +108,152 @@ def test_set_contact_ids_rejects_wrong_size(): probe.set_contact_ids(["a", "b", "c"]) +# ── get_global_contact_positions() tests ──────────────────────────────────── + + +def test_get_global_contact_positions_shape(probegroup): + pos = probegroup.get_global_contact_positions() + assert pos.shape == (probegroup.get_contact_count(), probegroup.ndim) + + +def test_get_global_contact_positions_matches_per_probe(probegroup): + pos = probegroup.get_global_contact_positions() + offset = 0 + for probe in probegroup.probes: + n = probe.get_contact_count() + np.testing.assert_array_equal(pos[offset : offset + n], probe.contact_positions) + offset += n + + +def test_get_global_contact_positions_single_probe(probegroup): + pos = probegroup.get_global_contact_positions() + np.testing.assert_array_equal( + pos[: probegroup.probes[0].get_contact_count()], probegroup.probes[0].contact_positions + ) + + +def test_get_global_contact_positions_3d(): + pg = ProbeGroup() + for i in range(2): + probe = generate_dummy_probe().to_3d() + probe.move([i * 100, i * 80, i * 30]) + pg.add_probe(probe) + pos = pg.get_global_contact_positions() + assert pos.shape[1] == 3 + assert pos.shape[0] == pg.get_contact_count() + + +def test_get_global_contact_positions_reflects_move(): + """Positions should reflect probe movement.""" + pg = ProbeGroup() + probe = generate_dummy_probe() + original_pos = probe.contact_positions.copy() + probe.move([50, 60]) + pg.add_probe(probe) + pos = pg.get_global_contact_positions() + np.testing.assert_array_equal(pos, original_pos + np.array([50, 60])) + + +# ── copy() tests ──────────────────────────────────────────────────────────── + + +def test_copy_returns_new_object(probegroup): + pg_copy = probegroup.copy() + assert pg_copy is not probegroup + assert len(pg_copy.probes) == len(probegroup.probes) + for orig, copied in zip(probegroup.probes, pg_copy.probes): + assert orig is not copied + + +def test_copy_preserves_positions(probegroup): + pg_copy = probegroup.copy() + for orig, copied in zip(probegroup.probes, pg_copy.probes): + np.testing.assert_array_equal(orig.contact_positions, copied.contact_positions) + + +def test_copy_preserves_device_channel_indices(probegroup): + pg_copy = probegroup.copy() + np.testing.assert_array_equal( + probegroup.get_global_device_channel_indices(), + pg_copy.get_global_device_channel_indices(), + ) + + +def test_copy_does_not_preserve_contact_ids(probegroup): + """Probe.copy() intentionally does not copy contact_ids.""" + pg_copy = probegroup.copy() + # All contact_ids should be empty strings after copy + assert all(cid == "" for cid in pg_copy.get_global_contact_ids()) + + +def test_copy_is_independent(probegroup): + """Mutating the copy must not affect the original.""" + original_positions = probegroup.probes[0].contact_positions.copy() + pg_copy = probegroup.copy() + pg_copy.probes[0].move([999, 999]) + np.testing.assert_array_equal(probegroup.probes[0].contact_positions, original_positions) + + +# ── get_slice() tests ─────────────────────────────────────────────────────── + + +def test_get_slice_by_bool(probegroup): + total = probegroup.get_contact_count() + sel = np.zeros(total, dtype=bool) + sel[:5] = True # first 5 contacts from the first probe + sliced = probegroup.get_slice(sel) + assert sliced.get_contact_count() == 5 + + +def test_get_slice_by_index(probegroup): + indices = np.array([0, 1, 2, 33, 34]) # contacts from both probes + sliced = probegroup.get_slice(indices) + assert sliced.get_contact_count() == 5 + + +def test_get_slice_preserves_device_channel_indices(probegroup): + indices = np.array([0, 1, 2]) + sliced = probegroup.get_slice(indices) + orig_chans = probegroup.get_global_device_channel_indices()["device_channel_indices"][:3] + sliced_chans = sliced.get_global_device_channel_indices()["device_channel_indices"] + np.testing.assert_array_equal(sliced_chans, orig_chans) + + +def test_get_slice_preserves_positions(probegroup): + indices = np.array([0, 1, 2]) + sliced = probegroup.get_slice(indices) + expected = probegroup.get_global_contact_positions()[indices] + np.testing.assert_array_equal(sliced.get_global_contact_positions(), expected) + + +def test_get_slice_empty_selection(probegroup): + sliced = probegroup.get_slice(np.array([], dtype=int)) + assert sliced.get_contact_count() == 0 + assert len(sliced.probes) == 0 + + +def test_get_slice_wrong_bool_size(probegroup): + with pytest.raises(AssertionError): + probegroup.get_slice(np.array([True, False])) # wrong size + + +def test_get_slice_out_of_bounds(probegroup): + total = probegroup.get_contact_count() + with pytest.raises(AssertionError): + probegroup.get_slice(np.array([total + 10])) + + +def test_get_slice_all_contacts(probegroup): + """Slicing with all contacts should give an equivalent ProbeGroup.""" + total = probegroup.get_contact_count() + sliced = probegroup.get_slice(np.arange(total)) + assert sliced.get_contact_count() == total + np.testing.assert_array_equal( + sliced.get_global_contact_positions(), + probegroup.get_global_contact_positions(), + ) + + if __name__ == "__main__": test_probegroup() # ~ test_probegroup_3d()