Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 114 additions & 37 deletions src/zarr/core/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from zarr.core.chunk_grids import (
SHARDED_INNER_CHUNK_MAX_BYTES,
ChunkGrid,
ChunkSpec,
_is_rectilinear_chunks,
as_regular_shape,
guess_chunks,
Expand Down Expand Up @@ -383,6 +384,13 @@ def __init__(
create_codec_pipeline(metadata=metadata_parsed, store=store_path.store),
)
object.__setattr__(self, "_transform", IndexTransform.from_shape(metadata_parsed.shape))
object.__setattr__(self, "_shape", self._transform.domain.shape)
# A freshly-opened array has the identity transform: input coord i maps to
# storage coord i over the full storage domain. Eager indexing on such an
# array can use the original (legacy) indexers directly, avoiding the
# transform-resolution overhead. Lazy views (created via _with_transform)
# carry a non-identity transform and must go through the transform path.
object.__setattr__(self, "_is_identity", True)

@classmethod
async def _create(
Expand Down Expand Up @@ -805,6 +813,10 @@ def _with_transform(self, transform: IndexTransform) -> AsyncArray[T_ArrayMetada
object.__setattr__(new, "_chunk_grid", self._chunk_grid)
object.__setattr__(new, "codec_pipeline", self.codec_pipeline)
object.__setattr__(new, "_transform", transform)
object.__setattr__(new, "_shape", transform.domain.shape)
object.__setattr__(
new, "_is_identity", _transform_is_identity(transform, self.metadata.shape)
)
return new

@property
Expand All @@ -825,7 +837,7 @@ def ndim(self) -> int:
int
The number of dimensions in the Array.
"""
return len(self.shape)
return len(self._shape)

@property
def shape(self) -> tuple[int, ...]:
Expand All @@ -836,7 +848,7 @@ def shape(self) -> tuple[int, ...]:
tuple
The shape of the Array.
"""
return self._transform.domain.shape
return self._shape

@property
def storage_shape(self) -> tuple[int, ...]:
Expand Down Expand Up @@ -1617,6 +1629,7 @@ async def _get_selection_t(
self.codec_pipeline,
prototype=prototype,
out=out,
chunk_grid=self._chunk_grid,
)

async def _set_selection_t(
Expand All @@ -1634,6 +1647,7 @@ async def _set_selection_t(
value,
self.codec_pipeline,
prototype=prototype,
chunk_grid=self._chunk_grid,
)

async def setitem(
Expand Down Expand Up @@ -2900,8 +2914,9 @@ def get_basic_selection(

if prototype is None:
prototype = default_buffer_prototype()
if fields is not None:
# Fall back to legacy path for structured dtype field selection
if fields is not None or self._async_array._is_identity:
# Eager (identity-transform) arrays and structured-dtype field
# selection use the original indexer path directly.
return sync(
self.async_array._get_selection(
BasicIndexer(selection, self.shape, self._chunk_grid),
Expand Down Expand Up @@ -3014,8 +3029,9 @@ def set_basic_selection(
"""
if prototype is None:
prototype = default_buffer_prototype()
if fields is not None:
# Fall back to legacy path for structured dtype field selection
if fields is not None or self._async_array._is_identity:
# Eager (identity-transform) arrays and structured-dtype field
# selection use the original indexer path directly.
indexer = BasicIndexer(selection, self.shape, self._chunk_grid)
sync(
self.async_array._set_selection(indexer, value, fields=fields, prototype=prototype)
Expand Down Expand Up @@ -3150,8 +3166,9 @@ def get_orthogonal_selection(
"""
if prototype is None:
prototype = default_buffer_prototype()
if fields is not None or not is_basic_selection(selection):
# Fall back to legacy path for structured dtypes or advanced selections
if fields is not None or self._async_array._is_identity or not is_basic_selection(selection):
# Eager (identity) arrays, structured dtypes, and advanced selections
# use the original indexer path directly.
indexer = OrthogonalIndexer(selection, self.shape, self._chunk_grid)
return sync(
self.async_array._get_selection(
Expand Down Expand Up @@ -3273,8 +3290,9 @@ def set_orthogonal_selection(
"""
if prototype is None:
prototype = default_buffer_prototype()
if fields is not None or not is_basic_selection(selection):
# Fall back to legacy path for structured dtypes or advanced selections
if fields is not None or self._async_array._is_identity or not is_basic_selection(selection):
# Eager (identity) arrays, structured dtypes, and advanced selections
# use the original indexer path directly.
indexer = OrthogonalIndexer(selection, self.shape, self._chunk_grid)
sync(
self.async_array._set_selection(indexer, value, fields=fields, prototype=prototype)
Expand Down Expand Up @@ -3367,7 +3385,7 @@ def get_mask_selection(

if prototype is None:
prototype = default_buffer_prototype()
if fields is not None:
if fields is not None or self._async_array._is_identity:
indexer = MaskIndexer(mask, self.shape, self._chunk_grid)
return sync(
self.async_array._get_selection(
Expand Down Expand Up @@ -3472,7 +3490,7 @@ def set_mask_selection(
"""
if prototype is None:
prototype = default_buffer_prototype()
if fields is not None:
if fields is not None or self._async_array._is_identity:
indexer = MaskIndexer(mask, self.shape, self._chunk_grid)
sync(
self.async_array._set_selection(indexer, value, fields=fields, prototype=prototype)
Expand Down Expand Up @@ -3577,7 +3595,7 @@ def get_coordinate_selection(
"""
if prototype is None:
prototype = default_buffer_prototype()
if fields is not None:
if fields is not None or self._async_array._is_identity:
indexer = CoordinateIndexer(selection, self.shape, self._chunk_grid)
out_array = sync(
self.async_array._get_selection(
Expand Down Expand Up @@ -3700,7 +3718,7 @@ def set_coordinate_selection(
# Normalize empty fields list to None
if not fields:
fields = None
if fields is not None:
if fields is not None or self._async_array._is_identity:
indexer = CoordinateIndexer(selection, self.shape, self._chunk_grid)
if not is_scalar(value, self.dtype):
try:
Expand All @@ -3711,6 +3729,13 @@ def set_coordinate_selection(
value = np.array(value)
if hasattr(value, "shape") and len(value.shape) > 1:
value = np.array(value).reshape(-1)
if not is_scalar(value, self.dtype) and (
isinstance(value, NDArrayLike) and indexer.shape != value.shape
):
raise ValueError(
f"Attempting to set a selection of {indexer.sel_shape[0]} "
f"elements with an array of {value.shape[0]} elements."
)
sync(
self.async_array._set_selection(indexer, value, fields=fields, prototype=prototype)
)
Expand Down Expand Up @@ -5582,17 +5607,18 @@ async def _nbytes_stored(
return await store_path.store.getsize_prefix(store_path.path)


def _get_chunk_spec(
def _array_spec_from_chunk_spec(
metadata: ArrayMetadata,
chunk_grid: ChunkGrid,
chunk_coords: tuple[int, ...],
spec: ChunkSpec,
array_config: ArrayConfig,
prototype: BufferPrototype,
) -> ArraySpec:
"""Build an ArraySpec for a single chunk using the ChunkGrid."""
spec = chunk_grid[chunk_coords]
if spec is None:
raise IndexError(f"Chunk coordinates {chunk_coords} are out of bounds.")
"""Build an ArraySpec from an already-resolved ChunkSpec.

Split out from :func:`_get_chunk_spec` so the transform read/write path can
resolve ``chunk_grid[chunk_coords]`` once per chunk and feed the same
``spec`` to both this and :func:`_is_complete_chunk`.
"""
return ArraySpec(
shape=spec.codec_shape,
dtype=metadata.dtype,
Expand All @@ -5602,25 +5628,66 @@ def _get_chunk_spec(
)


def _is_complete_chunk(
sub_transform: IndexTransform, chunk_grid: ChunkGrid, chunk_coords: tuple[int, ...]
) -> bool:
"""Check if a sub-transform covers an entire chunk."""
from zarr.core.transforms.output_map import ConstantMap, DimensionMap

def _get_chunk_spec(
metadata: ArrayMetadata,
chunk_grid: ChunkGrid,
chunk_coords: tuple[int, ...],
array_config: ArrayConfig,
prototype: BufferPrototype,
) -> ArraySpec:
"""Build an ArraySpec for a single chunk using the ChunkGrid."""
spec = chunk_grid[chunk_coords]
if spec is None:
raise IndexError(f"Chunk coordinates {chunk_coords} are out of bounds.")
return _array_spec_from_chunk_spec(metadata, spec, array_config, prototype)


def _transform_is_identity(transform: IndexTransform, storage_shape: tuple[int, ...]) -> bool:
"""Return True if ``transform`` is the identity over the full storage domain.

An identity transform maps input coordinate ``i`` to storage coordinate ``i``
across the array's whole storage shape (origin 0, unit stride, dimensions in
order). Such an array is an ordinary eager array — indexing it produces the
same coordinates the legacy indexers compute, so the legacy fast path is
safe. Any narrowing, striding, reordering, or fancy selection (i.e. a lazy
view) yields a non-identity transform that must go through the transform
resolver. Cheap: O(ndim), no array work.
"""
from zarr.core.transforms.output_map import DimensionMap

domain = transform.domain
ndim = len(storage_shape)
if domain.ndim != ndim or len(transform.output) != ndim:
return False
if domain.inclusive_min != (0,) * ndim or domain.exclusive_max != storage_shape:
return False
for i, m in enumerate(transform.output):
if not (
type(m) is DimensionMap and m.input_dimension == i and m.offset == 0 and m.stride == 1
):
return False
return True


def _is_complete_chunk(sub_transform: IndexTransform, spec: ChunkSpec) -> bool:
"""Check if a sub-transform covers an entire chunk.

``spec`` is the chunk's already-resolved :class:`ChunkSpec` (the caller looks
it up once and shares it with :func:`_array_spec_from_chunk_spec`).
"""
from zarr.core.transforms.output_map import ConstantMap, DimensionMap

shape = spec.shape
for out_dim, m in enumerate(sub_transform.output):
if isinstance(m, ConstantMap):
# A ConstantMap means a single element is selected along this output dimension,
# so the write does not cover the full chunk along this dimension.
chunk_dim_size = spec.shape[out_dim]
chunk_dim_size = shape[out_dim]
if chunk_dim_size > 1:
return False
continue # chunk dim size is 1, so selecting the single element is complete
if isinstance(m, DimensionMap):
chunk_dim_size = spec.shape[out_dim]
chunk_dim_size = shape[out_dim]
# Compute actual storage range: storage = offset + stride * input_coord
dim_lo = sub_transform.domain.inclusive_min[m.input_dimension]
dim_hi = sub_transform.domain.exclusive_max[m.input_dimension]
Expand All @@ -5645,9 +5712,11 @@ async def _get_selection_via_transform(
*,
prototype: BufferPrototype,
out: NDBuffer | None = None,
chunk_grid: ChunkGrid | None = None,
) -> NDArrayLikeOrScalar:
"""Read data using an IndexTransform."""
chunk_grid = ChunkGrid.from_metadata(metadata)
if chunk_grid is None:
chunk_grid = ChunkGrid.from_metadata(metadata)

# Get dtype (same logic as existing _get_selection)
if metadata.zarr_format == 2:
Expand Down Expand Up @@ -5693,16 +5762,18 @@ async def _get_selection_via_transform(
for chunk_coords, sub_transform, out_indices in iter_chunk_transforms(
transform, chunk_grid
):
chunk_spec = chunk_grid[chunk_coords]
if chunk_spec is None:
continue
chunk_sel, out_sel, da = sub_transform_to_selections(sub_transform, out_indices)
drop_axes = da # same for all chunks
is_complete = _is_complete_chunk(sub_transform, chunk_grid, chunk_coords)
batch_info.append(
(
store_path / metadata.encode_chunk_key(chunk_coords),
_get_chunk_spec(metadata, chunk_grid, chunk_coords, _config, prototype),
_array_spec_from_chunk_spec(metadata, chunk_spec, _config, prototype),
chunk_sel,
out_sel,
is_complete,
_is_complete_chunk(sub_transform, chunk_spec),
)
)

Expand Down Expand Up @@ -5743,9 +5814,11 @@ async def _set_selection_via_transform(
codec_pipeline: CodecPipeline,
*,
prototype: BufferPrototype,
chunk_grid: ChunkGrid | None = None,
) -> None:
"""Write data using an IndexTransform."""
chunk_grid = ChunkGrid.from_metadata(metadata)
if chunk_grid is None:
chunk_grid = ChunkGrid.from_metadata(metadata)

# Get dtype from metadata
if metadata.zarr_format == 2:
Expand Down Expand Up @@ -5821,16 +5894,18 @@ async def _set_selection_via_transform(
batch_info = []
drop_axes: tuple[int, ...] = ()
for chunk_coords, sub_transform, out_indices in iter_chunk_transforms(transform, chunk_grid):
chunk_spec = chunk_grid[chunk_coords]
if chunk_spec is None:
continue
chunk_sel, out_sel, da = sub_transform_to_selections(sub_transform, out_indices)
drop_axes = da # same for all chunks
is_complete = _is_complete_chunk(sub_transform, chunk_grid, chunk_coords)
batch_info.append(
(
store_path / metadata.encode_chunk_key(chunk_coords),
_get_chunk_spec(metadata, chunk_grid, chunk_coords, _config, prototype),
_array_spec_from_chunk_spec(metadata, chunk_spec, _config, prototype),
chunk_sel,
out_sel,
is_complete,
_is_complete_chunk(sub_transform, chunk_spec),
)
)

Expand Down Expand Up @@ -6404,6 +6479,8 @@ async def _delete_key(key: str) -> None:
object.__setattr__(array, "metadata", new_metadata)
object.__setattr__(array, "_chunk_grid", new_chunk_grid)
object.__setattr__(array, "_transform", IndexTransform.from_shape(new_shape))
object.__setattr__(array, "_shape", array._transform.domain.shape)
object.__setattr__(array, "_is_identity", True)


async def _append(
Expand Down
Loading