Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions src/zarr/core/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2827,11 +2827,17 @@ def __getitem__(self, selection: Selection) -> NDArrayLikeOrScalar:
"""
fields, pure_selection = pop_fields(selection)
if is_pure_fancy_indexing(pure_selection, self.ndim):
return self.vindex[cast("CoordinateSelection | MaskSelection", selection)]
result = self.vindex[cast("CoordinateSelection | MaskSelection", selection)]
elif is_pure_orthogonal_indexing(pure_selection, self.ndim):
return self.get_orthogonal_selection(pure_selection, fields=fields)
result = self.get_orthogonal_selection(pure_selection, fields=fields)
else:
return self.get_basic_selection(cast("BasicSelection", pure_selection), fields=fields)
result = self.get_basic_selection(cast("BasicSelection", pure_selection), fields=fields)

# Convert 0-d ndarray to numpy scalar for scalar indexing
# This matches numpy behavior where a[0] returns a scalar, not a 0-d array
if isinstance(result, np.ndarray) and result.ndim == 0:
return result[()]
return result

def __setitem__(self, selection: Selection, value: npt.ArrayLike) -> None:
"""Modify data for an item or region of the array.
Expand Down
117 changes: 117 additions & 0 deletions test_scalar_indexing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""Tests for scalar indexing fix (Issue #3741)."""

import numpy as np
import pytest

import zarr


class TestScalarIndexing:
"""Test that scalar indexing returns numpy scalars, matching numpy behavior."""

def test_1d_scalar_indexing(self):
"""Test scalar indexing on 1-D array returns numpy scalar."""
arr_zarr = zarr.array([1, 2, 3, 4, 5], dtype='int64')
arr_numpy = np.array([1, 2, 3, 4, 5], dtype='int64')

result_zarr = arr_zarr[0]
result_numpy = arr_numpy[0]

assert type(result_zarr) == type(result_numpy)
assert result_zarr == result_numpy
assert not isinstance(result_zarr, np.ndarray)
assert isinstance(result_zarr, np.generic)

def test_2d_scalar_indexing(self):
"""Test scalar indexing on 2-D array returns numpy scalar."""
arr_zarr = zarr.array([[1, 2, 3], [4, 5, 6]], dtype='int64')
arr_numpy = np.array([[1, 2, 3], [4, 5, 6]], dtype='int64')

result_zarr = arr_zarr[0, 0]
result_numpy = arr_numpy[0, 0]

assert type(result_zarr) == type(result_numpy)
assert result_zarr == result_numpy
assert not isinstance(result_zarr, np.ndarray)

def test_3d_scalar_indexing(self):
"""Test scalar indexing on 3-D array returns numpy scalar."""
arr_zarr = zarr.arange(24, dtype='int64').reshape(2, 3, 4)
arr_numpy = np.arange(24, dtype='int64').reshape(2, 3, 4)

result_zarr = arr_zarr[0, 1, 2]
result_numpy = arr_numpy[0, 1, 2]

assert type(result_zarr) == type(result_numpy)
assert result_zarr == result_numpy

def test_slice_indexing_returns_array(self):
"""Test that slice indexing still returns arrays."""
arr_zarr = zarr.array([1, 2, 3, 4, 5])
result = arr_zarr[0:2]

assert isinstance(result, np.ndarray)
assert result.ndim == 1
assert len(result) == 2

def test_partial_scalar_indexing_on_2d(self):
"""Test partial scalar indexing on 2-D array returns 1-D array."""
arr_zarr = zarr.array([[1, 2, 3], [4, 5, 6]], dtype='int64')
arr_numpy = np.array([[1, 2, 3], [4, 5, 6]], dtype='int64')

result_zarr = arr_zarr[0]
result_numpy = arr_numpy[0]

assert type(result_zarr) == type(result_numpy)
assert isinstance(result_zarr, np.ndarray)
assert result_zarr.ndim == 1
np.testing.assert_array_equal(result_zarr, result_numpy)

def test_float_dtype_scalar_indexing(self):
"""Test scalar indexing with float dtype."""
arr_zarr = zarr.array([1.5, 2.5, 3.5], dtype='float64')
arr_numpy = np.array([1.5, 2.5, 3.5], dtype='float64')

result_zarr = arr_zarr[0]
result_numpy = arr_numpy[0]

assert type(result_zarr) == type(result_numpy)
assert result_zarr == result_numpy

def test_negative_indexing(self):
"""Test scalar indexing with negative indices."""
arr_zarr = zarr.array([1, 2, 3, 4, 5], dtype='int64')
arr_numpy = np.array([1, 2, 3, 4, 5], dtype='int64')

result_zarr = arr_zarr[-1]
result_numpy = arr_numpy[-1]

assert type(result_zarr) == type(result_numpy)
assert result_zarr == result_numpy

def test_ellipsis_indexing_returns_array(self):
"""Test that ellipsis indexing returns the full array."""
arr_zarr = zarr.array([1, 2, 3], dtype='int64')
result = arr_zarr[...]

assert isinstance(result, np.ndarray)
assert result.ndim == 1
np.testing.assert_array_equal(result, np.array([1, 2, 3]))

def test_mixed_slice_and_scalar(self):
"""Test mixed slice and scalar indexing."""
arr_zarr = zarr.arange(24, dtype='int64').reshape(2, 3, 4)
arr_numpy = np.arange(24, dtype='int64').reshape(2, 3, 4)

# [0, :, 2] should return 1-D array
result_zarr = arr_zarr[0, :, 2]
result_numpy = arr_numpy[0, :, 2]

assert type(result_zarr) == type(result_numpy)
assert isinstance(result_zarr, np.ndarray)
assert result_zarr.ndim == 1
np.testing.assert_array_equal(result_zarr, result_numpy)


if __name__ == '__main__':
pytest.main([__file__, '-v'])
Loading