Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/pypi-release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
with:
fetch-depth: 0
persist-credentials: false
- uses: actions/setup-python@a0af7a228712d6121d37aba47adf55c1332c9c2e # v6.2.0
- uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0
name: Install Python
with:
python-version: "3.12"
Expand Down Expand Up @@ -61,7 +61,7 @@ jobs:
needs: build-artifacts
runs-on: ubuntu-latest
steps:
- uses: actions/setup-python@a0af7a228712d6121d37aba47adf55c1332c9c2e # v6.2.0
- uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0
name: Install Python
with:
python-version: "3.12"
Expand Down
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ New Features
- Added ``inherit='all_coords'`` option to :py:meth:`DataTree.to_dataset` to inherit
all parent coordinates, not just indexed ones (:issue:`10812`, :pull:`11230`).
By `Alfonso Ladino <https://github.com/aladinor>`_.
- Added complex dtype support to FillValueCoder for the Zarr backend. (:pull:`11151`)
By `Max Jones <https://github.com/maxrjones>`_.

Breaking Changes
~~~~~~~~~~~~~~~~
Expand Down
17 changes: 17 additions & 0 deletions properties/test_encode_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import hypothesis.extra.numpy as npst
import numpy as np
from hypothesis import given
from hypothesis import strategies as st

import xarray as xr
from xarray.coding.times import _parse_iso8601
Expand Down Expand Up @@ -48,6 +49,22 @@ def test_CFScaleOffset_coder_roundtrip(original) -> None:
xr.testing.assert_identical(original, roundtripped)


@given(
real=st.floats(allow_nan=True, allow_infinity=True),
imag=st.floats(allow_nan=True, allow_infinity=True),
dtype=st.sampled_from([np.complex64, np.complex128]),
)
def test_FillValueCoder_complex_roundtrip(real, imag, dtype) -> None:
from xarray.backends.zarr import FillValueCoder

value = dtype(complex(real, imag))
encoded = FillValueCoder.encode(value, np.dtype(dtype))
decoded = FillValueCoder.decode(encoded, np.dtype(dtype))
np.testing.assert_equal(
np.array(decoded, dtype=dtype), np.array(value, dtype=dtype)
)


