Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 8 additions & 91 deletions examples/custom_dtype/custom_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@
import json
import sys
from pathlib import Path
from typing import ClassVar, Literal, Self, TypeGuard, overload
from typing import ClassVar, Literal, Self, TypeGuard

import ml_dtypes # necessary to add extra dtypes to NumPy
import numpy as np
import pytest

import zarr
from zarr.dtype import ZDType, check_dtype_spec_v2, data_type_registry
from zarr.dtype import ZDType, register_data_type
from zarr.errors import DataTypeValidationError
from zarr.types import JSON, DTypeConfig_V2, DTypeJSON, ZarrFormat
from zarr.types import JSON, ZarrFormat

# This is the int2 array data type
int2_dtype_cls = type(np.dtype("int2"))
Expand All @@ -39,11 +39,11 @@ class Int2(ZDType[int2_dtype_cls, int2_scalar_cls]):
NumPy array of type int2) and the int2 scalar type (the ``dtype`` of the scalar value inside an int2 array).
"""

# This field is as the key for the data type in the internal data type registry, and also
# as the identifier for the data type when serializaing the data type to disk for zarr v3
# This is the key for the data type in the internal data type registry, and also the identifier
# for the data type when serializing it to disk. For a parameter-free data type like this one,
# ZDType uses it as the entire Zarr V3 representation and as the Zarr V2 ``name`` -- so we don't
# need to write any JSON (de)serialization for the data type itself; the base class handles it.
_zarr_v3_name: ClassVar[Literal["int2"]] = "int2"
# this field will be used internally
_zarr_v2_name: ClassVar[Literal["int2"]] = "int2"

# we bind a class variable to the native data type class so we can create instances of it
dtype_cls = int2_dtype_cls
Expand All @@ -61,89 +61,6 @@ def to_native_dtype(self: Self) -> int2_dtype_cls:
"""Create an int2 dtype instance from this ZDType"""
return self.dtype_cls()

@classmethod
def _check_json_v2(cls, data: DTypeJSON) -> TypeGuard[DTypeConfig_V2[Literal["|b1"], None]]:
"""
Type check for Zarr v2-flavored JSON.

This will check that the input is a dict like this:
.. code-block:: json

{
"name": "int2",
"object_codec_id": None
}

Note that this representation differs from the ``dtype`` field looks like in zarr v2 metadata.
Specifically, whatever goes into the ``dtype`` field in metadata is assigned to the ``name`` field here.

See the Zarr docs for more information about the JSON encoding for data types.
"""
return (
check_dtype_spec_v2(data) and data["name"] == "int2" and data["object_codec_id"] is None
)

@classmethod
def _check_json_v3(cls, data: DTypeJSON) -> TypeGuard[Literal["int2"]]:
"""
Type check for Zarr V3-flavored JSON.

Checks that the input is the string "int2".
"""
return data == cls._zarr_v3_name

@classmethod
def _from_json_v2(cls, data: DTypeJSON) -> Self:
"""
Create an instance of this ZDType from Zarr V3-flavored JSON.
"""
if cls._check_json_v2(data):
return cls()
# This first does a type check on the input, and if that passes we create an instance of the ZDType.
msg = f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected the string {cls._zarr_v2_name!r}"
raise DataTypeValidationError(msg)

@classmethod
def _from_json_v3(cls: type[Self], data: DTypeJSON) -> Self:
"""
Create an instance of this ZDType from Zarr V3-flavored JSON.

This first does a type check on the input, and if that passes we create an instance of the ZDType.
"""
if cls._check_json_v3(data):
return cls()
msg = f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected the string {cls._zarr_v3_name!r}"
raise DataTypeValidationError(msg)

@overload
def to_json(self, zarr_format: Literal[2]) -> DTypeConfig_V2[Literal["int2"], None]: ...

@overload
def to_json(self, zarr_format: Literal[3]) -> Literal["int2"]: ...

def to_json(
self, zarr_format: ZarrFormat
) -> DTypeConfig_V2[Literal["int2"], None] | Literal["int2"]:
"""
Serialize this ZDType to v2- or v3-flavored JSON

If the zarr_format is 2, then return a dict like this:
.. code-block:: json

{
"name": "int2",
"object_codec_id": None
}

If the zarr_format is 3, then return the string "int2"

