From 35a52eef3bf5150aef376e8788d64d3137efbefa Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Thu, 25 Sep 2025 17:23:59 -0700 Subject: [PATCH 01/39] wip commit --- python/mkdocs.yml | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/python/mkdocs.yml b/python/mkdocs.yml index 00ae85827..1644828f4 100644 --- a/python/mkdocs.yml +++ b/python/mkdocs.yml @@ -35,6 +35,7 @@ theme: code: IBM Plex Mono features: - navigation.instant + - navigation.instant.progress - navigation.tabs - navigation.sections - navigation.tracking @@ -52,12 +53,21 @@ extra: nav: - Home: index.md + - Examples: + - examples/sift_client.ipynb + - Sift Py API + - Sift Client API (New) +# - Guides: +# - Logging +# - Error Handling plugins: - search - autorefs - mike: # For docs versioning deploy_prefix: 'python' # In case we want to use doc sites for other client libs too + - mkdocs-jupyter: + include_source: True - mkdocstrings: default_handler: python handlers: From 768b69dd2069470f4b83617ee99ff7dcdf4e3708 Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Thu, 25 Sep 2025 17:31:40 -0700 Subject: [PATCH 02/39] change extras isntall testing to only be all or none --- python/scripts/build_utils.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/python/scripts/build_utils.py b/python/scripts/build_utils.py index 877658faa..48a247e2e 100644 --- a/python/scripts/build_utils.py +++ b/python/scripts/build_utils.py @@ -4,7 +4,6 @@ import os import subprocess import venv -from itertools import combinations from pathlib import Path from typing import List, Optional from zipfile import ZipFile @@ -35,7 +34,7 @@ def get_extras_from_wheel(wheel_path: str) -> List[str]: def get_extra_combinations(extras: List[str], exclude: Optional[List[str]] = None) -> List[str]: - """Generate all possible combinations of extras. + """Generate different extras permutations for install testing. Args: extras: List of extra names to generate combinations from. @@ -50,10 +49,8 @@ def get_extra_combinations(extras: List[str], exclude: Optional[List[str]] = Non else: filtered_extras = extras - all_combinations = [] - for r in range(len(filtered_extras) + 1): - all_combinations.extend(",".join(c) for c in combinations(filtered_extras, r)) - return all_combinations + # Get only the full extras lists, no additional permutations at the moment + return [','.join(filtered_extras)] def test_install( @@ -107,7 +104,7 @@ def main(): # Get all extras from the wheel extras = get_extras_from_wheel(str(wheel_file)) - combinations = get_extra_combinations(extras, ["development", "docs"]) + combinations = get_extra_combinations(extras, exclude=["development", "docs"]) # Test base installation first test_install( From d763b05dd79dc5121b2dbe2e9d32c1e2390ba976 Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Thu, 25 Sep 2025 18:50:27 -0700 Subject: [PATCH 03/39] add support for .env files and integration test marker for pytest --- python/pyproject.toml | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/python/pyproject.toml b/python/pyproject.toml index cc833a58c..035f64551 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -54,6 +54,7 @@ development = [ "pytest-asyncio==0.23.7", "pytest-benchmark==4.0.0", "pytest-mock==3.14.0", + "pytest-dotenv==0.5.2", "ruff~=0.12.10", ] build = ["pdoc==14.5.0", "build==1.2.1"] @@ -200,3 +201,12 @@ select = [ "N", # pep8-naming "TID", # flake8-tidy-imports ] + +[tool.pytest.ini_options] +env_files = [ + ".env" +] +markers = [ + "integration: mark a test as an integration test (requires API)" +] + From 5d46aa795c2810d41aa35882dccd3ad600ba7d88 Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Fri, 26 Sep 2025 09:06:53 -0700 Subject: [PATCH 04/39] add integration tests --- .../_internal/low_level_wrappers/base.py | 6 +- .../_internal/low_level_wrappers/__init__.py | 0 .../_internal/low_level_wrappers/test_base.py | 203 +++++++++++++ .../lib/sift_client/_tests/integrated/runs.py | 286 ------------------ .../sift_client/_tests/resources/__init__.py | 0 .../_tests/resources/test_assets.py | 215 +++++++++++++ .../sift_client/_tests/resources/test_ping.py | 74 +++++ .../sift_client/_tests/resources/test_runs.py | 267 ++++++++++++++++ python/lib/sift_client/_tests/test_client.py | 0 python/scripts/dev | 17 ++ 10 files changed, 781 insertions(+), 287 deletions(-) create mode 100644 python/lib/sift_client/_tests/_internal/low_level_wrappers/__init__.py create mode 100644 python/lib/sift_client/_tests/_internal/low_level_wrappers/test_base.py delete mode 100644 python/lib/sift_client/_tests/integrated/runs.py create mode 100644 python/lib/sift_client/_tests/resources/__init__.py create mode 100644 python/lib/sift_client/_tests/resources/test_assets.py create mode 100644 python/lib/sift_client/_tests/resources/test_ping.py create mode 100644 python/lib/sift_client/_tests/resources/test_runs.py create mode 100644 python/lib/sift_client/_tests/test_client.py diff --git a/python/lib/sift_client/_internal/low_level_wrappers/base.py b/python/lib/sift_client/_internal/low_level_wrappers/base.py index a8b93ed68..d349e51a2 100644 --- a/python/lib/sift_client/_internal/low_level_wrappers/base.py +++ b/python/lib/sift_client/_internal/low_level_wrappers/base.py @@ -22,7 +22,7 @@ async def _handle_pagination( page_size: The number of results to return per page. page_token: The token to use for the next page. order_by: How to order the retrieved results. - max_results: Maximum number of results to return. NOTE: Will be in increments of page_size or default page size defined by the call if no page_size is provided. + max_results: Maximum number of results to return. Returns: A list of all matching results. @@ -31,6 +31,8 @@ async def _handle_pagination( kwargs = {} results: list[Any] = [] + if max_results == 0: + return results if page_token is None: page_token = "" while True: @@ -45,4 +47,6 @@ async def _handle_pagination( results.extend(response) if page_token == "": break + if max_results and len(results) > max_results: + results = results[:max_results] return results diff --git a/python/lib/sift_client/_tests/_internal/low_level_wrappers/__init__.py b/python/lib/sift_client/_tests/_internal/low_level_wrappers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/lib/sift_client/_tests/_internal/low_level_wrappers/test_base.py b/python/lib/sift_client/_tests/_internal/low_level_wrappers/test_base.py new file mode 100644 index 000000000..7bb9ab941 --- /dev/null +++ b/python/lib/sift_client/_tests/_internal/low_level_wrappers/test_base.py @@ -0,0 +1,203 @@ +"""Tests for LowLevelClientBase. + +These tests validate the functionality of the LowLevelClientBase class including: +- Pagination handling with various scenarios +- Edge cases and error handling +- Parameter validation and behavior +""" + +from unittest.mock import AsyncMock + +import pytest + +from sift_client._internal.low_level_wrappers.base import LowLevelClientBase + + +class TestLowLevelClientBase: + """Test suite for LowLevelClientBase functionality.""" + + class TestHandlePagination: + """Tests for the _handle_pagination static method.""" + + @pytest.mark.asyncio + async def test_basic_pagination_single_page(self): + """Test pagination with a single page of results.""" + # Mock function that returns results and empty page token (indicating no more pages) + mock_func = AsyncMock(return_value=([1, 2, 3], "")) + + results = await LowLevelClientBase._handle_pagination(mock_func) + + assert results == [1, 2, 3] + mock_func.assert_called_once_with( + page_size=None, + page_token="", + order_by=None + ) + + @pytest.mark.asyncio + async def test_pagination_multiple_pages(self): + """Test pagination across multiple pages.""" + # Mock function that returns different results for different page tokens + mock_func = AsyncMock() + mock_func.side_effect = [ + ([1, 2, 3], "token1"), # First page + ([4, 5, 6], "token2"), # Second page + ([7, 8, 9], ""), # Last page (empty token) + ] + + results = await LowLevelClientBase._handle_pagination(mock_func) + + assert results == [1, 2, 3, 4, 5, 6, 7, 8, 9] + assert mock_func.call_count == 3 + + # Verify the calls were made with correct page tokens + calls = mock_func.call_args_list + assert calls[0][1]["page_token"] == "" + assert calls[1][1]["page_token"] == "token1" + assert calls[2][1]["page_token"] == "token2" + + @pytest.mark.asyncio + async def test_pagination_with_page_size(self): + """Test pagination with specified page size.""" + mock_func = AsyncMock(return_value=([1, 2], "")) + + results = await LowLevelClientBase._handle_pagination( + mock_func, + page_size=2 + ) + + assert results == [1, 2] + mock_func.assert_called_once_with( + page_size=2, + page_token="", + order_by=None + ) + + @pytest.mark.asyncio + async def test_pagination_with_order_by(self): + """Test pagination with order_by parameter.""" + mock_func = AsyncMock(return_value=([1, 2, 3], "")) + + results = await LowLevelClientBase._handle_pagination( + mock_func, + order_by="name asc" + ) + + assert results == [1, 2, 3] + mock_func.assert_called_once_with( + page_size=None, + page_token="", + order_by="name asc" + ) + + @pytest.mark.asyncio + async def test_pagination_with_initial_page_token(self): + """Test pagination starting with a specific page token.""" + mock_func = AsyncMock(return_value=([4, 5, 6], "")) + + results = await LowLevelClientBase._handle_pagination( + mock_func, + page_token="start_token" + ) + + assert results == [4, 5, 6] + mock_func.assert_called_once_with( + page_size=None, + page_token="start_token", + order_by=None + ) + + @pytest.mark.asyncio + async def test_pagination_with_kwargs(self): + """Test pagination with additional keyword arguments.""" + mock_func = AsyncMock(return_value=([1, 2, 3], "")) + kwargs = {"filter": "active", "include_archived": False} + + results = await LowLevelClientBase._handle_pagination( + mock_func, + kwargs=kwargs + ) + + assert results == [1, 2, 3] + mock_func.assert_called_once_with( + page_size=None, + page_token="", + order_by=None, + filter="active", + include_archived=False + ) + + @pytest.mark.asyncio + async def test_pagination_with_max_results_single_page(self): + """Test pagination with max_results that fits in a single page.""" + mock_func = AsyncMock(return_value=([1, 2, 3, 4, 5], "")) + + results = await LowLevelClientBase._handle_pagination( + mock_func, + max_results=3 + ) + + # Should return only the max results + assert results == [1, 2, 3] + mock_func.assert_called_once() + + @pytest.mark.asyncio + async def test_pagination_with_max_results_multiple_pages(self): + """Test pagination with max_results across multiple pages.""" + mock_func = AsyncMock() + mock_func.side_effect = [ + ([1, 2, 3], "token1"), # First page (3 items) + ([4, 5, 6], "token2"), # Second page (6 total items, exceeds max_results=5) + ] + + results = await LowLevelClientBase._handle_pagination( + mock_func, + max_results=5 + ) + + # Should include 2 pages and return the full first page but limited 2nd page + assert results == [1, 2, 3, 4, 5] + assert mock_func.call_count == 2 + + @pytest.mark.asyncio + async def test_pagination_with_max_results_exact_match(self): + """Test pagination when results exactly match max_results.""" + mock_func = AsyncMock() + mock_func.side_effect = [ + ([1, 2, 3], "token1"), # First page + ([4, 5], ""), # Second page, total = 5 + ] + + results = await LowLevelClientBase._handle_pagination( + mock_func, + max_results=5 + ) + + assert results == [1, 2, 3, 4, 5] + assert mock_func.call_count == 2 + + @pytest.mark.asyncio + async def test_pagination_empty_results(self): + """Test pagination when function returns empty results.""" + mock_func = AsyncMock(return_value=([], "")) + + results = await LowLevelClientBase._handle_pagination(mock_func) + + assert results == [] + mock_func.assert_called_once() + + + @pytest.mark.asyncio + async def test_pagination_max_results_zero(self): + """Test pagination with max_results=0.""" + mock_func = AsyncMock(return_value=([1, 2, 3], "")) + + results = await LowLevelClientBase._handle_pagination( + mock_func, + max_results=0 + ) + + # Should return empty list without calling the function + assert results == [] + mock_func.assert_not_called() + diff --git a/python/lib/sift_client/_tests/integrated/runs.py b/python/lib/sift_client/_tests/integrated/runs.py deleted file mode 100644 index 9c55c2aa1..000000000 --- a/python/lib/sift_client/_tests/integrated/runs.py +++ /dev/null @@ -1,286 +0,0 @@ -#!/usr/bin/env python3 -"""This test demonstrates the usage of the Runs API. - -It creates a new run, updates it, and associates assets with it. -It also lists runs, filters them, and deletes the run. - -It uses the SiftClient to interact with the API. -""" - -import asyncio -import os -from datetime import datetime, timedelta, timezone - -from sift_client import SiftClient - - -async def main(): - """Main function demonstrating the Runs API usage.""" - # Initialize the client - # You can set these environment variables or pass them directly - grpc_url = os.getenv("SIFT_GRPC_URI", "localhost:50051") - rest_url = os.getenv("SIFT_REST_URI", "localhost:8080") - api_key = os.getenv("SIFT_API_KEY", "") - client = SiftClient( - api_key=api_key, - grpc_url=grpc_url, - rest_url=rest_url, - ) - - # Use a known asset to fetch a run. - asset = client.assets.find(name="NostromoLV426") - asset_id = asset.id_ - print(f"Using asset: {asset.name} (ID: {asset_id})") - - # List runs for this asset - runs = asset.runs - print( - f"Found {len(runs)} run(s): {[run.name for run in runs]} for asset {asset.name} (ID: {asset_id})" - ) - - # Pick one. - run = runs[0] - run_id = run.id_ - print(f"Using run: {run.name} (ID: {run_id})") - - # List other assets for this run. - all_assets = run.assets - other_assets = [asset for asset in all_assets if asset.id_ != asset_id] - print( - f"Found {len(other_assets)} other asset(s): {other_assets} for run {run.name} (ID: {run_id})" - ) - - # Example 1: List all runs - print("\n1. Listing all runs...") - runs = client.runs.list_(limit=5) - print(f" Found {len(runs)} runs:") - for run in runs: - print(f" - {run.name} (ID: {run.id_}), Organization ID: {run.organization_id}") - - # Example 2: Test different filter options - print("\n2. Testing different filter options...") - - # Get a sample run for testing filters - sample_runs = client.runs.list_(limit=3) - if not sample_runs: - print(" No runs available for filter testing") - return - - sample_run = sample_runs[0] - - # 2a: Filter by exact name - print("\n 2a. Filter by exact name...") - run_name = sample_run.name - runs = client.runs.list_(name=run_name, limit=5) - print(f" Found {len(runs)} runs with exact name '{run_name}':") - for run in runs: - print(f" - {run.name} (ID: {run.id_})") - - # 2b: Filter by name containing text - print("\n 2b. Filter by name containing text...") - runs = client.runs.list_(name_contains="test", limit=5) - print(f" Found {len(runs)} runs with 'test' in name:") - for run in runs: - print(f" - {run.name}") - - # 2c: Filter by name using regex - print("\n 2c. Filter by name using regex...") - runs = client.runs.list_(name_regex=".*test.*", limit=5) - print(f" Found {len(runs)} runs with 'test' in name (regex):") - for run in runs: - print(f" - {run.name}") - - # 2d: Filter by exact description - print("\n 2d. Filter by description contains...") - if sample_run.description: - runs = client.runs.list_(description_contains=sample_run.description, limit=5) - print(f" Found {len(runs)} runs with description contains '{sample_run.description}':") - for run in runs: - print(f" - {run.name}: {run.description}") - else: - print(" No description available for testing") - - # 2e: Filter by description containing text - print("\n 2e. Filter by description containing text...") - runs = client.runs.list_(description_contains="test", limit=5) - print(f" Found {len(runs)} runs with 'test' in description:") - for run in runs: - print(f" - {run.name}: {run.description}") - - # 2f: Filter by duration seconds - print("\n 2f. Filter by duration seconds...") - # Calculate duration for sample run if it has start and stop times - if sample_run.start_time and sample_run.stop_time: - duration_seconds = int((sample_run.stop_time - sample_run.start_time).total_seconds()) - runs = client.runs.list_(duration_greater_than=timedelta(seconds=duration_seconds), limit=5) - print(f" Found {len(runs)} runs with duration greater than {duration_seconds} seconds:") - for run in runs: - if run.start_time and run.stop_time: - run_duration = int((run.stop_time - run.start_time).total_seconds()) - print(f" - {run.name} (duration: {run_duration}s)") - else: - print(" No start/stop times available for duration testing") - - # 2g: Filter by client key - print("\n 2g. Filter by client key...") - if sample_run.client_key: - run = client.runs.get(client_key=sample_run.client_key) - print(f" Found run with client key '{run.name}'") - else: - print(" No client key available for testing") - - # 2h: Filter by asset ID - print("\n 2h. Filter by asset ID...") - if sample_run.asset_ids: - asset_id = sample_run.asset_ids[0] - runs = client.runs.list_(assets=[asset_id], limit=5) - print(f" Found {len(runs)} runs associated with asset {asset_id}:") - for run in runs: - print(f" - {run.name} (asset_ids: {list(run.asset_ids)})") - else: - print(" No asset IDs available for testing") - - # 2i: Filter by asset name - print("\n 2i. Filter by asset name...") - runs = client.runs.list_(assets=[client.assets.find(name="NostromoLV426")], limit=5) - print(f" Found {len(runs)} runs associated with asset 'NostromoLV426':") - for run in runs: - print(f" - {run.name}") - - # 2j: Filter by created by user ID - print("\n 2j. Filter by created by user ID...") - created_by = sample_run.created_by_user_id - runs = client.runs.list_(created_by=created_by, limit=5) - print(f" Found {len(runs)} runs created by user {created_by}:") - for run in runs: - print(f" - {run.name} (created by: {run.created_by})") - - # 2l: Test ordering options - print("\n 2l. Testing ordering options...") - - # Order by name ascending - runs = client.runs.list_(order_by="name", limit=3) - print(" First 3 runs ordered by name (ascending):") - for run in runs: - print(f" - {run.name}") - - # Order by name descending - runs = client.runs.list_(order_by="name desc", limit=3) - print(" First 3 runs ordered by name (descending):") - for run in runs: - print(f" - {run.name}") - - # Order by creation date (newest first - default) - runs = client.runs.list_(order_by="created_date desc", limit=3) - print(" First 3 runs ordered by creation date (newest first):") - for run in runs: - print(f" - {run.name} (created: {run.created_date})") - - # Order by creation date (oldest first) - runs = client.runs.list_(order_by="created_date", limit=3) - print(" First 3 runs ordered by creation date (oldest first):") - for run in runs: - print(f" - {run.name} (created: {run.created_date})") - - # Example 3: Find a single run by name - print("\n3. Finding a single run by name...") - run_name = "test-run" # Replace with an actual run name - run = client.runs.find(name=run_name) - if run: - print(f" Found run: {run.name}") - print(f" Description: {run.description}") - else: - print(f" No run found with name '{run_name}'") - - # Example 4: Create a new run - print("\n4. Creating a new run...") - # Create metadata for the run - metadata = { - "environment": "production", - "test_type": "integration", - } - - # Create a run with start and stop times - start_time = datetime.now(timezone.utc) - stop_time = start_time + timedelta(minutes=2) - - previously_created_runs = client.runs.list_(name_regex="Example Test Run.*") - if previously_created_runs: - print(f" Deleting previously created runs: {previously_created_runs}") - for run in previously_created_runs: - print(f" Deleting run: {run.name}") - client.runs.archive(run=run) - - new_run = client.runs.create( - dict( - name=f"Example Test Run {datetime.now(tz=timezone.utc).strftime('%Y-%m-%d %H:%M:%S')}", - description="A test run created via the API", - tags=["api-created", "test"], - start_time=start_time, - stop_time=stop_time, - # Use a unique client key for each run - client_key=f"example-run-key-{datetime.now(tz=timezone.utc).timestamp()}", - metadata=metadata, - ) - ) - print(f" Created run: {new_run.name} (ID: {new_run.id_})") - print(f" Client key: {new_run.client_key}") - print(f" Tags: {new_run.tags}") - - # Example 5: Update a run - print("\n5. Updating a run...") - - run_to_update = new_run - print(f" Updating run: {run_to_update.name}") - - # Update the run - new_description = "Updated description via API" - new_metadata = { - "test_type": "ci", - } - new_tags = ["updated", "api-modified"] - updated_run = client.runs.update( - run=run_to_update, - update={ - "description": new_description, - "tags": new_tags, - "metadata": new_metadata, - }, - ) - print(f" Updated run: {updated_run.name}") - print(f" New description: {updated_run.description}") - print(f" New tags: {updated_run.tags}") - print(f" New metadata: {updated_run.metadata}") - assert updated_run.description == new_description - assert sorted(updated_run.tags) == sorted(new_tags) - assert updated_run.metadata == new_metadata - - # Example 6: Associate assets with a run - print("\n6. Associating assets with a run...") - ongoing_runs = client.runs.list_( - name_regex="Example Test Run.*", include_archived=True, is_stopped=False - ) - if ongoing_runs: - print(" Ensuring previously created runs are stopped:") - for run in ongoing_runs: - if run.stop_time is None: - print(f" Stopping run: {run.name}") - client.runs.stop(run=run) - - # Get a run to associate assets with - asset_names = ["asset1", "asset2"] # Replace with actual asset names - print(f" Associating assets {asset_names} with run: {new_run.name}") - - client.runs.create_automatic_association_for_assets(run=new_run, asset_names=asset_names) - print(f" Successfully associated assets with run: {new_run.name}") - - # Example 7: Delete a run - print("\n7. Deleting a run") - run_to_delete = new_run - print(f" Deleting run: {run_to_delete.name}") - client.runs.archive(run=run_to_delete) - print(f" Successfully archived run: {run_to_delete.name}") - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/python/lib/sift_client/_tests/resources/__init__.py b/python/lib/sift_client/_tests/resources/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/lib/sift_client/_tests/resources/test_assets.py b/python/lib/sift_client/_tests/resources/test_assets.py new file mode 100644 index 000000000..9ed200c36 --- /dev/null +++ b/python/lib/sift_client/_tests/resources/test_assets.py @@ -0,0 +1,215 @@ +"""Pytest tests for the Assets API. + +These tests demonstrate and validate the usage of the Assets API including: +- Basic asset operations (get, list, find) +- Asset filtering and searching +- Asset updates and archiving +- Error handling and edge cases +""" + +import os + +import pytest + +from sift_client import SiftClient +from sift_client.resources import AssetsAPI, AssetsAPIAsync +from sift_client.sift_types import Asset + +pytestmark = pytest.mark.integration + + +@pytest.fixture(scope="session") +def sift_client() -> SiftClient: + """Create a SiftClient instance for testing.""" + grpc_url = os.getenv("SIFT_GRPC_URI", "localhost:50051") + rest_url = os.getenv("SIFT_REST_URI", "localhost:8080") + api_key = os.getenv("SIFT_API_KEY", "") + + return SiftClient( + api_key=api_key, + grpc_url=grpc_url, + rest_url=rest_url, + ) + +def test_client_binding(sift_client): + assert sift_client.assets + assert isinstance(sift_client.assets, AssetsAPI) + assert sift_client.async_.assets + assert isinstance(sift_client.async_.assets, AssetsAPIAsync) + + + +@pytest.fixture +def assets_api_async(sift_client: SiftClient): + """Get the async assets API instance.""" + return sift_client.async_.assets + + +@pytest.fixture +def assets_api_sync(sift_client: SiftClient): + """Get the synchronous assets API instance.""" + return sift_client.assets + +@pytest.fixture +def test_asset(assets_api_sync): + assets = assets_api_sync.list_(limit=1) + assert assets + assert len(assets) >= 1 + return assets[0] + + +class TestAssetsAPIAsync: + """Test suite for the async Assets API functionality.""" + + class TestList: + """Tests for the async list_ method.""" + + @pytest.mark.asyncio + async def test_basic_list(self, assets_api_async): + """Test basic asset listing functionality.""" + assets = await assets_api_async.list_(limit=5) + + # Verify we get a list + assert isinstance(assets, list) + assert assets + + # If we have assets, verify their structure + asset = assets[0] + assert isinstance(asset, Asset) + assert asset.id_ is not None + assert asset.name is not None + + @pytest.mark.asyncio + async def test_list_with_name_filter(self, assets_api_async): + """Test asset listing with name filtering.""" + # First get some assets to work with + all_assets = await assets_api_async.list_(limit=10) + + if all_assets: + # Use the first asset's name for filtering + test_asset_name = all_assets[0].name + filtered_assets = await assets_api_async.list_(name=test_asset_name) + + # Should find at least one asset with exact name match + assert isinstance(filtered_assets, list) + assert len(filtered_assets) >= 1 + + # All returned assets should have the exact name + for asset in filtered_assets: + assert asset.name == test_asset_name + + @pytest.mark.asyncio + async def test_list_with_name_contains_filter(self, assets_api_async): + """Test asset listing with name contains filtering.""" + # Test with a common substring that might exist in asset names + assets = await assets_api_async.list_(name_contains="test", limit=5) + + assert isinstance(assets, list) + + # If we found assets, verify they contain the substring + for asset in assets: + assert "test" in asset.name.lower() + + @pytest.mark.asyncio + async def test_list_with_limit(self, assets_api_async): + """Test asset listing with different limits.""" + # Test with limit of 1 + assets_1 = await assets_api_async.list_(limit=1) + assert isinstance(assets_1, list) + assert len(assets_1) <= 1 + + # Test with limit of 3 + assets_3 = await assets_api_async.list_(limit=3) + assert isinstance(assets_3, list) + assert len(assets_3) <= 3 + + @pytest.mark.asyncio + async def test_list_include_archived(self, assets_api_async): + """Test asset listing with archived assets included.""" + # Test without archived assets (default) + assets_active = await assets_api_async.list_(limit=5, include_archived=False) + assert isinstance(assets_active, list) + + # Test with archived assets included + assets_all = await assets_api_async.list_(limit=5, include_archived=True) + assert isinstance(assets_all, list) + + # Should have at least as many assets when including archived + assert len(assets_all) >= len(assets_active) + + class TestGet: + """Tests for the async get method.""" + + @pytest.mark.asyncio + async def test_get_by_name(self, assets_api_async, test_asset): + """Test getting a specific asset by name.""" + retrieved_asset = await assets_api_async.get(name=test_asset.name) + + assert retrieved_asset is not None + assert retrieved_asset.id_ == test_asset.id_ + assert retrieved_asset.name == test_asset.name + + @pytest.mark.asyncio + async def test_get_by_id(self, assets_api_async, test_asset): + """Test getting a specific asset by ID.""" + retrieved_asset = await assets_api_async.get(asset_id=test_asset.id_) + + assert retrieved_asset is not None + assert retrieved_asset.id_ == test_asset.id_ + + @pytest.mark.asyncio + async def test_get_without_params_raises_error(self, assets_api_async): + """Test that getting an asset without parameters raises an error.""" + with pytest.raises(ValueError, match="Either asset_id or name must be provided"): + await assets_api_async.get() + + @pytest.mark.asyncio + async def test_get_nonexistent_asset_raises_error(self, assets_api_async): + """Test that getting a non-existent asset raises an error.""" + with pytest.raises(ValueError, match="No asset found"): + await assets_api_async.get(name="nonexistent-asset-name-12345") + + class TestFind: + """Tests for the async find method.""" + + @pytest.mark.asyncio + async def test_find_asset(self, assets_api_async, test_asset): + """Test finding a single asset.""" + # Find the same asset by name + found_asset = await assets_api_async.find(name=test_asset.name) + + assert found_asset is not None + assert found_asset.id_ == test_asset.id_ + + @pytest.mark.asyncio + async def test_find_nonexistent_asset(self, assets_api_async): + """Test finding a non-existent asset returns None.""" + found_asset = await assets_api_async.find(name="nonexistent-asset-name-12345") + assert found_asset is None + + @pytest.mark.asyncio + async def test_find_multiple_raises_error(self, assets_api_async): + """Test finding multiple assets raises an error.""" + with pytest.raises(ValueError, match="Multiple"): + await assets_api_async.find(name_contains="a") + + +class TestAssetsAPISync: + """Test suite for the synchronous Assets API functionality. + + Only includes a single test for basic sync generation. No specific sync behavior difference tests are needed. + """ + + class TestList: + """Tests for the sync list_ method.""" + + def test_basic_list(self, assets_api_sync): + """Test basic synchronous asset listing functionality.""" + assets = assets_api_sync.list_(limit=5) + + # Verify we get a list + assert isinstance(assets, list) + assert assets + assert isinstance(assets[0], Asset) + + diff --git a/python/lib/sift_client/_tests/resources/test_ping.py b/python/lib/sift_client/_tests/resources/test_ping.py new file mode 100644 index 000000000..6f26c6c33 --- /dev/null +++ b/python/lib/sift_client/_tests/resources/test_ping.py @@ -0,0 +1,74 @@ +"""Pytest tests for the Ping API. + +These tests demonstrate and validate the usage of the Ping API including: +- Basic ping functionality +- Connection health checks +- Error handling and edge cases +""" + +import os + +import pytest + +from sift_client import SiftClient +from sift_client.resources import PingAPI, PingAPIAsync + +pytestmark = pytest.mark.integration + +@pytest.fixture(scope="session") +def sift_client() -> SiftClient: + """Create a SiftClient instance for testing.""" + grpc_url = os.getenv("SIFT_GRPC_URI", "localhost:50051") + rest_url = os.getenv("SIFT_REST_URI", "localhost:8080") + api_key = os.getenv("SIFT_API_KEY", "") + + return SiftClient( + api_key=api_key, + grpc_url=grpc_url, + rest_url=rest_url, + ) + +def test_client_binding(sift_client): + assert sift_client.ping + assert isinstance(sift_client.ping, PingAPI) + assert sift_client.async_.ping + assert isinstance(sift_client.async_.ping, PingAPIAsync) + + +@pytest.fixture +def ping_api_async(sift_client: SiftClient): + """Get the ping async API instance.""" + return sift_client.async_.ping + +@pytest.fixture +def ping_api_sync(sift_client: SiftClient): + """Get the synchronous ping API instance.""" + return sift_client.ping + +class TestPingAPIAsync: + """Test suite for the Ping API functionality.""" + + @pytest.mark.asyncio + async def test_basic_ping(self, ping_api_async): + """Test basic ping functionality.""" + response = await ping_api_async.ping() + + # Verify response is a string + assert isinstance(response, str) + + # Verify response is not empty + assert len(response) > 0 + + +class TestPingAPISync: + """Test suite for the Ping API functionality.""" + + def test_basic_ping(self, ping_api_sync): + """Test basic synchronous ping functionality.""" + response = ping_api_sync.ping() + + # Verify response is a string + assert isinstance(response, str) + + # Verify response is not empty + assert len(response) > 0 diff --git a/python/lib/sift_client/_tests/resources/test_runs.py b/python/lib/sift_client/_tests/resources/test_runs.py new file mode 100644 index 000000000..0e3490fbd --- /dev/null +++ b/python/lib/sift_client/_tests/resources/test_runs.py @@ -0,0 +1,267 @@ +"""Pytest tests for the Runs API. + +These tests demonstrate and validate the usage of the Runs API including: +- Basic run operations (get, list, find) +- Run filtering and searching +- Run creation, updates, and archiving +- Error handling and edge cases +""" + +import os + +import pytest + +from sift_client import SiftClient +from sift_client.resources import RunsAPI, RunsAPIAsync +from sift_client.sift_types import Run + +pytestmark = pytest.mark.integration + + +@pytest.fixture(scope="session") +def sift_client() -> SiftClient: + """Create a SiftClient instance for testing.""" + grpc_url = os.getenv("SIFT_GRPC_URI", "localhost:50051") + rest_url = os.getenv("SIFT_REST_URI", "localhost:8080") + api_key = os.getenv("SIFT_API_KEY", "") + + return SiftClient( + api_key=api_key, + grpc_url=grpc_url, + rest_url=rest_url, + ) + +def test_client_binding(sift_client): + assert sift_client.runs + assert isinstance(sift_client.runs, RunsAPI) + assert sift_client.async_.runs + assert isinstance(sift_client.async_.runs, RunsAPIAsync) + +@pytest.fixture +def runs_api_async(sift_client: SiftClient): + """Get the async runs API instance.""" + return sift_client.async_.runs + + +@pytest.fixture +def runs_api_sync(sift_client: SiftClient): + """Get the synchronous runs API instance.""" + return sift_client.runs + + +@pytest.fixture +def test_run(runs_api_sync): + runs = runs_api_sync.list_(limit=1) + assert runs + assert len(runs) >= 1 + return runs[0] + + +class TestRunsAPIAsync: + """Test suite for the async Runs API functionality.""" + + class TestList: + """Tests for the async list method.""" + + @pytest.mark.asyncio + async def test_basic_list(self, runs_api_async): + """Test basic run listing functionality.""" + runs = await runs_api_async.list_(limit=5) + + # Verify we get a list + assert isinstance(runs, list) + assert runs + + # If we have runs, verify their structure + run = runs[0] + assert isinstance(run, Run) + assert run.id_ is not None + assert run.name is not None + + @pytest.mark.asyncio + async def test_list_with_name_filter(self, runs_api_async): + """Test run listing with name filtering.""" + # First get some runs to work with + all_runs = await runs_api_async.list_(limit=10) + + if all_runs: + # Use the first run's name for filtering + test_run_name = all_runs[0].name + filtered_runs = await runs_api_async.list_(name=test_run_name) + + # Should find at least one run with exact name match + assert isinstance(filtered_runs, list) + assert len(filtered_runs) >= 1 + + # All returned runs should have the exact name + for run in filtered_runs: + assert run.name == test_run_name + + @pytest.mark.asyncio + async def test_list_with_name_contains_filter(self, runs_api_async): + """Test run listing with name contains filtering.""" + # Test with a common substring that might exist in run names + runs = await runs_api_async.list_(name_contains="test", limit=5) + + assert isinstance(runs, list) + + # If we found runs, verify they contain the substring + for run in runs: + assert "test" in run.name.lower() + + @pytest.mark.asyncio + async def test_list_with_limit(self, runs_api_async): + """Test run listing with different limits.""" + # Test with limit of 1 + runs_1 = await runs_api_async.list_(limit=1) + assert isinstance(runs_1, list) + assert len(runs_1) <= 1 + + # Test with limit of 3 + runs_3 = await runs_api_async.list_(limit=3) + assert isinstance(runs_3, list) + assert len(runs_3) <= 3 + + @pytest.mark.asyncio + async def test_list_include_archived(self, runs_api_async): + """Test run listing with archived runs included.""" + # Test without archived runs (default) + runs_active = await runs_api_async.list_(limit=5, include_archived=False) + assert isinstance(runs_active, list) + + # Test with archived runs included + runs_all = await runs_api_async.list_(limit=5, include_archived=True) + assert isinstance(runs_all, list) + + # Should have at least as many runs when including archived + assert len(runs_all) >= len(runs_active) + + class TestGet: + """Tests for the async get method.""" + + @pytest.mark.asyncio + async def test_get_by_name(self, runs_api_async, test_run): + """Test getting a specific run by name.""" + retrieved_run = await runs_api_async.get(name=test_run.name) + + assert retrieved_run is not None + assert retrieved_run.id_ == test_run.id_ + assert retrieved_run.name == test_run.name + + @pytest.mark.asyncio + async def test_get_by_id(self, runs_api_async, test_run): + """Test getting a specific run by ID.""" + retrieved_run = await runs_api_async.get(run_id=test_run.id_) + + assert retrieved_run is not None + assert retrieved_run.id_ == test_run.id_ + + @pytest.mark.asyncio + async def test_get_without_params_raises_error(self, runs_api_async): + """Test that getting a run without parameters raises an error.""" + with pytest.raises(ValueError, match="Either run_id or name must be provided"): + await runs_api_async.get() + + @pytest.mark.asyncio + async def test_get_nonexistent_run_raises_error(self, runs_api_async): + """Test that getting a non-existent run raises an error.""" + with pytest.raises(ValueError, match="No run found"): + await runs_api_async.get(name="nonexistent-run-name-12345") + + class TestFind: + """Tests for the async find method.""" + + @pytest.mark.asyncio + async def test_find_run(self, runs_api_async, test_run): + """Test finding a single run.""" + # Find the same run by name + found_run = await runs_api_async.find(name=test_run.name) + + assert found_run is not None + assert found_run.id_ == test_run.id_ + + @pytest.mark.asyncio + async def test_find_nonexistent_run(self, runs_api_async): + """Test finding a non-existent run returns None.""" + found_run = await runs_api_async.find(name="nonexistent-run-name-12345") + assert found_run is None + + @pytest.mark.asyncio + async def test_find_multiple_raises_error(self, runs_api_async): + """Test finding multiple runs raises an error.""" + with pytest.raises(ValueError, match="Multiple"): + await runs_api_async.find(name_contains="a") + + + class TestErrorHandling: + """Tests for error handling scenarios.""" + + @pytest.mark.asyncio + async def test_invalid_client_configuration(self): + """Test that invalid client configurations are handled gracefully.""" + # Create a client with invalid configuration + invalid_client = SiftClient( + api_key="invalid-key", + grpc_url="invalid-url:99999", + rest_url="invalid-url:99999", + ) + + # The client should be created but API calls should fail gracefully + runs_api = invalid_client.async_.runs + + # This should raise an appropriate error, not crash + with pytest.raises(Exception): # Could be connection error, auth error, etc. + await runs_api.list_(limit=1) + + @pytest.mark.asyncio + async def test_integration_with_ping(self, runs_api_async, sift_client): + """Test that runs API works in conjunction with ping API.""" + # First verify connectivity with ping + ping_response = await sift_client.async_.ping.ping() + assert isinstance(ping_response, str) + assert len(ping_response) > 0 + + # Then test runs API + runs = await runs_api_async.list_(limit=1) + assert isinstance(runs, list) + + +class TestRunsAPISync: + """Test suite for the synchronous Runs API functionality. + + Only includes a single test for basic sync generation. No specific sync behavior difference tests are needed. + """ + + class TestList: + """Tests for the sync list method.""" + + def test_basic_list(self, runs_api_sync): + """Test basic synchronous run listing functionality.""" + runs = runs_api_sync.list_(limit=5) + + # Verify we get a list + assert isinstance(runs, list) + assert runs + assert isinstance(runs[0], Run) + + class TestGet: + """Tests for the sync get method.""" + + def test_basic_get(self, runs_api_sync, test_run): + """Test basic synchronous get functionality.""" + retrieved_run = runs_api_sync.get(name=test_run.name) + + assert retrieved_run is not None + assert isinstance(retrieved_run, Run) + assert retrieved_run.id_ == test_run.id_ + + class TestFind: + """Tests for the sync find method.""" + + def test_basic_find(self, runs_api_sync, test_run): + """Test basic synchronous find functionality.""" + found_run = runs_api_sync.find(name=test_run.name) + + assert found_run is not None + assert isinstance(found_run, Run) + assert found_run.id_ == test_run.id_ diff --git a/python/lib/sift_client/_tests/test_client.py b/python/lib/sift_client/_tests/test_client.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/scripts/dev b/python/scripts/dev index 1ecf01ba9..6329dce2f 100755 --- a/python/scripts/dev +++ b/python/scripts/dev @@ -23,6 +23,8 @@ Subcommands: mypy-stubs Runs stubtest (mypy) on the generated pyi stubs pip-install Install project dependencies test Execute tests + test-integration Execute integration tests + test-all Execute all non-integration tests update-dev Copies changes to the dev script over to the venv bin Options: @@ -90,9 +92,18 @@ doc_build() { } run_tests() { + pytest -m "not integration" +} + +run_tests_integration() { + pytest -m "integration" +} + +run_tests_all() { pytest } + gen_stubs() { source venv/bin/activate cd lib @@ -148,6 +159,12 @@ case "$1" in test) run_tests ;; + test_integration) + run_tests_integration + ;; + test_all) + run_tests_all + ;; gen-stubs) gen_stubs ;; From c33bdc6f9d7e2d8732ec060ffcb3a6e7eaa2f99b Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Mon, 29 Sep 2025 10:03:49 -0700 Subject: [PATCH 05/39] linting and reorganized pre-push hook --- .githooks/pre-push-python/stubs-check.sh | 62 +++++++++++++++++ .../_internal/low_level_wrappers/test_base.py | 68 ++++--------------- .../_tests/resources/test_assets.py | 5 +- .../sift_client/_tests/resources/test_ping.py | 4 ++ .../sift_client/_tests/resources/test_runs.py | 3 +- python/scripts/build_utils.py | 2 +- 6 files changed, 86 insertions(+), 58 deletions(-) create mode 100644 .githooks/pre-push-python/stubs-check.sh diff --git a/.githooks/pre-push-python/stubs-check.sh b/.githooks/pre-push-python/stubs-check.sh new file mode 100644 index 000000000..2d7ab6d38 --- /dev/null +++ b/.githooks/pre-push-python/stubs-check.sh @@ -0,0 +1,62 @@ +#!/bin/bash +set -e + + +# ensure generated python stubs are up-to-date, from sync clients and sift_stream_bindings + +REPO_ROOT="$(git rev-parse --show-toplevel)" +PYTHON_DIR="$REPO_ROOT/python" +BINDINGS_DIR="$REPO_ROOT/rust/crates/sift_stream_bindings" +STUBS_DIR="$PYTHON_DIR/lib/sift_client/resources/sync_stubs" + +# Function to check if generated stub files have changed +check_stub_changes() { + local target_path="$1" + local changed_files=$(git status --porcelain "$target_path" | grep -E '\.pyi$' || true) + + if [ -n "$changed_files" ]; then + echo "ERROR: Generated python stubs are not up-to-date. Please commit the changed files:" + echo "$changed_files" + exit 1 + fi +} + +# Function to generate Python stubs +generate_python_stubs() { + echo "Generating Python stubs..." + cd "$PYTHON_DIR" + + if [[ ! -d "$PYTHON_DIR/venv" ]]; then + echo "Running bootstrap script..." + bash ./scripts/dev bootstrap + fi + + bash ./scripts/dev gen-stubs + check_stub_changes "$STUBS_DIR" +} + +# Function to generate bindings stubs +generate_bindings_stubs() { + echo "Generating bindings stubs..." + cd "$BINDINGS_DIR" + cargo run --bin stub_gen + + # The stub file is generated in the bindings directory + local stub_file="$BINDINGS_DIR/sift_stream_bindings.pyi" + check_stub_changes "$stub_file" +} + +# Check for changes in relevant files +python_changed_files=($(git diff --name-only --diff-filter=ACM | grep '^python/lib/sift_client/' || true)) +bindings_changed_files=($(git diff --name-only --diff-filter=ACM | grep '^rust/crates/sift_stream_bindings/src/' || true)) + +# Generate stubs if needed +if [[ -n "$python_changed_files" ]]; then + generate_python_stubs +fi + +if [[ -n "$bindings_changed_files" ]]; then + generate_bindings_stubs +fi + +echo "All stubs are up-to-date." \ No newline at end of file diff --git a/python/lib/sift_client/_tests/_internal/low_level_wrappers/test_base.py b/python/lib/sift_client/_tests/_internal/low_level_wrappers/test_base.py index 7bb9ab941..02ecf8142 100644 --- a/python/lib/sift_client/_tests/_internal/low_level_wrappers/test_base.py +++ b/python/lib/sift_client/_tests/_internal/low_level_wrappers/test_base.py @@ -28,11 +28,7 @@ async def test_basic_pagination_single_page(self): results = await LowLevelClientBase._handle_pagination(mock_func) assert results == [1, 2, 3] - mock_func.assert_called_once_with( - page_size=None, - page_token="", - order_by=None - ) + mock_func.assert_called_once_with(page_size=None, page_token="", order_by=None) @pytest.mark.asyncio async def test_pagination_multiple_pages(self): @@ -42,7 +38,7 @@ async def test_pagination_multiple_pages(self): mock_func.side_effect = [ ([1, 2, 3], "token1"), # First page ([4, 5, 6], "token2"), # Second page - ([7, 8, 9], ""), # Last page (empty token) + ([7, 8, 9], ""), # Last page (empty token) ] results = await LowLevelClientBase._handle_pagination(mock_func) @@ -61,34 +57,20 @@ async def test_pagination_with_page_size(self): """Test pagination with specified page size.""" mock_func = AsyncMock(return_value=([1, 2], "")) - results = await LowLevelClientBase._handle_pagination( - mock_func, - page_size=2 - ) + results = await LowLevelClientBase._handle_pagination(mock_func, page_size=2) assert results == [1, 2] - mock_func.assert_called_once_with( - page_size=2, - page_token="", - order_by=None - ) + mock_func.assert_called_once_with(page_size=2, page_token="", order_by=None) @pytest.mark.asyncio async def test_pagination_with_order_by(self): """Test pagination with order_by parameter.""" mock_func = AsyncMock(return_value=([1, 2, 3], "")) - results = await LowLevelClientBase._handle_pagination( - mock_func, - order_by="name asc" - ) + results = await LowLevelClientBase._handle_pagination(mock_func, order_by="name asc") assert results == [1, 2, 3] - mock_func.assert_called_once_with( - page_size=None, - page_token="", - order_by="name asc" - ) + mock_func.assert_called_once_with(page_size=None, page_token="", order_by="name asc") @pytest.mark.asyncio async def test_pagination_with_initial_page_token(self): @@ -96,15 +78,12 @@ async def test_pagination_with_initial_page_token(self): mock_func = AsyncMock(return_value=([4, 5, 6], "")) results = await LowLevelClientBase._handle_pagination( - mock_func, - page_token="start_token" + mock_func, page_token="start_token" ) assert results == [4, 5, 6] mock_func.assert_called_once_with( - page_size=None, - page_token="start_token", - order_by=None + page_size=None, page_token="start_token", order_by=None ) @pytest.mark.asyncio @@ -113,10 +92,7 @@ async def test_pagination_with_kwargs(self): mock_func = AsyncMock(return_value=([1, 2, 3], "")) kwargs = {"filter": "active", "include_archived": False} - results = await LowLevelClientBase._handle_pagination( - mock_func, - kwargs=kwargs - ) + results = await LowLevelClientBase._handle_pagination(mock_func, kwargs=kwargs) assert results == [1, 2, 3] mock_func.assert_called_once_with( @@ -124,7 +100,7 @@ async def test_pagination_with_kwargs(self): page_token="", order_by=None, filter="active", - include_archived=False + include_archived=False, ) @pytest.mark.asyncio @@ -132,10 +108,7 @@ async def test_pagination_with_max_results_single_page(self): """Test pagination with max_results that fits in a single page.""" mock_func = AsyncMock(return_value=([1, 2, 3, 4, 5], "")) - results = await LowLevelClientBase._handle_pagination( - mock_func, - max_results=3 - ) + results = await LowLevelClientBase._handle_pagination(mock_func, max_results=3) # Should return only the max results assert results == [1, 2, 3] @@ -150,10 +123,7 @@ async def test_pagination_with_max_results_multiple_pages(self): ([4, 5, 6], "token2"), # Second page (6 total items, exceeds max_results=5) ] - results = await LowLevelClientBase._handle_pagination( - mock_func, - max_results=5 - ) + results = await LowLevelClientBase._handle_pagination(mock_func, max_results=5) # Should include 2 pages and return the full first page but limited 2nd page assert results == [1, 2, 3, 4, 5] @@ -165,13 +135,10 @@ async def test_pagination_with_max_results_exact_match(self): mock_func = AsyncMock() mock_func.side_effect = [ ([1, 2, 3], "token1"), # First page - ([4, 5], ""), # Second page, total = 5 + ([4, 5], ""), # Second page, total = 5 ] - results = await LowLevelClientBase._handle_pagination( - mock_func, - max_results=5 - ) + results = await LowLevelClientBase._handle_pagination(mock_func, max_results=5) assert results == [1, 2, 3, 4, 5] assert mock_func.call_count == 2 @@ -186,18 +153,13 @@ async def test_pagination_empty_results(self): assert results == [] mock_func.assert_called_once() - @pytest.mark.asyncio async def test_pagination_max_results_zero(self): """Test pagination with max_results=0.""" mock_func = AsyncMock(return_value=([1, 2, 3], "")) - results = await LowLevelClientBase._handle_pagination( - mock_func, - max_results=0 - ) + results = await LowLevelClientBase._handle_pagination(mock_func, max_results=0) # Should return empty list without calling the function assert results == [] mock_func.assert_not_called() - diff --git a/python/lib/sift_client/_tests/resources/test_assets.py b/python/lib/sift_client/_tests/resources/test_assets.py index 9ed200c36..10fda6c15 100644 --- a/python/lib/sift_client/_tests/resources/test_assets.py +++ b/python/lib/sift_client/_tests/resources/test_assets.py @@ -31,6 +31,7 @@ def sift_client() -> SiftClient: rest_url=rest_url, ) + def test_client_binding(sift_client): assert sift_client.assets assert isinstance(sift_client.assets, AssetsAPI) @@ -38,7 +39,6 @@ def test_client_binding(sift_client): assert isinstance(sift_client.async_.assets, AssetsAPIAsync) - @pytest.fixture def assets_api_async(sift_client: SiftClient): """Get the async assets API instance.""" @@ -50,6 +50,7 @@ def assets_api_sync(sift_client: SiftClient): """Get the synchronous assets API instance.""" return sift_client.assets + @pytest.fixture def test_asset(assets_api_sync): assets = assets_api_sync.list_(limit=1) @@ -211,5 +212,3 @@ def test_basic_list(self, assets_api_sync): assert isinstance(assets, list) assert assets assert isinstance(assets[0], Asset) - - diff --git a/python/lib/sift_client/_tests/resources/test_ping.py b/python/lib/sift_client/_tests/resources/test_ping.py index 6f26c6c33..ca924f92b 100644 --- a/python/lib/sift_client/_tests/resources/test_ping.py +++ b/python/lib/sift_client/_tests/resources/test_ping.py @@ -15,6 +15,7 @@ pytestmark = pytest.mark.integration + @pytest.fixture(scope="session") def sift_client() -> SiftClient: """Create a SiftClient instance for testing.""" @@ -28,6 +29,7 @@ def sift_client() -> SiftClient: rest_url=rest_url, ) + def test_client_binding(sift_client): assert sift_client.ping assert isinstance(sift_client.ping, PingAPI) @@ -40,11 +42,13 @@ def ping_api_async(sift_client: SiftClient): """Get the ping async API instance.""" return sift_client.async_.ping + @pytest.fixture def ping_api_sync(sift_client: SiftClient): """Get the synchronous ping API instance.""" return sift_client.ping + class TestPingAPIAsync: """Test suite for the Ping API functionality.""" diff --git a/python/lib/sift_client/_tests/resources/test_runs.py b/python/lib/sift_client/_tests/resources/test_runs.py index 0e3490fbd..01e6b3a6e 100644 --- a/python/lib/sift_client/_tests/resources/test_runs.py +++ b/python/lib/sift_client/_tests/resources/test_runs.py @@ -31,12 +31,14 @@ def sift_client() -> SiftClient: rest_url=rest_url, ) + def test_client_binding(sift_client): assert sift_client.runs assert isinstance(sift_client.runs, RunsAPI) assert sift_client.async_.runs assert isinstance(sift_client.async_.runs, RunsAPIAsync) + @pytest.fixture def runs_api_async(sift_client: SiftClient): """Get the async runs API instance.""" @@ -192,7 +194,6 @@ async def test_find_multiple_raises_error(self, runs_api_async): with pytest.raises(ValueError, match="Multiple"): await runs_api_async.find(name_contains="a") - class TestErrorHandling: """Tests for error handling scenarios.""" diff --git a/python/scripts/build_utils.py b/python/scripts/build_utils.py index 48a247e2e..43225ac8e 100644 --- a/python/scripts/build_utils.py +++ b/python/scripts/build_utils.py @@ -50,7 +50,7 @@ def get_extra_combinations(extras: List[str], exclude: Optional[List[str]] = Non filtered_extras = extras # Get only the full extras lists, no additional permutations at the moment - return [','.join(filtered_extras)] + return [",".join(filtered_extras)] def test_install( From 0cc9a4bc55bc27aaacb94a46c4527faee6306ada Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Mon, 29 Sep 2025 10:24:41 -0700 Subject: [PATCH 06/39] break out unit and integration tests --- .github/workflows/python_ci.yaml | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python_ci.yaml b/.github/workflows/python_ci.yaml index 4dbac8c5b..2c86f0447 100644 --- a/.github/workflows/python_ci.yaml +++ b/.github/workflows/python_ci.yaml @@ -44,9 +44,13 @@ jobs: run: | pyright lib - - name: Pytest + - name: Pytest Unit Tests run: | - pytest + pytest -m "not integration" + + - name: Pytest Integration Tests + run: | + pytest -m "integration" - name: Sync Stubs Mypy working-directory: python/lib From 4fb2523b36927de9c3c8c82ae6c86122f7167cbc Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Mon, 29 Sep 2025 10:39:08 -0700 Subject: [PATCH 07/39] linting --- python/lib/sift_client/_tests/resources/test_runs.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/lib/sift_client/_tests/resources/test_runs.py b/python/lib/sift_client/_tests/resources/test_runs.py index 01e6b3a6e..838b3155d 100644 --- a/python/lib/sift_client/_tests/resources/test_runs.py +++ b/python/lib/sift_client/_tests/resources/test_runs.py @@ -9,6 +9,7 @@ import os +import grpc.aio._call import pytest from sift_client import SiftClient @@ -211,7 +212,9 @@ async def test_invalid_client_configuration(self): runs_api = invalid_client.async_.runs # This should raise an appropriate error, not crash - with pytest.raises(Exception): # Could be connection error, auth error, etc. + with pytest.raises( + grpc.aio._call.AioRpcError + ): # Could be connection error, auth error, etc. await runs_api.list_(limit=1) @pytest.mark.asyncio From fcc4bba24fabf17f556e63a487afec2241162321 Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Mon, 29 Sep 2025 10:44:30 -0700 Subject: [PATCH 08/39] update CI --- .github/workflows/python_ci.yaml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/.github/workflows/python_ci.yaml b/.github/workflows/python_ci.yaml index 2c86f0447..0c8968b80 100644 --- a/.github/workflows/python_ci.yaml +++ b/.github/workflows/python_ci.yaml @@ -25,34 +25,42 @@ jobs: python-version: "3.8" - name: Pip install + id: install run: | python -m pip install --upgrade pip pip install '.[development,openssl,tdms,rosbags,hdf5,sift-stream]' - name: Lint + if: steps.install.outcome == 'success' run: | ruff check - name: Format + if: steps.install.outcome == 'success' run: | ruff format --check - name: MyPy + if: steps.install.outcome == 'success' run: | mypy lib - name: Pyright + if: steps.install.outcome == 'success' run: | pyright lib - name: Pytest Unit Tests + if: steps.install.outcome == 'success' run: | pytest -m "not integration" - name: Pytest Integration Tests + if: steps.install.outcome == 'success' run: | pytest -m "integration" - name: Sync Stubs Mypy + if: steps.install.outcome == 'success' working-directory: python/lib run: | stubtest \ From 2d59cd43a226cb6aafad18d7eec14b56021052d8 Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Mon, 29 Sep 2025 10:51:44 -0700 Subject: [PATCH 09/39] CI update to run all python checks in parallel --- .github/workflows/python_ci.yaml | 164 +++++++++++++++++++++++++------ 1 file changed, 135 insertions(+), 29 deletions(-) diff --git a/.github/workflows/python_ci.yaml b/.github/workflows/python_ci.yaml index 0c8968b80..d162cd9f9 100644 --- a/.github/workflows/python_ci.yaml +++ b/.github/workflows/python_ci.yaml @@ -10,59 +10,165 @@ on: workflow_call: jobs: - test-python: + setup: runs-on: ubuntu-latest defaults: run: working-directory: python steps: - name: Checkout code - uses: actions/checkout@v2 + uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: "3.8" - - name: Pip install - id: install + - name: Install dependencies run: | + python -m venv venv + source venv/bin/activate python -m pip install --upgrade pip pip install '.[development,openssl,tdms,rosbags,hdf5,sift-stream]' - - name: Lint - if: steps.install.outcome == 'success' - run: | + + - name: Upload venv + uses: actions/upload-artifact@v4 + with: + name: python-venv + path: python/venv + + lint: + runs-on: ubuntu-latest + needs: setup + defaults: + run: + working-directory: python + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v4 + with: + python-version: "3.8" + - uses: actions/download-artifact@v4 + with: + name: python-venv + path: python/venv + - run: | + source venv/bin/activate ruff check - - name: Format - if: steps.install.outcome == 'success' - run: | + format: + runs-on: ubuntu-latest + needs: setup + defaults: + run: + working-directory: python + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v4 + with: + python-version: "3.8" + - uses: actions/download-artifact@v4 + with: + name: python-venv + path: python/venv + - run: | + source venv/bin/activate ruff format --check - - name: MyPy - if: steps.install.outcome == 'success' - run: | + mypy: + runs-on: ubuntu-latest + needs: setup + defaults: + run: + working-directory: python + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v4 + with: + python-version: "3.8" + - uses: actions/download-artifact@v4 + with: + name: python-venv + path: python/venv + - run: | + source venv/bin/activate mypy lib - - name: Pyright - if: steps.install.outcome == 'success' - run: | + pyright: + runs-on: ubuntu-latest + needs: setup + defaults: + run: + working-directory: python + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v4 + with: + python-version: "3.8" + - uses: actions/download-artifact@v4 + with: + name: python-venv + path: python/venv + - run: | + source venv/bin/activate pyright lib - - name: Pytest Unit Tests - if: steps.install.outcome == 'success' - run: | + unit-tests: + runs-on: ubuntu-latest + needs: setup + defaults: + run: + working-directory: python + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v4 + with: + python-version: "3.8" + - uses: actions/download-artifact@v4 + with: + name: python-venv + path: python/venv + - run: | + source venv/bin/activate pytest -m "not integration" - - name: Pytest Integration Tests - if: steps.install.outcome == 'success' - run: | + integration-tests: + runs-on: ubuntu-latest + needs: setup + defaults: + run: + working-directory: python + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v4 + with: + python-version: "3.8" + - uses: actions/download-artifact@v4 + with: + name: python-venv + path: python/venv + - run: | + source venv/bin/activate pytest -m "integration" - - name: Sync Stubs Mypy - if: steps.install.outcome == 'success' - working-directory: python/lib - run: | + sync-stubs-mypy: + runs-on: ubuntu-latest + needs: setup + defaults: + run: + working-directory: python + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v4 + with: + python-version: "3.8" + - uses: actions/download-artifact@v4 + with: + name: python-venv + path: python/venv + - run: | + source venv/bin/activate + cd lib stubtest \ - --mypy-config-file ../pyproject.toml \ - sift_client.resources.sync_stubs \ No newline at end of file + --mypy-config-file ../pyproject.toml \ + sift_client.resources.sync_stubs From d4abaff85d24cb675cdb630875ebbd66c003f3a2 Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Mon, 29 Sep 2025 10:56:14 -0700 Subject: [PATCH 10/39] update ci to use pip cahce instead of a full venv --- .github/workflows/python_ci.yaml | 120 +++++++++++++++---------------- 1 file changed, 56 insertions(+), 64 deletions(-) diff --git a/.github/workflows/python_ci.yaml b/.github/workflows/python_ci.yaml index d162cd9f9..1f0ec3c24 100644 --- a/.github/workflows/python_ci.yaml +++ b/.github/workflows/python_ci.yaml @@ -1,5 +1,8 @@ name: python-ci +permissions: + contents: read + on: release: types: [ created ] @@ -10,55 +13,32 @@ on: workflow_call: jobs: - setup: - runs-on: ubuntu-latest - defaults: - run: - working-directory: python - steps: - - name: Checkout code - uses: actions/checkout@v4 - - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: "3.8" - - - name: Install dependencies - run: | - python -m venv venv - source venv/bin/activate - python -m pip install --upgrade pip - pip install '.[development,openssl,tdms,rosbags,hdf5,sift-stream]' - - - name: Upload venv - uses: actions/upload-artifact@v4 - with: - name: python-venv - path: python/venv - lint: runs-on: ubuntu-latest - needs: setup defaults: run: working-directory: python steps: - uses: actions/checkout@v4 + - uses: actions/setup-python@v4 with: python-version: "3.8" - - uses: actions/download-artifact@v4 + + - uses: actions/cache@v3 with: - name: python-venv - path: python/venv + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ hashFiles('python/pyproject.toml') }} + restore-keys: | + ${{ runner.os }}-pip- + - run: | - source venv/bin/activate + python -m pip install --upgrade pip + pip install '.[development,openssl,tdms,rosbags,hdf5,sift-stream]' ruff check format: runs-on: ubuntu-latest - needs: setup defaults: run: working-directory: python @@ -67,17 +47,19 @@ jobs: - uses: actions/setup-python@v4 with: python-version: "3.8" - - uses: actions/download-artifact@v4 + - uses: actions/cache@v3 with: - name: python-venv - path: python/venv + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ hashFiles('python/pyproject.toml') }} + restore-keys: | + ${{ runner.os }}-pip- - run: | - source venv/bin/activate + python -m pip install --upgrade pip + pip install '.[development,openssl,tdms,rosbags,hdf5,sift-stream]' ruff format --check mypy: runs-on: ubuntu-latest - needs: setup defaults: run: working-directory: python @@ -86,17 +68,19 @@ jobs: - uses: actions/setup-python@v4 with: python-version: "3.8" - - uses: actions/download-artifact@v4 + - uses: actions/cache@v3 with: - name: python-venv - path: python/venv + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ hashFiles('python/pyproject.toml') }} + restore-keys: | + ${{ runner.os }}-pip- - run: | - source venv/bin/activate + python -m pip install --upgrade pip + pip install '.[development,openssl,tdms,rosbags,hdf5,sift-stream]' mypy lib pyright: runs-on: ubuntu-latest - needs: setup defaults: run: working-directory: python @@ -105,17 +89,19 @@ jobs: - uses: actions/setup-python@v4 with: python-version: "3.8" - - uses: actions/download-artifact@v4 + - uses: actions/cache@v3 with: - name: python-venv - path: python/venv + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ hashFiles('python/pyproject.toml') }} + restore-keys: | + ${{ runner.os }}-pip- - run: | - source venv/bin/activate + python -m pip install --upgrade pip + pip install '.[development,openssl,tdms,rosbags,hdf5,sift-stream]' pyright lib unit-tests: runs-on: ubuntu-latest - needs: setup defaults: run: working-directory: python @@ -124,17 +110,19 @@ jobs: - uses: actions/setup-python@v4 with: python-version: "3.8" - - uses: actions/download-artifact@v4 + - uses: actions/cache@v3 with: - name: python-venv - path: python/venv + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ hashFiles('python/pyproject.toml') }} + restore-keys: | + ${{ runner.os }}-pip- - run: | - source venv/bin/activate + python -m pip install --upgrade pip + pip install '.[development,openssl,tdms,rosbags,hdf5,sift-stream]' pytest -m "not integration" integration-tests: runs-on: ubuntu-latest - needs: setup defaults: run: working-directory: python @@ -143,32 +131,36 @@ jobs: - uses: actions/setup-python@v4 with: python-version: "3.8" - - uses: actions/download-artifact@v4 + - uses: actions/cache@v3 with: - name: python-venv - path: python/venv + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ hashFiles('python/pyproject.toml') }} + restore-keys: | + ${{ runner.os }}-pip- - run: | - source venv/bin/activate + python -m pip install --upgrade pip + pip install '.[development,openssl,tdms,rosbags,hdf5,sift-stream]' pytest -m "integration" sync-stubs-mypy: runs-on: ubuntu-latest - needs: setup defaults: run: - working-directory: python + working-directory: python/lib steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v4 with: python-version: "3.8" - - uses: actions/download-artifact@v4 + - uses: actions/cache@v3 with: - name: python-venv - path: python/venv + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ hashFiles('python/pyproject.toml') }} + restore-keys: | + ${{ runner.os }}-pip- - run: | - source venv/bin/activate - cd lib + python -m pip install --upgrade pip + pip install '.[development,openssl,tdms,rosbags,hdf5,sift-stream]' stubtest \ --mypy-config-file ../pyproject.toml \ sift_client.resources.sync_stubs From 449e73b910fed7eaf9fc6eae9d8572272d5396e0 Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Mon, 29 Sep 2025 10:58:38 -0700 Subject: [PATCH 11/39] revert to steps only --- .github/workflows/python_ci.yaml | 154 +++++-------------------------- 1 file changed, 25 insertions(+), 129 deletions(-) diff --git a/.github/workflows/python_ci.yaml b/.github/workflows/python_ci.yaml index 1f0ec3c24..160a7f525 100644 --- a/.github/workflows/python_ci.yaml +++ b/.github/workflows/python_ci.yaml @@ -1,8 +1,5 @@ name: python-ci -permissions: - contents: read - on: release: types: [ created ] @@ -13,154 +10,53 @@ on: workflow_call: jobs: - lint: + test-python: runs-on: ubuntu-latest defaults: run: working-directory: python steps: - - uses: actions/checkout@v4 + - name: Checkout code + uses: actions/checkout@v2 - - uses: actions/setup-python@v4 + - name: Set up Python + uses: actions/setup-python@v2 with: python-version: "3.8" - - uses: actions/cache@v3 - with: - path: ~/.cache/pip - key: ${{ runner.os }}-pip-${{ hashFiles('python/pyproject.toml') }} - restore-keys: | - ${{ runner.os }}-pip- - - - run: | + - name: Pip install + id: install + run: | python -m pip install --upgrade pip pip install '.[development,openssl,tdms,rosbags,hdf5,sift-stream]' + + - name: Lint + run: | ruff check - format: - runs-on: ubuntu-latest - defaults: - run: - working-directory: python - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-python@v4 - with: - python-version: "3.8" - - uses: actions/cache@v3 - with: - path: ~/.cache/pip - key: ${{ runner.os }}-pip-${{ hashFiles('python/pyproject.toml') }} - restore-keys: | - ${{ runner.os }}-pip- - - run: | - python -m pip install --upgrade pip - pip install '.[development,openssl,tdms,rosbags,hdf5,sift-stream]' + - name: Format + run: | ruff format --check - mypy: - runs-on: ubuntu-latest - defaults: - run: - working-directory: python - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-python@v4 - with: - python-version: "3.8" - - uses: actions/cache@v3 - with: - path: ~/.cache/pip - key: ${{ runner.os }}-pip-${{ hashFiles('python/pyproject.toml') }} - restore-keys: | - ${{ runner.os }}-pip- - - run: | - python -m pip install --upgrade pip - pip install '.[development,openssl,tdms,rosbags,hdf5,sift-stream]' + - name: MyPy + run: | mypy lib - pyright: - runs-on: ubuntu-latest - defaults: - run: - working-directory: python - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-python@v4 - with: - python-version: "3.8" - - uses: actions/cache@v3 - with: - path: ~/.cache/pip - key: ${{ runner.os }}-pip-${{ hashFiles('python/pyproject.toml') }} - restore-keys: | - ${{ runner.os }}-pip- - - run: | - python -m pip install --upgrade pip - pip install '.[development,openssl,tdms,rosbags,hdf5,sift-stream]' + - name: Pyright + run: | pyright lib - unit-tests: - runs-on: ubuntu-latest - defaults: - run: - working-directory: python - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-python@v4 - with: - python-version: "3.8" - - uses: actions/cache@v3 - with: - path: ~/.cache/pip - key: ${{ runner.os }}-pip-${{ hashFiles('python/pyproject.toml') }} - restore-keys: | - ${{ runner.os }}-pip- - - run: | - python -m pip install --upgrade pip - pip install '.[development,openssl,tdms,rosbags,hdf5,sift-stream]' + - name: Pytest Unit Tests + run: | pytest -m "not integration" - integration-tests: - runs-on: ubuntu-latest - defaults: - run: - working-directory: python - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-python@v4 - with: - python-version: "3.8" - - uses: actions/cache@v3 - with: - path: ~/.cache/pip - key: ${{ runner.os }}-pip-${{ hashFiles('python/pyproject.toml') }} - restore-keys: | - ${{ runner.os }}-pip- - - run: | - python -m pip install --upgrade pip - pip install '.[development,openssl,tdms,rosbags,hdf5,sift-stream]' + - name: Pytest Integration Tests + run: | pytest -m "integration" - sync-stubs-mypy: - runs-on: ubuntu-latest - defaults: - run: + - name: Sync Stubs Mypy working-directory: python/lib - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-python@v4 - with: - python-version: "3.8" - - uses: actions/cache@v3 - with: - path: ~/.cache/pip - key: ${{ runner.os }}-pip-${{ hashFiles('python/pyproject.toml') }} - restore-keys: | - ${{ runner.os }}-pip- - - run: | - python -m pip install --upgrade pip - pip install '.[development,openssl,tdms,rosbags,hdf5,sift-stream]' + run: | stubtest \ - --mypy-config-file ../pyproject.toml \ - sift_client.resources.sync_stubs + --mypy-config-file ../pyproject.toml \ + sift_client.resources.sync_stubs \ No newline at end of file From ec54a232fc72ffb1dbab653796e4bb0d390e0434 Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Mon, 29 Sep 2025 10:59:52 -0700 Subject: [PATCH 12/39] remove error test from test_runs --- .../sift_client/_tests/resources/test_runs.py | 34 ------------------- 1 file changed, 34 deletions(-) diff --git a/python/lib/sift_client/_tests/resources/test_runs.py b/python/lib/sift_client/_tests/resources/test_runs.py index 838b3155d..592e51b6f 100644 --- a/python/lib/sift_client/_tests/resources/test_runs.py +++ b/python/lib/sift_client/_tests/resources/test_runs.py @@ -9,7 +9,6 @@ import os -import grpc.aio._call import pytest from sift_client import SiftClient @@ -195,39 +194,6 @@ async def test_find_multiple_raises_error(self, runs_api_async): with pytest.raises(ValueError, match="Multiple"): await runs_api_async.find(name_contains="a") - class TestErrorHandling: - """Tests for error handling scenarios.""" - - @pytest.mark.asyncio - async def test_invalid_client_configuration(self): - """Test that invalid client configurations are handled gracefully.""" - # Create a client with invalid configuration - invalid_client = SiftClient( - api_key="invalid-key", - grpc_url="invalid-url:99999", - rest_url="invalid-url:99999", - ) - - # The client should be created but API calls should fail gracefully - runs_api = invalid_client.async_.runs - - # This should raise an appropriate error, not crash - with pytest.raises( - grpc.aio._call.AioRpcError - ): # Could be connection error, auth error, etc. - await runs_api.list_(limit=1) - - @pytest.mark.asyncio - async def test_integration_with_ping(self, runs_api_async, sift_client): - """Test that runs API works in conjunction with ping API.""" - # First verify connectivity with ping - ping_response = await sift_client.async_.ping.ping() - assert isinstance(ping_response, str) - assert len(ping_response) > 0 - - # Then test runs API - runs = await runs_api_async.list_(limit=1) - assert isinstance(runs, list) class TestRunsAPISync: From bc82512577de69b5f70286ee9ac156c41a67a2d7 Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Mon, 29 Sep 2025 11:01:28 -0700 Subject: [PATCH 13/39] lint --- python/lib/sift_client/_tests/resources/test_runs.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/lib/sift_client/_tests/resources/test_runs.py b/python/lib/sift_client/_tests/resources/test_runs.py index 592e51b6f..9898e903c 100644 --- a/python/lib/sift_client/_tests/resources/test_runs.py +++ b/python/lib/sift_client/_tests/resources/test_runs.py @@ -195,7 +195,6 @@ async def test_find_multiple_raises_error(self, runs_api_async): await runs_api_async.find(name_contains="a") - class TestRunsAPISync: """Test suite for the synchronous Runs API functionality. From ddeafd3c7634692d3e5270250e252d444117689e Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Mon, 29 Sep 2025 11:12:14 -0700 Subject: [PATCH 14/39] update env --- .github/workflows/python_ci.yaml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/python_ci.yaml b/.github/workflows/python_ci.yaml index 160a7f525..8a3e272d7 100644 --- a/.github/workflows/python_ci.yaml +++ b/.github/workflows/python_ci.yaml @@ -12,6 +12,7 @@ on: jobs: test-python: runs-on: ubuntu-latest + environment: python-integration-tests defaults: run: working-directory: python @@ -51,6 +52,10 @@ jobs: pytest -m "not integration" - name: Pytest Integration Tests + env: + SIFT_GRPC_URI: ${{ secrets.SIFT_GRPC_URI }} + SIFT_REST_URI: ${{ secrets.SIFT_REST_URI }} + SIFT_API_KEY: ${{ secrets.SIFT_API_KEY }} run: | pytest -m "integration" From fa0960cfac60e68d34c22a679db3edc9588ad3a2 Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Mon, 29 Sep 2025 11:17:47 -0700 Subject: [PATCH 15/39] update env --- .github/workflows/python_ci.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python_ci.yaml b/.github/workflows/python_ci.yaml index 8a3e272d7..b3bef49b9 100644 --- a/.github/workflows/python_ci.yaml +++ b/.github/workflows/python_ci.yaml @@ -53,8 +53,8 @@ jobs: - name: Pytest Integration Tests env: - SIFT_GRPC_URI: ${{ secrets.SIFT_GRPC_URI }} - SIFT_REST_URI: ${{ secrets.SIFT_REST_URI }} + SIFT_GRPC_URI: ${{ env.SIFT_GRPC_URI }} + SIFT_REST_URI: ${{ env.SIFT_REST_URI }} SIFT_API_KEY: ${{ secrets.SIFT_API_KEY }} run: | pytest -m "integration" From 1b3e02b04114b60da3944eb6ed2ddff0b3eef398 Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Mon, 29 Sep 2025 11:23:51 -0700 Subject: [PATCH 16/39] update env --- .github/workflows/python_ci.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python_ci.yaml b/.github/workflows/python_ci.yaml index b3bef49b9..4b7d539fa 100644 --- a/.github/workflows/python_ci.yaml +++ b/.github/workflows/python_ci.yaml @@ -53,8 +53,8 @@ jobs: - name: Pytest Integration Tests env: - SIFT_GRPC_URI: ${{ env.SIFT_GRPC_URI }} - SIFT_REST_URI: ${{ env.SIFT_REST_URI }} + SIFT_GRPC_URI: ${{ vars.SIFT_GRPC_URI }} + SIFT_REST_URI: ${{ vars.SIFT_REST_URI }} SIFT_API_KEY: ${{ secrets.SIFT_API_KEY }} run: | pytest -m "integration" From d1937b9f66a150eb01499b70884adc7c68d7b4ad Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Mon, 29 Sep 2025 16:32:59 -0700 Subject: [PATCH 17/39] update runs integration test --- .../sift_client/_tests/resources/test_runs.py | 43 +++++-------------- 1 file changed, 11 insertions(+), 32 deletions(-) diff --git a/python/lib/sift_client/_tests/resources/test_runs.py b/python/lib/sift_client/_tests/resources/test_runs.py index 9898e903c..0b1db81d0 100644 --- a/python/lib/sift_client/_tests/resources/test_runs.py +++ b/python/lib/sift_client/_tests/resources/test_runs.py @@ -142,18 +142,19 @@ class TestGet: """Tests for the async get method.""" @pytest.mark.asyncio - async def test_get_by_name(self, runs_api_async, test_run): - """Test getting a specific run by name.""" - retrieved_run = await runs_api_async.get(name=test_run.name) + async def test_get_by_id(self, runs_api_async, test_run): + """Test getting a specific run by ID.""" + retrieved_run = await runs_api_async.get(run_id=test_run.id_) assert retrieved_run is not None assert retrieved_run.id_ == test_run.id_ - assert retrieved_run.name == test_run.name + # TODO: test for client key @pytest.mark.asyncio - async def test_get_by_id(self, runs_api_async, test_run): - """Test getting a specific run by ID.""" - retrieved_run = await runs_api_async.get(run_id=test_run.id_) + async def test_get_by_id_with_client_key(self, runs_api_async, test_run): + """Test getting a specific run by client key.""" + assert test_run.client_key is not None + retrieved_run = await runs_api_async.get(client_key=test_run.client_key) assert retrieved_run is not None assert retrieved_run.id_ == test_run.id_ @@ -161,14 +162,14 @@ async def test_get_by_id(self, runs_api_async, test_run): @pytest.mark.asyncio async def test_get_without_params_raises_error(self, runs_api_async): """Test that getting a run without parameters raises an error.""" - with pytest.raises(ValueError, match="Either run_id or name must be provided"): + with pytest.raises(ValueError, match="must be provided"): await runs_api_async.get() @pytest.mark.asyncio async def test_get_nonexistent_run_raises_error(self, runs_api_async): """Test that getting a non-existent run raises an error.""" - with pytest.raises(ValueError, match="No run found"): - await runs_api_async.get(name="nonexistent-run-name-12345") + with pytest.raises(ValueError, match="not found"): + await runs_api_async.get(client_key="nonexistent-run-name-12345") class TestFind: """Tests for the async find method.""" @@ -212,25 +213,3 @@ def test_basic_list(self, runs_api_sync): assert isinstance(runs, list) assert runs assert isinstance(runs[0], Run) - - class TestGet: - """Tests for the sync get method.""" - - def test_basic_get(self, runs_api_sync, test_run): - """Test basic synchronous get functionality.""" - retrieved_run = runs_api_sync.get(name=test_run.name) - - assert retrieved_run is not None - assert isinstance(retrieved_run, Run) - assert retrieved_run.id_ == test_run.id_ - - class TestFind: - """Tests for the sync find method.""" - - def test_basic_find(self, runs_api_sync, test_run): - """Test basic synchronous find functionality.""" - found_run = runs_api_sync.find(name=test_run.name) - - assert found_run is not None - assert isinstance(found_run, Run) - assert found_run.id_ == test_run.id_ From f3fb8a6308fc587f048d7d99a3fc781d924b7822 Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Mon, 29 Sep 2025 18:18:36 -0700 Subject: [PATCH 18/39] update runs integration test --- .../sift_client/_tests/resources/test_runs.py | 293 ++++++++++++++++++ python/lib/sift_client/resources/runs.py | 3 +- python/lib/sift_client/sift_types/_base.py | 4 +- 3 files changed, 297 insertions(+), 3 deletions(-) diff --git a/python/lib/sift_client/_tests/resources/test_runs.py b/python/lib/sift_client/_tests/resources/test_runs.py index 0b1db81d0..f541732a7 100644 --- a/python/lib/sift_client/_tests/resources/test_runs.py +++ b/python/lib/sift_client/_tests/resources/test_runs.py @@ -8,12 +8,14 @@ """ import os +from datetime import datetime, timedelta, timezone import pytest from sift_client import SiftClient from sift_client.resources import RunsAPI, RunsAPIAsync from sift_client.sift_types import Run +from sift_client.sift_types.run import RunCreate, RunUpdate pytestmark = pytest.mark.integration @@ -58,6 +60,20 @@ def test_run(runs_api_sync): assert len(runs) >= 1 return runs[0] +@pytest.fixture(scope="function") +def new_run(runs_api_sync): + """Create a test run for update tests.""" + run_name = f"test_run_update_{datetime.now(timezone.utc).isoformat()}" + description = "Test run created by Sift Client pytest" + created_run = runs_api_sync.create( + RunCreate( + name=run_name, + description=description, + tags=["sift-client-pytest"], + ) + ) + return created_run + class TestRunsAPIAsync: """Test suite for the async Runs API functionality.""" @@ -111,6 +127,8 @@ async def test_list_with_name_contains_filter(self, runs_api_async): for run in runs: assert "test" in run.name.lower() + # TODO: test run-specific filters + @pytest.mark.asyncio async def test_list_with_limit(self, runs_api_async): """Test run listing with different limits.""" @@ -195,6 +213,281 @@ async def test_find_multiple_raises_error(self, runs_api_async): with pytest.raises(ValueError, match="Multiple"): await runs_api_async.find(name_contains="a") + class TestCreate: + """Tests for the async create method.""" + + @pytest.mark.asyncio + async def test_create_basic_run(self, runs_api_async): + """Test creating a basic run with minimal fields.""" + run_name = f"test_run_create_{datetime.now(timezone.utc).isoformat()}" + description = "Test run created by Sift Client pytest" + run_create = RunCreate( + name=run_name, + description=description, + tags=["sift-client-pytest"], + ) + + created_run = await runs_api_async.create(run_create) + + try: + # Verify the run was created + assert created_run is not None + assert isinstance(created_run, Run) + assert created_run.id_ is not None + assert created_run.name == run_name + assert created_run.description == description + assert created_run.created_date is not None + assert created_run.modified_date is not None + finally: + # Clean up: archive the test run + await runs_api_async.archive(created_run) + + @pytest.mark.asyncio + async def test_create_run_with_all_fields(self, runs_api_async): + """Test creating a run with all optional fields.""" + run_name = f"test_run_full_{datetime.now(timezone.utc).isoformat()}" + description = "Test run created by Sift Client pytest" + start_time = datetime.now(timezone.utc) - timedelta(hours=1) + stop_time = datetime.now(timezone.utc) + + run_create = RunCreate( + name=run_name, + description=description, + client_key=f"client_key_{datetime.now(timezone.utc).timestamp()}", + start_time=start_time, + stop_time=stop_time, + tags=["test", "pytest", "integration", "sift-client-pytest"], + metadata={"test_type": "integration", "version": "1.0", "is_automated": True}, + ) + + created_run = await runs_api_async.create(run_create) + + try: + # Verify all fields + assert created_run.name == run_name + assert created_run.description == description + assert created_run.client_key is not None + assert created_run.start_time is not None + assert created_run.stop_time is not None + assert created_run.tags == ["test", "pytest", "integration", "sift-client-pytest"] + assert created_run.metadata["test_type"] == "integration" + assert created_run.metadata["version"] == "1.0" + assert created_run.metadata["is_automated"] is True + finally: + # Clean up + await runs_api_async.archive(created_run) + + @pytest.mark.asyncio + async def test_create_run_with_dict(self, runs_api_async): + """Test creating a run using a dictionary instead of RunCreate object.""" + run_name = f"test_run_dict_{datetime.now(timezone.utc).isoformat()}" + description = "Test run created by Sift Client pytest" + + run_dict = { + "name": run_name, + "description": description, + "tags": ["sift-client-pytest"], + } + + created_run = await runs_api_async.create(run_dict) + + try: + assert created_run.name == run_name + assert created_run.description == description + assert created_run.tags == ["sift-client-pytest"] + finally: + await runs_api_async.archive(created_run) + + class TestUpdate: + """Tests for the async update method.""" + + @pytest.mark.asyncio + async def test_update_run_description(self, runs_api_async, new_run): + """Test updating a run's description.""" + try: + # Update the description + update = RunUpdate(description="Updated description") + updated_run = await runs_api_async.update(new_run, update) + + # Verify the update + assert updated_run.id_ == new_run.id_ + assert updated_run.description == "Updated description" + assert updated_run.name == new_run.name # Name should remain unchanged + finally: + await runs_api_async.archive(new_run.id_) + + @pytest.mark.asyncio + async def test_update_run_name(self, runs_api_async, new_run): + """Test updating a run's name.""" + try: + # Update the name + new_name = f"updated_{new_run.name}" + update = RunUpdate(name=new_name) + updated_run = await runs_api_async.update(new_run, update) + + # Verify the update + assert updated_run.name == new_name + assert updated_run.id_ == new_run.id_ + finally: + await runs_api_async.archive(new_run.id_) + + @pytest.mark.asyncio + async def test_update_run_tags_and_metadata(self, runs_api_async, new_run): + """Test updating a run's tags and metadata.""" + try: + # Update tags + update = RunUpdate( + tags=["updated", "new-tag", "sift-client-pytest"], + ) + updated_run = await runs_api_async.update(new_run, update) + + # Verify the updates + assert set(updated_run.tags) == {"updated", "new-tag", "sift-client-pytest"} + finally: + await runs_api_async.archive(new_run.id_) + + @pytest.mark.asyncio + async def test_update_run_times(self, runs_api_async, new_run): + """Test updating a run's start and stop times.""" + try: + # Update with start and stop times + start_time = datetime.now(timezone.utc) - timedelta(hours=2) + stop_time = datetime.now(timezone.utc) - timedelta(hours=1) + update = RunUpdate(start_time=start_time, stop_time=stop_time) + updated_run = await runs_api_async.update(new_run, update) + + # Verify the times were set + assert updated_run.start_time is not None + assert updated_run.stop_time is not None + # Allow for small time differences due to serialization + assert abs((updated_run.start_time - start_time).total_seconds()) < 1 + assert abs((updated_run.stop_time - stop_time).total_seconds()) < 1 + finally: + await runs_api_async.archive(new_run.id_) + + @pytest.mark.asyncio + async def test_update_with_dict(self, runs_api_async, new_run): + """Test updating a run using a dictionary instead of RunUpdate object.""" + try: + # Update using dict + update_dict = {"description": "Updated via dict"} + updated_run = await runs_api_async.update(new_run, update_dict) + + assert updated_run.description == "Updated via dict" + finally: + await runs_api_async.archive(new_run.id_) + + @pytest.mark.asyncio + async def test_update_with_run_id_string(self, runs_api_async, new_run): + """Test updating a run by passing run ID as string.""" + try: + # Update using run ID string + update = RunUpdate(description="Updated via ID string") + updated_run = await runs_api_async.update(new_run.id_, update) + + assert updated_run.id_ == new_run.id_ + assert updated_run.description == "Updated via ID string" + finally: + await runs_api_async.archive(new_run.id_) + + class TestArchive: + """Tests for the async archive method.""" + + @pytest.mark.asyncio + async def test_archive_run(self, runs_api_async, new_run): + """Test archiving a run.""" + run = await runs_api_async.archive(new_run) + + assert isinstance(run, Run) + assert run.id_ == new_run.id_ + assert run.is_archived is True + + # Verify it's archived by checking it doesn't appear in normal list + runs_without_archived = await runs_api_async.list_( + name=new_run.name, include_archived=False + ) + assert len(runs_without_archived) == 0 + + # Verify it appears when including archived + runs_with_archived = await runs_api_async.list_( + name=new_run.name, include_archived=True + ) + assert len(runs_with_archived) == 1 + assert runs_with_archived[0].id_ == new_run.id_ + assert runs_with_archived[0].archived_date is not None + + @pytest.mark.asyncio + async def test_archive_with_run_id_string(self, runs_api_async, new_run): + """Test archiving a run by passing run ID as string.""" + # Archive using run ID string + run = await runs_api_async.archive(new_run.id_) + + assert isinstance(run, Run) + assert run.id_ == new_run.id_ + assert run.is_archived is True + + @pytest.mark.asyncio + async def test_get_archived_run_by_id(self, runs_api_async, new_run): + """Test that we can still get an archived run by ID.""" + # Archive the test run + run = await runs_api_async.archive(new_run) + + assert isinstance(run, Run) + assert run.id_ == new_run.id_ + assert run.is_archived is True + + class TestStop: + """Tests for the async stop method.""" + + @pytest.mark.asyncio + async def test_stop_run(self, runs_api_async, new_run): + """Test stopping a run.""" + try: + # Stop the run + stopped_run = await runs_api_async.stop(new_run) + + # Verify the run was stopped + assert isinstance(stopped_run, Run) + assert stopped_run.id_ == new_run.id_ + assert stopped_run.stop_time is not None + finally: + await runs_api_async.archive(new_run.id_) + + @pytest.mark.asyncio + async def test_stop_run_with_id_string(self, runs_api_async, new_run): + """Test stopping a run by passing run ID as string.""" + try: + # Stop using run ID string + stopped_run = await runs_api_async.stop(new_run.id_) + + # Verify the run was stopped + assert isinstance(stopped_run, Run) + assert stopped_run.id_ == new_run.id_ + assert stopped_run.stop_time is not None + finally: + await runs_api_async.archive(new_run.id_) + + @pytest.mark.asyncio + async def test_stop_run_with_start_time(self, runs_api_async, new_run): + """Test stopping a run that has a start time.""" + try: + # Set start time first + start_time = datetime.now(timezone.utc) - timedelta(hours=1) + update = RunUpdate(start_time=start_time) + await runs_api_async.update(new_run, update) + + # Stop the run + stopped_run = await runs_api_async.stop(new_run) + + # Verify the run was stopped and times are valid + assert stopped_run.stop_time is not None + assert stopped_run.start_time is not None + assert stopped_run.stop_time > stopped_run.start_time + finally: + await runs_api_async.archive(new_run.id_) + + + class TestRunsAPISync: """Test suite for the synchronous Runs API functionality. diff --git a/python/lib/sift_client/resources/runs.py b/python/lib/sift_client/resources/runs.py index cf62b6b29..1aa34e267 100644 --- a/python/lib/sift_client/resources/runs.py +++ b/python/lib/sift_client/resources/runs.py @@ -252,7 +252,7 @@ async def unarchive( async def stop( self, run: str | Run, - ) -> None: + ) -> Run: """Stop a run by setting its stop time to the current time. Args: @@ -260,6 +260,7 @@ async def stop( """ run_id = run._id_or_error if isinstance(run, Run) else run await self._low_level_client.stop_run(run_id=run_id or "") + return await self.get(run_id=run_id) async def create_automatic_association_for_assets( self, diff --git a/python/lib/sift_client/sift_types/_base.py b/python/lib/sift_client/sift_types/_base.py index adbdf37e7..b88535e2c 100644 --- a/python/lib/sift_client/sift_types/_base.py +++ b/python/lib/sift_client/sift_types/_base.py @@ -155,7 +155,7 @@ def _build_proto_and_paths( try: setattr(proto_msg, field_name, value) paths.append(path) - except TypeError as e: + except (TypeError, AttributeError) as e: raise TypeError( f"Can't set {field_name} to {value} on {proto_msg.__class__.__name__}" ) from e @@ -178,7 +178,7 @@ def to_proto(self) -> ProtoT: proto_msg = proto_cls() # Get all fields - data = self.model_dump(exclude_none=False) + data = self.model_dump(exclude_unset=True) self._build_proto_and_paths(proto_msg, data) return proto_msg From c29ec36196fa575d114159c7b3ba6b28ad3abc46 Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Mon, 29 Sep 2025 18:30:32 -0700 Subject: [PATCH 19/39] fmt --- python/lib/sift_client/_tests/resources/test_runs.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/lib/sift_client/_tests/resources/test_runs.py b/python/lib/sift_client/_tests/resources/test_runs.py index f541732a7..1cd18396c 100644 --- a/python/lib/sift_client/_tests/resources/test_runs.py +++ b/python/lib/sift_client/_tests/resources/test_runs.py @@ -60,6 +60,7 @@ def test_run(runs_api_sync): assert len(runs) >= 1 return runs[0] + @pytest.fixture(scope="function") def new_run(runs_api_sync): """Create a test run for update tests.""" @@ -487,8 +488,6 @@ async def test_stop_run_with_start_time(self, runs_api_async, new_run): await runs_api_async.archive(new_run.id_) - - class TestRunsAPISync: """Test suite for the synchronous Runs API functionality. From b7d3dca6b8668996587673553e53f221a77894ac Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Mon, 29 Sep 2025 18:49:59 -0700 Subject: [PATCH 20/39] fix tests --- .../sift_client/_tests/resources/test_runs.py | 23 +++++++++---------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/python/lib/sift_client/_tests/resources/test_runs.py b/python/lib/sift_client/_tests/resources/test_runs.py index 1cd18396c..0444a407f 100644 --- a/python/lib/sift_client/_tests/resources/test_runs.py +++ b/python/lib/sift_client/_tests/resources/test_runs.py @@ -169,14 +169,14 @@ async def test_get_by_id(self, runs_api_async, test_run): assert retrieved_run.id_ == test_run.id_ # TODO: test for client key - @pytest.mark.asyncio - async def test_get_by_id_with_client_key(self, runs_api_async, test_run): - """Test getting a specific run by client key.""" - assert test_run.client_key is not None - retrieved_run = await runs_api_async.get(client_key=test_run.client_key) - - assert retrieved_run is not None - assert retrieved_run.id_ == test_run.id_ + # @pytest.mark.asyncio + # async def test_get_by_id_with_client_key(self, runs_api_async, test_run): + # """Test getting a specific run by client key.""" + # assert test_run.client_key is not None + # retrieved_run = await runs_api_async.get(client_key=test_run.client_key) + # + # assert retrieved_run is not None + # assert retrieved_run.id_ == test_run.id_ @pytest.mark.asyncio async def test_get_without_params_raises_error(self, runs_api_async): @@ -258,7 +258,7 @@ async def test_create_run_with_all_fields(self, runs_api_async): start_time=start_time, stop_time=stop_time, tags=["test", "pytest", "integration", "sift-client-pytest"], - metadata={"test_type": "integration", "version": "1.0", "is_automated": True}, + metadata={"pytest_type": "integration"}, ) created_run = await runs_api_async.create(run_create) @@ -271,9 +271,8 @@ async def test_create_run_with_all_fields(self, runs_api_async): assert created_run.start_time is not None assert created_run.stop_time is not None assert created_run.tags == ["test", "pytest", "integration", "sift-client-pytest"] - assert created_run.metadata["test_type"] == "integration" - assert created_run.metadata["version"] == "1.0" - assert created_run.metadata["is_automated"] is True + assert created_run.metadata["pytest_type"] == "integration" + finally: # Clean up await runs_api_async.archive(created_run) From df3d868fea5e19ac4a859f4d0264653db5c59eee Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Mon, 6 Oct 2025 19:19:22 -0700 Subject: [PATCH 21/39] fix dev script args --- python/scripts/dev | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/scripts/dev b/python/scripts/dev index 6329dce2f..7a95a5ced 100755 --- a/python/scripts/dev +++ b/python/scripts/dev @@ -159,10 +159,10 @@ case "$1" in test) run_tests ;; - test_integration) + test-integration) run_tests_integration ;; - test_all) + test-all) run_tests_all ;; gen-stubs) From a3ae253ea1fe87533bf1d413f620fbc4771ef0a0 Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Tue, 7 Oct 2025 17:50:27 -0700 Subject: [PATCH 22/39] move sift client fixture to a shared file --- .../_tests/resources/test_assets.py | 14 ----------- .../sift_client/_tests/resources/test_ping.py | 16 ------------ .../sift_client/_tests/resources/test_runs.py | 14 ----------- .../lib/sift_client/_tests/shared_fixtures.py | 25 +++++++++++++++++++ 4 files changed, 25 insertions(+), 44 deletions(-) create mode 100644 python/lib/sift_client/_tests/shared_fixtures.py diff --git a/python/lib/sift_client/_tests/resources/test_assets.py b/python/lib/sift_client/_tests/resources/test_assets.py index 10fda6c15..ca69ee500 100644 --- a/python/lib/sift_client/_tests/resources/test_assets.py +++ b/python/lib/sift_client/_tests/resources/test_assets.py @@ -18,20 +18,6 @@ pytestmark = pytest.mark.integration -@pytest.fixture(scope="session") -def sift_client() -> SiftClient: - """Create a SiftClient instance for testing.""" - grpc_url = os.getenv("SIFT_GRPC_URI", "localhost:50051") - rest_url = os.getenv("SIFT_REST_URI", "localhost:8080") - api_key = os.getenv("SIFT_API_KEY", "") - - return SiftClient( - api_key=api_key, - grpc_url=grpc_url, - rest_url=rest_url, - ) - - def test_client_binding(sift_client): assert sift_client.assets assert isinstance(sift_client.assets, AssetsAPI) diff --git a/python/lib/sift_client/_tests/resources/test_ping.py b/python/lib/sift_client/_tests/resources/test_ping.py index ca924f92b..587d8a7f0 100644 --- a/python/lib/sift_client/_tests/resources/test_ping.py +++ b/python/lib/sift_client/_tests/resources/test_ping.py @@ -6,8 +6,6 @@ - Error handling and edge cases """ -import os - import pytest from sift_client import SiftClient @@ -16,20 +14,6 @@ pytestmark = pytest.mark.integration -@pytest.fixture(scope="session") -def sift_client() -> SiftClient: - """Create a SiftClient instance for testing.""" - grpc_url = os.getenv("SIFT_GRPC_URI", "localhost:50051") - rest_url = os.getenv("SIFT_REST_URI", "localhost:8080") - api_key = os.getenv("SIFT_API_KEY", "") - - return SiftClient( - api_key=api_key, - grpc_url=grpc_url, - rest_url=rest_url, - ) - - def test_client_binding(sift_client): assert sift_client.ping assert isinstance(sift_client.ping, PingAPI) diff --git a/python/lib/sift_client/_tests/resources/test_runs.py b/python/lib/sift_client/_tests/resources/test_runs.py index 0444a407f..a4e3fc750 100644 --- a/python/lib/sift_client/_tests/resources/test_runs.py +++ b/python/lib/sift_client/_tests/resources/test_runs.py @@ -20,20 +20,6 @@ pytestmark = pytest.mark.integration -@pytest.fixture(scope="session") -def sift_client() -> SiftClient: - """Create a SiftClient instance for testing.""" - grpc_url = os.getenv("SIFT_GRPC_URI", "localhost:50051") - rest_url = os.getenv("SIFT_REST_URI", "localhost:8080") - api_key = os.getenv("SIFT_API_KEY", "") - - return SiftClient( - api_key=api_key, - grpc_url=grpc_url, - rest_url=rest_url, - ) - - def test_client_binding(sift_client): assert sift_client.runs assert isinstance(sift_client.runs, RunsAPI) diff --git a/python/lib/sift_client/_tests/shared_fixtures.py b/python/lib/sift_client/_tests/shared_fixtures.py new file mode 100644 index 000000000..45566fd1b --- /dev/null +++ b/python/lib/sift_client/_tests/shared_fixtures.py @@ -0,0 +1,25 @@ +"""Shared pytest fixtures for all tests.""" + +import os + +import pytest + +from sift_client import SiftClient + + +@pytest.fixture(scope="session") +def sift_client() -> SiftClient: + """Create a SiftClient instance for testing. + + This fixture is shared across all test files and is session-scoped + to avoid creating multiple client instances. + """ + grpc_url = os.getenv("SIFT_GRPC_URI", "localhost:50051") + rest_url = os.getenv("SIFT_REST_URI", "localhost:8080") + api_key = os.getenv("SIFT_API_KEY", "") + + return SiftClient( + api_key=api_key, + grpc_url=grpc_url, + rest_url=rest_url, + ) From 37091850899c2cdea3248f2dc9ea4c1b78b869ba Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Wed, 8 Oct 2025 18:58:28 -0700 Subject: [PATCH 23/39] add channels test --- .../{shared_fixtures.py => conftest.py} | 2 +- .../_tests/resources/test_channels.py | 358 ++++++++++++++++++ python/lib/sift_client/resources/_base.py | 10 +- 3 files changed, 365 insertions(+), 5 deletions(-) rename python/lib/sift_client/_tests/{shared_fixtures.py => conftest.py} (99%) create mode 100644 python/lib/sift_client/_tests/resources/test_channels.py diff --git a/python/lib/sift_client/_tests/shared_fixtures.py b/python/lib/sift_client/_tests/conftest.py similarity index 99% rename from python/lib/sift_client/_tests/shared_fixtures.py rename to python/lib/sift_client/_tests/conftest.py index 45566fd1b..a8ff39d55 100644 --- a/python/lib/sift_client/_tests/shared_fixtures.py +++ b/python/lib/sift_client/_tests/conftest.py @@ -10,7 +10,7 @@ @pytest.fixture(scope="session") def sift_client() -> SiftClient: """Create a SiftClient instance for testing. - + This fixture is shared across all test files and is session-scoped to avoid creating multiple client instances. """ diff --git a/python/lib/sift_client/_tests/resources/test_channels.py b/python/lib/sift_client/_tests/resources/test_channels.py new file mode 100644 index 000000000..8afdb60be --- /dev/null +++ b/python/lib/sift_client/_tests/resources/test_channels.py @@ -0,0 +1,358 @@ +"""Pytest tests for the Channels API. + +These tests demonstrate and validate the usage of the Channels API including: +- Basic channel operations (get, list, find) +- Channel filtering and searching +- Channel data retrieval +- Error handling and edge cases +""" + +import pytest + +from sift_client import SiftClient +from sift_client.resources import ChannelsAPI, ChannelsAPIAsync +from sift_client.sift_types import Channel + +pytestmark = pytest.mark.integration + + +def test_client_binding(sift_client): + assert sift_client.channels + assert isinstance(sift_client.channels, ChannelsAPI) + assert sift_client.async_.channels + assert isinstance(sift_client.async_.channels, ChannelsAPIAsync) + + +@pytest.fixture +def channels_api_async(sift_client: SiftClient): + """Get the async channels API instance.""" + return sift_client.async_.channels + + +@pytest.fixture +def channels_api_sync(sift_client: SiftClient): + """Get the synchronous channels API instance.""" + return sift_client.channels + + +@pytest.fixture +def test_channel(channels_api_sync): + channels = channels_api_sync.list_(limit=1) + assert channels + assert len(channels) >= 1 + return channels[0] + + +class TestChannelsAPIAsync: + """Test suite for the async Channels API functionality.""" + + class TestGet: + """Tests for the async get method.""" + + @pytest.mark.asyncio + async def test_get_by_id(self, channels_api_async, test_channel): + """Test getting a specific channel by ID.""" + retrieved_channel = await channels_api_async.get(channel_id=test_channel.id_) + + assert isinstance(retrieved_channel, Channel) + assert retrieved_channel.id_ == test_channel.id_ + assert retrieved_channel.name == test_channel.name + + class TestList: + """Tests for the async list_ method.""" + + @pytest.mark.asyncio + async def test_basic_list(self, channels_api_async): + """Test basic channel listing functionality.""" + channels = await channels_api_async.list_(limit=5) + + # Verify we get a list + assert isinstance(channels, list) + assert len(channels) == 5 + + # If we have channels, verify their structure + channel = channels[0] + assert isinstance(channel, Channel) + + @pytest.mark.asyncio + async def test_list_with_name_filter(self, channels_api_async): + """Test channel listing with name filtering.""" + # First get some channels to work with + all_channels = await channels_api_async.list_(limit=10) + + test_channel_name = all_channels[0].name + filtered_channels = await channels_api_async.list_(name=test_channel_name) + + # Should find at least one channel with exact name match + assert isinstance(filtered_channels, list) + assert len(filtered_channels) >= 1 + + # All returned channels should have the exact name + for channel in filtered_channels: + assert channel.name == test_channel_name + + @pytest.mark.asyncio + async def test_list_with_name_contains_filter(self, channels_api_async): + """Test channel listing with name contains filtering.""" + # Test with a common substring that might exist in channel names + channels = await channels_api_async.list_(name_contains="test", limit=5) + + assert isinstance(channels, list) + assert channels + + for channel in channels: + assert "test" in channel.name.lower() + + @pytest.mark.asyncio + async def test_list_with_name_regex_filter(self, channels_api_async): + """Test channel listing with regex name filtering.""" + # Test with a regex pattern + channels = await channels_api_async.list_(name_regex=r".*test.*", limit=5) + + assert isinstance(channels, list) + assert channels + + import re + + pattern = re.compile(r".*test.*", re.IGNORECASE) + for channel in channels: + assert pattern.match(channel.name) + + @pytest.mark.asyncio + async def test_list_with_channel_ids_filter(self, channels_api_async): + """Test channel listing with channel IDs filter.""" + all_channels = await channels_api_async.list_(limit=3) + + if all_channels: + channel_ids = [ch.id_ for ch in all_channels] + filtered_channels = await channels_api_async.list_(channel_ids=channel_ids) + + # Should find at least the channels we specified + assert isinstance(filtered_channels, list) + assert len(filtered_channels) >= len(all_channels) + + # All returned channels should have IDs in our list + for channel in filtered_channels: + assert channel.id_ in channel_ids + + @pytest.mark.asyncio + async def test_list_with_asset_filter(self, channels_api_async): + """Test channel listing with asset filter.""" + # First get a channel to get its asset + all_channels = await channels_api_async.list_(limit=1) + + if all_channels: + test_channel = all_channels[0] + # Filter by asset ID + filtered_channels = await channels_api_async.list_(asset=test_channel.asset_id) + + # Should find at least one channel for this asset + assert isinstance(filtered_channels, list) + assert len(filtered_channels) >= 1 + + # All returned channels should belong to the same asset + for channel in filtered_channels: + assert channel.asset_id == test_channel.asset_id + + @pytest.mark.asyncio + async def test_list_with_description_contains_filter(self, channels_api_async): + """Test channel listing with description contains filtering.""" + # Test with a common substring that might exist in descriptions + channels = await channels_api_async.list_(description_contains="test", limit=5) + + assert isinstance(channels, list) + assert channels + + # If we found channels, verify they contain the substring in description + for channel in channels: + assert "test" in channel.description.lower() + + @pytest.mark.asyncio + async def test_list_with_limit(self, channels_api_async): + """Test channel listing with different limits.""" + # Test with limit of 1 + channels_1 = await channels_api_async.list_(limit=1) + assert isinstance(channels_1, list) + assert len(channels_1) <= 1 + + # Test with limit of 3 + channels_3 = await channels_api_async.list_(limit=3) + assert isinstance(channels_3, list) + assert len(channels_3) <= 3 + + # TODO: active channel test + # @pytest.mark.asyncio + # async def test_list_include_archived(self, channels_api_async): + # """Test channel listing with archived channels included.""" + # # Test without archived channels (default) + # channels_active = await channels_api_async.list_(limit=5, include_archived=False) + # assert isinstance(channels_active, list) + # + # # Test with archived channels included + # channels_all = await channels_api_async.list_(limit=5, include_archived=True) + # assert isinstance(channels_all, list) + # + # # Should have at least as many channels when including archived + # assert len(channels_all) >= len(channels_active) + + @pytest.mark.asyncio + async def test_list_with_time_filters(self, channels_api_async): + """Test channel listing with time-based filters.""" + from datetime import datetime, timedelta, timezone + + # Get channels created in the last year + one_year_ago = datetime.now(timezone.utc) - timedelta(days=365) + channels = await channels_api_async.list_(created_after=one_year_ago, limit=5) + + assert isinstance(channels, list) + assert channels + + # If we found channels, verify they were created after the specified time + for channel in channels: + assert channel.created_date >= one_year_ago + + class TestFind: + """Tests for the async find method.""" + + @pytest.mark.asyncio + async def test_find_channel(self, channels_api_async, test_channel): + """Test finding a single channel.""" + # Find the same channel by name and asset + found_channel = await channels_api_async.find( + name=test_channel.name, asset=test_channel.asset_id + ) + + assert found_channel is not None + assert found_channel.id_ == test_channel.id_ + + @pytest.mark.asyncio + async def test_find_nonexistent_channel(self, channels_api_async): + """Test finding a non-existent channel returns None.""" + found_channel = await channels_api_async.find(name="nonexistent-channel-name-12345") + assert found_channel is None + + @pytest.mark.asyncio + async def test_find_multiple_raises_error(self, channels_api_async): + """Test finding multiple channels raises an error.""" + with pytest.raises(ValueError, match="Multiple"): + await channels_api_async.find(name_contains="test", limit=5) + + # TODO: data retrieval tests + # class TestGetData: + # """Tests for the async get_data method.""" + # + # @pytest.mark.asyncio + # async def test_get_data_basic(self, channels_api_async, test_channel): + # """Test getting channel data.""" + # # Get the channel's asset to find a run + # from sift_client.sift_types import Asset + # + # asset = await channels_api_async.client.async_.assets.get( + # asset_id=test_channel.asset_id + # ) + # assert isinstance(asset, Asset) + # + # # Get runs for this asset + # runs = await channels_api_async.client.async_.runs.list_(limit=1) + # assert runs + # run = runs[0] + # # Get data for the channel + # data = await channels_api_async.get_data( + # channels=[test_channel], run=run.id_, limit=10 + # ) + # + # # Verify we get a dictionary + # assert isinstance(data, dict) + # assert data + # + # # Should have an entry for our channel + # assert test_channel.name in data or len(data) > 0 + # + # @pytest.mark.asyncio + # async def test_get_data_with_time_range(self, channels_api_async, test_channel): + # """Test getting channel data with time range.""" + # from datetime import datetime, timedelta, timezone + # + # # Get runs for this asset + # runs = await channels_api_async.client.async_.runs.list_(limit=1) + # assert runs + # run = runs[0] + # # Define a time range + # end_time = datetime.now(timezone.utc) + # start_time = end_time - timedelta(hours=1) + # + # # Get data for the channel with time range + # data = await channels_api_async.get_data( + # channels=[test_channel], + # run=run.id_, + # start_time=start_time, + # end_time=end_time, + # limit=10, + # ) + # + # # Verify we get a dictionary + # assert isinstance(data, dict) + # + # @pytest.mark.asyncio + # async def test_get_data_as_arrow(self, channels_api_async, test_channel): + # """Test getting channel data as Arrow table.""" + # import pyarrow as pa + # + # # Get runs for this asset + # runs = await channels_api_async.client.async_.runs.list_(limit=1) + # assert runs + # run = runs[0] + # # Get data as Arrow table + # data = await channels_api_async.get_data_as_arrow( + # channels=[test_channel], run=run.id_, limit=10 + # ) + # + # # Verify we get a dictionary + # assert isinstance(data, dict) + # assert data + # for table in data.values(): + # assert isinstance(table, pa.Table) + # + # @pytest.mark.asyncio + # async def test_get_data_multiple_channels(self, channels_api_async): + # """Test getting data for multiple channels.""" + # # Get multiple channels from the same asset + # channels = await channels_api_async.list_(limit=3) + # + # if len(channels) >= 2: + # # Get the first asset's channels + # first_asset_id = channels[0].asset_id + # asset_channels = [ + # ch for ch in channels if ch.asset_id == first_asset_id + # ][:2] + # + # if len(asset_channels) >= 2: + # # Get runs for this asset + # runs = await channels_api_async.client.async_.runs.list_(limit=1) + # + # if runs: + # run = runs[0] + # # Get data for multiple channels + # data = await channels_api_async.get_data( + # channels=asset_channels, run=run.id_, limit=10 + # ) + # + # # Verify we get a dictionary + # assert isinstance(data, dict) + + +class TestChannelsAPISync: + """Test suite for the synchronous Channels API functionality. + + Only includes a single test for basic sync generation. No specific sync behavior difference tests are needed. + """ + + class TestGet: + """Tests for the sync get method.""" + + def test_get_by_id(self, channels_api_sync, test_channel): + """Test getting a specific channel by ID synchronously.""" + retrieved_channel = channels_api_sync.get(channel_id=test_channel.id_) + + assert retrieved_channel is not None + assert retrieved_channel.id_ == test_channel.id_ diff --git a/python/lib/sift_client/resources/_base.py b/python/lib/sift_client/resources/_base.py index 8708411be..a46ea0fa7 100644 --- a/python/lib/sift_client/resources/_base.py +++ b/python/lib/sift_client/resources/_base.py @@ -91,7 +91,9 @@ def _build_time_cel_filters( return filter_parts def _build_tags_metadata_cel_filters( - self, tags: list[Any] | list[str] | None = None, metadata: list[Any] | None = None + self, + tags: list[Any] | list[str] | None = None, + metadata: list[Any] | None = None, ) -> list[str]: filter_parts = [] if tags: @@ -106,14 +108,14 @@ def _build_tags_metadata_cel_filters( def _build_common_cel_filters( self, description_contains: str | None = None, - include_archived: bool = False, + include_archived: bool | None = None, filter_query: str | None = None, ) -> list[str]: filter_parts = [] if description_contains: filter_parts.append(cel.contains("description", description_contains)) - if not include_archived: - filter_parts.append(cel.equals("is_archived", False)) + if include_archived is not None: + filter_parts.append(cel.equals("is_archived", include_archived)) if filter_query: filter_parts.append(filter_query) return filter_parts From 4223aa707ca0517012b8fe172e51ae5840ba2117 Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Thu, 9 Oct 2025 11:25:10 -0700 Subject: [PATCH 24/39] add test_calculated_channels.py --- .../sift_client/_tests/integrated/channels.py | 205 -------- .../resources/test_calculated_channels.py | 473 ++++++++++++++++++ .../sift_types/calculated_channel.py | 29 +- 3 files changed, 488 insertions(+), 219 deletions(-) delete mode 100644 python/lib/sift_client/_tests/integrated/channels.py create mode 100644 python/lib/sift_client/_tests/resources/test_calculated_channels.py diff --git a/python/lib/sift_client/_tests/integrated/channels.py b/python/lib/sift_client/_tests/integrated/channels.py deleted file mode 100644 index 73c54e0df..000000000 --- a/python/lib/sift_client/_tests/integrated/channels.py +++ /dev/null @@ -1,205 +0,0 @@ -import asyncio -import os -import time -from datetime import datetime, timezone - -import numpy as np -import pandas as pd -import pyarrow as pa - -from sift_client.client import SiftClient - - -async def main(): - grpc_url = os.getenv("SIFT_GRPC_URI", "localhost:50051") - rest_url = os.getenv("SIFT_REST_URI", "localhost:8080") - api_key = os.getenv("SIFT_API_KEY", "") - client = SiftClient(grpc_url=grpc_url, api_key=api_key, rest_url=rest_url) - - asset = client.assets.find(name="NostromoLV426") - asset_id = asset.id_ - print(f"Using asset: {asset.name} (ID: {asset_id})") - - # List runs for this asset - runs = asset.runs - print( - f"Found {len(runs)} run(s): {[run.name for run in runs]} for asset {asset.name} (ID: {asset_id})" - ) - - # Pick one. - run = runs[0] - run_id = run.id_ - print(f"Using run: {run.name} (ID: {run_id})") - - # List other assets for this run. - all_assets = run.assets - other_assets = [asset for asset in all_assets if asset.id_ != asset_id] - print( - f"Found {len(other_assets)} other asset(s): {other_assets} for run {run.name} (ID: {run_id})" - ) - - # List channels for this asset (find a run w/ data) - channels = [] - for run in runs: - asset_channels = asset.channels(run=run.id_, limit=10) - other_channels = [] - for c in asset_channels: - if c.name in {"voltage", "gpio", "temperature", "mainmotor.velocity"}: - channels.append(c) - else: - other_channels.append(c) - - if len(channels) > 3: - print( - f"Found {len(channels)} channel(s): {[channel.name for channel in channels]} for asset {asset.name} on run {run.name}" - ) - if len(other_channels) > 0: - print( - f"Found {len(other_channels)} other channel(s): {[c.name for c in other_channels]} for asset {asset.name} on run {run.name}" - ) - break - - # Get the channel data during a specific run - channel = channels[0] - channel_data = channel.data(run_id="1d5f5c93-eaaa-48f2-94ff-7ec4337faec7", limit=100) - print(f"Channel data for {channel.name} has {len(channel_data)} points") - - # Get data for multiple channels - print("Getting data for multiple channels:") - perf_start = time.perf_counter() - channel_data = client.channels.get_data( - run="1d5f5c93-eaaa-48f2-94ff-7ec4337faec7", channels=channels, limit=100 - ) - first_time = time.perf_counter() - perf_start - start_time = None - end_time = None - for i, (channel_name, data) in enumerate(channel_data.items()): - print(f"{i}: {channel_name}: {len(data)} points. Avg: {np.mean(data[channel_name])}") - - # Pick a random channel and grab the start end times so we can test the cache - if i == 1: - start_time = data.index[0] - end_time = data.index[-1] - print(f"Start time: {start_time}, End time: {end_time}") - - # Test cache with varying start_time and end_time parameters - if start_time and end_time: - print("\n=== Testing cache with varying time ranges ===") - - # Test 1: Exact same time range (should hit cache) - print("\nTest 1: Exact same time range no run_id (should hit cache)") - perf_start = time.perf_counter() - _ = client.channels.get_data( - channels=channels, - start_time=start_time, - end_time=end_time, - ) - exact_time = time.perf_counter() - perf_start - - # Test 2: Subset of time range (should hit cache if overlapping) - print("\nTest 2: Subset of time range (should hit cache if overlapping)") - mid_time = start_time + (end_time - start_time) / 2 - perf_start = time.perf_counter() - _ = client.channels.get_data( - channels=channels, - start_time=start_time, - end_time=mid_time, - ) - subset_time = time.perf_counter() - perf_start - - # Test 3: Extended time range (should hit cache for overlapping portion) - print("\nTest 3: Extended time range earlier (should hit cache for overlapping portion)") - extended_start = start_time - (end_time - start_time) * 0.1 - perf_start = time.perf_counter() - _ = client.channels.get_data( - channels=channels, - start_time=extended_start, - end_time=end_time, - ) - extended_time = time.perf_counter() - perf_start - - # Test 4: Different time range (should not hit cache) - print("\nTest 4: Different time encompassed range (should hit cache)") - different_start = extended_start + pd.Timedelta(seconds=2) - different_end = start_time + pd.Timedelta(seconds=3) - perf_start = time.perf_counter() - _ = client.channels.get_data( - channels=channels, - start_time=different_start, - end_time=different_end, - ) - different_time = time.perf_counter() - perf_start - - # Test 5: No time range specified (should miss cache from original call) - print("\nTest 5: No time range specified (should miss cache)") - # Since None end time is treated as now, we capture now so we can repeat it. - fake_no_end_time = datetime.now(timezone.utc) - perf_start = time.perf_counter() - channel_data_no_time = client.channels.get_data( - channels=channels, - end_time=fake_no_end_time, - limit=100, - ) - no_time_time = time.perf_counter() - perf_start - for i, (channel_name, data) in enumerate(channel_data_no_time.items()): - print(f"{i}: {channel_name}: {len(data)} points. Avg: {np.mean(data[channel_name])}") - - # Test 6: No time range specified again (should hit cache) - # NOTE: We're not comparing the results since limit combines with cache results. - print("\nTest 6: No time range specified again (should hit cache)") - perf_start = time.perf_counter() - channel_data_no_time = client.channels.get_data( - channels=channels, - end_time=fake_no_end_time, - limit=100, - ) - for i, (channel_name, data) in enumerate(channel_data_no_time.items()): - print( - f"{i}: {channel_name}: {len(data)} points. Avg: {np.mean(data[channel_name]) if channel_name in data else np.nan}" - ) - no_time_time_repeat = time.perf_counter() - perf_start - - # Test 7: Get data as arrow - print("\nTest 7: Get data as arrow") - perf_start = time.perf_counter() - channel_data_arrow = client.channels.get_data_as_arrow( - channels=channels, - end_time=fake_no_end_time, - ) - arrow_time = time.perf_counter() - perf_start - for i, (channel_name, data) in enumerate(channel_data_arrow.items()): - print( - f"{i}: {channel_name}: {len(data)} points. Avg: {pa.compute.mean(data[channel_name])}" - ) - - # Summary of cache performance - print("\n=== Cache Performance Summary ===") - print(f"Original call: {first_time:.4f} seconds") - print( - f"Exact time range no run_id: {exact_time:.4f} seconds ({(first_time / exact_time):.1f}x faster)" - ) - print( - f"Subset time range: {subset_time:.4f} seconds ({(first_time / subset_time):.1f}x faster)" - ) - print( - f"Extended time range earlier: {extended_time:.4f} seconds ({(first_time / extended_time):.1f}x faster)" - ) - print( - f"Different time range: {different_time:.4f} seconds ({(first_time / different_time):.1f}x faster)" - ) - print( - f"No time range: {no_time_time:.4f} seconds ({(no_time_time / first_time):.1f}x slower)" - ) - print( - f"No time range repeat: {no_time_time_repeat:.4f} seconds ({(no_time_time / no_time_time_repeat):.1f}x faster)" - ) - print(f"Arrow: {arrow_time:.4f} seconds ({(arrow_time / no_time_time_repeat):.1f}x faster)") - assert exact_time < first_time - assert subset_time < first_time - assert extended_time < first_time - assert different_time < first_time - assert no_time_time_repeat < no_time_time - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/python/lib/sift_client/_tests/resources/test_calculated_channels.py b/python/lib/sift_client/_tests/resources/test_calculated_channels.py new file mode 100644 index 000000000..17b0ba216 --- /dev/null +++ b/python/lib/sift_client/_tests/resources/test_calculated_channels.py @@ -0,0 +1,473 @@ +"""Pytest tests for the Calculated Channels API. + +These tests demonstrate and validate the usage of the Calculated Channels API including: +- Basic calculated channel operations (get, list, find) +- Calculated channel filtering and searching +- Calculated channel creation, updates, and archiving +- Version management +- Error handling and edge cases +""" + +import uuid + +import pytest + +from sift_client import SiftClient +from sift_client.resources import CalculatedChannelsAPI, CalculatedChannelsAPIAsync +from sift_client.sift_types import CalculatedChannel +from sift_client.sift_types.calculated_channel import ( + CalculatedChannelCreate, + CalculatedChannelUpdate, +) +from sift_client.sift_types.channel import ChannelReference + +pytestmark = pytest.mark.integration + + +def test_client_binding(sift_client): + assert sift_client.calculated_channels + assert isinstance(sift_client.calculated_channels, CalculatedChannelsAPI) + assert sift_client.async_.calculated_channels + assert isinstance( + sift_client.async_.calculated_channels, CalculatedChannelsAPIAsync + ) + + +@pytest.fixture +def calculated_channels_api_async(sift_client: SiftClient): + """Get the async calculated channels API instance.""" + return sift_client.async_.calculated_channels + + +@pytest.fixture +def calculated_channels_api_sync(sift_client: SiftClient): + """Get the synchronous calculated channels API instance.""" + return sift_client.calculated_channels + + +@pytest.fixture +def test_calculated_channel(calculated_channels_api_sync): + calculated_channels = calculated_channels_api_sync.list_(limit=1) + assert calculated_channels + assert len(calculated_channels) >= 1 + return calculated_channels[0] + + +@pytest.fixture(scope="function") +def new_calculated_channel(calculated_channels_api_sync, sift_client): + """Create a test calculated channel for update tests.""" + from datetime import datetime, timezone + + calc_channel_name = f"test_calc_channel_{datetime.now(timezone.utc).isoformat()}" + description = "Test calculated channel created by Sift Client pytest" + + # Get some channels to reference + channels = sift_client.channels.list_(limit=2) + assert len(channels) >= 2 + + created_calc_channel = calculated_channels_api_sync.create( + CalculatedChannelCreate( + name=calc_channel_name, + client_key=f"test_calc_chan_{str(uuid.uuid4())[-8:]}", + description=description, + expression="$1 + $2", + expression_channel_references=[ + ChannelReference( + channel_reference="$1", channel_identifier=channels[0].name + ), + ChannelReference( + channel_reference="$2", channel_identifier=channels[1].name + ), + ], + all_assets=True, + ) + ) + return created_calc_channel + + +class TestCalculatedChannelsAPIAsync: + """Test suite for the async Calculated Channels API functionality.""" + + class TestGet: + """Tests for the async get method.""" + + @pytest.mark.asyncio + async def test_get_by_id( + self, calculated_channels_api_async, test_calculated_channel + ): + """Test getting a specific calculated channel by ID.""" + retrieved_calc_channel = await calculated_channels_api_async.get( + calculated_channel_id=test_calculated_channel.id_ + ) + + assert isinstance(retrieved_calc_channel, CalculatedChannel) + assert retrieved_calc_channel.id_ == test_calculated_channel.id_ + assert retrieved_calc_channel.name == test_calculated_channel.name + + @pytest.mark.asyncio + async def test_get_by_client_key( + self, calculated_channels_api_async, test_calculated_channel + ): + """Test getting a specific calculated channel by client key.""" + retrieved_calc_channel = await calculated_channels_api_async.get( + client_key=test_calculated_channel.client_key + ) + + assert retrieved_calc_channel is not None + assert retrieved_calc_channel.id_ == test_calculated_channel.id_ + + class TestList: + """Tests for the async list_ method.""" + + @pytest.mark.asyncio + async def test_basic_list(self, calculated_channels_api_async): + """Test basic calculated channel listing functionality.""" + calc_channels = await calculated_channels_api_async.list_(limit=5) + + assert isinstance(calc_channels, list) + assert len(calc_channels) == 5 + + calc_channel = calc_channels[0] + assert isinstance(calc_channel, CalculatedChannel) + + @pytest.mark.asyncio + async def test_list_with_name_filter(self, calculated_channels_api_async): + """Test calculated channel listing with name filtering.""" + all_calc_channels = await calculated_channels_api_async.list_(limit=10) + + test_calc_channel_name = all_calc_channels[0].name + filtered_calc_channels = await calculated_channels_api_async.list_( + name=test_calc_channel_name + ) + + assert isinstance(filtered_calc_channels, list) + assert len(filtered_calc_channels) >= 1 + + for calc_channel in filtered_calc_channels: + assert calc_channel.name == test_calc_channel_name + + @pytest.mark.asyncio + async def test_list_with_name_contains_filter( + self, calculated_channels_api_async + ): + """Test calculated channel listing with name contains filtering.""" + calc_channels = await calculated_channels_api_async.list_( + name_contains="test", limit=5 + ) + + assert isinstance(calc_channels, list) + assert calc_channels + + for calc_channel in calc_channels: + assert "test" in calc_channel.name.lower() + + @pytest.mark.asyncio + async def test_list_with_name_regex_filter(self, calculated_channels_api_async): + """Test calculated channel listing with regex name filtering.""" + calc_channels = await calculated_channels_api_async.list_( + name_regex=r".*test.*", limit=5 + ) + + assert isinstance(calc_channels, list) + assert calc_channels + + import re + + pattern = re.compile(r".*test.*", re.IGNORECASE) + for calc_channel in calc_channels: + assert pattern.match(calc_channel.name) + + @pytest.mark.asyncio + async def test_list_with_limit(self, calculated_channels_api_async): + """Test calculated channel listing with different limits.""" + calc_channels_1 = await calculated_channels_api_async.list_(limit=1) + assert isinstance(calc_channels_1, list) + assert len(calc_channels_1) <= 1 + + calc_channels_3 = await calculated_channels_api_async.list_(limit=3) + assert isinstance(calc_channels_3, list) + assert len(calc_channels_3) <= 3 + + class TestFind: + """Tests for the async find method.""" + + @pytest.mark.asyncio + async def test_find_calculated_channel( + self, calculated_channels_api_async, test_calculated_channel + ): + """Test finding a single calculated channel.""" + found_calc_channel = await calculated_channels_api_async.find( + name=test_calculated_channel.name + ) + + assert found_calc_channel is not None + assert found_calc_channel.id_ == test_calculated_channel.id_ + + @pytest.mark.asyncio + async def test_find_nonexistent_calculated_channel( + self, calculated_channels_api_async + ): + """Test finding a non-existent calculated channel returns None.""" + found_calc_channel = await calculated_channels_api_async.find( + name="nonexistent-calculated-channel-name-12345" + ) + assert found_calc_channel is None + + @pytest.mark.asyncio + async def test_find_multiple_raises_error(self, calculated_channels_api_async): + """Test finding multiple calculated channels raises an error.""" + with pytest.raises(ValueError, match="Multiple"): + await calculated_channels_api_async.find(name_contains="test", limit=5) + + class TestCreate: + """Tests for the async create method.""" + + @pytest.mark.asyncio + async def test_create_basic_calculated_channel( + self, calculated_channels_api_async + ): + """Test creating a basic calculated channel with minimal fields.""" + from datetime import datetime, timezone + + calc_channel_name = ( + f"test_calc_channel_create_{datetime.now(timezone.utc).isoformat()}" + ) + description = "Test calculated channel created by Sift Client pytest" + + channels = await calculated_channels_api_async.client.async_.channels.list_( + limit=2 + ) + assert len(channels) >= 2 + + calc_channel_create = CalculatedChannelCreate( + name=calc_channel_name, + description=description, + expression="$1 + $2", + expression_channel_references=[ + ChannelReference( + channel_reference="$1", channel_identifier=channels[0].name + ), + ChannelReference( + channel_reference="$2", channel_identifier=channels[1].name + ), + ], + all_assets=True, + ) + + created_calc_channel = await calculated_channels_api_async.create( + calc_channel_create + ) + + try: + assert created_calc_channel is not None + assert isinstance(created_calc_channel, CalculatedChannel) + assert created_calc_channel.id_ is not None + assert created_calc_channel.name == calc_channel_name + assert created_calc_channel.description == description + assert created_calc_channel.created_date is not None + assert created_calc_channel.modified_date is not None + finally: + await calculated_channels_api_async.archive(created_calc_channel) + + @pytest.mark.asyncio + async def test_create_calculated_channel_with_dict( + self, calculated_channels_api_async + ): + """Test creating a calculated channel using a dictionary.""" + from datetime import datetime, timezone + + calc_channel_name = ( + f"test_calc_channel_dict_{datetime.now(timezone.utc).isoformat()}" + ) + description = "Test calculated channel created by Sift Client pytest" + + channels = await calculated_channels_api_async.client.async_.channels.list_( + limit=2 + ) + assert len(channels) >= 2 + + calc_channel_dict = { + "name": calc_channel_name, + "description": description, + "expression": "$1 + $2", + "expression_channel_references": [ + {"channel_reference": "$1", "channel_identifier": channels[0].name}, + {"channel_reference": "$2", "channel_identifier": channels[1].name}, + ], + "all_assets": True, + } + + created_calc_channel = await calculated_channels_api_async.create( + calc_channel_dict + ) + + try: + assert created_calc_channel.name == calc_channel_name + assert created_calc_channel.description == description + finally: + await calculated_channels_api_async.archive(created_calc_channel) + + class TestUpdate: + """Tests for the async update method.""" + + @pytest.mark.asyncio + async def test_update_calculated_channel_description( + self, calculated_channels_api_async, new_calculated_channel + ): + """Test updating a calculated channel's description.""" + try: + update = CalculatedChannelUpdate(description="Updated description") + updated_calc_channel = await calculated_channels_api_async.update( + new_calculated_channel, update + ) + + assert updated_calc_channel.id_ == new_calculated_channel.id_ + assert updated_calc_channel.description == "Updated description" + assert updated_calc_channel.name == new_calculated_channel.name + finally: + await calculated_channels_api_async.archive(new_calculated_channel.id_) + + @pytest.mark.asyncio + async def test_update_calculated_channel_name( + self, calculated_channels_api_async, new_calculated_channel + ): + """Test updating a calculated channel's name.""" + try: + new_name = f"updated_{new_calculated_channel.name}" + update = CalculatedChannelUpdate(name=new_name) + updated_calc_channel = await calculated_channels_api_async.update( + new_calculated_channel, update + ) + + assert updated_calc_channel.name == new_name + assert updated_calc_channel.id_ == new_calculated_channel.id_ + finally: + await calculated_channels_api_async.archive(new_calculated_channel.id_) + + @pytest.mark.asyncio + async def test_update_with_dict( + self, calculated_channels_api_async, new_calculated_channel + ): + """Test updating a calculated channel using a dictionary.""" + try: + update_dict = {"description": "Updated via dict"} + updated_calc_channel = await calculated_channels_api_async.update( + new_calculated_channel, update_dict + ) + + assert updated_calc_channel.description == "Updated via dict" + finally: + await calculated_channels_api_async.archive(new_calculated_channel.id_) + + @pytest.mark.asyncio + async def test_update_with_id_string( + self, calculated_channels_api_async, new_calculated_channel + ): + """Test updating a calculated channel by passing ID as string.""" + try: + update = CalculatedChannelUpdate(description="Updated via ID string") + updated_calc_channel = await calculated_channels_api_async.update( + new_calculated_channel.id_, update + ) + + assert updated_calc_channel.id_ == new_calculated_channel.id_ + assert updated_calc_channel.description == "Updated via ID string" + finally: + await calculated_channels_api_async.archive(new_calculated_channel.id_) + + class TestArchive: + """Tests for the async archive method.""" + + @pytest.mark.asyncio + async def test_archive_calculated_channel( + self, calculated_channels_api_async, new_calculated_channel + ): + """Test archiving a calculated channel.""" + calc_channel = await calculated_channels_api_async.archive( + new_calculated_channel + ) + + assert isinstance(calc_channel, CalculatedChannel) + assert calc_channel.id_ == new_calculated_channel.id_ + assert calc_channel.is_archived is True + + calc_channels_without_archived = await calculated_channels_api_async.list_( + name=new_calculated_channel.name, include_archived=False + ) + assert len(calc_channels_without_archived) == 0 + + calc_channels_with_archived = await calculated_channels_api_async.list_( + name=new_calculated_channel.name, include_archived=True + ) + assert len(calc_channels_with_archived) == 1 + assert calc_channels_with_archived[0].id_ == new_calculated_channel.id_ + assert calc_channels_with_archived[0].archived_date is not None + + @pytest.mark.asyncio + async def test_archive_with_id_string( + self, calculated_channels_api_async, new_calculated_channel + ): + """Test archiving a calculated channel by passing ID as string.""" + calc_channel = await calculated_channels_api_async.archive( + new_calculated_channel.id_ + ) + + assert isinstance(calc_channel, CalculatedChannel) + assert calc_channel.id_ == new_calculated_channel.id_ + assert calc_channel.is_archived is True + + class TestUnarchive: + """Tests for the async unarchive method.""" + + @pytest.mark.asyncio + async def test_unarchive_calculated_channel( + self, calculated_channels_api_async, new_calculated_channel + ): + """Test unarchiving a calculated channel.""" + try: + await calculated_channels_api_async.archive(new_calculated_channel) + + calc_channel = await calculated_channels_api_async.unarchive( + new_calculated_channel + ) + + assert isinstance(calc_channel, CalculatedChannel) + assert calc_channel.id_ == new_calculated_channel.id_ + assert calc_channel.is_archived is False + finally: + await calculated_channels_api_async.archive(new_calculated_channel.id_) + + class TestListVersions: + """Tests for the async list_versions method.""" + + @pytest.mark.asyncio + async def test_list_versions( + self, calculated_channels_api_async, test_calculated_channel + ): + """Test listing versions of a calculated channel.""" + versions = await calculated_channels_api_async.list_versions( + calculated_channel=test_calculated_channel + ) + + assert isinstance(versions, list) + assert len(versions) >= 1 + + for version in versions: + assert isinstance(version, CalculatedChannel) + assert version.name == test_calculated_channel.name + + +class TestCalculatedChannelsAPISync: + """Test suite for the synchronous Calculated Channels API functionality.""" + + class TestGet: + """Tests for the sync get method.""" + + def test_get_by_id(self, calculated_channels_api_sync, test_calculated_channel): + """Test getting a specific calculated channel by ID synchronously.""" + retrieved_calc_channel = calculated_channels_api_sync.get( + calculated_channel_id=test_calculated_channel.id_ + ) + + assert retrieved_calc_channel is not None + assert retrieved_calc_channel.id_ == test_calculated_channel.id_ diff --git a/python/lib/sift_client/sift_types/calculated_channel.py b/python/lib/sift_client/sift_types/calculated_channel.py index 987bb386a..7388c3bf1 100644 --- a/python/lib/sift_client/sift_types/calculated_channel.py +++ b/python/lib/sift_client/sift_types/calculated_channel.py @@ -135,7 +135,6 @@ class CalculatedChannelBase(ModelCreateUpdateBase): """Base class for CalculatedChannel create and update models with shared fields and validation.""" description: str | None = None - user_notes: str | None = None units: str | None = None expression: str | None = None @@ -180,8 +179,14 @@ class CalculatedChannelBase(ModelCreateUpdateBase): @model_validator(mode="after") def _validate_asset_configuration(self): """Validate that either all_assets is True or at least one of tag_ids or asset_ids is provided, but not both.""" - if self.all_assets is not None and self.all_assets and (self.asset_ids or self.tag_ids): - raise ValueError("Cannot specify both all_assets=True and asset_ids/tag_ids") + if ( + self.all_assets is not None + and self.all_assets + and (self.asset_ids or self.tag_ids) + ): + raise ValueError( + "Cannot specify both all_assets=True and asset_ids/tag_ids" + ) return self @model_validator(mode="after") @@ -194,31 +199,27 @@ def _validate_expression_and_channel_references(self): return self -class CalculatedChannelCreate(CalculatedChannelBase, ModelCreate[CreateCalculatedChannelRequest]): +class CalculatedChannelCreate( + CalculatedChannelBase, ModelCreate[CreateCalculatedChannelRequest] +): """Create model for a Calculated Channel.""" name: str + user_notes: str | None = None client_key: str | None = None def _get_proto_class(self) -> type[CreateCalculatedChannelRequest]: return CreateCalculatedChannelRequest -class CalculatedChannelUpdate(CalculatedChannelBase, ModelUpdate[CalculatedChannelProto]): +class CalculatedChannelUpdate( + CalculatedChannelBase, ModelUpdate[CalculatedChannelProto] +): """Update model for a Calculated Channel.""" name: str | None = None is_archived: bool | None = None - @model_validator(mode="after") - def _validate_non_updatable_fields(self): - """Validate that the fields that cannot be updated are not set.""" - if self.user_notes is not None: - raise ValueError("Cannot update user notes") - if self.client_key is not None: - raise ValueError("Cannot update client key") - return self - def _get_proto_class(self) -> type[CalculatedChannelProto]: return CalculatedChannelProto From dc1e389a214de250e79efb7f96779a23765da642 Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Thu, 9 Oct 2025 13:41:05 -0700 Subject: [PATCH 25/39] add test_rules.py --- .../_internal/low_level_wrappers/rules.py | 55 ++- .../_tests/integrated/calculated_channels.py | 251 ---------- .../sift_client/_tests/integrated/rules.py | 261 ---------- .../_tests/resources/test_assets.py | 2 - .../resources/test_calculated_channels.py | 84 +--- .../_tests/resources/test_rules.py | 458 ++++++++++++++++++ .../sift_client/_tests/resources/test_runs.py | 14 +- .../examples/generic_workflow_example.py | 2 +- .../resources/sync_stubs/__init__.pyi | 2 +- python/lib/sift_client/sift_types/_base.py | 13 +- .../sift_types/calculated_channel.py | 19 +- python/lib/sift_client/sift_types/rule.py | 23 +- 12 files changed, 557 insertions(+), 627 deletions(-) delete mode 100644 python/lib/sift_client/_tests/integrated/calculated_channels.py delete mode 100644 python/lib/sift_client/_tests/integrated/rules.py create mode 100644 python/lib/sift_client/_tests/resources/test_rules.py diff --git a/python/lib/sift_client/_internal/low_level_wrappers/rules.py b/python/lib/sift_client/_internal/low_level_wrappers/rules.py index bf5bdb905..b578b6177 100644 --- a/python/lib/sift_client/_internal/low_level_wrappers/rules.py +++ b/python/lib/sift_client/_internal/low_level_wrappers/rules.py @@ -135,7 +135,8 @@ async def create_rule( ) conditions_request = [ UpdateConditionRequest( - expression=expression_proto, actions=[create.action._to_update_request()] + expression=expression_proto, + actions=[create.action._to_update_request()], ) ] update_request = UpdateRuleRequest( @@ -183,9 +184,7 @@ def _update_rule_request_from_update( "asset_tag_ids", ] # Need to manually copy fields that will be reset even if not provided in update dict. - copy_unset_fields = [ - "description", - ] + copy_unset_fields = ["description", "name"] # Populate the trivial fields first. update_dict.update( @@ -214,15 +213,17 @@ def _update_rule_request_from_update( "Expression and channel_references must both be provided or both be None" ) expression_proto = RuleConditionExpression( - calculated_channel=CalculatedChannelConfig( - expression=expression, - channel_references={ - c.channel_reference: ChannelReferenceProto(name=c.channel_identifier) - for c in channel_references - }, + calculated_channel=( + CalculatedChannelConfig( + expression=expression, + channel_references={ + c.channel_reference: ChannelReferenceProto(name=c.channel_identifier) + for c in channel_references + }, + ) + if expression + else None ) - if expression - else None ) conditions_request = [ UpdateConditionRequest( @@ -238,10 +239,10 @@ def _update_rule_request_from_update( # This always needs to be set, so handle the defaults. update_dict["asset_configuration"] = RuleAssetConfiguration( # type: ignore - asset_ids=update.asset_ids if "asset_ids" in model_dump else rule.asset_ids or [], - tag_ids=update.asset_tag_ids - if "asset_tag_ids" in model_dump - else rule.asset_tag_ids or [], + asset_ids=(update.asset_ids if "asset_ids" in model_dump else rule.asset_ids or []), + tag_ids=( + update.asset_tag_ids if "asset_tag_ids" in model_dump else rule.asset_tag_ids or [] + ), ) update_request = UpdateRuleRequest( @@ -254,7 +255,7 @@ def _update_rule_request_from_update( async def update_rule( self, rule: Rule, update: RuleUpdate, version_notes: str | None = None ) -> Rule: - """Update a rule. + """Update a rule. Also handles archive/unarchive to behave similar to other low-level clients. Args: rule: The rule to update. @@ -264,14 +265,26 @@ async def update_rule( Returns: The updated Rule. """ + + should_update_archive = "is_archived" in update.model_fields_set + update.resource_id = rule.id_ + if not should_update_archive or ( + should_update_archive and len(update.model_fields_set) > 1 + ): + update_request = self._update_rule_request_from_update(rule, update, version_notes) + + response = await self._grpc_client.get_stub(RuleServiceStub).UpdateRule(update_request) + _ = cast("UpdateRuleResponse", response) - update_request = self._update_rule_request_from_update(rule, update, version_notes) + if should_update_archive: + if update.is_archived: + await self.archive_rule(rule_id=rule.id_) + else: + await self.unarchive_rule(rule_id=rule.id_) - response = await self._grpc_client.get_stub(RuleServiceStub).UpdateRule(update_request) - updated_grpc_rule = cast("UpdateRuleResponse", response) # Get the updated rule - return await self.get_rule(rule_id=updated_grpc_rule.rule_id) + return await self.get_rule(rule_id=rule.id_) async def batch_update_rules(self, rules: list[RuleUpdate]) -> BatchUpdateRulesResponse: """Batch update rules. diff --git a/python/lib/sift_client/_tests/integrated/calculated_channels.py b/python/lib/sift_client/_tests/integrated/calculated_channels.py deleted file mode 100644 index 2d423b17c..000000000 --- a/python/lib/sift_client/_tests/integrated/calculated_channels.py +++ /dev/null @@ -1,251 +0,0 @@ -import asyncio -import os -from datetime import datetime, timezone - -from sift_client.client import SiftClient - -# Import sift_client types for calculated channels and rules -from sift_client.sift_types import ( - CalculatedChannelUpdate, - ChannelReference, -) -from sift_client.sift_types.calculated_channel import CalculatedChannelCreate - -""" -Comprehensive test script for calculated channels with extensive update field exercises. - -This test demonstrates all available update fields for calculated channels: -- name: Update the channel name -- description: Update the channel description -- units: Update the units of measurement -- expression: Update the calculation expression -- expression_channel_references: Update channel references (must be updated with expression) -- asset_ids: Update which assets the channel applies to -- tag_ids: Update associated tags - -The test also includes: -- Edge case testing (minimal updates, invalid expressions) -- Batch operations demonstration -- Comprehensive validation -- Error handling and graceful fallbacks -- Archive operations - -TODO: TBD if we move this to an example or keep it here as a test expected to be used just by us. - -If we keep it as a test, we should ideally have a setup that populates data, and then ensure we teardown all the test assets/channels/rules etc. -""" - - -async def main(): - grpc_url = os.getenv("SIFT_GRPC_URI", "localhost:50051") - api_key = os.getenv("SIFT_API_KEY", "") - rest_url = os.getenv("SIFT_REST_URI", "localhost:8080") - client = SiftClient(grpc_url=grpc_url, api_key=api_key, rest_url=rest_url) - - # Find assets to work with - asset = client.assets.find(name="NostromoLV426") - asset_id = asset.id_ - print(f"Using asset: {asset.name} (ID: {asset_id})") - - # Create example calculated channels that will be unique to this test run in case things don't cleanup. - num_channels = 7 - unique_name_suffix = datetime.now(tz=timezone.utc).strftime("%Y%m%d%H%M%S") - print( - f"\n=== Creating {num_channels} calculated channels with unique suffix: {unique_name_suffix} ===" - ) - - created_channels = [] - for i in range(num_channels): - new_chan = CalculatedChannelCreate( - name=f"test_channel_{unique_name_suffix}_{i}", - description=f"Test calculated channel {i} - initial description", - expression="$1 / $2", # $1 = mainmotor.velocity, $2 = voltage - expression_channel_references=[ - ChannelReference(channel_reference="$1", channel_identifier="mainmotor.velocity"), - ChannelReference(channel_reference="$2", channel_identifier="voltage"), - ], - units="velocity/voltage", - asset_ids=[asset_id], - user_notes=f"Created for testing update fields - channel {i}", - ) - calculated_channel = client.calculated_channels.create(new_chan) - created_channels.append(calculated_channel) - print( - f"Created calculated channel: {calculated_channel.name} (ID: {calculated_channel.id_})" - ) - - # Find the channels we just created - search_results = client.calculated_channels.list( - name_regex="test_channel.*", - asset_id=asset_id, - ) - print(f"Found {len(search_results)} calculated channels: {[cc.name for cc in search_results]}") - - print("\n=== Testing comprehensive update scenarios ===") - - # Test 1: Update expression and channel references together - print("\n--- Test 1: Update expression and channel references ---") - channel_1 = created_channels[0] - updated_channel_1 = channel_1.update( - CalculatedChannelUpdate( - expression="$1 / $2 * 100", # Convert to percentage - expression_channel_references=[ - ChannelReference(channel_reference="$1", channel_identifier="mainmotor.velocity"), - ChannelReference(channel_reference="$2", channel_identifier="voltage"), - ], - ) - ) - print(f"Updated {updated_channel_1.name}: expression = {updated_channel_1.expression}") - - # Test 2: Update description - print("\n--- Test 2: Update description ---") - channel_2 = created_channels[1] - updated_channel_2 = channel_2.update( - CalculatedChannelUpdate( - description="Updated description with more details about velocity-to-voltage ratio calculation", - ) - ) - print(f"Updated {updated_channel_2.name}: description = {updated_channel_2.description}") - - # Test 3: Update units - print("\n--- Test 3: Update units ---") - channel_3 = created_channels[2] - updated_channel_3 = channel_3.update( - CalculatedChannelUpdate( - units="percentage", - ) - ) - print(f"Updated {updated_channel_3.name}: units = {updated_channel_3.units}") - - # Test 4: Update name - print("\n--- Test 4: Update name ---") - channel_4 = created_channels[3] - new_name = f"renamed_channel_{unique_name_suffix}_5" - updated_channel_4 = channel_4.update( - CalculatedChannelUpdate( - name=new_name, - ) - ) - print(f"Updated {channel_4.name} -> {updated_channel_4.name}") - - # Test 5: Update multiple fields at once - print("\n--- Test 5: Update multiple fields simultaneously ---") - channel_5 = created_channels[4] - updated_channel_5 = channel_5.update( - CalculatedChannelUpdate( - description="Multi-field update test", - units="ratio", - ), - user_notes="Updated via multi-field update", - ) - print(f"Updated {updated_channel_5.name}:") - print(f" - description: {updated_channel_5.description}") - print(f" - units: {updated_channel_5.units}") - print(f" - user_notes: {updated_channel_5.user_notes}") - - # Test 6: Update with complex expression - print("\n--- Test 6: Update with complex expression ---") - channel_6 = created_channels[5] - updated_channel_6 = channel_6.update( - CalculatedChannelUpdate( - expression="($1 / $2) * 100 + ($3 * 0.1)", # Complex expression with 3 variables - expression_channel_references=[ - ChannelReference(channel_reference="$1", channel_identifier="mainmotor.velocity"), - ChannelReference(channel_reference="$2", channel_identifier="voltage"), - ChannelReference(channel_reference="$3", channel_identifier="temperature"), - ], - ) - ) - print(f"Updated {updated_channel_6.name}: complex expression = {updated_channel_6.expression}") - - # Test 7: Update tag_ids (if tags are available) - print("\n--- Test 7: Update tag_ids ---") - channel_7 = created_channels[6] - # Note: This would require actual tag IDs from the system - # For now, we'll test with an empty list to show the capability - updated_channel_7 = channel_7.update( - CalculatedChannelUpdate( - tag_ids=[], # Empty list - in practice you'd use actual tag IDs - ) - ) - print(f"Updated {updated_channel_7.name}: tag_ids = {updated_channel_7.tag_ids}") - - # Test 7b: Edge case - Update with invalid expression (should fail gracefully) - print("\n--- Test 7b: Edge case - Invalid expression test ---") - try: - invalid_update = channel_7.update( - CalculatedChannelUpdate( - expression="invalid_expression", - expression_channel_references=[ - ChannelReference( - channel_reference="$1", channel_identifier="mainmotor.velocity" - ), - ], - ) - ) - print(f"Invalid expression update succeeded (unexpected): {invalid_update.expression}") - # TODO: Ticket this? - except Exception as e: - print(f"Invalid expression update failed as expected: {e}") - - # Test 8: Archive channels - print("\n--- Test 8: Archive channels ---") - archived_count = 0 - for cc in created_channels: - cc.archive() - print(f"Archived: {cc.name}") - archived_count += 1 - - print("\n=== Test Summary ===") - print(f"Created: {len(created_channels)} channels") - print(f"Archived: {archived_count} channels") - - # Verify all channels were processed - assert len(created_channels) == num_channels, ( - f"Expected {num_channels} created channels, got {len(created_channels)}" - ) - assert archived_count == num_channels, ( - f"Expected {num_channels} archived channels, got {archived_count}" - ) - - # Additional validation - print("\n=== Validation Checks ===") - - # Verify that updates actually changed the values - assert updated_channel_1.expression == "$1 / $2 * 100", ( - f"Expression update failed: {updated_channel_1.expression}" - ) - assert "more details" in updated_channel_2.description, ( - f"Description update failed: {updated_channel_2.description}" - ) - assert updated_channel_3.units == "percentage", ( - f"Units update failed: {updated_channel_3.units}" - ) - assert updated_channel_4.name == new_name, f"Name update failed: {updated_channel_4.name}" - assert updated_channel_5.description == "Multi-field update test", ( - f"Description update failed: {updated_channel_5.description}" - ) - assert updated_channel_5.units == "ratio", f"Units update failed: {updated_channel_5.units}" - assert updated_channel_5.user_notes == "Updated via multi-field update", ( - f"User notes update failed: {updated_channel_5.user_notes}" - ) - assert updated_channel_6.expression == "($1 / $2) * 100 + ($3 * 0.1)", ( - f"Complex expression update failed: {updated_channel_6.expression}" - ) - assert len(updated_channel_6.channel_references) == 3, ( - f"Complex expression should have 3 references, got {len(updated_channel_6.channel_references)}" - ) - assert updated_channel_7.tag_ids == [], f"Tag IDs update failed: {updated_channel_7.tag_ids}" - - versions = client.calculated_channels.list_versions( - calculated_channel=channel_1.id_, - limit=10, - ) - print(f"Found {len(versions)} versions for {created_channels[0].name}") - - print("All validation checks passed!") - print("\n=== Test completed successfully ===") - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/python/lib/sift_client/_tests/integrated/rules.py b/python/lib/sift_client/_tests/integrated/rules.py deleted file mode 100644 index dbe88aeea..000000000 --- a/python/lib/sift_client/_tests/integrated/rules.py +++ /dev/null @@ -1,261 +0,0 @@ -import os -from datetime import datetime, timezone - -from sift_client.client import SiftClient - -# Import sift_client types for calculated channels and rules -from sift_client.sift_types import ( - ChannelReference, - RuleAction, - RuleAnnotationType, - RuleUpdate, -) -from sift_client.sift_types.rule import RuleCreate - -""" -Comprehensive test script for rules with extensive update field exercises. - -This test demonstrates all available update fields for rules: -- name: Update the rule name -- description: Update the rule description -- expression: Update the rule expression -- channel_references: Update channel references (must be updated with expression) -- action: Update the rule action (annotation, notification, webhook) -- tag_ids: Update associated tags (TBD) -- contextual_channels: Update contextual channels -- version_notes: Update version notes - -The test also includes: -- Edge case testing (invalid expressions) -- Batch operations demonstration -- Comprehensive validation -- Archive operations - - -If we keep it as a test, we should ideally have a setup that populates data, and then ensure we teardown all the test assets/channels/rules etc. -""" - - -def main(): - grpc_url = os.getenv("SIFT_GRPC_URI", "localhost:50051") - api_key = os.getenv("SIFT_API_KEY", "") - rest_url = os.getenv("SIFT_REST_URI", "localhost:8080") - client = SiftClient(grpc_url=grpc_url, api_key=api_key, rest_url=rest_url) - - asset = client.assets.find(name="NostromoLV426") - asset_id = asset.id_ - print(f"Using asset: {asset.name} (ID: {asset_id})") - - unique_name_suffix = datetime.now(tz=timezone.utc).strftime("%Y%m%d%H%M%S") - num_rules = 8 - print(f"\n=== Creating {num_rules} rules with unique suffix: {unique_name_suffix} ===") - created_rules = [] - for i in range(num_rules): - rule = client.rules.create( - RuleCreate( - name=f"test_rule_{unique_name_suffix}_{i}", - description=f"Test rule {i} - initial description", - expression="$1 > 0.1", # Simple threshold check - channel_references=[ - ChannelReference( - channel_reference="$1", channel_identifier="mainmotor.velocity" - ), - ], - action=RuleAction.annotation( - annotation_type=RuleAnnotationType.DATA_REVIEW, - tags=["test", "initial"], - default_assignee_user_id=None, - ), - asset_ids=[asset_id], - ) - ) - created_rules.append(rule) - print(f"Created rule: {rule.name} (ID: {rule.id_})") - - # Find the rules we just created - search_results = client.rules.list_( - name_regex=f"test_rule_{unique_name_suffix}.*", - ) - assert len(search_results) == num_rules, ( - f"Expected {num_rules} created rules, got {len(search_results)}" - ) - - print("\n=== Testing comprehensive update scenarios ===") - - # Test 1: Update expression and channel references together - print("\n--- Test 1: Update expression and channel references ---") - rule_1 = created_rules[0] - rule_1_model_dump = rule_1.model_dump() - updated_rule_1 = rule_1.update( - RuleUpdate( - expression="$1 > 0.5", # Higher threshold - channel_references=[ - ChannelReference(channel_reference="$1", channel_identifier="mainmotor.velocity"), - ], - ) - ) - updated_rule_1_model_dump = updated_rule_1.model_dump() - print(f"Updated {updated_rule_1.name}: expression = {updated_rule_1.expression}") - - # Test 2: Update description - print("\n--- Test 2: Update description ---") - rule_2 = created_rules[1] - updated_rule_2 = rule_2.update( - RuleUpdate( - description="Updated description with more details about velocity-to-voltage ratio monitoring", - ) - ) - print(f"Updated {updated_rule_2.name}: description = {updated_rule_2.description}") - - # Test 3: Update action (change annotation type and tags) - print("\n--- Test 3: Update action ---") - rule_3 = created_rules[2] - updated_rule_3 = rule_3.update( - RuleUpdate( - action=RuleAction.annotation( - annotation_type=RuleAnnotationType.PHASE, - tags=["updated", "phase", "alert"], - default_assignee_user_id=rule_3.created_by_user_id, - ), - ) - ) - print(f"Updated {updated_rule_3.name}: action type = {updated_rule_3.action.action_type}") - print(f" - annotation type: {updated_rule_3.action.annotation_type}") - print(f" - tags: {updated_rule_3.action.tags}") - print(f" - assignee: {updated_rule_3.action.default_assignee_user_id}") - - # Test 4: Update name - print("\n--- Test 4: Update name ---") - rule_4 = created_rules[3] - new_name = f"renamed_rule_{unique_name_suffix}_4" - updated_rule_4 = rule_4.update( - RuleUpdate( - name=new_name, - ) - ) - print(f"Updated {rule_4.name} -> {updated_rule_4.name}") - - # Test 5: Update multiple fields at once - print("\n--- Test 5: Update multiple fields simultaneously ---") - rule_5 = created_rules[4] - updated_rule_5 = rule_5.update( - RuleUpdate( - description="Multi-field update test", - ), - version_notes="Updated via multi-field update", - ) - print(f"Updated {updated_rule_5.name}:") - print(f" - description: {updated_rule_5.description}") - print( - f" - version_notes: {updated_rule_5.rule_version.version_notes if updated_rule_5.rule_version else None}" - ) - - # Test 6: Update with complex expression - print("\n--- Test 6: Update with complex expression ---") - rule_6 = created_rules[5] - updated_rule_6 = rule_6.update( - RuleUpdate( - expression="$1 > 0.3 && $1 < 0.8", # Range check - channel_references=[ - ChannelReference(channel_reference="$1", channel_identifier="mainmotor.velocity"), - ], - ) - ) - print(f"Updated {updated_rule_6.name}: complex expression = {updated_rule_6.expression}") - - # Test 7: Update action to notification type - # print("\n--- Test 7: Update action to notification ---") - # rule_7 = created_rules[6] - # updated_rule_7 = rule_7 - # Note: Notification actions are not supported yet. - # updated_rule_7 = rule_7.update( - # RuleUpdate( - # action=RuleAction.notification( - # notify_recipients=[rule_7.created_by_user_id] - # ), - # ) - # ) - # print(f"Updated {updated_rule_7.name}: action type = {updated_rule_7.action.action_type}") - # print(f" - notification recipients: {updated_rule_7.action.notification_recipients}") - - # Test 8: Update tag_ids and contextual_channels - print("\n--- Test 8: Update tag_ids and contextual_channels ---") - rule_8 = created_rules[7] - updated_rule_8 = rule_8.update( - RuleUpdate( - # tag_ids=["tag-123", "tag-456"], # Example tag IDs # TODO: Where are these IDs supposed to come from? They're supposed to be uuids? {grpc_message:"invalid argument: invalid input syntax for type uuid: \"tag-123\" - contextual_channels=[ - "temperature", - "pressure", - ], # Example contextual channels - ) - ) - print(f"Updated {updated_rule_8.name}:") - print(f" - asset_tag_ids: {updated_rule_8.asset_tag_ids}") - print(f" - contextual_channels: {updated_rule_8.contextual_channels}") - - # Test 8b: Edge case - Update with invalid expression (should fail gracefully) - print("\n--- Test 8b: Edge case - Invalid expression test ---") - try: - invalid_update = rule_8.update( - RuleUpdate( - expression="invalid_expression", - channel_references=[ - ChannelReference( - channel_reference="$1", channel_identifier="mainmotor.velocity" - ), - ], - ) - ) - print(f"Invalid expression update succeeded (unexpected): {invalid_update.expression}") - except Exception as e: - print(f"Invalid expression update failed as expected: {e}") - - # Additional validation - print("\n=== Validation Checks ===") - - # Verify that updates actually changed the values - assert updated_rule_1.expression == "$1 > 0.5", ( - f"Expression update failed: {updated_rule_1.expression}" - ) - # For update 1, also verify that the fields that were not updated are not reset. - assert updated_rule_1_model_dump["description"] == rule_1_model_dump["description"], ( - f"Expected no description change, got {rule_1_model_dump['description']} -> {updated_rule_1.description}" - ) - assert ( - updated_rule_1_model_dump["channel_references"] == rule_1_model_dump["channel_references"] - ), ( - f"Expected no channel references change, got {rule_1_model_dump['channel_references']} -> {updated_rule_1.channel_references}" - ) - assert updated_rule_1_model_dump["asset_ids"] == rule_1_model_dump["asset_ids"], ( - f"Expected no asset IDs change, got {rule_1_model_dump['asset_ids']} -> {updated_rule_1.asset_ids}" - ) - assert updated_rule_1_model_dump["asset_tag_ids"] == rule_1_model_dump["asset_tag_ids"], ( - f"Expected no tag IDs change, got {rule_1_model_dump['asset_tag_ids']} -> {updated_rule_1.asset_tag_ids}" - ) - assert ( - updated_rule_1_model_dump["contextual_channels"] == rule_1_model_dump["contextual_channels"] - ), f"Contextual channels update failed: {updated_rule_1.contextual_channels}" - assert "more details" in updated_rule_2.description, ( - f"Description update failed: {updated_rule_2.description}" - ) - assert updated_rule_3.action.annotation_type == RuleAnnotationType.PHASE, ( - f"Action update failed: {updated_rule_3.action.annotation_type}" - ) - assert updated_rule_4.name == new_name, f"Name update failed: {updated_rule_4.name}" - - assert updated_rule_6.expression == "$1 > 0.3 && $1 < 0.8", ( - f"Complex expression update failed: {updated_rule_6.expression}" - ) - # assert updated_rule_7.action.action_type == RuleActionType.NOTIFICATION, f"Action type update failed: {updated_rule_7.action.action_type}" - # assert len(updated_rule_8.tag_ids) == 2, f"Tag IDs update failed: {updated_rule_8.tag_ids}" - assert len(updated_rule_8.contextual_channels) == 2, ( - f"Contextual channels update failed: {updated_rule_8.contextual_channels}" - ) - - print("All validation checks passed!") - print("\n=== Test completed successfully ===") - - -if __name__ == "__main__": - main() diff --git a/python/lib/sift_client/_tests/resources/test_assets.py b/python/lib/sift_client/_tests/resources/test_assets.py index ca69ee500..5fca01fa8 100644 --- a/python/lib/sift_client/_tests/resources/test_assets.py +++ b/python/lib/sift_client/_tests/resources/test_assets.py @@ -7,8 +7,6 @@ - Error handling and edge cases """ -import os - import pytest from sift_client import SiftClient diff --git a/python/lib/sift_client/_tests/resources/test_calculated_channels.py b/python/lib/sift_client/_tests/resources/test_calculated_channels.py index 17b0ba216..d75227823 100644 --- a/python/lib/sift_client/_tests/resources/test_calculated_channels.py +++ b/python/lib/sift_client/_tests/resources/test_calculated_channels.py @@ -28,9 +28,7 @@ def test_client_binding(sift_client): assert sift_client.calculated_channels assert isinstance(sift_client.calculated_channels, CalculatedChannelsAPI) assert sift_client.async_.calculated_channels - assert isinstance( - sift_client.async_.calculated_channels, CalculatedChannelsAPIAsync - ) + assert isinstance(sift_client.async_.calculated_channels, CalculatedChannelsAPIAsync) @pytest.fixture @@ -72,12 +70,8 @@ def new_calculated_channel(calculated_channels_api_sync, sift_client): description=description, expression="$1 + $2", expression_channel_references=[ - ChannelReference( - channel_reference="$1", channel_identifier=channels[0].name - ), - ChannelReference( - channel_reference="$2", channel_identifier=channels[1].name - ), + ChannelReference(channel_reference="$1", channel_identifier=channels[0].name), + ChannelReference(channel_reference="$2", channel_identifier=channels[1].name), ], all_assets=True, ) @@ -92,9 +86,7 @@ class TestGet: """Tests for the async get method.""" @pytest.mark.asyncio - async def test_get_by_id( - self, calculated_channels_api_async, test_calculated_channel - ): + async def test_get_by_id(self, calculated_channels_api_async, test_calculated_channel): """Test getting a specific calculated channel by ID.""" retrieved_calc_channel = await calculated_channels_api_async.get( calculated_channel_id=test_calculated_channel.id_ @@ -147,13 +139,9 @@ async def test_list_with_name_filter(self, calculated_channels_api_async): assert calc_channel.name == test_calc_channel_name @pytest.mark.asyncio - async def test_list_with_name_contains_filter( - self, calculated_channels_api_async - ): + async def test_list_with_name_contains_filter(self, calculated_channels_api_async): """Test calculated channel listing with name contains filtering.""" - calc_channels = await calculated_channels_api_async.list_( - name_contains="test", limit=5 - ) + calc_channels = await calculated_channels_api_async.list_(name_contains="test", limit=5) assert isinstance(calc_channels, list) assert calc_channels @@ -204,9 +192,7 @@ async def test_find_calculated_channel( assert found_calc_channel.id_ == test_calculated_channel.id_ @pytest.mark.asyncio - async def test_find_nonexistent_calculated_channel( - self, calculated_channels_api_async - ): + async def test_find_nonexistent_calculated_channel(self, calculated_channels_api_async): """Test finding a non-existent calculated channel returns None.""" found_calc_channel = await calculated_channels_api_async.find( name="nonexistent-calculated-channel-name-12345" @@ -223,20 +209,14 @@ class TestCreate: """Tests for the async create method.""" @pytest.mark.asyncio - async def test_create_basic_calculated_channel( - self, calculated_channels_api_async - ): + async def test_create_basic_calculated_channel(self, calculated_channels_api_async): """Test creating a basic calculated channel with minimal fields.""" from datetime import datetime, timezone - calc_channel_name = ( - f"test_calc_channel_create_{datetime.now(timezone.utc).isoformat()}" - ) + calc_channel_name = f"test_calc_channel_create_{datetime.now(timezone.utc).isoformat()}" description = "Test calculated channel created by Sift Client pytest" - channels = await calculated_channels_api_async.client.async_.channels.list_( - limit=2 - ) + channels = await calculated_channels_api_async.client.async_.channels.list_(limit=2) assert len(channels) >= 2 calc_channel_create = CalculatedChannelCreate( @@ -244,19 +224,13 @@ async def test_create_basic_calculated_channel( description=description, expression="$1 + $2", expression_channel_references=[ - ChannelReference( - channel_reference="$1", channel_identifier=channels[0].name - ), - ChannelReference( - channel_reference="$2", channel_identifier=channels[1].name - ), + ChannelReference(channel_reference="$1", channel_identifier=channels[0].name), + ChannelReference(channel_reference="$2", channel_identifier=channels[1].name), ], all_assets=True, ) - created_calc_channel = await calculated_channels_api_async.create( - calc_channel_create - ) + created_calc_channel = await calculated_channels_api_async.create(calc_channel_create) try: assert created_calc_channel is not None @@ -270,20 +244,14 @@ async def test_create_basic_calculated_channel( await calculated_channels_api_async.archive(created_calc_channel) @pytest.mark.asyncio - async def test_create_calculated_channel_with_dict( - self, calculated_channels_api_async - ): + async def test_create_calculated_channel_with_dict(self, calculated_channels_api_async): """Test creating a calculated channel using a dictionary.""" from datetime import datetime, timezone - calc_channel_name = ( - f"test_calc_channel_dict_{datetime.now(timezone.utc).isoformat()}" - ) + calc_channel_name = f"test_calc_channel_dict_{datetime.now(timezone.utc).isoformat()}" description = "Test calculated channel created by Sift Client pytest" - channels = await calculated_channels_api_async.client.async_.channels.list_( - limit=2 - ) + channels = await calculated_channels_api_async.client.async_.channels.list_(limit=2) assert len(channels) >= 2 calc_channel_dict = { @@ -297,9 +265,7 @@ async def test_create_calculated_channel_with_dict( "all_assets": True, } - created_calc_channel = await calculated_channels_api_async.create( - calc_channel_dict - ) + created_calc_channel = await calculated_channels_api_async.create(calc_channel_dict) try: assert created_calc_channel.name == calc_channel_name @@ -383,9 +349,7 @@ async def test_archive_calculated_channel( self, calculated_channels_api_async, new_calculated_channel ): """Test archiving a calculated channel.""" - calc_channel = await calculated_channels_api_async.archive( - new_calculated_channel - ) + calc_channel = await calculated_channels_api_async.archive(new_calculated_channel) assert isinstance(calc_channel, CalculatedChannel) assert calc_channel.id_ == new_calculated_channel.id_ @@ -408,9 +372,7 @@ async def test_archive_with_id_string( self, calculated_channels_api_async, new_calculated_channel ): """Test archiving a calculated channel by passing ID as string.""" - calc_channel = await calculated_channels_api_async.archive( - new_calculated_channel.id_ - ) + calc_channel = await calculated_channels_api_async.archive(new_calculated_channel.id_) assert isinstance(calc_channel, CalculatedChannel) assert calc_channel.id_ == new_calculated_channel.id_ @@ -427,9 +389,7 @@ async def test_unarchive_calculated_channel( try: await calculated_channels_api_async.archive(new_calculated_channel) - calc_channel = await calculated_channels_api_async.unarchive( - new_calculated_channel - ) + calc_channel = await calculated_channels_api_async.unarchive(new_calculated_channel) assert isinstance(calc_channel, CalculatedChannel) assert calc_channel.id_ == new_calculated_channel.id_ @@ -441,9 +401,7 @@ class TestListVersions: """Tests for the async list_versions method.""" @pytest.mark.asyncio - async def test_list_versions( - self, calculated_channels_api_async, test_calculated_channel - ): + async def test_list_versions(self, calculated_channels_api_async, test_calculated_channel): """Test listing versions of a calculated channel.""" versions = await calculated_channels_api_async.list_versions( calculated_channel=test_calculated_channel diff --git a/python/lib/sift_client/_tests/resources/test_rules.py b/python/lib/sift_client/_tests/resources/test_rules.py new file mode 100644 index 000000000..bae0e6efa --- /dev/null +++ b/python/lib/sift_client/_tests/resources/test_rules.py @@ -0,0 +1,458 @@ +"""Pytest tests for the Rules API. + +These tests demonstrate and validate the usage of the Rules API including: +- Basic rule operations (get, list, find) +- Rule filtering and searching +- Rule creation, updates, and archiving +- Error handling and edge cases +""" + +import uuid + +import pytest + +from sift_client import SiftClient +from sift_client.resources import RulesAPI, RulesAPIAsync +from sift_client.sift_types import Rule +from sift_client.sift_types.channel import ChannelReference +from sift_client.sift_types.rule import ( + RuleAction, + RuleActionType, + RuleAnnotationType, + RuleCreate, + RuleUpdate, +) + +pytestmark = pytest.mark.integration + + +def test_client_binding(sift_client): + assert sift_client.rules + assert isinstance(sift_client.rules, RulesAPI) + assert sift_client.async_.rules + assert isinstance(sift_client.async_.rules, RulesAPIAsync) + + +@pytest.fixture +def rules_api_async(sift_client: SiftClient): + """Get the async rules API instance.""" + return sift_client.async_.rules + + +@pytest.fixture +def rules_api_sync(sift_client: SiftClient): + """Get the synchronous rules API instance.""" + return sift_client.rules + + +@pytest.fixture +def test_rule(rules_api_sync): + rules = rules_api_sync.list_(limit=1) + assert rules + assert len(rules) >= 1 + return rules[0] + + +@pytest.fixture(scope="function") +def new_rule(rules_api_sync, sift_client): + """Create a test rule for update tests.""" + from datetime import datetime, timezone + + rule_name = f"test_rule_{datetime.now(timezone.utc).isoformat()}" + description = "Test rule created by Sift Client pytest" + + # Get some channels to reference + channels = sift_client.channels.list_(limit=2) + assert len(channels) >= 2 + + # Get an asset to apply the rule to + assets = sift_client.assets.list_(limit=1) + assert len(assets) >= 1 + + created_rule = rules_api_sync.create( + RuleCreate( + name=rule_name, + client_key=f"test_rule_{str(uuid.uuid4())[-8:]}", + description=description, + expression="$1 > $2", + channel_references=[ + ChannelReference(channel_reference="$1", channel_identifier=channels[0].name), + ChannelReference(channel_reference="$2", channel_identifier=channels[1].name), + ], + action=RuleAction.annotation( + annotation_type=RuleAnnotationType.DATA_REVIEW, + tags=[], + ), + asset_ids=[assets[0].id_], + ) + ) + return created_rule + + +class TestRulesAPIAsync: + """Test suite for the async Rules API functionality.""" + + class TestGet: + """Tests for the async get method.""" + + @pytest.mark.asyncio + async def test_get_by_id(self, rules_api_async, test_rule): + """Test getting a specific rule by ID.""" + retrieved_rule = await rules_api_async.get(rule_id=test_rule.id_) + + assert isinstance(retrieved_rule, Rule) + assert retrieved_rule.id_ == test_rule.id_ + assert retrieved_rule.name == test_rule.name + + @pytest.mark.asyncio + async def test_get_by_client_key(self, rules_api_async, test_rule): + """Test getting a specific rule by client key.""" + if test_rule.client_key: + retrieved_rule = await rules_api_async.get(client_key=test_rule.client_key) + + assert retrieved_rule is not None + assert retrieved_rule.id_ == test_rule.id_ + + class TestList: + """Tests for the async list_ method.""" + + @pytest.mark.asyncio + async def test_basic_list(self, rules_api_async): + """Test basic rule listing functionality.""" + rules = await rules_api_async.list_(limit=5) + + assert isinstance(rules, list) + assert len(rules) == 5 + + rule = rules[0] + assert isinstance(rule, Rule) + + @pytest.mark.asyncio + async def test_list_with_name_filter(self, rules_api_async, test_rule): + """Test rule listing with name filtering.""" + filtered_rules = await rules_api_async.list_(name=test_rule.name) + + assert isinstance(filtered_rules, list) + assert len(filtered_rules) >= 1 + + for rule in filtered_rules: + assert rule.name == test_rule.name + + @pytest.mark.asyncio + async def test_list_with_name_contains_filter(self, rules_api_async): + """Test rule listing with name contains filtering.""" + rules = await rules_api_async.list_(name_contains="test", limit=5) + + assert isinstance(rules, list) + assert rules + + for rule in rules: + assert "test" in rule.name.lower() + + @pytest.mark.asyncio + async def test_list_with_name_regex_filter(self, rules_api_async): + """Test rule listing with regex name filtering.""" + rules = await rules_api_async.list_(name_regex=r".*test.*", limit=5) + + assert isinstance(rules, list) + assert rules + + import re + + pattern = re.compile(r".*test.*", re.IGNORECASE) + for rule in rules: + assert pattern.match(rule.name) + + @pytest.mark.asyncio + async def test_list_with_rule_ids_filter(self, rules_api_async): + """Test rule listing with rule IDs filter.""" + all_rules = await rules_api_async.list_(limit=3) + + if all_rules: + rule_ids = [r.id_ for r in all_rules] + filtered_rules = await rules_api_async.list_(rule_ids=rule_ids) + + assert isinstance(filtered_rules, list) + assert len(filtered_rules) >= len(all_rules) + + for rule in filtered_rules: + assert rule.id_ in rule_ids + + @pytest.mark.asyncio + async def test_list_with_description_contains_filter(self, rules_api_async): + """Test rule listing with description contains filtering.""" + rules = await rules_api_async.list_(description_contains="test", limit=5) + + assert isinstance(rules, list) + assert rules + + for rule in rules: + assert "test" in rule.description.lower() + + @pytest.mark.asyncio + async def test_list_with_limit(self, rules_api_async): + """Test rule listing with different limits.""" + rules_1 = await rules_api_async.list_(limit=1) + assert isinstance(rules_1, list) + assert len(rules_1) <= 1 + + rules_3 = await rules_api_async.list_(limit=3) + assert isinstance(rules_3, list) + assert len(rules_3) <= 3 + + @pytest.mark.asyncio + async def test_list_with_time_filters(self, rules_api_async): + """Test rule listing with time-based filters.""" + from datetime import datetime, timedelta, timezone + + one_year_ago = datetime.now(timezone.utc) - timedelta(days=365) + rules = await rules_api_async.list_(created_after=one_year_ago, limit=5) + + assert isinstance(rules, list) + assert rules + + for rule in rules: + assert rule.created_date >= one_year_ago + + class TestFind: + """Tests for the async find method.""" + + @pytest.mark.asyncio + async def test_find_rule(self, rules_api_async, test_rule): + """Test finding a single rule.""" + found_rule = await rules_api_async.find(rule_ids=[test_rule.id_]) + + assert found_rule is not None + assert found_rule.id_ == test_rule.id_ + + @pytest.mark.asyncio + async def test_find_nonexistent_rule(self, rules_api_async): + """Test finding a non-existent rule returns None.""" + found_rule = await rules_api_async.find(name="nonexistent-rule-name-12345") + assert found_rule is None + + @pytest.mark.asyncio + async def test_find_multiple_raises_error(self, rules_api_async): + """Test finding multiple rules raises an error.""" + with pytest.raises(ValueError, match="Multiple"): + await rules_api_async.find(name_contains="test", limit=5) + + class TestCreate: + """Tests for the async create method.""" + + @pytest.mark.asyncio + async def test_create_basic_rule(self, rules_api_async): + """Test creating a basic rule with minimal fields.""" + from datetime import datetime, timezone + + rule_name = f"test_rule_create_{datetime.now(timezone.utc).isoformat()}" + description = "Test rule created by Sift Client pytest" + + channels = await rules_api_async.client.async_.channels.list_(limit=2) + assert len(channels) >= 2 + + assets = await rules_api_async.client.async_.assets.list_(limit=1) + assert len(assets) >= 1 + + rule_create = RuleCreate( + name=rule_name, + description=description, + expression="$1 > $2", + channel_references=[ + ChannelReference(channel_reference="$1", channel_identifier=channels[0].name), + ChannelReference(channel_reference="$2", channel_identifier=channels[1].name), + ], + action=RuleAction.annotation( + annotation_type=RuleAnnotationType.DATA_REVIEW, + tags=[], + ), + asset_ids=[assets[0].id_], + ) + + created_rule = await rules_api_async.create(rule_create) + + try: + assert created_rule is not None + assert isinstance(created_rule, Rule) + assert created_rule.id_ is not None + assert created_rule.name == rule_name + assert created_rule.description == description + assert created_rule.created_date is not None + assert created_rule.modified_date is not None + finally: + await rules_api_async.archive(created_rule) + + @pytest.mark.asyncio + async def test_create_rule_with_dict(self, rules_api_async): + """Test creating a rule using a dictionary.""" + from datetime import datetime, timezone + + rule_name = f"test_rule_dict_{datetime.now(timezone.utc).isoformat()}" + description = "Test rule created by Sift Client pytest" + + channels = await rules_api_async.client.async_.channels.list_(limit=2) + assert len(channels) >= 2 + + assets = await rules_api_async.client.async_.assets.list_(limit=1) + assert len(assets) >= 1 + + rule_dict = { + "name": rule_name, + "description": description, + "expression": "$1 > $2", + "channel_references": [ + {"channel_reference": "$1", "channel_identifier": channels[0].name}, + {"channel_reference": "$2", "channel_identifier": channels[1].name}, + ], + "action": { + "action_type": RuleActionType.ANNOTATION, + "annotation_type": RuleAnnotationType.PHASE, + "tags": [], + }, + "asset_ids": [assets[0].id_], + } + + created_rule = await rules_api_async.create(rule_dict) + + try: + assert created_rule.name == rule_name + assert created_rule.description == description + finally: + await rules_api_async.archive(created_rule) + + class TestUpdate: + """Tests for the async update method.""" + + @pytest.mark.asyncio + async def test_update_rule_description(self, rules_api_async, new_rule): + """Test updating a rule's description.""" + try: + update = RuleUpdate(description="Updated description") + updated_rule = await rules_api_async.update(new_rule, update) + + assert updated_rule.id_ == new_rule.id_ + assert updated_rule.description == "Updated description" + # Validate that things we didn't intentionally change didn't change + assert updated_rule.name == new_rule.name + assert updated_rule.is_enabled == new_rule.is_enabled + assert updated_rule.is_external == new_rule.is_external + assert updated_rule.expression == new_rule.expression + assert updated_rule.action.action_type == new_rule.action.action_type + assert updated_rule.client_key == new_rule.client_key + assert updated_rule.rule_version.created_date > new_rule.rule_version.created_date + finally: + await rules_api_async.archive(new_rule.id_) + + @pytest.mark.asyncio + async def test_update_rule_name(self, rules_api_async, new_rule): + """Test updating a rule's name.""" + try: + new_name = f"updated_{new_rule.name}" + update = RuleUpdate(name=new_name) + updated_rule = await rules_api_async.update(new_rule, update) + + assert updated_rule.name == new_name + assert updated_rule.id_ == new_rule.id_ + finally: + await rules_api_async.archive(new_rule.id_) + + @pytest.mark.asyncio + async def test_update_with_dict(self, rules_api_async, new_rule): + """Test updating a rule using a dictionary.""" + try: + update_dict = {"description": "Updated via dict"} + updated_rule = await rules_api_async.update(new_rule, update_dict) + + assert updated_rule.description == "Updated via dict" + finally: + await rules_api_async.archive(new_rule.id_) + + @pytest.mark.asyncio + async def test_update_with_id_string(self, rules_api_async, new_rule): + """Test updating a rule by passing ID as string.""" + try: + update = RuleUpdate(description="Updated via ID string") + updated_rule = await rules_api_async.update(new_rule.id_, update) + + assert updated_rule.id_ == new_rule.id_ + assert updated_rule.description == "Updated via ID string" + finally: + await rules_api_async.archive(new_rule.id_) + + @pytest.mark.asyncio + async def test_update_with_version_notes(self, rules_api_async, new_rule): + """Test updating a rule with version notes.""" + try: + update = RuleUpdate(description="Updated with version notes") + updated_rule = await rules_api_async.update( + new_rule, update, version_notes="Test version notes" + ) + + assert updated_rule.id_ == new_rule.id_ + assert updated_rule.description == "Updated with version notes" + finally: + await rules_api_async.archive(new_rule.id_) + + class TestArchive: + """Tests for the async archive method.""" + + @pytest.mark.asyncio + async def test_archive_rule(self, rules_api_async, new_rule): + """Test archiving a rule.""" + rule = await rules_api_async.archive(new_rule) + + assert isinstance(rule, Rule) + assert rule.id_ == new_rule.id_ + assert rule.is_archived is True + + rules_without_archived = await rules_api_async.list_( + name=new_rule.name, include_archived=False + ) + assert len(rules_without_archived) == 0 + + rules_with_archived = await rules_api_async.list_( + name=new_rule.name, include_archived=True + ) + assert len(rules_with_archived) == 1 + assert rules_with_archived[0].id_ == new_rule.id_ + assert rules_with_archived[0].archived_date is not None + + @pytest.mark.asyncio + async def test_archive_with_id_string(self, rules_api_async, new_rule): + """Test archiving a rule by passing ID as string.""" + rule = await rules_api_async.archive(new_rule.id_) + + assert isinstance(rule, Rule) + assert rule.id_ == new_rule.id_ + assert rule.is_archived is True + + class TestUnarchive: + """Tests for the async unarchive method.""" + + @pytest.mark.asyncio + async def test_unarchive_rule(self, rules_api_async, new_rule): + """Test unarchiving a rule.""" + try: + await rules_api_async.archive(new_rule) + + rule = await rules_api_async.unarchive(new_rule) + + assert isinstance(rule, Rule) + assert rule.id_ == new_rule.id_ + assert rule.is_archived is False + finally: + await rules_api_async.archive(new_rule.id_) + + +class TestRulesAPISync: + """Test suite for the synchronous Rules API functionality.""" + + class TestGet: + """Tests for the sync get method.""" + + def test_get_by_id(self, rules_api_sync, test_rule): + """Test getting a specific rule by ID synchronously.""" + retrieved_rule = rules_api_sync.get(rule_id=test_rule.id_) + + assert retrieved_rule is not None + assert retrieved_rule.id_ == test_rule.id_ diff --git a/python/lib/sift_client/_tests/resources/test_runs.py b/python/lib/sift_client/_tests/resources/test_runs.py index a4e3fc750..8d99757a5 100644 --- a/python/lib/sift_client/_tests/resources/test_runs.py +++ b/python/lib/sift_client/_tests/resources/test_runs.py @@ -7,7 +7,6 @@ - Error handling and edge cases """ -import os from datetime import datetime, timedelta, timezone import pytest @@ -256,7 +255,12 @@ async def test_create_run_with_all_fields(self, runs_api_async): assert created_run.client_key is not None assert created_run.start_time is not None assert created_run.stop_time is not None - assert created_run.tags == ["test", "pytest", "integration", "sift-client-pytest"] + assert created_run.tags == [ + "test", + "pytest", + "integration", + "sift-client-pytest", + ] assert created_run.metadata["pytest_type"] == "integration" finally: @@ -328,7 +332,11 @@ async def test_update_run_tags_and_metadata(self, runs_api_async, new_run): updated_run = await runs_api_async.update(new_run, update) # Verify the updates - assert set(updated_run.tags) == {"updated", "new-tag", "sift-client-pytest"} + assert set(updated_run.tags) == { + "updated", + "new-tag", + "sift-client-pytest", + } finally: await runs_api_async.archive(new_run.id_) diff --git a/python/lib/sift_client/examples/generic_workflow_example.py b/python/lib/sift_client/examples/generic_workflow_example.py index 08901f70a..c3ec1b642 100644 --- a/python/lib/sift_client/examples/generic_workflow_example.py +++ b/python/lib/sift_client/examples/generic_workflow_example.py @@ -29,7 +29,7 @@ async def main(): asset_id = asset.id_ print("Found asset", asset.name) - calculated_channels = client.calculated_channels.list( + calculated_channels = client.calculated_channels.list_( name_regex="velocity_per.*", asset_id=asset_id, ) diff --git a/python/lib/sift_client/resources/sync_stubs/__init__.pyi b/python/lib/sift_client/resources/sync_stubs/__init__.pyi index 15302378b..85c49bad2 100644 --- a/python/lib/sift_client/resources/sync_stubs/__init__.pyi +++ b/python/lib/sift_client/resources/sync_stubs/__init__.pyi @@ -782,7 +782,7 @@ class RunsAPI: """ ... - def stop(self, run: str | Run) -> None: + def stop(self, run: str | Run) -> Run: """Stop a run by setting its stop time to the current time. Args: diff --git a/python/lib/sift_client/sift_types/_base.py b/python/lib/sift_client/sift_types/_base.py index b88535e2c..595285fda 100644 --- a/python/lib/sift_client/sift_types/_base.py +++ b/python/lib/sift_client/sift_types/_base.py @@ -1,10 +1,11 @@ from __future__ import annotations from abc import ABC, abstractmethod +from datetime import datetime from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, TypeVar from google.protobuf import field_mask_pb2, message -from pydantic import BaseModel, ConfigDict, Field, PrivateAttr +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator if TYPE_CHECKING: from sift_client.client import SiftClient @@ -55,6 +56,16 @@ def _update(self, other: BaseType[ProtoT, SelfT]) -> BaseType[ProtoT, SelfT]: self.__dict__["proto"] = other.proto return self + @model_validator(mode="after") + def _validate_timezones(self): + """Validate datetime fiels have timezone information.""" + for field_name in self.model_fields.keys(): + val = getattr(self, field_name) + if isinstance(val, datetime) and val.tzinfo is None: + raise ValueError(f"{field_name} must have timezone information") + + return self + class MappingHelper(BaseModel): """Helper class for mapping fields to proto attributes and update fields diff --git a/python/lib/sift_client/sift_types/calculated_channel.py b/python/lib/sift_client/sift_types/calculated_channel.py index 7388c3bf1..19dc472d7 100644 --- a/python/lib/sift_client/sift_types/calculated_channel.py +++ b/python/lib/sift_client/sift_types/calculated_channel.py @@ -96,6 +96,7 @@ def _from_proto( cls, proto: CalculatedChannelProto, sift_client: SiftClient | None = None ) -> CalculatedChannel: return cls( + proto=proto, id_=proto.calculated_channel_id, name=proto.name, description=proto.description, @@ -179,14 +180,8 @@ class CalculatedChannelBase(ModelCreateUpdateBase): @model_validator(mode="after") def _validate_asset_configuration(self): """Validate that either all_assets is True or at least one of tag_ids or asset_ids is provided, but not both.""" - if ( - self.all_assets is not None - and self.all_assets - and (self.asset_ids or self.tag_ids) - ): - raise ValueError( - "Cannot specify both all_assets=True and asset_ids/tag_ids" - ) + if self.all_assets is not None and self.all_assets and (self.asset_ids or self.tag_ids): + raise ValueError("Cannot specify both all_assets=True and asset_ids/tag_ids") return self @model_validator(mode="after") @@ -199,9 +194,7 @@ def _validate_expression_and_channel_references(self): return self -class CalculatedChannelCreate( - CalculatedChannelBase, ModelCreate[CreateCalculatedChannelRequest] -): +class CalculatedChannelCreate(CalculatedChannelBase, ModelCreate[CreateCalculatedChannelRequest]): """Create model for a Calculated Channel.""" name: str @@ -212,9 +205,7 @@ def _get_proto_class(self) -> type[CreateCalculatedChannelRequest]: return CreateCalculatedChannelRequest -class CalculatedChannelUpdate( - CalculatedChannelBase, ModelUpdate[CalculatedChannelProto] -): +class CalculatedChannelUpdate(CalculatedChannelBase, ModelUpdate[CalculatedChannelProto]): """Update model for a Calculated Channel.""" name: str | None = None diff --git a/python/lib/sift_client/sift_types/rule.py b/python/lib/sift_client/sift_types/rule.py index ac4a4e3d6..df045f105 100644 --- a/python/lib/sift_client/sift_types/rule.py +++ b/python/lib/sift_client/sift_types/rule.py @@ -1,5 +1,6 @@ from __future__ import annotations +from datetime import datetime, timezone from enum import Enum from typing import TYPE_CHECKING @@ -32,8 +33,6 @@ from sift_client.sift_types.channel import ChannelReference if TYPE_CHECKING: - from datetime import datetime - from sift_client.client import SiftClient from sift_client.sift_types.asset import Asset @@ -135,8 +134,8 @@ def _from_proto(cls, proto: RuleProto, sift_client: SiftClient | None = None) -> ], action=RuleAction._from_proto(proto.conditions[0].actions[0]), is_enabled=proto.is_enabled, - created_date=proto.created_date.ToDatetime(), - modified_date=proto.modified_date.ToDatetime(), + created_date=proto.created_date.ToDatetime(tzinfo=timezone.utc), + modified_date=proto.modified_date.ToDatetime(tzinfo=timezone.utc), created_by_user_id=proto.created_by_user_id, modified_by_user_id=proto.modified_by_user_id, organization_id=proto.organization_id, @@ -147,7 +146,9 @@ def _from_proto(cls, proto: RuleProto, sift_client: SiftClient | None = None) -> asset_ids=proto.asset_configuration.asset_ids, # type: ignore asset_tag_ids=proto.asset_configuration.tag_ids, # type: ignore contextual_channels=[c.name for c in proto.contextual_channels.channels], - archived_date=(proto.archived_date.ToDatetime() if proto.archived_date else None), + archived_date=( + proto.archived_date.ToDatetime(tzinfo=timezone.utc) if proto.archived_date else None + ), is_archived=proto.is_archived, is_external=proto.is_external, _client=sift_client, @@ -296,9 +297,10 @@ def _from_proto( ) -> RuleAction: action_type = RuleActionType(proto.action_type) return cls( + proto=proto, condition_id=proto.rule_condition_id, - created_date=proto.created_date.ToDatetime(), - modified_date=proto.modified_date.ToDatetime(), + created_date=proto.created_date.ToDatetime(tzinfo=timezone.utc), + modified_date=proto.modified_date.ToDatetime(tzinfo=timezone.utc), created_by_user_id=proto.created_by_user_id, modified_by_user_id=proto.modified_by_user_id, version_id=proto.rule_action_version_id, @@ -358,14 +360,17 @@ def _from_proto( cls, proto: RuleVersionProto, sift_client: SiftClient | None = None ) -> RuleVersion: return cls( + proto=proto, rule_id=proto.rule_id, rule_version_id=proto.rule_version_id, version=proto.version, - created_date=proto.created_date.ToDatetime(), + created_date=proto.created_date.ToDatetime(tzinfo=timezone.utc), created_by_user_id=proto.created_by_user_id, version_notes=proto.version_notes, generated_change_message=proto.generated_change_message, - archived_date=(proto.archived_date.ToDatetime() if proto.archived_date else None), + archived_date=( + proto.archived_date.ToDatetime(tzinfo=timezone.utc) if proto.archived_date else None + ), is_archived=proto.is_archived, _client=sift_client, ) From 289604c33b7f6f2a953efc5ce3231b4c6350d579 Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Thu, 9 Oct 2025 14:17:13 -0700 Subject: [PATCH 26/39] add test_ingestion.py --- .../sift_client/_tests/integrated/__init__.py | 0 .../_tests/integrated/ingestion.py | 218 -------- .../_tests/resources/test_ingestion.py | 499 ++++++++++++++++++ python/lib/sift_client/client.py | 1 - .../lib/sift_client/sift_types/ingestion.py | 69 ++- python/lib/sift_client/sift_types/run.py | 40 +- 6 files changed, 572 insertions(+), 255 deletions(-) delete mode 100644 python/lib/sift_client/_tests/integrated/__init__.py delete mode 100644 python/lib/sift_client/_tests/integrated/ingestion.py create mode 100644 python/lib/sift_client/_tests/resources/test_ingestion.py diff --git a/python/lib/sift_client/_tests/integrated/__init__.py b/python/lib/sift_client/_tests/integrated/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/python/lib/sift_client/_tests/integrated/ingestion.py b/python/lib/sift_client/_tests/integrated/ingestion.py deleted file mode 100644 index b5f49d894..000000000 --- a/python/lib/sift_client/_tests/integrated/ingestion.py +++ /dev/null @@ -1,218 +0,0 @@ -import asyncio -import math -import os -import random -import time -from datetime import datetime, timedelta, timezone - -from sift_client._tests import setup_logger -from sift_client.client import SiftClient -from sift_client.sift_types.channel import ( - ChannelBitFieldElement, - ChannelDataType, -) -from sift_client.sift_types.ingestion import ChannelConfig, Flow -from sift_client.transport import SiftConnectionConfig - -setup_logger() - - -async def main(): - grpc_url = os.getenv("SIFT_GRPC_URI", "localhost:50051") - api_key = os.getenv("SIFT_API_KEY", "") - rest_url = os.getenv("SIFT_REST_URI", "localhost:8080") - client = SiftClient( - connection_config=SiftConnectionConfig( - grpc_url=grpc_url, - api_key=api_key, - rest_url=rest_url, - use_ssl=True, - cert_via_openssl=True, - ) - ) - - asset = "ian-test-asset" - - # TODO:Get user id from current user - previously_created_runs = client.runs.list_( - name_regex="test-run-.*", created_by="1eba461b-fa36-4e98-8fe8-ff32d3e43a6e" - ) - if previously_created_runs: - print(f" Deleting previously created runs: {previously_created_runs}") - for run in previously_created_runs: - print(f" Deleting run: {run.name}") - client.runs.archive(run=run) - - run = client.runs.create( - dict( - name=f"test-run-{datetime.now(tz=timezone.utc).timestamp()}", - description="A test run created via the API", - tags=["api-created", "test"], - ) - ) - - regular_flow = Flow( - name="test-flow", - channels=[ - ChannelConfig(name="test-channel", data_type=ChannelDataType.DOUBLE), - ChannelConfig( - name="test-enum-channel", - data_type=ChannelDataType.ENUM, - enum_types={"enum1": 1, "enum2": 2}, - ), - ], - ) - regular_flow.add_channelConfig( - ChannelConfig( - name="test-bit-field-channel", - data_type=ChannelDataType.BIT_FIELD, - bit_field_elements=[ - ChannelBitFieldElement(name="12v", index=0, bit_count=4), - ChannelBitFieldElement(name="charge", index=4, bit_count=2), - ChannelBitFieldElement(name="led", index=6, bit_count=1), - ChannelBitFieldElement(name="heater", index=7, bit_count=1), - ], - ) - ) - - highspeed_flow = Flow( - name="highspeed-flow", - channels=[ - ChannelConfig(name="highspeed-channel", data_type=ChannelDataType.DOUBLE), - ], - ) - # This seals the flow and ingestion config - config_id = await client.async_.ingestion.create_ingestion_config( - asset_name=asset, - run_id=run.id_, - flows=[regular_flow, highspeed_flow], - ) - print(f"config_id: {config_id}") - try: - regular_flow.add_channelConfig( - ChannelConfig(name="test-channel", data_type=ChannelDataType.DOUBLE) - ) - except ValueError as e: - assert repr(e) == "ValueError('Cannot add a channel to a flow after creation')" - - other_asset_flows = [ - Flow( - name="new-asset-flow", - channels=[ - # Same channel name as the regular flow, but on a different asset. - ChannelConfig(name="test-channel", data_type=ChannelDataType.DOUBLE), - ], - ) - ] - await client.async_.ingestion.create_ingestion_config( - asset_name="test-asset-ian2", - run_id=run.id_, - flows=other_asset_flows, - ) - sleep_time = 0.05 # Time between outer loop iterations to simulate real-time latency between ingestion calls. - simulated_duration = 50 - fake_hs_rate = 50 # Hz - fake_hs_period = 1 / fake_hs_rate - start = datetime.now(tz=timezone.utc) - for i in range(simulated_duration): - now = start + timedelta(seconds=i) - regular_flow.ingest( - timestamp=now, - channel_values={ - "test-channel": 3.0 * math.sin(2 * math.pi * fake_hs_rate * i + 0.07), - "test-enum-channel": i % 2 + 1, - "test-bit-field-channel": { - "12v": random.randint(3, 13), - "charge": random.randint(1, 3), - "led": random.choice([0, 1]), - "heater": random.choice([0, 1]), - }, - }, - ) - for j in range(fake_hs_rate): - val = 3.0 * math.sin(2 * math.pi * fake_hs_rate * (i + j * 0.001) + 0) - timestamp = now + timedelta(milliseconds=j * fake_hs_period * 1000) - channel_values = { - "highspeed-channel": val, - } - # Alternative way to ingest - client.ingestion.ingest( - flow=highspeed_flow, timestamp=timestamp, channel_values=channel_values - ) - time.sleep(sleep_time) - - other_asset_flows[0].ingest( - timestamp=start + timedelta(seconds=simulated_duration), - channel_values={ - "test-channel": -6.66, - }, - ) - - # Test ingestion of a flow without all channels specified - try: - regular_flow.ingest( - timestamp=start + timedelta(seconds=simulated_duration), - channel_values={ - "test-channel": 0, - "test-enum-channel": 2, - # "test-bit-field-channel": bytes([0b01010101]), - }, - ) - except ValueError as e: - assert "Expected all channels in flow to have a data point at same time." in repr(e) - - # Test ingestion of a bad enum value (string and int) - try: - regular_flow.ingest( - timestamp=start + timedelta(seconds=simulated_duration), - channel_values={ - "test-channel": 0, - "test-enum-channel": -3, - "test-bit-field-channel": bytes([0b01010101]), - }, - ) - except ValueError as e: - assert "Could not find enum value: -3 in enum options: {'enum1': 1, 'enum2': 2}" in repr(e) - try: - regular_flow.ingest( - timestamp=start + timedelta(seconds=simulated_duration), - channel_values={ - "test-channel": 0, - "test-enum-channel": "nonexistent-enum", - "test-bit-field-channel": bytes([0b01010101]), - }, - ) - except ValueError as e: - assert ( - "Could not find enum value: nonexistent-enum in enum options: {'enum1': 1, 'enum2': 2}" - in repr(e) - ) - - client.async_.ingestion.wait_for_ingestion_to_complete(timeout=2) - end = datetime.now(tz=timezone.utc) - # Test ingesting more data after letting a thread finish. Also exercise ingesting bitfield values as bytes. - time.sleep(1) - print("Restarting ingestion") - regular_flow.ingest( - timestamp=start + timedelta(seconds=simulated_duration + 1), - channel_values={ - "test-channel": 7.77, - "test-enum-channel": 1, - "test-bit-field-channel": bytes([0b11111111]), - }, - ) - # Wait less time than threads nominal no_data_timeout so we can exercise forced cleanup. - client.async_.ingestion.wait_for_ingestion_to_complete(timeout=0.01) - client.runs.archive(run=run.id_) - - num_datapoints = fake_hs_rate * len( - highspeed_flow.channels - ) * simulated_duration + simulated_duration * len(regular_flow.channels) - print(f"Ingestion time: {end - start} seconds") - print(f"Ingested {num_datapoints} datapoints") - total_time = (end - start).total_seconds() - print(f"Ingestion rate: {num_datapoints / total_time:.2f} datapoints/second") - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/python/lib/sift_client/_tests/resources/test_ingestion.py b/python/lib/sift_client/_tests/resources/test_ingestion.py new file mode 100644 index 000000000..c5f00d0b3 --- /dev/null +++ b/python/lib/sift_client/_tests/resources/test_ingestion.py @@ -0,0 +1,499 @@ +"""Pytest tests for the Ingestion API. + +These tests demonstrate and validate the usage of the Ingestion API including: +- Creating ingestion configurations +- Ingesting data with various channel types (double, enum, bit field) +- Flow management and validation +- High-speed and regular flow ingestion +- Error handling and edge cases +""" + +import math +import random +import time +from datetime import datetime, timedelta, timezone + +import pytest + +from sift_client import SiftClient +from sift_client.sift_types.channel import ChannelBitFieldElement, ChannelDataType +from sift_client.sift_types.ingestion import ChannelConfig, Flow + +pytestmark = pytest.mark.integration + +ASSET_NAME = "test-ingestion-asset" + + +def test_client_binding(sift_client): + assert getattr(sift_client, "ingestion", None) is None # Only async! + assert sift_client.async_.ingestion + + +@pytest.fixture(scope="function") +def test_run(sift_client: SiftClient): + """Create a test run for ingestion tests.""" + run = sift_client.runs.create( + { + "name": f"test-ingestion-run-{datetime.now(tz=timezone.utc).timestamp()}", + "description": "Test run for ingestion integration tests", + "tags": ["test", "ingestion", "pytest"], + } + ) + yield run + # Cleanup + sift_client.runs.archive(run=run.id_) + + +class TestIngestionAPIAsync: + """Test suite for the async Ingestion API functionality.""" + + class TestCreateIngestionConfig: + """Tests for creating ingestion configurations.""" + + @pytest.mark.asyncio + async def test_create_basic_config(self, sift_client, test_run): + """Test creating a basic ingestion configuration.""" + flow = Flow( + name="test-basic-flow", + channels=[ + ChannelConfig( + name="test-channel", data_type=ChannelDataType.DOUBLE + ), + ], + ) + + config_id = await sift_client.async_.ingestion.create_ingestion_config( + asset_name=ASSET_NAME, + run_id=test_run.id_, + flows=[flow], + ) + + assert config_id is not None + assert isinstance(config_id, str) + + @pytest.mark.asyncio + async def test_create_config_with_multiple_flows(self, sift_client, test_run): + """Test creating an ingestion configuration with multiple flows.""" + regular_flow = Flow( + name="test-regular-flow", + channels=[ + ChannelConfig( + name="regular-channel", data_type=ChannelDataType.DOUBLE + ), + ], + ) + + highspeed_flow = Flow( + name="test-highspeed-flow", + channels=[ + ChannelConfig( + name="highspeed-channel", data_type=ChannelDataType.DOUBLE + ), + ], + ) + + config_id = await sift_client.async_.ingestion.create_ingestion_config( + asset_name=ASSET_NAME, + run_id=test_run.id_, + flows=[regular_flow, highspeed_flow], + ) + + assert config_id is not None + + @pytest.mark.asyncio + async def test_create_config_with_enum_channel(self, sift_client, test_run): + """Test creating an ingestion configuration with enum channel.""" + flow = Flow( + name="test-enum-flow", + channels=[ + ChannelConfig( + name="test-enum-channel", + data_type=ChannelDataType.ENUM, + enum_types={"state1": 1, "state2": 2, "state3": 3}, + ), + ], + ) + + config_id = await sift_client.async_.ingestion.create_ingestion_config( + asset_name=ASSET_NAME, + run_id=test_run.id_, + flows=[flow], + ) + + assert config_id is not None + + @pytest.mark.asyncio + async def test_create_config_with_bit_field_channel( + self, sift_client, test_run + ): + """Test creating an ingestion configuration with bit field channel.""" + flow = Flow( + name="test-bitfield-flow", + channels=[ + ChannelConfig( + name="test-bit-field-channel", + data_type=ChannelDataType.BIT_FIELD, + bit_field_elements=[ + ChannelBitFieldElement( + name="voltage", index=0, bit_count=4 + ), + ChannelBitFieldElement( + name="current", index=4, bit_count=2 + ), + ChannelBitFieldElement(name="status", index=6, bit_count=2), + ], + ), + ], + ) + + config_id = await sift_client.async_.ingestion.create_ingestion_config( + asset_name=ASSET_NAME, + run_id=test_run.id_, + flows=[flow], + ) + + assert config_id is not None + + @pytest.mark.asyncio + async def test_flow_sealed_after_config_creation(self, sift_client, test_run): + """Test that flows are sealed after ingestion config creation.""" + flow = Flow( + name="test-sealed-flow", + channels=[ + ChannelConfig( + name="test-channel", data_type=ChannelDataType.DOUBLE + ), + ], + ) + + await sift_client.async_.ingestion.create_ingestion_config( + asset_name=ASSET_NAME, + run_id=test_run.id_, + flows=[flow], + ) + + # Try to add a channel after config creation + with pytest.raises( + ValueError, match="Cannot add a channel to a flow after creation" + ): + flow.add_channel( + ChannelConfig(name="new-channel", data_type=ChannelDataType.DOUBLE) + ) + + class TestIngestData: + """Tests for ingesting data.""" + + @pytest.mark.asyncio + async def test_ingest_double_data(self, sift_client, test_run): + """Test ingesting double data.""" + flow = Flow( + name="test-double-flow", + channels=[ + ChannelConfig( + name="double-channel", data_type=ChannelDataType.DOUBLE + ), + ], + ) + + await sift_client.async_.ingestion.create_ingestion_config( + asset_name=ASSET_NAME, + run_id=test_run.id_, + flows=[flow], + ) + + start_time = datetime.now(tz=timezone.utc) + for i in range(10): + timestamp = start_time + timedelta(seconds=i) + flow.ingest( + timestamp=timestamp, + channel_values={"double-channel": float(i)}, + ) + + sift_client.async_.ingestion.wait_for_ingestion_to_complete(timeout=2) + + @pytest.mark.asyncio + async def test_ingest_enum_data(self, sift_client, test_run): + """Test ingesting enum data.""" + flow = Flow( + name="test-enum-ingest-flow", + channels=[ + ChannelConfig( + name="enum-channel", + data_type=ChannelDataType.ENUM, + enum_types={"low": 1, "medium": 2, "high": 3}, + ), + ], + ) + + await sift_client.async_.ingestion.create_ingestion_config( + asset_name=ASSET_NAME, + run_id=test_run.id_, + flows=[flow], + ) + + start_time = datetime.now(tz=timezone.utc) + for i in range(10): + timestamp = start_time + timedelta(seconds=i) + flow.ingest( + timestamp=timestamp, + channel_values={"enum-channel": (i % 3) + 1}, + ) + + sift_client.async_.ingestion.wait_for_ingestion_to_complete(timeout=2) + + @pytest.mark.asyncio + async def test_ingest_bit_field_data_as_dict(self, sift_client, test_run): + """Test ingesting bit field data as dictionary.""" + flow = Flow( + name="test-bitfield-ingest-flow", + channels=[ + ChannelConfig( + name="bitfield-channel", + data_type=ChannelDataType.BIT_FIELD, + bit_field_elements=[ + ChannelBitFieldElement( + name="voltage", index=0, bit_count=4 + ), + ChannelBitFieldElement( + name="current", index=4, bit_count=2 + ), + ChannelBitFieldElement(name="led", index=6, bit_count=1), + ChannelBitFieldElement(name="heater", index=7, bit_count=1), + ], + ), + ], + ) + + await sift_client.async_.ingestion.create_ingestion_config( + asset_name=ASSET_NAME, + run_id=test_run.id_, + flows=[flow], + ) + + start_time = datetime.now(tz=timezone.utc) + for i in range(10): + timestamp = start_time + timedelta(seconds=i) + flow.ingest( + timestamp=timestamp, + channel_values={ + "bitfield-channel": { + "voltage": random.randint(3, 13), + "current": random.randint(1, 3), + "led": random.choice([0, 1]), + "heater": random.choice([0, 1]), + } + }, + ) + + sift_client.async_.ingestion.wait_for_ingestion_to_complete(timeout=2) + + @pytest.mark.asyncio + async def test_ingest_bit_field_data_as_bytes(self, sift_client, test_run): + """Test ingesting bit field data as bytes.""" + flow = Flow( + name="test-bitfield-bytes-flow", + channels=[ + ChannelConfig( + name="bitfield-channel", + data_type=ChannelDataType.BIT_FIELD, + bit_field_elements=[ + ChannelBitFieldElement(name="field1", index=0, bit_count=4), + ChannelBitFieldElement(name="field2", index=4, bit_count=4), + ], + ), + ], + ) + + await sift_client.async_.ingestion.create_ingestion_config( + asset_name=ASSET_NAME, + run_id=test_run.id_, + flows=[flow], + ) + + timestamp = datetime.now(tz=timezone.utc) + flow.ingest( + timestamp=timestamp, + channel_values={"bitfield-channel": bytes([0b11110000])}, + ) + + sift_client.async_.ingestion.wait_for_ingestion_to_complete(timeout=2) + + @pytest.mark.asyncio + async def test_ingest_multiple_channels(self, sift_client, test_run): + """Test ingesting data for multiple channels simultaneously.""" + flow = Flow( + name="test-multi-channel-flow", + channels=[ + ChannelConfig(name="channel1", data_type=ChannelDataType.DOUBLE), + ChannelConfig( + name="channel2", + data_type=ChannelDataType.ENUM, + enum_types={"a": 1, "b": 2}, + ), + ChannelConfig( + name="channel3", + data_type=ChannelDataType.BIT_FIELD, + bit_field_elements=[ + ChannelBitFieldElement(name="bit1", index=0, bit_count=4), + ChannelBitFieldElement(name="bit2", index=4, bit_count=4), + ], + ), + ], + ) + + await sift_client.async_.ingestion.create_ingestion_config( + asset_name=ASSET_NAME, + run_id=test_run.id_, + flows=[flow], + ) + + start_time = datetime.now(tz=timezone.utc) + for i in range(5): + timestamp = start_time + timedelta(seconds=i) + flow.ingest( + timestamp=timestamp, + channel_values={ + "channel1": float(i), + "channel2": (i % 2) + 1, + "channel3": {"bit1": i % 16, "bit2": (i * 2) % 16}, + }, + ) + + sift_client.async_.ingestion.wait_for_ingestion_to_complete(timeout=2) + + @pytest.mark.asyncio + async def test_ingest_highspeed_data(self, sift_client, test_run): + """Test ingesting high-speed data.""" + flow = Flow( + name="test-highspeed-data-flow", + channels=[ + ChannelConfig( + name="highspeed-channel", data_type=ChannelDataType.DOUBLE + ), + ], + ) + + await sift_client.async_.ingestion.create_ingestion_config( + asset_name=ASSET_NAME, + run_id=test_run.id_, + flows=[flow], + ) + + start_time = datetime.now(tz=timezone.utc) + fake_hs_rate = 50 # Hz + fake_hs_period = 1 / fake_hs_rate + duration = 2 # seconds + + for i in range(duration): + for j in range(fake_hs_rate): + val = 3.0 * math.sin(2 * math.pi * fake_hs_rate * (i + j * 0.001)) + timestamp = start_time + timedelta( + seconds=i, milliseconds=j * fake_hs_period * 1000 + ) + flow.ingest( + timestamp=timestamp, + channel_values={"highspeed-channel": val}, + ) + time.sleep(0.01) + + sift_client.async_.ingestion.wait_for_ingestion_to_complete(timeout=2) + + class TestIngestionValidation: + """Tests for ingestion validation and error handling.""" + + @pytest.mark.asyncio + async def test_ingest_missing_channel_raises_error(self, sift_client, test_run): + """Test that ingesting without all channels raises an error.""" + flow = Flow( + name="test-validation-flow", + channels=[ + ChannelConfig(name="channel1", data_type=ChannelDataType.DOUBLE), + ChannelConfig(name="channel2", data_type=ChannelDataType.DOUBLE), + ], + ) + + await sift_client.async_.ingestion.create_ingestion_config( + asset_name=ASSET_NAME, + run_id=test_run.id_, + flows=[flow], + ) + + timestamp = datetime.now(tz=timezone.utc) + with pytest.raises( + Exception, + match="Expected all channels in flow to have a data point at same time", + ): + flow.ingest( + timestamp=timestamp, + channel_values={"channel1": 1.0}, # Missing channel2 + ) + + @pytest.mark.asyncio + async def test_ingest_invalid_enum_value_raises_error( + self, sift_client, test_run + ): + """Test that ingesting an invalid enum value raises an error.""" + flow = Flow( + name="test-enum-validation-flow", + channels=[ + ChannelConfig( + name="enum-channel", + data_type=ChannelDataType.ENUM, + enum_types={"valid1": 1, "valid2": 2}, + ), + ], + ) + + await sift_client.async_.ingestion.create_ingestion_config( + asset_name=ASSET_NAME, + run_id=test_run.id_, + flows=[flow], + ) + + timestamp = datetime.now(tz=timezone.utc) + # Test with invalid integer + with pytest.raises(ValueError, match="Could not find enum value"): + flow.ingest( + timestamp=timestamp, + channel_values={"enum-channel": 99}, + ) + + # Test with invalid string + with pytest.raises(ValueError, match="Could not find enum value"): + flow.ingest( + timestamp=timestamp, + channel_values={"enum-channel": "invalid-enum"}, + ) + + @pytest.mark.asyncio + async def test_resume_ingestion_after_wait(self, sift_client, test_run): + """Test that ingestion can resume after waiting for completion.""" + flow = Flow( + name="test-resume-flow", + channels=[ + ChannelConfig( + name="test-channel", data_type=ChannelDataType.DOUBLE + ), + ], + ) + + await sift_client.async_.ingestion.create_ingestion_config( + asset_name=ASSET_NAME, + run_id=test_run.id_, + flows=[flow], + ) + + # First batch + timestamp1 = datetime.now(tz=timezone.utc) + flow.ingest(timestamp=timestamp1, channel_values={"test-channel": 1.0}) + + sift_client.async_.ingestion.wait_for_ingestion_to_complete(timeout=2) + + # Wait a bit + time.sleep(0.1) + + # Second batch after wait + timestamp2 = timestamp1 + timedelta(seconds=2) + flow.ingest(timestamp=timestamp2, channel_values={"test-channel": 2.0}) + + sift_client.async_.ingestion.wait_for_ingestion_to_complete(timeout=2) diff --git a/python/lib/sift_client/client.py b/python/lib/sift_client/client.py index 427e4def5..e77ab5e20 100644 --- a/python/lib/sift_client/client.py +++ b/python/lib/sift_client/client.py @@ -127,7 +127,6 @@ def __init__( self.assets = AssetsAPI(self) self.calculated_channels = CalculatedChannelsAPI(self) self.channels = ChannelsAPI(self) - self.ingestion = IngestionAPIAsync(self) self.rules = RulesAPI(self) self.runs = RunsAPI(self) diff --git a/python/lib/sift_client/sift_types/ingestion.py b/python/lib/sift_client/sift_types/ingestion.py index 2d6d22693..12d9251eb 100644 --- a/python/lib/sift_client/sift_types/ingestion.py +++ b/python/lib/sift_client/sift_types/ingestion.py @@ -75,7 +75,9 @@ def _validate_enum_types(self): raise ValueError( f"Channel '{self.name}' has data_type ENUM but enum_types is not provided" ) - elif self.data_type == ChannelDataType.BIT_FIELD and not self.bit_field_elements: + elif ( + self.data_type == ChannelDataType.BIT_FIELD and not self.bit_field_elements + ): raise ValueError( f"Channel '{self.name}' has data_type BIT_FIELD but bit_field_elements is not provided" ) @@ -92,14 +94,19 @@ def _from_proto( data_type=ChannelDataType(proto.data_type), description=proto.description if proto.description else None, unit=proto.unit if proto.unit else None, - bit_field_elements=[ - ChannelBitFieldElement._from_proto(el) for el in proto.bit_field_elements - ] - if proto.bit_field_elements - else None, - enum_types={enum.name: enum.key for enum in proto.enum_types} - if proto.enum_types - else None, + bit_field_elements=( + [ + ChannelBitFieldElement._from_proto(el) + for el in proto.bit_field_elements + ] + if proto.bit_field_elements + else None + ), + enum_types=( + {enum.name: enum.key for enum in proto.enum_types} + if proto.enum_types + else None + ), _client=sift_client, ) @@ -118,7 +125,9 @@ def from_channel(cls, channel: Channel) -> ChannelConfig: data_type=channel.data_type, description=channel.description, unit=channel.unit, - bit_field_elements=channel.bit_field_elements if channel.bit_field_elements else None, + bit_field_elements=( + channel.bit_field_elements if channel.bit_field_elements else None + ), enum_types=channel.enum_types, ) @@ -127,7 +136,9 @@ def _to_config_proto(self) -> ChannelConfigProto: from sift.common.type.v1.channel_bit_field_element_pb2 import ( ChannelBitFieldElement as ChannelBitFieldElementPb, ) - from sift.common.type.v1.channel_enum_type_pb2 import ChannelEnumType as ChannelEnumTypePb + from sift.common.type.v1.channel_enum_type_pb2 import ( + ChannelEnumType as ChannelEnumTypePb, + ) return ChannelConfigProto( name=self.name, @@ -162,7 +173,9 @@ class Flow(BaseType[FlowConfig, "Flow"]): run_id: str | None = None @classmethod - def _from_proto(cls, proto: FlowConfig, sift_client: SiftClient | None = None) -> Flow: + def _from_proto( + cls, proto: FlowConfig, sift_client: SiftClient | None = None + ) -> Flow: return cls( proto=proto, name=proto.name, @@ -207,7 +220,7 @@ def ingest(self, *, timestamp: datetime, channel_values: dict[str, Any]): """ if self.ingestion_config_id is None: raise ValueError("Ingestion config ID is not set.") - self.client.ingestion.ingest( + self.client.async_.ingestion.ingest( flow=self, timestamp=timestamp, channel_values=channel_values, @@ -222,15 +235,19 @@ def _channel_to_rust_config(channel: ChannelConfig) -> ChannelConfigPy: description=channel.description or "", unit=channel.unit or "", bit_field_elements=[ - ChannelBitFieldElementPy(name=bfe.name, index=bfe.index, bit_count=bfe.bit_count) + ChannelBitFieldElementPy( + name=bfe.name, index=bfe.index, bit_count=bfe.bit_count + ) for bfe in channel.bit_field_elements or [] ], - enum_types=[ - ChannelEnumTypePy(key=enum_key, name=enum_name) - for enum_name, enum_key in channel.enum_types.items() - ] - if channel.enum_types - else [], + enum_types=( + [ + ChannelEnumTypePy(key=enum_key, name=enum_name) + for enum_name, enum_key in channel.enum_types.items() + ] + if channel.enum_types + else [] + ), ) @@ -271,7 +288,9 @@ def _rust_channel_value_from_bitfield( return IngestWithConfigDataChannelValuePy.bitfield(byte_array) -def _to_rust_value(channel: ChannelConfig, value: Any) -> IngestWithConfigDataChannelValuePy: +def _to_rust_value( + channel: ChannelConfig, value: Any +) -> IngestWithConfigDataChannelValuePy: if value is None: return IngestWithConfigDataChannelValuePy.empty() if channel.data_type == ChannelDataType.ENUM and channel.enum_types is not None: @@ -279,7 +298,9 @@ def _to_rust_value(channel: ChannelConfig, value: Any) -> IngestWithConfigDataCh enum_val = channel.enum_types.get(enum_name) if enum_val is None: # Try to find the enum value by value instead of string. - for enum_name, enum_key in channel.enum_types.items() if channel.enum_types else []: + for enum_name, enum_key in ( + channel.enum_types.items() if channel.enum_types else [] + ): if enum_key == value: enum_name = enum_name enum_val = enum_key @@ -333,7 +354,9 @@ def _to_rust_type(data_type: ChannelDataType) -> ChannelDataTypePy: raise ValueError(f"Unknown data type: {data_type}") -def _to_ingestion_value(data_type: ChannelDataType, value: Any) -> IngestWithConfigDataChannelValue: +def _to_ingestion_value( + data_type: ChannelDataType, value: Any +) -> IngestWithConfigDataChannelValue: if value is None: return IngestWithConfigDataChannelValue(empty=Empty()) ingestion_type_string = data_type.name.lower().replace("int_", "int") diff --git a/python/lib/sift_client/sift_types/run.py b/python/lib/sift_client/sift_types/run.py index e8292b8d3..00f38ad06 100644 --- a/python/lib/sift_client/sift_types/run.py +++ b/python/lib/sift_client/sift_types/run.py @@ -56,13 +56,19 @@ def _from_proto(cls, proto: RunProto, sift_client: SiftClient | None = None) -> created_by_user_id=proto.created_by_user_id, modified_by_user_id=proto.modified_by_user_id, organization_id=proto.organization_id, - start_time=proto.start_time.ToDatetime(tzinfo=timezone.utc) - if proto.HasField("start_time") - else None, - stop_time=proto.stop_time.ToDatetime(tzinfo=timezone.utc) - if proto.HasField("stop_time") - else None, - duration=proto.duration.ToTimedelta() if proto.HasField("duration") else None, + start_time=( + proto.start_time.ToDatetime(tzinfo=timezone.utc) + if proto.HasField("start_time") + else None + ), + stop_time=( + proto.stop_time.ToDatetime(tzinfo=timezone.utc) + if proto.HasField("stop_time") + else None + ), + duration=( + proto.duration.ToTimedelta() if proto.HasField("duration") else None + ), name=proto.name, description=proto.description, tags=list(proto.tags), @@ -70,9 +76,11 @@ def _from_proto(cls, proto: RunProto, sift_client: SiftClient | None = None) -> client_key=proto.client_key if proto.HasField("client_key") else None, metadata=metadata_proto_to_dict(proto.metadata), # type: ignore asset_ids=list(proto.asset_ids), - archived_date=proto.archived_date.ToDatetime() - if proto.HasField("archived_date") - else None, + archived_date=( + proto.archived_date.ToDatetime(tzinfo=timezone.utc) + if proto.HasField("archived_date") + else None + ), is_archived=proto.is_archived, is_adhoc=proto.is_adhoc, _client=sift_client, @@ -122,7 +130,9 @@ class RunBase(ModelCreateUpdateBase): _to_proto_helpers: ClassVar = { "metadata": MappingHelper( - proto_attr_path="metadata", update_field="metadata", converter=metadata_dict_to_proto + proto_attr_path="metadata", + update_field="metadata", + converter=metadata_dict_to_proto, ), } @@ -152,7 +162,9 @@ class RunCreate(RunBase, ModelCreate[CreateRunRequestProto]): _to_proto_helpers: ClassVar = { "metadata": MappingHelper( - proto_attr_path="metadata", update_field="metadata", converter=metadata_dict_to_proto + proto_attr_path="metadata", + update_field="metadata", + converter=metadata_dict_to_proto, ), } @@ -172,7 +184,9 @@ class RunUpdate(RunBase, ModelUpdate[RunProto]): _to_proto_helpers: ClassVar = { "metadata": MappingHelper( - proto_attr_path="metadata", update_field="metadata", converter=metadata_dict_to_proto + proto_attr_path="metadata", + update_field="metadata", + converter=metadata_dict_to_proto, ), } From 69739dac978afe701cfe1d17590bc91b9190c5a2 Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Thu, 9 Oct 2025 14:25:07 -0700 Subject: [PATCH 27/39] linting --- .../_tests/resources/test_ingestion.py | 58 ++-- .../examples/generic_workflow_example.py | 254 +++++++++--------- .../lib/sift_client/sift_types/ingestion.py | 37 +-- python/lib/sift_client/sift_types/run.py | 4 +- 4 files changed, 152 insertions(+), 201 deletions(-) diff --git a/python/lib/sift_client/_tests/resources/test_ingestion.py b/python/lib/sift_client/_tests/resources/test_ingestion.py index c5f00d0b3..ce4a059be 100644 --- a/python/lib/sift_client/_tests/resources/test_ingestion.py +++ b/python/lib/sift_client/_tests/resources/test_ingestion.py @@ -41,7 +41,7 @@ def test_run(sift_client: SiftClient): ) yield run # Cleanup - sift_client.runs.archive(run=run.id_) + sift_client.runs.archive(run=run) class TestIngestionAPIAsync: @@ -56,9 +56,7 @@ async def test_create_basic_config(self, sift_client, test_run): flow = Flow( name="test-basic-flow", channels=[ - ChannelConfig( - name="test-channel", data_type=ChannelDataType.DOUBLE - ), + ChannelConfig(name="test-channel", data_type=ChannelDataType.DOUBLE), ], ) @@ -77,18 +75,14 @@ async def test_create_config_with_multiple_flows(self, sift_client, test_run): regular_flow = Flow( name="test-regular-flow", channels=[ - ChannelConfig( - name="regular-channel", data_type=ChannelDataType.DOUBLE - ), + ChannelConfig(name="regular-channel", data_type=ChannelDataType.DOUBLE), ], ) highspeed_flow = Flow( name="test-highspeed-flow", channels=[ - ChannelConfig( - name="highspeed-channel", data_type=ChannelDataType.DOUBLE - ), + ChannelConfig(name="highspeed-channel", data_type=ChannelDataType.DOUBLE), ], ) @@ -123,9 +117,7 @@ async def test_create_config_with_enum_channel(self, sift_client, test_run): assert config_id is not None @pytest.mark.asyncio - async def test_create_config_with_bit_field_channel( - self, sift_client, test_run - ): + async def test_create_config_with_bit_field_channel(self, sift_client, test_run): """Test creating an ingestion configuration with bit field channel.""" flow = Flow( name="test-bitfield-flow", @@ -134,12 +126,8 @@ async def test_create_config_with_bit_field_channel( name="test-bit-field-channel", data_type=ChannelDataType.BIT_FIELD, bit_field_elements=[ - ChannelBitFieldElement( - name="voltage", index=0, bit_count=4 - ), - ChannelBitFieldElement( - name="current", index=4, bit_count=2 - ), + ChannelBitFieldElement(name="voltage", index=0, bit_count=4), + ChannelBitFieldElement(name="current", index=4, bit_count=2), ChannelBitFieldElement(name="status", index=6, bit_count=2), ], ), @@ -160,9 +148,7 @@ async def test_flow_sealed_after_config_creation(self, sift_client, test_run): flow = Flow( name="test-sealed-flow", channels=[ - ChannelConfig( - name="test-channel", data_type=ChannelDataType.DOUBLE - ), + ChannelConfig(name="test-channel", data_type=ChannelDataType.DOUBLE), ], ) @@ -173,9 +159,7 @@ async def test_flow_sealed_after_config_creation(self, sift_client, test_run): ) # Try to add a channel after config creation - with pytest.raises( - ValueError, match="Cannot add a channel to a flow after creation" - ): + with pytest.raises(ValueError, match="Cannot add a channel to a flow after creation"): flow.add_channel( ChannelConfig(name="new-channel", data_type=ChannelDataType.DOUBLE) ) @@ -189,9 +173,7 @@ async def test_ingest_double_data(self, sift_client, test_run): flow = Flow( name="test-double-flow", channels=[ - ChannelConfig( - name="double-channel", data_type=ChannelDataType.DOUBLE - ), + ChannelConfig(name="double-channel", data_type=ChannelDataType.DOUBLE), ], ) @@ -251,12 +233,8 @@ async def test_ingest_bit_field_data_as_dict(self, sift_client, test_run): name="bitfield-channel", data_type=ChannelDataType.BIT_FIELD, bit_field_elements=[ - ChannelBitFieldElement( - name="voltage", index=0, bit_count=4 - ), - ChannelBitFieldElement( - name="current", index=4, bit_count=2 - ), + ChannelBitFieldElement(name="voltage", index=0, bit_count=4), + ChannelBitFieldElement(name="current", index=4, bit_count=2), ChannelBitFieldElement(name="led", index=6, bit_count=1), ChannelBitFieldElement(name="heater", index=7, bit_count=1), ], @@ -367,9 +345,7 @@ async def test_ingest_highspeed_data(self, sift_client, test_run): flow = Flow( name="test-highspeed-data-flow", channels=[ - ChannelConfig( - name="highspeed-channel", data_type=ChannelDataType.DOUBLE - ), + ChannelConfig(name="highspeed-channel", data_type=ChannelDataType.DOUBLE), ], ) @@ -429,9 +405,7 @@ async def test_ingest_missing_channel_raises_error(self, sift_client, test_run): ) @pytest.mark.asyncio - async def test_ingest_invalid_enum_value_raises_error( - self, sift_client, test_run - ): + async def test_ingest_invalid_enum_value_raises_error(self, sift_client, test_run): """Test that ingesting an invalid enum value raises an error.""" flow = Flow( name="test-enum-validation-flow", @@ -471,9 +445,7 @@ async def test_resume_ingestion_after_wait(self, sift_client, test_run): flow = Flow( name="test-resume-flow", channels=[ - ChannelConfig( - name="test-channel", data_type=ChannelDataType.DOUBLE - ), + ChannelConfig(name="test-channel", data_type=ChannelDataType.DOUBLE), ], ) diff --git a/python/lib/sift_client/examples/generic_workflow_example.py b/python/lib/sift_client/examples/generic_workflow_example.py index c3ec1b642..1263928ad 100644 --- a/python/lib/sift_client/examples/generic_workflow_example.py +++ b/python/lib/sift_client/examples/generic_workflow_example.py @@ -1,127 +1,127 @@ -import asyncio -import os -from datetime import datetime, timezone - -from sift_client.client import SiftClient - -# Import sift_client types for calculated channels and rules -from sift_client.sift_types import ( - CalculatedChannelUpdate, - ChannelReference, - RuleAction, - RuleAnnotationType, - RuleCreate, - RuleUpdate, -) - -""" -Placeholder for future examples. FD-67 -""" - - -async def main(): - grpc_url = os.getenv("SIFT_GRPC_URI", "localhost:50051") - api_key = os.getenv("SIFT_API_KEY", "") - rest_url = os.getenv("SIFT_REST_URI", "localhost:8080") - client = SiftClient(grpc_url=grpc_url, api_key=api_key, rest_url=rest_url) - - asset = client.assets.find(name="NostromoLV426") - asset_id = asset.id_ - print("Found asset", asset.name) - - calculated_channels = client.calculated_channels.list_( - name_regex="velocity_per.*", - asset_id=asset_id, - ) - updated = False - calculated_channel = None - if calculated_channels: - print(f"Found calculated channels: {[cc.name for cc in calculated_channels]}") - for cc in calculated_channels: - if cc.name == "velocity_per_voltage": - calculated_channel = cc.update( - CalculatedChannelUpdate( - expression="$1 / $2 + 0.1", - expression_channel_references=cc.channel_references, - ) - ) - print("Updated calculated channel", calculated_channel) - else: - # Create a calculated channel that divides mainmotor.velocity by voltage - print("\nCreating calculated channel...") - calculated_channel = client.calculated_channels.create( - dict( - name="velocity_per_voltage", - description="Ratio of mainmotor velocity to voltage", - expression="$1 / $2", # $1 = mainmotor.velocity, $2 = voltage - channel_references=[ - ChannelReference( - channel_reference="$1", channel_identifier="mainmotor.velocity" - ), - ChannelReference(channel_reference="$2", channel_identifier="voltage"), - ], - units="velocity/voltage", - asset_ids=[asset_id], - user_notes="Created to monitor velocity-to-voltage ratio", - ) - ) - print( - f"Created calculated channel: {calculated_channel.name} (ID: {calculated_channel.calculated_channel_id})" - ) - - # Create a rule that creates an annotation when the ratio is above 0.1 - rule_search = "high_velocity_voltage" - print(f"Looking for rule containing {rule_search}") - rules = client.rules.list( - name_contains=rule_search, - ) - if rules: - print(f"Found rules: {[rule.name for rule in rules]}") - # Example of batch get if you just had the rule ids: - rules = client.rules.batch_get(rule_ids=[rule.rule_id for rule in rules]) - print(f"Batch get on IDs also works: {[rule.name for rule in rules]}") - - rule = rules[0] - print(f"Updating rule: {rule.name}") - rule = rule.update( - RuleUpdate( - description=f"Alert when velocity-to-voltage ratio exceeds 0.1 (Updated at {datetime.now(tz=timezone.utc).isoformat()})", - asset_ids=[asset_id], - ) - ) - updated = True - else: - print(f"No rules found for {rule_search}") - rules = client.rules.list_( - asset_ids=[asset_id], - ) - if rules: - print(f"However these rules do exist: {[rule.name for rule in rules]}") - print("Attempting to create rule for high_velocity_voltage_ratio_alert") - rule = client.rules.create( - RuleCreate( - name="high_velocity_voltage_ratio_alert", - description="Alert when velocity-to-voltage ratio exceeds 0.1", - expression="$1 > 0.1", - channel_references=[ - ChannelReference( - channel_reference="$1", - channel_identifier=calculated_channel.name, - ), - ], - action=RuleAction.annotation( - annotation_type=RuleAnnotationType.DATA_REVIEW, - tags=["high_ratio", "alert"], - default_assignee_user_id=None, # You can set a user ID here if needed - ), - ) - ) - print(f"Created rule: {rule.name} (ID: {rule.rule_id})") - - if updated: - print("Second run through, deleting rule") - rule.delete() - - -if __name__ == "__main__": - asyncio.run(main()) +# import asyncio +# import os +# from datetime import datetime, timezone +# +# from sift_client.client import SiftClient +# +# # Import sift_client types for calculated channels and rules +# from sift_client.sift_types import ( +# CalculatedChannelUpdate, +# ChannelReference, +# RuleAction, +# RuleAnnotationType, +# RuleCreate, +# RuleUpdate, +# ) +# +# """ +# Placeholder for future examples. FD-67 +# """ +# +# +# async def main(): +# grpc_url = os.getenv("SIFT_GRPC_URI", "localhost:50051") +# api_key = os.getenv("SIFT_API_KEY", "") +# rest_url = os.getenv("SIFT_REST_URI", "localhost:8080") +# client = SiftClient(grpc_url=grpc_url, api_key=api_key, rest_url=rest_url) +# +# asset = client.assets.find(name="NostromoLV426") +# asset_id = asset.id_ +# print("Found asset", asset.name) +# +# calculated_channels = client.calculated_channels.list_( +# name_regex="velocity_per.*", +# asset_id=asset_id, +# ) +# updated = False +# calculated_channel = None +# if calculated_channels: +# print(f"Found calculated channels: {[cc.name for cc in calculated_channels]}") +# for cc in calculated_channels: +# if cc.name == "velocity_per_voltage": +# calculated_channel = cc.update( +# CalculatedChannelUpdate( +# expression="$1 / $2 + 0.1", +# expression_channel_references=cc.channel_references, +# ) +# ) +# print("Updated calculated channel", calculated_channel) +# else: +# # Create a calculated channel that divides mainmotor.velocity by voltage +# print("\nCreating calculated channel...") +# calculated_channel = client.calculated_channels.create( +# dict( +# name="velocity_per_voltage", +# description="Ratio of mainmotor velocity to voltage", +# expression="$1 / $2", # $1 = mainmotor.velocity, $2 = voltage +# channel_references=[ +# ChannelReference( +# channel_reference="$1", channel_identifier="mainmotor.velocity" +# ), +# ChannelReference(channel_reference="$2", channel_identifier="voltage"), +# ], +# units="velocity/voltage", +# asset_ids=[asset_id], +# user_notes="Created to monitor velocity-to-voltage ratio", +# ) +# ) +# print( +# f"Created calculated channel: {calculated_channel.name} (ID: {calculated_channel.calculated_channel_id})" +# ) +# +# # Create a rule that creates an annotation when the ratio is above 0.1 +# rule_search = "high_velocity_voltage" +# print(f"Looking for rule containing {rule_search}") +# rules = client.rules.list( +# name_contains=rule_search, +# ) +# if rules: +# print(f"Found rules: {[rule.name for rule in rules]}") +# # Example of batch get if you just had the rule ids: +# rules = client.rules.batch_get(rule_ids=[rule.rule_id for rule in rules]) +# print(f"Batch get on IDs also works: {[rule.name for rule in rules]}") +# +# rule = rules[0] +# print(f"Updating rule: {rule.name}") +# rule = rule.update( +# RuleUpdate( +# description=f"Alert when velocity-to-voltage ratio exceeds 0.1 (Updated at {datetime.now(tz=timezone.utc).isoformat()})", +# asset_ids=[asset_id], +# ) +# ) +# updated = True +# else: +# print(f"No rules found for {rule_search}") +# rules = client.rules.list_( +# asset_ids=[asset_id], +# ) +# if rules: +# print(f"However these rules do exist: {[rule.name for rule in rules]}") +# print("Attempting to create rule for high_velocity_voltage_ratio_alert") +# rule = client.rules.create( +# RuleCreate( +# name="high_velocity_voltage_ratio_alert", +# description="Alert when velocity-to-voltage ratio exceeds 0.1", +# expression="$1 > 0.1", +# channel_references=[ +# ChannelReference( +# channel_reference="$1", +# channel_identifier=calculated_channel.name, +# ), +# ], +# action=RuleAction.annotation( +# annotation_type=RuleAnnotationType.DATA_REVIEW, +# tags=["high_ratio", "alert"], +# default_assignee_user_id=None, # You can set a user ID here if needed +# ), +# ) +# ) +# print(f"Created rule: {rule.name} (ID: {rule.rule_id})") +# +# if updated: +# print("Second run through, deleting rule") +# rule.delete() +# +# +# if __name__ == "__main__": +# asyncio.run(main()) diff --git a/python/lib/sift_client/sift_types/ingestion.py b/python/lib/sift_client/sift_types/ingestion.py index 12d9251eb..a8f145b86 100644 --- a/python/lib/sift_client/sift_types/ingestion.py +++ b/python/lib/sift_client/sift_types/ingestion.py @@ -75,9 +75,7 @@ def _validate_enum_types(self): raise ValueError( f"Channel '{self.name}' has data_type ENUM but enum_types is not provided" ) - elif ( - self.data_type == ChannelDataType.BIT_FIELD and not self.bit_field_elements - ): + elif self.data_type == ChannelDataType.BIT_FIELD and not self.bit_field_elements: raise ValueError( f"Channel '{self.name}' has data_type BIT_FIELD but bit_field_elements is not provided" ) @@ -95,17 +93,12 @@ def _from_proto( description=proto.description if proto.description else None, unit=proto.unit if proto.unit else None, bit_field_elements=( - [ - ChannelBitFieldElement._from_proto(el) - for el in proto.bit_field_elements - ] + [ChannelBitFieldElement._from_proto(el) for el in proto.bit_field_elements] if proto.bit_field_elements else None ), enum_types=( - {enum.name: enum.key for enum in proto.enum_types} - if proto.enum_types - else None + {enum.name: enum.key for enum in proto.enum_types} if proto.enum_types else None ), _client=sift_client, ) @@ -125,9 +118,7 @@ def from_channel(cls, channel: Channel) -> ChannelConfig: data_type=channel.data_type, description=channel.description, unit=channel.unit, - bit_field_elements=( - channel.bit_field_elements if channel.bit_field_elements else None - ), + bit_field_elements=(channel.bit_field_elements if channel.bit_field_elements else None), enum_types=channel.enum_types, ) @@ -173,9 +164,7 @@ class Flow(BaseType[FlowConfig, "Flow"]): run_id: str | None = None @classmethod - def _from_proto( - cls, proto: FlowConfig, sift_client: SiftClient | None = None - ) -> Flow: + def _from_proto(cls, proto: FlowConfig, sift_client: SiftClient | None = None) -> Flow: return cls( proto=proto, name=proto.name, @@ -235,9 +224,7 @@ def _channel_to_rust_config(channel: ChannelConfig) -> ChannelConfigPy: description=channel.description or "", unit=channel.unit or "", bit_field_elements=[ - ChannelBitFieldElementPy( - name=bfe.name, index=bfe.index, bit_count=bfe.bit_count - ) + ChannelBitFieldElementPy(name=bfe.name, index=bfe.index, bit_count=bfe.bit_count) for bfe in channel.bit_field_elements or [] ], enum_types=( @@ -288,9 +275,7 @@ def _rust_channel_value_from_bitfield( return IngestWithConfigDataChannelValuePy.bitfield(byte_array) -def _to_rust_value( - channel: ChannelConfig, value: Any -) -> IngestWithConfigDataChannelValuePy: +def _to_rust_value(channel: ChannelConfig, value: Any) -> IngestWithConfigDataChannelValuePy: if value is None: return IngestWithConfigDataChannelValuePy.empty() if channel.data_type == ChannelDataType.ENUM and channel.enum_types is not None: @@ -298,9 +283,7 @@ def _to_rust_value( enum_val = channel.enum_types.get(enum_name) if enum_val is None: # Try to find the enum value by value instead of string. - for enum_name, enum_key in ( - channel.enum_types.items() if channel.enum_types else [] - ): + for enum_name, enum_key in channel.enum_types.items() if channel.enum_types else []: if enum_key == value: enum_name = enum_name enum_val = enum_key @@ -354,9 +337,7 @@ def _to_rust_type(data_type: ChannelDataType) -> ChannelDataTypePy: raise ValueError(f"Unknown data type: {data_type}") -def _to_ingestion_value( - data_type: ChannelDataType, value: Any -) -> IngestWithConfigDataChannelValue: +def _to_ingestion_value(data_type: ChannelDataType, value: Any) -> IngestWithConfigDataChannelValue: if value is None: return IngestWithConfigDataChannelValue(empty=Empty()) ingestion_type_string = data_type.name.lower().replace("int_", "int") diff --git a/python/lib/sift_client/sift_types/run.py b/python/lib/sift_client/sift_types/run.py index 00f38ad06..d66ce2f42 100644 --- a/python/lib/sift_client/sift_types/run.py +++ b/python/lib/sift_client/sift_types/run.py @@ -66,9 +66,7 @@ def _from_proto(cls, proto: RunProto, sift_client: SiftClient | None = None) -> if proto.HasField("stop_time") else None ), - duration=( - proto.duration.ToTimedelta() if proto.HasField("duration") else None - ), + duration=(proto.duration.ToTimedelta() if proto.HasField("duration") else None), name=proto.name, description=proto.description, tags=list(proto.tags), From e589b2aacdadf3d9881e57879100fa6405b026b5 Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Thu, 9 Oct 2025 15:32:35 -0700 Subject: [PATCH 28/39] add base test --- .../sift_client/_tests/sift_types/__init__.py | 0 .../_tests/sift_types/test_base.py | 591 ++++++++++++++++++ python/lib/sift_client/sift_types/_base.py | 6 +- 3 files changed, 594 insertions(+), 3 deletions(-) create mode 100644 python/lib/sift_client/_tests/sift_types/__init__.py create mode 100644 python/lib/sift_client/_tests/sift_types/test_base.py diff --git a/python/lib/sift_client/_tests/sift_types/__init__.py b/python/lib/sift_client/_tests/sift_types/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/lib/sift_client/_tests/sift_types/test_base.py b/python/lib/sift_client/_tests/sift_types/test_base.py new file mode 100644 index 000000000..c7f01e5de --- /dev/null +++ b/python/lib/sift_client/_tests/sift_types/test_base.py @@ -0,0 +1,591 @@ +"""Unit tests for sift_types._base module.""" + +from __future__ import annotations + +from datetime import datetime, timezone +from typing import ClassVar +from unittest.mock import MagicMock + +import pytest +from sift.calculated_channels.v2.calculated_channels_pb2 import ( + CalculatedChannel as CalculatedChannelProto, +) +from sift.calculated_channels.v2.calculated_channels_pb2 import ( + CreateCalculatedChannelRequest, +) + +from sift_client.sift_types._base import ( + BaseType, + MappingHelper, + ModelCreate, + ModelUpdate, +) + + +class SimpleCreateModel(ModelCreate[CreateCalculatedChannelRequest]): + """Simple model for testing basic field mapping.""" + + name: str + description: str | None = None + units: str | None = None + + def _get_proto_class(self) -> type[CreateCalculatedChannelRequest]: + return CreateCalculatedChannelRequest + + +class NestedCreateModel(ModelCreate[CreateCalculatedChannelRequest]): + """Model for testing nested field mapping with MappingHelper.""" + + name: str + description: str | None = None + expression: str | None = None + all_assets: bool | None = None + + _to_proto_helpers: ClassVar = { + "expression": MappingHelper( + proto_attr_path="calculated_channel_configuration.query_configuration.sel.expression", + update_field="query_configuration", + ), + "all_assets": MappingHelper( + proto_attr_path="calculated_channel_configuration.asset_configuration.all_assets", + ), + } + + def _get_proto_class(self) -> type[CreateCalculatedChannelRequest]: + return CreateCalculatedChannelRequest + + +class SimpleUpdateModel(ModelUpdate[CalculatedChannelProto]): + """Simple model for testing update with field masks.""" + + name: str | None = None + description: str | None = None + units: str | None = None + + def _get_proto_class(self) -> type[CalculatedChannelProto]: + return CalculatedChannelProto + + def _add_resource_id_to_proto(self, proto_msg: CalculatedChannelProto): + if self._resource_id is None: + raise ValueError("Resource ID must be set before adding to proto") + proto_msg.calculated_channel_id = self._resource_id + + +class TestModelCreate: + """Tests for ModelCreate base class.""" + + def test_simple_create_with_all_fields(self): + """Test creating a proto with all fields set.""" + model = SimpleCreateModel( + name="test_name", description="test_description", units="test_units" + ) + proto = model.to_proto() + + assert proto.name == "test_name" + assert proto.description == "test_description" + assert proto.units == "test_units" + + def test_simple_create_with_none_fields_excluded(self): + """Test that None fields are excluded from proto.""" + model = SimpleCreateModel(name="test_name", description=None, units=None) + proto = model.to_proto() + + assert proto.name == "test_name" + # Proto should not have description or units set + assert proto.description == "" # Proto default for string + assert proto.units == "" # Proto default for string + + def test_simple_create_with_unset_fields_excluded(self): + """Test that unset fields are excluded from proto.""" + model = SimpleCreateModel(name="test_name") + proto = model.to_proto() + + assert proto.name == "test_name" + # Proto should not have description or units set + assert proto.description == "" # Proto default for string + assert proto.units == "" # Proto default for string + + def test_nested_create_with_mapping_helper(self): + """Test creating a proto with nested field mapping.""" + model = NestedCreateModel( + name="test_name", + description="test_description", + expression="$1 + $2", + all_assets=True, + ) + proto = model.to_proto() + + assert proto.name == "test_name" + assert proto.description == "test_description" + # Check nested fields + assert ( + proto.calculated_channel_configuration.query_configuration.sel.expression == "$1 + $2" + ) + assert proto.calculated_channel_configuration.asset_configuration.all_assets is True + + def test_nested_create_with_none_nested_fields(self): + """Test that None values in nested fields are excluded.""" + model = NestedCreateModel( + name="test_name", + description="test_description", + expression=None, + all_assets=None, + ) + proto = model.to_proto() + + assert proto.name == "test_name" + assert proto.description == "test_description" + # Nested fields should not be set + assert ( + proto.calculated_channel_configuration.query_configuration.sel.expression == "" + ) # Proto default + assert ( + proto.calculated_channel_configuration.asset_configuration.all_assets is False + ) # Proto default for bool + + def test_nested_create_with_unset_nested_fields(self): + """Test that unset nested fields are excluded.""" + model = NestedCreateModel(name="test_name", description="test_description") + proto = model.to_proto() + + assert proto.name == "test_name" + assert proto.description == "test_description" + # Nested fields should not be set + assert ( + proto.calculated_channel_configuration.query_configuration.sel.expression == "" + ) # Proto default + assert ( + proto.calculated_channel_configuration.asset_configuration.all_assets is False + ) # Proto default for bool + + def test_mixed_none_and_set_fields(self): + """Test model with mix of None, unset, and set fields.""" + model = NestedCreateModel( + name="test_name", + description=None, # Explicitly None + expression="$1 + $2", # Set + # all_assets is unset + ) + proto = model.to_proto() + + assert proto.name == "test_name" + assert proto.description == "" # None excluded, proto default + assert ( + proto.calculated_channel_configuration.query_configuration.sel.expression == "$1 + $2" + ) + assert ( + proto.calculated_channel_configuration.asset_configuration.all_assets is False + ) # Unset, proto default + + +class TestModelUpdate: + """Tests for ModelUpdate base class.""" + + def test_simple_update_with_field_mask(self): + """Test updating a proto with field mask.""" + model = SimpleUpdateModel(name="new_name", description="new_description") + model.resource_id = "test_id" + + proto, mask = model.to_proto_with_mask() + + assert proto.calculated_channel_id == "test_id" + assert proto.name == "new_name" + assert proto.description == "new_description" + assert set(mask.paths) == {"name", "description"} + + def test_update_with_none_value_excluded(self): + """Test that explicitly setting a field to None excludes it in the mask.""" + model = SimpleUpdateModel(name="new_name", description=None) + model.resource_id = "test_id" + + proto, mask = model.to_proto_with_mask() + + assert proto.calculated_channel_id == "test_id" + assert proto.name == "new_name" + + assert "description" not in mask.paths + assert "name" in mask.paths + + def test_update_with_unset_fields_excluded(self): + """Test that unset fields are excluded from the mask.""" + model = SimpleUpdateModel(name="new_name") + model.resource_id = "test_id" + + proto, mask = model.to_proto_with_mask() + + assert proto.calculated_channel_id == "test_id" + assert proto.name == "new_name" + # Only name should be in the mask + assert mask.paths == ["name"] + + def test_update_requires_resource_id(self): + """Test that update fails without resource_id.""" + model = SimpleUpdateModel(name="new_name") + + with pytest.raises(ValueError, match="Resource ID must be set"): + model.to_proto_with_mask() + + +class TestMappingHelper: + """Tests for MappingHelper functionality.""" + + def test_mapping_helper_basic(self): + """Test basic MappingHelper creation.""" + helper = MappingHelper(proto_attr_path="field.nested.path") + assert helper.proto_attr_path == "field.nested.path" + assert helper.update_field is None + assert helper.converter is None + + def test_mapping_helper_with_update_field(self): + """Test MappingHelper with update_field.""" + helper = MappingHelper(proto_attr_path="field.nested.path", update_field="field") + assert helper.proto_attr_path == "field.nested.path" + assert helper.update_field == "field" + + def test_mapping_helper_with_converter(self): + """Test MappingHelper with converter function.""" + + def converter(x): + return x.upper() + + helper = MappingHelper(proto_attr_path="field.path", converter=converter) + assert helper.converter is not None + assert helper.converter("test") == "TEST" + + +class TestEdgeCases: + """Tests for edge cases and regression prevention.""" + + def test_empty_model_create(self): + """Test creating with only required fields.""" + model = SimpleCreateModel(name="test") + proto = model.to_proto() + assert proto.name == "test" + + def test_nested_path_expansion(self): + """Test that nested paths are properly expanded.""" + model = NestedCreateModel(name="test", expression="$1 + $2") + proto = model.to_proto() + + # Verify the nested structure was created + assert proto.HasField("calculated_channel_configuration") + assert proto.calculated_channel_configuration.HasField("query_configuration") + assert proto.calculated_channel_configuration.query_configuration.HasField("sel") + assert ( + proto.calculated_channel_configuration.query_configuration.sel.expression == "$1 + $2" + ) + + def test_multiple_nested_fields_same_parent(self): + """Test multiple nested fields that share a parent path.""" + model = NestedCreateModel(name="test", expression="$1 + $2", all_assets=True) + proto = model.to_proto() + + # Both fields should be set under calculated_channel_configuration + assert ( + proto.calculated_channel_configuration.query_configuration.sel.expression == "$1 + $2" + ) + assert proto.calculated_channel_configuration.asset_configuration.all_assets is True + + def test_validation_error_on_invalid_helper_field(self): + """Test that MappingHelper validation catches mismatched fields.""" + with pytest.raises(ValueError, match="MappingHelper created for"): + + class InvalidModel(ModelCreate[CreateCalculatedChannelRequest]): + name: str + + _to_proto_helpers: ClassVar = { + "nonexistent_field": MappingHelper(proto_attr_path="some.path"), + } + + def _get_proto_class(self): + return CreateCalculatedChannelRequest + + # This should raise during __init__ + InvalidModel(name="test") + + +class TestBaseType: + """Tests for BaseType base class.""" + + def test_base_type_concrete_implementation(self): + """Test creating a concrete BaseType implementation.""" + + class TestModel(BaseType[CalculatedChannelProto, "TestModel"]): + name: str + created_date: datetime | None = None + + @classmethod + def _from_proto(cls, proto, sift_client=None): + return cls(name=proto.name, _client=sift_client) + + model = TestModel(name="test", id_="test_id") + assert model.name == "test" + assert model.id_ == "test_id" + + def test_id_or_error_with_id(self): + """Test _id_or_error property when ID is set.""" + + class TestModel(BaseType[CalculatedChannelProto, "TestModel"]): + name: str + + @classmethod + def _from_proto(cls, proto, sift_client=None): + return cls(name=proto.name) + + model = TestModel(name="test", id_="test_id_123") + assert model._id_or_error == "test_id_123" + + def test_id_or_error_without_id(self): + """Test _id_or_error property raises when ID is not set.""" + + class TestModel(BaseType[CalculatedChannelProto, "TestModel"]): + name: str + + @classmethod + def _from_proto(cls, proto, sift_client=None): + return cls(name=proto.name) + + model = TestModel(name="test") + with pytest.raises(ValueError, match="ID is not set"): + _ = model._id_or_error + + def test_client_property_without_client(self): + """Test client property raises when client is not set.""" + + class TestModel(BaseType[CalculatedChannelProto, "TestModel"]): + name: str + + @classmethod + def _from_proto(cls, proto, sift_client=None): + return cls(name=proto.name) + + model = TestModel(name="test") + with pytest.raises(AttributeError, match="Sift client not set"): + _ = model.client + + def test_apply_client_to_instance(self): + """Test _apply_client_to_instance method.""" + + class TestModel(BaseType[CalculatedChannelProto, "TestModel"]): + name: str + + @classmethod + def _from_proto(cls, proto, sift_client=None): + return cls(name=proto.name) + + model = TestModel(name="test") + assert model._client is None + + mock_client = MagicMock() + model._apply_client_to_instance(mock_client) + assert model._client is mock_client + assert model.client is mock_client + + def test_update_method(self): + """Test _update method updates fields from another instance.""" + + class TestModel(BaseType[CalculatedChannelProto, "TestModel"]): + name: str + description: str | None = None + version: int | None = None + + @classmethod + def _from_proto(cls, proto, sift_client=None): + return cls( + name=proto.name, + description=proto.description, + version=proto.version, + proto=proto, + ) + + # Create original model + original = TestModel(name="original", description="old desc", version=1, id_="id1") + + # Create updated model + mock_proto = MagicMock() + updated = TestModel( + name="updated", + description="new desc", + version=2, + id_="id1", + proto=mock_proto, + ) + + # Update original with updated values + result = original._update(updated) + + assert result is original # Returns self + assert original.name == "updated" + assert original.description == "new desc" + assert original.version == 2 + assert original.proto is mock_proto + + def test_validate_timezones_with_valid_datetime(self): + """Test timezone validation passes with timezone-aware datetime.""" + + class TestModel(BaseType[CalculatedChannelProto, "TestModel"]): + name: str + created_date: datetime | None = None + + @classmethod + def _from_proto(cls, proto, sift_client=None): + return cls(name=proto.name) + + # Should not raise + model = TestModel(name="test", created_date=datetime.now(timezone.utc)) + assert model.created_date.tzinfo is not None + + def test_validate_timezones_with_naive_datetime(self): + """Test timezone validation fails with naive datetime.""" + + class TestModel(BaseType[CalculatedChannelProto, "TestModel"]): + name: str + created_date: datetime + + @classmethod + def _from_proto(cls, proto, sift_client=None): + return cls(name=proto.name) + + # Should raise validation error + with pytest.raises(ValueError, match="must have timezone information"): + TestModel(name="test", created_date=datetime.now(timezone.utc)) + + def test_validate_timezones_with_none_datetime(self): + """Test timezone validation passes when datetime is None.""" + + class TestModel(BaseType[CalculatedChannelProto, "TestModel"]): + name: str + created_date: datetime | None = None + + @classmethod + def _from_proto(cls, proto, sift_client=None): + return cls(name=proto.name) + + # Should not raise + model = TestModel(name="test", created_date=None) + assert model.created_date is None + + def test_proto_field_excluded_from_dump(self): + """Test that proto field is excluded from model_dump.""" + + class TestModel(BaseType[CalculatedChannelProto, "TestModel"]): + name: str + + @classmethod + def _from_proto(cls, proto, sift_client=None): + return cls(name=proto.name, proto=proto) + + mock_proto = MagicMock() + model = TestModel(name="test", proto=mock_proto) + + dumped = model.model_dump() + assert "proto" not in dumped + assert "name" in dumped + + def test_frozen_model_config(self): + """Test that BaseType models are frozen.""" + + class TestModel(BaseType[CalculatedChannelProto, "TestModel"]): + name: str + + @classmethod + def _from_proto(cls, proto, sift_client=None): + return cls(name=proto.name) + + model = TestModel(name="test") + + # Should not be able to modify frozen model + with pytest.raises(Exception): # Pydantic raises ValidationError # noqa: B017 + model.name = "new_name" + + +class TestBuildProtoAndPaths: + """Tests specifically for _build_proto_and_paths method.""" + + def test_build_proto_simple_fields(self): + """Test building proto with simple scalar fields.""" + model = SimpleCreateModel(name="test", description="desc") + proto = CreateCalculatedChannelRequest() + + paths = model._build_proto_and_paths(proto, {"name": "test", "description": "desc"}) + + assert proto.name == "test" + assert proto.description == "desc" + assert set(paths) == {"name", "description"} + + def test_build_proto_with_prefix(self): + """Test building proto with path prefix.""" + model = SimpleCreateModel(name="test") + proto = CreateCalculatedChannelRequest() + + paths = model._build_proto_and_paths(proto, {"name": "test"}, prefix="parent") + + assert proto.name == "test" + assert paths == ["parent.name"] + + def test_build_proto_with_nested_dict(self): + """Test building proto with nested dictionary through normal flow.""" + # This tests that the MappingHelper properly expands nested paths + model = NestedCreateModel(name="test", expression="$1 + $2") + proto = model.to_proto() + + # Verify the nested structure was created correctly + assert ( + proto.calculated_channel_configuration.query_configuration.sel.expression == "$1 + $2" + ) + assert proto.name == "test" + + def test_build_proto_with_submessage_dict(self): + """Test building proto when data contains a dict for a submessage field.""" + + class SubmessageModel(ModelCreate[CreateCalculatedChannelRequest]): + name: str + + def _get_proto_class(self): + return CreateCalculatedChannelRequest + + model = SubmessageModel(name="test") + proto = CreateCalculatedChannelRequest() + + # Test that we can build nested structures by passing dict data + # This simulates what happens when processing nested proto messages + data = {"name": "test"} + paths = model._build_proto_and_paths(proto, data) + + assert proto.name == "test" + assert "name" in paths + + def test_build_proto_with_mapping_helper_update_field(self): + """Test that mapping helper's update_field is added to paths.""" + model = NestedCreateModel(name="test", expression="$1 + $2") + proto = CreateCalculatedChannelRequest() + + data = {"name": "test", "expression": "$1 + $2"} + paths = model._build_proto_and_paths(proto, data) + + assert "query_configuration" in paths + assert "name" in paths + assert ( + proto.calculated_channel_configuration.query_configuration.sel.expression == "$1 + $2" + ) + + def test_build_proto_error_on_invalid_field(self): + """Test that setting an invalid field raises TypeError.""" + model = SimpleCreateModel(name="test") + proto = CreateCalculatedChannelRequest() + + with pytest.raises(TypeError, match="Can't set"): + model._build_proto_and_paths(proto, {"nonexistent_field": "value"}) + + def test_build_proto_already_setting_path_override(self): + """Test that already_setting_path_override skips helper processing.""" + model = NestedCreateModel(name="test") + proto = CreateCalculatedChannelRequest() + + # When already_setting_path_override=True, it should skip the helper + # and try to process the field directly + data = {"name": "test"} + paths = model._build_proto_and_paths(proto, data, already_setting_path_override=True) + + assert proto.name == "test" + assert "name" in paths diff --git a/python/lib/sift_client/sift_types/_base.py b/python/lib/sift_client/sift_types/_base.py index 595285fda..7a3e9c813 100644 --- a/python/lib/sift_client/sift_types/_base.py +++ b/python/lib/sift_client/sift_types/_base.py @@ -189,7 +189,7 @@ def to_proto(self) -> ProtoT: proto_msg = proto_cls() # Get all fields - data = self.model_dump(exclude_unset=True) + data = self.model_dump(exclude_unset=True, exclude_none=True) self._build_proto_and_paths(proto_msg, data) return proto_msg @@ -214,8 +214,8 @@ def to_proto_with_mask(self) -> tuple[ProtoT, field_mask_pb2.FieldMask]: proto_cls: type[ProtoT] = self._get_proto_class() proto_msg = proto_cls() - # Get only explicitly set fields, including those set to None - data = self.model_dump(exclude_unset=True, exclude_none=False) + # Get only explicitly set fields + data = self.model_dump(exclude_unset=True, exclude_none=True) paths = self._build_proto_and_paths(proto_msg, data) self._add_resource_id_to_proto(proto_msg) From 56f2aecb6e9c3cb92c1028c2c3683df2b1a58127 Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Thu, 9 Oct 2025 15:53:40 -0700 Subject: [PATCH 29/39] linting --- python/lib/sift_client/_tests/sift_types/test_base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/lib/sift_client/_tests/sift_types/test_base.py b/python/lib/sift_client/_tests/sift_types/test_base.py index c7f01e5de..c3ad99c53 100644 --- a/python/lib/sift_client/_tests/sift_types/test_base.py +++ b/python/lib/sift_client/_tests/sift_types/test_base.py @@ -429,7 +429,7 @@ class TestModel(BaseType[CalculatedChannelProto, "TestModel"]): @classmethod def _from_proto(cls, proto, sift_client=None): - return cls(name=proto.name) + return cls(name=proto.name, created_date=proto.created_date) # Should not raise model = TestModel(name="test", created_date=datetime.now(timezone.utc)) @@ -444,11 +444,11 @@ class TestModel(BaseType[CalculatedChannelProto, "TestModel"]): @classmethod def _from_proto(cls, proto, sift_client=None): - return cls(name=proto.name) + return cls(name=proto.name, created_date=proto.created_date) # Should raise validation error with pytest.raises(ValueError, match="must have timezone information"): - TestModel(name="test", created_date=datetime.now(timezone.utc)) + TestModel(name="test", created_date=datetime.now()) # noqa: DTZ005 def test_validate_timezones_with_none_datetime(self): """Test timezone validation passes when datetime is None.""" From 6a5aca3f7f3676593db59e461ab3e8b1d158a0dc Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Thu, 9 Oct 2025 15:55:09 -0700 Subject: [PATCH 30/39] docs fix --- python/pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index 035f64551..d2aad0d8b 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -64,7 +64,8 @@ docs = ["mkdocs", "mkdocs-include-markdown-plugin", "mkdocs-api-autonav", "mike", - "griffe-pydantic"] + "griffe-pydantic", + "mkdocs-jupyter"] # May be required for certain library functionality openssl = ["pyOpenSSL<24.0.0", "types-pyOpenSSL<24.0.0", "cffi~=1.14"] From 4be9e09cea963103bd0c156ade80a5b13617ce8c Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Thu, 9 Oct 2025 19:05:05 -0700 Subject: [PATCH 31/39] - fix tests - Properly handle optional dependency of sift-stream --- .../_internal/low_level_wrappers/ingestion.py | 35 +++++++++++------- .../_tests/sift_types/test_base.py | 4 +- python/lib/sift_client/sift_types/_base.py | 37 +++++++++++++++---- python/lib/sift_client/sift_types/asset.py | 2 +- .../sift_types/calculated_channel.py | 2 +- .../lib/sift_client/sift_types/ingestion.py | 29 +++++++++++---- python/lib/sift_client/sift_types/run.py | 6 +-- python/pyproject.toml | 2 +- 8 files changed, 80 insertions(+), 37 deletions(-) diff --git a/python/lib/sift_client/_internal/low_level_wrappers/ingestion.py b/python/lib/sift_client/_internal/low_level_wrappers/ingestion.py index b7c4a4341..133143bbf 100644 --- a/python/lib/sift_client/_internal/low_level_wrappers/ingestion.py +++ b/python/lib/sift_client/_internal/low_level_wrappers/ingestion.py @@ -18,17 +18,14 @@ from queue import Queue from typing import TYPE_CHECKING, Any, cast -import sift_stream_bindings from sift.ingestion_configs.v2.ingestion_configs_pb2 import ( GetIngestionConfigRequest, ListIngestionConfigFlowsResponse, ListIngestionConfigsRequest, ListIngestionConfigsResponse, ) -from sift.ingestion_configs.v2.ingestion_configs_pb2_grpc import IngestionConfigServiceStub -from sift_stream_bindings import ( - IngestionConfigFormPy, - IngestWithConfigDataStreamRequestPy, +from sift.ingestion_configs.v2.ingestion_configs_pb2_grpc import ( + IngestionConfigServiceStub, ) from sift_client._internal.low_level_wrappers.base import ( @@ -44,6 +41,12 @@ if TYPE_CHECKING: from datetime import datetime + from sift_stream_bindings import ( + IngestionConfigFormPy, + IngestWithConfigDataStreamRequestPy, + SiftStreamBuilderPy, + ) + class IngestionThread(threading.Thread): """Manages ingestion for a single ingestion config.""" @@ -54,7 +57,7 @@ class IngestionThread(threading.Thread): def __init__( self, - sift_stream_builder: sift_stream_bindings.SiftStreamBuilderPy, + sift_stream_builder: SiftStreamBuilderPy, data_queue: Queue, ingestion_config: IngestionConfigFormPy, no_data_timeout: int = 1, @@ -154,7 +157,7 @@ class IngestionLowLevelClient(LowLevelClientBase, WithGrpcClient): CacheEntry = namedtuple("CacheEntry", ["data_queue", "ingestion_config", "thread"]) - sift_stream_builder: sift_stream_bindings.SiftStreamBuilderPy + sift_stream_builder: SiftStreamBuilderPy stream_cache: dict[str, CacheEntry] def __init__(self, grpc_client: GrpcClient): @@ -163,21 +166,25 @@ def __init__(self, grpc_client: GrpcClient): Args: grpc_client: The gRPC client to use for making API calls. """ + from sift_stream_bindings import ( + RecoveryStrategyPy, + RetryPolicyPy, + SiftStreamBuilderPy, + ) + super().__init__(grpc_client=grpc_client) # Rust GRPC client expects URI to have http(s):// prefix. uri = grpc_client._config.uri if not uri.startswith("http"): uri = f"https://{uri}" if grpc_client._config.use_ssl else f"http://{uri}" - self.sift_stream_builder = sift_stream_bindings.SiftStreamBuilderPy( + self.sift_stream_builder = SiftStreamBuilderPy( uri=uri, apikey=grpc_client._config.api_key, ) self.sift_stream_builder.enable_tls = grpc_client._config.use_ssl # FD-177: Expose configuration for recovery strategy. - self.sift_stream_builder.recovery_strategy = ( - sift_stream_bindings.RecoveryStrategyPy.retry_only( - sift_stream_bindings.RetryPolicyPy.default() - ) + self.sift_stream_builder.recovery_strategy = RecoveryStrategyPy.retry_only( + RetryPolicyPy.default() ) self.stream_cache = {} @@ -229,7 +236,9 @@ async def get_ingestion_config_id_from_client_key(self, client_key: str) -> str return ingestion_configs[0].id_ def _new_ingestion_thread( - self, ingestion_config_id: str, ingestion_config: IngestionConfigFormPy | None = None + self, + ingestion_config_id: str, + ingestion_config: IngestionConfigFormPy | None = None, ): """Start a new ingestion thread. This allows ingestion to happen in the background regardless of if the user is using the sync or async client diff --git a/python/lib/sift_client/_tests/sift_types/test_base.py b/python/lib/sift_client/_tests/sift_types/test_base.py index c3ad99c53..cc266e7b9 100644 --- a/python/lib/sift_client/_tests/sift_types/test_base.py +++ b/python/lib/sift_client/_tests/sift_types/test_base.py @@ -41,7 +41,7 @@ class NestedCreateModel(ModelCreate[CreateCalculatedChannelRequest]): expression: str | None = None all_assets: bool | None = None - _to_proto_helpers: ClassVar = { + _to_proto_helpers: ClassVar[dict[str, MappingHelper]] = { "expression": MappingHelper( proto_attr_path="calculated_channel_configuration.query_configuration.sel.expression", update_field="query_configuration", @@ -293,7 +293,7 @@ def test_validation_error_on_invalid_helper_field(self): class InvalidModel(ModelCreate[CreateCalculatedChannelRequest]): name: str - _to_proto_helpers: ClassVar = { + _to_proto_helpers: ClassVar[dict[str, MappingHelper]] = { "nonexistent_field": MappingHelper(proto_attr_path="some.path"), } diff --git a/python/lib/sift_client/sift_types/_base.py b/python/lib/sift_client/sift_types/_base.py index 7a3e9c813..1cbadeec0 100644 --- a/python/lib/sift_client/sift_types/_base.py +++ b/python/lib/sift_client/sift_types/_base.py @@ -2,7 +2,14 @@ from abc import ABC, abstractmethod from datetime import datetime -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Generic, + TypeVar, +) from google.protobuf import field_mask_pb2, message from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator @@ -84,10 +91,13 @@ class ModelCreateUpdateBase(BaseModel, ABC): """Base class for Pydantic models that generate proto messages.""" model_config = ConfigDict(frozen=False) - _to_proto_helpers: ClassVar[dict[str, MappingHelper]] = PrivateAttr(default={}) + _to_proto_helpers: ClassVar[dict[str, MappingHelper]] = {} def __init__(self, **data: Any): super().__init__(**data) + + @model_validator(mode="after") + def _check_mapping_helpers(self): if self._to_proto_helpers: data = self.model_dump() for expected_field in self._to_proto_helpers.keys(): @@ -95,6 +105,17 @@ def __init__(self, **data: Any): raise ValueError( f"MappingHelper created for {expected_field} but {self.__class__.__name__} has no matching variable names." ) + return self + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + required_annotation = "ClassVar[dict[str, MappingHelper]]" + annotation = cls.__annotations__.get("_to_proto_helpers") + # Check for correct annotation otherwise pydantic will not populate this properly + if annotation and annotation != required_annotation: + raise TypeError( + f"{cls.__name__} must define _to_proto_helpers type as: {required_annotation}" + ) def _build_proto_and_paths( self, proto_msg, data, prefix="", already_setting_path_override=False @@ -123,13 +144,13 @@ def _build_proto_and_paths( if mapping_helper.update_field: paths.append(mapping_helper.update_field) elif isinstance(value, dict): - if field_name in self._to_proto_helpers: - assert self._to_proto_helpers[field_name].converter, ( + if field_name in self.__class__._to_proto_helpers: + assert self.__class__._to_proto_helpers[field_name].converter, ( f"Expecting to run a coverter given a helper was defined for: {field_name}" ) sub_paths = self._build_proto_and_paths( proto_msg, - {field_name: self._to_proto_helpers[field_name].converter(value)}, # type: ignore[misc] + {field_name: self.__class__._to_proto_helpers[field_name].converter(value)}, # type: ignore[misc] "", already_setting_path_override=True, ) @@ -151,13 +172,13 @@ def _build_proto_and_paths( try: repeated_field.extend(value) # Add all new values except TypeError as e: - if field_name in self._to_proto_helpers: - assert self._to_proto_helpers[field_name].converter, ( + if field_name in self.__class__._to_proto_helpers: + assert self.__class__._to_proto_helpers[field_name].converter, ( f"Expecting to run a coverter given a helper was defined for: {field_name}" ) for item in value: repeated_field.append( - self._to_proto_helpers[field_name].converter(**item) # type: ignore + self.__class__._to_proto_helpers[field_name].converter(**item) # type: ignore ) else: raise e diff --git a/python/lib/sift_client/sift_types/asset.py b/python/lib/sift_client/sift_types/asset.py index 7b02a3dce..dfab26c9f 100644 --- a/python/lib/sift_client/sift_types/asset.py +++ b/python/lib/sift_client/sift_types/asset.py @@ -113,7 +113,7 @@ class AssetUpdate(ModelUpdate[AssetProto]): metadata: dict[str, str | float | bool] | None = None is_archived: bool | None = None - _to_proto_helpers: ClassVar = { + _to_proto_helpers: ClassVar[dict[str, MappingHelper]] = { "metadata": MappingHelper( proto_attr_path="metadata", update_field="metadata", diff --git a/python/lib/sift_client/sift_types/calculated_channel.py b/python/lib/sift_client/sift_types/calculated_channel.py index 19dc472d7..1d05ea385 100644 --- a/python/lib/sift_client/sift_types/calculated_channel.py +++ b/python/lib/sift_client/sift_types/calculated_channel.py @@ -149,7 +149,7 @@ class CalculatedChannelBase(ModelCreateUpdateBase): metadata: dict[str, str | float | bool] | None = None - _to_proto_helpers: ClassVar = { + _to_proto_helpers: ClassVar[dict[str, MappingHelper]] = { "expression": MappingHelper( proto_attr_path="calculated_channel_configuration.query_configuration.sel.expression", update_field="query_configuration", diff --git a/python/lib/sift_client/sift_types/ingestion.py b/python/lib/sift_client/sift_types/ingestion.py index a8f145b86..2f90e94f5 100644 --- a/python/lib/sift_client/sift_types/ingestion.py +++ b/python/lib/sift_client/sift_types/ingestion.py @@ -15,14 +15,6 @@ from sift.ingestion_configs.v2.ingestion_configs_pb2 import ( IngestionConfig as IngestionConfigProto, ) -from sift_stream_bindings import ( - ChannelBitFieldElementPy, - ChannelConfigPy, - ChannelDataTypePy, - ChannelEnumTypePy, - FlowConfigPy, - IngestWithConfigDataChannelValuePy, -) from sift_client.sift_types._base import BaseType from sift_client.sift_types.channel import ChannelBitFieldElement, ChannelDataType @@ -30,6 +22,13 @@ if TYPE_CHECKING: from datetime import datetime + from sift_stream_bindings import ( + ChannelConfigPy, + ChannelDataTypePy, + FlowConfigPy, + IngestWithConfigDataChannelValuePy, + ) + from sift_client.client import SiftClient from sift_client.sift_types.channel import Channel @@ -179,6 +178,8 @@ def _to_proto(self) -> FlowConfig: ) def _to_rust_config(self) -> FlowConfigPy: + from sift_stream_bindings import FlowConfigPy + return FlowConfigPy( name=self.name, channels=[_channel_to_rust_config(channel) for channel in self.channels], @@ -218,6 +219,12 @@ def ingest(self, *, timestamp: datetime, channel_values: dict[str, Any]): # Converter functions. def _channel_to_rust_config(channel: ChannelConfig) -> ChannelConfigPy: + from sift_stream_bindings import ( + ChannelBitFieldElementPy, + ChannelConfigPy, + ChannelEnumTypePy, + ) + return ChannelConfigPy( name=channel.name, data_type=_to_rust_type(channel.data_type), @@ -252,6 +259,8 @@ def _rust_channel_value_from_bitfield( Returns: A ChannelValuePy object. """ + from sift_stream_bindings import IngestWithConfigDataChannelValuePy + assert channel.bit_field_elements is not None # We expect individual ints or bytes to represent full bitfield values. if isinstance(value, bytes) or isinstance(value, int): @@ -276,6 +285,8 @@ def _rust_channel_value_from_bitfield( def _to_rust_value(channel: ChannelConfig, value: Any) -> IngestWithConfigDataChannelValuePy: + from sift_stream_bindings import IngestWithConfigDataChannelValuePy + if value is None: return IngestWithConfigDataChannelValuePy.empty() if channel.data_type == ChannelDataType.ENUM and channel.enum_types is not None: @@ -314,6 +325,8 @@ def _to_rust_value(channel: ChannelConfig, value: Any) -> IngestWithConfigDataCh def _to_rust_type(data_type: ChannelDataType) -> ChannelDataTypePy: + from sift_stream_bindings import ChannelDataTypePy + if data_type == ChannelDataType.DOUBLE: return ChannelDataTypePy.Double elif data_type == ChannelDataType.FLOAT: diff --git a/python/lib/sift_client/sift_types/run.py b/python/lib/sift_client/sift_types/run.py index d66ce2f42..eb3e081b1 100644 --- a/python/lib/sift_client/sift_types/run.py +++ b/python/lib/sift_client/sift_types/run.py @@ -126,7 +126,7 @@ class RunBase(ModelCreateUpdateBase): tags: list[str] | None = None metadata: dict[str, str | float | bool] | None = None - _to_proto_helpers: ClassVar = { + _to_proto_helpers: ClassVar[dict[str, MappingHelper]] = { "metadata": MappingHelper( proto_attr_path="metadata", update_field="metadata", @@ -158,7 +158,7 @@ class RunCreate(RunBase, ModelCreate[CreateRunRequestProto]): stop_time: datetime | None = None organization_id: str | None = None - _to_proto_helpers: ClassVar = { + _to_proto_helpers: ClassVar[dict[str, MappingHelper]] = { "metadata": MappingHelper( proto_attr_path="metadata", update_field="metadata", @@ -180,7 +180,7 @@ class RunUpdate(RunBase, ModelUpdate[RunProto]): stop_time: datetime | None = None is_archived: bool | None = None - _to_proto_helpers: ClassVar = { + _to_proto_helpers: ClassVar[dict[str, MappingHelper]] = { "metadata": MappingHelper( proto_attr_path="metadata", update_field="metadata", diff --git a/python/pyproject.toml b/python/pyproject.toml index d2aad0d8b..31dea1967 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ "PyYAML~=6.0", "pandas~=2.0", "protobuf>=4.0", - "pydantic~=2.0", + "pydantic~=2.10", # Support python 3.9+ typing in older versons of python. "eval-type-backport~=0.2", "pydantic_core~=2.3", From 540cf1b14ad462ad88a737b0539f14f6539d1963 Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Thu, 9 Oct 2025 19:08:15 -0700 Subject: [PATCH 32/39] fix tests --- .../sift_client/_internal/low_level_wrappers/ingestion.py | 5 ++++- python/lib/sift_client/resources/ingestion.py | 1 - 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/lib/sift_client/_internal/low_level_wrappers/ingestion.py b/python/lib/sift_client/_internal/low_level_wrappers/ingestion.py index 133143bbf..743e211a1 100644 --- a/python/lib/sift_client/_internal/low_level_wrappers/ingestion.py +++ b/python/lib/sift_client/_internal/low_level_wrappers/ingestion.py @@ -299,7 +299,6 @@ async def create_ingestion_config( asset_name: str, flows: list[Flow], client_key: str | None = None, - organization_id: str | None = None, ) -> str: """Create an ingestion config. @@ -312,6 +311,8 @@ async def create_ingestion_config( Returns: The id of the new or found ingestion config. """ + from sift_stream_bindings import IngestionConfigFormPy + ingestion_config_id = None if client_key: logger.debug(f"Getting ingestion config id for client key {client_key}") @@ -390,6 +391,8 @@ def ingest_flow( channel_values: The channel values to ingest. organization_id: The organization id to use for ingestion. Only relevant if the user is part of several organizations. """ + from sift_stream_bindings import IngestWithConfigDataStreamRequestPy + if not flow.ingestion_config_id: raise ValueError( "Flow has no ingestion config id -- have you created an ingestion config for this flow?" diff --git a/python/lib/sift_client/resources/ingestion.py b/python/lib/sift_client/resources/ingestion.py index 8abd088ae..a04368c21 100644 --- a/python/lib/sift_client/resources/ingestion.py +++ b/python/lib/sift_client/resources/ingestion.py @@ -67,7 +67,6 @@ async def create_ingestion_config( asset_name=asset_name, flows=flows, client_key=client_key, - organization_id=organization_id, ) for flow in flows: flow._apply_client_to_instance(self.client) From 23247546aabef3c06867ddf51049a1e2ac1e321c Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Thu, 9 Oct 2025 19:11:38 -0700 Subject: [PATCH 33/39] remove environment --- .github/workflows/python_ci.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/python_ci.yaml b/.github/workflows/python_ci.yaml index 4b7d539fa..836e00437 100644 --- a/.github/workflows/python_ci.yaml +++ b/.github/workflows/python_ci.yaml @@ -12,7 +12,6 @@ on: jobs: test-python: runs-on: ubuntu-latest - environment: python-integration-tests defaults: run: working-directory: python From b1cecdd3e8dd6583e941bf40761e6c4d6088315d Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Thu, 9 Oct 2025 21:21:54 -0700 Subject: [PATCH 34/39] remove invalid test --- .../_tests/resources/test_ingestion.py | 27 ------------------- 1 file changed, 27 deletions(-) diff --git a/python/lib/sift_client/_tests/resources/test_ingestion.py b/python/lib/sift_client/_tests/resources/test_ingestion.py index ce4a059be..1858ff10b 100644 --- a/python/lib/sift_client/_tests/resources/test_ingestion.py +++ b/python/lib/sift_client/_tests/resources/test_ingestion.py @@ -377,33 +377,6 @@ async def test_ingest_highspeed_data(self, sift_client, test_run): class TestIngestionValidation: """Tests for ingestion validation and error handling.""" - @pytest.mark.asyncio - async def test_ingest_missing_channel_raises_error(self, sift_client, test_run): - """Test that ingesting without all channels raises an error.""" - flow = Flow( - name="test-validation-flow", - channels=[ - ChannelConfig(name="channel1", data_type=ChannelDataType.DOUBLE), - ChannelConfig(name="channel2", data_type=ChannelDataType.DOUBLE), - ], - ) - - await sift_client.async_.ingestion.create_ingestion_config( - asset_name=ASSET_NAME, - run_id=test_run.id_, - flows=[flow], - ) - - timestamp = datetime.now(tz=timezone.utc) - with pytest.raises( - Exception, - match="Expected all channels in flow to have a data point at same time", - ): - flow.ingest( - timestamp=timestamp, - channel_values={"channel1": 1.0}, # Missing channel2 - ) - @pytest.mark.asyncio async def test_ingest_invalid_enum_value_raises_error(self, sift_client, test_run): """Test that ingesting an invalid enum value raises an error.""" From ded61129347f3f7e53cca5364b3822da2aef0d4b Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Fri, 10 Oct 2025 13:06:44 -0700 Subject: [PATCH 35/39] improve integration test coverage --- .../resources/test_calculated_channels.py | 204 +++++++++++++++-- .../_tests/resources/test_rules.py | 216 +++++++++++++++++- .../sift_client/_tests/resources/test_runs.py | 37 +++ 3 files changed, 430 insertions(+), 27 deletions(-) diff --git a/python/lib/sift_client/_tests/resources/test_calculated_channels.py b/python/lib/sift_client/_tests/resources/test_calculated_channels.py index d75227823..ab71fd602 100644 --- a/python/lib/sift_client/_tests/resources/test_calculated_channels.py +++ b/python/lib/sift_client/_tests/resources/test_calculated_channels.py @@ -9,6 +9,7 @@ """ import uuid +from datetime import datetime, timezone import pytest @@ -28,7 +29,9 @@ def test_client_binding(sift_client): assert sift_client.calculated_channels assert isinstance(sift_client.calculated_channels, CalculatedChannelsAPI) assert sift_client.async_.calculated_channels - assert isinstance(sift_client.async_.calculated_channels, CalculatedChannelsAPIAsync) + assert isinstance( + sift_client.async_.calculated_channels, CalculatedChannelsAPIAsync + ) @pytest.fixture @@ -70,8 +73,12 @@ def new_calculated_channel(calculated_channels_api_sync, sift_client): description=description, expression="$1 + $2", expression_channel_references=[ - ChannelReference(channel_reference="$1", channel_identifier=channels[0].name), - ChannelReference(channel_reference="$2", channel_identifier=channels[1].name), + ChannelReference( + channel_reference="$1", channel_identifier=channels[0].name + ), + ChannelReference( + channel_reference="$2", channel_identifier=channels[1].name + ), ], all_assets=True, ) @@ -86,7 +93,9 @@ class TestGet: """Tests for the async get method.""" @pytest.mark.asyncio - async def test_get_by_id(self, calculated_channels_api_async, test_calculated_channel): + async def test_get_by_id( + self, calculated_channels_api_async, test_calculated_channel + ): """Test getting a specific calculated channel by ID.""" retrieved_calc_channel = await calculated_channels_api_async.get( calculated_channel_id=test_calculated_channel.id_ @@ -139,9 +148,13 @@ async def test_list_with_name_filter(self, calculated_channels_api_async): assert calc_channel.name == test_calc_channel_name @pytest.mark.asyncio - async def test_list_with_name_contains_filter(self, calculated_channels_api_async): + async def test_list_with_name_contains_filter( + self, calculated_channels_api_async + ): """Test calculated channel listing with name contains filtering.""" - calc_channels = await calculated_channels_api_async.list_(name_contains="test", limit=5) + calc_channels = await calculated_channels_api_async.list_( + name_contains="test", limit=5 + ) assert isinstance(calc_channels, list) assert calc_channels @@ -192,7 +205,9 @@ async def test_find_calculated_channel( assert found_calc_channel.id_ == test_calculated_channel.id_ @pytest.mark.asyncio - async def test_find_nonexistent_calculated_channel(self, calculated_channels_api_async): + async def test_find_nonexistent_calculated_channel( + self, calculated_channels_api_async + ): """Test finding a non-existent calculated channel returns None.""" found_calc_channel = await calculated_channels_api_async.find( name="nonexistent-calculated-channel-name-12345" @@ -209,14 +224,20 @@ class TestCreate: """Tests for the async create method.""" @pytest.mark.asyncio - async def test_create_basic_calculated_channel(self, calculated_channels_api_async): + async def test_create_basic_calculated_channel( + self, calculated_channels_api_async + ): """Test creating a basic calculated channel with minimal fields.""" from datetime import datetime, timezone - calc_channel_name = f"test_calc_channel_create_{datetime.now(timezone.utc).isoformat()}" + calc_channel_name = ( + f"test_calc_channel_create_{datetime.now(timezone.utc).isoformat()}" + ) description = "Test calculated channel created by Sift Client pytest" - channels = await calculated_channels_api_async.client.async_.channels.list_(limit=2) + channels = await calculated_channels_api_async.client.async_.channels.list_( + limit=2 + ) assert len(channels) >= 2 calc_channel_create = CalculatedChannelCreate( @@ -224,13 +245,19 @@ async def test_create_basic_calculated_channel(self, calculated_channels_api_asy description=description, expression="$1 + $2", expression_channel_references=[ - ChannelReference(channel_reference="$1", channel_identifier=channels[0].name), - ChannelReference(channel_reference="$2", channel_identifier=channels[1].name), + ChannelReference( + channel_reference="$1", channel_identifier=channels[0].name + ), + ChannelReference( + channel_reference="$2", channel_identifier=channels[1].name + ), ], all_assets=True, ) - created_calc_channel = await calculated_channels_api_async.create(calc_channel_create) + created_calc_channel = await calculated_channels_api_async.create( + calc_channel_create + ) try: assert created_calc_channel is not None @@ -244,14 +271,20 @@ async def test_create_basic_calculated_channel(self, calculated_channels_api_asy await calculated_channels_api_async.archive(created_calc_channel) @pytest.mark.asyncio - async def test_create_calculated_channel_with_dict(self, calculated_channels_api_async): + async def test_create_calculated_channel_with_dict( + self, calculated_channels_api_async + ): """Test creating a calculated channel using a dictionary.""" from datetime import datetime, timezone - calc_channel_name = f"test_calc_channel_dict_{datetime.now(timezone.utc).isoformat()}" + calc_channel_name = ( + f"test_calc_channel_dict_{datetime.now(timezone.utc).isoformat()}" + ) description = "Test calculated channel created by Sift Client pytest" - channels = await calculated_channels_api_async.client.async_.channels.list_(limit=2) + channels = await calculated_channels_api_async.client.async_.channels.list_( + limit=2 + ) assert len(channels) >= 2 calc_channel_dict = { @@ -265,7 +298,9 @@ async def test_create_calculated_channel_with_dict(self, calculated_channels_api "all_assets": True, } - created_calc_channel = await calculated_channels_api_async.create(calc_channel_dict) + created_calc_channel = await calculated_channels_api_async.create( + calc_channel_dict + ) try: assert created_calc_channel.name == calc_channel_name @@ -341,6 +376,125 @@ async def test_update_with_id_string( finally: await calculated_channels_api_async.archive(new_calculated_channel.id_) + @pytest.mark.asyncio + async def test_update_calculated_channel_units( + self, calculated_channels_api_async, new_calculated_channel + ): + """Test updating a calculated channel's units.""" + try: + update = CalculatedChannelUpdate(units="percentage") + updated_calc_channel = await calculated_channels_api_async.update( + new_calculated_channel, update + ) + + assert updated_calc_channel.id_ == new_calculated_channel.id_ + assert updated_calc_channel.units == "percentage" + assert updated_calc_channel.name == new_calculated_channel.name + finally: + await calculated_channels_api_async.archive(new_calculated_channel.id_) + + @pytest.mark.asyncio + async def test_update_with_complex_expression( + self, calculated_channels_api_async, sift_client + ): + """Test updating a calculated channel with a complex expression using multiple channel references.""" + # Get channels to reference + channels = await sift_client.async_.channels.list_(limit=3) + assert len(channels) >= 3 + + # Create a calculated channel + calc_channel_name = ( + f"test_calc_channel_complex_{datetime.now(timezone.utc).isoformat()}" + ) + calc_channel_create = CalculatedChannelCreate( + name=calc_channel_name, + description="Test calculated channel for complex expression update", + expression="$1 + $2", + expression_channel_references=[ + ChannelReference( + channel_reference="$1", channel_identifier=channels[0].name + ), + ChannelReference( + channel_reference="$2", channel_identifier=channels[1].name + ), + ], + all_assets=True, + ) + created_calc_channel = await calculated_channels_api_async.create( + calc_channel_create + ) + + try: + # Update with complex expression + update = CalculatedChannelUpdate( + expression="($1 / $2) * 100 + ($3 * 0.1)", + expression_channel_references=[ + ChannelReference( + channel_reference="$1", channel_identifier=channels[0].name + ), + ChannelReference( + channel_reference="$2", channel_identifier=channels[1].name + ), + ChannelReference( + channel_reference="$3", channel_identifier=channels[2].name + ), + ], + ) + updated_calc_channel = await calculated_channels_api_async.update( + created_calc_channel, update + ) + + assert updated_calc_channel.id_ == created_calc_channel.id_ + assert updated_calc_channel.expression == "($1 / $2) * 100 + ($3 * 0.1)" + assert len(updated_calc_channel.channel_references) == 3 + # Verify all three channel references are present + ref_identifiers = { + ref.channel_identifier + for ref in updated_calc_channel.channel_references + } + assert channels[0].name in ref_identifiers + assert channels[1].name in ref_identifiers + assert channels[2].name in ref_identifiers + finally: + await calculated_channels_api_async.archive(created_calc_channel.id_) + + @pytest.mark.asyncio + async def test_update_with_invalid_expression( + self, calculated_channels_api_async, new_calculated_channel + ): + """Test updating a calculated channel with an invalid expression. + + Note: The server may or may not validate expression syntax at update time. + This test documents the current behavior. + """ + try: + # Attempt to update with an invalid expression + update = CalculatedChannelUpdate( + expression="invalid_expression", + expression_channel_references=[ + ChannelReference( + channel_reference="$1", + channel_identifier=new_calculated_channel.channel_references[ + 0 + ].channel_identifier, + ), + ], + ) + + # This may succeed or fail depending on server-side validation + # If it succeeds, the expression is stored but may fail at evaluation time + try: + updated_calc_channel = await calculated_channels_api_async.update( + new_calculated_channel, update + ) + # If update succeeds, verify the expression was set + assert updated_calc_channel.expression == "invalid_expression" + except Exception as e: + # If server validates and rejects, that's also acceptable behavior + assert "expression" in str(e).lower() or "invalid" in str(e).lower() + finally: + await calculated_channels_api_async.archive(new_calculated_channel.id_) + class TestArchive: """Tests for the async archive method.""" @@ -349,7 +503,9 @@ async def test_archive_calculated_channel( self, calculated_channels_api_async, new_calculated_channel ): """Test archiving a calculated channel.""" - calc_channel = await calculated_channels_api_async.archive(new_calculated_channel) + calc_channel = await calculated_channels_api_async.archive( + new_calculated_channel + ) assert isinstance(calc_channel, CalculatedChannel) assert calc_channel.id_ == new_calculated_channel.id_ @@ -372,7 +528,9 @@ async def test_archive_with_id_string( self, calculated_channels_api_async, new_calculated_channel ): """Test archiving a calculated channel by passing ID as string.""" - calc_channel = await calculated_channels_api_async.archive(new_calculated_channel.id_) + calc_channel = await calculated_channels_api_async.archive( + new_calculated_channel.id_ + ) assert isinstance(calc_channel, CalculatedChannel) assert calc_channel.id_ == new_calculated_channel.id_ @@ -389,7 +547,9 @@ async def test_unarchive_calculated_channel( try: await calculated_channels_api_async.archive(new_calculated_channel) - calc_channel = await calculated_channels_api_async.unarchive(new_calculated_channel) + calc_channel = await calculated_channels_api_async.unarchive( + new_calculated_channel + ) assert isinstance(calc_channel, CalculatedChannel) assert calc_channel.id_ == new_calculated_channel.id_ @@ -401,7 +561,9 @@ class TestListVersions: """Tests for the async list_versions method.""" @pytest.mark.asyncio - async def test_list_versions(self, calculated_channels_api_async, test_calculated_channel): + async def test_list_versions( + self, calculated_channels_api_async, test_calculated_channel + ): """Test listing versions of a calculated channel.""" versions = await calculated_channels_api_async.list_versions( calculated_channel=test_calculated_channel diff --git a/python/lib/sift_client/_tests/resources/test_rules.py b/python/lib/sift_client/_tests/resources/test_rules.py index bae0e6efa..ac767aa4b 100644 --- a/python/lib/sift_client/_tests/resources/test_rules.py +++ b/python/lib/sift_client/_tests/resources/test_rules.py @@ -8,6 +8,7 @@ """ import uuid +from datetime import datetime, timezone import pytest @@ -76,8 +77,12 @@ def new_rule(rules_api_sync, sift_client): description=description, expression="$1 > $2", channel_references=[ - ChannelReference(channel_reference="$1", channel_identifier=channels[0].name), - ChannelReference(channel_reference="$2", channel_identifier=channels[1].name), + ChannelReference( + channel_reference="$1", channel_identifier=channels[0].name + ), + ChannelReference( + channel_reference="$2", channel_identifier=channels[1].name + ), ], action=RuleAction.annotation( annotation_type=RuleAnnotationType.DATA_REVIEW, @@ -108,7 +113,9 @@ async def test_get_by_id(self, rules_api_async, test_rule): async def test_get_by_client_key(self, rules_api_async, test_rule): """Test getting a specific rule by client key.""" if test_rule.client_key: - retrieved_rule = await rules_api_async.get(client_key=test_rule.client_key) + retrieved_rule = await rules_api_async.get( + client_key=test_rule.client_key + ) assert retrieved_rule is not None assert retrieved_rule.id_ == test_rule.id_ @@ -259,8 +266,12 @@ async def test_create_basic_rule(self, rules_api_async): description=description, expression="$1 > $2", channel_references=[ - ChannelReference(channel_reference="$1", channel_identifier=channels[0].name), - ChannelReference(channel_reference="$2", channel_identifier=channels[1].name), + ChannelReference( + channel_reference="$1", channel_identifier=channels[0].name + ), + ChannelReference( + channel_reference="$2", channel_identifier=channels[1].name + ), ], action=RuleAction.annotation( annotation_type=RuleAnnotationType.DATA_REVIEW, @@ -339,7 +350,10 @@ async def test_update_rule_description(self, rules_api_async, new_rule): assert updated_rule.expression == new_rule.expression assert updated_rule.action.action_type == new_rule.action.action_type assert updated_rule.client_key == new_rule.client_key - assert updated_rule.rule_version.created_date > new_rule.rule_version.created_date + assert ( + updated_rule.rule_version.created_date + > new_rule.rule_version.created_date + ) finally: await rules_api_async.archive(new_rule.id_) @@ -393,6 +407,196 @@ async def test_update_with_version_notes(self, rules_api_async, new_rule): finally: await rules_api_async.archive(new_rule.id_) + @pytest.mark.asyncio + async def test_update_rule_action(self, rules_api_async, new_rule): + """Test updating a rule's action including annotation type, tags, and assignee.""" + try: + # Update the action with new annotation type, tags, and assignee + update = RuleUpdate( + action=RuleAction.annotation( + annotation_type=RuleAnnotationType.PHASE, + tags=["sift-client-pytest"], + default_assignee_user_id=new_rule.created_by_user_id, + ), + ) + updated_rule = await rules_api_async.update(new_rule, update) + + # Verify the action was updated + assert updated_rule.id_ == new_rule.id_ + assert updated_rule.action.action_type == RuleActionType.ANNOTATION + assert updated_rule.action.annotation_type == RuleAnnotationType.PHASE + assert set(updated_rule.action.tags) == {"sift-client-pytest"} + assert ( + updated_rule.action.default_assignee_user_id + == new_rule.created_by_user_id + ) + + # Verify other fields remain unchanged + assert updated_rule.name == new_rule.name + assert updated_rule.expression == new_rule.expression + finally: + await rules_api_async.archive(new_rule.id_) + + @pytest.mark.asyncio + async def test_update_with_complex_expression( + self, rules_api_async, sift_client + ): + """Test updating a rule with a complex expression (range check).""" + # Get channels and assets + channels = await sift_client.async_.channels.list_(limit=2) + assert len(channels) >= 1 + assets = await sift_client.async_.assets.list_(limit=1) + assert len(assets) >= 1 + + # Create a rule with simple expression + rule_name = ( + f"test_rule_complex_expr_{datetime.now(timezone.utc).isoformat()}" + ) + rule_create = RuleCreate( + name=rule_name, + description="Test rule for complex expression update", + expression="$1 > 0.5", + channel_references=[ + ChannelReference( + channel_reference="$1", channel_identifier=channels[0].name + ), + ], + action=RuleAction.annotation( + annotation_type=RuleAnnotationType.DATA_REVIEW, + tags=["test"], + ), + asset_ids=[assets[0].id_], + ) + created_rule = await rules_api_async.create(rule_create) + + try: + # Update with complex expression (range check) + update = RuleUpdate( + expression="$1 > 0.3 && $1 < 0.8", + channel_references=[ + ChannelReference( + channel_reference="$1", channel_identifier=channels[0].name + ), + ], + ) + updated_rule = await rules_api_async.update(created_rule, update) + + # Verify the expression was updated + assert updated_rule.id_ == created_rule.id_ + assert updated_rule.expression == "$1 > 0.3 && $1 < 0.8" + assert len(updated_rule.channel_references) == 1 + assert ( + updated_rule.channel_references[0].channel_identifier + == channels[0].name + ) + + # Verify other fields remain unchanged + assert updated_rule.name == created_rule.name + assert ( + updated_rule.action.action_type == created_rule.action.action_type + ) + finally: + await rules_api_async.archive(created_rule.id_) + + @pytest.mark.asyncio + async def test_update_with_multiple_channel_references( + self, rules_api_async, sift_client + ): + """Test updating a rule expression to use multiple channel references.""" + # Get channels and assets + channels = await sift_client.async_.channels.list_(limit=3) + assert len(channels) >= 3 + assets = await sift_client.async_.assets.list_(limit=1) + assert len(assets) >= 1 + + # Create a rule with simple expression + rule_name = f"test_rule_multi_refs_{datetime.now(timezone.utc).isoformat()}" + rule_create = RuleCreate( + name=rule_name, + description="Test rule for multiple channel references", + expression="$1 > $2", + channel_references=[ + ChannelReference( + channel_reference="$1", channel_identifier=channels[0].name + ), + ChannelReference( + channel_reference="$2", channel_identifier=channels[1].name + ), + ], + action=RuleAction.annotation( + annotation_type=RuleAnnotationType.DATA_REVIEW, + tags=["test"], + ), + asset_ids=[assets[0].id_], + ) + created_rule = await rules_api_async.create(rule_create) + + try: + # Update with expression using three channel references + update = RuleUpdate( + expression="($1 > $2) && ($3 < 100)", + channel_references=[ + ChannelReference( + channel_reference="$1", channel_identifier=channels[0].name + ), + ChannelReference( + channel_reference="$2", channel_identifier=channels[1].name + ), + ChannelReference( + channel_reference="$3", channel_identifier=channels[2].name + ), + ], + ) + updated_rule = await rules_api_async.update(created_rule, update) + + # Verify the expression and channel references were updated + assert updated_rule.id_ == created_rule.id_ + assert updated_rule.expression == "($1 > $2) && ($3 < 100)" + assert len(updated_rule.channel_references) == 3 + + # Verify all three channel references are present + ref_identifiers = { + ref.channel_identifier for ref in updated_rule.channel_references + } + assert channels[0].name in ref_identifiers + assert channels[1].name in ref_identifiers + assert channels[2].name in ref_identifiers + finally: + await rules_api_async.archive(created_rule.id_) + + @pytest.mark.asyncio + async def test_update_with_invalid_expression(self, rules_api_async, new_rule): + """Test updating a rule with an invalid expression. + + Note: The server may or may not validate expression syntax at update time. + This test documents the current behavior. + """ + try: + # Attempt to update with an invalid expression + update = RuleUpdate( + expression="invalid_expression", + channel_references=[ + ChannelReference( + channel_reference="$1", + channel_identifier=new_rule.channel_references[ + 0 + ].channel_identifier, + ), + ], + ) + + # This may succeed or fail depending on server-side validation + # If it succeeds, the expression is stored but may fail at evaluation time + try: + updated_rule = await rules_api_async.update(new_rule, update) + # If update succeeds, verify the expression was set + assert updated_rule.expression == "invalid_expression" + except Exception as e: + # If server validates and rejects, that's also acceptable behavior + assert "expression" in str(e).lower() or "invalid" in str(e).lower() + finally: + await rules_api_async.archive(new_rule.id_) + class TestArchive: """Tests for the async archive method.""" diff --git a/python/lib/sift_client/_tests/resources/test_runs.py b/python/lib/sift_client/_tests/resources/test_runs.py index 8d99757a5..6153c63bb 100644 --- a/python/lib/sift_client/_tests/resources/test_runs.py +++ b/python/lib/sift_client/_tests/resources/test_runs.py @@ -480,6 +480,43 @@ async def test_stop_run_with_start_time(self, runs_api_async, new_run): finally: await runs_api_async.archive(new_run.id_) + class TestAssetAssociation: + """Tests for the async asset association methods.""" + + @pytest.mark.asyncio + async def test_create_automatic_association_for_assets( + self, runs_api_async, sift_client + ): + """Test associating assets with a run for automatic data ingestion.""" + # Create a test run + run_name = f"test_run_asset_assoc_{datetime.now(timezone.utc).isoformat()}" + run_create = RunCreate( + name=run_name, + description="Test run for asset association", + tags=["sift-client-pytest"], + ) + created_run = await runs_api_async.create(run_create) + + try: + # Get some assets to associate + assets = await sift_client.async_.assets.list_(limit=2) + assert len(assets) >= 1 + + asset_names = [asset.name for asset in assets[:2]] + + # Associate assets with the run + await runs_api_async.create_automatic_association_for_assets( + run=created_run, asset_names=asset_names + ) + + # Verify the association by getting the run and checking asset_ids + updated_run = await runs_api_async.get(run_id=created_run.id_) + assert updated_run.asset_ids is not None + assert len(updated_run.asset_ids) >= len(asset_names) + + finally: + await runs_api_async.archive(created_run) + class TestRunsAPISync: """Test suite for the synchronous Runs API functionality. From e3704f165f358614b5d75727fee30aeb95c80050 Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Fri, 10 Oct 2025 14:02:54 -0700 Subject: [PATCH 36/39] add unit tests for sift_types --- python/lib/sift_client/_tests/conftest.py | 17 + .../_tests/sift_types/test_asset.py | 175 ++++++++++ .../sift_types/test_calculated_channel.py | 300 ++++++++++++++++++ .../_tests/sift_types/test_channel.py | 123 +++++++ .../_tests/sift_types/test_ingestion.py | 182 +++++++++++ .../_tests/sift_types/test_rule.py | 148 +++++++++ .../sift_client/_tests/sift_types/test_run.py | 207 ++++++++++++ python/lib/sift_client/resources/runs.py | 14 +- python/lib/sift_client/sift_types/_base.py | 23 +- .../sift_types/calculated_channel.py | 11 +- python/pyproject.toml | 5 + 11 files changed, 1190 insertions(+), 15 deletions(-) create mode 100644 python/lib/sift_client/_tests/sift_types/test_asset.py create mode 100644 python/lib/sift_client/_tests/sift_types/test_calculated_channel.py create mode 100644 python/lib/sift_client/_tests/sift_types/test_channel.py create mode 100644 python/lib/sift_client/_tests/sift_types/test_ingestion.py create mode 100644 python/lib/sift_client/_tests/sift_types/test_rule.py create mode 100644 python/lib/sift_client/_tests/sift_types/test_run.py diff --git a/python/lib/sift_client/_tests/conftest.py b/python/lib/sift_client/_tests/conftest.py index a8ff39d55..1f218dedc 100644 --- a/python/lib/sift_client/_tests/conftest.py +++ b/python/lib/sift_client/_tests/conftest.py @@ -1,10 +1,12 @@ """Shared pytest fixtures for all tests.""" import os +from unittest.mock import MagicMock import pytest from sift_client import SiftClient +from sift_client.util.util import AsyncAPIs @pytest.fixture(scope="session") @@ -23,3 +25,18 @@ def sift_client() -> SiftClient: grpc_url=grpc_url, rest_url=rest_url, ) + + +@pytest.fixture +def mock_client(): + """Create a mock SiftClient for unit testing.""" + client = MagicMock(spec=SiftClient) + # Configure the mock to have the necessary API attributes + client.assets = MagicMock() + client.runs = MagicMock() + client.channels = MagicMock() + client.calculated_channels = MagicMock() + client.rules = MagicMock() + client.async_ = MagicMock(spec=AsyncAPIs) + client.async_.ingestion = MagicMock() + return client diff --git a/python/lib/sift_client/_tests/sift_types/test_asset.py b/python/lib/sift_client/_tests/sift_types/test_asset.py new file mode 100644 index 000000000..9c085b61a --- /dev/null +++ b/python/lib/sift_client/_tests/sift_types/test_asset.py @@ -0,0 +1,175 @@ +"""Tests for sift_types.Asset model.""" + +from datetime import datetime, timezone +from unittest.mock import MagicMock + +import pytest + +from sift_client.sift_types import Asset +from sift_client.sift_types.asset import AssetUpdate + + +class TestAssetUpdate: + """Unit tests for AssetUpdate model - tests _to_proto_helpers.""" + + def test_metadata_converter(self): + """Test that metadata is converted using _to_proto_helpers.""" + metadata = {"key1": "value1", "key2": 42.5, "key3": True} + update = AssetUpdate(metadata=metadata) + update.resource_id = "test_asset_id" + + proto, mask = update.to_proto_with_mask() + + assert proto.asset_id == "test_asset_id" + # Verify metadata was converted using the helper (returns a list) + assert len(proto.metadata) == 3 + + # Find each metadata value in the list + metadata_dict = {md.key.name: md for md in proto.metadata} + assert metadata_dict["key1"].string_value == "value1" + assert metadata_dict["key2"].number_value == 42.5 + assert metadata_dict["key3"].boolean_value is True + assert "metadata" in mask.paths + + +@pytest.fixture +def mock_asset(mock_client): + """Create a mock Asset instance for testing.""" + asset = Asset( + proto=MagicMock(), + id_="test_asset_id", + name="test_asset", + organization_id="org1", + created_date=datetime.now(timezone.utc), + created_by_user_id="user1", + modified_date=datetime.now(timezone.utc), + modified_by_user_id="user1", + tags=[], + metadata={}, + is_archived=False, + archived_date=None, + ) + asset._apply_client_to_instance(mock_client) + return asset + + +class TestAsset: + """Unit tests for Asset model - tests properties and methods.""" + + def test_runs_property_calls_client(self, mock_asset, mock_client): + """Test that runs property calls client.runs.list_ with correct parameters.""" + mock_client.runs.list_.return_value = [] + + # Access runs property + _ = mock_asset.runs + + # Verify client method was called with correct asset + mock_client.runs.list_.assert_called_once_with(assets=[mock_asset]) + + def test_channels_method_calls_client(self, mock_asset, mock_client): + """Test that channels() method calls client.channels.list_ with correct parameters.""" + mock_client.channels.list_.return_value = [] + + # Call channels method + _ = mock_asset.channels(limit=5) + + # Verify client method was called with correct parameters + mock_client.channels.list_.assert_called_once_with( + asset=mock_asset, run=None, limit=5 + ) + + def test_channels_method_with_run_filter(self, mock_asset, mock_client): + """Test that channels() method passes run filter to client.""" + mock_client.channels.list_.return_value = [] + mock_run = MagicMock() + + # Call channels method with run filter + _ = mock_asset.channels(run=mock_run, limit=10) + + # Verify client method was called with run parameter + mock_client.channels.list_.assert_called_once_with( + asset=mock_asset, run=mock_run, limit=10 + ) + + def test_archive_calls_client_and_updates_self(self, mock_asset, mock_client): + """Test that archive() calls client.assets.archive and calls _update.""" + archived_asset = MagicMock() + archived_asset.is_archived = True + archived_asset.archived_date = datetime.now(timezone.utc) + mock_client.assets.archive.return_value = archived_asset + + # Mock the _update method to verify it's called + with MagicMock() as mock_update: + mock_asset._update = mock_update + + # Call archive + result = mock_asset.archive(archive_runs=False) + + # Verify client method was called + mock_client.assets.archive.assert_called_once_with( + asset=mock_asset, archive_runs=False + ) + # Verify _update was called with the returned asset + mock_update.assert_called_once_with(archived_asset) + # Verify it returns self + assert result is mock_asset + + def test_archive_with_runs(self, mock_asset, mock_client): + """Test that archive() passes archive_runs parameter correctly.""" + archived_asset = MagicMock() + mock_client.assets.archive.return_value = archived_asset + + # Mock the _update method + with MagicMock() as mock_update: + mock_asset._update = mock_update + + # Call archive with archive_runs=True + mock_asset.archive(archive_runs=True) + + # Verify client method was called with archive_runs=True + mock_client.assets.archive.assert_called_once_with( + asset=mock_asset, archive_runs=True + ) + + def test_unarchive_calls_client_and_updates_self(self, mock_asset, mock_client): + """Test that unarchive() calls client.assets.unarchive and calls _update.""" + unarchived_asset = MagicMock() + unarchived_asset.is_archived = False + mock_client.assets.unarchive.return_value = unarchived_asset + + # Mock the _update method to verify it's called + with MagicMock() as mock_update: + mock_asset._update = mock_update + + # Call unarchive + result = mock_asset.unarchive() + + # Verify client method was called + mock_client.assets.unarchive.assert_called_once_with(asset=mock_asset) + # Verify _update was called with the returned asset + mock_update.assert_called_once_with(unarchived_asset) + # Verify it returns self + assert result is mock_asset + + def test_update_calls_client_and_updates_self(self, mock_asset, mock_client): + """Test that update() calls client.assets.update and calls _update.""" + updated_asset = MagicMock() + updated_asset.tags = ["updated"] + mock_client.assets.update.return_value = updated_asset + + # Mock the _update method to verify it's called + with MagicMock() as mock_update: + mock_asset._update = mock_update + + # Call update + update = AssetUpdate(tags=["updated"]) + result = mock_asset.update(update) + + # Verify client method was called with correct parameters + mock_client.assets.update.assert_called_once_with( + asset=mock_asset, update=update + ) + # Verify _update was called with the returned asset + mock_update.assert_called_once_with(updated_asset) + # Verify it returns self + assert result is mock_asset diff --git a/python/lib/sift_client/_tests/sift_types/test_calculated_channel.py b/python/lib/sift_client/_tests/sift_types/test_calculated_channel.py new file mode 100644 index 000000000..e4c310c92 --- /dev/null +++ b/python/lib/sift_client/_tests/sift_types/test_calculated_channel.py @@ -0,0 +1,300 @@ +"""Tests for sift_types.CalculatedChannel model.""" + +from datetime import datetime, timezone +from unittest.mock import MagicMock + +import pytest + +from sift_client.sift_types import CalculatedChannel +from sift_client.sift_types.calculated_channel import ( + CalculatedChannelBase, + CalculatedChannelUpdate, +) +from sift_client.sift_types.channel import ChannelReference + + +class TestCalculatedChannelBase: + """Unit tests for CalculatedChannelBase - tests _to_proto_helpers and validators shared by Create and Update.""" + + def test_metadata_converter(self): + """Test that metadata is converted using _to_proto_helpers.""" + metadata = {"key1": "value1", "key2": 42.5, "key3": False} + update = CalculatedChannelUpdate(metadata=metadata) + update.resource_id = "test_calc_channel_id" + + proto, mask = update.to_proto_with_mask() + + assert len(proto.metadata) == 3 + + # Convert list to dict for easier assertion + metadata_dict = {md.key.name: md for md in proto.metadata} + assert metadata_dict["key1"].string_value == "value1" + assert metadata_dict["key2"].number_value == 42.5 + assert metadata_dict["key3"].boolean_value is False + assert "metadata" in mask.paths + + def test_expression_helper(self): + """Test that expression is mapped to nested proto path.""" + update = CalculatedChannelUpdate( + expression="$1 + $2", + expression_channel_references=[ + ChannelReference(channel_reference="$1", channel_identifier="channel1"), + ChannelReference(channel_reference="$2", channel_identifier="channel2"), + ], + ) + update.resource_id = "test_calc_channel_id" + + proto, mask = update.to_proto_with_mask() + + # Verify expression is set in nested path + assert ( + proto.calculated_channel_configuration.query_configuration.sel.expression + == "$1 + $2" + ) + assert "query_configuration" in mask.paths + + def test_expression_channel_references_helper(self): + """Test that expression_channel_references are converted and mapped.""" + update = CalculatedChannelUpdate( + expression="$1 + $2", + expression_channel_references=[ + ChannelReference(channel_reference="$1", channel_identifier="channel1"), + ChannelReference(channel_reference="$2", channel_identifier="channel2"), + ], + ) + update.resource_id = "test_calc_channel_id" + + proto, mask = update.to_proto_with_mask() + + # Verify channel references are converted + refs = proto.calculated_channel_configuration.query_configuration.sel.expression_channel_references + assert len(refs) == 2 + assert refs[0].channel_reference == "$1" + assert refs[0].channel_identifier == "channel1" + assert refs[1].channel_reference == "$2" + assert refs[1].channel_identifier == "channel2" + assert "query_configuration" in mask.paths + + def test_tag_ids_helper(self): + """Test that tag_ids are mapped to nested proto path.""" + update = CalculatedChannelUpdate(tag_ids=["tag1", "tag2"]) + update.resource_id = "test_calc_channel_id" + + proto, mask = update.to_proto_with_mask() + + # Verify tag_ids are set in nested path + assert list( + proto.calculated_channel_configuration.asset_configuration.selection.tag_ids + ) == ["tag1", "tag2"] + assert "asset_configuration" in mask.paths + + def test_asset_ids_helper(self): + """Test that asset_ids are mapped to nested proto path.""" + update = CalculatedChannelUpdate(asset_ids=["asset1", "asset2"]) + update.resource_id = "test_calc_channel_id" + + proto, mask = update.to_proto_with_mask() + + # Verify asset_ids are set in nested path + assert list( + proto.calculated_channel_configuration.asset_configuration.selection.asset_ids + ) == ["asset1", "asset2"] + assert "asset_configuration" in mask.paths + + def test_all_assets_helper(self): + """Test that all_assets is mapped to nested proto path.""" + update = CalculatedChannelUpdate(all_assets=True) + update.resource_id = "test_calc_channel_id" + + proto, mask = update.to_proto_with_mask() + + # Verify all_assets is set in nested path + assert proto.calculated_channel_configuration.asset_configuration.all_assets is True + # Verify update_field is in mask (same as tag_ids and asset_ids) + assert "asset_configuration" in mask.paths + + def test_asset_configuration_validator_rejects_all_assets_with_asset_ids(self): + """Test validator rejects all_assets=True with asset_ids.""" + with pytest.raises( + ValueError, match="Cannot specify both all_assets=True and asset_ids/tag_ids" + ): + CalculatedChannelUpdate( + all_assets=True, + asset_ids=["asset1"], + ) + + def test_asset_configuration_validator_rejects_all_assets_with_tag_ids(self): + """Test validator rejects all_assets=True with tag_ids.""" + with pytest.raises( + ValueError, match="Cannot specify both all_assets=True and asset_ids/tag_ids" + ): + CalculatedChannelUpdate( + all_assets=True, + tag_ids=["tag1"], + ) + + def test_expression_validator_rejects_expression_without_references(self): + """Test validator rejects expression without channel references.""" + with pytest.raises( + ValueError, match="Expression and channel references must be set together" + ): + CalculatedChannelUpdate(expression="$1 + $2") + + def test_expression_validator_rejects_references_without_expression(self): + """Test validator rejects channel references without expression.""" + with pytest.raises( + ValueError, match="Expression and channel references must be set together" + ): + CalculatedChannelUpdate( + expression_channel_references=[ + ChannelReference( + channel_reference="$1", channel_identifier="channel1" + ), + ], + ) + + def test_expression_validator_accepts_both_set(self): + """Test validator accepts expression and channel references together.""" + # Should not raise + update = CalculatedChannelUpdate( + expression="$1 + $2", + expression_channel_references=[ + ChannelReference(channel_reference="$1", channel_identifier="channel1"), + ChannelReference(channel_reference="$2", channel_identifier="channel2"), + ], + ) + assert update.expression == "$1 + $2" + assert len(update.expression_channel_references) == 2 + + +@pytest.fixture +def mock_calculated_channel(mock_client): + """Create a mock CalculatedChannel instance for testing.""" + calc_channel = CalculatedChannel( + proto=MagicMock(), + id_="test_calc_channel_id", + name="test_calc_channel", + description="test description", + expression="$1 + $2", + channel_references=[ + ChannelReference(channel_reference="$1", channel_identifier="channel1"), + ChannelReference(channel_reference="$2", channel_identifier="channel2"), + ], + is_archived=False, + units=None, + asset_ids=None, + tag_ids=None, + all_assets=True, + organization_id="org1", + client_key=None, + archived_date=None, + version_id="v1", + version=1, + change_message=None, + user_notes=None, + created_date=datetime.now(timezone.utc), + modified_date=datetime.now(timezone.utc), + created_by_user_id="user1", + modified_by_user_id="user1", + ) + calc_channel._apply_client_to_instance(mock_client) + return calc_channel + + +class TestCalculatedChannel: + """Unit tests for CalculatedChannel model - tests properties and methods.""" + + def test_archive_calls_client_and_updates_self( + self, mock_calculated_channel, mock_client + ): + """Test that archive() calls client.calculated_channels.archive and calls _update.""" + archived_calc_channel = MagicMock() + archived_calc_channel.is_archived = True + mock_client.calculated_channels.archive.return_value = archived_calc_channel + + # Mock the _update method to verify it's called + with MagicMock() as mock_update: + mock_calculated_channel._update = mock_update + + # Call archive + result = mock_calculated_channel.archive() + + # Verify client method was called + mock_client.calculated_channels.archive.assert_called_once_with( + calculated_channel=mock_calculated_channel + ) + # Verify _update was called with the returned calculated channel + mock_update.assert_called_once_with(archived_calc_channel) + # Verify it returns self + assert result is mock_calculated_channel + + def test_unarchive_calls_client_and_updates_self( + self, mock_calculated_channel, mock_client + ): + """Test that unarchive() calls client.calculated_channels.unarchive and calls _update.""" + unarchived_calc_channel = MagicMock() + unarchived_calc_channel.is_archived = False + mock_client.calculated_channels.unarchive.return_value = unarchived_calc_channel + + # Mock the _update method to verify it's called + with MagicMock() as mock_update: + mock_calculated_channel._update = mock_update + + # Call unarchive + result = mock_calculated_channel.unarchive() + + # Verify client method was called + mock_client.calculated_channels.unarchive.assert_called_once_with( + calculated_channel=mock_calculated_channel + ) + # Verify _update was called with the returned calculated channel + mock_update.assert_called_once_with(unarchived_calc_channel) + # Verify it returns self + assert result is mock_calculated_channel + + def test_update_calls_client_and_updates_self( + self, mock_calculated_channel, mock_client + ): + """Test that update() calls client.calculated_channels.update and calls _update.""" + updated_calc_channel = MagicMock() + updated_calc_channel.description = "Updated description" + mock_client.calculated_channels.update.return_value = updated_calc_channel + + # Mock the _update method to verify it's called + with MagicMock() as mock_update: + mock_calculated_channel._update = mock_update + + # Call update + update = CalculatedChannelUpdate(description="Updated description") + result = mock_calculated_channel.update(update) + + # Verify client method was called with correct parameters + mock_client.calculated_channels.update.assert_called_once_with( + calculated_channel=mock_calculated_channel, + update=update, + user_notes=None, + ) + # Verify _update was called with the returned calculated channel + mock_update.assert_called_once_with(updated_calc_channel) + # Verify it returns self + assert result is mock_calculated_channel + + def test_update_with_user_notes(self, mock_calculated_channel, mock_client): + """Test that update() passes user_notes parameter correctly.""" + updated_calc_channel = MagicMock() + mock_client.calculated_channels.update.return_value = updated_calc_channel + + # Mock the _update method + with MagicMock() as mock_update: + mock_calculated_channel._update = mock_update + + # Call update with user_notes + update = CalculatedChannelUpdate(description="Updated") + mock_calculated_channel.update(update, user_notes="Test notes") + + # Verify client method was called with user_notes + mock_client.calculated_channels.update.assert_called_once_with( + calculated_channel=mock_calculated_channel, + update=update, + user_notes="Test notes", + ) diff --git a/python/lib/sift_client/_tests/sift_types/test_channel.py b/python/lib/sift_client/_tests/sift_types/test_channel.py new file mode 100644 index 000000000..ab8fa01c2 --- /dev/null +++ b/python/lib/sift_client/_tests/sift_types/test_channel.py @@ -0,0 +1,123 @@ +"""Tests for sift_types.Channel model.""" + +from datetime import datetime, timezone +from unittest.mock import MagicMock + +import pytest + +from sift_client.sift_types import Channel +from sift_client.sift_types.channel import ChannelDataType + + +@pytest.fixture +def mock_channel(mock_client): + """Create a mock Channel instance for testing.""" + channel = Channel( + proto=MagicMock(), + id_="test_channel_id", + name="test_channel", + data_type=ChannelDataType.DOUBLE, + description="test description", + unit="m/s", + bit_field_elements=[], + enum_types={}, + asset_id="test_asset_id", + created_date=datetime.now(timezone.utc), + modified_date=datetime.now(timezone.utc), + created_by_user_id="user1", + modified_by_user_id="user1", + ) + channel._apply_client_to_instance(mock_client) + return channel + + +class TestChannel: + """Unit tests for Channel model - tests properties and methods.""" + + def test_asset_property_calls_client(self, mock_channel, mock_client): + """Test that asset property calls client.assets.get with correct parameters.""" + mock_asset = MagicMock() + mock_client.assets.get.return_value = mock_asset + + # Access asset property + result = mock_channel.asset + + # Verify client method was called with correct asset_id + mock_client.assets.get.assert_called_once_with(asset_id="test_asset_id") + assert result is mock_asset + + def test_runs_property_calls_asset_runs(self, mock_channel, mock_client): + """Test that runs property calls asset.runs.""" + mock_asset = MagicMock() + mock_runs = [MagicMock(), MagicMock()] + mock_asset.runs = mock_runs + mock_client.assets.get.return_value = mock_asset + + # Access runs property + result = mock_channel.runs + + # Verify it returns the asset's runs + assert result == mock_runs + + def test_data_method_calls_get_data(self, mock_channel, mock_client): + """Test that data() method calls client.channels.get_data with correct parameters.""" + mock_data = {"test_channel": MagicMock()} + mock_client.channels.get_data.return_value = mock_data + + # Call data method + result = mock_channel.data( + run_id="run123", + start_time=datetime(2024, 1, 1, tzinfo=timezone.utc), + end_time=datetime(2024, 1, 2, tzinfo=timezone.utc), + limit=100, + ) + + # Verify client method was called with correct parameters + mock_client.channels.get_data.assert_called_once_with( + channels=[mock_channel], + run="run123", + start_time=datetime(2024, 1, 1, tzinfo=timezone.utc), + end_time=datetime(2024, 1, 2, tzinfo=timezone.utc), + limit=100, + ) + assert result == mock_data + + def test_data_method_as_arrow(self, mock_channel, mock_client): + """Test that data() method calls get_data_as_arrow when as_arrow=True.""" + mock_data = {"test_channel": MagicMock()} + mock_client.channels.get_data_as_arrow.return_value = mock_data + + # Call data method with as_arrow=True + result = mock_channel.data( + run_id="run123", + as_arrow=True, + ) + + # Verify get_data_as_arrow was called instead of get_data + mock_client.channels.get_data_as_arrow.assert_called_once_with( + channels=[mock_channel], + run="run123", + start_time=None, + end_time=None, + limit=None, + ) + mock_client.channels.get_data.assert_not_called() + assert result == mock_data + + def test_data_method_with_minimal_params(self, mock_channel, mock_client): + """Test that data() method works with minimal parameters.""" + mock_data = {"test_channel": MagicMock()} + mock_client.channels.get_data.return_value = mock_data + + # Call data method with no parameters + result = mock_channel.data() + + # Verify client method was called with None values + mock_client.channels.get_data.assert_called_once_with( + channels=[mock_channel], + run=None, + start_time=None, + end_time=None, + limit=None, + ) + assert result == mock_data diff --git a/python/lib/sift_client/_tests/sift_types/test_ingestion.py b/python/lib/sift_client/_tests/sift_types/test_ingestion.py new file mode 100644 index 000000000..2c333a915 --- /dev/null +++ b/python/lib/sift_client/_tests/sift_types/test_ingestion.py @@ -0,0 +1,182 @@ +"""Tests for sift_types.Ingestion models.""" + +from datetime import datetime, timezone +from unittest.mock import MagicMock + +import pytest + +from sift_client.sift_types.channel import ChannelBitFieldElement, ChannelDataType +from sift_client.sift_types.ingestion import ChannelConfig, Flow, IngestionConfig + + +class TestChannelConfig: + """Unit tests for ChannelConfig model - tests validators.""" + + def test_enum_validator_rejects_enum_without_enum_types(self): + """Test validator rejects ENUM data_type without enum_types.""" + with pytest.raises( + ValueError, + match="Channel 'test_channel' has data_type ENUM but enum_types is not provided", + ): + ChannelConfig( + name="test_channel", + data_type=ChannelDataType.ENUM, + ) + + def test_enum_validator_accepts_enum_with_enum_types(self): + """Test validator accepts ENUM data_type with enum_types.""" + # Should not raise + channel = ChannelConfig( + name="test_channel", + data_type=ChannelDataType.ENUM, + enum_types={"LOW": 0, "HIGH": 1}, + ) + assert channel.data_type == ChannelDataType.ENUM + assert channel.enum_types == {"LOW": 0, "HIGH": 1} + + def test_bitfield_validator_rejects_bitfield_without_elements(self): + """Test validator rejects BIT_FIELD data_type without bit_field_elements.""" + with pytest.raises( + ValueError, + match="Channel 'test_channel' has data_type BIT_FIELD but bit_field_elements is not provided", + ): + ChannelConfig( + name="test_channel", + data_type=ChannelDataType.BIT_FIELD, + ) + + def test_bitfield_validator_accepts_bitfield_with_elements(self): + """Test validator accepts BIT_FIELD data_type with bit_field_elements.""" + # Should not raise + channel = ChannelConfig( + name="test_channel", + data_type=ChannelDataType.BIT_FIELD, + bit_field_elements=[ + ChannelBitFieldElement(name="field1", index=0, bit_count=4), + ChannelBitFieldElement(name="field2", index=1, bit_count=4), + ], + ) + assert channel.data_type == ChannelDataType.BIT_FIELD + assert len(channel.bit_field_elements) == 2 + + def test_other_data_types_dont_require_special_fields(self): + """Test that other data types don't require enum_types or bit_field_elements.""" + # Should not raise for DOUBLE + channel = ChannelConfig( + name="test_channel", + data_type=ChannelDataType.DOUBLE, + ) + assert channel.data_type == ChannelDataType.DOUBLE + + +@pytest.fixture +def mock_flow(mock_client): + """Create a mock Flow instance for testing.""" + flow = Flow( + proto=MagicMock(), + name="test_flow", + channels=[ + ChannelConfig( + name="channel1", + data_type=ChannelDataType.DOUBLE, + description="Test channel 1", + ), + ChannelConfig( + name="channel2", + data_type=ChannelDataType.FLOAT, + description="Test channel 2", + ), + ], + ingestion_config_id="test_config_id", + run_id="test_run_id", + ) + flow._apply_client_to_instance(mock_client) + return flow + + +class TestFlow: + """Unit tests for Flow model - tests methods.""" + + def test_add_channel_success(self): + """Test that add_channel() adds a channel when no ingestion_config_id is set.""" + flow = Flow( + name="test_flow", + channels=[], + ingestion_config_id=None, + ) + + channel = ChannelConfig( + name="new_channel", + data_type=ChannelDataType.DOUBLE, + ) + + # Should not raise + flow.add_channel(channel) + + assert len(flow.channels) == 1 + assert flow.channels[0].name == "new_channel" + + def test_add_channel_raises_after_creation(self): + """Test that add_channel() raises ValueError when ingestion_config_id is set.""" + flow = Flow( + name="test_flow", + channels=[], + ingestion_config_id="config123", + ) + + channel = ChannelConfig( + name="new_channel", + data_type=ChannelDataType.DOUBLE, + ) + + with pytest.raises( + ValueError, match="Cannot add a channel to a flow after creation" + ): + flow.add_channel(channel) + + def test_ingest_calls_client(self, mock_flow, mock_client): + """Test that ingest() calls client.async_.ingestion.ingest with correct parameters.""" + timestamp = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + channel_values = {"channel1": 42.5, "channel2": 100.0} + + # Call ingest + mock_flow.ingest(timestamp=timestamp, channel_values=channel_values) + + # Verify client method was called with correct parameters + mock_client.async_.ingestion.ingest.assert_called_once_with( + flow=mock_flow, + timestamp=timestamp, + channel_values=channel_values, + ) + + def test_ingest_raises_without_config_id(self, mock_client): + """Test that ingest() raises ValueError when ingestion_config_id is not set.""" + flow = Flow( + name="test_flow", + channels=[], + ingestion_config_id=None, + ) + flow._apply_client_to_instance(mock_client) + + timestamp = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + channel_values = {"channel1": 42.5} + + with pytest.raises(ValueError, match="Ingestion config ID is not set"): + flow.ingest(timestamp=timestamp, channel_values=channel_values) + + +class TestIngestionConfig: + """Unit tests for IngestionConfig model.""" + + def test_ingestion_config_has_required_fields(self): + """Test that IngestionConfig can be created with required fields.""" + config = IngestionConfig( + proto=MagicMock(), + id_="config123", + asset_id="asset123", + client_key="client_key_123", + ) + + assert config.id_ == "config123" + assert config.asset_id == "asset123" + assert config.client_key == "client_key_123" diff --git a/python/lib/sift_client/_tests/sift_types/test_rule.py b/python/lib/sift_client/_tests/sift_types/test_rule.py new file mode 100644 index 000000000..de9dbe7b1 --- /dev/null +++ b/python/lib/sift_client/_tests/sift_types/test_rule.py @@ -0,0 +1,148 @@ +"""Tests for sift_types.Rule model.""" + +from datetime import datetime, timezone +from unittest.mock import MagicMock + +import pytest + +from sift_client.sift_types import Rule +from sift_client.sift_types.channel import ChannelReference +from sift_client.sift_types.rule import ( + RuleAction, + RuleActionType, + RuleAnnotationType, + RuleUpdate, +) + + +@pytest.fixture +def mock_rule(mock_client): + """Create a mock Rule instance for testing.""" + rule = Rule( + proto=MagicMock(), + id_="test_rule_id", + name="test_rule", + description="test description", + is_enabled=True, + created_date=datetime.now(timezone.utc), + modified_date=datetime.now(timezone.utc), + created_by_user_id="user1", + modified_by_user_id="user1", + organization_id="org1", + is_archived=False, + is_external=False, + expression="$1 > 100", + channel_references=[ + ChannelReference(channel_reference="$1", channel_identifier="channel1"), + ], + action=RuleAction( + action_type=RuleActionType.ANNOTATION, + annotation_type=RuleAnnotationType.DATA_REVIEW, + tags=["tag1"], + ), + asset_ids=["asset1", "asset2"], + asset_tag_ids=["tag1"], + contextual_channels=["channel2"], + client_key=None, + rule_version=None, + archived_date=None, + ) + rule._apply_client_to_instance(mock_client) + return rule + + +class TestRule: + """Unit tests for Rule model - tests properties and methods.""" + + def test_assets_property_calls_client(self, mock_rule, mock_client): + """Test that assets property calls client.assets.list_ with correct parameters.""" + mock_client.assets.list_.return_value = [] + + # Access assets property + _ = mock_rule.assets + + # Verify client method was called with correct parameters + mock_client.assets.list_.assert_called_once_with( + asset_ids=["asset1", "asset2"], _tag_ids=["tag1"] + ) + + def test_update_calls_client_and_updates_self(self, mock_rule, mock_client): + """Test that update() calls client.rules.update and calls _update.""" + updated_rule = MagicMock() + updated_rule.description = "Updated description" + mock_client.rules.update.return_value = updated_rule + + # Mock the _update method to verify it's called + with MagicMock() as mock_update: + mock_rule._update = mock_update + + # Call update + update = RuleUpdate(description="Updated description") + result = mock_rule.update(update) + + # Verify client method was called with correct parameters + mock_client.rules.update.assert_called_once_with( + rule=mock_rule, update=update, version_notes=None + ) + # Verify _update was called with the returned rule + mock_update.assert_called_once_with(updated_rule) + # Verify it returns self + assert result is mock_rule + + def test_update_with_version_notes(self, mock_rule, mock_client): + """Test that update() passes version_notes parameter correctly.""" + updated_rule = MagicMock() + mock_client.rules.update.return_value = updated_rule + + # Mock the _update method + with MagicMock() as mock_update: + mock_rule._update = mock_update + + # Call update with version_notes + update = RuleUpdate(description="Updated") + mock_rule.update(update, version_notes="Test notes") + + # Verify client method was called with version_notes + mock_client.rules.update.assert_called_once_with( + rule=mock_rule, update=update, version_notes="Test notes" + ) + + def test_archive_calls_client_and_updates_self(self, mock_rule, mock_client): + """Test that archive() calls client.rules.archive and calls _update.""" + archived_rule = MagicMock() + archived_rule.is_archived = True + mock_client.rules.archive.return_value = archived_rule + + # Mock the _update method to verify it's called + with MagicMock() as mock_update: + mock_rule._update = mock_update + + # Call archive + result = mock_rule.archive() + + # Verify client method was called + mock_client.rules.archive.assert_called_once_with(rule=mock_rule) + # Verify _update was called with the returned rule + mock_update.assert_called_once_with(archived_rule) + # Verify it returns self + assert result is mock_rule + + def test_unarchive_calls_client_and_updates_self(self, mock_rule, mock_client): + """Test that unarchive() calls client.rules.unarchive and calls _update.""" + unarchived_rule = MagicMock() + unarchived_rule.is_archived = False + mock_client.rules.unarchive.return_value = unarchived_rule + + # Mock the _update method to verify it's called + with MagicMock() as mock_update: + mock_rule._update = mock_update + + # Call unarchive + result = mock_rule.unarchive() + + # Verify client method was called + mock_client.rules.unarchive.assert_called_once_with(rule=mock_rule) + # Verify _update was called with the returned rule + mock_update.assert_called_once_with(unarchived_rule) + # Verify it returns self + assert result is mock_rule diff --git a/python/lib/sift_client/_tests/sift_types/test_run.py b/python/lib/sift_client/_tests/sift_types/test_run.py new file mode 100644 index 000000000..5f9e6b1a1 --- /dev/null +++ b/python/lib/sift_client/_tests/sift_types/test_run.py @@ -0,0 +1,207 @@ +"""Tests for sift_types.Run model.""" + +from datetime import datetime, timezone +from unittest.mock import MagicMock + +import pytest + +from sift_client.sift_types import Run +from sift_client.sift_types.run import RunCreate, RunUpdate + + +class TestRunCreate: + """Unit tests for RunCreate model - tests _to_proto_helpers and validators.""" + + def test_metadata_converter(self): + """Test that metadata is converted using _to_proto_helpers.""" + metadata = {"string_key": "value", "number_key": 3.14, "bool_key": True} + create = RunCreate(name="test_run", metadata=metadata) + proto = create.to_proto() + + assert len(proto.metadata) == 3 + + # Convert list to dict for easier assertion + metadata_dict = {md.key.name: md for md in proto.metadata} + assert metadata_dict["string_key"].string_value == "value" + assert metadata_dict["number_key"].number_value == 3.14 + assert metadata_dict["bool_key"].boolean_value is True + + def test_time_validator_start_before_stop(self): + """Test time validator accepts start_time before stop_time.""" + start_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + stop_time = datetime(2024, 1, 1, 13, 0, 0, tzinfo=timezone.utc) + + # Should not raise + create = RunCreate(name="test_run", start_time=start_time, stop_time=stop_time) + assert create.start_time == start_time + assert create.stop_time == stop_time + + def test_time_validator_rejects_start_after_stop(self): + """Test time validator rejects start_time after stop_time.""" + start_time = datetime(2024, 1, 1, 13, 0, 0, tzinfo=timezone.utc) + stop_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + + with pytest.raises(ValueError, match="start_time must be before stop_time"): + RunCreate(name="test_run", start_time=start_time, stop_time=stop_time) + + def test_time_validator_rejects_stop_without_start(self): + """Test time validator rejects stop_time without start_time.""" + stop_time = datetime(2024, 1, 1, 13, 0, 0, tzinfo=timezone.utc) + + with pytest.raises( + ValueError, match="start_time must be provided if stop_time is provided" + ): + RunCreate(name="test_run", stop_time=stop_time) + + +class TestRunUpdate: + """Unit tests for RunUpdate model - tests _to_proto_helpers and validators.""" + + def test_metadata_converter(self): + """Test that metadata is converted using _to_proto_helpers.""" + metadata = {"key1": "value1", "key2": 42.5, "key3": False} + update = RunUpdate(metadata=metadata) + update.resource_id = "test_run_id" + + proto, mask = update.to_proto_with_mask() + + assert len(proto.metadata) == 3 + + # Convert list to dict for easier assertion + metadata_dict = {md.key.name: md for md in proto.metadata} + assert metadata_dict["key1"].string_value == "value1" + assert metadata_dict["key2"].number_value == 42.5 + assert metadata_dict["key3"].boolean_value is False + assert "metadata" in mask.paths + + def test_time_validator_start_before_stop(self): + """Test time validator accepts start_time before stop_time.""" + start_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + stop_time = datetime(2024, 1, 1, 13, 0, 0, tzinfo=timezone.utc) + + # Should not raise + update = RunUpdate(start_time=start_time, stop_time=stop_time) + assert update.start_time == start_time + assert update.stop_time == stop_time + + def test_time_validator_rejects_start_after_stop(self): + """Test time validator rejects start_time after stop_time.""" + start_time = datetime(2024, 1, 1, 13, 0, 0, tzinfo=timezone.utc) + stop_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + + with pytest.raises(ValueError, match="start_time must be before stop_time"): + RunUpdate(start_time=start_time, stop_time=stop_time) + + def test_time_validator_rejects_stop_without_start(self): + """Test time validator rejects stop_time without start_time.""" + stop_time = datetime(2024, 1, 1, 13, 0, 0, tzinfo=timezone.utc) + + with pytest.raises( + ValueError, match="start_time must be provided if stop_time is provided" + ): + RunUpdate(stop_time=stop_time) + + +@pytest.fixture +def mock_run(mock_client): + """Create a mock Run instance for testing.""" + run = Run( + proto=MagicMock(), + id_="test_run_id", + name="test_run", + description="test", + created_date=datetime.now(timezone.utc), + modified_date=datetime.now(timezone.utc), + created_by_user_id="user1", + modified_by_user_id="user1", + organization_id="org1", + metadata={}, + tags=[], + asset_ids=["asset1", "asset2"], + is_adhoc=False, + is_archived=False, + start_time=None, + stop_time=None, + duration=None, + default_report_id=None, + client_key=None, + archived_date=None, + ) + run._apply_client_to_instance(mock_client) + return run + + +class TestRun: + """Unit tests for Run model - tests properties and methods.""" + + def test_assets_property_calls_client(self, mock_run, mock_client): + """Test that assets property calls client.assets.list_ with correct parameters.""" + mock_client.assets.list_.return_value = [] + + # Access assets property + _ = mock_run.assets + + # Verify client method was called with correct asset_ids + mock_client.assets.list_.assert_called_once_with(asset_ids=["asset1", "asset2"]) + + def test_archive_calls_client_and_updates_self(self, mock_run, mock_client): + """Test that archive() calls client.runs.archive and calls _update.""" + archived_run = MagicMock() + archived_run.is_archived = True + archived_run.archived_date = datetime.now(timezone.utc) + mock_client.runs.archive.return_value = archived_run + + # Mock the _update method to verify it's called + with MagicMock() as mock_update: + mock_run._update = mock_update + + # Call archive + result = mock_run.archive() + + # Verify client method was called + mock_client.runs.archive.assert_called_once_with(run=mock_run) + # Verify _update was called with the returned run + mock_update.assert_called_once_with(archived_run) + # Verify it returns self + assert result is mock_run + + def test_unarchive_calls_client_and_updates_self(self, mock_run, mock_client): + """Test that unarchive() calls client.runs.unarchive and calls _update.""" + unarchived_run = MagicMock() + unarchived_run.is_archived = False + mock_client.runs.unarchive.return_value = unarchived_run + + # Mock the _update method to verify it's called + with MagicMock() as mock_update: + mock_run._update = mock_update + + # Call unarchive + result = mock_run.unarchive() + + # Verify client method was called + mock_client.runs.unarchive.assert_called_once_with(run=mock_run) + # Verify _update was called with the returned run + mock_update.assert_called_once_with(unarchived_run) + # Verify it returns self + assert result is mock_run + + def test_update_calls_client_and_updates_self(self, mock_run, mock_client): + """Test that update() calls client.runs.update and calls _update.""" + updated_run = MagicMock() + updated_run.description = "Updated description" + mock_client.runs.update.return_value = updated_run + + # Mock the _update method to verify it's called + with MagicMock() as mock_update: + mock_run._update = mock_update + + # Call update + update = RunUpdate(description="Updated description") + result = mock_run.update(update) + + # Verify client method was called with correct parameters + mock_client.runs.update.assert_called_once_with(run=mock_run, update=update) + # Verify _update was called with the returned run + mock_update.assert_called_once_with(updated_run) + # Verify it returns self + assert result is mock_run diff --git a/python/lib/sift_client/resources/runs.py b/python/lib/sift_client/resources/runs.py index 1aa34e267..45ec228a1 100644 --- a/python/lib/sift_client/resources/runs.py +++ b/python/lib/sift_client/resources/runs.py @@ -34,7 +34,9 @@ def __init__(self, sift_client: SiftClient): super().__init__(sift_client) self._low_level_client = RunsLowLevelClient(grpc_client=self.client.grpc_client) - async def get(self, *, run_id: str | None = None, client_key: str | None = None) -> Run: + async def get( + self, *, run_id: str | None = None, client_key: str | None = None + ) -> Run: """Get a Run. Args: @@ -148,14 +150,18 @@ async def list_( if assets: if all(isinstance(s, str) for s in assets): ids = cast("list[str]", assets) # linting - filter_parts.append(cel.in_("asset_ids", ids)) + filter_parts.append(cel.in_("asset_id", ids)) else: asset = cast("list[Asset]", assets) # linting - filter_parts.append(cel.in_("asset_ids", [a._id_or_error for a in asset])) + filter_parts.append( + cel.in_("asset_id", [a._id_or_error for a in asset]) + ) if duration_less_than: filter_parts.append(cel.less_than("duration_string", duration_less_than)) if duration_greater_than: - filter_parts.append(cel.greater_than("duration_string", duration_greater_than)) + filter_parts.append( + cel.greater_than("duration_string", duration_greater_than) + ) if start_time_after: filter_parts.append(cel.greater_than("start_time", start_time_after)) if start_time_before: diff --git a/python/lib/sift_client/sift_types/_base.py b/python/lib/sift_client/sift_types/_base.py index 1cbadeec0..8de933bce 100644 --- a/python/lib/sift_client/sift_types/_base.py +++ b/python/lib/sift_client/sift_types/_base.py @@ -46,7 +46,9 @@ def _id_or_error(self) -> str: @classmethod @abstractmethod - def _from_proto(cls, proto: ProtoT, sift_client: SiftClient | None = None) -> SelfT: ... + def _from_proto( + cls, proto: ProtoT, sift_client: SiftClient | None = None + ) -> SelfT: ... def _apply_client_to_instance(self, client: SiftClient) -> None: # This bypasses the frozen status of the model @@ -58,10 +60,10 @@ def _update(self, other: BaseType[ProtoT, SelfT]) -> BaseType[ProtoT, SelfT]: for key in other.__class__.model_fields.keys(): if key in self.model_fields: self.__dict__.update({key: getattr(other, key)}) + return self # Make sure we also update the proto since it is excluded self.__dict__["proto"] = other.proto - return self @model_validator(mode="after") def _validate_timezones(self): @@ -131,7 +133,10 @@ def _build_proto_and_paths( for field_name, value in data.items(): path = f"{prefix}.{field_name}" if prefix else field_name - if not already_setting_path_override and field_name in self._to_proto_helpers: + if ( + not already_setting_path_override + and field_name in self._to_proto_helpers + ): mapping_helper = self._to_proto_helpers[field_name] # Expand the proto path to a dictionary and parse recursively for layer in reversed(mapping_helper.proto_attr_path.split(".")): @@ -145,9 +150,9 @@ def _build_proto_and_paths( paths.append(mapping_helper.update_field) elif isinstance(value, dict): if field_name in self.__class__._to_proto_helpers: - assert self.__class__._to_proto_helpers[field_name].converter, ( - f"Expecting to run a coverter given a helper was defined for: {field_name}" - ) + assert self.__class__._to_proto_helpers[ + field_name + ].converter, f"Expecting to run a coverter given a helper was defined for: {field_name}" sub_paths = self._build_proto_and_paths( proto_msg, {field_name: self.__class__._to_proto_helpers[field_name].converter(value)}, # type: ignore[misc] @@ -173,9 +178,9 @@ def _build_proto_and_paths( repeated_field.extend(value) # Add all new values except TypeError as e: if field_name in self.__class__._to_proto_helpers: - assert self.__class__._to_proto_helpers[field_name].converter, ( - f"Expecting to run a coverter given a helper was defined for: {field_name}" - ) + assert self.__class__._to_proto_helpers[ + field_name + ].converter, f"Expecting to run a coverter given a helper was defined for: {field_name}" for item in value: repeated_field.append( self.__class__._to_proto_helpers[field_name].converter(**item) # type: ignore diff --git a/python/lib/sift_client/sift_types/calculated_channel.py b/python/lib/sift_client/sift_types/calculated_channel.py index 1d05ea385..d1fca2523 100644 --- a/python/lib/sift_client/sift_types/calculated_channel.py +++ b/python/lib/sift_client/sift_types/calculated_channel.py @@ -63,12 +63,18 @@ def modified_by(self): def archive(self) -> CalculatedChannel: """Archive the calculated channel.""" - self.client.calculated_channels.archive(calculated_channel=self) + updated_calculated_channel = self.client.calculated_channels.archive( + calculated_channel=self + ) + self._update(updated_calculated_channel) return self def unarchive(self) -> CalculatedChannel: """Unarchive the calculated channel.""" - self.client.calculated_channels.unarchive(calculated_channel=self) + updated_calculated_channel = self.client.calculated_channels.unarchive( + calculated_channel=self + ) + self._update(updated_calculated_channel) return self def update( @@ -169,6 +175,7 @@ class CalculatedChannelBase(ModelCreateUpdateBase): ), "all_assets": MappingHelper( proto_attr_path="calculated_channel_configuration.asset_configuration.all_assets", + update_field="asset_configuration", ), "metadata": MappingHelper( proto_attr_path="metadata", diff --git a/python/pyproject.toml b/python/pyproject.toml index 31dea1967..993b4236b 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -207,6 +207,11 @@ select = [ env_files = [ ".env" ] +testpaths = [ + "lib/sift_py", + "lib/sift_client/_tests", +] + markers = [ "integration: mark a test as an integration test (requires API)" ] From 519e61fa16ea515fbbd0d1ed254f982df7c0f08b Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Fri, 10 Oct 2025 14:07:12 -0700 Subject: [PATCH 37/39] fmt and linting --- .../resources/test_calculated_channels.py | 99 +++++-------------- .../_tests/resources/test_rules.py | 67 +++---------- .../sift_client/_tests/resources/test_runs.py | 4 +- .../_tests/sift_types/test_asset.py | 20 +--- .../sift_types/test_calculated_channel.py | 16 ++- .../_tests/sift_types/test_ingestion.py | 4 +- python/lib/sift_client/resources/runs.py | 12 +-- python/lib/sift_client/sift_types/_base.py | 21 ++-- 8 files changed, 70 insertions(+), 173 deletions(-) diff --git a/python/lib/sift_client/_tests/resources/test_calculated_channels.py b/python/lib/sift_client/_tests/resources/test_calculated_channels.py index ab71fd602..3a39bb4b6 100644 --- a/python/lib/sift_client/_tests/resources/test_calculated_channels.py +++ b/python/lib/sift_client/_tests/resources/test_calculated_channels.py @@ -29,9 +29,7 @@ def test_client_binding(sift_client): assert sift_client.calculated_channels assert isinstance(sift_client.calculated_channels, CalculatedChannelsAPI) assert sift_client.async_.calculated_channels - assert isinstance( - sift_client.async_.calculated_channels, CalculatedChannelsAPIAsync - ) + assert isinstance(sift_client.async_.calculated_channels, CalculatedChannelsAPIAsync) @pytest.fixture @@ -73,12 +71,8 @@ def new_calculated_channel(calculated_channels_api_sync, sift_client): description=description, expression="$1 + $2", expression_channel_references=[ - ChannelReference( - channel_reference="$1", channel_identifier=channels[0].name - ), - ChannelReference( - channel_reference="$2", channel_identifier=channels[1].name - ), + ChannelReference(channel_reference="$1", channel_identifier=channels[0].name), + ChannelReference(channel_reference="$2", channel_identifier=channels[1].name), ], all_assets=True, ) @@ -93,9 +87,7 @@ class TestGet: """Tests for the async get method.""" @pytest.mark.asyncio - async def test_get_by_id( - self, calculated_channels_api_async, test_calculated_channel - ): + async def test_get_by_id(self, calculated_channels_api_async, test_calculated_channel): """Test getting a specific calculated channel by ID.""" retrieved_calc_channel = await calculated_channels_api_async.get( calculated_channel_id=test_calculated_channel.id_ @@ -148,13 +140,9 @@ async def test_list_with_name_filter(self, calculated_channels_api_async): assert calc_channel.name == test_calc_channel_name @pytest.mark.asyncio - async def test_list_with_name_contains_filter( - self, calculated_channels_api_async - ): + async def test_list_with_name_contains_filter(self, calculated_channels_api_async): """Test calculated channel listing with name contains filtering.""" - calc_channels = await calculated_channels_api_async.list_( - name_contains="test", limit=5 - ) + calc_channels = await calculated_channels_api_async.list_(name_contains="test", limit=5) assert isinstance(calc_channels, list) assert calc_channels @@ -205,9 +193,7 @@ async def test_find_calculated_channel( assert found_calc_channel.id_ == test_calculated_channel.id_ @pytest.mark.asyncio - async def test_find_nonexistent_calculated_channel( - self, calculated_channels_api_async - ): + async def test_find_nonexistent_calculated_channel(self, calculated_channels_api_async): """Test finding a non-existent calculated channel returns None.""" found_calc_channel = await calculated_channels_api_async.find( name="nonexistent-calculated-channel-name-12345" @@ -224,20 +210,14 @@ class TestCreate: """Tests for the async create method.""" @pytest.mark.asyncio - async def test_create_basic_calculated_channel( - self, calculated_channels_api_async - ): + async def test_create_basic_calculated_channel(self, calculated_channels_api_async): """Test creating a basic calculated channel with minimal fields.""" from datetime import datetime, timezone - calc_channel_name = ( - f"test_calc_channel_create_{datetime.now(timezone.utc).isoformat()}" - ) + calc_channel_name = f"test_calc_channel_create_{datetime.now(timezone.utc).isoformat()}" description = "Test calculated channel created by Sift Client pytest" - channels = await calculated_channels_api_async.client.async_.channels.list_( - limit=2 - ) + channels = await calculated_channels_api_async.client.async_.channels.list_(limit=2) assert len(channels) >= 2 calc_channel_create = CalculatedChannelCreate( @@ -245,19 +225,13 @@ async def test_create_basic_calculated_channel( description=description, expression="$1 + $2", expression_channel_references=[ - ChannelReference( - channel_reference="$1", channel_identifier=channels[0].name - ), - ChannelReference( - channel_reference="$2", channel_identifier=channels[1].name - ), + ChannelReference(channel_reference="$1", channel_identifier=channels[0].name), + ChannelReference(channel_reference="$2", channel_identifier=channels[1].name), ], all_assets=True, ) - created_calc_channel = await calculated_channels_api_async.create( - calc_channel_create - ) + created_calc_channel = await calculated_channels_api_async.create(calc_channel_create) try: assert created_calc_channel is not None @@ -271,20 +245,14 @@ async def test_create_basic_calculated_channel( await calculated_channels_api_async.archive(created_calc_channel) @pytest.mark.asyncio - async def test_create_calculated_channel_with_dict( - self, calculated_channels_api_async - ): + async def test_create_calculated_channel_with_dict(self, calculated_channels_api_async): """Test creating a calculated channel using a dictionary.""" from datetime import datetime, timezone - calc_channel_name = ( - f"test_calc_channel_dict_{datetime.now(timezone.utc).isoformat()}" - ) + calc_channel_name = f"test_calc_channel_dict_{datetime.now(timezone.utc).isoformat()}" description = "Test calculated channel created by Sift Client pytest" - channels = await calculated_channels_api_async.client.async_.channels.list_( - limit=2 - ) + channels = await calculated_channels_api_async.client.async_.channels.list_(limit=2) assert len(channels) >= 2 calc_channel_dict = { @@ -298,9 +266,7 @@ async def test_create_calculated_channel_with_dict( "all_assets": True, } - created_calc_channel = await calculated_channels_api_async.create( - calc_channel_dict - ) + created_calc_channel = await calculated_channels_api_async.create(calc_channel_dict) try: assert created_calc_channel.name == calc_channel_name @@ -411,18 +377,12 @@ async def test_update_with_complex_expression( description="Test calculated channel for complex expression update", expression="$1 + $2", expression_channel_references=[ - ChannelReference( - channel_reference="$1", channel_identifier=channels[0].name - ), - ChannelReference( - channel_reference="$2", channel_identifier=channels[1].name - ), + ChannelReference(channel_reference="$1", channel_identifier=channels[0].name), + ChannelReference(channel_reference="$2", channel_identifier=channels[1].name), ], all_assets=True, ) - created_calc_channel = await calculated_channels_api_async.create( - calc_channel_create - ) + created_calc_channel = await calculated_channels_api_async.create(calc_channel_create) try: # Update with complex expression @@ -449,8 +409,7 @@ async def test_update_with_complex_expression( assert len(updated_calc_channel.channel_references) == 3 # Verify all three channel references are present ref_identifiers = { - ref.channel_identifier - for ref in updated_calc_channel.channel_references + ref.channel_identifier for ref in updated_calc_channel.channel_references } assert channels[0].name in ref_identifiers assert channels[1].name in ref_identifiers @@ -503,9 +462,7 @@ async def test_archive_calculated_channel( self, calculated_channels_api_async, new_calculated_channel ): """Test archiving a calculated channel.""" - calc_channel = await calculated_channels_api_async.archive( - new_calculated_channel - ) + calc_channel = await calculated_channels_api_async.archive(new_calculated_channel) assert isinstance(calc_channel, CalculatedChannel) assert calc_channel.id_ == new_calculated_channel.id_ @@ -528,9 +485,7 @@ async def test_archive_with_id_string( self, calculated_channels_api_async, new_calculated_channel ): """Test archiving a calculated channel by passing ID as string.""" - calc_channel = await calculated_channels_api_async.archive( - new_calculated_channel.id_ - ) + calc_channel = await calculated_channels_api_async.archive(new_calculated_channel.id_) assert isinstance(calc_channel, CalculatedChannel) assert calc_channel.id_ == new_calculated_channel.id_ @@ -547,9 +502,7 @@ async def test_unarchive_calculated_channel( try: await calculated_channels_api_async.archive(new_calculated_channel) - calc_channel = await calculated_channels_api_async.unarchive( - new_calculated_channel - ) + calc_channel = await calculated_channels_api_async.unarchive(new_calculated_channel) assert isinstance(calc_channel, CalculatedChannel) assert calc_channel.id_ == new_calculated_channel.id_ @@ -561,9 +514,7 @@ class TestListVersions: """Tests for the async list_versions method.""" @pytest.mark.asyncio - async def test_list_versions( - self, calculated_channels_api_async, test_calculated_channel - ): + async def test_list_versions(self, calculated_channels_api_async, test_calculated_channel): """Test listing versions of a calculated channel.""" versions = await calculated_channels_api_async.list_versions( calculated_channel=test_calculated_channel diff --git a/python/lib/sift_client/_tests/resources/test_rules.py b/python/lib/sift_client/_tests/resources/test_rules.py index ac767aa4b..6036bec19 100644 --- a/python/lib/sift_client/_tests/resources/test_rules.py +++ b/python/lib/sift_client/_tests/resources/test_rules.py @@ -77,12 +77,8 @@ def new_rule(rules_api_sync, sift_client): description=description, expression="$1 > $2", channel_references=[ - ChannelReference( - channel_reference="$1", channel_identifier=channels[0].name - ), - ChannelReference( - channel_reference="$2", channel_identifier=channels[1].name - ), + ChannelReference(channel_reference="$1", channel_identifier=channels[0].name), + ChannelReference(channel_reference="$2", channel_identifier=channels[1].name), ], action=RuleAction.annotation( annotation_type=RuleAnnotationType.DATA_REVIEW, @@ -113,9 +109,7 @@ async def test_get_by_id(self, rules_api_async, test_rule): async def test_get_by_client_key(self, rules_api_async, test_rule): """Test getting a specific rule by client key.""" if test_rule.client_key: - retrieved_rule = await rules_api_async.get( - client_key=test_rule.client_key - ) + retrieved_rule = await rules_api_async.get(client_key=test_rule.client_key) assert retrieved_rule is not None assert retrieved_rule.id_ == test_rule.id_ @@ -266,12 +260,8 @@ async def test_create_basic_rule(self, rules_api_async): description=description, expression="$1 > $2", channel_references=[ - ChannelReference( - channel_reference="$1", channel_identifier=channels[0].name - ), - ChannelReference( - channel_reference="$2", channel_identifier=channels[1].name - ), + ChannelReference(channel_reference="$1", channel_identifier=channels[0].name), + ChannelReference(channel_reference="$2", channel_identifier=channels[1].name), ], action=RuleAction.annotation( annotation_type=RuleAnnotationType.DATA_REVIEW, @@ -350,10 +340,7 @@ async def test_update_rule_description(self, rules_api_async, new_rule): assert updated_rule.expression == new_rule.expression assert updated_rule.action.action_type == new_rule.action.action_type assert updated_rule.client_key == new_rule.client_key - assert ( - updated_rule.rule_version.created_date - > new_rule.rule_version.created_date - ) + assert updated_rule.rule_version.created_date > new_rule.rule_version.created_date finally: await rules_api_async.archive(new_rule.id_) @@ -426,10 +413,7 @@ async def test_update_rule_action(self, rules_api_async, new_rule): assert updated_rule.action.action_type == RuleActionType.ANNOTATION assert updated_rule.action.annotation_type == RuleAnnotationType.PHASE assert set(updated_rule.action.tags) == {"sift-client-pytest"} - assert ( - updated_rule.action.default_assignee_user_id - == new_rule.created_by_user_id - ) + assert updated_rule.action.default_assignee_user_id == new_rule.created_by_user_id # Verify other fields remain unchanged assert updated_rule.name == new_rule.name @@ -438,9 +422,7 @@ async def test_update_rule_action(self, rules_api_async, new_rule): await rules_api_async.archive(new_rule.id_) @pytest.mark.asyncio - async def test_update_with_complex_expression( - self, rules_api_async, sift_client - ): + async def test_update_with_complex_expression(self, rules_api_async, sift_client): """Test updating a rule with a complex expression (range check).""" # Get channels and assets channels = await sift_client.async_.channels.list_(limit=2) @@ -449,17 +431,13 @@ async def test_update_with_complex_expression( assert len(assets) >= 1 # Create a rule with simple expression - rule_name = ( - f"test_rule_complex_expr_{datetime.now(timezone.utc).isoformat()}" - ) + rule_name = f"test_rule_complex_expr_{datetime.now(timezone.utc).isoformat()}" rule_create = RuleCreate( name=rule_name, description="Test rule for complex expression update", expression="$1 > 0.5", channel_references=[ - ChannelReference( - channel_reference="$1", channel_identifier=channels[0].name - ), + ChannelReference(channel_reference="$1", channel_identifier=channels[0].name), ], action=RuleAction.annotation( annotation_type=RuleAnnotationType.DATA_REVIEW, @@ -485,23 +463,16 @@ async def test_update_with_complex_expression( assert updated_rule.id_ == created_rule.id_ assert updated_rule.expression == "$1 > 0.3 && $1 < 0.8" assert len(updated_rule.channel_references) == 1 - assert ( - updated_rule.channel_references[0].channel_identifier - == channels[0].name - ) + assert updated_rule.channel_references[0].channel_identifier == channels[0].name # Verify other fields remain unchanged assert updated_rule.name == created_rule.name - assert ( - updated_rule.action.action_type == created_rule.action.action_type - ) + assert updated_rule.action.action_type == created_rule.action.action_type finally: await rules_api_async.archive(created_rule.id_) @pytest.mark.asyncio - async def test_update_with_multiple_channel_references( - self, rules_api_async, sift_client - ): + async def test_update_with_multiple_channel_references(self, rules_api_async, sift_client): """Test updating a rule expression to use multiple channel references.""" # Get channels and assets channels = await sift_client.async_.channels.list_(limit=3) @@ -516,12 +487,8 @@ async def test_update_with_multiple_channel_references( description="Test rule for multiple channel references", expression="$1 > $2", channel_references=[ - ChannelReference( - channel_reference="$1", channel_identifier=channels[0].name - ), - ChannelReference( - channel_reference="$2", channel_identifier=channels[1].name - ), + ChannelReference(channel_reference="$1", channel_identifier=channels[0].name), + ChannelReference(channel_reference="$2", channel_identifier=channels[1].name), ], action=RuleAction.annotation( annotation_type=RuleAnnotationType.DATA_REVIEW, @@ -578,9 +545,7 @@ async def test_update_with_invalid_expression(self, rules_api_async, new_rule): channel_references=[ ChannelReference( channel_reference="$1", - channel_identifier=new_rule.channel_references[ - 0 - ].channel_identifier, + channel_identifier=new_rule.channel_references[0].channel_identifier, ), ], ) diff --git a/python/lib/sift_client/_tests/resources/test_runs.py b/python/lib/sift_client/_tests/resources/test_runs.py index 6153c63bb..b28c0cc06 100644 --- a/python/lib/sift_client/_tests/resources/test_runs.py +++ b/python/lib/sift_client/_tests/resources/test_runs.py @@ -484,9 +484,7 @@ class TestAssetAssociation: """Tests for the async asset association methods.""" @pytest.mark.asyncio - async def test_create_automatic_association_for_assets( - self, runs_api_async, sift_client - ): + async def test_create_automatic_association_for_assets(self, runs_api_async, sift_client): """Test associating assets with a run for automatic data ingestion.""" # Create a test run run_name = f"test_run_asset_assoc_{datetime.now(timezone.utc).isoformat()}" diff --git a/python/lib/sift_client/_tests/sift_types/test_asset.py b/python/lib/sift_client/_tests/sift_types/test_asset.py index 9c085b61a..e209e2aad 100644 --- a/python/lib/sift_client/_tests/sift_types/test_asset.py +++ b/python/lib/sift_client/_tests/sift_types/test_asset.py @@ -74,9 +74,7 @@ def test_channels_method_calls_client(self, mock_asset, mock_client): _ = mock_asset.channels(limit=5) # Verify client method was called with correct parameters - mock_client.channels.list_.assert_called_once_with( - asset=mock_asset, run=None, limit=5 - ) + mock_client.channels.list_.assert_called_once_with(asset=mock_asset, run=None, limit=5) def test_channels_method_with_run_filter(self, mock_asset, mock_client): """Test that channels() method passes run filter to client.""" @@ -87,9 +85,7 @@ def test_channels_method_with_run_filter(self, mock_asset, mock_client): _ = mock_asset.channels(run=mock_run, limit=10) # Verify client method was called with run parameter - mock_client.channels.list_.assert_called_once_with( - asset=mock_asset, run=mock_run, limit=10 - ) + mock_client.channels.list_.assert_called_once_with(asset=mock_asset, run=mock_run, limit=10) def test_archive_calls_client_and_updates_self(self, mock_asset, mock_client): """Test that archive() calls client.assets.archive and calls _update.""" @@ -106,9 +102,7 @@ def test_archive_calls_client_and_updates_self(self, mock_asset, mock_client): result = mock_asset.archive(archive_runs=False) # Verify client method was called - mock_client.assets.archive.assert_called_once_with( - asset=mock_asset, archive_runs=False - ) + mock_client.assets.archive.assert_called_once_with(asset=mock_asset, archive_runs=False) # Verify _update was called with the returned asset mock_update.assert_called_once_with(archived_asset) # Verify it returns self @@ -127,9 +121,7 @@ def test_archive_with_runs(self, mock_asset, mock_client): mock_asset.archive(archive_runs=True) # Verify client method was called with archive_runs=True - mock_client.assets.archive.assert_called_once_with( - asset=mock_asset, archive_runs=True - ) + mock_client.assets.archive.assert_called_once_with(asset=mock_asset, archive_runs=True) def test_unarchive_calls_client_and_updates_self(self, mock_asset, mock_client): """Test that unarchive() calls client.assets.unarchive and calls _update.""" @@ -166,9 +158,7 @@ def test_update_calls_client_and_updates_self(self, mock_asset, mock_client): result = mock_asset.update(update) # Verify client method was called with correct parameters - mock_client.assets.update.assert_called_once_with( - asset=mock_asset, update=update - ) + mock_client.assets.update.assert_called_once_with(asset=mock_asset, update=update) # Verify _update was called with the returned asset mock_update.assert_called_once_with(updated_asset) # Verify it returns self diff --git a/python/lib/sift_client/_tests/sift_types/test_calculated_channel.py b/python/lib/sift_client/_tests/sift_types/test_calculated_channel.py index e4c310c92..d79fa34ba 100644 --- a/python/lib/sift_client/_tests/sift_types/test_calculated_channel.py +++ b/python/lib/sift_client/_tests/sift_types/test_calculated_channel.py @@ -7,7 +7,6 @@ from sift_client.sift_types import CalculatedChannel from sift_client.sift_types.calculated_channel import ( - CalculatedChannelBase, CalculatedChannelUpdate, ) from sift_client.sift_types.channel import ChannelReference @@ -67,7 +66,9 @@ def test_expression_channel_references_helper(self): proto, mask = update.to_proto_with_mask() # Verify channel references are converted - refs = proto.calculated_channel_configuration.query_configuration.sel.expression_channel_references + refs = ( + proto.calculated_channel_configuration.query_configuration.sel.expression_channel_references + ) assert len(refs) == 2 assert refs[0].channel_reference == "$1" assert refs[0].channel_identifier == "channel1" @@ -109,14 +110,18 @@ def test_all_assets_helper(self): proto, mask = update.to_proto_with_mask() # Verify all_assets is set in nested path - assert proto.calculated_channel_configuration.asset_configuration.all_assets is True + assert ( + proto.calculated_channel_configuration.asset_configuration.all_assets + is True + ) # Verify update_field is in mask (same as tag_ids and asset_ids) assert "asset_configuration" in mask.paths def test_asset_configuration_validator_rejects_all_assets_with_asset_ids(self): """Test validator rejects all_assets=True with asset_ids.""" with pytest.raises( - ValueError, match="Cannot specify both all_assets=True and asset_ids/tag_ids" + ValueError, + match="Cannot specify both all_assets=True and asset_ids/tag_ids", ): CalculatedChannelUpdate( all_assets=True, @@ -126,7 +131,8 @@ def test_asset_configuration_validator_rejects_all_assets_with_asset_ids(self): def test_asset_configuration_validator_rejects_all_assets_with_tag_ids(self): """Test validator rejects all_assets=True with tag_ids.""" with pytest.raises( - ValueError, match="Cannot specify both all_assets=True and asset_ids/tag_ids" + ValueError, + match="Cannot specify both all_assets=True and asset_ids/tag_ids", ): CalculatedChannelUpdate( all_assets=True, diff --git a/python/lib/sift_client/_tests/sift_types/test_ingestion.py b/python/lib/sift_client/_tests/sift_types/test_ingestion.py index 2c333a915..6b29abafe 100644 --- a/python/lib/sift_client/_tests/sift_types/test_ingestion.py +++ b/python/lib/sift_client/_tests/sift_types/test_ingestion.py @@ -129,9 +129,7 @@ def test_add_channel_raises_after_creation(self): data_type=ChannelDataType.DOUBLE, ) - with pytest.raises( - ValueError, match="Cannot add a channel to a flow after creation" - ): + with pytest.raises(ValueError, match="Cannot add a channel to a flow after creation"): flow.add_channel(channel) def test_ingest_calls_client(self, mock_flow, mock_client): diff --git a/python/lib/sift_client/resources/runs.py b/python/lib/sift_client/resources/runs.py index 45ec228a1..1e097612a 100644 --- a/python/lib/sift_client/resources/runs.py +++ b/python/lib/sift_client/resources/runs.py @@ -34,9 +34,7 @@ def __init__(self, sift_client: SiftClient): super().__init__(sift_client) self._low_level_client = RunsLowLevelClient(grpc_client=self.client.grpc_client) - async def get( - self, *, run_id: str | None = None, client_key: str | None = None - ) -> Run: + async def get(self, *, run_id: str | None = None, client_key: str | None = None) -> Run: """Get a Run. Args: @@ -153,15 +151,11 @@ async def list_( filter_parts.append(cel.in_("asset_id", ids)) else: asset = cast("list[Asset]", assets) # linting - filter_parts.append( - cel.in_("asset_id", [a._id_or_error for a in asset]) - ) + filter_parts.append(cel.in_("asset_id", [a._id_or_error for a in asset])) if duration_less_than: filter_parts.append(cel.less_than("duration_string", duration_less_than)) if duration_greater_than: - filter_parts.append( - cel.greater_than("duration_string", duration_greater_than) - ) + filter_parts.append(cel.greater_than("duration_string", duration_greater_than)) if start_time_after: filter_parts.append(cel.greater_than("start_time", start_time_after)) if start_time_before: diff --git a/python/lib/sift_client/sift_types/_base.py b/python/lib/sift_client/sift_types/_base.py index 8de933bce..db0370e36 100644 --- a/python/lib/sift_client/sift_types/_base.py +++ b/python/lib/sift_client/sift_types/_base.py @@ -46,9 +46,7 @@ def _id_or_error(self) -> str: @classmethod @abstractmethod - def _from_proto( - cls, proto: ProtoT, sift_client: SiftClient | None = None - ) -> SelfT: ... + def _from_proto(cls, proto: ProtoT, sift_client: SiftClient | None = None) -> SelfT: ... def _apply_client_to_instance(self, client: SiftClient) -> None: # This bypasses the frozen status of the model @@ -133,10 +131,7 @@ def _build_proto_and_paths( for field_name, value in data.items(): path = f"{prefix}.{field_name}" if prefix else field_name - if ( - not already_setting_path_override - and field_name in self._to_proto_helpers - ): + if not already_setting_path_override and field_name in self._to_proto_helpers: mapping_helper = self._to_proto_helpers[field_name] # Expand the proto path to a dictionary and parse recursively for layer in reversed(mapping_helper.proto_attr_path.split(".")): @@ -150,9 +145,9 @@ def _build_proto_and_paths( paths.append(mapping_helper.update_field) elif isinstance(value, dict): if field_name in self.__class__._to_proto_helpers: - assert self.__class__._to_proto_helpers[ - field_name - ].converter, f"Expecting to run a coverter given a helper was defined for: {field_name}" + assert self.__class__._to_proto_helpers[field_name].converter, ( + f"Expecting to run a coverter given a helper was defined for: {field_name}" + ) sub_paths = self._build_proto_and_paths( proto_msg, {field_name: self.__class__._to_proto_helpers[field_name].converter(value)}, # type: ignore[misc] @@ -178,9 +173,9 @@ def _build_proto_and_paths( repeated_field.extend(value) # Add all new values except TypeError as e: if field_name in self.__class__._to_proto_helpers: - assert self.__class__._to_proto_helpers[ - field_name - ].converter, f"Expecting to run a coverter given a helper was defined for: {field_name}" + assert self.__class__._to_proto_helpers[field_name].converter, ( + f"Expecting to run a coverter given a helper was defined for: {field_name}" + ) for item in value: repeated_field.append( self.__class__._to_proto_helpers[field_name].converter(**item) # type: ignore From 3bc0c2ded55882457b2e82018f4ca3ad72fb293c Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Fri, 10 Oct 2025 14:12:48 -0700 Subject: [PATCH 38/39] update pre-push --- .githooks/pre-push | 5 +- .githooks/pre-push-python/fmt-lint.sh | 35 +++++++++++ .githooks/pre-push-python/stubs-check.sh | 62 ------------------- .../sift_client/_tests/resources/test_runs.py | 6 +- .../sift_types/test_calculated_channel.py | 28 +++------ 5 files changed, 51 insertions(+), 85 deletions(-) create mode 100644 .githooks/pre-push-python/fmt-lint.sh delete mode 100644 .githooks/pre-push-python/stubs-check.sh diff --git a/.githooks/pre-push b/.githooks/pre-push index 6b13b65f5..ba7f5ac86 100755 --- a/.githooks/pre-push +++ b/.githooks/pre-push @@ -8,7 +8,10 @@ GITHOOKS_DIR="$REPO_ROOT/.githooks" python_changed_files=($(git diff --name-only --diff-filter=ACM | grep '^python/lib/sift_client/' || true)) if [[ -n "$python_changed_files" ]]; then - echo "Python files changed, running Python stub checks..." + echo "Python files changed, running Python formatting and linting..." + bash "$GITHOOKS_DIR/pre-push-python/fmt-lint.sh" + + echo "Running Python stub checks..." bash "$GITHOOKS_DIR/pre-push-python/stubs.sh" fi diff --git a/.githooks/pre-push-python/fmt-lint.sh b/.githooks/pre-push-python/fmt-lint.sh new file mode 100644 index 000000000..575a4475b --- /dev/null +++ b/.githooks/pre-push-python/fmt-lint.sh @@ -0,0 +1,35 @@ +#!/usr/bin/env bash + +set -e + +# Store the root directory of the repository +REPO_ROOT="$(git rev-parse --show-toplevel)" +PYTHON_DIR="$REPO_ROOT/python" + +echo "Running Python formatting and linting with --fix..." + +# Change to Python directory +cd "$PYTHON_DIR" + +# Run ruff format (formatter) +echo "Running ruff format..." +bash ./scripts/dev fmt + +# Run ruff check with --fix (linter) +echo "Running ruff check --fix..." +bash ./scripts/dev lint-fix + +# Check if any files were modified by formatting/linting +cd "$REPO_ROOT" +changed_files=$(git status --porcelain python/lib/sift_client/ | grep -E '\.py$' || true) + +if [ -n "$changed_files" ]; then + echo "" + echo "ERROR: Formatting/linting made changes to the following files:" + echo "$changed_files" + echo "" + echo "Please commit these changes before pushing." + exit 1 +fi + +echo "Python formatting and linting completed successfully." diff --git a/.githooks/pre-push-python/stubs-check.sh b/.githooks/pre-push-python/stubs-check.sh deleted file mode 100644 index 2d7ab6d38..000000000 --- a/.githooks/pre-push-python/stubs-check.sh +++ /dev/null @@ -1,62 +0,0 @@ -#!/bin/bash -set -e - - -# ensure generated python stubs are up-to-date, from sync clients and sift_stream_bindings - -REPO_ROOT="$(git rev-parse --show-toplevel)" -PYTHON_DIR="$REPO_ROOT/python" -BINDINGS_DIR="$REPO_ROOT/rust/crates/sift_stream_bindings" -STUBS_DIR="$PYTHON_DIR/lib/sift_client/resources/sync_stubs" - -# Function to check if generated stub files have changed -check_stub_changes() { - local target_path="$1" - local changed_files=$(git status --porcelain "$target_path" | grep -E '\.pyi$' || true) - - if [ -n "$changed_files" ]; then - echo "ERROR: Generated python stubs are not up-to-date. Please commit the changed files:" - echo "$changed_files" - exit 1 - fi -} - -# Function to generate Python stubs -generate_python_stubs() { - echo "Generating Python stubs..." - cd "$PYTHON_DIR" - - if [[ ! -d "$PYTHON_DIR/venv" ]]; then - echo "Running bootstrap script..." - bash ./scripts/dev bootstrap - fi - - bash ./scripts/dev gen-stubs - check_stub_changes "$STUBS_DIR" -} - -# Function to generate bindings stubs -generate_bindings_stubs() { - echo "Generating bindings stubs..." - cd "$BINDINGS_DIR" - cargo run --bin stub_gen - - # The stub file is generated in the bindings directory - local stub_file="$BINDINGS_DIR/sift_stream_bindings.pyi" - check_stub_changes "$stub_file" -} - -# Check for changes in relevant files -python_changed_files=($(git diff --name-only --diff-filter=ACM | grep '^python/lib/sift_client/' || true)) -bindings_changed_files=($(git diff --name-only --diff-filter=ACM | grep '^rust/crates/sift_stream_bindings/src/' || true)) - -# Generate stubs if needed -if [[ -n "$python_changed_files" ]]; then - generate_python_stubs -fi - -if [[ -n "$bindings_changed_files" ]]; then - generate_bindings_stubs -fi - -echo "All stubs are up-to-date." \ No newline at end of file diff --git a/python/lib/sift_client/_tests/resources/test_runs.py b/python/lib/sift_client/_tests/resources/test_runs.py index b28c0cc06..f687ecaed 100644 --- a/python/lib/sift_client/_tests/resources/test_runs.py +++ b/python/lib/sift_client/_tests/resources/test_runs.py @@ -325,9 +325,10 @@ async def test_update_run_name(self, runs_api_async, new_run): async def test_update_run_tags_and_metadata(self, runs_api_async, new_run): """Test updating a run's tags and metadata.""" try: - # Update tags + # Update tags and metadata update = RunUpdate( tags=["updated", "new-tag", "sift-client-pytest"], + metadata={"test_key": "test_value", "number": 42.5, "flag": True}, ) updated_run = await runs_api_async.update(new_run, update) @@ -337,6 +338,9 @@ async def test_update_run_tags_and_metadata(self, runs_api_async, new_run): "new-tag", "sift-client-pytest", } + assert updated_run.metadata["test_key"] == "test_value" + assert updated_run.metadata["number"] == 42.5 + assert updated_run.metadata["flag"] is True finally: await runs_api_async.archive(new_run.id_) diff --git a/python/lib/sift_client/_tests/sift_types/test_calculated_channel.py b/python/lib/sift_client/_tests/sift_types/test_calculated_channel.py index d79fa34ba..e5b1059c5 100644 --- a/python/lib/sift_client/_tests/sift_types/test_calculated_channel.py +++ b/python/lib/sift_client/_tests/sift_types/test_calculated_channel.py @@ -47,8 +47,7 @@ def test_expression_helper(self): # Verify expression is set in nested path assert ( - proto.calculated_channel_configuration.query_configuration.sel.expression - == "$1 + $2" + proto.calculated_channel_configuration.query_configuration.sel.expression == "$1 + $2" ) assert "query_configuration" in mask.paths @@ -66,9 +65,7 @@ def test_expression_channel_references_helper(self): proto, mask = update.to_proto_with_mask() # Verify channel references are converted - refs = ( - proto.calculated_channel_configuration.query_configuration.sel.expression_channel_references - ) + refs = proto.calculated_channel_configuration.query_configuration.sel.expression_channel_references assert len(refs) == 2 assert refs[0].channel_reference == "$1" assert refs[0].channel_identifier == "channel1" @@ -110,10 +107,7 @@ def test_all_assets_helper(self): proto, mask = update.to_proto_with_mask() # Verify all_assets is set in nested path - assert ( - proto.calculated_channel_configuration.asset_configuration.all_assets - is True - ) + assert proto.calculated_channel_configuration.asset_configuration.all_assets is True # Verify update_field is in mask (same as tag_ids and asset_ids) assert "asset_configuration" in mask.paths @@ -153,9 +147,7 @@ def test_expression_validator_rejects_references_without_expression(self): ): CalculatedChannelUpdate( expression_channel_references=[ - ChannelReference( - channel_reference="$1", channel_identifier="channel1" - ), + ChannelReference(channel_reference="$1", channel_identifier="channel1"), ], ) @@ -210,9 +202,7 @@ def mock_calculated_channel(mock_client): class TestCalculatedChannel: """Unit tests for CalculatedChannel model - tests properties and methods.""" - def test_archive_calls_client_and_updates_self( - self, mock_calculated_channel, mock_client - ): + def test_archive_calls_client_and_updates_self(self, mock_calculated_channel, mock_client): """Test that archive() calls client.calculated_channels.archive and calls _update.""" archived_calc_channel = MagicMock() archived_calc_channel.is_archived = True @@ -234,9 +224,7 @@ def test_archive_calls_client_and_updates_self( # Verify it returns self assert result is mock_calculated_channel - def test_unarchive_calls_client_and_updates_self( - self, mock_calculated_channel, mock_client - ): + def test_unarchive_calls_client_and_updates_self(self, mock_calculated_channel, mock_client): """Test that unarchive() calls client.calculated_channels.unarchive and calls _update.""" unarchived_calc_channel = MagicMock() unarchived_calc_channel.is_archived = False @@ -258,9 +246,7 @@ def test_unarchive_calls_client_and_updates_self( # Verify it returns self assert result is mock_calculated_channel - def test_update_calls_client_and_updates_self( - self, mock_calculated_channel, mock_client - ): + def test_update_calls_client_and_updates_self(self, mock_calculated_channel, mock_client): """Test that update() calls client.calculated_channels.update and calls _update.""" updated_calc_channel = MagicMock() updated_calc_channel.description = "Updated description" From feacab2f6e6edafb15e362e9cdda574aad680d27 Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Fri, 10 Oct 2025 16:05:55 -0700 Subject: [PATCH 39/39] fix regression --- python/lib/sift_client/sift_types/_base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/lib/sift_client/sift_types/_base.py b/python/lib/sift_client/sift_types/_base.py index db0370e36..faa3382bb 100644 --- a/python/lib/sift_client/sift_types/_base.py +++ b/python/lib/sift_client/sift_types/_base.py @@ -58,11 +58,12 @@ def _update(self, other: BaseType[ProtoT, SelfT]) -> BaseType[ProtoT, SelfT]: for key in other.__class__.model_fields.keys(): if key in self.model_fields: self.__dict__.update({key: getattr(other, key)}) - return self # Make sure we also update the proto since it is excluded self.__dict__["proto"] = other.proto + return self + @model_validator(mode="after") def _validate_timezones(self): """Validate datetime fiels have timezone information."""