Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vgi/examples/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@ def output_schema(self) -> pa.Schema:
When vgi_verbose_mode is "true", includes an extra "details" column.
This demonstrates how settings can affect the bind result.
"""
fields: list[pa.Field] = [
fields: list[pa.Field[pa.DataType]] = [
pa.field("id", pa.int64()),
pa.field("value", pa.float64()),
]
Expand Down
72 changes: 72 additions & 0 deletions vgi/output_complete.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""Internal output normalization for VGI function generators.

This module provides the _OutputComplete class used by all function types
to normalize generator yields into a consistent format with guaranteed
non-None batches.

This is an internal module - users should not import from here directly.
"""

from dataclasses import dataclass
from typing import TYPE_CHECKING

import pyarrow as pa

import vgi.log

if TYPE_CHECKING:
from vgi.table_function import Output as TableOutput
from vgi.table_in_out_function import Output as TableInOutOutput

__all__ = ["OutputComplete"]


@dataclass(frozen=True, slots=True)
class OutputComplete:
"""Internal: Output with guaranteed non-None batch.

Used by the framework to normalize generator yields. When the user yields
None, Output with None batch, or Message, this class ensures we always
have a valid RecordBatch for the protocol.

Attributes:
batch: Always a valid RecordBatch (never None).
has_more: If True, generator expects another send() call.
Only used by TableInOutGeneratorFunction.
log_message: Present when user yielded Message directly.

"""

batch: pa.RecordBatch
has_more: bool = False
log_message: vgi.log.Message | None = None

@classmethod
def from_process_result(
cls,
source: "vgi.log.Message | TableOutput | TableInOutOutput | None",
empty_batch: pa.RecordBatch,
) -> "OutputComplete":
"""Create from user's yield value.

Args:
source: What the user yielded (Output, Message, or None).
empty_batch: Empty batch to substitute when needed.

Returns:
Normalized output with guaranteed non-None batch.

"""
if source is None:
return cls(batch=empty_batch)
if isinstance(source, vgi.log.Message):
# When yielding a log message, has_more=True so the caller
# re-sends the current input after the message is processed
return cls(batch=empty_batch, has_more=True, log_message=source)
# source is Output (either TableOutput or TableInOutOutput)
# TableOutput doesn't have has_more, TableInOutOutput does
has_more = getattr(source, "has_more", False)
return cls(
batch=source.batch if source.batch is not None else empty_batch,
has_more=has_more,
)
46 changes: 8 additions & 38 deletions vgi/scalar_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def compute(self, batch: pa.RecordBatch) -> pa.Array:
import vgi.function
import vgi.log
from vgi.exceptions import SchemaValidationError
from vgi.output_complete import OutputComplete
from vgi.table_function import Output, ProtocolOutput

