Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions src/glassflow/ee/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

from typing import Any, Dict, List

from glassflow.etl import errors
from glassflow.etl.pipeline import Pipeline as _OSSPipeline

from .dlq import DLQ
Expand All @@ -10,8 +13,9 @@ class Pipeline(_OSSPipeline):

Extends the open-source :class:`glassflow.etl.pipeline.Pipeline`. Its ``dlq``
property exposes the Enterprise :class:`~.dlq.DLQ` (with
``list``/``reprocess``/``discard``). Construction is inherited unchanged;
only the DLQ collaborator class is swapped via ``_dlq_class``.
``list``/``reprocess``/``discard``), and it adds :meth:`get_streams`.
Construction is inherited unchanged; only the DLQ collaborator class is
swapped via ``_dlq_class``.
"""

_dlq_class = DLQ
Expand All @@ -24,3 +28,35 @@ def dlq(self) -> DLQ:
@dlq.setter
def dlq(self, dlq: DLQ) -> None:
self._dlq = dlq

def get_streams(self) -> List[Dict[str, Any]]:
"""Return the NATS JetStream streams backing this pipeline.

Each entry has a ``stream_name`` and the ``component`` the stream belongs
to (for example ``ingestor``, ``join``, ``sink``, ``dedup``, ``dlq``).
Useful for diagnosing NATS-level issues.

Returns:
List of ``{"stream_name": ..., "component": ...}`` dicts.

Raises:
PipelineNotFoundError: If the pipeline does not exist.
FeatureNotLicensedError: If the backend is not licensed for this.
APIError: If the API request fails.
"""
try:
response = self._request(
"GET",
f"{self.ENDPOINT}/{self.pipeline_id}/streams",
event_name="PipelineStreamsGet",
)
except errors.ForbiddenError as e:
raise errors.FeatureNotLicensedError(
status_code=e.status_code,
message="Pipeline streams require a GlassFlow Enterprise license",
response=e.response,
details=e.details,
) from e
if response.status_code == 204 or not response.content:
return []
return response.json().get("streams", [])
55 changes: 55 additions & 0 deletions tests/test_ee.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
DLQ-specific Enterprise capabilities are covered in a follow-up PR.
"""

from unittest.mock import patch

import pytest

from glassflow import ee
from glassflow.etl import errors
from glassflow.etl.client import Client as OSSClient
from glassflow.etl.pipeline import Pipeline as OSSPipeline
from tests.data import mock_responses


@pytest.fixture
Expand Down Expand Up @@ -43,3 +47,54 @@ def test_get_pipeline_returns_ee_pipeline(
pipeline = client.get_pipeline("test-pipeline-id")

assert isinstance(pipeline, ee.Pipeline)


class TestGetStreams:
@pytest.fixture
def ee_pipeline_by_id(self):
return ee.Pipeline(host="http://localhost:8080", pipeline_id="p1")

def test_get_streams_success(self, ee_pipeline_by_id, mock_success, mock_track):
payload = {
"pipeline_id": "p1",
"streams": [
{"stream_name": "gfm-abc-DLQ", "component": "dlq"},
{"stream_name": "gfm-abc-ingestor", "component": "ingestor"},
],
}
with mock_success(json_payloads=[payload]) as mock_get:
streams = ee_pipeline_by_id.get_streams()

mock_get.assert_called_once_with("GET", "/api/v1/pipeline/p1/streams")
assert streams == payload["streams"]
assert streams[0]["component"] == "dlq"

def test_get_streams_empty_on_204(self, ee_pipeline_by_id, mock_track):
resp = mock_responses.create_mock_response_factory()(
status_code=204, json_data=None
)
with patch("httpx.Client.request", return_value=resp):
assert ee_pipeline_by_id.get_streams() == []

def test_get_streams_not_found(self, ee_pipeline_by_id, mock_track):
resp = mock_responses.create_mock_response_factory()(
status_code=404, json_data={"message": "not found"}
)
with patch(
"httpx.Client.request", side_effect=resp.raise_for_status.side_effect
):
with pytest.raises(errors.PipelineNotFoundError):
ee_pipeline_by_id.get_streams()

def test_get_streams_forbidden_maps_to_feature_not_licensed(
self, ee_pipeline_by_id, mock_track
):
resp = mock_responses.create_mock_response_factory()(
status_code=403, json_data={"message": "Forbidden"}
)
with patch(
"httpx.Client.request", side_effect=resp.raise_for_status.side_effect
):
with pytest.raises(errors.FeatureNotLicensedError) as exc:
ee_pipeline_by_id.get_streams()
assert isinstance(exc.value, errors.ForbiddenError)