diff --git a/vgi/examples/table.py b/vgi/examples/table.py index eecdf61..ca98763 100644 --- a/vgi/examples/table.py +++ b/vgi/examples/table.py @@ -700,7 +700,7 @@ def output_schema(self) -> pa.Schema: When vgi_verbose_mode is "true", includes an extra "details" column. This demonstrates how settings can affect the bind result. """ - fields: list[pa.Field] = [ + fields: list[pa.Field[pa.DataType]] = [ pa.field("id", pa.int64()), pa.field("value", pa.float64()), ] diff --git a/vgi/output_complete.py b/vgi/output_complete.py new file mode 100644 index 0000000..7395029 --- /dev/null +++ b/vgi/output_complete.py @@ -0,0 +1,72 @@ +"""Internal output normalization for VGI function generators. + +This module provides the _OutputComplete class used by all function types +to normalize generator yields into a consistent format with guaranteed +non-None batches. + +This is an internal module - users should not import from here directly. +""" + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import pyarrow as pa + +import vgi.log + +if TYPE_CHECKING: + from vgi.table_function import Output as TableOutput + from vgi.table_in_out_function import Output as TableInOutOutput + +__all__ = ["OutputComplete"] + + +@dataclass(frozen=True, slots=True) +class OutputComplete: + """Internal: Output with guaranteed non-None batch. + + Used by the framework to normalize generator yields. When the user yields + None, Output with None batch, or Message, this class ensures we always + have a valid RecordBatch for the protocol. + + Attributes: + batch: Always a valid RecordBatch (never None). + has_more: If True, generator expects another send() call. + Only used by TableInOutGeneratorFunction. + log_message: Present when user yielded Message directly. + + """ + + batch: pa.RecordBatch + has_more: bool = False + log_message: vgi.log.Message | None = None + + @classmethod + def from_process_result( + cls, + source: "vgi.log.Message | TableOutput | TableInOutOutput | None", + empty_batch: pa.RecordBatch, + ) -> "OutputComplete": + """Create from user's yield value. + + Args: + source: What the user yielded (Output, Message, or None). + empty_batch: Empty batch to substitute when needed. + + Returns: + Normalized output with guaranteed non-None batch. + + """ + if source is None: + return cls(batch=empty_batch) + if isinstance(source, vgi.log.Message): + # When yielding a log message, has_more=True so the caller + # re-sends the current input after the message is processed + return cls(batch=empty_batch, has_more=True, log_message=source) + # source is Output (either TableOutput or TableInOutOutput) + # TableOutput doesn't have has_more, TableInOutOutput does + has_more = getattr(source, "has_more", False) + return cls( + batch=source.batch if source.batch is not None else empty_batch, + has_more=has_more, + ) diff --git a/vgi/scalar_function.py b/vgi/scalar_function.py index cfac8ca..c101d47 100644 --- a/vgi/scalar_function.py +++ b/vgi/scalar_function.py @@ -50,6 +50,7 @@ def compute(self, batch: pa.RecordBatch) -> pa.Array: import vgi.function import vgi.log from vgi.exceptions import SchemaValidationError +from vgi.output_complete import OutputComplete from vgi.table_function import Output, ProtocolOutput __all__ = [ @@ -165,37 +166,6 @@ class ProtocolInput: metadata: pa.KeyValueMetadata | None = None -@dataclass(frozen=True, slots=True) -class _ScalarOutputComplete: - """Internal: Output with guaranteed non-None batch for scalar functions.""" - - batch: pa.RecordBatch - log_message: vgi.log.Message | None = None - - @classmethod - def from_process_result( - cls, - source: vgi.log.Message | Output, - empty_batch: pa.RecordBatch, - ) -> _ScalarOutputComplete: - """Create from user's yield value. - - Args: - source: What the user yielded (Output or Message). - empty_batch: Empty batch to substitute when yielding Message. - - Returns: - Normalized output with guaranteed non-None batch. - - """ - if isinstance(source, vgi.log.Message): - return cls(batch=empty_batch, log_message=source) - # source is Output - return cls( - batch=source.batch if source.batch is not None else empty_batch, - ) - - class ScalarFunctionGenerator(vgi.function.Function[vgi.function.FunctionInitInput]): """Generator-based base class for scalar functions. @@ -299,7 +269,7 @@ def _process_and_validate( self, generator: ScalarOutputGenerator, input_batch: pa.RecordBatch, - ) -> _ScalarOutputComplete: + ) -> OutputComplete: """Process a batch and validate schemas and row count. Args: @@ -307,7 +277,7 @@ def _process_and_validate( input_batch: The input RecordBatch to process. Returns: - _ScalarOutputComplete with validated output batch. + OutputComplete with validated output batch. Raises: SchemaValidationError: If input or output batch schema doesn't match. @@ -315,7 +285,7 @@ def _process_and_validate( """ self._validate_input_schema(input_batch) - result: _ScalarOutputComplete = _ScalarOutputComplete.from_process_result( + result: OutputComplete = OutputComplete.from_process_result( generator.send(input_batch), self.empty_output_batch, ) @@ -330,22 +300,22 @@ def _process_with_exception_handling( self, generator: ScalarOutputGenerator, input_batch: pa.RecordBatch, - ) -> _ScalarOutputComplete: + ) -> OutputComplete: """Process a batch with exception handling. Wraps _process_and_validate to catch exceptions and convert them - to _ScalarOutputComplete with an error log message. + to OutputComplete with an error log message. """ try: return self._process_and_validate(generator, input_batch) except Exception as e: - return _ScalarOutputComplete( + return OutputComplete( batch=self.empty_output_batch, log_message=vgi.log.Message.from_exception(e), ) @final - def _should_terminate(self, result: _ScalarOutputComplete) -> bool: + def _should_terminate(self, result: OutputComplete) -> bool: """Check if processing should terminate due to an exception.""" return ( result.log_message is not None diff --git a/vgi/table_function.py b/vgi/table_function.py index 9073439..c742726 100644 --- a/vgi/table_function.py +++ b/vgi/table_function.py @@ -31,6 +31,7 @@ import vgi.function import vgi.ipc_utils import vgi.log +from vgi.output_complete import OutputComplete __all__ = [ "TableCardinality", @@ -133,47 +134,6 @@ class Output: OutputGenerator = Generator[vgi.log.Message | Output, None, None] -@dataclass(frozen=True, slots=True) -class _OutputComplete: - """Internal: Output with guaranteed non-None batch. - - Used by the framework to normalize generator yields. When the user yields - None, Output with None batch, or Message, this class ensures we always - have a valid RecordBatch for the protocol. - - Attributes: - batch: Always a valid RecordBatch (never None). - log_message: Present when user yielded Message directly. - - """ - - batch: pa.RecordBatch - log_message: vgi.log.Message | None = None - - @classmethod - def from_process_result( - cls, - source: vgi.log.Message | Output, - empty_batch: pa.RecordBatch, - ) -> "_OutputComplete": - """Create from user's yield value. - - Args: - source: What the user yielded (Output or Message). - empty_batch: Empty batch to substitute when needed. - - Returns: - Normalized output with guaranteed non-None batch. - - """ - if isinstance(source, vgi.log.Message): - return cls(batch=empty_batch, log_message=source) - # source is Output - return cls( - batch=source.batch if source.batch is not None else empty_batch, - ) - - @dataclass(frozen=True, slots=True) class ProtocolOutput: """Output yielded by the generator after each send(). @@ -210,7 +170,7 @@ def metadata( ) @classmethod - def from_process_result(cls, process_result: "_OutputComplete") -> "ProtocolOutput": + def from_process_result(cls, process_result: "OutputComplete") -> "ProtocolOutput": """Create a ProtocolOutput from an Output and status. Args: @@ -384,7 +344,7 @@ def process(self) -> OutputGenerator: """ @final - def _process_and_validate(self, generator: OutputGenerator) -> _OutputComplete: + def _process_and_validate(self, generator: OutputGenerator) -> OutputComplete: """Process a batch and validate the output schema. Converts the result of the generator to OutputComplete, and @@ -400,7 +360,7 @@ def _process_and_validate(self, generator: OutputGenerator) -> _OutputComplete: SchemaValidationError: If output batch schema doesn't match. """ - result: _OutputComplete = _OutputComplete.from_process_result( + result: OutputComplete = OutputComplete.from_process_result( generator.send(None), self.empty_output_batch, ) @@ -411,7 +371,7 @@ def _process_and_validate(self, generator: OutputGenerator) -> _OutputComplete: def _process_with_exception_handling( self, generator: OutputGenerator, - ) -> _OutputComplete: + ) -> OutputComplete: """Process a batch with exception handling. Wraps _process_and_validate to catch exceptions and convert them @@ -425,13 +385,13 @@ def _process_with_exception_handling( except StopIteration: raise except Exception as e: - return _OutputComplete( + return OutputComplete( batch=self.empty_output_batch, log_message=vgi.log.Message.from_exception(e), ) @final - def _should_terminate(self, result: _OutputComplete) -> bool: + def _should_terminate(self, result: OutputComplete) -> bool: """Check if processing should terminate due to an exception.""" return ( result.log_message is not None diff --git a/vgi/table_in_out_function.py b/vgi/table_in_out_function.py index 879ca5d..a3d8378 100644 --- a/vgi/table_in_out_function.py +++ b/vgi/table_in_out_function.py @@ -78,6 +78,7 @@ def process(self, batch: pa.RecordBatch) -> OutputGenerator: import vgi.ipc_utils import vgi.log import vgi.table_function +from vgi.output_complete import OutputComplete __all__ = [ "ProtocolInput", @@ -181,7 +182,7 @@ def metadata( @classmethod def from_process_result( - cls, process_result: "_OutputComplete", in_finalize_phase: bool + cls, process_result: "OutputComplete", in_finalize_phase: bool ) -> "ProtocolOutput": """Create a ProtocolOutput from an Output and status. @@ -353,52 +354,6 @@ def wrapper(self: T, first_batch: pa.RecordBatch) -> OutputGenerator: return wrapper -@dataclass(frozen=True, slots=True) -class _OutputComplete: - """Internal: Output with guaranteed non-None batch. - - Used by the framework to normalize generator yields. When the user yields - None, Output with None batch, or Message, this class ensures we always - have a valid RecordBatch for the protocol. - - Attributes: - batch: Always a valid RecordBatch (never None). - has_more: If True, generator expects another send() call. - log_message: Present when user yielded Message directly. - - """ - - batch: pa.RecordBatch - has_more: bool = False - log_message: vgi.log.Message | None = None - - @classmethod - def from_process_result( - cls, - source: vgi.log.Message | Output | None, - empty_batch: pa.RecordBatch, - ) -> "_OutputComplete": - """Create from user's yield value. - - Args: - source: What the user yielded (Output, Message, or None). - empty_batch: Empty batch to substitute when needed. - - Returns: - Normalized output with guaranteed non-None batch. - - """ - if source is None: - return cls(batch=empty_batch) - if isinstance(source, vgi.log.Message): - return cls(batch=empty_batch, has_more=True, log_message=source) - # source is Output - return cls( - batch=source.batch if source.batch is not None else empty_batch, - has_more=source.has_more, - ) - - class TableInOutGeneratorFunction(vgi.table_function.TableFunctionBase): """Base class for streaming table functions that transform Arrow RecordBatches. @@ -588,7 +543,7 @@ def _process_and_validate( self, generator: OutputGenerator, batch: pa.RecordBatch | None, - ) -> _OutputComplete: + ) -> OutputComplete: """Process a batch and validate both input and output schemas. Validates the input batch schema, sends it to the generator, converts @@ -607,7 +562,7 @@ def _process_and_validate( """ if batch is not None: self._validate_input_schema(batch) - result: _OutputComplete = _OutputComplete.from_process_result( + result: OutputComplete = OutputComplete.from_process_result( generator.send(batch), self.empty_output_batch, ) @@ -619,7 +574,7 @@ def _process_with_exception_handling( self, generator: OutputGenerator, batch: pa.RecordBatch | None, - ) -> _OutputComplete: + ) -> OutputComplete: """Process a batch with exception handling. Wraps _process_and_validate to catch exceptions and convert them @@ -628,13 +583,13 @@ def _process_with_exception_handling( try: return self._process_and_validate(generator, batch) except Exception as e: - return _OutputComplete( + return OutputComplete( batch=self.empty_output_batch, log_message=vgi.log.Message.from_exception(e), ) @final - def _should_terminate(self, result: _OutputComplete) -> bool: + def _should_terminate(self, result: OutputComplete) -> bool: """Check if processing should terminate due to an exception.""" return ( result.log_message is not None