diff --git a/.githooks/pre-push b/.githooks/pre-push index 6b13b65f5..ba7f5ac86 100755 --- a/.githooks/pre-push +++ b/.githooks/pre-push @@ -8,7 +8,10 @@ GITHOOKS_DIR="$REPO_ROOT/.githooks" python_changed_files=($(git diff --name-only --diff-filter=ACM | grep '^python/lib/sift_client/' || true)) if [[ -n "$python_changed_files" ]]; then - echo "Python files changed, running Python stub checks..." + echo "Python files changed, running Python formatting and linting..." + bash "$GITHOOKS_DIR/pre-push-python/fmt-lint.sh" + + echo "Running Python stub checks..." bash "$GITHOOKS_DIR/pre-push-python/stubs.sh" fi diff --git a/.githooks/pre-push-python/fmt-lint.sh b/.githooks/pre-push-python/fmt-lint.sh new file mode 100644 index 000000000..575a4475b --- /dev/null +++ b/.githooks/pre-push-python/fmt-lint.sh @@ -0,0 +1,35 @@ +#!/usr/bin/env bash + +set -e + +# Store the root directory of the repository +REPO_ROOT="$(git rev-parse --show-toplevel)" +PYTHON_DIR="$REPO_ROOT/python" + +echo "Running Python formatting and linting with --fix..." + +# Change to Python directory +cd "$PYTHON_DIR" + +# Run ruff format (formatter) +echo "Running ruff format..." +bash ./scripts/dev fmt + +# Run ruff check with --fix (linter) +echo "Running ruff check --fix..." +bash ./scripts/dev lint-fix + +# Check if any files were modified by formatting/linting +cd "$REPO_ROOT" +changed_files=$(git status --porcelain python/lib/sift_client/ | grep -E '\.py$' || true) + +if [ -n "$changed_files" ]; then + echo "" + echo "ERROR: Formatting/linting made changes to the following files:" + echo "$changed_files" + echo "" + echo "Please commit these changes before pushing." + exit 1 +fi + +echo "Python formatting and linting completed successfully." diff --git a/.github/workflows/python_ci.yaml b/.github/workflows/python_ci.yaml index 4dbac8c5b..836e00437 100644 --- a/.github/workflows/python_ci.yaml +++ b/.github/workflows/python_ci.yaml @@ -25,9 +25,11 @@ jobs: python-version: "3.8" - name: Pip install + id: install run: | python -m pip install --upgrade pip pip install '.[development,openssl,tdms,rosbags,hdf5,sift-stream]' + - name: Lint run: | ruff check @@ -44,9 +46,17 @@ jobs: run: | pyright lib - - name: Pytest + - name: Pytest Unit Tests + run: | + pytest -m "not integration" + + - name: Pytest Integration Tests + env: + SIFT_GRPC_URI: ${{ vars.SIFT_GRPC_URI }} + SIFT_REST_URI: ${{ vars.SIFT_REST_URI }} + SIFT_API_KEY: ${{ secrets.SIFT_API_KEY }} run: | - pytest + pytest -m "integration" - name: Sync Stubs Mypy working-directory: python/lib diff --git a/python/lib/sift_client/_internal/low_level_wrappers/base.py b/python/lib/sift_client/_internal/low_level_wrappers/base.py index a8b93ed68..d349e51a2 100644 --- a/python/lib/sift_client/_internal/low_level_wrappers/base.py +++ b/python/lib/sift_client/_internal/low_level_wrappers/base.py @@ -22,7 +22,7 @@ async def _handle_pagination( page_size: The number of results to return per page. page_token: The token to use for the next page. order_by: How to order the retrieved results. - max_results: Maximum number of results to return. NOTE: Will be in increments of page_size or default page size defined by the call if no page_size is provided. + max_results: Maximum number of results to return. Returns: A list of all matching results. @@ -31,6 +31,8 @@ async def _handle_pagination( kwargs = {} results: list[Any] = [] + if max_results == 0: + return results if page_token is None: page_token = "" while True: @@ -45,4 +47,6 @@ async def _handle_pagination( results.extend(response) if page_token == "": break + if max_results and len(results) > max_results: + results = results[:max_results] return results diff --git a/python/lib/sift_client/_internal/low_level_wrappers/ingestion.py b/python/lib/sift_client/_internal/low_level_wrappers/ingestion.py index b7c4a4341..743e211a1 100644 --- a/python/lib/sift_client/_internal/low_level_wrappers/ingestion.py +++ b/python/lib/sift_client/_internal/low_level_wrappers/ingestion.py @@ -18,17 +18,14 @@ from queue import Queue from typing import TYPE_CHECKING, Any, cast -import sift_stream_bindings from sift.ingestion_configs.v2.ingestion_configs_pb2 import ( GetIngestionConfigRequest, ListIngestionConfigFlowsResponse, ListIngestionConfigsRequest, ListIngestionConfigsResponse, ) -from sift.ingestion_configs.v2.ingestion_configs_pb2_grpc import IngestionConfigServiceStub -from sift_stream_bindings import ( - IngestionConfigFormPy, - IngestWithConfigDataStreamRequestPy, +from sift.ingestion_configs.v2.ingestion_configs_pb2_grpc import ( + IngestionConfigServiceStub, ) from sift_client._internal.low_level_wrappers.base import ( @@ -44,6 +41,12 @@ if TYPE_CHECKING: from datetime import datetime + from sift_stream_bindings import ( + IngestionConfigFormPy, + IngestWithConfigDataStreamRequestPy, + SiftStreamBuilderPy, + ) + class IngestionThread(threading.Thread): """Manages ingestion for a single ingestion config.""" @@ -54,7 +57,7 @@ class IngestionThread(threading.Thread): def __init__( self, - sift_stream_builder: sift_stream_bindings.SiftStreamBuilderPy, + sift_stream_builder: SiftStreamBuilderPy, data_queue: Queue, ingestion_config: IngestionConfigFormPy, no_data_timeout: int = 1, @@ -154,7 +157,7 @@ class IngestionLowLevelClient(LowLevelClientBase, WithGrpcClient): CacheEntry = namedtuple("CacheEntry", ["data_queue", "ingestion_config", "thread"]) - sift_stream_builder: sift_stream_bindings.SiftStreamBuilderPy + sift_stream_builder: SiftStreamBuilderPy stream_cache: dict[str, CacheEntry] def __init__(self, grpc_client: GrpcClient): @@ -163,21 +166,25 @@ def __init__(self, grpc_client: GrpcClient): Args: grpc_client: The gRPC client to use for making API calls. """ + from sift_stream_bindings import ( + RecoveryStrategyPy, + RetryPolicyPy, + SiftStreamBuilderPy, + ) + super().__init__(grpc_client=grpc_client) # Rust GRPC client expects URI to have http(s):// prefix. uri = grpc_client._config.uri if not uri.startswith("http"): uri = f"https://{uri}" if grpc_client._config.use_ssl else f"http://{uri}" - self.sift_stream_builder = sift_stream_bindings.SiftStreamBuilderPy( + self.sift_stream_builder = SiftStreamBuilderPy( uri=uri, apikey=grpc_client._config.api_key, ) self.sift_stream_builder.enable_tls = grpc_client._config.use_ssl # FD-177: Expose configuration for recovery strategy. - self.sift_stream_builder.recovery_strategy = ( - sift_stream_bindings.RecoveryStrategyPy.retry_only( - sift_stream_bindings.RetryPolicyPy.default() - ) + self.sift_stream_builder.recovery_strategy = RecoveryStrategyPy.retry_only( + RetryPolicyPy.default() ) self.stream_cache = {} @@ -229,7 +236,9 @@ async def get_ingestion_config_id_from_client_key(self, client_key: str) -> str return ingestion_configs[0].id_ def _new_ingestion_thread( - self, ingestion_config_id: str, ingestion_config: IngestionConfigFormPy | None = None + self, + ingestion_config_id: str, + ingestion_config: IngestionConfigFormPy | None = None, ): """Start a new ingestion thread. This allows ingestion to happen in the background regardless of if the user is using the sync or async client @@ -290,7 +299,6 @@ async def create_ingestion_config( asset_name: str, flows: list[Flow], client_key: str | None = None, - organization_id: str | None = None, ) -> str: """Create an ingestion config. @@ -303,6 +311,8 @@ async def create_ingestion_config( Returns: The id of the new or found ingestion config. """ + from sift_stream_bindings import IngestionConfigFormPy + ingestion_config_id = None if client_key: logger.debug(f"Getting ingestion config id for client key {client_key}") @@ -381,6 +391,8 @@ def ingest_flow( channel_values: The channel values to ingest. organization_id: The organization id to use for ingestion. Only relevant if the user is part of several organizations. """ + from sift_stream_bindings import IngestWithConfigDataStreamRequestPy + if not flow.ingestion_config_id: raise ValueError( "Flow has no ingestion config id -- have you created an ingestion config for this flow?" diff --git a/python/lib/sift_client/_internal/low_level_wrappers/rules.py b/python/lib/sift_client/_internal/low_level_wrappers/rules.py index bf5bdb905..b578b6177 100644 --- a/python/lib/sift_client/_internal/low_level_wrappers/rules.py +++ b/python/lib/sift_client/_internal/low_level_wrappers/rules.py @@ -135,7 +135,8 @@ async def create_rule( ) conditions_request = [ UpdateConditionRequest( - expression=expression_proto, actions=[create.action._to_update_request()] + expression=expression_proto, + actions=[create.action._to_update_request()], ) ] update_request = UpdateRuleRequest( @@ -183,9 +184,7 @@ def _update_rule_request_from_update( "asset_tag_ids", ] # Need to manually copy fields that will be reset even if not provided in update dict. - copy_unset_fields = [ - "description", - ] + copy_unset_fields = ["description", "name"] # Populate the trivial fields first. update_dict.update( @@ -214,15 +213,17 @@ def _update_rule_request_from_update( "Expression and channel_references must both be provided or both be None" ) expression_proto = RuleConditionExpression( - calculated_channel=CalculatedChannelConfig( - expression=expression, - channel_references={ - c.channel_reference: ChannelReferenceProto(name=c.channel_identifier) - for c in channel_references - }, + calculated_channel=( + CalculatedChannelConfig( + expression=expression, + channel_references={ + c.channel_reference: ChannelReferenceProto(name=c.channel_identifier) + for c in channel_references + }, + ) + if expression + else None ) - if expression - else None ) conditions_request = [ UpdateConditionRequest( @@ -238,10 +239,10 @@ def _update_rule_request_from_update( # This always needs to be set, so handle the defaults. update_dict["asset_configuration"] = RuleAssetConfiguration( # type: ignore - asset_ids=update.asset_ids if "asset_ids" in model_dump else rule.asset_ids or [], - tag_ids=update.asset_tag_ids - if "asset_tag_ids" in model_dump - else rule.asset_tag_ids or [], + asset_ids=(update.asset_ids if "asset_ids" in model_dump else rule.asset_ids or []), + tag_ids=( + update.asset_tag_ids if "asset_tag_ids" in model_dump else rule.asset_tag_ids or [] + ), ) update_request = UpdateRuleRequest( @@ -254,7 +255,7 @@ def _update_rule_request_from_update( async def update_rule( self, rule: Rule, update: RuleUpdate, version_notes: str | None = None ) -> Rule: - """Update a rule. + """Update a rule. Also handles archive/unarchive to behave similar to other low-level clients. Args: rule: The rule to update. @@ -264,14 +265,26 @@ async def update_rule( Returns: The updated Rule. """ + + should_update_archive = "is_archived" in update.model_fields_set + update.resource_id = rule.id_ + if not should_update_archive or ( + should_update_archive and len(update.model_fields_set) > 1 + ): + update_request = self._update_rule_request_from_update(rule, update, version_notes) + + response = await self._grpc_client.get_stub(RuleServiceStub).UpdateRule(update_request) + _ = cast("UpdateRuleResponse", response) - update_request = self._update_rule_request_from_update(rule, update, version_notes) + if should_update_archive: + if update.is_archived: + await self.archive_rule(rule_id=rule.id_) + else: + await self.unarchive_rule(rule_id=rule.id_) - response = await self._grpc_client.get_stub(RuleServiceStub).UpdateRule(update_request) - updated_grpc_rule = cast("UpdateRuleResponse", response) # Get the updated rule - return await self.get_rule(rule_id=updated_grpc_rule.rule_id) + return await self.get_rule(rule_id=rule.id_) async def batch_update_rules(self, rules: list[RuleUpdate]) -> BatchUpdateRulesResponse: """Batch update rules. diff --git a/python/lib/sift_client/_tests/integrated/__init__.py b/python/lib/sift_client/_tests/_internal/low_level_wrappers/__init__.py similarity index 100% rename from python/lib/sift_client/_tests/integrated/__init__.py rename to python/lib/sift_client/_tests/_internal/low_level_wrappers/__init__.py diff --git a/python/lib/sift_client/_tests/_internal/low_level_wrappers/test_base.py b/python/lib/sift_client/_tests/_internal/low_level_wrappers/test_base.py new file mode 100644 index 000000000..02ecf8142 --- /dev/null +++ b/python/lib/sift_client/_tests/_internal/low_level_wrappers/test_base.py @@ -0,0 +1,165 @@ +"""Tests for LowLevelClientBase. + +These tests validate the functionality of the LowLevelClientBase class including: +- Pagination handling with various scenarios +- Edge cases and error handling +- Parameter validation and behavior +""" + +from unittest.mock import AsyncMock + +import pytest + +from sift_client._internal.low_level_wrappers.base import LowLevelClientBase + + +class TestLowLevelClientBase: + """Test suite for LowLevelClientBase functionality.""" + + class TestHandlePagination: + """Tests for the _handle_pagination static method.""" + + @pytest.mark.asyncio + async def test_basic_pagination_single_page(self): + """Test pagination with a single page of results.""" + # Mock function that returns results and empty page token (indicating no more pages) + mock_func = AsyncMock(return_value=([1, 2, 3], "")) + + results = await LowLevelClientBase._handle_pagination(mock_func) + + assert results == [1, 2, 3] + mock_func.assert_called_once_with(page_size=None, page_token="", order_by=None) + + @pytest.mark.asyncio + async def test_pagination_multiple_pages(self): + """Test pagination across multiple pages.""" + # Mock function that returns different results for different page tokens + mock_func = AsyncMock() + mock_func.side_effect = [ + ([1, 2, 3], "token1"), # First page + ([4, 5, 6], "token2"), # Second page + ([7, 8, 9], ""), # Last page (empty token) + ] + + results = await LowLevelClientBase._handle_pagination(mock_func) + + assert results == [1, 2, 3, 4, 5, 6, 7, 8, 9] + assert mock_func.call_count == 3 + + # Verify the calls were made with correct page tokens + calls = mock_func.call_args_list + assert calls[0][1]["page_token"] == "" + assert calls[1][1]["page_token"] == "token1" + assert calls[2][1]["page_token"] == "token2" + + @pytest.mark.asyncio + async def test_pagination_with_page_size(self): + """Test pagination with specified page size.""" + mock_func = AsyncMock(return_value=([1, 2], "")) + + results = await LowLevelClientBase._handle_pagination(mock_func, page_size=2) + + assert results == [1, 2] + mock_func.assert_called_once_with(page_size=2, page_token="", order_by=None) + + @pytest.mark.asyncio + async def test_pagination_with_order_by(self): + """Test pagination with order_by parameter.""" + mock_func = AsyncMock(return_value=([1, 2, 3], "")) + + results = await LowLevelClientBase._handle_pagination(mock_func, order_by="name asc") + + assert results == [1, 2, 3] + mock_func.assert_called_once_with(page_size=None, page_token="", order_by="name asc") + + @pytest.mark.asyncio + async def test_pagination_with_initial_page_token(self): + """Test pagination starting with a specific page token.""" + mock_func = AsyncMock(return_value=([4, 5, 6], "")) + + results = await LowLevelClientBase._handle_pagination( + mock_func, page_token="start_token" + ) + + assert results == [4, 5, 6] + mock_func.assert_called_once_with( + page_size=None, page_token="start_token", order_by=None + ) + + @pytest.mark.asyncio + async def test_pagination_with_kwargs(self): + """Test pagination with additional keyword arguments.""" + mock_func = AsyncMock(return_value=([1, 2, 3], "")) + kwargs = {"filter": "active", "include_archived": False} + + results = await LowLevelClientBase._handle_pagination(mock_func, kwargs=kwargs) + + assert results == [1, 2, 3] + mock_func.assert_called_once_with( + page_size=None, + page_token="", + order_by=None, + filter="active", + include_archived=False, + ) + + @pytest.mark.asyncio + async def test_pagination_with_max_results_single_page(self): + """Test pagination with max_results that fits in a single page.""" + mock_func = AsyncMock(return_value=([1, 2, 3, 4, 5], "")) + + results = await LowLevelClientBase._handle_pagination(mock_func, max_results=3) + + # Should return only the max results + assert results == [1, 2, 3] + mock_func.assert_called_once() + + @pytest.mark.asyncio + async def test_pagination_with_max_results_multiple_pages(self): + """Test pagination with max_results across multiple pages.""" + mock_func = AsyncMock() + mock_func.side_effect = [ + ([1, 2, 3], "token1"), # First page (3 items) + ([4, 5, 6], "token2"), # Second page (6 total items, exceeds max_results=5) + ] + + results = await LowLevelClientBase._handle_pagination(mock_func, max_results=5) + + # Should include 2 pages and return the full first page but limited 2nd page + assert results == [1, 2, 3, 4, 5] + assert mock_func.call_count == 2 + + @pytest.mark.asyncio + async def test_pagination_with_max_results_exact_match(self): + """Test pagination when results exactly match max_results.""" + mock_func = AsyncMock() + mock_func.side_effect = [ + ([1, 2, 3], "token1"), # First page + ([4, 5], ""), # Second page, total = 5 + ] + + results = await LowLevelClientBase._handle_pagination(mock_func, max_results=5) + + assert results == [1, 2, 3, 4, 5] + assert mock_func.call_count == 2 + + @pytest.mark.asyncio + async def test_pagination_empty_results(self): + """Test pagination when function returns empty results.""" + mock_func = AsyncMock(return_value=([], "")) + + results = await LowLevelClientBase._handle_pagination(mock_func) + + assert results == [] + mock_func.assert_called_once() + + @pytest.mark.asyncio + async def test_pagination_max_results_zero(self): + """Test pagination with max_results=0.""" + mock_func = AsyncMock(return_value=([1, 2, 3], "")) + + results = await LowLevelClientBase._handle_pagination(mock_func, max_results=0) + + # Should return empty list without calling the function + assert results == [] + mock_func.assert_not_called() diff --git a/python/lib/sift_client/_tests/conftest.py b/python/lib/sift_client/_tests/conftest.py new file mode 100644 index 000000000..1f218dedc --- /dev/null +++ b/python/lib/sift_client/_tests/conftest.py @@ -0,0 +1,42 @@ +"""Shared pytest fixtures for all tests.""" + +import os +from unittest.mock import MagicMock + +import pytest + +from sift_client import SiftClient +from sift_client.util.util import AsyncAPIs + + +@pytest.fixture(scope="session") +def sift_client() -> SiftClient: + """Create a SiftClient instance for testing. + + This fixture is shared across all test files and is session-scoped + to avoid creating multiple client instances. + """ + grpc_url = os.getenv("SIFT_GRPC_URI", "localhost:50051") + rest_url = os.getenv("SIFT_REST_URI", "localhost:8080") + api_key = os.getenv("SIFT_API_KEY", "") + + return SiftClient( + api_key=api_key, + grpc_url=grpc_url, + rest_url=rest_url, + ) + + +@pytest.fixture +def mock_client(): + """Create a mock SiftClient for unit testing.""" + client = MagicMock(spec=SiftClient) + # Configure the mock to have the necessary API attributes + client.assets = MagicMock() + client.runs = MagicMock() + client.channels = MagicMock() + client.calculated_channels = MagicMock() + client.rules = MagicMock() + client.async_ = MagicMock(spec=AsyncAPIs) + client.async_.ingestion = MagicMock() + return client diff --git a/python/lib/sift_client/_tests/integrated/calculated_channels.py b/python/lib/sift_client/_tests/integrated/calculated_channels.py deleted file mode 100644 index 2d423b17c..000000000 --- a/python/lib/sift_client/_tests/integrated/calculated_channels.py +++ /dev/null @@ -1,251 +0,0 @@ -import asyncio -import os -from datetime import datetime, timezone - -from sift_client.client import SiftClient - -# Import sift_client types for calculated channels and rules -from sift_client.sift_types import ( - CalculatedChannelUpdate, - ChannelReference, -) -from sift_client.sift_types.calculated_channel import CalculatedChannelCreate - -""" -Comprehensive test script for calculated channels with extensive update field exercises. - -This test demonstrates all available update fields for calculated channels: -- name: Update the channel name -- description: Update the channel description -- units: Update the units of measurement -- expression: Update the calculation expression -- expression_channel_references: Update channel references (must be updated with expression) -- asset_ids: Update which assets the channel applies to -- tag_ids: Update associated tags - -The test also includes: -- Edge case testing (minimal updates, invalid expressions) -- Batch operations demonstration -- Comprehensive validation -- Error handling and graceful fallbacks -- Archive operations - -TODO: TBD if we move this to an example or keep it here as a test expected to be used just by us. - -If we keep it as a test, we should ideally have a setup that populates data, and then ensure we teardown all the test assets/channels/rules etc. -""" - - -async def main(): - grpc_url = os.getenv("SIFT_GRPC_URI", "localhost:50051") - api_key = os.getenv("SIFT_API_KEY", "") - rest_url = os.getenv("SIFT_REST_URI", "localhost:8080") - client = SiftClient(grpc_url=grpc_url, api_key=api_key, rest_url=rest_url) - - # Find assets to work with - asset = client.assets.find(name="NostromoLV426") - asset_id = asset.id_ - print(f"Using asset: {asset.name} (ID: {asset_id})") - - # Create example calculated channels that will be unique to this test run in case things don't cleanup. - num_channels = 7 - unique_name_suffix = datetime.now(tz=timezone.utc).strftime("%Y%m%d%H%M%S") - print( - f"\n=== Creating {num_channels} calculated channels with unique suffix: {unique_name_suffix} ===" - ) - - created_channels = [] - for i in range(num_channels): - new_chan = CalculatedChannelCreate( - name=f"test_channel_{unique_name_suffix}_{i}", - description=f"Test calculated channel {i} - initial description", - expression="$1 / $2", # $1 = mainmotor.velocity, $2 = voltage - expression_channel_references=[ - ChannelReference(channel_reference="$1", channel_identifier="mainmotor.velocity"), - ChannelReference(channel_reference="$2", channel_identifier="voltage"), - ], - units="velocity/voltage", - asset_ids=[asset_id], - user_notes=f"Created for testing update fields - channel {i}", - ) - calculated_channel = client.calculated_channels.create(new_chan) - created_channels.append(calculated_channel) - print( - f"Created calculated channel: {calculated_channel.name} (ID: {calculated_channel.id_})" - ) - - # Find the channels we just created - search_results = client.calculated_channels.list( - name_regex="test_channel.*", - asset_id=asset_id, - ) - print(f"Found {len(search_results)} calculated channels: {[cc.name for cc in search_results]}") - - print("\n=== Testing comprehensive update scenarios ===") - - # Test 1: Update expression and channel references together - print("\n--- Test 1: Update expression and channel references ---") - channel_1 = created_channels[0] - updated_channel_1 = channel_1.update( - CalculatedChannelUpdate( - expression="$1 / $2 * 100", # Convert to percentage - expression_channel_references=[ - ChannelReference(channel_reference="$1", channel_identifier="mainmotor.velocity"), - ChannelReference(channel_reference="$2", channel_identifier="voltage"), - ], - ) - ) - print(f"Updated {updated_channel_1.name}: expression = {updated_channel_1.expression}") - - # Test 2: Update description - print("\n--- Test 2: Update description ---") - channel_2 = created_channels[1] - updated_channel_2 = channel_2.update( - CalculatedChannelUpdate( - description="Updated description with more details about velocity-to-voltage ratio calculation", - ) - ) - print(f"Updated {updated_channel_2.name}: description = {updated_channel_2.description}") - - # Test 3: Update units - print("\n--- Test 3: Update units ---") - channel_3 = created_channels[2] - updated_channel_3 = channel_3.update( - CalculatedChannelUpdate( - units="percentage", - ) - ) - print(f"Updated {updated_channel_3.name}: units = {updated_channel_3.units}") - - # Test 4: Update name - print("\n--- Test 4: Update name ---") - channel_4 = created_channels[3] - new_name = f"renamed_channel_{unique_name_suffix}_5" - updated_channel_4 = channel_4.update( - CalculatedChannelUpdate( - name=new_name, - ) - ) - print(f"Updated {channel_4.name} -> {updated_channel_4.name}") - - # Test 5: Update multiple fields at once - print("\n--- Test 5: Update multiple fields simultaneously ---") - channel_5 = created_channels[4] - updated_channel_5 = channel_5.update( - CalculatedChannelUpdate( - description="Multi-field update test", - units="ratio", - ), - user_notes="Updated via multi-field update", - ) - print(f"Updated {updated_channel_5.name}:") - print(f" - description: {updated_channel_5.description}") - print(f" - units: {updated_channel_5.units}") - print(f" - user_notes: {updated_channel_5.user_notes}") - - # Test 6: Update with complex expression - print("\n--- Test 6: Update with complex expression ---") - channel_6 = created_channels[5] - updated_channel_6 = channel_6.update( - CalculatedChannelUpdate( - expression="($1 / $2) * 100 + ($3 * 0.1)", # Complex expression with 3 variables - expression_channel_references=[ - ChannelReference(channel_reference="$1", channel_identifier="mainmotor.velocity"), - ChannelReference(channel_reference="$2", channel_identifier="voltage"), - ChannelReference(channel_reference="$3", channel_identifier="temperature"), - ], - ) - ) - print(f"Updated {updated_channel_6.name}: complex expression = {updated_channel_6.expression}") - - # Test 7: Update tag_ids (if tags are available) - print("\n--- Test 7: Update tag_ids ---") - channel_7 = created_channels[6] - # Note: This would require actual tag IDs from the system - # For now, we'll test with an empty list to show the capability - updated_channel_7 = channel_7.update( - CalculatedChannelUpdate( - tag_ids=[], # Empty list - in practice you'd use actual tag IDs - ) - ) - print(f"Updated {updated_channel_7.name}: tag_ids = {updated_channel_7.tag_ids}") - - # Test 7b: Edge case - Update with invalid expression (should fail gracefully) - print("\n--- Test 7b: Edge case - Invalid expression test ---") - try: - invalid_update = channel_7.update( - CalculatedChannelUpdate( - expression="invalid_expression", - expression_channel_references=[ - ChannelReference( - channel_reference="$1", channel_identifier="mainmotor.velocity" - ), - ], - ) - ) - print(f"Invalid expression update succeeded (unexpected): {invalid_update.expression}") - # TODO: Ticket this? - except Exception as e: - print(f"Invalid expression update failed as expected: {e}") - - # Test 8: Archive channels - print("\n--- Test 8: Archive channels ---") - archived_count = 0 - for cc in created_channels: - cc.archive() - print(f"Archived: {cc.name}") - archived_count += 1 - - print("\n=== Test Summary ===") - print(f"Created: {len(created_channels)} channels") - print(f"Archived: {archived_count} channels") - - # Verify all channels were processed - assert len(created_channels) == num_channels, ( - f"Expected {num_channels} created channels, got {len(created_channels)}" - ) - assert archived_count == num_channels, ( - f"Expected {num_channels} archived channels, got {archived_count}" - ) - - # Additional validation - print("\n=== Validation Checks ===") - - # Verify that updates actually changed the values - assert updated_channel_1.expression == "$1 / $2 * 100", ( - f"Expression update failed: {updated_channel_1.expression}" - ) - assert "more details" in updated_channel_2.description, ( - f"Description update failed: {updated_channel_2.description}" - ) - assert updated_channel_3.units == "percentage", ( - f"Units update failed: {updated_channel_3.units}" - ) - assert updated_channel_4.name == new_name, f"Name update failed: {updated_channel_4.name}" - assert updated_channel_5.description == "Multi-field update test", ( - f"Description update failed: {updated_channel_5.description}" - ) - assert updated_channel_5.units == "ratio", f"Units update failed: {updated_channel_5.units}" - assert updated_channel_5.user_notes == "Updated via multi-field update", ( - f"User notes update failed: {updated_channel_5.user_notes}" - ) - assert updated_channel_6.expression == "($1 / $2) * 100 + ($3 * 0.1)", ( - f"Complex expression update failed: {updated_channel_6.expression}" - ) - assert len(updated_channel_6.channel_references) == 3, ( - f"Complex expression should have 3 references, got {len(updated_channel_6.channel_references)}" - ) - assert updated_channel_7.tag_ids == [], f"Tag IDs update failed: {updated_channel_7.tag_ids}" - - versions = client.calculated_channels.list_versions( - calculated_channel=channel_1.id_, - limit=10, - ) - print(f"Found {len(versions)} versions for {created_channels[0].name}") - - print("All validation checks passed!") - print("\n=== Test completed successfully ===") - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/python/lib/sift_client/_tests/integrated/channels.py b/python/lib/sift_client/_tests/integrated/channels.py deleted file mode 100644 index 73c54e0df..000000000 --- a/python/lib/sift_client/_tests/integrated/channels.py +++ /dev/null @@ -1,205 +0,0 @@ -import asyncio -import os -import time -from datetime import datetime, timezone - -import numpy as np -import pandas as pd -import pyarrow as pa - -from sift_client.client import SiftClient - - -async def main(): - grpc_url = os.getenv("SIFT_GRPC_URI", "localhost:50051") - rest_url = os.getenv("SIFT_REST_URI", "localhost:8080") - api_key = os.getenv("SIFT_API_KEY", "") - client = SiftClient(grpc_url=grpc_url, api_key=api_key, rest_url=rest_url) - - asset = client.assets.find(name="NostromoLV426") - asset_id = asset.id_ - print(f"Using asset: {asset.name} (ID: {asset_id})") - - # List runs for this asset - runs = asset.runs - print( - f"Found {len(runs)} run(s): {[run.name for run in runs]} for asset {asset.name} (ID: {asset_id})" - ) - - # Pick one. - run = runs[0] - run_id = run.id_ - print(f"Using run: {run.name} (ID: {run_id})") - - # List other assets for this run. - all_assets = run.assets - other_assets = [asset for asset in all_assets if asset.id_ != asset_id] - print( - f"Found {len(other_assets)} other asset(s): {other_assets} for run {run.name} (ID: {run_id})" - ) - - # List channels for this asset (find a run w/ data) - channels = [] - for run in runs: - asset_channels = asset.channels(run=run.id_, limit=10) - other_channels = [] - for c in asset_channels: - if c.name in {"voltage", "gpio", "temperature", "mainmotor.velocity"}: - channels.append(c) - else: - other_channels.append(c) - - if len(channels) > 3: - print( - f"Found {len(channels)} channel(s): {[channel.name for channel in channels]} for asset {asset.name} on run {run.name}" - ) - if len(other_channels) > 0: - print( - f"Found {len(other_channels)} other channel(s): {[c.name for c in other_channels]} for asset {asset.name} on run {run.name}" - ) - break - - # Get the channel data during a specific run - channel = channels[0] - channel_data = channel.data(run_id="1d5f5c93-eaaa-48f2-94ff-7ec4337faec7", limit=100) - print(f"Channel data for {channel.name} has {len(channel_data)} points") - - # Get data for multiple channels - print("Getting data for multiple channels:") - perf_start = time.perf_counter() - channel_data = client.channels.get_data( - run="1d5f5c93-eaaa-48f2-94ff-7ec4337faec7", channels=channels, limit=100 - ) - first_time = time.perf_counter() - perf_start - start_time = None - end_time = None - for i, (channel_name, data) in enumerate(channel_data.items()): - print(f"{i}: {channel_name}: {len(data)} points. Avg: {np.mean(data[channel_name])}") - - # Pick a random channel and grab the start end times so we can test the cache - if i == 1: - start_time = data.index[0] - end_time = data.index[-1] - print(f"Start time: {start_time}, End time: {end_time}") - - # Test cache with varying start_time and end_time parameters - if start_time and end_time: - print("\n=== Testing cache with varying time ranges ===") - - # Test 1: Exact same time range (should hit cache) - print("\nTest 1: Exact same time range no run_id (should hit cache)") - perf_start = time.perf_counter() - _ = client.channels.get_data( - channels=channels, - start_time=start_time, - end_time=end_time, - ) - exact_time = time.perf_counter() - perf_start - - # Test 2: Subset of time range (should hit cache if overlapping) - print("\nTest 2: Subset of time range (should hit cache if overlapping)") - mid_time = start_time + (end_time - start_time) / 2 - perf_start = time.perf_counter() - _ = client.channels.get_data( - channels=channels, - start_time=start_time, - end_time=mid_time, - ) - subset_time = time.perf_counter() - perf_start - - # Test 3: Extended time range (should hit cache for overlapping portion) - print("\nTest 3: Extended time range earlier (should hit cache for overlapping portion)") - extended_start = start_time - (end_time - start_time) * 0.1 - perf_start = time.perf_counter() - _ = client.channels.get_data( - channels=channels, - start_time=extended_start, - end_time=end_time, - ) - extended_time = time.perf_counter() - perf_start - - # Test 4: Different time range (should not hit cache) - print("\nTest 4: Different time encompassed range (should hit cache)") - different_start = extended_start + pd.Timedelta(seconds=2) - different_end = start_time + pd.Timedelta(seconds=3) - perf_start = time.perf_counter() - _ = client.channels.get_data( - channels=channels, - start_time=different_start, - end_time=different_end, - ) - different_time = time.perf_counter() - perf_start - - # Test 5: No time range specified (should miss cache from original call) - print("\nTest 5: No time range specified (should miss cache)") - # Since None end time is treated as now, we capture now so we can repeat it. - fake_no_end_time = datetime.now(timezone.utc) - perf_start = time.perf_counter() - channel_data_no_time = client.channels.get_data( - channels=channels, - end_time=fake_no_end_time, - limit=100, - ) - no_time_time = time.perf_counter() - perf_start - for i, (channel_name, data) in enumerate(channel_data_no_time.items()): - print(f"{i}: {channel_name}: {len(data)} points. Avg: {np.mean(data[channel_name])}") - - # Test 6: No time range specified again (should hit cache) - # NOTE: We're not comparing the results since limit combines with cache results. - print("\nTest 6: No time range specified again (should hit cache)") - perf_start = time.perf_counter() - channel_data_no_time = client.channels.get_data( - channels=channels, - end_time=fake_no_end_time, - limit=100, - ) - for i, (channel_name, data) in enumerate(channel_data_no_time.items()): - print( - f"{i}: {channel_name}: {len(data)} points. Avg: {np.mean(data[channel_name]) if channel_name in data else np.nan}" - ) - no_time_time_repeat = time.perf_counter() - perf_start - - # Test 7: Get data as arrow - print("\nTest 7: Get data as arrow") - perf_start = time.perf_counter() - channel_data_arrow = client.channels.get_data_as_arrow( - channels=channels, - end_time=fake_no_end_time, - ) - arrow_time = time.perf_counter() - perf_start - for i, (channel_name, data) in enumerate(channel_data_arrow.items()): - print( - f"{i}: {channel_name}: {len(data)} points. Avg: {pa.compute.mean(data[channel_name])}" - ) - - # Summary of cache performance - print("\n=== Cache Performance Summary ===") - print(f"Original call: {first_time:.4f} seconds") - print( - f"Exact time range no run_id: {exact_time:.4f} seconds ({(first_time / exact_time):.1f}x faster)" - ) - print( - f"Subset time range: {subset_time:.4f} seconds ({(first_time / subset_time):.1f}x faster)" - ) - print( - f"Extended time range earlier: {extended_time:.4f} seconds ({(first_time / extended_time):.1f}x faster)" - ) - print( - f"Different time range: {different_time:.4f} seconds ({(first_time / different_time):.1f}x faster)" - ) - print( - f"No time range: {no_time_time:.4f} seconds ({(no_time_time / first_time):.1f}x slower)" - ) - print( - f"No time range repeat: {no_time_time_repeat:.4f} seconds ({(no_time_time / no_time_time_repeat):.1f}x faster)" - ) - print(f"Arrow: {arrow_time:.4f} seconds ({(arrow_time / no_time_time_repeat):.1f}x faster)") - assert exact_time < first_time - assert subset_time < first_time - assert extended_time < first_time - assert different_time < first_time - assert no_time_time_repeat < no_time_time - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/python/lib/sift_client/_tests/integrated/ingestion.py b/python/lib/sift_client/_tests/integrated/ingestion.py deleted file mode 100644 index b5f49d894..000000000 --- a/python/lib/sift_client/_tests/integrated/ingestion.py +++ /dev/null @@ -1,218 +0,0 @@ -import asyncio -import math -import os -import random -import time -from datetime import datetime, timedelta, timezone - -from sift_client._tests import setup_logger -from sift_client.client import SiftClient -from sift_client.sift_types.channel import ( - ChannelBitFieldElement, - ChannelDataType, -) -from sift_client.sift_types.ingestion import ChannelConfig, Flow -from sift_client.transport import SiftConnectionConfig - -setup_logger() - - -async def main(): - grpc_url = os.getenv("SIFT_GRPC_URI", "localhost:50051") - api_key = os.getenv("SIFT_API_KEY", "") - rest_url = os.getenv("SIFT_REST_URI", "localhost:8080") - client = SiftClient( - connection_config=SiftConnectionConfig( - grpc_url=grpc_url, - api_key=api_key, - rest_url=rest_url, - use_ssl=True, - cert_via_openssl=True, - ) - ) - - asset = "ian-test-asset" - - # TODO:Get user id from current user - previously_created_runs = client.runs.list_( - name_regex="test-run-.*", created_by="1eba461b-fa36-4e98-8fe8-ff32d3e43a6e" - ) - if previously_created_runs: - print(f" Deleting previously created runs: {previously_created_runs}") - for run in previously_created_runs: - print(f" Deleting run: {run.name}") - client.runs.archive(run=run) - - run = client.runs.create( - dict( - name=f"test-run-{datetime.now(tz=timezone.utc).timestamp()}", - description="A test run created via the API", - tags=["api-created", "test"], - ) - ) - - regular_flow = Flow( - name="test-flow", - channels=[ - ChannelConfig(name="test-channel", data_type=ChannelDataType.DOUBLE), - ChannelConfig( - name="test-enum-channel", - data_type=ChannelDataType.ENUM, - enum_types={"enum1": 1, "enum2": 2}, - ), - ], - ) - regular_flow.add_channelConfig( - ChannelConfig( - name="test-bit-field-channel", - data_type=ChannelDataType.BIT_FIELD, - bit_field_elements=[ - ChannelBitFieldElement(name="12v", index=0, bit_count=4), - ChannelBitFieldElement(name="charge", index=4, bit_count=2), - ChannelBitFieldElement(name="led", index=6, bit_count=1), - ChannelBitFieldElement(name="heater", index=7, bit_count=1), - ], - ) - ) - - highspeed_flow = Flow( - name="highspeed-flow", - channels=[ - ChannelConfig(name="highspeed-channel", data_type=ChannelDataType.DOUBLE), - ], - ) - # This seals the flow and ingestion config - config_id = await client.async_.ingestion.create_ingestion_config( - asset_name=asset, - run_id=run.id_, - flows=[regular_flow, highspeed_flow], - ) - print(f"config_id: {config_id}") - try: - regular_flow.add_channelConfig( - ChannelConfig(name="test-channel", data_type=ChannelDataType.DOUBLE) - ) - except ValueError as e: - assert repr(e) == "ValueError('Cannot add a channel to a flow after creation')" - - other_asset_flows = [ - Flow( - name="new-asset-flow", - channels=[ - # Same channel name as the regular flow, but on a different asset. - ChannelConfig(name="test-channel", data_type=ChannelDataType.DOUBLE), - ], - ) - ] - await client.async_.ingestion.create_ingestion_config( - asset_name="test-asset-ian2", - run_id=run.id_, - flows=other_asset_flows, - ) - sleep_time = 0.05 # Time between outer loop iterations to simulate real-time latency between ingestion calls. - simulated_duration = 50 - fake_hs_rate = 50 # Hz - fake_hs_period = 1 / fake_hs_rate - start = datetime.now(tz=timezone.utc) - for i in range(simulated_duration): - now = start + timedelta(seconds=i) - regular_flow.ingest( - timestamp=now, - channel_values={ - "test-channel": 3.0 * math.sin(2 * math.pi * fake_hs_rate * i + 0.07), - "test-enum-channel": i % 2 + 1, - "test-bit-field-channel": { - "12v": random.randint(3, 13), - "charge": random.randint(1, 3), - "led": random.choice([0, 1]), - "heater": random.choice([0, 1]), - }, - }, - ) - for j in range(fake_hs_rate): - val = 3.0 * math.sin(2 * math.pi * fake_hs_rate * (i + j * 0.001) + 0) - timestamp = now + timedelta(milliseconds=j * fake_hs_period * 1000) - channel_values = { - "highspeed-channel": val, - } - # Alternative way to ingest - client.ingestion.ingest( - flow=highspeed_flow, timestamp=timestamp, channel_values=channel_values - ) - time.sleep(sleep_time) - - other_asset_flows[0].ingest( - timestamp=start + timedelta(seconds=simulated_duration), - channel_values={ - "test-channel": -6.66, - }, - ) - - # Test ingestion of a flow without all channels specified - try: - regular_flow.ingest( - timestamp=start + timedelta(seconds=simulated_duration), - channel_values={ - "test-channel": 0, - "test-enum-channel": 2, - # "test-bit-field-channel": bytes([0b01010101]), - }, - ) - except ValueError as e: - assert "Expected all channels in flow to have a data point at same time." in repr(e) - - # Test ingestion of a bad enum value (string and int) - try: - regular_flow.ingest( - timestamp=start + timedelta(seconds=simulated_duration), - channel_values={ - "test-channel": 0, - "test-enum-channel": -3, - "test-bit-field-channel": bytes([0b01010101]), - }, - ) - except ValueError as e: - assert "Could not find enum value: -3 in enum options: {'enum1': 1, 'enum2': 2}" in repr(e) - try: - regular_flow.ingest( - timestamp=start + timedelta(seconds=simulated_duration), - channel_values={ - "test-channel": 0, - "test-enum-channel": "nonexistent-enum", - "test-bit-field-channel": bytes([0b01010101]), - }, - ) - except ValueError as e: - assert ( - "Could not find enum value: nonexistent-enum in enum options: {'enum1': 1, 'enum2': 2}" - in repr(e) - ) - - client.async_.ingestion.wait_for_ingestion_to_complete(timeout=2) - end = datetime.now(tz=timezone.utc) - # Test ingesting more data after letting a thread finish. Also exercise ingesting bitfield values as bytes. - time.sleep(1) - print("Restarting ingestion") - regular_flow.ingest( - timestamp=start + timedelta(seconds=simulated_duration + 1), - channel_values={ - "test-channel": 7.77, - "test-enum-channel": 1, - "test-bit-field-channel": bytes([0b11111111]), - }, - ) - # Wait less time than threads nominal no_data_timeout so we can exercise forced cleanup. - client.async_.ingestion.wait_for_ingestion_to_complete(timeout=0.01) - client.runs.archive(run=run.id_) - - num_datapoints = fake_hs_rate * len( - highspeed_flow.channels - ) * simulated_duration + simulated_duration * len(regular_flow.channels) - print(f"Ingestion time: {end - start} seconds") - print(f"Ingested {num_datapoints} datapoints") - total_time = (end - start).total_seconds() - print(f"Ingestion rate: {num_datapoints / total_time:.2f} datapoints/second") - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/python/lib/sift_client/_tests/integrated/rules.py b/python/lib/sift_client/_tests/integrated/rules.py deleted file mode 100644 index dbe88aeea..000000000 --- a/python/lib/sift_client/_tests/integrated/rules.py +++ /dev/null @@ -1,261 +0,0 @@ -import os -from datetime import datetime, timezone - -from sift_client.client import SiftClient - -# Import sift_client types for calculated channels and rules -from sift_client.sift_types import ( - ChannelReference, - RuleAction, - RuleAnnotationType, - RuleUpdate, -) -from sift_client.sift_types.rule import RuleCreate - -""" -Comprehensive test script for rules with extensive update field exercises. - -This test demonstrates all available update fields for rules: -- name: Update the rule name -- description: Update the rule description -- expression: Update the rule expression -- channel_references: Update channel references (must be updated with expression) -- action: Update the rule action (annotation, notification, webhook) -- tag_ids: Update associated tags (TBD) -- contextual_channels: Update contextual channels -- version_notes: Update version notes - -The test also includes: -- Edge case testing (invalid expressions) -- Batch operations demonstration -- Comprehensive validation -- Archive operations - - -If we keep it as a test, we should ideally have a setup that populates data, and then ensure we teardown all the test assets/channels/rules etc. -""" - - -def main(): - grpc_url = os.getenv("SIFT_GRPC_URI", "localhost:50051") - api_key = os.getenv("SIFT_API_KEY", "") - rest_url = os.getenv("SIFT_REST_URI", "localhost:8080") - client = SiftClient(grpc_url=grpc_url, api_key=api_key, rest_url=rest_url) - - asset = client.assets.find(name="NostromoLV426") - asset_id = asset.id_ - print(f"Using asset: {asset.name} (ID: {asset_id})") - - unique_name_suffix = datetime.now(tz=timezone.utc).strftime("%Y%m%d%H%M%S") - num_rules = 8 - print(f"\n=== Creating {num_rules} rules with unique suffix: {unique_name_suffix} ===") - created_rules = [] - for i in range(num_rules): - rule = client.rules.create( - RuleCreate( - name=f"test_rule_{unique_name_suffix}_{i}", - description=f"Test rule {i} - initial description", - expression="$1 > 0.1", # Simple threshold check - channel_references=[ - ChannelReference( - channel_reference="$1", channel_identifier="mainmotor.velocity" - ), - ], - action=RuleAction.annotation( - annotation_type=RuleAnnotationType.DATA_REVIEW, - tags=["test", "initial"], - default_assignee_user_id=None, - ), - asset_ids=[asset_id], - ) - ) - created_rules.append(rule) - print(f"Created rule: {rule.name} (ID: {rule.id_})") - - # Find the rules we just created - search_results = client.rules.list_( - name_regex=f"test_rule_{unique_name_suffix}.*", - ) - assert len(search_results) == num_rules, ( - f"Expected {num_rules} created rules, got {len(search_results)}" - ) - - print("\n=== Testing comprehensive update scenarios ===") - - # Test 1: Update expression and channel references together - print("\n--- Test 1: Update expression and channel references ---") - rule_1 = created_rules[0] - rule_1_model_dump = rule_1.model_dump() - updated_rule_1 = rule_1.update( - RuleUpdate( - expression="$1 > 0.5", # Higher threshold - channel_references=[ - ChannelReference(channel_reference="$1", channel_identifier="mainmotor.velocity"), - ], - ) - ) - updated_rule_1_model_dump = updated_rule_1.model_dump() - print(f"Updated {updated_rule_1.name}: expression = {updated_rule_1.expression}") - - # Test 2: Update description - print("\n--- Test 2: Update description ---") - rule_2 = created_rules[1] - updated_rule_2 = rule_2.update( - RuleUpdate( - description="Updated description with more details about velocity-to-voltage ratio monitoring", - ) - ) - print(f"Updated {updated_rule_2.name}: description = {updated_rule_2.description}") - - # Test 3: Update action (change annotation type and tags) - print("\n--- Test 3: Update action ---") - rule_3 = created_rules[2] - updated_rule_3 = rule_3.update( - RuleUpdate( - action=RuleAction.annotation( - annotation_type=RuleAnnotationType.PHASE, - tags=["updated", "phase", "alert"], - default_assignee_user_id=rule_3.created_by_user_id, - ), - ) - ) - print(f"Updated {updated_rule_3.name}: action type = {updated_rule_3.action.action_type}") - print(f" - annotation type: {updated_rule_3.action.annotation_type}") - print(f" - tags: {updated_rule_3.action.tags}") - print(f" - assignee: {updated_rule_3.action.default_assignee_user_id}") - - # Test 4: Update name - print("\n--- Test 4: Update name ---") - rule_4 = created_rules[3] - new_name = f"renamed_rule_{unique_name_suffix}_4" - updated_rule_4 = rule_4.update( - RuleUpdate( - name=new_name, - ) - ) - print(f"Updated {rule_4.name} -> {updated_rule_4.name}") - - # Test 5: Update multiple fields at once - print("\n--- Test 5: Update multiple fields simultaneously ---") - rule_5 = created_rules[4] - updated_rule_5 = rule_5.update( - RuleUpdate( - description="Multi-field update test", - ), - version_notes="Updated via multi-field update", - ) - print(f"Updated {updated_rule_5.name}:") - print(f" - description: {updated_rule_5.description}") - print( - f" - version_notes: {updated_rule_5.rule_version.version_notes if updated_rule_5.rule_version else None}" - ) - - # Test 6: Update with complex expression - print("\n--- Test 6: Update with complex expression ---") - rule_6 = created_rules[5] - updated_rule_6 = rule_6.update( - RuleUpdate( - expression="$1 > 0.3 && $1 < 0.8", # Range check - channel_references=[ - ChannelReference(channel_reference="$1", channel_identifier="mainmotor.velocity"), - ], - ) - ) - print(f"Updated {updated_rule_6.name}: complex expression = {updated_rule_6.expression}") - - # Test 7: Update action to notification type - # print("\n--- Test 7: Update action to notification ---") - # rule_7 = created_rules[6] - # updated_rule_7 = rule_7 - # Note: Notification actions are not supported yet. - # updated_rule_7 = rule_7.update( - # RuleUpdate( - # action=RuleAction.notification( - # notify_recipients=[rule_7.created_by_user_id] - # ), - # ) - # ) - # print(f"Updated {updated_rule_7.name}: action type = {updated_rule_7.action.action_type}") - # print(f" - notification recipients: {updated_rule_7.action.notification_recipients}") - - # Test 8: Update tag_ids and contextual_channels - print("\n--- Test 8: Update tag_ids and contextual_channels ---") - rule_8 = created_rules[7] - updated_rule_8 = rule_8.update( - RuleUpdate( - # tag_ids=["tag-123", "tag-456"], # Example tag IDs # TODO: Where are these IDs supposed to come from? They're supposed to be uuids? {grpc_message:"invalid argument: invalid input syntax for type uuid: \"tag-123\" - contextual_channels=[ - "temperature", - "pressure", - ], # Example contextual channels - ) - ) - print(f"Updated {updated_rule_8.name}:") - print(f" - asset_tag_ids: {updated_rule_8.asset_tag_ids}") - print(f" - contextual_channels: {updated_rule_8.contextual_channels}") - - # Test 8b: Edge case - Update with invalid expression (should fail gracefully) - print("\n--- Test 8b: Edge case - Invalid expression test ---") - try: - invalid_update = rule_8.update( - RuleUpdate( - expression="invalid_expression", - channel_references=[ - ChannelReference( - channel_reference="$1", channel_identifier="mainmotor.velocity" - ), - ], - ) - ) - print(f"Invalid expression update succeeded (unexpected): {invalid_update.expression}") - except Exception as e: - print(f"Invalid expression update failed as expected: {e}") - - # Additional validation - print("\n=== Validation Checks ===") - - # Verify that updates actually changed the values - assert updated_rule_1.expression == "$1 > 0.5", ( - f"Expression update failed: {updated_rule_1.expression}" - ) - # For update 1, also verify that the fields that were not updated are not reset. - assert updated_rule_1_model_dump["description"] == rule_1_model_dump["description"], ( - f"Expected no description change, got {rule_1_model_dump['description']} -> {updated_rule_1.description}" - ) - assert ( - updated_rule_1_model_dump["channel_references"] == rule_1_model_dump["channel_references"] - ), ( - f"Expected no channel references change, got {rule_1_model_dump['channel_references']} -> {updated_rule_1.channel_references}" - ) - assert updated_rule_1_model_dump["asset_ids"] == rule_1_model_dump["asset_ids"], ( - f"Expected no asset IDs change, got {rule_1_model_dump['asset_ids']} -> {updated_rule_1.asset_ids}" - ) - assert updated_rule_1_model_dump["asset_tag_ids"] == rule_1_model_dump["asset_tag_ids"], ( - f"Expected no tag IDs change, got {rule_1_model_dump['asset_tag_ids']} -> {updated_rule_1.asset_tag_ids}" - ) - assert ( - updated_rule_1_model_dump["contextual_channels"] == rule_1_model_dump["contextual_channels"] - ), f"Contextual channels update failed: {updated_rule_1.contextual_channels}" - assert "more details" in updated_rule_2.description, ( - f"Description update failed: {updated_rule_2.description}" - ) - assert updated_rule_3.action.annotation_type == RuleAnnotationType.PHASE, ( - f"Action update failed: {updated_rule_3.action.annotation_type}" - ) - assert updated_rule_4.name == new_name, f"Name update failed: {updated_rule_4.name}" - - assert updated_rule_6.expression == "$1 > 0.3 && $1 < 0.8", ( - f"Complex expression update failed: {updated_rule_6.expression}" - ) - # assert updated_rule_7.action.action_type == RuleActionType.NOTIFICATION, f"Action type update failed: {updated_rule_7.action.action_type}" - # assert len(updated_rule_8.tag_ids) == 2, f"Tag IDs update failed: {updated_rule_8.tag_ids}" - assert len(updated_rule_8.contextual_channels) == 2, ( - f"Contextual channels update failed: {updated_rule_8.contextual_channels}" - ) - - print("All validation checks passed!") - print("\n=== Test completed successfully ===") - - -if __name__ == "__main__": - main() diff --git a/python/lib/sift_client/_tests/integrated/runs.py b/python/lib/sift_client/_tests/integrated/runs.py deleted file mode 100644 index 9c55c2aa1..000000000 --- a/python/lib/sift_client/_tests/integrated/runs.py +++ /dev/null @@ -1,286 +0,0 @@ -#!/usr/bin/env python3 -"""This test demonstrates the usage of the Runs API. - -It creates a new run, updates it, and associates assets with it. -It also lists runs, filters them, and deletes the run. - -It uses the SiftClient to interact with the API. -""" - -import asyncio -import os -from datetime import datetime, timedelta, timezone - -from sift_client import SiftClient - - -async def main(): - """Main function demonstrating the Runs API usage.""" - # Initialize the client - # You can set these environment variables or pass them directly - grpc_url = os.getenv("SIFT_GRPC_URI", "localhost:50051") - rest_url = os.getenv("SIFT_REST_URI", "localhost:8080") - api_key = os.getenv("SIFT_API_KEY", "") - client = SiftClient( - api_key=api_key, - grpc_url=grpc_url, - rest_url=rest_url, - ) - - # Use a known asset to fetch a run. - asset = client.assets.find(name="NostromoLV426") - asset_id = asset.id_ - print(f"Using asset: {asset.name} (ID: {asset_id})") - - # List runs for this asset - runs = asset.runs - print( - f"Found {len(runs)} run(s): {[run.name for run in runs]} for asset {asset.name} (ID: {asset_id})" - ) - - # Pick one. - run = runs[0] - run_id = run.id_ - print(f"Using run: {run.name} (ID: {run_id})") - - # List other assets for this run. - all_assets = run.assets - other_assets = [asset for asset in all_assets if asset.id_ != asset_id] - print( - f"Found {len(other_assets)} other asset(s): {other_assets} for run {run.name} (ID: {run_id})" - ) - - # Example 1: List all runs - print("\n1. Listing all runs...") - runs = client.runs.list_(limit=5) - print(f" Found {len(runs)} runs:") - for run in runs: - print(f" - {run.name} (ID: {run.id_}), Organization ID: {run.organization_id}") - - # Example 2: Test different filter options - print("\n2. Testing different filter options...") - - # Get a sample run for testing filters - sample_runs = client.runs.list_(limit=3) - if not sample_runs: - print(" No runs available for filter testing") - return - - sample_run = sample_runs[0] - - # 2a: Filter by exact name - print("\n 2a. Filter by exact name...") - run_name = sample_run.name - runs = client.runs.list_(name=run_name, limit=5) - print(f" Found {len(runs)} runs with exact name '{run_name}':") - for run in runs: - print(f" - {run.name} (ID: {run.id_})") - - # 2b: Filter by name containing text - print("\n 2b. Filter by name containing text...") - runs = client.runs.list_(name_contains="test", limit=5) - print(f" Found {len(runs)} runs with 'test' in name:") - for run in runs: - print(f" - {run.name}") - - # 2c: Filter by name using regex - print("\n 2c. Filter by name using regex...") - runs = client.runs.list_(name_regex=".*test.*", limit=5) - print(f" Found {len(runs)} runs with 'test' in name (regex):") - for run in runs: - print(f" - {run.name}") - - # 2d: Filter by exact description - print("\n 2d. Filter by description contains...") - if sample_run.description: - runs = client.runs.list_(description_contains=sample_run.description, limit=5) - print(f" Found {len(runs)} runs with description contains '{sample_run.description}':") - for run in runs: - print(f" - {run.name}: {run.description}") - else: - print(" No description available for testing") - - # 2e: Filter by description containing text - print("\n 2e. Filter by description containing text...") - runs = client.runs.list_(description_contains="test", limit=5) - print(f" Found {len(runs)} runs with 'test' in description:") - for run in runs: - print(f" - {run.name}: {run.description}") - - # 2f: Filter by duration seconds - print("\n 2f. Filter by duration seconds...") - # Calculate duration for sample run if it has start and stop times - if sample_run.start_time and sample_run.stop_time: - duration_seconds = int((sample_run.stop_time - sample_run.start_time).total_seconds()) - runs = client.runs.list_(duration_greater_than=timedelta(seconds=duration_seconds), limit=5) - print(f" Found {len(runs)} runs with duration greater than {duration_seconds} seconds:") - for run in runs: - if run.start_time and run.stop_time: - run_duration = int((run.stop_time - run.start_time).total_seconds()) - print(f" - {run.name} (duration: {run_duration}s)") - else: - print(" No start/stop times available for duration testing") - - # 2g: Filter by client key - print("\n 2g. Filter by client key...") - if sample_run.client_key: - run = client.runs.get(client_key=sample_run.client_key) - print(f" Found run with client key '{run.name}'") - else: - print(" No client key available for testing") - - # 2h: Filter by asset ID - print("\n 2h. Filter by asset ID...") - if sample_run.asset_ids: - asset_id = sample_run.asset_ids[0] - runs = client.runs.list_(assets=[asset_id], limit=5) - print(f" Found {len(runs)} runs associated with asset {asset_id}:") - for run in runs: - print(f" - {run.name} (asset_ids: {list(run.asset_ids)})") - else: - print(" No asset IDs available for testing") - - # 2i: Filter by asset name - print("\n 2i. Filter by asset name...") - runs = client.runs.list_(assets=[client.assets.find(name="NostromoLV426")], limit=5) - print(f" Found {len(runs)} runs associated with asset 'NostromoLV426':") - for run in runs: - print(f" - {run.name}") - - # 2j: Filter by created by user ID - print("\n 2j. Filter by created by user ID...") - created_by = sample_run.created_by_user_id - runs = client.runs.list_(created_by=created_by, limit=5) - print(f" Found {len(runs)} runs created by user {created_by}:") - for run in runs: - print(f" - {run.name} (created by: {run.created_by})") - - # 2l: Test ordering options - print("\n 2l. Testing ordering options...") - - # Order by name ascending - runs = client.runs.list_(order_by="name", limit=3) - print(" First 3 runs ordered by name (ascending):") - for run in runs: - print(f" - {run.name}") - - # Order by name descending - runs = client.runs.list_(order_by="name desc", limit=3) - print(" First 3 runs ordered by name (descending):") - for run in runs: - print(f" - {run.name}") - - # Order by creation date (newest first - default) - runs = client.runs.list_(order_by="created_date desc", limit=3) - print(" First 3 runs ordered by creation date (newest first):") - for run in runs: - print(f" - {run.name} (created: {run.created_date})") - - # Order by creation date (oldest first) - runs = client.runs.list_(order_by="created_date", limit=3) - print(" First 3 runs ordered by creation date (oldest first):") - for run in runs: - print(f" - {run.name} (created: {run.created_date})") - - # Example 3: Find a single run by name - print("\n3. Finding a single run by name...") - run_name = "test-run" # Replace with an actual run name - run = client.runs.find(name=run_name) - if run: - print(f" Found run: {run.name}") - print(f" Description: {run.description}") - else: - print(f" No run found with name '{run_name}'") - - # Example 4: Create a new run - print("\n4. Creating a new run...") - # Create metadata for the run - metadata = { - "environment": "production", - "test_type": "integration", - } - - # Create a run with start and stop times - start_time = datetime.now(timezone.utc) - stop_time = start_time + timedelta(minutes=2) - - previously_created_runs = client.runs.list_(name_regex="Example Test Run.*") - if previously_created_runs: - print(f" Deleting previously created runs: {previously_created_runs}") - for run in previously_created_runs: - print(f" Deleting run: {run.name}") - client.runs.archive(run=run) - - new_run = client.runs.create( - dict( - name=f"Example Test Run {datetime.now(tz=timezone.utc).strftime('%Y-%m-%d %H:%M:%S')}", - description="A test run created via the API", - tags=["api-created", "test"], - start_time=start_time, - stop_time=stop_time, - # Use a unique client key for each run - client_key=f"example-run-key-{datetime.now(tz=timezone.utc).timestamp()}", - metadata=metadata, - ) - ) - print(f" Created run: {new_run.name} (ID: {new_run.id_})") - print(f" Client key: {new_run.client_key}") - print(f" Tags: {new_run.tags}") - - # Example 5: Update a run - print("\n5. Updating a run...") - - run_to_update = new_run - print(f" Updating run: {run_to_update.name}") - - # Update the run - new_description = "Updated description via API" - new_metadata = { - "test_type": "ci", - } - new_tags = ["updated", "api-modified"] - updated_run = client.runs.update( - run=run_to_update, - update={ - "description": new_description, - "tags": new_tags, - "metadata": new_metadata, - }, - ) - print(f" Updated run: {updated_run.name}") - print(f" New description: {updated_run.description}") - print(f" New tags: {updated_run.tags}") - print(f" New metadata: {updated_run.metadata}") - assert updated_run.description == new_description - assert sorted(updated_run.tags) == sorted(new_tags) - assert updated_run.metadata == new_metadata - - # Example 6: Associate assets with a run - print("\n6. Associating assets with a run...") - ongoing_runs = client.runs.list_( - name_regex="Example Test Run.*", include_archived=True, is_stopped=False - ) - if ongoing_runs: - print(" Ensuring previously created runs are stopped:") - for run in ongoing_runs: - if run.stop_time is None: - print(f" Stopping run: {run.name}") - client.runs.stop(run=run) - - # Get a run to associate assets with - asset_names = ["asset1", "asset2"] # Replace with actual asset names - print(f" Associating assets {asset_names} with run: {new_run.name}") - - client.runs.create_automatic_association_for_assets(run=new_run, asset_names=asset_names) - print(f" Successfully associated assets with run: {new_run.name}") - - # Example 7: Delete a run - print("\n7. Deleting a run") - run_to_delete = new_run - print(f" Deleting run: {run_to_delete.name}") - client.runs.archive(run=run_to_delete) - print(f" Successfully archived run: {run_to_delete.name}") - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/python/lib/sift_client/_tests/resources/__init__.py b/python/lib/sift_client/_tests/resources/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/lib/sift_client/_tests/resources/test_assets.py b/python/lib/sift_client/_tests/resources/test_assets.py new file mode 100644 index 000000000..5fca01fa8 --- /dev/null +++ b/python/lib/sift_client/_tests/resources/test_assets.py @@ -0,0 +1,198 @@ +"""Pytest tests for the Assets API. + +These tests demonstrate and validate the usage of the Assets API including: +- Basic asset operations (get, list, find) +- Asset filtering and searching +- Asset updates and archiving +- Error handling and edge cases +""" + +import pytest + +from sift_client import SiftClient +from sift_client.resources import AssetsAPI, AssetsAPIAsync +from sift_client.sift_types import Asset + +pytestmark = pytest.mark.integration + + +def test_client_binding(sift_client): + assert sift_client.assets + assert isinstance(sift_client.assets, AssetsAPI) + assert sift_client.async_.assets + assert isinstance(sift_client.async_.assets, AssetsAPIAsync) + + +@pytest.fixture +def assets_api_async(sift_client: SiftClient): + """Get the async assets API instance.""" + return sift_client.async_.assets + + +@pytest.fixture +def assets_api_sync(sift_client: SiftClient): + """Get the synchronous assets API instance.""" + return sift_client.assets + + +@pytest.fixture +def test_asset(assets_api_sync): + assets = assets_api_sync.list_(limit=1) + assert assets + assert len(assets) >= 1 + return assets[0] + + +class TestAssetsAPIAsync: + """Test suite for the async Assets API functionality.""" + + class TestList: + """Tests for the async list_ method.""" + + @pytest.mark.asyncio + async def test_basic_list(self, assets_api_async): + """Test basic asset listing functionality.""" + assets = await assets_api_async.list_(limit=5) + + # Verify we get a list + assert isinstance(assets, list) + assert assets + + # If we have assets, verify their structure + asset = assets[0] + assert isinstance(asset, Asset) + assert asset.id_ is not None + assert asset.name is not None + + @pytest.mark.asyncio + async def test_list_with_name_filter(self, assets_api_async): + """Test asset listing with name filtering.""" + # First get some assets to work with + all_assets = await assets_api_async.list_(limit=10) + + if all_assets: + # Use the first asset's name for filtering + test_asset_name = all_assets[0].name + filtered_assets = await assets_api_async.list_(name=test_asset_name) + + # Should find at least one asset with exact name match + assert isinstance(filtered_assets, list) + assert len(filtered_assets) >= 1 + + # All returned assets should have the exact name + for asset in filtered_assets: + assert asset.name == test_asset_name + + @pytest.mark.asyncio + async def test_list_with_name_contains_filter(self, assets_api_async): + """Test asset listing with name contains filtering.""" + # Test with a common substring that might exist in asset names + assets = await assets_api_async.list_(name_contains="test", limit=5) + + assert isinstance(assets, list) + + # If we found assets, verify they contain the substring + for asset in assets: + assert "test" in asset.name.lower() + + @pytest.mark.asyncio + async def test_list_with_limit(self, assets_api_async): + """Test asset listing with different limits.""" + # Test with limit of 1 + assets_1 = await assets_api_async.list_(limit=1) + assert isinstance(assets_1, list) + assert len(assets_1) <= 1 + + # Test with limit of 3 + assets_3 = await assets_api_async.list_(limit=3) + assert isinstance(assets_3, list) + assert len(assets_3) <= 3 + + @pytest.mark.asyncio + async def test_list_include_archived(self, assets_api_async): + """Test asset listing with archived assets included.""" + # Test without archived assets (default) + assets_active = await assets_api_async.list_(limit=5, include_archived=False) + assert isinstance(assets_active, list) + + # Test with archived assets included + assets_all = await assets_api_async.list_(limit=5, include_archived=True) + assert isinstance(assets_all, list) + + # Should have at least as many assets when including archived + assert len(assets_all) >= len(assets_active) + + class TestGet: + """Tests for the async get method.""" + + @pytest.mark.asyncio + async def test_get_by_name(self, assets_api_async, test_asset): + """Test getting a specific asset by name.""" + retrieved_asset = await assets_api_async.get(name=test_asset.name) + + assert retrieved_asset is not None + assert retrieved_asset.id_ == test_asset.id_ + assert retrieved_asset.name == test_asset.name + + @pytest.mark.asyncio + async def test_get_by_id(self, assets_api_async, test_asset): + """Test getting a specific asset by ID.""" + retrieved_asset = await assets_api_async.get(asset_id=test_asset.id_) + + assert retrieved_asset is not None + assert retrieved_asset.id_ == test_asset.id_ + + @pytest.mark.asyncio + async def test_get_without_params_raises_error(self, assets_api_async): + """Test that getting an asset without parameters raises an error.""" + with pytest.raises(ValueError, match="Either asset_id or name must be provided"): + await assets_api_async.get() + + @pytest.mark.asyncio + async def test_get_nonexistent_asset_raises_error(self, assets_api_async): + """Test that getting a non-existent asset raises an error.""" + with pytest.raises(ValueError, match="No asset found"): + await assets_api_async.get(name="nonexistent-asset-name-12345") + + class TestFind: + """Tests for the async find method.""" + + @pytest.mark.asyncio + async def test_find_asset(self, assets_api_async, test_asset): + """Test finding a single asset.""" + # Find the same asset by name + found_asset = await assets_api_async.find(name=test_asset.name) + + assert found_asset is not None + assert found_asset.id_ == test_asset.id_ + + @pytest.mark.asyncio + async def test_find_nonexistent_asset(self, assets_api_async): + """Test finding a non-existent asset returns None.""" + found_asset = await assets_api_async.find(name="nonexistent-asset-name-12345") + assert found_asset is None + + @pytest.mark.asyncio + async def test_find_multiple_raises_error(self, assets_api_async): + """Test finding multiple assets raises an error.""" + with pytest.raises(ValueError, match="Multiple"): + await assets_api_async.find(name_contains="a") + + +class TestAssetsAPISync: + """Test suite for the synchronous Assets API functionality. + + Only includes a single test for basic sync generation. No specific sync behavior difference tests are needed. + """ + + class TestList: + """Tests for the sync list_ method.""" + + def test_basic_list(self, assets_api_sync): + """Test basic synchronous asset listing functionality.""" + assets = assets_api_sync.list_(limit=5) + + # Verify we get a list + assert isinstance(assets, list) + assert assets + assert isinstance(assets[0], Asset) diff --git a/python/lib/sift_client/_tests/resources/test_calculated_channels.py b/python/lib/sift_client/_tests/resources/test_calculated_channels.py new file mode 100644 index 000000000..3a39bb4b6 --- /dev/null +++ b/python/lib/sift_client/_tests/resources/test_calculated_channels.py @@ -0,0 +1,544 @@ +"""Pytest tests for the Calculated Channels API. + +These tests demonstrate and validate the usage of the Calculated Channels API including: +- Basic calculated channel operations (get, list, find) +- Calculated channel filtering and searching +- Calculated channel creation, updates, and archiving +- Version management +- Error handling and edge cases +""" + +import uuid +from datetime import datetime, timezone + +import pytest + +from sift_client import SiftClient +from sift_client.resources import CalculatedChannelsAPI, CalculatedChannelsAPIAsync +from sift_client.sift_types import CalculatedChannel +from sift_client.sift_types.calculated_channel import ( + CalculatedChannelCreate, + CalculatedChannelUpdate, +) +from sift_client.sift_types.channel import ChannelReference + +pytestmark = pytest.mark.integration + + +def test_client_binding(sift_client): + assert sift_client.calculated_channels + assert isinstance(sift_client.calculated_channels, CalculatedChannelsAPI) + assert sift_client.async_.calculated_channels + assert isinstance(sift_client.async_.calculated_channels, CalculatedChannelsAPIAsync) + + +@pytest.fixture +def calculated_channels_api_async(sift_client: SiftClient): + """Get the async calculated channels API instance.""" + return sift_client.async_.calculated_channels + + +@pytest.fixture +def calculated_channels_api_sync(sift_client: SiftClient): + """Get the synchronous calculated channels API instance.""" + return sift_client.calculated_channels + + +@pytest.fixture +def test_calculated_channel(calculated_channels_api_sync): + calculated_channels = calculated_channels_api_sync.list_(limit=1) + assert calculated_channels + assert len(calculated_channels) >= 1 + return calculated_channels[0] + + +@pytest.fixture(scope="function") +def new_calculated_channel(calculated_channels_api_sync, sift_client): + """Create a test calculated channel for update tests.""" + from datetime import datetime, timezone + + calc_channel_name = f"test_calc_channel_{datetime.now(timezone.utc).isoformat()}" + description = "Test calculated channel created by Sift Client pytest" + + # Get some channels to reference + channels = sift_client.channels.list_(limit=2) + assert len(channels) >= 2 + + created_calc_channel = calculated_channels_api_sync.create( + CalculatedChannelCreate( + name=calc_channel_name, + client_key=f"test_calc_chan_{str(uuid.uuid4())[-8:]}", + description=description, + expression="$1 + $2", + expression_channel_references=[ + ChannelReference(channel_reference="$1", channel_identifier=channels[0].name), + ChannelReference(channel_reference="$2", channel_identifier=channels[1].name), + ], + all_assets=True, + ) + ) + return created_calc_channel + + +class TestCalculatedChannelsAPIAsync: + """Test suite for the async Calculated Channels API functionality.""" + + class TestGet: + """Tests for the async get method.""" + + @pytest.mark.asyncio + async def test_get_by_id(self, calculated_channels_api_async, test_calculated_channel): + """Test getting a specific calculated channel by ID.""" + retrieved_calc_channel = await calculated_channels_api_async.get( + calculated_channel_id=test_calculated_channel.id_ + ) + + assert isinstance(retrieved_calc_channel, CalculatedChannel) + assert retrieved_calc_channel.id_ == test_calculated_channel.id_ + assert retrieved_calc_channel.name == test_calculated_channel.name + + @pytest.mark.asyncio + async def test_get_by_client_key( + self, calculated_channels_api_async, test_calculated_channel + ): + """Test getting a specific calculated channel by client key.""" + retrieved_calc_channel = await calculated_channels_api_async.get( + client_key=test_calculated_channel.client_key + ) + + assert retrieved_calc_channel is not None + assert retrieved_calc_channel.id_ == test_calculated_channel.id_ + + class TestList: + """Tests for the async list_ method.""" + + @pytest.mark.asyncio + async def test_basic_list(self, calculated_channels_api_async): + """Test basic calculated channel listing functionality.""" + calc_channels = await calculated_channels_api_async.list_(limit=5) + + assert isinstance(calc_channels, list) + assert len(calc_channels) == 5 + + calc_channel = calc_channels[0] + assert isinstance(calc_channel, CalculatedChannel) + + @pytest.mark.asyncio + async def test_list_with_name_filter(self, calculated_channels_api_async): + """Test calculated channel listing with name filtering.""" + all_calc_channels = await calculated_channels_api_async.list_(limit=10) + + test_calc_channel_name = all_calc_channels[0].name + filtered_calc_channels = await calculated_channels_api_async.list_( + name=test_calc_channel_name + ) + + assert isinstance(filtered_calc_channels, list) + assert len(filtered_calc_channels) >= 1 + + for calc_channel in filtered_calc_channels: + assert calc_channel.name == test_calc_channel_name + + @pytest.mark.asyncio + async def test_list_with_name_contains_filter(self, calculated_channels_api_async): + """Test calculated channel listing with name contains filtering.""" + calc_channels = await calculated_channels_api_async.list_(name_contains="test", limit=5) + + assert isinstance(calc_channels, list) + assert calc_channels + + for calc_channel in calc_channels: + assert "test" in calc_channel.name.lower() + + @pytest.mark.asyncio + async def test_list_with_name_regex_filter(self, calculated_channels_api_async): + """Test calculated channel listing with regex name filtering.""" + calc_channels = await calculated_channels_api_async.list_( + name_regex=r".*test.*", limit=5 + ) + + assert isinstance(calc_channels, list) + assert calc_channels + + import re + + pattern = re.compile(r".*test.*", re.IGNORECASE) + for calc_channel in calc_channels: + assert pattern.match(calc_channel.name) + + @pytest.mark.asyncio + async def test_list_with_limit(self, calculated_channels_api_async): + """Test calculated channel listing with different limits.""" + calc_channels_1 = await calculated_channels_api_async.list_(limit=1) + assert isinstance(calc_channels_1, list) + assert len(calc_channels_1) <= 1 + + calc_channels_3 = await calculated_channels_api_async.list_(limit=3) + assert isinstance(calc_channels_3, list) + assert len(calc_channels_3) <= 3 + + class TestFind: + """Tests for the async find method.""" + + @pytest.mark.asyncio + async def test_find_calculated_channel( + self, calculated_channels_api_async, test_calculated_channel + ): + """Test finding a single calculated channel.""" + found_calc_channel = await calculated_channels_api_async.find( + name=test_calculated_channel.name + ) + + assert found_calc_channel is not None + assert found_calc_channel.id_ == test_calculated_channel.id_ + + @pytest.mark.asyncio + async def test_find_nonexistent_calculated_channel(self, calculated_channels_api_async): + """Test finding a non-existent calculated channel returns None.""" + found_calc_channel = await calculated_channels_api_async.find( + name="nonexistent-calculated-channel-name-12345" + ) + assert found_calc_channel is None + + @pytest.mark.asyncio + async def test_find_multiple_raises_error(self, calculated_channels_api_async): + """Test finding multiple calculated channels raises an error.""" + with pytest.raises(ValueError, match="Multiple"): + await calculated_channels_api_async.find(name_contains="test", limit=5) + + class TestCreate: + """Tests for the async create method.""" + + @pytest.mark.asyncio + async def test_create_basic_calculated_channel(self, calculated_channels_api_async): + """Test creating a basic calculated channel with minimal fields.""" + from datetime import datetime, timezone + + calc_channel_name = f"test_calc_channel_create_{datetime.now(timezone.utc).isoformat()}" + description = "Test calculated channel created by Sift Client pytest" + + channels = await calculated_channels_api_async.client.async_.channels.list_(limit=2) + assert len(channels) >= 2 + + calc_channel_create = CalculatedChannelCreate( + name=calc_channel_name, + description=description, + expression="$1 + $2", + expression_channel_references=[ + ChannelReference(channel_reference="$1", channel_identifier=channels[0].name), + ChannelReference(channel_reference="$2", channel_identifier=channels[1].name), + ], + all_assets=True, + ) + + created_calc_channel = await calculated_channels_api_async.create(calc_channel_create) + + try: + assert created_calc_channel is not None + assert isinstance(created_calc_channel, CalculatedChannel) + assert created_calc_channel.id_ is not None + assert created_calc_channel.name == calc_channel_name + assert created_calc_channel.description == description + assert created_calc_channel.created_date is not None + assert created_calc_channel.modified_date is not None + finally: + await calculated_channels_api_async.archive(created_calc_channel) + + @pytest.mark.asyncio + async def test_create_calculated_channel_with_dict(self, calculated_channels_api_async): + """Test creating a calculated channel using a dictionary.""" + from datetime import datetime, timezone + + calc_channel_name = f"test_calc_channel_dict_{datetime.now(timezone.utc).isoformat()}" + description = "Test calculated channel created by Sift Client pytest" + + channels = await calculated_channels_api_async.client.async_.channels.list_(limit=2) + assert len(channels) >= 2 + + calc_channel_dict = { + "name": calc_channel_name, + "description": description, + "expression": "$1 + $2", + "expression_channel_references": [ + {"channel_reference": "$1", "channel_identifier": channels[0].name}, + {"channel_reference": "$2", "channel_identifier": channels[1].name}, + ], + "all_assets": True, + } + + created_calc_channel = await calculated_channels_api_async.create(calc_channel_dict) + + try: + assert created_calc_channel.name == calc_channel_name + assert created_calc_channel.description == description + finally: + await calculated_channels_api_async.archive(created_calc_channel) + + class TestUpdate: + """Tests for the async update method.""" + + @pytest.mark.asyncio + async def test_update_calculated_channel_description( + self, calculated_channels_api_async, new_calculated_channel + ): + """Test updating a calculated channel's description.""" + try: + update = CalculatedChannelUpdate(description="Updated description") + updated_calc_channel = await calculated_channels_api_async.update( + new_calculated_channel, update + ) + + assert updated_calc_channel.id_ == new_calculated_channel.id_ + assert updated_calc_channel.description == "Updated description" + assert updated_calc_channel.name == new_calculated_channel.name + finally: + await calculated_channels_api_async.archive(new_calculated_channel.id_) + + @pytest.mark.asyncio + async def test_update_calculated_channel_name( + self, calculated_channels_api_async, new_calculated_channel + ): + """Test updating a calculated channel's name.""" + try: + new_name = f"updated_{new_calculated_channel.name}" + update = CalculatedChannelUpdate(name=new_name) + updated_calc_channel = await calculated_channels_api_async.update( + new_calculated_channel, update + ) + + assert updated_calc_channel.name == new_name + assert updated_calc_channel.id_ == new_calculated_channel.id_ + finally: + await calculated_channels_api_async.archive(new_calculated_channel.id_) + + @pytest.mark.asyncio + async def test_update_with_dict( + self, calculated_channels_api_async, new_calculated_channel + ): + """Test updating a calculated channel using a dictionary.""" + try: + update_dict = {"description": "Updated via dict"} + updated_calc_channel = await calculated_channels_api_async.update( + new_calculated_channel, update_dict + ) + + assert updated_calc_channel.description == "Updated via dict" + finally: + await calculated_channels_api_async.archive(new_calculated_channel.id_) + + @pytest.mark.asyncio + async def test_update_with_id_string( + self, calculated_channels_api_async, new_calculated_channel + ): + """Test updating a calculated channel by passing ID as string.""" + try: + update = CalculatedChannelUpdate(description="Updated via ID string") + updated_calc_channel = await calculated_channels_api_async.update( + new_calculated_channel.id_, update + ) + + assert updated_calc_channel.id_ == new_calculated_channel.id_ + assert updated_calc_channel.description == "Updated via ID string" + finally: + await calculated_channels_api_async.archive(new_calculated_channel.id_) + + @pytest.mark.asyncio + async def test_update_calculated_channel_units( + self, calculated_channels_api_async, new_calculated_channel + ): + """Test updating a calculated channel's units.""" + try: + update = CalculatedChannelUpdate(units="percentage") + updated_calc_channel = await calculated_channels_api_async.update( + new_calculated_channel, update + ) + + assert updated_calc_channel.id_ == new_calculated_channel.id_ + assert updated_calc_channel.units == "percentage" + assert updated_calc_channel.name == new_calculated_channel.name + finally: + await calculated_channels_api_async.archive(new_calculated_channel.id_) + + @pytest.mark.asyncio + async def test_update_with_complex_expression( + self, calculated_channels_api_async, sift_client + ): + """Test updating a calculated channel with a complex expression using multiple channel references.""" + # Get channels to reference + channels = await sift_client.async_.channels.list_(limit=3) + assert len(channels) >= 3 + + # Create a calculated channel + calc_channel_name = ( + f"test_calc_channel_complex_{datetime.now(timezone.utc).isoformat()}" + ) + calc_channel_create = CalculatedChannelCreate( + name=calc_channel_name, + description="Test calculated channel for complex expression update", + expression="$1 + $2", + expression_channel_references=[ + ChannelReference(channel_reference="$1", channel_identifier=channels[0].name), + ChannelReference(channel_reference="$2", channel_identifier=channels[1].name), + ], + all_assets=True, + ) + created_calc_channel = await calculated_channels_api_async.create(calc_channel_create) + + try: + # Update with complex expression + update = CalculatedChannelUpdate( + expression="($1 / $2) * 100 + ($3 * 0.1)", + expression_channel_references=[ + ChannelReference( + channel_reference="$1", channel_identifier=channels[0].name + ), + ChannelReference( + channel_reference="$2", channel_identifier=channels[1].name + ), + ChannelReference( + channel_reference="$3", channel_identifier=channels[2].name + ), + ], + ) + updated_calc_channel = await calculated_channels_api_async.update( + created_calc_channel, update + ) + + assert updated_calc_channel.id_ == created_calc_channel.id_ + assert updated_calc_channel.expression == "($1 / $2) * 100 + ($3 * 0.1)" + assert len(updated_calc_channel.channel_references) == 3 + # Verify all three channel references are present + ref_identifiers = { + ref.channel_identifier for ref in updated_calc_channel.channel_references + } + assert channels[0].name in ref_identifiers + assert channels[1].name in ref_identifiers + assert channels[2].name in ref_identifiers + finally: + await calculated_channels_api_async.archive(created_calc_channel.id_) + + @pytest.mark.asyncio + async def test_update_with_invalid_expression( + self, calculated_channels_api_async, new_calculated_channel + ): + """Test updating a calculated channel with an invalid expression. + + Note: The server may or may not validate expression syntax at update time. + This test documents the current behavior. + """ + try: + # Attempt to update with an invalid expression + update = CalculatedChannelUpdate( + expression="invalid_expression", + expression_channel_references=[ + ChannelReference( + channel_reference="$1", + channel_identifier=new_calculated_channel.channel_references[ + 0 + ].channel_identifier, + ), + ], + ) + + # This may succeed or fail depending on server-side validation + # If it succeeds, the expression is stored but may fail at evaluation time + try: + updated_calc_channel = await calculated_channels_api_async.update( + new_calculated_channel, update + ) + # If update succeeds, verify the expression was set + assert updated_calc_channel.expression == "invalid_expression" + except Exception as e: + # If server validates and rejects, that's also acceptable behavior + assert "expression" in str(e).lower() or "invalid" in str(e).lower() + finally: + await calculated_channels_api_async.archive(new_calculated_channel.id_) + + class TestArchive: + """Tests for the async archive method.""" + + @pytest.mark.asyncio + async def test_archive_calculated_channel( + self, calculated_channels_api_async, new_calculated_channel + ): + """Test archiving a calculated channel.""" + calc_channel = await calculated_channels_api_async.archive(new_calculated_channel) + + assert isinstance(calc_channel, CalculatedChannel) + assert calc_channel.id_ == new_calculated_channel.id_ + assert calc_channel.is_archived is True + + calc_channels_without_archived = await calculated_channels_api_async.list_( + name=new_calculated_channel.name, include_archived=False + ) + assert len(calc_channels_without_archived) == 0 + + calc_channels_with_archived = await calculated_channels_api_async.list_( + name=new_calculated_channel.name, include_archived=True + ) + assert len(calc_channels_with_archived) == 1 + assert calc_channels_with_archived[0].id_ == new_calculated_channel.id_ + assert calc_channels_with_archived[0].archived_date is not None + + @pytest.mark.asyncio + async def test_archive_with_id_string( + self, calculated_channels_api_async, new_calculated_channel + ): + """Test archiving a calculated channel by passing ID as string.""" + calc_channel = await calculated_channels_api_async.archive(new_calculated_channel.id_) + + assert isinstance(calc_channel, CalculatedChannel) + assert calc_channel.id_ == new_calculated_channel.id_ + assert calc_channel.is_archived is True + + class TestUnarchive: + """Tests for the async unarchive method.""" + + @pytest.mark.asyncio + async def test_unarchive_calculated_channel( + self, calculated_channels_api_async, new_calculated_channel + ): + """Test unarchiving a calculated channel.""" + try: + await calculated_channels_api_async.archive(new_calculated_channel) + + calc_channel = await calculated_channels_api_async.unarchive(new_calculated_channel) + + assert isinstance(calc_channel, CalculatedChannel) + assert calc_channel.id_ == new_calculated_channel.id_ + assert calc_channel.is_archived is False + finally: + await calculated_channels_api_async.archive(new_calculated_channel.id_) + + class TestListVersions: + """Tests for the async list_versions method.""" + + @pytest.mark.asyncio + async def test_list_versions(self, calculated_channels_api_async, test_calculated_channel): + """Test listing versions of a calculated channel.""" + versions = await calculated_channels_api_async.list_versions( + calculated_channel=test_calculated_channel + ) + + assert isinstance(versions, list) + assert len(versions) >= 1 + + for version in versions: + assert isinstance(version, CalculatedChannel) + assert version.name == test_calculated_channel.name + + +class TestCalculatedChannelsAPISync: + """Test suite for the synchronous Calculated Channels API functionality.""" + + class TestGet: + """Tests for the sync get method.""" + + def test_get_by_id(self, calculated_channels_api_sync, test_calculated_channel): + """Test getting a specific calculated channel by ID synchronously.""" + retrieved_calc_channel = calculated_channels_api_sync.get( + calculated_channel_id=test_calculated_channel.id_ + ) + + assert retrieved_calc_channel is not None + assert retrieved_calc_channel.id_ == test_calculated_channel.id_ diff --git a/python/lib/sift_client/_tests/resources/test_channels.py b/python/lib/sift_client/_tests/resources/test_channels.py new file mode 100644 index 000000000..8afdb60be --- /dev/null +++ b/python/lib/sift_client/_tests/resources/test_channels.py @@ -0,0 +1,358 @@ +"""Pytest tests for the Channels API. + +These tests demonstrate and validate the usage of the Channels API including: +- Basic channel operations (get, list, find) +- Channel filtering and searching +- Channel data retrieval +- Error handling and edge cases +""" + +import pytest + +from sift_client import SiftClient +from sift_client.resources import ChannelsAPI, ChannelsAPIAsync +from sift_client.sift_types import Channel + +pytestmark = pytest.mark.integration + + +def test_client_binding(sift_client): + assert sift_client.channels + assert isinstance(sift_client.channels, ChannelsAPI) + assert sift_client.async_.channels + assert isinstance(sift_client.async_.channels, ChannelsAPIAsync) + + +@pytest.fixture +def channels_api_async(sift_client: SiftClient): + """Get the async channels API instance.""" + return sift_client.async_.channels + + +@pytest.fixture +def channels_api_sync(sift_client: SiftClient): + """Get the synchronous channels API instance.""" + return sift_client.channels + + +@pytest.fixture +def test_channel(channels_api_sync): + channels = channels_api_sync.list_(limit=1) + assert channels + assert len(channels) >= 1 + return channels[0] + + +class TestChannelsAPIAsync: + """Test suite for the async Channels API functionality.""" + + class TestGet: + """Tests for the async get method.""" + + @pytest.mark.asyncio + async def test_get_by_id(self, channels_api_async, test_channel): + """Test getting a specific channel by ID.""" + retrieved_channel = await channels_api_async.get(channel_id=test_channel.id_) + + assert isinstance(retrieved_channel, Channel) + assert retrieved_channel.id_ == test_channel.id_ + assert retrieved_channel.name == test_channel.name + + class TestList: + """Tests for the async list_ method.""" + + @pytest.mark.asyncio + async def test_basic_list(self, channels_api_async): + """Test basic channel listing functionality.""" + channels = await channels_api_async.list_(limit=5) + + # Verify we get a list + assert isinstance(channels, list) + assert len(channels) == 5 + + # If we have channels, verify their structure + channel = channels[0] + assert isinstance(channel, Channel) + + @pytest.mark.asyncio + async def test_list_with_name_filter(self, channels_api_async): + """Test channel listing with name filtering.""" + # First get some channels to work with + all_channels = await channels_api_async.list_(limit=10) + + test_channel_name = all_channels[0].name + filtered_channels = await channels_api_async.list_(name=test_channel_name) + + # Should find at least one channel with exact name match + assert isinstance(filtered_channels, list) + assert len(filtered_channels) >= 1 + + # All returned channels should have the exact name + for channel in filtered_channels: + assert channel.name == test_channel_name + + @pytest.mark.asyncio + async def test_list_with_name_contains_filter(self, channels_api_async): + """Test channel listing with name contains filtering.""" + # Test with a common substring that might exist in channel names + channels = await channels_api_async.list_(name_contains="test", limit=5) + + assert isinstance(channels, list) + assert channels + + for channel in channels: + assert "test" in channel.name.lower() + + @pytest.mark.asyncio + async def test_list_with_name_regex_filter(self, channels_api_async): + """Test channel listing with regex name filtering.""" + # Test with a regex pattern + channels = await channels_api_async.list_(name_regex=r".*test.*", limit=5) + + assert isinstance(channels, list) + assert channels + + import re + + pattern = re.compile(r".*test.*", re.IGNORECASE) + for channel in channels: + assert pattern.match(channel.name) + + @pytest.mark.asyncio + async def test_list_with_channel_ids_filter(self, channels_api_async): + """Test channel listing with channel IDs filter.""" + all_channels = await channels_api_async.list_(limit=3) + + if all_channels: + channel_ids = [ch.id_ for ch in all_channels] + filtered_channels = await channels_api_async.list_(channel_ids=channel_ids) + + # Should find at least the channels we specified + assert isinstance(filtered_channels, list) + assert len(filtered_channels) >= len(all_channels) + + # All returned channels should have IDs in our list + for channel in filtered_channels: + assert channel.id_ in channel_ids + + @pytest.mark.asyncio + async def test_list_with_asset_filter(self, channels_api_async): + """Test channel listing with asset filter.""" + # First get a channel to get its asset + all_channels = await channels_api_async.list_(limit=1) + + if all_channels: + test_channel = all_channels[0] + # Filter by asset ID + filtered_channels = await channels_api_async.list_(asset=test_channel.asset_id) + + # Should find at least one channel for this asset + assert isinstance(filtered_channels, list) + assert len(filtered_channels) >= 1 + + # All returned channels should belong to the same asset + for channel in filtered_channels: + assert channel.asset_id == test_channel.asset_id + + @pytest.mark.asyncio + async def test_list_with_description_contains_filter(self, channels_api_async): + """Test channel listing with description contains filtering.""" + # Test with a common substring that might exist in descriptions + channels = await channels_api_async.list_(description_contains="test", limit=5) + + assert isinstance(channels, list) + assert channels + + # If we found channels, verify they contain the substring in description + for channel in channels: + assert "test" in channel.description.lower() + + @pytest.mark.asyncio + async def test_list_with_limit(self, channels_api_async): + """Test channel listing with different limits.""" + # Test with limit of 1 + channels_1 = await channels_api_async.list_(limit=1) + assert isinstance(channels_1, list) + assert len(channels_1) <= 1 + + # Test with limit of 3 + channels_3 = await channels_api_async.list_(limit=3) + assert isinstance(channels_3, list) + assert len(channels_3) <= 3 + + # TODO: active channel test + # @pytest.mark.asyncio + # async def test_list_include_archived(self, channels_api_async): + # """Test channel listing with archived channels included.""" + # # Test without archived channels (default) + # channels_active = await channels_api_async.list_(limit=5, include_archived=False) + # assert isinstance(channels_active, list) + # + # # Test with archived channels included + # channels_all = await channels_api_async.list_(limit=5, include_archived=True) + # assert isinstance(channels_all, list) + # + # # Should have at least as many channels when including archived + # assert len(channels_all) >= len(channels_active) + + @pytest.mark.asyncio + async def test_list_with_time_filters(self, channels_api_async): + """Test channel listing with time-based filters.""" + from datetime import datetime, timedelta, timezone + + # Get channels created in the last year + one_year_ago = datetime.now(timezone.utc) - timedelta(days=365) + channels = await channels_api_async.list_(created_after=one_year_ago, limit=5) + + assert isinstance(channels, list) + assert channels + + # If we found channels, verify they were created after the specified time + for channel in channels: + assert channel.created_date >= one_year_ago + + class TestFind: + """Tests for the async find method.""" + + @pytest.mark.asyncio + async def test_find_channel(self, channels_api_async, test_channel): + """Test finding a single channel.""" + # Find the same channel by name and asset + found_channel = await channels_api_async.find( + name=test_channel.name, asset=test_channel.asset_id + ) + + assert found_channel is not None + assert found_channel.id_ == test_channel.id_ + + @pytest.mark.asyncio + async def test_find_nonexistent_channel(self, channels_api_async): + """Test finding a non-existent channel returns None.""" + found_channel = await channels_api_async.find(name="nonexistent-channel-name-12345") + assert found_channel is None + + @pytest.mark.asyncio + async def test_find_multiple_raises_error(self, channels_api_async): + """Test finding multiple channels raises an error.""" + with pytest.raises(ValueError, match="Multiple"): + await channels_api_async.find(name_contains="test", limit=5) + + # TODO: data retrieval tests + # class TestGetData: + # """Tests for the async get_data method.""" + # + # @pytest.mark.asyncio + # async def test_get_data_basic(self, channels_api_async, test_channel): + # """Test getting channel data.""" + # # Get the channel's asset to find a run + # from sift_client.sift_types import Asset + # + # asset = await channels_api_async.client.async_.assets.get( + # asset_id=test_channel.asset_id + # ) + # assert isinstance(asset, Asset) + # + # # Get runs for this asset + # runs = await channels_api_async.client.async_.runs.list_(limit=1) + # assert runs + # run = runs[0] + # # Get data for the channel + # data = await channels_api_async.get_data( + # channels=[test_channel], run=run.id_, limit=10 + # ) + # + # # Verify we get a dictionary + # assert isinstance(data, dict) + # assert data + # + # # Should have an entry for our channel + # assert test_channel.name in data or len(data) > 0 + # + # @pytest.mark.asyncio + # async def test_get_data_with_time_range(self, channels_api_async, test_channel): + # """Test getting channel data with time range.""" + # from datetime import datetime, timedelta, timezone + # + # # Get runs for this asset + # runs = await channels_api_async.client.async_.runs.list_(limit=1) + # assert runs + # run = runs[0] + # # Define a time range + # end_time = datetime.now(timezone.utc) + # start_time = end_time - timedelta(hours=1) + # + # # Get data for the channel with time range + # data = await channels_api_async.get_data( + # channels=[test_channel], + # run=run.id_, + # start_time=start_time, + # end_time=end_time, + # limit=10, + # ) + # + # # Verify we get a dictionary + # assert isinstance(data, dict) + # + # @pytest.mark.asyncio + # async def test_get_data_as_arrow(self, channels_api_async, test_channel): + # """Test getting channel data as Arrow table.""" + # import pyarrow as pa + # + # # Get runs for this asset + # runs = await channels_api_async.client.async_.runs.list_(limit=1) + # assert runs + # run = runs[0] + # # Get data as Arrow table + # data = await channels_api_async.get_data_as_arrow( + # channels=[test_channel], run=run.id_, limit=10 + # ) + # + # # Verify we get a dictionary + # assert isinstance(data, dict) + # assert data + # for table in data.values(): + # assert isinstance(table, pa.Table) + # + # @pytest.mark.asyncio + # async def test_get_data_multiple_channels(self, channels_api_async): + # """Test getting data for multiple channels.""" + # # Get multiple channels from the same asset + # channels = await channels_api_async.list_(limit=3) + # + # if len(channels) >= 2: + # # Get the first asset's channels + # first_asset_id = channels[0].asset_id + # asset_channels = [ + # ch for ch in channels if ch.asset_id == first_asset_id + # ][:2] + # + # if len(asset_channels) >= 2: + # # Get runs for this asset + # runs = await channels_api_async.client.async_.runs.list_(limit=1) + # + # if runs: + # run = runs[0] + # # Get data for multiple channels + # data = await channels_api_async.get_data( + # channels=asset_channels, run=run.id_, limit=10 + # ) + # + # # Verify we get a dictionary + # assert isinstance(data, dict) + + +class TestChannelsAPISync: + """Test suite for the synchronous Channels API functionality. + + Only includes a single test for basic sync generation. No specific sync behavior difference tests are needed. + """ + + class TestGet: + """Tests for the sync get method.""" + + def test_get_by_id(self, channels_api_sync, test_channel): + """Test getting a specific channel by ID synchronously.""" + retrieved_channel = channels_api_sync.get(channel_id=test_channel.id_) + + assert retrieved_channel is not None + assert retrieved_channel.id_ == test_channel.id_ diff --git a/python/lib/sift_client/_tests/resources/test_ingestion.py b/python/lib/sift_client/_tests/resources/test_ingestion.py new file mode 100644 index 000000000..1858ff10b --- /dev/null +++ b/python/lib/sift_client/_tests/resources/test_ingestion.py @@ -0,0 +1,444 @@ +"""Pytest tests for the Ingestion API. + +These tests demonstrate and validate the usage of the Ingestion API including: +- Creating ingestion configurations +- Ingesting data with various channel types (double, enum, bit field) +- Flow management and validation +- High-speed and regular flow ingestion +- Error handling and edge cases +""" + +import math +import random +import time +from datetime import datetime, timedelta, timezone + +import pytest + +from sift_client import SiftClient +from sift_client.sift_types.channel import ChannelBitFieldElement, ChannelDataType +from sift_client.sift_types.ingestion import ChannelConfig, Flow + +pytestmark = pytest.mark.integration + +ASSET_NAME = "test-ingestion-asset" + + +def test_client_binding(sift_client): + assert getattr(sift_client, "ingestion", None) is None # Only async! + assert sift_client.async_.ingestion + + +@pytest.fixture(scope="function") +def test_run(sift_client: SiftClient): + """Create a test run for ingestion tests.""" + run = sift_client.runs.create( + { + "name": f"test-ingestion-run-{datetime.now(tz=timezone.utc).timestamp()}", + "description": "Test run for ingestion integration tests", + "tags": ["test", "ingestion", "pytest"], + } + ) + yield run + # Cleanup + sift_client.runs.archive(run=run) + + +class TestIngestionAPIAsync: + """Test suite for the async Ingestion API functionality.""" + + class TestCreateIngestionConfig: + """Tests for creating ingestion configurations.""" + + @pytest.mark.asyncio + async def test_create_basic_config(self, sift_client, test_run): + """Test creating a basic ingestion configuration.""" + flow = Flow( + name="test-basic-flow", + channels=[ + ChannelConfig(name="test-channel", data_type=ChannelDataType.DOUBLE), + ], + ) + + config_id = await sift_client.async_.ingestion.create_ingestion_config( + asset_name=ASSET_NAME, + run_id=test_run.id_, + flows=[flow], + ) + + assert config_id is not None + assert isinstance(config_id, str) + + @pytest.mark.asyncio + async def test_create_config_with_multiple_flows(self, sift_client, test_run): + """Test creating an ingestion configuration with multiple flows.""" + regular_flow = Flow( + name="test-regular-flow", + channels=[ + ChannelConfig(name="regular-channel", data_type=ChannelDataType.DOUBLE), + ], + ) + + highspeed_flow = Flow( + name="test-highspeed-flow", + channels=[ + ChannelConfig(name="highspeed-channel", data_type=ChannelDataType.DOUBLE), + ], + ) + + config_id = await sift_client.async_.ingestion.create_ingestion_config( + asset_name=ASSET_NAME, + run_id=test_run.id_, + flows=[regular_flow, highspeed_flow], + ) + + assert config_id is not None + + @pytest.mark.asyncio + async def test_create_config_with_enum_channel(self, sift_client, test_run): + """Test creating an ingestion configuration with enum channel.""" + flow = Flow( + name="test-enum-flow", + channels=[ + ChannelConfig( + name="test-enum-channel", + data_type=ChannelDataType.ENUM, + enum_types={"state1": 1, "state2": 2, "state3": 3}, + ), + ], + ) + + config_id = await sift_client.async_.ingestion.create_ingestion_config( + asset_name=ASSET_NAME, + run_id=test_run.id_, + flows=[flow], + ) + + assert config_id is not None + + @pytest.mark.asyncio + async def test_create_config_with_bit_field_channel(self, sift_client, test_run): + """Test creating an ingestion configuration with bit field channel.""" + flow = Flow( + name="test-bitfield-flow", + channels=[ + ChannelConfig( + name="test-bit-field-channel", + data_type=ChannelDataType.BIT_FIELD, + bit_field_elements=[ + ChannelBitFieldElement(name="voltage", index=0, bit_count=4), + ChannelBitFieldElement(name="current", index=4, bit_count=2), + ChannelBitFieldElement(name="status", index=6, bit_count=2), + ], + ), + ], + ) + + config_id = await sift_client.async_.ingestion.create_ingestion_config( + asset_name=ASSET_NAME, + run_id=test_run.id_, + flows=[flow], + ) + + assert config_id is not None + + @pytest.mark.asyncio + async def test_flow_sealed_after_config_creation(self, sift_client, test_run): + """Test that flows are sealed after ingestion config creation.""" + flow = Flow( + name="test-sealed-flow", + channels=[ + ChannelConfig(name="test-channel", data_type=ChannelDataType.DOUBLE), + ], + ) + + await sift_client.async_.ingestion.create_ingestion_config( + asset_name=ASSET_NAME, + run_id=test_run.id_, + flows=[flow], + ) + + # Try to add a channel after config creation + with pytest.raises(ValueError, match="Cannot add a channel to a flow after creation"): + flow.add_channel( + ChannelConfig(name="new-channel", data_type=ChannelDataType.DOUBLE) + ) + + class TestIngestData: + """Tests for ingesting data.""" + + @pytest.mark.asyncio + async def test_ingest_double_data(self, sift_client, test_run): + """Test ingesting double data.""" + flow = Flow( + name="test-double-flow", + channels=[ + ChannelConfig(name="double-channel", data_type=ChannelDataType.DOUBLE), + ], + ) + + await sift_client.async_.ingestion.create_ingestion_config( + asset_name=ASSET_NAME, + run_id=test_run.id_, + flows=[flow], + ) + + start_time = datetime.now(tz=timezone.utc) + for i in range(10): + timestamp = start_time + timedelta(seconds=i) + flow.ingest( + timestamp=timestamp, + channel_values={"double-channel": float(i)}, + ) + + sift_client.async_.ingestion.wait_for_ingestion_to_complete(timeout=2) + + @pytest.mark.asyncio + async def test_ingest_enum_data(self, sift_client, test_run): + """Test ingesting enum data.""" + flow = Flow( + name="test-enum-ingest-flow", + channels=[ + ChannelConfig( + name="enum-channel", + data_type=ChannelDataType.ENUM, + enum_types={"low": 1, "medium": 2, "high": 3}, + ), + ], + ) + + await sift_client.async_.ingestion.create_ingestion_config( + asset_name=ASSET_NAME, + run_id=test_run.id_, + flows=[flow], + ) + + start_time = datetime.now(tz=timezone.utc) + for i in range(10): + timestamp = start_time + timedelta(seconds=i) + flow.ingest( + timestamp=timestamp, + channel_values={"enum-channel": (i % 3) + 1}, + ) + + sift_client.async_.ingestion.wait_for_ingestion_to_complete(timeout=2) + + @pytest.mark.asyncio + async def test_ingest_bit_field_data_as_dict(self, sift_client, test_run): + """Test ingesting bit field data as dictionary.""" + flow = Flow( + name="test-bitfield-ingest-flow", + channels=[ + ChannelConfig( + name="bitfield-channel", + data_type=ChannelDataType.BIT_FIELD, + bit_field_elements=[ + ChannelBitFieldElement(name="voltage", index=0, bit_count=4), + ChannelBitFieldElement(name="current", index=4, bit_count=2), + ChannelBitFieldElement(name="led", index=6, bit_count=1), + ChannelBitFieldElement(name="heater", index=7, bit_count=1), + ], + ), + ], + ) + + await sift_client.async_.ingestion.create_ingestion_config( + asset_name=ASSET_NAME, + run_id=test_run.id_, + flows=[flow], + ) + + start_time = datetime.now(tz=timezone.utc) + for i in range(10): + timestamp = start_time + timedelta(seconds=i) + flow.ingest( + timestamp=timestamp, + channel_values={ + "bitfield-channel": { + "voltage": random.randint(3, 13), + "current": random.randint(1, 3), + "led": random.choice([0, 1]), + "heater": random.choice([0, 1]), + } + }, + ) + + sift_client.async_.ingestion.wait_for_ingestion_to_complete(timeout=2) + + @pytest.mark.asyncio + async def test_ingest_bit_field_data_as_bytes(self, sift_client, test_run): + """Test ingesting bit field data as bytes.""" + flow = Flow( + name="test-bitfield-bytes-flow", + channels=[ + ChannelConfig( + name="bitfield-channel", + data_type=ChannelDataType.BIT_FIELD, + bit_field_elements=[ + ChannelBitFieldElement(name="field1", index=0, bit_count=4), + ChannelBitFieldElement(name="field2", index=4, bit_count=4), + ], + ), + ], + ) + + await sift_client.async_.ingestion.create_ingestion_config( + asset_name=ASSET_NAME, + run_id=test_run.id_, + flows=[flow], + ) + + timestamp = datetime.now(tz=timezone.utc) + flow.ingest( + timestamp=timestamp, + channel_values={"bitfield-channel": bytes([0b11110000])}, + ) + + sift_client.async_.ingestion.wait_for_ingestion_to_complete(timeout=2) + + @pytest.mark.asyncio + async def test_ingest_multiple_channels(self, sift_client, test_run): + """Test ingesting data for multiple channels simultaneously.""" + flow = Flow( + name="test-multi-channel-flow", + channels=[ + ChannelConfig(name="channel1", data_type=ChannelDataType.DOUBLE), + ChannelConfig( + name="channel2", + data_type=ChannelDataType.ENUM, + enum_types={"a": 1, "b": 2}, + ), + ChannelConfig( + name="channel3", + data_type=ChannelDataType.BIT_FIELD, + bit_field_elements=[ + ChannelBitFieldElement(name="bit1", index=0, bit_count=4), + ChannelBitFieldElement(name="bit2", index=4, bit_count=4), + ], + ), + ], + ) + + await sift_client.async_.ingestion.create_ingestion_config( + asset_name=ASSET_NAME, + run_id=test_run.id_, + flows=[flow], + ) + + start_time = datetime.now(tz=timezone.utc) + for i in range(5): + timestamp = start_time + timedelta(seconds=i) + flow.ingest( + timestamp=timestamp, + channel_values={ + "channel1": float(i), + "channel2": (i % 2) + 1, + "channel3": {"bit1": i % 16, "bit2": (i * 2) % 16}, + }, + ) + + sift_client.async_.ingestion.wait_for_ingestion_to_complete(timeout=2) + + @pytest.mark.asyncio + async def test_ingest_highspeed_data(self, sift_client, test_run): + """Test ingesting high-speed data.""" + flow = Flow( + name="test-highspeed-data-flow", + channels=[ + ChannelConfig(name="highspeed-channel", data_type=ChannelDataType.DOUBLE), + ], + ) + + await sift_client.async_.ingestion.create_ingestion_config( + asset_name=ASSET_NAME, + run_id=test_run.id_, + flows=[flow], + ) + + start_time = datetime.now(tz=timezone.utc) + fake_hs_rate = 50 # Hz + fake_hs_period = 1 / fake_hs_rate + duration = 2 # seconds + + for i in range(duration): + for j in range(fake_hs_rate): + val = 3.0 * math.sin(2 * math.pi * fake_hs_rate * (i + j * 0.001)) + timestamp = start_time + timedelta( + seconds=i, milliseconds=j * fake_hs_period * 1000 + ) + flow.ingest( + timestamp=timestamp, + channel_values={"highspeed-channel": val}, + ) + time.sleep(0.01) + + sift_client.async_.ingestion.wait_for_ingestion_to_complete(timeout=2) + + class TestIngestionValidation: + """Tests for ingestion validation and error handling.""" + + @pytest.mark.asyncio + async def test_ingest_invalid_enum_value_raises_error(self, sift_client, test_run): + """Test that ingesting an invalid enum value raises an error.""" + flow = Flow( + name="test-enum-validation-flow", + channels=[ + ChannelConfig( + name="enum-channel", + data_type=ChannelDataType.ENUM, + enum_types={"valid1": 1, "valid2": 2}, + ), + ], + ) + + await sift_client.async_.ingestion.create_ingestion_config( + asset_name=ASSET_NAME, + run_id=test_run.id_, + flows=[flow], + ) + + timestamp = datetime.now(tz=timezone.utc) + # Test with invalid integer + with pytest.raises(ValueError, match="Could not find enum value"): + flow.ingest( + timestamp=timestamp, + channel_values={"enum-channel": 99}, + ) + + # Test with invalid string + with pytest.raises(ValueError, match="Could not find enum value"): + flow.ingest( + timestamp=timestamp, + channel_values={"enum-channel": "invalid-enum"}, + ) + + @pytest.mark.asyncio + async def test_resume_ingestion_after_wait(self, sift_client, test_run): + """Test that ingestion can resume after waiting for completion.""" + flow = Flow( + name="test-resume-flow", + channels=[ + ChannelConfig(name="test-channel", data_type=ChannelDataType.DOUBLE), + ], + ) + + await sift_client.async_.ingestion.create_ingestion_config( + asset_name=ASSET_NAME, + run_id=test_run.id_, + flows=[flow], + ) + + # First batch + timestamp1 = datetime.now(tz=timezone.utc) + flow.ingest(timestamp=timestamp1, channel_values={"test-channel": 1.0}) + + sift_client.async_.ingestion.wait_for_ingestion_to_complete(timeout=2) + + # Wait a bit + time.sleep(0.1) + + # Second batch after wait + timestamp2 = timestamp1 + timedelta(seconds=2) + flow.ingest(timestamp=timestamp2, channel_values={"test-channel": 2.0}) + + sift_client.async_.ingestion.wait_for_ingestion_to_complete(timeout=2) diff --git a/python/lib/sift_client/_tests/resources/test_ping.py b/python/lib/sift_client/_tests/resources/test_ping.py new file mode 100644 index 000000000..587d8a7f0 --- /dev/null +++ b/python/lib/sift_client/_tests/resources/test_ping.py @@ -0,0 +1,62 @@ +"""Pytest tests for the Ping API. + +These tests demonstrate and validate the usage of the Ping API including: +- Basic ping functionality +- Connection health checks +- Error handling and edge cases +""" + +import pytest + +from sift_client import SiftClient +from sift_client.resources import PingAPI, PingAPIAsync + +pytestmark = pytest.mark.integration + + +def test_client_binding(sift_client): + assert sift_client.ping + assert isinstance(sift_client.ping, PingAPI) + assert sift_client.async_.ping + assert isinstance(sift_client.async_.ping, PingAPIAsync) + + +@pytest.fixture +def ping_api_async(sift_client: SiftClient): + """Get the ping async API instance.""" + return sift_client.async_.ping + + +@pytest.fixture +def ping_api_sync(sift_client: SiftClient): + """Get the synchronous ping API instance.""" + return sift_client.ping + + +class TestPingAPIAsync: + """Test suite for the Ping API functionality.""" + + @pytest.mark.asyncio + async def test_basic_ping(self, ping_api_async): + """Test basic ping functionality.""" + response = await ping_api_async.ping() + + # Verify response is a string + assert isinstance(response, str) + + # Verify response is not empty + assert len(response) > 0 + + +class TestPingAPISync: + """Test suite for the Ping API functionality.""" + + def test_basic_ping(self, ping_api_sync): + """Test basic synchronous ping functionality.""" + response = ping_api_sync.ping() + + # Verify response is a string + assert isinstance(response, str) + + # Verify response is not empty + assert len(response) > 0 diff --git a/python/lib/sift_client/_tests/resources/test_rules.py b/python/lib/sift_client/_tests/resources/test_rules.py new file mode 100644 index 000000000..6036bec19 --- /dev/null +++ b/python/lib/sift_client/_tests/resources/test_rules.py @@ -0,0 +1,627 @@ +"""Pytest tests for the Rules API. + +These tests demonstrate and validate the usage of the Rules API including: +- Basic rule operations (get, list, find) +- Rule filtering and searching +- Rule creation, updates, and archiving +- Error handling and edge cases +""" + +import uuid +from datetime import datetime, timezone + +import pytest + +from sift_client import SiftClient +from sift_client.resources import RulesAPI, RulesAPIAsync +from sift_client.sift_types import Rule +from sift_client.sift_types.channel import ChannelReference +from sift_client.sift_types.rule import ( + RuleAction, + RuleActionType, + RuleAnnotationType, + RuleCreate, + RuleUpdate, +) + +pytestmark = pytest.mark.integration + + +def test_client_binding(sift_client): + assert sift_client.rules + assert isinstance(sift_client.rules, RulesAPI) + assert sift_client.async_.rules + assert isinstance(sift_client.async_.rules, RulesAPIAsync) + + +@pytest.fixture +def rules_api_async(sift_client: SiftClient): + """Get the async rules API instance.""" + return sift_client.async_.rules + + +@pytest.fixture +def rules_api_sync(sift_client: SiftClient): + """Get the synchronous rules API instance.""" + return sift_client.rules + + +@pytest.fixture +def test_rule(rules_api_sync): + rules = rules_api_sync.list_(limit=1) + assert rules + assert len(rules) >= 1 + return rules[0] + + +@pytest.fixture(scope="function") +def new_rule(rules_api_sync, sift_client): + """Create a test rule for update tests.""" + from datetime import datetime, timezone + + rule_name = f"test_rule_{datetime.now(timezone.utc).isoformat()}" + description = "Test rule created by Sift Client pytest" + + # Get some channels to reference + channels = sift_client.channels.list_(limit=2) + assert len(channels) >= 2 + + # Get an asset to apply the rule to + assets = sift_client.assets.list_(limit=1) + assert len(assets) >= 1 + + created_rule = rules_api_sync.create( + RuleCreate( + name=rule_name, + client_key=f"test_rule_{str(uuid.uuid4())[-8:]}", + description=description, + expression="$1 > $2", + channel_references=[ + ChannelReference(channel_reference="$1", channel_identifier=channels[0].name), + ChannelReference(channel_reference="$2", channel_identifier=channels[1].name), + ], + action=RuleAction.annotation( + annotation_type=RuleAnnotationType.DATA_REVIEW, + tags=[], + ), + asset_ids=[assets[0].id_], + ) + ) + return created_rule + + +class TestRulesAPIAsync: + """Test suite for the async Rules API functionality.""" + + class TestGet: + """Tests for the async get method.""" + + @pytest.mark.asyncio + async def test_get_by_id(self, rules_api_async, test_rule): + """Test getting a specific rule by ID.""" + retrieved_rule = await rules_api_async.get(rule_id=test_rule.id_) + + assert isinstance(retrieved_rule, Rule) + assert retrieved_rule.id_ == test_rule.id_ + assert retrieved_rule.name == test_rule.name + + @pytest.mark.asyncio + async def test_get_by_client_key(self, rules_api_async, test_rule): + """Test getting a specific rule by client key.""" + if test_rule.client_key: + retrieved_rule = await rules_api_async.get(client_key=test_rule.client_key) + + assert retrieved_rule is not None + assert retrieved_rule.id_ == test_rule.id_ + + class TestList: + """Tests for the async list_ method.""" + + @pytest.mark.asyncio + async def test_basic_list(self, rules_api_async): + """Test basic rule listing functionality.""" + rules = await rules_api_async.list_(limit=5) + + assert isinstance(rules, list) + assert len(rules) == 5 + + rule = rules[0] + assert isinstance(rule, Rule) + + @pytest.mark.asyncio + async def test_list_with_name_filter(self, rules_api_async, test_rule): + """Test rule listing with name filtering.""" + filtered_rules = await rules_api_async.list_(name=test_rule.name) + + assert isinstance(filtered_rules, list) + assert len(filtered_rules) >= 1 + + for rule in filtered_rules: + assert rule.name == test_rule.name + + @pytest.mark.asyncio + async def test_list_with_name_contains_filter(self, rules_api_async): + """Test rule listing with name contains filtering.""" + rules = await rules_api_async.list_(name_contains="test", limit=5) + + assert isinstance(rules, list) + assert rules + + for rule in rules: + assert "test" in rule.name.lower() + + @pytest.mark.asyncio + async def test_list_with_name_regex_filter(self, rules_api_async): + """Test rule listing with regex name filtering.""" + rules = await rules_api_async.list_(name_regex=r".*test.*", limit=5) + + assert isinstance(rules, list) + assert rules + + import re + + pattern = re.compile(r".*test.*", re.IGNORECASE) + for rule in rules: + assert pattern.match(rule.name) + + @pytest.mark.asyncio + async def test_list_with_rule_ids_filter(self, rules_api_async): + """Test rule listing with rule IDs filter.""" + all_rules = await rules_api_async.list_(limit=3) + + if all_rules: + rule_ids = [r.id_ for r in all_rules] + filtered_rules = await rules_api_async.list_(rule_ids=rule_ids) + + assert isinstance(filtered_rules, list) + assert len(filtered_rules) >= len(all_rules) + + for rule in filtered_rules: + assert rule.id_ in rule_ids + + @pytest.mark.asyncio + async def test_list_with_description_contains_filter(self, rules_api_async): + """Test rule listing with description contains filtering.""" + rules = await rules_api_async.list_(description_contains="test", limit=5) + + assert isinstance(rules, list) + assert rules + + for rule in rules: + assert "test" in rule.description.lower() + + @pytest.mark.asyncio + async def test_list_with_limit(self, rules_api_async): + """Test rule listing with different limits.""" + rules_1 = await rules_api_async.list_(limit=1) + assert isinstance(rules_1, list) + assert len(rules_1) <= 1 + + rules_3 = await rules_api_async.list_(limit=3) + assert isinstance(rules_3, list) + assert len(rules_3) <= 3 + + @pytest.mark.asyncio + async def test_list_with_time_filters(self, rules_api_async): + """Test rule listing with time-based filters.""" + from datetime import datetime, timedelta, timezone + + one_year_ago = datetime.now(timezone.utc) - timedelta(days=365) + rules = await rules_api_async.list_(created_after=one_year_ago, limit=5) + + assert isinstance(rules, list) + assert rules + + for rule in rules: + assert rule.created_date >= one_year_ago + + class TestFind: + """Tests for the async find method.""" + + @pytest.mark.asyncio + async def test_find_rule(self, rules_api_async, test_rule): + """Test finding a single rule.""" + found_rule = await rules_api_async.find(rule_ids=[test_rule.id_]) + + assert found_rule is not None + assert found_rule.id_ == test_rule.id_ + + @pytest.mark.asyncio + async def test_find_nonexistent_rule(self, rules_api_async): + """Test finding a non-existent rule returns None.""" + found_rule = await rules_api_async.find(name="nonexistent-rule-name-12345") + assert found_rule is None + + @pytest.mark.asyncio + async def test_find_multiple_raises_error(self, rules_api_async): + """Test finding multiple rules raises an error.""" + with pytest.raises(ValueError, match="Multiple"): + await rules_api_async.find(name_contains="test", limit=5) + + class TestCreate: + """Tests for the async create method.""" + + @pytest.mark.asyncio + async def test_create_basic_rule(self, rules_api_async): + """Test creating a basic rule with minimal fields.""" + from datetime import datetime, timezone + + rule_name = f"test_rule_create_{datetime.now(timezone.utc).isoformat()}" + description = "Test rule created by Sift Client pytest" + + channels = await rules_api_async.client.async_.channels.list_(limit=2) + assert len(channels) >= 2 + + assets = await rules_api_async.client.async_.assets.list_(limit=1) + assert len(assets) >= 1 + + rule_create = RuleCreate( + name=rule_name, + description=description, + expression="$1 > $2", + channel_references=[ + ChannelReference(channel_reference="$1", channel_identifier=channels[0].name), + ChannelReference(channel_reference="$2", channel_identifier=channels[1].name), + ], + action=RuleAction.annotation( + annotation_type=RuleAnnotationType.DATA_REVIEW, + tags=[], + ), + asset_ids=[assets[0].id_], + ) + + created_rule = await rules_api_async.create(rule_create) + + try: + assert created_rule is not None + assert isinstance(created_rule, Rule) + assert created_rule.id_ is not None + assert created_rule.name == rule_name + assert created_rule.description == description + assert created_rule.created_date is not None + assert created_rule.modified_date is not None + finally: + await rules_api_async.archive(created_rule) + + @pytest.mark.asyncio + async def test_create_rule_with_dict(self, rules_api_async): + """Test creating a rule using a dictionary.""" + from datetime import datetime, timezone + + rule_name = f"test_rule_dict_{datetime.now(timezone.utc).isoformat()}" + description = "Test rule created by Sift Client pytest" + + channels = await rules_api_async.client.async_.channels.list_(limit=2) + assert len(channels) >= 2 + + assets = await rules_api_async.client.async_.assets.list_(limit=1) + assert len(assets) >= 1 + + rule_dict = { + "name": rule_name, + "description": description, + "expression": "$1 > $2", + "channel_references": [ + {"channel_reference": "$1", "channel_identifier": channels[0].name}, + {"channel_reference": "$2", "channel_identifier": channels[1].name}, + ], + "action": { + "action_type": RuleActionType.ANNOTATION, + "annotation_type": RuleAnnotationType.PHASE, + "tags": [], + }, + "asset_ids": [assets[0].id_], + } + + created_rule = await rules_api_async.create(rule_dict) + + try: + assert created_rule.name == rule_name + assert created_rule.description == description + finally: + await rules_api_async.archive(created_rule) + + class TestUpdate: + """Tests for the async update method.""" + + @pytest.mark.asyncio + async def test_update_rule_description(self, rules_api_async, new_rule): + """Test updating a rule's description.""" + try: + update = RuleUpdate(description="Updated description") + updated_rule = await rules_api_async.update(new_rule, update) + + assert updated_rule.id_ == new_rule.id_ + assert updated_rule.description == "Updated description" + # Validate that things we didn't intentionally change didn't change + assert updated_rule.name == new_rule.name + assert updated_rule.is_enabled == new_rule.is_enabled + assert updated_rule.is_external == new_rule.is_external + assert updated_rule.expression == new_rule.expression + assert updated_rule.action.action_type == new_rule.action.action_type + assert updated_rule.client_key == new_rule.client_key + assert updated_rule.rule_version.created_date > new_rule.rule_version.created_date + finally: + await rules_api_async.archive(new_rule.id_) + + @pytest.mark.asyncio + async def test_update_rule_name(self, rules_api_async, new_rule): + """Test updating a rule's name.""" + try: + new_name = f"updated_{new_rule.name}" + update = RuleUpdate(name=new_name) + updated_rule = await rules_api_async.update(new_rule, update) + + assert updated_rule.name == new_name + assert updated_rule.id_ == new_rule.id_ + finally: + await rules_api_async.archive(new_rule.id_) + + @pytest.mark.asyncio + async def test_update_with_dict(self, rules_api_async, new_rule): + """Test updating a rule using a dictionary.""" + try: + update_dict = {"description": "Updated via dict"} + updated_rule = await rules_api_async.update(new_rule, update_dict) + + assert updated_rule.description == "Updated via dict" + finally: + await rules_api_async.archive(new_rule.id_) + + @pytest.mark.asyncio + async def test_update_with_id_string(self, rules_api_async, new_rule): + """Test updating a rule by passing ID as string.""" + try: + update = RuleUpdate(description="Updated via ID string") + updated_rule = await rules_api_async.update(new_rule.id_, update) + + assert updated_rule.id_ == new_rule.id_ + assert updated_rule.description == "Updated via ID string" + finally: + await rules_api_async.archive(new_rule.id_) + + @pytest.mark.asyncio + async def test_update_with_version_notes(self, rules_api_async, new_rule): + """Test updating a rule with version notes.""" + try: + update = RuleUpdate(description="Updated with version notes") + updated_rule = await rules_api_async.update( + new_rule, update, version_notes="Test version notes" + ) + + assert updated_rule.id_ == new_rule.id_ + assert updated_rule.description == "Updated with version notes" + finally: + await rules_api_async.archive(new_rule.id_) + + @pytest.mark.asyncio + async def test_update_rule_action(self, rules_api_async, new_rule): + """Test updating a rule's action including annotation type, tags, and assignee.""" + try: + # Update the action with new annotation type, tags, and assignee + update = RuleUpdate( + action=RuleAction.annotation( + annotation_type=RuleAnnotationType.PHASE, + tags=["sift-client-pytest"], + default_assignee_user_id=new_rule.created_by_user_id, + ), + ) + updated_rule = await rules_api_async.update(new_rule, update) + + # Verify the action was updated + assert updated_rule.id_ == new_rule.id_ + assert updated_rule.action.action_type == RuleActionType.ANNOTATION + assert updated_rule.action.annotation_type == RuleAnnotationType.PHASE + assert set(updated_rule.action.tags) == {"sift-client-pytest"} + assert updated_rule.action.default_assignee_user_id == new_rule.created_by_user_id + + # Verify other fields remain unchanged + assert updated_rule.name == new_rule.name + assert updated_rule.expression == new_rule.expression + finally: + await rules_api_async.archive(new_rule.id_) + + @pytest.mark.asyncio + async def test_update_with_complex_expression(self, rules_api_async, sift_client): + """Test updating a rule with a complex expression (range check).""" + # Get channels and assets + channels = await sift_client.async_.channels.list_(limit=2) + assert len(channels) >= 1 + assets = await sift_client.async_.assets.list_(limit=1) + assert len(assets) >= 1 + + # Create a rule with simple expression + rule_name = f"test_rule_complex_expr_{datetime.now(timezone.utc).isoformat()}" + rule_create = RuleCreate( + name=rule_name, + description="Test rule for complex expression update", + expression="$1 > 0.5", + channel_references=[ + ChannelReference(channel_reference="$1", channel_identifier=channels[0].name), + ], + action=RuleAction.annotation( + annotation_type=RuleAnnotationType.DATA_REVIEW, + tags=["test"], + ), + asset_ids=[assets[0].id_], + ) + created_rule = await rules_api_async.create(rule_create) + + try: + # Update with complex expression (range check) + update = RuleUpdate( + expression="$1 > 0.3 && $1 < 0.8", + channel_references=[ + ChannelReference( + channel_reference="$1", channel_identifier=channels[0].name + ), + ], + ) + updated_rule = await rules_api_async.update(created_rule, update) + + # Verify the expression was updated + assert updated_rule.id_ == created_rule.id_ + assert updated_rule.expression == "$1 > 0.3 && $1 < 0.8" + assert len(updated_rule.channel_references) == 1 + assert updated_rule.channel_references[0].channel_identifier == channels[0].name + + # Verify other fields remain unchanged + assert updated_rule.name == created_rule.name + assert updated_rule.action.action_type == created_rule.action.action_type + finally: + await rules_api_async.archive(created_rule.id_) + + @pytest.mark.asyncio + async def test_update_with_multiple_channel_references(self, rules_api_async, sift_client): + """Test updating a rule expression to use multiple channel references.""" + # Get channels and assets + channels = await sift_client.async_.channels.list_(limit=3) + assert len(channels) >= 3 + assets = await sift_client.async_.assets.list_(limit=1) + assert len(assets) >= 1 + + # Create a rule with simple expression + rule_name = f"test_rule_multi_refs_{datetime.now(timezone.utc).isoformat()}" + rule_create = RuleCreate( + name=rule_name, + description="Test rule for multiple channel references", + expression="$1 > $2", + channel_references=[ + ChannelReference(channel_reference="$1", channel_identifier=channels[0].name), + ChannelReference(channel_reference="$2", channel_identifier=channels[1].name), + ], + action=RuleAction.annotation( + annotation_type=RuleAnnotationType.DATA_REVIEW, + tags=["test"], + ), + asset_ids=[assets[0].id_], + ) + created_rule = await rules_api_async.create(rule_create) + + try: + # Update with expression using three channel references + update = RuleUpdate( + expression="($1 > $2) && ($3 < 100)", + channel_references=[ + ChannelReference( + channel_reference="$1", channel_identifier=channels[0].name + ), + ChannelReference( + channel_reference="$2", channel_identifier=channels[1].name + ), + ChannelReference( + channel_reference="$3", channel_identifier=channels[2].name + ), + ], + ) + updated_rule = await rules_api_async.update(created_rule, update) + + # Verify the expression and channel references were updated + assert updated_rule.id_ == created_rule.id_ + assert updated_rule.expression == "($1 > $2) && ($3 < 100)" + assert len(updated_rule.channel_references) == 3 + + # Verify all three channel references are present + ref_identifiers = { + ref.channel_identifier for ref in updated_rule.channel_references + } + assert channels[0].name in ref_identifiers + assert channels[1].name in ref_identifiers + assert channels[2].name in ref_identifiers + finally: + await rules_api_async.archive(created_rule.id_) + + @pytest.mark.asyncio + async def test_update_with_invalid_expression(self, rules_api_async, new_rule): + """Test updating a rule with an invalid expression. + + Note: The server may or may not validate expression syntax at update time. + This test documents the current behavior. + """ + try: + # Attempt to update with an invalid expression + update = RuleUpdate( + expression="invalid_expression", + channel_references=[ + ChannelReference( + channel_reference="$1", + channel_identifier=new_rule.channel_references[0].channel_identifier, + ), + ], + ) + + # This may succeed or fail depending on server-side validation + # If it succeeds, the expression is stored but may fail at evaluation time + try: + updated_rule = await rules_api_async.update(new_rule, update) + # If update succeeds, verify the expression was set + assert updated_rule.expression == "invalid_expression" + except Exception as e: + # If server validates and rejects, that's also acceptable behavior + assert "expression" in str(e).lower() or "invalid" in str(e).lower() + finally: + await rules_api_async.archive(new_rule.id_) + + class TestArchive: + """Tests for the async archive method.""" + + @pytest.mark.asyncio + async def test_archive_rule(self, rules_api_async, new_rule): + """Test archiving a rule.""" + rule = await rules_api_async.archive(new_rule) + + assert isinstance(rule, Rule) + assert rule.id_ == new_rule.id_ + assert rule.is_archived is True + + rules_without_archived = await rules_api_async.list_( + name=new_rule.name, include_archived=False + ) + assert len(rules_without_archived) == 0 + + rules_with_archived = await rules_api_async.list_( + name=new_rule.name, include_archived=True + ) + assert len(rules_with_archived) == 1 + assert rules_with_archived[0].id_ == new_rule.id_ + assert rules_with_archived[0].archived_date is not None + + @pytest.mark.asyncio + async def test_archive_with_id_string(self, rules_api_async, new_rule): + """Test archiving a rule by passing ID as string.""" + rule = await rules_api_async.archive(new_rule.id_) + + assert isinstance(rule, Rule) + assert rule.id_ == new_rule.id_ + assert rule.is_archived is True + + class TestUnarchive: + """Tests for the async unarchive method.""" + + @pytest.mark.asyncio + async def test_unarchive_rule(self, rules_api_async, new_rule): + """Test unarchiving a rule.""" + try: + await rules_api_async.archive(new_rule) + + rule = await rules_api_async.unarchive(new_rule) + + assert isinstance(rule, Rule) + assert rule.id_ == new_rule.id_ + assert rule.is_archived is False + finally: + await rules_api_async.archive(new_rule.id_) + + +class TestRulesAPISync: + """Test suite for the synchronous Rules API functionality.""" + + class TestGet: + """Tests for the sync get method.""" + + def test_get_by_id(self, rules_api_sync, test_rule): + """Test getting a specific rule by ID synchronously.""" + retrieved_rule = rules_api_sync.get(rule_id=test_rule.id_) + + assert retrieved_rule is not None + assert retrieved_rule.id_ == test_rule.id_ diff --git a/python/lib/sift_client/_tests/resources/test_runs.py b/python/lib/sift_client/_tests/resources/test_runs.py new file mode 100644 index 000000000..f687ecaed --- /dev/null +++ b/python/lib/sift_client/_tests/resources/test_runs.py @@ -0,0 +1,539 @@ +"""Pytest tests for the Runs API. + +These tests demonstrate and validate the usage of the Runs API including: +- Basic run operations (get, list, find) +- Run filtering and searching +- Run creation, updates, and archiving +- Error handling and edge cases +""" + +from datetime import datetime, timedelta, timezone + +import pytest + +from sift_client import SiftClient +from sift_client.resources import RunsAPI, RunsAPIAsync +from sift_client.sift_types import Run +from sift_client.sift_types.run import RunCreate, RunUpdate + +pytestmark = pytest.mark.integration + + +def test_client_binding(sift_client): + assert sift_client.runs + assert isinstance(sift_client.runs, RunsAPI) + assert sift_client.async_.runs + assert isinstance(sift_client.async_.runs, RunsAPIAsync) + + +@pytest.fixture +def runs_api_async(sift_client: SiftClient): + """Get the async runs API instance.""" + return sift_client.async_.runs + + +@pytest.fixture +def runs_api_sync(sift_client: SiftClient): + """Get the synchronous runs API instance.""" + return sift_client.runs + + +@pytest.fixture +def test_run(runs_api_sync): + runs = runs_api_sync.list_(limit=1) + assert runs + assert len(runs) >= 1 + return runs[0] + + +@pytest.fixture(scope="function") +def new_run(runs_api_sync): + """Create a test run for update tests.""" + run_name = f"test_run_update_{datetime.now(timezone.utc).isoformat()}" + description = "Test run created by Sift Client pytest" + created_run = runs_api_sync.create( + RunCreate( + name=run_name, + description=description, + tags=["sift-client-pytest"], + ) + ) + return created_run + + +class TestRunsAPIAsync: + """Test suite for the async Runs API functionality.""" + + class TestList: + """Tests for the async list method.""" + + @pytest.mark.asyncio + async def test_basic_list(self, runs_api_async): + """Test basic run listing functionality.""" + runs = await runs_api_async.list_(limit=5) + + # Verify we get a list + assert isinstance(runs, list) + assert runs + + # If we have runs, verify their structure + run = runs[0] + assert isinstance(run, Run) + assert run.id_ is not None + assert run.name is not None + + @pytest.mark.asyncio + async def test_list_with_name_filter(self, runs_api_async): + """Test run listing with name filtering.""" + # First get some runs to work with + all_runs = await runs_api_async.list_(limit=10) + + if all_runs: + # Use the first run's name for filtering + test_run_name = all_runs[0].name + filtered_runs = await runs_api_async.list_(name=test_run_name) + + # Should find at least one run with exact name match + assert isinstance(filtered_runs, list) + assert len(filtered_runs) >= 1 + + # All returned runs should have the exact name + for run in filtered_runs: + assert run.name == test_run_name + + @pytest.mark.asyncio + async def test_list_with_name_contains_filter(self, runs_api_async): + """Test run listing with name contains filtering.""" + # Test with a common substring that might exist in run names + runs = await runs_api_async.list_(name_contains="test", limit=5) + + assert isinstance(runs, list) + + # If we found runs, verify they contain the substring + for run in runs: + assert "test" in run.name.lower() + + # TODO: test run-specific filters + + @pytest.mark.asyncio + async def test_list_with_limit(self, runs_api_async): + """Test run listing with different limits.""" + # Test with limit of 1 + runs_1 = await runs_api_async.list_(limit=1) + assert isinstance(runs_1, list) + assert len(runs_1) <= 1 + + # Test with limit of 3 + runs_3 = await runs_api_async.list_(limit=3) + assert isinstance(runs_3, list) + assert len(runs_3) <= 3 + + @pytest.mark.asyncio + async def test_list_include_archived(self, runs_api_async): + """Test run listing with archived runs included.""" + # Test without archived runs (default) + runs_active = await runs_api_async.list_(limit=5, include_archived=False) + assert isinstance(runs_active, list) + + # Test with archived runs included + runs_all = await runs_api_async.list_(limit=5, include_archived=True) + assert isinstance(runs_all, list) + + # Should have at least as many runs when including archived + assert len(runs_all) >= len(runs_active) + + class TestGet: + """Tests for the async get method.""" + + @pytest.mark.asyncio + async def test_get_by_id(self, runs_api_async, test_run): + """Test getting a specific run by ID.""" + retrieved_run = await runs_api_async.get(run_id=test_run.id_) + + assert retrieved_run is not None + assert retrieved_run.id_ == test_run.id_ + + # TODO: test for client key + # @pytest.mark.asyncio + # async def test_get_by_id_with_client_key(self, runs_api_async, test_run): + # """Test getting a specific run by client key.""" + # assert test_run.client_key is not None + # retrieved_run = await runs_api_async.get(client_key=test_run.client_key) + # + # assert retrieved_run is not None + # assert retrieved_run.id_ == test_run.id_ + + @pytest.mark.asyncio + async def test_get_without_params_raises_error(self, runs_api_async): + """Test that getting a run without parameters raises an error.""" + with pytest.raises(ValueError, match="must be provided"): + await runs_api_async.get() + + @pytest.mark.asyncio + async def test_get_nonexistent_run_raises_error(self, runs_api_async): + """Test that getting a non-existent run raises an error.""" + with pytest.raises(ValueError, match="not found"): + await runs_api_async.get(client_key="nonexistent-run-name-12345") + + class TestFind: + """Tests for the async find method.""" + + @pytest.mark.asyncio + async def test_find_run(self, runs_api_async, test_run): + """Test finding a single run.""" + # Find the same run by name + found_run = await runs_api_async.find(name=test_run.name) + + assert found_run is not None + assert found_run.id_ == test_run.id_ + + @pytest.mark.asyncio + async def test_find_nonexistent_run(self, runs_api_async): + """Test finding a non-existent run returns None.""" + found_run = await runs_api_async.find(name="nonexistent-run-name-12345") + assert found_run is None + + @pytest.mark.asyncio + async def test_find_multiple_raises_error(self, runs_api_async): + """Test finding multiple runs raises an error.""" + with pytest.raises(ValueError, match="Multiple"): + await runs_api_async.find(name_contains="a") + + class TestCreate: + """Tests for the async create method.""" + + @pytest.mark.asyncio + async def test_create_basic_run(self, runs_api_async): + """Test creating a basic run with minimal fields.""" + run_name = f"test_run_create_{datetime.now(timezone.utc).isoformat()}" + description = "Test run created by Sift Client pytest" + run_create = RunCreate( + name=run_name, + description=description, + tags=["sift-client-pytest"], + ) + + created_run = await runs_api_async.create(run_create) + + try: + # Verify the run was created + assert created_run is not None + assert isinstance(created_run, Run) + assert created_run.id_ is not None + assert created_run.name == run_name + assert created_run.description == description + assert created_run.created_date is not None + assert created_run.modified_date is not None + finally: + # Clean up: archive the test run + await runs_api_async.archive(created_run) + + @pytest.mark.asyncio + async def test_create_run_with_all_fields(self, runs_api_async): + """Test creating a run with all optional fields.""" + run_name = f"test_run_full_{datetime.now(timezone.utc).isoformat()}" + description = "Test run created by Sift Client pytest" + start_time = datetime.now(timezone.utc) - timedelta(hours=1) + stop_time = datetime.now(timezone.utc) + + run_create = RunCreate( + name=run_name, + description=description, + client_key=f"client_key_{datetime.now(timezone.utc).timestamp()}", + start_time=start_time, + stop_time=stop_time, + tags=["test", "pytest", "integration", "sift-client-pytest"], + metadata={"pytest_type": "integration"}, + ) + + created_run = await runs_api_async.create(run_create) + + try: + # Verify all fields + assert created_run.name == run_name + assert created_run.description == description + assert created_run.client_key is not None + assert created_run.start_time is not None + assert created_run.stop_time is not None + assert created_run.tags == [ + "test", + "pytest", + "integration", + "sift-client-pytest", + ] + assert created_run.metadata["pytest_type"] == "integration" + + finally: + # Clean up + await runs_api_async.archive(created_run) + + @pytest.mark.asyncio + async def test_create_run_with_dict(self, runs_api_async): + """Test creating a run using a dictionary instead of RunCreate object.""" + run_name = f"test_run_dict_{datetime.now(timezone.utc).isoformat()}" + description = "Test run created by Sift Client pytest" + + run_dict = { + "name": run_name, + "description": description, + "tags": ["sift-client-pytest"], + } + + created_run = await runs_api_async.create(run_dict) + + try: + assert created_run.name == run_name + assert created_run.description == description + assert created_run.tags == ["sift-client-pytest"] + finally: + await runs_api_async.archive(created_run) + + class TestUpdate: + """Tests for the async update method.""" + + @pytest.mark.asyncio + async def test_update_run_description(self, runs_api_async, new_run): + """Test updating a run's description.""" + try: + # Update the description + update = RunUpdate(description="Updated description") + updated_run = await runs_api_async.update(new_run, update) + + # Verify the update + assert updated_run.id_ == new_run.id_ + assert updated_run.description == "Updated description" + assert updated_run.name == new_run.name # Name should remain unchanged + finally: + await runs_api_async.archive(new_run.id_) + + @pytest.mark.asyncio + async def test_update_run_name(self, runs_api_async, new_run): + """Test updating a run's name.""" + try: + # Update the name + new_name = f"updated_{new_run.name}" + update = RunUpdate(name=new_name) + updated_run = await runs_api_async.update(new_run, update) + + # Verify the update + assert updated_run.name == new_name + assert updated_run.id_ == new_run.id_ + finally: + await runs_api_async.archive(new_run.id_) + + @pytest.mark.asyncio + async def test_update_run_tags_and_metadata(self, runs_api_async, new_run): + """Test updating a run's tags and metadata.""" + try: + # Update tags and metadata + update = RunUpdate( + tags=["updated", "new-tag", "sift-client-pytest"], + metadata={"test_key": "test_value", "number": 42.5, "flag": True}, + ) + updated_run = await runs_api_async.update(new_run, update) + + # Verify the updates + assert set(updated_run.tags) == { + "updated", + "new-tag", + "sift-client-pytest", + } + assert updated_run.metadata["test_key"] == "test_value" + assert updated_run.metadata["number"] == 42.5 + assert updated_run.metadata["flag"] is True + finally: + await runs_api_async.archive(new_run.id_) + + @pytest.mark.asyncio + async def test_update_run_times(self, runs_api_async, new_run): + """Test updating a run's start and stop times.""" + try: + # Update with start and stop times + start_time = datetime.now(timezone.utc) - timedelta(hours=2) + stop_time = datetime.now(timezone.utc) - timedelta(hours=1) + update = RunUpdate(start_time=start_time, stop_time=stop_time) + updated_run = await runs_api_async.update(new_run, update) + + # Verify the times were set + assert updated_run.start_time is not None + assert updated_run.stop_time is not None + # Allow for small time differences due to serialization + assert abs((updated_run.start_time - start_time).total_seconds()) < 1 + assert abs((updated_run.stop_time - stop_time).total_seconds()) < 1 + finally: + await runs_api_async.archive(new_run.id_) + + @pytest.mark.asyncio + async def test_update_with_dict(self, runs_api_async, new_run): + """Test updating a run using a dictionary instead of RunUpdate object.""" + try: + # Update using dict + update_dict = {"description": "Updated via dict"} + updated_run = await runs_api_async.update(new_run, update_dict) + + assert updated_run.description == "Updated via dict" + finally: + await runs_api_async.archive(new_run.id_) + + @pytest.mark.asyncio + async def test_update_with_run_id_string(self, runs_api_async, new_run): + """Test updating a run by passing run ID as string.""" + try: + # Update using run ID string + update = RunUpdate(description="Updated via ID string") + updated_run = await runs_api_async.update(new_run.id_, update) + + assert updated_run.id_ == new_run.id_ + assert updated_run.description == "Updated via ID string" + finally: + await runs_api_async.archive(new_run.id_) + + class TestArchive: + """Tests for the async archive method.""" + + @pytest.mark.asyncio + async def test_archive_run(self, runs_api_async, new_run): + """Test archiving a run.""" + run = await runs_api_async.archive(new_run) + + assert isinstance(run, Run) + assert run.id_ == new_run.id_ + assert run.is_archived is True + + # Verify it's archived by checking it doesn't appear in normal list + runs_without_archived = await runs_api_async.list_( + name=new_run.name, include_archived=False + ) + assert len(runs_without_archived) == 0 + + # Verify it appears when including archived + runs_with_archived = await runs_api_async.list_( + name=new_run.name, include_archived=True + ) + assert len(runs_with_archived) == 1 + assert runs_with_archived[0].id_ == new_run.id_ + assert runs_with_archived[0].archived_date is not None + + @pytest.mark.asyncio + async def test_archive_with_run_id_string(self, runs_api_async, new_run): + """Test archiving a run by passing run ID as string.""" + # Archive using run ID string + run = await runs_api_async.archive(new_run.id_) + + assert isinstance(run, Run) + assert run.id_ == new_run.id_ + assert run.is_archived is True + + @pytest.mark.asyncio + async def test_get_archived_run_by_id(self, runs_api_async, new_run): + """Test that we can still get an archived run by ID.""" + # Archive the test run + run = await runs_api_async.archive(new_run) + + assert isinstance(run, Run) + assert run.id_ == new_run.id_ + assert run.is_archived is True + + class TestStop: + """Tests for the async stop method.""" + + @pytest.mark.asyncio + async def test_stop_run(self, runs_api_async, new_run): + """Test stopping a run.""" + try: + # Stop the run + stopped_run = await runs_api_async.stop(new_run) + + # Verify the run was stopped + assert isinstance(stopped_run, Run) + assert stopped_run.id_ == new_run.id_ + assert stopped_run.stop_time is not None + finally: + await runs_api_async.archive(new_run.id_) + + @pytest.mark.asyncio + async def test_stop_run_with_id_string(self, runs_api_async, new_run): + """Test stopping a run by passing run ID as string.""" + try: + # Stop using run ID string + stopped_run = await runs_api_async.stop(new_run.id_) + + # Verify the run was stopped + assert isinstance(stopped_run, Run) + assert stopped_run.id_ == new_run.id_ + assert stopped_run.stop_time is not None + finally: + await runs_api_async.archive(new_run.id_) + + @pytest.mark.asyncio + async def test_stop_run_with_start_time(self, runs_api_async, new_run): + """Test stopping a run that has a start time.""" + try: + # Set start time first + start_time = datetime.now(timezone.utc) - timedelta(hours=1) + update = RunUpdate(start_time=start_time) + await runs_api_async.update(new_run, update) + + # Stop the run + stopped_run = await runs_api_async.stop(new_run) + + # Verify the run was stopped and times are valid + assert stopped_run.stop_time is not None + assert stopped_run.start_time is not None + assert stopped_run.stop_time > stopped_run.start_time + finally: + await runs_api_async.archive(new_run.id_) + + class TestAssetAssociation: + """Tests for the async asset association methods.""" + + @pytest.mark.asyncio + async def test_create_automatic_association_for_assets(self, runs_api_async, sift_client): + """Test associating assets with a run for automatic data ingestion.""" + # Create a test run + run_name = f"test_run_asset_assoc_{datetime.now(timezone.utc).isoformat()}" + run_create = RunCreate( + name=run_name, + description="Test run for asset association", + tags=["sift-client-pytest"], + ) + created_run = await runs_api_async.create(run_create) + + try: + # Get some assets to associate + assets = await sift_client.async_.assets.list_(limit=2) + assert len(assets) >= 1 + + asset_names = [asset.name for asset in assets[:2]] + + # Associate assets with the run + await runs_api_async.create_automatic_association_for_assets( + run=created_run, asset_names=asset_names + ) + + # Verify the association by getting the run and checking asset_ids + updated_run = await runs_api_async.get(run_id=created_run.id_) + assert updated_run.asset_ids is not None + assert len(updated_run.asset_ids) >= len(asset_names) + + finally: + await runs_api_async.archive(created_run) + + +class TestRunsAPISync: + """Test suite for the synchronous Runs API functionality. + + Only includes a single test for basic sync generation. No specific sync behavior difference tests are needed. + """ + + class TestList: + """Tests for the sync list method.""" + + def test_basic_list(self, runs_api_sync): + """Test basic synchronous run listing functionality.""" + runs = runs_api_sync.list_(limit=5) + + # Verify we get a list + assert isinstance(runs, list) + assert runs + assert isinstance(runs[0], Run) diff --git a/python/lib/sift_client/_tests/sift_types/__init__.py b/python/lib/sift_client/_tests/sift_types/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/lib/sift_client/_tests/sift_types/test_asset.py b/python/lib/sift_client/_tests/sift_types/test_asset.py new file mode 100644 index 000000000..e209e2aad --- /dev/null +++ b/python/lib/sift_client/_tests/sift_types/test_asset.py @@ -0,0 +1,165 @@ +"""Tests for sift_types.Asset model.""" + +from datetime import datetime, timezone +from unittest.mock import MagicMock + +import pytest + +from sift_client.sift_types import Asset +from sift_client.sift_types.asset import AssetUpdate + + +class TestAssetUpdate: + """Unit tests for AssetUpdate model - tests _to_proto_helpers.""" + + def test_metadata_converter(self): + """Test that metadata is converted using _to_proto_helpers.""" + metadata = {"key1": "value1", "key2": 42.5, "key3": True} + update = AssetUpdate(metadata=metadata) + update.resource_id = "test_asset_id" + + proto, mask = update.to_proto_with_mask() + + assert proto.asset_id == "test_asset_id" + # Verify metadata was converted using the helper (returns a list) + assert len(proto.metadata) == 3 + + # Find each metadata value in the list + metadata_dict = {md.key.name: md for md in proto.metadata} + assert metadata_dict["key1"].string_value == "value1" + assert metadata_dict["key2"].number_value == 42.5 + assert metadata_dict["key3"].boolean_value is True + assert "metadata" in mask.paths + + +@pytest.fixture +def mock_asset(mock_client): + """Create a mock Asset instance for testing.""" + asset = Asset( + proto=MagicMock(), + id_="test_asset_id", + name="test_asset", + organization_id="org1", + created_date=datetime.now(timezone.utc), + created_by_user_id="user1", + modified_date=datetime.now(timezone.utc), + modified_by_user_id="user1", + tags=[], + metadata={}, + is_archived=False, + archived_date=None, + ) + asset._apply_client_to_instance(mock_client) + return asset + + +class TestAsset: + """Unit tests for Asset model - tests properties and methods.""" + + def test_runs_property_calls_client(self, mock_asset, mock_client): + """Test that runs property calls client.runs.list_ with correct parameters.""" + mock_client.runs.list_.return_value = [] + + # Access runs property + _ = mock_asset.runs + + # Verify client method was called with correct asset + mock_client.runs.list_.assert_called_once_with(assets=[mock_asset]) + + def test_channels_method_calls_client(self, mock_asset, mock_client): + """Test that channels() method calls client.channels.list_ with correct parameters.""" + mock_client.channels.list_.return_value = [] + + # Call channels method + _ = mock_asset.channels(limit=5) + + # Verify client method was called with correct parameters + mock_client.channels.list_.assert_called_once_with(asset=mock_asset, run=None, limit=5) + + def test_channels_method_with_run_filter(self, mock_asset, mock_client): + """Test that channels() method passes run filter to client.""" + mock_client.channels.list_.return_value = [] + mock_run = MagicMock() + + # Call channels method with run filter + _ = mock_asset.channels(run=mock_run, limit=10) + + # Verify client method was called with run parameter + mock_client.channels.list_.assert_called_once_with(asset=mock_asset, run=mock_run, limit=10) + + def test_archive_calls_client_and_updates_self(self, mock_asset, mock_client): + """Test that archive() calls client.assets.archive and calls _update.""" + archived_asset = MagicMock() + archived_asset.is_archived = True + archived_asset.archived_date = datetime.now(timezone.utc) + mock_client.assets.archive.return_value = archived_asset + + # Mock the _update method to verify it's called + with MagicMock() as mock_update: + mock_asset._update = mock_update + + # Call archive + result = mock_asset.archive(archive_runs=False) + + # Verify client method was called + mock_client.assets.archive.assert_called_once_with(asset=mock_asset, archive_runs=False) + # Verify _update was called with the returned asset + mock_update.assert_called_once_with(archived_asset) + # Verify it returns self + assert result is mock_asset + + def test_archive_with_runs(self, mock_asset, mock_client): + """Test that archive() passes archive_runs parameter correctly.""" + archived_asset = MagicMock() + mock_client.assets.archive.return_value = archived_asset + + # Mock the _update method + with MagicMock() as mock_update: + mock_asset._update = mock_update + + # Call archive with archive_runs=True + mock_asset.archive(archive_runs=True) + + # Verify client method was called with archive_runs=True + mock_client.assets.archive.assert_called_once_with(asset=mock_asset, archive_runs=True) + + def test_unarchive_calls_client_and_updates_self(self, mock_asset, mock_client): + """Test that unarchive() calls client.assets.unarchive and calls _update.""" + unarchived_asset = MagicMock() + unarchived_asset.is_archived = False + mock_client.assets.unarchive.return_value = unarchived_asset + + # Mock the _update method to verify it's called + with MagicMock() as mock_update: + mock_asset._update = mock_update + + # Call unarchive + result = mock_asset.unarchive() + + # Verify client method was called + mock_client.assets.unarchive.assert_called_once_with(asset=mock_asset) + # Verify _update was called with the returned asset + mock_update.assert_called_once_with(unarchived_asset) + # Verify it returns self + assert result is mock_asset + + def test_update_calls_client_and_updates_self(self, mock_asset, mock_client): + """Test that update() calls client.assets.update and calls _update.""" + updated_asset = MagicMock() + updated_asset.tags = ["updated"] + mock_client.assets.update.return_value = updated_asset + + # Mock the _update method to verify it's called + with MagicMock() as mock_update: + mock_asset._update = mock_update + + # Call update + update = AssetUpdate(tags=["updated"]) + result = mock_asset.update(update) + + # Verify client method was called with correct parameters + mock_client.assets.update.assert_called_once_with(asset=mock_asset, update=update) + # Verify _update was called with the returned asset + mock_update.assert_called_once_with(updated_asset) + # Verify it returns self + assert result is mock_asset diff --git a/python/lib/sift_client/_tests/sift_types/test_base.py b/python/lib/sift_client/_tests/sift_types/test_base.py new file mode 100644 index 000000000..cc266e7b9 --- /dev/null +++ b/python/lib/sift_client/_tests/sift_types/test_base.py @@ -0,0 +1,591 @@ +"""Unit tests for sift_types._base module.""" + +from __future__ import annotations + +from datetime import datetime, timezone +from typing import ClassVar +from unittest.mock import MagicMock + +import pytest +from sift.calculated_channels.v2.calculated_channels_pb2 import ( + CalculatedChannel as CalculatedChannelProto, +) +from sift.calculated_channels.v2.calculated_channels_pb2 import ( + CreateCalculatedChannelRequest, +) + +from sift_client.sift_types._base import ( + BaseType, + MappingHelper, + ModelCreate, + ModelUpdate, +) + + +class SimpleCreateModel(ModelCreate[CreateCalculatedChannelRequest]): + """Simple model for testing basic field mapping.""" + + name: str + description: str | None = None + units: str | None = None + + def _get_proto_class(self) -> type[CreateCalculatedChannelRequest]: + return CreateCalculatedChannelRequest + + +class NestedCreateModel(ModelCreate[CreateCalculatedChannelRequest]): + """Model for testing nested field mapping with MappingHelper.""" + + name: str + description: str | None = None + expression: str | None = None + all_assets: bool | None = None + + _to_proto_helpers: ClassVar[dict[str, MappingHelper]] = { + "expression": MappingHelper( + proto_attr_path="calculated_channel_configuration.query_configuration.sel.expression", + update_field="query_configuration", + ), + "all_assets": MappingHelper( + proto_attr_path="calculated_channel_configuration.asset_configuration.all_assets", + ), + } + + def _get_proto_class(self) -> type[CreateCalculatedChannelRequest]: + return CreateCalculatedChannelRequest + + +class SimpleUpdateModel(ModelUpdate[CalculatedChannelProto]): + """Simple model for testing update with field masks.""" + + name: str | None = None + description: str | None = None + units: str | None = None + + def _get_proto_class(self) -> type[CalculatedChannelProto]: + return CalculatedChannelProto + + def _add_resource_id_to_proto(self, proto_msg: CalculatedChannelProto): + if self._resource_id is None: + raise ValueError("Resource ID must be set before adding to proto") + proto_msg.calculated_channel_id = self._resource_id + + +class TestModelCreate: + """Tests for ModelCreate base class.""" + + def test_simple_create_with_all_fields(self): + """Test creating a proto with all fields set.""" + model = SimpleCreateModel( + name="test_name", description="test_description", units="test_units" + ) + proto = model.to_proto() + + assert proto.name == "test_name" + assert proto.description == "test_description" + assert proto.units == "test_units" + + def test_simple_create_with_none_fields_excluded(self): + """Test that None fields are excluded from proto.""" + model = SimpleCreateModel(name="test_name", description=None, units=None) + proto = model.to_proto() + + assert proto.name == "test_name" + # Proto should not have description or units set + assert proto.description == "" # Proto default for string + assert proto.units == "" # Proto default for string + + def test_simple_create_with_unset_fields_excluded(self): + """Test that unset fields are excluded from proto.""" + model = SimpleCreateModel(name="test_name") + proto = model.to_proto() + + assert proto.name == "test_name" + # Proto should not have description or units set + assert proto.description == "" # Proto default for string + assert proto.units == "" # Proto default for string + + def test_nested_create_with_mapping_helper(self): + """Test creating a proto with nested field mapping.""" + model = NestedCreateModel( + name="test_name", + description="test_description", + expression="$1 + $2", + all_assets=True, + ) + proto = model.to_proto() + + assert proto.name == "test_name" + assert proto.description == "test_description" + # Check nested fields + assert ( + proto.calculated_channel_configuration.query_configuration.sel.expression == "$1 + $2" + ) + assert proto.calculated_channel_configuration.asset_configuration.all_assets is True + + def test_nested_create_with_none_nested_fields(self): + """Test that None values in nested fields are excluded.""" + model = NestedCreateModel( + name="test_name", + description="test_description", + expression=None, + all_assets=None, + ) + proto = model.to_proto() + + assert proto.name == "test_name" + assert proto.description == "test_description" + # Nested fields should not be set + assert ( + proto.calculated_channel_configuration.query_configuration.sel.expression == "" + ) # Proto default + assert ( + proto.calculated_channel_configuration.asset_configuration.all_assets is False + ) # Proto default for bool + + def test_nested_create_with_unset_nested_fields(self): + """Test that unset nested fields are excluded.""" + model = NestedCreateModel(name="test_name", description="test_description") + proto = model.to_proto() + + assert proto.name == "test_name" + assert proto.description == "test_description" + # Nested fields should not be set + assert ( + proto.calculated_channel_configuration.query_configuration.sel.expression == "" + ) # Proto default + assert ( + proto.calculated_channel_configuration.asset_configuration.all_assets is False + ) # Proto default for bool + + def test_mixed_none_and_set_fields(self): + """Test model with mix of None, unset, and set fields.""" + model = NestedCreateModel( + name="test_name", + description=None, # Explicitly None + expression="$1 + $2", # Set + # all_assets is unset + ) + proto = model.to_proto() + + assert proto.name == "test_name" + assert proto.description == "" # None excluded, proto default + assert ( + proto.calculated_channel_configuration.query_configuration.sel.expression == "$1 + $2" + ) + assert ( + proto.calculated_channel_configuration.asset_configuration.all_assets is False + ) # Unset, proto default + + +class TestModelUpdate: + """Tests for ModelUpdate base class.""" + + def test_simple_update_with_field_mask(self): + """Test updating a proto with field mask.""" + model = SimpleUpdateModel(name="new_name", description="new_description") + model.resource_id = "test_id" + + proto, mask = model.to_proto_with_mask() + + assert proto.calculated_channel_id == "test_id" + assert proto.name == "new_name" + assert proto.description == "new_description" + assert set(mask.paths) == {"name", "description"} + + def test_update_with_none_value_excluded(self): + """Test that explicitly setting a field to None excludes it in the mask.""" + model = SimpleUpdateModel(name="new_name", description=None) + model.resource_id = "test_id" + + proto, mask = model.to_proto_with_mask() + + assert proto.calculated_channel_id == "test_id" + assert proto.name == "new_name" + + assert "description" not in mask.paths + assert "name" in mask.paths + + def test_update_with_unset_fields_excluded(self): + """Test that unset fields are excluded from the mask.""" + model = SimpleUpdateModel(name="new_name") + model.resource_id = "test_id" + + proto, mask = model.to_proto_with_mask() + + assert proto.calculated_channel_id == "test_id" + assert proto.name == "new_name" + # Only name should be in the mask + assert mask.paths == ["name"] + + def test_update_requires_resource_id(self): + """Test that update fails without resource_id.""" + model = SimpleUpdateModel(name="new_name") + + with pytest.raises(ValueError, match="Resource ID must be set"): + model.to_proto_with_mask() + + +class TestMappingHelper: + """Tests for MappingHelper functionality.""" + + def test_mapping_helper_basic(self): + """Test basic MappingHelper creation.""" + helper = MappingHelper(proto_attr_path="field.nested.path") + assert helper.proto_attr_path == "field.nested.path" + assert helper.update_field is None + assert helper.converter is None + + def test_mapping_helper_with_update_field(self): + """Test MappingHelper with update_field.""" + helper = MappingHelper(proto_attr_path="field.nested.path", update_field="field") + assert helper.proto_attr_path == "field.nested.path" + assert helper.update_field == "field" + + def test_mapping_helper_with_converter(self): + """Test MappingHelper with converter function.""" + + def converter(x): + return x.upper() + + helper = MappingHelper(proto_attr_path="field.path", converter=converter) + assert helper.converter is not None + assert helper.converter("test") == "TEST" + + +class TestEdgeCases: + """Tests for edge cases and regression prevention.""" + + def test_empty_model_create(self): + """Test creating with only required fields.""" + model = SimpleCreateModel(name="test") + proto = model.to_proto() + assert proto.name == "test" + + def test_nested_path_expansion(self): + """Test that nested paths are properly expanded.""" + model = NestedCreateModel(name="test", expression="$1 + $2") + proto = model.to_proto() + + # Verify the nested structure was created + assert proto.HasField("calculated_channel_configuration") + assert proto.calculated_channel_configuration.HasField("query_configuration") + assert proto.calculated_channel_configuration.query_configuration.HasField("sel") + assert ( + proto.calculated_channel_configuration.query_configuration.sel.expression == "$1 + $2" + ) + + def test_multiple_nested_fields_same_parent(self): + """Test multiple nested fields that share a parent path.""" + model = NestedCreateModel(name="test", expression="$1 + $2", all_assets=True) + proto = model.to_proto() + + # Both fields should be set under calculated_channel_configuration + assert ( + proto.calculated_channel_configuration.query_configuration.sel.expression == "$1 + $2" + ) + assert proto.calculated_channel_configuration.asset_configuration.all_assets is True + + def test_validation_error_on_invalid_helper_field(self): + """Test that MappingHelper validation catches mismatched fields.""" + with pytest.raises(ValueError, match="MappingHelper created for"): + + class InvalidModel(ModelCreate[CreateCalculatedChannelRequest]): + name: str + + _to_proto_helpers: ClassVar[dict[str, MappingHelper]] = { + "nonexistent_field": MappingHelper(proto_attr_path="some.path"), + } + + def _get_proto_class(self): + return CreateCalculatedChannelRequest + + # This should raise during __init__ + InvalidModel(name="test") + + +class TestBaseType: + """Tests for BaseType base class.""" + + def test_base_type_concrete_implementation(self): + """Test creating a concrete BaseType implementation.""" + + class TestModel(BaseType[CalculatedChannelProto, "TestModel"]): + name: str + created_date: datetime | None = None + + @classmethod + def _from_proto(cls, proto, sift_client=None): + return cls(name=proto.name, _client=sift_client) + + model = TestModel(name="test", id_="test_id") + assert model.name == "test" + assert model.id_ == "test_id" + + def test_id_or_error_with_id(self): + """Test _id_or_error property when ID is set.""" + + class TestModel(BaseType[CalculatedChannelProto, "TestModel"]): + name: str + + @classmethod + def _from_proto(cls, proto, sift_client=None): + return cls(name=proto.name) + + model = TestModel(name="test", id_="test_id_123") + assert model._id_or_error == "test_id_123" + + def test_id_or_error_without_id(self): + """Test _id_or_error property raises when ID is not set.""" + + class TestModel(BaseType[CalculatedChannelProto, "TestModel"]): + name: str + + @classmethod + def _from_proto(cls, proto, sift_client=None): + return cls(name=proto.name) + + model = TestModel(name="test") + with pytest.raises(ValueError, match="ID is not set"): + _ = model._id_or_error + + def test_client_property_without_client(self): + """Test client property raises when client is not set.""" + + class TestModel(BaseType[CalculatedChannelProto, "TestModel"]): + name: str + + @classmethod + def _from_proto(cls, proto, sift_client=None): + return cls(name=proto.name) + + model = TestModel(name="test") + with pytest.raises(AttributeError, match="Sift client not set"): + _ = model.client + + def test_apply_client_to_instance(self): + """Test _apply_client_to_instance method.""" + + class TestModel(BaseType[CalculatedChannelProto, "TestModel"]): + name: str + + @classmethod + def _from_proto(cls, proto, sift_client=None): + return cls(name=proto.name) + + model = TestModel(name="test") + assert model._client is None + + mock_client = MagicMock() + model._apply_client_to_instance(mock_client) + assert model._client is mock_client + assert model.client is mock_client + + def test_update_method(self): + """Test _update method updates fields from another instance.""" + + class TestModel(BaseType[CalculatedChannelProto, "TestModel"]): + name: str + description: str | None = None + version: int | None = None + + @classmethod + def _from_proto(cls, proto, sift_client=None): + return cls( + name=proto.name, + description=proto.description, + version=proto.version, + proto=proto, + ) + + # Create original model + original = TestModel(name="original", description="old desc", version=1, id_="id1") + + # Create updated model + mock_proto = MagicMock() + updated = TestModel( + name="updated", + description="new desc", + version=2, + id_="id1", + proto=mock_proto, + ) + + # Update original with updated values + result = original._update(updated) + + assert result is original # Returns self + assert original.name == "updated" + assert original.description == "new desc" + assert original.version == 2 + assert original.proto is mock_proto + + def test_validate_timezones_with_valid_datetime(self): + """Test timezone validation passes with timezone-aware datetime.""" + + class TestModel(BaseType[CalculatedChannelProto, "TestModel"]): + name: str + created_date: datetime | None = None + + @classmethod + def _from_proto(cls, proto, sift_client=None): + return cls(name=proto.name, created_date=proto.created_date) + + # Should not raise + model = TestModel(name="test", created_date=datetime.now(timezone.utc)) + assert model.created_date.tzinfo is not None + + def test_validate_timezones_with_naive_datetime(self): + """Test timezone validation fails with naive datetime.""" + + class TestModel(BaseType[CalculatedChannelProto, "TestModel"]): + name: str + created_date: datetime + + @classmethod + def _from_proto(cls, proto, sift_client=None): + return cls(name=proto.name, created_date=proto.created_date) + + # Should raise validation error + with pytest.raises(ValueError, match="must have timezone information"): + TestModel(name="test", created_date=datetime.now()) # noqa: DTZ005 + + def test_validate_timezones_with_none_datetime(self): + """Test timezone validation passes when datetime is None.""" + + class TestModel(BaseType[CalculatedChannelProto, "TestModel"]): + name: str + created_date: datetime | None = None + + @classmethod + def _from_proto(cls, proto, sift_client=None): + return cls(name=proto.name) + + # Should not raise + model = TestModel(name="test", created_date=None) + assert model.created_date is None + + def test_proto_field_excluded_from_dump(self): + """Test that proto field is excluded from model_dump.""" + + class TestModel(BaseType[CalculatedChannelProto, "TestModel"]): + name: str + + @classmethod + def _from_proto(cls, proto, sift_client=None): + return cls(name=proto.name, proto=proto) + + mock_proto = MagicMock() + model = TestModel(name="test", proto=mock_proto) + + dumped = model.model_dump() + assert "proto" not in dumped + assert "name" in dumped + + def test_frozen_model_config(self): + """Test that BaseType models are frozen.""" + + class TestModel(BaseType[CalculatedChannelProto, "TestModel"]): + name: str + + @classmethod + def _from_proto(cls, proto, sift_client=None): + return cls(name=proto.name) + + model = TestModel(name="test") + + # Should not be able to modify frozen model + with pytest.raises(Exception): # Pydantic raises ValidationError # noqa: B017 + model.name = "new_name" + + +class TestBuildProtoAndPaths: + """Tests specifically for _build_proto_and_paths method.""" + + def test_build_proto_simple_fields(self): + """Test building proto with simple scalar fields.""" + model = SimpleCreateModel(name="test", description="desc") + proto = CreateCalculatedChannelRequest() + + paths = model._build_proto_and_paths(proto, {"name": "test", "description": "desc"}) + + assert proto.name == "test" + assert proto.description == "desc" + assert set(paths) == {"name", "description"} + + def test_build_proto_with_prefix(self): + """Test building proto with path prefix.""" + model = SimpleCreateModel(name="test") + proto = CreateCalculatedChannelRequest() + + paths = model._build_proto_and_paths(proto, {"name": "test"}, prefix="parent") + + assert proto.name == "test" + assert paths == ["parent.name"] + + def test_build_proto_with_nested_dict(self): + """Test building proto with nested dictionary through normal flow.""" + # This tests that the MappingHelper properly expands nested paths + model = NestedCreateModel(name="test", expression="$1 + $2") + proto = model.to_proto() + + # Verify the nested structure was created correctly + assert ( + proto.calculated_channel_configuration.query_configuration.sel.expression == "$1 + $2" + ) + assert proto.name == "test" + + def test_build_proto_with_submessage_dict(self): + """Test building proto when data contains a dict for a submessage field.""" + + class SubmessageModel(ModelCreate[CreateCalculatedChannelRequest]): + name: str + + def _get_proto_class(self): + return CreateCalculatedChannelRequest + + model = SubmessageModel(name="test") + proto = CreateCalculatedChannelRequest() + + # Test that we can build nested structures by passing dict data + # This simulates what happens when processing nested proto messages + data = {"name": "test"} + paths = model._build_proto_and_paths(proto, data) + + assert proto.name == "test" + assert "name" in paths + + def test_build_proto_with_mapping_helper_update_field(self): + """Test that mapping helper's update_field is added to paths.""" + model = NestedCreateModel(name="test", expression="$1 + $2") + proto = CreateCalculatedChannelRequest() + + data = {"name": "test", "expression": "$1 + $2"} + paths = model._build_proto_and_paths(proto, data) + + assert "query_configuration" in paths + assert "name" in paths + assert ( + proto.calculated_channel_configuration.query_configuration.sel.expression == "$1 + $2" + ) + + def test_build_proto_error_on_invalid_field(self): + """Test that setting an invalid field raises TypeError.""" + model = SimpleCreateModel(name="test") + proto = CreateCalculatedChannelRequest() + + with pytest.raises(TypeError, match="Can't set"): + model._build_proto_and_paths(proto, {"nonexistent_field": "value"}) + + def test_build_proto_already_setting_path_override(self): + """Test that already_setting_path_override skips helper processing.""" + model = NestedCreateModel(name="test") + proto = CreateCalculatedChannelRequest() + + # When already_setting_path_override=True, it should skip the helper + # and try to process the field directly + data = {"name": "test"} + paths = model._build_proto_and_paths(proto, data, already_setting_path_override=True) + + assert proto.name == "test" + assert "name" in paths diff --git a/python/lib/sift_client/_tests/sift_types/test_calculated_channel.py b/python/lib/sift_client/_tests/sift_types/test_calculated_channel.py new file mode 100644 index 000000000..e5b1059c5 --- /dev/null +++ b/python/lib/sift_client/_tests/sift_types/test_calculated_channel.py @@ -0,0 +1,292 @@ +"""Tests for sift_types.CalculatedChannel model.""" + +from datetime import datetime, timezone +from unittest.mock import MagicMock + +import pytest + +from sift_client.sift_types import CalculatedChannel +from sift_client.sift_types.calculated_channel import ( + CalculatedChannelUpdate, +) +from sift_client.sift_types.channel import ChannelReference + + +class TestCalculatedChannelBase: + """Unit tests for CalculatedChannelBase - tests _to_proto_helpers and validators shared by Create and Update.""" + + def test_metadata_converter(self): + """Test that metadata is converted using _to_proto_helpers.""" + metadata = {"key1": "value1", "key2": 42.5, "key3": False} + update = CalculatedChannelUpdate(metadata=metadata) + update.resource_id = "test_calc_channel_id" + + proto, mask = update.to_proto_with_mask() + + assert len(proto.metadata) == 3 + + # Convert list to dict for easier assertion + metadata_dict = {md.key.name: md for md in proto.metadata} + assert metadata_dict["key1"].string_value == "value1" + assert metadata_dict["key2"].number_value == 42.5 + assert metadata_dict["key3"].boolean_value is False + assert "metadata" in mask.paths + + def test_expression_helper(self): + """Test that expression is mapped to nested proto path.""" + update = CalculatedChannelUpdate( + expression="$1 + $2", + expression_channel_references=[ + ChannelReference(channel_reference="$1", channel_identifier="channel1"), + ChannelReference(channel_reference="$2", channel_identifier="channel2"), + ], + ) + update.resource_id = "test_calc_channel_id" + + proto, mask = update.to_proto_with_mask() + + # Verify expression is set in nested path + assert ( + proto.calculated_channel_configuration.query_configuration.sel.expression == "$1 + $2" + ) + assert "query_configuration" in mask.paths + + def test_expression_channel_references_helper(self): + """Test that expression_channel_references are converted and mapped.""" + update = CalculatedChannelUpdate( + expression="$1 + $2", + expression_channel_references=[ + ChannelReference(channel_reference="$1", channel_identifier="channel1"), + ChannelReference(channel_reference="$2", channel_identifier="channel2"), + ], + ) + update.resource_id = "test_calc_channel_id" + + proto, mask = update.to_proto_with_mask() + + # Verify channel references are converted + refs = proto.calculated_channel_configuration.query_configuration.sel.expression_channel_references + assert len(refs) == 2 + assert refs[0].channel_reference == "$1" + assert refs[0].channel_identifier == "channel1" + assert refs[1].channel_reference == "$2" + assert refs[1].channel_identifier == "channel2" + assert "query_configuration" in mask.paths + + def test_tag_ids_helper(self): + """Test that tag_ids are mapped to nested proto path.""" + update = CalculatedChannelUpdate(tag_ids=["tag1", "tag2"]) + update.resource_id = "test_calc_channel_id" + + proto, mask = update.to_proto_with_mask() + + # Verify tag_ids are set in nested path + assert list( + proto.calculated_channel_configuration.asset_configuration.selection.tag_ids + ) == ["tag1", "tag2"] + assert "asset_configuration" in mask.paths + + def test_asset_ids_helper(self): + """Test that asset_ids are mapped to nested proto path.""" + update = CalculatedChannelUpdate(asset_ids=["asset1", "asset2"]) + update.resource_id = "test_calc_channel_id" + + proto, mask = update.to_proto_with_mask() + + # Verify asset_ids are set in nested path + assert list( + proto.calculated_channel_configuration.asset_configuration.selection.asset_ids + ) == ["asset1", "asset2"] + assert "asset_configuration" in mask.paths + + def test_all_assets_helper(self): + """Test that all_assets is mapped to nested proto path.""" + update = CalculatedChannelUpdate(all_assets=True) + update.resource_id = "test_calc_channel_id" + + proto, mask = update.to_proto_with_mask() + + # Verify all_assets is set in nested path + assert proto.calculated_channel_configuration.asset_configuration.all_assets is True + # Verify update_field is in mask (same as tag_ids and asset_ids) + assert "asset_configuration" in mask.paths + + def test_asset_configuration_validator_rejects_all_assets_with_asset_ids(self): + """Test validator rejects all_assets=True with asset_ids.""" + with pytest.raises( + ValueError, + match="Cannot specify both all_assets=True and asset_ids/tag_ids", + ): + CalculatedChannelUpdate( + all_assets=True, + asset_ids=["asset1"], + ) + + def test_asset_configuration_validator_rejects_all_assets_with_tag_ids(self): + """Test validator rejects all_assets=True with tag_ids.""" + with pytest.raises( + ValueError, + match="Cannot specify both all_assets=True and asset_ids/tag_ids", + ): + CalculatedChannelUpdate( + all_assets=True, + tag_ids=["tag1"], + ) + + def test_expression_validator_rejects_expression_without_references(self): + """Test validator rejects expression without channel references.""" + with pytest.raises( + ValueError, match="Expression and channel references must be set together" + ): + CalculatedChannelUpdate(expression="$1 + $2") + + def test_expression_validator_rejects_references_without_expression(self): + """Test validator rejects channel references without expression.""" + with pytest.raises( + ValueError, match="Expression and channel references must be set together" + ): + CalculatedChannelUpdate( + expression_channel_references=[ + ChannelReference(channel_reference="$1", channel_identifier="channel1"), + ], + ) + + def test_expression_validator_accepts_both_set(self): + """Test validator accepts expression and channel references together.""" + # Should not raise + update = CalculatedChannelUpdate( + expression="$1 + $2", + expression_channel_references=[ + ChannelReference(channel_reference="$1", channel_identifier="channel1"), + ChannelReference(channel_reference="$2", channel_identifier="channel2"), + ], + ) + assert update.expression == "$1 + $2" + assert len(update.expression_channel_references) == 2 + + +@pytest.fixture +def mock_calculated_channel(mock_client): + """Create a mock CalculatedChannel instance for testing.""" + calc_channel = CalculatedChannel( + proto=MagicMock(), + id_="test_calc_channel_id", + name="test_calc_channel", + description="test description", + expression="$1 + $2", + channel_references=[ + ChannelReference(channel_reference="$1", channel_identifier="channel1"), + ChannelReference(channel_reference="$2", channel_identifier="channel2"), + ], + is_archived=False, + units=None, + asset_ids=None, + tag_ids=None, + all_assets=True, + organization_id="org1", + client_key=None, + archived_date=None, + version_id="v1", + version=1, + change_message=None, + user_notes=None, + created_date=datetime.now(timezone.utc), + modified_date=datetime.now(timezone.utc), + created_by_user_id="user1", + modified_by_user_id="user1", + ) + calc_channel._apply_client_to_instance(mock_client) + return calc_channel + + +class TestCalculatedChannel: + """Unit tests for CalculatedChannel model - tests properties and methods.""" + + def test_archive_calls_client_and_updates_self(self, mock_calculated_channel, mock_client): + """Test that archive() calls client.calculated_channels.archive and calls _update.""" + archived_calc_channel = MagicMock() + archived_calc_channel.is_archived = True + mock_client.calculated_channels.archive.return_value = archived_calc_channel + + # Mock the _update method to verify it's called + with MagicMock() as mock_update: + mock_calculated_channel._update = mock_update + + # Call archive + result = mock_calculated_channel.archive() + + # Verify client method was called + mock_client.calculated_channels.archive.assert_called_once_with( + calculated_channel=mock_calculated_channel + ) + # Verify _update was called with the returned calculated channel + mock_update.assert_called_once_with(archived_calc_channel) + # Verify it returns self + assert result is mock_calculated_channel + + def test_unarchive_calls_client_and_updates_self(self, mock_calculated_channel, mock_client): + """Test that unarchive() calls client.calculated_channels.unarchive and calls _update.""" + unarchived_calc_channel = MagicMock() + unarchived_calc_channel.is_archived = False + mock_client.calculated_channels.unarchive.return_value = unarchived_calc_channel + + # Mock the _update method to verify it's called + with MagicMock() as mock_update: + mock_calculated_channel._update = mock_update + + # Call unarchive + result = mock_calculated_channel.unarchive() + + # Verify client method was called + mock_client.calculated_channels.unarchive.assert_called_once_with( + calculated_channel=mock_calculated_channel + ) + # Verify _update was called with the returned calculated channel + mock_update.assert_called_once_with(unarchived_calc_channel) + # Verify it returns self + assert result is mock_calculated_channel + + def test_update_calls_client_and_updates_self(self, mock_calculated_channel, mock_client): + """Test that update() calls client.calculated_channels.update and calls _update.""" + updated_calc_channel = MagicMock() + updated_calc_channel.description = "Updated description" + mock_client.calculated_channels.update.return_value = updated_calc_channel + + # Mock the _update method to verify it's called + with MagicMock() as mock_update: + mock_calculated_channel._update = mock_update + + # Call update + update = CalculatedChannelUpdate(description="Updated description") + result = mock_calculated_channel.update(update) + + # Verify client method was called with correct parameters + mock_client.calculated_channels.update.assert_called_once_with( + calculated_channel=mock_calculated_channel, + update=update, + user_notes=None, + ) + # Verify _update was called with the returned calculated channel + mock_update.assert_called_once_with(updated_calc_channel) + # Verify it returns self + assert result is mock_calculated_channel + + def test_update_with_user_notes(self, mock_calculated_channel, mock_client): + """Test that update() passes user_notes parameter correctly.""" + updated_calc_channel = MagicMock() + mock_client.calculated_channels.update.return_value = updated_calc_channel + + # Mock the _update method + with MagicMock() as mock_update: + mock_calculated_channel._update = mock_update + + # Call update with user_notes + update = CalculatedChannelUpdate(description="Updated") + mock_calculated_channel.update(update, user_notes="Test notes") + + # Verify client method was called with user_notes + mock_client.calculated_channels.update.assert_called_once_with( + calculated_channel=mock_calculated_channel, + update=update, + user_notes="Test notes", + ) diff --git a/python/lib/sift_client/_tests/sift_types/test_channel.py b/python/lib/sift_client/_tests/sift_types/test_channel.py new file mode 100644 index 000000000..ab8fa01c2 --- /dev/null +++ b/python/lib/sift_client/_tests/sift_types/test_channel.py @@ -0,0 +1,123 @@ +"""Tests for sift_types.Channel model.""" + +from datetime import datetime, timezone +from unittest.mock import MagicMock + +import pytest + +from sift_client.sift_types import Channel +from sift_client.sift_types.channel import ChannelDataType + + +@pytest.fixture +def mock_channel(mock_client): + """Create a mock Channel instance for testing.""" + channel = Channel( + proto=MagicMock(), + id_="test_channel_id", + name="test_channel", + data_type=ChannelDataType.DOUBLE, + description="test description", + unit="m/s", + bit_field_elements=[], + enum_types={}, + asset_id="test_asset_id", + created_date=datetime.now(timezone.utc), + modified_date=datetime.now(timezone.utc), + created_by_user_id="user1", + modified_by_user_id="user1", + ) + channel._apply_client_to_instance(mock_client) + return channel + + +class TestChannel: + """Unit tests for Channel model - tests properties and methods.""" + + def test_asset_property_calls_client(self, mock_channel, mock_client): + """Test that asset property calls client.assets.get with correct parameters.""" + mock_asset = MagicMock() + mock_client.assets.get.return_value = mock_asset + + # Access asset property + result = mock_channel.asset + + # Verify client method was called with correct asset_id + mock_client.assets.get.assert_called_once_with(asset_id="test_asset_id") + assert result is mock_asset + + def test_runs_property_calls_asset_runs(self, mock_channel, mock_client): + """Test that runs property calls asset.runs.""" + mock_asset = MagicMock() + mock_runs = [MagicMock(), MagicMock()] + mock_asset.runs = mock_runs + mock_client.assets.get.return_value = mock_asset + + # Access runs property + result = mock_channel.runs + + # Verify it returns the asset's runs + assert result == mock_runs + + def test_data_method_calls_get_data(self, mock_channel, mock_client): + """Test that data() method calls client.channels.get_data with correct parameters.""" + mock_data = {"test_channel": MagicMock()} + mock_client.channels.get_data.return_value = mock_data + + # Call data method + result = mock_channel.data( + run_id="run123", + start_time=datetime(2024, 1, 1, tzinfo=timezone.utc), + end_time=datetime(2024, 1, 2, tzinfo=timezone.utc), + limit=100, + ) + + # Verify client method was called with correct parameters + mock_client.channels.get_data.assert_called_once_with( + channels=[mock_channel], + run="run123", + start_time=datetime(2024, 1, 1, tzinfo=timezone.utc), + end_time=datetime(2024, 1, 2, tzinfo=timezone.utc), + limit=100, + ) + assert result == mock_data + + def test_data_method_as_arrow(self, mock_channel, mock_client): + """Test that data() method calls get_data_as_arrow when as_arrow=True.""" + mock_data = {"test_channel": MagicMock()} + mock_client.channels.get_data_as_arrow.return_value = mock_data + + # Call data method with as_arrow=True + result = mock_channel.data( + run_id="run123", + as_arrow=True, + ) + + # Verify get_data_as_arrow was called instead of get_data + mock_client.channels.get_data_as_arrow.assert_called_once_with( + channels=[mock_channel], + run="run123", + start_time=None, + end_time=None, + limit=None, + ) + mock_client.channels.get_data.assert_not_called() + assert result == mock_data + + def test_data_method_with_minimal_params(self, mock_channel, mock_client): + """Test that data() method works with minimal parameters.""" + mock_data = {"test_channel": MagicMock()} + mock_client.channels.get_data.return_value = mock_data + + # Call data method with no parameters + result = mock_channel.data() + + # Verify client method was called with None values + mock_client.channels.get_data.assert_called_once_with( + channels=[mock_channel], + run=None, + start_time=None, + end_time=None, + limit=None, + ) + assert result == mock_data diff --git a/python/lib/sift_client/_tests/sift_types/test_ingestion.py b/python/lib/sift_client/_tests/sift_types/test_ingestion.py new file mode 100644 index 000000000..6b29abafe --- /dev/null +++ b/python/lib/sift_client/_tests/sift_types/test_ingestion.py @@ -0,0 +1,180 @@ +"""Tests for sift_types.Ingestion models.""" + +from datetime import datetime, timezone +from unittest.mock import MagicMock + +import pytest + +from sift_client.sift_types.channel import ChannelBitFieldElement, ChannelDataType +from sift_client.sift_types.ingestion import ChannelConfig, Flow, IngestionConfig + + +class TestChannelConfig: + """Unit tests for ChannelConfig model - tests validators.""" + + def test_enum_validator_rejects_enum_without_enum_types(self): + """Test validator rejects ENUM data_type without enum_types.""" + with pytest.raises( + ValueError, + match="Channel 'test_channel' has data_type ENUM but enum_types is not provided", + ): + ChannelConfig( + name="test_channel", + data_type=ChannelDataType.ENUM, + ) + + def test_enum_validator_accepts_enum_with_enum_types(self): + """Test validator accepts ENUM data_type with enum_types.""" + # Should not raise + channel = ChannelConfig( + name="test_channel", + data_type=ChannelDataType.ENUM, + enum_types={"LOW": 0, "HIGH": 1}, + ) + assert channel.data_type == ChannelDataType.ENUM + assert channel.enum_types == {"LOW": 0, "HIGH": 1} + + def test_bitfield_validator_rejects_bitfield_without_elements(self): + """Test validator rejects BIT_FIELD data_type without bit_field_elements.""" + with pytest.raises( + ValueError, + match="Channel 'test_channel' has data_type BIT_FIELD but bit_field_elements is not provided", + ): + ChannelConfig( + name="test_channel", + data_type=ChannelDataType.BIT_FIELD, + ) + + def test_bitfield_validator_accepts_bitfield_with_elements(self): + """Test validator accepts BIT_FIELD data_type with bit_field_elements.""" + # Should not raise + channel = ChannelConfig( + name="test_channel", + data_type=ChannelDataType.BIT_FIELD, + bit_field_elements=[ + ChannelBitFieldElement(name="field1", index=0, bit_count=4), + ChannelBitFieldElement(name="field2", index=1, bit_count=4), + ], + ) + assert channel.data_type == ChannelDataType.BIT_FIELD + assert len(channel.bit_field_elements) == 2 + + def test_other_data_types_dont_require_special_fields(self): + """Test that other data types don't require enum_types or bit_field_elements.""" + # Should not raise for DOUBLE + channel = ChannelConfig( + name="test_channel", + data_type=ChannelDataType.DOUBLE, + ) + assert channel.data_type == ChannelDataType.DOUBLE + + +@pytest.fixture +def mock_flow(mock_client): + """Create a mock Flow instance for testing.""" + flow = Flow( + proto=MagicMock(), + name="test_flow", + channels=[ + ChannelConfig( + name="channel1", + data_type=ChannelDataType.DOUBLE, + description="Test channel 1", + ), + ChannelConfig( + name="channel2", + data_type=ChannelDataType.FLOAT, + description="Test channel 2", + ), + ], + ingestion_config_id="test_config_id", + run_id="test_run_id", + ) + flow._apply_client_to_instance(mock_client) + return flow + + +class TestFlow: + """Unit tests for Flow model - tests methods.""" + + def test_add_channel_success(self): + """Test that add_channel() adds a channel when no ingestion_config_id is set.""" + flow = Flow( + name="test_flow", + channels=[], + ingestion_config_id=None, + ) + + channel = ChannelConfig( + name="new_channel", + data_type=ChannelDataType.DOUBLE, + ) + + # Should not raise + flow.add_channel(channel) + + assert len(flow.channels) == 1 + assert flow.channels[0].name == "new_channel" + + def test_add_channel_raises_after_creation(self): + """Test that add_channel() raises ValueError when ingestion_config_id is set.""" + flow = Flow( + name="test_flow", + channels=[], + ingestion_config_id="config123", + ) + + channel = ChannelConfig( + name="new_channel", + data_type=ChannelDataType.DOUBLE, + ) + + with pytest.raises(ValueError, match="Cannot add a channel to a flow after creation"): + flow.add_channel(channel) + + def test_ingest_calls_client(self, mock_flow, mock_client): + """Test that ingest() calls client.async_.ingestion.ingest with correct parameters.""" + timestamp = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + channel_values = {"channel1": 42.5, "channel2": 100.0} + + # Call ingest + mock_flow.ingest(timestamp=timestamp, channel_values=channel_values) + + # Verify client method was called with correct parameters + mock_client.async_.ingestion.ingest.assert_called_once_with( + flow=mock_flow, + timestamp=timestamp, + channel_values=channel_values, + ) + + def test_ingest_raises_without_config_id(self, mock_client): + """Test that ingest() raises ValueError when ingestion_config_id is not set.""" + flow = Flow( + name="test_flow", + channels=[], + ingestion_config_id=None, + ) + flow._apply_client_to_instance(mock_client) + + timestamp = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + channel_values = {"channel1": 42.5} + + with pytest.raises(ValueError, match="Ingestion config ID is not set"): + flow.ingest(timestamp=timestamp, channel_values=channel_values) + + +class TestIngestionConfig: + """Unit tests for IngestionConfig model.""" + + def test_ingestion_config_has_required_fields(self): + """Test that IngestionConfig can be created with required fields.""" + config = IngestionConfig( + proto=MagicMock(), + id_="config123", + asset_id="asset123", + client_key="client_key_123", + ) + + assert config.id_ == "config123" + assert config.asset_id == "asset123" + assert config.client_key == "client_key_123" diff --git a/python/lib/sift_client/_tests/sift_types/test_rule.py b/python/lib/sift_client/_tests/sift_types/test_rule.py new file mode 100644 index 000000000..de9dbe7b1 --- /dev/null +++ b/python/lib/sift_client/_tests/sift_types/test_rule.py @@ -0,0 +1,148 @@ +"""Tests for sift_types.Rule model.""" + +from datetime import datetime, timezone +from unittest.mock import MagicMock + +import pytest + +from sift_client.sift_types import Rule +from sift_client.sift_types.channel import ChannelReference +from sift_client.sift_types.rule import ( + RuleAction, + RuleActionType, + RuleAnnotationType, + RuleUpdate, +) + + +@pytest.fixture +def mock_rule(mock_client): + """Create a mock Rule instance for testing.""" + rule = Rule( + proto=MagicMock(), + id_="test_rule_id", + name="test_rule", + description="test description", + is_enabled=True, + created_date=datetime.now(timezone.utc), + modified_date=datetime.now(timezone.utc), + created_by_user_id="user1", + modified_by_user_id="user1", + organization_id="org1", + is_archived=False, + is_external=False, + expression="$1 > 100", + channel_references=[ + ChannelReference(channel_reference="$1", channel_identifier="channel1"), + ], + action=RuleAction( + action_type=RuleActionType.ANNOTATION, + annotation_type=RuleAnnotationType.DATA_REVIEW, + tags=["tag1"], + ), + asset_ids=["asset1", "asset2"], + asset_tag_ids=["tag1"], + contextual_channels=["channel2"], + client_key=None, + rule_version=None, + archived_date=None, + ) + rule._apply_client_to_instance(mock_client) + return rule + + +class TestRule: + """Unit tests for Rule model - tests properties and methods.""" + + def test_assets_property_calls_client(self, mock_rule, mock_client): + """Test that assets property calls client.assets.list_ with correct parameters.""" + mock_client.assets.list_.return_value = [] + + # Access assets property + _ = mock_rule.assets + + # Verify client method was called with correct parameters + mock_client.assets.list_.assert_called_once_with( + asset_ids=["asset1", "asset2"], _tag_ids=["tag1"] + ) + + def test_update_calls_client_and_updates_self(self, mock_rule, mock_client): + """Test that update() calls client.rules.update and calls _update.""" + updated_rule = MagicMock() + updated_rule.description = "Updated description" + mock_client.rules.update.return_value = updated_rule + + # Mock the _update method to verify it's called + with MagicMock() as mock_update: + mock_rule._update = mock_update + + # Call update + update = RuleUpdate(description="Updated description") + result = mock_rule.update(update) + + # Verify client method was called with correct parameters + mock_client.rules.update.assert_called_once_with( + rule=mock_rule, update=update, version_notes=None + ) + # Verify _update was called with the returned rule + mock_update.assert_called_once_with(updated_rule) + # Verify it returns self + assert result is mock_rule + + def test_update_with_version_notes(self, mock_rule, mock_client): + """Test that update() passes version_notes parameter correctly.""" + updated_rule = MagicMock() + mock_client.rules.update.return_value = updated_rule + + # Mock the _update method + with MagicMock() as mock_update: + mock_rule._update = mock_update + + # Call update with version_notes + update = RuleUpdate(description="Updated") + mock_rule.update(update, version_notes="Test notes") + + # Verify client method was called with version_notes + mock_client.rules.update.assert_called_once_with( + rule=mock_rule, update=update, version_notes="Test notes" + ) + + def test_archive_calls_client_and_updates_self(self, mock_rule, mock_client): + """Test that archive() calls client.rules.archive and calls _update.""" + archived_rule = MagicMock() + archived_rule.is_archived = True + mock_client.rules.archive.return_value = archived_rule + + # Mock the _update method to verify it's called + with MagicMock() as mock_update: + mock_rule._update = mock_update + + # Call archive + result = mock_rule.archive() + + # Verify client method was called + mock_client.rules.archive.assert_called_once_with(rule=mock_rule) + # Verify _update was called with the returned rule + mock_update.assert_called_once_with(archived_rule) + # Verify it returns self + assert result is mock_rule + + def test_unarchive_calls_client_and_updates_self(self, mock_rule, mock_client): + """Test that unarchive() calls client.rules.unarchive and calls _update.""" + unarchived_rule = MagicMock() + unarchived_rule.is_archived = False + mock_client.rules.unarchive.return_value = unarchived_rule + + # Mock the _update method to verify it's called + with MagicMock() as mock_update: + mock_rule._update = mock_update + + # Call unarchive + result = mock_rule.unarchive() + + # Verify client method was called + mock_client.rules.unarchive.assert_called_once_with(rule=mock_rule) + # Verify _update was called with the returned rule + mock_update.assert_called_once_with(unarchived_rule) + # Verify it returns self + assert result is mock_rule diff --git a/python/lib/sift_client/_tests/sift_types/test_run.py b/python/lib/sift_client/_tests/sift_types/test_run.py new file mode 100644 index 000000000..5f9e6b1a1 --- /dev/null +++ b/python/lib/sift_client/_tests/sift_types/test_run.py @@ -0,0 +1,207 @@ +"""Tests for sift_types.Run model.""" + +from datetime import datetime, timezone +from unittest.mock import MagicMock + +import pytest + +from sift_client.sift_types import Run +from sift_client.sift_types.run import RunCreate, RunUpdate + + +class TestRunCreate: + """Unit tests for RunCreate model - tests _to_proto_helpers and validators.""" + + def test_metadata_converter(self): + """Test that metadata is converted using _to_proto_helpers.""" + metadata = {"string_key": "value", "number_key": 3.14, "bool_key": True} + create = RunCreate(name="test_run", metadata=metadata) + proto = create.to_proto() + + assert len(proto.metadata) == 3 + + # Convert list to dict for easier assertion + metadata_dict = {md.key.name: md for md in proto.metadata} + assert metadata_dict["string_key"].string_value == "value" + assert metadata_dict["number_key"].number_value == 3.14 + assert metadata_dict["bool_key"].boolean_value is True + + def test_time_validator_start_before_stop(self): + """Test time validator accepts start_time before stop_time.""" + start_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + stop_time = datetime(2024, 1, 1, 13, 0, 0, tzinfo=timezone.utc) + + # Should not raise + create = RunCreate(name="test_run", start_time=start_time, stop_time=stop_time) + assert create.start_time == start_time + assert create.stop_time == stop_time + + def test_time_validator_rejects_start_after_stop(self): + """Test time validator rejects start_time after stop_time.""" + start_time = datetime(2024, 1, 1, 13, 0, 0, tzinfo=timezone.utc) + stop_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + + with pytest.raises(ValueError, match="start_time must be before stop_time"): + RunCreate(name="test_run", start_time=start_time, stop_time=stop_time) + + def test_time_validator_rejects_stop_without_start(self): + """Test time validator rejects stop_time without start_time.""" + stop_time = datetime(2024, 1, 1, 13, 0, 0, tzinfo=timezone.utc) + + with pytest.raises( + ValueError, match="start_time must be provided if stop_time is provided" + ): + RunCreate(name="test_run", stop_time=stop_time) + + +class TestRunUpdate: + """Unit tests for RunUpdate model - tests _to_proto_helpers and validators.""" + + def test_metadata_converter(self): + """Test that metadata is converted using _to_proto_helpers.""" + metadata = {"key1": "value1", "key2": 42.5, "key3": False} + update = RunUpdate(metadata=metadata) + update.resource_id = "test_run_id" + + proto, mask = update.to_proto_with_mask() + + assert len(proto.metadata) == 3 + + # Convert list to dict for easier assertion + metadata_dict = {md.key.name: md for md in proto.metadata} + assert metadata_dict["key1"].string_value == "value1" + assert metadata_dict["key2"].number_value == 42.5 + assert metadata_dict["key3"].boolean_value is False + assert "metadata" in mask.paths + + def test_time_validator_start_before_stop(self): + """Test time validator accepts start_time before stop_time.""" + start_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + stop_time = datetime(2024, 1, 1, 13, 0, 0, tzinfo=timezone.utc) + + # Should not raise + update = RunUpdate(start_time=start_time, stop_time=stop_time) + assert update.start_time == start_time + assert update.stop_time == stop_time + + def test_time_validator_rejects_start_after_stop(self): + """Test time validator rejects start_time after stop_time.""" + start_time = datetime(2024, 1, 1, 13, 0, 0, tzinfo=timezone.utc) + stop_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + + with pytest.raises(ValueError, match="start_time must be before stop_time"): + RunUpdate(start_time=start_time, stop_time=stop_time) + + def test_time_validator_rejects_stop_without_start(self): + """Test time validator rejects stop_time without start_time.""" + stop_time = datetime(2024, 1, 1, 13, 0, 0, tzinfo=timezone.utc) + + with pytest.raises( + ValueError, match="start_time must be provided if stop_time is provided" + ): + RunUpdate(stop_time=stop_time) + + +@pytest.fixture +def mock_run(mock_client): + """Create a mock Run instance for testing.""" + run = Run( + proto=MagicMock(), + id_="test_run_id", + name="test_run", + description="test", + created_date=datetime.now(timezone.utc), + modified_date=datetime.now(timezone.utc), + created_by_user_id="user1", + modified_by_user_id="user1", + organization_id="org1", + metadata={}, + tags=[], + asset_ids=["asset1", "asset2"], + is_adhoc=False, + is_archived=False, + start_time=None, + stop_time=None, + duration=None, + default_report_id=None, + client_key=None, + archived_date=None, + ) + run._apply_client_to_instance(mock_client) + return run + + +class TestRun: + """Unit tests for Run model - tests properties and methods.""" + + def test_assets_property_calls_client(self, mock_run, mock_client): + """Test that assets property calls client.assets.list_ with correct parameters.""" + mock_client.assets.list_.return_value = [] + + # Access assets property + _ = mock_run.assets + + # Verify client method was called with correct asset_ids + mock_client.assets.list_.assert_called_once_with(asset_ids=["asset1", "asset2"]) + + def test_archive_calls_client_and_updates_self(self, mock_run, mock_client): + """Test that archive() calls client.runs.archive and calls _update.""" + archived_run = MagicMock() + archived_run.is_archived = True + archived_run.archived_date = datetime.now(timezone.utc) + mock_client.runs.archive.return_value = archived_run + + # Mock the _update method to verify it's called + with MagicMock() as mock_update: + mock_run._update = mock_update + + # Call archive + result = mock_run.archive() + + # Verify client method was called + mock_client.runs.archive.assert_called_once_with(run=mock_run) + # Verify _update was called with the returned run + mock_update.assert_called_once_with(archived_run) + # Verify it returns self + assert result is mock_run + + def test_unarchive_calls_client_and_updates_self(self, mock_run, mock_client): + """Test that unarchive() calls client.runs.unarchive and calls _update.""" + unarchived_run = MagicMock() + unarchived_run.is_archived = False + mock_client.runs.unarchive.return_value = unarchived_run + + # Mock the _update method to verify it's called + with MagicMock() as mock_update: + mock_run._update = mock_update + + # Call unarchive + result = mock_run.unarchive() + + # Verify client method was called + mock_client.runs.unarchive.assert_called_once_with(run=mock_run) + # Verify _update was called with the returned run + mock_update.assert_called_once_with(unarchived_run) + # Verify it returns self + assert result is mock_run + + def test_update_calls_client_and_updates_self(self, mock_run, mock_client): + """Test that update() calls client.runs.update and calls _update.""" + updated_run = MagicMock() + updated_run.description = "Updated description" + mock_client.runs.update.return_value = updated_run + + # Mock the _update method to verify it's called + with MagicMock() as mock_update: + mock_run._update = mock_update + + # Call update + update = RunUpdate(description="Updated description") + result = mock_run.update(update) + + # Verify client method was called with correct parameters + mock_client.runs.update.assert_called_once_with(run=mock_run, update=update) + # Verify _update was called with the returned run + mock_update.assert_called_once_with(updated_run) + # Verify it returns self + assert result is mock_run diff --git a/python/lib/sift_client/_tests/test_client.py b/python/lib/sift_client/_tests/test_client.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/lib/sift_client/client.py b/python/lib/sift_client/client.py index 427e4def5..e77ab5e20 100644 --- a/python/lib/sift_client/client.py +++ b/python/lib/sift_client/client.py @@ -127,7 +127,6 @@ def __init__( self.assets = AssetsAPI(self) self.calculated_channels = CalculatedChannelsAPI(self) self.channels = ChannelsAPI(self) - self.ingestion = IngestionAPIAsync(self) self.rules = RulesAPI(self) self.runs = RunsAPI(self) diff --git a/python/lib/sift_client/examples/generic_workflow_example.py b/python/lib/sift_client/examples/generic_workflow_example.py index 08901f70a..1263928ad 100644 --- a/python/lib/sift_client/examples/generic_workflow_example.py +++ b/python/lib/sift_client/examples/generic_workflow_example.py @@ -1,127 +1,127 @@ -import asyncio -import os -from datetime import datetime, timezone - -from sift_client.client import SiftClient - -# Import sift_client types for calculated channels and rules -from sift_client.sift_types import ( - CalculatedChannelUpdate, - ChannelReference, - RuleAction, - RuleAnnotationType, - RuleCreate, - RuleUpdate, -) - -""" -Placeholder for future examples. FD-67 -""" - - -async def main(): - grpc_url = os.getenv("SIFT_GRPC_URI", "localhost:50051") - api_key = os.getenv("SIFT_API_KEY", "") - rest_url = os.getenv("SIFT_REST_URI", "localhost:8080") - client = SiftClient(grpc_url=grpc_url, api_key=api_key, rest_url=rest_url) - - asset = client.assets.find(name="NostromoLV426") - asset_id = asset.id_ - print("Found asset", asset.name) - - calculated_channels = client.calculated_channels.list( - name_regex="velocity_per.*", - asset_id=asset_id, - ) - updated = False - calculated_channel = None - if calculated_channels: - print(f"Found calculated channels: {[cc.name for cc in calculated_channels]}") - for cc in calculated_channels: - if cc.name == "velocity_per_voltage": - calculated_channel = cc.update( - CalculatedChannelUpdate( - expression="$1 / $2 + 0.1", - expression_channel_references=cc.channel_references, - ) - ) - print("Updated calculated channel", calculated_channel) - else: - # Create a calculated channel that divides mainmotor.velocity by voltage - print("\nCreating calculated channel...") - calculated_channel = client.calculated_channels.create( - dict( - name="velocity_per_voltage", - description="Ratio of mainmotor velocity to voltage", - expression="$1 / $2", # $1 = mainmotor.velocity, $2 = voltage - channel_references=[ - ChannelReference( - channel_reference="$1", channel_identifier="mainmotor.velocity" - ), - ChannelReference(channel_reference="$2", channel_identifier="voltage"), - ], - units="velocity/voltage", - asset_ids=[asset_id], - user_notes="Created to monitor velocity-to-voltage ratio", - ) - ) - print( - f"Created calculated channel: {calculated_channel.name} (ID: {calculated_channel.calculated_channel_id})" - ) - - # Create a rule that creates an annotation when the ratio is above 0.1 - rule_search = "high_velocity_voltage" - print(f"Looking for rule containing {rule_search}") - rules = client.rules.list( - name_contains=rule_search, - ) - if rules: - print(f"Found rules: {[rule.name for rule in rules]}") - # Example of batch get if you just had the rule ids: - rules = client.rules.batch_get(rule_ids=[rule.rule_id for rule in rules]) - print(f"Batch get on IDs also works: {[rule.name for rule in rules]}") - - rule = rules[0] - print(f"Updating rule: {rule.name}") - rule = rule.update( - RuleUpdate( - description=f"Alert when velocity-to-voltage ratio exceeds 0.1 (Updated at {datetime.now(tz=timezone.utc).isoformat()})", - asset_ids=[asset_id], - ) - ) - updated = True - else: - print(f"No rules found for {rule_search}") - rules = client.rules.list_( - asset_ids=[asset_id], - ) - if rules: - print(f"However these rules do exist: {[rule.name for rule in rules]}") - print("Attempting to create rule for high_velocity_voltage_ratio_alert") - rule = client.rules.create( - RuleCreate( - name="high_velocity_voltage_ratio_alert", - description="Alert when velocity-to-voltage ratio exceeds 0.1", - expression="$1 > 0.1", - channel_references=[ - ChannelReference( - channel_reference="$1", - channel_identifier=calculated_channel.name, - ), - ], - action=RuleAction.annotation( - annotation_type=RuleAnnotationType.DATA_REVIEW, - tags=["high_ratio", "alert"], - default_assignee_user_id=None, # You can set a user ID here if needed - ), - ) - ) - print(f"Created rule: {rule.name} (ID: {rule.rule_id})") - - if updated: - print("Second run through, deleting rule") - rule.delete() - - -if __name__ == "__main__": - asyncio.run(main()) +# import asyncio +# import os +# from datetime import datetime, timezone +# +# from sift_client.client import SiftClient +# +# # Import sift_client types for calculated channels and rules +# from sift_client.sift_types import ( +# CalculatedChannelUpdate, +# ChannelReference, +# RuleAction, +# RuleAnnotationType, +# RuleCreate, +# RuleUpdate, +# ) +# +# """ +# Placeholder for future examples. FD-67 +# """ +# +# +# async def main(): +# grpc_url = os.getenv("SIFT_GRPC_URI", "localhost:50051") +# api_key = os.getenv("SIFT_API_KEY", "") +# rest_url = os.getenv("SIFT_REST_URI", "localhost:8080") +# client = SiftClient(grpc_url=grpc_url, api_key=api_key, rest_url=rest_url) +# +# asset = client.assets.find(name="NostromoLV426") +# asset_id = asset.id_ +# print("Found asset", asset.name) +# +# calculated_channels = client.calculated_channels.list_( +# name_regex="velocity_per.*", +# asset_id=asset_id, +# ) +# updated = False +# calculated_channel = None +# if calculated_channels: +# print(f"Found calculated channels: {[cc.name for cc in calculated_channels]}") +# for cc in calculated_channels: +# if cc.name == "velocity_per_voltage": +# calculated_channel = cc.update( +# CalculatedChannelUpdate( +# expression="$1 / $2 + 0.1", +# expression_channel_references=cc.channel_references, +# ) +# ) +# print("Updated calculated channel", calculated_channel) +# else: +# # Create a calculated channel that divides mainmotor.velocity by voltage +# print("\nCreating calculated channel...") +# calculated_channel = client.calculated_channels.create( +# dict( +# name="velocity_per_voltage", +# description="Ratio of mainmotor velocity to voltage", +# expression="$1 / $2", # $1 = mainmotor.velocity, $2 = voltage +# channel_references=[ +# ChannelReference( +# channel_reference="$1", channel_identifier="mainmotor.velocity" +# ), +# ChannelReference(channel_reference="$2", channel_identifier="voltage"), +# ], +# units="velocity/voltage", +# asset_ids=[asset_id], +# user_notes="Created to monitor velocity-to-voltage ratio", +# ) +# ) +# print( +# f"Created calculated channel: {calculated_channel.name} (ID: {calculated_channel.calculated_channel_id})" +# ) +# +# # Create a rule that creates an annotation when the ratio is above 0.1 +# rule_search = "high_velocity_voltage" +# print(f"Looking for rule containing {rule_search}") +# rules = client.rules.list( +# name_contains=rule_search, +# ) +# if rules: +# print(f"Found rules: {[rule.name for rule in rules]}") +# # Example of batch get if you just had the rule ids: +# rules = client.rules.batch_get(rule_ids=[rule.rule_id for rule in rules]) +# print(f"Batch get on IDs also works: {[rule.name for rule in rules]}") +# +# rule = rules[0] +# print(f"Updating rule: {rule.name}") +# rule = rule.update( +# RuleUpdate( +# description=f"Alert when velocity-to-voltage ratio exceeds 0.1 (Updated at {datetime.now(tz=timezone.utc).isoformat()})", +# asset_ids=[asset_id], +# ) +# ) +# updated = True +# else: +# print(f"No rules found for {rule_search}") +# rules = client.rules.list_( +# asset_ids=[asset_id], +# ) +# if rules: +# print(f"However these rules do exist: {[rule.name for rule in rules]}") +# print("Attempting to create rule for high_velocity_voltage_ratio_alert") +# rule = client.rules.create( +# RuleCreate( +# name="high_velocity_voltage_ratio_alert", +# description="Alert when velocity-to-voltage ratio exceeds 0.1", +# expression="$1 > 0.1", +# channel_references=[ +# ChannelReference( +# channel_reference="$1", +# channel_identifier=calculated_channel.name, +# ), +# ], +# action=RuleAction.annotation( +# annotation_type=RuleAnnotationType.DATA_REVIEW, +# tags=["high_ratio", "alert"], +# default_assignee_user_id=None, # You can set a user ID here if needed +# ), +# ) +# ) +# print(f"Created rule: {rule.name} (ID: {rule.rule_id})") +# +# if updated: +# print("Second run through, deleting rule") +# rule.delete() +# +# +# if __name__ == "__main__": +# asyncio.run(main()) diff --git a/python/lib/sift_client/resources/_base.py b/python/lib/sift_client/resources/_base.py index 8708411be..a46ea0fa7 100644 --- a/python/lib/sift_client/resources/_base.py +++ b/python/lib/sift_client/resources/_base.py @@ -91,7 +91,9 @@ def _build_time_cel_filters( return filter_parts def _build_tags_metadata_cel_filters( - self, tags: list[Any] | list[str] | None = None, metadata: list[Any] | None = None + self, + tags: list[Any] | list[str] | None = None, + metadata: list[Any] | None = None, ) -> list[str]: filter_parts = [] if tags: @@ -106,14 +108,14 @@ def _build_tags_metadata_cel_filters( def _build_common_cel_filters( self, description_contains: str | None = None, - include_archived: bool = False, + include_archived: bool | None = None, filter_query: str | None = None, ) -> list[str]: filter_parts = [] if description_contains: filter_parts.append(cel.contains("description", description_contains)) - if not include_archived: - filter_parts.append(cel.equals("is_archived", False)) + if include_archived is not None: + filter_parts.append(cel.equals("is_archived", include_archived)) if filter_query: filter_parts.append(filter_query) return filter_parts diff --git a/python/lib/sift_client/resources/ingestion.py b/python/lib/sift_client/resources/ingestion.py index 8abd088ae..a04368c21 100644 --- a/python/lib/sift_client/resources/ingestion.py +++ b/python/lib/sift_client/resources/ingestion.py @@ -67,7 +67,6 @@ async def create_ingestion_config( asset_name=asset_name, flows=flows, client_key=client_key, - organization_id=organization_id, ) for flow in flows: flow._apply_client_to_instance(self.client) diff --git a/python/lib/sift_client/resources/runs.py b/python/lib/sift_client/resources/runs.py index cf62b6b29..1e097612a 100644 --- a/python/lib/sift_client/resources/runs.py +++ b/python/lib/sift_client/resources/runs.py @@ -148,10 +148,10 @@ async def list_( if assets: if all(isinstance(s, str) for s in assets): ids = cast("list[str]", assets) # linting - filter_parts.append(cel.in_("asset_ids", ids)) + filter_parts.append(cel.in_("asset_id", ids)) else: asset = cast("list[Asset]", assets) # linting - filter_parts.append(cel.in_("asset_ids", [a._id_or_error for a in asset])) + filter_parts.append(cel.in_("asset_id", [a._id_or_error for a in asset])) if duration_less_than: filter_parts.append(cel.less_than("duration_string", duration_less_than)) if duration_greater_than: @@ -252,7 +252,7 @@ async def unarchive( async def stop( self, run: str | Run, - ) -> None: + ) -> Run: """Stop a run by setting its stop time to the current time. Args: @@ -260,6 +260,7 @@ async def stop( """ run_id = run._id_or_error if isinstance(run, Run) else run await self._low_level_client.stop_run(run_id=run_id or "") + return await self.get(run_id=run_id) async def create_automatic_association_for_assets( self, diff --git a/python/lib/sift_client/resources/sync_stubs/__init__.pyi b/python/lib/sift_client/resources/sync_stubs/__init__.pyi index 15302378b..85c49bad2 100644 --- a/python/lib/sift_client/resources/sync_stubs/__init__.pyi +++ b/python/lib/sift_client/resources/sync_stubs/__init__.pyi @@ -782,7 +782,7 @@ class RunsAPI: """ ... - def stop(self, run: str | Run) -> None: + def stop(self, run: str | Run) -> Run: """Stop a run by setting its stop time to the current time. Args: diff --git a/python/lib/sift_client/sift_types/_base.py b/python/lib/sift_client/sift_types/_base.py index adbdf37e7..faa3382bb 100644 --- a/python/lib/sift_client/sift_types/_base.py +++ b/python/lib/sift_client/sift_types/_base.py @@ -1,10 +1,18 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, TypeVar +from datetime import datetime +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Generic, + TypeVar, +) from google.protobuf import field_mask_pb2, message -from pydantic import BaseModel, ConfigDict, Field, PrivateAttr +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator if TYPE_CHECKING: from sift_client.client import SiftClient @@ -53,6 +61,17 @@ def _update(self, other: BaseType[ProtoT, SelfT]) -> BaseType[ProtoT, SelfT]: # Make sure we also update the proto since it is excluded self.__dict__["proto"] = other.proto + + return self + + @model_validator(mode="after") + def _validate_timezones(self): + """Validate datetime fiels have timezone information.""" + for field_name in self.model_fields.keys(): + val = getattr(self, field_name) + if isinstance(val, datetime) and val.tzinfo is None: + raise ValueError(f"{field_name} must have timezone information") + return self @@ -73,10 +92,13 @@ class ModelCreateUpdateBase(BaseModel, ABC): """Base class for Pydantic models that generate proto messages.""" model_config = ConfigDict(frozen=False) - _to_proto_helpers: ClassVar[dict[str, MappingHelper]] = PrivateAttr(default={}) + _to_proto_helpers: ClassVar[dict[str, MappingHelper]] = {} def __init__(self, **data: Any): super().__init__(**data) + + @model_validator(mode="after") + def _check_mapping_helpers(self): if self._to_proto_helpers: data = self.model_dump() for expected_field in self._to_proto_helpers.keys(): @@ -84,6 +106,17 @@ def __init__(self, **data: Any): raise ValueError( f"MappingHelper created for {expected_field} but {self.__class__.__name__} has no matching variable names." ) + return self + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + required_annotation = "ClassVar[dict[str, MappingHelper]]" + annotation = cls.__annotations__.get("_to_proto_helpers") + # Check for correct annotation otherwise pydantic will not populate this properly + if annotation and annotation != required_annotation: + raise TypeError( + f"{cls.__name__} must define _to_proto_helpers type as: {required_annotation}" + ) def _build_proto_and_paths( self, proto_msg, data, prefix="", already_setting_path_override=False @@ -112,13 +145,13 @@ def _build_proto_and_paths( if mapping_helper.update_field: paths.append(mapping_helper.update_field) elif isinstance(value, dict): - if field_name in self._to_proto_helpers: - assert self._to_proto_helpers[field_name].converter, ( + if field_name in self.__class__._to_proto_helpers: + assert self.__class__._to_proto_helpers[field_name].converter, ( f"Expecting to run a coverter given a helper was defined for: {field_name}" ) sub_paths = self._build_proto_and_paths( proto_msg, - {field_name: self._to_proto_helpers[field_name].converter(value)}, # type: ignore[misc] + {field_name: self.__class__._to_proto_helpers[field_name].converter(value)}, # type: ignore[misc] "", already_setting_path_override=True, ) @@ -140,13 +173,13 @@ def _build_proto_and_paths( try: repeated_field.extend(value) # Add all new values except TypeError as e: - if field_name in self._to_proto_helpers: - assert self._to_proto_helpers[field_name].converter, ( + if field_name in self.__class__._to_proto_helpers: + assert self.__class__._to_proto_helpers[field_name].converter, ( f"Expecting to run a coverter given a helper was defined for: {field_name}" ) for item in value: repeated_field.append( - self._to_proto_helpers[field_name].converter(**item) # type: ignore + self.__class__._to_proto_helpers[field_name].converter(**item) # type: ignore ) else: raise e @@ -155,7 +188,7 @@ def _build_proto_and_paths( try: setattr(proto_msg, field_name, value) paths.append(path) - except TypeError as e: + except (TypeError, AttributeError) as e: raise TypeError( f"Can't set {field_name} to {value} on {proto_msg.__class__.__name__}" ) from e @@ -178,7 +211,7 @@ def to_proto(self) -> ProtoT: proto_msg = proto_cls() # Get all fields - data = self.model_dump(exclude_none=False) + data = self.model_dump(exclude_unset=True, exclude_none=True) self._build_proto_and_paths(proto_msg, data) return proto_msg @@ -203,8 +236,8 @@ def to_proto_with_mask(self) -> tuple[ProtoT, field_mask_pb2.FieldMask]: proto_cls: type[ProtoT] = self._get_proto_class() proto_msg = proto_cls() - # Get only explicitly set fields, including those set to None - data = self.model_dump(exclude_unset=True, exclude_none=False) + # Get only explicitly set fields + data = self.model_dump(exclude_unset=True, exclude_none=True) paths = self._build_proto_and_paths(proto_msg, data) self._add_resource_id_to_proto(proto_msg) diff --git a/python/lib/sift_client/sift_types/asset.py b/python/lib/sift_client/sift_types/asset.py index 7b02a3dce..dfab26c9f 100644 --- a/python/lib/sift_client/sift_types/asset.py +++ b/python/lib/sift_client/sift_types/asset.py @@ -113,7 +113,7 @@ class AssetUpdate(ModelUpdate[AssetProto]): metadata: dict[str, str | float | bool] | None = None is_archived: bool | None = None - _to_proto_helpers: ClassVar = { + _to_proto_helpers: ClassVar[dict[str, MappingHelper]] = { "metadata": MappingHelper( proto_attr_path="metadata", update_field="metadata", diff --git a/python/lib/sift_client/sift_types/calculated_channel.py b/python/lib/sift_client/sift_types/calculated_channel.py index 987bb386a..d1fca2523 100644 --- a/python/lib/sift_client/sift_types/calculated_channel.py +++ b/python/lib/sift_client/sift_types/calculated_channel.py @@ -63,12 +63,18 @@ def modified_by(self): def archive(self) -> CalculatedChannel: """Archive the calculated channel.""" - self.client.calculated_channels.archive(calculated_channel=self) + updated_calculated_channel = self.client.calculated_channels.archive( + calculated_channel=self + ) + self._update(updated_calculated_channel) return self def unarchive(self) -> CalculatedChannel: """Unarchive the calculated channel.""" - self.client.calculated_channels.unarchive(calculated_channel=self) + updated_calculated_channel = self.client.calculated_channels.unarchive( + calculated_channel=self + ) + self._update(updated_calculated_channel) return self def update( @@ -96,6 +102,7 @@ def _from_proto( cls, proto: CalculatedChannelProto, sift_client: SiftClient | None = None ) -> CalculatedChannel: return cls( + proto=proto, id_=proto.calculated_channel_id, name=proto.name, description=proto.description, @@ -135,7 +142,6 @@ class CalculatedChannelBase(ModelCreateUpdateBase): """Base class for CalculatedChannel create and update models with shared fields and validation.""" description: str | None = None - user_notes: str | None = None units: str | None = None expression: str | None = None @@ -149,7 +155,7 @@ class CalculatedChannelBase(ModelCreateUpdateBase): metadata: dict[str, str | float | bool] | None = None - _to_proto_helpers: ClassVar = { + _to_proto_helpers: ClassVar[dict[str, MappingHelper]] = { "expression": MappingHelper( proto_attr_path="calculated_channel_configuration.query_configuration.sel.expression", update_field="query_configuration", @@ -169,6 +175,7 @@ class CalculatedChannelBase(ModelCreateUpdateBase): ), "all_assets": MappingHelper( proto_attr_path="calculated_channel_configuration.asset_configuration.all_assets", + update_field="asset_configuration", ), "metadata": MappingHelper( proto_attr_path="metadata", @@ -198,6 +205,7 @@ class CalculatedChannelCreate(CalculatedChannelBase, ModelCreate[CreateCalculate """Create model for a Calculated Channel.""" name: str + user_notes: str | None = None client_key: str | None = None def _get_proto_class(self) -> type[CreateCalculatedChannelRequest]: @@ -210,15 +218,6 @@ class CalculatedChannelUpdate(CalculatedChannelBase, ModelUpdate[CalculatedChann name: str | None = None is_archived: bool | None = None - @model_validator(mode="after") - def _validate_non_updatable_fields(self): - """Validate that the fields that cannot be updated are not set.""" - if self.user_notes is not None: - raise ValueError("Cannot update user notes") - if self.client_key is not None: - raise ValueError("Cannot update client key") - return self - def _get_proto_class(self) -> type[CalculatedChannelProto]: return CalculatedChannelProto diff --git a/python/lib/sift_client/sift_types/ingestion.py b/python/lib/sift_client/sift_types/ingestion.py index 2d6d22693..2f90e94f5 100644 --- a/python/lib/sift_client/sift_types/ingestion.py +++ b/python/lib/sift_client/sift_types/ingestion.py @@ -15,14 +15,6 @@ from sift.ingestion_configs.v2.ingestion_configs_pb2 import ( IngestionConfig as IngestionConfigProto, ) -from sift_stream_bindings import ( - ChannelBitFieldElementPy, - ChannelConfigPy, - ChannelDataTypePy, - ChannelEnumTypePy, - FlowConfigPy, - IngestWithConfigDataChannelValuePy, -) from sift_client.sift_types._base import BaseType from sift_client.sift_types.channel import ChannelBitFieldElement, ChannelDataType @@ -30,6 +22,13 @@ if TYPE_CHECKING: from datetime import datetime + from sift_stream_bindings import ( + ChannelConfigPy, + ChannelDataTypePy, + FlowConfigPy, + IngestWithConfigDataChannelValuePy, + ) + from sift_client.client import SiftClient from sift_client.sift_types.channel import Channel @@ -92,14 +91,14 @@ def _from_proto( data_type=ChannelDataType(proto.data_type), description=proto.description if proto.description else None, unit=proto.unit if proto.unit else None, - bit_field_elements=[ - ChannelBitFieldElement._from_proto(el) for el in proto.bit_field_elements - ] - if proto.bit_field_elements - else None, - enum_types={enum.name: enum.key for enum in proto.enum_types} - if proto.enum_types - else None, + bit_field_elements=( + [ChannelBitFieldElement._from_proto(el) for el in proto.bit_field_elements] + if proto.bit_field_elements + else None + ), + enum_types=( + {enum.name: enum.key for enum in proto.enum_types} if proto.enum_types else None + ), _client=sift_client, ) @@ -118,7 +117,7 @@ def from_channel(cls, channel: Channel) -> ChannelConfig: data_type=channel.data_type, description=channel.description, unit=channel.unit, - bit_field_elements=channel.bit_field_elements if channel.bit_field_elements else None, + bit_field_elements=(channel.bit_field_elements if channel.bit_field_elements else None), enum_types=channel.enum_types, ) @@ -127,7 +126,9 @@ def _to_config_proto(self) -> ChannelConfigProto: from sift.common.type.v1.channel_bit_field_element_pb2 import ( ChannelBitFieldElement as ChannelBitFieldElementPb, ) - from sift.common.type.v1.channel_enum_type_pb2 import ChannelEnumType as ChannelEnumTypePb + from sift.common.type.v1.channel_enum_type_pb2 import ( + ChannelEnumType as ChannelEnumTypePb, + ) return ChannelConfigProto( name=self.name, @@ -177,6 +178,8 @@ def _to_proto(self) -> FlowConfig: ) def _to_rust_config(self) -> FlowConfigPy: + from sift_stream_bindings import FlowConfigPy + return FlowConfigPy( name=self.name, channels=[_channel_to_rust_config(channel) for channel in self.channels], @@ -207,7 +210,7 @@ def ingest(self, *, timestamp: datetime, channel_values: dict[str, Any]): """ if self.ingestion_config_id is None: raise ValueError("Ingestion config ID is not set.") - self.client.ingestion.ingest( + self.client.async_.ingestion.ingest( flow=self, timestamp=timestamp, channel_values=channel_values, @@ -216,6 +219,12 @@ def ingest(self, *, timestamp: datetime, channel_values: dict[str, Any]): # Converter functions. def _channel_to_rust_config(channel: ChannelConfig) -> ChannelConfigPy: + from sift_stream_bindings import ( + ChannelBitFieldElementPy, + ChannelConfigPy, + ChannelEnumTypePy, + ) + return ChannelConfigPy( name=channel.name, data_type=_to_rust_type(channel.data_type), @@ -225,12 +234,14 @@ def _channel_to_rust_config(channel: ChannelConfig) -> ChannelConfigPy: ChannelBitFieldElementPy(name=bfe.name, index=bfe.index, bit_count=bfe.bit_count) for bfe in channel.bit_field_elements or [] ], - enum_types=[ - ChannelEnumTypePy(key=enum_key, name=enum_name) - for enum_name, enum_key in channel.enum_types.items() - ] - if channel.enum_types - else [], + enum_types=( + [ + ChannelEnumTypePy(key=enum_key, name=enum_name) + for enum_name, enum_key in channel.enum_types.items() + ] + if channel.enum_types + else [] + ), ) @@ -248,6 +259,8 @@ def _rust_channel_value_from_bitfield( Returns: A ChannelValuePy object. """ + from sift_stream_bindings import IngestWithConfigDataChannelValuePy + assert channel.bit_field_elements is not None # We expect individual ints or bytes to represent full bitfield values. if isinstance(value, bytes) or isinstance(value, int): @@ -272,6 +285,8 @@ def _rust_channel_value_from_bitfield( def _to_rust_value(channel: ChannelConfig, value: Any) -> IngestWithConfigDataChannelValuePy: + from sift_stream_bindings import IngestWithConfigDataChannelValuePy + if value is None: return IngestWithConfigDataChannelValuePy.empty() if channel.data_type == ChannelDataType.ENUM and channel.enum_types is not None: @@ -310,6 +325,8 @@ def _to_rust_value(channel: ChannelConfig, value: Any) -> IngestWithConfigDataCh def _to_rust_type(data_type: ChannelDataType) -> ChannelDataTypePy: + from sift_stream_bindings import ChannelDataTypePy + if data_type == ChannelDataType.DOUBLE: return ChannelDataTypePy.Double elif data_type == ChannelDataType.FLOAT: diff --git a/python/lib/sift_client/sift_types/rule.py b/python/lib/sift_client/sift_types/rule.py index ac4a4e3d6..df045f105 100644 --- a/python/lib/sift_client/sift_types/rule.py +++ b/python/lib/sift_client/sift_types/rule.py @@ -1,5 +1,6 @@ from __future__ import annotations +from datetime import datetime, timezone from enum import Enum from typing import TYPE_CHECKING @@ -32,8 +33,6 @@ from sift_client.sift_types.channel import ChannelReference if TYPE_CHECKING: - from datetime import datetime - from sift_client.client import SiftClient from sift_client.sift_types.asset import Asset @@ -135,8 +134,8 @@ def _from_proto(cls, proto: RuleProto, sift_client: SiftClient | None = None) -> ], action=RuleAction._from_proto(proto.conditions[0].actions[0]), is_enabled=proto.is_enabled, - created_date=proto.created_date.ToDatetime(), - modified_date=proto.modified_date.ToDatetime(), + created_date=proto.created_date.ToDatetime(tzinfo=timezone.utc), + modified_date=proto.modified_date.ToDatetime(tzinfo=timezone.utc), created_by_user_id=proto.created_by_user_id, modified_by_user_id=proto.modified_by_user_id, organization_id=proto.organization_id, @@ -147,7 +146,9 @@ def _from_proto(cls, proto: RuleProto, sift_client: SiftClient | None = None) -> asset_ids=proto.asset_configuration.asset_ids, # type: ignore asset_tag_ids=proto.asset_configuration.tag_ids, # type: ignore contextual_channels=[c.name for c in proto.contextual_channels.channels], - archived_date=(proto.archived_date.ToDatetime() if proto.archived_date else None), + archived_date=( + proto.archived_date.ToDatetime(tzinfo=timezone.utc) if proto.archived_date else None + ), is_archived=proto.is_archived, is_external=proto.is_external, _client=sift_client, @@ -296,9 +297,10 @@ def _from_proto( ) -> RuleAction: action_type = RuleActionType(proto.action_type) return cls( + proto=proto, condition_id=proto.rule_condition_id, - created_date=proto.created_date.ToDatetime(), - modified_date=proto.modified_date.ToDatetime(), + created_date=proto.created_date.ToDatetime(tzinfo=timezone.utc), + modified_date=proto.modified_date.ToDatetime(tzinfo=timezone.utc), created_by_user_id=proto.created_by_user_id, modified_by_user_id=proto.modified_by_user_id, version_id=proto.rule_action_version_id, @@ -358,14 +360,17 @@ def _from_proto( cls, proto: RuleVersionProto, sift_client: SiftClient | None = None ) -> RuleVersion: return cls( + proto=proto, rule_id=proto.rule_id, rule_version_id=proto.rule_version_id, version=proto.version, - created_date=proto.created_date.ToDatetime(), + created_date=proto.created_date.ToDatetime(tzinfo=timezone.utc), created_by_user_id=proto.created_by_user_id, version_notes=proto.version_notes, generated_change_message=proto.generated_change_message, - archived_date=(proto.archived_date.ToDatetime() if proto.archived_date else None), + archived_date=( + proto.archived_date.ToDatetime(tzinfo=timezone.utc) if proto.archived_date else None + ), is_archived=proto.is_archived, _client=sift_client, ) diff --git a/python/lib/sift_client/sift_types/run.py b/python/lib/sift_client/sift_types/run.py index e8292b8d3..eb3e081b1 100644 --- a/python/lib/sift_client/sift_types/run.py +++ b/python/lib/sift_client/sift_types/run.py @@ -56,13 +56,17 @@ def _from_proto(cls, proto: RunProto, sift_client: SiftClient | None = None) -> created_by_user_id=proto.created_by_user_id, modified_by_user_id=proto.modified_by_user_id, organization_id=proto.organization_id, - start_time=proto.start_time.ToDatetime(tzinfo=timezone.utc) - if proto.HasField("start_time") - else None, - stop_time=proto.stop_time.ToDatetime(tzinfo=timezone.utc) - if proto.HasField("stop_time") - else None, - duration=proto.duration.ToTimedelta() if proto.HasField("duration") else None, + start_time=( + proto.start_time.ToDatetime(tzinfo=timezone.utc) + if proto.HasField("start_time") + else None + ), + stop_time=( + proto.stop_time.ToDatetime(tzinfo=timezone.utc) + if proto.HasField("stop_time") + else None + ), + duration=(proto.duration.ToTimedelta() if proto.HasField("duration") else None), name=proto.name, description=proto.description, tags=list(proto.tags), @@ -70,9 +74,11 @@ def _from_proto(cls, proto: RunProto, sift_client: SiftClient | None = None) -> client_key=proto.client_key if proto.HasField("client_key") else None, metadata=metadata_proto_to_dict(proto.metadata), # type: ignore asset_ids=list(proto.asset_ids), - archived_date=proto.archived_date.ToDatetime() - if proto.HasField("archived_date") - else None, + archived_date=( + proto.archived_date.ToDatetime(tzinfo=timezone.utc) + if proto.HasField("archived_date") + else None + ), is_archived=proto.is_archived, is_adhoc=proto.is_adhoc, _client=sift_client, @@ -120,9 +126,11 @@ class RunBase(ModelCreateUpdateBase): tags: list[str] | None = None metadata: dict[str, str | float | bool] | None = None - _to_proto_helpers: ClassVar = { + _to_proto_helpers: ClassVar[dict[str, MappingHelper]] = { "metadata": MappingHelper( - proto_attr_path="metadata", update_field="metadata", converter=metadata_dict_to_proto + proto_attr_path="metadata", + update_field="metadata", + converter=metadata_dict_to_proto, ), } @@ -150,9 +158,11 @@ class RunCreate(RunBase, ModelCreate[CreateRunRequestProto]): stop_time: datetime | None = None organization_id: str | None = None - _to_proto_helpers: ClassVar = { + _to_proto_helpers: ClassVar[dict[str, MappingHelper]] = { "metadata": MappingHelper( - proto_attr_path="metadata", update_field="metadata", converter=metadata_dict_to_proto + proto_attr_path="metadata", + update_field="metadata", + converter=metadata_dict_to_proto, ), } @@ -170,9 +180,11 @@ class RunUpdate(RunBase, ModelUpdate[RunProto]): stop_time: datetime | None = None is_archived: bool | None = None - _to_proto_helpers: ClassVar = { + _to_proto_helpers: ClassVar[dict[str, MappingHelper]] = { "metadata": MappingHelper( - proto_attr_path="metadata", update_field="metadata", converter=metadata_dict_to_proto + proto_attr_path="metadata", + update_field="metadata", + converter=metadata_dict_to_proto, ), } diff --git a/python/mkdocs.yml b/python/mkdocs.yml index 00ae85827..1644828f4 100644 --- a/python/mkdocs.yml +++ b/python/mkdocs.yml @@ -35,6 +35,7 @@ theme: code: IBM Plex Mono features: - navigation.instant + - navigation.instant.progress - navigation.tabs - navigation.sections - navigation.tracking @@ -52,12 +53,21 @@ extra: nav: - Home: index.md + - Examples: + - examples/sift_client.ipynb + - Sift Py API + - Sift Client API (New) +# - Guides: +# - Logging +# - Error Handling plugins: - search - autorefs - mike: # For docs versioning deploy_prefix: 'python' # In case we want to use doc sites for other client libs too + - mkdocs-jupyter: + include_source: True - mkdocstrings: default_handler: python handlers: diff --git a/python/pyproject.toml b/python/pyproject.toml index cc833a58c..993b4236b 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ "PyYAML~=6.0", "pandas~=2.0", "protobuf>=4.0", - "pydantic~=2.0", + "pydantic~=2.10", # Support python 3.9+ typing in older versons of python. "eval-type-backport~=0.2", "pydantic_core~=2.3", @@ -54,6 +54,7 @@ development = [ "pytest-asyncio==0.23.7", "pytest-benchmark==4.0.0", "pytest-mock==3.14.0", + "pytest-dotenv==0.5.2", "ruff~=0.12.10", ] build = ["pdoc==14.5.0", "build==1.2.1"] @@ -63,7 +64,8 @@ docs = ["mkdocs", "mkdocs-include-markdown-plugin", "mkdocs-api-autonav", "mike", - "griffe-pydantic"] + "griffe-pydantic", + "mkdocs-jupyter"] # May be required for certain library functionality openssl = ["pyOpenSSL<24.0.0", "types-pyOpenSSL<24.0.0", "cffi~=1.14"] @@ -200,3 +202,17 @@ select = [ "N", # pep8-naming "TID", # flake8-tidy-imports ] + +[tool.pytest.ini_options] +env_files = [ + ".env" +] +testpaths = [ + "lib/sift_py", + "lib/sift_client/_tests", +] + +markers = [ + "integration: mark a test as an integration test (requires API)" +] + diff --git a/python/scripts/build_utils.py b/python/scripts/build_utils.py index 877658faa..43225ac8e 100644 --- a/python/scripts/build_utils.py +++ b/python/scripts/build_utils.py @@ -4,7 +4,6 @@ import os import subprocess import venv -from itertools import combinations from pathlib import Path from typing import List, Optional from zipfile import ZipFile @@ -35,7 +34,7 @@ def get_extras_from_wheel(wheel_path: str) -> List[str]: def get_extra_combinations(extras: List[str], exclude: Optional[List[str]] = None) -> List[str]: - """Generate all possible combinations of extras. + """Generate different extras permutations for install testing. Args: extras: List of extra names to generate combinations from. @@ -50,10 +49,8 @@ def get_extra_combinations(extras: List[str], exclude: Optional[List[str]] = Non else: filtered_extras = extras - all_combinations = [] - for r in range(len(filtered_extras) + 1): - all_combinations.extend(",".join(c) for c in combinations(filtered_extras, r)) - return all_combinations + # Get only the full extras lists, no additional permutations at the moment + return [",".join(filtered_extras)] def test_install( @@ -107,7 +104,7 @@ def main(): # Get all extras from the wheel extras = get_extras_from_wheel(str(wheel_file)) - combinations = get_extra_combinations(extras, ["development", "docs"]) + combinations = get_extra_combinations(extras, exclude=["development", "docs"]) # Test base installation first test_install( diff --git a/python/scripts/dev b/python/scripts/dev index 1ecf01ba9..7a95a5ced 100755 --- a/python/scripts/dev +++ b/python/scripts/dev @@ -23,6 +23,8 @@ Subcommands: mypy-stubs Runs stubtest (mypy) on the generated pyi stubs pip-install Install project dependencies test Execute tests + test-integration Execute integration tests + test-all Execute all non-integration tests update-dev Copies changes to the dev script over to the venv bin Options: @@ -90,9 +92,18 @@ doc_build() { } run_tests() { + pytest -m "not integration" +} + +run_tests_integration() { + pytest -m "integration" +} + +run_tests_all() { pytest } + gen_stubs() { source venv/bin/activate cd lib @@ -148,6 +159,12 @@ case "$1" in test) run_tests ;; + test-integration) + run_tests_integration + ;; + test-all) + run_tests_all + ;; gen-stubs) gen_stubs ;;