Skip to content

Commit 2eba66d

Browse files
committed
table functions continued
1 parent 76af089 commit 2eba66d

3 files changed

Lines changed: 466 additions & 0 deletions

File tree

tests/table/generator/conftest.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
"""Shared fixtures for table generator function tests."""
2+
3+
from collections.abc import Callable, Generator
4+
from typing import Any, Literal
5+
6+
import pyarrow as pa
7+
import pytest
8+
9+
from vgi.client import Client, ClientError
10+
from vgi.function import Arguments
11+
from vgi.log import Message
12+
from vgi.table_function import TableFunctionGenerator
13+
from vgi.testing import FunctionTestClientError, TableFunctionTestClient
14+
15+
# Type alias for the runner function
16+
TableFunctionRunner = Callable[
17+
[type[TableFunctionGenerator], tuple[Any, ...]],
18+
tuple[list[pa.RecordBatch], list[Message]],
19+
]
20+
21+
# Type alias for the test mode
22+
TestMode = Literal["in_process", "client"]
23+
24+
# Type alias for the fixture return type
25+
RunnerWithMode = tuple[TableFunctionRunner, TestMode]
26+
27+
28+
def run_in_process(
29+
func_class: type[TableFunctionGenerator],
30+
args: tuple[Any, ...],
31+
) -> tuple[list[pa.RecordBatch], list[Message]]:
32+
"""Run a table function in-process using TableFunctionTestClient."""
33+
with TableFunctionTestClient(func_class) as client:
34+
outputs = list(
35+
client.table_function(
36+
arguments=Arguments(positional=tuple(pa.scalar(a) for a in args))
37+
)
38+
)
39+
return outputs, client.logs
40+
41+
42+
def run_via_client(
43+
func_class: type[TableFunctionGenerator],
44+
args: tuple[Any, ...],
45+
) -> tuple[list[pa.RecordBatch], list[Message]]:
46+
"""Run a table function via subprocess using Client.table_function.
47+
48+
Uses max_workers=1 to ensure consistent behavior with in-process mode.
49+
For multi-worker tests, use Client directly with explicit max_workers.
50+
"""
51+
meta = func_class.get_metadata()
52+
function_name = meta.name
53+
54+
with Client("vgi-example-worker", max_workers=1) as client:
55+
outputs = list(
56+
client.table_function(
57+
function_name=function_name,
58+
arguments=Arguments(positional=tuple(pa.scalar(a) for a in args)),
59+
)
60+
)
61+
# Note: logs are not captured via Client (they go to stderr)
62+
return outputs, []
63+
64+
65+
@pytest.fixture(params=["in_process", "client"])
66+
def run_table_function_mode(
67+
request: pytest.FixtureRequest,
68+
) -> Generator[RunnerWithMode, None, None]:
69+
"""Fixture that provides both in-process and client-based runners.
70+
71+
Tests using this fixture will run twice: once in-process and once via Client.
72+
Returns a tuple of (runner_function, mode_name) for conditional assertions.
73+
"""
74+
mode: TestMode = request.param
75+
if mode == "in_process":
76+
yield run_in_process, mode
77+
else:
78+
yield run_via_client, mode
79+
80+
81+
@pytest.fixture
82+
def run_function() -> TableFunctionRunner:
83+
"""Fixture for in-process-only tests (e.g., log capture tests)."""
84+
return run_in_process
85+
86+
87+
# Re-export error types for convenience
88+
InProcessError = FunctionTestClientError
89+
SubprocessError = ClientError
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
"""Tests for the PartitionedRangeFunction with multi-worker support."""
2+
3+
import pyarrow as pa
4+
5+
from vgi.client import Client
6+
from vgi.examples.table import PartitionedRangeFunction
7+
from vgi.function import Arguments
8+
from vgi.testing import TableFunctionTestClient
9+
10+
from .conftest import RunnerWithMode
11+
12+
13+
class TestPartitionedRangeFunctionInProcess:
14+
"""In-process tests for the partitioned_range function."""
15+
16+
def test_generates_full_range_single_worker(self) -> None:
17+
"""Single worker should generate the complete range."""
18+
with TableFunctionTestClient(PartitionedRangeFunction) as client:
19+
outputs = list(
20+
client.table_function(
21+
arguments=Arguments(positional=(pa.scalar(10),))
22+
)
23+
)
24+
25+
table = pa.Table.from_batches(outputs)
26+
assert table.num_rows == 10
27+
28+
values = sorted(table.column("value").to_pylist())
29+
assert values == list(range(10))
30+
31+
def test_metadata(self) -> None:
32+
"""Partitioned range function should have correct metadata."""
33+
meta = PartitionedRangeFunction.get_metadata()
34+
assert meta.name == "partitioned_range"
35+
# Should not have max_workers limit (parallelizable)
36+
assert meta.max_workers is None
37+
38+
def test_zero_count(self) -> None:
39+
"""Partitioned range with count=0 should produce no output."""
40+
with TableFunctionTestClient(PartitionedRangeFunction) as client:
41+
outputs = list(
42+
client.table_function(
43+
arguments=Arguments(positional=(pa.scalar(0),))
44+
)
45+
)
46+
47+
assert len(outputs) == 0
48+
49+
50+
class TestPartitionedRangeFunctionBothModes:
51+
"""Tests that run both in-process and via Client subprocess."""
52+
53+
def test_generates_correct_count(
54+
self, run_table_function_mode: RunnerWithMode
55+
) -> None:
56+
"""Partitioned range should generate exactly the requested number of rows."""
57+
runner, mode = run_table_function_mode
58+
outputs, logs = runner(PartitionedRangeFunction, (100,))
59+
60+
table = pa.Table.from_batches(outputs)
61+
assert table.num_rows == 100
62+
63+
def test_values_are_sequential(
64+
self, run_table_function_mode: RunnerWithMode
65+
) -> None:
66+
"""Partitioned range should produce all values in range."""
67+
runner, mode = run_table_function_mode
68+
outputs, logs = runner(PartitionedRangeFunction, (50,))
69+
70+
table = pa.Table.from_batches(outputs)
71+
values = sorted(table.column("value").to_pylist())
72+
assert values == list(range(50))
73+
74+
75+
class TestPartitionedRangeFunctionMultiWorker:
76+
"""Tests for multi-worker partitioned execution via Client."""
77+
78+
def test_two_workers_produce_complete_range(self) -> None:
79+
"""Two workers should together produce the complete range."""
80+
with Client("vgi-example-worker", max_workers=2) as client:
81+
outputs = list(
82+
client.table_function(
83+
function_name="partitioned_range",
84+
arguments=Arguments(positional=(pa.scalar(20),)),
85+
)
86+
)
87+
88+
table = pa.Table.from_batches(outputs)
89+
assert table.num_rows == 20
90+
91+
values = sorted(table.column("value").to_pylist())
92+
assert values == list(range(20))
93+
94+
def test_three_workers_produce_complete_range(self) -> None:
95+
"""Three workers should together produce the complete range."""
96+
with Client("vgi-example-worker", max_workers=3) as client:
97+
outputs = list(
98+
client.table_function(
99+
function_name="partitioned_range",
100+
arguments=Arguments(positional=(pa.scalar(30),)),
101+
)
102+
)
103+
104+
table = pa.Table.from_batches(outputs)
105+
assert table.num_rows == 30
106+
107+
values = sorted(table.column("value").to_pylist())
108+
assert values == list(range(30))
109+
110+
def test_workers_produce_large_range(self) -> None:
111+
"""Multiple workers should handle large ranges."""
112+
with Client("vgi-example-worker", max_workers=4) as client:
113+
outputs = list(
114+
client.table_function(
115+
function_name="partitioned_range",
116+
arguments=Arguments(positional=(pa.scalar(10000),)),
117+
)
118+
)
119+
120+
table = pa.Table.from_batches(outputs)
121+
assert table.num_rows == 10000
122+
123+
values = sorted(table.column("value").to_pylist())
124+
assert values == list(range(10000))
125+
126+
def test_uneven_distribution(self) -> None:
127+
"""Workers should handle ranges that don't divide evenly."""
128+
# 7 items with 3 workers: worker 0 gets [0,3,6], worker 1 gets [1,4],
129+
# worker 2 gets [2,5]
130+
with Client("vgi-example-worker", max_workers=3) as client:
131+
outputs = list(
132+
client.table_function(
133+
function_name="partitioned_range",
134+
arguments=Arguments(positional=(pa.scalar(7),)),
135+
)
136+
)
137+
138+
table = pa.Table.from_batches(outputs)
139+
assert table.num_rows == 7
140+
141+
values = sorted(table.column("value").to_pylist())
142+
assert values == list(range(7))
143+
144+
def test_single_worker_fallback(self) -> None:
145+
"""max_workers=1 should work like single worker mode."""
146+
with Client("vgi-example-worker", max_workers=1) as client:
147+
outputs = list(
148+
client.table_function(
149+
function_name="partitioned_range",
150+
arguments=Arguments(positional=(pa.scalar(15),)),
151+
)
152+
)
153+
154+
table = pa.Table.from_batches(outputs)
155+
assert table.num_rows == 15
156+
157+
values = sorted(table.column("value").to_pylist())
158+
assert values == list(range(15))

0 commit comments

Comments
 (0)