diff --git a/spikeinterface_gui/basescatterview.py b/spikeinterface_gui/basescatterview.py index c9754f9..249ea65 100644 --- a/spikeinterface_gui/basescatterview.py +++ b/spikeinterface_gui/basescatterview.py @@ -56,7 +56,7 @@ def get_unit_data(self, unit_id, segment_index=0): return spike_times, spike_data, np.array([1]), np.array([ymin, ymax]), ymin, ymax, inds # avoid clear outliers in the plot and histogram by using percentiles - ymin, ymax = np.percentile(spike_data, [self.settings['display_low_percentiles'], self.settings['display_high_percentiles']]) + ymin, ymax = np.percentile(spike_data[~np.isnan(spike_data)], [self.settings['display_low_percentiles'], self.settings['display_high_percentiles']]) min_bin_size = np.min(np.diff(np.unique(spike_data))) bins = np.linspace(ymin, ymax, self.settings['num_bins']) # if bins are too small, adjust the number of bins to ensure a minimum bin size and avoid jumps in the histogram @@ -329,8 +329,8 @@ def _qt_refresh(self, set_scatter_range=False): # set x range to time range of the current segment for scatter, and max count for histogram # set y range to min and max of visible spike amplitudes if len(ymins) > 0 and (set_scatter_range or not self._first_refresh_done): - ymin = np.min(ymins) - ymax = np.max(ymaxs) + ymin = np.nanmin(ymins) + ymax = np.nanmax(ymaxs) t_start, t_stop = self.controller.get_t_start_t_stop() self.viewBox.setXRange(t_start, t_stop, padding = 0.0) self.viewBox.setYRange(ymin, ymax, padding = 0.0) diff --git a/spikeinterface_gui/controller.py b/spikeinterface_gui/controller.py index 105783f..6722c02 100644 --- a/spikeinterface_gui/controller.py +++ b/spikeinterface_gui/controller.py @@ -1,4 +1,5 @@ import time +from copy import deepcopy import numpy as np @@ -10,9 +11,13 @@ from spikeinterface import compute_sparsity from spikeinterface.core import get_template_extremum_channel, BaseEvent from spikeinterface.core.sorting_tools import spike_vector_to_indices -from spikeinterface.curation import validate_curation_dict +from spikeinterface.core.core_tools import check_json +from spikeinterface.curation import validate_curation_dict, apply_curation from spikeinterface.curation.curation_model import Curation from spikeinterface.widgets.utils import make_units_table_from_analyzer +from spikeinterface.widgets.utils import make_units_table_from_analyzer + +from .utils_global import add_new_unit_ids_to_curation_dict from .curation_tools import add_merge, default_label_definitions, empty_curation_data from .event_tools import parse_events @@ -25,7 +30,9 @@ _default_main_settings = dict( max_visible_units=10, color_mode='color_by_unit', - use_times=False + use_times=False, + merge_new_id_strategy = 'take_first', + split_new_id_strategy = 'append', ) from spikeinterface.widgets.sorting_summary import _default_displayed_unit_properties @@ -60,6 +67,10 @@ def __init__( self.backend = backend self.disable_save_settings_button = disable_save_settings_button self.current_curation_saved = True + self.applied_curations = [] + + if extra_unit_properties is None: + self.extra_unit_properties_names = [] self.external_data = external_data if self.backend == "qt": @@ -72,19 +83,43 @@ def __init__( self.with_traces = with_traces - self.analyzer = analyzer - assert self.analyzer.get_extension("random_spikes") is not None - - self.return_in_uV = self.analyzer.return_in_uV self.save_on_compute = save_on_compute self.verbose = verbose - t0 = time.perf_counter() + self.original_analyzer = None self.main_settings = _default_main_settings.copy() if user_main_settings is not None: self.main_settings.update(user_main_settings) + self.set_analyzer_info(analyzer) + self.units_table = make_units_table_from_analyzer(self.analyzer, extra_properties=extra_unit_properties) + + self.set_curation_info(curation, curation_data, label_definitions, curation_callback, curation_callback_kwargs) + + # parse events + self.events = None + if events is not None: + self.events = parse_events(events, self, verbose=verbose) + if len(self.events) == 0: + self.events = None + + if displayed_unit_properties is None: + displayed_unit_properties = list(_default_displayed_unit_properties) + if extra_unit_properties is not None: + self.extra_unit_properties_names = list(extra_unit_properties.keys()) + displayed_unit_properties += self.extra_unit_properties_names + displayed_unit_properties = [v for v in displayed_unit_properties if v in self.units_table.columns] + self.displayed_unit_properties = displayed_unit_properties + + def set_analyzer_info(self, analyzer): + + self.analyzer = analyzer + assert self.analyzer.get_extension("random_spikes") is not None + + self.return_in_uV = self.analyzer.return_in_uV + t0 = time.perf_counter() + self.num_channels = self.analyzer.get_num_channels() # this now private and should be access using function self._visible_unit_ids = [self.unit_ids[0]] @@ -98,7 +133,7 @@ def __init__( self.analyzer_sparsity = self.analyzer.sparsity # Mandatory extensions: computation forced - if verbose: + if self.verbose: print('\tLoading templates') temp_ext = self.analyzer.get_extension("templates") if temp_ext is None: @@ -112,7 +147,7 @@ def __init__( else: self.templates_std = None - if verbose: + if self.verbose: print('\tLoading unit_locations') ext = analyzer.get_extension('unit_locations') if ext is None: @@ -122,7 +157,7 @@ def __init__( self.unit_positions = ext.get_data()[:, :2] # Optional extensions : can be None or skipped - if verbose: + if self.verbose: print('\tLoading noise_levels') ext = analyzer.get_extension('noise_levels') if ext is None and self.has_extension('recording'): @@ -130,12 +165,12 @@ def __init__( ext = analyzer.compute_one_extension('noise_levels') self.noise_levels = ext.get_data() if ext is not None else None - if "quality_metrics" in skip_extensions: + if "quality_metrics" in self.skip_extensions: if self.verbose: print('\tSkipping quality_metrics') self.metrics = None else: - if verbose: + if self.verbose: print('\tLoading quality_metrics') qm_ext = analyzer.get_extension('quality_metrics') if qm_ext is not None: @@ -143,12 +178,12 @@ def __init__( else: self.metrics = None - if "spike_amplitudes" in skip_extensions: + if "spike_amplitudes" in self.skip_extensions: if self.verbose: print('\tSkipping spike_amplitudes') self.spike_amplitudes = None else: - if verbose: + if self.verbose: print('\tLoading spike_amplitudes') sa_ext = analyzer.get_extension('spike_amplitudes') if sa_ext is not None: @@ -156,12 +191,12 @@ def __init__( else: self.spike_amplitudes = None - if "amplitude_scalings" in skip_extensions: + if "amplitude_scalings" in self.skip_extensions: if self.verbose: print('\tSkipping amplitude_scalings') self.amplitude_scalings = None else: - if verbose: + if self.verbose: print('\tLoading amplitude_scalings') sa_ext = analyzer.get_extension('amplitude_scalings') if sa_ext is not None: @@ -169,12 +204,12 @@ def __init__( else: self.amplitude_scalings = None - if "spike_locations" in skip_extensions: + if "spike_locations" in self.skip_extensions: if self.verbose: print('\tSkipping spike_locations') self.spike_depths = None else: - if verbose: + if self.verbose: print('\tLoading spike_locations') sl_ext = analyzer.get_extension('spike_locations') if sl_ext is not None: @@ -182,13 +217,13 @@ def __init__( else: self.spike_depths = None - if "correlograms" in skip_extensions: + if "correlograms" in self.skip_extensions: if self.verbose: print('\tSkipping correlograms') self.correlograms = None self.correlograms_bins = None else: - if verbose: + if self.verbose: print('\tLoading correlograms') ccg_ext = analyzer.get_extension('correlograms') if ccg_ext is not None: @@ -196,13 +231,13 @@ def __init__( else: self.correlograms, self.correlograms_bins = None, None - if "isi_histograms" in skip_extensions: + if "isi_histograms" in self.skip_extensions: if self.verbose: print('\tSkipping isi_histograms') self.isi_histograms = None self.isi_bins = None else: - if verbose: + if self.verbose: print('\tLoading isi_histograms') isi_ext = analyzer.get_extension('isi_histograms') if isi_ext is not None: @@ -211,11 +246,11 @@ def __init__( self.isi_histograms, self.isi_bins = None, None self._similarity_by_method = {} - if "template_similarity" in skip_extensions: + if "template_similarity" in self.skip_extensions: if self.verbose: print('\tSkipping template_similarity') else: - if verbose: + if self.verbose: print('\tLoading template_similarity') ts_ext = analyzer.get_extension('template_similarity') if ts_ext is not None: @@ -228,12 +263,12 @@ def __init__( ts_ext = analyzer.compute_one_extension('template_similarity', method=method, save=save_on_compute) self._similarity_by_method[method] = ts_ext.get_data() - if "waveforms" in skip_extensions: + if "waveforms" in self.skip_extensions: if self.verbose: print('\tSkipping waveforms') self.waveforms_ext = None else: - if verbose: + if self.verbose: print('\tLoading waveforms') wf_ext = analyzer.get_extension('waveforms') if wf_ext is not None: @@ -241,12 +276,12 @@ def __init__( else: self.waveforms_ext = None self._pc_projections = None - if "principal_components" in skip_extensions: + if "principal_components" in self.skip_extensions: if self.verbose: print('\tSkipping principal_components') self.pc_ext = None else: - if verbose: + if self.verbose: print('\tLoading principal_components') pc_ext = analyzer.get_extension('principal_components') self.pc_ext = pc_ext @@ -262,15 +297,8 @@ def __init__( self.num_segments = self.analyzer.get_num_segments() self.sampling_frequency = self.analyzer.sampling_frequency - # parse events - self.events = None - if events is not None: - self.events = parse_events(events, self, verbose=verbose) - if len(self.events) == 0: - self.events = None - t1 = time.perf_counter() - if verbose: + if self.verbose: print('Loading extensions took', t1 - t0) t0 = time.perf_counter() @@ -332,7 +360,7 @@ def __init__( self._spike_index_by_units[unit_id] = np.concatenate(inds) t1 = time.perf_counter() - if verbose: + if self.verbose: print('Gathering all spikes took', t1 - t0) self._spike_visible_indices = np.array([], dtype='int64') @@ -341,22 +369,17 @@ def __init__( self._traces_cached = {} - self.units_table = make_units_table_from_analyzer(analyzer, extra_properties=extra_unit_properties) - - if displayed_unit_properties is None: - displayed_unit_properties = list(_default_displayed_unit_properties) - if extra_unit_properties is not None: - displayed_unit_properties += list(extra_unit_properties.keys()) - displayed_unit_properties = [v for v in displayed_unit_properties if v in self.units_table.columns] - self.displayed_unit_properties = displayed_unit_properties - # set default time info self.update_time_info() + + def set_curation_info(self, curation, curation_data, label_definitions, curation_callback, curation_callback_kwargs): self.curation = curation self.curation_callback = curation_callback self.curation_callback_kwargs = curation_callback_kwargs + self._potential_merges = None + # TODO: Reload the dictionary if it already exists if self.curation: # rules: # * if user sends curation_data, then it is used @@ -375,6 +398,24 @@ def __init__( except Exception as e: raise ValueError(f"Invalid curation data.\nError: {e}") + if curation_data.get("merges") is None: + curation_data["merges"] = [] + else: + # here we reset the merges for better formatting (str) + existing_merges = curation_data["merges"] + new_merges = [] + for m in existing_merges: + if "unit_ids" not in m: + continue + if len(m["unit_ids"]) < 2: + continue + new_merges = add_merge(new_merges, m["unit_ids"]) + curation_data["merges"] = new_merges + if curation_data.get("splits") is None: + curation_data["splits"] = [] + if curation_data.get("removed") is None: + curation_data["removed"] = [] + elif self.analyzer.format == "binary_folder": json_file = self.analyzer.folder / "spikeinterface_gui" / "curation_data.json" if json_file.exists(): @@ -390,26 +431,23 @@ def __init__( if curation_data is None: curation_data = deepcopy(empty_curation_data) curation_data["unit_ids"] = self.unit_ids.tolist() + curation_data["label_definitions"] = default_label_definitions.copy() - if "label_definitions" not in curation_data: + self.curation_data = curation_data + + if "label_definitions" not in self.curation_data: if label_definitions is not None: - curation_data["label_definitions"] = label_definitions - else: - curation_data["label_definitions"] = default_label_definitions.copy() + self.curation_data["label_definitions"] = label_definitions - # This will enable the default shortcuts if has default quality labels self.has_default_quality_labels = False - if "quality" in curation_data["label_definitions"]: - curation_dict_quality_labels = curation_data["label_definitions"]["quality"]["label_options"] + if "quality" in self.curation_data["label_definitions"]: + curation_dict_quality_labels = self.curation_data["label_definitions"]["quality"]["label_options"] default_quality_labels = default_label_definitions["quality"]["label_options"] if set(curation_dict_quality_labels) == set(default_quality_labels): if self.verbose: print('Curation quality labels are the default ones') self.has_default_quality_labels = True - curation_data = Curation(**curation_data).model_dump() - self.curation_data = curation_data - def check_is_view_possible(self, view_name): from .viewlist import get_all_possible_views possible_class_views = get_all_possible_views() @@ -548,15 +586,22 @@ def get_information_txt(self): return txt - def refresh_colors(self): + def refresh_colors(self, existing_colors=None): if self.backend == "qt": self._cached_qcolors = {} elif self.backend == "panel": pass if self.main_settings['color_mode'] == 'color_by_unit': - self.colors = get_unit_colors(self.analyzer.sorting, color_engine='matplotlib', map_name='gist_ncar', - shuffle=True, seed=42) + unit_colors = get_unit_colors(self.analyzer.sorting, color_engine='matplotlib', map_name='gist_ncar', + shuffle=True, seed=42) + if existing_colors is None: + self.colors = unit_colors + else: + for unit_id, unit_color in unit_colors.items(): + if unit_id not in self.colors.keys(): + self.colors[unit_id] = unit_color + elif self.main_settings['color_mode'] == 'color_only_visible': unit_colors = get_unit_colors(self.analyzer.sorting, color_engine='matplotlib', map_name='gist_ncar', shuffle=True, seed=42) @@ -859,9 +904,6 @@ def compute_isi_histograms(self, window_ms, bin_ms): self.isi_histograms, self.isi_bins = ext.get_data() return self.isi_histograms, self.isi_bins - def get_units_table(self): - return self.units_table - def compute_auto_merge(self, **params): from spikeinterface.curation import compute_merge_unit_groups @@ -878,14 +920,58 @@ def compute_auto_merge(self, **params): def curation_can_be_saved(self): return self.analyzer.format != "memory" - def construct_final_curation(self): + def construct_final_curation(self, with_explicit_new_unit_ids=False): d = dict() d["format_version"] = "2" d["unit_ids"] = self.unit_ids.tolist() d.update(self.curation_data.copy()) + if with_explicit_new_unit_ids: + split_new_id_strategy = self.main_settings.get('split_new_id_strategy') + merge_new_id_strategy = self.main_settings.get('merge_new_id_strategy') + d = add_new_unit_ids_to_curation_dict(d, self.analyzer.sorting, split_new_id_strategy=split_new_id_strategy, merge_new_id_strategy=merge_new_id_strategy) + model = Curation(**d) return model + def apply_curation(self): + + if self.original_analyzer is None: + self.original_analyzer = deepcopy(self.analyzer) + self.original_analyzer.extensions = {} + + curation = self.construct_final_curation(with_explicit_new_unit_ids=True) + curated_analyzer = apply_curation(self.analyzer, curation) + + self.applied_curations.append(curation) + self.remove_curation(curated_analyzer) + + self.set_analyzer_info(curated_analyzer) + + # for now, don't show externally provided properties after curation + self.displayed_unit_properties = [displayed_property for displayed_property in self.displayed_unit_properties if displayed_property not in self.extra_unit_properties_names] + self.units_table = make_units_table_from_analyzer(self.analyzer) + self.refresh_colors(existing_colors=self.colors) + + for view in self.views: + view.reinitialize() + + def remove_curation(self, curated_analyzer): + """Removes curation from the controller, retaining quality labels.""" + + curation_data = deepcopy(empty_curation_data) + # retain label definitions and 'quality' label + label_definitioins = self.curation_data.get("label_definitions", None) + curation_data["label_definitions"] = label_definitioins + + if (quality_labels := curated_analyzer.get_sorting_property('quality')) is not None: + manual_labels = [] + for unit_id, quality_label in zip(curated_analyzer.unit_ids, quality_labels): + manual_labels.append({'unit_id': unit_id, 'labels': {'quality': [quality_label]}}) + + curation_data['manual_labels'] = manual_labels + + self.curation_data = curation_data + def set_curation_data(self, curation_data): print("Setting curation data") new_curation_data = empty_curation_data.copy() diff --git a/spikeinterface_gui/correlogramview.py b/spikeinterface_gui/correlogramview.py index 9ca6fa6..8e2d585 100644 --- a/spikeinterface_gui/correlogramview.py +++ b/spikeinterface_gui/correlogramview.py @@ -48,6 +48,10 @@ def _qt_make_layout(self): self.grid = pg.GraphicsLayoutWidget() self.layout.addWidget(self.grid) + def _reinitialize(self): + self.ccg, self.bins = self.controller.get_correlograms() + self.figure_cache = {} + self._refresh() def _qt_refresh(self): import pyqtgraph as pg diff --git a/spikeinterface_gui/curationview.py b/spikeinterface_gui/curationview.py index dc92163..22e19ff 100644 --- a/spikeinterface_gui/curationview.py +++ b/spikeinterface_gui/curationview.py @@ -1,10 +1,8 @@ -import json from pathlib import Path from .view_base import ViewBase -from spikeinterface.core.core_tools import check_json - +from spikeinterface.curation.curation_model import SequentialCuration class CurationView(ViewBase): id = "curation" @@ -74,6 +72,11 @@ def _qt_make_layout(self): but = QT.QPushButton("Save in analyzer") tb.addWidget(but) but.clicked.connect(self.controller.save_curation_in_analyzer) + + but_apply = QT.QPushButton("Apply curation") + tb.addWidget(but_apply) + but_apply.clicked.connect(self.apply_curation_to_analyzer) + but = QT.QPushButton("Export JSON") but.clicked.connect(self._qt_export_json) tb.addWidget(but) @@ -277,6 +280,10 @@ def _qt_on_unit_visibility_changed(self): def on_manual_curation_updated(self): self.refresh() + def apply_curation_to_analyzer(self): + with self.busy_cursor(): + self.controller.apply_curation() + def _qt_export_json(self): from .myqt import QT @@ -286,10 +293,20 @@ def _qt_export_json(self): fd.setViewMode(QT.QFileDialog.Detail) if fd.exec_(): json_file = Path(fd.selectedFiles()[0]) - curation_model = self.controller.construct_final_curation() - with json_file.open("w") as f: - f.write(curation_model.model_dump_json(indent=4)) - self.controller.current_curation_saved = True + if len(self.controller.applied_curations) == 0: + curation_model = self.controller.construct_final_curation() + with json_file.open("w") as f: + f.write(curation_model.model_dump_json(indent=4)) + self.controller.current_curation_saved = True + else: + current_curation_model = self.controller.construct_final_curation() + applied_curations = self.controller.applied_curations + current_and_applied_curations = applied_curations + [current_curation_model] + + sequential_curation_model = SequentialCuration(curation_steps=current_and_applied_curations) + with json_file.open("w") as f: + f.write(sequential_curation_model.model_dump_json(indent=4)) + self.controller.current_curation_saved = True # PANEL def _panel_make_layout(self): @@ -363,6 +380,13 @@ def _panel_make_layout(self): ) save_button.on_click(save_button_callback) + apply_button = pn.widgets.Button( + name="Apply curation", + button_type="primary", + height=30 + ) + apply_button.on_click(self._panel_apply_curation_to_analyzer) + download_button = pn.widgets.FileDownload( button_type="primary", filename="curation.json", callback=self._panel_generate_json, height=30 ) @@ -380,6 +404,8 @@ def _panel_make_layout(self): buttons_save = pn.Row( save_button, download_button, + apply_button, + submit_button, sizing_mode="stretch_width", ) save_sections = pn.Column( @@ -495,9 +521,13 @@ def _panel_restore_units(self, event): def _panel_unmerge(self, event): self.unmerge() + def _panel_apply_curation_to_analyzer(self, event): + self.apply_curation_to_analyzer() + def _panel_unsplit(self, event): self.unsplit() + def _panel_save_in_analyzer(self, event): self.controller.save_curation_in_analyzer() self.refresh() diff --git a/spikeinterface_gui/isiview.py b/spikeinterface_gui/isiview.py index f9fa293..c894ce0 100644 --- a/spikeinterface_gui/isiview.py +++ b/spikeinterface_gui/isiview.py @@ -25,6 +25,10 @@ def _on_settings_changed(self): self.isi_histograms, self.isi_bins = None, None self.refresh() + def _reinitialize(self): + self.isi_histograms, self.isi_bins = self.controller.get_isi_histograms() + self._refresh() + ## QT ## def _qt_make_layout(self): diff --git a/spikeinterface_gui/mainsettingsview.py b/spikeinterface_gui/mainsettingsview.py index 79a2638..abdeb9f 100644 --- a/spikeinterface_gui/mainsettingsview.py +++ b/spikeinterface_gui/mainsettingsview.py @@ -8,7 +8,9 @@ {'name': 'max_visible_units', 'type': 'int', 'value' : 10 }, {'name': 'color_mode', 'type': 'list', 'value' : 'color_by_unit', 'limits': ['color_by_unit', 'color_only_visible', 'color_by_visibility']}, - {'name': 'use_times', 'type': 'bool', 'value': False} + {'name': 'use_times', 'type': 'bool', 'value': False}, + {'name': 'merge_new_id_strategy', 'type': 'list', 'limits' : ['take_first', 'append', 'join'], 'value': 'take_first'}, + {'name': 'split_new_id_strategy', 'type': 'list', 'limits' : ['append', 'split'], 'value': 'append'}, ] @@ -51,6 +53,12 @@ def on_use_times(self): self.controller.update_time_info() self.notify_use_times_updated() + def on_merge_new_id_strategy(self): + self.controller.main_settings['merge_new_id_strategy'] = self.main_settings['merge_new_id_strategy'] + + def on_split_new_id_strategy(self): + self.controller.main_settings['split_new_id_strategy'] = self.main_settings['split_new_id_strategy'] + def save_current_settings(self, event=None): backend = self.controller.backend @@ -116,6 +124,8 @@ def _qt_make_layout(self): self.main_settings.param('max_visible_units').sigValueChanged.connect(self.on_max_visible_units_changed) self.main_settings.param('color_mode').sigValueChanged.connect(self.on_change_color_mode) self.main_settings.param('use_times').sigValueChanged.connect(self.on_use_times) + self.main_settings.param('merge_new_id_strategy').sigValueChanged.connect(self.on_merge_new_id_strategy) + self.main_settings.param('split_new_id_strategy').sigValueChanged.connect(self.on_split_new_id_strategy) def qt_make_settings_dict(self, view): """For a given view, return the current settings in a dict""" @@ -151,6 +161,8 @@ def _panel_make_layout(self): self.main_settings._parameterized.param.watch(self._panel_on_max_visible_units_changed, 'max_visible_units') self.main_settings._parameterized.param.watch(self._panel_on_change_color_mode, 'color_mode') self.main_settings._parameterized.param.watch(self._panel_on_use_times, 'use_times') + self.main_settings._parameterized.param.watch(self._panel_on_merge_new_id_strategy, 'merge_new_id_strategy') + self.main_settings._parameterized.param.watch(self._panel_on_split_new_id_strategy, 'split_new_id_strategy') self.layout = pn.Column(self.save_setting_button, self.main_settings_layout, sizing_mode="stretch_both") def panel_make_settings_dict(self, view): @@ -170,6 +182,12 @@ def _panel_on_max_visible_units_changed(self, event): def _panel_on_change_color_mode(self, event): self.on_change_color_mode() + def _panel_on_merge_new_id_strategy(self, event): + self.on_merge_new_id_strategy() + + def _panel_on_split_new_id_strategy(self, event): + self.on_split_new_id_strategy() + def _panel_on_use_times(self, event): self.on_use_times() diff --git a/spikeinterface_gui/maintemplateview.py b/spikeinterface_gui/maintemplateview.py index 0849a0b..4a143eb 100644 --- a/spikeinterface_gui/maintemplateview.py +++ b/spikeinterface_gui/maintemplateview.py @@ -92,19 +92,19 @@ def _qt_refresh(self): if peak_data is not None: # trough - peak_inds = peak_data[['trough_index']].values + peak_inds = peak_data[['trough_index']].values.astype(int) scatter = pg.ScatterPlotItem(x = times[peak_inds], y = template_high[peak_inds], size=10, pxMode = True, color="white", symbol="t") plot.addItem(scatter) names = ('peak_before', 'peak_after') - peak_inds = peak_data[[f'{k}_index' for k in names]].values + peak_inds = peak_data[[f'{k}_index' for k in names]].values.astype(int) scatter = pg.ScatterPlotItem(x = times[peak_inds], y = template_high[peak_inds], size=10, pxMode = True, color="white", symbol="t1") plot.addItem(scatter) all_names = ('trough', 'peak_before', 'peak_after') - peak_inds = peak_data[[f'{k}_index' for k in all_names]].values + peak_inds = peak_data[[f'{k}_index' for k in all_names]].values.astype(int) # Vertical dotted lines from peak to zero for ind in peak_inds: x = [times[ind], times[ind]] diff --git a/spikeinterface_gui/mergeview.py b/spikeinterface_gui/mergeview.py index 66712ea..1196d4d 100644 --- a/spikeinterface_gui/mergeview.py +++ b/spikeinterface_gui/mergeview.py @@ -160,6 +160,12 @@ def accept_group_merge(self, group_ids): self.notify_manual_curation_updated() self.refresh() + def _reinitialize(self): + self.proposed_merge_unit_groups_all = [] + self.proposed_merge_unit_groups = [] + self.merge_info = {} + self._refresh() + ### QT def _qt_get_selected_group_ids(self): inds = self.table.selectedIndexes() diff --git a/spikeinterface_gui/probeview.py b/spikeinterface_gui/probeview.py index ebba875..27be847 100644 --- a/spikeinterface_gui/probeview.py +++ b/spikeinterface_gui/probeview.py @@ -149,6 +149,17 @@ def _qt_make_layout(self): self.roi_units.sigRegionChangeFinished.connect(self._qt_on_roi_units_changed) + def _qt_reinitialize(self): + import pyqtgraph as pg + + self.plot.removeItem(self.scatter) + unit_positions = self.controller.unit_positions + brush = [self.get_unit_color(u) for u in self.controller.unit_ids] + self.scatter = pg.ScatterPlotItem(pos=unit_positions, pxMode=False, size=10, brush=brush) + self.plot.addItem(self.scatter) + + self._qt_refresh() + def _qt_refresh(self): current_unit_positions = self.controller.unit_positions # if not np.array_equal(current_unit_positions, self._unit_positions): @@ -478,11 +489,14 @@ def _panel_make_layout(self): self.should_resize_unit_circle = None # Main layout - self.layout = pn.Column( - self.figure, - styles={"display": "flex", "flex-direction": "column"}, - sizing_mode="stretch_both", - ) + if self.layout is None: + self.layout = pn.Column( + self.figure, + styles={"display": "flex", "flex-direction": "column"}, + sizing_mode="stretch_both", + ) + else: + self.layout.objects = [self.figure] def _panel_refresh(self): import panel as pn @@ -554,6 +568,9 @@ def _panel_refresh(self): self.y_range.start = zoom_bounds[2] self.y_range.end = zoom_bounds[3] + def _panel_reinitialize(self): + self._panel_make_layout() + self._refresh() def _panel_compute_unit_glyph_patches(self): """Compute glyph patches without modifying Bokeh models.""" diff --git a/spikeinterface_gui/spikeamplitudeview.py b/spikeinterface_gui/spikeamplitudeview.py index ee4ac61..bfaf642 100644 --- a/spikeinterface_gui/spikeamplitudeview.py +++ b/spikeinterface_gui/spikeamplitudeview.py @@ -25,6 +25,10 @@ def __init__(self, controller=None, parent=None, backend="qt"): spike_data=spike_data, ) + def _reinitialize(self): + self.spike_data = self.controller.spike_amplitudes + self._refresh() + def _qt_make_layout(self): from .myqt import QT diff --git a/spikeinterface_gui/spikedepthview.py b/spikeinterface_gui/spikedepthview.py index 0bee9df..032e4b5 100644 --- a/spikeinterface_gui/spikedepthview.py +++ b/spikeinterface_gui/spikedepthview.py @@ -17,6 +17,9 @@ def __init__(self, controller=None, parent=None, backend="qt"): spike_data=spike_data, ) + def _reinitialize(self): + self.spike_data = self.controller.spike_depths + self._refresh() SpikeDepthView._gui_help_txt = """ diff --git a/spikeinterface_gui/unitlistview.py b/spikeinterface_gui/unitlistview.py index bae636b..0dd2b6a 100644 --- a/spikeinterface_gui/unitlistview.py +++ b/spikeinterface_gui/unitlistview.py @@ -36,8 +36,6 @@ def update_manual_labels(self): def _qt_make_layout(self): from .myqt import QT - import pyqtgraph as pg - self.menu = None self.layout = QT.QVBoxLayout() @@ -47,21 +45,7 @@ def _qt_make_layout(self): but.clicked.connect(self._qt_select_columns) tb.addWidget(but) - - visible_cols = [] - for col in self.controller.units_table.columns: - visible_cols.append( - {'name': str(col), 'type': 'bool', 'value': col in self.controller.displayed_unit_properties, 'default': True} - ) - self.visible_columns = pg.parametertree.Parameter.create( name='visible columns', type='group', children=visible_cols) - self.tree_visible_columns = pg.parametertree.ParameterTree(parent=self.qt_widget) - self.tree_visible_columns.header().hide() - self.tree_visible_columns.setParameters(self.visible_columns, showTop=True) - # self.tree_visible_columns.setWindowTitle(u'visible columns') - # self.tree_visible_columns.setWindowFlags(QT.Qt.Window) - self.visible_columns.sigTreeStateChanged.connect(self._qt_on_visible_columns_changed) - self.layout.addWidget(self.tree_visible_columns) - self.tree_visible_columns.hide() + self._qt_set_up_visible_columns() # h = QT.QHBoxLayout() # self.layout.addLayout(h) @@ -127,6 +111,28 @@ def _qt_make_layout(self): self.shortcut_noise.setKey(QT.QKeySequence('n')) self.shortcut_noise.activated.connect(lambda: self._qt_set_default_label('noise')) + def _qt_set_up_visible_columns(self): + + import pyqtgraph as pg + visible_cols = [] + for col in self.controller.units_table.columns: + visible_cols.append( + {'name': str(col), 'type': 'bool', 'value': col in self.controller.displayed_unit_properties, 'default': True} + ) + self.visible_columns = pg.parametertree.Parameter.create( name='visible columns', type='group', children=visible_cols) + self.tree_visible_columns = pg.parametertree.ParameterTree(parent=self.qt_widget) + self.tree_visible_columns.header().hide() + self.tree_visible_columns.setParameters(self.visible_columns, showTop=True) + + self.visible_columns.sigTreeStateChanged.connect(self._qt_on_visible_columns_changed) + self.layout.addWidget(self.tree_visible_columns) + self.tree_visible_columns.hide() + + def _qt_reinitialize(self): + + #self._qt_set_up_visible_columns() + self._qt_full_table_refresh() + self._qt_refresh() def _qt_on_column_moved(self, logical_index, old_visual_index, new_visual_index): # Update stored column order @@ -221,7 +227,6 @@ def _qt_full_table_refresh(self): self.table.clear() - internal_column_names = ['unit_id', 'visible', 'channel_id'] # internal labels @@ -572,16 +577,22 @@ def _panel_make_layout(self): shortcuts_component = KeyboardShortcuts(shortcuts=shortcuts) shortcuts_component.on_msg(self._panel_handle_shortcut) - self.layout = pn.Column( - pn.Row( - self.info_text, - ), - buttons, - sizing_mode="stretch_width", - ) + if self.layout is None: + self.layout = pn.Column( + pn.Row( + self.info_text, + ), + buttons, + sizing_mode="stretch_width", + ) - self.layout.append(self.table) - self.layout.append(shortcuts_component) + self.layout.append(self.table) + self.layout.append(shortcuts_component) + else: + self.layout[0][0] = self.info_text + self.layout[1] = buttons + self.layout[2] = self.table + self.layout[3] = shortcuts_component self.table.tabulator.on_edit(self._panel_on_edit) self.refresh_button.on_click(self._panel_refresh_click) @@ -636,6 +647,10 @@ def _panel_refresh(self): # refresh header self._panel_refresh_header() + def _panel_reinitialize(self): + self._panel_make_layout() + self._panel_refresh() + def _panel_refresh_header(self): unit_ids = self.controller.unit_ids n1 = len(unit_ids) diff --git a/spikeinterface_gui/utils_global.py b/spikeinterface_gui/utils_global.py index 23fc61d..d885aa2 100644 --- a/spikeinterface_gui/utils_global.py +++ b/spikeinterface_gui/utils_global.py @@ -1,6 +1,7 @@ import numpy as np from pathlib import Path import os +from copy import copy def get_config_folder() -> Path: """Get the config folder for spikeinterface-gui settings files. @@ -58,3 +59,37 @@ def get_present_zones_in_half_of_layout(layout_zone, shift): is_present = [views is not None and len(views) > 0 for views in half_dict.values()] present_zones = set(np.array(list(half_dict.keys()))[np.array(is_present)]) return present_zones + + +def add_new_unit_ids_to_curation_dict(curation_dict, sorting, split_new_id_strategy, merge_new_id_strategy): + """ + Explicitly adds the new unit ids to `curation_dict` based on the split and merge new id strategies. + These *should* be the ids that would have been generated during `apply_curation` with these strategies. + """ + + from spikeinterface.core.sorting_tools import generate_unit_ids_for_split, generate_unit_ids_for_merge_group + from spikeinterface.curation.curation_model import Curation + + curation_model = Curation(**curation_dict) + old_unit_ids = copy(curation_model.unit_ids) + + if len(curation_model.splits) > 0: + unit_splits = {split.unit_id: split.get_full_spike_indices(sorting) for split in curation_model.splits} + new_split_unit_ids = generate_unit_ids_for_split(old_unit_ids, unit_splits, new_unit_ids=None, new_id_strategy=split_new_id_strategy) + + all_new_unit_ids = [] + for split_index, new_unit_ids in enumerate(new_split_unit_ids): + curation_dict['splits'][split_index]['new_unit_ids'] = new_unit_ids + all_new_unit_ids = all_new_unit_ids + new_unit_ids + + # update old unit ids with the newly split units + old_unit_ids = np.setdiff1d(old_unit_ids, np.array(list(unit_splits.keys()))) + old_unit_ids = np.concat([old_unit_ids, all_new_unit_ids]) + + if len(curation_model.merges) > 0: + merge_unit_groups = [m.unit_ids for m in curation_model.merges] + new_merge_unit_ids = generate_unit_ids_for_merge_group(old_unit_ids, merge_unit_groups, new_unit_ids=None, new_id_strategy=merge_new_id_strategy) + for merge_index, new_unit_id in enumerate(new_merge_unit_ids): + curation_dict['merges'][merge_index]['new_unit_id'] = new_unit_id + + return curation_dict diff --git a/spikeinterface_gui/view_base.py b/spikeinterface_gui/view_base.py index 06ae99f..8bc728d 100644 --- a/spikeinterface_gui/view_base.py +++ b/spikeinterface_gui/view_base.py @@ -40,7 +40,6 @@ def __init__(self, controller=None, parent=None, backend="qt"): create_settings(self) self.notifier = SignalNotifier(view=self) self.busy = pn.indicators.LoadingSpinner(value=True, size=20, name='busy...') - make_layout() if self._settings is not None: listen_setting_changes(self) @@ -115,6 +114,14 @@ def refresh(self, **kwargs): t1 = time.perf_counter() print(f"Refresh {self.__class__.__name__} took {t1 - t0:.3f} seconds", flush=True) + def reinitialize(self, **kwargs): + if self.controller.verbose: + t0 = time.perf_counter() + self._reinitialize(**kwargs) + if self.controller.verbose: + t1 = time.perf_counter() + print(f"Reinitialize {self.__class__.__name__} took {t1 - t0:.3f} seconds", flush=True) + def compute(self, event=None): with self.busy_cursor(): self._compute() @@ -130,6 +137,12 @@ def _refresh(self, **kwargs): import panel as pn pn.state.execute(lambda: self._panel_refresh(**kwargs), schedule=True) + def _reinitialize(self, **kwargs): + if self.backend == "qt": + self._qt_reinitialize(**kwargs) + elif self.backend == "panel": + self._panel_reinitialize(**kwargs) + def warning(self, warning_msg): if self.backend == "qt": self._qt_insert_warning(warning_msg) @@ -266,6 +279,9 @@ def _qt_make_layout(self): def _qt_refresh(self): raise (NotImplementedError) + + def _qt_reinitialize(self): + self._qt_refresh() def _qt_on_spike_selection_changed(self): pass @@ -326,6 +342,9 @@ def _panel_make_layout(self): def _panel_refresh(self): raise (NotImplementedError) + + def _panel_reinitialize(self): + self._panel_refresh() def _panel_on_spike_selection_changed(self): pass