Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions .github/scripts/generate_zarr_v2_fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
#!/usr/bin/env python3
"""
Generate zarr v2 fixtures for backward compatibility tests.

Run this script with an old spikeinterface version and zarr<3, e.g.:
pip install "spikeinterface==0.104.0" "zarr<3"
python generate_zarr_v2_fixtures.py --output /tmp/zarr_v2_fixtures

The script saves:
- recording.zarr : a small ZarrRecordingExtractor
- sorting.zarr : a small ZarrSortingExtractor
- expected_values.json : key values used to verify correct loading
"""
import argparse
import shutil
import json
from pathlib import Path

import numpy as np
import zarr

import spikeinterface as si


def main(output_dir: Path) -> None:
print(f"spikeinterface version : {si.__version__}")
print(f"zarr version : {zarr.__version__}")


output_dir.mkdir(parents=True, exist_ok=True)

recording, sorting = si.generate_ground_truth_recording(durations=[10, 5],num_channels=32, num_units=10, seed=0)
# save to binary to make them JSON serializable for later expected values extraction
recording = recording.save(folder=output_dir / "recording_binary", overwrite=True)
sorting = sorting.save(folder=output_dir / "sorting_binary", overwrite=True)
# --- save recording ---
recording_path = output_dir / "recording.zarr"
recording_zarr = recording.save(format="zarr", folder=recording_path, overwrite=True)
print(f"Saved recording -> {recording_path}")

# --- save sorting ---
sorting_path = output_dir / "sorting.zarr"
sorting_zarr = sorting.save(format="zarr", folder=sorting_path, overwrite=True)
print(f"Saved sorting -> {sorting_path}")

# --- save SortingAnalyzer ---
# Reload the recording from zarr so it is a serializable ZarrRecordingExtractor,
# which the analyzer can store as provenance.
analyzer_path = output_dir / "analyzer.zarr"
if analyzer_path.is_dir():
shutil.rmtree(analyzer_path)
analyzer = si.create_sorting_analyzer(
sorting_zarr, recording_zarr, format="zarr", folder=analyzer_path, overwrite=True
)
analyzer.compute(["random_spikes", "templates"])
print(f"Saved analyzer -> {analyzer_path}")

# Reload to verify templates are accessible before writing expected values
templates_array = analyzer.get_extension("templates").get_data()

# --- capture expected values for later assertion ---
expected = {
"spikeinterface_version": si.__version__,
"zarr_version": zarr.__version__,
"recording": {
"num_channels": int(recording.get_num_channels()),
"num_segments": int(recording.get_num_segments()),
"sampling_frequency": float(recording.get_sampling_frequency()),
"num_samples_per_segment": [int(recording.get_num_samples(seg)) for seg in range(recording.get_num_segments())],
"channel_ids": recording.get_channel_ids().tolist(),
"dtype": str(recording.get_dtype()),
# first 10 frames of segment 0 for all channels
"traces_seg0_first10": recording.get_traces(start_frame=0, end_frame=10, segment_index=0).tolist(),
},
"sorting": {
"num_segments": int(sorting.get_num_segments()),
"sampling_frequency": float(sorting.get_sampling_frequency()),
"unit_ids": sorting.get_unit_ids().tolist(),
"spike_trains_seg0": {
str(uid): sorting.get_unit_spike_train(unit_id=uid, segment_index=0).tolist()
for uid in sorting.unit_ids
},
},
"analyzer": {
"num_units": int(analyzer.get_num_units()),
"num_channels": int(analyzer.get_num_channels()),
"templates_shape": list(templates_array.shape),
},
}

expected_path = output_dir / "expected_values.json"
with open(expected_path, "w") as f:
json.dump(expected, f, indent=2)
print(f"Saved expected -> {expected_path}")


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate zarr v2 fixtures for backward compatibility tests")
parser.add_argument("--output", type=Path, required=True, help="Directory to write fixtures into")
args = parser.parse_args()
main(args.output)
2 changes: 1 addition & 1 deletion .github/workflows/all-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.10", "3.13"] # Lower and higher versions we support
python-version: ["3.11", "3.13"] # Lower and higher versions we support
os: [macos-latest, windows-latest, ubuntu-latest]
steps:
- uses: actions/checkout@v4
Expand Down
6 changes: 2 additions & 4 deletions .github/workflows/deepinterpolation.yml
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
name: Testing deepinterpolation

