Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,7 @@ wheels/

# Virtual environments
.venv

# Mypy reports
mypy-reports/
mypy-html-report/
5 changes: 5 additions & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ uv run ruff format . # Format
uv run mypy vgi/ # Type check
```

**Before committing**, always run lint and format checks:
```bash
uv run ruff check --fix . && uv run ruff format . && uv run mypy vgi/
```

## Project Overview

VGI (Vector Gateway Interface) provides an Apache Arrow-based protocol for connecting DuckDB to external programs. It enables user-defined functions to run in separate processes, communicating via stdin/stdout using Arrow IPC streaming.
Expand Down
12 changes: 7 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ requires-python = ">=3.12.4"
dependencies = ["click", "pyarrow", "structlog", "platformdirs"]

[project.optional-dependencies]
dev = ["mypy", "pytest", "pytest-mypy", "pytest-ruff", "pytest-xdist", "ruff"]
dev = ["mypy", "pyarrow-stubs", "pytest", "pytest-mypy", "pytest-ruff", "pytest-xdist", "ruff"]

[project.scripts]
vgi-client = "vgi.client.cli:main"
Expand Down Expand Up @@ -38,14 +38,16 @@ strict = true
warn_return_any = true
warn_unused_ignores = true

[[tool.mypy.overrides]]
module = "pyarrow.*"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "structlog.*"
ignore_missing_imports = true

[tool.pytest.ini_options]
addopts = "--mypy --ruff"
testpaths = ["tests"]

[dependency-groups]
dev = [
"lxml>=6.0.2",
"pytest-timeout>=2.4.0",
]
26 changes: 13 additions & 13 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Shared fixtures for VGI tests."""

from typing import Any

import pyarrow as pa
import pytest

Expand All @@ -13,13 +15,12 @@ def example_worker() -> str:
@pytest.fixture
def simple_batches() -> list[pa.RecordBatch]:
"""Create simple test batches with integer and string columns."""
schema = pa.schema(
[
pa.field("id", pa.int64()),
pa.field("value", pa.int64()),
pa.field("name", pa.string()),
]
)
fields: list[pa.Field[Any]] = [
pa.field("id", pa.int64()),
pa.field("value", pa.int64()),
pa.field("name", pa.string()),
]
schema = pa.schema(fields)
batch1 = pa.RecordBatch.from_pydict(
{"id": [1, 2], "value": [10, 20], "name": ["a", "b"]},
schema=schema,
Expand All @@ -34,12 +35,11 @@ def simple_batches() -> list[pa.RecordBatch]:
@pytest.fixture
def numeric_batches() -> list[pa.RecordBatch]:
"""Create test batches with only numeric columns for sum tests."""
schema = pa.schema(
[
pa.field("a", pa.int32()),
pa.field("b", pa.float64()),
]
)
fields: list[pa.Field[Any]] = [
pa.field("a", pa.int32()),
pa.field("b", pa.float64()),
]
schema = pa.schema(fields)
batch1 = pa.RecordBatch.from_pydict(
{"a": [1, 2, 3], "b": [1.5, 2.5, 3.0]},
schema=schema,
Expand Down
23 changes: 16 additions & 7 deletions tests/table/generator/test_partitioned_function.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
"""Tests for the PartitionedRangeFunction with multi-worker support."""

from __future__ import annotations

from typing import Any

import pyarrow as pa

from vgi.client import Client
Expand All @@ -10,6 +14,11 @@
from .conftest import RunnerWithMode


def _sorted_non_null(values: list[Any | None]) -> list[Any]:
"""Return sorted list, filtering out None values for type safety."""
return sorted(v for v in values if v is not None)


class TestPartitionedRangeFunctionInProcess:
"""In-process tests for the partitioned_range function."""

Expand All @@ -23,7 +32,7 @@ def test_generates_full_range_single_worker(self) -> None:
table = pa.Table.from_batches(outputs)
assert table.num_rows == 10

values = sorted(table.column("value").to_pylist())
values = _sorted_non_null(table.column("value").to_pylist())
assert values == list(range(10))

def test_metadata(self) -> None:
Expand Down Expand Up @@ -64,7 +73,7 @@ def test_values_are_sequential(
outputs, logs = runner(PartitionedRangeFunction, (50,))

table = pa.Table.from_batches(outputs)
values = sorted(table.column("value").to_pylist())
values = _sorted_non_null(table.column("value").to_pylist())
assert values == list(range(50))


Expand All @@ -84,7 +93,7 @@ def test_two_workers_produce_complete_range(self) -> None:
table = pa.Table.from_batches(outputs)
assert table.num_rows == 20

values = sorted(table.column("value").to_pylist())
values = _sorted_non_null(table.column("value").to_pylist())
assert values == list(range(20))

def test_three_workers_produce_complete_range(self) -> None:
Expand All @@ -100,7 +109,7 @@ def test_three_workers_produce_complete_range(self) -> None:
table = pa.Table.from_batches(outputs)
assert table.num_rows == 30

values = sorted(table.column("value").to_pylist())
values = _sorted_non_null(table.column("value").to_pylist())
assert values == list(range(30))

def test_workers_produce_large_range(self) -> None:
Expand All @@ -116,7 +125,7 @@ def test_workers_produce_large_range(self) -> None:
table = pa.Table.from_batches(outputs)
assert table.num_rows == 10000

values = sorted(table.column("value").to_pylist())
values = _sorted_non_null(table.column("value").to_pylist())
assert values == list(range(10000))

def test_uneven_distribution(self) -> None:
Expand All @@ -134,7 +143,7 @@ def test_uneven_distribution(self) -> None:
table = pa.Table.from_batches(outputs)
assert table.num_rows == 7

values = sorted(table.column("value").to_pylist())
values = _sorted_non_null(table.column("value").to_pylist())
assert values == list(range(7))

def test_single_worker_fallback(self) -> None:
Expand All @@ -150,5 +159,5 @@ def test_single_worker_fallback(self) -> None:
table = pa.Table.from_batches(outputs)
assert table.num_rows == 15

values = sorted(table.column("value").to_pylist())
values = _sorted_non_null(table.column("value").to_pylist())
assert values == list(range(15))
2 changes: 1 addition & 1 deletion tests/table/generator/test_random_sample_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,4 @@ def test_values_in_range(self, run_table_function_mode: RunnerWithMode) -> None:
table = pa.Table.from_batches(outputs)
values = table.column("value").to_pylist()

assert all(0 <= v < 1 for v in values)
assert all(v is not None and 0 <= v < 1 for v in values)
29 changes: 16 additions & 13 deletions tests/table/test_function.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
"""Generic tests for TableFunctionGenerator behavior."""

from __future__ import annotations

import pyarrow as pa
import pytest
import structlog

from tests.utils import make_schema
from vgi.function import Arguments, Invocation
from vgi.table_function import (
CardinalityInfo,
Expand All @@ -23,7 +26,7 @@ def test_empty_process_generator(self) -> None:
class EmptyFunction(TableFunctionGenerator):
@property
def output_schema(self) -> pa.Schema:
return pa.schema([pa.field("x", pa.int64())])
return make_schema([pa.field("x", pa.int64())])

with TableFunctionTestClient(EmptyFunction) as client:
outputs = list(client.table_function())
Expand All @@ -37,7 +40,7 @@ def test_single_batch_output(self) -> None:
class SingleBatchFunction(TableFunctionGenerator):
@property
def output_schema(self) -> pa.Schema:
return pa.schema([pa.field("x", pa.int64())])
return make_schema([pa.field("x", pa.int64())])

def process(self) -> OutputGenerator:
yield Output(
Expand All @@ -58,7 +61,7 @@ def test_multiple_batch_output(self) -> None:
class MultiBatchFunction(TableFunctionGenerator):
@property
def output_schema(self) -> pa.Schema:
return pa.schema([pa.field("n", pa.int64())])
return make_schema([pa.field("n", pa.int64())])

def process(self) -> OutputGenerator:
for i in range(3):
Expand Down Expand Up @@ -86,7 +89,7 @@ def test_setup_called_before_process(self) -> None:
class LifecycleFunction(TableFunctionGenerator):
@property
def output_schema(self) -> pa.Schema:
return pa.schema([pa.field("x", pa.int64())])
return make_schema([pa.field("x", pa.int64())])

def setup(self) -> None:
call_order.append("setup")
Expand All @@ -112,7 +115,7 @@ def test_teardown_called_on_exception(self) -> None:
class ExceptionFunction(TableFunctionGenerator):
@property
def output_schema(self) -> pa.Schema:
return pa.schema([pa.field("x", pa.int64())])
return make_schema([pa.field("x", pa.int64())])

def process(self) -> OutputGenerator:
raise ValueError("test error")
Expand Down Expand Up @@ -140,7 +143,7 @@ def test_valid_schema_passes(self) -> None:
class ValidSchemaFunction(TableFunctionGenerator):
@property
def output_schema(self) -> pa.Schema:
return pa.schema([pa.field("x", pa.int64())])
return make_schema([pa.field("x", pa.int64())])

def process(self) -> OutputGenerator:
yield Output(
Expand All @@ -158,11 +161,11 @@ def test_invalid_schema_raises(self) -> None:
class InvalidSchemaFunction(TableFunctionGenerator):
@property
def output_schema(self) -> pa.Schema:
return pa.schema([pa.field("x", pa.int64())])
return make_schema([pa.field("x", pa.int64())])

def process(self) -> OutputGenerator:
# Return batch with wrong column name
wrong_schema = pa.schema([pa.field("y", pa.int64())])
wrong_schema = make_schema([pa.field("y", pa.int64())])
wrong_batch = pa.RecordBatch.from_pydict(
{"y": [1]}, schema=wrong_schema
)
Expand All @@ -187,7 +190,7 @@ def test_default_cardinality_is_none(self) -> None:
class NoCardinalityFunction(TableFunctionGenerator):
@property
def output_schema(self) -> pa.Schema:
return pa.schema([pa.field("x", pa.int64())])
return make_schema([pa.field("x", pa.int64())])

invocation = Invocation(
function_name="test",
Expand All @@ -209,7 +212,7 @@ def test_custom_cardinality(self) -> None:
class CardinalityFunction(TableFunctionGenerator):
@property
def output_schema(self) -> pa.Schema:
return pa.schema([pa.field("x", pa.int64())])
return make_schema([pa.field("x", pa.int64())])

def cardinality(self) -> CardinalityInfo:
return CardinalityInfo(estimate=100, max=1000)
Expand Down Expand Up @@ -244,7 +247,7 @@ class ArgFunction(TableFunctionGenerator):

@property
def output_schema(self) -> pa.Schema:
return pa.schema([pa.field("n", pa.int64())])
return make_schema([pa.field("n", pa.int64())])

def process(self) -> OutputGenerator:
yield Output(
Expand All @@ -270,7 +273,7 @@ class NamedArgFunction(TableFunctionGenerator):

@property
def output_schema(self) -> pa.Schema:
return pa.schema([pa.field("result", pa.int64())])
return make_schema([pa.field("result", pa.int64())])

def process(self) -> OutputGenerator:
yield Output(
Expand Down Expand Up @@ -299,7 +302,7 @@ def test_empty_output_batch_property(self) -> None:
class TestFunction(TableFunctionGenerator):
@property
def output_schema(self) -> pa.Schema:
return pa.schema(
return make_schema(
[
pa.field("a", pa.int64()),
pa.field("b", pa.string()),
Expand Down
17 changes: 11 additions & 6 deletions tests/table_in_out/generator/test_repeat_inputs_function.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""Tests for the RepeatInputsFunction (explosion)."""

from __future__ import annotations

import pyarrow as pa

from tests.utils import make_schema
from vgi.client import Client
from vgi.function import Arguments

Expand All @@ -18,7 +21,7 @@ def test_repeat_custom_count(
output_batches = list(
client.table_in_out_function(
function_name="repeat_inputs",
arguments=Arguments(positional=tuple([repeat_count]), named={}),
arguments=Arguments(positional=(pa.scalar(repeat_count),)),
input=iter(simple_batches),
)
)
Expand All @@ -35,7 +38,7 @@ def test_repeat_single_time(
output_batches = list(
client.table_in_out_function(
function_name="repeat_inputs",
arguments=Arguments(positional=tuple([1]), named={}),
arguments=Arguments(positional=(pa.scalar(1),), named={}),
input=iter(simple_batches),
)
)
Expand All @@ -46,7 +49,7 @@ def test_repeat_single_time(

def test_repeat_distributed_many_batches(self, example_worker: str) -> None:
"""Should correctly repeat across many batches with multiple workers."""
schema = pa.schema([pa.field("a", pa.int64()), pa.field("b", pa.float64())])
schema = make_schema([pa.field("a", pa.int64()), pa.field("b", pa.float64())])

# Create 100 batches, each with 50 rows
num_batches = 100
Expand All @@ -69,7 +72,7 @@ def test_repeat_distributed_many_batches(self, example_worker: str) -> None:
output_batches = list(
client.table_in_out_function(
function_name="repeat_inputs",
arguments=Arguments(positional=tuple([repeat_count]), named={}),
arguments=Arguments(positional=(pa.scalar(repeat_count),)),
input=iter(batches),
)
)
Expand All @@ -80,7 +83,9 @@ def test_repeat_distributed_many_batches(self, example_worker: str) -> None:

def test_repeat_distributed_preserves_data(self, example_worker: str) -> None:
"""Should preserve data correctly when repeated across workers."""
schema = pa.schema([pa.field("id", pa.int64()), pa.field("value", pa.string())])
schema = make_schema(
[pa.field("id", pa.int64()), pa.field("value", pa.string())]
)

# Create batches with distinct values to verify data integrity
batches = [
Expand All @@ -101,7 +106,7 @@ def test_repeat_distributed_preserves_data(self, example_worker: str) -> None:
output_batches = list(
client.table_in_out_function(
function_name="repeat_inputs",
arguments=Arguments(positional=tuple([repeat_count]), named={}),
arguments=Arguments(positional=(pa.scalar(repeat_count),)),
input=iter(batches),
)
)
Expand Down
Loading