diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 9dc270d38d..4e0bdc1b6a 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -213,6 +213,14 @@ def id_to_index(self, id) -> int: return ind def annotate(self, **new_annotations) -> None: + """Adds annotations. + + Parameters + ---------- + **new_annotations : dict + Key-value pairs of annotations to add. If an annotation key already exists, + it will be overwritten. + """ self._annotations.update(new_annotations) def set_annotation(self, annotation_key: str, value: Any, overwrite=False) -> None: @@ -236,6 +244,24 @@ def set_annotation(self, annotation_key: str, value: Any, overwrite=False) -> No else: raise ValueError(f"{annotation_key} is already an annotation key. Use 'overwrite=True' to overwrite it") + def delete_annotation(self, annotation_key: str) -> None: + """Deletes existing annotation. + + Parameters + ---------- + annotation_key : str + The annotation key to delete + + Raises + ------ + ValueError + If the annotation key does not exist + """ + if annotation_key in self._annotations.keys(): + del self._annotations[annotation_key] + else: + raise ValueError(f"{annotation_key} is not an annotation key") + def get_preferred_mp_context(self): """ Get the preferred context for multiprocessing. @@ -434,6 +460,15 @@ def copy_metadata( if self._preferred_mp_context is not None: other._preferred_mp_context = self._preferred_mp_context + if not only_main: + self._extra_metadata_copy(other) + + def _extra_metadata_copy(self, other: BaseExtractor): + """ + This is a hook to copy extra metadata that is not in the annotations/properties dict. + """ + pass + def to_dict( self, include_annotations: bool = False, @@ -567,6 +602,8 @@ def to_dict( folder_metadata = Path(folder_metadata).resolve().absolute().relative_to(relative_to) dump_dict["folder_metadata"] = str(folder_metadata) + self._extra_metadata_to_dict(dump_dict) + return dump_dict @staticmethod @@ -855,6 +892,14 @@ def _extra_metadata_to_folder(self, folder): # This implemented in BaseRecording for probe pass + def _extra_metadata_from_dict(self, dump_dict): + # This implemented in BaseRecording for probe + pass + + def _extra_metadata_to_dict(self, dump_dict): + # This implemented in BaseRecording for probe + pass + def save(self, **kwargs) -> BaseExtractor: """ Save a SpikeInterface object. @@ -1154,6 +1199,8 @@ def _load_extractor_from_dict(dic) -> BaseExtractor: for k, v in dic["properties"].items(): extractor.set_property(k, v) + extractor._extra_metadata_from_dict(dic) + return extractor diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 75bd47597b..8dc27f56c0 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -21,7 +21,6 @@ class BaseRecording(BaseRecordingSnippets): _main_annotations = BaseRecordingSnippets._main_annotations + ["is_filtered"] _main_properties = [ "group", - "location", "gain_to_uV", "offset_to_uV", "gain_to_physical_unit", @@ -591,6 +590,9 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): kwargs, job_kwargs = split_job_kwargs(save_kwargs) if format == "binary": + from .binaryfolder import BinaryFolderRecording + from .binaryrecordingextractor import BinaryRecordingExtractor + folder = kwargs["folder"] file_paths = [folder / f"traces_cached_seg{i}.raw" for i in range(self.get_num_segments())] dtype = kwargs.get("dtype", None) or self.get_dtype() @@ -598,8 +600,6 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): write_binary_recording(self, file_paths=file_paths, dtype=dtype, verbose=verbose, **job_kwargs) - from .binaryrecordingextractor import BinaryRecordingExtractor - # This is created so it can be saved as json because the `BinaryFolderRecording` requires it loading # See the __init__ of `BinaryFolderRecording` binary_rec = BinaryRecordingExtractor( @@ -616,8 +616,9 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): offset_to_uV=self.get_channel_offsets(), ) binary_rec.dump(folder / "binary.json", relative_to=folder) - - from .binaryfolder import BinaryFolderRecording + if self.has_probe(): + probegroup = self.get_probegroup() + write_probeinterface(folder / "probe.json", probegroup) cached = BinaryFolderRecording(folder_path=folder) @@ -648,10 +649,7 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): else: raise ValueError(f"format {format} not supported") - if self.get_property("contact_vector") is not None: - probegroup = self.get_probegroup() - cached.set_probegroup(probegroup) - + # TODO: write binary should save timestamps too for segment_index in range(self.get_num_segments()): if self.has_time_vector(segment_index): # the use of get_times is preferred since timestamps are converted to array @@ -676,7 +674,7 @@ def _extra_metadata_from_folder(self, folder): def _extra_metadata_to_folder(self, folder): # save probe - if self.get_property("contact_vector") is not None: + if self.has_probe(): probegroup = self.get_probegroup() write_probeinterface(folder / "probe.json", probegroup) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 58e91ec35c..db033d2393 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -1,5 +1,5 @@ from pathlib import Path - +import warnings import numpy as np from probeinterface import Probe, ProbeGroup, write_probeinterface, read_probeinterface, select_axes @@ -19,6 +19,7 @@ def __init__(self, sampling_frequency: float, channel_ids: list[str, int], dtype BaseExtractor.__init__(self, channel_ids) self._sampling_frequency = float(sampling_frequency) self._dtype = np.dtype(dtype) + self._probegroup = None @property def channel_ids(self): @@ -51,15 +52,31 @@ def has_scaleable_traces(self) -> bool: return True def has_probe(self) -> bool: - return "contact_vector" in self.get_property_keys() + if self._probegroup is None and self.get_property("contact_vector") is not None: + # if contact_vector is present we can reconstruct the probe + self._probegroup = self._build_probegroup_from_properties() + return self._probegroup is not None + + def has_3d_probe(self) -> bool: + if self.has_probe(): + probe = self.get_probegroup().probes[0] + return probe.ndim == 3 + else: + return False def has_channel_location(self) -> bool: - return self.has_probe() or "location" in self.get_property_keys() + return self.has_probe() def is_filtered(self): # the is_filtered is handle with annotation return self._annotations.get("is_filtered", False) + def reset_probe(self): + """ + Removes probe information + """ + self._probegroup = None + def set_probe(self, probe, group_mode="auto", in_place=False): """ Attach a list of Probe object to a recording. @@ -85,12 +102,43 @@ def set_probe(self, probe, group_mode="auto", in_place=False): probegroup.add_probe(probe) return self._set_probes(probegroup, group_mode=group_mode, in_place=in_place) - def set_probegroup(self, probegroup, group_mode="auto", in_place=False): - return self._set_probes(probegroup, group_mode=group_mode, in_place=in_place) + def set_probegroup(self, probegroup, group_mode="auto", in_place=False, raise_if_overlapping_probes=True): + """ + Attach a ProbeGroup to a recording. + For this ProbeGroup.get_global_device_channel_indices() is used to link contacts to recording channels. + If some contacts of the probe group are not connected (device_channel_indices=-1) + then the recording is "sliced" and only connected channel are kept. + + The probe group order is not kept. Channel ids are re-ordered to match the channel_ids of the recording. - def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): + Parameters + ---------- + probe_or_probegroup: Probe, list of Probe, or ProbeGroup + The probe(s) to be attached to the recording + group_mode: "auto" | "by_probe" | "by_shank" | "by_side", default: "auto" + How to add the "group" property. + "auto" is the best splitting possible that can be all at once when multiple probes, multiple shanks and two sides are present. + in_place: bool + False by default. + Useful internally when extractor do self.set_probegroup(probe) + raise_if_overlapping_probes: bool + If True, raises an error if the probes overlap. If False, it will just warn + + Returns + ------- + sub_recording: BaseRecording + A view of the recording (ChannelSlice or clone or itself) """ - Attach a list of Probe objects to a recording. + return self._set_probes( + probegroup, + group_mode=group_mode, + in_place=in_place, + raise_if_overlapping_probes=raise_if_overlapping_probes, + ) + + def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False, raise_if_overlapping_probes=True): + """ + Attach a list of Probe objects or a ProbeGroup to a recording. For this Probe.device_channel_indices is used to link contacts to recording channels. If some contacts of the Probe are not connected (device_channel_indices=-1) then the recording is "sliced" and only connected channel are kept. @@ -100,7 +148,7 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): Parameters ---------- - probe_or_probegroup: Probe, list of Probe, or ProbeGroup + probe_or_probegroup: Probe, list of Probes, ProbeGroup, or dict The probe(s) to be attached to the recording group_mode: "auto" | "by_probe" | "by_shank" | "by_side", default: "auto" How to add the "group" property. @@ -108,6 +156,8 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): in_place: bool False by default. Useful internally when extractor do self.set_probegroup(probe) + raise_if_overlapping_probes: bool + If True, raises an error if the probes overlap. If False, it will just warn Returns ------- @@ -132,12 +182,14 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): probegroup = ProbeGroup() for probe in probe_or_probegroup: probegroup.add_probe(probe) + elif isinstance(probe_or_probegroup, dict): + probegroup = ProbeGroup.from_dict(probe_or_probegroup) else: raise ValueError("must give Probe or ProbeGroup or list of Probe") # check that the probe do not overlap num_probes = len(probegroup.probes) - if num_probes > 1: + if num_probes > 1 and raise_if_overlapping_probes: check_probe_do_not_overlap(probegroup.probes) # handle not connected channels @@ -145,36 +197,36 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): probe.device_channel_indices is not None for probe in probegroup.probes ), "Probe must have device_channel_indices" - # this is a vector with complex fileds (dataframe like) that handle all contact attr probe_as_numpy_array = probegroup.to_numpy(complete=True) - - # keep only connected contact ( != -1) - keep = probe_as_numpy_array["device_channel_indices"] >= 0 + device_channel_indices = probegroup.get_global_device_channel_indices()["device_channel_indices"] + keep = device_channel_indices >= 0 if np.any(~keep): warn("The given probes have unconnected contacts: they are removed") - + device_channel_indices = device_channel_indices[keep] probe_as_numpy_array = probe_as_numpy_array[keep] + if len(device_channel_indices) > 0: + probegroup = probegroup.get_slice(device_channel_indices) + order = np.argsort(device_channel_indices) + device_channel_indices = device_channel_indices[order] + probegroup.set_global_device_channel_indices(np.arange(len(device_channel_indices))) + + # check TODO: Where did this came from? + number_of_device_channel_indices = np.max(list(device_channel_indices) + [0]) + if number_of_device_channel_indices >= self.get_num_channels(): + error_msg = ( + f"The given Probe either has 'device_channel_indices' that does not match channel count \n" + f"{len(device_channel_indices)} vs {self.get_num_channels()} \n" + f"or it's max index {number_of_device_channel_indices} is the same as the number of channels {self.get_num_channels()} \n" + f"If using all channels remember that python is 0-indexed so max device_channel_index should be {self.get_num_channels() - 1} \n" + f"device_channel_indices are the following: {device_channel_indices} \n" + f"recording channels are the following: {self.get_channel_ids()} \n" + ) + raise ValueError(error_msg) + else: + warn("No connected channel in the probe! The probe will be attached but no channel will be selected.") + probegroup = ProbeGroup() # empty probegroup - device_channel_indices = probe_as_numpy_array["device_channel_indices"] - order = np.argsort(device_channel_indices) - device_channel_indices = device_channel_indices[order] - - # check TODO: Where did this came from? - number_of_device_channel_indices = np.max(list(device_channel_indices) + [0]) - if number_of_device_channel_indices >= self.get_num_channels(): - error_msg = ( - f"The given Probe either has 'device_channel_indices' that does not match channel count \n" - f"{len(device_channel_indices)} vs {self.get_num_channels()} \n" - f"or it's max index {number_of_device_channel_indices} is the same as the number of channels {self.get_num_channels()} \n" - f"If using all channels remember that python is 0-indexed so max device_channel_index should be {self.get_num_channels() - 1} \n" - f"device_channel_indices are the following: {device_channel_indices} \n" - f"recording channels are the following: {self.get_channel_ids()} \n" - ) - raise ValueError(error_msg) - - new_channel_ids = self.get_channel_ids()[device_channel_indices] - probe_as_numpy_array = probe_as_numpy_array[order] - probe_as_numpy_array["device_channel_indices"] = np.arange(probe_as_numpy_array.size, dtype="int64") + new_channel_ids = self.channel_ids[device_channel_indices] # create recording : channel slice or clone or self if in_place: @@ -187,21 +239,8 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): else: sub_recording = self.select_channels(new_channel_ids) - # create a vector that handle all contacts in property - sub_recording.set_property("contact_vector", probe_as_numpy_array, ids=None) - - # planar_contour is saved in annotations - for probe_index, probe in enumerate(probegroup.probes): - contour = probe.probe_planar_contour - if contour is not None: - sub_recording.set_annotation(f"probe_{probe_index}_planar_contour", contour, overwrite=True) - - # duplicate positions to "locations" property - ndim = probegroup.ndim - locations = np.zeros((probe_as_numpy_array.size, ndim), dtype="float64") - for i, dim in enumerate(["x", "y", "z"][:ndim]): - locations[:, i] = probe_as_numpy_array[dim] - sub_recording.set_property("location", locations, ids=None) + # Set probegroup + sub_recording._probegroup = probegroup # handle groups has_shank_id = "shank_ids" in probe_as_numpy_array.dtype.fields @@ -232,17 +271,11 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): groups[mask] = group sub_recording.set_property("group", groups, ids=None) - # add probe annotations to recording - probes_info = [] - for probe in probegroup.probes: - probes_info.append(probe.annotations) - sub_recording.annotate(probes_info=probes_info) - return sub_recording def get_probe(self): probes = self.get_probes() - assert len(probes) == 1, "there are several probe use .get_probes() or get_probegroup()" + assert len(probes) == 1, "There are several probe use .get_probes() or get_probegroup()" return probes[0] def get_probes(self): @@ -250,11 +283,22 @@ def get_probes(self): return probegroup.probes def get_probegroup(self): + if self._probegroup is not None: + return self._probegroup + else: # Backward compatibility: if contact_vector is present we reconstruct the probe, otherwise we look for + probegroup = self._build_probegroup_from_properties() + if probegroup is None: + raise ValueError("There is no Probe attached to this recording. Use set_probe(...) to attach one.") + self._probegroup = probegroup + return probegroup + + def _build_probegroup_from_properties(self): + # location and create a dummy probe arr = self.get_property("contact_vector") if arr is None: positions = self.get_property("location") if positions is None: - raise ValueError("There is no Probe attached to this recording. Use set_probe(...) to attach one.") + return None else: warn("There is no Probe attached to this recording. Creating a dummy one with contact positions") probe = self.create_dummy_probe_from_locations(positions) @@ -273,8 +317,15 @@ def get_probegroup(self): contour = self.get_annotation(f"probe_{probe_index}_planar_contour") if contour is not None: probe.set_planar_contour(contour) + self.delete_annotation(f"probe_{probe_index}_planar_contour") + # delete contact_vector as it is not needed anymore + self.delete_property("contact_vector") return probegroup + def _extra_metadata_copy(self, other): + if self._probegroup is not None: + other._probegroup = self._probegroup.copy() + def _extra_metadata_from_folder(self, folder): # load probe folder = Path(folder) @@ -284,10 +335,22 @@ def _extra_metadata_from_folder(self, folder): def _extra_metadata_to_folder(self, folder): # save probe - if self.get_property("contact_vector") is not None: + if self.has_probe(): probegroup = self.get_probegroup() write_probeinterface(folder / "probe.json", probegroup) + def _extra_metadata_from_dict(self, dump_dict): + # load probe + if "probegroup" in dump_dict: + probegroup = dump_dict["probegroup"] + self.set_probegroup(probegroup, in_place=True) + + def _extra_metadata_to_dict(self, dump_dict): + # save probe + if self.has_probe(): + probegroup = self.get_probegroup() + dump_dict["probegroup"] = probegroup + def create_dummy_probe_from_locations(self, locations, shape="circle", shape_params={"radius": 1}, axes="xy"): """ Creates a "dummy" probe based on locations. @@ -330,51 +393,55 @@ def set_dummy_probe_from_locations(self, locations, shape="circle", shape_params ---------- locations : np.array Array with channel locations (num_channels, ndim) [ndim can be 2 or 3] - shape : str, default: default: "circle" + shape : str, default: "circle" Electrode shapes shape_params : dict, default: {"radius": 1} Shape parameters axes : "xy" | "yz" | "xz", default: "xy" If ndim is 3, indicates the axes that define the plane of the electrodes """ - probe = self.create_dummy_probe_from_locations(locations, shape=shape, shape_params=shape_params, axes=axes) + probe = self.create_dummy_probe_from_locations( + np.array(locations), shape=shape, shape_params=shape_params, axes=axes + ) self.set_probe(probe, in_place=True) def set_channel_locations(self, locations, channel_ids=None): - if self.get_property("contact_vector") is not None: - raise ValueError("set_channel_locations(..) destroys the probe description, prefer _set_probes(..)") - self.set_property("location", locations, ids=channel_ids) + warnings.warn( + ( + "set_channel_locations() is deprecated and will be removed in version 0.106.0. " + "If you want to set probe information, use `set_dummy_probe_from_locations()`." + ), + DeprecationWarning, + stacklevel=2, + ) + self.set_dummy_probe_from_locations(locations, axes="xy") def get_channel_locations(self, channel_ids=None, axes: str = "xy") -> np.ndarray: if channel_ids is None: channel_ids = self.get_channel_ids() channel_indices = self.ids_to_indices(channel_ids) - contact_vector = self.get_property("contact_vector") - if contact_vector is not None: - # here we bypass the probe reconstruction so this works both for probe and probegroup - ndim = len(axes) - all_positions = np.zeros((contact_vector.size, ndim), dtype="float64") - for i, dim in enumerate(axes): - all_positions[:, i] = contact_vector[dim] - positions = all_positions[channel_indices] - return positions - else: - locations = self.get_property("location") - if locations is None: - raise Exception("There are no channel locations") - locations = np.asarray(locations)[channel_indices] - return select_axes(locations, axes) + if not self.has_probe(): + raise ValueError("get_channel_locations(..) needs a probe to be attached to the recording") + probegroup = self.get_probegroup() + contact_positions = probegroup.get_global_contact_positions() + return select_axes(contact_positions, axes)[channel_indices] - def has_3d_locations(self) -> bool: - return self.get_property("location").shape[1] == 3 + def is_probe_3d(self) -> bool: + if not self.has_probe(): + raise ValueError("is_probe_3d() needs a probe to be attached to the recording") + probegroup = self.get_probegroup() + return probegroup.ndim == 3 def clear_channel_locations(self, channel_ids=None): - if channel_ids is None: - n = self.get_num_channel() - else: - n = len(channel_ids) - locations = np.zeros((n, 2)) * np.nan - self.set_property("location", locations, ids=channel_ids) + warnings.warn( + ( + "clear_channel_locations() is deprecated and will be removed in version 0.106.0. " + "If you want to remove probe information, use `reset_probe()`." + ), + DeprecationWarning, + stacklevel=2, + ) + self.reset_probe() def set_channel_groups(self, groups, channel_ids=None): if "probes" in self._annotations: @@ -429,7 +496,7 @@ def planarize(self, axes: str = "xy"): BaseRecording The recording with 2D positions """ - assert self.has_3d_locations, "The 'planarize' function needs a recording with 3d locations" + assert self.has_3d_probe(), "The 'planarize' function needs a recording with 3d locations" assert len(axes) == 2, "You need to specify 2 dimensions (e.g. 'xy', 'zy')" probe2d = self.get_probe().to_2d(axes=axes) diff --git a/src/spikeinterface/core/basesnippets.py b/src/spikeinterface/core/basesnippets.py index b56a093ccc..39e8c51225 100644 --- a/src/spikeinterface/core/basesnippets.py +++ b/src/spikeinterface/core/basesnippets.py @@ -11,7 +11,7 @@ class BaseSnippets(BaseRecordingSnippets): Abstract class representing several multichannel snippets. """ - _main_properties = ["group", "location", "gain_to_uV", "offset_to_uV"] + _main_properties = ["group", "gain_to_uV", "offset_to_uV"] _main_features = [] def __init__(self, sampling_frequency: float, nbefore: int | None, snippet_len: int, channel_ids: list, dtype): @@ -259,9 +259,9 @@ def _save(self, format="npy", **save_kwargs): else: raise ValueError(f"format {format} not supported") - if self.get_property("contact_vector") is not None: + if self.has_probe(): probegroup = self.get_probegroup() - cached.set_probegroup(probegroup) + cached.set_probegroup(probegroup, in_place=True) return cached diff --git a/src/spikeinterface/core/binaryfolder.py b/src/spikeinterface/core/binaryfolder.py index 4b9d7b7d09..8d8ed3c206 100644 --- a/src/spikeinterface/core/binaryfolder.py +++ b/src/spikeinterface/core/binaryfolder.py @@ -3,6 +3,8 @@ import numpy as np +from probeinterface import read_probeinterface + from .binaryrecordingextractor import BinaryRecordingExtractor from .core_tools import define_function_from_class, make_paths_absolute diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index 0da4797440..4e933c9e9d 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -2,6 +2,7 @@ import numpy as np +from probeinterface import ProbeGroup from .baserecording import BaseRecording, BaseRecordingSegment @@ -90,31 +91,28 @@ def __init__(self, recording_list_or_dict=None, renamed_channel_ids=None, record break for prop_name, prop_values in property_dict.items(): - if prop_name == "contact_vector": - # remap device channel indices correctly - prop_values["device_channel_indices"] = np.arange(self.get_num_channels()) self.set_property(key=prop_name, values=prop_values) - # if locations are present, check that they are all different! - if "location" in self.get_property_keys(): - location_tuple = [tuple(loc) for loc in self.get_property("location")] - assert len(set(location_tuple)) == self.get_num_channels(), ( - "Locations are not unique! " "Cannot aggregate recordings!" - ) - - planar_contour_keys = [ - key for recording in recording_list for key in recording.get_annotation_keys() if "planar_contour" in key - ] - if len(planar_contour_keys) > 0: - if all( - k == planar_contour_keys[0] for k in planar_contour_keys - ): # we add the 'planar_contour' annotations only if there is a unique one in the recording_list - planar_contour_key = planar_contour_keys[0] - collect_planar_contours = [] - for rec in recording_list: - collect_planar_contours.append(rec.get_annotation(planar_contour_key)) - if all(np.array_equal(arr, collect_planar_contours[0]) for arr in collect_planar_contours): - self.set_annotation(planar_contour_key, collect_planar_contours[0]) + # Aggregate probe information + all_probegroups = [rec.get_probegroup() for rec in recording_list if rec.has_probe()] + if len(all_probegroups) == len(recording_list): + # check that contact positions are unique across all recordings + all_positions = [] + for probegroup in all_probegroups: + for probe in probegroup.probes: + all_positions.extend(probe.contact_positions) + assert len(np.unique(np.array(all_positions), axis=0)) == len( + all_positions + ), "Contact positions are not unique! Cannot aggregate recordings." + + # Now make a new probegroup with all probes and set global device channel indices + all_probes = [] + for probegroup in all_probegroups: + all_probes.extend([p.copy() for p in probegroup.probes]) + probegroup_agg = ProbeGroup() + probegroup_agg.probes = all_probes + probegroup_agg.set_global_device_channel_indices(np.arange(num_all_channels)) + self.set_probegroup(probegroup_agg, in_place=True, raise_if_overlapping_probes=False) # finally add segments, we need a channel mapping ch_id = 0 diff --git a/src/spikeinterface/core/channelslice.py b/src/spikeinterface/core/channelslice.py index 67d25b2925..748d052530 100644 --- a/src/spikeinterface/core/channelslice.py +++ b/src/spikeinterface/core/channelslice.py @@ -62,10 +62,11 @@ def __init__(self, parent_recording, channel_ids=None, renamed_channel_ids=None) self._parent = parent_recording # change the wiring of the probe - contact_vector = self.get_property("contact_vector") - if contact_vector is not None: - contact_vector["device_channel_indices"] = np.arange(len(channel_ids), dtype="int64") - self.set_property("contact_vector", contact_vector) + if self._parent.has_probe(): + parent_probegroup = self._parent.get_probegroup() + sliced_probegroup = parent_probegroup.get_slice(self._parent_channel_indices) + sliced_probegroup.set_global_device_channel_indices(np.arange(len(self._channel_ids))) + self.set_probegroup(sliced_probegroup, in_place=True) # update dump dict self._kwargs = { @@ -152,10 +153,11 @@ def __init__(self, parent_snippets, channel_ids=None, renamed_channel_ids=None): parent_snippets.copy_metadata(self, only_main=False, ids=self._channel_ids) # change the wiring of the probe - contact_vector = self.get_property("contact_vector") - if contact_vector is not None: - contact_vector["device_channel_indices"] = np.arange(len(channel_ids), dtype="int64") - self.set_property("contact_vector", contact_vector) + if self._parent_snippets.has_probe(): + parent_probegroup = self._parent_snippets.get_probegroup() + sliced_probegroup = parent_probegroup.get_slice(self._parent_channel_indices) + sliced_probegroup.set_global_device_channel_indices(np.arange(len(self._channel_ids))) + self.set_probegroup(sliced_probegroup, in_place=True) # update dump dict self._kwargs = { diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index ed98613553..ba11642a5e 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -10,6 +10,7 @@ from collections import namedtuple import inspect +from probeinterface import ProbeGroup import numpy as np @@ -148,6 +149,9 @@ def default(self, obj): if isinstance(obj, Motion): return obj.to_dict() + if isinstance(obj, ProbeGroup): + return obj.to_dict() + # The base-class handles the assertion return super().default(obj) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 48eb2d7fd4..19e54d1fc9 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -1023,8 +1023,6 @@ def get_rec_attributes(recording): The rec_attributes dictionary """ properties_to_attrs = deepcopy(recording._properties) - if "contact_vector" in properties_to_attrs: - del properties_to_attrs["contact_vector"] rec_attributes = dict( channel_ids=recording.channel_ids, sampling_frequency=recording.get_sampling_frequency(), diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 8e16757bcc..e712b881d5 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -365,7 +365,6 @@ def create( ) # check that multiple probes are non-overlapping all_probes = recording.get_probegroup().probes - check_probe_do_not_overlap(all_probes) if has_exceeding_spikes(sorting=sorting, recording=recording): warnings.warn( diff --git a/src/spikeinterface/core/tests/test_baserecording.py b/src/spikeinterface/core/tests/test_baserecording.py index 1ebeb677c6..c724cbba98 100644 --- a/src/spikeinterface/core/tests/test_baserecording.py +++ b/src/spikeinterface/core/tests/test_baserecording.py @@ -196,7 +196,13 @@ def test_BaseRecording(create_cache_folder): probe.create_auto_shape() rec_p = rec.set_probe(probe, group_mode="auto") + positions2 = rec_p.get_channel_locations() + assert np.array_equal(positions2, [[0, 30.0], [0.0, 0.0]]) + rec_p = rec.set_probe(probe, group_mode="by_shank") + positions2 = rec_p.get_channel_locations() + assert np.array_equal(positions2, [[0, 30.0], [0.0, 0.0]]) + rec_p = rec.set_probe(probe, group_mode="by_probe") positions2 = rec_p.get_channel_locations() assert np.array_equal(positions2, [[0, 30.0], [0.0, 0.0]]) @@ -204,7 +210,6 @@ def test_BaseRecording(create_cache_folder): probe2 = rec_p.get_probe() positions3 = probe2.contact_positions assert np.array_equal(positions2, positions3) - assert np.array_equal(probe2.device_channel_indices, [0, 1]) # test save with probe @@ -284,8 +289,9 @@ def test_BaseRecording(create_cache_folder): rec_int16.set_property("offset_to_uV", [0.0] * 5) # Test deprecated return_scaled parameter - traces_float32_old = rec_int16.get_traces(return_scaled=True) # Keep this for testing the deprecation warning - assert traces_float32_old.dtype == "float32" + with pytest.warns(DeprecationWarning, match="`return_scaled` is deprecated"): + traces_float32_old = rec_int16.get_traces(return_scaled=True) # Keep this for testing the deprecation warning + assert traces_float32_old.dtype == "float32" # Test new return_in_uV parameter traces_float32_new = rec_int16.get_traces(return_in_uV=True) @@ -342,7 +348,7 @@ def test_BaseRecording(create_cache_folder): # test 3d probe rec_3d = generate_recording(ndim=3, num_channels=30) - locations_3d = rec_3d.get_property("location") + locations_3d = rec_3d.get_probe().contact_positions locations_xy = rec_3d.get_channel_locations(axes="xy") assert np.allclose(locations_xy, locations_3d[:, [0, 1]]) @@ -411,8 +417,8 @@ def test_json_pickle_equivalence(create_cache_folder): for key, value in data_json.items(): # skip probe info, since pickle keeps some additional information - if key not in ["properties"]: - if isinstance(value, dict): + if key not in ["properties", "probegroup"]: + if isinstance(value, dict) and isinstance(data_pickle[key], dict): for sub_key, sub_value in value.items(): assert np.all(sub_value == data_pickle[key][sub_key]) else: diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index e58ef4ee68..a0fa2a24a1 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -494,7 +494,7 @@ def add_recording_to_zarr_group( ) # save probe - if recording.get_property("contact_vector") is not None: + if recording.has_probe(): probegroup = recording.get_probegroup() zarr_group.attrs["probe"] = check_json(probegroup.to_dict(array_as_list=True)) diff --git a/src/spikeinterface/extractors/neoextractors/biocam.py b/src/spikeinterface/extractors/neoextractors/biocam.py index 8d1fac0c72..c85e82b574 100644 --- a/src/spikeinterface/extractors/neoextractors/biocam.py +++ b/src/spikeinterface/extractors/neoextractors/biocam.py @@ -70,9 +70,11 @@ def __init__( if electrode_width is not None: probe_kwargs["electrode_width"] = electrode_width probe = probeinterface.read_3brain(file_path, **probe_kwargs) + rows = probe.contact_annotations["row"] + cols = probe.contact_annotations["col"] self.set_probe(probe, in_place=True) - self.set_property("row", self.get_property("contact_vector")["row"]) - self.set_property("col", self.get_property("contact_vector")["col"]) + self.set_property("row", rows) + self.set_property("col", cols) self._kwargs.update( { diff --git a/src/spikeinterface/extractors/neoextractors/maxwell.py b/src/spikeinterface/extractors/neoextractors/maxwell.py index 932ecee106..5eaa49e6b8 100644 --- a/src/spikeinterface/extractors/neoextractors/maxwell.py +++ b/src/spikeinterface/extractors/neoextractors/maxwell.py @@ -74,8 +74,9 @@ def __init__( # rec_name auto set by neo rec_name = self.neo_reader.rec_name probe = probeinterface.read_maxwell(file_path, well_name=well_name, rec_name=rec_name) + electrodes = probe.contact_annotations["electrode"] self.set_probe(probe, in_place=True) - self.set_property("electrode", self.get_property("contact_vector")["electrode"]) + self.set_property("electrode", electrodes) self._kwargs.update(dict(file_path=str(Path(file_path).absolute()), rec_name=rec_name)) @classmethod diff --git a/src/spikeinterface/extractors/tests/test_iblextractors.py b/src/spikeinterface/extractors/tests/test_iblextractors.py index 5306de2441..87eb4df47a 100644 --- a/src/spikeinterface/extractors/tests/test_iblextractors.py +++ b/src/spikeinterface/extractors/tests/test_iblextractors.py @@ -84,8 +84,6 @@ def test_property_keys(self): expected_property_keys = [ "gain_to_uV", "offset_to_uV", - "contact_vector", - "location", "group", "shank", "shank_row", @@ -97,6 +95,9 @@ def test_property_keys(self): ] self.assertCountEqual(first=self.recording.get_property_keys(), second=expected_property_keys) + def test_has_probe(self): + assert self.recording.has_probe() is True + def test_trace_shape(self): expected_shape = (21, 384) self.assertTupleEqual(tuple1=self.small_scaled_trace.shape, tuple2=expected_shape) diff --git a/src/spikeinterface/preprocessing/average_across_direction.py b/src/spikeinterface/preprocessing/average_across_direction.py index 113d1e22f1..dfef781ec1 100644 --- a/src/spikeinterface/preprocessing/average_across_direction.py +++ b/src/spikeinterface/preprocessing/average_across_direction.py @@ -63,7 +63,7 @@ def __init__( # my geometry channel_locations = np.zeros( (n_pos_unique, parent_channel_locations.shape[1]), - dtype=parent_channel_locations.dtype, + dtype=np.float32, ) # average other dimensions in the geometry other_dim = np.arange(parent_channel_locations.shape[1]) != dim diff --git a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py index a571894374..1294b57a91 100644 --- a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py @@ -2,6 +2,8 @@ import numpy as np import os +import probeinterface as pi + import spikeinterface as si import spikeinterface.preprocessing as spre import spikeinterface.extractors as se @@ -125,9 +127,12 @@ def test_compare_input_argument_ranges_against_ibl(shanks, p, sigma_um, num_chan # distribute default probe locations across 4 shanks if set rng = np.random.default_rng(seed=None) - x = rng.choice(shanks, num_channels) - for idx, __ in enumerate(recording._properties["contact_vector"]): - recording._properties["contact_vector"][idx][1] = x[idx] + x_new = rng.choice(shanks, num_channels) + probe = recording.get_probe() + new_positions = probe.contact_positions.copy() + new_positions[:, 0] = x_new # column 0 is x + recording._probegroup.probes[0]._contact_positions = new_positions + recording.set_probe(probe, in_place=True) # generate random bad channel locations bad_channel_indexes = rng.choice(num_channels, rng.integers(1, int(num_channels / 5)), replace=False) @@ -161,18 +166,21 @@ def test_output_values(): the non-interpolated channels is also an implicit test these were not accidently changed. """ - recording = generate_recording(num_channels=5, durations=[1]) + recording = generate_recording(num_channels=5, durations=[1], set_probe=False) bad_channel_indexes = np.array([0]) bad_channel_ids = recording.channel_ids[bad_channel_indexes] - new_probe_locs = [ - [5, 7, 3, 5, 5], # 5 channels, a in the center ('bad channel', zero index) - [5, 5, 5, 7, 3], - ] # all others equal distance away. - # Overwrite the probe information with the new locations - for idx, (x, y) in enumerate(zip(*new_probe_locs)): - recording._properties["contact_vector"][idx][1] = x - recording._properties["contact_vector"][idx][2] = y + probe_locs = np.array( + [ + [5, 7, 3, 5, 5], # 5 channels, a in the center ('bad channel', zero index) + [5, 5, 5, 7, 3], + ] # all others equal distance away. + ).T + # Set the probe information with the new locations + probe = pi.Probe(ndim=2) + probe.set_contacts(positions=probe_locs) + probe.set_device_channel_indices(np.arange(len(probe_locs))) + recording.set_probe(probe, in_place=True) # Run interpolation in SI and check the interpolated channel # 0 is a linear combination of other channels @@ -186,8 +194,7 @@ def test_output_values(): # Shift the last channel position so that it is 4 units, rather than 2 # away. Setting sigma_um = p = 1 allows easy calculation of the expected # weights. - recording._properties["contact_vector"][-1][1] = 5 - recording._properties["contact_vector"][-1][2] = 9 + recording._probegroup.probes[0]._contact_positions[-1] = [5, 9] expected_weights = np.r_[np.tile(np.exp(-2), 3), np.exp(-4)] expected_weights /= np.sum(expected_weights) diff --git a/src/spikeinterface/preprocessing/zero_channel_pad.py b/src/spikeinterface/preprocessing/zero_channel_pad.py index 35b984449d..aaede9de5e 100644 --- a/src/spikeinterface/preprocessing/zero_channel_pad.py +++ b/src/spikeinterface/preprocessing/zero_channel_pad.py @@ -157,7 +157,7 @@ def __init__(self, recording: BaseRecording, num_channels: int, channel_mapping: "The new mapping cannot exceed total number of channels " "in the zero-chanenl-padded recording." ) else: - if "locations" in recording.get_property_keys() or "contact_vector" in recording.get_property_keys(): + if recording.has_probe(): self.channel_mapping = np.argsort(recording.get_channel_locations()[:, 1]) else: self.channel_mapping = np.arange(recording.get_num_channels()) diff --git a/src/spikeinterface/sorters/external/hdsort.py b/src/spikeinterface/sorters/external/hdsort.py index 3daaf85b7a..07d59332d2 100644 --- a/src/spikeinterface/sorters/external/hdsort.py +++ b/src/spikeinterface/sorters/external/hdsort.py @@ -276,8 +276,8 @@ def write_hdsort_input_format(cls, recording, save_path, chunk_memory="500M"): [("electrode", np.int32), ("x", np.float64), ("y", np.float64), ("channel", np.int32)] ) - locations = recording.get_property("location") - assert locations is not None, "'location' property is needed to run HDSort" + assert recording.has_probe(), "The recording must have a probe to run HDSort" + locations = recording.get_channel_locations() with h5py.File(save_path, "w") as f: f.create_group("ephys") diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index a50b9609b9..aa6f43936b 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -405,11 +405,12 @@ def __init__( if border_mode == "remove_channels": # change the wiring of the probe - # TODO this is also done in ChannelSliceRecording, this should be done in a common place - contact_vector = self.get_property("contact_vector") - if contact_vector is not None: - contact_vector["device_channel_indices"] = np.arange(len(channel_ids), dtype="int64") - self.set_property("contact_vector", contact_vector) + if recording.has_probe(): + probegroup = recording.get_probegroup() + channel_indices = recording.ids_to_indices(channel_ids) + probegroup_sliced = probegroup.get_slice(channel_indices) + probegroup_sliced.set_global_device_channel_indices(np.arange(len(channel_ids), dtype="int64")) + self.set_probegroup(probegroup_sliced, in_place=True) # handle manual interpolation_time_bin_centers_s # the case where interpolation_time_bin_size_s is set is handled per-segment below