"""
if zarr_format == 2:
return {"name": "int2", "object_codec_id": None}
if zarr_format == 3:
return self._zarr_v3_name
raise ValueError(f"zarr_format must be 2 or 3, got {zarr_format}") # pragma: no cover

def _check_scalar(self, data: object) -> TypeGuard[int | ml_dtypes.int2]:
"""
Check if a python object is a valid int2-compatible scalar
Expand Down Expand Up @@ -209,7 +126,7 @@ def from_json_scalar(self, data: JSON, *, zarr_format: ZarrFormat) -> ml_dtypes.


# after defining dtype class, it must be registered with the data type registry so zarr can use it
data_type_registry.register(Int2._zarr_v3_name, Int2)
register_data_type(Int2)


# this parametrized function will create arrays in zarr v2 and v3 using our new data type
Expand Down
5 changes: 5 additions & 0 deletions src/zarr/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ def enable_gpu(self) -> ConfigSet:
defaults=[
{
"default_zarr_format": 3,
# How to interpret data type metadata. "compatible" makes a best-effort attempt to
# read wrong-but-parsable data type metadata (e.g. a Zarr V2 ``">u1"`` typestring,
# which NumPy accepts but normalizes to ``"|u1"``). "strict" accepts only
# spec-compliant, canonical data type metadata.
"data_type_resolution": "compatible",
"array": {
"order": "C",
"write_empty_chunks": False,
Expand Down
36 changes: 25 additions & 11 deletions src/zarr/core/dtype/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,15 @@
VariableLengthUTF8,
VariableLengthUTF8JSON_V2,
)
from zarr.core.dtype.registry import DataTypeRegistry
from zarr.core.dtype.registry import (
DataTypeRegistry,
data_type_registry,
load_data_type_entrypoints,
match_dtype,
match_json,
register_data_type,
unregister_data_type,
)
from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType

__all__ = [
Expand Down Expand Up @@ -97,12 +105,15 @@
"VariableLengthUTF8JSON_V2",
"ZDType",
"data_type_registry",
"load_data_type_entrypoints",
"match_dtype",
"match_json",
"parse_data_type",
"parse_dtype",
"register_data_type",
"unregister_data_type",
]

data_type_registry = DataTypeRegistry()

IntegerDType = Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64
INTEGER_DTYPE: Final = Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64

Expand Down Expand Up @@ -146,18 +157,21 @@
VariableLengthBytes,
)

# These are aliases for variable-length UTF-8 strings
# We handle them when a user requests a data type instead of using NumPy's dtype inferece because
# the default NumPy behavior -- to inspect the user-provided array data and choose
# an appropriately sized U dtype -- is unworkable for Zarr.
VLEN_UTF8_ALIAS: Final = ("str", str, "string")
# These are aliases for variable-length UTF-8 strings: the python ``str`` type plus the data
# type's declared Zarr V3 names ("string" and "str"). We handle them when a user requests a data
# type instead of using NumPy's dtype inference because the default NumPy behavior -- to inspect
# the user-provided array data and choose an appropriately sized U dtype -- is unworkable for Zarr.
VLEN_UTF8_ALIAS: Final = (str, *VariableLengthUTF8._zarr_v3_names())

# This type models inputs that can be coerced to a ZDType
type ZDTypeLike = npt.DTypeLike | ZDType[TBaseDType, TBaseScalar] | Mapping[str, JSON] | str

for dtype in ANY_DTYPE:
# mypy does not know that all the elements of ANY_DTYPE are subclasses of ZDType
data_type_registry.register(dtype._zarr_v3_name, dtype) # type: ignore[arg-type]
register_data_type(dtype) # type: ignore[arg-type]

# Register any data types advertised by third-party packages via entry points.
load_data_type_entrypoints()


# TODO: find a better name for this function
Expand All @@ -174,7 +188,7 @@ def get_data_type_from_native_dtype(dtype: npt.DTypeLike) -> ZDType[TBaseDType,
na_dtype = np.dtype(dtype)
else:
na_dtype = dtype
return data_type_registry.match_dtype(dtype=na_dtype)
return match_dtype(na_dtype)


def get_data_type_from_json(
Expand All @@ -184,7 +198,7 @@ def get_data_type_from_json(
Given a JSON representation of a data type and a Zarr format version,
attempt to create a ZDType instance from the registered ZDType classes.
"""
return data_type_registry.match_json(dtype_spec, zarr_format=zarr_format)
return match_json(dtype_spec, zarr_format=zarr_format)


def parse_data_type(
Expand Down
135 changes: 4 additions & 131 deletions src/zarr/core/dtype/npy/bool.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,23 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Literal, Self, TypeGuard, overload
from typing import TYPE_CHECKING, ClassVar, Literal, Self

import numpy as np

from zarr.core.dtype.common import (
DTypeConfig_V2,
DTypeJSON,
HasItemSize,
check_dtype_spec_v2,
)
from zarr.core.dtype.wrapper import TBaseDType, ZDType
from zarr.core.dtype.npy.common import NumpyNativeDTypeV2
from zarr.errors import DataTypeValidationError

if TYPE_CHECKING:
from zarr.core.common import JSON, ZarrFormat
from zarr.core.dtype.wrapper import TBaseDType


@dataclass(frozen=True, kw_only=True, slots=True)
class Bool(ZDType[np.dtypes.BoolDType, np.bool_], HasItemSize):
class Bool(NumpyNativeDTypeV2[np.dtypes.BoolDType, np.bool_], HasItemSize):
"""
A Zarr data type for arrays containing booleans.

Expand All @@ -45,7 +43,6 @@ class Bool(ZDType[np.dtypes.BoolDType, np.bool_], HasItemSize):
"""

_zarr_v3_name: ClassVar[Literal["bool"]] = "bool"
_zarr_v2_name: ClassVar[Literal["|b1"]] = "|b1"
dtype_cls = np.dtypes.BoolDType

@classmethod
Expand Down Expand Up @@ -85,130 +82,6 @@ def to_native_dtype(self: Self) -> np.dtypes.BoolDType:
"""
return self.dtype_cls()

@classmethod
def _check_json_v2(
cls,
data: DTypeJSON,
) -> TypeGuard[DTypeConfig_V2[Literal["|b1"], None]]:
"""
Check that the input is a valid JSON representation of a Bool.

Parameters
----------
data : DTypeJSON
The JSON data to check.

Returns
-------
``TypeGuard[DTypeConfig_V2[Literal["|b1"], None]]``
True if the input is a valid JSON representation, False otherwise.
"""
return (
check_dtype_spec_v2(data)
and data["name"] == cls._zarr_v2_name
and data["object_codec_id"] is None
)

@classmethod
def _check_json_v3(cls, data: DTypeJSON) -> TypeGuard[Literal["bool"]]:
"""
Check that the input is a valid JSON representation of this class in Zarr V3.

Parameters
----------
data : DTypeJSON
The JSON data to check.

Returns
-------
bool
True if the input is a valid JSON representation, False otherwise.
"""
return data == cls._zarr_v3_name

@classmethod
def _from_json_v2(cls, data: DTypeJSON) -> Self:
"""
Create an instance of Bool from Zarr V2-flavored JSON.

Parameters
----------
data : DTypeJSON
The JSON data.

Returns
-------
Bool
An instance of Bool.

Raises
------
DataTypeValidationError
If the input JSON is not a valid representation of this class.
"""
if cls._check_json_v2(data):
return cls()
msg = f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected the string {cls._zarr_v2_name!r}"
raise DataTypeValidationError(msg)

@classmethod
def _from_json_v3(cls: type[Self], data: DTypeJSON) -> Self:
"""
Create an instance of Bool from Zarr V3-flavored JSON.

Parameters
----------
data : DTypeJSON
The JSON data.

Returns
-------
Bool
An instance of Bool.

Raises
------
DataTypeValidationError
If the input JSON is not a valid representation of this class.
"""
if cls._check_json_v3(data):
return cls()
msg = f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected the string {cls._zarr_v3_name!r}"
raise DataTypeValidationError(msg)

@overload
def to_json(self, zarr_format: Literal[2]) -> DTypeConfig_V2[Literal["|b1"], None]: ...

@overload
def to_json(self, zarr_format: Literal[3]) -> Literal["bool"]: ...

def to_json(
self, zarr_format: ZarrFormat
) -> DTypeConfig_V2[Literal["|b1"], None] | Literal["bool"]:
"""
Serialize this Bool instance to JSON.

Parameters
----------
zarr_format : ZarrFormat
The Zarr format version (2 or 3).

Returns
-------
``DTypeConfig_V2[Literal["|b1"], None] | Literal["bool"]``
The JSON representation of the Bool instance.

Raises
------
ValueError
If the zarr_format is not 2 or 3.
"""
if zarr_format == 2:
return {"name": self._zarr_v2_name, "object_codec_id": None}
elif zarr_format == 3:
return self._zarr_v3_name
raise ValueError(f"zarr_format must be 2 or 3, got {zarr_format}") # pragma: no cover

def _check_scalar(self, data: object) -> bool:
"""
Check if the input can be cast to a boolean scalar.
Expand Down
Loading