diff --git a/doc/source/conf.py b/doc/source/conf.py index 49a062315362..ff1c99b62c49 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -92,6 +92,18 @@ myst_heading_anchors = 3 +# Make broken internal references into build time errors. +# See https://www.sphinx-doc.org/en/master/usage/configuration.html#confval-nitpicky +# for more information. :py:class: references are ignored due to false positives +# arising from type annotations. See https://github.com/ray-project/ray/pull/46103 +# for additional context. +nitpicky = True +nitpick_ignore_regex = [ + ("py:class", ".*"), + # Workaround for https://github.com/sphinx-doc/sphinx/issues/10974 + ("py:obj", "ray\.data\.datasource\.datasink\.WriteReturnType"), +] + # Cache notebook outputs in _build/.jupyter_cache # To prevent notebook execution, set this to "off". To force re-execution, set this to # "force". To cache previous runs, set this to "cache". diff --git a/doc/source/data/api/input_output.rst b/doc/source/data/api/input_output.rst index d6b496990ef6..bd29acbe98f3 100644 --- a/doc/source/data/api/input_output.rst +++ b/doc/source/data/api/input_output.rst @@ -168,6 +168,53 @@ Databricks read_databricks_tables +Delta Sharing +------------- + +.. autosummary:: + :nosignatures: + :toctree: doc/ + + read_delta_sharing_tables + +Hudi +---- + +.. autosummary:: + :nosignatures: + :toctree: doc/ + + read_hudi + +Iceberg +------- + +.. autosummary:: + :nosignatures: + :toctree: doc/ + + read_iceberg + Dataset.write_iceberg + +Lance +----- + +.. autosummary:: + :nosignatures: + :toctree: doc/ + + read_lance + Dataset.write_lance + +ClickHouse +---------- + +.. autosummary:: + :nosignatures: + :toctree: doc/ + + read_clickhouse + Dask ---- @@ -270,6 +317,8 @@ Datasink API datasource.RowBasedFileDatasink datasource.BlockBasedFileDatasink datasource.FileBasedDatasource + datasource.WriteResult + datasource.WriteReturnType Partitioning API ---------------- diff --git a/python/ray/data/_internal/execution/interfaces/task_context.py b/python/ray/data/_internal/execution/interfaces/task_context.py index 99431125a0ad..094faf2440e0 100644 --- a/python/ray/data/_internal/execution/interfaces/task_context.py +++ b/python/ray/data/_internal/execution/interfaces/task_context.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Dict, Optional from ray.data._internal.progress_bar import ProgressBar @@ -39,3 +39,6 @@ class TaskContext: # The target maximum number of bytes to include in the task's output block. target_max_block_size: Optional[int] = None + + # Additional keyword arguments passed to the task. + kwargs: Dict[str, Any] = field(default_factory=dict) diff --git a/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py b/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py index acf811f32a73..d8c88a242549 100644 --- a/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py +++ b/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py @@ -212,7 +212,12 @@ def _dispatch_tasks(self): num_returns="streaming", name=self.name, **self._ray_actor_task_remote_args, - ).remote(DataContext.get_current(), ctx, *input_blocks) + ).remote( + DataContext.get_current(), + ctx, + *input_blocks, + **self.get_map_task_kwargs(), + ) def _task_done_callback(actor_to_return): # Return the actor that was running the task to the pool. @@ -401,12 +406,14 @@ def submit( data_context: DataContext, ctx: TaskContext, *blocks: Block, + **kwargs: Dict[str, Any], ) -> Iterator[Union[Block, List[BlockMetadata]]]: yield from _map_task( self._map_transformer, data_context, ctx, *blocks, + **kwargs, ) def __repr__(self): diff --git a/python/ray/data/_internal/execution/operators/map_operator.py b/python/ray/data/_internal/execution/operators/map_operator.py index 6f9992faf530..29079d2c9553 100644 --- a/python/ray/data/_internal/execution/operators/map_operator.py +++ b/python/ray/data/_internal/execution/operators/map_operator.py @@ -85,6 +85,25 @@ def __init__( # too-large blocks, which may reduce parallelism for # the subsequent operator. self._additional_split_factor = None + # Callback functions that generate additional task kwargs + # for the map task. + self._map_task_kwargs_fns: List[Callable[[], Dict[str, Any]]] = [] + + def add_map_task_kwargs_fn(self, map_task_kwargs_fn: Callable[[], Dict[str, Any]]): + """Add a callback function that generates additional kwargs for the map tasks. + In the map tasks, the kwargs can be accessible via `TaskContext.kwargs`. + """ + self._map_task_kwargs_fns.append(map_task_kwargs_fn) + + def get_map_task_kwargs(self) -> Dict[str, Any]: + """Get the kwargs for the map task. + Subclasses should pass the returned kwargs to the map tasks. + In the map tasks, the kwargs can be accessible via `TaskContext.kwargs`. + """ + kwargs = {} + for fn in self._map_task_kwargs_fns: + kwargs.update(fn()) + return kwargs def get_additional_split_factor(self) -> int: if self._additional_split_factor is None: @@ -402,6 +421,7 @@ def _map_task( data_context: DataContext, ctx: TaskContext, *blocks: Block, + **kwargs: Dict[str, Any], ) -> Iterator[Union[Block, List[BlockMetadata]]]: """Remote function for a single operator task. @@ -415,6 +435,7 @@ def _map_task( as the last generator return. """ DataContext._set_current(data_context) + ctx.kwargs.update(kwargs) stats = BlockExecStats.builder() map_transformer.set_target_max_block_size(ctx.target_max_block_size) for b_out in map_transformer.apply_transform(iter(blocks), ctx): diff --git a/python/ray/data/_internal/execution/operators/task_pool_map_operator.py b/python/ray/data/_internal/execution/operators/task_pool_map_operator.py index 2d84dd1bc111..a0c5dc3de733 100644 --- a/python/ray/data/_internal/execution/operators/task_pool_map_operator.py +++ b/python/ray/data/_internal/execution/operators/task_pool_map_operator.py @@ -76,7 +76,8 @@ def _add_bundled_input(self, bundle: RefBundle): self._map_transformer_ref, data_context, ctx, - *input_blocks, + *bundle.block_refs, + **self.get_map_task_kwargs(), ) self._submit_data_task(gen, bundle) diff --git a/python/ray/data/_internal/logical/rules/operator_fusion.py b/python/ray/data/_internal/logical/rules/operator_fusion.py index d555db84d6c3..79d664371aae 100644 --- a/python/ray/data/_internal/logical/rules/operator_fusion.py +++ b/python/ray/data/_internal/logical/rules/operator_fusion.py @@ -1,3 +1,4 @@ +import itertools from typing import List, Optional, Tuple # TODO(Clark): Remove compute dependency once we delete the legacy compute. @@ -311,6 +312,11 @@ def _get_fused_map_operator( min_rows_per_bundle=min_rows_per_bundled_input, ray_remote_args=ray_remote_args, ) + op.set_logical_operators(*up_op._logical_operators, *down_op._logical_operators) + for map_task_kwargs_fn in itertools.chain( + up_op._map_task_kwargs_fns, down_op._map_task_kwargs_fns + ): + op.add_map_task_kwargs_fn(map_task_kwargs_fn) # Build a map logical operator to be used as a reference for further fusion. # TODO(Scott): This is hacky, remove this once we push fusion to be purely based diff --git a/python/ray/data/_internal/planner/plan_write_op.py b/python/ray/data/_internal/planner/plan_write_op.py index c33e831fde0b..ab61ea90d7b6 100644 --- a/python/ray/data/_internal/planner/plan_write_op.py +++ b/python/ray/data/_internal/planner/plan_write_op.py @@ -1,4 +1,7 @@ -from typing import Callable, Iterator, Union +import itertools +from typing import Callable, Iterator, List, Union + +from pandas import DataFrame from ray.data._internal.compute import TaskPoolStrategy from ray.data._internal.execution.interfaces import PhysicalOperator @@ -9,41 +12,82 @@ MapTransformer, ) from ray.data._internal.logical.operators.write_operator import Write -from ray.data.block import Block -from ray.data.datasource.datasink import Datasink +from ray.data.block import Block, BlockAccessor +from ray.data.datasource.datasink import Datasink, WriteResult from ray.data.datasource.datasource import Datasource +def gen_datasink_write_result( + write_result_blocks: List[Block], +) -> WriteResult: + assert all( + isinstance(block, DataFrame) and len(block) == 1 + for block in write_result_blocks + ) + total_num_rows = sum(result["num_rows"].sum() for result in write_result_blocks) + total_size_bytes = sum(result["size_bytes"].sum() for result in write_result_blocks) + + write_returns = [result["write_return"][0] for result in write_result_blocks] + return WriteResult(total_num_rows, total_size_bytes, write_returns) + + def generate_write_fn( datasink_or_legacy_datasource: Union[Datasink, Datasource], **write_args ) -> Callable[[Iterator[Block], TaskContext], Iterator[Block]]: - # If the write op succeeds, the resulting Dataset is a list of - # arbitrary objects (one object per write task). Otherwise, an error will - # be raised. The Datasource can handle execution outcomes with the - # on_write_complete() and on_write_failed(). - def fn(blocks: Iterator[Block], ctx) -> Iterator[Block]: + def fn(blocks: Iterator[Block], ctx: TaskContext) -> Iterator[Block]: + """Writes the blocks to the given datasink or legacy datasource. + + Outputs the original blocks to be written.""" + # Create a copy of the iterator, so we can return the original blocks. + it1, it2 = itertools.tee(blocks, 2) if isinstance(datasink_or_legacy_datasource, Datasink): - write_result = datasink_or_legacy_datasource.write(blocks, ctx) - else: - write_result = datasink_or_legacy_datasource.write( - blocks, ctx, **write_args + ctx.kwargs["_datasink_write_return"] = datasink_or_legacy_datasource.write( + it1, ctx ) + else: + datasink_or_legacy_datasource.write(it1, ctx, **write_args) + + return it2 + + return fn + + +def generate_collect_write_stats_fn() -> Callable[ + [Iterator[Block], TaskContext], Iterator[Block] +]: + # If the write op succeeds, the resulting Dataset is a list of + # one Block which contain stats/metrics about the write. + # Otherwise, an error will be raised. The Datasource can handle + # execution outcomes with `on_write_complete()`` and `on_write_failed()``. + def fn(blocks: Iterator[Block], ctx: TaskContext) -> Iterator[Block]: + """Handles stats collection for block writes.""" + block_accessors = [BlockAccessor.for_block(block) for block in blocks] + total_num_rows = sum(ba.num_rows() for ba in block_accessors) + total_size_bytes = sum(ba.size_bytes() for ba in block_accessors) # NOTE: Write tasks can return anything, so we need to wrap it in a valid block # type. import pandas as pd - block = pd.DataFrame({"write_result": [write_result]}) - return [block] + block = pd.DataFrame( + { + "num_rows": [total_num_rows], + "size_bytes": [total_size_bytes], + "write_return": [ctx.kwargs.get("_datasink_write_return", None)], + } + ) + return iter([block]) return fn def plan_write_op(op: Write, input_physical_dag: PhysicalOperator) -> PhysicalOperator: write_fn = generate_write_fn(op._datasink_or_legacy_datasource, **op._write_args) + collect_stats_fn = generate_collect_write_stats_fn() # Create a MapTransformer for a write operator transform_fns = [ BlockMapTransformFn(write_fn), + BlockMapTransformFn(collect_stats_fn), ] map_transformer = MapTransformer(transform_fns) return MapOperator.create( diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 780fb4acd957..73591229b92d 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -31,9 +31,11 @@ from ray.data._internal.block_list import BlockList from ray.data._internal.compute import ComputeStrategy from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder +from ray.data.datasource.iceberg_datasink import IcebergDatasink from ray.data._internal.equalize import _equalize from ray.data._internal.execution.interfaces import RefBundle from ray.data._internal.execution.legacy_compat import _block_list_to_bundles +from ray.data._internal.execution.util import memory_string from ray.data._internal.iterator.iterator_impl import DataIteratorImpl from ray.data._internal.iterator.stream_split_iterator import StreamSplitDataIterator from ray.data._internal.lazy_block_list import LazyBlockList @@ -60,6 +62,7 @@ from ray.data._internal.pandas_block import PandasBlockSchema from ray.data._internal.plan import ExecutionPlan from ray.data._internal.planner.exchange.sort_task_spec import SortKey +from ray.data._internal.planner.plan_write_op import gen_datasink_write_result from ray.data._internal.remote_fn import cached_remote_fn from ray.data._internal.split import _get_num_rows, _split_at_indices from ray.data._internal.stats import DatasetStats, DatasetStatsSummary, StatsManager @@ -2845,6 +2848,61 @@ def write_json( concurrency=concurrency, ) + @ConsumptionAPI + @PublicAPI(stability="alpha") + def write_iceberg( + self, + table_identifier: str, + catalog_kwargs: Optional[Dict[str, Any]] = None, + snapshot_properties: Optional[Dict[str, str]] = None, + ray_remote_args: Dict[str, Any] = None, + concurrency: Optional[int] = None, + ) -> None: + """Writes the :class:`~ray.data.Dataset` to an Iceberg table. + + .. tip:: + For more details on PyIceberg, see + - URI: https://py.iceberg.apache.org/ + + Examples: + .. testcode:: + :skipif: True + + import ray + import pandas as pd + docs = [{"title": "Iceberg data sink test"} for key in range(4)] + ds = ray.data.from_pandas(pd.DataFrame(docs)) + ds.write_iceberg( + table_identifier="db_name.table_name", + catalog_kwargs={"name": "default", "type": "sql"} + ) + + Args: + table_identifier: Fully qualified table identifier (``db_name.table_name``) + catalog_kwargs: Optional arguments to pass to PyIceberg's catalog.load_catalog() + function (e.g., name, type, etc.). For the function definition, see + `pyiceberg catalog + `_. + snapshot_properties: custom properties write to snapshot when committing + to an iceberg table. + ray_remote_args: kwargs passed to :func:`ray.remote` in the write tasks. + concurrency: The maximum number of Ray tasks to run concurrently. Set this + to control number of tasks to run concurrently. This doesn't change the + total number of tasks run. By default, concurrency is dynamically + decided based on the available resources. + """ + + datasink = IcebergDatasink( + table_identifier, catalog_kwargs, snapshot_properties + ) + + self.write_datasink( + datasink, + ray_remote_args=ray_remote_args, + concurrency=concurrency, + ) + @PublicAPI(stability="alpha") @ConsumptionAPI def write_images( @@ -3575,18 +3633,22 @@ def write_datasink( logical_plan = LogicalPlan(write_op) try: - import pandas as pd datasink.on_write_start() self._write_ds = Dataset(plan, logical_plan).materialize() - blocks = ray.get(self._write_ds._plan.execute().get_blocks()) - assert all( - isinstance(block, pd.DataFrame) and len(block) == 1 for block in blocks + # TODO: Get and handle the blocks with an iterator instead of getting + # everything in a blocking way, so some blocks can be freed earlier. + raw_write_results = ray.get(self._write_ds._plan.execute().block_refs) + write_result = gen_datasink_write_result(raw_write_results) + logger.info( + "Data sink %s finished. %d rows and %s data written.", + datasink.get_name(), + write_result.num_rows, + memory_string(write_result.size_bytes), ) - write_results = [block["write_result"][0] for block in blocks] + datasink.on_write_complete(write_result) - datasink.on_write_complete(write_results) except Exception as e: datasink.on_write_failed(e) raise diff --git a/python/ray/data/datasource/__init__.py b/python/ray/data/datasource/__init__.py index 5f950ec99001..5176f0ea4c7c 100644 --- a/python/ray/data/datasource/__init__.py +++ b/python/ray/data/datasource/__init__.py @@ -8,6 +8,13 @@ from ray.data.datasource.csv_datasink import _CSVDatasink from ray.data.datasource.csv_datasource import CSVDatasource from ray.data.datasource.datasink import Datasink, DummyOutputDatasink +from ray.data.datasource.sql_datasource import Connection +from ray.data.datasource.datasink import ( + Datasink, + DummyOutputDatasink, + WriteResult, + WriteReturnType, +) from ray.data.datasource.datasource import ( Datasource, RandomIntRowDatasource, @@ -113,4 +120,6 @@ "_WebDatasetDatasink", "WebDatasetDatasource", "_S3FileSystemWrapper", + "WriteResult", + "WriteReturnType", ] diff --git a/python/ray/data/datasource/bigquery_datasink.py b/python/ray/data/datasource/bigquery_datasink.py index 33550f0791cb..000962be29be 100644 --- a/python/ray/data/datasource/bigquery_datasink.py +++ b/python/ray/data/datasource/bigquery_datasink.py @@ -3,7 +3,7 @@ import tempfile import time import uuid -from typing import Any, Iterable, Optional +from typing import Iterable, Optional import pyarrow.parquet as pq @@ -20,7 +20,7 @@ RATE_LIMIT_EXCEEDED_SLEEP_TIME = 11 -class _BigQueryDatasink(Datasink): +class _BigQueryDatasink(Datasink[None]): def __init__( self, project_id: str, @@ -70,7 +70,7 @@ def write( self, blocks: Iterable[Block], ctx: TaskContext, - ) -> Any: + ) -> None: def _write_single_block(block: Block, project_id: str, dataset: str) -> None: from google.api_core import exceptions from google.cloud import bigquery @@ -127,5 +127,3 @@ def _write_single_block(block: Block, project_id: str, dataset: str) -> None: for block in blocks ] ) - - return "ok" diff --git a/python/ray/data/datasource/datasink.py b/python/ray/data/datasource/datasink.py index f77c3b93f3d2..666264c8f6e8 100644 --- a/python/ray/data/datasource/datasink.py +++ b/python/ray/data/datasource/datasink.py @@ -1,13 +1,34 @@ -from typing import Any, Iterable, List, Optional +import logging +from dataclasses import dataclass +from typing import Generic, Iterable, List, Optional, TypeVar import ray from ray.data._internal.execution.interfaces import TaskContext from ray.data.block import Block, BlockAccessor from ray.util.annotations import DeveloperAPI +logger = logging.getLogger(__name__) + + +WriteReturnType = TypeVar("WriteReturnType") +"""Generic type for the return value of `Datasink.write`.""" + + +@dataclass +@DeveloperAPI +class WriteResult(Generic[WriteReturnType]): + """Aggregated result of the Datasink write operations.""" + + # Total number of written rows. + num_rows: int + # Total size in bytes of written data. + size_bytes: int + # All returned values of `Datasink.write`. + write_returns: List[WriteReturnType] + @DeveloperAPI -class Datasink: +class Datasink(Generic[WriteReturnType]): """Interface for defining write-related logic. If you want to write data to something that isn't built-in, subclass this class @@ -26,7 +47,7 @@ def write( self, blocks: Iterable[Block], ctx: TaskContext, - ) -> Any: + ) -> WriteReturnType: """Write blocks. This is used by a single write task. Args: @@ -34,12 +55,13 @@ def write( ctx: ``TaskContext`` for the write task. Returns: - A user-defined output. Can be anything, and the returned value is passed to - :meth:`~Datasink.on_write_complete`. + Result of this write task. When the entire write operator finishes, + All returned values will be passed as `WriteResult.write_returns` + to `Datasink.on_write_complete`. """ raise NotImplementedError - def on_write_complete(self, write_results: List[Any]) -> None: + def on_write_complete(self, write_result: WriteResult[WriteReturnType]): """Callback for when a write job completes. This can be used to "commit" a write output. This method must @@ -47,7 +69,8 @@ def on_write_complete(self, write_results: List[Any]) -> None: method fails, then ``on_write_failed()`` is called. Args: - write_results: The objects returned by every :meth:`~Datasink.write` task. + write_result: Aggregated result of the + the Write operator, containing write results and stats. """ pass @@ -89,7 +112,7 @@ def num_rows_per_write(self) -> Optional[int]: @DeveloperAPI -class DummyOutputDatasink(Datasink): +class DummyOutputDatasink(Datasink[None]): """An example implementation of a writable datasource for testing. Examples: >>> import ray @@ -110,10 +133,9 @@ def __init__(self): self.rows_written = 0 self.enabled = True - def write(self, block: Block) -> str: + def write(self, block: Block) -> None: block = BlockAccessor.for_block(block) self.rows_written += block.num_rows() - return "ok" def get_rows_written(self): return self.rows_written @@ -127,17 +149,15 @@ def write( self, blocks: Iterable[Block], ctx: TaskContext, - ) -> Any: + ) -> None: tasks = [] if not self.enabled: raise ValueError("disabled") for b in blocks: tasks.append(self.data_sink.write.remote(b)) ray.get(tasks) - return "ok" - def on_write_complete(self, write_results: List[Any]) -> None: - assert all(w == "ok" for w in write_results), write_results + def on_write_complete(self, write_result: WriteResult[None]): self.num_ok += 1 def on_write_failed(self, error: Exception) -> None: diff --git a/python/ray/data/datasource/file_datasink.py b/python/ray/data/datasource/file_datasink.py index c4bf3fb6c867..6266e1fedbfa 100644 --- a/python/ray/data/datasource/file_datasink.py +++ b/python/ray/data/datasource/file_datasink.py @@ -1,5 +1,5 @@ import posixpath -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional +from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional from ray._private.utils import _add_creatable_buckets_param_if_s3_uri from ray.data._internal.dataset_logger import DatasetLogger @@ -8,8 +8,8 @@ from ray.data._internal.util import _is_local_scheme, call_with_retry from ray.data.block import Block, BlockAccessor from ray.data.context import DataContext +from ray.data.datasource.datasink import Datasink, WriteResult from ray.data.datasource.block_path_provider import BlockWritePathProvider -from ray.data.datasource.datasink import Datasink from ray.data.datasource.filename_provider import ( FilenameProvider, _DefaultFilenameProvider, @@ -27,7 +27,7 @@ WRITE_FILE_RETRY_MAX_BACKOFF_SECONDS = 32 -class _FileDatasink(Datasink): +class _FileDatasink(Datasink[None]): def __init__( self, path: str, @@ -106,7 +106,7 @@ def write( self, blocks: Iterable[Block], ctx: TaskContext, - ) -> Any: + ) -> None: builder = DelegatingBlockBuilder() for block in blocks: builder.add_block(block) @@ -114,22 +114,17 @@ def write( block_accessor = BlockAccessor.for_block(block) if block_accessor.num_rows() == 0: - logger.get_logger().warning(f"Skipped writing empty block to {self.path}") - return "skip" + logger.warning(f"Skipped writing empty block to {self.path}") + return self.write_block(block_accessor, 0, ctx) - # TODO: decide if we want to return richer object when the task - # succeeds. - return "ok" def write_block(self, block: BlockAccessor, block_index: int, ctx: TaskContext): raise NotImplementedError - def on_write_complete(self, write_results: List[Any]) -> None: - if not self.has_created_dir: - return - - if all(write_results == "skip" for write_results in write_results): + def on_write_complete(self, write_result: WriteResult[None]): + # If no rows were written, we can delete the directory. + if self.has_created_dir and write_result.num_rows == 0: self.filesystem.delete_dir(self.path) @property @@ -183,13 +178,15 @@ def write_block(self, block: BlockAccessor, block_index: int, ctx: TaskContext): ) write_path = posixpath.join(self.path, filename) - def write_row_to_path(): + def write_row_to_path(row, write_path): with self.open_output_stream(write_path) as file: self.write_row_to_file(row, file) logger.get_logger(log_to_stdout=False).debug(f"Writing {write_path} file.") call_with_retry( - write_row_to_path, + lambda row=row, write_path=write_path: write_row_to_path( + row, write_path + ), description=f"write '{write_path}'", match=DataContext.get_current().write_file_retry_on_errors, max_attempts=WRITE_FILE_MAX_ATTEMPTS, diff --git a/python/ray/data/datasource/iceberg_datasink.py b/python/ray/data/datasource/iceberg_datasink.py new file mode 100644 index 000000000000..7f2ae9964f7a --- /dev/null +++ b/python/ray/data/datasource/iceberg_datasink.py @@ -0,0 +1,161 @@ +""" +Module to write a Ray Dataset into an iceberg table, by using the Ray Datasink API. +""" +import logging + +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional + +from ray.data.datasource.datasink import Datasink +from ray.util.annotations import DeveloperAPI +from ray.data.block import BlockAccessor, Block +from ray.data._internal.execution.interfaces import TaskContext +from ray.data.datasource.datasink import WriteResult +import uuid + +if TYPE_CHECKING: + from pyiceberg.catalog import Catalog + from pyiceberg.manifest import DataFile + + +logger = logging.getLogger(__name__) + + +@DeveloperAPI +class IcebergDatasink(Datasink[List["DataFile"]]): + """ + Iceberg datasink to write a Ray Dataset into an existing Iceberg table. This module + heavily uses PyIceberg to write to iceberg table. All the routines in this class override + `ray.data.Datasink`. + + """ + + def __init__( + self, + table_identifier: str, + catalog_kwargs: Optional[Dict[str, Any]] = None, + snapshot_properties: Optional[Dict[str, str]] = None, + ): + """ + Initialize the IcebergDatasink + + Args: + table_identifier: The identifier of the table to read e.g. `default.taxi_dataset` + catalog_kwargs: Optional arguments to use when setting up the Iceberg + catalog + snapshot_properties: custom properties write to snapshot when committing + to an iceberg table, e.g. {"commit_time": "2021-01-01T00:00:00Z"} + """ + + from pyiceberg.io import FileIO + from pyiceberg.table import Transaction + from pyiceberg.table.metadata import TableMetadata + + self.table_identifier = table_identifier + self._catalog_kwargs = catalog_kwargs if catalog_kwargs is not None else {} + self._snapshot_properties = ( + snapshot_properties if snapshot_properties is not None else {} + ) + + if "name" in self._catalog_kwargs: + self._catalog_name = self._catalog_kwargs.pop("name") + else: + self._catalog_name = "default" + + self._uuid: str = None + self._io: FileIO = None + self._txn: Transaction = None + self._table_metadata: TableMetadata = None + + # Since iceberg transaction is not pickle-able, because of the table and catalog properties + # we need to exclude the transaction object during serialization and deserialization during pickle + def __getstate__(self) -> dict: + """Exclude `_txn` during pickling.""" + state = self.__dict__.copy() + del state["_txn"] + return state + + def __setstate__(self, state: dict) -> None: + self.__dict__.update(state) + self._txn = None + + def _get_catalog(self) -> "Catalog": + from pyiceberg import catalog + + return catalog.load_catalog(self._catalog_name, **self._catalog_kwargs) + + def on_write_start(self) -> None: + """Prepare for the transaction""" + from pyiceberg.table import PropertyUtil, TableProperties + + catalog = self._get_catalog() + table = catalog.load_table(self.table_identifier) + self._txn = table.transaction() + self._io = self._txn._table.io + self._table_metadata = self._txn.table_metadata + self._uuid = uuid.uuid4() + + if unsupported_partitions := [ + field + for field in self._table_metadata.spec().fields + if not field.transform.supports_pyarrow_transform + ]: + raise ValueError( + f"Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: {unsupported_partitions}." + ) + + self._manifest_merge_enabled = PropertyUtil.property_as_bool( + self._table_metadata.properties, + TableProperties.MANIFEST_MERGE_ENABLED, + TableProperties.MANIFEST_MERGE_ENABLED_DEFAULT, + ) + + def write( + self, blocks: Iterable[Block], ctx: TaskContext + ) -> WriteResult[List["DataFile"]]: + from pyiceberg.io.pyarrow import ( + _check_pyarrow_schema_compatible, + _dataframe_to_data_files, + ) + from pyiceberg.table import DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE + from pyiceberg.utils.config import Config + + data_files_list: WriteResult[List["DataFile"]] = [] + for block in blocks: + pa_table = BlockAccessor.for_block(block).to_arrow() + + downcast_ns_timestamp_to_us = ( + Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False + ) + _check_pyarrow_schema_compatible( + self._table_metadata.schema(), + provided_schema=pa_table.schema, + downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us, + ) + + if pa_table.shape[0] <= 0: + continue + + data_files = _dataframe_to_data_files( + self._table_metadata, pa_table, self._io, self._uuid + ) + data_files_list.extend(data_files) + + return data_files_list + + def on_write_complete(self, write_result: WriteResult[List["DataFile"]]): + update_snapshot = self._txn.update_snapshot( + snapshot_properties=self._snapshot_properties + ) + append_method = ( + update_snapshot.merge_append + if self._manifest_merge_enabled + else update_snapshot.fast_append + ) + + with append_method() as append_files: + append_files.commit_uuid = self._uuid + for data_files in write_result.write_returns: + for data_file in data_files: + append_files.append_data_file(data_file) + + self._txn.commit_transaction() diff --git a/python/ray/data/datasource/mongo_datasink.py b/python/ray/data/datasource/mongo_datasink.py index f2c20355a272..5dca4baf189d 100644 --- a/python/ray/data/datasource/mongo_datasink.py +++ b/python/ray/data/datasource/mongo_datasink.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Iterable +from typing import Iterable from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder from ray.data._internal.execution.interfaces import TaskContext @@ -11,7 +11,7 @@ logger = logging.getLogger(__name__) -class _MongoDatasink(Datasink): +class _MongoDatasink(Datasink[None]): def __init__(self, uri: str, database: str, collection: str) -> None: _check_import(self, module="pymongo", package="pymongo") _check_import(self, module="pymongoarrow", package="pymongoarrow") @@ -24,7 +24,7 @@ def write( self, blocks: Iterable[Block], ctx: TaskContext, - ) -> Any: + ) -> None: import pymongo _validate_database_collection_exist( @@ -44,5 +44,3 @@ def write_block(uri: str, database: str, collection: str, block: Block): block = builder.build() write_block(self.uri, self.database, self.collection, block) - - return "ok" diff --git a/python/ray/data/datasource/parquet_datasink.py b/python/ray/data/datasource/parquet_datasink.py index a8e085e5e0f3..834ff845be61 100644 --- a/python/ray/data/datasource/parquet_datasink.py +++ b/python/ray/data/datasource/parquet_datasink.py @@ -57,13 +57,13 @@ def write( self, blocks: Iterable[Block], ctx: TaskContext, - ) -> Any: + ) -> None: import pyarrow.parquet as pq blocks = list(blocks) if all(BlockAccessor.for_block(block).num_rows() == 0 for block in blocks): - return "skip" + return filename = self.filename_provider.get_filename_for_block( blocks[0], ctx.task_idx, 0 @@ -90,8 +90,6 @@ def write_blocks_to_path(): max_backoff_s=WRITE_FILE_RETRY_MAX_BACKOFF_SECONDS, ) - return "ok" - @property def num_rows_per_write(self) -> Optional[int]: return self.num_rows_per_file diff --git a/python/ray/data/datasource/sql_datasink.py b/python/ray/data/datasource/sql_datasink.py index f29480ae6b1c..f80a8d8127b8 100644 --- a/python/ray/data/datasource/sql_datasink.py +++ b/python/ray/data/datasource/sql_datasink.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Iterable +from typing import Callable, Iterable from ray.data._internal.execution.interfaces import TaskContext from ray.data.block import Block, BlockAccessor @@ -6,7 +6,7 @@ from ray.data.datasource.sql_datasource import Connection, _connect -class _SQLDatasink(Datasink): +class _SQLDatasink(Datasink[None]): _MAX_ROWS_PER_WRITE = 128 @@ -18,7 +18,7 @@ def write( self, blocks: Iterable[Block], ctx: TaskContext, - ) -> Any: + ) -> None: with _connect(self.connection_factory) as cursor: for block in blocks: block_accessor = BlockAccessor.for_block(block) @@ -33,5 +33,3 @@ def write( if values: cursor.executemany(self.sql, values) - - return "ok" diff --git a/python/ray/data/tests/test_bigquery.py b/python/ray/data/tests/test_bigquery.py index e26d2a383437..69ba4cd380da 100644 --- a/python/ray/data/tests/test_bigquery.py +++ b/python/ray/data/tests/test_bigquery.py @@ -1,5 +1,7 @@ +from typing import Iterator from unittest import mock +import pandas as pd import pyarrow as pa import pytest from google.api_core import exceptions, operation @@ -8,7 +10,11 @@ from google.cloud.bigquery_storage_v1.types import stream as gcbqs_stream import ray -from ray.data.datasource import BigQueryDatasource, _BigQueryDatasink +from ray.data.datasource.bigquery_datasink import BigQueryDatasink +from ray.data.datasource.bigquery_datasource import BigQueryDatasource +from ray.data._internal.execution.interfaces.task_context import TaskContext +from ray.data._internal.planner.plan_write_op import generate_collect_write_stats_fn +from ray.data.block import Block from ray.data.tests.conftest import * # noqa from ray.data.tests.mock_http_server import * # noqa from ray.tests.conftest import * # noqa @@ -18,7 +24,6 @@ _TEST_BQ_TABLE_ID = "mocktable" _TEST_BQ_DATASET = _TEST_BQ_DATASET_ID + "." + _TEST_BQ_TABLE_ID _TEST_BQ_TEMP_DESTINATION = _TEST_GCP_PROJECT_ID + ".tempdataset.temptable" -_TEST_DISPLAY_NAME = "display_name" @pytest.fixture(autouse=True) @@ -196,6 +201,9 @@ def test_create_reader_table_not_found(self): class TestWriteBigQuery: """Tests for BigQuery Write.""" + def _extract_write_result(self, stats: Iterator[Block]): + return dict(next(stats).iloc[0]) + def test_write(self, ray_get_mock): bq_datasink = _BigQueryDatasink( project_id=_TEST_GCP_PROJECT_ID, @@ -203,11 +211,24 @@ def test_write(self, ray_get_mock): ) arr = pa.array([2, 4, 5, 100]) block = pa.Table.from_arrays([arr], names=["data"]) - status = bq_datasink.write( + ctx = TaskContext(1) + bq_datasink.write( blocks=[block], - ctx=None, + ctx=ctx, + ) + + collect_stats_fn = generate_collect_write_stats_fn() + stats = collect_stats_fn([block], ctx) + pd.testing.assert_frame_equal( + next(stats), + pd.DataFrame( + { + "num_rows": [4], + "size_bytes": [32], + "write_return": [None], + } + ), ) - assert status == "ok" def test_write_dataset_exists(self, ray_get_mock): bq_datasink = _BigQueryDatasink( @@ -216,11 +237,23 @@ def test_write_dataset_exists(self, ray_get_mock): ) arr = pa.array([2, 4, 5, 100]) block = pa.Table.from_arrays([arr], names=["data"]) - status = bq_datasink.write( + ctx = TaskContext(1) + bq_datasink.write( blocks=[block], - ctx=None, + ctx=ctx, + ) + collect_stats_fn = generate_collect_write_stats_fn() + stats = collect_stats_fn([block], ctx) + pd.testing.assert_frame_equal( + next(stats), + pd.DataFrame( + { + "num_rows": [4], + "size_bytes": [32], + "write_return": [None], + } + ), ) - assert status == "ok" if __name__ == "__main__": diff --git a/python/ray/data/tests/test_datasink.py b/python/ray/data/tests/test_datasink.py index 772078490601..8b5eaff9f2c1 100644 --- a/python/ray/data/tests/test_datasink.py +++ b/python/ray/data/tests/test_datasink.py @@ -1,20 +1,117 @@ -from typing import Any, Iterable +from dataclasses import dataclass +from typing import Iterable, List +import numpy import pytest import ray from ray.data._internal.execution.interfaces import TaskContext -from ray.data.block import Block +from ray.data.block import Block, BlockAccessor from ray.data.datasource import Datasink +from ray.data.datasource.datasink import DummyOutputDatasink, WriteResult + + +def test_write_datasink(ray_start_regular_shared): + output = DummyOutputDatasink() + ds = ray.data.range(10, override_num_blocks=2) + ds.write_datasink(output) + assert output.num_ok == 1 + assert output.num_failed == 0 + assert ray.get(output.data_sink.get_rows_written.remote()) == 10 + + output.enabled = False + ds = ray.data.range(10, override_num_blocks=2) + with pytest.raises(ValueError): + ds.write_datasink(output, ray_remote_args={"max_retries": 0}) + assert output.num_ok == 1 + assert output.num_failed == 1 + assert ray.get(output.data_sink.get_rows_written.remote()) == 10 + + +class NodeLoggerOutputDatasink(Datasink[None]): + """A writable datasource that logs node IDs of write tasks, for testing.""" + + def __init__(self): + @ray.remote + class DataSink: + def __init__(self): + self.rows_written = 0 + self.node_ids = set() + + def write(self, node_id: str, block: Block) -> str: + block = BlockAccessor.for_block(block) + self.rows_written += block.num_rows() + self.node_ids.add(node_id) + + def get_rows_written(self): + return self.rows_written + + def get_node_ids(self): + return self.node_ids + + self.data_sink = DataSink.remote() + self.num_ok = 0 + self.num_failed = 0 + + def write( + self, + blocks: Iterable[Block], + ctx: TaskContext, + ) -> None: + data_sink = self.data_sink + + def write(b): + node_id = ray.get_runtime_context().get_node_id() + return data_sink.write.remote(node_id, b) + + tasks = [] + for b in blocks: + tasks.append(write(b)) + ray.get(tasks) + + def on_write_complete(self, write_result: WriteResult[None]): + self.num_ok += 1 + + def on_write_failed(self, error: Exception) -> None: + self.num_failed += 1 + + +def test_write_datasink_ray_remote_args(ray_start_cluster): + ray.shutdown() + cluster = ray_start_cluster + cluster.add_node( + resources={"foo": 100}, + num_cpus=1, + ) + cluster.add_node(resources={"bar": 100}, num_cpus=1) + + ray.init(cluster.address) + + @ray.remote + def get_node_id(): + return ray.get_runtime_context().get_node_id() + + bar_node_id = ray.get(get_node_id.options(resources={"bar": 1}).remote()) + + output = NodeLoggerOutputDatasink() + ds = ray.data.range(100, override_num_blocks=10) + # Pin write tasks to node with "bar" resource. + ds.write_datasink(output, ray_remote_args={"resources": {"bar": 1}}) + assert output.num_ok == 1 + assert output.num_failed == 0 + assert ray.get(output.data_sink.get_rows_written.remote()) == 100 + + node_ids = ray.get(output.data_sink.get_node_ids.remote()) + assert node_ids == {bar_node_id} @pytest.mark.parametrize("num_rows_per_write", [5, 10, 50]) def test_num_rows_per_write(tmp_path, ray_start_regular_shared, num_rows_per_write): - class MockDatasink(Datasink): + class MockDatasink(Datasink[None]): def __init__(self, num_rows_per_write): self._num_rows_per_write = num_rows_per_write - def write(self, blocks: Iterable[Block], ctx: TaskContext) -> Any: + def write(self, blocks: Iterable[Block], ctx: TaskContext) -> None: assert sum(len(block) for block in blocks) == self._num_rows_per_write @property @@ -26,6 +123,51 @@ def num_rows_per_write(self): ) +def test_write_result(ray_start_regular_shared): + """Test the write_result argument in `on_write_complete`.""" + + @dataclass + class CustomWriteResult: + + ids: List[int] + + class CustomDatasink(Datasink[CustomWriteResult]): + def __init__(self) -> None: + self.ids = [] + self.num_rows = 0 + self.size_bytes = 0 + + def write(self, blocks: Iterable[Block], ctx: TaskContext): + ids = [] + for b in blocks: + ids.extend(b["id"].to_pylist()) + return CustomWriteResult(ids=ids) + + def on_write_complete(self, write_result: WriteResult[CustomWriteResult]): + ids = [] + for result in write_result.write_returns: + ids.extend(result.ids) + self.ids = sorted(ids) + self.num_rows = write_result.num_rows + self.size_bytes = write_result.size_bytes + + num_items = 100 + size_bytes_per_row = 1000 + + def map_fn(row): + row["data"] = numpy.zeros(size_bytes_per_row, dtype=numpy.int8) + return row + + ds = ray.data.range(num_items).map(map_fn) + + datasink = CustomDatasink() + ds.write_datasink(datasink) + + assert datasink.ids == list(range(num_items)) + assert datasink.num_rows == num_items + assert datasink.size_bytes == pytest.approx(num_items * size_bytes_per_row, rel=0.1) + + if __name__ == "__main__": import sys diff --git a/python/ray/data/tests/test_formats.py b/python/ray/data/tests/test_formats.py index 8039c173e4b8..63ce2052eb06 100644 --- a/python/ray/data/tests/test_formats.py +++ b/python/ray/data/tests/test_formats.py @@ -1,5 +1,5 @@ import os -from typing import Any, Iterable, List +import sys import pandas as pd import pyarrow as pa @@ -11,9 +11,7 @@ import ray from ray._private.test_utils import wait_for_condition -from ray.data._internal.execution.interfaces import TaskContext -from ray.data.block import Block, BlockAccessor -from ray.data.datasource import Datasink, DummyOutputDatasink +from ray.data.block import BlockAccessor from ray.data.datasource.file_meta_provider import _handle_read_os_error from ray.data.tests.conftest import * # noqa from ray.data.tests.mock_http_server import * # noqa @@ -144,23 +142,6 @@ def test_read_example_data(ray_start_regular_shared, tmp_path): ] -def test_write_datasink(ray_start_regular_shared): - output = DummyOutputDatasink() - ds = ray.data.range(10, override_num_blocks=2) - ds.write_datasink(output) - assert output.num_ok == 1 - assert output.num_failed == 0 - assert ray.get(output.data_sink.get_rows_written.remote()) == 10 - - output.enabled = False - ds = ray.data.range(10, override_num_blocks=2) - with pytest.raises(ValueError): - ds.write_datasink(output, ray_remote_args={"max_retries": 0}) - assert output.num_ok == 1 - assert output.num_failed == 1 - assert ray.get(output.data_sink.get_rows_written.remote()) == 10 - - def test_from_tf(ray_start_regular_shared): import tensorflow as tf import tensorflow_datasets as tfds @@ -205,86 +186,6 @@ def __iter__(self): assert actual_data == expected_data -class NodeLoggerOutputDatasink(Datasink): - """A writable datasource that logs node IDs of write tasks, for testing.""" - - def __init__(self): - @ray.remote - class DataSink: - def __init__(self): - self.rows_written = 0 - self.node_ids = set() - - def write(self, node_id: str, block: Block) -> str: - block = BlockAccessor.for_block(block) - self.rows_written += block.num_rows() - self.node_ids.add(node_id) - return "ok" - - def get_rows_written(self): - return self.rows_written - - def get_node_ids(self): - return self.node_ids - - self.data_sink = DataSink.remote() - self.num_ok = 0 - self.num_failed = 0 - - def write( - self, - blocks: Iterable[Block], - ctx: TaskContext, - ) -> Any: - data_sink = self.data_sink - - def write(b): - node_id = ray.get_runtime_context().get_node_id() - return data_sink.write.remote(node_id, b) - - tasks = [] - for b in blocks: - tasks.append(write(b)) - ray.get(tasks) - return "ok" - - def on_write_complete(self, write_results: List[Any]) -> None: - assert all(w == "ok" for w in write_results), write_results - self.num_ok += 1 - - def on_write_failed(self, error: Exception) -> None: - self.num_failed += 1 - - -def test_write_datasink_ray_remote_args(ray_start_cluster): - ray.shutdown() - cluster = ray_start_cluster - cluster.add_node( - resources={"foo": 100}, - num_cpus=1, - ) - cluster.add_node(resources={"bar": 100}, num_cpus=1) - - ray.init(cluster.address) - - @ray.remote - def get_node_id(): - return ray.get_runtime_context().get_node_id() - - bar_node_id = ray.get(get_node_id.options(resources={"bar": 1}).remote()) - - output = NodeLoggerOutputDatasink() - ds = ray.data.range(100, override_num_blocks=10) - # Pin write tasks to node with "bar" resource. - ds.write_datasink(output, ray_remote_args={"resources": {"bar": 1}}) - assert output.num_ok == 1 - assert output.num_failed == 0 - assert ray.get(output.data_sink.get_rows_written.remote()) == 100 - - node_ids = ray.get(output.data_sink.get_node_ids.remote()) - assert node_ids == {bar_node_id} - - def test_read_s3_file_error(shutdown_only, s3_path): dummy_path = s3_path + "_dummy" error_message = "Please check that file exists and has properly configured access." diff --git a/python/ray/data/tests/test_operators.py b/python/ray/data/tests/test_operators.py index c962a50a6b56..974d6aa1e2e5 100644 --- a/python/ray/data/tests/test_operators.py +++ b/python/ray/data/tests/test_operators.py @@ -16,6 +16,7 @@ PhysicalOperator, RefBundle, ) +from ray.data._internal.execution.interfaces.task_context import TaskContext from ray.data._internal.execution.operators.actor_pool_map_operator import ( ActorPoolMapOperator, )