# Manual only — deepinterpolation requires Python 3.10, incompatible with 3.11+ required by Zarr 3.0.0+
on:
pull_request:
types: [synchronize, opened, reopened]
branches:
- main
workflow_dispatch:

concurrency: # Cancel previous workflows on the same pull request
group: ${{ github.workflow }}-${{ github.ref }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_containers_docker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.10'
python-version: '3.11'
- name: Python version
run:
python --version
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_containers_singularity.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.10'
python-version: '3.11'
- uses: eWaterCycle/setup-singularity@v7
with:
singularity-version: 3.8.7
Expand Down
47 changes: 47 additions & 0 deletions .github/workflows/test_zarr_compat.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
name: Test zarr backwards compatibility

on:
workflow_dispatch:
pull_request:
types: [synchronize, opened, reopened]
branches:
- main
paths:
- "src/spikeinterface/core/zarrextractors.py"
- "src/spikeinterface/core/zarrrecordingextractor.py"
- "src/spikeinterface/core/tests/test_zarr_backwards_compat.py"
- ".github/workflows/test_zarr_compat.yml"
- ".github/scripts/generate_zarr_v2_fixtures.py"

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
test-zarr-compat:
name: zarr v2 -> v3 backwards compatibility
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"

- name: Install SI 0.104.0 with zarr v2
run: pip install "spikeinterface==0.104.0" "zarr<3"

- name: Generate zarr v2 fixtures
run: python .github/scripts/generate_zarr_v2_fixtures.py --output /tmp/zarr_v2_fixtures

- name: Install current SI with zarr v3
run: pip install -e ".[test_core]"

- name: Check zarr version is v3
run: python -c "import zarr; v = zarr.__version__; print(f'zarr {v}'); assert int(v.split('.')[0]) >= 3"

- name: Run backward compatibility tests
env:
ZARR_V2_FIXTURES_PATH: /tmp/zarr_v2_fixtures
run: pytest src/spikeinterface/core/tests/test_zarr_backwards_compat.py -v
21 changes: 14 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ authors = [
]
description = "Python toolkit for analysis, visualization, and comparison of spike sorting output"
readme = "README.md"
requires-python = ">=3.10"
requires-python = ">=3.11"
classifiers = [
"Programming Language :: Python :: 3 :: Only",
"License :: OSI Approved :: MIT License",
Expand All @@ -24,12 +24,11 @@ dependencies = [
"numpy>=2.0.0;python_version>='3.13'",
"threadpoolctl>=3.0.0",
"tqdm",
"zarr>=2.18,<3",
"zarr>=3,<4",
"neo>=0.14.4",
"probeinterface>=0.3.2",
"packaging",
"pydantic",
"numcodecs<0.16.0", # For supporting zarr < 3
]

[build-system]
Expand Down Expand Up @@ -127,7 +126,9 @@ test_core = [

# for github test : probeinterface and neo from master
# for release we need pypi, so this need to be commented
"probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git",
# FOR TESTING: use probeinterface zarrv3 branch
"probeinterface @ git+https://github.com/alejoe91/probeinterface.git@zarrv3",
# "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git",
"neo @ git+https://github.com/NeuralEnsemble/python-neo.git",

# for slurm jobs,
Expand All @@ -139,7 +140,9 @@ test_extractors = [
"pooch>=1.8.2",
"datalad>=1.0.2",
# Commenting out for release
"probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git",
# FOR TESTING: use probeinterface zarrv3 branch
"probeinterface @ git+https://github.com/alejoe91/probeinterface.git@zarrv3",
# "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git",
"neo @ git+https://github.com/NeuralEnsemble/python-neo.git",
]

Expand Down Expand Up @@ -190,7 +193,9 @@ test = [

# for github test : probeinterface and neo from master
# for release we need pypi, so this need to be commented
"probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git",
# FOR TESTING: use probeinterface zarrv3 branch
"probeinterface @ git+https://github.com/alejoe91/probeinterface.git@zarrv3",
# "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git",
"neo @ git+https://github.com/NeuralEnsemble/python-neo.git",

# for slurm jobs
Expand Down Expand Up @@ -219,7 +224,9 @@ docs = [
"huggingface_hub", # For automated curation

# for release we need pypi, so this needs to be commented
"probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", # We always build from the latest version
# FOR TESTING: use probeinterface zarrv3 branch
"probeinterface @ git+https://github.com/alejoe91/probeinterface.git@zarrv3",
# "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git",
"neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # We always build from the latest version
]

Expand Down
42 changes: 19 additions & 23 deletions src/spikeinterface/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,8 @@ def save(self, **kwargs) -> BaseExtractor:
* dump_ext: "json" or "pkl", default "json" (if format is "folder")
* verbose: if True output is verbose
* **save_kwargs: additional kwargs format-dependent and job kwargs for recording
(check `save_to_memory()`, `save_to_folder()`, `save_to_zarr()` for more details on format-dependent
kwargs)
{}

Returns
Expand All @@ -892,13 +894,27 @@ def save(self, **kwargs) -> BaseExtractor:
save.__doc__ = save.__doc__.format(_shared_job_kwargs_doc)

def save_to_memory(self, sharedmem=True, **save_kwargs) -> BaseExtractor:
"""
Save the object to memory.

Parameters
----------
sharedmem : bool, default: True
If True, the object is saved to shared memory, allowing it to be accessed by multiple processes without
copying. If False, the object is saved to regular memory, which may involve copying when accessed by
multiple processes.

Returns
-------
BaseExtractor
A saved copy of the extractor in memory.
"""
save_kwargs.pop("format", None)

cached = self._save(format="memory", sharedmem=sharedmem, **save_kwargs)
self.copy_metadata(cached)
return cached

# TODO rename to saveto_binary_folder
def save_to_folder(
self,
name: str | None = None,
Expand Down Expand Up @@ -944,8 +960,7 @@ def save_to_folder(
If True, an existing folder at the specified path will be deleted before saving.
verbose : bool, default: True
If True, print information about the cache folder being used.
**save_kwargs
Additional keyword arguments to be passed to the underlying save method.
{}

Returns
-------
Expand Down Expand Up @@ -1010,7 +1025,6 @@ def save_to_zarr(
folder=None,
overwrite=False,
storage_options=None,
channel_chunk_size=None,
verbose=True,
**save_kwargs,
):
Expand All @@ -1030,26 +1044,9 @@ def save_to_zarr(
storage_options: dict or None, default: None
Storage options for zarr `store`. E.g., if "s3://" or "gcs://" they can provide authentication methods, etc.
For cloud storage locations, this should not be None (in case of default values, use an empty dict)
channel_chunk_size: int or None, default: None
Channels per chunk (only for BaseRecording)
compressor: numcodecs.Codec or None, default: None
Global compressor. If None, Blosc-zstd, level 5, with bit shuffle is used
filters: list[numcodecs.Codec] or None, default: None
Global filters for zarr (global)
compressor_by_dataset: dict or None, default: None
Optional compressor per dataset:
- traces
- times
If None, the global compressor is used
filters_by_dataset: dict or None, default: None
Optional filters per dataset:
- traces
- times
If None, the global filters are used
verbose: bool, default: True
If True, the output is verbose
auto_cast_uint: bool, default: True
If True, unsigned integers are cast to signed integers to avoid issues with zarr (only for BaseRecording)
{}

Returns
-------
Expand Down Expand Up @@ -1085,7 +1082,6 @@ def save_to_zarr(
assert not zarr_path.exists(), f"Path {zarr_path} already exists, choose another name"
save_kwargs["zarr_path"] = zarr_path
save_kwargs["storage_options"] = storage_options
save_kwargs["channel_chunk_size"] = channel_chunk_size
cached = self._save(format="zarr", verbose=verbose, **save_kwargs)
cached = read_zarr(zarr_path)

Expand Down
Loading
Loading