diff --git a/changes/3004.feature.md b/changes/3004.feature.md new file mode 100644 index 0000000000..1855acec89 --- /dev/null +++ b/changes/3004.feature.md @@ -0,0 +1,7 @@ +Optimizes reading multiple chunks from a shard. Reads of nearby chunks within +the same shard are coalesced to reduce the number of calls to the store. +After any coalescing, the resulting byte ranges are read in parallel. + +Coalescing respects two config options. Reads are coalesced if there are fewer +than `sharding.read.coalesce_max_gap_bytes` bytes between chunks and the total +size of the coalesced read is no more than `sharding.read.coalesce_max_bytes`. \ No newline at end of file diff --git a/docs/contributing.md b/docs/contributing.md index e42ba0edf1..cb7de3ed07 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -64,7 +64,7 @@ hatch env show # list all available environments To verify that your development environment is working, you can run the unit tests for one of the test environments, e.g.: ```bash -hatch env run --env test.py3.12-2.2-optional run-pytest +hatch env run --env test.py3.13-optional run-pytest ``` ### Creating a branch diff --git a/docs/user-guide/config.md b/docs/user-guide/config.md index 21fe9b5def..572a9db12c 100644 --- a/docs/user-guide/config.md +++ b/docs/user-guide/config.md @@ -33,6 +33,7 @@ Configuration options include the following: - Async and threading options, e.g. `async.concurrency` and `threading.max_workers` - Selections of implementations of codecs, codec pipelines and buffers - Enabling GPU support with `zarr.config.enable_gpu()`. See GPU support for more. +- Control request merging when reading multiple chunks from the same shard with `sharding.read.coalesce_max_gap_bytes` and `sharding.read.coalesce_max_bytes` For selecting custom implementations of codecs, pipelines, buffers and ndbuffers, first register the implementations in the registry and then select them in the config. diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index 8124ea44ea..86ef8dcd85 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -5,7 +5,7 @@ from enum import Enum from functools import lru_cache from operator import itemgetter -from typing import TYPE_CHECKING, Any, NamedTuple, cast +from typing import TYPE_CHECKING, Any, NamedTuple import numpy as np import numpy.typing as npt @@ -37,11 +37,13 @@ from zarr.core.chunk_grids import ChunkGrid, RegularChunkGrid from zarr.core.common import ( ShapeLike, + concurrent_map, parse_enum, parse_named_configuration, parse_shapelike, product, ) +from zarr.core.config import config from zarr.core.dtype.npy.int import UInt64 from zarr.core.indexing import ( BasicIndexer, @@ -114,9 +116,7 @@ class _ShardIndex(NamedTuple): @property def chunks_per_shard(self) -> tuple[int, ...]: - result = tuple(self.offsets_and_lengths.shape[0:-1]) - # The cast is required until https://github.com/numpy/numpy/pull/27211 is merged - return cast("tuple[int, ...]", result) + return tuple(self.offsets_and_lengths.shape[0:-1]) def _localize_chunk(self, chunk_coords: tuple[int, ...]) -> tuple[int, ...]: return tuple( @@ -220,9 +220,19 @@ def __iter__(self) -> Iterator[tuple[int, ...]]: return c_order_iter(self.index.offsets_and_lengths.shape[:-1]) +@dataclass(frozen=True) +class _ChunkCoordsByteSlice: + """Holds a core.indexing.ChunkProjection.chunk_coords and its byte range in a serialized shard.""" + + chunk_coords: tuple[int, ...] + byte_slice: slice + + @dataclass(frozen=True) class ShardingCodec( - ArrayBytesCodec, ArrayBytesCodecPartialDecodeMixin, ArrayBytesCodecPartialEncodeMixin + ArrayBytesCodec, + ArrayBytesCodecPartialDecodeMixin, + ArrayBytesCodecPartialEncodeMixin, ): """Sharding codec""" @@ -400,32 +410,31 @@ async def _decode_partial_single( all_chunk_coords = {chunk_coords for chunk_coords, *_ in indexed_chunks} # reading bytes of all requested chunks - shard_dict: ShardMapping = {} + shard_dict_maybe: ShardMapping | None = {} if self._is_total_shard(all_chunk_coords, chunks_per_shard): # read entire shard shard_dict_maybe = await self._load_full_shard_maybe( - byte_getter=byte_getter, - prototype=chunk_spec.prototype, - chunks_per_shard=chunks_per_shard, + byte_getter, chunk_spec.prototype, chunks_per_shard ) - if shard_dict_maybe is None: - return None - shard_dict = shard_dict_maybe else: # read some chunks within the shard - shard_index = await self._load_shard_index_maybe(byte_getter, chunks_per_shard) - if shard_index is None: - return None - shard_dict = {} - for chunk_coords in all_chunk_coords: - chunk_byte_slice = shard_index.get_chunk_slice(chunk_coords) - if chunk_byte_slice: - chunk_bytes = await byte_getter.get( - prototype=chunk_spec.prototype, - byte_range=RangeByteRequest(chunk_byte_slice[0], chunk_byte_slice[1]), - ) - if chunk_bytes: - shard_dict[chunk_coords] = chunk_bytes + max_gap_bytes = config.get("sharding.read.coalesce_max_gap_bytes") + coalesce_max_bytes = config.get("sharding.read.coalesce_max_bytes") + async_concurrency = config.get("async.concurrency") + + shard_dict_maybe = await self._load_partial_shard_maybe( + byte_getter, + chunk_spec.prototype, + chunks_per_shard, + all_chunk_coords, + max_gap_bytes, + coalesce_max_bytes, + async_concurrency, + ) + + if shard_dict_maybe is None: + return None + shard_dict = shard_dict_maybe # decoding chunks and writing them into the output buffer await self.codec_pipeline.read( @@ -509,7 +518,9 @@ async def _encode_partial_single( indexer = list( get_indexer( - selection, shape=shard_shape, chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape) + selection, + shape=shard_shape, + chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape), ) ) @@ -624,7 +635,8 @@ def _shard_index_size(self, chunks_per_shard: tuple[int, ...]) -> int: get_pipeline_class() .from_codecs(self.index_codecs) .compute_encoded_size( - 16 * product(chunks_per_shard), self._get_index_chunk_spec(chunks_per_shard) + 16 * product(chunks_per_shard), + self._get_index_chunk_spec(chunks_per_shard), ) ) @@ -669,7 +681,8 @@ async def _load_shard_index_maybe( ) else: index_bytes = await byte_getter.get( - prototype=numpy_buffer_prototype(), byte_range=SuffixByteRequest(shard_index_size) + prototype=numpy_buffer_prototype(), + byte_range=SuffixByteRequest(shard_index_size), ) if index_bytes is not None: return await self._decode_shard_index(index_bytes, chunks_per_shard) @@ -693,6 +706,115 @@ async def _load_full_shard_maybe( else None ) + async def _load_partial_shard_maybe( + self, + byte_getter: ByteGetter, + prototype: BufferPrototype, + chunks_per_shard: tuple[int, ...], + all_chunk_coords: set[tuple[int, ...]], + max_gap_bytes: int, + coalesce_max_bytes: int, + async_concurrency: int, + ) -> ShardMapping | None: + """ + Read chunks from `byte_getter` for the case where the read is less than a full shard. + Returns a mapping of chunk coordinates to bytes or None. + + Reads are coalesced if there are fewer than `max_gap_bytes` bytes between chunks + and the total size of the coalesced read is no more than `coalesce_max_bytes`. + """ + shard_index = await self._load_shard_index_maybe(byte_getter, chunks_per_shard) + if shard_index is None: + return None # shard index read failure, the ByteGetter returned None + + chunks = [ + _ChunkCoordsByteSlice(chunk_coords, slice(*chunk_byte_slice)) + for chunk_coords in all_chunk_coords + # Drop chunks where index lookup fails + # e.g. empty chunks when write_empty_chunks = False + if (chunk_byte_slice := shard_index.get_chunk_slice(chunk_coords)) + ] + + groups = self._coalesce_chunks(chunks, max_gap_bytes, coalesce_max_bytes) + + shard_dicts = await concurrent_map( + [(group, byte_getter, prototype) for group in groups], + self._get_group_bytes, + async_concurrency, + ) + + shard_dict: ShardMutableMapping = {} + for d in shard_dicts: + # can be None if the ByteGetter returned None when reading chunk data + if d is not None: + shard_dict.update(d) + + return shard_dict + + def _coalesce_chunks( + self, + chunks: list[_ChunkCoordsByteSlice], + max_gap_bytes: int, + coalesce_max_bytes: int, + ) -> list[list[_ChunkCoordsByteSlice]]: + """ + Combine chunks from a single shard into groups that should be read together + in a single request to the store. + """ + sorted_chunks = sorted(chunks, key=lambda c: c.byte_slice.start) + + if len(sorted_chunks) == 0: + return [] + + groups = [] + current_group = [sorted_chunks[0]] + + for chunk in sorted_chunks[1:]: + gap_to_chunk = chunk.byte_slice.start - current_group[-1].byte_slice.stop + size_if_coalesced = chunk.byte_slice.stop - current_group[0].byte_slice.start + if gap_to_chunk < max_gap_bytes and size_if_coalesced < coalesce_max_bytes: + current_group.append(chunk) + else: + groups.append(current_group) + current_group = [chunk] + + groups.append(current_group) + + return groups + + async def _get_group_bytes( + self, + group: list[_ChunkCoordsByteSlice], + byte_getter: ByteGetter, + prototype: BufferPrototype, + ) -> ShardMapping | None: + """ + Reads a possibly coalesced group of one or more chunks from a shard. + Returns a mapping of chunk coordinates to bytes. + """ + # _coalesce_chunks ensures that the group is not empty. + group_start = group[0].byte_slice.start + group_end = group[-1].byte_slice.stop + + # A single call to retrieve the bytes for the entire group. + group_bytes = await byte_getter.get( + prototype=prototype, + byte_range=RangeByteRequest(group_start, group_end), + ) + if group_bytes is None: + return None + + # Extract the bytes corresponding to each chunk in group from group_bytes. + shard_dict = {} + for chunk in group: + chunk_slice = slice( + chunk.byte_slice.start - group_start, + chunk.byte_slice.stop - group_start, + ) + shard_dict[chunk.chunk_coords] = group_bytes[chunk_slice] + + return shard_dict + def compute_encoded_size(self, input_byte_length: int, shard_spec: ArraySpec) -> int: chunks_per_shard = self._get_chunks_per_shard(shard_spec) return input_byte_length + self._shard_index_size(chunks_per_shard) diff --git a/src/zarr/core/config.py b/src/zarr/core/config.py index f8f8ea4f5f..0911601d69 100644 --- a/src/zarr/core/config.py +++ b/src/zarr/core/config.py @@ -100,6 +100,12 @@ def enable_gpu(self) -> ConfigSet: }, "async": {"concurrency": 10, "timeout": None}, "threading": {"max_workers": None}, + "sharding": { + "read": { + "coalesce_max_bytes": 100 * 2**20, # 100MiB + "coalesce_max_gap_bytes": 2**20, # 1MiB + } + }, "json_indent": 2, "codec_pipeline": { "path": "zarr.core.codec_pipeline.BatchedCodecPipeline", diff --git a/tests/test_codecs/test_sharding.py b/tests/test_codecs/test_sharding.py index 7eb4deccbf..3946eaef50 100644 --- a/tests/test_codecs/test_sharding.py +++ b/tests/test_codecs/test_sharding.py @@ -1,6 +1,7 @@ import pickle import re from typing import Any +from unittest.mock import AsyncMock import numpy as np import numpy.typing as npt @@ -10,7 +11,7 @@ import zarr.api import zarr.api.asynchronous from zarr import Array -from zarr.abc.store import Store +from zarr.abc.store import RangeByteRequest, Store, SuffixByteRequest from zarr.codecs import ( BloscCodec, ShardingCodec, @@ -112,7 +113,9 @@ def test_sharding_scalar( indirect=["array_fixture"], ) def test_sharding_partial( - store: Store, array_fixture: npt.NDArray[Any], index_location: ShardingCodecIndexLocation + store: Store, + array_fixture: npt.NDArray[Any], + index_location: ShardingCodecIndexLocation, ) -> None: data = array_fixture spath = StorePath(store) @@ -148,7 +151,9 @@ def test_sharding_partial( indirect=["array_fixture"], ) def test_sharding_partial_readwrite( - store: Store, array_fixture: npt.NDArray[Any], index_location: ShardingCodecIndexLocation + store: Store, + array_fixture: npt.NDArray[Any], + index_location: ShardingCodecIndexLocation, ) -> None: data = array_fixture spath = StorePath(store) @@ -180,7 +185,9 @@ def test_sharding_partial_readwrite( @pytest.mark.parametrize("index_location", ["start", "end"]) @pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"]) def test_sharding_partial_read( - store: Store, array_fixture: npt.NDArray[Any], index_location: ShardingCodecIndexLocation + store: Store, + array_fixture: npt.NDArray[Any], + index_location: ShardingCodecIndexLocation, ) -> None: data = array_fixture spath = StorePath(store) @@ -199,6 +206,301 @@ def test_sharding_partial_read( assert np.all(read_data == 1) +@pytest.mark.parametrize("index_location", ["start", "end"]) +@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"]) +@pytest.mark.parametrize("coalesce_reads", [True, False]) +def test_sharding_multiple_chunks_partial_shard_read( + store: Store, index_location: ShardingCodecIndexLocation, coalesce_reads: bool +) -> None: + array_shape = (16, 64) + shard_shape = (8, 32) + chunk_shape = (2, 4) + data = np.arange(np.prod(array_shape), dtype="float32").reshape(array_shape) + + if coalesce_reads: + # 1MiB, enough to coalesce all chunks within a shard in this example + zarr.config.set({"sharding.read.coalesce_max_gap_bytes": 2**20}) + else: + # disable coalescing + zarr.config.set({"sharding.read.coalesce_max_gap_bytes": -1}) + + store_mock = AsyncMock(wraps=store, spec=store.__class__) + a = zarr.create_array( + StorePath(store_mock), + shape=data.shape, + chunks=chunk_shape, + shards={"shape": shard_shape, "index_location": index_location}, + compressors=BloscCodec(cname="lz4"), + dtype=data.dtype, + fill_value=1, + ) + a[:] = data + + store_mock.reset_mock() # ignore store calls during array creation + + # Reads 3 (2 full, 1 partial) chunks each from 2 shards (a subset of both shards) + # for a total of 6 chunks accessed + assert np.allclose(a[0, 22:42], np.arange(22, 42, dtype="float32")) + + if coalesce_reads: + # 2 shard index requests + 2 coalesced chunk data byte ranges (one for each shard) + assert store_mock.get.call_count == 4 + else: + # 2 shard index requests + 6 chunks + assert store_mock.get.call_count == 8 + + for method, args, kwargs in store_mock.method_calls: + assert method == "get" + assert args[0].startswith("c/") # get from a chunk + assert isinstance(kwargs["byte_range"], (SuffixByteRequest, RangeByteRequest)) + + store_mock.reset_mock() + + # Reads 4 chunks from both shards along dimension 0 for a total of 8 chunks accessed + assert np.allclose(a[:, 0], np.arange(0, data.size, array_shape[1], dtype="float32")) + + if coalesce_reads: + # 2 shard index requests + 2 coalesced chunk data byte ranges (one for each shard) + assert store_mock.get.call_count == 4 + else: + # 2 shard index requests + 8 chunks + assert store_mock.get.call_count == 10 + + for method, args, kwargs in store_mock.method_calls: + assert method == "get" + assert args[0].startswith("c/") # get from a chunk + assert isinstance(kwargs["byte_range"], (SuffixByteRequest, RangeByteRequest)) + + +@pytest.mark.parametrize("index_location", ["start", "end"]) +@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"]) +@pytest.mark.parametrize("coalesce_reads", [True, False]) +def test_sharding_duplicate_read_indexes( + store: Store, index_location: ShardingCodecIndexLocation, coalesce_reads: bool +) -> None: + """ + Check that coalesce optimization parses the grouped reads back out correctly + when there are multiple reads for the same index. + """ + array_shape = (15,) + shard_shape = (8,) + chunk_shape = (2,) + data = np.arange(np.prod(array_shape), dtype="float32").reshape(array_shape) + + if coalesce_reads: + # 1MiB, enough to coalesce all chunks within a shard in this example + zarr.config.set({"sharding.read.coalesce_max_gap_bytes": 2**20}) + else: + # disable coalescing + zarr.config.set({"sharding.read.coalesce_max_gap_bytes": -1}) + + store_mock = AsyncMock(wraps=store, spec=store.__class__) + a = zarr.create_array( + StorePath(store_mock), + shape=data.shape, + chunks=chunk_shape, + shards={"shape": shard_shape, "index_location": index_location}, + compressors=BloscCodec(cname="lz4"), + dtype=data.dtype, + fill_value=-1, + ) + a[:] = data + + store_mock.reset_mock() # ignore store calls during array creation + + # Read the same index multiple times, do that from two chunks which can be coalesced + indexer = [8, 8, 12, 12] + np.array_equal(a[indexer], data[indexer]) + + if coalesce_reads: + # 1 shard index request + 1 coalesced read + assert store_mock.get.call_count == 2 + else: + # 1 shard index request + 2 chunks + assert store_mock.get.call_count == 3 + + +@pytest.mark.parametrize("index_location", ["start", "end"]) +@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"]) +def test_sharding_read_empty_chunks_within_non_empty_shard_write_empty_false( + store: Store, index_location: ShardingCodecIndexLocation +) -> None: + """ + Case where + - some, but not all, chunks in the last shard are empty + - the last shard is not complete (array length is not a multiple of shard shape), + this takes us down the partial shard read path + - write_empty_chunks=False so the shard index will have fewer entries than chunks in the shard + """ + # array with mixed empty and non-empty chunks in second shard + data = np.array([ + # shard 0. full 8 elements, all chunks have some non-fill data + 0, 1, 2, 3, 4, 5, 6, 7, + # shard 1. 6 elements (< shard shape) + 2, 0, # chunk 0, written + -9, -9, # chunk 1, all fill, not written + 4, 5 # chunk 2, written + ], dtype="int32") # fmt: off + + spath = StorePath(store) + a = zarr.create_array( + spath, + shape=(14,), + chunks=(2,), + shards={"shape": (8,), "index_location": index_location}, + dtype="int32", + fill_value=-9, + filters=None, + compressors=None, + config={"write_empty_chunks": False}, + ) + a[:] = data + + assert np.array_equal(a[:], data) + + +@pytest.mark.parametrize("index_location", ["start", "end"]) +@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"]) +def test_sharding_read_empty_chunks_within_empty_shard_write_empty_false( + store: Store, index_location: ShardingCodecIndexLocation +) -> None: + """ + Case where + - all chunks in last shard are empty + - the last shard is not complete (array length is not a multiple of shard shape), + this takes us down the partial shard read path + - write_empty_chunks=False so the shard index will have no entries + """ + fill_value = -99 + shard_size = 8 + data = np.arange(14, dtype="int32") + data[shard_size:] = fill_value # 2nd shard is all fill value + + spath = StorePath(store) + a = zarr.create_array( + spath, + shape=(14,), + chunks=(2,), + shards={"shape": (shard_size,), "index_location": index_location}, + dtype="int32", + fill_value=fill_value, + filters=None, + compressors=None, + config={"write_empty_chunks": False}, + ) + a[:] = data + + assert np.array_equal(a[:], data) + + +@pytest.mark.parametrize("index_location", ["start", "end"]) +@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"]) +def test_sharding_partial_shard_read__index_load_fails( + store: Store, index_location: ShardingCodecIndexLocation +) -> None: + """Test fill value is returned when the call to the store to load the bytes of the shard's chunk index fails.""" + array_shape = (16,) + shard_shape = (16,) + chunk_shape = (8,) + data = np.arange(np.prod(array_shape), dtype="float32").reshape(array_shape) + fill_value = -999 + + store_mock = AsyncMock(wraps=store, spec=store.__class__) + # loading the index is the first call to .get() so returning None will simulate an index load failure + store_mock.get.return_value = None + + a = zarr.create_array( + StorePath(store_mock), + shape=data.shape, + chunks=chunk_shape, + shards={"shape": shard_shape, "index_location": index_location}, + compressors=BloscCodec(cname="lz4"), + dtype=data.dtype, + fill_value=fill_value, + ) + a[:] = data + + # Read from one of two chunks in a shard to test the partial shard read path + assert a[0] == fill_value + assert a[0] != data[0] + + +@pytest.mark.parametrize("index_location", ["start", "end"]) +@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"]) +def test_sharding_partial_shard_read__index_chunk_slice_fails( + store: Store, + index_location: ShardingCodecIndexLocation, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test fill value is returned when looking up a chunk's byte slice within a shard fails.""" + array_shape = (16,) + shard_shape = (16,) + chunk_shape = (8,) + data = np.arange(np.prod(array_shape), dtype="float32").reshape(array_shape) + fill_value = -999 + + monkeypatch.setattr( + "zarr.codecs.sharding._ShardIndex.get_chunk_slice", + lambda self, chunk_coords: None, + ) + + a = zarr.create_array( + StorePath(store), + shape=data.shape, + chunks=chunk_shape, + shards={"shape": shard_shape, "index_location": index_location}, + compressors=BloscCodec(cname="lz4"), + dtype=data.dtype, + fill_value=fill_value, + ) + a[:] = data + + # Read from one of two chunks in a shard to test the partial shard read path + assert a[0] == fill_value + assert a[0] != data[0] + + +@pytest.mark.parametrize("index_location", ["start", "end"]) +@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"]) +def test_sharding_partial_shard_read__chunk_load_fails( + store: Store, index_location: ShardingCodecIndexLocation +) -> None: + """Test fill value is returned when the call to the store to load a chunk's bytes fails.""" + array_shape = (16,) + shard_shape = (16,) + chunk_shape = (8,) + data = np.arange(np.prod(array_shape), dtype="float32").reshape(array_shape) + fill_value = -999 + + store_mock = AsyncMock(wraps=store, spec=store.__class__) + + a = zarr.create_array( + StorePath(store_mock), + shape=data.shape, + chunks=chunk_shape, + shards={"shape": shard_shape, "index_location": index_location}, + compressors=BloscCodec(cname="lz4"), + dtype=data.dtype, + fill_value=fill_value, + ) + a[:] = data + + # Set up store mock after array creation to only modify calls during array indexing + # Succeed on first call (index load), fail on subsequent calls (chunk loads) + async def first_success_then_fail(*args: Any, **kwargs: Any) -> Any: + if store_mock.get.call_count == 1: + return await store.get(*args, **kwargs) + else: + return None + + store_mock.get.reset_mock() + store_mock.get.side_effect = first_success_then_fail + + # Read from one of two chunks in a shard to test the partial shard read path + assert a[0] == fill_value + assert a[0] != data[0] + + @pytest.mark.parametrize( "array_fixture", [ @@ -209,7 +511,9 @@ def test_sharding_partial_read( @pytest.mark.parametrize("index_location", ["start", "end"]) @pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"]) def test_sharding_partial_overwrite( - store: Store, array_fixture: npt.NDArray[Any], index_location: ShardingCodecIndexLocation + store: Store, + array_fixture: npt.NDArray[Any], + index_location: ShardingCodecIndexLocation, ) -> None: data = array_fixture[:10, :10, :10] spath = StorePath(store) @@ -326,7 +630,6 @@ def test_nested_sharding_create_array( filters=None, compressors=None, ) - print(a.metadata.to_dict()) a[:, :, :] = data @@ -386,7 +689,6 @@ async def test_delete_empty_shards(store: Store) -> None: compressors=None, fill_value=1, ) - print(a.metadata.to_dict()) await _AsyncArrayProxy(a)[:, :].set(np.zeros((16, 16))) await _AsyncArrayProxy(a)[8:, :].set(np.ones((8, 16))) await _AsyncArrayProxy(a)[:, 8:].set(np.ones((16, 8))) @@ -431,7 +733,6 @@ async def test_sharding_with_empty_inner_chunk( ) data[:4, :4] = fill_value await a.setitem(..., data) - print("read data") data_read = await a.getitem(...) assert np.array_equal(data_read, data) @@ -443,7 +744,9 @@ async def test_sharding_with_empty_inner_chunk( ) @pytest.mark.parametrize("chunks_per_shard", [(5, 2), (2, 5), (5, 5)]) async def test_sharding_with_chunks_per_shard( - store: Store, index_location: ShardingCodecIndexLocation, chunks_per_shard: tuple[int] + store: Store, + index_location: ShardingCodecIndexLocation, + chunks_per_shard: tuple[int], ) -> None: chunk_shape = (2, 1) shape = tuple(x * y for x, y in zip(chunks_per_shard, chunk_shape, strict=False)) diff --git a/tests/test_codecs/test_sharding_unit.py b/tests/test_codecs/test_sharding_unit.py new file mode 100644 index 0000000000..cc5df22b67 --- /dev/null +++ b/tests/test_codecs/test_sharding_unit.py @@ -0,0 +1,664 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import numpy as np +import pytest + +from zarr.codecs.sharding import ( + MAX_UINT_64, + ShardingCodec, + _ChunkCoordsByteSlice, + _ShardIndex, + _ShardReader, +) +from zarr.core.buffer import default_buffer_prototype +from zarr.core.buffer.cpu import Buffer + +if TYPE_CHECKING: + from zarr.abc.store import ByteRequest + from zarr.core.buffer import BufferPrototype + + +# ============================================================================ +# _ShardIndex tests +# ============================================================================ + + +def test_shard_index_create_empty() -> None: + """Test that create_empty creates an index filled with MAX_UINT_64.""" + chunks_per_shard = (2, 3) + index = _ShardIndex.create_empty(chunks_per_shard) + + assert index.chunks_per_shard == chunks_per_shard + assert index.offsets_and_lengths.shape == (2, 3, 2) + assert index.offsets_and_lengths.dtype == np.dtype(" None: + """Test create_empty with 1D chunks_per_shard.""" + chunks_per_shard = (4,) + index = _ShardIndex.create_empty(chunks_per_shard) + + assert index.chunks_per_shard == chunks_per_shard + assert index.offsets_and_lengths.shape == (4, 2) + + +def test_shard_index_is_all_empty_true() -> None: + """Test is_all_empty returns True for a freshly created empty index.""" + index = _ShardIndex.create_empty((2, 2)) + assert index.is_all_empty() is True + + +def test_shard_index_is_all_empty_false() -> None: + """Test is_all_empty returns False when at least one chunk is set.""" + index = _ShardIndex.create_empty((2, 2)) + index.set_chunk_slice((0, 0), slice(0, 100)) + assert index.is_all_empty() is False + + +def test_shard_index_get_chunk_slice_empty() -> None: + """Test get_chunk_slice returns None for empty chunks.""" + index = _ShardIndex.create_empty((2, 2)) + assert index.get_chunk_slice((0, 0)) is None + assert index.get_chunk_slice((1, 1)) is None + + +def test_shard_index_get_chunk_slice_set() -> None: + """Test get_chunk_slice returns correct (start, end) tuple after setting.""" + index = _ShardIndex.create_empty((2, 2)) + index.set_chunk_slice((0, 1), slice(100, 200)) + + result = index.get_chunk_slice((0, 1)) + assert result == (100, 200) + + +def test_shard_index_set_chunk_slice() -> None: + """Test set_chunk_slice correctly sets offset and length.""" + index = _ShardIndex.create_empty((3, 3)) + + # Set a chunk slice + index.set_chunk_slice((1, 2), slice(50, 150)) + + # Verify the underlying array + assert index.offsets_and_lengths[1, 2, 0] == 50 # offset + assert index.offsets_and_lengths[1, 2, 1] == 100 # length (150 - 50) + + +def test_shard_index_set_chunk_slice_none() -> None: + """Test set_chunk_slice with None marks chunk as empty.""" + index = _ShardIndex.create_empty((2, 2)) + + # First set a value + index.set_chunk_slice((0, 0), slice(0, 100)) + assert index.get_chunk_slice((0, 0)) == (0, 100) + + # Then clear it + index.set_chunk_slice((0, 0), None) + assert index.get_chunk_slice((0, 0)) is None + assert index.offsets_and_lengths[0, 0, 0] == MAX_UINT_64 + assert index.offsets_and_lengths[0, 0, 1] == MAX_UINT_64 + + +def test_shard_index_get_full_chunk_map() -> None: + """Test get_full_chunk_map returns correct boolean array.""" + index = _ShardIndex.create_empty((2, 3)) + + # Set some chunks + index.set_chunk_slice((0, 0), slice(0, 10)) + index.set_chunk_slice((1, 2), slice(10, 20)) + + chunk_map = index.get_full_chunk_map() + + assert chunk_map.shape == (2, 3) + assert chunk_map.dtype == np.bool_ + assert chunk_map[0, 0] is np.True_ + assert chunk_map[0, 1] is np.False_ + assert chunk_map[0, 2] is np.False_ + assert chunk_map[1, 0] is np.False_ + assert chunk_map[1, 1] is np.False_ + assert chunk_map[1, 2] is np.True_ + + +def test_shard_index_localize_chunk() -> None: + """Test _localize_chunk maps global coords to local shard coords via modulo.""" + index = _ShardIndex.create_empty((2, 3)) + + # Within bounds - should return same coords + assert index._localize_chunk((0, 0)) == (0, 0) + assert index._localize_chunk((1, 2)) == (1, 2) + + # Out of bounds - should wrap via modulo + assert index._localize_chunk((2, 0)) == (0, 0) # 2 % 2 = 0 + assert index._localize_chunk((3, 5)) == (1, 2) # 3 % 2 = 1, 5 % 3 = 2 + assert index._localize_chunk((4, 6)) == (0, 0) # 4 % 2 = 0, 6 % 3 = 0 + + +def test_shard_index_is_dense_true() -> None: + """Test is_dense returns True when chunks are contiguously packed.""" + index = _ShardIndex.create_empty((2,)) + chunk_byte_length = 100 + + # Set chunks contiguously: [0-100), [100-200) + index.set_chunk_slice((0,), slice(0, 100)) + index.set_chunk_slice((1,), slice(100, 200)) + + assert index.is_dense(chunk_byte_length) is True + + +def test_shard_index_is_dense_false_duplicate_offsets() -> None: + """Test is_dense returns False when chunks have duplicate offsets.""" + index = _ShardIndex.create_empty((2,)) + chunk_byte_length = 100 + + # Set both chunks to same offset (duplicate) + index.set_chunk_slice((0,), slice(0, 100)) + index.set_chunk_slice((1,), slice(0, 100)) + + assert index.is_dense(chunk_byte_length) is False + + +def test_shard_index_is_dense_false_wrong_alignment() -> None: + """Test is_dense returns False when chunks are not aligned to chunk_byte_length.""" + index = _ShardIndex.create_empty((2,)) + chunk_byte_length = 100 + + # Set chunks not aligned: [0-100), [150-250) + index.set_chunk_slice((0,), slice(0, 100)) + index.set_chunk_slice((1,), slice(150, 250)) + + assert index.is_dense(chunk_byte_length) is False + + +def test_shard_index_is_dense_with_empty_chunks() -> None: + """Test is_dense handles empty chunks correctly.""" + index = _ShardIndex.create_empty((3,)) + chunk_byte_length = 100 + + # Only set first and third chunk, skip middle + index.set_chunk_slice((0,), slice(0, 100)) + # (1,) is empty + index.set_chunk_slice((2,), slice(100, 200)) + + # Should still be dense since only non-empty chunks are considered + assert index.is_dense(chunk_byte_length) is True + + +# ============================================================================ +# _coalesce_chunks tests +# ============================================================================ + + +def test_coalesce_chunks_empty_list() -> None: + """Test _coalesce_chunks returns empty list for empty input.""" + codec = ShardingCodec(chunk_shape=(8,)) + result = codec._coalesce_chunks([], max_gap_bytes=100, coalesce_max_bytes=1000) + assert result == [] + + +def test_coalesce_chunks_single_chunk() -> None: + """Test _coalesce_chunks returns single group for single chunk.""" + codec = ShardingCodec(chunk_shape=(8,)) + chunk = _ChunkCoordsByteSlice(chunk_coords=(0,), byte_slice=slice(0, 100)) + + result = codec._coalesce_chunks([chunk], max_gap_bytes=100, coalesce_max_bytes=1000) + + assert len(result) == 1 + assert len(result[0]) == 1 + assert result[0][0] == chunk + + +def test_coalesce_chunks_adjacent_small_gap() -> None: + """Test adjacent chunks with small gap are coalesced.""" + codec = ShardingCodec(chunk_shape=(8,)) + chunk0 = _ChunkCoordsByteSlice(chunk_coords=(0,), byte_slice=slice(0, 100)) + chunk1 = _ChunkCoordsByteSlice(chunk_coords=(1,), byte_slice=slice(110, 210)) # 10 byte gap + + result = codec._coalesce_chunks([chunk0, chunk1], max_gap_bytes=20, coalesce_max_bytes=1000) + + assert len(result) == 1 + assert len(result[0]) == 2 + assert result[0][0] == chunk0 + assert result[0][1] == chunk1 + + +def test_coalesce_chunks_distant_large_gap() -> None: + """Test chunks with large gap are not coalesced.""" + codec = ShardingCodec(chunk_shape=(8,)) + chunk0 = _ChunkCoordsByteSlice(chunk_coords=(0,), byte_slice=slice(0, 100)) + chunk1 = _ChunkCoordsByteSlice(chunk_coords=(1,), byte_slice=slice(500, 600)) # 400 byte gap + + result = codec._coalesce_chunks([chunk0, chunk1], max_gap_bytes=100, coalesce_max_bytes=1000) + + assert len(result) == 2 + assert result[0] == [chunk0] + assert result[1] == [chunk1] + + +def test_coalesce_chunks_disabled_negative_gap() -> None: + """Test coalescing is disabled when max_gap_bytes is negative (like -1).""" + codec = ShardingCodec(chunk_shape=(8,)) + chunk0 = _ChunkCoordsByteSlice(chunk_coords=(0,), byte_slice=slice(0, 100)) + chunk1 = _ChunkCoordsByteSlice(chunk_coords=(1,), byte_slice=slice(100, 200)) # Adjacent! + + result = codec._coalesce_chunks([chunk0, chunk1], max_gap_bytes=-1, coalesce_max_bytes=1000) + + # Even adjacent chunks should not be coalesced + assert len(result) == 2 + + +def test_coalesce_chunks_exceeds_max_bytes() -> None: + """Test chunks are split when total size exceeds coalesce_max_bytes.""" + codec = ShardingCodec(chunk_shape=(8,)) + chunk0 = _ChunkCoordsByteSlice(chunk_coords=(0,), byte_slice=slice(0, 100)) + chunk1 = _ChunkCoordsByteSlice(chunk_coords=(1,), byte_slice=slice(100, 200)) + chunk2 = _ChunkCoordsByteSlice(chunk_coords=(2,), byte_slice=slice(200, 300)) + + # Total would be 300 bytes, but max is 250 + result = codec._coalesce_chunks( + [chunk0, chunk1, chunk2], max_gap_bytes=100, coalesce_max_bytes=250 + ) + + # First two chunks (200 bytes) should be coalesced, third separate + assert len(result) == 2 + assert len(result[0]) == 2 + assert result[0][0] == chunk0 + assert result[0][1] == chunk1 + assert result[1] == [chunk2] + + +def test_coalesce_chunks_unsorted_input() -> None: + """Test chunks are sorted by byte_slice.start before coalescing.""" + codec = ShardingCodec(chunk_shape=(8,)) + chunk0 = _ChunkCoordsByteSlice(chunk_coords=(0,), byte_slice=slice(200, 300)) + chunk1 = _ChunkCoordsByteSlice(chunk_coords=(1,), byte_slice=slice(0, 100)) + chunk2 = _ChunkCoordsByteSlice(chunk_coords=(2,), byte_slice=slice(100, 200)) + + # Input is out of order + result = codec._coalesce_chunks( + [chunk0, chunk1, chunk2], max_gap_bytes=100, coalesce_max_bytes=1000 + ) + + # All should be coalesced and in sorted order + assert len(result) == 1 + assert len(result[0]) == 3 + assert result[0][0] == chunk1 # slice(0, 100) + assert result[0][1] == chunk2 # slice(100, 200) + assert result[0][2] == chunk0 # slice(200, 300) + + +def test_coalesce_chunks_mixed_coalescing() -> None: + """Test mixed scenario with some chunks coalesced and some separate.""" + codec = ShardingCodec(chunk_shape=(8,)) + # Group 1: chunks at 0-100, 100-200 (adjacent) + chunk0 = _ChunkCoordsByteSlice(chunk_coords=(0,), byte_slice=slice(0, 100)) + chunk1 = _ChunkCoordsByteSlice(chunk_coords=(1,), byte_slice=slice(100, 200)) + # Gap of 300 bytes + # Group 2: chunks at 500-600, 600-700 (adjacent) + chunk2 = _ChunkCoordsByteSlice(chunk_coords=(2,), byte_slice=slice(500, 600)) + chunk3 = _ChunkCoordsByteSlice(chunk_coords=(3,), byte_slice=slice(600, 700)) + + result = codec._coalesce_chunks( + [chunk0, chunk1, chunk2, chunk3], max_gap_bytes=100, coalesce_max_bytes=1000 + ) + + assert len(result) == 2 + assert len(result[0]) == 2 + assert result[0][0] == chunk0 + assert result[0][1] == chunk1 + assert len(result[1]) == 2 + assert result[1][0] == chunk2 + assert result[1][1] == chunk3 + + +def test_coalesce_chunks_boundary_gap_equals_max() -> None: + """Test boundary condition where gap equals max_gap_bytes (should NOT coalesce).""" + codec = ShardingCodec(chunk_shape=(8,)) + chunk0 = _ChunkCoordsByteSlice(chunk_coords=(0,), byte_slice=slice(0, 100)) + chunk1 = _ChunkCoordsByteSlice(chunk_coords=(1,), byte_slice=slice(150, 250)) # 50 byte gap + + # Gap is exactly max_gap_bytes, condition is `gap < max_gap_bytes` so should NOT coalesce + result = codec._coalesce_chunks([chunk0, chunk1], max_gap_bytes=50, coalesce_max_bytes=1000) + + assert len(result) == 2 + + +def test_coalesce_chunks_boundary_gap_less_than_max() -> None: + """Test boundary condition where gap is just under max_gap_bytes (should coalesce).""" + codec = ShardingCodec(chunk_shape=(8,)) + chunk0 = _ChunkCoordsByteSlice(chunk_coords=(0,), byte_slice=slice(0, 100)) + chunk1 = _ChunkCoordsByteSlice(chunk_coords=(1,), byte_slice=slice(149, 249)) # 49 byte gap + + result = codec._coalesce_chunks([chunk0, chunk1], max_gap_bytes=50, coalesce_max_bytes=1000) + + assert len(result) == 1 + + +# ============================================================================ +# _get_group_bytes tests +# ============================================================================ + + +@dataclass +class MockByteGetter: + """Mock ByteGetter for testing _get_group_bytes.""" + + data: bytes + return_none: bool = False + + async def get( + self, prototype: BufferPrototype, byte_range: ByteRequest | None = None + ) -> Buffer | None: + if self.return_none: + return None + if byte_range is None: + return Buffer.from_bytes(self.data) + # For RangeByteRequest, extract start and end + start = getattr(byte_range, "start", 0) + end = getattr(byte_range, "end", len(self.data)) + return Buffer.from_bytes(self.data[start:end]) + + +async def test_get_group_bytes_single_chunk() -> None: + """Test _get_group_bytes extracts single chunk correctly.""" + codec = ShardingCodec(chunk_shape=(8,)) + data = b"0123456789" * 10 # 100 bytes + byte_getter = MockByteGetter(data=data) + + chunk = _ChunkCoordsByteSlice(chunk_coords=(0,), byte_slice=slice(10, 30)) + group = [chunk] + + result = await codec._get_group_bytes(group, byte_getter, default_buffer_prototype()) + + assert result is not None + assert (0,) in result + chunk_buf = result[(0,)] + assert chunk_buf is not None + assert chunk_buf.as_numpy_array().tobytes() == data[10:30] + + +async def test_get_group_bytes_multiple_chunks() -> None: + """Test _get_group_bytes extracts multiple chunks with correct offsets.""" + codec = ShardingCodec(chunk_shape=(8,)) + data = b"0123456789" * 10 # 100 bytes + byte_getter = MockByteGetter(data=data) + + # Two chunks: [10, 30) and [30, 50) + chunk0 = _ChunkCoordsByteSlice(chunk_coords=(0,), byte_slice=slice(10, 30)) + chunk1 = _ChunkCoordsByteSlice(chunk_coords=(1,), byte_slice=slice(30, 50)) + group = [chunk0, chunk1] + + result = await codec._get_group_bytes(group, byte_getter, default_buffer_prototype()) + + assert result is not None + assert len(result) == 2 + chunk0_buf = result[(0,)] + chunk1_buf = result[(1,)] + assert chunk0_buf is not None + assert chunk1_buf is not None + assert chunk0_buf.as_numpy_array().tobytes() == data[10:30] + assert chunk1_buf.as_numpy_array().tobytes() == data[30:50] + + +async def test_get_group_bytes_with_gap() -> None: + """Test _get_group_bytes handles chunks with gaps correctly.""" + codec = ShardingCodec(chunk_shape=(8,)) + data = b"0123456789" * 10 # 100 bytes + byte_getter = MockByteGetter(data=data) + + # Two chunks with a gap: [10, 20) and [40, 60) + chunk0 = _ChunkCoordsByteSlice(chunk_coords=(0,), byte_slice=slice(10, 20)) + chunk1 = _ChunkCoordsByteSlice(chunk_coords=(1,), byte_slice=slice(40, 60)) + group = [chunk0, chunk1] + + result = await codec._get_group_bytes(group, byte_getter, default_buffer_prototype()) + + assert result is not None + assert len(result) == 2 + # The byte_getter.get is called with range [10, 60), then sliced + chunk0_buf = result[(0,)] + chunk1_buf = result[(1,)] + assert chunk0_buf is not None + assert chunk1_buf is not None + assert chunk0_buf.as_numpy_array().tobytes() == data[10:20] + assert chunk1_buf.as_numpy_array().tobytes() == data[40:60] + + +async def test_get_group_bytes_returns_none_on_failed_read() -> None: + """Test _get_group_bytes returns None when ByteGetter.get returns None.""" + codec = ShardingCodec(chunk_shape=(8,)) + byte_getter = MockByteGetter(data=b"", return_none=True) + + chunk = _ChunkCoordsByteSlice(chunk_coords=(0,), byte_slice=slice(0, 100)) + group = [chunk] + + result = await codec._get_group_bytes(group, byte_getter, default_buffer_prototype()) + + assert result is None + + +# ============================================================================ +# _load_partial_shard_maybe tests +# ============================================================================ + + +@dataclass +class MockByteGetterWithIndex: + """Mock ByteGetter that can return a shard index and chunk data.""" + + index_data: bytes | None + chunk_data: bytes | None + call_count: int = 0 + + async def get( + self, prototype: BufferPrototype, byte_range: ByteRequest | None = None + ) -> Buffer | None: + self.call_count += 1 + # First call is typically for the index + if self.call_count == 1: + if self.index_data is None: + return None + return Buffer.from_bytes(self.index_data) + # Subsequent calls are for chunk data + if self.chunk_data is None: + return None + if byte_range is None: + return Buffer.from_bytes(self.chunk_data) + # For RangeByteRequest, extract start and end + start = getattr(byte_range, "start", 0) + end = getattr(byte_range, "end", len(self.chunk_data)) + return Buffer.from_bytes(self.chunk_data[start:end]) + + +async def test_load_partial_shard_maybe_index_load_fails() -> None: + """Test _load_partial_shard_maybe returns None when index load fails.""" + codec = ShardingCodec(chunk_shape=(8,)) + byte_getter = MockByteGetterWithIndex(index_data=None, chunk_data=None) + + chunks_per_shard = (2,) + all_chunk_coords: set[tuple[int, ...]] = {(0,)} + + result = await codec._load_partial_shard_maybe( + byte_getter=byte_getter, + prototype=default_buffer_prototype(), + chunks_per_shard=chunks_per_shard, + all_chunk_coords=all_chunk_coords, + max_gap_bytes=100, + coalesce_max_bytes=1000, + async_concurrency=1, + ) + + assert result is None + + +async def test_load_partial_shard_maybe_with_empty_chunks( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test _load_partial_shard_maybe skips chunks where get_chunk_slice returns None.""" + codec = ShardingCodec(chunk_shape=(8,)) + chunks_per_shard = (4,) + + # Create an index where chunk (1,) is empty (returns None from get_chunk_slice) + index = _ShardIndex.create_empty(chunks_per_shard) + index.set_chunk_slice((0,), slice(0, 100)) + # (1,) is intentionally left empty + index.set_chunk_slice((2,), slice(100, 200)) + index.set_chunk_slice((3,), slice(200, 300)) + + # Mock _load_shard_index_maybe on the class to return our custom index + async def mock_load_index( + self: ShardingCodec, byte_getter: MockByteGetter, cps: tuple[int, ...] + ) -> _ShardIndex: + return index + + monkeypatch.setattr(ShardingCodec, "_load_shard_index_maybe", mock_load_index) + + # Create byte getter with chunk data + chunk_data = b"x" * 300 + byte_getter = MockByteGetter(data=chunk_data) + + # Request chunks including the empty one + all_chunk_coords: set[tuple[int, ...]] = {(0,), (1,), (2,)} + + result = await codec._load_partial_shard_maybe( + byte_getter=byte_getter, + prototype=default_buffer_prototype(), + chunks_per_shard=chunks_per_shard, + all_chunk_coords=all_chunk_coords, + max_gap_bytes=1000, + coalesce_max_bytes=10000, + async_concurrency=1, + ) + + assert result is not None + # Only chunks (0,) and (2,) should be in result, (1,) is empty and skipped + assert (0,) in result + assert (1,) not in result # Empty chunk should be skipped + assert (2,) in result + + +async def test_load_partial_shard_maybe_all_chunks_empty( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test _load_partial_shard_maybe returns empty dict when all requested chunks are empty.""" + codec = ShardingCodec(chunk_shape=(8,)) + chunks_per_shard = (4,) + + # Create an empty index (all chunks empty) + index = _ShardIndex.create_empty(chunks_per_shard) + + # Mock _load_shard_index_maybe on the class to return our empty index + async def mock_load_index( + self: ShardingCodec, byte_getter: MockByteGetter, cps: tuple[int, ...] + ) -> _ShardIndex: + return index + + monkeypatch.setattr(ShardingCodec, "_load_shard_index_maybe", mock_load_index) + + byte_getter = MockByteGetter(data=b"") + + # Request some chunks - all will be empty + all_chunk_coords: set[tuple[int, ...]] = {(0,), (1,), (2,)} + + result = await codec._load_partial_shard_maybe( + byte_getter=byte_getter, + prototype=default_buffer_prototype(), + chunks_per_shard=chunks_per_shard, + all_chunk_coords=all_chunk_coords, + max_gap_bytes=1000, + coalesce_max_bytes=10000, + async_concurrency=1, + ) + + assert result is not None + assert len(result) == 0 # All chunks were empty, so result is empty dict + + +# ============================================================================ +# Supporting class tests (_ShardReader, _is_total_shard, _ChunkCoordsByteSlice) +# ============================================================================ + + +def test_chunk_coords_byte_slice() -> None: + """Test _ChunkCoordsByteSlice dataclass.""" + chunk = _ChunkCoordsByteSlice(chunk_coords=(1, 2, 3), byte_slice=slice(100, 200)) + + assert chunk.chunk_coords == (1, 2, 3) + assert chunk.byte_slice == slice(100, 200) + assert chunk.byte_slice.start == 100 + assert chunk.byte_slice.stop == 200 + + +def test_shard_reader_create_empty() -> None: + """Test _ShardReader.create_empty creates reader with empty index.""" + chunks_per_shard = (2, 3) + reader = _ShardReader.create_empty(chunks_per_shard) + + assert reader.index.is_all_empty() + assert len(reader.buf) == 0 + assert len(reader) == 6 # 2 * 3 + + +def test_shard_reader_iteration() -> None: + """Test _ShardReader iteration yields all chunk coordinates.""" + chunks_per_shard = (2, 2) + reader = _ShardReader.create_empty(chunks_per_shard) + + coords = list(reader) + + assert len(coords) == 4 + assert (0, 0) in coords + assert (0, 1) in coords + assert (1, 0) in coords + assert (1, 1) in coords + + +def test_shard_reader_getitem_raises_for_empty() -> None: + """Test _ShardReader.__getitem__ raises KeyError for empty chunks.""" + chunks_per_shard = (2,) + reader = _ShardReader.create_empty(chunks_per_shard) + + with pytest.raises(KeyError): + _ = reader[(0,)] + + +def test_is_total_shard_full() -> None: + """Test _is_total_shard returns True when all chunk coords are present.""" + codec = ShardingCodec(chunk_shape=(8,)) + chunks_per_shard = (2, 2) + all_chunk_coords: set[tuple[int, ...]] = {(0, 0), (0, 1), (1, 0), (1, 1)} + + assert codec._is_total_shard(all_chunk_coords, chunks_per_shard) is True + + +def test_is_total_shard_partial() -> None: + """Test _is_total_shard returns False for partial chunk coords.""" + codec = ShardingCodec(chunk_shape=(8,)) + chunks_per_shard = (2, 2) + all_chunk_coords: set[tuple[int, ...]] = {(0, 0), (1, 1)} # Missing (0, 1) and (1, 0) + + assert codec._is_total_shard(all_chunk_coords, chunks_per_shard) is False + + +def test_is_total_shard_empty() -> None: + """Test _is_total_shard returns False for empty chunk coords.""" + codec = ShardingCodec(chunk_shape=(8,)) + chunks_per_shard = (2, 2) + all_chunk_coords: set[tuple[int, ...]] = set() + + assert codec._is_total_shard(all_chunk_coords, chunks_per_shard) is False + + +def test_is_total_shard_1d() -> None: + """Test _is_total_shard works with 1D shards.""" + codec = ShardingCodec(chunk_shape=(8,)) + chunks_per_shard = (4,) + all_chunk_coords: set[tuple[int, ...]] = {(0,), (1,), (2,), (3,)} + + assert codec._is_total_shard(all_chunk_coords, chunks_per_shard) is True + + # Partial + partial_coords: set[tuple[int, ...]] = {(0,), (2,)} + assert codec._is_total_shard(partial_coords, chunks_per_shard) is False diff --git a/tests/test_config.py b/tests/test_config.py index c3102e8efe..ad900f874f 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -97,6 +97,12 @@ def test_config_defaults_set() -> None: }, "buffer": "zarr.buffer.cpu.Buffer", "ndbuffer": "zarr.buffer.cpu.NDBuffer", + "sharding": { + "read": { + "coalesce_max_bytes": 100 * 2**20, # 100 MiB + "coalesce_max_gap_bytes": 2**20, # 1 MiB + } + }, } ] ) @@ -105,6 +111,8 @@ def test_config_defaults_set() -> None: assert config.get("async.timeout") is None assert config.get("codec_pipeline.batch_size") == 1 assert config.get("json_indent") == 2 + assert config.get("sharding.read.coalesce_max_bytes") == 100 * 2**20 # 100 MiB + assert config.get("sharding.read.coalesce_max_gap_bytes") == 2**20 # 1 MiB @pytest.mark.parametrize(