From de1106cab6f73a0f3173ccba2d96a368d66b63c6 Mon Sep 17 00:00:00 2001 From: Rusty Conover Date: Fri, 2 Jan 2026 15:44:33 -0500 Subject: [PATCH 1/6] feat: add scalar function support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add support for scalar functions - a new function type that transforms input batches to single-column output with strict 1:1 row mapping. Key features: - ScalarFunctionGenerator: Generator-based base class with process() - ScalarFunction: Callback-based API with compute() method - Single-column output enforced at construction - Row count validation (output must match input rows) - Logging support via yield Message or self.log() - No finalize phase (ends when input exhausted) New files: - vgi/scalar_function.py: Core scalar function classes - vgi/examples/scalar.py: Example functions (double_column, add_columns, upper_case) - tests/scalar/test_function.py: Comprehensive tests Updated: - Worker dispatch for ScalarFunctionGenerator - Client.scalar_function() method for invoking scalar functions - CLI --type option (auto/table/table-in-out/scalar) - Testing utilities (ScalarFunctionTestClient, run_scalar_function, etc.) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- tests/scalar/__init__.py | 1 + tests/scalar/test_function.py | 320 ++++++++++++++++++++++ vgi/__init__.py | 21 +- vgi/client/cli.py | 56 +++- vgi/client/client.py | 233 ++++++++++++++++ vgi/examples/scalar.py | 112 ++++++++ vgi/examples/worker.py | 12 +- vgi/scalar_function.py | 484 ++++++++++++++++++++++++++++++++++ vgi/table_function.py | 10 + vgi/testing.py | 330 +++++++++++++++++++++++ vgi/worker.py | 112 +++++++- 11 files changed, 1673 insertions(+), 18 deletions(-) create mode 100644 tests/scalar/__init__.py create mode 100644 tests/scalar/test_function.py create mode 100644 vgi/examples/scalar.py create mode 100644 vgi/scalar_function.py diff --git a/tests/scalar/__init__.py b/tests/scalar/__init__.py new file mode 100644 index 0000000..487759c --- /dev/null +++ b/tests/scalar/__init__.py @@ -0,0 +1 @@ +"""Tests for scalar functions.""" diff --git a/tests/scalar/test_function.py b/tests/scalar/test_function.py new file mode 100644 index 0000000..333845b --- /dev/null +++ b/tests/scalar/test_function.py @@ -0,0 +1,320 @@ +"""Tests for scalar function base classes.""" + +from __future__ import annotations + +from typing import Any + +import pyarrow as pa +import pytest +import structlog + +from vgi.arguments import Arg +from vgi.function import Arguments, Invocation +from vgi.log import Level, Message +from vgi.scalar_function import ( + Output, + OutputGenerator, + ProtocolInput, + ScalarFunction, + ScalarFunctionGenerator, +) +from vgi.table_function import SchemaValidationError + + +def create_invocation(input_schema: pa.Schema) -> Invocation: + """Create a test invocation with the given input schema.""" + return Invocation( + function_name="test_function", + in_out_function_input_schema=input_schema, + correlation_id="test-correlation", + invocation_id=b"test-invocation", + arguments=Arguments(), + ) + + +class TestScalarFunctionGenerator: + """Tests for the generator-based ScalarFunctionGenerator.""" + + def test_basic_process(self) -> None: + """Test basic processing of batches.""" + + class DoubleColumn(ScalarFunctionGenerator): + @property + def output_schema(self) -> pa.Schema: + return pa.schema([("result", pa.int64())]) + + def process(self, batch: pa.RecordBatch) -> OutputGenerator: + _ = yield None + import pyarrow.compute as pc + + while True: + result = pc.multiply(batch.column("x"), 2) + output = pa.RecordBatch.from_arrays( + [result], schema=self.output_schema + ) + received = yield Output(output) + if received is None: + break + batch = received + + input_schema = pa.schema([("x", pa.int64())]) + invocation = create_invocation(input_schema) + logger = structlog.get_logger() + + func = DoubleColumn(invocation=invocation, logger=logger) + + # Run the protocol + generator = func.run() + next(generator) # Prime + + # Send a batch + input_batch = pa.RecordBatch.from_pydict( + {"x": [1, 2, 3]}, schema=input_schema + ) + output = generator.send(ProtocolInput(batch=input_batch)) + + assert output.batch is not None + assert output.batch.num_rows == 3 + assert output.batch.column("result").to_pylist() == [2, 4, 6] + + def test_requires_input_schema(self) -> None: + """Test that input schema is required.""" + + class TestFunc(ScalarFunctionGenerator): + @property + def output_schema(self) -> pa.Schema: + return pa.schema([("result", pa.int64())]) + + def process(self, batch: pa.RecordBatch) -> OutputGenerator: + _ = yield None + + invocation = Invocation( + function_name="test", + in_out_function_input_schema=None, + correlation_id="test", + invocation_id=None, + arguments=Arguments(), + ) + + with pytest.raises(ValueError, match="requires an input schema"): + TestFunc(invocation=invocation, logger=structlog.get_logger()) + + def test_requires_single_column_output(self) -> None: + """Test that output schema must have exactly one column.""" + + class TwoColumnOutput(ScalarFunctionGenerator): + @property + def output_schema(self) -> pa.Schema: + return pa.schema([("a", pa.int64()), ("b", pa.int64())]) + + def process(self, batch: pa.RecordBatch) -> OutputGenerator: + _ = yield None + + input_schema = pa.schema([("x", pa.int64())]) + invocation = create_invocation(input_schema) + + with pytest.raises(SchemaValidationError, match="exactly 1 output column"): + TwoColumnOutput(invocation=invocation, logger=structlog.get_logger()) + + def test_log_message_support(self) -> None: + """Test that log messages can be yielded.""" + + class LoggingScalar(ScalarFunctionGenerator): + @property + def output_schema(self) -> pa.Schema: + return pa.schema([("result", pa.int64())]) + + def process(self, batch: pa.RecordBatch) -> OutputGenerator: + _ = yield None + import pyarrow.compute as pc + + while True: + yield Message(Level.INFO, f"Processing {batch.num_rows} rows") + result = pc.multiply(batch.column("x"), 2) + output = pa.RecordBatch.from_arrays( + [result], schema=self.output_schema + ) + received = yield Output(output) + if received is None: + break + batch = received + + input_schema = pa.schema([("x", pa.int64())]) + invocation = create_invocation(input_schema) + func = LoggingScalar(invocation=invocation, logger=structlog.get_logger()) + + generator = func.run() + next(generator) + + input_batch = pa.RecordBatch.from_pydict( + {"x": [1, 2, 3]}, schema=input_schema + ) + + # First yield should be the log message + output = generator.send(ProtocolInput(batch=input_batch)) + assert output.log_message is not None + assert output.log_message.level == Level.INFO + assert "Processing 3 rows" in output.log_message.message + + # Re-send to get actual output + output = generator.send(ProtocolInput(batch=input_batch)) + assert output.batch is not None + assert output.batch.column("result").to_pylist() == [2, 4, 6] + + +class TestScalarFunction: + """Tests for the callback-based ScalarFunction.""" + + def test_basic_compute(self) -> None: + """Test basic compute() method.""" + + class DoubleColumn(ScalarFunction): + column = Arg[str](0) + + @property + def output_type(self) -> pa.DataType: + return pa.int64() + + def compute(self, batch: pa.RecordBatch) -> pa.Array[Any]: + import pyarrow.compute as pc + + return pc.multiply(batch.column(self.column), 2) + + input_schema = pa.schema([("x", pa.int64())]) + invocation = Invocation( + function_name="test", + in_out_function_input_schema=input_schema, + correlation_id="test", + invocation_id=None, + arguments=Arguments(positional=(pa.scalar("x"),)), + ) + + func = DoubleColumn(invocation=invocation, logger=structlog.get_logger()) + + # Run the protocol + generator = func.run() + next(generator) + + input_batch = pa.RecordBatch.from_pydict( + {"x": [1, 2, 3]}, schema=input_schema + ) + output = generator.send(ProtocolInput(batch=input_batch)) + + assert output.batch is not None + assert output.batch.num_rows == 3 + assert output.batch.column("result").to_pylist() == [2, 4, 6] + + def test_custom_output_name(self) -> None: + """Test custom output column name.""" + + class CustomName(ScalarFunction): + @property + def output_name(self) -> str: + return "doubled" + + @property + def output_type(self) -> pa.DataType: + return pa.int64() + + def compute(self, batch: pa.RecordBatch) -> pa.Array[Any]: + import pyarrow.compute as pc + + return pc.multiply(batch.column("x"), 2) + + input_schema = pa.schema([("x", pa.int64())]) + invocation = create_invocation(input_schema) + + func = CustomName(invocation=invocation, logger=structlog.get_logger()) + assert func.output_schema.names == ["doubled"] + + def test_log_method(self) -> None: + """Test self.log() method.""" + + class LoggingFunc(ScalarFunction): + @property + def output_type(self) -> pa.DataType: + return pa.int64() + + def compute(self, batch: pa.RecordBatch) -> pa.Array[Any]: + import pyarrow.compute as pc + + self.log(Level.INFO, f"Processing {batch.num_rows} rows") + return pc.multiply(batch.column("x"), 2) + + input_schema = pa.schema([("x", pa.int64())]) + invocation = create_invocation(input_schema) + func = LoggingFunc(invocation=invocation, logger=structlog.get_logger()) + + generator = func.run() + next(generator) + + input_batch = pa.RecordBatch.from_pydict( + {"x": [1, 2, 3]}, schema=input_schema + ) + + # First yield should be the log message + output = generator.send(ProtocolInput(batch=input_batch)) + assert output.log_message is not None + assert output.log_message.level == Level.INFO + assert "Processing 3 rows" in output.log_message.message + + # Re-send to get actual output + output = generator.send(ProtocolInput(batch=input_batch)) + assert output.batch is not None + assert output.batch.column("result").to_pylist() == [2, 4, 6] + + def test_row_count_validation(self) -> None: + """Test that row count mismatch raises error.""" + + class WrongRowCount(ScalarFunction): + @property + def output_type(self) -> pa.DataType: + return pa.int64() + + def compute(self, batch: pa.RecordBatch) -> pa.Array[Any]: + # Return wrong number of rows + return pa.array([1, 2]) + + input_schema = pa.schema([("x", pa.int64())]) + invocation = create_invocation(input_schema) + func = WrongRowCount(invocation=invocation, logger=structlog.get_logger()) + + generator = func.run() + next(generator) + + input_batch = pa.RecordBatch.from_pydict( + {"x": [1, 2, 3]}, schema=input_schema + ) + output = generator.send(ProtocolInput(batch=input_batch)) + + # Should have an exception log message + assert output.log_message is not None + assert output.log_message.level == Level.EXCEPTION + assert "same row count" in output.log_message.message.lower() + + def test_empty_batch(self) -> None: + """Test handling of empty batches.""" + + class DoubleFunc(ScalarFunction): + @property + def output_type(self) -> pa.DataType: + return pa.int64() + + def compute(self, batch: pa.RecordBatch) -> pa.Array[Any]: + import pyarrow.compute as pc + + return pc.multiply(batch.column("x"), 2) + + input_schema = pa.schema([("x", pa.int64())]) + invocation = create_invocation(input_schema) + func = DoubleFunc(invocation=invocation, logger=structlog.get_logger()) + + generator = func.run() + next(generator) + + # Empty batch + input_batch = pa.RecordBatch.from_pydict({"x": []}, schema=input_schema) + output = generator.send(ProtocolInput(batch=input_batch)) + + assert output.batch is not None + assert output.batch.num_rows == 0 diff --git a/vgi/__init__.py b/vgi/__init__.py index 6561b9e..bb52a1b 100644 --- a/vgi/__init__.py +++ b/vgi/__init__.py @@ -59,6 +59,8 @@ class MyWorker(Worker): TableInOutFunction - Callback-based API (recommended) TableInOutGeneratorFunction - Generator-based API (advanced) + ScalarFunction - Scalar function with compute() (single-column output) + ScalarFunctionGenerator - Scalar function with generator protocol Output - Output batch from process()/finalize() OutputGenerator - Type alias for process()/finalize() StreamingGenerator - Type alias for @streaming decorated methods @@ -110,12 +112,14 @@ class Meta: CLASS HIERARCHY --------------- vgi.function.Function - Base (max_processes, invocation_id) - └─ vgi.table_function.TableFunction - Adds cardinality hints, projection - └─ TableInOutGeneratorFunction - Full streaming (process/finalize) - └─ TableInOutFunction - Callback API (transform/finish) - ├─ AggregationFunction - Reduce to summary - ├─ FilterFunction - Row filtering - └─ MapFunction - Column transformation + └─ vgi.table_function.TableFunctionBase - Adds cardinality hints, projection + ├─ TableInOutGeneratorFunction - Full streaming (process/finalize) + │ └─ TableInOutFunction - Callback API (transform/finish) + │ ├─ AggregationFunction - Reduce to summary + │ ├─ FilterFunction - Row filtering + │ └─ MapFunction - Column transformation + └─ ScalarFunctionGenerator - Single-column output (1:1 rows) + └─ ScalarFunction - Callback API (compute) Examples -------- @@ -142,7 +146,9 @@ class Meta: TableInputValidationError, functions_to_arrow, ) +from vgi.scalar_function import ScalarFunction, ScalarFunctionGenerator from vgi.schema_utils import schema, schema_like +from vgi.table_function import RowCountMismatchError from vgi.table_in_out_function import ( Output, OutputGenerator, @@ -178,6 +184,9 @@ class Meta: "OutputGenerator", "ParameterInfo", "ResolvedMetadata", + "RowCountMismatchError", + "ScalarFunction", + "ScalarFunctionGenerator", "StreamingGenerator", "TableInOutFunction", "TableInOutGeneratorFunction", diff --git a/vgi/client/cli.py b/vgi/client/cli.py index cab7199..00f026c 100644 --- a/vgi/client/cli.py +++ b/vgi/client/cli.py @@ -1,4 +1,4 @@ -"""Command-line interface for the VGI client. +r"""Command-line interface for the VGI client. This module provides the CLI entry point for invoking VGI functions. @@ -12,6 +12,10 @@ vgi-client --function sequence --args '[100]' vgi-client --function range --args '[0, 10]' + # Scalar functions (with input, single-column output): + vgi-client --input data.parquet --function double_column \ + --args '["x"]' --type scalar + # Specify table input position (for functions where TableInput isn't first): vgi-client --input data.parquet --function transform --args '["prefix"]' \ --table-input-position 1 @@ -204,6 +208,16 @@ def _create_cli() -> Any: "Used to trace calls back to a specific attachment." ), ) + @click.option( + "--type", + "function_type", + type=click.Choice(["auto", "table", "table-in-out", "scalar"]), + default="auto", + help=( + "Function type. 'auto' (default) uses table-in-out if --input is provided, " + "otherwise table. Use 'scalar' for scalar functions." + ), + ) def cli( input_file: str | None, output_file: str | None, @@ -216,6 +230,7 @@ def cli( max_workers: int | None, table_input_position: int | None, attach_id: str | None, + function_type: str, ) -> None: """Invoke a VGI function and display results.""" try: @@ -257,6 +272,18 @@ def cli( log.info("starting_server", function=function_name, server_path=server_path) + # Validate function_type requirements + if function_type == "scalar" and input_file is None: + raise click.ClickException("--type scalar requires --input to be specified") + if function_type == "table-in-out" and input_file is None: + raise click.ClickException( + "--type table-in-out requires --input to be specified" + ) + if function_type == "table" and input_file is not None: + raise click.ClickException( + "--type table does not accept --input (table functions have no input)" + ) + output_writer: OutputWriter | None = None try: with Client( @@ -265,16 +292,37 @@ def cli( max_workers=max_workers, attach_id=attach_id_bytes, ) as client: - if input_file is None: - # Table function (no input) - use table_function method + # Determine effective function type + if function_type == "auto": + effective_type = "table" if input_file is None else "table-in-out" + else: + effective_type = function_type + + if effective_type == "table": + # Table function (no input) log.info("invoking_table_function", function=function_name) output_iterator = client.table_function( function_name=function_name, arguments=Arguments(positional=positional_args, named={}), projection_ids=list(projection_ids) if projection_ids else None, ) + elif effective_type == "scalar": + # Scalar function (with input, single-column output) + assert input_file is not None # Validated earlier + log.info("invoking_scalar_function", function=function_name) + log.info("reading_input", file=input_file) + pf = pq.ParquetFile(input_file) + + output_iterator = client.scalar_function( + function_name=function_name, + arguments=Arguments(positional=positional_args, named={}), + input=pf.iter_batches(), + projection_ids=list(projection_ids) if projection_ids else None, + ) else: - # Table-in-out function - use table_in_out_function method + # Table-in-out function (with input) + assert input_file is not None # Validated earlier + log.info("invoking_table_in_out_function", function=function_name) log.info("reading_input", file=input_file) pf = pq.ParquetFile(input_file) diff --git a/vgi/client/client.py b/vgi/client/client.py index 02ca523..2792cfd 100644 --- a/vgi/client/client.py +++ b/vgi/client/client.py @@ -40,6 +40,7 @@ client.stop() : Stop the worker subprocess client.table_in_out_function() : Invoke a TableInOutGeneratorFunction and stream results client.table_function() : Invoke a TableFunctionGenerator and stream results +client.scalar_function() : Invoke a ScalarFunction and stream results client.get_worker_stderr() : Get captured stderr from worker See Also @@ -1804,3 +1805,235 @@ def table_function( yield from self._table_function_parallel( primary_output_reader=output_reader, ) + + def scalar_function( + self, + *, + function_name: str, + input: Iterator[pa.RecordBatch], + arguments: Arguments | None = None, + bind_result_callback: Callable[[pa.RecordBatch], None] | None = None, + projection_ids: list[int] | None = None, + ) -> Generator[pa.RecordBatch, None, None]: + """Invoke a scalar function on the worker and stream results. + + Scalar functions transform input batches to single-column output with + 1:1 row mapping. Unlike table_in_out_function, scalar functions have + no finalize phase - processing ends when input is exhausted. + + Processing flow: + 1. Reads the first input batch to determine the input schema + 2. Sends Invocation to worker and receives bind result + 3. Spawns additional workers if max_processes > 1 + 4. Distributes input batches to workers (round-robin for parallel mode) + 5. Collects output batches, handling HAVE_MORE_OUTPUT for log messages + 6. Returns when input is exhausted (no FINALIZE signal) + + For parallel processing (max_processes > 1), input batches are distributed + round-robin across workers using dedicated threads. Output order may not + match input order in parallel mode. + + Args: + function_name: Name of the function to invoke. Must exist in the + worker's registry. + input: Iterator yielding input RecordBatches. Must yield at least one + batch. The first batch's schema is used to initialize the IPC + stream. If the iterator is empty, no output is produced. + arguments: Optional Arguments container with positional and named + arguments to pass to the function. Defaults to empty Arguments(). + bind_result_callback: Optional callback invoked with the raw bind + result RecordBatch before processing begins. Useful for inspecting + output schema or max_processes. + projection_ids: Optional list of column indices for column projection. + Passed to the worker via GlobalStateInitInput. + + Yields: + Output RecordBatches from the function. Each output batch has a single + column and the same number of rows as its corresponding input batch. + In single-worker mode, output order corresponds to input order. + In parallel mode (max_processes > 1), output order is non-deterministic. + + Raises: + ClientError: If the client is not started, input iterator yields + non-RecordBatch objects, communication with the worker fails, + or the worker returns an unexpected status or exception. + + Example: + >>> with Client("vgi-example-worker") as client: + ... batches = [pa.RecordBatch.from_pydict({"x": [1, 2, 3]})] + ... for output in client.scalar_function( + ... function_name="double_column", + ... input=iter(batches), + ... arguments=Arguments(positional=[pa.scalar("x")]), + ... ): + ... print(output.to_pydict()) + {'result': [2, 4, 6]} + + """ + if arguments is None: + arguments = Arguments() + + if ( + self._proc is None + or self._stdin_sink is None + or self._stdout_buffered is None + ): + raise ClientError( + "Client not started. Call start() or use context manager." + ) + + # Get the first batch to determine schema and initialize + for input_batch in input: + if not isinstance(input_batch, pa.RecordBatch): + raise ClientError("Input iterator must yield RecordBatches") + + input_schema = input_batch.schema + data_writer, _ = self._initialize_function_stream( + function_name=function_name, + arguments=arguments, + input_schema=input_schema, + bind_result_callback=bind_result_callback, + projection_ids=projection_ids, + ) + + # Use parallel processing for all cases (handles both single and + # multi-worker) + assert data_writer is not None # set when input_schema is not None + yield from self._scalar_function_parallel( + input_batch=input_batch, + input_iterator=input, + data_writer=data_writer, + ) + return + + def _scalar_function_parallel( + self, + *, + input_batch: pa.RecordBatch, + input_iterator: Iterator[pa.RecordBatch], + data_writer: ipc.RecordBatchStreamWriter, + ) -> Generator[pa.RecordBatch, None, None]: + """Process scalar function batches across one or more workers using threads. + + Similar to _table_in_out_function_parallel but without finalization. + Handles both single-worker and multi-worker cases uniformly. + + Processing flow: + 1. Creates worker connection objects for primary + additional workers + 2. Starts one thread per worker running _worker_thread_loop + 3. Distributes input batches round-robin to worker input queues + 4. Signals end-of-input to all workers via None sentinel + 5. Collects all output batches from shared output queue + 6. Waits for worker threads to complete + 7. Closes all workers (no finalize phase) + + Args: + input_batch: The first input batch, already consumed from the + iterator by scalar_function(). + input_iterator: Iterator for remaining input batches. May be empty + if all input was in the first batch. + data_writer: IPC stream writer for the primary worker, already + initialized by _initialize_function_stream(). + + Yields: + Output RecordBatches from processing, in non-deterministic order for + multi-worker mode. When multiple batches are returned for a single + input (HAVE_MORE_OUTPUT for logs), they are combined into one batch. + + Raises: + ClientError: If a worker thread fails with an exception. + + """ + primary_worker = self._create_primary_worker(data_writer=data_writer) + all_workers = [primary_worker] + self._additional_workers + num_workers = len(all_workers) + + log.debug("starting_scalar_parallel_processing", num_workers=num_workers) + + # Create queues for each worker + input_queues: list[Queue[tuple[int, pa.RecordBatch] | None]] = [ + Queue() for _ in range(num_workers) + ] + output_queue: Queue[tuple[int, list[pa.RecordBatch]] | BaseException] = Queue() + + # Start worker threads + threads: list[threading.Thread] = [] + for i, worker in enumerate(all_workers): + thread = threading.Thread( + target=self._worker_thread_loop, + args=(worker, input_queues[i], output_queue), + daemon=True, + ) + thread.start() + threads.append(thread) + + # Distribute batches round-robin across workers + batch_index = 0 + batches_sent = 0 + + # Send first batch + worker_idx = batch_index % num_workers + input_queues[worker_idx].put((batch_index, input_batch)) + batches_sent += 1 + batch_index += 1 + + # Send remaining batches + for input_batch in input_iterator: + worker_idx = batch_index % num_workers + input_queues[worker_idx].put((batch_index, input_batch)) + batches_sent += 1 + batch_index += 1 + + # Signal end of input to all workers + for q in input_queues: + q.put(None) + + log.debug("scalar_all_batches_distributed", total_batches=batches_sent) + + # Collect outputs from all workers + # We expect batches_sent regular outputs + num_workers thread completion signals + outputs_expected = batches_sent + num_workers + outputs_received = 0 + + while outputs_received < outputs_expected: + result = output_queue.get() + + # Check for exceptions from worker threads + if isinstance(result, BaseException): + raise ClientError(f"Worker thread failed: {result}") from result + + batch_idx, output_batches = result + outputs_received += 1 + + # Combine output batches if needed + combined = self._combine_batches(output_batches) + if combined is not None: + yield combined + + log.debug( + "scalar_output_received", + batch_index=batch_idx, + outputs_received=outputs_received, + outputs_expected=outputs_expected, + ) + + self._join_threads(threads) + log.debug("all_scalar_worker_threads_complete") + + # Close all workers (no finalize for scalar functions) + # Close data writers to signal EOF to workers + for worker in all_workers: + if worker.data_writer is not None: + worker.data_writer.close() + + # Wait for secondary workers to exit + secondary_workers = all_workers[1:] + for worker in secondary_workers: + worker.proc.wait(timeout=self.PROCESS_WAIT_TIMEOUT) + log.debug( + "scalar_secondary_worker_exited", + worker_index=worker.worker_index, + returncode=worker.proc.returncode, + ) + + log.debug("scalar_parallel_processing_complete") diff --git a/vgi/examples/scalar.py b/vgi/examples/scalar.py new file mode 100644 index 0000000..021df37 --- /dev/null +++ b/vgi/examples/scalar.py @@ -0,0 +1,112 @@ +"""Example scalar function implementations. + +This module provides example scalar functions that transform input batches +to single-column output with 1:1 row mapping. + +AVAILABLE FUNCTIONS +------------------- +DoubleColumnFunction - Doubles values in a numeric column +AddColumnsFunction - Adds two numeric columns +UpperCaseFunction - Converts string column to uppercase +""" + +from __future__ import annotations + +from typing import Any, cast + +import pyarrow as pa +import pyarrow.compute as pc + +from vgi.arguments import Arg +from vgi.scalar_function import ScalarFunction + +__all__ = [ + "DoubleColumnFunction", + "AddColumnsFunction", + "UpperCaseFunction", +] + + +class DoubleColumnFunction(ScalarFunction): + """Doubles values in a numeric column. + + Example: + Input: x=[1, 2, 3] + Args: column="x" + Output: result=[2, 4, 6] + + """ + + class Meta: + """Function metadata.""" + + name = "double_column" + description = "Doubles values in a numeric column" + + column = Arg[str](0, doc="Column name to double") + + @property + def output_type(self) -> pa.DataType: + """Return the type of the doubled column.""" + return cast(pa.DataType, self.input_schema.field(self.column).type) + + def compute(self, batch: pa.RecordBatch) -> pa.Array[Any]: + """Double the values in the specified column.""" + return pc.multiply(batch.column(self.column), 2) + + +class AddColumnsFunction(ScalarFunction): + """Adds two numeric columns together. + + Example: + Input: a=[1, 2, 3], b=[10, 20, 30] + Args: col1="a", col2="b" + Output: result=[11, 22, 33] + + """ + + class Meta: + """Function metadata.""" + + name = "add_columns" + description = "Adds two numeric columns" + + col1 = Arg[str](0, doc="First column name") + col2 = Arg[str](1, doc="Second column name") + + @property + def output_type(self) -> pa.DataType: + """Return the type of the first column.""" + return cast(pa.DataType, self.input_schema.field(self.col1).type) + + def compute(self, batch: pa.RecordBatch) -> pa.Array[Any]: + """Add the two columns together.""" + return pc.add(batch.column(self.col1), batch.column(self.col2)) + + +class UpperCaseFunction(ScalarFunction): + """Converts a string column to uppercase. + + Example: + Input: name=["alice", "bob", "charlie"] + Args: column="name" + Output: result=["ALICE", "BOB", "CHARLIE"] + + """ + + class Meta: + """Function metadata.""" + + name = "upper_case" + description = "Converts string column to uppercase" + + column = Arg[str](0, doc="Column name to uppercase") + + @property + def output_type(self) -> pa.DataType: + """Return string type.""" + return pa.string() + + def compute(self, batch: pa.RecordBatch) -> pa.Array[Any]: + """Convert the column values to uppercase.""" + return pc.utf8_upper(batch.column(self.column)) diff --git a/vgi/examples/worker.py b/vgi/examples/worker.py index e7331e6..1bb7ea9 100644 --- a/vgi/examples/worker.py +++ b/vgi/examples/worker.py @@ -4,14 +4,20 @@ and listing function classes. Function names are derived from each class's metadata (Meta.name or snake_case of class name). -The worker supports both: +The worker supports: - TableInOutGeneratorFunction: Transforms input batches to output batches - TableFunctionGenerator: Generates output batches without input +- ScalarFunctionGenerator: Transforms input to single-column output (1:1 rows) Usage: vgi-example-worker """ +from vgi.examples.scalar import ( + AddColumnsFunction, + DoubleColumnFunction, + UpperCaseFunction, +) from vgi.examples.table import ( ConstantTableFunction, GeneratorExceptionFunction, @@ -59,6 +65,10 @@ class ExampleWorker(Worker): LoggingGeneratorFunction, PartitionedRangeFunction, ProjectedDataFunction, + # ScalarFunctionGenerator - transform to single-column output + DoubleColumnFunction, + AddColumnsFunction, + UpperCaseFunction, ] diff --git a/vgi/scalar_function.py b/vgi/scalar_function.py new file mode 100644 index 0000000..6c2bd81 --- /dev/null +++ b/vgi/scalar_function.py @@ -0,0 +1,484 @@ +"""Base classes for scalar functions that transform input batches to single-column. + +Scalar functions transform input batches to single-column output. + +Scalar functions receive input batches and produce output batches where: +1. Output row count must exactly match input row count (1:1 mapping) +2. Output schema has exactly one column + +This module provides: +- ScalarFunctionGenerator: Generator-based base class (like TableInOutGeneratorFunction) +- ScalarFunction: Callback-based API with compute() method (like TableInOutFunction) + +Class Hierarchy: + TableFunctionBase (vgi.table_function) + └── ScalarFunctionGenerator (generator protocol, validates row count) + └── ScalarFunction (callback API with compute()) + +ScalarFunctionGenerator is useful for functions that need full generator control +including yielding log messages. For most use cases, use ScalarFunction with its +simpler compute() method. +""" + +from __future__ import annotations + +from abc import abstractmethod +from collections.abc import Generator +from dataclasses import dataclass +from typing import Any, final + +import pyarrow as pa +import structlog + +import vgi.function +import vgi.log +import vgi.table_function +from vgi.table_function import RowCountMismatchError, SchemaValidationError + +__all__ = [ + "ScalarFunctionGenerator", + "ScalarFunction", + "Output", + "OutputGenerator", + "ProtocolInput", +] + + +# Protocol types - reuse from table_in_out_function +from vgi.table_in_out_function import ( # noqa: E402 + Output, + OutputGenerator, + ProtocolInput, + ProtocolOutput, + _OutputStatus, +) + + +@dataclass(frozen=True, slots=True) +class _ScalarOutputComplete: + """Internal: Output with guaranteed non-None batch for scalar functions. + + Similar to _OutputComplete in table_in_out_function, but tracks the input + batch for row count validation. + """ + + batch: pa.RecordBatch + has_more: bool = False + log_message: vgi.log.Message | None = None + + @classmethod + def from_process_result( + cls, + source: vgi.log.Message | Output | None, + empty_batch: pa.RecordBatch, + ) -> _ScalarOutputComplete: + """Create from user's yield value. + + Args: + source: What the user yielded (Output, Message, or None). + empty_batch: Empty batch to substitute when needed. + + Returns: + Normalized output with guaranteed non-None batch. + + """ + if source is None: + return cls(batch=empty_batch) + if isinstance(source, vgi.log.Message): + return cls(batch=empty_batch, has_more=True, log_message=source) + # source is Output + return cls( + batch=source.batch if source.batch is not None else empty_batch, + has_more=source.has_more, + ) + + +class ScalarFunctionGenerator(vgi.table_function.TableFunctionBase): + """Base class for scalar functions with generator protocol. + + Scalar functions transform input batches to single-column output with + 1:1 row mapping. Unlike TableInOutGeneratorFunction, scalar functions: + - Have no finalize() phase + - Must produce exactly one output row per input row + - Must have exactly one column in output_schema + + Override process() for full generator control. Can yield Output or Message: + + def process(self, batch: pa.RecordBatch) -> OutputGenerator: + _ = yield None # Priming yield + while True: + # Optional: yield log messages + yield Message(Level.INFO, f"Processing {batch.num_rows} rows") + + result_array = compute_result(batch) + output_batch = pa.RecordBatch.from_arrays( + [result_array], schema=self.output_schema + ) + batch = yield Output(output_batch) + if batch is None: + break + + METHODS TO OVERRIDE + ------------------- + output_schema -> pa.Schema (property) + Override to define the single-column output schema. + + process(batch: pa.RecordBatch) -> OutputGenerator + Generator that processes input batches. Must yield Output with + batch.num_rows matching input batch.num_rows. + + setup() -> None + Called before processing starts. Default: no-op. + + teardown() -> None + Called after processing completes. Default: no-op. + + AVAILABLE ATTRIBUTES + -------------------- + self.invocation: Invocation - The complete invocation request + self.input_schema: pa.Schema - Input schema (from invocation) + self.output_schema: pa.Schema - Property returning the output schema + self.empty_output_batch - Empty batch conforming to output_schema + """ + + def __init__( + self, + invocation: vgi.function.Invocation, + logger: structlog.stdlib.BoundLogger, + ): + """Initialize the scalar function with invocation data and logger.""" + super().__init__(invocation=invocation, logger=logger) + if invocation.in_out_function_input_schema is None: + raise ValueError( + f"{type(self).__name__} requires an input schema, but none was " + f"provided. ScalarFunction processes input batches and requires " + f"in_out_function_input_schema to be set in the Invocation." + ) + # Validate single-column output at construction + if len(self.output_schema) != 1: + raise SchemaValidationError( + f"ScalarFunction must have exactly 1 output column, " + f"got {len(self.output_schema)}: {self.output_schema}" + ) + + @property + def input_schema(self) -> pa.Schema: + """Return the input schema from the invocation.""" + # Validated as non-None in __init__ + assert self.invocation.in_out_function_input_schema is not None + return self.invocation.in_out_function_input_schema + + def teardown(self) -> None: + """Release resources after processing completes. + + Override to release resources acquired in setup(). + Always called, even if an error occurred during processing. + """ + pass + + @final + def _validate_input_schema(self, batch: pa.RecordBatch) -> None: + """Validate that a batch conforms to the expected input schema.""" + if batch.schema != self.input_schema: + raise SchemaValidationError( + f"Input batch schema does not match expected input_schema. " + f"Expected: {self.input_schema}, got: {batch.schema}" + ) + + @final + def _validate_row_count( + self, output_batch: pa.RecordBatch, input_batch: pa.RecordBatch + ) -> None: + """Validate that output row count matches input row count.""" + if output_batch.num_rows != input_batch.num_rows: + raise RowCountMismatchError( + f"ScalarFunction output must have same row count as input. " + f"Input: {input_batch.num_rows}, Output: {output_batch.num_rows}" + ) + + @final + def _process_and_validate( + self, + generator: OutputGenerator, + input_batch: pa.RecordBatch, + ) -> _ScalarOutputComplete: + """Process a batch and validate schemas and row count. + + Args: + generator: The user's process() generator. + input_batch: The input RecordBatch to process. + + Returns: + _ScalarOutputComplete with validated output batch. + + Raises: + SchemaValidationError: If input or output batch schema doesn't match. + RowCountMismatchError: If output row count doesn't match input. + + """ + self._validate_input_schema(input_batch) + result: _ScalarOutputComplete = _ScalarOutputComplete.from_process_result( + generator.send(input_batch), + self.empty_output_batch, + ) + self._validate_output_schema(result.batch) + # Only validate row count for actual output, not log messages + if result.log_message is None and result.batch.num_rows > 0: + self._validate_row_count(result.batch, input_batch) + return result + + @final + def _process_with_exception_handling( + self, + generator: OutputGenerator, + input_batch: pa.RecordBatch, + ) -> _ScalarOutputComplete: + """Process a batch with exception handling. + + Wraps _process_and_validate to catch exceptions and convert them + to _ScalarOutputComplete with an error log message. + """ + try: + return self._process_and_validate(generator, input_batch) + except Exception as e: + return _ScalarOutputComplete( + batch=self.empty_output_batch, + log_message=vgi.log.Message.from_exception(e), + ) + + @final + def _should_terminate(self, result: _ScalarOutputComplete) -> bool: + """Check if processing should terminate due to an exception.""" + return ( + result.log_message is not None + and result.log_message.level == vgi.log.Level.EXCEPTION + ) + + @abstractmethod + def process(self, batch: pa.RecordBatch) -> OutputGenerator: + """Process input batches. + + Override this method to implement your scalar transformation. + The generator must yield Output with batch.num_rows matching + input batch.num_rows. + + Args: + batch: First input batch (subsequent batches via yield return). + + Yields: + Output: Batch with same row count as input. + Message: Log message (input will be re-sent). + None: No output (ready for next batch). + + """ + ... + + @final + def run(self) -> Generator[ProtocolOutput, ProtocolInput | None, None]: + """Run the scalar function protocol. Do not override. + + This generator implements the SETUP -> DATA -> TEARDOWN lifecycle. + No FINALIZE phase for scalar functions. + + Protocol: + - Caller primes with next() or send(None) + - Caller sends ProtocolInput for each batch + - When input exhausted, generator closes (no FINALIZE signal needed) + """ + # Priming yield - caller calls next() or send(None) + input: ProtocolInput | None = yield ProtocolOutput( + batch=None, status=_OutputStatus.NEED_MORE_INPUT + ) + if input is None: + raise ValueError("Expected ProtocolInput, got None") + + # Acquire resources before processing + self.setup() + + generator = self.process(input.batch) + # Prime the process() generator past the initial yield + generator.send(None) + + try: + # DATA phase - no FINALIZE for scalar functions + while not input.is_finalize: + result = self._process_with_exception_handling(generator, input.batch) + + # Determine status based on result + has_more_output = result.has_more or result.log_message is not None + if has_more_output: + status = _OutputStatus.HAVE_MORE_OUTPUT + else: + status = _OutputStatus.NEED_MORE_INPUT + + input = yield ProtocolOutput( + batch=result.batch, + status=status, + log_message=result.log_message, + ) + if input is None: + raise ValueError("Expected ProtocolInput, got None") + if self._should_terminate(result): + return + + # When FINALIZE signal comes, just emit FINISHED (no finalize phase) + yield ProtocolOutput( + batch=self.empty_output_batch, status=_OutputStatus.FINISHED + ) + finally: + generator.close() + # Release resources after processing completes + self.teardown() + + +class ScalarFunction(ScalarFunctionGenerator): + """Simplified base class using compute() callback instead of generators. + + This class provides a simpler API for scalar functions. Instead of + implementing process() as a generator, you override compute() as a + regular method that returns a single Array. + + METHODS TO OVERRIDE + ------------------- + output_type -> pa.DataType (property) + Return the Arrow type for the output column. + + compute(batch) -> pa.Array + Transform the input batch to a single output array. + Must return an array with exactly batch.num_rows elements. + + output_name -> str (property, optional) + Return the name of the output column. Default: "result" + + LOGGING + ------- + Call self.log(level, message) from compute() to emit log messages: + + def compute(self, batch: pa.RecordBatch) -> pa.Array: + self.log(Level.INFO, f"Processing {batch.num_rows} rows") + return pc.multiply(batch.column("x"), 2) + + Example: + ------- + class DoubleColumn(ScalarFunction): + column = Arg[str](0, doc="Column to double") + + @property + def output_type(self) -> pa.DataType: + return self.input_schema.field(self.column).type + + def compute(self, batch: pa.RecordBatch) -> pa.Array: + return pc.multiply(batch.column(self.column), 2) + + """ + + # Message queue for log() method (same pattern as TableInOutFunction) + _pending_messages: list[vgi.log.Message] + + def __init__( + self, + invocation: vgi.function.Invocation, + logger: structlog.stdlib.BoundLogger, + ): + """Initialize the scalar function.""" + # Initialize pending messages before super().__init__ because + # output_schema property may be accessed during init + self._pending_messages = [] + super().__init__(invocation=invocation, logger=logger) + + def log(self, level: vgi.log.Level, message: str) -> None: + """Queue a log message to be emitted with the output. + + Messages are yielded before the compute() result. + + Args: + level: Log level (DEBUG, INFO, WARNING, ERROR). + message: Log message text. + + Example: + def compute(self, batch: pa.RecordBatch) -> pa.Array: + self.log(Level.INFO, f"Processing {batch.num_rows} rows") + return pc.multiply(batch.column(self.column), 2) + + """ + self._pending_messages.append(vgi.log.Message(level=level, message=message)) + + @property + def output_name(self) -> str: + """Return the name of the output column. Override to customize.""" + return "result" + + @property + @abstractmethod + def output_type(self) -> pa.DataType: + """Return the Arrow type for the output column. + + Override this property to specify the output column type. + + Example: + @property + def output_type(self) -> pa.DataType: + return pa.int64() + + # Or derive from input: + @property + def output_type(self) -> pa.DataType: + return self.input_schema.field(self.column).type + + """ + ... + + @property + @final + def output_schema(self) -> pa.Schema: + """Return single-column output schema. Do not override.""" + return pa.schema([pa.field(self.output_name, self.output_type)]) + + @abstractmethod + def compute(self, batch: pa.RecordBatch) -> pa.Array[Any]: + """Compute output array from input batch. + + Override this method to implement your scalar transformation. + + Args: + batch: Input RecordBatch. + + Returns: + Array with exactly batch.num_rows elements. + + Example: + def compute(self, batch: pa.RecordBatch) -> pa.Array[Any]: + return pc.multiply(batch.column("x"), 2) + + """ + ... + + @final + def _yield_pending_messages(self) -> OutputGenerator: + """Yield all pending log messages. Helper for process().""" + while self._pending_messages: + msg = self._pending_messages.pop(0) + _ = yield msg + + @final + def process(self, batch: pa.RecordBatch) -> OutputGenerator: + """Convert compute() to generator protocol. Do not override. + + This method implements the generator protocol by calling your compute() + method for each input batch. + """ + _ = yield None # Priming yield + + while True: + result = self.compute(batch) + + # Yield any pending log messages first + yield from self._yield_pending_messages() + + # Create output batch from result array + output = pa.RecordBatch.from_arrays([result], schema=self.output_schema) + received = yield Output(output) + + if received is None: + break + batch = received diff --git a/vgi/table_function.py b/vgi/table_function.py index e8698c1..7cbfb6a 100644 --- a/vgi/table_function.py +++ b/vgi/table_function.py @@ -40,6 +40,7 @@ "OutputGenerator", "OutputSpec", "ProtocolOutput", + "RowCountMismatchError", "SchemaValidationError", "TableFunctionBase", "TableFunctionGenerator", @@ -55,6 +56,15 @@ class SchemaValidationError(Exception): """ +class RowCountMismatchError(Exception): + """Raised when scalar function output row count doesn't match input. + + Scalar functions must produce exactly one output row for each input row. + This error indicates the compute() method returned an array with the + wrong number of elements. + """ + + @dataclass(frozen=True, slots=True) class CardinalityInfo: """Cardinality hints for query optimization. diff --git a/vgi/testing.py b/vgi/testing.py index f2c8520..bf26aaf 100644 --- a/vgi/testing.py +++ b/vgi/testing.py @@ -84,6 +84,13 @@ from vgi.function import Arguments, Invocation from vgi.log import Level, Message +from vgi.scalar_function import ( + ProtocolInput as ScalarProtocolInput, +) +from vgi.scalar_function import ( + ScalarFunction, + ScalarFunctionGenerator, +) from vgi.table_function import ( GlobalStateInitInput, TableFunctionGenerator, @@ -103,12 +110,15 @@ "FunctionTestClient", "FunctionTestClientError", "TableFunctionTestClient", + "ScalarFunctionTestClient", "batch", "assert_function_output", "assert_function_logs", "run_function", "run_table_function", "assert_table_function_output", + "run_scalar_function", + "assert_scalar_function_output", ] @@ -895,3 +905,323 @@ def assert_table_function_output( ) return logs + + +# ============================================================================= +# Scalar Function Test Client and Helpers +# ============================================================================= + + +class ScalarFunctionTestClient: + """In-process client for testing ScalarFunction and ScalarFunctionGenerator. + + Scalar functions transform input batches to single-column output with 1:1 + row mapping. Unlike TableInOut functions, scalar functions have no finalize + phase. + + Example: + with ScalarFunctionTestClient(DoubleColumnFunction) as client: + outputs = list(client.scalar_function( + input=iter([batch]), + arguments=Arguments(positional=(pa.scalar("x"),)), + )) + + Attributes: + logs: List of log messages emitted during the last function call. + + """ + + def __init__( + self, + function_class: type[ScalarFunctionGenerator] | type[ScalarFunction], + ) -> None: + """Initialize the ScalarFunctionTestClient. + + Args: + function_class: The scalar function class to test (not an instance). + + """ + self.function_class = function_class + self.logs: list[Message] = [] + self._logger: structlog.stdlib.BoundLogger = structlog.get_logger().bind( + component="scalar_test_client" + ) + + def __enter__(self) -> "ScalarFunctionTestClient": + """Enter context manager.""" + return self + + def __exit__(self, _exc_type: Any, _exc_val: Any, _exc_tb: Any) -> None: + """Exit context manager.""" + pass + + def scalar_function( + self, + *, + input: Iterator[pa.RecordBatch], + arguments: Arguments | None = None, + bind_result_callback: Callable[[pa.RecordBatch], None] | None = None, + ) -> Generator[pa.RecordBatch, None, None]: + """Call the scalar function with the given input data. + + This method implements the VGI scalar function protocol directly in-process, + without any IPC or subprocess communication. + + Args: + input: Iterator yielding input RecordBatches. + arguments: Arguments container with positional and named arguments. + bind_result_callback: Optional callback invoked with the bind result. + + Yields: + Output RecordBatches from the function (single-column). + + Raises: + FunctionTestClientError: If the function raises an exception. + + """ + # Clear logs from previous invocation + self.logs = [] + + if arguments is None: + arguments = Arguments() + + # Get first batch to determine input schema + try: + first_batch = next(input) + except StopIteration: + # No input batches - nothing to process + return + + input_schema = first_batch.schema + + # Create invocation + invocation_id = uuid.uuid4().bytes + invocation = Invocation( + function_name=self.function_class.__name__, + arguments=arguments, + in_out_function_input_schema=input_schema, + correlation_id="test", + invocation_id=invocation_id, + ) + + # Instantiate function + func = self.function_class(invocation=invocation, logger=self._logger) + + # Create bind result for callback + if bind_result_callback is not None: + bind_batch = pa.RecordBatch.from_pylist( + [ + { + "output_schema": func.output_schema.serialize().to_pybytes(), + "max_processes": func.max_processes(), + "invocation_id": invocation_id, + } + ], + schema=pa.schema( + cast( + list[tuple[str, pa.DataType]], + [ + ("output_schema", pa.binary()), + ("max_processes", pa.int64()), + ("invocation_id", pa.binary()), + ], + ) + ), + ) + bind_result_callback(bind_batch) + + # Get the run generator + generator = func.run() + + # Prime the generator + try: + priming_output = next(generator) + assert priming_output.status == _OutputStatus.NEED_MORE_INPUT + except StopIteration: + return + + # Process first batch + yield from self._process_scalar_batch(generator, first_batch) + + # Process remaining batches + for batch in input: + yield from self._process_scalar_batch(generator, batch) + + # No finalize for scalar functions - just close + generator.close() + + def _process_scalar_batch( + self, + generator: Generator[ProtocolOutput, ScalarProtocolInput | None, None], + batch: pa.RecordBatch, + ) -> Generator[pa.RecordBatch, None, None]: + """Process a single input batch, handling HAVE_MORE_OUTPUT for logs.""" + while True: + try: + output = generator.send(ScalarProtocolInput(batch=batch)) + except StopIteration: + return + + # Capture log message if present + if output.log_message is not None: + self.logs.append(output.log_message) + # Check for exception + if output.log_message.level == Level.EXCEPTION: + raise FunctionTestClientError(output.log_message.message) + + # Yield output batch if it has rows + if output.batch is not None and output.batch.num_rows > 0: + yield output.batch + + # Check status + if output.status == _OutputStatus.HAVE_MORE_OUTPUT: + # Re-send the same batch to get more output (log messages) + continue + elif output.status == _OutputStatus.NEED_MORE_INPUT: + # Ready for next input batch + break + elif output.status == _OutputStatus.FINISHED: + # Scalar function ended + return + else: + raise FunctionTestClientError(f"Unexpected status: {output.status}") + + +def run_scalar_function( + function: type[ScalarFunctionGenerator] | type[ScalarFunction], + input_batches: list[pa.RecordBatch], + args: tuple[Any, ...] | None = None, + kwargs: dict[str, Any] | None = None, +) -> tuple[list[pa.RecordBatch], list[Message]]: + """Run a scalar function and return outputs and logs. + + A convenience wrapper around ScalarFunctionTestClient for simple test cases. + + Args: + function: The scalar function class to test. + input_batches: List of input RecordBatches. + args: Optional positional arguments as a tuple. + kwargs: Optional named arguments as a dict. + + Returns: + Tuple of (output_batches, log_messages). + + Example: + outputs, logs = run_scalar_function( + DoubleColumnFunction, + [batch(x=[1, 2, 3])], + args=("x",), + ) + assert outputs[0].to_pydict() == {"result": [2, 4, 6]} + + """ + # Build Arguments from args/kwargs + positional: tuple[pa.Scalar[Any], ...] = () + named: dict[str, pa.Scalar[Any]] = {} + + if args: + positional = tuple(pa.scalar(a) for a in args) + if kwargs: + named = {k: pa.scalar(v) for k, v in kwargs.items()} + + arguments = Arguments(positional=positional, named=named) + + with ScalarFunctionTestClient(function) as client: + outputs = list( + client.scalar_function( + input=iter(input_batches), + arguments=arguments, + ) + ) + return outputs, client.logs + + +def assert_scalar_function_output( + function: type[ScalarFunctionGenerator] | type[ScalarFunction], + input: list[pa.RecordBatch], + expected: list[pa.RecordBatch], + args: tuple[Any, ...] | None = None, + kwargs: dict[str, Any] | None = None, + check_order: bool = True, + msg: str | None = None, +) -> list[Message]: + """Assert that a scalar function produces expected output batches. + + Runs the function with the given input and compares output to expected batches. + Returns captured log messages for additional assertions. + + Args: + function: The scalar function class to test. + input: List of input RecordBatches. + expected: List of expected output RecordBatches. + args: Optional positional arguments as a tuple. + kwargs: Optional named arguments as a dict. + check_order: If True, order of output batches must match. Default True. + msg: Optional custom assertion message prefix. + + Returns: + List of log messages captured during execution. + + Raises: + AssertionError: If output doesn't match expected. + FunctionTestClientError: If the function raises an exception. + + Examples: + # Double column test + assert_scalar_function_output( + DoubleColumnFunction, + input=[batch(x=[1, 2, 3])], + expected=[batch(result=[2, 4, 6])], + args=("x",), + ) + + # Add columns test + assert_scalar_function_output( + AddColumnsFunction, + input=[batch(a=[1, 2], b=[10, 20])], + expected=[batch(result=[11, 22])], + args=("a", "b"), + ) + + """ + outputs, logs = run_scalar_function( + function=function, + input_batches=input, + args=args, + kwargs=kwargs, + ) + + prefix = f"{msg}: " if msg else "" + + # Check batch count + if len(outputs) != len(expected): + actual_rows = [o.num_rows for o in outputs] + expected_rows = [e.num_rows for e in expected] + raise AssertionError( + f"{prefix}Expected {len(expected)} output batches, got {len(outputs)}. " + f"Output rows: {actual_rows}, Expected rows: {expected_rows}" + ) + + # Compare batches + if check_order: + for i, (actual, exp) in enumerate(zip(outputs, expected, strict=True)): + if not actual.equals(exp): + raise AssertionError( + f"{prefix}Batch {i} mismatch.\n" + f"Expected:\n{exp.to_pydict()}\n" + f"Got:\n{actual.to_pydict()}" + ) + else: + # Convert to sets of dicts for unordered comparison + actual_dicts = [b.to_pydict() for b in outputs] + expected_dicts = [b.to_pydict() for b in expected] + + for exp_dict in expected_dicts: + if exp_dict not in actual_dicts: + raise AssertionError( + f"{prefix}Expected batch not found in output.\n" + f"Expected:\n{exp_dict}\n" + f"Actual outputs:\n{actual_dicts}" + ) + + return logs diff --git a/vgi/worker.py b/vgi/worker.py index 6972849..574d56a 100644 --- a/vgi/worker.py +++ b/vgi/worker.py @@ -84,10 +84,12 @@ class MyWorker(Worker): OutputSpec, ) from vgi.ipc_utils import read_ipc_batch +from vgi.scalar_function import ScalarFunctionGenerator from vgi.table_function import TableFunctionGenerator from vgi.table_in_out_function import ( ProtocolInput, TableInOutGeneratorFunction, + _OutputStatus, ) @@ -278,6 +280,97 @@ def _read_init_data(self) -> pa.RecordBatch: """Read and parse the init data from stdin.""" return self._read_ipc_batch("init_data") + def _process_scalar_batches( + self, + instance: ScalarFunctionGenerator, + invocation: Invocation, + fn_log: structlog.stdlib.BoundLogger, + ) -> WorkerStats: + """Process data batches through a scalar function. + + Similar to _process_batches but simplified: + - No FINALIZE phase (ends when input exhausted) + - HAVE_MORE_OUTPUT only used for log messages (not multiple output batches) + + Returns: + WorkerStats with batch_count, total_input_rows, total_output_rows. + + """ + if invocation.global_init_identifier is None: + raise ValueError( + "global_init_identifier is required but was None. " + "This is an internal protocol error - the worker should have set " + "global_init_identifier after perform_init() completed successfully." + ) + generator = instance.run() + next(generator) # Prime the run() generator + + with ( + ipc.new_stream(cast(IOBase, sys.stdout), instance.output_schema) as writer, + ipc.open_stream(cast(IOBase, sys.stdin)) as data_reader, + ): + # Validate data stream schema matches expected input schema + if data_reader.schema != invocation.in_out_function_input_schema: + expected = invocation.in_out_function_input_schema + raise ValueError( + f"Data stream schema mismatch. Expected: {expected}, " + f"got: {data_reader.schema}" + ) + + batch_count = 0 + total_input_rows = 0 + total_output_rows = 0 + while True: + fn_log.debug("batch_waiting") + + try: + batch, metadata = data_reader.read_next_batch_with_custom_metadata() + except StopIteration: + fn_log.debug("input_stream_ended") + # Close the generator - no FINALIZE for scalar functions + generator.close() + break + + batch_count += 1 + total_input_rows += batch.num_rows + fn_log.debug( + "batch_received", + batch_index=batch_count, + input_rows=batch.num_rows, + ) + + protocol_input = ProtocolInput(batch=batch, metadata=metadata) + output = generator.send(protocol_input) + + # Handle log messages (HAVE_MORE_OUTPUT) + while output.status == _OutputStatus.HAVE_MORE_OUTPUT: + fn_log.debug("log_message_received", output=output) + assert output.batch is not None + writer.write_batch( + output.batch, custom_metadata=output.metadata(invocation) + ) + # Re-send same input to continue + output = generator.send(protocol_input) + + fn_log.debug("batch_processed", output=output) + assert output.batch is not None + output_rows = output.batch.num_rows + total_output_rows += output_rows + writer.write_batch( + output.batch, custom_metadata=output.metadata(invocation) + ) + fn_log.debug( + "batch_written", + batch_index=batch_count, + output_rows=output_rows, + status=output.status.value if output.status else None, + ) + return WorkerStats( + batch_count=batch_count, + total_input_rows=total_input_rows, + total_output_rows=total_output_rows, + ) + def _process_batches( self, instance: TableInOutGeneratorFunction, @@ -467,21 +560,26 @@ def run(self) -> None: instance.retrieve_init(invocation.global_init_identifier) # Dispatch to appropriate processing method based on function type. + # ScalarFunctionGenerator processes input batches to single-column output. # TableInOutGeneratorFunction reads input batches and produces output. # TableFunctionGenerator generates output without input batches. - # Note: Check TableInOutGeneratorFunction first since it's more specific - # (TableFunctionGenerator is a parent of TableInOutGeneratorFunction's parent). - if isinstance(instance, TableInOutGeneratorFunction): + # Note: Check ScalarFunctionGenerator first since it doesn't inherit from + # TableInOutGeneratorFunction, then TableInOutGeneratorFunction. + if isinstance(instance, ScalarFunctionGenerator): + stats = self._process_scalar_batches(instance, invocation, fn_log) + elif isinstance(instance, TableInOutGeneratorFunction): stats = self._process_batches(instance, invocation, fn_log) elif isinstance(instance, TableFunctionGenerator): stats = self._generate_batches(instance, invocation, fn_log) else: raise TypeError( f"Unsupported function type: {type(instance).__name__}. " - f"Functions must inherit from TableInOutGeneratorFunction (for " - f"functions that process input batches) or TableFunctionGenerator " - f"(for functions that generate output without input). " - f"See vgi.table_in_out_function and vgi.table_function modules." + f"Functions must inherit from ScalarFunctionGenerator (for " + f"scalar functions), TableInOutGeneratorFunction (for functions " + f"that process input batches), or TableFunctionGenerator (for " + f"functions that generate output without input). " + f"See vgi.scalar_function, vgi.table_in_out_function, and " + f"vgi.table_function modules." ) fn_log.info( From a952fb4fa8c4a8e9cc918f0ec55efb33a9eeb32e Mon Sep 17 00:00:00 2001 From: Rusty Conover Date: Fri, 2 Jan 2026 15:47:33 -0500 Subject: [PATCH 2/6] style: apply ruff format to test file MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- tests/scalar/test_function.py | 20 +++++--------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/tests/scalar/test_function.py b/tests/scalar/test_function.py index 333845b..903432f 100644 --- a/tests/scalar/test_function.py +++ b/tests/scalar/test_function.py @@ -68,9 +68,7 @@ def process(self, batch: pa.RecordBatch) -> OutputGenerator: next(generator) # Prime # Send a batch - input_batch = pa.RecordBatch.from_pydict( - {"x": [1, 2, 3]}, schema=input_schema - ) + input_batch = pa.RecordBatch.from_pydict({"x": [1, 2, 3]}, schema=input_schema) output = generator.send(ProtocolInput(batch=input_batch)) assert output.batch is not None @@ -146,9 +144,7 @@ def process(self, batch: pa.RecordBatch) -> OutputGenerator: generator = func.run() next(generator) - input_batch = pa.RecordBatch.from_pydict( - {"x": [1, 2, 3]}, schema=input_schema - ) + input_batch = pa.RecordBatch.from_pydict({"x": [1, 2, 3]}, schema=input_schema) # First yield should be the log message output = generator.send(ProtocolInput(batch=input_batch)) @@ -195,9 +191,7 @@ def compute(self, batch: pa.RecordBatch) -> pa.Array[Any]: generator = func.run() next(generator) - input_batch = pa.RecordBatch.from_pydict( - {"x": [1, 2, 3]}, schema=input_schema - ) + input_batch = pa.RecordBatch.from_pydict({"x": [1, 2, 3]}, schema=input_schema) output = generator.send(ProtocolInput(batch=input_batch)) assert output.batch is not None @@ -248,9 +242,7 @@ def compute(self, batch: pa.RecordBatch) -> pa.Array[Any]: generator = func.run() next(generator) - input_batch = pa.RecordBatch.from_pydict( - {"x": [1, 2, 3]}, schema=input_schema - ) + input_batch = pa.RecordBatch.from_pydict({"x": [1, 2, 3]}, schema=input_schema) # First yield should be the log message output = generator.send(ProtocolInput(batch=input_batch)) @@ -282,9 +274,7 @@ def compute(self, batch: pa.RecordBatch) -> pa.Array[Any]: generator = func.run() next(generator) - input_batch = pa.RecordBatch.from_pydict( - {"x": [1, 2, 3]}, schema=input_schema - ) + input_batch = pa.RecordBatch.from_pydict({"x": [1, 2, 3]}, schema=input_schema) output = generator.send(ProtocolInput(batch=input_batch)) # Should have an exception log message From 54db81e84d07be31d90b144c46b5ef74c0fbc540 Mon Sep 17 00:00:00 2001 From: Rusty Conover Date: Fri, 2 Jan 2026 15:53:04 -0500 Subject: [PATCH 3/6] refactor: clean up scalar function docstrings and remove dead code MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Remove mentions of HAVE_MORE_OUTPUT and FINALIZE from public docstrings - Remove dead code that handled FINALIZE signal (scalar functions don't use it) - Simplify run() method to use while True loop (exits via generator.close()) - Clean up internal comments 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- vgi/client/client.py | 13 ++++--------- vgi/scalar_function.py | 13 ++++--------- 2 files changed, 8 insertions(+), 18 deletions(-) diff --git a/vgi/client/client.py b/vgi/client/client.py index 2792cfd..b6b85a8 100644 --- a/vgi/client/client.py +++ b/vgi/client/client.py @@ -1818,16 +1818,14 @@ def scalar_function( """Invoke a scalar function on the worker and stream results. Scalar functions transform input batches to single-column output with - 1:1 row mapping. Unlike table_in_out_function, scalar functions have - no finalize phase - processing ends when input is exhausted. + 1:1 row mapping. Processing ends when input is exhausted. Processing flow: 1. Reads the first input batch to determine the input schema 2. Sends Invocation to worker and receives bind result 3. Spawns additional workers if max_processes > 1 4. Distributes input batches to workers (round-robin for parallel mode) - 5. Collects output batches, handling HAVE_MORE_OUTPUT for log messages - 6. Returns when input is exhausted (no FINALIZE signal) + 5. Collects and yields output batches For parallel processing (max_processes > 1), input batches are distributed round-robin across workers using dedicated threads. Output order may not @@ -1915,7 +1913,6 @@ def _scalar_function_parallel( ) -> Generator[pa.RecordBatch, None, None]: """Process scalar function batches across one or more workers using threads. - Similar to _table_in_out_function_parallel but without finalization. Handles both single-worker and multi-worker cases uniformly. Processing flow: @@ -1925,7 +1922,7 @@ def _scalar_function_parallel( 4. Signals end-of-input to all workers via None sentinel 5. Collects all output batches from shared output queue 6. Waits for worker threads to complete - 7. Closes all workers (no finalize phase) + 7. Closes all workers Args: input_batch: The first input batch, already consumed from the @@ -1937,8 +1934,7 @@ def _scalar_function_parallel( Yields: Output RecordBatches from processing, in non-deterministic order for - multi-worker mode. When multiple batches are returned for a single - input (HAVE_MORE_OUTPUT for logs), they are combined into one batch. + multi-worker mode. Raises: ClientError: If a worker thread fails with an exception. @@ -2020,7 +2016,6 @@ def _scalar_function_parallel( self._join_threads(threads) log.debug("all_scalar_worker_threads_complete") - # Close all workers (no finalize for scalar functions) # Close data writers to signal EOF to workers for worker in all_workers: if worker.data_writer is not None: diff --git a/vgi/scalar_function.py b/vgi/scalar_function.py index 6c2bd81..8f1ba4a 100644 --- a/vgi/scalar_function.py +++ b/vgi/scalar_function.py @@ -278,12 +278,12 @@ def run(self) -> Generator[ProtocolOutput, ProtocolInput | None, None]: """Run the scalar function protocol. Do not override. This generator implements the SETUP -> DATA -> TEARDOWN lifecycle. - No FINALIZE phase for scalar functions. + The generator is closed by the caller when input is exhausted. Protocol: - Caller primes with next() or send(None) - Caller sends ProtocolInput for each batch - - When input exhausted, generator closes (no FINALIZE signal needed) + - When input exhausted, caller closes the generator """ # Priming yield - caller calls next() or send(None) input: ProtocolInput | None = yield ProtocolOutput( @@ -300,8 +300,8 @@ def run(self) -> Generator[ProtocolOutput, ProtocolInput | None, None]: generator.send(None) try: - # DATA phase - no FINALIZE for scalar functions - while not input.is_finalize: + # DATA phase - process batches until generator is closed + while True: result = self._process_with_exception_handling(generator, input.batch) # Determine status based on result @@ -320,11 +320,6 @@ def run(self) -> Generator[ProtocolOutput, ProtocolInput | None, None]: raise ValueError("Expected ProtocolInput, got None") if self._should_terminate(result): return - - # When FINALIZE signal comes, just emit FINISHED (no finalize phase) - yield ProtocolOutput( - batch=self.empty_output_batch, status=_OutputStatus.FINISHED - ) finally: generator.close() # Release resources after processing completes From f67ba46e317599347f0f9145a603ef55eb0db4ab Mon Sep 17 00:00:00 2001 From: Rusty Conover Date: Fri, 2 Jan 2026 16:17:44 -0500 Subject: [PATCH 4/6] test: add comprehensive tests for scalar functions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add end-to-end tests via Client subprocess (tests/scalar/test_client.py) - TestScalarFunctionClient: 9 tests for basic operations, error handling - TestScalarFunctionParallel: 4 tests for parallel processing - Add CLI tests for --type scalar option (tests/client/test_cli.py) - 10 tests covering invocation, output formats, validation - Create scalar-specific ProtocolInput without unused is_finalize field - Update worker.py to use ScalarProtocolInput 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- tests/client/test_cli.py | 285 +++++++++++++++++++++++++++++++++ tests/scalar/test_client.py | 303 ++++++++++++++++++++++++++++++++++++ vgi/scalar_function.py | 21 ++- vgi/worker.py | 3 +- 4 files changed, 609 insertions(+), 3 deletions(-) create mode 100644 tests/scalar/test_client.py diff --git a/tests/client/test_cli.py b/tests/client/test_cli.py index f1f9b69..a7ff650 100644 --- a/tests/client/test_cli.py +++ b/tests/client/test_cli.py @@ -580,3 +580,288 @@ def test_stdout_output_json(self, example_worker: str) -> None: assert result.exit_code == 0 # Should have JSON output assert "n" in result.output + + +class TestCLIScalarFunction: + """Tests for CLI scalar function invocation with --type scalar.""" + + @pytest.fixture + def scalar_input_parquet(self, tmp_path: Path) -> Path: + """Create a parquet file suitable for scalar function tests.""" + batch = pa.RecordBatch.from_pydict({"x": [1, 2, 3, 4, 5]}) + input_file = tmp_path / "scalar_input.parquet" + pq.write_table(pa.Table.from_batches([batch]), str(input_file)) + return input_file + + def test_scalar_function_invocation( + self, example_worker: str, scalar_input_parquet: Path + ) -> None: + """Invoke a scalar function with --type scalar.""" + runner = CliRunner() + result = runner.invoke( + cli, + [ + "--input", + str(scalar_input_parquet), + "--function", + "double_column", + "--args", + '["x"]', + "--type", + "scalar", + "--server", + example_worker, + ], + ) + assert result.exit_code == 0 + + def test_scalar_function_with_output_file( + self, example_worker: str, scalar_input_parquet: Path, tmp_path: Path + ) -> None: + """Scalar function with output to file.""" + output_file = tmp_path / "output.jsonl" + runner = CliRunner() + result = runner.invoke( + cli, + [ + "--input", + str(scalar_input_parquet), + "--output", + str(output_file), + "--format", + "json", + "--function", + "double_column", + "--args", + '["x"]', + "--type", + "scalar", + "--server", + example_worker, + ], + ) + assert result.exit_code == 0 + assert output_file.exists() + lines = output_file.read_text().strip().split("\n") + # Should have 5 rows + assert len(lines) == 5 + # Verify first row is doubled + first_row = json.loads(lines[0]) + assert first_row["result"] == 2 + + def test_scalar_function_parquet_output( + self, example_worker: str, scalar_input_parquet: Path, tmp_path: Path + ) -> None: + """Scalar function with parquet output.""" + output_file = tmp_path / "output.parquet" + runner = CliRunner() + result = runner.invoke( + cli, + [ + "--input", + str(scalar_input_parquet), + "--output", + str(output_file), + "--format", + "parquet", + "--function", + "double_column", + "--args", + '["x"]', + "--type", + "scalar", + "--server", + example_worker, + ], + ) + assert result.exit_code == 0 + # Verify parquet output + table = pq.read_table(str(output_file)) + assert table.num_rows == 5 + assert table.column_names == ["result"] + assert table.column("result").to_pylist() == [2, 4, 6, 8, 10] + + def test_scalar_type_requires_input(self, example_worker: str) -> None: + """--type scalar requires --input.""" + runner = CliRunner() + result = runner.invoke( + cli, + [ + "--function", + "double_column", + "--args", + '["x"]', + "--type", + "scalar", + "--server", + example_worker, + ], + ) + assert result.exit_code != 0 + assert "requires --input" in result.output + + def test_table_in_out_type_requires_input(self, example_worker: str) -> None: + """--type table-in-out requires --input.""" + runner = CliRunner() + result = runner.invoke( + cli, + [ + "--function", + "echo", + "--type", + "table-in-out", + "--server", + example_worker, + ], + ) + assert result.exit_code != 0 + assert "requires --input" in result.output + + def test_table_type_rejects_input( + self, example_worker: str, scalar_input_parquet: Path + ) -> None: + """--type table does not accept --input.""" + runner = CliRunner() + result = runner.invoke( + cli, + [ + "--input", + str(scalar_input_parquet), + "--function", + "sequence", + "--args", + "[5]", + "--type", + "table", + "--server", + example_worker, + ], + ) + assert result.exit_code != 0 + assert "does not accept --input" in result.output + + def test_auto_type_with_input_uses_table_in_out( + self, example_worker: str, scalar_input_parquet: Path, tmp_path: Path + ) -> None: + """--type auto with --input uses table-in-out (echo function).""" + output_file = tmp_path / "output.jsonl" + runner = CliRunner() + result = runner.invoke( + cli, + [ + "--input", + str(scalar_input_parquet), + "--output", + str(output_file), + "--format", + "json", + "--function", + "echo", + "--type", + "auto", + "--server", + example_worker, + ], + ) + assert result.exit_code == 0 + # Echo should preserve the original column name "x" + content = output_file.read_text() + assert '"x"' in content + + def test_auto_type_without_input_uses_table( + self, example_worker: str, tmp_path: Path + ) -> None: + """--type auto without --input uses table function.""" + output_file = tmp_path / "output.jsonl" + runner = CliRunner() + result = runner.invoke( + cli, + [ + "--output", + str(output_file), + "--format", + "json", + "--function", + "sequence", + "--args", + "[3]", + "--type", + "auto", + "--server", + example_worker, + ], + ) + assert result.exit_code == 0 + lines = output_file.read_text().strip().split("\n") + assert len(lines) == 3 + + def test_scalar_with_add_columns( + self, example_worker: str, tmp_path: Path + ) -> None: + """Test add_columns scalar function via CLI.""" + # Create input with two columns + batch = pa.RecordBatch.from_pydict({"a": [1, 2, 3], "b": [10, 20, 30]}) + input_file = tmp_path / "input.parquet" + pq.write_table(pa.Table.from_batches([batch]), str(input_file)) + + output_file = tmp_path / "output.jsonl" + runner = CliRunner() + result = runner.invoke( + cli, + [ + "--input", + str(input_file), + "--output", + str(output_file), + "--format", + "json", + "--function", + "add_columns", + "--args", + '["a", "b"]', + "--type", + "scalar", + "--server", + example_worker, + ], + ) + assert result.exit_code == 0 + lines = output_file.read_text().strip().split("\n") + assert len(lines) == 3 + # Verify sums + results = [json.loads(line)["result"] for line in lines] + assert results == [11, 22, 33] + + def test_scalar_with_upper_case( + self, example_worker: str, tmp_path: Path + ) -> None: + """Test upper_case scalar function via CLI.""" + batch = pa.RecordBatch.from_pydict({"name": ["alice", "bob"]}) + input_file = tmp_path / "input.parquet" + pq.write_table(pa.Table.from_batches([batch]), str(input_file)) + + output_file = tmp_path / "output.jsonl" + runner = CliRunner() + result = runner.invoke( + cli, + [ + "--input", + str(input_file), + "--output", + str(output_file), + "--format", + "json", + "--function", + "upper_case", + "--args", + '["name"]', + "--type", + "scalar", + "--server", + example_worker, + ], + ) + assert result.exit_code == 0 + lines = output_file.read_text().strip().split("\n") + results = [json.loads(line)["result"] for line in lines] + assert results == ["ALICE", "BOB"] diff --git a/tests/scalar/test_client.py b/tests/scalar/test_client.py new file mode 100644 index 0000000..1e19bb8 --- /dev/null +++ b/tests/scalar/test_client.py @@ -0,0 +1,303 @@ +"""End-to-end tests for scalar functions via Client subprocess.""" + +from __future__ import annotations + +from typing import cast + +import pyarrow as pa +import pytest + +from vgi.client import Client +from vgi.client.client import ClientError +from vgi.function import Arguments + + +class TestScalarFunctionClient: + """Tests for scalar functions via Client subprocess.""" + + def test_double_column_basic(self, example_worker: str) -> None: + """Test basic scalar function via Client.""" + schema = pa.schema([("x", pa.int64())]) + batch = pa.RecordBatch.from_pydict({"x": [1, 2, 3]}, schema=schema) + + with Client(example_worker) as client: + outputs = list( + client.scalar_function( + function_name="double_column", + input=iter([batch]), + arguments=Arguments(positional=(pa.scalar("x"),)), + ) + ) + + assert len(outputs) == 1 + assert outputs[0].to_pydict() == {"result": [2, 4, 6]} + + def test_add_columns(self, example_worker: str) -> None: + """Test add_columns scalar function.""" + schema = pa.schema([("a", pa.int64()), ("b", pa.int64())]) + batch = pa.RecordBatch.from_pydict( + {"a": [1, 2, 3], "b": [10, 20, 30]}, schema=schema + ) + + with Client(example_worker) as client: + outputs = list( + client.scalar_function( + function_name="add_columns", + input=iter([batch]), + arguments=Arguments(positional=(pa.scalar("a"), pa.scalar("b"))), + ) + ) + + assert len(outputs) == 1 + assert outputs[0].to_pydict() == {"result": [11, 22, 33]} + + def test_upper_case(self, example_worker: str) -> None: + """Test upper_case scalar function.""" + schema = pa.schema([("name", pa.string())]) + batch = pa.RecordBatch.from_pydict( + {"name": ["alice", "bob", "charlie"]}, schema=schema + ) + + with Client(example_worker) as client: + outputs = list( + client.scalar_function( + function_name="upper_case", + input=iter([batch]), + arguments=Arguments(positional=(pa.scalar("name"),)), + ) + ) + + assert len(outputs) == 1 + assert outputs[0].to_pydict() == {"result": ["ALICE", "BOB", "CHARLIE"]} + + def test_multiple_batches(self, example_worker: str) -> None: + """Test scalar function with multiple input batches.""" + schema = pa.schema([("x", pa.int64())]) + batch1 = pa.RecordBatch.from_pydict({"x": [1, 2]}, schema=schema) + batch2 = pa.RecordBatch.from_pydict({"x": [3, 4, 5]}, schema=schema) + batch3 = pa.RecordBatch.from_pydict({"x": [6]}, schema=schema) + + with Client(example_worker) as client: + outputs = list( + client.scalar_function( + function_name="double_column", + input=iter([batch1, batch2, batch3]), + arguments=Arguments(positional=(pa.scalar("x"),)), + ) + ) + + # Should get 3 output batches (one per input) + assert len(outputs) == 3 + total_rows = sum(b.num_rows for b in outputs) + assert total_rows == 6 + + # Verify the values (order may vary in parallel mode, but we're single-worker) + all_values: list[int] = [] + for batch in outputs: + all_values.extend(cast(list[int], batch.column("result").to_pylist())) + assert sorted(all_values) == [2, 4, 6, 8, 10, 12] + + def test_empty_batch(self, example_worker: str) -> None: + """Test scalar function with empty batch.""" + schema = pa.schema([("x", pa.int64())]) + empty_batch = pa.RecordBatch.from_pydict({"x": []}, schema=schema) + + with Client(example_worker) as client: + outputs = list( + client.scalar_function( + function_name="double_column", + input=iter([empty_batch]), + arguments=Arguments(positional=(pa.scalar("x"),)), + ) + ) + + # Should get one output batch with zero rows + assert len(outputs) == 1 + assert outputs[0].num_rows == 0 + + def test_empty_iterator(self, example_worker: str) -> None: + """Test scalar function with no input batches.""" + with Client(example_worker) as client: + outputs = list( + client.scalar_function( + function_name="double_column", + input=iter([]), + arguments=Arguments(positional=(pa.scalar("x"),)), + ) + ) + + # No input means no output + assert len(outputs) == 0 + + def test_scalar_function_not_started_raises(self, example_worker: str) -> None: + """Calling scalar_function before start should raise ClientError.""" + client = Client(example_worker) + schema = pa.schema([("x", pa.int64())]) + batch = pa.RecordBatch.from_pydict({"x": [1]}, schema=schema) + + with pytest.raises(ClientError, match="not started"): + list( + client.scalar_function( + function_name="double_column", + input=iter([batch]), + arguments=Arguments(positional=(pa.scalar("x"),)), + ) + ) + + def test_large_batch(self, example_worker: str) -> None: + """Test scalar function with a large batch.""" + schema = pa.schema([("x", pa.int64())]) + large_data = list(range(10000)) + batch = pa.RecordBatch.from_pydict({"x": large_data}, schema=schema) + + with Client(example_worker) as client: + outputs = list( + client.scalar_function( + function_name="double_column", + input=iter([batch]), + arguments=Arguments(positional=(pa.scalar("x"),)), + ) + ) + + total_rows = sum(b.num_rows for b in outputs) + assert total_rows == 10000 + + # Verify first and last values + all_values = [] + for b in outputs: + all_values.extend(b.column("result").to_pylist()) + assert all_values[0] == 0 # 0 * 2 = 0 + assert all_values[-1] == 19998 # 9999 * 2 = 19998 + + def test_bind_result_callback(self, example_worker: str) -> None: + """Test that bind_result_callback is invoked.""" + schema = pa.schema([("x", pa.int64())]) + batch = pa.RecordBatch.from_pydict({"x": [1, 2, 3]}, schema=schema) + + bind_results: list[pa.RecordBatch] = [] + + def capture_bind_result(result: pa.RecordBatch) -> None: + bind_results.append(result) + + with Client(example_worker) as client: + list( + client.scalar_function( + function_name="double_column", + input=iter([batch]), + arguments=Arguments(positional=(pa.scalar("x"),)), + bind_result_callback=capture_bind_result, + ) + ) + + # Should have received bind result + assert len(bind_results) == 1 + bind_result = bind_results[0] + + # Verify bind result contains expected fields + assert "output_schema" in bind_result.schema.names + assert "max_processes" in bind_result.schema.names + + +class TestScalarFunctionParallel: + """Tests for scalar functions with parallel processing.""" + + def test_parallel_double_column(self, example_worker: str) -> None: + """Test scalar function with multiple workers.""" + schema = pa.schema([("x", pa.int64())]) + batches = [ + pa.RecordBatch.from_pydict( + {"x": list(range(i * 100, (i + 1) * 100))}, schema=schema + ) + for i in range(10) + ] + + with Client(example_worker, max_workers=4) as client: + outputs = list( + client.scalar_function( + function_name="double_column", + input=iter(batches), + arguments=Arguments(positional=(pa.scalar("x"),)), + ) + ) + + # Should get all 1000 rows back + total_rows = sum(b.num_rows for b in outputs) + assert total_rows == 1000 + + # Verify all values are correctly doubled + all_values = set() + for batch in outputs: + all_values.update(batch.column("result").to_pylist()) + + expected = {i * 2 for i in range(1000)} + assert all_values == expected + + def test_parallel_add_columns(self, example_worker: str) -> None: + """Test add_columns with multiple workers.""" + schema = pa.schema([("a", pa.int64()), ("b", pa.int64())]) + batches = [ + pa.RecordBatch.from_pydict( + {"a": [i, i + 1, i + 2], "b": [100, 200, 300]}, schema=schema + ) + for i in range(20) + ] + + with Client(example_worker, max_workers=3) as client: + outputs = list( + client.scalar_function( + function_name="add_columns", + input=iter(batches), + arguments=Arguments(positional=(pa.scalar("a"), pa.scalar("b"))), + ) + ) + + # Should get 60 rows total (20 batches * 3 rows) + total_rows = sum(b.num_rows for b in outputs) + assert total_rows == 60 + + def test_parallel_empty_batches_mixed(self, example_worker: str) -> None: + """Test parallel processing with mix of empty and non-empty batches.""" + schema = pa.schema([("x", pa.int64())]) + batches = [ + pa.RecordBatch.from_pydict({"x": [1, 2]}, schema=schema), + pa.RecordBatch.from_pydict({"x": []}, schema=schema), # Empty + pa.RecordBatch.from_pydict({"x": [3]}, schema=schema), + pa.RecordBatch.from_pydict({"x": []}, schema=schema), # Empty + pa.RecordBatch.from_pydict({"x": [4, 5, 6]}, schema=schema), + ] + + with Client(example_worker, max_workers=2) as client: + outputs = list( + client.scalar_function( + function_name="double_column", + input=iter(batches), + arguments=Arguments(positional=(pa.scalar("x"),)), + ) + ) + + # Should get 6 rows total (2 + 0 + 1 + 0 + 3) + total_rows = sum(b.num_rows for b in outputs) + assert total_rows == 6 + + # Verify values + all_values = set() + for batch in outputs: + all_values.update(batch.column("result").to_pylist()) + assert all_values == {2, 4, 6, 8, 10, 12} + + def test_parallel_single_batch(self, example_worker: str) -> None: + """Test parallel mode with just one batch (should still work).""" + schema = pa.schema([("x", pa.int64())]) + batch = pa.RecordBatch.from_pydict({"x": [1, 2, 3]}, schema=schema) + + with Client(example_worker, max_workers=4) as client: + outputs = list( + client.scalar_function( + function_name="double_column", + input=iter([batch]), + arguments=Arguments(positional=(pa.scalar("x"),)), + ) + ) + + assert len(outputs) == 1 + assert outputs[0].to_pydict() == {"result": [2, 4, 6]} diff --git a/vgi/scalar_function.py b/vgi/scalar_function.py index 8f1ba4a..d7d6a00 100644 --- a/vgi/scalar_function.py +++ b/vgi/scalar_function.py @@ -44,16 +44,33 @@ ] -# Protocol types - reuse from table_in_out_function +# Protocol types - reuse Output/OutputGenerator from table_in_out_function from vgi.table_in_out_function import ( # noqa: E402 Output, OutputGenerator, - ProtocolInput, ProtocolOutput, _OutputStatus, ) +@dataclass(frozen=True, slots=True) +class ProtocolInput: + """Input sent to the scalar function generator via send(). + + This is a simplified version of table_in_out_function.ProtocolInput + without finalization support, since scalar functions don't have a + finalize phase. + + Attributes: + batch: The input RecordBatch to process. + metadata: Optional metadata from the IPC stream. + + """ + + batch: pa.RecordBatch + metadata: pa.KeyValueMetadata | None = None + + @dataclass(frozen=True, slots=True) class _ScalarOutputComplete: """Internal: Output with guaranteed non-None batch for scalar functions. diff --git a/vgi/worker.py b/vgi/worker.py index 574d56a..025a0d0 100644 --- a/vgi/worker.py +++ b/vgi/worker.py @@ -84,6 +84,7 @@ class MyWorker(Worker): OutputSpec, ) from vgi.ipc_utils import read_ipc_batch +from vgi.scalar_function import ProtocolInput as ScalarProtocolInput from vgi.scalar_function import ScalarFunctionGenerator from vgi.table_function import TableFunctionGenerator from vgi.table_in_out_function import ( @@ -339,7 +340,7 @@ def _process_scalar_batches( input_rows=batch.num_rows, ) - protocol_input = ProtocolInput(batch=batch, metadata=metadata) + protocol_input = ScalarProtocolInput(batch=batch, metadata=metadata) output = generator.send(protocol_input) # Handle log messages (HAVE_MORE_OUTPUT) From 2be9090d43e6ee9fb182c8c4e0652b8a6fe58865 Mon Sep 17 00:00:00 2001 From: Rusty Conover Date: Sat, 3 Jan 2026 22:49:24 -0500 Subject: [PATCH 5/6] cleanups for scalar functions --- CLAUDE.md | 76 +++- tests/client/test_cli.py | 8 +- tests/scalar/test_function.py | 53 +-- .../generator/test_constant_table_function.py | 7 +- .../generator/test_projected_data_function.py | 11 +- .../table/generator/test_sequence_function.py | 7 +- tests/table/test_function.py | 17 +- tests/table_in_out/test_function.py | 8 +- tests/test_metadata.py | 14 +- tests/test_patterns.py | 5 +- tests/test_protocol_classes.py | 46 +- tests/test_schema_utils.py | 14 +- tests/test_worker.py | 40 +- vgi/__init__.py | 26 +- vgi/arguments.py | 2 + vgi/client/cli.py | 1 - vgi/client/client.py | 61 ++- vgi/examples/table.py | 4 +- vgi/function.py | 403 +++++++++++++++--- vgi/scalar_function.py | 143 ++----- vgi/table_function.py | 240 +++++------ vgi/table_in_out_function.py | 22 +- vgi/testing.py | 54 ++- vgi/worker.py | 137 ++++-- 24 files changed, 881 insertions(+), 518 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index e5d687a..c13e192 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -40,6 +40,9 @@ VGI (Vector Gateway Interface) provides an Apache Arrow-based protocol for conne │ ▼ │ │ ┌───────────────────────────────────────────────────────────────┐ │ │ │ Worker Process │ │ +│ │ SCALAR FUNCTION (ScalarFunction) │ │ +│ │ - compute(batch): Transform each row to single output column │ │ +│ │ OR │ │ │ │ TABLE FUNCTION (TableFunctionGenerator) │ │ │ │ - process(): Generator yielding output batches (no input) │ │ │ │ OR │ │ @@ -52,6 +55,7 @@ VGI (Vector Gateway Interface) provides an Apache Arrow-based protocol for conne | Type | Base Class | Input | Use Case | |------|------------|-------|----------| +| **Scalar Function** | `ScalarFunction` | Batches | Per-row transforms (1:1 row mapping, single output column) | | **Table Function** | `TableFunctionGenerator` | None | Generate data (sequences, ranges) | | **Table-In-Out Function** | `TableInOutFunction` | Batches | Transform, filter, aggregate | @@ -59,6 +63,7 @@ VGI (Vector Gateway Interface) provides an Apache Arrow-based protocol for conne - **Worker** (`vgi/worker.py`): Subprocess that hosts functions - **Client** (`vgi/client/client.py`): Spawns workers, streams data +- **ScalarFunction** (`vgi/scalar_function.py`): Base for scalar functions - **TableFunctionGenerator** (`vgi/table_function.py`): Base for table functions - **TableInOutFunction** (`vgi/table_in_out_function.py`): Base for table-in-out functions @@ -67,7 +72,8 @@ VGI (Vector Gateway Interface) provides an Apache Arrow-based protocol for conne ``` vgi/ __init__.py # Package exports - function.py # Invocation, OutputSpec, Arguments, GlobalInitResult + function.py # Invocation, OutputSpec, Arguments, FunctionType + scalar_function.py # ScalarFunction, ScalarFunctionGenerator table_function.py # TableFunctionGenerator, CardinalityInfo, Output table_in_out_function.py # TableInOutFunction, TableInOutGeneratorFunction metadata.py # Function metadata for introspection @@ -76,6 +82,7 @@ vgi/ client/ client.py # Client class examples/ + scalar.py # Example scalar functions table.py # Example table functions table_in_out.py # Example table-in-out functions worker.py # ExampleWorker with registry @@ -89,6 +96,32 @@ vgi-client --input data.parquet --function echo --server vgi-example-worker vgi-client --input data.parquet --function sum_all_columns --server vgi-example-worker ``` +## Creating a Scalar Function (Per-Row Transform) + +```python +import pyarrow as pa +import pyarrow.compute as pc +from vgi import ScalarFunction, Arg + +class DoubleColumn(ScalarFunction): + """Double the value in a specified column.""" + + column = Arg[str](0, doc="Column to double") + + @property + def output_type(self) -> pa.DataType: + # Output type matches input column type + return self.input_schema.field(self.column).type + + def compute(self, batch: pa.RecordBatch) -> pa.Array: + return pc.multiply(batch.column(self.column), 2) +``` + +### Key Constraints for Scalar Functions: +- **1:1 row mapping**: Output must have exactly the same number of rows as input +- **Single column output**: Output schema has exactly one column named "result" +- **No finalize phase**: All processing happens in compute() + ## Creating a Table-In-Out Function (Recommended) ```python @@ -182,6 +215,9 @@ if __name__ == "__main__": ### Imports ```python +# Scalar Functions (per-row transform) +from vgi import ScalarFunction, Arg, Worker + # Table Functions (no input) from vgi import TableFunctionGenerator, Output, Arg, Worker @@ -221,6 +257,17 @@ output_schema = schema_like(self.input_schema, rename={"old": "new"}) ### Method Override Summary +**ScalarFunction:** + +| Method | When to Override | Default | +|--------|------------------|---------| +| `output_type` | Define output column type | Required | +| `compute(batch)` | Transform batch to single array | Required | +| `setup()` | Acquire resources | No-op | +| `teardown()` | Release resources | No-op | + +**TableInOutFunction:** + | Method | When to Override | Default | |--------|------------------|---------| | `output_schema` | Change output columns | Returns input_schema | @@ -232,17 +279,22 @@ output_schema = schema_like(self.input_schema, rename={"old": "new"}) ### Pattern Decision Tree ``` -Need to implement a VGI function? -│ -├─ Does the function receive input data? -│ │ -│ ├─ NO → Use TableFunctionGenerator -│ │ Override process() to yield Output batches -│ │ -│ └─ YES → Use TableInOutFunction -│ ├─ Transform each batch? → Override transform() -│ ├─ Aggregate results? → Accumulate in transform(), emit in finish() -│ └─ Need generator control? → See docs/generator-api.md +How will your function be used in SQL? + +1. SELECT my_func(col1, col2) FROM table + → SCALAR FUNCTION: Returns one value per input row + → Use ScalarFunction, override output_type and compute() + → Example: upper(), abs(), concat() + +2. SELECT * FROM my_func(args) + → TABLE FUNCTION: Generates rows from arguments (no input table) + → Use TableFunctionGenerator, override process() + → Example: range(), read_csv(), glob() + +3. SELECT * FROM my_func(args, (SELECT * FROM input_table)) + → TABLE-IN-OUT FUNCTION: Transforms input rows to output rows + → Use TableInOutFunction, override transform() and optionally finish() + → Example: filtering, enrichment, aggregation ``` ## Additional Documentation diff --git a/tests/client/test_cli.py b/tests/client/test_cli.py index a7ff650..ea33f93 100644 --- a/tests/client/test_cli.py +++ b/tests/client/test_cli.py @@ -795,9 +795,7 @@ def test_auto_type_without_input_uses_table( lines = output_file.read_text().strip().split("\n") assert len(lines) == 3 - def test_scalar_with_add_columns( - self, example_worker: str, tmp_path: Path - ) -> None: + def test_scalar_with_add_columns(self, example_worker: str, tmp_path: Path) -> None: """Test add_columns scalar function via CLI.""" # Create input with two columns batch = pa.RecordBatch.from_pydict({"a": [1, 2, 3], "b": [10, 20, 30]}) @@ -832,9 +830,7 @@ def test_scalar_with_add_columns( results = [json.loads(line)["result"] for line in lines] assert results == [11, 22, 33] - def test_scalar_with_upper_case( - self, example_worker: str, tmp_path: Path - ) -> None: + def test_scalar_with_upper_case(self, example_worker: str, tmp_path: Path) -> None: """Test upper_case scalar function via CLI.""" batch = pa.RecordBatch.from_pydict({"name": ["alice", "bob"]}) input_file = tmp_path / "input.parquet" diff --git a/tests/scalar/test_function.py b/tests/scalar/test_function.py index 903432f..1422b87 100644 --- a/tests/scalar/test_function.py +++ b/tests/scalar/test_function.py @@ -9,23 +9,23 @@ import structlog from vgi.arguments import Arg -from vgi.function import Arguments, Invocation +from vgi.function import Arguments, Invocation, InvocationType, SchemaValidationError from vgi.log import Level, Message from vgi.scalar_function import ( Output, - OutputGenerator, ProtocolInput, ScalarFunction, ScalarFunctionGenerator, + ScalarOutputGenerator, ) -from vgi.table_function import SchemaValidationError def create_invocation(input_schema: pa.Schema) -> Invocation: """Create a test invocation with the given input schema.""" return Invocation( function_name="test_function", - in_out_function_input_schema=input_schema, + input_schema=input_schema, + function_type=InvocationType.SCALAR, correlation_id="test-correlation", invocation_id=b"test-invocation", arguments=Arguments(), @@ -43,8 +43,8 @@ class DoubleColumn(ScalarFunctionGenerator): def output_schema(self) -> pa.Schema: return pa.schema([("result", pa.int64())]) - def process(self, batch: pa.RecordBatch) -> OutputGenerator: - _ = yield None + def process(self, batch: pa.RecordBatch) -> ScalarOutputGenerator: + _ = yield Output(self.empty_output_batch) # Priming yield import pyarrow.compute as pc while True: @@ -83,12 +83,13 @@ class TestFunc(ScalarFunctionGenerator): def output_schema(self) -> pa.Schema: return pa.schema([("result", pa.int64())]) - def process(self, batch: pa.RecordBatch) -> OutputGenerator: - _ = yield None + def process(self, batch: pa.RecordBatch) -> ScalarOutputGenerator: + _ = yield Output(self.empty_output_batch) invocation = Invocation( function_name="test", - in_out_function_input_schema=None, + input_schema=None, + function_type=InvocationType.SCALAR, correlation_id="test", invocation_id=None, arguments=Arguments(), @@ -105,8 +106,8 @@ class TwoColumnOutput(ScalarFunctionGenerator): def output_schema(self) -> pa.Schema: return pa.schema([("a", pa.int64()), ("b", pa.int64())]) - def process(self, batch: pa.RecordBatch) -> OutputGenerator: - _ = yield None + def process(self, batch: pa.RecordBatch) -> ScalarOutputGenerator: + _ = yield Output(self.empty_output_batch) input_schema = pa.schema([("x", pa.int64())]) invocation = create_invocation(input_schema) @@ -122,8 +123,8 @@ class LoggingScalar(ScalarFunctionGenerator): def output_schema(self) -> pa.Schema: return pa.schema([("result", pa.int64())]) - def process(self, batch: pa.RecordBatch) -> OutputGenerator: - _ = yield None + def process(self, batch: pa.RecordBatch) -> ScalarOutputGenerator: + _ = yield Output(self.empty_output_batch) # Priming yield import pyarrow.compute as pc while True: @@ -179,7 +180,8 @@ def compute(self, batch: pa.RecordBatch) -> pa.Array[Any]: input_schema = pa.schema([("x", pa.int64())]) invocation = Invocation( function_name="test", - in_out_function_input_schema=input_schema, + input_schema=input_schema, + function_type=InvocationType.SCALAR, correlation_id="test", invocation_id=None, arguments=Arguments(positional=(pa.scalar("x"),)), @@ -198,29 +200,6 @@ def compute(self, batch: pa.RecordBatch) -> pa.Array[Any]: assert output.batch.num_rows == 3 assert output.batch.column("result").to_pylist() == [2, 4, 6] - def test_custom_output_name(self) -> None: - """Test custom output column name.""" - - class CustomName(ScalarFunction): - @property - def output_name(self) -> str: - return "doubled" - - @property - def output_type(self) -> pa.DataType: - return pa.int64() - - def compute(self, batch: pa.RecordBatch) -> pa.Array[Any]: - import pyarrow.compute as pc - - return pc.multiply(batch.column("x"), 2) - - input_schema = pa.schema([("x", pa.int64())]) - invocation = create_invocation(input_schema) - - func = CustomName(invocation=invocation, logger=structlog.get_logger()) - assert func.output_schema.names == ["doubled"] - def test_log_method(self) -> None: """Test self.log() method.""" diff --git a/tests/table/generator/test_constant_table_function.py b/tests/table/generator/test_constant_table_function.py index 94681b5..3d489d0 100644 --- a/tests/table/generator/test_constant_table_function.py +++ b/tests/table/generator/test_constant_table_function.py @@ -29,14 +29,15 @@ def test_cardinality(self) -> None: """Cardinality should always be 1.""" import structlog - from vgi.function import Arguments, Invocation + from vgi.function import Arguments, Invocation, InvocationType invocation = Invocation( function_name="constant_table", - arguments=Arguments(positional=(pa.scalar(42),)), - in_out_function_input_schema=None, + input_schema=None, + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", + arguments=Arguments(positional=(pa.scalar(42),)), ) func = ConstantTableFunction( invocation=invocation, diff --git a/tests/table/generator/test_projected_data_function.py b/tests/table/generator/test_projected_data_function.py index d30b632..9e84ad2 100644 --- a/tests/table/generator/test_projected_data_function.py +++ b/tests/table/generator/test_projected_data_function.py @@ -152,15 +152,16 @@ def test_output_schema_reflects_projection(self) -> None: """The output_schema property should reflect the projection.""" import structlog - from vgi.function import Invocation - from vgi.table_function import GlobalStateInitInput + from vgi.function import Invocation, InvocationType + from vgi.table_function import TableFunctionInitInput invocation = Invocation( function_name="projected_data", - arguments=Arguments(positional=(pa.scalar(10),)), - in_out_function_input_schema=None, + input_schema=None, + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", + arguments=Arguments(positional=(pa.scalar(10),)), ) func = ProjectedDataFunction( invocation=invocation, @@ -171,7 +172,7 @@ def test_output_schema_reflects_projection(self) -> None: assert func.output_schema == ProjectedDataFunction.FULL_SCHEMA # After setting init_data with projection, should return projected schema - func.init_data = GlobalStateInitInput(projection_ids=[0, 2]) + func.init_data = TableFunctionInitInput(projection_ids=[0, 2]) schema = func.output_schema assert len(schema) == 2 assert schema.names == ["id", "value"] diff --git a/tests/table/generator/test_sequence_function.py b/tests/table/generator/test_sequence_function.py index f3e8963..780689c 100644 --- a/tests/table/generator/test_sequence_function.py +++ b/tests/table/generator/test_sequence_function.py @@ -30,14 +30,15 @@ def test_cardinality(self) -> None: """Cardinality should match requested count.""" import structlog - from vgi.function import Arguments, Invocation + from vgi.function import Arguments, Invocation, InvocationType invocation = Invocation( function_name="sequence", - arguments=Arguments(positional=(pa.scalar(100),)), - in_out_function_input_schema=None, + input_schema=None, + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", + arguments=Arguments(positional=(pa.scalar(100),)), ) func = SequenceFunction( invocation=invocation, diff --git a/tests/table/test_function.py b/tests/table/test_function.py index 566c561..e531681 100644 --- a/tests/table/test_function.py +++ b/tests/table/test_function.py @@ -7,7 +7,7 @@ import structlog from tests.utils import make_schema -from vgi.function import Arguments, Invocation +from vgi.function import Arguments, Invocation, InvocationType from vgi.table_function import ( CardinalityInfo, Output, @@ -194,10 +194,11 @@ def output_schema(self) -> pa.Schema: invocation = Invocation( function_name="test", - arguments=Arguments(), - in_out_function_input_schema=None, + input_schema=None, + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", + arguments=Arguments(), ) func = NoCardinalityFunction( invocation=invocation, @@ -219,10 +220,11 @@ def cardinality(self) -> CardinalityInfo: invocation = Invocation( function_name="test", - arguments=Arguments(), - in_out_function_input_schema=None, + input_schema=None, + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", + arguments=Arguments(), ) func = CardinalityFunction( invocation=invocation, @@ -311,10 +313,11 @@ def output_schema(self) -> pa.Schema: invocation = Invocation( function_name="test", - arguments=Arguments(), - in_out_function_input_schema=None, + input_schema=None, + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", + arguments=Arguments(), ) func = TestFunction( invocation=invocation, diff --git a/tests/table_in_out/test_function.py b/tests/table_in_out/test_function.py index d624da8..78ae2bd 100644 --- a/tests/table_in_out/test_function.py +++ b/tests/table_in_out/test_function.py @@ -4,7 +4,7 @@ import pyarrow.compute as pc import structlog -from vgi.function import Arg, Arguments, Invocation +from vgi.function import Arg, Arguments, Invocation, InvocationType from vgi.ipc_utils import RecordBatchState from vgi.log import Level from vgi.table_in_out_function import ( @@ -18,7 +18,8 @@ def make_invocation(input_schema: pa.Schema) -> Invocation: """Create a minimal Invocation for testing.""" return Invocation( function_name="test", - in_out_function_input_schema=input_schema, + input_schema=input_schema, + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", arguments=Arguments(), @@ -31,7 +32,8 @@ def make_invocation_with_args( """Create an Invocation with positional arguments.""" return Invocation( function_name="test", - in_out_function_input_schema=input_schema, + input_schema=input_schema, + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", arguments=Arguments(positional=tuple(pa.scalar(v) for v in positional)), diff --git a/tests/test_metadata.py b/tests/test_metadata.py index e1d1cd8..d2b8d63 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -211,7 +211,7 @@ class TestMaxWorkersIntegration: def test_max_workers_used(self) -> None: """max_processes() returns Meta.max_workers when defined.""" from vgi.arguments import Arguments - from vgi.function import Invocation + from vgi.function import Invocation, InvocationType class LimitedFunction(TableInOutFunction): class Meta: @@ -221,10 +221,11 @@ class Meta: invocation = Invocation( function_name="test", - arguments=Arguments(), - in_out_function_input_schema=pa.schema([]), + input_schema=pa.schema([]), + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", + arguments=Arguments(), ) import structlog @@ -234,17 +235,18 @@ class Meta: def test_default_max_workers(self) -> None: """max_processes() returns default when max_workers not defined.""" from vgi.arguments import Arguments - from vgi.function import Invocation + from vgi.function import Invocation, InvocationType class UnlimitedFunction(TableInOutFunction): data: TableInput = Arg[TableInput](0, doc="Input table") # type: ignore[assignment] invocation = Invocation( function_name="test", - arguments=Arguments(), - in_out_function_input_schema=pa.schema([]), + input_schema=pa.schema([]), + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", + arguments=Arguments(), ) import structlog diff --git a/tests/test_patterns.py b/tests/test_patterns.py index 4857f41..d7dff6a 100644 --- a/tests/test_patterns.py +++ b/tests/test_patterns.py @@ -8,7 +8,7 @@ import pyarrow.compute as pc import structlog -from vgi.function import Arg, Arguments, Invocation +from vgi.function import Arg, Arguments, Invocation, InvocationType from vgi.log import Level from vgi.table_in_out_function_patterns import ( AggregationFunction, @@ -22,7 +22,8 @@ def make_invocation(input_schema: pa.Schema) -> Invocation: """Create a minimal Invocation for testing.""" return Invocation( function_name="test", - in_out_function_input_schema=input_schema, + input_schema=input_schema, + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", arguments=Arguments(), diff --git a/tests/test_protocol_classes.py b/tests/test_protocol_classes.py index 3693a4a..c6f22dc 100644 --- a/tests/test_protocol_classes.py +++ b/tests/test_protocol_classes.py @@ -15,12 +15,13 @@ ArgumentValidationError, GlobalInitResult, Invocation, + InvocationType, ) from vgi.log import Level, Message from vgi.table_function import ( CardinalityInfo, - GlobalStateInitInput, OutputSpec, + TableFunctionInitInput, ) @@ -180,10 +181,11 @@ def test_basic_round_trip(self) -> None: """Basic Invocation should serialize and deserialize correctly.""" original = Invocation( function_name="test_function", - arguments=Arguments(positional=(pa.scalar(42),), named={}), - in_out_function_input_schema=make_schema([pa.field("col1", pa.int64())]), + input_schema=make_schema([pa.field("col1", pa.int64())]), + function_type=InvocationType.TABLE, correlation_id="test-123", invocation_id=b"bind-id-bytes", + arguments=Arguments(positional=(pa.scalar(42),), named={}), ) serialized = original.serialize() @@ -200,10 +202,7 @@ def test_basic_round_trip(self) -> None: assert deserialized.function_name == original.function_name assert deserialized.correlation_id == original.correlation_id assert deserialized.invocation_id == original.invocation_id - assert ( - deserialized.in_out_function_input_schema - == original.in_out_function_input_schema - ) + assert deserialized.input_schema == original.input_schema assert len(deserialized.arguments.positional) == 1 assert deserialized.arguments.positional[0] is not None assert deserialized.arguments.positional[0].as_py() == 42 @@ -212,7 +211,8 @@ def test_nullmake_schema(self) -> None: """Invocation with null input schema should round-trip correctly.""" original = Invocation( function_name="scalar_function", - in_out_function_input_schema=None, + input_schema=None, + function_type=InvocationType.TABLE, correlation_id="", invocation_id=None, ) @@ -225,7 +225,7 @@ def test_nullmake_schema(self) -> None: deserialized = Invocation.deserialize(batch) assert deserialized.function_name == "scalar_function" - assert deserialized.in_out_function_input_schema is None + assert deserialized.input_schema is None assert deserialized.invocation_id is None def test_complexmake_schema(self) -> None: @@ -242,7 +242,8 @@ def test_complexmake_schema(self) -> None: original = Invocation( function_name="complex_function", - in_out_function_input_schema=complex_schema, + input_schema=complex_schema, + function_type=InvocationType.TABLE, correlation_id="complex-test", invocation_id=b"complex-bind", ) @@ -254,7 +255,7 @@ def test_complexmake_schema(self) -> None: batch = reader.read_next_batch() deserialized = Invocation.deserialize(batch) - assert deserialized.in_out_function_input_schema == complex_schema + assert deserialized.input_schema == complex_schema def test_deserialize_empty_batch_raises(self) -> None: """Deserializing empty batch should raise ValueError.""" @@ -264,7 +265,7 @@ def test_deserialize_empty_batch_raises(self) -> None: [ pa.field("function_name", pa.string()), pa.field("arguments", pa.struct([])), - pa.field("in_out_function_input_schema", pa.binary()), + pa.field("input_schema", pa.binary()), pa.field("invocation_id", pa.binary()), pa.field("correlation_id", pa.string()), ] @@ -281,14 +282,14 @@ def test_deserialize_multi_row_batch_raises(self) -> None: { "function_name": "fn1", "arguments": {}, - "in_out_function_input_schema": None, + "input_schema": None, "invocation_id": None, "correlation_id": "", }, { "function_name": "fn2", "arguments": {}, - "in_out_function_input_schema": None, + "input_schema": None, "invocation_id": None, "correlation_id": "", }, @@ -302,7 +303,8 @@ def test_with_global_init_identifier(self) -> None: """Test that with_global_init_identifier creates a new Invocation.""" original = Invocation( function_name="test", - in_out_function_input_schema=None, + input_schema=None, + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=None, global_init_identifier=None, @@ -517,7 +519,7 @@ class TestGlobalStateInitInput: def test_basic_round_trip(self) -> None: """GlobalStateInitInput should serialize and deserialize correctly.""" - original = GlobalStateInitInput(projection_ids=[0, 2, 4]) + original = TableFunctionInitInput(projection_ids=[0, 2, 4]) serialized = original.serialize() assert isinstance(serialized, bytes) @@ -526,39 +528,39 @@ def test_basic_round_trip(self) -> None: reader = ipc.open_stream(serialized) batch = reader.read_next_batch() - deserialized = GlobalStateInitInput.deserialize(batch) + deserialized = TableFunctionInitInput.deserialize(batch) assert deserialized.projection_ids == [0, 2, 4] def test_null_projection_ids(self) -> None: """GlobalStateInitInput with null projection_ids should round-trip.""" - original = GlobalStateInitInput(projection_ids=None) + original = TableFunctionInitInput(projection_ids=None) serialized = original.serialize() from pyarrow import ipc reader = ipc.open_stream(serialized) batch = reader.read_next_batch() - deserialized = GlobalStateInitInput.deserialize(batch) + deserialized = TableFunctionInitInput.deserialize(batch) assert deserialized.projection_ids is None def test_empty_projection_ids(self) -> None: """GlobalStateInitInput with empty list should round-trip.""" - original = GlobalStateInitInput(projection_ids=[]) + original = TableFunctionInitInput(projection_ids=[]) serialized = original.serialize() from pyarrow import ipc reader = ipc.open_stream(serialized) batch = reader.read_next_batch() - deserialized = GlobalStateInitInput.deserialize(batch) + deserialized = TableFunctionInitInput.deserialize(batch) assert deserialized.projection_ids == [] def test_default_value(self) -> None: """GlobalStateInitInput default should have None projection_ids.""" - default = GlobalStateInitInput() + default = TableFunctionInitInput() assert default.projection_ids is None diff --git a/tests/test_schema_utils.py b/tests/test_schema_utils.py index 568f5a6..6760233 100644 --- a/tests/test_schema_utils.py +++ b/tests/test_schema_utils.py @@ -267,7 +267,7 @@ def test_schema_with_function(self) -> None: from vgi import TableInOutFunction from vgi.arguments import Arguments - from vgi.function import Invocation + from vgi.function import Invocation, InvocationType class TestFunction(TableInOutFunction): @property @@ -277,10 +277,11 @@ def output_schema(self) -> pa.Schema: # Verify the schema is created correctly when function is used invocation = Invocation( function_name="test", - arguments=Arguments(), - in_out_function_input_schema=pa.schema([pa.field("x", pa.int64())]), + input_schema=pa.schema([pa.field("x", pa.int64())]), + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", + arguments=Arguments(), ) logger = structlog.get_logger() func = TestFunction(invocation=invocation, logger=logger) @@ -299,7 +300,7 @@ def test_schema_like_with_function(self) -> None: from vgi import TableInOutFunction from vgi.arguments import Arguments - from vgi.function import Invocation + from vgi.function import Invocation, InvocationType class TestFunction(TableInOutFunction): @property @@ -318,10 +319,11 @@ def output_schema(self) -> pa.Schema: invocation = Invocation( function_name="test", - arguments=Arguments(), - in_out_function_input_schema=input_schema, + input_schema=input_schema, + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", + arguments=Arguments(), ) logger = structlog.get_logger() func = TestFunction(invocation=invocation, logger=logger) diff --git a/tests/test_worker.py b/tests/test_worker.py index fcb904c..336ed23 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -5,7 +5,7 @@ from vgi import Arg, TableInOutFunction, TableInput from vgi.arguments import Arguments -from vgi.function import Invocation +from vgi.function import Invocation, InvocationType from vgi.worker import Worker @@ -25,10 +25,11 @@ class Meta: invocation = Invocation( function_name="single", - arguments=Arguments(), - in_out_function_input_schema=pa.schema([]), + input_schema=pa.schema([]), + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", + arguments=Arguments(), ) result = Worker._match_function(invocation, [SingleFunction]) @@ -70,7 +71,8 @@ class Meta: inv0 = Invocation( function_name="func", arguments=Arguments(positional=()), - in_out_function_input_schema=pa.schema([]), + input_schema=pa.schema([]), + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", ) @@ -80,7 +82,8 @@ class Meta: inv1 = Invocation( function_name="func", arguments=Arguments(positional=(pa.scalar(5),)), - in_out_function_input_schema=pa.schema([]), + input_schema=pa.schema([]), + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", ) @@ -90,7 +93,8 @@ class Meta: inv2 = Invocation( function_name="func", arguments=Arguments(positional=(pa.scalar(5), pa.scalar(10))), - in_out_function_input_schema=pa.schema([]), + input_schema=pa.schema([]), + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", ) @@ -121,7 +125,8 @@ class Meta: inv_with = Invocation( function_name="func", arguments=Arguments(positional=(pa.scalar(5),)), - in_out_function_input_schema=pa.schema([]), + input_schema=pa.schema([]), + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", ) @@ -132,7 +137,8 @@ class Meta: inv_without = Invocation( function_name="func", arguments=Arguments(positional=()), - in_out_function_input_schema=pa.schema([]), + input_schema=pa.schema([]), + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", ) @@ -166,7 +172,8 @@ class Meta: inv_format = Invocation( function_name="func", arguments=Arguments(positional=(), named={"format": pa.scalar("json")}), - in_out_function_input_schema=pa.schema([]), + input_schema=pa.schema([]), + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", ) @@ -176,7 +183,8 @@ class Meta: inv_sep = Invocation( function_name="func", arguments=Arguments(positional=(), named={"separator": pa.scalar(",")}), - in_out_function_input_schema=pa.schema([]), + input_schema=pa.schema([]), + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", ) @@ -198,7 +206,8 @@ class Meta: inv = Invocation( function_name="func", arguments=Arguments(positional=(pa.scalar(1), pa.scalar(2), pa.scalar(3))), - in_out_function_input_schema=pa.schema([]), + input_schema=pa.schema([]), + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", ) @@ -230,7 +239,8 @@ class Meta: inv = Invocation( function_name="func", arguments=Arguments(positional=()), - in_out_function_input_schema=pa.schema([]), + input_schema=pa.schema([]), + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", ) @@ -254,7 +264,8 @@ class Meta: inv = Invocation( function_name="func", arguments=Arguments(positional=(), named={"unknown": pa.scalar("x")}), - in_out_function_input_schema=pa.schema([]), + input_schema=pa.schema([]), + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", ) @@ -277,7 +288,8 @@ class Meta: inv = Invocation( function_name="func", arguments=Arguments(positional=(), named=None), - in_out_function_input_schema=pa.schema([]), + input_schema=pa.schema([]), + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", ) diff --git a/vgi/__init__.py b/vgi/__init__.py index bb52a1b..0aff109 100644 --- a/vgi/__init__.py +++ b/vgi/__init__.py @@ -48,7 +48,7 @@ def process(self, batch: pa.RecordBatch) -> OutputGenerator: from vgi import Worker class MyWorker(Worker): - registry = {"my_function": MyFunction} + functions = [MyFunction] if __name__ == "__main__": MyWorker().run() @@ -112,14 +112,15 @@ class Meta: CLASS HIERARCHY --------------- vgi.function.Function - Base (max_processes, invocation_id) - └─ vgi.table_function.TableFunctionBase - Adds cardinality hints, projection - ├─ TableInOutGeneratorFunction - Full streaming (process/finalize) - │ └─ TableInOutFunction - Callback API (transform/finish) - │ ├─ AggregationFunction - Reduce to summary - │ ├─ FilterFunction - Row filtering - │ └─ MapFunction - Column transformation - └─ ScalarFunctionGenerator - Single-column output (1:1 rows) - └─ ScalarFunction - Callback API (compute) + ├─ vgi.table_function.TableFunctionBase - Adds cardinality hints, projection + │ ├─ TableFunctionGenerator - Generate output without input + │ └─ TableInOutGeneratorFunction - Full streaming (process/finalize) + │ └─ TableInOutFunction - Callback API (transform/finish) + │ ├─ AggregationFunction - Reduce to summary + │ ├─ FilterFunction - Row filtering + │ └─ MapFunction - Column transformation + └─ ScalarFunctionGenerator - Single-column output (1:1 rows) + └─ ScalarFunction - Callback API (compute) Examples -------- @@ -146,7 +147,11 @@ class Meta: TableInputValidationError, functions_to_arrow, ) -from vgi.scalar_function import ScalarFunction, ScalarFunctionGenerator +from vgi.scalar_function import ( + ScalarFunction, + ScalarFunctionGenerator, + ScalarOutputGenerator, +) from vgi.schema_utils import schema, schema_like from vgi.table_function import RowCountMismatchError from vgi.table_in_out_function import ( @@ -187,6 +192,7 @@ class Meta: "RowCountMismatchError", "ScalarFunction", "ScalarFunctionGenerator", + "ScalarOutputGenerator", "StreamingGenerator", "TableInOutFunction", "TableInOutGeneratorFunction", diff --git a/vgi/arguments.py b/vgi/arguments.py index 0a18e1e..440cbd7 100644 --- a/vgi/arguments.py +++ b/vgi/arguments.py @@ -601,6 +601,8 @@ def _validate(self, value: ArgT) -> None: valid_range = self._describe_valid_range() # Numeric range validation + # Note: type: ignore needed because ArgT is generic - comparisons only valid + # for numeric types, but we can't express "ArgT when constraints are set" if self.ge is not None and value < self.ge: # type: ignore[operator] raise ArgumentValidationError( f"Argument '{arg_name}' is too small.", diff --git a/vgi/client/cli.py b/vgi/client/cli.py index 00f026c..44ba14f 100644 --- a/vgi/client/cli.py +++ b/vgi/client/cli.py @@ -317,7 +317,6 @@ def cli( function_name=function_name, arguments=Arguments(positional=positional_args, named={}), input=pf.iter_batches(), - projection_ids=list(projection_ids) if projection_ids else None, ) else: # Table-in-out function (with input) diff --git a/vgi/client/client.py b/vgi/client/client.py index b6b85a8..458be0c 100644 --- a/vgi/client/client.py +++ b/vgi/client/client.py @@ -71,11 +71,13 @@ from vgi.function import ( Arguments, + FunctionInitInput, GlobalInitResult, Invocation, + InvocationType, ) from vgi.ipc_utils import IPCError, read_ipc_batch -from vgi.table_function import GlobalStateInitInput +from vgi.table_function import TableFunctionInitInput # Configure structlog to write to stderr structlog.configure( @@ -370,6 +372,7 @@ def _initialize_stream_common( function_name: str, arguments: Arguments, input_schema: pa.Schema | None, + function_type: InvocationType, bind_result_callback: Callable[[pa.RecordBatch], None] | None, projection_ids: list[int] | None, ) -> tuple[_BindResult, GlobalInitResult, Invocation]: @@ -381,7 +384,7 @@ def _initialize_stream_common( 3. Invokes bind_result_callback if provided 4. Validates protocol version compatibility 5. Applies CPU/max_workers limits to max_processes - 6. Sends GlobalStateInitInput (with projection_ids) + 6. Sends init data (FunctionInitInput or TableFunctionInitInput) 7. Reads GlobalInitResult (shared state identifier for parallel workers) 8. Creates an Invocation with global_init_identifier for additional workers @@ -392,10 +395,13 @@ def _initialize_stream_common( to pass to the function. input_schema: Schema of input batches for table-in-out functions, or None for table functions that generate output without input. + function_type: Type of function being invoked (SCALAR or TABLE). + Determines what init data format is sent to the worker. bind_result_callback: Optional callback invoked with the raw bind result RecordBatch. Called before further processing. projection_ids: Optional list of column indices to project in the - output. Passed to the worker via GlobalStateInitInput. + output. Passed to the worker via TableFunctionInitInput (ignored + for scalar functions). Returns: A tuple of (bind_result, global_init_result, request_with_init): @@ -407,7 +413,7 @@ def _initialize_stream_common( Raises: ClientError: If the worker process is not started, or if reading the bind result or init result fails. - OSError: If writing the Invocation or GlobalStateInitInput fails. + OSError: If writing the Invocation or init data fails. ClientError: If the worker activated unsupported features. """ @@ -422,10 +428,11 @@ def _initialize_stream_common( initial_request = Invocation( function_name=function_name, - arguments=arguments, - in_out_function_input_schema=input_schema, + input_schema=input_schema, + function_type=function_type, correlation_id=self.correlation_id, invocation_id=None, + arguments=arguments, client_features=client_features, attach_id=self._attach_id, ) @@ -470,15 +477,17 @@ def _initialize_stream_common( ), ) - # Send global state init input - global_state_info_serialized_bytes = GlobalStateInitInput( - projection_ids=projection_ids - ).serialize() + if initial_request.function_type == InvocationType.SCALAR: + # Scalar functions use empty init input + init_serialized_bytes = FunctionInitInput().serialize() + else: + # Table functions (generator and table-in-out) use TableFunctionInitInput + init_serialized_bytes = TableFunctionInitInput( + projection_ids=projection_ids + ).serialize() - if self._stdin_sink.write(global_state_info_serialized_bytes) != len( - global_state_info_serialized_bytes - ): - raise OSError("Failed to write global state init input record batch") + if self._stdin_sink.write(init_serialized_bytes) != len(init_serialized_bytes): + raise OSError("Failed to write init record batch") # Read init result log.debug("reading_init_result") @@ -496,11 +505,12 @@ def _initialize_stream_common( # Create request with init for additional workers request_with_init = Invocation( function_name=function_name, - arguments=arguments, - in_out_function_input_schema=input_schema, + input_schema=input_schema, + function_type=function_type, correlation_id=self.correlation_id, invocation_id=bind_result.invocation_id, global_init_identifier=global_init_result, + arguments=arguments, ) return bind_result, global_init_result, request_with_init @@ -580,6 +590,7 @@ def _initialize_function_stream( function_name: str, arguments: Arguments, input_schema: pa.Schema | None, + function_type: InvocationType, bind_result_callback: Callable[[pa.RecordBatch], None] | None, projection_ids: list[int] | None, ) -> tuple[ipc.RecordBatchStreamWriter | None, ipc.RecordBatchStreamReader | None]: @@ -603,6 +614,7 @@ def _initialize_function_stream( arguments: Arguments container with positional and named arguments. input_schema: Schema of input batches for table-in-out functions, or None for table functions. + function_type: Type of function being invoked (SCALAR or TABLE). bind_result_callback: Optional callback invoked with the raw bind result RecordBatch. projection_ids: Optional list of column indices to project. @@ -625,6 +637,7 @@ def _initialize_function_stream( function_name=function_name, arguments=arguments, input_schema=input_schema, + function_type=function_type, bind_result_callback=bind_result_callback, projection_ids=projection_ids, ) @@ -1073,9 +1086,11 @@ def _process_batch_on_worker( output_batches.append(output_batch) + # status is None for scalar functions (which don't emit status) + # and means we're done with this batch if status == b"HAVE_MORE_OUTPUT": continue - elif status == b"NEED_MORE_INPUT": + elif status == b"NEED_MORE_INPUT" or status is None: break else: raise ClientError( @@ -1398,7 +1413,7 @@ def table_in_out_function( result RecordBatch before processing begins. Useful for inspecting output schema, max_processes, or cardinality hints. projection_ids: Optional list of column indices for column projection. - Passed to the worker via GlobalStateInitInput. + Passed to the worker via TableFunctionInitInput. Yields: Output RecordBatches from the function. In single-worker mode, output @@ -1443,6 +1458,7 @@ def table_in_out_function( function_name=function_name, arguments=arguments, input_schema=input_schema, + function_type=InvocationType.TABLE, bind_result_callback=bind_result_callback, projection_ids=projection_ids, ) @@ -1759,7 +1775,7 @@ def table_function( result RecordBatch before processing begins. Useful for inspecting output schema, max_processes, or cardinality hints. projection_ids: Optional list of column indices for column projection. - Passed to the worker via GlobalStateInitInput. + Passed to the worker via TableFunctionInitInput. Yields: Output RecordBatches from the function. In parallel mode @@ -1794,6 +1810,7 @@ def table_function( function_name=function_name, arguments=arguments, input_schema=None, + function_type=InvocationType.TABLE, bind_result_callback=bind_result_callback, projection_ids=projection_ids, ) @@ -1813,7 +1830,6 @@ def scalar_function( input: Iterator[pa.RecordBatch], arguments: Arguments | None = None, bind_result_callback: Callable[[pa.RecordBatch], None] | None = None, - projection_ids: list[int] | None = None, ) -> Generator[pa.RecordBatch, None, None]: """Invoke a scalar function on the worker and stream results. @@ -1842,8 +1858,6 @@ def scalar_function( bind_result_callback: Optional callback invoked with the raw bind result RecordBatch before processing begins. Useful for inspecting output schema or max_processes. - projection_ids: Optional list of column indices for column projection. - Passed to the worker via GlobalStateInitInput. Yields: Output RecordBatches from the function. Each output batch has a single @@ -1890,8 +1904,9 @@ def scalar_function( function_name=function_name, arguments=arguments, input_schema=input_schema, + function_type=InvocationType.SCALAR, bind_result_callback=bind_result_callback, - projection_ids=projection_ids, + projection_ids=None, # Scalar functions don't use projection ) # Use parallel processing for all cases (handles both single and diff --git a/vgi/examples/table.py b/vgi/examples/table.py index 9bfbd52..4e4a87b 100644 --- a/vgi/examples/table.py +++ b/vgi/examples/table.py @@ -24,10 +24,10 @@ from vgi.metadata import FunctionExample from vgi.table_function import ( CardinalityInfo, - GlobalStateInitInput, Output, OutputGenerator, TableFunctionGenerator, + TableFunctionInitInput, ) __all__ = [ @@ -476,7 +476,7 @@ def cardinality(self) -> CardinalityInfo: def perform_init(self, init_input: pa.RecordBatch) -> GlobalInitResult: """Populate the work queue with range chunks.""" # Parse init data and store in init_storage - self.init_data = GlobalStateInitInput.deserialize(init_input) + self.init_data = TableFunctionInitInput.deserialize(init_input) self.init_identifier = self.init_storage.create(self.init_data.serialize()) # Create work items for each chunk of the range diff --git a/vgi/function.py b/vgi/function.py index f2127bc..620fb61 100644 --- a/vgi/function.py +++ b/vgi/function.py @@ -2,21 +2,24 @@ This module defines the foundational classes used during function binding in the VGI protocol. When a client invokes a function, it sends Invocation -describing the function name, arguments, and input schema. The worker -returns an OutputSpec describing the output schema and execution hints. +describing the function name, arguments, input schema, and function type. +The worker returns an OutputSpec describing the output schema and execution hints. Classes: + InvocationType: Enum distinguishing scalar vs table invocation types. Arguments: Container for positional and named function arguments. - Invocation: Complete function invocation request (name, args, schema). + Invocation: Complete function invocation request (name, args, schema, type). OutputSpec: Result from binding a function (output schema, etc). + Function: Base class for all VGI functions. The Invocation and OutputSpec are serialized to Arrow IPC format for transmission between client and worker processes. See Also: + vgi.scalar_function: Scalar functions with 1:1 row transforms. + vgi.table_function: Table functions with cardinality hints. + vgi.table_in_out_function: Streaming table functions for batch transforms. vgi.log: LogLevel and LogMessage for function diagnostics. - vgi.table_function: Extended bind results with cardinality hints. - vgi.table_in_out_function: Streaming table functions built on these primitives. """ @@ -24,7 +27,16 @@ import sqlite3 import uuid from dataclasses import dataclass, replace -from typing import TYPE_CHECKING, Any, ClassVar, Protocol, Self, TypeVar +from enum import Enum +from functools import cached_property +from typing import ( + Any, + ClassVar, + Protocol, + Self, + TypeVar, + final, +) import pyarrow as pa import structlog @@ -34,25 +46,176 @@ from vgi.log import Level, Message from vgi.metadata import MetadataMixin, ResolvedMetadata -if TYPE_CHECKING: - pass - __all__ = [ "Arg", "ArgumentValidationError", "Arguments", "Function", + "InvocationType", "GlobalInitResult", "Level", "Message", "OutputSpec", "Invocation", + "SchemaValidationError", "Serializable", "SqliteInitStorage", "SqliteWorkerStateStorage", ] +class SchemaValidationError(Exception): + """Raised when a batch schema doesn't match the expected schema. + + This error is raised by the framework during input/output validation. + It indicates a programming error where a batch doesn't conform to the + declared schema. + + The error message includes detailed information about what differs: + - Missing fields (in expected but not in actual) + - Extra fields (in actual but not in expected) + - Type mismatches (same field name, different types) + - Field order differences + + Attributes: + expected: The expected Arrow schema. + actual: The actual Arrow schema that was received. + context: Description of where the validation occurred. + + """ + + def __init__( + self, + message: str, + *, + expected: "pa.Schema | None" = None, + actual: "pa.Schema | None" = None, + context: str = "", + ) -> None: + """Initialize with schema comparison details. + + Args: + message: Base error message. + expected: The expected Arrow schema. + actual: The actual Arrow schema. + context: Where the error occurred (e.g., "output from transform()"). + + """ + self.expected = expected + self.actual = actual + self.context = context + + if expected is not None and actual is not None: + full_message = self._build_detailed_message(message, expected, actual) + else: + full_message = message + + super().__init__(full_message) + + def _build_detailed_message( + self, base_message: str, expected: "pa.Schema", actual: "pa.Schema" + ) -> str: + """Build a detailed message showing exactly what differs.""" + lines = [base_message, ""] + + if self.context: + lines.append(f" Context: {self.context}") + lines.append("") + + # Build field maps for comparison + expected_fields = {f.name: f for f in expected} + actual_fields = {f.name: f for f in actual} + + expected_names = set(expected_fields.keys()) + actual_names = set(actual_fields.keys()) + + # Find differences + missing = expected_names - actual_names + extra = actual_names - expected_names + common = expected_names & actual_names + + # Check for type mismatches in common fields + type_mismatches = [] + for name in common: + exp_field = expected_fields[name] + act_field = actual_fields[name] + if exp_field.type != act_field.type: + type_mismatches.append((name, exp_field.type, act_field.type)) + elif exp_field.nullable != act_field.nullable: + exp_null = "nullable" if exp_field.nullable else "non-nullable" + act_null = "nullable" if act_field.nullable else "non-nullable" + type_mismatches.append((name, exp_null, act_null)) + + # Check for order differences (only if names match but order differs) + order_differs = False + if not missing and not extra and not type_mismatches: + expected_order = [f.name for f in expected] + actual_order = [f.name for f in actual] + if expected_order != actual_order: + order_differs = True + + # Report missing fields + if missing: + lines.append(" Missing fields (expected but not found):") + for name in sorted(missing): + field = expected_fields[name] + lines.append(f" - {name}: {field.type}") + + # Report extra fields + if extra: + lines.append(" Extra fields (found but not expected):") + for name in sorted(extra): + field = actual_fields[name] + lines.append(f" - {name}: {field.type}") + + # Report type mismatches + if type_mismatches: + lines.append(" Type mismatches:") + for name, exp_type, act_type in type_mismatches: + lines.append(f" - {name}: expected {exp_type}, got {act_type}") + + # Report order differences + if order_differs: + lines.append(" Field order differs:") + lines.append(f" Expected: {[f.name for f in expected]}") + lines.append(f" Actual: {[f.name for f in actual]}") + + # Summary of schemas + lines.append("") + lines.append(" Expected schema:") + for field in expected: + nullable = " (nullable)" if field.nullable else "" + lines.append(f" {field.name}: {field.type}{nullable}") + + lines.append(" Actual schema:") + for field in actual: + nullable = " (nullable)" if field.nullable else "" + lines.append(f" {field.name}: {field.type}{nullable}") + + return "\n".join(lines) + + +class InvocationType(Enum): + """Type of VGI invocation for protocol dispatch. + + Used by the client to determine the correct init data format to send + to the worker. Scalar functions use FunctionInitInput (no projection), + while table functions use TableFunctionInitInput (with projection support). + + Note: This is distinct from vgi.metadata.FunctionType which is used for + DuckDB catalog registration and includes AGGREGATE. + + Attributes: + SCALAR: Scalar function that transforms input batches to single-column output. + TABLE: Table function (either generator or table-in-out) that produces + multi-column output. + + """ + + SCALAR = "scalar" + TABLE = "table" + + class Serializable(Protocol): """Protocol for objects that can be serialized to/from bytes. @@ -176,21 +339,23 @@ class Invocation: Invocation encapsulates all information needed to bind and execute a function: the function name, its arguments, the expected input schema (for table - functions), and identifiers for logging and correlation. + functions), the function type, and identifiers for logging and correlation. This is serialized to Arrow IPC format and sent as the first message when the client connects to a worker subprocess. Attributes: function_name: Name of the function to invoke, must exist in worker registry. - arguments: Positional and named arguments passed to the function. - in_out_function_input_schema: Arrow schema of input data (required for - Function, None for scalar functions or functions that don't - process input tables). + input_schema: Arrow schema of input data. Required for table-in-out and + scalar functions that process input batches. None for table functions + that generate output without input. + function_type: Type of function being invoked (SCALAR or TABLE). Used by + the client to determine the correct init data format to send. correlation_id: String identifier for logging and correlation purposes. invocation_id: Unique bytes identifying this function binding. Used to correlate multiple parallel workers processing the same logical call. global_init_identifier: Optional result from global initialization phase. + arguments: Positional and named arguments passed to the function. client_features: Feature flags supported by the client. The worker will respond with active_features in OutputSpec indicating which features will be used for this invocation. @@ -201,16 +366,18 @@ class Invocation: Example: invocation = Invocation( function_name="sum_columns", - arguments=Arguments(positional=("col1", "col2")), - in_out_function_input_schema=pa.schema([pa.field("col1", pa.int64())]), + input_schema=pa.schema([pa.field("col1", pa.int64())]), + function_type=InvocationType.TABLE, correlation_id="request-123", invocation_id=None, # Set by worker after binding + arguments=Arguments(positional=("col1", "col2")), ) """ function_name: str - in_out_function_input_schema: pa.Schema | None + input_schema: pa.Schema | None + function_type: InvocationType correlation_id: str # The unique identifier for the call, typically this may be a uuid. @@ -248,11 +415,12 @@ def serialize(self) -> bytes: { "function_name": self.function_name, "arguments": args_dict, - "in_out_function_input_schema": ( - self.in_out_function_input_schema.serialize().to_pybytes() - if self.in_out_function_input_schema + "input_schema": ( + self.input_schema.serialize().to_pybytes() + if self.input_schema else None ), + "function_type": self.function_type.value, "invocation_id": self.invocation_id, "correlation_id": self.correlation_id, GlobalInitResult._IDENTIFIER_FIELD_NAME: ( @@ -268,9 +436,8 @@ def serialize(self) -> bytes: [ pa.field("function_name", pa.string(), nullable=False), pa.field("arguments", args_struct_type, nullable=True), - pa.field( - "in_out_function_input_schema", pa.binary(), nullable=True - ), + pa.field("input_schema", pa.binary(), nullable=True), + pa.field("function_type", pa.string(), nullable=False), pa.field("invocation_id", pa.binary(), nullable=True), pa.field("correlation_id", pa.string(), nullable=False), pa.field( @@ -303,7 +470,8 @@ def deserialize(data: pa.RecordBatch) -> "Invocation": required_fields = [ "function_name", "arguments", - "in_out_function_input_schema", + "input_schema", + "function_type", "invocation_id", "correlation_id", ] @@ -311,11 +479,12 @@ def deserialize(data: pa.RecordBatch) -> "Invocation": data, "Invocation", required_fields=required_fields ) - in_out_function_input_schema = None - if first_row["in_out_function_input_schema"] is not None: - in_out_function_input_schema = pa.ipc.read_schema( - pa.py_buffer(first_row["in_out_function_input_schema"]) - ) + input_schema = None + if first_row["input_schema"] is not None: + input_schema = pa.ipc.read_schema(pa.py_buffer(first_row["input_schema"])) + + # Parse function_type from string value + function_type = InvocationType(first_row["function_type"]) # Parse global_init_identifier - only create GlobalInitResult if field exists # and has a non-None value @@ -339,8 +508,9 @@ def deserialize(data: pa.RecordBatch) -> "Invocation": return Invocation( function_name=first_row["function_name"], + input_schema=input_schema, + function_type=function_type, arguments=Arguments.decode(data.column("arguments")[0]), - in_out_function_input_schema=in_out_function_input_schema, invocation_id=first_row["invocation_id"], correlation_id=first_row["correlation_id"], global_init_identifier=global_init_identifier, @@ -787,7 +957,54 @@ def cleanup_queue(self, invocation_id: bytes) -> int: conn.close() -class Function(MetadataMixin): +class FunctionInitInput: + """Input sent to initialize global state for a Function. + + This is the base init input class for functions that don't require + any initialization data (like scalar functions). It serializes to + an empty single-row batch. + """ + + def serialize(self) -> bytes: + """Serialize FunctionInitInput to bytes. + + Creates a single-row batch with an empty schema. The batch must have + exactly 1 row so that deserialize can access row 0. + """ + # Create a batch with 1 row using a struct array approach + struct_array: pa.StructArray = pa.array([{}], type=pa.struct([])) # type: ignore[assignment] + batch = pa.RecordBatch.from_struct_array(struct_array) + return vgi.ipc_utils.serialize_record_batch(batch) + + @classmethod + def deserialize(cls, _batch: pa.RecordBatch) -> Self: + """Deserialize FunctionInitInput from a RecordBatch. + + Args: + _batch: RecordBatch (unused - FunctionInitInput has no fields). + + Returns: + New FunctionInitInput instance. + + """ + return cls() + + @classmethod + def deserialize_bytes(cls, data: bytes) -> Self: + """Deserialize FunctionInitInput from bytes. + + Args: + data: Serialized bytes. + + Returns: + New FunctionInitInput instance. + + """ + batch = vgi.ipc_utils.deserialize_record_batch(data) + return cls.deserialize(batch) + + +class Function[T: FunctionInitInput](MetadataMixin): """Base class for all VGI functions. Functions are instantiated with Invocation describing the invocation, @@ -830,12 +1047,16 @@ class Meta: init_storage: ClassVar[SqliteInitStorage] = SqliteInitStorage() state_storage: ClassVar[SqliteWorkerStateStorage] = SqliteWorkerStateStorage() + # Cache for resolved metadata + _metadata_cache: ClassVar[ResolvedMetadata | None] = None + # The unique identifier for init data in storage. Set by perform_init() # or retrieve_init(). Used to correlate parallel workers and for state storage. init_identifier: bytes | None = None - # Cache for resolved metadata - _metadata_cache: ClassVar[ResolvedMetadata | None] = None + # This is the init data that may be been read. + InitDataCls: type[T] + init_data: T | None = None def __init__( self, @@ -887,27 +1108,6 @@ def create_invocation_id(self) -> bytes: """ return uuid.uuid4().bytes - def perform_init(self, init_input: pa.RecordBatch) -> GlobalInitResult: - """Perform any global initialization required before processing. - - This method is called once per worker process before any data - batches are processed. Override to set up shared resources, load - models, or perform expensive setup tasks. - - Args: - init_input: An initial RecordBatch that may contain configuration - or context information for initialization. - - """ - # If there is an id supplied, detect it so it will be passed on. - if GlobalInitResult.has_identifier(init_input): - return GlobalInitResult.deserialize(init_input) - - return GlobalInitResult() - - def retrieve_init(self, init_input: GlobalInitResult) -> None: - """Retrieve init data from storage (default does nothing).""" - @property def output_schema(self) -> pa.Schema: """Return the output schema (must be implemented by subclass).""" @@ -1062,3 +1262,98 @@ def process(self) -> OutputGenerator: "retrieve_init() for secondary workers." ) return self.state_storage.dequeue_work(self.init_identifier) + + @final + @cached_property + def empty_output_batch(self) -> pa.RecordBatch: + """Return an empty batch conforming to output_schema. Cached.""" + output_schema = self.output_schema + return pa.RecordBatch.from_arrays( + [pa.array([], type=field.type) for field in output_schema], + schema=output_schema, + ) + + @final + def _validate_output_schema(self, batch: pa.RecordBatch) -> None: + """Validate that a batch conforms to the expected output schema.""" + if batch.schema != self.output_schema: + raise SchemaValidationError( + "Output batch schema does not match expected output_schema.", + expected=self.output_schema, + actual=batch.schema, + context=f"output from {type(self).__name__}", + ) + + @property + def input_schema(self) -> pa.Schema: + """Return the input schema from the invocation. + + This property is available for functions that receive input batches + (ScalarFunction, TableInOutFunction). For TableFunctionGenerator, + the invocation.input_schema is None. + + Raises: + ValueError: If invocation.input_schema is None. + + """ + if self.invocation.input_schema is None: + raise ValueError( + "input_schema is not available for this function type. " + "TableFunctionGenerator does not receive input batches." + ) + return self.invocation.input_schema + + @final + def _validate_input_schema(self, batch: pa.RecordBatch) -> None: + """Validate that a batch conforms to the expected input schema.""" + if batch.schema != self.input_schema: + raise SchemaValidationError( + "Input batch schema does not match expected input_schema.", + expected=self.input_schema, + actual=batch.schema, + context=f"input to {type(self).__name__}", + ) + + def perform_init(self, init_input: pa.RecordBatch) -> GlobalInitResult: + """Perform a new init call and store it in the storage.""" + self.init_data = self.InitDataCls.deserialize(init_input) + assert self.init_data is not None + self.init_identifier = self.init_storage.create(self.init_data.serialize()) + return GlobalInitResult(self.init_identifier) + + def retrieve_init(self, init_input: GlobalInitResult) -> None: + """Retrieve and store init data from the storage.""" + if init_input.global_init_identifier is None: + raise ValueError( + "global_init_identifier is required but was None. " + "This indicates the GlobalInitResult was not properly initialized. " + "Ensure perform_init() returns a GlobalInitResult with a valid " + "identifier." + ) + self.init_identifier = init_input.global_init_identifier + self.init_data = self.InitDataCls.deserialize_bytes( + self.init_storage.get(self.init_identifier) + ) + + def setup(self) -> None: + """Acquire resources before processing starts. + + Override to acquire resources like database connections, file handles, + or external service clients. Called after init_data is available. + + Available at this point: + - self.init_data: The init data (FunctionInitInput or + TableFunctionInitInput) + - self.init_identifier: Storage key for distributed state + - self.invocation: The complete invocation request + + """ + pass + + def teardown(self) -> None: + """Release resources after processing completes. + + Always called, even if an error occurred during processing. + + """ + pass diff --git a/vgi/scalar_function.py b/vgi/scalar_function.py index d7d6a00..3e73573 100644 --- a/vgi/scalar_function.py +++ b/vgi/scalar_function.py @@ -3,15 +3,16 @@ Scalar functions transform input batches to single-column output. Scalar functions receive input batches and produce output batches where: -1. Output row count must exactly match input row count (1:1 mapping) +1. Each Output yield must have row count matching input (1:1 mapping) 2. Output schema has exactly one column +3. Message yields (for logging) are exempt from row count validation This module provides: - ScalarFunctionGenerator: Generator-based base class (like TableInOutGeneratorFunction) - ScalarFunction: Callback-based API with compute() method (like TableInOutFunction) Class Hierarchy: - TableFunctionBase (vgi.table_function) + Function (vgi.function) └── ScalarFunctionGenerator (generator protocol, validates row count) └── ScalarFunction (callback API with compute()) @@ -32,25 +33,21 @@ import vgi.function import vgi.log -import vgi.table_function -from vgi.table_function import RowCountMismatchError, SchemaValidationError +from vgi.function import SchemaValidationError +from vgi.table_function import Output, ProtocolOutput, RowCountMismatchError __all__ = [ "ScalarFunctionGenerator", "ScalarFunction", "Output", - "OutputGenerator", + "ScalarOutputGenerator", "ProtocolInput", ] -# Protocol types - reuse Output/OutputGenerator from table_in_out_function -from vgi.table_in_out_function import ( # noqa: E402 - Output, - OutputGenerator, - ProtocolOutput, - _OutputStatus, -) +# Scalar functions must always produce output - None is not valid +# (unlike OutputGenerator which allows None for buffering/aggregation) +ScalarOutputGenerator = Generator[vgi.log.Message | Output, pa.RecordBatch | None, None] @dataclass(frozen=True, slots=True) @@ -73,44 +70,36 @@ class ProtocolInput: @dataclass(frozen=True, slots=True) class _ScalarOutputComplete: - """Internal: Output with guaranteed non-None batch for scalar functions. - - Similar to _OutputComplete in table_in_out_function, but tracks the input - batch for row count validation. - """ + """Internal: Output with guaranteed non-None batch for scalar functions.""" batch: pa.RecordBatch - has_more: bool = False log_message: vgi.log.Message | None = None @classmethod def from_process_result( cls, - source: vgi.log.Message | Output | None, + source: vgi.log.Message | Output, empty_batch: pa.RecordBatch, ) -> _ScalarOutputComplete: """Create from user's yield value. Args: - source: What the user yielded (Output, Message, or None). - empty_batch: Empty batch to substitute when needed. + source: What the user yielded (Output or Message). + empty_batch: Empty batch to substitute when yielding Message. Returns: Normalized output with guaranteed non-None batch. """ - if source is None: - return cls(batch=empty_batch) if isinstance(source, vgi.log.Message): - return cls(batch=empty_batch, has_more=True, log_message=source) + return cls(batch=empty_batch, log_message=source) # source is Output return cls( batch=source.batch if source.batch is not None else empty_batch, - has_more=source.has_more, ) -class ScalarFunctionGenerator(vgi.table_function.TableFunctionBase): +class ScalarFunctionGenerator(vgi.function.Function[vgi.function.FunctionInitInput]): """Base class for scalar functions with generator protocol. Scalar functions transform input batches to single-column output with @@ -119,10 +108,11 @@ class ScalarFunctionGenerator(vgi.table_function.TableFunctionBase): - Must produce exactly one output row per input row - Must have exactly one column in output_schema - Override process() for full generator control. Can yield Output or Message: + Override process() for full generator control. Must yield Output or Message + (unlike table-in-out functions, yielding None is not allowed): - def process(self, batch: pa.RecordBatch) -> OutputGenerator: - _ = yield None # Priming yield + def process(self, batch: pa.RecordBatch) -> ScalarOutputGenerator: + _ = yield Output(self.empty_output_batch) # Priming yield while True: # Optional: yield log messages yield Message(Level.INFO, f"Processing {batch.num_rows} rows") @@ -140,15 +130,8 @@ def process(self, batch: pa.RecordBatch) -> OutputGenerator: output_schema -> pa.Schema (property) Override to define the single-column output schema. - process(batch: pa.RecordBatch) -> OutputGenerator - Generator that processes input batches. Must yield Output with - batch.num_rows matching input batch.num_rows. - - setup() -> None - Called before processing starts. Default: no-op. - - teardown() -> None - Called after processing completes. Default: no-op. + process(batch: pa.RecordBatch) -> ScalarOutputGenerator + Generator that processes input batches. Must yield Output or Message. AVAILABLE ATTRIBUTES -------------------- @@ -158,6 +141,8 @@ def process(self, batch: pa.RecordBatch) -> OutputGenerator: self.empty_output_batch - Empty batch conforming to output_schema """ + InitDataCls = vgi.function.FunctionInitInput + def __init__( self, invocation: vgi.function.Invocation, @@ -165,42 +150,24 @@ def __init__( ): """Initialize the scalar function with invocation data and logger.""" super().__init__(invocation=invocation, logger=logger) - if invocation.in_out_function_input_schema is None: + if invocation.input_schema is None: raise ValueError( f"{type(self).__name__} requires an input schema, but none was " f"provided. ScalarFunction processes input batches and requires " - f"in_out_function_input_schema to be set in the Invocation." + f"input_schema to be set in the Invocation." ) # Validate single-column output at construction if len(self.output_schema) != 1: + cols = [f.name for f in self.output_schema] raise SchemaValidationError( f"ScalarFunction must have exactly 1 output column, " - f"got {len(self.output_schema)}: {self.output_schema}" + f"but output_schema has {len(self.output_schema)} columns.\n\n" + f" Columns found: {cols}\n\n" + f" Scalar functions transform each input row to a single value.\n" + f" If you need multiple output columns, use TableInOutFunction." ) - @property - def input_schema(self) -> pa.Schema: - """Return the input schema from the invocation.""" - # Validated as non-None in __init__ - assert self.invocation.in_out_function_input_schema is not None - return self.invocation.in_out_function_input_schema - - def teardown(self) -> None: - """Release resources after processing completes. - - Override to release resources acquired in setup(). - Always called, even if an error occurred during processing. - """ - pass - - @final - def _validate_input_schema(self, batch: pa.RecordBatch) -> None: - """Validate that a batch conforms to the expected input schema.""" - if batch.schema != self.input_schema: - raise SchemaValidationError( - f"Input batch schema does not match expected input_schema. " - f"Expected: {self.input_schema}, got: {batch.schema}" - ) + # input_schema property and _validate_input_schema inherited from Function @final def _validate_row_count( @@ -209,14 +176,16 @@ def _validate_row_count( """Validate that output row count matches input row count.""" if output_batch.num_rows != input_batch.num_rows: raise RowCountMismatchError( - f"ScalarFunction output must have same row count as input. " - f"Input: {input_batch.num_rows}, Output: {output_batch.num_rows}" + "Scalar function output must have same row count as input.", + input_rows=input_batch.num_rows, + output_rows=output_batch.num_rows, + function_name=type(self).__name__, ) @final def _process_and_validate( self, - generator: OutputGenerator, + generator: ScalarOutputGenerator, input_batch: pa.RecordBatch, ) -> _ScalarOutputComplete: """Process a batch and validate schemas and row count. @@ -239,15 +208,15 @@ def _process_and_validate( self.empty_output_batch, ) self._validate_output_schema(result.batch) - # Only validate row count for actual output, not log messages - if result.log_message is None and result.batch.num_rows > 0: + # Validate row count for actual output (not log messages) + if result.log_message is None: self._validate_row_count(result.batch, input_batch) return result @final def _process_with_exception_handling( self, - generator: OutputGenerator, + generator: ScalarOutputGenerator, input_batch: pa.RecordBatch, ) -> _ScalarOutputComplete: """Process a batch with exception handling. @@ -272,12 +241,10 @@ def _should_terminate(self, result: _ScalarOutputComplete) -> bool: ) @abstractmethod - def process(self, batch: pa.RecordBatch) -> OutputGenerator: + def process(self, batch: pa.RecordBatch) -> ScalarOutputGenerator: """Process input batches. Override this method to implement your scalar transformation. - The generator must yield Output with batch.num_rows matching - input batch.num_rows. Args: batch: First input batch (subsequent batches via yield return). @@ -285,7 +252,6 @@ def process(self, batch: pa.RecordBatch) -> OutputGenerator: Yields: Output: Batch with same row count as input. Message: Log message (input will be re-sent). - None: No output (ready for next batch). """ ... @@ -303,9 +269,7 @@ def run(self) -> Generator[ProtocolOutput, ProtocolInput | None, None]: - When input exhausted, caller closes the generator """ # Priming yield - caller calls next() or send(None) - input: ProtocolInput | None = yield ProtocolOutput( - batch=None, status=_OutputStatus.NEED_MORE_INPUT - ) + input: ProtocolInput | None = yield ProtocolOutput(batch=None) if input is None: raise ValueError("Expected ProtocolInput, got None") @@ -321,16 +285,8 @@ def run(self) -> Generator[ProtocolOutput, ProtocolInput | None, None]: while True: result = self._process_with_exception_handling(generator, input.batch) - # Determine status based on result - has_more_output = result.has_more or result.log_message is not None - if has_more_output: - status = _OutputStatus.HAVE_MORE_OUTPUT - else: - status = _OutputStatus.NEED_MORE_INPUT - input = yield ProtocolOutput( batch=result.batch, - status=status, log_message=result.log_message, ) if input is None: @@ -359,9 +315,6 @@ class ScalarFunction(ScalarFunctionGenerator): Transform the input batch to a single output array. Must return an array with exactly batch.num_rows elements. - output_name -> str (property, optional) - Return the name of the output column. Default: "result" - LOGGING ------- Call self.log(level, message) from compute() to emit log messages: @@ -415,11 +368,6 @@ def compute(self, batch: pa.RecordBatch) -> pa.Array: """ self._pending_messages.append(vgi.log.Message(level=level, message=message)) - @property - def output_name(self) -> str: - """Return the name of the output column. Override to customize.""" - return "result" - @property @abstractmethod def output_type(self) -> pa.DataType: @@ -444,7 +392,7 @@ def output_type(self) -> pa.DataType: @final def output_schema(self) -> pa.Schema: """Return single-column output schema. Do not override.""" - return pa.schema([pa.field(self.output_name, self.output_type)]) + return pa.schema([pa.field("result", self.output_type)]) @abstractmethod def compute(self, batch: pa.RecordBatch) -> pa.Array[Any]: @@ -466,20 +414,21 @@ def compute(self, batch: pa.RecordBatch) -> pa.Array[Any]: ... @final - def _yield_pending_messages(self) -> OutputGenerator: + def _yield_pending_messages(self) -> ScalarOutputGenerator: """Yield all pending log messages. Helper for process().""" while self._pending_messages: msg = self._pending_messages.pop(0) _ = yield msg @final - def process(self, batch: pa.RecordBatch) -> OutputGenerator: + def process(self, batch: pa.RecordBatch) -> ScalarOutputGenerator: """Convert compute() to generator protocol. Do not override. This method implements the generator protocol by calling your compute() method for each input batch. """ - _ = yield None # Priming yield + # Priming yield + _ = yield Output(self.empty_output_batch) while True: result = self.compute(batch) diff --git a/vgi/table_function.py b/vgi/table_function.py index 7cbfb6a..5596285 100644 --- a/vgi/table_function.py +++ b/vgi/table_function.py @@ -23,7 +23,6 @@ from collections.abc import Generator from dataclasses import dataclass -from functools import cached_property from typing import Any, final import pyarrow as pa @@ -35,35 +34,98 @@ __all__ = [ "CardinalityInfo", - "GlobalStateInitInput", + "TableFunctionInitInput", "Output", "OutputGenerator", "OutputSpec", "ProtocolOutput", "RowCountMismatchError", - "SchemaValidationError", "TableFunctionBase", "TableFunctionGenerator", ] -class SchemaValidationError(Exception): - """Raised when a batch schema doesn't match the expected schema. - - This error is raised by the framework during input/output validation. - It indicates a programming error where a batch doesn't conform to the - declared schema. - """ - - class RowCountMismatchError(Exception): """Raised when scalar function output row count doesn't match input. Scalar functions must produce exactly one output row for each input row. This error indicates the compute() method returned an array with the wrong number of elements. + + Attributes: + input_rows: Number of rows in the input batch. + output_rows: Number of rows in the output batch. + function_name: Name of the function that produced the mismatch. + """ + def __init__( + self, + message: str, + *, + input_rows: int | None = None, + output_rows: int | None = None, + function_name: str = "", + ) -> None: + """Initialize with row count details. + + Args: + message: Base error message. + input_rows: Number of input rows. + output_rows: Number of output rows. + function_name: Name of the function class. + + """ + self.input_rows = input_rows + self.output_rows = output_rows + self.function_name = function_name + + if input_rows is not None and output_rows is not None: + full_message = self._build_detailed_message( + message, input_rows, output_rows + ) + else: + full_message = message + + super().__init__(full_message) + + def _build_detailed_message( + self, base_message: str, input_rows: int, output_rows: int + ) -> str: + """Build a detailed, helpful error message.""" + lines = [base_message, ""] + + if self.function_name: + lines.append(f" Function: {self.function_name}") + + lines.append(f" Input rows: {input_rows}") + lines.append(f" Output rows: {output_rows}") + + # Provide specific guidance based on the mismatch type + lines.append("") + if output_rows < input_rows: + lines.append(" Problem: Output has fewer rows than input.") + lines.append("") + lines.append(" Possible causes:") + lines.append(" - compute() is filtering rows (not allowed)") + lines.append(" - compute() is aggregating (use TableInOutFunction)") + lines.append(" - Bug in array construction") + lines.append("") + lines.append(" If you need to filter or aggregate rows:") + lines.append(" Use TableInOutFunction instead of ScalarFunction.") + else: + lines.append(" Problem: Output has more rows than input.") + lines.append("") + lines.append(" Possible causes:") + lines.append(" - compute() is expanding rows (not allowed)") + lines.append(" - compute() is unnesting arrays") + lines.append(" - Bug in array construction") + lines.append("") + lines.append(" If you need to expand rows (1→N mapping):") + lines.append(" Use TableInOutFunction instead of ScalarFunction.") + + return "\n".join(lines) + @dataclass(frozen=True, slots=True) class CardinalityInfo: @@ -92,44 +154,6 @@ class CardinalityInfo: max: int | None -@dataclass(frozen=True, slots=True) -class GlobalStateInitInput: - """Input sent to initialize global state for a TableFunction. - - Attributes: - projection_ids: Optional list of column indices to project, or None for all. - - Note: - For parallel execution, functions should use the work queue pattern - via enqueue_work() and dequeue_work() methods on the Function base class - instead of static partitioning. - - """ - - projection_ids: list[int] | None = None - - def serialize(self) -> bytes: - """Serialize GlobalStateInitInput to bytes.""" - batch = pa.RecordBatch.from_arrays( - [pa.array([self.projection_ids], type=pa.list_(pa.int32()))], - schema=pa.schema([pa.field("projection_ids", pa.list_(pa.int32()))]), - ) - return vgi.ipc_utils.serialize_record_batch(batch) - - @staticmethod - def deserialize(batch: pa.RecordBatch) -> "GlobalStateInitInput": - """Deserialize GlobalStateInitInput from a RecordBatch.""" - values = batch.to_pylist()[0] - # Handle backward compatibility: ignore extra fields - return GlobalStateInitInput(projection_ids=values.get("projection_ids")) - - @staticmethod - def deserialize_bytes(data: bytes) -> "GlobalStateInitInput": - """Deserialize GlobalStateInitInput from bytes.""" - batch = vgi.ipc_utils.deserialize_record_batch(data) - return GlobalStateInitInput.deserialize(batch) - - @dataclass(frozen=True, slots=True) class OutputSpec(vgi.function.OutputSpec): """Extended bind result for table functions with cardinality information. @@ -282,15 +306,49 @@ def from_process_result(cls, process_result: "_OutputComplete") -> "ProtocolOutp ) -class TableFunctionBase(vgi.function.Function): +@dataclass(frozen=True, slots=True) +class TableFunctionInitInput(vgi.function.FunctionInitInput): + """Input sent to initialize global state for a TableFunction. + + Attributes: + projection_ids: Optional list of column indices to project, or None for all. + + Note: + For parallel execution, functions should use the work queue pattern + via enqueue_work() and dequeue_work() methods on the Function base class + instead of static partitioning. + + """ + + projection_ids: list[int] | None = None + + def serialize(self) -> bytes: + """Serialize TableFunctionInitInput to bytes.""" + batch = pa.RecordBatch.from_arrays( + [pa.array([self.projection_ids], type=pa.list_(pa.int32()))], + schema=pa.schema([pa.field("projection_ids", pa.list_(pa.int32()))]), + ) + return vgi.ipc_utils.serialize_record_batch(batch) + + @classmethod + def deserialize(cls, batch: pa.RecordBatch) -> "TableFunctionInitInput": + """Deserialize TableFunctionInitInput from a RecordBatch.""" + values = batch.to_pylist()[0] + # Handle backward compatibility: ignore extra fields + return cls(projection_ids=values.get("projection_ids")) + + @classmethod + def deserialize_bytes(cls, data: bytes) -> "TableFunctionInitInput": + """Deserialize TableFunctionInitInput from bytes.""" + batch = vgi.ipc_utils.deserialize_record_batch(data) + return cls.deserialize(batch) + + +class TableFunctionBase(vgi.function.Function[TableFunctionInitInput]): """Base class for table functions with cardinality and schema validation. Extends Function with: - Cardinality hints for query optimization - - Output schema validation - - Setup/teardown lifecycle hooks - - Empty output batch helper (cached) - - Init data storage and retrieval - Projection pushdown support This class is not meant to be used directly. Subclass either: @@ -298,7 +356,7 @@ class TableFunctionBase(vgi.function.Function): - TableInOutGeneratorFunction: For functions that transform input batches Attributes: - init_data: GlobalStateInitInput with projection info (set after init) + init_data: TableFunctionInitInput with projection info (set after init) empty_output_batch: Cached empty batch conforming to output_schema See Also: @@ -307,8 +365,8 @@ class TableFunctionBase(vgi.function.Function): """ - # This is the init data that may be been read. - init_data: GlobalStateInitInput | None = None + InitDataCls = TableFunctionInitInput + init_data: TableFunctionInitInput | None = None def __init__( self, @@ -326,28 +384,6 @@ def __init__( """ super().__init__(invocation=invocation, logger=logger) - def setup(self) -> None: - """Acquire resources before processing starts. - - Override to acquire resources like database connections, file handles, - or external service clients. Called after init_data is available. - - Available at this point: - - self.init_data: The GlobalStateInitInput with projection info - - self.init_identifier: Storage key for distributed state - - self.invocation: The complete invocation request - - """ - pass - - def teardown(self) -> None: - """Release resources after processing completes. - - Always called, even if an error occurred during processing. - - """ - pass - def cardinality(self) -> CardinalityInfo | None: """Return optional cardinality estimate for the output. @@ -360,50 +396,6 @@ def cardinality(self) -> CardinalityInfo | None: """ return None - @final - @cached_property - def empty_output_batch(self) -> pa.RecordBatch: - """Return an empty batch conforming to output_schema. Cached.""" - output_schema = self.output_schema - return pa.RecordBatch.from_arrays( - [pa.array([], type=field.type) for field in output_schema], - schema=output_schema, - ) - - @final - def _validate_output_schema(self, batch: pa.RecordBatch) -> None: - """Validate that a batch conforms to the expected output schema.""" - if batch.schema != self.output_schema: - raise SchemaValidationError( - f"Output batch schema does not match expected output_schema. " - f"Expected: {self.output_schema}, got: {batch.schema}" - ) - - @property - def output_schema(self) -> pa.Schema: - """Return the output schema (default: passthrough input schema).""" - raise NotImplementedError("Subclasses must implement output_schema property") - - def perform_init(self, init_input: pa.RecordBatch) -> vgi.function.GlobalInitResult: - """Perform a new init call and store it in the storage.""" - self.init_data = GlobalStateInitInput.deserialize(init_input) - self.init_identifier = self.init_storage.create(self.init_data.serialize()) - return vgi.function.GlobalInitResult(self.init_identifier) - - def retrieve_init(self, init_input: vgi.function.GlobalInitResult) -> None: - """Retrieve and store init data from the storage.""" - if init_input.global_init_identifier is None: - raise ValueError( - "global_init_identifier is required but was None. " - "This indicates the GlobalInitResult was not properly initialized. " - "Ensure perform_init() returns a GlobalInitResult with a valid " - "identifier." - ) - self.init_identifier = init_input.global_init_identifier - self.init_data = GlobalStateInitInput.deserialize_bytes( - self.init_storage.get(self.init_identifier) - ) - def apply_projection(self, schema: pa.Schema) -> pa.Schema: """Apply any projection specified in the init data to the schema. diff --git a/vgi/table_in_out_function.py b/vgi/table_in_out_function.py index 9720629..7288609 100644 --- a/vgi/table_in_out_function.py +++ b/vgi/table_in_out_function.py @@ -523,7 +523,7 @@ def process(self, batch: pa.RecordBatch) -> OutputGenerator: invocation = vgi.function.Invocation( function_name="my_function", arguments=vgi.function.Arguments(positional=[], named={}), - in_out_function_input_schema=input_schema, + input_schema=input_schema, correlation_id="", invocation_id=None, ) @@ -553,21 +553,16 @@ def __init__( ): """Initialize the function with invocation data and logger.""" super().__init__(invocation=invocation, logger=logger) - if invocation.in_out_function_input_schema is None: + if invocation.input_schema is None: raise ValueError( f"{type(self).__name__} requires an input schema, but none was " f"provided. TableInOutGeneratorFunction processes input batches and " - f"requires in_out_function_input_schema to be set in the Invocation. " + f"requires input_schema to be set in the Invocation. " f"If your function generates output without input, inherit from " f"TableFunctionGenerator instead." ) - @property - def input_schema(self) -> pa.Schema: - """Return the input schema from the invocation.""" - # Validated as non-None in __init__ - assert self.invocation.in_out_function_input_schema is not None - return self.invocation.in_out_function_input_schema + # input_schema property inherited from Function def teardown(self) -> None: """Release resources after processing completes. @@ -586,14 +581,7 @@ def output_schema(self) -> pa.Schema: """Return the output schema (default: passthrough input schema).""" return self.input_schema - @final - def _validate_input_schema(self, batch: pa.RecordBatch) -> None: - """Validate that a batch conforms to the expected input schema.""" - if batch.schema != self.input_schema: - raise vgi.table_function.SchemaValidationError( - f"Input batch schema does not match expected input_schema. " - f"Expected: {self.input_schema}, got: {batch.schema}" - ) + # _validate_input_schema inherited from Function @final def _process_and_validate( diff --git a/vgi/testing.py b/vgi/testing.py index bf26aaf..df85a71 100644 --- a/vgi/testing.py +++ b/vgi/testing.py @@ -82,7 +82,7 @@ import structlog import structlog.stdlib -from vgi.function import Arguments, Invocation +from vgi.function import Arguments, Invocation, InvocationType from vgi.log import Level, Message from vgi.scalar_function import ( ProtocolInput as ScalarProtocolInput, @@ -92,11 +92,11 @@ ScalarFunctionGenerator, ) from vgi.table_function import ( - GlobalStateInitInput, - TableFunctionGenerator, + ProtocolOutput as TableProtocolOutput, ) from vgi.table_function import ( - ProtocolOutput as TableProtocolOutput, + TableFunctionGenerator, + TableFunctionInitInput, ) from vgi.table_in_out_function import ( ProtocolInput, @@ -213,10 +213,11 @@ def table_in_out_function( invocation_id = uuid.uuid4().bytes invocation = Invocation( function_name=self.function_class.__name__, - arguments=arguments, - in_out_function_input_schema=input_schema, + input_schema=input_schema, + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=invocation_id, + arguments=arguments, ) # Instantiate function @@ -245,8 +246,8 @@ def table_in_out_function( ) bind_result_callback(bind_batch) - # Perform init with GlobalStateInitInput - init_input = GlobalStateInitInput(projection_ids=projection_ids) + # Perform init with TableFunctionInitInput + init_input = TableFunctionInitInput(projection_ids=projection_ids) init_batch = pa.RecordBatch.from_arrays( [pa.array([init_input.projection_ids], type=pa.list_(pa.int32()))], schema=pa.schema([pa.field("projection_ids", pa.list_(pa.int32()))]), @@ -423,17 +424,18 @@ def table_function( invocation_id = uuid.uuid4().bytes invocation = Invocation( function_name=self.function_class.__name__, - arguments=arguments, - in_out_function_input_schema=None, + input_schema=None, + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=invocation_id, + arguments=arguments, ) # Instantiate function func = self.function_class(invocation=invocation, logger=self._logger) - # Perform init with GlobalStateInitInput - init_input = GlobalStateInitInput(projection_ids=projection_ids) + # Perform init with TableFunctionInitInput + init_input = TableFunctionInitInput(projection_ids=projection_ids) init_batch = pa.RecordBatch.from_arrays( [pa.array([init_input.projection_ids], type=pa.list_(pa.int32()))], schema=pa.schema([pa.field("projection_ids", pa.list_(pa.int32()))]), @@ -998,10 +1000,11 @@ def scalar_function( invocation_id = uuid.uuid4().bytes invocation = Invocation( function_name=self.function_class.__name__, - arguments=arguments, - in_out_function_input_schema=input_schema, + input_schema=input_schema, + function_type=InvocationType.SCALAR, correlation_id="test", invocation_id=invocation_id, + arguments=arguments, ) # Instantiate function @@ -1035,8 +1038,7 @@ def scalar_function( # Prime the generator try: - priming_output = next(generator) - assert priming_output.status == _OutputStatus.NEED_MORE_INPUT + next(generator) # Priming output is discarded except StopIteration: return @@ -1052,10 +1054,10 @@ def scalar_function( def _process_scalar_batch( self, - generator: Generator[ProtocolOutput, ScalarProtocolInput | None, None], + generator: Generator[TableProtocolOutput, ScalarProtocolInput | None, None], batch: pa.RecordBatch, ) -> Generator[pa.RecordBatch, None, None]: - """Process a single input batch, handling HAVE_MORE_OUTPUT for logs.""" + """Process a single input batch, handling log messages.""" while True: try: output = generator.send(ScalarProtocolInput(batch=batch)) @@ -1068,23 +1070,15 @@ def _process_scalar_batch( # Check for exception if output.log_message.level == Level.EXCEPTION: raise FunctionTestClientError(output.log_message.message) + # Re-send the same batch to get actual output after log + continue # Yield output batch if it has rows if output.batch is not None and output.batch.num_rows > 0: yield output.batch - # Check status - if output.status == _OutputStatus.HAVE_MORE_OUTPUT: - # Re-send the same batch to get more output (log messages) - continue - elif output.status == _OutputStatus.NEED_MORE_INPUT: - # Ready for next input batch - break - elif output.status == _OutputStatus.FINISHED: - # Scalar function ended - return - else: - raise FunctionTestClientError(f"Unexpected status: {output.status}") + # No log message means we're done with this batch + break def run_scalar_function( diff --git a/vgi/worker.py b/vgi/worker.py index 025a0d0..29567c7 100644 --- a/vgi/worker.py +++ b/vgi/worker.py @@ -5,12 +5,15 @@ SUPPORTED FUNCTION TYPES ------------------------ -The worker supports two function types, dispatched based on class inheritance: +The worker supports three function types, dispatched based on class inheritance: -1. TableInOutGeneratorFunction: Reads input batches, produces output batches. +1. ScalarFunctionGenerator: Transforms input batches to single-column output + with 1:1 row mapping. Use for per-row computations like add(), upper(), etc. + +2. TableInOutGeneratorFunction: Reads input batches, produces output batches. Use for transforming, filtering, or aggregating input data. -2. TableFunctionGenerator: Generates output batches without reading input. +3. TableFunctionGenerator: Generates output batches without reading input. Use for data generation functions like sequence(), range(), random_sample(). QUICK START @@ -18,9 +21,14 @@ Create a worker by subclassing Worker and listing your functions: from vgi.worker import Worker + from vgi.scalar_function import ScalarFunction from vgi.table_in_out_function import TableInOutGeneratorFunction from vgi.table_function import TableFunctionGenerator + class DoubleColumn(ScalarFunction): + # Single-column output with 1:1 row mapping + ... + class EchoFunction(TableInOutGeneratorFunction): # Transforms input batches ... @@ -30,7 +38,7 @@ class SequenceFunction(TableFunctionGenerator): ... class MyWorker(Worker): - functions = [EchoFunction, SequenceFunction] + functions = [DoubleColumn, EchoFunction, SequenceFunction] if __name__ == "__main__": MyWorker().run() @@ -38,11 +46,19 @@ class MyWorker(Worker): Function names are derived from metadata (Meta.name or class name converted to snake_case). No manual name mapping required. +PROTOCOL FLOW (ScalarFunctionGenerator) +--------------------------------------- +1. Read Invocation: function name, arguments, input schema +2. Write OutputSpec: output schema, max_processes, invocation_id +3. Read/write FunctionInitInput/GlobalInitResult for initialization +4. Stream: read input batches -> compute -> write single-column output batches + (ends when input exhausted, no FINALIZE phase) + PROTOCOL FLOW (TableInOutGeneratorFunction) ------------------------------------------- 1. Read Invocation: function name, arguments, input schema 2. Write OutputSpec: output schema, max_processes, invocation_id -3. Read/write GlobalStateInitInput/GlobalInitResult for initialization +3. Read/write TableFunctionInitInput/GlobalInitResult for initialization 4. Stream: read input batches -> process -> write output batches 5. Finalize: receive FINALIZE signal -> emit final results @@ -50,7 +66,7 @@ class MyWorker(Worker): -------------------------------------- 1. Read Invocation: function name, arguments (no input schema) 2. Write OutputSpec: output schema, max_processes, invocation_id -3. Read/write GlobalStateInitInput/GlobalInitResult for initialization +3. Read/write TableFunctionInitInput/GlobalInitResult for initialization 4. Generate: produce output batches until generator exhausted KEY CLASSES @@ -71,7 +87,7 @@ class MyWorker(Worker): from collections.abc import Sequence from dataclasses import dataclass from io import IOBase -from typing import cast +from typing import Any, cast import pyarrow as pa import structlog @@ -82,6 +98,7 @@ class MyWorker(Worker): Function, Invocation, OutputSpec, + SchemaValidationError, ) from vgi.ipc_utils import read_ipc_batch from vgi.scalar_function import ProtocolInput as ScalarProtocolInput @@ -90,7 +107,6 @@ class MyWorker(Worker): from vgi.table_in_out_function import ( ProtocolInput, TableInOutGeneratorFunction, - _OutputStatus, ) @@ -131,11 +147,11 @@ class MyWorker(Worker): """ - functions: Sequence[type[Function]] = [] - _registry: dict[str, list[type[Function]]] | None = None + functions: Sequence[type[Function[Any]]] = [] + _registry: dict[str, list[type[Function[Any]]]] | None = None @classmethod - def _build_registry(cls) -> dict[str, list[type[Function]]]: + def _build_registry(cls) -> dict[str, list[type[Function[Any]]]]: """Build function name -> list of classes mapping from functions list. Multiple functions can share the same name if they have different @@ -144,7 +160,7 @@ def _build_registry(cls) -> dict[str, list[type[Function]]]: if cls._registry is not None: return cls._registry - registry: dict[str, list[type[Function]]] = {} + registry: dict[str, list[type[Function[Any]]]] = {} for func_cls in cls.functions: meta = func_cls.get_metadata() if meta.name not in registry: @@ -157,8 +173,8 @@ def _build_registry(cls) -> dict[str, list[type[Function]]]: @staticmethod def _match_function( invocation: Invocation, - candidates: Sequence[type[Function]], - ) -> type[Function]: + candidates: Sequence[type[Function[Any]]], + ) -> type[Function[Any]]: """Find the function that matches the invocation's arguments. Compares the invocation's positional and named arguments against each @@ -179,7 +195,7 @@ def _match_function( num_positional = len(args.positional) named_keys = set(args.named.keys()) if args.named else set() - matches: list[type[Function]] = [] + matches: list[type[Function[Any]]] = [] for func_cls in candidates: meta = func_cls.get_metadata() @@ -242,6 +258,50 @@ def _match_function( return matches[0] + @staticmethod + def _suggest_similar_names(name: str, candidates: list[str]) -> list[str]: + """Find function names similar to the given name. + + Uses prefix matching, substring matching, and character overlap to + suggest likely alternatives for typos. + + Args: + name: The unknown function name. + candidates: List of valid function names. + + Returns: + List of similar names, sorted by relevance. + + """ + if not candidates: + return [] + + name_lower = name.lower() + scored: list[tuple[int, str]] = [] + + for candidate in candidates: + candidate_lower = candidate.lower() + + # Exact prefix match (highest priority) + if candidate_lower.startswith(name_lower): + scored.append((0, candidate)) + elif name_lower.startswith(candidate_lower): + scored.append((1, candidate)) + # Substring matches + elif name_lower in candidate_lower or candidate_lower in name_lower: + scored.append((2, candidate)) + else: + # Character overlap score (for typos) + name_chars = set(name_lower) + candidate_chars = set(candidate_lower) + overlap = len(name_chars & candidate_chars) + # Require at least half the characters to match + if overlap > len(name_lower) // 2: + scored.append((10 - overlap, candidate)) + + scored.sort(key=lambda x: (x[0], x[1])) + return [candidate for _, candidate in scored] + def __init__(self) -> None: """Initialize the worker with structured logging.""" structlog.configure( @@ -311,11 +371,12 @@ def _process_scalar_batches( ipc.open_stream(cast(IOBase, sys.stdin)) as data_reader, ): # Validate data stream schema matches expected input schema - if data_reader.schema != invocation.in_out_function_input_schema: - expected = invocation.in_out_function_input_schema - raise ValueError( - f"Data stream schema mismatch. Expected: {expected}, " - f"got: {data_reader.schema}" + if data_reader.schema != invocation.input_schema: + raise SchemaValidationError( + "Data stream schema does not match expected input schema.", + expected=invocation.input_schema, + actual=data_reader.schema, + context="input stream to scalar function", ) batch_count = 0 @@ -343,8 +404,8 @@ def _process_scalar_batches( protocol_input = ScalarProtocolInput(batch=batch, metadata=metadata) output = generator.send(protocol_input) - # Handle log messages (HAVE_MORE_OUTPUT) - while output.status == _OutputStatus.HAVE_MORE_OUTPUT: + # Handle log messages (indicated by log_message being set) + while output.log_message is not None: fn_log.debug("log_message_received", output=output) assert output.batch is not None writer.write_batch( @@ -364,7 +425,6 @@ def _process_scalar_batches( "batch_written", batch_index=batch_count, output_rows=output_rows, - status=output.status.value if output.status else None, ) return WorkerStats( batch_count=batch_count, @@ -402,11 +462,12 @@ def _process_batches( ipc.open_stream(cast(IOBase, sys.stdin)) as data_reader, ): # Validate data stream schema matches expected input schema - if data_reader.schema != invocation.in_out_function_input_schema: - expected = invocation.in_out_function_input_schema - raise ValueError( - f"Data stream schema mismatch. Expected: {expected}, " - f"got: {data_reader.schema}" + if data_reader.schema != invocation.input_schema: + raise SchemaValidationError( + "Data stream schema does not match expected input schema.", + expected=invocation.input_schema, + actual=data_reader.schema, + context="input stream to table-in-out function", ) batch_count = 0 @@ -448,7 +509,6 @@ def _process_batches( "batch_written", batch_index=batch_count, output_rows=output_rows, - status=output.status.value if output.status else None, ) return WorkerStats( batch_count=batch_count, @@ -509,16 +569,25 @@ def run(self) -> None: fn_log = self.log.bind(function=invocation.function_name) fn_log.info("init_received", arguments=invocation.arguments) - fn_log.debug( - "input_schema_parsed", schema=str(invocation.in_out_function_input_schema) - ) + fn_log.debug("input_schema_parsed", schema=str(invocation.input_schema)) registry = self._build_registry() if invocation.function_name not in registry: available = sorted(registry.keys()) - raise ValueError( - f"Unknown function: {invocation.function_name}. Available: {available}" + suggestions = self._suggest_similar_names( + invocation.function_name, available ) + msg_lines = [ + f"Unknown function: '{invocation.function_name}'", + "", + ] + if suggestions: + msg_lines.append(" Did you mean:") + for suggestion in suggestions[:3]: + msg_lines.append(f" - {suggestion}") + msg_lines.append("") + msg_lines.append(f" Available functions: {available}") + raise ValueError("\n".join(msg_lines)) candidates = registry[invocation.function_name] func_cls = self._match_function(invocation, candidates) From 6da56f17f7970aa1699ba48bbcdb73d0ab041a02 Mon Sep 17 00:00:00 2001 From: Rusty Conover Date: Sat, 3 Jan 2026 22:59:57 -0500 Subject: [PATCH 6/6] ty fixes --- pyproject.toml | 5 +++++ vgi/examples/scalar.py | 4 ++-- vgi/table_function.py | 6 +++--- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3bd59d8..6c39f3b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,11 @@ warn_unused_ignores = true module = "structlog.*" ignore_missing_imports = true +[[tool.mypy.overrides]] +# These files have type: ignore comments for ty that mypy doesn't need +module = ["vgi.examples.scalar", "vgi.table_function"] +warn_unused_ignores = false + [tool.ty.environment] python-version = "3.12" diff --git a/vgi/examples/scalar.py b/vgi/examples/scalar.py index 021df37..5a92d64 100644 --- a/vgi/examples/scalar.py +++ b/vgi/examples/scalar.py @@ -52,7 +52,7 @@ def output_type(self) -> pa.DataType: def compute(self, batch: pa.RecordBatch) -> pa.Array[Any]: """Double the values in the specified column.""" - return pc.multiply(batch.column(self.column), 2) + return pc.multiply(batch.column(self.column), 2) # type: ignore[no-matching-overload] class AddColumnsFunction(ScalarFunction): @@ -109,4 +109,4 @@ def output_type(self) -> pa.DataType: def compute(self, batch: pa.RecordBatch) -> pa.Array[Any]: """Convert the column values to uppercase.""" - return pc.utf8_upper(batch.column(self.column)) + return pc.utf8_upper(batch.column(self.column)) # type: ignore[no-matching-overload] diff --git a/vgi/table_function.py b/vgi/table_function.py index 5596285..4f6a68d 100644 --- a/vgi/table_function.py +++ b/vgi/table_function.py @@ -23,7 +23,7 @@ from collections.abc import Generator from dataclasses import dataclass -from typing import Any, final +from typing import Any, Self, final import pyarrow as pa import structlog @@ -331,14 +331,14 @@ def serialize(self) -> bytes: return vgi.ipc_utils.serialize_record_batch(batch) @classmethod - def deserialize(cls, batch: pa.RecordBatch) -> "TableFunctionInitInput": + def deserialize(cls, batch: pa.RecordBatch) -> Self: # type: ignore[override] """Deserialize TableFunctionInitInput from a RecordBatch.""" values = batch.to_pylist()[0] # Handle backward compatibility: ignore extra fields return cls(projection_ids=values.get("projection_ids")) @classmethod - def deserialize_bytes(cls, data: bytes) -> "TableFunctionInitInput": + def deserialize_bytes(cls, data: bytes) -> Self: """Deserialize TableFunctionInitInput from bytes.""" batch = vgi.ipc_utils.deserialize_record_batch(data) return cls.deserialize(batch)