diff --git a/VERSION b/VERSION index cb2b00e..b502146 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -3.0.1 +3.0.2 diff --git a/src/glassflow/etl/models/__init__.py b/src/glassflow/etl/models/__init__.py index 3b5ae08..4f39346 100644 --- a/src/glassflow/etl/models/__init__.py +++ b/src/glassflow/etl/models/__init__.py @@ -8,7 +8,7 @@ JoinSourceConfigPatch, JoinType, ) -from .pipeline import PipelineConfig, PipelineConfigPatch +from .pipeline import PipelineConfig, PipelineConfigPatch, PipelineStatus from .sink import SinkConfig, SinkConfigPatch, SinkType, TableMapping from .source import ( ConsumerGroupOffset, @@ -40,6 +40,7 @@ "JoinType", "PipelineConfig", "PipelineConfigPatch", + "PipelineStatus", "SinkConfig", "SinkType", "TableMapping", diff --git a/src/glassflow/etl/models/pipeline.py b/src/glassflow/etl/models/pipeline.py index 8d79905..9555bb8 100644 --- a/src/glassflow/etl/models/pipeline.py +++ b/src/glassflow/etl/models/pipeline.py @@ -4,12 +4,24 @@ from pydantic import BaseModel, Field, field_validator, model_validator from ..errors import InvalidDataTypeMappingError +from .base import CaseInsensitiveStrEnum from .data_types import kafka_to_clickhouse_data_type_mappings from .join import JoinConfig, JoinConfigPatch from .sink import SinkConfig, SinkConfigPatch from .source import SourceConfig, SourceConfigPatch +class PipelineStatus(CaseInsensitiveStrEnum): + CREATED = "Created" + RUNNING = "Running" + PAUSING = "Pausing" + PAUSED = "Paused" + RESUMING = "Resuming" + TERMINATING = "Terminating" + TERMINATED = "Terminated" + FAILED = "Failed" + + class PipelineConfig(BaseModel): pipeline_id: str name: Optional[str] = Field(default=None) diff --git a/src/glassflow/etl/pipeline.py b/src/glassflow/etl/pipeline.py index ef8406d..616f1b5 100644 --- a/src/glassflow/etl/pipeline.py +++ b/src/glassflow/etl/pipeline.py @@ -52,6 +52,7 @@ def __init__( self.config = None self._dlq = DLQ(pipeline_id=self.pipeline_id, host=host) + self.status: models.PipelineStatus | None = None def get(self) -> Pipeline: """Fetch a pipeline by its ID. @@ -67,6 +68,7 @@ def get(self) -> Pipeline: "GET", f"{self.ENDPOINT}/{self.pipeline_id}", event_name="PipelineGet" ) self.config = models.PipelineConfig.model_validate(response.json()) + self.status = models.PipelineStatus(response.json()["status"]) self._dlq = DLQ(pipeline_id=self.pipeline_id, host=self.host) return self @@ -94,6 +96,7 @@ def create(self) -> Pipeline: ), event_name="PipelineCreated", ) + self.status = models.PipelineStatus.CREATED return self except errors.ForbiddenError as e: @@ -160,6 +163,7 @@ def delete(self, terminate: bool = True) -> None: self.get() endpoint = f"{self.ENDPOINT}/{self.pipeline_id}/terminate" self._request("DELETE", endpoint, event_name="PipelineDeleted") + self.status = models.PipelineStatus.TERMINATING def pause(self) -> Pipeline: """Pauses the pipeline with the given ID. @@ -191,11 +195,13 @@ def health(self) -> dict[str, Any]: Returns: dict: Pipeline health """ - return self._request( + response = self._request( "GET", f"{self.ENDPOINT}/{self.pipeline_id}/health", event_name="PipelineHealth", ).json() + self.status = models.PipelineStatus(response["overall_status"]) + return response def to_dict(self) -> dict[str, Any]: """Convert the pipeline configuration to a dictionary. diff --git a/tests/conftest.py b/tests/conftest.py index ed717fc..322338f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,6 +21,14 @@ def valid_config() -> dict: return pipeline_configs.get_valid_pipeline_config() +@pytest.fixture +def get_pipeline_response(valid_config) -> dict: + """Fixture for a valid pipeline configuration with status.""" + config = valid_config + config["status"] = "Running" + return config + + @pytest.fixture def valid_config_without_joins() -> dict: """Fixture for a valid pipeline configuration without joins.""" @@ -88,11 +96,11 @@ def mock_connection_error(): @pytest.fixture -def mock_success_get_pipeline(valid_config): +def mock_success_get_pipeline(get_pipeline_response): """Fixture for a successful GET pipeline response.""" return mock_responses.create_mock_response_factory()( status_code=200, - json_data=valid_config, + json_data=get_pipeline_response, ) diff --git a/tests/test_client.py b/tests/test_client.py index 8de60ad..fe1694c 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -18,12 +18,14 @@ def test_client_init(self): assert client.host == "https://example.com" assert client.http_client.base_url == "https://example.com" - def test_client_get_pipeline_success(self, valid_config, mock_success_response): + def test_client_get_pipeline_success( + self, get_pipeline_response, mock_success_response + ): """Test successful pipeline retrieval by ID.""" client = Client() pipeline_id = "test-pipeline-id" - mock_success_response.json.return_value = valid_config + mock_success_response.json.return_value = get_pipeline_response with patch( "httpx.Client.request", return_value=mock_success_response diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index eb31c38..b8f32c9 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -91,14 +91,14 @@ def test_lifecycle_operations( method, endpoint, params, - valid_config, + get_pipeline_response, ): """Test common pipeline lifecycle operations.""" with patch( "httpx.Client.request", return_value=mock_success_response ) as mock_request: if method == "GET": - mock_request.return_value.json.return_value = valid_config + mock_request.return_value.json.return_value = get_pipeline_response result = getattr(pipeline, operation)(**params) expected_endpoint = f"{pipeline.ENDPOINT}/{pipeline.pipeline_id}{endpoint}" mock_request.assert_called_once_with(method, expected_endpoint)