diff --git a/CLAUDE.md b/CLAUDE.md index e5d687a..c13e192 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -40,6 +40,9 @@ VGI (Vector Gateway Interface) provides an Apache Arrow-based protocol for conne │ ▼ │ │ ┌───────────────────────────────────────────────────────────────┐ │ │ │ Worker Process │ │ +│ │ SCALAR FUNCTION (ScalarFunction) │ │ +│ │ - compute(batch): Transform each row to single output column │ │ +│ │ OR │ │ │ │ TABLE FUNCTION (TableFunctionGenerator) │ │ │ │ - process(): Generator yielding output batches (no input) │ │ │ │ OR │ │ @@ -52,6 +55,7 @@ VGI (Vector Gateway Interface) provides an Apache Arrow-based protocol for conne | Type | Base Class | Input | Use Case | |------|------------|-------|----------| +| **Scalar Function** | `ScalarFunction` | Batches | Per-row transforms (1:1 row mapping, single output column) | | **Table Function** | `TableFunctionGenerator` | None | Generate data (sequences, ranges) | | **Table-In-Out Function** | `TableInOutFunction` | Batches | Transform, filter, aggregate | @@ -59,6 +63,7 @@ VGI (Vector Gateway Interface) provides an Apache Arrow-based protocol for conne - **Worker** (`vgi/worker.py`): Subprocess that hosts functions - **Client** (`vgi/client/client.py`): Spawns workers, streams data +- **ScalarFunction** (`vgi/scalar_function.py`): Base for scalar functions - **TableFunctionGenerator** (`vgi/table_function.py`): Base for table functions - **TableInOutFunction** (`vgi/table_in_out_function.py`): Base for table-in-out functions @@ -67,7 +72,8 @@ VGI (Vector Gateway Interface) provides an Apache Arrow-based protocol for conne ``` vgi/ __init__.py # Package exports - function.py # Invocation, OutputSpec, Arguments, GlobalInitResult + function.py # Invocation, OutputSpec, Arguments, FunctionType + scalar_function.py # ScalarFunction, ScalarFunctionGenerator table_function.py # TableFunctionGenerator, CardinalityInfo, Output table_in_out_function.py # TableInOutFunction, TableInOutGeneratorFunction metadata.py # Function metadata for introspection @@ -76,6 +82,7 @@ vgi/ client/ client.py # Client class examples/ + scalar.py # Example scalar functions table.py # Example table functions table_in_out.py # Example table-in-out functions worker.py # ExampleWorker with registry @@ -89,6 +96,32 @@ vgi-client --input data.parquet --function echo --server vgi-example-worker vgi-client --input data.parquet --function sum_all_columns --server vgi-example-worker ``` +## Creating a Scalar Function (Per-Row Transform) + +```python +import pyarrow as pa +import pyarrow.compute as pc +from vgi import ScalarFunction, Arg + +class DoubleColumn(ScalarFunction): + """Double the value in a specified column.""" + + column = Arg[str](0, doc="Column to double") + + @property + def output_type(self) -> pa.DataType: + # Output type matches input column type + return self.input_schema.field(self.column).type + + def compute(self, batch: pa.RecordBatch) -> pa.Array: + return pc.multiply(batch.column(self.column), 2) +``` + +### Key Constraints for Scalar Functions: +- **1:1 row mapping**: Output must have exactly the same number of rows as input +- **Single column output**: Output schema has exactly one column named "result" +- **No finalize phase**: All processing happens in compute() + ## Creating a Table-In-Out Function (Recommended) ```python @@ -182,6 +215,9 @@ if __name__ == "__main__": ### Imports ```python +# Scalar Functions (per-row transform) +from vgi import ScalarFunction, Arg, Worker + # Table Functions (no input) from vgi import TableFunctionGenerator, Output, Arg, Worker @@ -221,6 +257,17 @@ output_schema = schema_like(self.input_schema, rename={"old": "new"}) ### Method Override Summary +**ScalarFunction:** + +| Method | When to Override | Default | +|--------|------------------|---------| +| `output_type` | Define output column type | Required | +| `compute(batch)` | Transform batch to single array | Required | +| `setup()` | Acquire resources | No-op | +| `teardown()` | Release resources | No-op | + +**TableInOutFunction:** + | Method | When to Override | Default | |--------|------------------|---------| | `output_schema` | Change output columns | Returns input_schema | @@ -232,17 +279,22 @@ output_schema = schema_like(self.input_schema, rename={"old": "new"}) ### Pattern Decision Tree ``` -Need to implement a VGI function? -│ -├─ Does the function receive input data? -│ │ -│ ├─ NO → Use TableFunctionGenerator -│ │ Override process() to yield Output batches -│ │ -│ └─ YES → Use TableInOutFunction -│ ├─ Transform each batch? → Override transform() -│ ├─ Aggregate results? → Accumulate in transform(), emit in finish() -│ └─ Need generator control? → See docs/generator-api.md +How will your function be used in SQL? + +1. SELECT my_func(col1, col2) FROM table + → SCALAR FUNCTION: Returns one value per input row + → Use ScalarFunction, override output_type and compute() + → Example: upper(), abs(), concat() + +2. SELECT * FROM my_func(args) + → TABLE FUNCTION: Generates rows from arguments (no input table) + → Use TableFunctionGenerator, override process() + → Example: range(), read_csv(), glob() + +3. SELECT * FROM my_func(args, (SELECT * FROM input_table)) + → TABLE-IN-OUT FUNCTION: Transforms input rows to output rows + → Use TableInOutFunction, override transform() and optionally finish() + → Example: filtering, enrichment, aggregation ``` ## Additional Documentation diff --git a/pyproject.toml b/pyproject.toml index 3bd59d8..6c39f3b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,11 @@ warn_unused_ignores = true module = "structlog.*" ignore_missing_imports = true +[[tool.mypy.overrides]] +# These files have type: ignore comments for ty that mypy doesn't need +module = ["vgi.examples.scalar", "vgi.table_function"] +warn_unused_ignores = false + [tool.ty.environment] python-version = "3.12" diff --git a/tests/client/test_cli.py b/tests/client/test_cli.py index f1f9b69..ea33f93 100644 --- a/tests/client/test_cli.py +++ b/tests/client/test_cli.py @@ -580,3 +580,284 @@ def test_stdout_output_json(self, example_worker: str) -> None: assert result.exit_code == 0 # Should have JSON output assert "n" in result.output + + +class TestCLIScalarFunction: + """Tests for CLI scalar function invocation with --type scalar.""" + + @pytest.fixture + def scalar_input_parquet(self, tmp_path: Path) -> Path: + """Create a parquet file suitable for scalar function tests.""" + batch = pa.RecordBatch.from_pydict({"x": [1, 2, 3, 4, 5]}) + input_file = tmp_path / "scalar_input.parquet" + pq.write_table(pa.Table.from_batches([batch]), str(input_file)) + return input_file + + def test_scalar_function_invocation( + self, example_worker: str, scalar_input_parquet: Path + ) -> None: + """Invoke a scalar function with --type scalar.""" + runner = CliRunner() + result = runner.invoke( + cli, + [ + "--input", + str(scalar_input_parquet), + "--function", + "double_column", + "--args", + '["x"]', + "--type", + "scalar", + "--server", + example_worker, + ], + ) + assert result.exit_code == 0 + + def test_scalar_function_with_output_file( + self, example_worker: str, scalar_input_parquet: Path, tmp_path: Path + ) -> None: + """Scalar function with output to file.""" + output_file = tmp_path / "output.jsonl" + runner = CliRunner() + result = runner.invoke( + cli, + [ + "--input", + str(scalar_input_parquet), + "--output", + str(output_file), + "--format", + "json", + "--function", + "double_column", + "--args", + '["x"]', + "--type", + "scalar", + "--server", + example_worker, + ], + ) + assert result.exit_code == 0 + assert output_file.exists() + lines = output_file.read_text().strip().split("\n") + # Should have 5 rows + assert len(lines) == 5 + # Verify first row is doubled + first_row = json.loads(lines[0]) + assert first_row["result"] == 2 + + def test_scalar_function_parquet_output( + self, example_worker: str, scalar_input_parquet: Path, tmp_path: Path + ) -> None: + """Scalar function with parquet output.""" + output_file = tmp_path / "output.parquet" + runner = CliRunner() + result = runner.invoke( + cli, + [ + "--input", + str(scalar_input_parquet), + "--output", + str(output_file), + "--format", + "parquet", + "--function", + "double_column", + "--args", + '["x"]', + "--type", + "scalar", + "--server", + example_worker, + ], + ) + assert result.exit_code == 0 + # Verify parquet output + table = pq.read_table(str(output_file)) + assert table.num_rows == 5 + assert table.column_names == ["result"] + assert table.column("result").to_pylist() == [2, 4, 6, 8, 10] + + def test_scalar_type_requires_input(self, example_worker: str) -> None: + """--type scalar requires --input.""" + runner = CliRunner() + result = runner.invoke( + cli, + [ + "--function", + "double_column", + "--args", + '["x"]', + "--type", + "scalar", + "--server", + example_worker, + ], + ) + assert result.exit_code != 0 + assert "requires --input" in result.output + + def test_table_in_out_type_requires_input(self, example_worker: str) -> None: + """--type table-in-out requires --input.""" + runner = CliRunner() + result = runner.invoke( + cli, + [ + "--function", + "echo", + "--type", + "table-in-out", + "--server", + example_worker, + ], + ) + assert result.exit_code != 0 + assert "requires --input" in result.output + + def test_table_type_rejects_input( + self, example_worker: str, scalar_input_parquet: Path + ) -> None: + """--type table does not accept --input.""" + runner = CliRunner() + result = runner.invoke( + cli, + [ + "--input", + str(scalar_input_parquet), + "--function", + "sequence", + "--args", + "[5]", + "--type", + "table", + "--server", + example_worker, + ], + ) + assert result.exit_code != 0 + assert "does not accept --input" in result.output + + def test_auto_type_with_input_uses_table_in_out( + self, example_worker: str, scalar_input_parquet: Path, tmp_path: Path + ) -> None: + """--type auto with --input uses table-in-out (echo function).""" + output_file = tmp_path / "output.jsonl" + runner = CliRunner() + result = runner.invoke( + cli, + [ + "--input", + str(scalar_input_parquet), + "--output", + str(output_file), + "--format", + "json", + "--function", + "echo", + "--type", + "auto", + "--server", + example_worker, + ], + ) + assert result.exit_code == 0 + # Echo should preserve the original column name "x" + content = output_file.read_text() + assert '"x"' in content + + def test_auto_type_without_input_uses_table( + self, example_worker: str, tmp_path: Path + ) -> None: + """--type auto without --input uses table function.""" + output_file = tmp_path / "output.jsonl" + runner = CliRunner() + result = runner.invoke( + cli, + [ + "--output", + str(output_file), + "--format", + "json", + "--function", + "sequence", + "--args", + "[3]", + "--type", + "auto", + "--server", + example_worker, + ], + ) + assert result.exit_code == 0 + lines = output_file.read_text().strip().split("\n") + assert len(lines) == 3 + + def test_scalar_with_add_columns(self, example_worker: str, tmp_path: Path) -> None: + """Test add_columns scalar function via CLI.""" + # Create input with two columns + batch = pa.RecordBatch.from_pydict({"a": [1, 2, 3], "b": [10, 20, 30]}) + input_file = tmp_path / "input.parquet" + pq.write_table(pa.Table.from_batches([batch]), str(input_file)) + + output_file = tmp_path / "output.jsonl" + runner = CliRunner() + result = runner.invoke( + cli, + [ + "--input", + str(input_file), + "--output", + str(output_file), + "--format", + "json", + "--function", + "add_columns", + "--args", + '["a", "b"]', + "--type", + "scalar", + "--server", + example_worker, + ], + ) + assert result.exit_code == 0 + lines = output_file.read_text().strip().split("\n") + assert len(lines) == 3 + # Verify sums + results = [json.loads(line)["result"] for line in lines] + assert results == [11, 22, 33] + + def test_scalar_with_upper_case(self, example_worker: str, tmp_path: Path) -> None: + """Test upper_case scalar function via CLI.""" + batch = pa.RecordBatch.from_pydict({"name": ["alice", "bob"]}) + input_file = tmp_path / "input.parquet" + pq.write_table(pa.Table.from_batches([batch]), str(input_file)) + + output_file = tmp_path / "output.jsonl" + runner = CliRunner() + result = runner.invoke( + cli, + [ + "--input", + str(input_file), + "--output", + str(output_file), + "--format", + "json", + "--function", + "upper_case", + "--args", + '["name"]', + "--type", + "scalar", + "--server", + example_worker, + ], + ) + assert result.exit_code == 0 + lines = output_file.read_text().strip().split("\n") + results = [json.loads(line)["result"] for line in lines] + assert results == ["ALICE", "BOB"] diff --git a/tests/scalar/__init__.py b/tests/scalar/__init__.py new file mode 100644 index 0000000..487759c --- /dev/null +++ b/tests/scalar/__init__.py @@ -0,0 +1 @@ +"""Tests for scalar functions.""" diff --git a/tests/scalar/test_client.py b/tests/scalar/test_client.py new file mode 100644 index 0000000..1e19bb8 --- /dev/null +++ b/tests/scalar/test_client.py @@ -0,0 +1,303 @@ +"""End-to-end tests for scalar functions via Client subprocess.""" + +from __future__ import annotations + +from typing import cast + +import pyarrow as pa +import pytest + +from vgi.client import Client +from vgi.client.client import ClientError +from vgi.function import Arguments + + +class TestScalarFunctionClient: + """Tests for scalar functions via Client subprocess.""" + + def test_double_column_basic(self, example_worker: str) -> None: + """Test basic scalar function via Client.""" + schema = pa.schema([("x", pa.int64())]) + batch = pa.RecordBatch.from_pydict({"x": [1, 2, 3]}, schema=schema) + + with Client(example_worker) as client: + outputs = list( + client.scalar_function( + function_name="double_column", + input=iter([batch]), + arguments=Arguments(positional=(pa.scalar("x"),)), + ) + ) + + assert len(outputs) == 1 + assert outputs[0].to_pydict() == {"result": [2, 4, 6]} + + def test_add_columns(self, example_worker: str) -> None: + """Test add_columns scalar function.""" + schema = pa.schema([("a", pa.int64()), ("b", pa.int64())]) + batch = pa.RecordBatch.from_pydict( + {"a": [1, 2, 3], "b": [10, 20, 30]}, schema=schema + ) + + with Client(example_worker) as client: + outputs = list( + client.scalar_function( + function_name="add_columns", + input=iter([batch]), + arguments=Arguments(positional=(pa.scalar("a"), pa.scalar("b"))), + ) + ) + + assert len(outputs) == 1 + assert outputs[0].to_pydict() == {"result": [11, 22, 33]} + + def test_upper_case(self, example_worker: str) -> None: + """Test upper_case scalar function.""" + schema = pa.schema([("name", pa.string())]) + batch = pa.RecordBatch.from_pydict( + {"name": ["alice", "bob", "charlie"]}, schema=schema + ) + + with Client(example_worker) as client: + outputs = list( + client.scalar_function( + function_name="upper_case", + input=iter([batch]), + arguments=Arguments(positional=(pa.scalar("name"),)), + ) + ) + + assert len(outputs) == 1 + assert outputs[0].to_pydict() == {"result": ["ALICE", "BOB", "CHARLIE"]} + + def test_multiple_batches(self, example_worker: str) -> None: + """Test scalar function with multiple input batches.""" + schema = pa.schema([("x", pa.int64())]) + batch1 = pa.RecordBatch.from_pydict({"x": [1, 2]}, schema=schema) + batch2 = pa.RecordBatch.from_pydict({"x": [3, 4, 5]}, schema=schema) + batch3 = pa.RecordBatch.from_pydict({"x": [6]}, schema=schema) + + with Client(example_worker) as client: + outputs = list( + client.scalar_function( + function_name="double_column", + input=iter([batch1, batch2, batch3]), + arguments=Arguments(positional=(pa.scalar("x"),)), + ) + ) + + # Should get 3 output batches (one per input) + assert len(outputs) == 3 + total_rows = sum(b.num_rows for b in outputs) + assert total_rows == 6 + + # Verify the values (order may vary in parallel mode, but we're single-worker) + all_values: list[int] = [] + for batch in outputs: + all_values.extend(cast(list[int], batch.column("result").to_pylist())) + assert sorted(all_values) == [2, 4, 6, 8, 10, 12] + + def test_empty_batch(self, example_worker: str) -> None: + """Test scalar function with empty batch.""" + schema = pa.schema([("x", pa.int64())]) + empty_batch = pa.RecordBatch.from_pydict({"x": []}, schema=schema) + + with Client(example_worker) as client: + outputs = list( + client.scalar_function( + function_name="double_column", + input=iter([empty_batch]), + arguments=Arguments(positional=(pa.scalar("x"),)), + ) + ) + + # Should get one output batch with zero rows + assert len(outputs) == 1 + assert outputs[0].num_rows == 0 + + def test_empty_iterator(self, example_worker: str) -> None: + """Test scalar function with no input batches.""" + with Client(example_worker) as client: + outputs = list( + client.scalar_function( + function_name="double_column", + input=iter([]), + arguments=Arguments(positional=(pa.scalar("x"),)), + ) + ) + + # No input means no output + assert len(outputs) == 0 + + def test_scalar_function_not_started_raises(self, example_worker: str) -> None: + """Calling scalar_function before start should raise ClientError.""" + client = Client(example_worker) + schema = pa.schema([("x", pa.int64())]) + batch = pa.RecordBatch.from_pydict({"x": [1]}, schema=schema) + + with pytest.raises(ClientError, match="not started"): + list( + client.scalar_function( + function_name="double_column", + input=iter([batch]), + arguments=Arguments(positional=(pa.scalar("x"),)), + ) + ) + + def test_large_batch(self, example_worker: str) -> None: + """Test scalar function with a large batch.""" + schema = pa.schema([("x", pa.int64())]) + large_data = list(range(10000)) + batch = pa.RecordBatch.from_pydict({"x": large_data}, schema=schema) + + with Client(example_worker) as client: + outputs = list( + client.scalar_function( + function_name="double_column", + input=iter([batch]), + arguments=Arguments(positional=(pa.scalar("x"),)), + ) + ) + + total_rows = sum(b.num_rows for b in outputs) + assert total_rows == 10000 + + # Verify first and last values + all_values = [] + for b in outputs: + all_values.extend(b.column("result").to_pylist()) + assert all_values[0] == 0 # 0 * 2 = 0 + assert all_values[-1] == 19998 # 9999 * 2 = 19998 + + def test_bind_result_callback(self, example_worker: str) -> None: + """Test that bind_result_callback is invoked.""" + schema = pa.schema([("x", pa.int64())]) + batch = pa.RecordBatch.from_pydict({"x": [1, 2, 3]}, schema=schema) + + bind_results: list[pa.RecordBatch] = [] + + def capture_bind_result(result: pa.RecordBatch) -> None: + bind_results.append(result) + + with Client(example_worker) as client: + list( + client.scalar_function( + function_name="double_column", + input=iter([batch]), + arguments=Arguments(positional=(pa.scalar("x"),)), + bind_result_callback=capture_bind_result, + ) + ) + + # Should have received bind result + assert len(bind_results) == 1 + bind_result = bind_results[0] + + # Verify bind result contains expected fields + assert "output_schema" in bind_result.schema.names + assert "max_processes" in bind_result.schema.names + + +class TestScalarFunctionParallel: + """Tests for scalar functions with parallel processing.""" + + def test_parallel_double_column(self, example_worker: str) -> None: + """Test scalar function with multiple workers.""" + schema = pa.schema([("x", pa.int64())]) + batches = [ + pa.RecordBatch.from_pydict( + {"x": list(range(i * 100, (i + 1) * 100))}, schema=schema + ) + for i in range(10) + ] + + with Client(example_worker, max_workers=4) as client: + outputs = list( + client.scalar_function( + function_name="double_column", + input=iter(batches), + arguments=Arguments(positional=(pa.scalar("x"),)), + ) + ) + + # Should get all 1000 rows back + total_rows = sum(b.num_rows for b in outputs) + assert total_rows == 1000 + + # Verify all values are correctly doubled + all_values = set() + for batch in outputs: + all_values.update(batch.column("result").to_pylist()) + + expected = {i * 2 for i in range(1000)} + assert all_values == expected + + def test_parallel_add_columns(self, example_worker: str) -> None: + """Test add_columns with multiple workers.""" + schema = pa.schema([("a", pa.int64()), ("b", pa.int64())]) + batches = [ + pa.RecordBatch.from_pydict( + {"a": [i, i + 1, i + 2], "b": [100, 200, 300]}, schema=schema + ) + for i in range(20) + ] + + with Client(example_worker, max_workers=3) as client: + outputs = list( + client.scalar_function( + function_name="add_columns", + input=iter(batches), + arguments=Arguments(positional=(pa.scalar("a"), pa.scalar("b"))), + ) + ) + + # Should get 60 rows total (20 batches * 3 rows) + total_rows = sum(b.num_rows for b in outputs) + assert total_rows == 60 + + def test_parallel_empty_batches_mixed(self, example_worker: str) -> None: + """Test parallel processing with mix of empty and non-empty batches.""" + schema = pa.schema([("x", pa.int64())]) + batches = [ + pa.RecordBatch.from_pydict({"x": [1, 2]}, schema=schema), + pa.RecordBatch.from_pydict({"x": []}, schema=schema), # Empty + pa.RecordBatch.from_pydict({"x": [3]}, schema=schema), + pa.RecordBatch.from_pydict({"x": []}, schema=schema), # Empty + pa.RecordBatch.from_pydict({"x": [4, 5, 6]}, schema=schema), + ] + + with Client(example_worker, max_workers=2) as client: + outputs = list( + client.scalar_function( + function_name="double_column", + input=iter(batches), + arguments=Arguments(positional=(pa.scalar("x"),)), + ) + ) + + # Should get 6 rows total (2 + 0 + 1 + 0 + 3) + total_rows = sum(b.num_rows for b in outputs) + assert total_rows == 6 + + # Verify values + all_values = set() + for batch in outputs: + all_values.update(batch.column("result").to_pylist()) + assert all_values == {2, 4, 6, 8, 10, 12} + + def test_parallel_single_batch(self, example_worker: str) -> None: + """Test parallel mode with just one batch (should still work).""" + schema = pa.schema([("x", pa.int64())]) + batch = pa.RecordBatch.from_pydict({"x": [1, 2, 3]}, schema=schema) + + with Client(example_worker, max_workers=4) as client: + outputs = list( + client.scalar_function( + function_name="double_column", + input=iter([batch]), + arguments=Arguments(positional=(pa.scalar("x"),)), + ) + ) + + assert len(outputs) == 1 + assert outputs[0].to_pydict() == {"result": [2, 4, 6]} diff --git a/tests/scalar/test_function.py b/tests/scalar/test_function.py new file mode 100644 index 0000000..1422b87 --- /dev/null +++ b/tests/scalar/test_function.py @@ -0,0 +1,289 @@ +"""Tests for scalar function base classes.""" + +from __future__ import annotations + +from typing import Any + +import pyarrow as pa +import pytest +import structlog + +from vgi.arguments import Arg +from vgi.function import Arguments, Invocation, InvocationType, SchemaValidationError +from vgi.log import Level, Message +from vgi.scalar_function import ( + Output, + ProtocolInput, + ScalarFunction, + ScalarFunctionGenerator, + ScalarOutputGenerator, +) + + +def create_invocation(input_schema: pa.Schema) -> Invocation: + """Create a test invocation with the given input schema.""" + return Invocation( + function_name="test_function", + input_schema=input_schema, + function_type=InvocationType.SCALAR, + correlation_id="test-correlation", + invocation_id=b"test-invocation", + arguments=Arguments(), + ) + + +class TestScalarFunctionGenerator: + """Tests for the generator-based ScalarFunctionGenerator.""" + + def test_basic_process(self) -> None: + """Test basic processing of batches.""" + + class DoubleColumn(ScalarFunctionGenerator): + @property + def output_schema(self) -> pa.Schema: + return pa.schema([("result", pa.int64())]) + + def process(self, batch: pa.RecordBatch) -> ScalarOutputGenerator: + _ = yield Output(self.empty_output_batch) # Priming yield + import pyarrow.compute as pc + + while True: + result = pc.multiply(batch.column("x"), 2) + output = pa.RecordBatch.from_arrays( + [result], schema=self.output_schema + ) + received = yield Output(output) + if received is None: + break + batch = received + + input_schema = pa.schema([("x", pa.int64())]) + invocation = create_invocation(input_schema) + logger = structlog.get_logger() + + func = DoubleColumn(invocation=invocation, logger=logger) + + # Run the protocol + generator = func.run() + next(generator) # Prime + + # Send a batch + input_batch = pa.RecordBatch.from_pydict({"x": [1, 2, 3]}, schema=input_schema) + output = generator.send(ProtocolInput(batch=input_batch)) + + assert output.batch is not None + assert output.batch.num_rows == 3 + assert output.batch.column("result").to_pylist() == [2, 4, 6] + + def test_requires_input_schema(self) -> None: + """Test that input schema is required.""" + + class TestFunc(ScalarFunctionGenerator): + @property + def output_schema(self) -> pa.Schema: + return pa.schema([("result", pa.int64())]) + + def process(self, batch: pa.RecordBatch) -> ScalarOutputGenerator: + _ = yield Output(self.empty_output_batch) + + invocation = Invocation( + function_name="test", + input_schema=None, + function_type=InvocationType.SCALAR, + correlation_id="test", + invocation_id=None, + arguments=Arguments(), + ) + + with pytest.raises(ValueError, match="requires an input schema"): + TestFunc(invocation=invocation, logger=structlog.get_logger()) + + def test_requires_single_column_output(self) -> None: + """Test that output schema must have exactly one column.""" + + class TwoColumnOutput(ScalarFunctionGenerator): + @property + def output_schema(self) -> pa.Schema: + return pa.schema([("a", pa.int64()), ("b", pa.int64())]) + + def process(self, batch: pa.RecordBatch) -> ScalarOutputGenerator: + _ = yield Output(self.empty_output_batch) + + input_schema = pa.schema([("x", pa.int64())]) + invocation = create_invocation(input_schema) + + with pytest.raises(SchemaValidationError, match="exactly 1 output column"): + TwoColumnOutput(invocation=invocation, logger=structlog.get_logger()) + + def test_log_message_support(self) -> None: + """Test that log messages can be yielded.""" + + class LoggingScalar(ScalarFunctionGenerator): + @property + def output_schema(self) -> pa.Schema: + return pa.schema([("result", pa.int64())]) + + def process(self, batch: pa.RecordBatch) -> ScalarOutputGenerator: + _ = yield Output(self.empty_output_batch) # Priming yield + import pyarrow.compute as pc + + while True: + yield Message(Level.INFO, f"Processing {batch.num_rows} rows") + result = pc.multiply(batch.column("x"), 2) + output = pa.RecordBatch.from_arrays( + [result], schema=self.output_schema + ) + received = yield Output(output) + if received is None: + break + batch = received + + input_schema = pa.schema([("x", pa.int64())]) + invocation = create_invocation(input_schema) + func = LoggingScalar(invocation=invocation, logger=structlog.get_logger()) + + generator = func.run() + next(generator) + + input_batch = pa.RecordBatch.from_pydict({"x": [1, 2, 3]}, schema=input_schema) + + # First yield should be the log message + output = generator.send(ProtocolInput(batch=input_batch)) + assert output.log_message is not None + assert output.log_message.level == Level.INFO + assert "Processing 3 rows" in output.log_message.message + + # Re-send to get actual output + output = generator.send(ProtocolInput(batch=input_batch)) + assert output.batch is not None + assert output.batch.column("result").to_pylist() == [2, 4, 6] + + +class TestScalarFunction: + """Tests for the callback-based ScalarFunction.""" + + def test_basic_compute(self) -> None: + """Test basic compute() method.""" + + class DoubleColumn(ScalarFunction): + column = Arg[str](0) + + @property + def output_type(self) -> pa.DataType: + return pa.int64() + + def compute(self, batch: pa.RecordBatch) -> pa.Array[Any]: + import pyarrow.compute as pc + + return pc.multiply(batch.column(self.column), 2) + + input_schema = pa.schema([("x", pa.int64())]) + invocation = Invocation( + function_name="test", + input_schema=input_schema, + function_type=InvocationType.SCALAR, + correlation_id="test", + invocation_id=None, + arguments=Arguments(positional=(pa.scalar("x"),)), + ) + + func = DoubleColumn(invocation=invocation, logger=structlog.get_logger()) + + # Run the protocol + generator = func.run() + next(generator) + + input_batch = pa.RecordBatch.from_pydict({"x": [1, 2, 3]}, schema=input_schema) + output = generator.send(ProtocolInput(batch=input_batch)) + + assert output.batch is not None + assert output.batch.num_rows == 3 + assert output.batch.column("result").to_pylist() == [2, 4, 6] + + def test_log_method(self) -> None: + """Test self.log() method.""" + + class LoggingFunc(ScalarFunction): + @property + def output_type(self) -> pa.DataType: + return pa.int64() + + def compute(self, batch: pa.RecordBatch) -> pa.Array[Any]: + import pyarrow.compute as pc + + self.log(Level.INFO, f"Processing {batch.num_rows} rows") + return pc.multiply(batch.column("x"), 2) + + input_schema = pa.schema([("x", pa.int64())]) + invocation = create_invocation(input_schema) + func = LoggingFunc(invocation=invocation, logger=structlog.get_logger()) + + generator = func.run() + next(generator) + + input_batch = pa.RecordBatch.from_pydict({"x": [1, 2, 3]}, schema=input_schema) + + # First yield should be the log message + output = generator.send(ProtocolInput(batch=input_batch)) + assert output.log_message is not None + assert output.log_message.level == Level.INFO + assert "Processing 3 rows" in output.log_message.message + + # Re-send to get actual output + output = generator.send(ProtocolInput(batch=input_batch)) + assert output.batch is not None + assert output.batch.column("result").to_pylist() == [2, 4, 6] + + def test_row_count_validation(self) -> None: + """Test that row count mismatch raises error.""" + + class WrongRowCount(ScalarFunction): + @property + def output_type(self) -> pa.DataType: + return pa.int64() + + def compute(self, batch: pa.RecordBatch) -> pa.Array[Any]: + # Return wrong number of rows + return pa.array([1, 2]) + + input_schema = pa.schema([("x", pa.int64())]) + invocation = create_invocation(input_schema) + func = WrongRowCount(invocation=invocation, logger=structlog.get_logger()) + + generator = func.run() + next(generator) + + input_batch = pa.RecordBatch.from_pydict({"x": [1, 2, 3]}, schema=input_schema) + output = generator.send(ProtocolInput(batch=input_batch)) + + # Should have an exception log message + assert output.log_message is not None + assert output.log_message.level == Level.EXCEPTION + assert "same row count" in output.log_message.message.lower() + + def test_empty_batch(self) -> None: + """Test handling of empty batches.""" + + class DoubleFunc(ScalarFunction): + @property + def output_type(self) -> pa.DataType: + return pa.int64() + + def compute(self, batch: pa.RecordBatch) -> pa.Array[Any]: + import pyarrow.compute as pc + + return pc.multiply(batch.column("x"), 2) + + input_schema = pa.schema([("x", pa.int64())]) + invocation = create_invocation(input_schema) + func = DoubleFunc(invocation=invocation, logger=structlog.get_logger()) + + generator = func.run() + next(generator) + + # Empty batch + input_batch = pa.RecordBatch.from_pydict({"x": []}, schema=input_schema) + output = generator.send(ProtocolInput(batch=input_batch)) + + assert output.batch is not None + assert output.batch.num_rows == 0 diff --git a/tests/table/generator/test_constant_table_function.py b/tests/table/generator/test_constant_table_function.py index 94681b5..3d489d0 100644 --- a/tests/table/generator/test_constant_table_function.py +++ b/tests/table/generator/test_constant_table_function.py @@ -29,14 +29,15 @@ def test_cardinality(self) -> None: """Cardinality should always be 1.""" import structlog - from vgi.function import Arguments, Invocation + from vgi.function import Arguments, Invocation, InvocationType invocation = Invocation( function_name="constant_table", - arguments=Arguments(positional=(pa.scalar(42),)), - in_out_function_input_schema=None, + input_schema=None, + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", + arguments=Arguments(positional=(pa.scalar(42),)), ) func = ConstantTableFunction( invocation=invocation, diff --git a/tests/table/generator/test_projected_data_function.py b/tests/table/generator/test_projected_data_function.py index d30b632..9e84ad2 100644 --- a/tests/table/generator/test_projected_data_function.py +++ b/tests/table/generator/test_projected_data_function.py @@ -152,15 +152,16 @@ def test_output_schema_reflects_projection(self) -> None: """The output_schema property should reflect the projection.""" import structlog - from vgi.function import Invocation - from vgi.table_function import GlobalStateInitInput + from vgi.function import Invocation, InvocationType + from vgi.table_function import TableFunctionInitInput invocation = Invocation( function_name="projected_data", - arguments=Arguments(positional=(pa.scalar(10),)), - in_out_function_input_schema=None, + input_schema=None, + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", + arguments=Arguments(positional=(pa.scalar(10),)), ) func = ProjectedDataFunction( invocation=invocation, @@ -171,7 +172,7 @@ def test_output_schema_reflects_projection(self) -> None: assert func.output_schema == ProjectedDataFunction.FULL_SCHEMA # After setting init_data with projection, should return projected schema - func.init_data = GlobalStateInitInput(projection_ids=[0, 2]) + func.init_data = TableFunctionInitInput(projection_ids=[0, 2]) schema = func.output_schema assert len(schema) == 2 assert schema.names == ["id", "value"] diff --git a/tests/table/generator/test_sequence_function.py b/tests/table/generator/test_sequence_function.py index f3e8963..780689c 100644 --- a/tests/table/generator/test_sequence_function.py +++ b/tests/table/generator/test_sequence_function.py @@ -30,14 +30,15 @@ def test_cardinality(self) -> None: """Cardinality should match requested count.""" import structlog - from vgi.function import Arguments, Invocation + from vgi.function import Arguments, Invocation, InvocationType invocation = Invocation( function_name="sequence", - arguments=Arguments(positional=(pa.scalar(100),)), - in_out_function_input_schema=None, + input_schema=None, + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", + arguments=Arguments(positional=(pa.scalar(100),)), ) func = SequenceFunction( invocation=invocation, diff --git a/tests/table/test_function.py b/tests/table/test_function.py index 566c561..e531681 100644 --- a/tests/table/test_function.py +++ b/tests/table/test_function.py @@ -7,7 +7,7 @@ import structlog from tests.utils import make_schema -from vgi.function import Arguments, Invocation +from vgi.function import Arguments, Invocation, InvocationType from vgi.table_function import ( CardinalityInfo, Output, @@ -194,10 +194,11 @@ def output_schema(self) -> pa.Schema: invocation = Invocation( function_name="test", - arguments=Arguments(), - in_out_function_input_schema=None, + input_schema=None, + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", + arguments=Arguments(), ) func = NoCardinalityFunction( invocation=invocation, @@ -219,10 +220,11 @@ def cardinality(self) -> CardinalityInfo: invocation = Invocation( function_name="test", - arguments=Arguments(), - in_out_function_input_schema=None, + input_schema=None, + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", + arguments=Arguments(), ) func = CardinalityFunction( invocation=invocation, @@ -311,10 +313,11 @@ def output_schema(self) -> pa.Schema: invocation = Invocation( function_name="test", - arguments=Arguments(), - in_out_function_input_schema=None, + input_schema=None, + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", + arguments=Arguments(), ) func = TestFunction( invocation=invocation, diff --git a/tests/table_in_out/test_function.py b/tests/table_in_out/test_function.py index d624da8..78ae2bd 100644 --- a/tests/table_in_out/test_function.py +++ b/tests/table_in_out/test_function.py @@ -4,7 +4,7 @@ import pyarrow.compute as pc import structlog -from vgi.function import Arg, Arguments, Invocation +from vgi.function import Arg, Arguments, Invocation, InvocationType from vgi.ipc_utils import RecordBatchState from vgi.log import Level from vgi.table_in_out_function import ( @@ -18,7 +18,8 @@ def make_invocation(input_schema: pa.Schema) -> Invocation: """Create a minimal Invocation for testing.""" return Invocation( function_name="test", - in_out_function_input_schema=input_schema, + input_schema=input_schema, + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", arguments=Arguments(), @@ -31,7 +32,8 @@ def make_invocation_with_args( """Create an Invocation with positional arguments.""" return Invocation( function_name="test", - in_out_function_input_schema=input_schema, + input_schema=input_schema, + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", arguments=Arguments(positional=tuple(pa.scalar(v) for v in positional)), diff --git a/tests/test_metadata.py b/tests/test_metadata.py index e1d1cd8..d2b8d63 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -211,7 +211,7 @@ class TestMaxWorkersIntegration: def test_max_workers_used(self) -> None: """max_processes() returns Meta.max_workers when defined.""" from vgi.arguments import Arguments - from vgi.function import Invocation + from vgi.function import Invocation, InvocationType class LimitedFunction(TableInOutFunction): class Meta: @@ -221,10 +221,11 @@ class Meta: invocation = Invocation( function_name="test", - arguments=Arguments(), - in_out_function_input_schema=pa.schema([]), + input_schema=pa.schema([]), + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", + arguments=Arguments(), ) import structlog @@ -234,17 +235,18 @@ class Meta: def test_default_max_workers(self) -> None: """max_processes() returns default when max_workers not defined.""" from vgi.arguments import Arguments - from vgi.function import Invocation + from vgi.function import Invocation, InvocationType class UnlimitedFunction(TableInOutFunction): data: TableInput = Arg[TableInput](0, doc="Input table") # type: ignore[assignment] invocation = Invocation( function_name="test", - arguments=Arguments(), - in_out_function_input_schema=pa.schema([]), + input_schema=pa.schema([]), + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", + arguments=Arguments(), ) import structlog diff --git a/tests/test_patterns.py b/tests/test_patterns.py index 4857f41..d7dff6a 100644 --- a/tests/test_patterns.py +++ b/tests/test_patterns.py @@ -8,7 +8,7 @@ import pyarrow.compute as pc import structlog -from vgi.function import Arg, Arguments, Invocation +from vgi.function import Arg, Arguments, Invocation, InvocationType from vgi.log import Level from vgi.table_in_out_function_patterns import ( AggregationFunction, @@ -22,7 +22,8 @@ def make_invocation(input_schema: pa.Schema) -> Invocation: """Create a minimal Invocation for testing.""" return Invocation( function_name="test", - in_out_function_input_schema=input_schema, + input_schema=input_schema, + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", arguments=Arguments(), diff --git a/tests/test_protocol_classes.py b/tests/test_protocol_classes.py index 3693a4a..c6f22dc 100644 --- a/tests/test_protocol_classes.py +++ b/tests/test_protocol_classes.py @@ -15,12 +15,13 @@ ArgumentValidationError, GlobalInitResult, Invocation, + InvocationType, ) from vgi.log import Level, Message from vgi.table_function import ( CardinalityInfo, - GlobalStateInitInput, OutputSpec, + TableFunctionInitInput, ) @@ -180,10 +181,11 @@ def test_basic_round_trip(self) -> None: """Basic Invocation should serialize and deserialize correctly.""" original = Invocation( function_name="test_function", - arguments=Arguments(positional=(pa.scalar(42),), named={}), - in_out_function_input_schema=make_schema([pa.field("col1", pa.int64())]), + input_schema=make_schema([pa.field("col1", pa.int64())]), + function_type=InvocationType.TABLE, correlation_id="test-123", invocation_id=b"bind-id-bytes", + arguments=Arguments(positional=(pa.scalar(42),), named={}), ) serialized = original.serialize() @@ -200,10 +202,7 @@ def test_basic_round_trip(self) -> None: assert deserialized.function_name == original.function_name assert deserialized.correlation_id == original.correlation_id assert deserialized.invocation_id == original.invocation_id - assert ( - deserialized.in_out_function_input_schema - == original.in_out_function_input_schema - ) + assert deserialized.input_schema == original.input_schema assert len(deserialized.arguments.positional) == 1 assert deserialized.arguments.positional[0] is not None assert deserialized.arguments.positional[0].as_py() == 42 @@ -212,7 +211,8 @@ def test_nullmake_schema(self) -> None: """Invocation with null input schema should round-trip correctly.""" original = Invocation( function_name="scalar_function", - in_out_function_input_schema=None, + input_schema=None, + function_type=InvocationType.TABLE, correlation_id="", invocation_id=None, ) @@ -225,7 +225,7 @@ def test_nullmake_schema(self) -> None: deserialized = Invocation.deserialize(batch) assert deserialized.function_name == "scalar_function" - assert deserialized.in_out_function_input_schema is None + assert deserialized.input_schema is None assert deserialized.invocation_id is None def test_complexmake_schema(self) -> None: @@ -242,7 +242,8 @@ def test_complexmake_schema(self) -> None: original = Invocation( function_name="complex_function", - in_out_function_input_schema=complex_schema, + input_schema=complex_schema, + function_type=InvocationType.TABLE, correlation_id="complex-test", invocation_id=b"complex-bind", ) @@ -254,7 +255,7 @@ def test_complexmake_schema(self) -> None: batch = reader.read_next_batch() deserialized = Invocation.deserialize(batch) - assert deserialized.in_out_function_input_schema == complex_schema + assert deserialized.input_schema == complex_schema def test_deserialize_empty_batch_raises(self) -> None: """Deserializing empty batch should raise ValueError.""" @@ -264,7 +265,7 @@ def test_deserialize_empty_batch_raises(self) -> None: [ pa.field("function_name", pa.string()), pa.field("arguments", pa.struct([])), - pa.field("in_out_function_input_schema", pa.binary()), + pa.field("input_schema", pa.binary()), pa.field("invocation_id", pa.binary()), pa.field("correlation_id", pa.string()), ] @@ -281,14 +282,14 @@ def test_deserialize_multi_row_batch_raises(self) -> None: { "function_name": "fn1", "arguments": {}, - "in_out_function_input_schema": None, + "input_schema": None, "invocation_id": None, "correlation_id": "", }, { "function_name": "fn2", "arguments": {}, - "in_out_function_input_schema": None, + "input_schema": None, "invocation_id": None, "correlation_id": "", }, @@ -302,7 +303,8 @@ def test_with_global_init_identifier(self) -> None: """Test that with_global_init_identifier creates a new Invocation.""" original = Invocation( function_name="test", - in_out_function_input_schema=None, + input_schema=None, + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=None, global_init_identifier=None, @@ -517,7 +519,7 @@ class TestGlobalStateInitInput: def test_basic_round_trip(self) -> None: """GlobalStateInitInput should serialize and deserialize correctly.""" - original = GlobalStateInitInput(projection_ids=[0, 2, 4]) + original = TableFunctionInitInput(projection_ids=[0, 2, 4]) serialized = original.serialize() assert isinstance(serialized, bytes) @@ -526,39 +528,39 @@ def test_basic_round_trip(self) -> None: reader = ipc.open_stream(serialized) batch = reader.read_next_batch() - deserialized = GlobalStateInitInput.deserialize(batch) + deserialized = TableFunctionInitInput.deserialize(batch) assert deserialized.projection_ids == [0, 2, 4] def test_null_projection_ids(self) -> None: """GlobalStateInitInput with null projection_ids should round-trip.""" - original = GlobalStateInitInput(projection_ids=None) + original = TableFunctionInitInput(projection_ids=None) serialized = original.serialize() from pyarrow import ipc reader = ipc.open_stream(serialized) batch = reader.read_next_batch() - deserialized = GlobalStateInitInput.deserialize(batch) + deserialized = TableFunctionInitInput.deserialize(batch) assert deserialized.projection_ids is None def test_empty_projection_ids(self) -> None: """GlobalStateInitInput with empty list should round-trip.""" - original = GlobalStateInitInput(projection_ids=[]) + original = TableFunctionInitInput(projection_ids=[]) serialized = original.serialize() from pyarrow import ipc reader = ipc.open_stream(serialized) batch = reader.read_next_batch() - deserialized = GlobalStateInitInput.deserialize(batch) + deserialized = TableFunctionInitInput.deserialize(batch) assert deserialized.projection_ids == [] def test_default_value(self) -> None: """GlobalStateInitInput default should have None projection_ids.""" - default = GlobalStateInitInput() + default = TableFunctionInitInput() assert default.projection_ids is None diff --git a/tests/test_schema_utils.py b/tests/test_schema_utils.py index 568f5a6..6760233 100644 --- a/tests/test_schema_utils.py +++ b/tests/test_schema_utils.py @@ -267,7 +267,7 @@ def test_schema_with_function(self) -> None: from vgi import TableInOutFunction from vgi.arguments import Arguments - from vgi.function import Invocation + from vgi.function import Invocation, InvocationType class TestFunction(TableInOutFunction): @property @@ -277,10 +277,11 @@ def output_schema(self) -> pa.Schema: # Verify the schema is created correctly when function is used invocation = Invocation( function_name="test", - arguments=Arguments(), - in_out_function_input_schema=pa.schema([pa.field("x", pa.int64())]), + input_schema=pa.schema([pa.field("x", pa.int64())]), + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", + arguments=Arguments(), ) logger = structlog.get_logger() func = TestFunction(invocation=invocation, logger=logger) @@ -299,7 +300,7 @@ def test_schema_like_with_function(self) -> None: from vgi import TableInOutFunction from vgi.arguments import Arguments - from vgi.function import Invocation + from vgi.function import Invocation, InvocationType class TestFunction(TableInOutFunction): @property @@ -318,10 +319,11 @@ def output_schema(self) -> pa.Schema: invocation = Invocation( function_name="test", - arguments=Arguments(), - in_out_function_input_schema=input_schema, + input_schema=input_schema, + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", + arguments=Arguments(), ) logger = structlog.get_logger() func = TestFunction(invocation=invocation, logger=logger) diff --git a/tests/test_worker.py b/tests/test_worker.py index fcb904c..336ed23 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -5,7 +5,7 @@ from vgi import Arg, TableInOutFunction, TableInput from vgi.arguments import Arguments -from vgi.function import Invocation +from vgi.function import Invocation, InvocationType from vgi.worker import Worker @@ -25,10 +25,11 @@ class Meta: invocation = Invocation( function_name="single", - arguments=Arguments(), - in_out_function_input_schema=pa.schema([]), + input_schema=pa.schema([]), + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", + arguments=Arguments(), ) result = Worker._match_function(invocation, [SingleFunction]) @@ -70,7 +71,8 @@ class Meta: inv0 = Invocation( function_name="func", arguments=Arguments(positional=()), - in_out_function_input_schema=pa.schema([]), + input_schema=pa.schema([]), + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", ) @@ -80,7 +82,8 @@ class Meta: inv1 = Invocation( function_name="func", arguments=Arguments(positional=(pa.scalar(5),)), - in_out_function_input_schema=pa.schema([]), + input_schema=pa.schema([]), + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", ) @@ -90,7 +93,8 @@ class Meta: inv2 = Invocation( function_name="func", arguments=Arguments(positional=(pa.scalar(5), pa.scalar(10))), - in_out_function_input_schema=pa.schema([]), + input_schema=pa.schema([]), + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", ) @@ -121,7 +125,8 @@ class Meta: inv_with = Invocation( function_name="func", arguments=Arguments(positional=(pa.scalar(5),)), - in_out_function_input_schema=pa.schema([]), + input_schema=pa.schema([]), + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", ) @@ -132,7 +137,8 @@ class Meta: inv_without = Invocation( function_name="func", arguments=Arguments(positional=()), - in_out_function_input_schema=pa.schema([]), + input_schema=pa.schema([]), + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", ) @@ -166,7 +172,8 @@ class Meta: inv_format = Invocation( function_name="func", arguments=Arguments(positional=(), named={"format": pa.scalar("json")}), - in_out_function_input_schema=pa.schema([]), + input_schema=pa.schema([]), + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", ) @@ -176,7 +183,8 @@ class Meta: inv_sep = Invocation( function_name="func", arguments=Arguments(positional=(), named={"separator": pa.scalar(",")}), - in_out_function_input_schema=pa.schema([]), + input_schema=pa.schema([]), + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", ) @@ -198,7 +206,8 @@ class Meta: inv = Invocation( function_name="func", arguments=Arguments(positional=(pa.scalar(1), pa.scalar(2), pa.scalar(3))), - in_out_function_input_schema=pa.schema([]), + input_schema=pa.schema([]), + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", ) @@ -230,7 +239,8 @@ class Meta: inv = Invocation( function_name="func", arguments=Arguments(positional=()), - in_out_function_input_schema=pa.schema([]), + input_schema=pa.schema([]), + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", ) @@ -254,7 +264,8 @@ class Meta: inv = Invocation( function_name="func", arguments=Arguments(positional=(), named={"unknown": pa.scalar("x")}), - in_out_function_input_schema=pa.schema([]), + input_schema=pa.schema([]), + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", ) @@ -277,7 +288,8 @@ class Meta: inv = Invocation( function_name="func", arguments=Arguments(positional=(), named=None), - in_out_function_input_schema=pa.schema([]), + input_schema=pa.schema([]), + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=b"test", ) diff --git a/vgi/__init__.py b/vgi/__init__.py index 6561b9e..0aff109 100644 --- a/vgi/__init__.py +++ b/vgi/__init__.py @@ -48,7 +48,7 @@ def process(self, batch: pa.RecordBatch) -> OutputGenerator: from vgi import Worker class MyWorker(Worker): - registry = {"my_function": MyFunction} + functions = [MyFunction] if __name__ == "__main__": MyWorker().run() @@ -59,6 +59,8 @@ class MyWorker(Worker): TableInOutFunction - Callback-based API (recommended) TableInOutGeneratorFunction - Generator-based API (advanced) + ScalarFunction - Scalar function with compute() (single-column output) + ScalarFunctionGenerator - Scalar function with generator protocol Output - Output batch from process()/finalize() OutputGenerator - Type alias for process()/finalize() StreamingGenerator - Type alias for @streaming decorated methods @@ -110,12 +112,15 @@ class Meta: CLASS HIERARCHY --------------- vgi.function.Function - Base (max_processes, invocation_id) - └─ vgi.table_function.TableFunction - Adds cardinality hints, projection - └─ TableInOutGeneratorFunction - Full streaming (process/finalize) - └─ TableInOutFunction - Callback API (transform/finish) - ├─ AggregationFunction - Reduce to summary - ├─ FilterFunction - Row filtering - └─ MapFunction - Column transformation + ├─ vgi.table_function.TableFunctionBase - Adds cardinality hints, projection + │ ├─ TableFunctionGenerator - Generate output without input + │ └─ TableInOutGeneratorFunction - Full streaming (process/finalize) + │ └─ TableInOutFunction - Callback API (transform/finish) + │ ├─ AggregationFunction - Reduce to summary + │ ├─ FilterFunction - Row filtering + │ └─ MapFunction - Column transformation + └─ ScalarFunctionGenerator - Single-column output (1:1 rows) + └─ ScalarFunction - Callback API (compute) Examples -------- @@ -142,7 +147,13 @@ class Meta: TableInputValidationError, functions_to_arrow, ) +from vgi.scalar_function import ( + ScalarFunction, + ScalarFunctionGenerator, + ScalarOutputGenerator, +) from vgi.schema_utils import schema, schema_like +from vgi.table_function import RowCountMismatchError from vgi.table_in_out_function import ( Output, OutputGenerator, @@ -178,6 +189,10 @@ class Meta: "OutputGenerator", "ParameterInfo", "ResolvedMetadata", + "RowCountMismatchError", + "ScalarFunction", + "ScalarFunctionGenerator", + "ScalarOutputGenerator", "StreamingGenerator", "TableInOutFunction", "TableInOutGeneratorFunction", diff --git a/vgi/arguments.py b/vgi/arguments.py index 0a18e1e..440cbd7 100644 --- a/vgi/arguments.py +++ b/vgi/arguments.py @@ -601,6 +601,8 @@ def _validate(self, value: ArgT) -> None: valid_range = self._describe_valid_range() # Numeric range validation + # Note: type: ignore needed because ArgT is generic - comparisons only valid + # for numeric types, but we can't express "ArgT when constraints are set" if self.ge is not None and value < self.ge: # type: ignore[operator] raise ArgumentValidationError( f"Argument '{arg_name}' is too small.", diff --git a/vgi/client/cli.py b/vgi/client/cli.py index cab7199..44ba14f 100644 --- a/vgi/client/cli.py +++ b/vgi/client/cli.py @@ -1,4 +1,4 @@ -"""Command-line interface for the VGI client. +r"""Command-line interface for the VGI client. This module provides the CLI entry point for invoking VGI functions. @@ -12,6 +12,10 @@ vgi-client --function sequence --args '[100]' vgi-client --function range --args '[0, 10]' + # Scalar functions (with input, single-column output): + vgi-client --input data.parquet --function double_column \ + --args '["x"]' --type scalar + # Specify table input position (for functions where TableInput isn't first): vgi-client --input data.parquet --function transform --args '["prefix"]' \ --table-input-position 1 @@ -204,6 +208,16 @@ def _create_cli() -> Any: "Used to trace calls back to a specific attachment." ), ) + @click.option( + "--type", + "function_type", + type=click.Choice(["auto", "table", "table-in-out", "scalar"]), + default="auto", + help=( + "Function type. 'auto' (default) uses table-in-out if --input is provided, " + "otherwise table. Use 'scalar' for scalar functions." + ), + ) def cli( input_file: str | None, output_file: str | None, @@ -216,6 +230,7 @@ def cli( max_workers: int | None, table_input_position: int | None, attach_id: str | None, + function_type: str, ) -> None: """Invoke a VGI function and display results.""" try: @@ -257,6 +272,18 @@ def cli( log.info("starting_server", function=function_name, server_path=server_path) + # Validate function_type requirements + if function_type == "scalar" and input_file is None: + raise click.ClickException("--type scalar requires --input to be specified") + if function_type == "table-in-out" and input_file is None: + raise click.ClickException( + "--type table-in-out requires --input to be specified" + ) + if function_type == "table" and input_file is not None: + raise click.ClickException( + "--type table does not accept --input (table functions have no input)" + ) + output_writer: OutputWriter | None = None try: with Client( @@ -265,16 +292,36 @@ def cli( max_workers=max_workers, attach_id=attach_id_bytes, ) as client: - if input_file is None: - # Table function (no input) - use table_function method + # Determine effective function type + if function_type == "auto": + effective_type = "table" if input_file is None else "table-in-out" + else: + effective_type = function_type + + if effective_type == "table": + # Table function (no input) log.info("invoking_table_function", function=function_name) output_iterator = client.table_function( function_name=function_name, arguments=Arguments(positional=positional_args, named={}), projection_ids=list(projection_ids) if projection_ids else None, ) + elif effective_type == "scalar": + # Scalar function (with input, single-column output) + assert input_file is not None # Validated earlier + log.info("invoking_scalar_function", function=function_name) + log.info("reading_input", file=input_file) + pf = pq.ParquetFile(input_file) + + output_iterator = client.scalar_function( + function_name=function_name, + arguments=Arguments(positional=positional_args, named={}), + input=pf.iter_batches(), + ) else: - # Table-in-out function - use table_in_out_function method + # Table-in-out function (with input) + assert input_file is not None # Validated earlier + log.info("invoking_table_in_out_function", function=function_name) log.info("reading_input", file=input_file) pf = pq.ParquetFile(input_file) diff --git a/vgi/client/client.py b/vgi/client/client.py index 02ca523..458be0c 100644 --- a/vgi/client/client.py +++ b/vgi/client/client.py @@ -40,6 +40,7 @@ client.stop() : Stop the worker subprocess client.table_in_out_function() : Invoke a TableInOutGeneratorFunction and stream results client.table_function() : Invoke a TableFunctionGenerator and stream results +client.scalar_function() : Invoke a ScalarFunction and stream results client.get_worker_stderr() : Get captured stderr from worker See Also @@ -70,11 +71,13 @@ from vgi.function import ( Arguments, + FunctionInitInput, GlobalInitResult, Invocation, + InvocationType, ) from vgi.ipc_utils import IPCError, read_ipc_batch -from vgi.table_function import GlobalStateInitInput +from vgi.table_function import TableFunctionInitInput # Configure structlog to write to stderr structlog.configure( @@ -369,6 +372,7 @@ def _initialize_stream_common( function_name: str, arguments: Arguments, input_schema: pa.Schema | None, + function_type: InvocationType, bind_result_callback: Callable[[pa.RecordBatch], None] | None, projection_ids: list[int] | None, ) -> tuple[_BindResult, GlobalInitResult, Invocation]: @@ -380,7 +384,7 @@ def _initialize_stream_common( 3. Invokes bind_result_callback if provided 4. Validates protocol version compatibility 5. Applies CPU/max_workers limits to max_processes - 6. Sends GlobalStateInitInput (with projection_ids) + 6. Sends init data (FunctionInitInput or TableFunctionInitInput) 7. Reads GlobalInitResult (shared state identifier for parallel workers) 8. Creates an Invocation with global_init_identifier for additional workers @@ -391,10 +395,13 @@ def _initialize_stream_common( to pass to the function. input_schema: Schema of input batches for table-in-out functions, or None for table functions that generate output without input. + function_type: Type of function being invoked (SCALAR or TABLE). + Determines what init data format is sent to the worker. bind_result_callback: Optional callback invoked with the raw bind result RecordBatch. Called before further processing. projection_ids: Optional list of column indices to project in the - output. Passed to the worker via GlobalStateInitInput. + output. Passed to the worker via TableFunctionInitInput (ignored + for scalar functions). Returns: A tuple of (bind_result, global_init_result, request_with_init): @@ -406,7 +413,7 @@ def _initialize_stream_common( Raises: ClientError: If the worker process is not started, or if reading the bind result or init result fails. - OSError: If writing the Invocation or GlobalStateInitInput fails. + OSError: If writing the Invocation or init data fails. ClientError: If the worker activated unsupported features. """ @@ -421,10 +428,11 @@ def _initialize_stream_common( initial_request = Invocation( function_name=function_name, - arguments=arguments, - in_out_function_input_schema=input_schema, + input_schema=input_schema, + function_type=function_type, correlation_id=self.correlation_id, invocation_id=None, + arguments=arguments, client_features=client_features, attach_id=self._attach_id, ) @@ -469,15 +477,17 @@ def _initialize_stream_common( ), ) - # Send global state init input - global_state_info_serialized_bytes = GlobalStateInitInput( - projection_ids=projection_ids - ).serialize() + if initial_request.function_type == InvocationType.SCALAR: + # Scalar functions use empty init input + init_serialized_bytes = FunctionInitInput().serialize() + else: + # Table functions (generator and table-in-out) use TableFunctionInitInput + init_serialized_bytes = TableFunctionInitInput( + projection_ids=projection_ids + ).serialize() - if self._stdin_sink.write(global_state_info_serialized_bytes) != len( - global_state_info_serialized_bytes - ): - raise OSError("Failed to write global state init input record batch") + if self._stdin_sink.write(init_serialized_bytes) != len(init_serialized_bytes): + raise OSError("Failed to write init record batch") # Read init result log.debug("reading_init_result") @@ -495,11 +505,12 @@ def _initialize_stream_common( # Create request with init for additional workers request_with_init = Invocation( function_name=function_name, - arguments=arguments, - in_out_function_input_schema=input_schema, + input_schema=input_schema, + function_type=function_type, correlation_id=self.correlation_id, invocation_id=bind_result.invocation_id, global_init_identifier=global_init_result, + arguments=arguments, ) return bind_result, global_init_result, request_with_init @@ -579,6 +590,7 @@ def _initialize_function_stream( function_name: str, arguments: Arguments, input_schema: pa.Schema | None, + function_type: InvocationType, bind_result_callback: Callable[[pa.RecordBatch], None] | None, projection_ids: list[int] | None, ) -> tuple[ipc.RecordBatchStreamWriter | None, ipc.RecordBatchStreamReader | None]: @@ -602,6 +614,7 @@ def _initialize_function_stream( arguments: Arguments container with positional and named arguments. input_schema: Schema of input batches for table-in-out functions, or None for table functions. + function_type: Type of function being invoked (SCALAR or TABLE). bind_result_callback: Optional callback invoked with the raw bind result RecordBatch. projection_ids: Optional list of column indices to project. @@ -624,6 +637,7 @@ def _initialize_function_stream( function_name=function_name, arguments=arguments, input_schema=input_schema, + function_type=function_type, bind_result_callback=bind_result_callback, projection_ids=projection_ids, ) @@ -1072,9 +1086,11 @@ def _process_batch_on_worker( output_batches.append(output_batch) + # status is None for scalar functions (which don't emit status) + # and means we're done with this batch if status == b"HAVE_MORE_OUTPUT": continue - elif status == b"NEED_MORE_INPUT": + elif status == b"NEED_MORE_INPUT" or status is None: break else: raise ClientError( @@ -1397,7 +1413,7 @@ def table_in_out_function( result RecordBatch before processing begins. Useful for inspecting output schema, max_processes, or cardinality hints. projection_ids: Optional list of column indices for column projection. - Passed to the worker via GlobalStateInitInput. + Passed to the worker via TableFunctionInitInput. Yields: Output RecordBatches from the function. In single-worker mode, output @@ -1442,6 +1458,7 @@ def table_in_out_function( function_name=function_name, arguments=arguments, input_schema=input_schema, + function_type=InvocationType.TABLE, bind_result_callback=bind_result_callback, projection_ids=projection_ids, ) @@ -1758,7 +1775,7 @@ def table_function( result RecordBatch before processing begins. Useful for inspecting output schema, max_processes, or cardinality hints. projection_ids: Optional list of column indices for column projection. - Passed to the worker via GlobalStateInitInput. + Passed to the worker via TableFunctionInitInput. Yields: Output RecordBatches from the function. In parallel mode @@ -1793,6 +1810,7 @@ def table_function( function_name=function_name, arguments=arguments, input_schema=None, + function_type=InvocationType.TABLE, bind_result_callback=bind_result_callback, projection_ids=projection_ids, ) @@ -1804,3 +1822,228 @@ def table_function( yield from self._table_function_parallel( primary_output_reader=output_reader, ) + + def scalar_function( + self, + *, + function_name: str, + input: Iterator[pa.RecordBatch], + arguments: Arguments | None = None, + bind_result_callback: Callable[[pa.RecordBatch], None] | None = None, + ) -> Generator[pa.RecordBatch, None, None]: + """Invoke a scalar function on the worker and stream results. + + Scalar functions transform input batches to single-column output with + 1:1 row mapping. Processing ends when input is exhausted. + + Processing flow: + 1. Reads the first input batch to determine the input schema + 2. Sends Invocation to worker and receives bind result + 3. Spawns additional workers if max_processes > 1 + 4. Distributes input batches to workers (round-robin for parallel mode) + 5. Collects and yields output batches + + For parallel processing (max_processes > 1), input batches are distributed + round-robin across workers using dedicated threads. Output order may not + match input order in parallel mode. + + Args: + function_name: Name of the function to invoke. Must exist in the + worker's registry. + input: Iterator yielding input RecordBatches. Must yield at least one + batch. The first batch's schema is used to initialize the IPC + stream. If the iterator is empty, no output is produced. + arguments: Optional Arguments container with positional and named + arguments to pass to the function. Defaults to empty Arguments(). + bind_result_callback: Optional callback invoked with the raw bind + result RecordBatch before processing begins. Useful for inspecting + output schema or max_processes. + + Yields: + Output RecordBatches from the function. Each output batch has a single + column and the same number of rows as its corresponding input batch. + In single-worker mode, output order corresponds to input order. + In parallel mode (max_processes > 1), output order is non-deterministic. + + Raises: + ClientError: If the client is not started, input iterator yields + non-RecordBatch objects, communication with the worker fails, + or the worker returns an unexpected status or exception. + + Example: + >>> with Client("vgi-example-worker") as client: + ... batches = [pa.RecordBatch.from_pydict({"x": [1, 2, 3]})] + ... for output in client.scalar_function( + ... function_name="double_column", + ... input=iter(batches), + ... arguments=Arguments(positional=[pa.scalar("x")]), + ... ): + ... print(output.to_pydict()) + {'result': [2, 4, 6]} + + """ + if arguments is None: + arguments = Arguments() + + if ( + self._proc is None + or self._stdin_sink is None + or self._stdout_buffered is None + ): + raise ClientError( + "Client not started. Call start() or use context manager." + ) + + # Get the first batch to determine schema and initialize + for input_batch in input: + if not isinstance(input_batch, pa.RecordBatch): + raise ClientError("Input iterator must yield RecordBatches") + + input_schema = input_batch.schema + data_writer, _ = self._initialize_function_stream( + function_name=function_name, + arguments=arguments, + input_schema=input_schema, + function_type=InvocationType.SCALAR, + bind_result_callback=bind_result_callback, + projection_ids=None, # Scalar functions don't use projection + ) + + # Use parallel processing for all cases (handles both single and + # multi-worker) + assert data_writer is not None # set when input_schema is not None + yield from self._scalar_function_parallel( + input_batch=input_batch, + input_iterator=input, + data_writer=data_writer, + ) + return + + def _scalar_function_parallel( + self, + *, + input_batch: pa.RecordBatch, + input_iterator: Iterator[pa.RecordBatch], + data_writer: ipc.RecordBatchStreamWriter, + ) -> Generator[pa.RecordBatch, None, None]: + """Process scalar function batches across one or more workers using threads. + + Handles both single-worker and multi-worker cases uniformly. + + Processing flow: + 1. Creates worker connection objects for primary + additional workers + 2. Starts one thread per worker running _worker_thread_loop + 3. Distributes input batches round-robin to worker input queues + 4. Signals end-of-input to all workers via None sentinel + 5. Collects all output batches from shared output queue + 6. Waits for worker threads to complete + 7. Closes all workers + + Args: + input_batch: The first input batch, already consumed from the + iterator by scalar_function(). + input_iterator: Iterator for remaining input batches. May be empty + if all input was in the first batch. + data_writer: IPC stream writer for the primary worker, already + initialized by _initialize_function_stream(). + + Yields: + Output RecordBatches from processing, in non-deterministic order for + multi-worker mode. + + Raises: + ClientError: If a worker thread fails with an exception. + + """ + primary_worker = self._create_primary_worker(data_writer=data_writer) + all_workers = [primary_worker] + self._additional_workers + num_workers = len(all_workers) + + log.debug("starting_scalar_parallel_processing", num_workers=num_workers) + + # Create queues for each worker + input_queues: list[Queue[tuple[int, pa.RecordBatch] | None]] = [ + Queue() for _ in range(num_workers) + ] + output_queue: Queue[tuple[int, list[pa.RecordBatch]] | BaseException] = Queue() + + # Start worker threads + threads: list[threading.Thread] = [] + for i, worker in enumerate(all_workers): + thread = threading.Thread( + target=self._worker_thread_loop, + args=(worker, input_queues[i], output_queue), + daemon=True, + ) + thread.start() + threads.append(thread) + + # Distribute batches round-robin across workers + batch_index = 0 + batches_sent = 0 + + # Send first batch + worker_idx = batch_index % num_workers + input_queues[worker_idx].put((batch_index, input_batch)) + batches_sent += 1 + batch_index += 1 + + # Send remaining batches + for input_batch in input_iterator: + worker_idx = batch_index % num_workers + input_queues[worker_idx].put((batch_index, input_batch)) + batches_sent += 1 + batch_index += 1 + + # Signal end of input to all workers + for q in input_queues: + q.put(None) + + log.debug("scalar_all_batches_distributed", total_batches=batches_sent) + + # Collect outputs from all workers + # We expect batches_sent regular outputs + num_workers thread completion signals + outputs_expected = batches_sent + num_workers + outputs_received = 0 + + while outputs_received < outputs_expected: + result = output_queue.get() + + # Check for exceptions from worker threads + if isinstance(result, BaseException): + raise ClientError(f"Worker thread failed: {result}") from result + + batch_idx, output_batches = result + outputs_received += 1 + + # Combine output batches if needed + combined = self._combine_batches(output_batches) + if combined is not None: + yield combined + + log.debug( + "scalar_output_received", + batch_index=batch_idx, + outputs_received=outputs_received, + outputs_expected=outputs_expected, + ) + + self._join_threads(threads) + log.debug("all_scalar_worker_threads_complete") + + # Close data writers to signal EOF to workers + for worker in all_workers: + if worker.data_writer is not None: + worker.data_writer.close() + + # Wait for secondary workers to exit + secondary_workers = all_workers[1:] + for worker in secondary_workers: + worker.proc.wait(timeout=self.PROCESS_WAIT_TIMEOUT) + log.debug( + "scalar_secondary_worker_exited", + worker_index=worker.worker_index, + returncode=worker.proc.returncode, + ) + + log.debug("scalar_parallel_processing_complete") diff --git a/vgi/examples/scalar.py b/vgi/examples/scalar.py new file mode 100644 index 0000000..5a92d64 --- /dev/null +++ b/vgi/examples/scalar.py @@ -0,0 +1,112 @@ +"""Example scalar function implementations. + +This module provides example scalar functions that transform input batches +to single-column output with 1:1 row mapping. + +AVAILABLE FUNCTIONS +------------------- +DoubleColumnFunction - Doubles values in a numeric column +AddColumnsFunction - Adds two numeric columns +UpperCaseFunction - Converts string column to uppercase +""" + +from __future__ import annotations + +from typing import Any, cast + +import pyarrow as pa +import pyarrow.compute as pc + +from vgi.arguments import Arg +from vgi.scalar_function import ScalarFunction + +__all__ = [ + "DoubleColumnFunction", + "AddColumnsFunction", + "UpperCaseFunction", +] + + +class DoubleColumnFunction(ScalarFunction): + """Doubles values in a numeric column. + + Example: + Input: x=[1, 2, 3] + Args: column="x" + Output: result=[2, 4, 6] + + """ + + class Meta: + """Function metadata.""" + + name = "double_column" + description = "Doubles values in a numeric column" + + column = Arg[str](0, doc="Column name to double") + + @property + def output_type(self) -> pa.DataType: + """Return the type of the doubled column.""" + return cast(pa.DataType, self.input_schema.field(self.column).type) + + def compute(self, batch: pa.RecordBatch) -> pa.Array[Any]: + """Double the values in the specified column.""" + return pc.multiply(batch.column(self.column), 2) # type: ignore[no-matching-overload] + + +class AddColumnsFunction(ScalarFunction): + """Adds two numeric columns together. + + Example: + Input: a=[1, 2, 3], b=[10, 20, 30] + Args: col1="a", col2="b" + Output: result=[11, 22, 33] + + """ + + class Meta: + """Function metadata.""" + + name = "add_columns" + description = "Adds two numeric columns" + + col1 = Arg[str](0, doc="First column name") + col2 = Arg[str](1, doc="Second column name") + + @property + def output_type(self) -> pa.DataType: + """Return the type of the first column.""" + return cast(pa.DataType, self.input_schema.field(self.col1).type) + + def compute(self, batch: pa.RecordBatch) -> pa.Array[Any]: + """Add the two columns together.""" + return pc.add(batch.column(self.col1), batch.column(self.col2)) + + +class UpperCaseFunction(ScalarFunction): + """Converts a string column to uppercase. + + Example: + Input: name=["alice", "bob", "charlie"] + Args: column="name" + Output: result=["ALICE", "BOB", "CHARLIE"] + + """ + + class Meta: + """Function metadata.""" + + name = "upper_case" + description = "Converts string column to uppercase" + + column = Arg[str](0, doc="Column name to uppercase") + + @property + def output_type(self) -> pa.DataType: + """Return string type.""" + return pa.string() + + def compute(self, batch: pa.RecordBatch) -> pa.Array[Any]: + """Convert the column values to uppercase.""" + return pc.utf8_upper(batch.column(self.column)) # type: ignore[no-matching-overload] diff --git a/vgi/examples/table.py b/vgi/examples/table.py index 9bfbd52..4e4a87b 100644 --- a/vgi/examples/table.py +++ b/vgi/examples/table.py @@ -24,10 +24,10 @@ from vgi.metadata import FunctionExample from vgi.table_function import ( CardinalityInfo, - GlobalStateInitInput, Output, OutputGenerator, TableFunctionGenerator, + TableFunctionInitInput, ) __all__ = [ @@ -476,7 +476,7 @@ def cardinality(self) -> CardinalityInfo: def perform_init(self, init_input: pa.RecordBatch) -> GlobalInitResult: """Populate the work queue with range chunks.""" # Parse init data and store in init_storage - self.init_data = GlobalStateInitInput.deserialize(init_input) + self.init_data = TableFunctionInitInput.deserialize(init_input) self.init_identifier = self.init_storage.create(self.init_data.serialize()) # Create work items for each chunk of the range diff --git a/vgi/examples/worker.py b/vgi/examples/worker.py index e7331e6..1bb7ea9 100644 --- a/vgi/examples/worker.py +++ b/vgi/examples/worker.py @@ -4,14 +4,20 @@ and listing function classes. Function names are derived from each class's metadata (Meta.name or snake_case of class name). -The worker supports both: +The worker supports: - TableInOutGeneratorFunction: Transforms input batches to output batches - TableFunctionGenerator: Generates output batches without input +- ScalarFunctionGenerator: Transforms input to single-column output (1:1 rows) Usage: vgi-example-worker """ +from vgi.examples.scalar import ( + AddColumnsFunction, + DoubleColumnFunction, + UpperCaseFunction, +) from vgi.examples.table import ( ConstantTableFunction, GeneratorExceptionFunction, @@ -59,6 +65,10 @@ class ExampleWorker(Worker): LoggingGeneratorFunction, PartitionedRangeFunction, ProjectedDataFunction, + # ScalarFunctionGenerator - transform to single-column output + DoubleColumnFunction, + AddColumnsFunction, + UpperCaseFunction, ] diff --git a/vgi/function.py b/vgi/function.py index f2127bc..620fb61 100644 --- a/vgi/function.py +++ b/vgi/function.py @@ -2,21 +2,24 @@ This module defines the foundational classes used during function binding in the VGI protocol. When a client invokes a function, it sends Invocation -describing the function name, arguments, and input schema. The worker -returns an OutputSpec describing the output schema and execution hints. +describing the function name, arguments, input schema, and function type. +The worker returns an OutputSpec describing the output schema and execution hints. Classes: + InvocationType: Enum distinguishing scalar vs table invocation types. Arguments: Container for positional and named function arguments. - Invocation: Complete function invocation request (name, args, schema). + Invocation: Complete function invocation request (name, args, schema, type). OutputSpec: Result from binding a function (output schema, etc). + Function: Base class for all VGI functions. The Invocation and OutputSpec are serialized to Arrow IPC format for transmission between client and worker processes. See Also: + vgi.scalar_function: Scalar functions with 1:1 row transforms. + vgi.table_function: Table functions with cardinality hints. + vgi.table_in_out_function: Streaming table functions for batch transforms. vgi.log: LogLevel and LogMessage for function diagnostics. - vgi.table_function: Extended bind results with cardinality hints. - vgi.table_in_out_function: Streaming table functions built on these primitives. """ @@ -24,7 +27,16 @@ import sqlite3 import uuid from dataclasses import dataclass, replace -from typing import TYPE_CHECKING, Any, ClassVar, Protocol, Self, TypeVar +from enum import Enum +from functools import cached_property +from typing import ( + Any, + ClassVar, + Protocol, + Self, + TypeVar, + final, +) import pyarrow as pa import structlog @@ -34,25 +46,176 @@ from vgi.log import Level, Message from vgi.metadata import MetadataMixin, ResolvedMetadata -if TYPE_CHECKING: - pass - __all__ = [ "Arg", "ArgumentValidationError", "Arguments", "Function", + "InvocationType", "GlobalInitResult", "Level", "Message", "OutputSpec", "Invocation", + "SchemaValidationError", "Serializable", "SqliteInitStorage", "SqliteWorkerStateStorage", ] +class SchemaValidationError(Exception): + """Raised when a batch schema doesn't match the expected schema. + + This error is raised by the framework during input/output validation. + It indicates a programming error where a batch doesn't conform to the + declared schema. + + The error message includes detailed information about what differs: + - Missing fields (in expected but not in actual) + - Extra fields (in actual but not in expected) + - Type mismatches (same field name, different types) + - Field order differences + + Attributes: + expected: The expected Arrow schema. + actual: The actual Arrow schema that was received. + context: Description of where the validation occurred. + + """ + + def __init__( + self, + message: str, + *, + expected: "pa.Schema | None" = None, + actual: "pa.Schema | None" = None, + context: str = "", + ) -> None: + """Initialize with schema comparison details. + + Args: + message: Base error message. + expected: The expected Arrow schema. + actual: The actual Arrow schema. + context: Where the error occurred (e.g., "output from transform()"). + + """ + self.expected = expected + self.actual = actual + self.context = context + + if expected is not None and actual is not None: + full_message = self._build_detailed_message(message, expected, actual) + else: + full_message = message + + super().__init__(full_message) + + def _build_detailed_message( + self, base_message: str, expected: "pa.Schema", actual: "pa.Schema" + ) -> str: + """Build a detailed message showing exactly what differs.""" + lines = [base_message, ""] + + if self.context: + lines.append(f" Context: {self.context}") + lines.append("") + + # Build field maps for comparison + expected_fields = {f.name: f for f in expected} + actual_fields = {f.name: f for f in actual} + + expected_names = set(expected_fields.keys()) + actual_names = set(actual_fields.keys()) + + # Find differences + missing = expected_names - actual_names + extra = actual_names - expected_names + common = expected_names & actual_names + + # Check for type mismatches in common fields + type_mismatches = [] + for name in common: + exp_field = expected_fields[name] + act_field = actual_fields[name] + if exp_field.type != act_field.type: + type_mismatches.append((name, exp_field.type, act_field.type)) + elif exp_field.nullable != act_field.nullable: + exp_null = "nullable" if exp_field.nullable else "non-nullable" + act_null = "nullable" if act_field.nullable else "non-nullable" + type_mismatches.append((name, exp_null, act_null)) + + # Check for order differences (only if names match but order differs) + order_differs = False + if not missing and not extra and not type_mismatches: + expected_order = [f.name for f in expected] + actual_order = [f.name for f in actual] + if expected_order != actual_order: + order_differs = True + + # Report missing fields + if missing: + lines.append(" Missing fields (expected but not found):") + for name in sorted(missing): + field = expected_fields[name] + lines.append(f" - {name}: {field.type}") + + # Report extra fields + if extra: + lines.append(" Extra fields (found but not expected):") + for name in sorted(extra): + field = actual_fields[name] + lines.append(f" - {name}: {field.type}") + + # Report type mismatches + if type_mismatches: + lines.append(" Type mismatches:") + for name, exp_type, act_type in type_mismatches: + lines.append(f" - {name}: expected {exp_type}, got {act_type}") + + # Report order differences + if order_differs: + lines.append(" Field order differs:") + lines.append(f" Expected: {[f.name for f in expected]}") + lines.append(f" Actual: {[f.name for f in actual]}") + + # Summary of schemas + lines.append("") + lines.append(" Expected schema:") + for field in expected: + nullable = " (nullable)" if field.nullable else "" + lines.append(f" {field.name}: {field.type}{nullable}") + + lines.append(" Actual schema:") + for field in actual: + nullable = " (nullable)" if field.nullable else "" + lines.append(f" {field.name}: {field.type}{nullable}") + + return "\n".join(lines) + + +class InvocationType(Enum): + """Type of VGI invocation for protocol dispatch. + + Used by the client to determine the correct init data format to send + to the worker. Scalar functions use FunctionInitInput (no projection), + while table functions use TableFunctionInitInput (with projection support). + + Note: This is distinct from vgi.metadata.FunctionType which is used for + DuckDB catalog registration and includes AGGREGATE. + + Attributes: + SCALAR: Scalar function that transforms input batches to single-column output. + TABLE: Table function (either generator or table-in-out) that produces + multi-column output. + + """ + + SCALAR = "scalar" + TABLE = "table" + + class Serializable(Protocol): """Protocol for objects that can be serialized to/from bytes. @@ -176,21 +339,23 @@ class Invocation: Invocation encapsulates all information needed to bind and execute a function: the function name, its arguments, the expected input schema (for table - functions), and identifiers for logging and correlation. + functions), the function type, and identifiers for logging and correlation. This is serialized to Arrow IPC format and sent as the first message when the client connects to a worker subprocess. Attributes: function_name: Name of the function to invoke, must exist in worker registry. - arguments: Positional and named arguments passed to the function. - in_out_function_input_schema: Arrow schema of input data (required for - Function, None for scalar functions or functions that don't - process input tables). + input_schema: Arrow schema of input data. Required for table-in-out and + scalar functions that process input batches. None for table functions + that generate output without input. + function_type: Type of function being invoked (SCALAR or TABLE). Used by + the client to determine the correct init data format to send. correlation_id: String identifier for logging and correlation purposes. invocation_id: Unique bytes identifying this function binding. Used to correlate multiple parallel workers processing the same logical call. global_init_identifier: Optional result from global initialization phase. + arguments: Positional and named arguments passed to the function. client_features: Feature flags supported by the client. The worker will respond with active_features in OutputSpec indicating which features will be used for this invocation. @@ -201,16 +366,18 @@ class Invocation: Example: invocation = Invocation( function_name="sum_columns", - arguments=Arguments(positional=("col1", "col2")), - in_out_function_input_schema=pa.schema([pa.field("col1", pa.int64())]), + input_schema=pa.schema([pa.field("col1", pa.int64())]), + function_type=InvocationType.TABLE, correlation_id="request-123", invocation_id=None, # Set by worker after binding + arguments=Arguments(positional=("col1", "col2")), ) """ function_name: str - in_out_function_input_schema: pa.Schema | None + input_schema: pa.Schema | None + function_type: InvocationType correlation_id: str # The unique identifier for the call, typically this may be a uuid. @@ -248,11 +415,12 @@ def serialize(self) -> bytes: { "function_name": self.function_name, "arguments": args_dict, - "in_out_function_input_schema": ( - self.in_out_function_input_schema.serialize().to_pybytes() - if self.in_out_function_input_schema + "input_schema": ( + self.input_schema.serialize().to_pybytes() + if self.input_schema else None ), + "function_type": self.function_type.value, "invocation_id": self.invocation_id, "correlation_id": self.correlation_id, GlobalInitResult._IDENTIFIER_FIELD_NAME: ( @@ -268,9 +436,8 @@ def serialize(self) -> bytes: [ pa.field("function_name", pa.string(), nullable=False), pa.field("arguments", args_struct_type, nullable=True), - pa.field( - "in_out_function_input_schema", pa.binary(), nullable=True - ), + pa.field("input_schema", pa.binary(), nullable=True), + pa.field("function_type", pa.string(), nullable=False), pa.field("invocation_id", pa.binary(), nullable=True), pa.field("correlation_id", pa.string(), nullable=False), pa.field( @@ -303,7 +470,8 @@ def deserialize(data: pa.RecordBatch) -> "Invocation": required_fields = [ "function_name", "arguments", - "in_out_function_input_schema", + "input_schema", + "function_type", "invocation_id", "correlation_id", ] @@ -311,11 +479,12 @@ def deserialize(data: pa.RecordBatch) -> "Invocation": data, "Invocation", required_fields=required_fields ) - in_out_function_input_schema = None - if first_row["in_out_function_input_schema"] is not None: - in_out_function_input_schema = pa.ipc.read_schema( - pa.py_buffer(first_row["in_out_function_input_schema"]) - ) + input_schema = None + if first_row["input_schema"] is not None: + input_schema = pa.ipc.read_schema(pa.py_buffer(first_row["input_schema"])) + + # Parse function_type from string value + function_type = InvocationType(first_row["function_type"]) # Parse global_init_identifier - only create GlobalInitResult if field exists # and has a non-None value @@ -339,8 +508,9 @@ def deserialize(data: pa.RecordBatch) -> "Invocation": return Invocation( function_name=first_row["function_name"], + input_schema=input_schema, + function_type=function_type, arguments=Arguments.decode(data.column("arguments")[0]), - in_out_function_input_schema=in_out_function_input_schema, invocation_id=first_row["invocation_id"], correlation_id=first_row["correlation_id"], global_init_identifier=global_init_identifier, @@ -787,7 +957,54 @@ def cleanup_queue(self, invocation_id: bytes) -> int: conn.close() -class Function(MetadataMixin): +class FunctionInitInput: + """Input sent to initialize global state for a Function. + + This is the base init input class for functions that don't require + any initialization data (like scalar functions). It serializes to + an empty single-row batch. + """ + + def serialize(self) -> bytes: + """Serialize FunctionInitInput to bytes. + + Creates a single-row batch with an empty schema. The batch must have + exactly 1 row so that deserialize can access row 0. + """ + # Create a batch with 1 row using a struct array approach + struct_array: pa.StructArray = pa.array([{}], type=pa.struct([])) # type: ignore[assignment] + batch = pa.RecordBatch.from_struct_array(struct_array) + return vgi.ipc_utils.serialize_record_batch(batch) + + @classmethod + def deserialize(cls, _batch: pa.RecordBatch) -> Self: + """Deserialize FunctionInitInput from a RecordBatch. + + Args: + _batch: RecordBatch (unused - FunctionInitInput has no fields). + + Returns: + New FunctionInitInput instance. + + """ + return cls() + + @classmethod + def deserialize_bytes(cls, data: bytes) -> Self: + """Deserialize FunctionInitInput from bytes. + + Args: + data: Serialized bytes. + + Returns: + New FunctionInitInput instance. + + """ + batch = vgi.ipc_utils.deserialize_record_batch(data) + return cls.deserialize(batch) + + +class Function[T: FunctionInitInput](MetadataMixin): """Base class for all VGI functions. Functions are instantiated with Invocation describing the invocation, @@ -830,12 +1047,16 @@ class Meta: init_storage: ClassVar[SqliteInitStorage] = SqliteInitStorage() state_storage: ClassVar[SqliteWorkerStateStorage] = SqliteWorkerStateStorage() + # Cache for resolved metadata + _metadata_cache: ClassVar[ResolvedMetadata | None] = None + # The unique identifier for init data in storage. Set by perform_init() # or retrieve_init(). Used to correlate parallel workers and for state storage. init_identifier: bytes | None = None - # Cache for resolved metadata - _metadata_cache: ClassVar[ResolvedMetadata | None] = None + # This is the init data that may be been read. + InitDataCls: type[T] + init_data: T | None = None def __init__( self, @@ -887,27 +1108,6 @@ def create_invocation_id(self) -> bytes: """ return uuid.uuid4().bytes - def perform_init(self, init_input: pa.RecordBatch) -> GlobalInitResult: - """Perform any global initialization required before processing. - - This method is called once per worker process before any data - batches are processed. Override to set up shared resources, load - models, or perform expensive setup tasks. - - Args: - init_input: An initial RecordBatch that may contain configuration - or context information for initialization. - - """ - # If there is an id supplied, detect it so it will be passed on. - if GlobalInitResult.has_identifier(init_input): - return GlobalInitResult.deserialize(init_input) - - return GlobalInitResult() - - def retrieve_init(self, init_input: GlobalInitResult) -> None: - """Retrieve init data from storage (default does nothing).""" - @property def output_schema(self) -> pa.Schema: """Return the output schema (must be implemented by subclass).""" @@ -1062,3 +1262,98 @@ def process(self) -> OutputGenerator: "retrieve_init() for secondary workers." ) return self.state_storage.dequeue_work(self.init_identifier) + + @final + @cached_property + def empty_output_batch(self) -> pa.RecordBatch: + """Return an empty batch conforming to output_schema. Cached.""" + output_schema = self.output_schema + return pa.RecordBatch.from_arrays( + [pa.array([], type=field.type) for field in output_schema], + schema=output_schema, + ) + + @final + def _validate_output_schema(self, batch: pa.RecordBatch) -> None: + """Validate that a batch conforms to the expected output schema.""" + if batch.schema != self.output_schema: + raise SchemaValidationError( + "Output batch schema does not match expected output_schema.", + expected=self.output_schema, + actual=batch.schema, + context=f"output from {type(self).__name__}", + ) + + @property + def input_schema(self) -> pa.Schema: + """Return the input schema from the invocation. + + This property is available for functions that receive input batches + (ScalarFunction, TableInOutFunction). For TableFunctionGenerator, + the invocation.input_schema is None. + + Raises: + ValueError: If invocation.input_schema is None. + + """ + if self.invocation.input_schema is None: + raise ValueError( + "input_schema is not available for this function type. " + "TableFunctionGenerator does not receive input batches." + ) + return self.invocation.input_schema + + @final + def _validate_input_schema(self, batch: pa.RecordBatch) -> None: + """Validate that a batch conforms to the expected input schema.""" + if batch.schema != self.input_schema: + raise SchemaValidationError( + "Input batch schema does not match expected input_schema.", + expected=self.input_schema, + actual=batch.schema, + context=f"input to {type(self).__name__}", + ) + + def perform_init(self, init_input: pa.RecordBatch) -> GlobalInitResult: + """Perform a new init call and store it in the storage.""" + self.init_data = self.InitDataCls.deserialize(init_input) + assert self.init_data is not None + self.init_identifier = self.init_storage.create(self.init_data.serialize()) + return GlobalInitResult(self.init_identifier) + + def retrieve_init(self, init_input: GlobalInitResult) -> None: + """Retrieve and store init data from the storage.""" + if init_input.global_init_identifier is None: + raise ValueError( + "global_init_identifier is required but was None. " + "This indicates the GlobalInitResult was not properly initialized. " + "Ensure perform_init() returns a GlobalInitResult with a valid " + "identifier." + ) + self.init_identifier = init_input.global_init_identifier + self.init_data = self.InitDataCls.deserialize_bytes( + self.init_storage.get(self.init_identifier) + ) + + def setup(self) -> None: + """Acquire resources before processing starts. + + Override to acquire resources like database connections, file handles, + or external service clients. Called after init_data is available. + + Available at this point: + - self.init_data: The init data (FunctionInitInput or + TableFunctionInitInput) + - self.init_identifier: Storage key for distributed state + - self.invocation: The complete invocation request + + """ + pass + + def teardown(self) -> None: + """Release resources after processing completes. + + Always called, even if an error occurred during processing. + + """ + pass diff --git a/vgi/scalar_function.py b/vgi/scalar_function.py new file mode 100644 index 0000000..3e73573 --- /dev/null +++ b/vgi/scalar_function.py @@ -0,0 +1,445 @@ +"""Base classes for scalar functions that transform input batches to single-column. + +Scalar functions transform input batches to single-column output. + +Scalar functions receive input batches and produce output batches where: +1. Each Output yield must have row count matching input (1:1 mapping) +2. Output schema has exactly one column +3. Message yields (for logging) are exempt from row count validation + +This module provides: +- ScalarFunctionGenerator: Generator-based base class (like TableInOutGeneratorFunction) +- ScalarFunction: Callback-based API with compute() method (like TableInOutFunction) + +Class Hierarchy: + Function (vgi.function) + └── ScalarFunctionGenerator (generator protocol, validates row count) + └── ScalarFunction (callback API with compute()) + +ScalarFunctionGenerator is useful for functions that need full generator control +including yielding log messages. For most use cases, use ScalarFunction with its +simpler compute() method. +""" + +from __future__ import annotations + +from abc import abstractmethod +from collections.abc import Generator +from dataclasses import dataclass +from typing import Any, final + +import pyarrow as pa +import structlog + +import vgi.function +import vgi.log +from vgi.function import SchemaValidationError +from vgi.table_function import Output, ProtocolOutput, RowCountMismatchError + +__all__ = [ + "ScalarFunctionGenerator", + "ScalarFunction", + "Output", + "ScalarOutputGenerator", + "ProtocolInput", +] + + +# Scalar functions must always produce output - None is not valid +# (unlike OutputGenerator which allows None for buffering/aggregation) +ScalarOutputGenerator = Generator[vgi.log.Message | Output, pa.RecordBatch | None, None] + + +@dataclass(frozen=True, slots=True) +class ProtocolInput: + """Input sent to the scalar function generator via send(). + + This is a simplified version of table_in_out_function.ProtocolInput + without finalization support, since scalar functions don't have a + finalize phase. + + Attributes: + batch: The input RecordBatch to process. + metadata: Optional metadata from the IPC stream. + + """ + + batch: pa.RecordBatch + metadata: pa.KeyValueMetadata | None = None + + +@dataclass(frozen=True, slots=True) +class _ScalarOutputComplete: + """Internal: Output with guaranteed non-None batch for scalar functions.""" + + batch: pa.RecordBatch + log_message: vgi.log.Message | None = None + + @classmethod + def from_process_result( + cls, + source: vgi.log.Message | Output, + empty_batch: pa.RecordBatch, + ) -> _ScalarOutputComplete: + """Create from user's yield value. + + Args: + source: What the user yielded (Output or Message). + empty_batch: Empty batch to substitute when yielding Message. + + Returns: + Normalized output with guaranteed non-None batch. + + """ + if isinstance(source, vgi.log.Message): + return cls(batch=empty_batch, log_message=source) + # source is Output + return cls( + batch=source.batch if source.batch is not None else empty_batch, + ) + + +class ScalarFunctionGenerator(vgi.function.Function[vgi.function.FunctionInitInput]): + """Base class for scalar functions with generator protocol. + + Scalar functions transform input batches to single-column output with + 1:1 row mapping. Unlike TableInOutGeneratorFunction, scalar functions: + - Have no finalize() phase + - Must produce exactly one output row per input row + - Must have exactly one column in output_schema + + Override process() for full generator control. Must yield Output or Message + (unlike table-in-out functions, yielding None is not allowed): + + def process(self, batch: pa.RecordBatch) -> ScalarOutputGenerator: + _ = yield Output(self.empty_output_batch) # Priming yield + while True: + # Optional: yield log messages + yield Message(Level.INFO, f"Processing {batch.num_rows} rows") + + result_array = compute_result(batch) + output_batch = pa.RecordBatch.from_arrays( + [result_array], schema=self.output_schema + ) + batch = yield Output(output_batch) + if batch is None: + break + + METHODS TO OVERRIDE + ------------------- + output_schema -> pa.Schema (property) + Override to define the single-column output schema. + + process(batch: pa.RecordBatch) -> ScalarOutputGenerator + Generator that processes input batches. Must yield Output or Message. + + AVAILABLE ATTRIBUTES + -------------------- + self.invocation: Invocation - The complete invocation request + self.input_schema: pa.Schema - Input schema (from invocation) + self.output_schema: pa.Schema - Property returning the output schema + self.empty_output_batch - Empty batch conforming to output_schema + """ + + InitDataCls = vgi.function.FunctionInitInput + + def __init__( + self, + invocation: vgi.function.Invocation, + logger: structlog.stdlib.BoundLogger, + ): + """Initialize the scalar function with invocation data and logger.""" + super().__init__(invocation=invocation, logger=logger) + if invocation.input_schema is None: + raise ValueError( + f"{type(self).__name__} requires an input schema, but none was " + f"provided. ScalarFunction processes input batches and requires " + f"input_schema to be set in the Invocation." + ) + # Validate single-column output at construction + if len(self.output_schema) != 1: + cols = [f.name for f in self.output_schema] + raise SchemaValidationError( + f"ScalarFunction must have exactly 1 output column, " + f"but output_schema has {len(self.output_schema)} columns.\n\n" + f" Columns found: {cols}\n\n" + f" Scalar functions transform each input row to a single value.\n" + f" If you need multiple output columns, use TableInOutFunction." + ) + + # input_schema property and _validate_input_schema inherited from Function + + @final + def _validate_row_count( + self, output_batch: pa.RecordBatch, input_batch: pa.RecordBatch + ) -> None: + """Validate that output row count matches input row count.""" + if output_batch.num_rows != input_batch.num_rows: + raise RowCountMismatchError( + "Scalar function output must have same row count as input.", + input_rows=input_batch.num_rows, + output_rows=output_batch.num_rows, + function_name=type(self).__name__, + ) + + @final + def _process_and_validate( + self, + generator: ScalarOutputGenerator, + input_batch: pa.RecordBatch, + ) -> _ScalarOutputComplete: + """Process a batch and validate schemas and row count. + + Args: + generator: The user's process() generator. + input_batch: The input RecordBatch to process. + + Returns: + _ScalarOutputComplete with validated output batch. + + Raises: + SchemaValidationError: If input or output batch schema doesn't match. + RowCountMismatchError: If output row count doesn't match input. + + """ + self._validate_input_schema(input_batch) + result: _ScalarOutputComplete = _ScalarOutputComplete.from_process_result( + generator.send(input_batch), + self.empty_output_batch, + ) + self._validate_output_schema(result.batch) + # Validate row count for actual output (not log messages) + if result.log_message is None: + self._validate_row_count(result.batch, input_batch) + return result + + @final + def _process_with_exception_handling( + self, + generator: ScalarOutputGenerator, + input_batch: pa.RecordBatch, + ) -> _ScalarOutputComplete: + """Process a batch with exception handling. + + Wraps _process_and_validate to catch exceptions and convert them + to _ScalarOutputComplete with an error log message. + """ + try: + return self._process_and_validate(generator, input_batch) + except Exception as e: + return _ScalarOutputComplete( + batch=self.empty_output_batch, + log_message=vgi.log.Message.from_exception(e), + ) + + @final + def _should_terminate(self, result: _ScalarOutputComplete) -> bool: + """Check if processing should terminate due to an exception.""" + return ( + result.log_message is not None + and result.log_message.level == vgi.log.Level.EXCEPTION + ) + + @abstractmethod + def process(self, batch: pa.RecordBatch) -> ScalarOutputGenerator: + """Process input batches. + + Override this method to implement your scalar transformation. + + Args: + batch: First input batch (subsequent batches via yield return). + + Yields: + Output: Batch with same row count as input. + Message: Log message (input will be re-sent). + + """ + ... + + @final + def run(self) -> Generator[ProtocolOutput, ProtocolInput | None, None]: + """Run the scalar function protocol. Do not override. + + This generator implements the SETUP -> DATA -> TEARDOWN lifecycle. + The generator is closed by the caller when input is exhausted. + + Protocol: + - Caller primes with next() or send(None) + - Caller sends ProtocolInput for each batch + - When input exhausted, caller closes the generator + """ + # Priming yield - caller calls next() or send(None) + input: ProtocolInput | None = yield ProtocolOutput(batch=None) + if input is None: + raise ValueError("Expected ProtocolInput, got None") + + # Acquire resources before processing + self.setup() + + generator = self.process(input.batch) + # Prime the process() generator past the initial yield + generator.send(None) + + try: + # DATA phase - process batches until generator is closed + while True: + result = self._process_with_exception_handling(generator, input.batch) + + input = yield ProtocolOutput( + batch=result.batch, + log_message=result.log_message, + ) + if input is None: + raise ValueError("Expected ProtocolInput, got None") + if self._should_terminate(result): + return + finally: + generator.close() + # Release resources after processing completes + self.teardown() + + +class ScalarFunction(ScalarFunctionGenerator): + """Simplified base class using compute() callback instead of generators. + + This class provides a simpler API for scalar functions. Instead of + implementing process() as a generator, you override compute() as a + regular method that returns a single Array. + + METHODS TO OVERRIDE + ------------------- + output_type -> pa.DataType (property) + Return the Arrow type for the output column. + + compute(batch) -> pa.Array + Transform the input batch to a single output array. + Must return an array with exactly batch.num_rows elements. + + LOGGING + ------- + Call self.log(level, message) from compute() to emit log messages: + + def compute(self, batch: pa.RecordBatch) -> pa.Array: + self.log(Level.INFO, f"Processing {batch.num_rows} rows") + return pc.multiply(batch.column("x"), 2) + + Example: + ------- + class DoubleColumn(ScalarFunction): + column = Arg[str](0, doc="Column to double") + + @property + def output_type(self) -> pa.DataType: + return self.input_schema.field(self.column).type + + def compute(self, batch: pa.RecordBatch) -> pa.Array: + return pc.multiply(batch.column(self.column), 2) + + """ + + # Message queue for log() method (same pattern as TableInOutFunction) + _pending_messages: list[vgi.log.Message] + + def __init__( + self, + invocation: vgi.function.Invocation, + logger: structlog.stdlib.BoundLogger, + ): + """Initialize the scalar function.""" + # Initialize pending messages before super().__init__ because + # output_schema property may be accessed during init + self._pending_messages = [] + super().__init__(invocation=invocation, logger=logger) + + def log(self, level: vgi.log.Level, message: str) -> None: + """Queue a log message to be emitted with the output. + + Messages are yielded before the compute() result. + + Args: + level: Log level (DEBUG, INFO, WARNING, ERROR). + message: Log message text. + + Example: + def compute(self, batch: pa.RecordBatch) -> pa.Array: + self.log(Level.INFO, f"Processing {batch.num_rows} rows") + return pc.multiply(batch.column(self.column), 2) + + """ + self._pending_messages.append(vgi.log.Message(level=level, message=message)) + + @property + @abstractmethod + def output_type(self) -> pa.DataType: + """Return the Arrow type for the output column. + + Override this property to specify the output column type. + + Example: + @property + def output_type(self) -> pa.DataType: + return pa.int64() + + # Or derive from input: + @property + def output_type(self) -> pa.DataType: + return self.input_schema.field(self.column).type + + """ + ... + + @property + @final + def output_schema(self) -> pa.Schema: + """Return single-column output schema. Do not override.""" + return pa.schema([pa.field("result", self.output_type)]) + + @abstractmethod + def compute(self, batch: pa.RecordBatch) -> pa.Array[Any]: + """Compute output array from input batch. + + Override this method to implement your scalar transformation. + + Args: + batch: Input RecordBatch. + + Returns: + Array with exactly batch.num_rows elements. + + Example: + def compute(self, batch: pa.RecordBatch) -> pa.Array[Any]: + return pc.multiply(batch.column("x"), 2) + + """ + ... + + @final + def _yield_pending_messages(self) -> ScalarOutputGenerator: + """Yield all pending log messages. Helper for process().""" + while self._pending_messages: + msg = self._pending_messages.pop(0) + _ = yield msg + + @final + def process(self, batch: pa.RecordBatch) -> ScalarOutputGenerator: + """Convert compute() to generator protocol. Do not override. + + This method implements the generator protocol by calling your compute() + method for each input batch. + """ + # Priming yield + _ = yield Output(self.empty_output_batch) + + while True: + result = self.compute(batch) + + # Yield any pending log messages first + yield from self._yield_pending_messages() + + # Create output batch from result array + output = pa.RecordBatch.from_arrays([result], schema=self.output_schema) + received = yield Output(output) + + if received is None: + break + batch = received diff --git a/vgi/table_function.py b/vgi/table_function.py index e8698c1..4f6a68d 100644 --- a/vgi/table_function.py +++ b/vgi/table_function.py @@ -23,8 +23,7 @@ from collections.abc import Generator from dataclasses import dataclass -from functools import cached_property -from typing import Any, final +from typing import Any, Self, final import pyarrow as pa import structlog @@ -35,25 +34,98 @@ __all__ = [ "CardinalityInfo", - "GlobalStateInitInput", + "TableFunctionInitInput", "Output", "OutputGenerator", "OutputSpec", "ProtocolOutput", - "SchemaValidationError", + "RowCountMismatchError", "TableFunctionBase", "TableFunctionGenerator", ] -class SchemaValidationError(Exception): - """Raised when a batch schema doesn't match the expected schema. +class RowCountMismatchError(Exception): + """Raised when scalar function output row count doesn't match input. + + Scalar functions must produce exactly one output row for each input row. + This error indicates the compute() method returned an array with the + wrong number of elements. + + Attributes: + input_rows: Number of rows in the input batch. + output_rows: Number of rows in the output batch. + function_name: Name of the function that produced the mismatch. - This error is raised by the framework during input/output validation. - It indicates a programming error where a batch doesn't conform to the - declared schema. """ + def __init__( + self, + message: str, + *, + input_rows: int | None = None, + output_rows: int | None = None, + function_name: str = "", + ) -> None: + """Initialize with row count details. + + Args: + message: Base error message. + input_rows: Number of input rows. + output_rows: Number of output rows. + function_name: Name of the function class. + + """ + self.input_rows = input_rows + self.output_rows = output_rows + self.function_name = function_name + + if input_rows is not None and output_rows is not None: + full_message = self._build_detailed_message( + message, input_rows, output_rows + ) + else: + full_message = message + + super().__init__(full_message) + + def _build_detailed_message( + self, base_message: str, input_rows: int, output_rows: int + ) -> str: + """Build a detailed, helpful error message.""" + lines = [base_message, ""] + + if self.function_name: + lines.append(f" Function: {self.function_name}") + + lines.append(f" Input rows: {input_rows}") + lines.append(f" Output rows: {output_rows}") + + # Provide specific guidance based on the mismatch type + lines.append("") + if output_rows < input_rows: + lines.append(" Problem: Output has fewer rows than input.") + lines.append("") + lines.append(" Possible causes:") + lines.append(" - compute() is filtering rows (not allowed)") + lines.append(" - compute() is aggregating (use TableInOutFunction)") + lines.append(" - Bug in array construction") + lines.append("") + lines.append(" If you need to filter or aggregate rows:") + lines.append(" Use TableInOutFunction instead of ScalarFunction.") + else: + lines.append(" Problem: Output has more rows than input.") + lines.append("") + lines.append(" Possible causes:") + lines.append(" - compute() is expanding rows (not allowed)") + lines.append(" - compute() is unnesting arrays") + lines.append(" - Bug in array construction") + lines.append("") + lines.append(" If you need to expand rows (1→N mapping):") + lines.append(" Use TableInOutFunction instead of ScalarFunction.") + + return "\n".join(lines) + @dataclass(frozen=True, slots=True) class CardinalityInfo: @@ -82,44 +154,6 @@ class CardinalityInfo: max: int | None -@dataclass(frozen=True, slots=True) -class GlobalStateInitInput: - """Input sent to initialize global state for a TableFunction. - - Attributes: - projection_ids: Optional list of column indices to project, or None for all. - - Note: - For parallel execution, functions should use the work queue pattern - via enqueue_work() and dequeue_work() methods on the Function base class - instead of static partitioning. - - """ - - projection_ids: list[int] | None = None - - def serialize(self) -> bytes: - """Serialize GlobalStateInitInput to bytes.""" - batch = pa.RecordBatch.from_arrays( - [pa.array([self.projection_ids], type=pa.list_(pa.int32()))], - schema=pa.schema([pa.field("projection_ids", pa.list_(pa.int32()))]), - ) - return vgi.ipc_utils.serialize_record_batch(batch) - - @staticmethod - def deserialize(batch: pa.RecordBatch) -> "GlobalStateInitInput": - """Deserialize GlobalStateInitInput from a RecordBatch.""" - values = batch.to_pylist()[0] - # Handle backward compatibility: ignore extra fields - return GlobalStateInitInput(projection_ids=values.get("projection_ids")) - - @staticmethod - def deserialize_bytes(data: bytes) -> "GlobalStateInitInput": - """Deserialize GlobalStateInitInput from bytes.""" - batch = vgi.ipc_utils.deserialize_record_batch(data) - return GlobalStateInitInput.deserialize(batch) - - @dataclass(frozen=True, slots=True) class OutputSpec(vgi.function.OutputSpec): """Extended bind result for table functions with cardinality information. @@ -272,15 +306,49 @@ def from_process_result(cls, process_result: "_OutputComplete") -> "ProtocolOutp ) -class TableFunctionBase(vgi.function.Function): +@dataclass(frozen=True, slots=True) +class TableFunctionInitInput(vgi.function.FunctionInitInput): + """Input sent to initialize global state for a TableFunction. + + Attributes: + projection_ids: Optional list of column indices to project, or None for all. + + Note: + For parallel execution, functions should use the work queue pattern + via enqueue_work() and dequeue_work() methods on the Function base class + instead of static partitioning. + + """ + + projection_ids: list[int] | None = None + + def serialize(self) -> bytes: + """Serialize TableFunctionInitInput to bytes.""" + batch = pa.RecordBatch.from_arrays( + [pa.array([self.projection_ids], type=pa.list_(pa.int32()))], + schema=pa.schema([pa.field("projection_ids", pa.list_(pa.int32()))]), + ) + return vgi.ipc_utils.serialize_record_batch(batch) + + @classmethod + def deserialize(cls, batch: pa.RecordBatch) -> Self: # type: ignore[override] + """Deserialize TableFunctionInitInput from a RecordBatch.""" + values = batch.to_pylist()[0] + # Handle backward compatibility: ignore extra fields + return cls(projection_ids=values.get("projection_ids")) + + @classmethod + def deserialize_bytes(cls, data: bytes) -> Self: + """Deserialize TableFunctionInitInput from bytes.""" + batch = vgi.ipc_utils.deserialize_record_batch(data) + return cls.deserialize(batch) + + +class TableFunctionBase(vgi.function.Function[TableFunctionInitInput]): """Base class for table functions with cardinality and schema validation. Extends Function with: - Cardinality hints for query optimization - - Output schema validation - - Setup/teardown lifecycle hooks - - Empty output batch helper (cached) - - Init data storage and retrieval - Projection pushdown support This class is not meant to be used directly. Subclass either: @@ -288,7 +356,7 @@ class TableFunctionBase(vgi.function.Function): - TableInOutGeneratorFunction: For functions that transform input batches Attributes: - init_data: GlobalStateInitInput with projection info (set after init) + init_data: TableFunctionInitInput with projection info (set after init) empty_output_batch: Cached empty batch conforming to output_schema See Also: @@ -297,8 +365,8 @@ class TableFunctionBase(vgi.function.Function): """ - # This is the init data that may be been read. - init_data: GlobalStateInitInput | None = None + InitDataCls = TableFunctionInitInput + init_data: TableFunctionInitInput | None = None def __init__( self, @@ -316,28 +384,6 @@ def __init__( """ super().__init__(invocation=invocation, logger=logger) - def setup(self) -> None: - """Acquire resources before processing starts. - - Override to acquire resources like database connections, file handles, - or external service clients. Called after init_data is available. - - Available at this point: - - self.init_data: The GlobalStateInitInput with projection info - - self.init_identifier: Storage key for distributed state - - self.invocation: The complete invocation request - - """ - pass - - def teardown(self) -> None: - """Release resources after processing completes. - - Always called, even if an error occurred during processing. - - """ - pass - def cardinality(self) -> CardinalityInfo | None: """Return optional cardinality estimate for the output. @@ -350,50 +396,6 @@ def cardinality(self) -> CardinalityInfo | None: """ return None - @final - @cached_property - def empty_output_batch(self) -> pa.RecordBatch: - """Return an empty batch conforming to output_schema. Cached.""" - output_schema = self.output_schema - return pa.RecordBatch.from_arrays( - [pa.array([], type=field.type) for field in output_schema], - schema=output_schema, - ) - - @final - def _validate_output_schema(self, batch: pa.RecordBatch) -> None: - """Validate that a batch conforms to the expected output schema.""" - if batch.schema != self.output_schema: - raise SchemaValidationError( - f"Output batch schema does not match expected output_schema. " - f"Expected: {self.output_schema}, got: {batch.schema}" - ) - - @property - def output_schema(self) -> pa.Schema: - """Return the output schema (default: passthrough input schema).""" - raise NotImplementedError("Subclasses must implement output_schema property") - - def perform_init(self, init_input: pa.RecordBatch) -> vgi.function.GlobalInitResult: - """Perform a new init call and store it in the storage.""" - self.init_data = GlobalStateInitInput.deserialize(init_input) - self.init_identifier = self.init_storage.create(self.init_data.serialize()) - return vgi.function.GlobalInitResult(self.init_identifier) - - def retrieve_init(self, init_input: vgi.function.GlobalInitResult) -> None: - """Retrieve and store init data from the storage.""" - if init_input.global_init_identifier is None: - raise ValueError( - "global_init_identifier is required but was None. " - "This indicates the GlobalInitResult was not properly initialized. " - "Ensure perform_init() returns a GlobalInitResult with a valid " - "identifier." - ) - self.init_identifier = init_input.global_init_identifier - self.init_data = GlobalStateInitInput.deserialize_bytes( - self.init_storage.get(self.init_identifier) - ) - def apply_projection(self, schema: pa.Schema) -> pa.Schema: """Apply any projection specified in the init data to the schema. diff --git a/vgi/table_in_out_function.py b/vgi/table_in_out_function.py index 9720629..7288609 100644 --- a/vgi/table_in_out_function.py +++ b/vgi/table_in_out_function.py @@ -523,7 +523,7 @@ def process(self, batch: pa.RecordBatch) -> OutputGenerator: invocation = vgi.function.Invocation( function_name="my_function", arguments=vgi.function.Arguments(positional=[], named={}), - in_out_function_input_schema=input_schema, + input_schema=input_schema, correlation_id="", invocation_id=None, ) @@ -553,21 +553,16 @@ def __init__( ): """Initialize the function with invocation data and logger.""" super().__init__(invocation=invocation, logger=logger) - if invocation.in_out_function_input_schema is None: + if invocation.input_schema is None: raise ValueError( f"{type(self).__name__} requires an input schema, but none was " f"provided. TableInOutGeneratorFunction processes input batches and " - f"requires in_out_function_input_schema to be set in the Invocation. " + f"requires input_schema to be set in the Invocation. " f"If your function generates output without input, inherit from " f"TableFunctionGenerator instead." ) - @property - def input_schema(self) -> pa.Schema: - """Return the input schema from the invocation.""" - # Validated as non-None in __init__ - assert self.invocation.in_out_function_input_schema is not None - return self.invocation.in_out_function_input_schema + # input_schema property inherited from Function def teardown(self) -> None: """Release resources after processing completes. @@ -586,14 +581,7 @@ def output_schema(self) -> pa.Schema: """Return the output schema (default: passthrough input schema).""" return self.input_schema - @final - def _validate_input_schema(self, batch: pa.RecordBatch) -> None: - """Validate that a batch conforms to the expected input schema.""" - if batch.schema != self.input_schema: - raise vgi.table_function.SchemaValidationError( - f"Input batch schema does not match expected input_schema. " - f"Expected: {self.input_schema}, got: {batch.schema}" - ) + # _validate_input_schema inherited from Function @final def _process_and_validate( diff --git a/vgi/testing.py b/vgi/testing.py index f2c8520..df85a71 100644 --- a/vgi/testing.py +++ b/vgi/testing.py @@ -82,15 +82,22 @@ import structlog import structlog.stdlib -from vgi.function import Arguments, Invocation +from vgi.function import Arguments, Invocation, InvocationType from vgi.log import Level, Message -from vgi.table_function import ( - GlobalStateInitInput, - TableFunctionGenerator, +from vgi.scalar_function import ( + ProtocolInput as ScalarProtocolInput, +) +from vgi.scalar_function import ( + ScalarFunction, + ScalarFunctionGenerator, ) from vgi.table_function import ( ProtocolOutput as TableProtocolOutput, ) +from vgi.table_function import ( + TableFunctionGenerator, + TableFunctionInitInput, +) from vgi.table_in_out_function import ( ProtocolInput, ProtocolOutput, @@ -103,12 +110,15 @@ "FunctionTestClient", "FunctionTestClientError", "TableFunctionTestClient", + "ScalarFunctionTestClient", "batch", "assert_function_output", "assert_function_logs", "run_function", "run_table_function", "assert_table_function_output", + "run_scalar_function", + "assert_scalar_function_output", ] @@ -203,10 +213,11 @@ def table_in_out_function( invocation_id = uuid.uuid4().bytes invocation = Invocation( function_name=self.function_class.__name__, - arguments=arguments, - in_out_function_input_schema=input_schema, + input_schema=input_schema, + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=invocation_id, + arguments=arguments, ) # Instantiate function @@ -235,8 +246,8 @@ def table_in_out_function( ) bind_result_callback(bind_batch) - # Perform init with GlobalStateInitInput - init_input = GlobalStateInitInput(projection_ids=projection_ids) + # Perform init with TableFunctionInitInput + init_input = TableFunctionInitInput(projection_ids=projection_ids) init_batch = pa.RecordBatch.from_arrays( [pa.array([init_input.projection_ids], type=pa.list_(pa.int32()))], schema=pa.schema([pa.field("projection_ids", pa.list_(pa.int32()))]), @@ -413,17 +424,18 @@ def table_function( invocation_id = uuid.uuid4().bytes invocation = Invocation( function_name=self.function_class.__name__, - arguments=arguments, - in_out_function_input_schema=None, + input_schema=None, + function_type=InvocationType.TABLE, correlation_id="test", invocation_id=invocation_id, + arguments=arguments, ) # Instantiate function func = self.function_class(invocation=invocation, logger=self._logger) - # Perform init with GlobalStateInitInput - init_input = GlobalStateInitInput(projection_ids=projection_ids) + # Perform init with TableFunctionInitInput + init_input = TableFunctionInitInput(projection_ids=projection_ids) init_batch = pa.RecordBatch.from_arrays( [pa.array([init_input.projection_ids], type=pa.list_(pa.int32()))], schema=pa.schema([pa.field("projection_ids", pa.list_(pa.int32()))]), @@ -895,3 +907,315 @@ def assert_table_function_output( ) return logs + + +# ============================================================================= +# Scalar Function Test Client and Helpers +# ============================================================================= + + +class ScalarFunctionTestClient: + """In-process client for testing ScalarFunction and ScalarFunctionGenerator. + + Scalar functions transform input batches to single-column output with 1:1 + row mapping. Unlike TableInOut functions, scalar functions have no finalize + phase. + + Example: + with ScalarFunctionTestClient(DoubleColumnFunction) as client: + outputs = list(client.scalar_function( + input=iter([batch]), + arguments=Arguments(positional=(pa.scalar("x"),)), + )) + + Attributes: + logs: List of log messages emitted during the last function call. + + """ + + def __init__( + self, + function_class: type[ScalarFunctionGenerator] | type[ScalarFunction], + ) -> None: + """Initialize the ScalarFunctionTestClient. + + Args: + function_class: The scalar function class to test (not an instance). + + """ + self.function_class = function_class + self.logs: list[Message] = [] + self._logger: structlog.stdlib.BoundLogger = structlog.get_logger().bind( + component="scalar_test_client" + ) + + def __enter__(self) -> "ScalarFunctionTestClient": + """Enter context manager.""" + return self + + def __exit__(self, _exc_type: Any, _exc_val: Any, _exc_tb: Any) -> None: + """Exit context manager.""" + pass + + def scalar_function( + self, + *, + input: Iterator[pa.RecordBatch], + arguments: Arguments | None = None, + bind_result_callback: Callable[[pa.RecordBatch], None] | None = None, + ) -> Generator[pa.RecordBatch, None, None]: + """Call the scalar function with the given input data. + + This method implements the VGI scalar function protocol directly in-process, + without any IPC or subprocess communication. + + Args: + input: Iterator yielding input RecordBatches. + arguments: Arguments container with positional and named arguments. + bind_result_callback: Optional callback invoked with the bind result. + + Yields: + Output RecordBatches from the function (single-column). + + Raises: + FunctionTestClientError: If the function raises an exception. + + """ + # Clear logs from previous invocation + self.logs = [] + + if arguments is None: + arguments = Arguments() + + # Get first batch to determine input schema + try: + first_batch = next(input) + except StopIteration: + # No input batches - nothing to process + return + + input_schema = first_batch.schema + + # Create invocation + invocation_id = uuid.uuid4().bytes + invocation = Invocation( + function_name=self.function_class.__name__, + input_schema=input_schema, + function_type=InvocationType.SCALAR, + correlation_id="test", + invocation_id=invocation_id, + arguments=arguments, + ) + + # Instantiate function + func = self.function_class(invocation=invocation, logger=self._logger) + + # Create bind result for callback + if bind_result_callback is not None: + bind_batch = pa.RecordBatch.from_pylist( + [ + { + "output_schema": func.output_schema.serialize().to_pybytes(), + "max_processes": func.max_processes(), + "invocation_id": invocation_id, + } + ], + schema=pa.schema( + cast( + list[tuple[str, pa.DataType]], + [ + ("output_schema", pa.binary()), + ("max_processes", pa.int64()), + ("invocation_id", pa.binary()), + ], + ) + ), + ) + bind_result_callback(bind_batch) + + # Get the run generator + generator = func.run() + + # Prime the generator + try: + next(generator) # Priming output is discarded + except StopIteration: + return + + # Process first batch + yield from self._process_scalar_batch(generator, first_batch) + + # Process remaining batches + for batch in input: + yield from self._process_scalar_batch(generator, batch) + + # No finalize for scalar functions - just close + generator.close() + + def _process_scalar_batch( + self, + generator: Generator[TableProtocolOutput, ScalarProtocolInput | None, None], + batch: pa.RecordBatch, + ) -> Generator[pa.RecordBatch, None, None]: + """Process a single input batch, handling log messages.""" + while True: + try: + output = generator.send(ScalarProtocolInput(batch=batch)) + except StopIteration: + return + + # Capture log message if present + if output.log_message is not None: + self.logs.append(output.log_message) + # Check for exception + if output.log_message.level == Level.EXCEPTION: + raise FunctionTestClientError(output.log_message.message) + # Re-send the same batch to get actual output after log + continue + + # Yield output batch if it has rows + if output.batch is not None and output.batch.num_rows > 0: + yield output.batch + + # No log message means we're done with this batch + break + + +def run_scalar_function( + function: type[ScalarFunctionGenerator] | type[ScalarFunction], + input_batches: list[pa.RecordBatch], + args: tuple[Any, ...] | None = None, + kwargs: dict[str, Any] | None = None, +) -> tuple[list[pa.RecordBatch], list[Message]]: + """Run a scalar function and return outputs and logs. + + A convenience wrapper around ScalarFunctionTestClient for simple test cases. + + Args: + function: The scalar function class to test. + input_batches: List of input RecordBatches. + args: Optional positional arguments as a tuple. + kwargs: Optional named arguments as a dict. + + Returns: + Tuple of (output_batches, log_messages). + + Example: + outputs, logs = run_scalar_function( + DoubleColumnFunction, + [batch(x=[1, 2, 3])], + args=("x",), + ) + assert outputs[0].to_pydict() == {"result": [2, 4, 6]} + + """ + # Build Arguments from args/kwargs + positional: tuple[pa.Scalar[Any], ...] = () + named: dict[str, pa.Scalar[Any]] = {} + + if args: + positional = tuple(pa.scalar(a) for a in args) + if kwargs: + named = {k: pa.scalar(v) for k, v in kwargs.items()} + + arguments = Arguments(positional=positional, named=named) + + with ScalarFunctionTestClient(function) as client: + outputs = list( + client.scalar_function( + input=iter(input_batches), + arguments=arguments, + ) + ) + return outputs, client.logs + + +def assert_scalar_function_output( + function: type[ScalarFunctionGenerator] | type[ScalarFunction], + input: list[pa.RecordBatch], + expected: list[pa.RecordBatch], + args: tuple[Any, ...] | None = None, + kwargs: dict[str, Any] | None = None, + check_order: bool = True, + msg: str | None = None, +) -> list[Message]: + """Assert that a scalar function produces expected output batches. + + Runs the function with the given input and compares output to expected batches. + Returns captured log messages for additional assertions. + + Args: + function: The scalar function class to test. + input: List of input RecordBatches. + expected: List of expected output RecordBatches. + args: Optional positional arguments as a tuple. + kwargs: Optional named arguments as a dict. + check_order: If True, order of output batches must match. Default True. + msg: Optional custom assertion message prefix. + + Returns: + List of log messages captured during execution. + + Raises: + AssertionError: If output doesn't match expected. + FunctionTestClientError: If the function raises an exception. + + Examples: + # Double column test + assert_scalar_function_output( + DoubleColumnFunction, + input=[batch(x=[1, 2, 3])], + expected=[batch(result=[2, 4, 6])], + args=("x",), + ) + + # Add columns test + assert_scalar_function_output( + AddColumnsFunction, + input=[batch(a=[1, 2], b=[10, 20])], + expected=[batch(result=[11, 22])], + args=("a", "b"), + ) + + """ + outputs, logs = run_scalar_function( + function=function, + input_batches=input, + args=args, + kwargs=kwargs, + ) + + prefix = f"{msg}: " if msg else "" + + # Check batch count + if len(outputs) != len(expected): + actual_rows = [o.num_rows for o in outputs] + expected_rows = [e.num_rows for e in expected] + raise AssertionError( + f"{prefix}Expected {len(expected)} output batches, got {len(outputs)}. " + f"Output rows: {actual_rows}, Expected rows: {expected_rows}" + ) + + # Compare batches + if check_order: + for i, (actual, exp) in enumerate(zip(outputs, expected, strict=True)): + if not actual.equals(exp): + raise AssertionError( + f"{prefix}Batch {i} mismatch.\n" + f"Expected:\n{exp.to_pydict()}\n" + f"Got:\n{actual.to_pydict()}" + ) + else: + # Convert to sets of dicts for unordered comparison + actual_dicts = [b.to_pydict() for b in outputs] + expected_dicts = [b.to_pydict() for b in expected] + + for exp_dict in expected_dicts: + if exp_dict not in actual_dicts: + raise AssertionError( + f"{prefix}Expected batch not found in output.\n" + f"Expected:\n{exp_dict}\n" + f"Actual outputs:\n{actual_dicts}" + ) + + return logs diff --git a/vgi/worker.py b/vgi/worker.py index 6972849..29567c7 100644 --- a/vgi/worker.py +++ b/vgi/worker.py @@ -5,12 +5,15 @@ SUPPORTED FUNCTION TYPES ------------------------ -The worker supports two function types, dispatched based on class inheritance: +The worker supports three function types, dispatched based on class inheritance: -1. TableInOutGeneratorFunction: Reads input batches, produces output batches. +1. ScalarFunctionGenerator: Transforms input batches to single-column output + with 1:1 row mapping. Use for per-row computations like add(), upper(), etc. + +2. TableInOutGeneratorFunction: Reads input batches, produces output batches. Use for transforming, filtering, or aggregating input data. -2. TableFunctionGenerator: Generates output batches without reading input. +3. TableFunctionGenerator: Generates output batches without reading input. Use for data generation functions like sequence(), range(), random_sample(). QUICK START @@ -18,9 +21,14 @@ Create a worker by subclassing Worker and listing your functions: from vgi.worker import Worker + from vgi.scalar_function import ScalarFunction from vgi.table_in_out_function import TableInOutGeneratorFunction from vgi.table_function import TableFunctionGenerator + class DoubleColumn(ScalarFunction): + # Single-column output with 1:1 row mapping + ... + class EchoFunction(TableInOutGeneratorFunction): # Transforms input batches ... @@ -30,7 +38,7 @@ class SequenceFunction(TableFunctionGenerator): ... class MyWorker(Worker): - functions = [EchoFunction, SequenceFunction] + functions = [DoubleColumn, EchoFunction, SequenceFunction] if __name__ == "__main__": MyWorker().run() @@ -38,11 +46,19 @@ class MyWorker(Worker): Function names are derived from metadata (Meta.name or class name converted to snake_case). No manual name mapping required. +PROTOCOL FLOW (ScalarFunctionGenerator) +--------------------------------------- +1. Read Invocation: function name, arguments, input schema +2. Write OutputSpec: output schema, max_processes, invocation_id +3. Read/write FunctionInitInput/GlobalInitResult for initialization +4. Stream: read input batches -> compute -> write single-column output batches + (ends when input exhausted, no FINALIZE phase) + PROTOCOL FLOW (TableInOutGeneratorFunction) ------------------------------------------- 1. Read Invocation: function name, arguments, input schema 2. Write OutputSpec: output schema, max_processes, invocation_id -3. Read/write GlobalStateInitInput/GlobalInitResult for initialization +3. Read/write TableFunctionInitInput/GlobalInitResult for initialization 4. Stream: read input batches -> process -> write output batches 5. Finalize: receive FINALIZE signal -> emit final results @@ -50,7 +66,7 @@ class MyWorker(Worker): -------------------------------------- 1. Read Invocation: function name, arguments (no input schema) 2. Write OutputSpec: output schema, max_processes, invocation_id -3. Read/write GlobalStateInitInput/GlobalInitResult for initialization +3. Read/write TableFunctionInitInput/GlobalInitResult for initialization 4. Generate: produce output batches until generator exhausted KEY CLASSES @@ -71,7 +87,7 @@ class MyWorker(Worker): from collections.abc import Sequence from dataclasses import dataclass from io import IOBase -from typing import cast +from typing import Any, cast import pyarrow as pa import structlog @@ -82,8 +98,11 @@ class MyWorker(Worker): Function, Invocation, OutputSpec, + SchemaValidationError, ) from vgi.ipc_utils import read_ipc_batch +from vgi.scalar_function import ProtocolInput as ScalarProtocolInput +from vgi.scalar_function import ScalarFunctionGenerator from vgi.table_function import TableFunctionGenerator from vgi.table_in_out_function import ( ProtocolInput, @@ -128,11 +147,11 @@ class MyWorker(Worker): """ - functions: Sequence[type[Function]] = [] - _registry: dict[str, list[type[Function]]] | None = None + functions: Sequence[type[Function[Any]]] = [] + _registry: dict[str, list[type[Function[Any]]]] | None = None @classmethod - def _build_registry(cls) -> dict[str, list[type[Function]]]: + def _build_registry(cls) -> dict[str, list[type[Function[Any]]]]: """Build function name -> list of classes mapping from functions list. Multiple functions can share the same name if they have different @@ -141,7 +160,7 @@ def _build_registry(cls) -> dict[str, list[type[Function]]]: if cls._registry is not None: return cls._registry - registry: dict[str, list[type[Function]]] = {} + registry: dict[str, list[type[Function[Any]]]] = {} for func_cls in cls.functions: meta = func_cls.get_metadata() if meta.name not in registry: @@ -154,8 +173,8 @@ def _build_registry(cls) -> dict[str, list[type[Function]]]: @staticmethod def _match_function( invocation: Invocation, - candidates: Sequence[type[Function]], - ) -> type[Function]: + candidates: Sequence[type[Function[Any]]], + ) -> type[Function[Any]]: """Find the function that matches the invocation's arguments. Compares the invocation's positional and named arguments against each @@ -176,7 +195,7 @@ def _match_function( num_positional = len(args.positional) named_keys = set(args.named.keys()) if args.named else set() - matches: list[type[Function]] = [] + matches: list[type[Function[Any]]] = [] for func_cls in candidates: meta = func_cls.get_metadata() @@ -239,6 +258,50 @@ def _match_function( return matches[0] + @staticmethod + def _suggest_similar_names(name: str, candidates: list[str]) -> list[str]: + """Find function names similar to the given name. + + Uses prefix matching, substring matching, and character overlap to + suggest likely alternatives for typos. + + Args: + name: The unknown function name. + candidates: List of valid function names. + + Returns: + List of similar names, sorted by relevance. + + """ + if not candidates: + return [] + + name_lower = name.lower() + scored: list[tuple[int, str]] = [] + + for candidate in candidates: + candidate_lower = candidate.lower() + + # Exact prefix match (highest priority) + if candidate_lower.startswith(name_lower): + scored.append((0, candidate)) + elif name_lower.startswith(candidate_lower): + scored.append((1, candidate)) + # Substring matches + elif name_lower in candidate_lower or candidate_lower in name_lower: + scored.append((2, candidate)) + else: + # Character overlap score (for typos) + name_chars = set(name_lower) + candidate_chars = set(candidate_lower) + overlap = len(name_chars & candidate_chars) + # Require at least half the characters to match + if overlap > len(name_lower) // 2: + scored.append((10 - overlap, candidate)) + + scored.sort(key=lambda x: (x[0], x[1])) + return [candidate for _, candidate in scored] + def __init__(self) -> None: """Initialize the worker with structured logging.""" structlog.configure( @@ -278,6 +341,97 @@ def _read_init_data(self) -> pa.RecordBatch: """Read and parse the init data from stdin.""" return self._read_ipc_batch("init_data") + def _process_scalar_batches( + self, + instance: ScalarFunctionGenerator, + invocation: Invocation, + fn_log: structlog.stdlib.BoundLogger, + ) -> WorkerStats: + """Process data batches through a scalar function. + + Similar to _process_batches but simplified: + - No FINALIZE phase (ends when input exhausted) + - HAVE_MORE_OUTPUT only used for log messages (not multiple output batches) + + Returns: + WorkerStats with batch_count, total_input_rows, total_output_rows. + + """ + if invocation.global_init_identifier is None: + raise ValueError( + "global_init_identifier is required but was None. " + "This is an internal protocol error - the worker should have set " + "global_init_identifier after perform_init() completed successfully." + ) + generator = instance.run() + next(generator) # Prime the run() generator + + with ( + ipc.new_stream(cast(IOBase, sys.stdout), instance.output_schema) as writer, + ipc.open_stream(cast(IOBase, sys.stdin)) as data_reader, + ): + # Validate data stream schema matches expected input schema + if data_reader.schema != invocation.input_schema: + raise SchemaValidationError( + "Data stream schema does not match expected input schema.", + expected=invocation.input_schema, + actual=data_reader.schema, + context="input stream to scalar function", + ) + + batch_count = 0 + total_input_rows = 0 + total_output_rows = 0 + while True: + fn_log.debug("batch_waiting") + + try: + batch, metadata = data_reader.read_next_batch_with_custom_metadata() + except StopIteration: + fn_log.debug("input_stream_ended") + # Close the generator - no FINALIZE for scalar functions + generator.close() + break + + batch_count += 1 + total_input_rows += batch.num_rows + fn_log.debug( + "batch_received", + batch_index=batch_count, + input_rows=batch.num_rows, + ) + + protocol_input = ScalarProtocolInput(batch=batch, metadata=metadata) + output = generator.send(protocol_input) + + # Handle log messages (indicated by log_message being set) + while output.log_message is not None: + fn_log.debug("log_message_received", output=output) + assert output.batch is not None + writer.write_batch( + output.batch, custom_metadata=output.metadata(invocation) + ) + # Re-send same input to continue + output = generator.send(protocol_input) + + fn_log.debug("batch_processed", output=output) + assert output.batch is not None + output_rows = output.batch.num_rows + total_output_rows += output_rows + writer.write_batch( + output.batch, custom_metadata=output.metadata(invocation) + ) + fn_log.debug( + "batch_written", + batch_index=batch_count, + output_rows=output_rows, + ) + return WorkerStats( + batch_count=batch_count, + total_input_rows=total_input_rows, + total_output_rows=total_output_rows, + ) + def _process_batches( self, instance: TableInOutGeneratorFunction, @@ -308,11 +462,12 @@ def _process_batches( ipc.open_stream(cast(IOBase, sys.stdin)) as data_reader, ): # Validate data stream schema matches expected input schema - if data_reader.schema != invocation.in_out_function_input_schema: - expected = invocation.in_out_function_input_schema - raise ValueError( - f"Data stream schema mismatch. Expected: {expected}, " - f"got: {data_reader.schema}" + if data_reader.schema != invocation.input_schema: + raise SchemaValidationError( + "Data stream schema does not match expected input schema.", + expected=invocation.input_schema, + actual=data_reader.schema, + context="input stream to table-in-out function", ) batch_count = 0 @@ -354,7 +509,6 @@ def _process_batches( "batch_written", batch_index=batch_count, output_rows=output_rows, - status=output.status.value if output.status else None, ) return WorkerStats( batch_count=batch_count, @@ -415,16 +569,25 @@ def run(self) -> None: fn_log = self.log.bind(function=invocation.function_name) fn_log.info("init_received", arguments=invocation.arguments) - fn_log.debug( - "input_schema_parsed", schema=str(invocation.in_out_function_input_schema) - ) + fn_log.debug("input_schema_parsed", schema=str(invocation.input_schema)) registry = self._build_registry() if invocation.function_name not in registry: available = sorted(registry.keys()) - raise ValueError( - f"Unknown function: {invocation.function_name}. Available: {available}" + suggestions = self._suggest_similar_names( + invocation.function_name, available ) + msg_lines = [ + f"Unknown function: '{invocation.function_name}'", + "", + ] + if suggestions: + msg_lines.append(" Did you mean:") + for suggestion in suggestions[:3]: + msg_lines.append(f" - {suggestion}") + msg_lines.append("") + msg_lines.append(f" Available functions: {available}") + raise ValueError("\n".join(msg_lines)) candidates = registry[invocation.function_name] func_cls = self._match_function(invocation, candidates) @@ -467,21 +630,26 @@ def run(self) -> None: instance.retrieve_init(invocation.global_init_identifier) # Dispatch to appropriate processing method based on function type. + # ScalarFunctionGenerator processes input batches to single-column output. # TableInOutGeneratorFunction reads input batches and produces output. # TableFunctionGenerator generates output without input batches. - # Note: Check TableInOutGeneratorFunction first since it's more specific - # (TableFunctionGenerator is a parent of TableInOutGeneratorFunction's parent). - if isinstance(instance, TableInOutGeneratorFunction): + # Note: Check ScalarFunctionGenerator first since it doesn't inherit from + # TableInOutGeneratorFunction, then TableInOutGeneratorFunction. + if isinstance(instance, ScalarFunctionGenerator): + stats = self._process_scalar_batches(instance, invocation, fn_log) + elif isinstance(instance, TableInOutGeneratorFunction): stats = self._process_batches(instance, invocation, fn_log) elif isinstance(instance, TableFunctionGenerator): stats = self._generate_batches(instance, invocation, fn_log) else: raise TypeError( f"Unsupported function type: {type(instance).__name__}. " - f"Functions must inherit from TableInOutGeneratorFunction (for " - f"functions that process input batches) or TableFunctionGenerator " - f"(for functions that generate output without input). " - f"See vgi.table_in_out_function and vgi.table_function modules." + f"Functions must inherit from ScalarFunctionGenerator (for " + f"scalar functions), TableInOutGeneratorFunction (for functions " + f"that process input batches), or TableFunctionGenerator (for " + f"functions that generate output without input). " + f"See vgi.scalar_function, vgi.table_in_out_function, and " + f"vgi.table_function modules." ) fn_log.info(