Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 26 additions & 10 deletions src/zarr/codecs/sharding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from collections.abc import Iterable, Mapping, MutableMapping
from collections.abc import Iterable, Mapping, MutableMapping, Sequence
from dataclasses import dataclass, replace
from enum import Enum
from functools import lru_cache
Expand Down Expand Up @@ -45,6 +45,7 @@
from zarr.core.dtype.npy.int import UInt64
from zarr.core.indexing import (
BasicIndexer,
ChunkProjection,
SelectorTuple,
_morton_order,
_morton_order_keys,
Expand Down Expand Up @@ -574,21 +575,26 @@ async def _encode_partial_single(
chunks_per_shard = self._get_chunks_per_shard(shard_spec)
chunk_spec = self._get_chunk_spec(shard_spec)

shard_reader = await self._load_full_shard_maybe(
byte_getter=byte_setter,
prototype=chunk_spec.prototype,
chunks_per_shard=chunks_per_shard,
)
shard_reader = shard_reader or _ShardReader.create_empty(chunks_per_shard)
# Use vectorized lookup for better performance
shard_dict = shard_reader.to_dict_vectorized(np.asarray(_morton_order(chunks_per_shard)))

indexer = list(
get_indexer(
selection, shape=shard_shape, chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape)
)
)

if self._is_complete_shard_write(indexer, chunks_per_shard):
shard_dict = dict.fromkeys(morton_order_iter(chunks_per_shard))
else:
shard_reader = await self._load_full_shard_maybe(
byte_getter=byte_setter,
prototype=chunk_spec.prototype,
chunks_per_shard=chunks_per_shard,
)
shard_reader = shard_reader or _ShardReader.create_empty(chunks_per_shard)
# Use vectorized lookup for better performance
shard_dict = shard_reader.to_dict_vectorized(
np.asarray(_morton_order(chunks_per_shard))
)

await self.codec_pipeline.write(
[
(
Expand Down Expand Up @@ -661,6 +667,16 @@ def _is_total_shard(
chunk_coords in all_chunk_coords for chunk_coords in c_order_iter(chunks_per_shard)
)

def _is_complete_shard_write(
self,
indexed_chunks: Sequence[ChunkProjection],
chunks_per_shard: tuple[int, ...],
) -> bool:
all_chunk_coords = {chunk_coords for chunk_coords, *_ in indexed_chunks}
return self._is_total_shard(all_chunk_coords, chunks_per_shard) and all(
is_complete_chunk for *_, is_complete_chunk in indexed_chunks
)

async def _decode_shard_index(
self, index_bytes: Buffer, chunks_per_shard: tuple[int, ...]
) -> _ShardIndex:
Expand Down
31 changes: 28 additions & 3 deletions tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2259,9 +2259,34 @@ def test_create_array_with_data_num_gets(
data = zarr.zeros(shape, dtype="int64")

zarr.create_array(store, data=data, chunks=chunk_shape, shards=shard_shape, fill_value=-1) # type: ignore[arg-type]
# one get for the metadata and one per shard.
# Note: we don't actually need one get per shard, but this is the current behavior
assert store.counter["get"] == 1 + num_shards
# One get for the metadata; full-shard writes should not read shard payloads.
assert store.counter["get"] == 1


@pytest.mark.parametrize(
("selection", "expected_gets"),
[(slice(None), 0), (slice(1, 9), 1)],
)
def test_shard_write_num_gets(selection: slice, expected_gets: int) -> None:
"""
Test that partial-shard writes read the existing data and full-shard writes don't.
"""
store = LoggingStore(store=MemoryStore())
arr = zarr.create_array(
store,
shape=(10,),
chunks=(1,),
shards=(10,),
dtype="int64",
fill_value=-1,
)
arr[:] = 0

store.counter.clear()

arr[selection] = 1

assert store.counter["get"] == expected_gets


@pytest.mark.parametrize("config", [{}, {"write_empty_chunks": True}, {"order": "C"}])
Expand Down