Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ io = [
"cftime",
"pooch",
]
arrow = ["pyarrow"]
etc = ["sparse>=0.15"]
parallel = ["dask[complete]"]
viz = ["cartopy>=0.24", "matplotlib>=3.10", "nc-time-axis", "seaborn"]
Expand Down Expand Up @@ -157,6 +158,7 @@ module = [
"opt_einsum.*",
"pint.*",
"pooch.*",
"polars.*",
"pyarrow.*",
"pydap.*",
"seaborn.*",
Expand Down
27 changes: 27 additions & 0 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,33 @@ def __init__(

self._close = None

def __arrow_c_stream__(self, requested_schema: Any = None) -> Any:
try:
import pyarrow as pa
except ImportError:
raise ImportError(
"pyarrow is required to export via the Arrow PyCapsule Interface."
) from None

values = self._variable.values
dims = self._variable.dims

if not values.flags.c_contiguous:
values = np.ascontiguousarray(values)
Comment on lines +491 to +492
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can only use values.ravel down there to ensure contiguous array.


# Expand coords to match the full flattened length
coord_arrays = (self.coords[dim].values for dim in dims)
grids = np.meshgrid(*coord_arrays, indexing="ij", copy=False)

columns: dict[Hashable, pa.Array] = {}
for dim, grid in zip(dims, grids, strict=True):
columns[dim] = pa.array(grid.ravel())

columns[self.name or "values"] = pa.array(values.ravel())

table = pa.table(columns)
return table.__arrow_c_stream__(requested_schema)

@classmethod
def _construct_direct(
cls,
Expand Down
1 change: 1 addition & 0 deletions xarray/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def _importorskip(
has_iris, requires_iris = _importorskip("iris")
has_numbagg, requires_numbagg = _importorskip("numbagg")
has_pyarrow, requires_pyarrow = _importorskip("pyarrow")
has_polars, requires_polars = _importorskip("polars")
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
Expand Down
134 changes: 134 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
requires_iris,
requires_numexpr,
requires_pint,
requires_polars,
requires_pyarrow,
requires_scipy,
requires_sparse,
Expand Down Expand Up @@ -7673,3 +7674,136 @@ def test_unstack_index_var() -> None:
name="x",
)
assert_identical(actual, expected)


class TestArrowPyCapsule:
@requires_pyarrow
def test_pyarrow_table_1d(self):
import pyarrow as pa

da = xr.DataArray(
[1.0, 2.0, 3.0],
dims=["x"],
coords={"x": [10, 20, 30]},
name="temperature",
)
table = pa.table(da)

assert isinstance(table, pa.Table)
assert set(table.column_names) == {"x", "temperature"}
assert table.num_rows == 3
assert table.schema.field("x").type == pa.int64()
assert table.schema.field("temperature").type == pa.float64()
np.testing.assert_array_equal(table["x"].to_pylist(), [10, 20, 30])
np.testing.assert_array_equal(table["temperature"].to_pylist(), [1.0, 2.0, 3.0])

@requires_pyarrow
def test_pyarrow_table_2d(self):
import pyarrow as pa

da = xr.DataArray(
np.arange(6, dtype=float).reshape(2, 3),
dims=["x", "y"],
coords={"x": [0, 1], "y": [10, 20, 30]},
name="data",
)
table = pa.table(da)

assert isinstance(table, pa.Table)
assert set(table.column_names) == {"x", "y", "data"}
assert table.num_rows == 6
assert table.schema.field("x").type == pa.int64()
assert table.schema.field("y").type == pa.int64()
assert table.schema.field("data").type == pa.float64()
np.testing.assert_array_equal(
table["data"].to_pylist(), list(np.arange(6, dtype=float))
)

@requires_pyarrow
def test_data_array_unnamed_variable(self):
import pyarrow as pa

da = xr.DataArray([1, 2, 3], dims=["x"], coords={"x": [0, 1, 2]})
table = pa.table(da)

assert "values" in table.column_names

@requires_polars
def test_polars_dataframe_1d(self):
import polars as pl

da = xr.DataArray(
[1.0, 2.0, 3.0],
dims=["x"],
coords={"x": [10, 20, 30]},
name="temperature",
)
df = pl.from_arrow(da)

assert isinstance(df, pl.DataFrame)
assert set(df.columns) == {"x", "temperature"}
assert len(df) == 3
np.testing.assert_array_equal(df["x"].to_list(), [10, 20, 30])
np.testing.assert_array_equal(df["temperature"].to_list(), [1.0, 2.0, 3.0])

@requires_polars
def test_polars_dataframe_2d(self):
import polars as pl

da = xr.DataArray(
np.arange(6, dtype=float).reshape(2, 3),
dims=["x", "y"],
coords={"x": [0, 1], "y": [10, 20, 30]},
name="data",
)
df = pl.from_arrow(da)

assert isinstance(df, pl.DataFrame)
assert set(df.columns) == {"x", "y", "data"}
assert len(df) == 6
np.testing.assert_array_equal(
df["data"].to_list(), list(np.arange(6, dtype=float))
)
# x repeats for each y: [0,0,0,1,1,1]
np.testing.assert_array_equal(df["x"].to_list(), [0, 0, 0, 1, 1, 1])
# y cycles for each x: [10,20,30,10,20,30]
np.testing.assert_array_equal(df["y"].to_list(), [10, 20, 30, 10, 20, 30])

@requires_dask
@requires_pyarrow
def test_dask_dataarray(self):
import dask.array as da
import pyarrow as pa

dask_da = xr.DataArray(
da.from_array(np.arange(6, dtype=float).reshape(2, 3)),
dims=["x", "y"],
coords={"x": [0, 1], "y": [10, 20, 30]},
name="data",
)

table = pa.table(dask_da)
assert isinstance(table, pa.Table)
np.testing.assert_array_equal(
table["data"].to_pylist(), list(np.arange(6, dtype=float))
)

@requires_polars
@requires_pyarrow
def test_polars_pyarrow_consistent(self):
import polars as pl
import pyarrow as pa

da = xr.DataArray(
np.arange(6, dtype=float).reshape(2, 3),
dims=["x", "y"],
coords={"x": [0, 1], "y": [10, 20, 30]},
name="data",
)
pa_table = pa.table(da)
pl_df = pl.from_arrow(da)

for col in pa_table.column_names:
np.testing.assert_array_equal(
pa_table[col].to_pylist(), pl_df[col].to_list()
)
Loading