@given(dt=datetimes())
def test_iso8601_decode(dt):
iso = dt.isoformat()
Expand Down
61 changes: 55 additions & 6 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,40 +122,89 @@ class FillValueCoder:
"""

@classmethod
def encode(cls, value: int | float | str | bytes, dtype: np.dtype[Any]) -> Any:
def encode(
cls, value: int | float | complex | str | bytes, dtype: np.dtype[Any]
) -> Any:
if dtype.kind == "S":
# byte string, this implies that 'value' must also be `bytes` dtype.
assert isinstance(value, bytes)
if not isinstance(value, bytes):
raise TypeError(
f"Failed to encode fill_value: expected bytes for dtype {dtype}, got {type(value).__name__}"
)
return base64.standard_b64encode(value).decode()
elif dtype.kind == "b":
# boolean
return bool(value)
elif dtype.kind in "iu":
# todo: do we want to check for decimals?
if not isinstance(value, int | float | np.integer | np.floating):
raise TypeError(
f"Failed to encode fill_value: expected int or float for dtype {dtype}, got {type(value).__name__}"
)
return int(value)
elif dtype.kind == "f":
if not isinstance(value, int | float | np.integer | np.floating):
raise TypeError(
f"Failed to encode fill_value: expected int or float for dtype {dtype}, got {type(value).__name__}"
)
return base64.standard_b64encode(struct.pack("<d", float(value))).decode()
elif dtype.kind == "c":
# complex - encode each component as base64, matching float encoding
if not isinstance(value, complex) and not np.issubdtype(
type(value), np.complexfloating
):
raise TypeError(
f"Failed to encode fill_value: expected complex for dtype {dtype}, got {type(value).__name__}"
)
return [
base64.standard_b64encode(
struct.pack("<d", float(value.real)) # type: ignore[union-attr]
).decode(),
base64.standard_b64encode(
struct.pack("<d", float(value.imag)) # type: ignore[union-attr]
).decode(),
]
elif dtype.kind == "U":
return str(value)
else:
raise ValueError(f"Failed to encode fill_value. Unsupported dtype {dtype}")

@classmethod
def decode(cls, value: int | float | str | bytes, dtype: str | np.dtype[Any]):
def decode(
cls, value: int | float | str | bytes | list, dtype: str | np.dtype[Any]
):
if dtype == "string":
# zarr V3 string type
return str(value)
elif dtype == "bytes":
# zarr V3 bytes type
assert isinstance(value, str | bytes)
if not isinstance(value, str | bytes):
raise TypeError(
f"Failed to decode fill_value: expected str or bytes for dtype {dtype}, got {type(value).__name__}"
)
return base64.standard_b64decode(value)
np_dtype = np.dtype(dtype)
if np_dtype.kind == "f":
assert isinstance(value, str | bytes)
if not isinstance(value, str | bytes):
raise TypeError(
f"Failed to decode fill_value: expected str or bytes for dtype {np_dtype}, got {type(value).__name__}"
)
return struct.unpack("<d", base64.standard_b64decode(value))[0]
elif np_dtype.kind == "c":
# complex - decode each component from base64, matching float decoding
if not (isinstance(value, list | tuple) and len(value) == 2):
raise TypeError(
f"Failed to decode fill_value: expected a 2-element list for dtype {np_dtype}, got {type(value).__name__}"
)
real = struct.unpack("<d", base64.standard_b64decode(value[0]))[0]
imag = struct.unpack("<d", base64.standard_b64decode(value[1]))[0]
return complex(real, imag)
elif np_dtype.kind == "b":
return bool(value)
elif np_dtype.kind in "iu":
if not isinstance(value, int | float | np.integer | np.floating):
raise TypeError(
f"Failed to decode fill_value: expected int or float for dtype {np_dtype}, got {type(value).__name__}"
)
return int(value)
else:
raise ValueError(f"Failed to decode fill_value. Unsupported dtype {dtype}")
Expand Down
35 changes: 35 additions & 0 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -7241,6 +7241,41 @@ def test_encode_zarr_attr_value() -> None:
assert actual3 == expected3


@requires_zarr
@pytest.mark.parametrize("dtype", [complex, np.complex64, np.complex128])
def test_fill_value_coder_complex(dtype) -> None:
"""Test that FillValueCoder round-trips complex fill values."""
from xarray.backends.zarr import FillValueCoder

for value in [dtype(1 + 2j), dtype(-3.5 + 4.5j), dtype(complex("nan+nanj"))]:
encoded = FillValueCoder.encode(value, np.dtype(dtype))
decoded = FillValueCoder.decode(encoded, np.dtype(dtype))
np.testing.assert_equal(np.array(decoded, dtype=dtype), np.array(value))


@requires_zarr
@pytest.mark.parametrize(
"value,dtype",
[
(np.float32(np.inf), np.float32),
(np.float32(-np.inf), np.float32),
(np.float64(np.inf), np.float64),
(np.float64(-np.inf), np.float64),
(np.float32(np.nan), np.float32),
(np.float64(np.nan), np.float64),
],
)
def test_fill_value_coder_inf_nan(value, dtype) -> None:
"""Test that FillValueCoder round-trips inf and nan fill values."""
from xarray.backends.zarr import FillValueCoder

encoded = FillValueCoder.encode(value, np.dtype(dtype))
decoded = FillValueCoder.decode(encoded, np.dtype(dtype))
np.testing.assert_equal(
np.array(decoded, dtype=dtype), np.array(value, dtype=dtype)
)


@requires_zarr
def test_extract_zarr_variable_encoding() -> None:
var = xr.Variable("x", [1, 2])
Expand Down