diff --git a/README.md b/README.md index 2b3bea5..e99ffdb 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ - +

@@ -31,6 +31,7 @@ A Python SDK for creating and managing data pipelines between Kafka and ClickHou - Pipeline configuration via YAML or JSON - Schema validation and configuration management - Fine-grained resource control per pipeline component +- Enterprise Edition client (`glassflow.ee`) with DLQ reprocessing and discard ## Installation @@ -157,6 +158,44 @@ client.delete_pipeline("my-pipeline-id") pipeline.delete() ``` +## Enterprise Edition + +The GlassFlow Enterprise Edition adds capabilities on top of the Open Source engine. The SDK exposes them through a drop-in client that extends the Open Source one. Import `Client` from `glassflow.ee` instead of `glassflow.etl`: + +```python +from glassflow.ee import Client + +client = Client(host="your-glassflow-etl-url") +``` + +The Enterprise client does everything the Open Source client does, plus the Enterprise-only features below. Entitlement is enforced by the backend: calling an Enterprise-only operation against a backend that is not licensed for it raises `FeatureNotLicensedError`. + +### DLQ message processing + +When a pipeline component fails to process a message, that message lands in the pipeline's dead-letter queue (DLQ). On the Enterprise client, `pipeline.dlq` adds message management on top of the Open Source `state`, `consume`, and `purge`: + +- `list(batch_size, cursor, component)`: non-destructive paginated read. Returns a page dict with `messages` (each carrying a stable `message_id`, `component`, `error`, `original_message`, and `received_at`), `has_more`, and `next_cursor`. Pass `component` to filter to a single component (`ingestor`, `join`, `sink`, `dedup`, `oltp-receiver`), and pass `next_cursor` back as `cursor` to page. +- `list_iter(batch_size, component)`: lazily iterate over every message, paging via the cursor for you. Yields individual messages, so you do not manage the cursor by hand. +- `reprocess(message_ids)` / `reprocess_all()`: move messages back into the pipeline input to be processed again. +- `discard(message_ids)` / `discard_all()`: permanently remove messages. + +```python +pipeline = client.get_pipeline("my-pipeline-id") + +# Inspect failed messages from the sink only (paged automatically) +ids = [m["message_id"] for m in pipeline.dlq.list_iter(component="sink")] + +# Retry them after fixing the underlying issue +pipeline.dlq.reprocess(ids) # or pipeline.dlq.reprocess_all() + +# Or drop the ones you do not want +pipeline.dlq.discard(["seq_200"]) # or pipeline.dlq.discard_all() +``` + +Reprocessing replays messages through the running pipeline, so the pipeline must be in the `Running` state. Calling `reprocess` on a stopped, terminated, or failed pipeline raises `PipelineNotRunningError`. Discard acts on the queue directly and works in any state. + +`reprocess` and `discard` accept at most 1000 `message_id` values per call. For larger sets, use the `*_all` variants. See the [DLQ documentation](https://docs.glassflow.dev/configuration/dlq) for the full reference. + ## Migrating from V2 to V3 Pipeline version `v2` has been removed. Use `Client.migrate_pipeline_v2_to_v3()` to convert an existing configuration automatically: diff --git a/VERSION b/VERSION index fcdb2e1..ee74734 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -4.0.0 +4.1.0 diff --git a/coverage.xml b/coverage.xml index 8dfb618..afee7dd 100644 --- a/coverage.xml +++ b/coverage.xml @@ -1,13 +1,132 @@ - - + + /home/runner/work/glassflow-python-sdk/glassflow-python-sdk/src /home/runner/work/glassflow-python-sdk/glassflow-python-sdk/src/glassflow - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + @@ -19,7 +138,7 @@ - + @@ -53,42 +172,46 @@ - + - + + - - - + + + + - - - - - - + + + + - - - - - + + + - - - - - + + + + + + - + + + + + + - + @@ -98,45 +221,46 @@ - + - - - - - - - + + + + + - - - - - - - - + + + + + + + + - - - - + + + + - - - + + + + + - - - - - - + + - - + + + + + + + @@ -196,29 +320,33 @@ - + - - - - - - - - - + + + + + + + + - - - + + + + + + + + - + @@ -231,145 +359,152 @@ - - - - - - - - - + + + + - + - + + - + - - - - + + + + + + + + - - + + + + - - - + + - - - + + + - - - - - - - - - - - - - - + + + + + + + + + + + + + + + - - - - - - - - - - - - - + + + + + - - - - - + + + + + + + + + + + + - - - - + + - - - - - - + + + + + + - + + + - - - + + + - - - - - - - - + + + + + + - - - - - - - - + + + + + + + - - + + + - + + + - - + + - - - - - - - + + - + - - + + + + + - - + + + + + + + + + + @@ -405,7 +540,7 @@ - + @@ -418,8 +553,8 @@ - - + + @@ -571,139 +706,176 @@ - + - - - - - - + + + + + + + - - - - - - - + + - - + + - - + - - - - + + + + - - + + + + + + + - - - - + - - - - - + - + + - + + + - - - - + + + + + + - - - - - + + + + + + - - - - + + + + - - - - - - - - - - + + + + + + - - - + + + + + + - - - + + + + - - - + + - + - + - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + @@ -887,7 +1059,7 @@ - + @@ -896,15 +1068,18 @@ - - - - - + + + + + + + + - + @@ -916,91 +1091,138 @@ - + - - + + + + - - - + - + + + - - - + - - - - - - - - - - - - - + - - + + + + + + + + - - + + - - - + + - - - + - + + + + + - - - + + - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/glassflow/ee/__init__.py b/src/glassflow/ee/__init__.py new file mode 100644 index 0000000..0aff448 --- /dev/null +++ b/src/glassflow/ee/__init__.py @@ -0,0 +1,36 @@ +""" +GlassFlow Enterprise SDK. + +Drop-in superset of :mod:`glassflow.etl` that adds Enterprise-only capabilities. +Use it exactly like the OSS client:: + + from glassflow.ee import Client + + client = Client(host="https://...") + pipeline = client.get_pipeline("my-pipeline") + pipeline.dlq.reprocess_all() # Enterprise DLQ management + +All open-source models are re-exported from :mod:`glassflow.etl` for +convenience, so a single import path covers both tiers. +""" + +from glassflow.etl.models import ( + JoinConfig, + PipelineConfig, + SinkConfig, + SourceConfig, +) + +from .client import Client +from .dlq import DLQ +from .pipeline import Pipeline + +__all__ = [ + "Pipeline", + "Client", + "DLQ", + "PipelineConfig", + "SourceConfig", + "SinkConfig", + "JoinConfig", +] diff --git a/src/glassflow/ee/client.py b/src/glassflow/ee/client.py new file mode 100644 index 0000000..3932fb6 --- /dev/null +++ b/src/glassflow/ee/client.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +from glassflow.etl.client import Client as _OSSClient + +from .pipeline import Pipeline + + +class Client(_OSSClient): + """Enterprise GlassFlow client. + + Extends the open-source :class:`glassflow.etl.client.Client`. Every pipeline + it returns is the Enterprise :class:`~.pipeline.Pipeline`, giving + Enterprise-only capabilities a home as they are added. Backend entitlement + is enforced server-side. + """ + + _pipeline_class = Pipeline diff --git a/src/glassflow/ee/dlq.py b/src/glassflow/ee/dlq.py new file mode 100644 index 0000000..b6aa3a3 --- /dev/null +++ b/src/glassflow/ee/dlq.py @@ -0,0 +1,234 @@ +from __future__ import annotations + +import warnings +from typing import Any, Dict, Iterator, List, Optional + +from glassflow.etl import errors +from glassflow.etl.dlq import DLQ as _OSSDLQ + +# Mirrors the backend cap on message_ids per mode=selected request +# (dlqSelectedMaxMessageIDs in glassflow-etl-ee). Validated client-side for a +# fast, offline error instead of a round-trip 400. +MAX_SELECTED_MESSAGE_IDS = 1000 + +# Data-plane components a DLQ message can come from (DataPlaneRoles in +# glassflow-etl-ee). Used to validate the list() component filter client-side. +DLQ_COMPONENTS = ("ingestor", "join", "sink", "dedup", "oltp-receiver") + + +class DLQ(_OSSDLQ): + """Enterprise Dead Letter Queue client. + + Extends the open-source :class:`glassflow.etl.dlq.DLQ` with message + management: a non-destructive paginated :meth:`list`, and + :meth:`reprocess`/:meth:`discard` (plus their ``*_all`` variants) that move + messages back into the pipeline or permanently remove them. + + Backend entitlement is enforced server-side; calling these against a backend + that is not licensed for them raises :class:`FeatureNotLicensedError`. + """ + + def list( + self, + batch_size: int = 100, + cursor: Optional[str] = None, + component: Optional[str] = None, + ) -> Dict[str, Any]: + """Read messages from the DLQ without removing them. + + The non-destructive successor to :meth:`consume`. Each message in the + returned page carries a stable ``message_id`` (NATS sequence number, as + a string) plus ``component``, ``error``, ``original_message`` and + ``received_at``. The ``message_id`` values are what :meth:`reprocess` + and :meth:`discard` act on in ``mode=selected``. + + Args: + batch_size: Number of messages per page (between 1 and 1000). + cursor: NATS sequence to resume from, taken from the previous page's + ``next_cursor``; omit for the first page. + component: Filter to messages from a single data-plane component; + must be one of :data:`DLQ_COMPONENTS` (``ingestor``, ``join``, + ``sink``, ``dedup``, ``oltp-receiver``); omit for all components. + + Returns: + A dict with ``messages`` (list of message dicts), ``has_more`` + (bool), and ``next_cursor`` (str, present when ``has_more`` is + true; pass it back as ``cursor`` to fetch the next page). + + Raises: + ValueError: If ``batch_size`` or ``component`` is invalid. + APIError: If the API request fails. + """ + if ( + not isinstance(batch_size, int) + or batch_size < 1 + or batch_size > MAX_SELECTED_MESSAGE_IDS + ): + raise ValueError( + f"batch_size must be an integer between 1 and " + f"{MAX_SELECTED_MESSAGE_IDS}" + ) + + if component is not None and component not in DLQ_COMPONENTS: + raise ValueError(f"component must be one of {', '.join(DLQ_COMPONENTS)}") + + params: Dict[str, Any] = {"batch_size": batch_size} + if cursor is not None: + params["cursor"] = cursor + if component is not None: + params["component"] = component + + response = self._request("GET", f"{self.endpoint}/list", params=params) + if response.status_code == 204 or not response.content: + return {"messages": [], "has_more": False} + return response.json() + + def list_iter( + self, + batch_size: int = 100, + component: Optional[str] = None, + cursor: Optional[str] = None, + ) -> Iterator[Dict[str, Any]]: + """Lazily iterate over every DLQ message, paging via the cursor for you. + + The streaming companion to :meth:`list`: it calls :meth:`list` page by + page and yields each message, so callers do not manage the cursor + themselves. Memory stays flat (one page at a time) and it composes with + ``itertools`` (e.g. ``itertools.islice`` for the first N). + + Args: + batch_size: Messages fetched per underlying page request (1-1000). + component: Optional component filter; see :meth:`list`. + cursor: Optional starting cursor (e.g. to resume a previous run); + omit to start at the beginning of the DLQ. + + Yields: + Individual DLQ message dicts. + """ + while True: + page = self.list(batch_size=batch_size, cursor=cursor, component=component) + yield from page.get("messages", []) + if not page.get("has_more"): + return + cursor = page.get("next_cursor") + # Defensive: a truthy has_more without a cursor would loop forever. + if not cursor: + return + + def reprocess(self, message_ids: List[str]) -> Dict[str, Any]: + """Move specific messages from the DLQ back into the pipeline input. + + Args: + message_ids: ``message_id`` values (from :meth:`list`) to reprocess; + must be non-empty and at most ``MAX_SELECTED_MESSAGE_IDS``. + + Returns: + Dict with ``request_id`` and ``status`` ("accepted"). The republish + happens asynchronously. + + Raises: + ValueError: If ``message_ids`` is empty or too large. + FeatureNotLicensedError: If the backend is not licensed for this. + PipelineNotRunningError: If the pipeline is not in the Running state. + APIError: If the API request fails. + """ + return self._action("reprocess", "selected", message_ids) + + def reprocess_all(self) -> Dict[str, Any]: + """Reprocess every message currently in the DLQ (up to the head at + request time). See :meth:`reprocess`.""" + return self._action("reprocess", "all", None) + + def discard(self, message_ids: List[str]) -> Dict[str, Any]: + """Permanently remove specific messages from the DLQ without + reprocessing them. + + Args: + message_ids: ``message_id`` values (from :meth:`list`) to discard; + must be non-empty and at most ``MAX_SELECTED_MESSAGE_IDS``. + + Returns: + Dict with ``request_id`` and ``discarded_count``. + + Raises: + ValueError: If ``message_ids`` is empty or too large. + FeatureNotLicensedError: If the backend is not licensed for this. + APIError: If the API request fails. + """ + return self._action("discard", "selected", message_ids) + + def discard_all(self) -> Dict[str, Any]: + """Permanently remove every message currently in the DLQ. See + :meth:`discard`.""" + return self._action("discard", "all", None) + + # Deprecated, inherited operations ------------------------------------- + + def consume(self, batch_size: int = 100) -> List[Dict[str, Any]]: + """Deprecated: use :meth:`list`. + + The ``/dlq/consume`` endpoint is being removed in favour of the + non-destructive ``/dlq/list``. This override delegates to :meth:`list` + so existing callers keep working through the transition. + """ + warnings.warn( + "DLQ.consume() is deprecated and the /dlq/consume endpoint is being " + "removed; use DLQ.list() (non-destructive).", + DeprecationWarning, + stacklevel=2, + ) + return self.list(batch_size=batch_size).get("messages", []) + + def purge(self) -> None: + """Deprecated: use :meth:`discard_all`. + + ``/dlq/purge`` remains for backward compatibility but is superseded by + :meth:`discard_all`. + """ + warnings.warn( + "DLQ.purge() is deprecated; use DLQ.discard_all().", + DeprecationWarning, + stacklevel=2, + ) + super().purge() + + def _action( + self, action: str, mode: str, message_ids: Optional[List[str]] + ) -> Dict[str, Any]: + body: Dict[str, Any] = {"mode": mode} + if mode == "selected": + if not message_ids: + raise ValueError( + "message_ids must be non-empty when selecting messages" + ) + if len(message_ids) > MAX_SELECTED_MESSAGE_IDS: + raise ValueError( + f"message_ids cannot exceed {MAX_SELECTED_MESSAGE_IDS} entries" + ) + body["message_ids"] = message_ids + + try: + response = self._request("POST", f"{self.endpoint}/{action}", json=body) + if response.status_code == 204 or not response.content: + return {} + return response.json() + except errors.ForbiddenError as e: + raise errors.FeatureNotLicensedError( + status_code=e.status_code, + message=(f"DLQ {action} requires a GlassFlow Enterprise license"), + response=e.response, + ) from e + except errors.ConflictError as e: + # Reprocess replays through the running pipeline, so the API rejects + # it with 409 when the pipeline is not in the Running state. Discard + # acts on the queue directly and has no such constraint. + if action == "reprocess": + raise errors.PipelineNotRunningError( + status_code=e.status_code, + message=( + "Pipeline must be in the Running state to reprocess DLQ " + "messages" + ), + response=e.response, + ) from e + raise diff --git a/src/glassflow/ee/pipeline.py b/src/glassflow/ee/pipeline.py new file mode 100644 index 0000000..f1a936f --- /dev/null +++ b/src/glassflow/ee/pipeline.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +from typing import Any, Dict, List + +from glassflow.etl import errors +from glassflow.etl.pipeline import Pipeline as _OSSPipeline + +from .dlq import DLQ + + +class Pipeline(_OSSPipeline): + """Enterprise Pipeline. + + Extends the open-source :class:`glassflow.etl.pipeline.Pipeline`. Its ``dlq`` + property exposes the Enterprise :class:`~.dlq.DLQ` (with + ``list``/``reprocess``/``discard``), and it adds :meth:`get_streams`. + Construction is inherited unchanged; only the DLQ collaborator class is + swapped via ``_dlq_class``. + """ + + _dlq_class = DLQ + + @property + def dlq(self) -> DLQ: + """Get the Enterprise DLQ client for this pipeline.""" + return self._dlq + + @dlq.setter + def dlq(self, dlq: DLQ) -> None: + self._dlq = dlq + + def get_streams(self) -> List[Dict[str, Any]]: + """Return the NATS JetStream streams backing this pipeline. + + Each entry has a ``stream_name`` and the ``component`` the stream belongs + to (for example ``ingestor``, ``join``, ``sink``, ``dedup``, ``dlq``). + Useful for diagnosing NATS-level issues. + + Returns: + List of ``{"stream_name": ..., "component": ...}`` dicts. + + Raises: + PipelineNotFoundError: If the pipeline does not exist. + FeatureNotLicensedError: If the backend is not licensed for this. + APIError: If the API request fails. + """ + try: + response = self._request( + "GET", + f"{self.ENDPOINT}/{self.pipeline_id}/streams", + event_name="PipelineStreamsGet", + ) + except errors.ForbiddenError as e: + raise errors.FeatureNotLicensedError( + status_code=e.status_code, + message="Pipeline streams require a GlassFlow Enterprise license", + response=e.response, + details=e.details, + ) from e + if response.status_code == 204 or not response.content: + return [] + return response.json().get("streams", []) diff --git a/src/glassflow/etl/api_client.py b/src/glassflow/etl/api_client.py index 221f6ce..362dfcc 100644 --- a/src/glassflow/etl/api_client.py +++ b/src/glassflow/etl/api_client.py @@ -67,64 +67,81 @@ def _raise_api_error(response: httpx.Response) -> None: error_data = response.json() message = error_data.get("message", None) code = error_data.get("code", None) + details = error_data.get("details", None) except json.JSONDecodeError: message = f"{status_code} {response.reason_phrase}" code = None + details = None error_data = {} if status_code == 400: # Handle specific status validation error codes if code == "TERMINAL_STATE_VIOLATION": raise errors.TerminalStateViolationError( - status_code, message, response=response + status_code, message, response=response, details=details ) elif code == "INVALID_STATUS_TRANSITION": raise errors.InvalidStatusTransitionError( - status_code, - message, - response=response, + status_code, message, response=response, details=details ) elif code == "UNKNOWN_STATUS": - raise errors.UnknownStatusError(status_code, message, response=response) + raise errors.UnknownStatusError( + status_code, message, response=response, details=details + ) elif code == "PIPELINE_ALREADY_IN_STATE": raise errors.PipelineAlreadyInStateError( - status_code, message, response=response + status_code, message, response=response, details=details ) elif code == "PIPELINE_IN_TRANSITION": raise errors.PipelineInTransitionError( - status_code, message, response=response + status_code, message, response=response, details=details ) elif message and message.startswith("invalid json:"): - raise errors.InvalidJsonError(status_code, message, response=response) + raise errors.InvalidJsonError( + status_code, message, response=response, details=details + ) elif message and message == "pipeline id cannot be empty": raise errors.EmptyPipelineIdError( - status_code, message, response=response + status_code, message, response=response, details=details ) elif message and message.startswith( "pipeline can only be deleted if it's stopped or terminated" ): raise errors.PipelineDeletionStateViolationError( - status_code, message, response=response + status_code, message, response=response, details=details ) else: # Generic 400 error for unknown codes - raise errors.ValidationError(status_code, message, response=response) + raise errors.ValidationError( + status_code, message, response=response, details=details + ) elif status_code == 403: - raise errors.ForbiddenError(status_code, message, response=response) + raise errors.ForbiddenError( + status_code, message, response=response, details=details + ) elif status_code == 404: - raise errors.NotFoundError(status_code, message, response=response) + raise errors.NotFoundError( + status_code, message, response=response, details=details + ) + elif status_code == 409: + raise errors.ConflictError( + status_code, message, response=response, details=details + ) elif status_code == 422: raise errors.UnprocessableContentError( - status_code, message, response=response + status_code, message, response=response, details=details ) elif status_code == 500: - raise errors.ServerError(status_code, message, response=response) + raise errors.ServerError( + status_code, message, response=response, details=details + ) else: raise errors.APIError( status_code, message="An error occurred: " f"({status_code} {response.reason_phrase}) {message}", response=response, + details=details, ) def _track_event(self, event_name: str, **kwargs: Any) -> None: diff --git a/src/glassflow/etl/client.py b/src/glassflow/etl/client.py index 333a400..0a8e85d 100644 --- a/src/glassflow/etl/client.py +++ b/src/glassflow/etl/client.py @@ -14,6 +14,12 @@ class Client(APIClient): ENDPOINT = "/api/v1/pipeline" + # Class of the Pipeline this client constructs and returns. Editions + # (e.g. the enterprise client) override this so every pipeline handed back + # is their own subclass, propagating edition-specific behaviour down the + # Client -> Pipeline -> DLQ chain without re-implementing these methods. + _pipeline_class: type[Pipeline] = Pipeline + def __init__(self, host: str | None = None) -> None: """Initialize the PipelineManager class. @@ -35,7 +41,7 @@ def get_pipeline(self, pipeline_id: str): PipelineNotFoundError: If pipeline is not found APIError: If the API request fails """ - return Pipeline(host=self.host, pipeline_id=pipeline_id).get() + return self._pipeline_class(host=self.host, pipeline_id=pipeline_id).get() def list_pipelines(self) -> List[dict]: """Returns a list of available pipelines. @@ -91,9 +97,13 @@ def create_pipeline( "pipeline_config_json_path must be provided" ) if pipeline_config_yaml_path is not None: - pipeline = Pipeline.from_yaml(pipeline_config_yaml_path, host=self.host) + pipeline = self._pipeline_class.from_yaml( + pipeline_config_yaml_path, host=self.host + ) elif pipeline_config_json_path is not None: - pipeline = Pipeline.from_json(pipeline_config_json_path, host=self.host) + pipeline = self._pipeline_class.from_json( + pipeline_config_json_path, host=self.host + ) else: if ( pipeline_config_yaml_path is not None @@ -103,7 +113,7 @@ def create_pipeline( "Either pipeline_config or pipeline_config_yaml_path or " "pipeline_config_json_path must be provided" ) - pipeline = Pipeline(config=pipeline_config, host=self.host) + pipeline = self._pipeline_class(config=pipeline_config, host=self.host) return pipeline.create() @@ -120,7 +130,9 @@ def stop_pipeline(self, pipeline_id: str, terminate: bool = False) -> None: PipelineNotFoundError: If pipeline is not found APIError: If the API request fails """ - Pipeline(host=self.host, pipeline_id=pipeline_id).stop(terminate=terminate) + self._pipeline_class(host=self.host, pipeline_id=pipeline_id).stop( + terminate=terminate + ) def delete_pipeline(self, pipeline_id: str) -> None: """Deletes the pipeline with the given ID. @@ -134,7 +146,7 @@ def delete_pipeline(self, pipeline_id: str) -> None: PipelineNotFoundError: If pipeline is not found APIError: If the API request fails """ - Pipeline(host=self.host, pipeline_id=pipeline_id).delete() + self._pipeline_class(host=self.host, pipeline_id=pipeline_id).delete() def migrate_pipeline_v2_to_v3( self, pipeline_config: dict[str, Any] diff --git a/src/glassflow/etl/errors.py b/src/glassflow/etl/errors.py index b3651f7..b92f015 100644 --- a/src/glassflow/etl/errors.py +++ b/src/glassflow/etl/errors.py @@ -15,10 +15,14 @@ class ConnectionError(RequestError): class APIError(GlassFlowError): """Base for API response errors.""" - def __init__(self, status_code, message=None, response=None): + def __init__(self, status_code, message=None, response=None, details=None): self.status_code = status_code self.response = response self.message = message + # The API's structured ``details`` object, when present. For invalid + # configs it carries the specific cause under ``details["error"]`` (for + # example an Avro/Protobuf schema compilation error). + self.details = details or {} super().__init__(self.message) @@ -34,6 +38,23 @@ class ForbiddenError(APIError): """Raised on 403 Forbidden errors.""" +class FeatureNotLicensedError(ForbiddenError): + """Raised when an Enterprise-only capability is invoked against a backend + that is not licensed for it (the API responds 403). Subclasses + ForbiddenError so existing 403 handling still catches it.""" + + +class ConflictError(APIError): + """Raised on 409 Conflict errors.""" + + +class PipelineNotRunningError(ConflictError): + """Raised when an operation requires a Running pipeline but the pipeline is + in another state (the API responds 409). For example, DLQ reprocessing + replays messages through the running pipeline and is rejected when the + pipeline is stopped, terminated, or failed.""" + + class UnprocessableContentError(APIError): """Raised on 422 Unprocessable Content errors.""" diff --git a/src/glassflow/etl/models/__init__.py b/src/glassflow/etl/models/__init__.py index 9628469..edda7e0 100644 --- a/src/glassflow/etl/models/__init__.py +++ b/src/glassflow/etl/models/__init__.py @@ -31,8 +31,10 @@ KafkaConnectionParams, KafkaConnectionParamsPatch, KafkaField, + KafkaFormat, KafkaMechanism, KafkaProtocol, + KafkaSchema, KafkaSource, KafkaSourcePatch, OTLPLogsSource, @@ -81,8 +83,10 @@ "KafkaConnectionParamsPatch", "KafkaDataType", "KafkaField", + "KafkaFormat", "KafkaMechanism", "KafkaProtocol", + "KafkaSchema", "KafkaSource", "KafkaSourcePatch", "MetadataConfig", diff --git a/src/glassflow/etl/models/pipeline.py b/src/glassflow/etl/models/pipeline.py index a51ef9f..e453b85 100644 --- a/src/glassflow/etl/models/pipeline.py +++ b/src/glassflow/etl/models/pipeline.py @@ -1,12 +1,20 @@ import re -from typing import List, Optional - -from pydantic import BaseModel, Field, field_validator, model_validator +from typing import Any, List, Optional + +from pydantic import ( + BaseModel, + Field, + SerializeAsAny, + field_validator, + model_validator, +) from .base import CaseInsensitiveStrEnum from .metadata import MetadataConfig +from .registry import resolve_source from .resources import PipelineResourcesConfig from .sink import SinkConfig, SinkConfigPatch +from .source import SourceBaseConfig from .sources import KafkaSource, OTLPSource, SourceConfig from .transforms import ( DedupTransform, @@ -38,13 +46,24 @@ class PipelineConfig(BaseModel): version: PipelineVersion = Field(default=PipelineVersion.V3) pipeline_id: str name: Optional[str] = Field(default=None) - sources: List[SourceConfig] + # Each source is dispatched to its concrete class via the source registry + # (see resolve_sources), so editions can add source types without + # redefining a union here. SerializeAsAny preserves subclass-only fields. + sources: List[SerializeAsAny[SourceBaseConfig]] transforms: Optional[List[TransformEntry]] = Field(default=None) join: Optional[JoinConfig] = Field(default=None) sink: SinkConfig metadata: Optional[MetadataConfig] = Field(default=MetadataConfig()) resources: Optional[PipelineResourcesConfig] = Field(default=None) + @field_validator("sources", mode="before") + @classmethod + def resolve_sources(cls, value: Any) -> Any: + """Dispatch each raw source dict to its concrete registered class.""" + if isinstance(value, list): + return [resolve_source(item) for item in value] + return value + @field_validator("version") @classmethod def validate_version(cls, v: PipelineVersion) -> PipelineVersion: diff --git a/src/glassflow/etl/models/registry.py b/src/glassflow/etl/models/registry.py new file mode 100644 index 0000000..7ed65da --- /dev/null +++ b/src/glassflow/etl/models/registry.py @@ -0,0 +1,61 @@ +"""Extensible registry for source types. + +Source types (``kafka``, ``otlp.logs``, ...) are open-ended: an edition can add +new ones without modifying this package. Instead of a static Pydantic +discriminated union (whose members are frozen at class-definition time), models +annotate their field with the base type and dispatch each value to the concrete +class by its ``type`` string at validation time, looked up here. + +Editions register their classes at import:: + + from glassflow.etl.models.registry import register_source + register_source(KinesisSource) +""" + +from __future__ import annotations + +from typing import Any, Dict, Type, TypeVar + +from pydantic import BaseModel + +_T = TypeVar("_T", bound=BaseModel) + +_SOURCE_CLASSES: Dict[str, Type[BaseModel]] = {} + + +def _type_key(cls: Type[BaseModel]) -> str: + """Derive the discriminator key from a model's ``type`` field default.""" + field = cls.model_fields.get("type") + if field is None or field.default is None: + raise ValueError( + f"{cls.__name__} must define a 'type' field with a literal default " + "to be registered" + ) + return str(field.default) + + +def register_source(cls: Type[_T]) -> Type[_T]: + """Register a source class, keyed by its ``type`` default. Usable as a + decorator. Idempotent for re-imports.""" + _SOURCE_CLASSES[_type_key(cls)] = cls + return cls + + +def _resolve(value: Any, registry: Dict[str, Type[BaseModel]], kind: str) -> Any: + """Coerce a raw dict to the registered concrete model by its ``type``. + + Already-constructed model instances and non-dict values pass through + untouched so Pydantic can validate (or reject) them normally. + """ + if isinstance(value, BaseModel) or not isinstance(value, dict): + return value + type_value = value.get("type") + cls = registry.get(str(type_value)) if type_value is not None else None + if cls is None: + known = ", ".join(sorted(registry)) or "(none registered)" + raise ValueError(f"Unknown {kind} type {type_value!r}. Known types: {known}") + return cls.model_validate(value) + + +def resolve_source(value: Any) -> Any: + return _resolve(value, _SOURCE_CLASSES, "source") diff --git a/src/glassflow/etl/models/sources/__init__.py b/src/glassflow/etl/models/sources/__init__.py index 6b8ebcd..bb93ca8 100644 --- a/src/glassflow/etl/models/sources/__init__.py +++ b/src/glassflow/etl/models/sources/__init__.py @@ -14,14 +14,17 @@ from pydantic import Field # noqa: F401 +from ..registry import register_source from ..source import SourceBaseConfig, SourceBaseConfigPatch, SourceType from .kafka import ( ConsumerGroupOffset, KafkaConnectionParams, KafkaConnectionParamsPatch, KafkaField, + KafkaFormat, KafkaMechanism, KafkaProtocol, + KafkaSchema, KafkaSource, KafkaSourcePatch, SchemaRegistry, @@ -34,7 +37,10 @@ OTLPTracesSource, ) -# Discriminated union -- Pydantic resolves the concrete class via the `type` field. +# Discriminated union -- kept as a convenience type alias for the OSS source +# set. The PipelineConfig.sources field no longer uses it directly; instead it +# dispatches via the source registry so editions can add new source types +# without redefining this union. SourceConfig = Annotated[ Union[KafkaSource, OTLPLogsSource, OTLPMetricsSource, OTLPTracesSource], Field(discriminator="type"), @@ -46,6 +52,15 @@ AnySource = SourceConfig AnySourcePatch = SourceConfigPatch +# Register the OSS source types so the registry-backed dispatch can resolve them. +for _source_cls in ( + KafkaSource, + OTLPLogsSource, + OTLPMetricsSource, + OTLPTracesSource, +): + register_source(_source_cls) + __all__ = [ # Base "SourceType", @@ -56,8 +71,10 @@ "KafkaConnectionParams", "KafkaConnectionParamsPatch", "KafkaField", + "KafkaFormat", "KafkaMechanism", "KafkaProtocol", + "KafkaSchema", "KafkaSource", "KafkaSourcePatch", "SchemaRegistry", diff --git a/src/glassflow/etl/models/sources/kafka.py b/src/glassflow/etl/models/sources/kafka.py index 3f6158c..092b269 100644 --- a/src/glassflow/etl/models/sources/kafka.py +++ b/src/glassflow/etl/models/sources/kafka.py @@ -2,13 +2,19 @@ from typing import Any, List, Literal, Optional -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, ConfigDict, Field, model_validator from ..base import CaseInsensitiveStrEnum from ..data_types import KafkaDataType from ..source import SourceBaseConfig, SourceBaseConfigPatch, SourceType +class KafkaFormat(CaseInsensitiveStrEnum): + JSON = "json" + AVRO = "avro" + PROTOBUF = "protobuf" + + class KafkaProtocol(CaseInsensitiveStrEnum): SSL = "SSL" SASL_SSL = "SASL_SSL" @@ -44,6 +50,26 @@ class KafkaField(BaseModel): type: KafkaDataType +class KafkaSchema(BaseModel): + """Unified schema for a Kafka source. The shape used is selected by the + source's :class:`KafkaFormat`: + + - ``json`` -> ``fields`` (GlassFlow field declarations) + - ``avro`` -> ``file`` (the inline ``.avsc`` schema text) + - ``protobuf`` -> ``file`` (the inline ``.proto`` text) and ``message_type`` + (the message within it to decode) + + On reads the backend also returns ``parsed_fields``: the field list parsed + from the avsc/proto. It is read-only, exposed for inspection, and never sent + back on create/edit. + """ + + fields: Optional[List[KafkaField]] = Field(default=None) + file: Optional[str] = Field(default=None) + message_type: Optional[str] = Field(default=None) + parsed_fields: Optional[List[KafkaField]] = Field(default=None, exclude=True) + + class KafkaConnectionParams(BaseModel): brokers: List[str] protocol: KafkaProtocol @@ -72,6 +98,26 @@ def update(self, patch: "KafkaConnectionParamsPatch") -> "KafkaConnectionParams" return KafkaConnectionParams.model_validate(merged_dict) +def _parse_source_schema(data: Any) -> Any: + """Accept the legacy top-level ``schema_fields`` (deprecated) by folding it + into the unified ``schema`` object's ``fields``. The ``schema`` object itself + maps directly to ``source_schema`` via its alias. Also drops an empty + ``schema_registry``. Left untouched if ``source_schema`` is supplied by name. + """ + if not isinstance(data, dict) or "source_schema" in data: + return data + data = dict(data) + if data.get("schema_registry", None) == {}: + data.pop("schema_registry", None) + if "schema_fields" in data: + fields = data.pop("schema_fields") + schema = data.get("schema") + schema = dict(schema) if isinstance(schema, dict) else {} + schema.setdefault("fields", fields) + data["schema"] = schema + return data + + class KafkaSource(SourceBaseConfig): """Kafka source configuration. @@ -79,21 +125,35 @@ class KafkaSource(SourceBaseConfig): and a single topic string. """ + model_config = ConfigDict(populate_by_name=True) + type: Literal[SourceType.KAFKA] = SourceType.KAFKA connection_params: KafkaConnectionParams topic: str consumer_group_initial_offset: ConsumerGroupOffset = ConsumerGroupOffset.LATEST schema_registry: Optional[SchemaRegistry] = Field(default=None) schema_version: Optional[str] = Field(default=None) - schema_fields: Optional[List[KafkaField]] = Field(default=None) + # Payload wire format. ``None`` means JSON (the backend default) and is + # omitted from the serialized config. ``avro`` and ``protobuf`` are + # Enterprise features and are rejected by an unlicensed backend. + format: Optional[KafkaFormat] = Field(default=None) + # All schema-related config in one place, serialized as the ``schema`` + # object (``fields`` for json, ``file`` [+ ``message_type``] for + # avro/protobuf). The legacy top-level ``schema_fields`` is still accepted on + # input (see _parse_source_schema). + source_schema: Optional[KafkaSchema] = Field(default=None, alias="schema") @model_validator(mode="before") @classmethod - def validate_empty_schema_registry(cls, data: Any) -> Any: - if isinstance(data, dict): - if data.get("schema_registry", None) == {}: - data.pop("schema_registry", None) - return data + def parse_schema(cls, data: Any) -> Any: + return _parse_source_schema(data) + + @property + def schema_fields(self) -> Optional[List[KafkaField]]: + """Backward-compatible accessor for the JSON field declarations, held + under ``schema.fields``. (``schema.parsed_fields`` is read-only backend + info and is intentionally not surfaced here.)""" + return self.source_schema.fields if self.source_schema else None @model_validator(mode="after") def validate_schema_registry_requires_version(self) -> "KafkaSource": @@ -104,6 +164,24 @@ def validate_schema_registry_requires_version(self) -> "KafkaSource": ) return self + @model_validator(mode="after") + def validate_format_schema(self) -> "KafkaSource": + """The schema shape must match the declared format.""" + schema = self.source_schema + if self.format == KafkaFormat.AVRO: + if schema is None or not schema.file: + raise ValueError("avro format requires schema.file") + elif self.format == KafkaFormat.PROTOBUF: + if schema is None or not schema.file or not schema.message_type: + raise ValueError( + "protobuf format requires schema.file and schema.message_type" + ) + elif schema is not None and (schema.file or schema.message_type): + raise ValueError( + "schema.file / message_type require format 'avro' or 'protobuf'" + ) + return self + def update(self, patch: "KafkaSourcePatch") -> "KafkaSource": """Apply a patch to this source config.""" update_dict = self.model_copy(deep=True) @@ -116,8 +194,11 @@ def update(self, patch: "KafkaSourcePatch") -> "KafkaSource": if patch.topic is not None: update_dict.topic = patch.topic - if patch.schema_fields is not None: - update_dict.schema_fields = patch.schema_fields + if patch.format is not None: + update_dict.format = patch.format + + if patch.source_schema is not None: + update_dict.source_schema = patch.source_schema return update_dict @@ -142,6 +223,19 @@ class KafkaConnectionParamsPatch(BaseModel): class KafkaSourcePatch(SourceBaseConfigPatch): """Patch model for KafkaSource.""" + model_config = ConfigDict(populate_by_name=True) + connection_params: Optional[KafkaConnectionParamsPatch] = Field(default=None) topic: Optional[str] = Field(default=None) - schema_fields: Optional[List[KafkaField]] = Field(default=None) + format: Optional[KafkaFormat] = Field(default=None) + source_schema: Optional[KafkaSchema] = Field(default=None, alias="schema") + + @model_validator(mode="before") + @classmethod + def parse_schema(cls, data: Any) -> Any: + return _parse_source_schema(data) + + @property + def schema_fields(self) -> Optional[List[KafkaField]]: + """Backward-compatible accessor for ``schema.fields``.""" + return self.source_schema.fields if self.source_schema else None diff --git a/src/glassflow/etl/pipeline.py b/src/glassflow/etl/pipeline.py index 1e34cda..ae51586 100644 --- a/src/glassflow/etl/pipeline.py +++ b/src/glassflow/etl/pipeline.py @@ -18,6 +18,17 @@ class Pipeline(APIClient): ENDPOINT = "/api/v1/pipeline" + # Class of the DLQ client this pipeline constructs. Editions override this + # to have ``self.dlq`` expose their own DLQ subclass without re-implementing + # construction. The Enterprise DLQ itself lands in a follow-up PR. + _dlq_class: type[DLQ] = DLQ + + # Config models this pipeline validates/serializes. Editions override these + # to use their extended PipelineConfig (e.g. with EE-only source types and + # formats) without re-implementing __init__/get/update. + _config_class: type[models.PipelineConfig] = models.PipelineConfig + _config_patch_class: type[models.PipelineConfigPatch] = models.PipelineConfigPatch + def __init__( self, host: str | None = None, @@ -43,14 +54,14 @@ def __init__( if config is not None: if isinstance(config, dict): - self.config = models.PipelineConfig.model_validate(config) + self.config = self._config_class.model_validate(config) else: self.config = config self.pipeline_id = self.config.pipeline_id else: self.config = None - self._dlq = DLQ(pipeline_id=self.pipeline_id, host=host) + self._dlq = self._dlq_class(pipeline_id=self.pipeline_id, host=host) self.status: models.PipelineStatus | None = None def get( @@ -85,9 +96,9 @@ def get( event_name="PipelineGet", **kwargs, ) - self.config = models.PipelineConfig.model_validate(response.json()) + self.config = self._config_class.model_validate(response.json()) self.health() - self._dlq = DLQ(pipeline_id=self.pipeline_id, host=self.host) + self._dlq = self._dlq_class(pipeline_id=self.pipeline_id, host=self.host) return self def create(self) -> Pipeline: @@ -169,7 +180,7 @@ def update( """ self.get() # Get latest config if isinstance(config_patch, dict): - config_patch = models.PipelineConfigPatch.model_validate(config_patch) + config_patch = self._config_patch_class.model_validate(config_patch) else: config_patch = config_patch updated_config = self.config.update(config_patch) @@ -333,8 +344,8 @@ def from_json(cls, json_path: str, host: str | None = None) -> Pipeline: config = json.load(f) return cls(config=config, host=host) - @staticmethod - def validate_config(config: dict[str, Any]) -> bool: + @classmethod + def validate_config(cls, config: dict[str, Any]) -> bool: """ Validate a pipeline configuration. @@ -348,7 +359,7 @@ def validate_config(config: dict[str, Any]) -> bool: ValueError: If the configuration is invalid ValidationError: If the configuration fails Pydantic validation """ - models.PipelineConfig.model_validate(config) + cls._config_class.model_validate(config) return True @property @@ -437,9 +448,17 @@ def _request( ) from e except errors.UnprocessableContentError as e: self._track_event(event_name, error_type="InvalidPipelineConfig") + message = e.message or "Invalid pipeline configuration" + # The specific cause (e.g. an Avro/Protobuf schema compilation error) + # is in details.error; surface it instead of the generic message. + detail = e.details.get("error") if e.details else None + if detail: + message = f"{message}: {detail}" raise errors.PipelineInvalidConfigurationError( status_code=e.status_code, - message=e.message or "Invalid pipeline configuration", + message=message, + response=e.response, + details=e.details, ) from e except errors.APIError as e: self._track_event(event_name, error_type="InternalServerError") diff --git a/tests/test_ee.py b/tests/test_ee.py new file mode 100644 index 0000000..48446f0 --- /dev/null +++ b/tests/test_ee.py @@ -0,0 +1,100 @@ +"""Tests for the Enterprise (ee) client and pipeline scaffold. + +DLQ-specific Enterprise capabilities are covered in a follow-up PR. +""" + +from unittest.mock import patch + +import pytest + +from glassflow import ee +from glassflow.etl import errors +from glassflow.etl.client import Client as OSSClient +from glassflow.etl.pipeline import Pipeline as OSSPipeline +from tests.data import mock_responses + + +@pytest.fixture +def ee_pipeline(valid_config): + """Fixture for an Enterprise Pipeline with a valid config.""" + config = ee.PipelineConfig(**valid_config) + return ee.Pipeline(host="http://localhost:8080", config=config) + + +class TestEEInheritance: + """The ee classes extend, not replace, the OSS ones.""" + + def test_ee_client_subclasses_oss(self): + assert issubclass(ee.Client, OSSClient) + + def test_ee_pipeline_subclasses_oss(self): + assert issubclass(ee.Pipeline, OSSPipeline) + + +class TestEEWiring: + """Edition propagates from Client to the Pipeline it returns.""" + + def test_client_constructs_ee_pipeline(self): + assert ee.Client._pipeline_class is ee.Pipeline + + def test_get_pipeline_returns_ee_pipeline( + self, mock_success, get_pipeline_response, get_health_payload + ): + client = ee.Client(host="http://localhost:8080") + with mock_success( + [get_pipeline_response, get_health_payload("test-pipeline-id")] + ): + pipeline = client.get_pipeline("test-pipeline-id") + + assert isinstance(pipeline, ee.Pipeline) + + +class TestGetStreams: + @pytest.fixture + def ee_pipeline_by_id(self): + return ee.Pipeline(host="http://localhost:8080", pipeline_id="p1") + + def test_get_streams_success(self, ee_pipeline_by_id, mock_success, mock_track): + payload = { + "pipeline_id": "p1", + "streams": [ + {"stream_name": "gfm-abc-DLQ", "component": "dlq"}, + {"stream_name": "gfm-abc-ingestor", "component": "ingestor"}, + ], + } + with mock_success(json_payloads=[payload]) as mock_get: + streams = ee_pipeline_by_id.get_streams() + + mock_get.assert_called_once_with("GET", "/api/v1/pipeline/p1/streams") + assert streams == payload["streams"] + assert streams[0]["component"] == "dlq" + + def test_get_streams_empty_on_204(self, ee_pipeline_by_id, mock_track): + resp = mock_responses.create_mock_response_factory()( + status_code=204, json_data=None + ) + with patch("httpx.Client.request", return_value=resp): + assert ee_pipeline_by_id.get_streams() == [] + + def test_get_streams_not_found(self, ee_pipeline_by_id, mock_track): + resp = mock_responses.create_mock_response_factory()( + status_code=404, json_data={"message": "not found"} + ) + with patch( + "httpx.Client.request", side_effect=resp.raise_for_status.side_effect + ): + with pytest.raises(errors.PipelineNotFoundError): + ee_pipeline_by_id.get_streams() + + def test_get_streams_forbidden_maps_to_feature_not_licensed( + self, ee_pipeline_by_id, mock_track + ): + resp = mock_responses.create_mock_response_factory()( + status_code=403, json_data={"message": "Forbidden"} + ) + with patch( + "httpx.Client.request", side_effect=resp.raise_for_status.side_effect + ): + with pytest.raises(errors.FeatureNotLicensedError) as exc: + ee_pipeline_by_id.get_streams() + assert isinstance(exc.value, errors.ForbiddenError) diff --git a/tests/test_ee_dlq.py b/tests/test_ee_dlq.py new file mode 100644 index 0000000..8d53dc6 --- /dev/null +++ b/tests/test_ee_dlq.py @@ -0,0 +1,279 @@ +"""Tests for the Enterprise DLQ: list / reprocess / discard.""" + +from unittest.mock import patch + +import pytest + +from glassflow import ee +from glassflow.etl import errors +from glassflow.etl.dlq import DLQ as OSSDLQ +from tests.data import mock_responses + + +@pytest.fixture +def ee_dlq(): + return ee.DLQ(host="http://localhost:8080", pipeline_id="test-pipeline") + + +@pytest.fixture +def ee_pipeline(valid_config): + config = ee.PipelineConfig(**valid_config) + return ee.Pipeline(host="http://localhost:8080", config=config) + + +class TestEEDLQWiring: + def test_ee_dlq_subclasses_oss(self): + assert issubclass(ee.DLQ, OSSDLQ) + + def test_pipeline_exposes_ee_dlq(self, ee_pipeline): + assert isinstance(ee_pipeline.dlq, ee.DLQ) + + def test_dlq_setter_still_works(self, ee_pipeline, ee_dlq): + ee_pipeline.dlq = ee_dlq + assert ee_pipeline.dlq is ee_dlq + + def test_get_pipeline_dlq_is_ee( + self, mock_success, get_pipeline_response, get_health_payload + ): + client = ee.Client(host="http://localhost:8080") + with mock_success( + [get_pipeline_response, get_health_payload("test-pipeline-id")] + ): + pipeline = client.get_pipeline("test-pipeline-id") + assert isinstance(pipeline.dlq, ee.DLQ) + + +class TestList: + def test_list_success(self, ee_dlq, mock_success): + payload = { + "messages": [ + { + "message_id": "seq_101", + "component": "sink", + "error": "connection refused", + "original_message": "{}", + "received_at": "2026-05-29T14:00:00Z", + } + ], + "next_cursor": "seq_101", + "has_more": True, + } + with mock_success(json_payloads=[payload]) as mock_get: + result = ee_dlq.list(batch_size=50) + + mock_get.assert_called_once_with( + "GET", f"{ee_dlq.endpoint}/list", params={"batch_size": 50} + ) + assert result == payload + assert result["messages"][0]["message_id"] == "seq_101" + + def test_list_with_cursor(self, ee_dlq, mock_success): + with mock_success(json_payloads=[{"messages": [], "has_more": False}]) as m: + ee_dlq.list(batch_size=10, cursor="seq_200") + m.assert_called_once_with( + "GET", + f"{ee_dlq.endpoint}/list", + params={"batch_size": 10, "cursor": "seq_200"}, + ) + + @pytest.mark.parametrize( + "component", ["ingestor", "join", "sink", "dedup", "oltp-receiver"] + ) + def test_list_with_valid_component_filter(self, ee_dlq, mock_success, component): + with mock_success(json_payloads=[{"messages": [], "has_more": False}]) as m: + ee_dlq.list(batch_size=10, component=component) + m.assert_called_once_with( + "GET", + f"{ee_dlq.endpoint}/list", + params={"batch_size": 10, "component": component}, + ) + + @pytest.mark.parametrize("bad", ["otlp-receiver", "Sink", "", "transform"]) + def test_list_invalid_component_raises_client_side(self, ee_dlq, bad): + # Validated client-side before any HTTP call (no mock needed). + with pytest.raises(ValueError, match="component must be one of"): + ee_dlq.list(component=bad) + + def test_list_empty_on_204(self, ee_dlq): + mock_response = mock_responses.create_mock_response_factory()( + status_code=204, json_data=None + ) + with patch("httpx.Client.request", return_value=mock_response): + assert ee_dlq.list() == {"messages": [], "has_more": False} + + @pytest.mark.parametrize("bad", [0, 1001, -1, "10"]) + def test_list_invalid_batch_size(self, ee_dlq, bad): + with pytest.raises(ValueError, match="batch_size must be an integer"): + ee_dlq.list(batch_size=bad) + + +class TestListIter: + def test_pages_through_and_advances_cursor(self, ee_dlq, mock_success): + page1 = { + "messages": [{"message_id": "1"}, {"message_id": "2"}], + "next_cursor": "2", + "has_more": True, + } + page2 = {"messages": [{"message_id": "3"}], "has_more": False} + with mock_success(json_payloads=[page1, page2]) as mock_get: + ids = [m["message_id"] for m in ee_dlq.list_iter(batch_size=2)] + + assert ids == ["1", "2", "3"] + assert mock_get.call_count == 2 + # Second page resumes from the first page's next_cursor. + assert mock_get.call_args_list[1].kwargs["params"] == { + "batch_size": 2, + "cursor": "2", + } + + def test_forwards_component_and_is_lazy(self, ee_dlq, mock_success): + page = {"messages": [{"message_id": "1"}], "has_more": False} + with mock_success(json_payloads=[page]) as mock_get: + it = ee_dlq.list_iter(component="sink") + # Generator is lazy: no request until first iteration. + assert mock_get.call_count == 0 + next(it) + assert mock_get.call_args_list[0].kwargs["params"] == { + "batch_size": 100, + "component": "sink", + } + + def test_stops_when_has_more_without_cursor(self, ee_dlq, mock_success): + # Defensive guard: truthy has_more but no next_cursor must not loop. + page = {"messages": [{"message_id": "1"}], "has_more": True} + with mock_success(json_payloads=[page]) as mock_get: + ids = [m["message_id"] for m in ee_dlq.list_iter()] + assert ids == ["1"] + assert mock_get.call_count == 1 + + +class TestReprocess: + def test_reprocess_selected(self, ee_dlq, mock_success): + with mock_success( + json_payloads=[{"request_id": "rep_1", "status": "accepted"}] + ) as mock_post: + result = ee_dlq.reprocess(["seq_101", "seq_102"]) + + mock_post.assert_called_once_with( + "POST", + f"{ee_dlq.endpoint}/reprocess", + json={"mode": "selected", "message_ids": ["seq_101", "seq_102"]}, + ) + assert result == {"request_id": "rep_1", "status": "accepted"} + + def test_reprocess_all(self, ee_dlq, mock_success): + with mock_success( + json_payloads=[{"request_id": "rep_2", "status": "accepted"}] + ) as mock_post: + ee_dlq.reprocess_all() + mock_post.assert_called_once_with( + "POST", f"{ee_dlq.endpoint}/reprocess", json={"mode": "all"} + ) + + def test_reprocess_empty_ids_raises(self, ee_dlq): + with pytest.raises(ValueError, match="must be non-empty"): + ee_dlq.reprocess([]) + + def test_reprocess_too_many_ids_raises(self, ee_dlq): + with pytest.raises(ValueError, match="cannot exceed 1000"): + ee_dlq.reprocess([str(i) for i in range(1001)]) + + +class TestDiscard: + def test_discard_selected(self, ee_dlq, mock_success): + with mock_success( + json_payloads=[{"request_id": "dis_1", "discarded_count": 2}] + ) as mock_post: + result = ee_dlq.discard(["seq_1", "seq_2"]) + + mock_post.assert_called_once_with( + "POST", + f"{ee_dlq.endpoint}/discard", + json={"mode": "selected", "message_ids": ["seq_1", "seq_2"]}, + ) + assert result == {"request_id": "dis_1", "discarded_count": 2} + + def test_discard_all(self, ee_dlq, mock_success): + with mock_success( + json_payloads=[{"request_id": "dis_2", "discarded_count": 9}] + ) as mock_post: + ee_dlq.discard_all() + mock_post.assert_called_once_with( + "POST", f"{ee_dlq.endpoint}/discard", json={"mode": "all"} + ) + + def test_discard_empty_ids_raises(self, ee_dlq): + with pytest.raises(ValueError, match="must be non-empty"): + ee_dlq.discard([]) + + +class TestDeprecatedInherited: + def test_consume_warns_and_delegates_to_list(self, ee_dlq, mock_success): + envelope = {"messages": [{"message_id": "seq_1"}], "has_more": False} + with mock_success(json_payloads=[envelope]) as mock_get: + with pytest.warns(DeprecationWarning, match="use DLQ.list"): + result = ee_dlq.consume(batch_size=25) + + # Hits /dlq/list, not /dlq/consume. + mock_get.assert_called_once_with( + "GET", f"{ee_dlq.endpoint}/list", params={"batch_size": 25} + ) + # consume() unwraps the envelope to the legacy list shape. + assert result == [{"message_id": "seq_1"}] + + def test_purge_warns_and_hits_purge_endpoint(self, ee_dlq, mock_success): + with mock_success() as mock_post: + with pytest.warns(DeprecationWarning, match="use DLQ.discard_all"): + ee_dlq.purge() + + mock_post.assert_called_once_with("POST", f"{ee_dlq.endpoint}/purge") + + +class TestEntitlement: + def test_forbidden_maps_to_feature_not_licensed(self, ee_dlq): + mock_response = mock_responses.create_mock_response_factory()( + status_code=403, json_data={"message": "Forbidden"} + ) + with patch( + "httpx.Client.request", + side_effect=mock_response.raise_for_status.side_effect, + ): + with pytest.raises(errors.FeatureNotLicensedError) as exc_info: + ee_dlq.reprocess_all() + + assert "Enterprise" in str(exc_info.value) + # Still catchable as a ForbiddenError by existing 403 handling. + assert isinstance(exc_info.value, errors.ForbiddenError) + + +class TestPipelineState: + def _conflict_patch(self): + mock_response = mock_responses.create_mock_response_factory()( + status_code=409, json_data={"message": "pipeline is not running"} + ) + return patch( + "httpx.Client.request", + side_effect=mock_response.raise_for_status.side_effect, + ) + + def test_reprocess_on_non_running_raises_pipeline_not_running(self, ee_dlq): + with self._conflict_patch(): + with pytest.raises(errors.PipelineNotRunningError) as exc_info: + ee_dlq.reprocess(["seq_1"]) + + assert "Running" in str(exc_info.value) + # Still catchable as the generic 409 ConflictError. + assert isinstance(exc_info.value, errors.ConflictError) + + def test_reprocess_all_on_non_running_raises_pipeline_not_running(self, ee_dlq): + with self._conflict_patch(): + with pytest.raises(errors.PipelineNotRunningError): + ee_dlq.reprocess_all() + + def test_discard_409_stays_conflict_error(self, ee_dlq): + # Discard has no Running-state constraint, so a 409 is not remapped. + with self._conflict_patch(): + with pytest.raises(errors.ConflictError) as exc_info: + ee_dlq.discard_all() + + assert not isinstance(exc_info.value, errors.PipelineNotRunningError) diff --git a/tests/test_models/test_source_formats.py b/tests/test_models/test_source_formats.py new file mode 100644 index 0000000..2306e10 --- /dev/null +++ b/tests/test_models/test_source_formats.py @@ -0,0 +1,210 @@ +"""Tests for Kafka source formats (json/avro/protobuf) and the source registry. + +The unified ``source_schema`` (wire key ``schema``) holds all schema config: +``fields`` (json), ``file`` (the inline avsc/proto text), ``message_type`` +(protobuf), and read-only ``parsed_fields`` (returned on GET). The legacy +top-level ``schema_fields`` is still accepted on input and upgraded to +``schema.fields``. The source registry lets an out-of-tree source type plug in +without changing OSS models. +""" + +from typing import Literal +from unittest.mock import patch + +import pytest + +from glassflow.etl import errors, models +from glassflow.etl.models import registry +from tests.data import mock_responses + + +def _kafka(**overrides) -> dict: + base = { + "type": "kafka", + "source_id": "events", + "connection_params": {"brokers": ["b:9092"], "protocol": "PLAINTEXT"}, + "topic": "events", + "consumer_group_initial_offset": "earliest", + } + base.update(overrides) + return base + + +AVSC_TEXT = ( + '{"type": "record", "name": "Event", "fields": [{"name": "id", "type": "string"}]}' +) +PROTO_TEXT = 'syntax = "proto3";\npackage test;\nmessage Event {\n string id = 1;\n}' + + +class TestJsonFormat: + def test_format_defaults_to_none_and_is_omitted(self): + src = models.KafkaSource.model_validate(_kafka()) + assert src.format is None + assert "format" not in src.model_dump(by_alias=True, exclude_none=True) + + def test_legacy_schema_fields_accepted_and_upgraded(self): + # The deprecated top-level schema_fields is folded into schema.fields. + wire = _kafka(format="json", schema_fields=[{"name": "id", "type": "string"}]) + src = models.KafkaSource.model_validate(wire) + assert src.source_schema.fields[0].name == "id" + assert src.schema_fields[0].name == "id" # compat accessor + dumped = src.model_dump(by_alias=True, exclude_none=True) + assert dumped["schema"] == {"fields": [{"name": "id", "type": "string"}]} + assert "schema_fields" not in dumped # upgraded to the unified schema + + def test_schema_fields_round_trips(self): + wire = _kafka( + format="json", schema={"fields": [{"name": "id", "type": "string"}]} + ) + src = models.KafkaSource.model_validate(wire) + assert src.schema_fields[0].name == "id" + dumped = src.model_dump(by_alias=True, exclude_none=True) + assert dumped["schema"] == {"fields": [{"name": "id", "type": "string"}]} + + +class TestAvroFormat: + def test_avro_round_trips(self): + src = models.KafkaSource.model_validate( + _kafka(format="avro", schema={"file": AVSC_TEXT}) + ) + assert src.format == models.KafkaFormat.AVRO + assert src.source_schema.file == AVSC_TEXT + dumped = src.model_dump(by_alias=True, exclude_none=True) + assert dumped["schema"] == {"file": AVSC_TEXT} + + def test_avro_requires_file(self): + with pytest.raises(ValueError, match="avro format requires schema.file"): + models.KafkaSource.model_validate(_kafka(format="avro")) + + +class TestProtobufFormat: + def test_protobuf_round_trips(self): + src = models.KafkaSource.model_validate( + _kafka( + format="protobuf", + schema={"file": PROTO_TEXT, "message_type": "Event"}, + ) + ) + assert src.format == models.KafkaFormat.PROTOBUF + assert src.source_schema.file == PROTO_TEXT + assert src.source_schema.message_type == "Event" + dumped = src.model_dump(by_alias=True, exclude_none=True) + assert dumped["schema"] == {"file": PROTO_TEXT, "message_type": "Event"} + + def test_protobuf_requires_file_and_message_type(self): + with pytest.raises(ValueError, match="schema.file and schema.message_type"): + models.KafkaSource.model_validate( + _kafka(format="protobuf", schema={"file": PROTO_TEXT}) + ) + + +class TestParsedFields: + def test_parsed_fields_read_only(self): + # Shape the backend returns on GET for an avro source. + src = models.KafkaSource.model_validate( + _kafka( + format="avro", + schema={ + "file": AVSC_TEXT, + "parsed_fields": [{"name": "id", "type": "string"}], + }, + ) + ) + # Available as informational backend output... + assert src.source_schema.parsed_fields[0].name == "id" + # ...but not surfaced via the schema_fields compat accessor (json only), + assert src.schema_fields is None + # ...and not emitted back on dump. + dumped = src.model_dump(by_alias=True, exclude_none=True) + assert dumped["schema"] == {"file": AVSC_TEXT} + assert "parsed_fields" not in dumped["schema"] + + +class TestSchemaFormatConsistency: + def test_file_without_avro_or_protobuf_format_rejected(self): + with pytest.raises(ValueError, match="require format 'avro' or 'protobuf'"): + models.KafkaSource.model_validate( + _kafka(format="json", schema={"file": AVSC_TEXT}) + ) + + +class TestSchemaErrorSurfacing: + """A backend 422 puts the specific cause in details.error; the SDK surfaces + it on the create/edit path instead of the generic message.""" + + def test_invalid_schema_surfaces_backend_detail(self, pipeline, mock_track): + resp = mock_responses.create_mock_response_factory()( + status_code=422, + json_data={ + "status": 422, + "code": "unprocessable_entity", + "message": "failed to convert request to pipeline model", + "details": {"error": 'source "events": proto compilation error: boom'}, + }, + ) + with patch( + "httpx.Client.request", side_effect=resp.raise_for_status.side_effect + ): + with pytest.raises(errors.PipelineInvalidConfigurationError) as exc: + pipeline.create() + + assert "proto compilation error: boom" in str(exc.value) + assert exc.value.details["error"].startswith('source "events"') + + +# --- Source registry: an out-of-tree source type plugs in ------------------- + + +class KinesisSource(models.SourceBaseConfig): + type: Literal["kinesis"] = "kinesis" + stream_name: str + region: str + + +@pytest.fixture +def register_kinesis(): + registry.register_source(KinesisSource) + yield + registry._SOURCE_CLASSES.pop("kinesis", None) + + +class TestRegisteredSourceType: + def _config(self, source: dict) -> dict: + return { + "pipeline_id": "p1", + "sources": [source], + "sink": { + "type": "clickhouse", + "connection_params": { + "host": "h", + "port": "9000", + "database": "db", + "username": "u", + "password": "p", + "secure": True, + }, + "table": "t", + "max_batch_size": 1, + "mapping": [ + {"name": "id", "column_name": "id", "column_type": "String"} + ], + }, + } + + def test_kinesis_dispatches_and_roundtrips(self, register_kinesis): + kinesis = { + "type": "kinesis", + "source_id": "k1", + "stream_name": "events", + "region": "eu-central-1", + } + cfg = models.PipelineConfig.model_validate(self._config(kinesis)) + assert isinstance(cfg.sources[0], KinesisSource) + dumped = cfg.model_dump(by_alias=True, exclude_none=True) + assert dumped["sources"][0]["stream_name"] == "events" + + def test_unknown_source_type_raises(self): + with pytest.raises(ValueError, match="Unknown source type 'pubsub'"): + models.PipelineConfig.model_validate( + self._config({"type": "pubsub", "source_id": "x"}) + )