__all__ = [
Expand Down Expand Up @@ -165,37 +166,6 @@ class ProtocolInput:
metadata: pa.KeyValueMetadata | None = None


@dataclass(frozen=True, slots=True)
class _ScalarOutputComplete:
"""Internal: Output with guaranteed non-None batch for scalar functions."""

batch: pa.RecordBatch
log_message: vgi.log.Message | None = None

@classmethod
def from_process_result(
cls,
source: vgi.log.Message | Output,
empty_batch: pa.RecordBatch,
) -> _ScalarOutputComplete:
"""Create from user's yield value.

Args:
source: What the user yielded (Output or Message).
empty_batch: Empty batch to substitute when yielding Message.

Returns:
Normalized output with guaranteed non-None batch.

"""
if isinstance(source, vgi.log.Message):
return cls(batch=empty_batch, log_message=source)
# source is Output
return cls(
batch=source.batch if source.batch is not None else empty_batch,
)


class ScalarFunctionGenerator(vgi.function.Function[vgi.function.FunctionInitInput]):
"""Generator-based base class for scalar functions.

Expand Down Expand Up @@ -299,23 +269,23 @@ def _process_and_validate(
self,
generator: ScalarOutputGenerator,
input_batch: pa.RecordBatch,
) -> _ScalarOutputComplete:
) -> OutputComplete:
"""Process a batch and validate schemas and row count.

Args:
generator: The user's process() generator.
input_batch: The input RecordBatch to process.

Returns:
_ScalarOutputComplete with validated output batch.
OutputComplete with validated output batch.

Raises:
SchemaValidationError: If input or output batch schema doesn't match.
RowCountMismatchError: If output row count doesn't match input.

"""
self._validate_input_schema(input_batch)
result: _ScalarOutputComplete = _ScalarOutputComplete.from_process_result(
result: OutputComplete = OutputComplete.from_process_result(
generator.send(input_batch),
self.empty_output_batch,
)
Expand All @@ -330,22 +300,22 @@ def _process_with_exception_handling(
self,
generator: ScalarOutputGenerator,
input_batch: pa.RecordBatch,
) -> _ScalarOutputComplete:
) -> OutputComplete:
"""Process a batch with exception handling.

Wraps _process_and_validate to catch exceptions and convert them
to _ScalarOutputComplete with an error log message.
to OutputComplete with an error log message.
"""
try:
return self._process_and_validate(generator, input_batch)
except Exception as e:
return _ScalarOutputComplete(
return OutputComplete(
batch=self.empty_output_batch,
log_message=vgi.log.Message.from_exception(e),
)

@final
def _should_terminate(self, result: _ScalarOutputComplete) -> bool:
def _should_terminate(self, result: OutputComplete) -> bool:
"""Check if processing should terminate due to an exception."""
return (
result.log_message is not None
Expand Down
54 changes: 7 additions & 47 deletions vgi/table_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import vgi.function
import vgi.ipc_utils
import vgi.log
from vgi.output_complete import OutputComplete

__all__ = [
"TableCardinality",
Expand Down Expand Up @@ -133,47 +134,6 @@ class Output:
OutputGenerator = Generator[vgi.log.Message | Output, None, None]


@dataclass(frozen=True, slots=True)
class _OutputComplete:
"""Internal: Output with guaranteed non-None batch.

Used by the framework to normalize generator yields. When the user yields
None, Output with None batch, or Message, this class ensures we always
have a valid RecordBatch for the protocol.

Attributes:
batch: Always a valid RecordBatch (never None).
log_message: Present when user yielded Message directly.

"""

batch: pa.RecordBatch
log_message: vgi.log.Message | None = None

@classmethod
def from_process_result(
cls,
source: vgi.log.Message | Output,
empty_batch: pa.RecordBatch,
) -> "_OutputComplete":
"""Create from user's yield value.

Args:
source: What the user yielded (Output or Message).
empty_batch: Empty batch to substitute when needed.

Returns:
Normalized output with guaranteed non-None batch.

"""
if isinstance(source, vgi.log.Message):
return cls(batch=empty_batch, log_message=source)
# source is Output
return cls(
batch=source.batch if source.batch is not None else empty_batch,
)


@dataclass(frozen=True, slots=True)
class ProtocolOutput:
"""Output yielded by the generator after each send().
Expand Down Expand Up @@ -210,7 +170,7 @@ def metadata(
)

@classmethod
def from_process_result(cls, process_result: "_OutputComplete") -> "ProtocolOutput":
def from_process_result(cls, process_result: "OutputComplete") -> "ProtocolOutput":
"""Create a ProtocolOutput from an Output and status.

Args:
Expand Down Expand Up @@ -384,7 +344,7 @@ def process(self) -> OutputGenerator:
"""

@final
def _process_and_validate(self, generator: OutputGenerator) -> _OutputComplete:
def _process_and_validate(self, generator: OutputGenerator) -> OutputComplete:
"""Process a batch and validate the output schema.

Converts the result of the generator to OutputComplete, and
Expand All @@ -400,7 +360,7 @@ def _process_and_validate(self, generator: OutputGenerator) -> _OutputComplete:
SchemaValidationError: If output batch schema doesn't match.

"""
result: _OutputComplete = _OutputComplete.from_process_result(
result: OutputComplete = OutputComplete.from_process_result(
generator.send(None),
self.empty_output_batch,
)
Expand All @@ -411,7 +371,7 @@ def _process_and_validate(self, generator: OutputGenerator) -> _OutputComplete:
def _process_with_exception_handling(
self,
generator: OutputGenerator,
) -> _OutputComplete:
) -> OutputComplete:
"""Process a batch with exception handling.

Wraps _process_and_validate to catch exceptions and convert them
Expand All @@ -425,13 +385,13 @@ def _process_with_exception_handling(
except StopIteration:
raise
except Exception as e:
return _OutputComplete(
return OutputComplete(
batch=self.empty_output_batch,
log_message=vgi.log.Message.from_exception(e),
)

@final
def _should_terminate(self, result: _OutputComplete) -> bool:
def _should_terminate(self, result: OutputComplete) -> bool:
"""Check if processing should terminate due to an exception."""
return (
result.log_message is not None
Expand Down
Loading
Loading