Skip to content
113 changes: 104 additions & 9 deletions src/probeinterface/probegroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class ProbeGroup:
def __init__(self):
self.probes = []

def add_probe(self, probe: Probe):
def add_probe(self, probe: Probe) -> None:
"""
Add an additional probe to the ProbeGroup

Expand All @@ -30,7 +30,7 @@ def add_probe(self, probe: Probe):
self.probes.append(probe)
probe._probe_group = self

def _check_compatible(self, probe: Probe):
def _check_compatible(self, probe: Probe) -> None:
if probe._probe_group is not None:
raise ValueError(
"This probe is already attached to another ProbeGroup. Use probe.copy() to attach it to another ProbeGroup"
Expand All @@ -47,9 +47,25 @@ def _check_compatible(self, probe: Probe):
self.probes = self.probes[:-1]

@property
def ndim(self):
def ndim(self) -> int:
return self.probes[0].ndim

def copy(self) -> "ProbeGroup":
"""
Create a copy of the ProbeGroup

Returns
-------
copy: ProbeGroup
A copy of the ProbeGroup
"""
copy = ProbeGroup()
for probe in self.probes:
copy.add_probe(probe.copy())
global_device_channel_indices = self.get_global_device_channel_indices()["device_channel_indices"]
copy.set_global_device_channel_indices(global_device_channel_indices)
return copy

def get_contact_count(self) -> int:
"""
Total number of channels.
Expand Down Expand Up @@ -147,7 +163,7 @@ def to_dataframe(self, complete: bool = False) -> "pandas.DataFrame":
df.index = np.arange(df.shape[0], dtype="int64")
return df

def to_dict(self, array_as_list: bool = False):
def to_dict(self, array_as_list: bool = False) -> dict:
"""Create a dictionary of all necessary attributes.

Parameters
Expand All @@ -168,7 +184,7 @@ def to_dict(self, array_as_list: bool = False):
return d

@staticmethod
def from_dict(d: dict):
def from_dict(d: dict) -> "ProbeGroup":
"""Instantiate a ProbeGroup from a dictionary

Parameters
Expand Down Expand Up @@ -210,7 +226,7 @@ def get_global_device_channel_indices(self) -> np.ndarray:
channels["device_channel_indices"] = arr["device_channel_indices"]
return channels

def set_global_device_channel_indices(self, channels: np.ndarray | list):
def set_global_device_channel_indices(self, channels: np.ndarray | list) -> None:
"""
Set global indices for all probes

Expand Down Expand Up @@ -249,7 +265,86 @@ def get_global_contact_ids(self) -> np.ndarray:
contact_ids = self.to_numpy(complete=True)["contact_ids"]
return contact_ids

def check_global_device_wiring_and_ids(self):
def get_global_contact_positions(self) -> np.ndarray:
"""
Gets all contact positions concatenated across probes

Returns
-------
contact_positions: np.ndarray
An array of the contact positions across all probes
"""
contact_positions = np.vstack([probe.contact_positions for probe in self.probes])
return contact_positions

def get_slice(self, selection: np.ndarray[bool | int]) -> "ProbeGroup":
"""
Get a copy of the ProbeGroup with a sub selection of contacts.

Selection can be boolean or by index

Parameters
----------
selection : np.array of bool or int (for index)
Either an np.array of bool or for desired selection of contacts
or the indices of the desired contacts

Returns
-------
sliced_probe_group: ProbeGroup
The sliced probe group

