Skip to content

Commit 3dac957

Browse files
authored
Merge pull request #1 from Query-farm/fix-types3
Add pyarrow-stubs and improve type coverage
2 parents f87bb3d + 05760bc commit 3dac957

30 files changed

Lines changed: 474 additions & 228 deletions

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,7 @@ wheels/
88

99
# Virtual environments
1010
.venv
11+
12+
# Mypy reports
13+
mypy-reports/
14+
mypy-html-report/

CLAUDE.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@ uv run ruff format . # Format
1212
uv run mypy vgi/ # Type check
1313
```
1414

15+
**Before committing**, always run lint and format checks:
16+
```bash
17+
uv run ruff check --fix . && uv run ruff format . && uv run mypy vgi/
18+
```
19+
1520
## Project Overview
1621

1722
VGI (Vector Gateway Interface) provides an Apache Arrow-based protocol for connecting DuckDB to external programs. It enables user-defined functions to run in separate processes, communicating via stdin/stdout using Arrow IPC streaming.

pyproject.toml

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ requires-python = ">=3.12.4"
77
dependencies = ["click", "pyarrow", "structlog", "platformdirs"]
88

99
[project.optional-dependencies]
10-
dev = ["mypy", "pytest", "pytest-mypy", "pytest-ruff", "pytest-xdist", "ruff"]
10+
dev = ["mypy", "pyarrow-stubs", "pytest", "pytest-mypy", "pytest-ruff", "pytest-xdist", "ruff"]
1111

1212
[project.scripts]
1313
vgi-client = "vgi.client.cli:main"
@@ -38,14 +38,16 @@ strict = true
3838
warn_return_any = true
3939
warn_unused_ignores = true
4040

41-
[[tool.mypy.overrides]]
42-
module = "pyarrow.*"
43-
ignore_missing_imports = true
44-
4541
[[tool.mypy.overrides]]
4642
module = "structlog.*"
4743
ignore_missing_imports = true
4844

4945
[tool.pytest.ini_options]
5046
addopts = "--mypy --ruff"
5147
testpaths = ["tests"]
48+
49+
[dependency-groups]
50+
dev = [
51+
"lxml>=6.0.2",
52+
"pytest-timeout>=2.4.0",
53+
]

tests/conftest.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Shared fixtures for VGI tests."""
22

3+
from typing import Any
4+
35
import pyarrow as pa
46
import pytest
57

@@ -13,13 +15,12 @@ def example_worker() -> str:
1315
@pytest.fixture
1416
def simple_batches() -> list[pa.RecordBatch]:
1517
"""Create simple test batches with integer and string columns."""
16-
schema = pa.schema(
17-
[
18-
pa.field("id", pa.int64()),
19-
pa.field("value", pa.int64()),
20-
pa.field("name", pa.string()),
21-
]
22-
)
18+
fields: list[pa.Field[Any]] = [
19+
pa.field("id", pa.int64()),
20+
pa.field("value", pa.int64()),
21+
pa.field("name", pa.string()),
22+
]
23+
schema = pa.schema(fields)
2324
batch1 = pa.RecordBatch.from_pydict(
2425
{"id": [1, 2], "value": [10, 20], "name": ["a", "b"]},
2526
schema=schema,
@@ -34,12 +35,11 @@ def simple_batches() -> list[pa.RecordBatch]:
3435
@pytest.fixture
3536
def numeric_batches() -> list[pa.RecordBatch]:
3637
"""Create test batches with only numeric columns for sum tests."""
37-
schema = pa.schema(
38-
[
39-
pa.field("a", pa.int32()),
40-
pa.field("b", pa.float64()),
41-
]
42-
)
38+
fields: list[pa.Field[Any]] = [
39+
pa.field("a", pa.int32()),
40+
pa.field("b", pa.float64()),
41+
]
42+
schema = pa.schema(fields)
4343
batch1 = pa.RecordBatch.from_pydict(
4444
{"a": [1, 2, 3], "b": [1.5, 2.5, 3.0]},
4545
schema=schema,

tests/table/generator/test_partitioned_function.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
"""Tests for the PartitionedRangeFunction with multi-worker support."""
22

3+
from __future__ import annotations
4+
5+
from typing import Any
6+
37
import pyarrow as pa
48

59
from vgi.client import Client
@@ -10,6 +14,11 @@
1014
from .conftest import RunnerWithMode
1115

1216

17+
def _sorted_non_null(values: list[Any | None]) -> list[Any]:
18+
"""Return sorted list, filtering out None values for type safety."""
19+
return sorted(v for v in values if v is not None)
20+
21+
1322
class TestPartitionedRangeFunctionInProcess:
1423
"""In-process tests for the partitioned_range function."""
1524

@@ -23,7 +32,7 @@ def test_generates_full_range_single_worker(self) -> None:
2332
table = pa.Table.from_batches(outputs)
2433
assert table.num_rows == 10
2534

26-
values = sorted(table.column("value").to_pylist())
35+
values = _sorted_non_null(table.column("value").to_pylist())
2736
assert values == list(range(10))
2837

2938
def test_metadata(self) -> None:
@@ -64,7 +73,7 @@ def test_values_are_sequential(
6473
outputs, logs = runner(PartitionedRangeFunction, (50,))
6574

6675
table = pa.Table.from_batches(outputs)
67-
values = sorted(table.column("value").to_pylist())
76+
values = _sorted_non_null(table.column("value").to_pylist())
6877
assert values == list(range(50))
6978

7079

@@ -84,7 +93,7 @@ def test_two_workers_produce_complete_range(self) -> None:
8493
table = pa.Table.from_batches(outputs)
8594
assert table.num_rows == 20
8695

87-
values = sorted(table.column("value").to_pylist())
96+
values = _sorted_non_null(table.column("value").to_pylist())
8897
assert values == list(range(20))
8998

9099
def test_three_workers_produce_complete_range(self) -> None:
@@ -100,7 +109,7 @@ def test_three_workers_produce_complete_range(self) -> None:
100109
table = pa.Table.from_batches(outputs)
101110
assert table.num_rows == 30
102111

103-
values = sorted(table.column("value").to_pylist())
112+
values = _sorted_non_null(table.column("value").to_pylist())
104113
assert values == list(range(30))
105114

106115
def test_workers_produce_large_range(self) -> None:
@@ -116,7 +125,7 @@ def test_workers_produce_large_range(self) -> None:
116125
table = pa.Table.from_batches(outputs)
117126
assert table.num_rows == 10000
118127

119-
values = sorted(table.column("value").to_pylist())
128+
values = _sorted_non_null(table.column("value").to_pylist())
120129
assert values == list(range(10000))
121130

122131
def test_uneven_distribution(self) -> None:
@@ -134,7 +143,7 @@ def test_uneven_distribution(self) -> None:
134143
table = pa.Table.from_batches(outputs)
135144
assert table.num_rows == 7
136145

137-
values = sorted(table.column("value").to_pylist())
146+
values = _sorted_non_null(table.column("value").to_pylist())
138147
assert values == list(range(7))
139148

140149
def test_single_worker_fallback(self) -> None:
@@ -150,5 +159,5 @@ def test_single_worker_fallback(self) -> None:
150159
table = pa.Table.from_batches(outputs)
151160
assert table.num_rows == 15
152161

153-
values = sorted(table.column("value").to_pylist())
162+
values = _sorted_non_null(table.column("value").to_pylist())
154163
assert values == list(range(15))

tests/table/generator/test_random_sample_function.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,4 +95,4 @@ def test_values_in_range(self, run_table_function_mode: RunnerWithMode) -> None:
9595
table = pa.Table.from_batches(outputs)
9696
values = table.column("value").to_pylist()
9797

98-
assert all(0 <= v < 1 for v in values)
98+
assert all(v is not None and 0 <= v < 1 for v in values)

tests/table/test_function.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
"""Generic tests for TableFunctionGenerator behavior."""
22

3+
from __future__ import annotations
4+
35
import pyarrow as pa
46
import pytest
57
import structlog
68

9+
from tests.utils import make_schema
710
from vgi.function import Arguments, Invocation
811
from vgi.table_function import (
912
CardinalityInfo,
@@ -23,7 +26,7 @@ def test_empty_process_generator(self) -> None:
2326
class EmptyFunction(TableFunctionGenerator):
2427
@property
2528
def output_schema(self) -> pa.Schema:
26-
return pa.schema([pa.field("x", pa.int64())])
29+
return make_schema([pa.field("x", pa.int64())])
2730

2831
with TableFunctionTestClient(EmptyFunction) as client:
2932
outputs = list(client.table_function())
@@ -37,7 +40,7 @@ def test_single_batch_output(self) -> None:
3740
class SingleBatchFunction(TableFunctionGenerator):
3841
@property
3942
def output_schema(self) -> pa.Schema:
40-
return pa.schema([pa.field("x", pa.int64())])
43+
return make_schema([pa.field("x", pa.int64())])
4144

4245
def process(self) -> OutputGenerator:
4346
yield Output(
@@ -58,7 +61,7 @@ def test_multiple_batch_output(self) -> None:
5861
class MultiBatchFunction(TableFunctionGenerator):
5962
@property
6063
def output_schema(self) -> pa.Schema:
61-
return pa.schema([pa.field("n", pa.int64())])
64+
return make_schema([pa.field("n", pa.int64())])
6265

6366
def process(self) -> OutputGenerator:
6467
for i in range(3):
@@ -86,7 +89,7 @@ def test_setup_called_before_process(self) -> None:
8689
class LifecycleFunction(TableFunctionGenerator):
8790
@property
8891
def output_schema(self) -> pa.Schema:
89-
return pa.schema([pa.field("x", pa.int64())])
92+
return make_schema([pa.field("x", pa.int64())])
9093

9194
def setup(self) -> None:
9295
call_order.append("setup")
@@ -112,7 +115,7 @@ def test_teardown_called_on_exception(self) -> None:
112115
class ExceptionFunction(TableFunctionGenerator):
113116
@property
114117
def output_schema(self) -> pa.Schema:
115-
return pa.schema([pa.field("x", pa.int64())])
118+
return make_schema([pa.field("x", pa.int64())])
116119

117120
def process(self) -> OutputGenerator:
118121
raise ValueError("test error")
@@ -140,7 +143,7 @@ def test_valid_schema_passes(self) -> None:
140143
class ValidSchemaFunction(TableFunctionGenerator):
141144
@property
142145
def output_schema(self) -> pa.Schema:
143-
return pa.schema([pa.field("x", pa.int64())])
146+
return make_schema([pa.field("x", pa.int64())])
144147

145148
def process(self) -> OutputGenerator:
146149
yield Output(
@@ -158,11 +161,11 @@ def test_invalid_schema_raises(self) -> None:
158161
class InvalidSchemaFunction(TableFunctionGenerator):
159162
@property
160163
def output_schema(self) -> pa.Schema:
161-
return pa.schema([pa.field("x", pa.int64())])
164+
return make_schema([pa.field("x", pa.int64())])
162165

163166
def process(self) -> OutputGenerator:
164167
# Return batch with wrong column name
165-
wrong_schema = pa.schema([pa.field("y", pa.int64())])
168+
wrong_schema = make_schema([pa.field("y", pa.int64())])
166169
wrong_batch = pa.RecordBatch.from_pydict(
167170
{"y": [1]}, schema=wrong_schema
168171
)
@@ -187,7 +190,7 @@ def test_default_cardinality_is_none(self) -> None:
187190
class NoCardinalityFunction(TableFunctionGenerator):
188191
@property
189192
def output_schema(self) -> pa.Schema:
190-
return pa.schema([pa.field("x", pa.int64())])
193+
return make_schema([pa.field("x", pa.int64())])
191194

192195
invocation = Invocation(
193196
function_name="test",
@@ -209,7 +212,7 @@ def test_custom_cardinality(self) -> None:
209212
class CardinalityFunction(TableFunctionGenerator):
210213
@property
211214
def output_schema(self) -> pa.Schema:
212-
return pa.schema([pa.field("x", pa.int64())])
215+
return make_schema([pa.field("x", pa.int64())])
213216

214217
def cardinality(self) -> CardinalityInfo:
215218
return CardinalityInfo(estimate=100, max=1000)
@@ -244,7 +247,7 @@ class ArgFunction(TableFunctionGenerator):
244247

245248
@property
246249
def output_schema(self) -> pa.Schema:
247-
return pa.schema([pa.field("n", pa.int64())])
250+
return make_schema([pa.field("n", pa.int64())])
248251

249252
def process(self) -> OutputGenerator:
250253
yield Output(
@@ -270,7 +273,7 @@ class NamedArgFunction(TableFunctionGenerator):
270273

271274
@property
272275
def output_schema(self) -> pa.Schema:
273-
return pa.schema([pa.field("result", pa.int64())])
276+
return make_schema([pa.field("result", pa.int64())])
274277

275278
def process(self) -> OutputGenerator:
276279
yield Output(
@@ -299,7 +302,7 @@ def test_empty_output_batch_property(self) -> None:
299302
class TestFunction(TableFunctionGenerator):
300303
@property
301304
def output_schema(self) -> pa.Schema:
302-
return pa.schema(
305+
return make_schema(
303306
[
304307
pa.field("a", pa.int64()),
305308
pa.field("b", pa.string()),

tests/table_in_out/generator/test_repeat_inputs_function.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
"""Tests for the RepeatInputsFunction (explosion)."""
22

3+
from __future__ import annotations
4+
35
import pyarrow as pa
46

7+
from tests.utils import make_schema
58
from vgi.client import Client
69
from vgi.function import Arguments
710

@@ -18,7 +21,7 @@ def test_repeat_custom_count(
1821
output_batches = list(
1922
client.table_in_out_function(
2023
function_name="repeat_inputs",
21-
arguments=Arguments(positional=tuple([repeat_count]), named={}),
24+
arguments=Arguments(positional=(pa.scalar(repeat_count),)),
2225
input=iter(simple_batches),
2326
)
2427
)
@@ -35,7 +38,7 @@ def test_repeat_single_time(
3538
output_batches = list(
3639
client.table_in_out_function(
3740
function_name="repeat_inputs",
38-
arguments=Arguments(positional=tuple([1]), named={}),
41+
arguments=Arguments(positional=(pa.scalar(1),), named={}),
3942
input=iter(simple_batches),
4043
)
4144
)
@@ -46,7 +49,7 @@ def test_repeat_single_time(
4649

4750
def test_repeat_distributed_many_batches(self, example_worker: str) -> None:
4851
"""Should correctly repeat across many batches with multiple workers."""
49-
schema = pa.schema([pa.field("a", pa.int64()), pa.field("b", pa.float64())])
52+
schema = make_schema([pa.field("a", pa.int64()), pa.field("b", pa.float64())])
5053

5154
# Create 100 batches, each with 50 rows
5255
num_batches = 100
@@ -69,7 +72,7 @@ def test_repeat_distributed_many_batches(self, example_worker: str) -> None:
6972
output_batches = list(
7073
client.table_in_out_function(
7174
function_name="repeat_inputs",
72-
arguments=Arguments(positional=tuple([repeat_count]), named={}),
75+
arguments=Arguments(positional=(pa.scalar(repeat_count),)),
7376
input=iter(batches),
7477
)
7578
)
@@ -80,7 +83,9 @@ def test_repeat_distributed_many_batches(self, example_worker: str) -> None:
8083

8184
def test_repeat_distributed_preserves_data(self, example_worker: str) -> None:
8285
"""Should preserve data correctly when repeated across workers."""
83-
schema = pa.schema([pa.field("id", pa.int64()), pa.field("value", pa.string())])
86+
schema = make_schema(
87+
[pa.field("id", pa.int64()), pa.field("value", pa.string())]
88+
)
8489

8590
# Create batches with distinct values to verify data integrity
8691
batches = [
@@ -101,7 +106,7 @@ def test_repeat_distributed_preserves_data(self, example_worker: str) -> None:
101106
output_batches = list(
102107
client.table_in_out_function(
103108
function_name="repeat_inputs",
104-
arguments=Arguments(positional=tuple([repeat_count]), named={}),
109+
arguments=Arguments(positional=(pa.scalar(repeat_count),)),
105110
input=iter(batches),
106111
)
107112
)

0 commit comments

Comments
 (0)