diff --git a/changes/3285.feature.md b/changes/3285.feature.md new file mode 100644 index 0000000000..82232f8610 --- /dev/null +++ b/changes/3285.feature.md @@ -0,0 +1,6 @@ +JSON metadata validation now delegates to ``msgspec.convert`` for the type +coercions it supports (``Literal`` membership, ``int`` / ``bool`` strictness, +list-to-tuple), replacing the per-field hand-written ``parse_*`` logic. A small +fallback validates the recursive JSON values msgspec cannot, now with an +explicit nesting-depth limit, and a latent generator-exhaustion bug in +``parse_storage_transformers`` is fixed. See #3285. diff --git a/pyproject.toml b/pyproject.toml index 02e66c67e8..6cfb7d52cf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ dependencies = [ 'google-crc32c>=1.5', 'typing_extensions>=4.14', 'donfig>=0.8', + 'msgspec>=0.19', ] dynamic = [ @@ -271,6 +272,7 @@ extra-dependencies = [ 'typing_extensions==4.14.*', 'donfig==0.8.*', 'obstore==0.5.*', + 'msgspec==0.19.*', ] [tool.hatch.envs.default] diff --git a/src/zarr/codecs/blosc.py b/src/zarr/codecs/blosc.py index 087de716fc..d0458a91a0 100644 --- a/src/zarr/codecs/blosc.py +++ b/src/zarr/codecs/blosc.py @@ -104,27 +104,30 @@ class BloscCname(metaclass=_DeprecatedStrEnumMeta): def parse_typesize(data: JSON) -> int: - if isinstance(data, int): - if data > 0: - return data - else: - raise ValueError( - f"Value must be greater than 0. Got {data}, which is less or equal to 0." - ) - raise TypeError(f"Value must be an int. Got {type(data)} instead.") + from zarr.core.json_parse import parse_field + + parsed: int = parse_field(data, int, "typesize", error=TypeError) + if parsed > 0: + return parsed + else: + raise ValueError( + f"Value must be greater than 0. Got {parsed}, which is less or equal to 0." + ) # todo: real validation def parse_clevel(data: JSON) -> int: - if isinstance(data, int): - return data - raise TypeError(f"Value should be an int. Got {type(data)} instead.") + from zarr.core.json_parse import parse_field + + parsed: int = parse_field(data, int, "clevel", error=TypeError) + return parsed def parse_blocksize(data: JSON) -> int: - if isinstance(data, int): - return data - raise TypeError(f"Value should be an int. Got {type(data)} instead.") + from zarr.core.json_parse import parse_field + + parsed: int = parse_field(data, int, "blocksize", error=TypeError) + return parsed def _parse_cname(data: object) -> BloscCnameLiteral: diff --git a/src/zarr/codecs/gzip.py b/src/zarr/codecs/gzip.py index b8591748f7..052d1e928d 100644 --- a/src/zarr/codecs/gzip.py +++ b/src/zarr/codecs/gzip.py @@ -19,13 +19,14 @@ def parse_gzip_level(data: JSON) -> int: - if not isinstance(data, (int)): - raise TypeError(f"Expected int, got {type(data)}") - if data not in range(10): + from zarr.core.json_parse import parse_field + + parsed: int = parse_field(data, int, "level", error=TypeError) + if parsed not in range(10): raise ValueError( - f"Expected an integer from the inclusive range (0, 9). Got {data} instead." + f"Expected an integer from the inclusive range (0, 9). Got {parsed} instead." ) - return data + return parsed @dataclass(frozen=True) diff --git a/src/zarr/codecs/zstd.py b/src/zarr/codecs/zstd.py index f93c25a3c7..92c97f20b1 100644 --- a/src/zarr/codecs/zstd.py +++ b/src/zarr/codecs/zstd.py @@ -21,17 +21,19 @@ def parse_zstd_level(data: JSON) -> int: - if isinstance(data, int): - if data >= 23: - raise ValueError(f"Value must be less than or equal to 22. Got {data} instead.") - return data - raise TypeError(f"Got value with type {type(data)}, but expected an int.") + from zarr.core.json_parse import parse_field + + parsed: int = parse_field(data, int, "level", error=TypeError) + if parsed >= 23: + raise ValueError(f"Value must be less than or equal to 22. Got {parsed} instead.") + return parsed def parse_checksum(data: JSON) -> bool: - if isinstance(data, bool): - return data - raise TypeError(f"Expected bool. Got {type(data)}.") + from zarr.core.json_parse import parse_field + + parsed: bool = parse_field(data, bool, "checksum", error=TypeError) + return parsed @dataclass(frozen=True) diff --git a/src/zarr/core/chunk_key_encodings.py b/src/zarr/core/chunk_key_encodings.py index 098f2c8981..7055f3930b 100644 --- a/src/zarr/core/chunk_key_encodings.py +++ b/src/zarr/core/chunk_key_encodings.py @@ -19,9 +19,9 @@ def parse_separator(data: JSON) -> SeparatorLiteral: - if data not in (".", "/"): - raise ValueError(f"Expected an '.' or '/' separator. Got {data} instead.") - return cast("SeparatorLiteral", data) + from zarr.core.json_parse import parse_field + + return cast("SeparatorLiteral", parse_field(data, Literal[".", "/"], "separator")) class ChunkKeyEncodingParams(TypedDict): diff --git a/src/zarr/core/common.py b/src/zarr/core/common.py index 20664e553e..4281decbf5 100644 --- a/src/zarr/core/common.py +++ b/src/zarr/core/common.py @@ -124,12 +124,15 @@ def parse_enum[E: Enum](data: object, cls: type[E]) -> E: def parse_name(data: JSON, expected: str | None = None) -> str: - if isinstance(data, str): - if expected is None or data == expected: - return data - raise ValueError(f"Expected '{expected}'. Got {data} instead.") - else: - raise TypeError(f"Expected a string, got an instance of {type(data)}.") + from zarr.core.json_parse import convert + + try: + data = cast("str", convert(data, str)) + except (ValueError, TypeError) as exc: + raise TypeError(f"Expected a string, got an instance of {type(data)}.") from exc + if expected is None or data == expected: + return data + raise ValueError(f"Expected '{expected}'. Got {data} instead.") def parse_configuration(data: JSON) -> JSON: @@ -204,15 +207,15 @@ def parse_fill_value(data: Any) -> Any: def parse_order(data: Any) -> Literal["C", "F"]: - if data in ("C", "F"): - return cast("Literal['C', 'F']", data) - raise ValueError(f"Expected one of ('C', 'F'), got {data} instead.") + from zarr.core.json_parse import parse_field + + return cast("Literal['C', 'F']", parse_field(data, Literal["C", "F"], "order")) def parse_bool(data: Any) -> bool: - if isinstance(data, bool): - return data - raise ValueError(f"Expected bool, got {data} instead.") + from zarr.core.json_parse import convert + + return cast("bool", convert(data, bool)) def parse_int(data: Any) -> int: diff --git a/src/zarr/core/config.py b/src/zarr/core/config.py index 08d2a50ace..9752a58870 100644 --- a/src/zarr/core/config.py +++ b/src/zarr/core/config.py @@ -151,7 +151,6 @@ def enable_gpu(self) -> ConfigSet: def parse_indexing_order(data: Any) -> Literal["C", "F"]: - if data in ("C", "F"): - return cast("Literal['C', 'F']", data) - msg = f"Expected one of ('C', 'F'), got {data} instead." - raise ValueError(msg) + from zarr.core.json_parse import parse_field + + return cast("Literal['C', 'F']", parse_field(data, Literal["C", "F"], "order")) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 52eaa3e144..4bf26e4b31 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -85,18 +85,19 @@ def parse_zarr_format(data: Any) -> ZarrFormat: """Parse the zarr_format field from metadata.""" - if data in (2, 3): - return cast("ZarrFormat", data) - msg = f"Invalid zarr_format. Expected one of 2 or 3. Got {data}." - raise ValueError(msg) + from zarr.core.json_parse import parse_field + + return cast("ZarrFormat", parse_field(data, Literal[2, 3], "zarr_format")) def parse_node_type(data: Any) -> NodeType: """Parse the node_type field from metadata.""" - if data in ("array", "group"): - return cast("Literal['array', 'group']", data) - msg = f"Invalid value for 'node_type'. Expected 'array' or 'group'. Got '{data}'." - raise MetadataValidationError(msg) + from zarr.core.json_parse import parse_field + + return cast( + "Literal['array', 'group']", + parse_field(data, Literal["array", "group"], "node_type", error=MetadataValidationError), + ) # todo: convert None to empty dict diff --git a/src/zarr/core/json_parse.py b/src/zarr/core/json_parse.py new file mode 100644 index 0000000000..c07479b146 --- /dev/null +++ b/src/zarr/core/json_parse.py @@ -0,0 +1,91 @@ +"""Helpers for validating JSON-decoded metadata. + +Most JSON metadata validation is delegated to :func:`msgspec.convert`, which +handles the type coercions Zarr needs (``Literal`` membership, ``int``/``bool`` +strictness, list-to-tuple, ``TypedDict`` with ``NotRequired``). :func:`convert` +is a thin wrapper that translates :class:`msgspec.ValidationError` into the +``TypeError`` the rest of the codebase already raises. + +msgspec cannot handle two things in Zarr's metadata types: + +* the recursive ``JSON`` / ``JSONValue`` aliases, which it rejects at + schema-build time, and +* PEP 728 ``extra_items=`` extension fields, which it silently drops. + +:func:`validate_json_value` is the small hand-written fallback for the first of +those. See https://github.com/zarr-developers/zarr-python/issues/3285. +""" + +from __future__ import annotations + +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any, Final, cast + +import msgspec + +if TYPE_CHECKING: + from zarr.core.common import JSON + +__all__ = ["MAX_JSON_DEPTH", "convert", "parse_field", "validate_json_value"] + +MAX_JSON_DEPTH: Final = 64 +"""Maximum nesting depth accepted by :func:`validate_json_value`.""" + + +def _type_name(type_: Any) -> str: + return getattr(type_, "__name__", None) or str(type_).replace("typing.", "") + + +def convert(value: object, type_: Any, *, strict: bool = True) -> Any: + """Validate and coerce ``value`` against ``type_`` via :func:`msgspec.convert`. + + On a mismatch msgspec raises :class:`msgspec.ValidationError`; this re-raises + a plain, field-agnostic ``ValueError`` naming the expected type, so callers + can add their own field context (see :func:`parse_field`). + """ + try: + return msgspec.convert(value, type_, strict=strict) + except msgspec.ValidationError as exc: + raise ValueError(f"Expected instance of {_type_name(type_)}, got {value!r}.") from exc + + +def parse_field( + data: object, type_: Any, field: str, *, error: type[Exception] = ValueError +) -> Any: + """Validate ``data`` for metadata field ``field`` against ``type_``. + + Wraps :func:`convert` and, on failure, re-raises ``error`` with field + context, chaining the underlying type error. This keeps the + ``convert``-then-re-raise pattern in one place rather than repeating it in + every per-field parser. + """ + try: + return convert(data, type_) + except ValueError as exc: + raise error(f"Failed to parse input for {field!r}.") from exc + + +def validate_json_value(value: object, *, max_depth: int = MAX_JSON_DEPTH, _depth: int = 0) -> JSON: + """Check that ``value`` is a JSON value and return it unchanged. + + msgspec cannot build a schema for Zarr's recursive ``JSON`` / ``JSONValue`` + aliases, so this covers the fields typed that way (``attributes``, + ``fill_value``, extension-field values). Unlike the previous per-field + parsers it also enforces ``max_depth``: a pathologically nested document + could otherwise exhaust the interpreter stack. + """ + if _depth > max_depth: + raise ValueError(f"JSON value nesting exceeds the maximum depth of {max_depth}.") + if value is None or isinstance(value, (bool, int, float, str)): + return cast("JSON", value) + if isinstance(value, (list, tuple)): + for item in value: + validate_json_value(item, max_depth=max_depth, _depth=_depth + 1) + return cast("JSON", value) + if isinstance(value, Mapping): + for key, item in value.items(): + if not isinstance(key, str): + raise TypeError(f"JSON object keys must be str, got {type(key).__name__}.") + validate_json_value(item, max_depth=max_depth, _depth=_depth + 1) + return cast("JSON", value) + raise TypeError(f"Value {value!r} is not a valid JSON value.") diff --git a/src/zarr/core/metadata/v2.py b/src/zarr/core/metadata/v2.py index ac32521239..a613842106 100644 --- a/src/zarr/core/metadata/v2.py +++ b/src/zarr/core/metadata/v2.py @@ -278,9 +278,11 @@ def parse_dtype(data: npt.DTypeLike) -> np.dtype[Any]: def parse_zarr_format(data: object) -> Literal[2]: - if data == 2: - return 2 - raise ValueError(f"Invalid value. Expected 2. Got {data}.") + from typing import Literal + + from zarr.core.json_parse import parse_field + + return cast("Literal[2]", parse_field(data, Literal[2], "zarr_format")) def parse_filters(data: object) -> tuple[Numcodec, ...] | None: diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index 9eaccc5076..a165a27586 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -47,17 +47,20 @@ def parse_zarr_format(data: object) -> Literal[3]: - if data == 3: - return 3 - msg = f"Invalid value for 'zarr_format'. Expected '3'. Got '{data}'." - raise MetadataValidationError(msg) + from zarr.core.json_parse import parse_field + + return cast( + "Literal[3]", parse_field(data, Literal[3], "zarr_format", error=MetadataValidationError) + ) def parse_node_type_array(data: object) -> Literal["array"]: - if data == "array": - return "array" - msg = f"Invalid value for 'node_type'. Expected 'array'. Got '{data}'." - raise NodeTypeValidationError(msg) + from zarr.core.json_parse import parse_field + + return cast( + 'Literal["array"]', + parse_field(data, Literal["array"], "node_type", error=NodeTypeValidationError), + ) def parse_codecs(data: object) -> tuple[Codec, ...]: @@ -130,11 +133,12 @@ def parse_storage_transformers(data: object) -> tuple[dict[str, JSON], ...]: """ if data is None: return () - if isinstance(data, Iterable): - if len(tuple(data)) >= 1: - return data # type: ignore[return-value] - else: - return () + if isinstance(data, Iterable) and not isinstance(data, (str, bytes)): + # Materialise once. The previous implementation called ``len(tuple(data))`` + # and then returned ``data`` itself, which exhausted (and discarded) a + # one-shot iterable and could return a value typed as a tuple that was not + # actually a tuple. + return tuple(data) raise TypeError( f"Invalid storage_transformers. Expected an iterable of dicts. Got {type(data)} instead." ) @@ -610,6 +614,8 @@ def to_buffer_dict(self, prototype: BufferPrototype) -> dict[str, Buffer]: @classmethod def from_dict(cls, data: dict[str, JSON]) -> Self: + from zarr.core.json_parse import validate_json_value + # make a copy because we are modifying the dict _data = data.copy() @@ -656,7 +662,7 @@ def from_dict(cls, data: dict[str, JSON]) -> Self: chunk_grid=_data_typed["chunk_grid"], # type: ignore[arg-type] chunk_key_encoding=_data_typed["chunk_key_encoding"], # type: ignore[arg-type] codecs=_data_typed["codecs"], - attributes=_data_typed.get("attributes", {}), # type: ignore[arg-type] + attributes=validate_json_value(_data_typed.get("attributes", {})), # type: ignore[arg-type] dimension_names=_data_typed.get("dimension_names", None), fill_value=fill_value_parsed, data_type=data_type, diff --git a/tests/test_common.py b/tests/test_common.py index 2fe0743e14..2f1e7c8522 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -9,8 +9,10 @@ from zarr.core.common import ( ANY_ACCESS_MODE, AccessModeLiteral, + parse_bool, parse_int, parse_name, + parse_order, parse_shapelike, product, ) @@ -69,10 +71,28 @@ def test_parse_name_valid(data: tuple[Any, Any]) -> None: @pytest.mark.parametrize("data", [0, 1, "hello", "f"]) def test_parse_indexing_order_invalid(data: Any) -> None: - with pytest.raises(ValueError, match="Expected one of"): + with pytest.raises(ValueError, match="Failed to parse input for 'order'"): parse_indexing_order(data) +@pytest.mark.parametrize("data", [0, 1, "hello", "f"]) +def test_parse_order_invalid(data: Any) -> None: + with pytest.raises(ValueError, match="Failed to parse input for 'order'"): + parse_order(data) + + +@pytest.mark.parametrize("data", [0, 1, "true", None, [True]]) +def test_parse_bool_invalid(data: Any) -> None: + """Non-bool values are rejected with a ValueError.""" + with pytest.raises(ValueError, match="Expected instance of bool"): + parse_bool(data) + + +@pytest.mark.parametrize("data", [True, False]) +def test_parse_bool_valid(data: bool) -> None: + assert parse_bool(data) is data + + @pytest.mark.parametrize("data", ["1", 1.0, True, False, None, [1], (1,)]) def test_parse_int_invalid(data: Any) -> None: """Non-int values (including bools, which are int subclasses) are rejected.""" diff --git a/tests/test_json_parse.py b/tests/test_json_parse.py new file mode 100644 index 0000000000..da723119aa --- /dev/null +++ b/tests/test_json_parse.py @@ -0,0 +1,122 @@ +"""Tests for :mod:`zarr.core.json_parse`. + +``convert`` delegates JSON type coercion to :func:`msgspec.convert` (translating +``msgspec.ValidationError`` into ``TypeError``); ``validate_json_value`` is the +hand-written fallback for the recursive ``JSON`` alias msgspec cannot build, +including a nesting-depth limit. The final group is a regression test for the +``parse_storage_transformers`` fix that motivated the depth limit work. +""" + +from __future__ import annotations + +from typing import Literal + +import pytest + +from zarr.core.json_parse import MAX_JSON_DEPTH, convert, parse_field, validate_json_value +from zarr.core.metadata.v3 import parse_storage_transformers + + +class TestConvert: + def test_literal(self) -> None: + assert convert(3, Literal[3]) == 3 + assert convert("array", Literal["array", "group"]) == "array" + + def test_literal_rejects_non_member(self) -> None: + with pytest.raises(ValueError, match="Expected instance of"): + convert(4, Literal[3]) + with pytest.raises(ValueError, match="Expected instance of"): + convert("Q", Literal["C", "F"]) + + def test_sequence_coerced_to_tuple(self) -> None: + assert convert([1, 2, 3], tuple[int, ...]) == (1, 2, 3) + assert convert([1, 2], tuple[int, int]) == (1, 2) + + def test_int(self) -> None: + assert convert(5, int) == 5 + + def test_bool_int_strictness(self) -> None: + # bool is an int subclass, but the two must not be interchangeable. + with pytest.raises(ValueError): + convert(True, int) + with pytest.raises(ValueError): + convert(1, bool) + # ... and True must not satisfy Literal[1]. + with pytest.raises(ValueError): + convert(True, Literal[1]) + + +class TestParseField: + def test_valid_passthrough(self) -> None: + assert parse_field(3, Literal[3], "zarr_format") == 3 + + def test_wraps_with_field_context(self) -> None: + with pytest.raises(ValueError, match="Failed to parse input for 'zarr_format'"): + parse_field(4, Literal[3], "zarr_format") + + def test_custom_error_type_and_chaining(self) -> None: + class MyError(ValueError): + pass + + with pytest.raises(MyError, match="Failed to parse input for 'node_type'") as exc_info: + parse_field(5, Literal["array"], "node_type", error=MyError) + # the generic type error is chained as the cause + assert isinstance(exc_info.value.__cause__, ValueError) + + +class TestValidateJsonValue: + @pytest.mark.parametrize("value", [None, True, 1, 1.5, "s"]) + def test_primitives(self, value: object) -> None: + assert validate_json_value(value) is value + + def test_nested(self) -> None: + value = {"a": [1, 2.0, "x", True, None], "b": {"c": [{}]}} + assert validate_json_value(value) is value + + def test_rejects_non_str_keys(self) -> None: + with pytest.raises(TypeError, match="keys must be str"): + validate_json_value({1: "x"}) + + def test_rejects_non_json_leaf(self) -> None: + with pytest.raises(TypeError, match="not a valid JSON value"): + validate_json_value(object()) + with pytest.raises(TypeError, match="not a valid JSON value"): + validate_json_value({"a": object()}) + + def test_depth_limit(self) -> None: + def nest(depth: int) -> object: + v: object = "leaf" + for _ in range(depth): + v = {"k": v} + return v + + # At the limit it passes; one level deeper it is rejected. This bound is + # new behavior the previous per-field parsers never had. + assert validate_json_value(nest(MAX_JSON_DEPTH)) is not None + with pytest.raises(ValueError, match="maximum depth"): + validate_json_value(nest(MAX_JSON_DEPTH + 1)) + + +class TestStorageTransformersRegression: + """`parse_storage_transformers` used to call `len(tuple(data))` and then + return `data` itself, exhausting a one-shot iterable and returning a value + typed as a tuple but not actually a tuple.""" + + def test_none(self) -> None: + assert parse_storage_transformers(None) == () + + def test_empty(self) -> None: + assert parse_storage_transformers([]) == () + + def test_list_returns_tuple(self) -> None: + result = parse_storage_transformers([{"a": 1}]) + assert result == ({"a": 1},) + assert isinstance(result, tuple) + + def test_generator_not_exhausted(self) -> None: + result = parse_storage_transformers(iter([{"a": 1}, {"b": 2}])) + assert result == ({"a": 1}, {"b": 2}) + + def test_non_iterable_rejected(self) -> None: + with pytest.raises(TypeError, match="Expected an iterable"): + parse_storage_transformers(5) diff --git a/tests/test_metadata/test_v2.py b/tests/test_metadata/test_v2.py index d1a1ca00b4..d0560aacea 100644 --- a/tests/test_metadata/test_v2.py +++ b/tests/test_metadata/test_v2.py @@ -31,7 +31,7 @@ def test_parse_zarr_format_valid() -> None: @pytest.mark.parametrize("data", [None, 1, 3, 4, 5, "3"]) def test_parse_zarr_format_invalid(data: Any) -> None: - with pytest.raises(ValueError, match=f"Invalid value. Expected 2. Got {data}"): + with pytest.raises(ValueError, match="Failed to parse input for 'zarr_format'"): parse_zarr_format(data)