"""

n = self.get_contact_count()

selection = np.asarray(selection)
if selection.dtype == "bool":
assert selection.shape == (
n,
), f"if array of bool given it must be the same size as the number of contacts {selection.shape} != {n}"
(selection_indices,) = np.nonzero(selection)
elif selection.dtype.kind == "i":
assert np.unique(selection).size == selection.size
if len(selection) > 0:
assert (
0 <= np.min(selection) < n
), f"An index within your selection is out of bounds {np.min(selection)}"
assert (
0 <= np.max(selection) < n
), f"An index within your selection is out of bounds {np.max(selection)}"
selection_indices = selection
else:
selection_indices = []
else:
raise TypeError(f"selection must be bool array or int array, not of type: {type(selection)}")

if len(selection_indices) == 0:
return ProbeGroup()

# Map selection to indices of individual probes
ind = 0
sliced_probes = []
for probe in self.probes:
n = probe.get_contact_count()
probe_limits = (ind, ind + n)
ind += n

probe_selection_indices = selection_indices[
(selection_indices >= probe_limits[0]) & (selection_indices < probe_limits[1])
]
if len(probe_selection_indices) == 0:
continue
sliced_probe = probe.get_slice(probe_selection_indices - probe_limits[0])
sliced_probes.append(sliced_probe)

sliced_probe_group = ProbeGroup()
for probe in sliced_probes:
sliced_probe_group.add_probe(probe)

return sliced_probe_group

def check_global_device_wiring_and_ids(self) -> None:
# check unique device_channel_indices for !=-1
chans = self.get_global_device_channel_indices()
keep = chans["device_channel_indices"] >= 0
Expand All @@ -258,7 +353,7 @@ def check_global_device_wiring_and_ids(self):
if valid_chans.size != np.unique(valid_chans).size:
raise ValueError("channel device indices are not unique across probes")

def auto_generate_probe_ids(self, *args, **kwargs):
def auto_generate_probe_ids(self, *args, **kwargs) -> None:
"""
Annotate all probes with unique probe_id values.

Expand All @@ -282,7 +377,7 @@ def auto_generate_probe_ids(self, *args, **kwargs):
for pid, probe in enumerate(self.probes):
probe.annotate(probe_id=probe_ids[pid])

def auto_generate_contact_ids(self, *args, **kwargs):
def auto_generate_contact_ids(self, *args, **kwargs) -> None:
"""
Annotate all contacts with unique contact_id values.

Expand Down
168 changes: 153 additions & 15 deletions tests/test_probegroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,27 @@
import numpy as np


def test_probegroup():
@pytest.fixture
def probegroup():
"""Fixture: a ProbeGroup with 3 probes, each with device channel indices set."""
probegroup = ProbeGroup()

nchan = 0
for i in range(3):
probe = generate_dummy_probe()
probe.move([i * 100, i * 80])
n = probe.get_contact_count()
probe.set_device_channel_indices(np.arange(n)[::-1] + nchan)
shank_ids = np.ones(n)
shank_ids[: n // 2] *= i * 2
shank_ids[n // 2 :] *= i * 2 + 1
probe.set_shank_ids(shank_ids)
probe.set_device_channel_indices(np.arange(n) + nchan)
probegroup.add_probe(probe)

nchan += n
return probegroup


def test_probegroup(probegroup):
indices = probegroup.get_global_device_channel_indices()

ids = probegroup.get_global_contact_ids()

df = probegroup.to_dataframe()
# ~ print(df['global_contact_ids'])

arr = probegroup.to_numpy(complete=False)
other = ProbeGroup.from_numpy(arr)
Expand All @@ -38,12 +36,6 @@ def test_probegroup():
d = probegroup.to_dict()
other = ProbeGroup.from_dict(d)

# ~ from probeinterface.plotting import plot_probe_group, plot_probe
# ~ import matplotlib.pyplot as plt
# ~ plot_probe_group(probegroup)
# ~ plot_probe_group(other)
# ~ plt.show()

# checking automatic generation of ids with new dummy probes
probegroup.probes = []
for i in range(3):
Expand Down Expand Up @@ -116,6 +108,152 @@ def test_set_contact_ids_rejects_wrong_size():
probe.set_contact_ids(["a", "b", "c"])


# ── get_global_contact_positions() tests ────────────────────────────────────


def test_get_global_contact_positions_shape(probegroup):
pos = probegroup.get_global_contact_positions()
assert pos.shape == (probegroup.get_contact_count(), probegroup.ndim)


def test_get_global_contact_positions_matches_per_probe(probegroup):
pos = probegroup.get_global_contact_positions()
offset = 0
for probe in probegroup.probes:
n = probe.get_contact_count()
np.testing.assert_array_equal(pos[offset : offset + n], probe.contact_positions)
offset += n


def test_get_global_contact_positions_single_probe(probegroup):
pos = probegroup.get_global_contact_positions()
np.testing.assert_array_equal(
pos[: probegroup.probes[0].get_contact_count()], probegroup.probes[0].contact_positions
)


def test_get_global_contact_positions_3d():
pg = ProbeGroup()
for i in range(2):
probe = generate_dummy_probe().to_3d()
probe.move([i * 100, i * 80, i * 30])
pg.add_probe(probe)
pos = pg.get_global_contact_positions()
assert pos.shape[1] == 3
assert pos.shape[0] == pg.get_contact_count()


def test_get_global_contact_positions_reflects_move():
"""Positions should reflect probe movement."""
pg = ProbeGroup()
probe = generate_dummy_probe()
original_pos = probe.contact_positions.copy()
probe.move([50, 60])
pg.add_probe(probe)
pos = pg.get_global_contact_positions()
np.testing.assert_array_equal(pos, original_pos + np.array([50, 60]))


# ── copy() tests ────────────────────────────────────────────────────────────


def test_copy_returns_new_object(probegroup):
pg_copy = probegroup.copy()
assert pg_copy is not probegroup
assert len(pg_copy.probes) == len(probegroup.probes)
for orig, copied in zip(probegroup.probes, pg_copy.probes):
assert orig is not copied


def test_copy_preserves_positions(probegroup):
pg_copy = probegroup.copy()
for orig, copied in zip(probegroup.probes, pg_copy.probes):
np.testing.assert_array_equal(orig.contact_positions, copied.contact_positions)


def test_copy_preserves_device_channel_indices(probegroup):
pg_copy = probegroup.copy()
np.testing.assert_array_equal(
probegroup.get_global_device_channel_indices(),
pg_copy.get_global_device_channel_indices(),
)


def test_copy_does_not_preserve_contact_ids(probegroup):
"""Probe.copy() intentionally does not copy contact_ids."""
pg_copy = probegroup.copy()
# All contact_ids should be empty strings after copy
assert all(cid == "" for cid in pg_copy.get_global_contact_ids())


def test_copy_is_independent(probegroup):
"""Mutating the copy must not affect the original."""
original_positions = probegroup.probes[0].contact_positions.copy()
pg_copy = probegroup.copy()
pg_copy.probes[0].move([999, 999])
np.testing.assert_array_equal(probegroup.probes[0].contact_positions, original_positions)


# ── get_slice() tests ───────────────────────────────────────────────────────


def test_get_slice_by_bool(probegroup):
total = probegroup.get_contact_count()
sel = np.zeros(total, dtype=bool)
sel[:5] = True # first 5 contacts from the first probe
sliced = probegroup.get_slice(sel)
assert sliced.get_contact_count() == 5


def test_get_slice_by_index(probegroup):
indices = np.array([0, 1, 2, 33, 34]) # contacts from both probes
sliced = probegroup.get_slice(indices)
assert sliced.get_contact_count() == 5


def test_get_slice_preserves_device_channel_indices(probegroup):
indices = np.array([0, 1, 2])
sliced = probegroup.get_slice(indices)
orig_chans = probegroup.get_global_device_channel_indices()["device_channel_indices"][:3]
sliced_chans = sliced.get_global_device_channel_indices()["device_channel_indices"]
np.testing.assert_array_equal(sliced_chans, orig_chans)


def test_get_slice_preserves_positions(probegroup):
indices = np.array([0, 1, 2])
sliced = probegroup.get_slice(indices)
expected = probegroup.get_global_contact_positions()[indices]
np.testing.assert_array_equal(sliced.get_global_contact_positions(), expected)


def test_get_slice_empty_selection(probegroup):
sliced = probegroup.get_slice(np.array([], dtype=int))
assert sliced.get_contact_count() == 0
assert len(sliced.probes) == 0


def test_get_slice_wrong_bool_size(probegroup):
with pytest.raises(AssertionError):
probegroup.get_slice(np.array([True, False])) # wrong size


def test_get_slice_out_of_bounds(probegroup):
total = probegroup.get_contact_count()
with pytest.raises(AssertionError):
probegroup.get_slice(np.array([total + 10]))


def test_get_slice_all_contacts(probegroup):
"""Slicing with all contacts should give an equivalent ProbeGroup."""
total = probegroup.get_contact_count()
sliced = probegroup.get_slice(np.arange(total))
assert sliced.get_contact_count() == total
np.testing.assert_array_equal(
sliced.get_global_contact_positions(),
probegroup.get_global_contact_positions(),
)


if __name__ == "__main__":
test_probegroup()
# ~ test_probegroup_3d()
Loading