diff --git a/.githooks/pre-push b/.githooks/pre-push index 917b0ee8c..6b13b65f5 100755 --- a/.githooks/pre-push +++ b/.githooks/pre-push @@ -1,61 +1,24 @@ #!/usr/bin/env bash -# ensure generated python stubs are up-to-date, from sync clients and sift_stream_bindings - # Store the root directory of the repository REPO_ROOT="$(git rev-parse --show-toplevel)" -PYTHON_DIR="$REPO_ROOT/python" -BINDINGS_DIR="$REPO_ROOT/rust/crates/sift_stream_bindings" -STUBS_DIR="$PYTHON_DIR/lib/sift_client/resources/sync_stubs" - -# Function to check if generated stub files have changed -check_stub_changes() { - local target_path="$1" - local changed_files=$(git status --porcelain "$target_path" | grep -E '\.pyi$' || true) - - if [ -n "$changed_files" ]; then - echo "ERROR: Generated python stubs are not up-to-date. Please commit the changed files:" - echo "$changed_files" - exit 1 - fi -} - -# Function to generate Python stubs -generate_python_stubs() { - echo "Generating Python stubs..." - cd "$PYTHON_DIR" - - if [[ ! -d "$PYTHON_DIR/venv" ]]; then - echo "Running bootstrap script..." - bash ./scripts/dev bootstrap - fi - - bash ./scripts/dev gen-stubs - check_stub_changes "$STUBS_DIR" -} +GITHOOKS_DIR="$REPO_ROOT/.githooks" -# Function to generate bindings stubs -generate_bindings_stubs() { - echo "Generating bindings stubs..." - cd "$BINDINGS_DIR" - cargo run --bin stub_gen - - # The stub file is generated in the bindings directory - local stub_file="$BINDINGS_DIR/sift_stream_bindings.pyi" - check_stub_changes "$stub_file" -} - -# Check for changes in relevant files +# Check for changes in Python files python_changed_files=($(git diff --name-only --diff-filter=ACM | grep '^python/lib/sift_client/' || true)) -bindings_changed_files=($(git diff --name-only --diff-filter=ACM | grep '^rust/crates/sift_stream_bindings/src/' || true)) -# Generate stubs if needed if [[ -n "$python_changed_files" ]]; then - generate_python_stubs + echo "Python files changed, running Python stub checks..." + bash "$GITHOOKS_DIR/pre-push-python/stubs.sh" fi +# Check for changes in Rust bindings files +bindings_changed_files=($(git diff --name-only --diff-filter=ACM | grep '^rust/crates/sift_stream_bindings/src/' || true)) + if [[ -n "$bindings_changed_files" ]]; then - generate_bindings_stubs + echo "Rust bindings files changed, running Rust stub checks..." + bash "$GITHOOKS_DIR/pre-push-rust/stubs.sh" fi -echo "All stubs are up-to-date." \ No newline at end of file +echo "Pre-push checks completed successfully." + diff --git a/.githooks/pre-push-python/stubs.sh b/.githooks/pre-push-python/stubs.sh new file mode 100644 index 000000000..42d6d712b --- /dev/null +++ b/.githooks/pre-push-python/stubs.sh @@ -0,0 +1,36 @@ +# ensure generated python stubs are up-to-date, from sync clients + +# Store the root directory of the repository +REPO_ROOT="$(git rev-parse --show-toplevel)" +PYTHON_DIR="$REPO_ROOT/python" +STUBS_DIR="$PYTHON_DIR/lib/sift_client/resources/sync_stubs" + +# Function to check if generated stub files have changed +check_stub_changes() { + local target_path="$1" + local changed_files=$(git status --porcelain "$target_path" | grep -E '\.pyi$' || true) + + if [ -n "$changed_files" ]; then + echo "ERROR: Generated python stubs are not up-to-date. Please commit the changed files:" + echo "$changed_files" + exit 1 + fi +} + +# Function to generate Python stubs +generate_python_stubs() { + echo "Generating Python stubs..." + cd "$PYTHON_DIR" + + if [[ ! -d "$PYTHON_DIR/venv" ]]; then + echo "Running bootstrap script..." + bash ./scripts/dev bootstrap + fi + + bash ./scripts/dev gen-stubs + check_stub_changes "$STUBS_DIR" +} + +generate_python_stubs + +echo "All stubs are up-to-date." \ No newline at end of file diff --git a/.githooks/pre-push-rust/stubs.sh b/.githooks/pre-push-rust/stubs.sh new file mode 100644 index 000000000..1af106713 --- /dev/null +++ b/.githooks/pre-push-rust/stubs.sh @@ -0,0 +1,32 @@ +# ensure generated python stubs are up-to-date, from sift_stream_bindings + +# Store the root directory of the repository +REPO_ROOT="$(git rev-parse --show-toplevel)" +BINDINGS_DIR="$REPO_ROOT/rust/crates/sift_stream_bindings" + +# Function to check if generated stub files have changed +check_stub_changes() { + local target_path="$1" + local changed_files=$(git status --porcelain "$target_path" | grep -E '\.pyi$' || true) + + if [ -n "$changed_files" ]; then + echo "ERROR: Generated python stubs are not up-to-date. Please commit the changed files:" + echo "$changed_files" + exit 1 + fi +} + +# Function to generate bindings stubs +generate_bindings_stubs() { + echo "Generating bindings stubs..." + cd "$BINDINGS_DIR" + cargo run --bin stub_gen + + # The stub file is generated in the bindings directory + local stub_file="$BINDINGS_DIR/sift_stream_bindings.pyi" + check_stub_changes "$stub_file" +} + +generate_bindings_stubs + +echo "All stubs are up-to-date." \ No newline at end of file diff --git a/python/lib/sift_client/.ruff.toml b/python/lib/sift_client/.ruff.toml index e5e6a152f..6786523a6 100644 --- a/python/lib/sift_client/.ruff.toml +++ b/python/lib/sift_client/.ruff.toml @@ -49,6 +49,7 @@ ignore = ["W191", "D206", "D300", # https://docs.astral.sh/ruff/formatter/#confl "D105", # Missing docstring in magic method "D205", # 1 blank line required between summary line and description "D100", # Missing docstring in public module + "C408", # Allow dict() ] diff --git a/python/lib/sift_client/_internal/CONTRIBUTING.md b/python/lib/sift_client/_internal/CONTRIBUTING.md index a5aff728c..fc6aa088e 100644 --- a/python/lib/sift_client/_internal/CONTRIBUTING.md +++ b/python/lib/sift_client/_internal/CONTRIBUTING.md @@ -46,12 +46,48 @@ All low-level clients should implement `LowLevelClientBase` from `sift_client/_i ### Sift Types -New Sift types can be implemented in `sift_client/types`. +New Sift types can be implemented in `sift_client/sift_types`. These types are used to define Pydantic models for all domain objects and to convert between protocol buffers and Python. Additional -update models can be implemented for performing updates with field masks. +update and create models can be implemented for performing updates with field masks. -All Sift types should inherit from `BaseType` and model updates from `ModelUpdate` in `sift_client/types/_base.py` +All Sift types should inherit from `BaseType` in `sift_client/sift_types/_base.py` + +#### Create/Update Pydantic Model Inheritance Pattern + +The Sift client uses a composition-based inheritance pattern for Pydantic models to avoid complex multiple inheritance issues: + +1. **Base Classes** (`sift_client/sift_types/_base.py`): + - `ModelCreateUpdateBase`: Base class containing shared functionality for proto conversion and field mapping + - `ModelCreate`: Inherits from `ModelCreateUpdateBase` with generic typing for creation operations + - `ModelUpdate`: Inherits from `ModelCreateUpdateBase` with additional field mask support for updates + +2. **Domain-Specific Base Classes**: + Create a base class that inherits from `ModelCreateUpdateBase` and contains: + - All shared field definitions + - Shared `_to_proto_helpers` configuration for complex proto mappings + - Common validation logic using `@model_validator` + It may not always make sense to implement a base class if there is little/no overlap in fields or protos. + +3. **Create and Update Models**: + - `{Domain}Create`: Inherits from both `{Domain}Base` and `ModelCreate[{CreateProto}]` + - Include create only fields and validators + - `{Domain}Update`: Inherits from both `{Domain}Base` and `ModelUpdate[{UpdateProto}]` + - Include update only fields and validators + +#### Proto Mapping Helpers + +Use `MappingHelper` for complex proto field mappings when the Pydantic model doesn't match the proto model exactly: +- `proto_attr_path`: Dot-separated path to the proto field +- `update_field`: Field name for update masks (optional) +- `converter`: Function/class to convert the value (optional) + +#### Validation Guidelines + +- Use `@model_validator(mode="after")` for cross-field validation +- Prefix validation method names with `_` (e.g., `_validate_time_fields`) since these don't need to be user visible +- Keep validation logic in the base class when shared between create/update +- Add specific validation in create/update classes as needed ### High-Level APIs @@ -62,6 +98,62 @@ Static and class methods should be avoided since these cannot have associated sy All high-level APIs should inherit from `ResourceBase` from `sift_client/resources/_base.py`. +#### Resource Method Patterns + +Resource classes should implement consistent patterns for common operations. Use the helper methods from `ResourceBase` to build standard filter arguments. + +**Important:** Arguments that represent another Sift Type should always accept both the object instance and its ID string. This provides flexibility for users who may have either form. + + +**Note**: If the proto API does not support filters for a resource, the API should be updated to make the resource filterable in a consistent way with other resources. + +Examples: +```python +# Accept either Asset object or asset ID string +async def update(self, asset: Asset | str, ...) -> Asset: +``` + +##### Standard Method Signatures + +**`get(resource_id: str) -> {Type}`** +- Single required positional argument for the resource ID +- Returns the specific resource instance + +**`list_(...) -> list[{Type}]`** +- Use `list_` (with underscore) to avoid conflicts with Python's built-in `list` +- Standard filter arguments in consistent order (as applicable:) + 1. Name filters: `name`, `name_contains`, `name_regex` + 2. Self IDs: Resource-specific ID filters (e.g., `run_ids`, `asset_ids`, `client_keys`) + 3. Created/modified ranges: `created_after`, `created_before`, `modified_after`, `modified_before` + 4. Created/modified users: `created_by`, `modified_by` + 5. Metadata: `metadata`, `tags` + 6. Resource-specific filters: Domain-specific filters (e.g., `assets`, `duration_less_than`, `start_time_after`) + 7. Common filters: `description_contains`, `include_archived`, `filter_query` + 8. Ordering and pagination: `order_by`, `limit`, `page_size`, `page_token` + +**`find(...) -> {Type} | None`** +- Similar signature to `list_` but returns single result or None +- Should use the same filter arguments as `list_` + +**`create(create: {Type}Create | dict, **kwargs) -> {Type}`** +- Accept both Pydantic model and dict +- Additional keyword arguments for operation-specific options + +**`update({resource}: str | {Type}, update: {Type}Update | dict, **kwargs) -> {Type}`** +- First argument accepts either ID string or resource instance +- Update model as second argument +- Additional keyword arguments for operation-specific options + +##### Using ResourceBase Helper Methods + +The `ResourceBase` class provides helper methods to build consistent CEL filter expressions: + +- `_build_name_cel_filters()`: Handles `name`, `name_contains`, `name_regex` +- `_build_time_cel_filters()`: Handles time-based filters and user filters +- `_build_tags_metadata_cel_filters()`: Handles `tags` and `metadata` filters +- `_build_common_cel_filters()`: Handles `description_contains`, `include_archived`, `filter_query` + + #### Sync API Generation To generate a sync API from an async API, add a `generate_sync_api` function call in `sift_client/resources/sync_stubs/__init__.py` and diff --git a/python/lib/sift_client/_internal/gen_pyi.py b/python/lib/sift_client/_internal/gen_pyi.py index e798f997f..6e2594b0f 100644 --- a/python/lib/sift_client/_internal/gen_pyi.py +++ b/python/lib/sift_client/_internal/gen_pyi.py @@ -143,7 +143,9 @@ def generate_stubs_for_module(path_arg: str | pathlib.Path) -> dict[pathlib.Path raw_doc = inspect.getdoc(cls) or "" if raw_doc: doc = ( - ' """\n' + "\n".join(f" {l}" for l in raw_doc.splitlines()) + '\n """' + ' """\n' + + "\n".join(f" {l.strip()}" for l in raw_doc.splitlines()) + + '\n """' ) else: doc = " ..." diff --git a/python/lib/sift_client/_internal/low_level_wrappers/assets.py b/python/lib/sift_client/_internal/low_level_wrappers/assets.py index cede17db9..9c42bd74e 100644 --- a/python/lib/sift_client/_internal/low_level_wrappers/assets.py +++ b/python/lib/sift_client/_internal/low_level_wrappers/assets.py @@ -3,7 +3,8 @@ from typing import Any, cast from sift.assets.v1.assets_pb2 import ( - DeleteAssetRequest, + ArchiveAssetRequest, + ArchiveAssetResponse, GetAssetRequest, GetAssetResponse, ListAssetsRequest, @@ -96,6 +97,8 @@ async def update_asset(self, update: AssetUpdate) -> Asset: updated_grpc_asset = cast("UpdateAssetResponse", response).asset return Asset._from_proto(updated_grpc_asset) - async def delete_asset(self, asset_id: str, archive_runs: bool = False) -> None: - request = DeleteAssetRequest(asset_id=asset_id, archive_runs=archive_runs) - await self._grpc_client.get_stub(AssetServiceStub).DeleteAsset(request) + async def archive_asset(self, asset_id: str, archive_runs: bool = False) -> list[str] | None: + request = ArchiveAssetRequest(asset_id=asset_id, archive_runs=archive_runs) + response = await self._grpc_client.get_stub(AssetServiceStub).ArchiveAsset(request) + response = cast("ArchiveAssetResponse", response) + return response.archived_runs diff --git a/python/lib/sift_client/_internal/low_level_wrappers/calculated_channels.py b/python/lib/sift_client/_internal/low_level_wrappers/calculated_channels.py index 70af5f93e..890146d78 100644 --- a/python/lib/sift_client/_internal/low_level_wrappers/calculated_channels.py +++ b/python/lib/sift_client/_internal/low_level_wrappers/calculated_channels.py @@ -1,15 +1,10 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any, cast +from typing import Any, cast from sift.calculated_channels.v2.calculated_channels_pb2 import ( - CalculatedChannelAbstractChannelReference, - CalculatedChannelAssetConfiguration, - CalculatedChannelConfiguration, - CalculatedChannelQueryConfiguration, CalculatedChannelValidationResult, - CreateCalculatedChannelRequest, CreateCalculatedChannelResponse, GetCalculatedChannelRequest, GetCalculatedChannelResponse, @@ -25,13 +20,11 @@ from sift_client._internal.low_level_wrappers.base import LowLevelClientBase from sift_client.sift_types.calculated_channel import ( CalculatedChannel, + CalculatedChannelCreate, CalculatedChannelUpdate, ) from sift_client.transport import GrpcClient, WithGrpcClient -if TYPE_CHECKING: - from sift_client.sift_types.channel import ChannelReference - logger = logging.getLogger(__name__) @@ -79,72 +72,9 @@ async def get_calculated_channel( return CalculatedChannel._from_proto(grpc_calculated_channel) async def create_calculated_channel( - self, - *, - name: str, - all_assets: bool = False, - asset_ids: list[str] | None = None, - tag_ids: list[str] | None = None, - expression: str = "", - channel_references: list[ChannelReference] | None = None, - description: str = "", - user_notes: str = "", - units: str | None = None, - client_key: str | None = None, + self, *, create: CalculatedChannelCreate ) -> tuple[CalculatedChannel, list[Any]]: - """Create a calculated channel. - - Args: - name: The name of the calculated channel. - all_assets: Whether to include all assets in the calculated channel. - asset_ids: The IDs of the assets to include in the calculated channel. - tag_ids: The IDs of the tags to include in the calculated channel. - expression: The CEL expression for the calculated channel. - channel_references: The channel references to include in the calculated channel. - description: The description of the calculated channel. - user_notes: User notes for the calculated channel. - units: The units for the calculated channel. - client_key: A user-defined unique identifier for the calculated channel. - - Returns: - A tuple of (CalculatedChannel, list of inapplicable assets). - """ - if channel_references is None: - channel_references = [] - - asset_config = CalculatedChannelAssetConfiguration( - all_assets=all_assets, - selection=CalculatedChannelAssetConfiguration.AssetSelection( - asset_ids=asset_ids, - tag_ids=tag_ids, - ), - ) - request_kwargs: dict[str, Any] = { - "name": name, - "description": description, - "user_notes": user_notes, - "calculated_channel_configuration": CalculatedChannelConfiguration( - asset_configuration=asset_config, - query_configuration=CalculatedChannelQueryConfiguration( - sel=CalculatedChannelQueryConfiguration.Sel( - expression=expression, - expression_channel_references=[ - CalculatedChannelAbstractChannelReference( - channel_identifier=ref.channel_identifier, - channel_reference=ref.channel_reference, - ) - for ref in channel_references - ], - ), - ), - ), - } - if units is not None: - request_kwargs["units"] = units - if client_key is not None: - request_kwargs["client_key"] = client_key - - request = CreateCalculatedChannelRequest(**request_kwargs) + request = create.to_proto() response = await self._grpc_client.get_stub( CalculatedChannelServiceStub ).CreateCalculatedChannel(request) diff --git a/python/lib/sift_client/_internal/low_level_wrappers/ingestion.py b/python/lib/sift_client/_internal/low_level_wrappers/ingestion.py index da82e106a..b7c4a4341 100644 --- a/python/lib/sift_client/_internal/low_level_wrappers/ingestion.py +++ b/python/lib/sift_client/_internal/low_level_wrappers/ingestion.py @@ -34,10 +34,10 @@ from sift_client._internal.low_level_wrappers.base import ( LowLevelClientBase, ) +from sift_client._internal.util.timestamp import to_rust_py_timestamp from sift_client.sift_types.ingestion import Flow, IngestionConfig, _to_rust_value from sift_client.transport import GrpcClient, WithGrpcClient from sift_client.util import cel_utils as cel -from sift_client.util.timestamp import to_rust_py_timestamp logger = logging.getLogger(__name__) diff --git a/python/lib/sift_client/_internal/low_level_wrappers/rules.py b/python/lib/sift_client/_internal/low_level_wrappers/rules.py index 676b5fa71..bf5bdb905 100644 --- a/python/lib/sift_client/_internal/low_level_wrappers/rules.py +++ b/python/lib/sift_client/_internal/low_level_wrappers/rules.py @@ -4,23 +4,22 @@ from typing import TYPE_CHECKING, Any, cast from sift.rules.v1.rules_pb2 import ( - BatchDeleteRulesRequest, + ArchiveRuleRequest, + BatchArchiveRulesRequest, BatchGetRulesRequest, - BatchGetRulesResponse, - BatchUndeleteRulesRequest, + BatchUnarchiveRulesRequest, BatchUpdateRulesRequest, BatchUpdateRulesResponse, CalculatedChannelConfig, ContextualChannels, CreateRuleRequest, CreateRuleResponse, - DeleteRuleRequest, GetRuleRequest, GetRuleResponse, ListRulesRequest, RuleAssetConfiguration, RuleConditionExpression, - UndeleteRuleRequest, + UnarchiveRuleRequest, UpdateConditionRequest, UpdateRuleRequest, UpdateRuleResponse, @@ -33,7 +32,7 @@ from sift_client._internal.low_level_wrappers.base import LowLevelClientBase from sift_client.sift_types.rule import ( Rule, - RuleAction, + RuleCreate, RuleUpdate, ) from sift_client.transport import GrpcClient, WithGrpcClient @@ -109,66 +108,50 @@ async def batch_get_rules( request = BatchGetRulesRequest(**request_kwargs) response = await self._grpc_client.get_stub(RuleServiceStub).BatchGetRules(request) - response = cast("BatchGetRulesResponse", response) return [Rule._from_proto(rule) for rule in response.rules] async def create_rule( self, *, - name: str, - description: str, - organization_id: str | None = None, - client_key: str | None = None, - asset_ids: list[str] | None = None, - tag_ids: list[str] | None = None, - contextual_channels: list[str] | None = None, - is_external: bool, - expression: str, - channel_references: list[ChannelReference], - action: RuleAction, + create: RuleCreate, ) -> Rule: """Create a new rule. Args: - name: The name of the rule. - description: The description of the rule. - organization_id: The organization ID of the rule. - client_key: The client key of the rule. - asset_ids: The asset IDs of the rule. - contextual_channels: Optional contextual channels of the rule. + create: The RuleCreate model with the rule configuration. Returns: - The rule ID of the created rule. + The created Rule. """ # Convert rule to UpdateRuleRequest expression_proto = RuleConditionExpression( calculated_channel=CalculatedChannelConfig( - expression=expression, + expression=create.expression, channel_references={ c.channel_reference: ChannelReferenceProto(name=c.channel_identifier) - for c in channel_references + for c in create.channel_references }, ) ) conditions_request = [ UpdateConditionRequest( - expression=expression_proto, actions=[action._to_update_request()] + expression=expression_proto, actions=[create.action._to_update_request()] ) ] update_request = UpdateRuleRequest( - name=name, - description=description, + name=create.name, + description=create.description, is_enabled=True, - organization_id=organization_id or "", - client_key=client_key, - is_external=is_external, + organization_id=create.organization_id or "", + client_key=create.client_key, + is_external=create.is_external, conditions=conditions_request, asset_configuration=RuleAssetConfiguration( - asset_ids=asset_ids or [], - tag_ids=tag_ids or [], + asset_ids=create.asset_ids or [], + tag_ids=create.asset_tag_ids or [], ), contextual_channels=ContextualChannels( - channels=[ChannelReferenceProto(name=c) for c in contextual_channels or []] + channels=[ChannelReferenceProto(name=c) for c in create.contextual_channels or []] ), # type: ignore ) @@ -177,7 +160,7 @@ async def create_rule( "CreateRuleResponse", await self._grpc_client.get_stub(RuleServiceStub).CreateRule(request), ) - return await self.get_rule(rule_id=created_rule.rule_id, client_key=client_key) + return await self.get_rule(rule_id=created_rule.rule_id, client_key=create.client_key) def _update_rule_request_from_update( self, rule: Rule, update: RuleUpdate, version_notes: str | None = None @@ -328,7 +311,7 @@ async def archive_rule(self, rule_id: str | None = None, client_key: str | None if client_key is not None: request_kwargs["client_key"] = client_key - request = DeleteRuleRequest(**request_kwargs) + request = ArchiveRuleRequest(**request_kwargs) await self._grpc_client.get_stub(RuleServiceStub).ArchiveRule(request) async def batch_archive_rules( @@ -338,7 +321,7 @@ async def batch_archive_rules( Args: rule_ids: List of rule IDs to archive. - client_keys: List of client keys to delete. If both are provided, rule_ids will be used. + client_keys: List of client keys to archive. If both are provided, rule_ids will be used. Raises: ValueError: If neither rule_ids nor client_keys is provided. @@ -352,59 +335,55 @@ async def batch_archive_rules( if client_keys is not None: request_kwargs["client_keys"] = client_keys - request = BatchDeleteRulesRequest(**request_kwargs) - await self._grpc_client.get_stub(RuleServiceStub).BatchDeleteRules(request) + request = BatchArchiveRulesRequest(**request_kwargs) + await self._grpc_client.get_stub(RuleServiceStub).BatchArchiveRules(request) - async def restore_rule(self, rule_id: str | None = None, client_key: str | None = None) -> Rule: - """Restore a rule. + async def unarchive_rule( + self, rule_id: str | None = None, client_key: str | None = None + ) -> Rule: + """Unarchive a rule. Args: - rule_id: The rule ID to restore. - client_key: The client key to restore. + rule_id: The rule ID to unarchive. + client_key: The client key to unarchive. Returns: - The restored Rule. + The unarchived Rule. Raises: ValueError: If neither rule_id nor client_key is provided. """ - if rule_id is None and client_key is None: - raise ValueError("Either rule_id or client_key must be provided") - request_kwargs: dict[str, Any] = {} if rule_id is not None: request_kwargs["rule_id"] = rule_id if client_key is not None: request_kwargs["client_key"] = client_key - request = UndeleteRuleRequest(**request_kwargs) - await self._grpc_client.get_stub(RuleServiceStub).UndeleteRule(request) - # Get the restored rule + request = UnarchiveRuleRequest(**request_kwargs) + await self._grpc_client.get_stub(RuleServiceStub).UnarchiveRule(request) + # Get the unarchived rule return await self.get_rule(rule_id=rule_id, client_key=client_key) - async def batch_restore_rules( + async def batch_unarchive_rules( self, rule_ids: list[str] | None = None, client_keys: list[str] | None = None ) -> None: - """Batch restore rules. + """Batch unarchive rules. Args: - rule_ids: List of rule IDs to restore. - client_keys: List of client keys to restore. + rule_ids: List of rule IDs to unarchive. + client_keys: List of client keys to unarchive. Raises: ValueError: If neither rule_ids nor client_keys is provided. """ - if rule_ids is None and client_keys is None: - raise ValueError("Either rule_ids or client_keys must be provided") - request_kwargs: dict[str, Any] = {} if rule_ids is not None: request_kwargs["rule_ids"] = rule_ids if client_keys is not None: request_kwargs["client_keys"] = client_keys - request = BatchUndeleteRulesRequest(**request_kwargs) - await self._grpc_client.get_stub(RuleServiceStub).BatchUndeleteRules(request) + request = BatchUnarchiveRulesRequest(**request_kwargs) + await self._grpc_client.get_stub(RuleServiceStub).BatchUnarchiveRules(request) async def list_rules( self, diff --git a/python/lib/sift_client/_internal/low_level_wrappers/runs.py b/python/lib/sift_client/_internal/low_level_wrappers/runs.py index 067d99c7f..38c020454 100644 --- a/python/lib/sift_client/_internal/low_level_wrappers/runs.py +++ b/python/lib/sift_client/_internal/low_level_wrappers/runs.py @@ -5,9 +5,7 @@ from sift.runs.v2.runs_pb2 import ( CreateAutomaticRunAssociationForAssetsRequest, - CreateRunRequest, CreateRunResponse, - DeleteRunRequest, GetRunRequest, GetRunResponse, ListRunsRequest, @@ -19,9 +17,8 @@ from sift.runs.v2.runs_pb2_grpc import RunServiceStub from sift_client._internal.low_level_wrappers.base import LowLevelClientBase -from sift_client.sift_types.run import Run, RunUpdate +from sift_client.sift_types.run import Run, RunCreate, RunUpdate from sift_client.transport import WithGrpcClient -from sift_client.util.metadata import metadata_dict_to_proto if TYPE_CHECKING: from sift_client.transport.grpc_transport import GrpcClient @@ -56,9 +53,6 @@ async def get_run(self, run_id: str) -> Run: Raises: ValueError: If run_id is not provided. """ - if not run_id: - raise ValueError("run_id must be provided") - request = GetRunRequest(run_id=run_id) response = await self._grpc_client.get_stub(RunServiceStub).GetRun(request) grpc_run = cast("GetRunResponse", response).run @@ -124,88 +118,18 @@ async def list_all_runs( max_results=max_results, ) - async def create_run( - self, - *, - name: str, - description: str, - tags: list[str] | None = None, - start_time: Any | None = None, - stop_time: Any | None = None, - organization_id: str | None = None, - client_key: str | None = None, - metadata: dict[str, str | float | bool] | None = None, - ) -> Run: - """Create a new run. - - Args: - name: The name of the run. - description: The description of the run. - tags: Tags to associate with the run. - start_time: The start time of the run. - stop_time: The stop time of the run. - organization_id: The organization ID. - client_key: A unique client key for the run. - metadata: Metadata values for the run. - - Returns: - The created Run. - """ - request_kwargs: dict[str, Any] = { - "name": name, - "description": description, - } - - if tags is not None: - request_kwargs["tags"] = tags - if start_time is not None: - request_kwargs["start_time"] = start_time - if stop_time is not None: - request_kwargs["stop_time"] = stop_time - if organization_id is not None: - request_kwargs["organization_id"] = organization_id - if client_key is not None: - request_kwargs["client_key"] = client_key - if metadata is not None: - metadata_proto = metadata_dict_to_proto(metadata) - request_kwargs["metadata"] = metadata_proto - - request = CreateRunRequest(**request_kwargs) - response = await self._grpc_client.get_stub(RunServiceStub).CreateRun(request) + async def create_run(self, *, create: RunCreate) -> Run: + request_proto = create.to_proto() + response = await self._grpc_client.get_stub(RunServiceStub).CreateRun(request_proto) grpc_run = cast("CreateRunResponse", response).run return Run._from_proto(grpc_run) - async def update_run(self, run: Run, update: RunUpdate) -> Run: - """Update an existing run. - - Args: - run: The run to update. - update: The updates to apply. - - Returns: - The updated Run. - """ - run_proto, field_mask = update.to_proto_with_mask() - - request = UpdateRunRequest(run=run_proto, update_mask=field_mask) + async def update_run(self, update: RunUpdate) -> Run: + grpc_run, update_mask = update.to_proto_with_mask() + request = UpdateRunRequest(run=grpc_run, update_mask=update_mask) response = await self._grpc_client.get_stub(RunServiceStub).UpdateRun(request) - grpc_run = cast("UpdateRunResponse", response).run - return Run._from_proto(grpc_run) - - async def archive_run(self, run_id: str) -> None: - """Archive a run. - - Args: - run_id: The ID of the run to archive. - - Raises: - ValueError: If run_id is not provided. - """ - if not run_id: - raise ValueError("run_id must be provided") - - request = DeleteRunRequest(run_id=run_id) - await self._grpc_client.get_stub(RunServiceStub).DeleteRun(request) + updated_grpc_run = cast("UpdateRunResponse", response).run + return Run._from_proto(updated_grpc_run) async def stop_run(self, run_id: str) -> None: """Stop a run by setting its stop time to the current time. diff --git a/python/lib/sift_client/_internal/sync_wrapper.py b/python/lib/sift_client/_internal/sync_wrapper.py index b5ce3f43d..eb6d0240e 100644 --- a/python/lib/sift_client/_internal/sync_wrapper.py +++ b/python/lib/sift_client/_internal/sync_wrapper.py @@ -57,7 +57,7 @@ def _run(self, coro): namespace = { "__module__": module, - "__doc__": f"Sync counterpart to `{name}`.\n\n{cls.__doc__ or ''}", + "__doc__": f"Sync counterpart to `{name}`.\n\n{(cls.__doc__ or '').strip()}", "__init__": __init__, "_run": _run, "__qualname__": sync_name, # Add __qualname__ to help static analyzers diff --git a/python/lib/sift_client/_internal/utils/__init__.py b/python/lib/sift_client/_internal/util/__init__.py similarity index 100% rename from python/lib/sift_client/_internal/utils/__init__.py rename to python/lib/sift_client/_internal/util/__init__.py diff --git a/python/lib/sift_client/util/timestamp.py b/python/lib/sift_client/_internal/util/timestamp.py similarity index 100% rename from python/lib/sift_client/util/timestamp.py rename to python/lib/sift_client/_internal/util/timestamp.py diff --git a/python/lib/sift_client/_tests/integrated/calculated_channels.py b/python/lib/sift_client/_tests/integrated/calculated_channels.py index 41a01318a..2d423b17c 100644 --- a/python/lib/sift_client/_tests/integrated/calculated_channels.py +++ b/python/lib/sift_client/_tests/integrated/calculated_channels.py @@ -9,6 +9,7 @@ CalculatedChannelUpdate, ChannelReference, ) +from sift_client.sift_types.calculated_channel import CalculatedChannelCreate """ Comprehensive test script for calculated channels with extensive update field exercises. @@ -55,11 +56,11 @@ async def main(): created_channels = [] for i in range(num_channels): - calculated_channel = client.calculated_channels.create( + new_chan = CalculatedChannelCreate( name=f"test_channel_{unique_name_suffix}_{i}", description=f"Test calculated channel {i} - initial description", expression="$1 / $2", # $1 = mainmotor.velocity, $2 = voltage - channel_references=[ + expression_channel_references=[ ChannelReference(channel_reference="$1", channel_identifier="mainmotor.velocity"), ChannelReference(channel_reference="$2", channel_identifier="voltage"), ], @@ -67,6 +68,7 @@ async def main(): asset_ids=[asset_id], user_notes=f"Created for testing update fields - channel {i}", ) + calculated_channel = client.calculated_channels.create(new_chan) created_channels.append(calculated_channel) print( f"Created calculated channel: {calculated_channel.name} (ID: {calculated_channel.id_})" @@ -236,7 +238,7 @@ async def main(): assert updated_channel_7.tag_ids == [], f"Tag IDs update failed: {updated_channel_7.tag_ids}" versions = client.calculated_channels.list_versions( - calculated_channel_id=channel_1.id_, + calculated_channel=channel_1.id_, limit=10, ) print(f"Found {len(versions)} versions for {created_channels[0].name}") diff --git a/python/lib/sift_client/_tests/integrated/channels.py b/python/lib/sift_client/_tests/integrated/channels.py index 2a721a4bf..73c54e0df 100644 --- a/python/lib/sift_client/_tests/integrated/channels.py +++ b/python/lib/sift_client/_tests/integrated/channels.py @@ -41,7 +41,7 @@ async def main(): # List channels for this asset (find a run w/ data) channels = [] for run in runs: - asset_channels = asset.channels(run_id=run.id_, limit=10) + asset_channels = asset.channels(run=run.id_, limit=10) other_channels = [] for c in asset_channels: if c.name in {"voltage", "gpio", "temperature", "mainmotor.velocity"}: @@ -68,7 +68,7 @@ async def main(): print("Getting data for multiple channels:") perf_start = time.perf_counter() channel_data = client.channels.get_data( - run_id="1d5f5c93-eaaa-48f2-94ff-7ec4337faec7", channels=channels, limit=100 + run="1d5f5c93-eaaa-48f2-94ff-7ec4337faec7", channels=channels, limit=100 ) first_time = time.perf_counter() - perf_start start_time = None diff --git a/python/lib/sift_client/_tests/integrated/ingestion.py b/python/lib/sift_client/_tests/integrated/ingestion.py index 87e9e54e2..b5f49d894 100644 --- a/python/lib/sift_client/_tests/integrated/ingestion.py +++ b/python/lib/sift_client/_tests/integrated/ingestion.py @@ -8,11 +8,10 @@ from sift_client._tests import setup_logger from sift_client.client import SiftClient from sift_client.sift_types.channel import ( - Channel, ChannelBitFieldElement, ChannelDataType, ) -from sift_client.sift_types.ingestion import Flow +from sift_client.sift_types.ingestion import ChannelConfig, Flow from sift_client.transport import SiftConnectionConfig setup_logger() @@ -35,8 +34,8 @@ async def main(): asset = "ian-test-asset" # TODO:Get user id from current user - previously_created_runs = client.runs.list( - name_regex="test-run-.*", created_by_user_id="1eba461b-fa36-4e98-8fe8-ff32d3e43a6e" + previously_created_runs = client.runs.list_( + name_regex="test-run-.*", created_by="1eba461b-fa36-4e98-8fe8-ff32d3e43a6e" ) if previously_created_runs: print(f" Deleting previously created runs: {previously_created_runs}") @@ -45,24 +44,26 @@ async def main(): client.runs.archive(run=run) run = client.runs.create( - name=f"test-run-{datetime.now(tz=timezone.utc).timestamp()}", - description="A test run created via the API", - tags=["api-created", "test"], + dict( + name=f"test-run-{datetime.now(tz=timezone.utc).timestamp()}", + description="A test run created via the API", + tags=["api-created", "test"], + ) ) regular_flow = Flow( name="test-flow", channels=[ - Channel(name="test-channel", data_type=ChannelDataType.DOUBLE), - Channel( + ChannelConfig(name="test-channel", data_type=ChannelDataType.DOUBLE), + ChannelConfig( name="test-enum-channel", data_type=ChannelDataType.ENUM, enum_types={"enum1": 1, "enum2": 2}, ), ], ) - regular_flow.add_channel( - Channel( + regular_flow.add_channelConfig( + ChannelConfig( name="test-bit-field-channel", data_type=ChannelDataType.BIT_FIELD, bit_field_elements=[ @@ -77,7 +78,7 @@ async def main(): highspeed_flow = Flow( name="highspeed-flow", channels=[ - Channel(name="highspeed-channel", data_type=ChannelDataType.DOUBLE), + ChannelConfig(name="highspeed-channel", data_type=ChannelDataType.DOUBLE), ], ) # This seals the flow and ingestion config @@ -88,7 +89,9 @@ async def main(): ) print(f"config_id: {config_id}") try: - regular_flow.add_channel(Channel(name="test-channel", data_type=ChannelDataType.DOUBLE)) + regular_flow.add_channelConfig( + ChannelConfig(name="test-channel", data_type=ChannelDataType.DOUBLE) + ) except ValueError as e: assert repr(e) == "ValueError('Cannot add a channel to a flow after creation')" @@ -97,7 +100,7 @@ async def main(): name="new-asset-flow", channels=[ # Same channel name as the regular flow, but on a different asset. - Channel(name="test-channel", data_type=ChannelDataType.DOUBLE), + ChannelConfig(name="test-channel", data_type=ChannelDataType.DOUBLE), ], ) ] diff --git a/python/lib/sift_client/_tests/integrated/rules.py b/python/lib/sift_client/_tests/integrated/rules.py index 85c1e1ddc..dbe88aeea 100644 --- a/python/lib/sift_client/_tests/integrated/rules.py +++ b/python/lib/sift_client/_tests/integrated/rules.py @@ -10,6 +10,7 @@ RuleAnnotationType, RuleUpdate, ) +from sift_client.sift_types.rule import RuleCreate """ Comprehensive test script for rules with extensive update field exercises. @@ -51,24 +52,28 @@ def main(): created_rules = [] for i in range(num_rules): rule = client.rules.create( - name=f"test_rule_{unique_name_suffix}_{i}", - description=f"Test rule {i} - initial description", - expression="$1 > 0.1", # Simple threshold check - channel_references=[ - ChannelReference(channel_reference="$1", channel_identifier="mainmotor.velocity"), - ], - action=RuleAction.annotation( - annotation_type=RuleAnnotationType.DATA_REVIEW, - tags=["test", "initial"], - default_assignee_user_id=None, - ), - asset_ids=[asset_id], + RuleCreate( + name=f"test_rule_{unique_name_suffix}_{i}", + description=f"Test rule {i} - initial description", + expression="$1 > 0.1", # Simple threshold check + channel_references=[ + ChannelReference( + channel_reference="$1", channel_identifier="mainmotor.velocity" + ), + ], + action=RuleAction.annotation( + annotation_type=RuleAnnotationType.DATA_REVIEW, + tags=["test", "initial"], + default_assignee_user_id=None, + ), + asset_ids=[asset_id], + ) ) created_rules.append(rule) print(f"Created rule: {rule.name} (ID: {rule.id_})") # Find the rules we just created - search_results = client.rules.list( + search_results = client.rules.list_( name_regex=f"test_rule_{unique_name_suffix}.*", ) assert len(search_results) == num_rules, ( @@ -159,9 +164,9 @@ def main(): print(f"Updated {updated_rule_6.name}: complex expression = {updated_rule_6.expression}") # Test 7: Update action to notification type - print("\n--- Test 7: Update action to notification ---") - rule_7 = created_rules[6] - updated_rule_7 = rule_7 + # print("\n--- Test 7: Update action to notification ---") + # rule_7 = created_rules[6] + # updated_rule_7 = rule_7 # Note: Notification actions are not supported yet. # updated_rule_7 = rule_7.update( # RuleUpdate( @@ -179,7 +184,10 @@ def main(): updated_rule_8 = rule_8.update( RuleUpdate( # tag_ids=["tag-123", "tag-456"], # Example tag IDs # TODO: Where are these IDs supposed to come from? They're supposed to be uuids? {grpc_message:"invalid argument: invalid input syntax for type uuid: \"tag-123\" - contextual_channels=["temperature", "pressure"], # Example contextual channels + contextual_channels=[ + "temperature", + "pressure", + ], # Example contextual channels ) ) print(f"Updated {updated_rule_8.name}:") @@ -203,42 +211,6 @@ def main(): except Exception as e: print(f"Invalid expression update failed as expected: {e}") - # Test 9: Batch operations demonstration - print("\n--- Test 9: Batch operations demonstration ---") - all_updated_rules = [ - updated_rule_1, - updated_rule_2, - updated_rule_3, - updated_rule_4, - updated_rule_5, - updated_rule_6, - updated_rule_7, - updated_rule_8, - ] - - # Batch get the updated rules - rule_ids = [rule.id_ for rule in all_updated_rules] - batch_rules = client.rules.batch_get(rule_ids=rule_ids) - print(f"Batch retrieved {len(batch_rules)} rules:") - for rule in batch_rules: - print(f" - {rule.name}: {rule.expression}") - - # Test 10: Archive rules - print("\n--- Test 10: Archive rules ---") - client.rules.archive(rules=created_rules) - - print("\n=== Test Summary ===") - print(f"Created: {len(created_rules)} rules") - print(f"Updated: {len(all_updated_rules)} rules") - - # Verify all rules were processed - assert len(created_rules) == num_rules, ( - f"Expected {num_rules} created rules, got {len(created_rules)}" - ) - assert len(all_updated_rules) == num_rules, ( - f"Expected {num_rules} updated rules, got {len(all_updated_rules)}" - ) - # Additional validation print("\n=== Validation Checks ===") diff --git a/python/lib/sift_client/_tests/integrated/runs.py b/python/lib/sift_client/_tests/integrated/runs.py index 7a8cdf19b..9c55c2aa1 100644 --- a/python/lib/sift_client/_tests/integrated/runs.py +++ b/python/lib/sift_client/_tests/integrated/runs.py @@ -52,7 +52,7 @@ async def main(): # Example 1: List all runs print("\n1. Listing all runs...") - runs = client.runs.list(limit=5) + runs = client.runs.list_(limit=5) print(f" Found {len(runs)} runs:") for run in runs: print(f" - {run.name} (ID: {run.id_}), Organization ID: {run.organization_id}") @@ -61,7 +61,7 @@ async def main(): print("\n2. Testing different filter options...") # Get a sample run for testing filters - sample_runs = client.runs.list(limit=3) + sample_runs = client.runs.list_(limit=3) if not sample_runs: print(" No runs available for filter testing") return @@ -71,30 +71,30 @@ async def main(): # 2a: Filter by exact name print("\n 2a. Filter by exact name...") run_name = sample_run.name - runs = client.runs.list(name=run_name, limit=5) + runs = client.runs.list_(name=run_name, limit=5) print(f" Found {len(runs)} runs with exact name '{run_name}':") for run in runs: print(f" - {run.name} (ID: {run.id_})") # 2b: Filter by name containing text print("\n 2b. Filter by name containing text...") - runs = client.runs.list(name_contains="test", limit=5) + runs = client.runs.list_(name_contains="test", limit=5) print(f" Found {len(runs)} runs with 'test' in name:") for run in runs: print(f" - {run.name}") # 2c: Filter by name using regex print("\n 2c. Filter by name using regex...") - runs = client.runs.list(name_regex=".*test.*", limit=5) + runs = client.runs.list_(name_regex=".*test.*", limit=5) print(f" Found {len(runs)} runs with 'test' in name (regex):") for run in runs: print(f" - {run.name}") # 2d: Filter by exact description - print("\n 2d. Filter by exact description...") + print("\n 2d. Filter by description contains...") if sample_run.description: - runs = client.runs.list(description=sample_run.description, limit=5) - print(f" Found {len(runs)} runs with exact description '{sample_run.description}':") + runs = client.runs.list_(description_contains=sample_run.description, limit=5) + print(f" Found {len(runs)} runs with description contains '{sample_run.description}':") for run in runs: print(f" - {run.name}: {run.description}") else: @@ -102,7 +102,7 @@ async def main(): # 2e: Filter by description containing text print("\n 2e. Filter by description containing text...") - runs = client.runs.list(description_contains="test", limit=5) + runs = client.runs.list_(description_contains="test", limit=5) print(f" Found {len(runs)} runs with 'test' in description:") for run in runs: print(f" - {run.name}: {run.description}") @@ -112,8 +112,8 @@ async def main(): # Calculate duration for sample run if it has start and stop times if sample_run.start_time and sample_run.stop_time: duration_seconds = int((sample_run.stop_time - sample_run.start_time).total_seconds()) - runs = client.runs.list(duration_seconds=duration_seconds, limit=5) - print(f" Found {len(runs)} runs with duration {duration_seconds} seconds:") + runs = client.runs.list_(duration_greater_than=timedelta(seconds=duration_seconds), limit=5) + print(f" Found {len(runs)} runs with duration greater than {duration_seconds} seconds:") for run in runs: if run.start_time and run.stop_time: run_duration = int((run.stop_time - run.start_time).total_seconds()) @@ -124,10 +124,8 @@ async def main(): # 2g: Filter by client key print("\n 2g. Filter by client key...") if sample_run.client_key: - runs = client.runs.list(client_key=sample_run.client_key, limit=5) - print(f" Found {len(runs)} runs with client key '{sample_run.client_key}':") - for run in runs: - print(f" - {run.name} (client_key: {run.client_key})") + run = client.runs.get(client_key=sample_run.client_key) + print(f" Found run with client key '{run.name}'") else: print(" No client key available for testing") @@ -135,7 +133,7 @@ async def main(): print("\n 2h. Filter by asset ID...") if sample_run.asset_ids: asset_id = sample_run.asset_ids[0] - runs = client.runs.list(asset_id=asset_id, limit=5) + runs = client.runs.list_(assets=[asset_id], limit=5) print(f" Found {len(runs)} runs associated with asset {asset_id}:") for run in runs: print(f" - {run.name} (asset_ids: {list(run.asset_ids)})") @@ -144,42 +142,42 @@ async def main(): # 2i: Filter by asset name print("\n 2i. Filter by asset name...") - runs = client.runs.list(asset_name="NostromoLV426", limit=5) + runs = client.runs.list_(assets=[client.assets.find(name="NostromoLV426")], limit=5) print(f" Found {len(runs)} runs associated with asset 'NostromoLV426':") for run in runs: print(f" - {run.name}") # 2j: Filter by created by user ID print("\n 2j. Filter by created by user ID...") - created_by_user_id = sample_run.created_by_user_id - runs = client.runs.list(created_by_user_id=created_by_user_id, limit=5) - print(f" Found {len(runs)} runs created by user {created_by_user_id}:") + created_by = sample_run.created_by_user_id + runs = client.runs.list_(created_by=created_by, limit=5) + print(f" Found {len(runs)} runs created by user {created_by}:") for run in runs: - print(f" - {run.name} (created by: {run.created_by_user_id})") + print(f" - {run.name} (created by: {run.created_by})") # 2l: Test ordering options print("\n 2l. Testing ordering options...") # Order by name ascending - runs = client.runs.list(order_by="name", limit=3) + runs = client.runs.list_(order_by="name", limit=3) print(" First 3 runs ordered by name (ascending):") for run in runs: print(f" - {run.name}") # Order by name descending - runs = client.runs.list(order_by="name desc", limit=3) + runs = client.runs.list_(order_by="name desc", limit=3) print(" First 3 runs ordered by name (descending):") for run in runs: print(f" - {run.name}") # Order by creation date (newest first - default) - runs = client.runs.list(order_by="created_date desc", limit=3) + runs = client.runs.list_(order_by="created_date desc", limit=3) print(" First 3 runs ordered by creation date (newest first):") for run in runs: print(f" - {run.name} (created: {run.created_date})") # Order by creation date (oldest first) - runs = client.runs.list(order_by="created_date", limit=3) + runs = client.runs.list_(order_by="created_date", limit=3) print(" First 3 runs ordered by creation date (oldest first):") for run in runs: print(f" - {run.name} (created: {run.created_date})") @@ -206,7 +204,7 @@ async def main(): start_time = datetime.now(timezone.utc) stop_time = start_time + timedelta(minutes=2) - previously_created_runs = client.runs.list(name_regex="Example Test Run.*") + previously_created_runs = client.runs.list_(name_regex="Example Test Run.*") if previously_created_runs: print(f" Deleting previously created runs: {previously_created_runs}") for run in previously_created_runs: @@ -214,14 +212,16 @@ async def main(): client.runs.archive(run=run) new_run = client.runs.create( - name=f"Example Test Run {datetime.now(tz=timezone.utc).strftime('%Y-%m-%d %H:%M:%S')}", - description="A test run created via the API", - tags=["api-created", "test"], - start_time=start_time, - stop_time=stop_time, - # Use a unique client key for each run - client_key=f"example-run-key-{datetime.now(tz=timezone.utc).timestamp()}", - metadata=metadata, + dict( + name=f"Example Test Run {datetime.now(tz=timezone.utc).strftime('%Y-%m-%d %H:%M:%S')}", + description="A test run created via the API", + tags=["api-created", "test"], + start_time=start_time, + stop_time=stop_time, + # Use a unique client key for each run + client_key=f"example-run-key-{datetime.now(tz=timezone.utc).timestamp()}", + metadata=metadata, + ) ) print(f" Created run: {new_run.name} (ID: {new_run.id_})") print(f" Client key: {new_run.client_key}") @@ -257,7 +257,7 @@ async def main(): # Example 6: Associate assets with a run print("\n6. Associating assets with a run...") - ongoing_runs = client.runs.list( + ongoing_runs = client.runs.list_( name_regex="Example Test Run.*", include_archived=True, is_stopped=False ) if ongoing_runs: diff --git a/python/lib/sift_client/_tests/util/test_cel_utils.py b/python/lib/sift_client/_tests/util/test_cel_utils.py index f2f0cac9d..1ba1ccbe5 100644 --- a/python/lib/sift_client/_tests/util/test_cel_utils.py +++ b/python/lib/sift_client/_tests/util/test_cel_utils.py @@ -1,5 +1,5 @@ import re -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from sift_client.util.cel_utils import ( and_, @@ -51,8 +51,8 @@ def test_equals_number(self): def test_equals_boolean(self): """Test equals function with boolean value.""" - assert equals("field", True) == "field == True" - assert equals("field", False) == "field == False" + assert equals("field", True) == "field == true" + assert equals("field", False) == "field == false" def test_equals_all_empty(self): """Test equals_all function with empty dict.""" @@ -154,7 +154,7 @@ def test_greater_than_number(self): def test_greater_than_datetime(self): """Test greater_than function with datetime value.""" dt = datetime(2023, 1, 1, 12, 0, 0, tzinfo=timezone.utc) - assert greater_than("field", dt) == f"field > {dt.isoformat()}" + assert greater_than("field", dt) == f"field > timestamp('{dt.isoformat()}')" def test_less_than_number(self): """Test less_than function with numeric value.""" @@ -164,4 +164,16 @@ def test_less_than_number(self): def test_less_than_datetime(self): """Test less_than function with datetime value.""" dt = datetime(2023, 1, 1, 12, 0, 0, tzinfo=timezone.utc) - assert less_than("field", dt) == f"field < {dt.isoformat()}" + assert less_than("field", dt) == f"field < timestamp('{dt.isoformat()}')" + + def test_greater_than_duration(self): + """Test greater_than function with timedelta value.""" + duration = timedelta(hours=2, minutes=30, seconds=15) + expected_seconds = duration.total_seconds() + assert greater_than("field", duration) == f"field > duration('{expected_seconds}s')" + + def test_less_than_duration(self): + """Test less_than function with timedelta value.""" + duration = timedelta(hours=1, minutes=15, seconds=30) + expected_seconds = duration.total_seconds() + assert less_than("field", duration) == f"field < duration('{expected_seconds}s')" diff --git a/python/lib/sift_client/examples/generic_workflow_example.py b/python/lib/sift_client/examples/generic_workflow_example.py index 307c00547..08901f70a 100644 --- a/python/lib/sift_client/examples/generic_workflow_example.py +++ b/python/lib/sift_client/examples/generic_workflow_example.py @@ -10,6 +10,7 @@ ChannelReference, RuleAction, RuleAnnotationType, + RuleCreate, RuleUpdate, ) @@ -49,16 +50,20 @@ async def main(): # Create a calculated channel that divides mainmotor.velocity by voltage print("\nCreating calculated channel...") calculated_channel = client.calculated_channels.create( - name="velocity_per_voltage", - description="Ratio of mainmotor velocity to voltage", - expression="$1 / $2", # $1 = mainmotor.velocity, $2 = voltage - channel_references=[ - ChannelReference(channel_reference="$1", channel_identifier="mainmotor.velocity"), - ChannelReference(channel_reference="$2", channel_identifier="voltage"), - ], - units="velocity/voltage", - asset_ids=[asset_id], - user_notes="Created to monitor velocity-to-voltage ratio", + dict( + name="velocity_per_voltage", + description="Ratio of mainmotor velocity to voltage", + expression="$1 / $2", # $1 = mainmotor.velocity, $2 = voltage + channel_references=[ + ChannelReference( + channel_reference="$1", channel_identifier="mainmotor.velocity" + ), + ChannelReference(channel_reference="$2", channel_identifier="voltage"), + ], + units="velocity/voltage", + asset_ids=[asset_id], + user_notes="Created to monitor velocity-to-voltage ratio", + ) ) print( f"Created calculated channel: {calculated_channel.name} (ID: {calculated_channel.calculated_channel_id})" @@ -87,26 +92,29 @@ async def main(): updated = True else: print(f"No rules found for {rule_search}") - rules = client.rules.search( + rules = client.rules.list_( asset_ids=[asset_id], ) if rules: print(f"However these rules do exist: {[rule.name for rule in rules]}") print("Attempting to create rule for high_velocity_voltage_ratio_alert") rule = client.rules.create( - name="high_velocity_voltage_ratio_alert", - description="Alert when velocity-to-voltage ratio exceeds 0.1", - expression="$1 > 0.1", - channel_references=[ - ChannelReference( - channel_reference="$1", channel_identifier=calculated_channel.name + RuleCreate( + name="high_velocity_voltage_ratio_alert", + description="Alert when velocity-to-voltage ratio exceeds 0.1", + expression="$1 > 0.1", + channel_references=[ + ChannelReference( + channel_reference="$1", + channel_identifier=calculated_channel.name, + ), + ], + action=RuleAction.annotation( + annotation_type=RuleAnnotationType.DATA_REVIEW, + tags=["high_ratio", "alert"], + default_assignee_user_id=None, # You can set a user ID here if needed ), - ], - action=RuleAction.annotation( - annotation_type=RuleAnnotationType.DATA_REVIEW, - tags=["high_ratio", "alert"], - default_assignee_user_id=None, # You can set a user ID here if needed - ), + ) ) print(f"Created rule: {rule.name} (ID: {rule.rule_id})") diff --git a/python/lib/sift_client/resources/_base.py b/python/lib/sift_client/resources/_base.py index 2170aad64..8708411be 100644 --- a/python/lib/sift_client/resources/_base.py +++ b/python/lib/sift_client/resources/_base.py @@ -1,13 +1,17 @@ from __future__ import annotations from abc import ABC -from typing import TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar from sift_client.errors import _sift_client_experimental_warning +from sift_client.util import cel_utils as cel _sift_client_experimental_warning() if TYPE_CHECKING: + import re + from datetime import datetime + from sift_client.client import SiftClient from sift_client.sift_types._base import BaseType from sift_client.transport.base_connection import GrpcClient, RestClient @@ -39,3 +43,77 @@ def _apply_client_to_instance(self, instance: T) -> T: def _apply_client_to_instances(self, instances: list[T]) -> list[T]: return [self._apply_client_to_instance(i) for i in instances] + + # Common CEL filters used in resources + def _build_name_cel_filters( + self, + name: str | None = None, + name_contains: str | None = None, + name_regex: str | re.Pattern | None = None, + ) -> list[str]: + filter_parts = [] + if name: + filter_parts.append(cel.equals("name", name)) + if name_contains: + filter_parts.append(cel.contains("name", name_contains)) + if name_regex: + filter_parts.append(cel.match("name", name_regex)) + return filter_parts + + def _build_time_cel_filters( + self, + created_after: datetime | None = None, + created_before: datetime | None = None, + modified_after: datetime | None = None, + modified_before: datetime | None = None, + created_by: Any | str | None = None, + modified_by: Any | str | None = None, + ) -> list[str]: + filter_parts = [] + if created_after: + filter_parts.append(cel.greater_than("created_date", created_after)) + if created_before: + filter_parts.append(cel.less_than("created_date", created_before)) + if modified_after: + filter_parts.append(cel.greater_than("modified_date", modified_after)) + if modified_before: + filter_parts.append(cel.less_than("modified_date", modified_before)) + if created_by: + if isinstance(created_by, str): + filter_parts.append(cel.equals("created_by_user_id", created_by)) + else: + raise NotImplementedError + if modified_by: + if isinstance(modified_by, str): + filter_parts.append(cel.equals("modified_by_user_id", created_by)) + else: + raise NotImplementedError + return filter_parts + + def _build_tags_metadata_cel_filters( + self, tags: list[Any] | list[str] | None = None, metadata: list[Any] | None = None + ) -> list[str]: + filter_parts = [] + if tags: + if all(isinstance(tag, str) for tag in tags): + filter_parts.append(cel.in_("tag_name", tags)) + else: + raise NotImplementedError + if metadata: + raise NotImplementedError + return filter_parts + + def _build_common_cel_filters( + self, + description_contains: str | None = None, + include_archived: bool = False, + filter_query: str | None = None, + ) -> list[str]: + filter_parts = [] + if description_contains: + filter_parts.append(cel.contains("description", description_contains)) + if not include_archived: + filter_parts.append(cel.equals("is_archived", False)) + if filter_query: + filter_parts.append(filter_query) + return filter_parts diff --git a/python/lib/sift_client/resources/assets.py b/python/lib/sift_client/resources/assets.py index dec34bfcb..6d1727766 100644 --- a/python/lib/sift_client/resources/assets.py +++ b/python/lib/sift_client/resources/assets.py @@ -5,7 +5,7 @@ from sift_client._internal.low_level_wrappers.assets import AssetsLowLevelClient from sift_client.resources._base import ResourceBase from sift_client.sift_types.asset import Asset, AssetUpdate -from sift_client.util import cel_utils +from sift_client.util import cel_utils as cel if TYPE_CHECKING: import re @@ -48,21 +48,13 @@ async def get( Returns: The Asset. """ - if asset_id: + asset: Asset | None + if asset_id is not None: asset = await self._low_level_client.get_asset(asset_id) - - elif name: - assets = await self._low_level_client.list_all_assets( - query_filter=cel_utils.equals("name", name) - ) - if len(assets) < 1: + elif name is not None: + asset = await self.find(name=name) + if asset is None: raise ValueError(f"No asset found with name '{name}'") - if len(assets) > 1: - raise ValueError( - f"Multiple ({len(assets)}) assets found with name '{name}'" - ) # should not happen - asset = assets[0] - else: raise ValueError("Either asset_id or name must be provided") @@ -71,19 +63,28 @@ async def get( async def list_( self, *, + # name name: str | None = None, name_contains: str | None = None, name_regex: str | re.Pattern | None = None, + # self ids asset_ids: list[str] | None = None, + # created/modified ranges created_after: datetime | None = None, created_before: datetime | None = None, modified_after: datetime | None = None, modified_before: datetime | None = None, - created_by: Any | None = None, - modified_by: Any | None = None, - tags: list[str] | None = None, - tag_ids: list[str] | None = None, + # created/modified users + created_by: Any | str | None = None, + modified_by: Any | str | None = None, + # tags + tags: list[Any] | list[str] | None = None, + _tag_ids: list[str] + | None = None, # For compatibility until first class Tag support is added + # metadata metadata: list[Any] | None = None, + # common filters + description_contains: str | None = None, include_archived: bool = False, filter_query: str | None = None, order_by: str | None = None, @@ -92,63 +93,54 @@ async def list_( """List assets with optional filtering. Args: - asset_ids: List of asset IDs to filter by. name: Exact name of the asset. name_contains: Partial name of the asset. - name_regex: Regular expression string to filter assets by name. - asset_ids: List of asset IDs to filter by. - created_after: Created after this date. - created_before: Created before this date. - modified_after: Modified after this date. - modified_before: Modified before this date. - created_by: Assets created by this user. - modified_by: Assets last modified by this user. - tags: Assets with these tags. - tag_ids: List of asset tag IDs to filter by. - metadata: metadata filter - include_archived: Include archived assets. + name_regex: Regular expression to filter assets by name. + asset_ids: Filter to assets with any of these Ids. + created_after: Filter assets created after this datetime. + created_before: Filter assets created before this datetime. + modified_after: Filter assets modified after this datetime. + modified_before: Filter assets modified before this datetime. + created_by: Filter assets created by this User or user ID. + modified_by: Filter assets last modified by this User or user ID. + tags: Filter assets with any of these Tags or tag names. + metadata: Filter assets by metadata criteria. + description_contains: Partial description of the asset. + include_archived: If True, include archived assets in results. filter_query: Explicit CEL query to filter assets. - order_by: How to order the retrieved assets. # TODO: tooling for this? - limit: How many assets to retrieve. If None, retrieves all matches. + order_by: Field and direction to order results by. + limit: Maximum number of assets to return. If None, returns all matches. Returns: - A list of Assets that matches the filter. - + A list of Asset objects that match the filter criteria. """ - if not filter_query: - filters = [] - if name: - filters.append(cel_utils.equals("name", name)) - if name_contains: - filters.append(cel_utils.contains("name", name_contains)) - if name_regex: - filters.append(cel_utils.match("name", name_regex)) - if asset_ids: - filters.append(cel_utils.in_("asset_id", asset_ids)) - if created_after: - filters.append(cel_utils.greater_than("created_date", created_after)) - if created_before: - filters.append(cel_utils.less_than("created_date", created_before)) - if modified_after: - filters.append(cel_utils.greater_than("modified_date", modified_after)) - if modified_before: - filters.append(cel_utils.less_than("modified_date", modified_before)) - if created_by: - raise NotImplementedError - if modified_by: - raise NotImplementedError - if tags: - filters.append(cel_utils.in_("tag_name", tags)) - if tag_ids: - filters.append(cel_utils.in_("tag_ids", tag_ids)) - if metadata: - raise NotImplementedError - if not include_archived: - filters.append(cel_utils.equals_null("archived_date")) - filter_query = cel_utils.and_(*filters) + filter_parts = [ + *self._build_name_cel_filters( + name=name, name_contains=name_contains, name_regex=name_regex + ), + *self._build_time_cel_filters( + created_after=created_after, + created_before=created_before, + modified_after=modified_after, + modified_before=modified_before, + created_by=created_by, + modified_by=modified_by, + ), + *self._build_tags_metadata_cel_filters(tags=tags, metadata=metadata), + *self._build_common_cel_filters( + description_contains=description_contains, + include_archived=include_archived, + filter_query=filter_query, + ), + ] + if asset_ids: + filter_parts.append(cel.in_("asset_id", asset_ids)) + if _tag_ids: + filter_parts.append(cel.in_("tag_id", _tag_ids)) + filter_query = cel.and_(*filter_parts) assets = await self._low_level_client.list_all_assets( - query_filter=filter_query, + query_filter=filter_query or None, order_by=order_by, max_results=limit, ) @@ -171,6 +163,24 @@ async def find(self, **kwargs) -> Asset | None: return assets[0] return None + async def update(self, asset: str | Asset, update: AssetUpdate | dict) -> Asset: + """Update an Asset. + + Args: + asset: The Asset or asset ID to update. + update: Updates to apply to the Asset. + + Returns: + The updated Asset. + + """ + asset_id = asset._id_or_error if isinstance(asset, Asset) else asset + if isinstance(update, dict): + update = AssetUpdate.model_validate(update) + update.resource_id = asset_id + asset = await self._low_level_client.update_asset(update=update) + return self._apply_client_to_instance(asset) + async def archive(self, asset: str | Asset, *, archive_runs: bool = False) -> Asset: """Archive an asset. @@ -181,26 +191,19 @@ async def archive(self, asset: str | Asset, *, archive_runs: bool = False) -> As Returns: The archived Asset. """ - asset_id = asset.id_ or "" if isinstance(asset, Asset) else asset + asset_id = asset._id_or_error if isinstance(asset, Asset) else asset - await self._low_level_client.delete_asset(asset_id or "", archive_runs=archive_runs) + await self._low_level_client.archive_asset(asset_id, archive_runs=archive_runs) return await self.get(asset_id=asset_id) - async def update(self, asset: str | Asset, update: AssetUpdate | dict) -> Asset: - """Update an Asset. + async def unarchive(self, asset: str | Asset) -> Asset: + """Unarchive an asset. Args: - asset: The Asset or asset ID to update. - update: Updates to apply to the Asset. + asset: The Asset or asset ID to unarchive. Returns: - The updated Asset. - + The unarchived Asset. """ - asset_id = asset.id_ or "" if isinstance(asset, Asset) else asset - if isinstance(update, dict): - update = AssetUpdate.model_validate(update) - update.resource_id = asset_id - asset = await self._low_level_client.update_asset(update=update) - return self._apply_client_to_instance(asset) + return await self.update(asset, AssetUpdate(is_archived=False)) diff --git a/python/lib/sift_client/resources/calculated_channels.py b/python/lib/sift_client/resources/calculated_channels.py index 7c786ebf9..2c1f30750 100644 --- a/python/lib/sift_client/resources/calculated_channels.py +++ b/python/lib/sift_client/resources/calculated_channels.py @@ -1,23 +1,25 @@ from __future__ import annotations -from datetime import datetime, timezone from typing import TYPE_CHECKING, Any from sift_client._internal.low_level_wrappers.calculated_channels import ( CalculatedChannelsLowLevelClient, ) from sift_client.resources._base import ResourceBase +from sift_client.sift_types.asset import Asset from sift_client.sift_types.calculated_channel import ( CalculatedChannel, + CalculatedChannelCreate, CalculatedChannelUpdate, ) +from sift_client.sift_types.run import Run from sift_client.util import cel_utils as cel if TYPE_CHECKING: import re + from datetime import datetime from sift_client.client import SiftClient - from sift_client.sift_types.channel import ChannelReference class CalculatedChannelsAPIAsync(ResourceBase): @@ -46,14 +48,12 @@ async def get( *, calculated_channel_id: str | None = None, client_key: str | None = None, - organization_id: str | None = None, ) -> CalculatedChannel: """Get a Calculated Channel. Args: calculated_channel_id: The ID of the calculated channel. client_key: The client key of the calculated channel. - organization_id: The organization ID (required if using client_key and user belongs to multiple organizations). Returns: The CalculatedChannel. @@ -61,13 +61,9 @@ async def get( Raises: ValueError: If neither calculated_channel_id nor client_key is provided. """ - if not calculated_channel_id and not client_key: - raise ValueError("Either calculated_channel_id or client_key must be provided") - calculated_channel = await self._low_level_client.get_calculated_channel( calculated_channel_id=calculated_channel_id, client_key=client_key, - organization_id=organization_id, ) return self._apply_client_to_instance(calculated_channel) @@ -78,92 +74,98 @@ async def list_( name: str | None = None, name_contains: str | None = None, name_regex: str | re.Pattern | None = None, + # self ids + calculated_channel_ids: list[str] | None = None, + client_keys: list[str] | None = None, + # created/modified ranges created_after: datetime | None = None, created_before: datetime | None = None, modified_after: datetime | None = None, modified_before: datetime | None = None, - created_by: Any | None = None, - modified_by: Any | None = None, - client_key: str | None = None, - asset_id: str | None = None, - asset_name: str | None = None, - tag_id: str | None = None, - tag_name: str | None = None, + # created/modified users + created_by: Any | str | None = None, + modified_by: Any | str | None = None, + # tags + tags: list[Any] | list[str] | None = None, + # metadata + metadata: list[Any] | None = None, + # calculated channel specific + asset: Asset | str | None = None, + run: Run | str | None = None, version: int | None = None, + # common filters + description_contains: str | None = None, include_archived: bool = False, filter_query: str | None = None, order_by: str | None = None, limit: int | None = None, - organization_id: str | None = None, ) -> list[CalculatedChannel]: - """List calculated channels with optional filtering. + """List calculated channels with optional filtering. This will return the latest version. To find all versions, use `list_versions`. Args: name: Exact name of the calculated channel. name_contains: Partial name of the calculated channel. name_regex: Regular expression string to filter calculated channels by name. + calculated_channel_ids: Filter to calculated channels with any of these IDs. + client_keys: Filter to calculated channels with any of these client keys. created_after: Created after this date. created_before: Created before this date. modified_after: Modified after this date. modified_before: Modified before this date. created_by: Calculated channels created by this user. modified_by: Calculated channels last modified by this user. - client_key: The client key of the calculated channel. - asset_id: The asset ID associated with the calculated channel. - asset_name: The asset name associated with the calculated channel. - tag_id: The tag ID associated with the calculated channel. - tag_name: The tag name associated with the calculated channel. + tags: Filter calculated channels with any of these Tags or tag names. + metadata: Filter calculated channels by metadata criteria. + asset: Filter calculated channels associated with this Asset or asset ID. + run: Filter calculated channels associated with this Run or run ID. version: The version of the calculated channel. + description_contains: Partial description of the calculated channel. include_archived: Include archived calculated channels. filter_query: Explicit CEL query to filter calculated channels. order_by: How to order the retrieved calculated channels. limit: How many calculated channels to retrieve. If None, retrieves all matches. - organization_id: The organization ID (required if user belongs to multiple organizations). Returns: A list of CalculatedChannels that matches the filter. """ - if not filter_query: - filters = [] - if name: - filters.append(cel.equals("name", name)) - if name_contains: - filters.append(cel.contains("name", name_contains)) - if name_regex: - filters.append(cel.match("name", name_regex)) - if created_after: - filters.append(cel.greater_than("created_date", created_after)) - if created_before: - filters.append(cel.less_than("created_date", created_before)) - if modified_after: - filters.append(cel.greater_than("modified_date", modified_after)) - if modified_before: - filters.append(cel.less_than("modified_date", modified_before)) - if created_by: - raise NotImplementedError - if modified_by: - raise NotImplementedError - if client_key: - filters.append(cel.equals("client_key", client_key)) - if asset_id: - filters.append(cel.equals("asset_id", asset_id)) - if asset_name: - filters.append(cel.equals("asset_name", asset_name)) - if tag_id: - filters.append(cel.equals("tag_id", tag_id)) - if tag_name: - filters.append(cel.equals("tag_name", tag_name)) - if version: - filters.append(cel.equals("version", version)) - if not include_archived: - filters.append(cel.equals_null("archived_date")) - filter_query = cel.and_(*filters) + filter_parts = [ + *self._build_name_cel_filters( + name=name, name_contains=name_contains, name_regex=name_regex + ), + *self._build_time_cel_filters( + created_after=created_after, + created_before=created_before, + modified_after=modified_after, + modified_before=modified_before, + created_by=created_by, + modified_by=modified_by, + ), + *self._build_tags_metadata_cel_filters(tags=tags, metadata=metadata), + *self._build_common_cel_filters( + description_contains=description_contains, + include_archived=include_archived, + filter_query=filter_query, + ), + ] + if calculated_channel_ids: + filter_parts.append(cel.in_("calculated_channel_id", calculated_channel_ids)) + if client_keys: + filter_parts.append(cel.in_("client_key", client_keys)) + if asset: + asset_id = asset._id_or_error if isinstance(asset, Asset) else asset + filter_parts.append(cel.equals("asset_id", asset_id)) + if run: + run_id = run._id_or_error if isinstance(run, Run) else run + filter_parts.append(cel.equals("run_id", run_id)) + if version: + filter_parts.append(cel.equals("version", version)) + + query_filter = cel.and_(*filter_parts) calculated_channels = await self._low_level_client.list_all_calculated_channels( - query_filter=filter_query, + query_filter=query_filter or None, order_by=order_by, max_results=limit, - organization_id=organization_id, ) return self._apply_client_to_instances(calculated_channels) @@ -172,7 +174,7 @@ async def find(self, **kwargs) -> CalculatedChannel | None: Will raise an error if multiple calculated channels are found. Args: - **kwargs: Keyword arguments to pass to `list`. + **kwargs: Keyword arguments to pass to `list_`. Returns: The CalculatedChannel found or None. @@ -180,8 +182,7 @@ async def find(self, **kwargs) -> CalculatedChannel | None: calculated_channels = await self.list_(**kwargs) if len(calculated_channels) > 1: raise ValueError( - f"Multiple calculated channels found for query: {kwargs}. " - "Use `list` to handle all matching calculated channels." + f"Multiple ({len(calculated_channels)}) calculated channels found for query" ) elif len(calculated_channels) == 1: return calculated_channels[0] @@ -189,67 +190,31 @@ async def find(self, **kwargs) -> CalculatedChannel | None: async def create( self, - *, - name: str, - expression: str, - channel_references: list[ChannelReference], - description: str = "", - units: str | None = None, - client_key: str | None = None, - asset_ids: list[str] | None = None, - tag_ids: list[str] | None = None, - all_assets: bool = False, - user_notes: str = "", + create: CalculatedChannelCreate | dict, ) -> CalculatedChannel: """Create a calculated channel. Args: - name: The name of the calculated channel. - expression: The expression to calculate the value of the calculated channel. - channel_references: A list of channel references that are used in the expression. - description: The description of the calculated channel. - units: The units of the calculated channel. - client_key: A user-defined unique identifier for the calculated channel. - asset_ids: A list of asset IDs to make the calculation available for. - tag_ids: A list of tag IDs to make the calculation available for. - all_assets: A flag that, when set to True, associates the calculated channel with all assets. - user_notes: User notes for the calculated channel. + create: A CalculatedChannelCreate object or dictionary with configuration for the new calculated channel. + This should include properties like name, expression, channel_references, etc. Returns: The created CalculatedChannel. - Raises: - ValueError: If asset configuration is invalid. """ - # Validate asset configuration - if all_assets and (asset_ids or tag_ids): - raise ValueError("Cannot specify both all_assets and asset_ids/tag_ids") - if not all_assets and not asset_ids and not tag_ids: - raise ValueError("Must specify either all_assets=True or provide asset_ids/tag_ids") + if isinstance(create, dict): + create = CalculatedChannelCreate.model_validate(create) - ( - calculated_channel, - inapplicable_assets, - ) = await self._low_level_client.create_calculated_channel( - name=name, - all_assets=all_assets, - asset_ids=asset_ids, - tag_ids=tag_ids, - expression=expression, - channel_references=channel_references, - description=description, - user_notes=user_notes, - units=units, - client_key=client_key, + created_calc_channel, _ = await self._low_level_client.create_calculated_channel( + create=create ) - - return self._apply_client_to_instance(calculated_channel) + return self._apply_client_to_instance(created_calc_channel) async def update( self, - *, - calculated_channel: str | CalculatedChannel, + calculated_channel: CalculatedChannel | str, update: CalculatedChannelUpdate | dict, + *, user_notes: str | None = None, ) -> CalculatedChannel: """Update a Calculated Channel. @@ -282,90 +247,112 @@ async def update( return self._apply_client_to_instance(updated_calculated_channel) - async def archive(self, *, calculated_channel: str | CalculatedChannel) -> None: - """Archive a Calculated Channel.""" - update = CalculatedChannelUpdate( - archived_date=datetime.now(tz=timezone.utc), + async def archive(self, calculated_channel: str | CalculatedChannel) -> CalculatedChannel: + """Archive a calculated channel. + + Args: + calculated_channel: The id or CalculatedChannel object of the calculated channel to archive. + + Returns: + The archived CalculatedChannel. + """ + return await self.update( + calculated_channel=calculated_channel, update=CalculatedChannelUpdate(is_archived=True) + ) + + async def unarchive(self, calculated_channel: str | CalculatedChannel) -> CalculatedChannel: + """Unarchive a calculated channel. + + Args: + calculated_channel: The id or CalculatedChannel object of the calculated channel to unarchive. + + Returns: + The unarchived CalculatedChannel. + """ + return await self.update( + calculated_channel=calculated_channel, update=CalculatedChannelUpdate(is_archived=False) ) - await self.update(calculated_channel=calculated_channel, update=update) async def list_versions( self, *, - calculated_channel_id: str | None = None, + # self ids + calculated_channel: CalculatedChannel | str | None = None, client_key: str | None = None, - organization_id: str | None = None, name: str | None = None, name_contains: str | None = None, name_regex: str | re.Pattern | None = None, - asset_id: str | None = None, - asset_name: str | None = None, - tag_id: str | None = None, - tag_name: str | None = None, - version: int | None = None, + # created/modified ranges + created_after: datetime | None = None, + created_before: datetime | None = None, + modified_after: datetime | None = None, + modified_before: datetime | None = None, + # created/modified users + created_by: Any | str | None = None, + modified_by: Any | str | None = None, + # tags + tags: list[Any] | list[str] | None = None, + # metadata + metadata: list[Any] | None = None, + # common filters + description_contains: str | None = None, include_archived: bool = False, + filter_query: str | None = None, order_by: str | None = None, limit: int | None = None, ) -> list[CalculatedChannel]: """List versions of a calculated channel. Args: - calculated_channel_id: The ID of the calculated channel. + calculated_channel: The CalculatedChannel or ID of the calculated channel to get versions for. client_key: The client key of the calculated channel. - name: The name of the calculated channel. - name_contains: The name of the calculated channel. - name_regex: The name of the calculated channel. - asset_id: The asset ID of the calculated channel. - asset_name: The asset name of the calculated channel. - tag_id: The tag ID of the calculated channel. - tag_name: The tag name of the calculated channel. - version: The version of the calculated channel. - include_archived: Whether to include archived calculated channels. - organization_id: The organization ID. Required if your user belongs to multiple organizations. - order_by: The field to order by. - limit: How many versions to retrieve. If None, retrieves all matches. + name: Exact name of the calculated channel. + name_contains: Partial name of the calculated channel. + name_regex: Regular expression string to filter calculated channels by name. + created_after: Filter versions created after this datetime. + created_before: Filter versions created before this datetime. + modified_after: Filter versions modified after this datetime. + modified_before: Filter versions modified before this datetime. + created_by: Filter versions created by this user or user ID. + modified_by: Filter versions modified by this user or user ID. + tags: Filter versions with any of these Tags or tag names. + metadata: Filter versions by metadata criteria. + description_contains: Partial description of the calculated channel. + include_archived: Include archived versions. + filter_query: Explicit CEL query to filter versions. + order_by: How to order the retrieved versions. + limit: Maximum number of versions to return. If None, returns all matches. Returns: - A list of CalculatedChannel versions. - - Raises: - ValueError: If neither calculated_channel_id nor client_key is provided. + A list of CalculatedChannel versions that match the filter criteria. """ - if sum(bool(v) for v in [calculated_channel_id, name, name_contains, name_regex]) != 1: - raise ValueError( - "Exactly one of calculated_channel_id, name, name_contains, or name_regex must be provided" - ) - if asset_id and asset_name: - raise ValueError("Cannot specify both asset_id and asset_name") - if tag_id and tag_name: - raise ValueError("Cannot specify both tag_id and tag_name") - - filter_query_parts = [] - if name: - filter_query_parts.append(cel.equals("name", name)) - if name_contains: - filter_query_parts.append(cel.contains("name", name_contains)) - if name_regex: - filter_query_parts.append(cel.match("name", name_regex)) - if asset_id: - filter_query_parts.append(cel.equals("asset_id", asset_id)) - if asset_name: - filter_query_parts.append(cel.equals("asset_name", asset_name)) - if tag_id: - filter_query_parts.append(cel.equals("tag_id", tag_id)) - if tag_name: - filter_query_parts.append(cel.equals("tag_name", tag_name)) - if version: - filter_query_parts.append(cel.equals("version", version)) - if not include_archived: - filter_query_parts.append(cel.equals_null("archived_date")) - filter_query = cel.and_(*filter_query_parts) + filter_parts = [ + *self._build_name_cel_filters( + name=name, name_contains=name_contains, name_regex=name_regex + ), + *self._build_time_cel_filters( + created_after=created_after, + created_before=created_before, + modified_after=modified_after, + modified_before=modified_before, + created_by=created_by, + modified_by=modified_by, + ), + *self._build_tags_metadata_cel_filters(tags=tags, metadata=metadata), + *self._build_common_cel_filters( + description_contains=description_contains, + include_archived=include_archived, + filter_query=filter_query, + ), + ] + query_filter = cel.and_(*filter_parts) versions = await self._low_level_client.list_all_calculated_channel_versions( - calculated_channel_id=calculated_channel_id, client_key=client_key, - organization_id=organization_id, - query_filter=filter_query, + calculated_channel_id=calculated_channel.id_ + if isinstance(calculated_channel, CalculatedChannel) + else calculated_channel, + query_filter=query_filter or None, order_by=order_by, limit=limit, ) diff --git a/python/lib/sift_client/resources/channels.py b/python/lib/sift_client/resources/channels.py index 352677715..7a5bb1d9b 100644 --- a/python/lib/sift_client/resources/channels.py +++ b/python/lib/sift_client/resources/channels.py @@ -1,6 +1,5 @@ from __future__ import annotations -import re from typing import TYPE_CHECKING import pyarrow as pa @@ -8,9 +7,12 @@ from sift_client._internal.low_level_wrappers.channels import ChannelsLowLevelClient from sift_client._internal.low_level_wrappers.data import DataLowLevelClient from sift_client.resources._base import ResourceBase +from sift_client.sift_types.asset import Asset +from sift_client.sift_types.run import Run from sift_client.util import cel_utils as cel if TYPE_CHECKING: + import re from datetime import datetime import pandas as pd @@ -58,91 +60,78 @@ async def get( async def list_( self, *, - asset_id: str | None = None, name: str | None = None, name_contains: str | None = None, name_regex: str | re.Pattern | None = None, - description: str | None = None, - description_contains: str | None = None, - active: bool | None = None, - run_id: str | None = None, - run_name: str | None = None, - client_key: str | None = None, - created_before: datetime | None = None, + # self ids + channel_ids: list[str] | None = None, + # created/modified ranges created_after: datetime | None = None, - modified_before: datetime | None = None, + created_before: datetime | None = None, modified_after: datetime | None = None, + modified_before: datetime | None = None, + # channel specific + asset: Asset | str | None = None, + run: Run | str | None = None, + # common filters + description_contains: str | None = None, + include_archived: bool | None = None, + filter_query: str | None = None, order_by: str | None = None, limit: int | None = None, ) -> list[Channel]: """List channels with optional filtering. Args: - asset_id: The asset ID to get. - name: The name of the channel to get. - name_contains: The partial name of the channel to get. - name_regex: The regex name of the channel to get. - description: The description of the channel to get. - description_contains: The partial description of the channel to get. - active: Whether the channel is active. - run_id: The run ID to get. - run_name: The name of the run to get. - client_key: The client key of the run to get. - created_before: The created date of the channel to get. - created_after: The created date of the channel to get. - modified_before: The modified date of the channel to get. - modified_after: The modified date of the channel to get. - order_by: How to order the retrieved channels. - limit: How many channels to retrieve. If None, retrieves all matches. + name: Exact name of the channel. + name_contains: Partial name of the channel. + name_regex: Regular expression to filter channels by name. + channel_ids: Filter to channels with any of these IDs. + created_after: Filter channels created after this datetime. + created_before: Filter channels created before this datetime. + modified_after: Filter channels modified after this datetime. + modified_before: Filter channels modified before this datetime. + asset: Filter channels associated with this Asset or asset ID. + run: Filter channels associated with this Run or run ID. + description_contains: Partial description of the channel. + include_archived: If True, include archived channels in results. + filter_query: Explicit CEL query to filter channels. + order_by: Field and direction to order results by. + limit: Maximum number of channels to return. If None, returns all matches. Returns: - A list of Channels that matches the filter. + A list of Channels that matches the filter criteria. """ - if sum(bool(x) for x in [name, name_contains, name_regex]) > 1: - raise ValueError("Cannot provide more than one of name, name_contains, or name_regex") - if sum(bool(x) for x in [description, description_contains]) > 1: - raise ValueError("Cannot provide both description and description_contains") - if sum(bool(x) for x in [created_before, created_after]) > 1: - raise ValueError("Cannot provide both created_before and created_after") - if sum(bool(x) for x in [modified_before, modified_after]) > 1: - raise ValueError("Cannot provide both modified_before and modified_after") - - filter_parts = [] - if asset_id: + filter_parts = [ + *self._build_name_cel_filters( + name=name, name_contains=name_contains, name_regex=name_regex + ), + *self._build_time_cel_filters( + created_after=created_after, + created_before=created_before, + modified_after=modified_after, + modified_before=modified_before, + ), + *self._build_common_cel_filters( + description_contains=description_contains, filter_query=filter_query + ), + ] + if channel_ids: + filter_parts.append(cel.in_("channel_id", channel_ids)) + if asset is not None: + asset_id = asset.id_ if isinstance(asset, Asset) else asset filter_parts.append(cel.equals("asset_id", asset_id)) - if name: - filter_parts.append(cel.equals("name", name)) - elif name_contains: - filter_parts.append(cel.contains("name", name_contains)) - elif name_regex: - if isinstance(name_regex, re.Pattern): - name_regex = name_regex.pattern - filter_parts.append(cel.match("name", name_regex)) # type: ignore - if description: - filter_parts.append(cel.equals("description", description)) - elif description_contains: - filter_parts.append(cel.contains("description", description_contains)) - if active: - filter_parts.append(cel.equals("active", active)) - if run_id: + if run is not None: + run_id = run.id_ if isinstance(run, Run) else run filter_parts.append(cel.equals("run_id", run_id)) - if run_name: - filter_parts.append(cel.equals("run_name", run_name)) - if client_key: - filter_parts.append(cel.equals("client_key", client_key)) - if created_before: - filter_parts.append(cel.less_than("created_date", created_before)) - if created_after: - filter_parts.append(cel.greater_than("created_date", created_after)) - if modified_before: - filter_parts.append(cel.less_than("modified_date", modified_before)) - if modified_after: - filter_parts.append(cel.greater_than("modified_date", modified_after)) - - filter_str = " && ".join(filter_parts) + # This is opposite of usual archived state + if include_archived is not None: + filter_parts.append(cel.equals("active", not include_archived)) + + query_filter = cel.and_(*filter_parts) channels = await self._low_level_client.list_all_channels( - query_filter=filter_str, + query_filter=query_filter or None, order_by=order_by, max_results=limit, ) @@ -153,14 +142,14 @@ async def find(self, **kwargs) -> Channel | None: raises an error. Args: - **kwargs: Keyword arguments to pass to `list`. + **kwargs: Keyword arguments to pass to `list_`. Returns: The Channel found or None. """ channels = await self.list_(**kwargs) if len(channels) > 1: - raise ValueError("Multiple channels found for query") + raise ValueError(f"Multiple ({len(channels)}) channels found for query") elif len(channels) == 1: return channels[0] return None @@ -169,7 +158,7 @@ async def get_data( self, *, channels: list[Channel], - run_id: str | None = None, + run: Run | str | None = None, start_time: datetime | None = None, end_time: datetime | None = None, limit: int | None = None, @@ -178,11 +167,15 @@ async def get_data( Args: channels: The channels to get data for. - run_id: The run to get data for. + run: The Run or run_id to get data for. start_time: The start time to get data for. end_time: The end time to get data for. limit: The maximum number of data points to return. Will be in increments of page_size or default page size defined by the call if no page_size is provided. + + Returns: + A dictionary mapping channel names to pandas DataFrames containing the channel data. """ + run_id = run._id_or_error if isinstance(run, Run) else run return await self._data_low_level_client.get_channel_data( channels=channels, run_id=run_id, @@ -195,15 +188,16 @@ async def get_data_as_arrow( self, *, channels: list[Channel], - run_id: str | None = None, + run: Run | str | None = None, start_time: datetime | None = None, end_time: datetime | None = None, limit: int | None = None, ) -> dict[str, pa.Table]: """Get data for one or more channels as pyarrow tables.""" + run_id = run.id_ if isinstance(run, Run) else run data = await self.get_data( channels=channels, - run_id=run_id, + run=run_id, start_time=start_time, end_time=end_time, limit=limit, diff --git a/python/lib/sift_client/resources/rules.py b/python/lib/sift_client/resources/rules.py index a101a3ae5..09d9f255c 100644 --- a/python/lib/sift_client/resources/rules.py +++ b/python/lib/sift_client/resources/rules.py @@ -1,17 +1,17 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from sift_client._internal.low_level_wrappers.rules import RulesLowLevelClient from sift_client.resources._base import ResourceBase -from sift_client.sift_types.rule import Rule, RuleAction, RuleUpdate +from sift_client.sift_types.rule import Rule, RuleCreate, RuleUpdate from sift_client.util import cel_utils as cel if TYPE_CHECKING: import re + from datetime import datetime from sift_client.client import SiftClient - from sift_client.sift_types.channel import ChannelReference class RulesAPIAsync(ResourceBase): @@ -57,9 +57,28 @@ async def list_( name: str | None = None, name_contains: str | None = None, name_regex: str | re.Pattern | None = None, + # self ids + rule_ids: list[str] | None = None, + client_keys: list[str] | None = None, + # created/modified ranges + created_after: datetime | None = None, + created_before: datetime | None = None, + modified_after: datetime | None = None, + modified_before: datetime | None = None, + # created/modified users + created_by: Any | str | None = None, + modified_by: Any | str | None = None, + # metadata + metadata: list[Any] | None = None, + # rule specific + asset_ids: list[str] | None = None, + asset_tag_ids: list[str] | None = None, + # common filters + description_contains: str | None = None, + include_archived: bool = False, + filter_query: str | None = None, order_by: str | None = None, limit: int | None = None, - include_deleted: bool = False, ) -> list[Rule]: """List rules with optional filtering. @@ -67,28 +86,56 @@ async def list_( name: Exact name of the rule. name_contains: Partial name of the rule. name_regex: Regular expression string to filter rules by name. - order_by: How to order the retrieved rules. - limit: How many rules to retrieve. If None, retrieves all matches. - include_deleted: Include deleted rules. + rule_ids: IDs of rules to filter to. + client_keys: Client keys of rules to filter to. + created_after: Rules created after this datetime. + created_before: Rules created before this datetime. + modified_after: Rules modified after this datetime. + modified_before: Rules modified before this datetime. + created_by: Filter rules created by this User or user ID. + modified_by: Filter rules last modified by this User or user ID. + metadata: Filter rules by metadata criteria. + asset_ids: Filter rules associated with any of these Asset IDs. + asset_tag_ids: Filter rules associated with any of these Asset Tag IDs. + description_contains: Partial description of the rule. + include_archived: If True, include archived rules in results. + filter_query: Explicit CEL query to filter rules. + order_by: Field and direction to order results by. + limit: Maximum number of rules to return. If None, returns all matches. Returns: A list of Rules that matches the filter. """ - if int(name is not None) + int(name_contains is not None) + int(name_regex is not None) > 1: - raise ValueError("Must use EITHER name, name_contains, or name_regex, not multiple") - - filters = [] - if name: - filters.append(cel.equals("name", name)) - if name_contains: - filters.append(cel.contains("name", name_contains)) - if name_regex: - filters.append(cel.match("name", name_regex)) - if not include_deleted: - filters.append(cel.equals_null("deleted_date")) - filter_str = " && ".join(filters) if filters else "" + filter_parts = [ + *self._build_name_cel_filters( + name=name, name_contains=name_contains, name_regex=name_regex + ), + *self._build_time_cel_filters( + created_after=created_after, + created_before=created_before, + modified_after=modified_after, + modified_before=modified_before, + created_by=created_by, + modified_by=modified_by, + ), + *self._build_tags_metadata_cel_filters(metadata=metadata), + *self._build_common_cel_filters( + description_contains=description_contains, + include_archived=include_archived, + filter_query=filter_query, + ), + ] + if rule_ids: + filter_parts.append(cel.in_("rule_id", rule_ids)) + if client_keys: + filter_parts.append(cel.in_("client_key", client_keys)) + if asset_ids: + filter_parts.append(cel.in_("asset_id", asset_ids)) + if asset_tag_ids: + filter_parts.append(cel.in_("tag_id", asset_tag_ids)) + query_filter = cel.and_(*filter_parts) rules = await self._low_level_client.list_all_rules( - filter_query=filter_str, + filter_query=query_filter, order_by=order_by, max_results=limit, page_size=limit, @@ -114,34 +161,28 @@ async def find(self, **kwargs) -> Rule | None: async def create( self, - name: str, - description: str, - expression: str, - channel_references: list[ChannelReference], - action: RuleAction, - organization_id: str | None = None, - client_key: str | None = None, - asset_ids: list[str] | None = None, - contextual_channels: list[str] | None = None, - is_external: bool = False, + create: RuleCreate | dict, ) -> Rule: - """Create a new rule.""" - created_rule = await self._low_level_client.create_rule( - name=name, - description=description, - organization_id=organization_id, - expression=expression, - action=action, - channel_references=channel_references, - client_key=client_key, - asset_ids=asset_ids, - contextual_channels=contextual_channels, - is_external=is_external, - ) + """Create a new rule. + + Args: + create: A RuleCreate object or dictionary with configuration for the new rule. + + Returns: + The created Rule. + """ + if isinstance(create, dict): + create = RuleCreate.model_validate(create) + + created_rule = await self._low_level_client.create_rule(create=create) return self._apply_client_to_instance(created_rule) async def update( - self, rule: str | Rule, update: RuleUpdate | dict, version_notes: str | None = None + self, + rule: Rule | str, + update: RuleUpdate | dict, + *, + version_notes: str | None = None, ) -> Rule: """Update a Rule. @@ -153,110 +194,38 @@ async def update( Returns: The updated Rule. """ + rule_obj: Rule if isinstance(rule, str): - rule = await self.get(rule_id=rule) + rule_obj = await self.get(rule_id=rule) + else: + rule_obj = rule if isinstance(update, dict): update = RuleUpdate.model_validate(update) - updated_rule = await self._low_level_client.update_rule(rule, update, version_notes) + updated_rule = await self._low_level_client.update_rule( + rule=rule_obj, update=update, version_notes=version_notes + ) return self._apply_client_to_instance(updated_rule) - async def archive( - self, - *, - rule: str | Rule | None = None, - rules: list[Rule] | None = None, - rule_ids: list[str] | None = None, - client_keys: list[str] | None = None, - ) -> None: - """Archive a rule or multiple. + async def archive(self, rule: str | Rule) -> Rule: + """Archive a rule. Args: - rule: The Rule to archive. - rules: The Rules to archive. - rule_ids: The rule IDs to archive. - client_keys: The client keys to archive. - """ - if rule: - if isinstance(rule, Rule): - await self._low_level_client.archive_rule(rule_id=rule.id_) - else: - await self._low_level_client.archive_rule(rule_id=rule) - elif rules: - if len(rules) == 1: - await self._low_level_client.archive_rule(rule_id=rules[0].id_) - else: - await self._low_level_client.batch_archive_rules( - rule_ids=[r.id_ for r in rules], # type: ignore - ) - elif rule_ids: - if len(rule_ids) == 1: - await self._low_level_client.archive_rule(rule_id=rule_ids[0]) - else: - await self._low_level_client.batch_archive_rules(rule_ids=rule_ids) - elif client_keys: - await self._low_level_client.batch_archive_rules(client_keys=client_keys) - else: - raise ValueError("Either rules, rule_ids, or client_keys must be provided") - - async def restore( - self, - *, - rule: str | Rule, - rule_id: str | None = None, - client_key: str | None = None, - ) -> Rule: - """Restore a rule. - - Args: - rule: The Rule or rule ID to restore. - rule_id: The rule ID to restore (alternative to rule parameter). - client_key: The client key to restore (alternative to rule parameter). + rule: The id or Rule object of the rule to archive. Returns: - The restored Rule. + The archived Rule. """ - if rule_id or client_key: - restored_rule = await self._low_level_client.restore_rule( - rule_id=rule_id, client_key=client_key - ) - else: - rule_id = rule.id_ if isinstance(rule, Rule) else rule - restored_rule = await self._low_level_client.restore_rule(rule_id=rule_id) - - return self._apply_client_to_instance(restored_rule) + return await self.update(rule=rule, update=RuleUpdate(is_archived=True)) - async def batch_restore( - self, - *, - rule_ids: list[str] | None = None, - client_keys: list[str] | None = None, - ) -> None: - """Batch restore rules. + async def unarchive(self, rule: str | Rule) -> Rule: + """Unarchive a rule. Args: - rule_ids: List of rule IDs to restore. - client_keys: List of client keys to undelete. - """ - await self._low_level_client.batch_restore_rules(rule_ids=rule_ids, client_keys=client_keys) - - async def batch_get( - self, - *, - rule_ids: list[str] | None = None, - client_keys: list[str] | None = None, - ) -> list[Rule]: - """Get multiple rules by rule IDs or client keys. - - Args: - rule_ids: List of rule IDs to get. - client_keys: List of client keys to get. + rule: The id or Rule object of the rule to unarchive. Returns: - List of Rules. + The unarchived Rule. """ - rules = await self._low_level_client.batch_get_rules( - rule_ids=rule_ids, client_keys=client_keys - ) - return self._apply_client_to_instances(rules) + return await self.update(rule=rule, update=RuleUpdate(is_archived=False)) diff --git a/python/lib/sift_client/resources/runs.py b/python/lib/sift_client/resources/runs.py index abb324f89..cf62b6b29 100644 --- a/python/lib/sift_client/resources/runs.py +++ b/python/lib/sift_client/resources/runs.py @@ -1,17 +1,18 @@ from __future__ import annotations -import re -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, cast from sift_client._internal.low_level_wrappers.runs import RunsLowLevelClient from sift_client.resources._base import ResourceBase -from sift_client.sift_types.run import Run, RunUpdate -from sift_client.util.cel_utils import contains, equals, equals_null, match, not_ +from sift_client.sift_types.run import Run, RunCreate, RunUpdate +from sift_client.util import cel_utils as cel if TYPE_CHECKING: - from datetime import datetime + import re + from datetime import datetime, timedelta from sift_client.client import SiftClient + from sift_client.sift_types.asset import Asset class RunsAPIAsync(ResourceBase): @@ -33,20 +34,25 @@ def __init__(self, sift_client: SiftClient): super().__init__(sift_client) self._low_level_client = RunsLowLevelClient(grpc_client=self.client.grpc_client) - async def get( - self, - *, - run_id: str, - ) -> Run: + async def get(self, *, run_id: str | None = None, client_key: str | None = None) -> Run: """Get a Run. Args: run_id: The ID of the run. + client_key: The client key of the run. Returns: The Run. """ - run = await self._low_level_client.get_run(run_id=run_id) + run: Run | None + if run_id is not None: + run = await self._low_level_client.get_run(run_id=run_id) + elif client_key is not None: + run = await self.find(client_keys=[client_key]) + if run is None: + raise ValueError(f"Run with client_key {client_key} not found") + else: + raise ValueError("Either run_id or client_key must be provided") return self._apply_client_to_instance(run) async def list_( @@ -55,15 +61,32 @@ async def list_( name: str | None = None, name_contains: str | None = None, name_regex: str | re.Pattern | None = None, - description: str | None = None, - description_contains: str | None = None, - duration_seconds: int | None = None, - client_key: str | None = None, - asset_id: str | None = None, - asset_name: str | None = None, - created_by_user_id: str | None = None, + # self ids + run_ids: list[str] | None = None, + client_keys: list[str] | None = None, + # created/modified ranges + created_after: datetime | None = None, + created_before: datetime | None = None, + modified_after: datetime | None = None, + modified_before: datetime | None = None, + # created/modified users + created_by: Any | str | None = None, + modified_by: Any | str | None = None, + # metadata + metadata: list[Any] | None = None, + # run specific + assets: list[Asset] | list[str] | None = None, + duration_less_than: timedelta | None = None, + duration_greater_than: timedelta | None = None, + start_time_after: datetime | None = None, + start_time_before: datetime | None = None, + stop_time_after: datetime | None = None, + stop_time_before: datetime | None = None, is_stopped: bool | None = None, + # common filters + description_contains: str | None = None, include_archived: bool = False, + filter_query: str | None = None, order_by: str | None = None, limit: int | None = None, ) -> list[Run]: @@ -72,122 +95,119 @@ async def list_( Args: name: Exact name of the run. name_contains: Partial name of the run. - name_regex: Regular expression string to filter runs by name. - description: Exact description of the run. - description_contains: Partial description of the run. - duration_seconds: Duration of the run in seconds. - client_key: Client key to filter by. - asset_id: Asset ID to filter by. - asset_name: Asset name to filter by. - created_by_user_id: User ID who created the run. + name_regex: Regular expression to filter runs by name. + run_ids: Filter to runs with any of these IDs. + client_keys: Filter to runs with any of these client keys. + created_after: Filter runs created after this datetime. + created_before: Filter runs created before this datetime. + modified_after: Filter runs modified after this datetime. + modified_before: Filter runs modified before this datetime. + created_by: Filter runs created by this User or user ID. + modified_by: Filter runs last modified by this User or user ID. + metadata: Filter runs by metadata criteria. + assets: Filter runs associated with any of these Assets or asset IDs. + duration_less_than: Filter runs with duration less than this time. + duration_greater_than: Filter runs with duration greater than this time. + start_time_after: Filter runs that started after this datetime. + start_time_before: Filter runs that started before this datetime. + stop_time_after: Filter runs that stopped after this datetime. + stop_time_before: Filter runs that stopped before this datetime. is_stopped: Whether the run is stopped. - include_archived: Whether to include archived runs. - order_by: How to order the retrieved runs. - limit: How many runs to retrieve. If None, retrieves all matches. + description_contains: Partial description of the run. + include_archived: If True, include archived runs in results. + filter_query: Explicit CEL query to filter runs. + order_by: Field and direction to order results by. + limit: Maximum number of runs to return. If None, returns all matches. Returns: - A list of Runs that matches the filter. + A list of Run objects that match the filter criteria. """ - # Build CEL filter - filter_parts = [] - - if name: - filter_parts.append(equals("name", name)) - elif name_contains: - filter_parts.append(contains("name", name_contains)) - elif name_regex: - if isinstance(name_regex, re.Pattern): - name_regex = name_regex.pattern - filter_parts.append(match("name", name_regex)) # type: ignore - - if description: - filter_parts.append(equals("description", description)) - elif description_contains: - filter_parts.append(contains("description", description_contains)) - - if duration_seconds: - filter_parts.append(equals("duration", duration_seconds)) - - if client_key: - filter_parts.append(equals("client_key", client_key)) - - if asset_id: - filter_parts.append(equals("asset_id", asset_id)) - - if asset_name: - filter_parts.append(equals("asset_name", asset_name)) - - if created_by_user_id: - filter_parts.append(equals("created_by_user_id", created_by_user_id)) - + filter_parts = [ + *self._build_name_cel_filters( + name=name, name_contains=name_contains, name_regex=name_regex + ), + *self._build_time_cel_filters( + created_after=created_after, + created_before=created_before, + modified_after=modified_after, + modified_before=modified_before, + created_by=created_by, + modified_by=modified_by, + ), + *self._build_tags_metadata_cel_filters(metadata=metadata), + *self._build_common_cel_filters( + description_contains=description_contains, + include_archived=include_archived, + filter_query=filter_query, + ), + ] + if run_ids: + filter_parts.append(cel.in_("run_id", run_ids)) + if client_keys: + filter_parts.append(cel.in_("client_key", client_keys)) + if assets: + if all(isinstance(s, str) for s in assets): + ids = cast("list[str]", assets) # linting + filter_parts.append(cel.in_("asset_ids", ids)) + else: + asset = cast("list[Asset]", assets) # linting + filter_parts.append(cel.in_("asset_ids", [a._id_or_error for a in asset])) + if duration_less_than: + filter_parts.append(cel.less_than("duration_string", duration_less_than)) + if duration_greater_than: + filter_parts.append(cel.greater_than("duration_string", duration_greater_than)) + if start_time_after: + filter_parts.append(cel.greater_than("start_time", start_time_after)) + if start_time_before: + filter_parts.append(cel.less_than("start_time", start_time_before)) + if stop_time_after: + filter_parts.append(cel.greater_than("stop_time", stop_time_after)) + if stop_time_before: + filter_parts.append(cel.less_than("stop_time", stop_time_before)) if is_stopped is not None: - filter_parts.append(not_(equals_null("stop_time"))) - - if not include_archived: - filter_parts.append(equals("archived_date", None)) - - query_filter = " && ".join(filter_parts) if filter_parts else None + filter_parts.append(cel.not_(cel.equals_null("stop_time"))) + query_filter = cel.and_(*filter_parts) runs = await self._low_level_client.list_all_runs( - query_filter=query_filter, + query_filter=query_filter or None, order_by=order_by, max_results=limit, ) return self._apply_client_to_instances(runs) async def find(self, **kwargs) -> Run | None: - """Find a single run matching the given query. Takes the same arguments as `list`. If more than one run is found, + """Find a single run matching the given query. Takes the same arguments as `list_`. If more than one run is found, raises an error. Args: - **kwargs: Keyword arguments to pass to `list`. + **kwargs: Keyword arguments to pass to `list_`. Returns: The Run found or None. """ runs = await self.list_(**kwargs) if len(runs) > 1: - raise ValueError("Multiple runs found for query") + raise ValueError(f"Multiple ({len(runs)}) runs found for query") elif len(runs) == 1: return runs[0] return None async def create( self, - name: str, - description: str, - tags: list[str] | None = None, - start_time: datetime | None = None, - stop_time: datetime | None = None, - organization_id: str | None = None, - client_key: str | None = None, - metadata: dict[str, str | float | bool] | None = None, + create: RunCreate | dict, ) -> Run: """Create a new run. Args: - name: The name of the run. - description: The description of the run. - tags: Tags to associate with the run. - start_time: The start time of the run. - stop_time: The stop time of the run. - organization_id: The organization ID. - client_key: A unique client key for the run. - metadata: Metadata values for the run. + create: The Run definition to create. Returns: The created Run. """ - created_run = await self._low_level_client.create_run( - name=name, - description=description, - tags=tags, - start_time=start_time, - stop_time=stop_time, - organization_id=organization_id, - client_key=client_key, - metadata=metadata, - ) + if isinstance(create, dict): + create = RunCreate.model_validate(create) + + created_run = await self._low_level_client.create_run(create=create) return self._apply_client_to_instance(created_run) async def update(self, run: str | Run, update: RunUpdate | dict) -> Run: @@ -200,34 +220,37 @@ async def update(self, run: str | Run, update: RunUpdate | dict) -> Run: Returns: The updated Run. """ - if isinstance(run, str): - run = await self.get(run_id=run) - + run_id = run._id_or_error if isinstance(run, Run) else run if isinstance(update, dict): update = RunUpdate.model_validate(update) - - update.resource_id = run.id_ - updated_run = await self._low_level_client.update_run(run, update) + update.resource_id = run_id + updated_run = await self._low_level_client.update_run(update) return self._apply_client_to_instance(updated_run) async def archive( self, - *, run: str | Run, - ) -> None: + ) -> Run: """Archive a run. Args: run: The Run or run ID to archive. """ - run_id = run.id_ if isinstance(run, Run) else run - if not isinstance(run_id, str): - raise TypeError(f"run_id must be a string not {type(run_id)}") - await self._low_level_client.archive_run(run_id=run_id) + return await self.update(run, RunUpdate(is_archived=True)) + + async def unarchive( + self, + run: str | Run, + ) -> Run: + """Unarchive a run. + + Args: + run: The Run or run ID to unarchive. + """ + return await self.update(run, RunUpdate(is_archived=False)) async def stop( self, - *, run: str | Run, ) -> None: """Stop a run by setting its stop time to the current time. @@ -235,12 +258,13 @@ async def stop( Args: run: The Run or run ID to stop. """ - run_id = run.id_ if isinstance(run, Run) else run + run_id = run._id_or_error if isinstance(run, Run) else run await self._low_level_client.stop_run(run_id=run_id or "") async def create_automatic_association_for_assets( self, run: str | Run, + *, asset_names: list[str], ) -> None: """Associate assets with a run for automatic data ingestion. @@ -249,16 +273,7 @@ async def create_automatic_association_for_assets( run: The Run or run ID. asset_names: List of asset names to associate. """ - run_id = run.id_ or "" if isinstance(run, Run) else run + run_id = run._id_or_error or "" if isinstance(run, Run) else run await self._low_level_client.create_automatic_run_association_for_assets( run_id=run_id, asset_names=asset_names ) - - async def stop_run(self, run: str | Run) -> None: - """Stop a run by setting its stop time to the current time. - - Args: - run: The Run or run ID to stop. - """ - run_id = run.id_ or "" if isinstance(run, Run) else run - await self._low_level_client.stop_run(run_id=run_id or "") diff --git a/python/lib/sift_client/resources/sync_stubs/__init__.pyi b/python/lib/sift_client/resources/sync_stubs/__init__.pyi index 0c52d3b15..15302378b 100644 --- a/python/lib/sift_client/resources/sync_stubs/__init__.pyi +++ b/python/lib/sift_client/resources/sync_stubs/__init__.pyi @@ -3,7 +3,7 @@ from __future__ import annotations import re -from datetime import datetime +from datetime import datetime, timedelta from typing import Any import pandas as pd @@ -11,22 +11,25 @@ import pyarrow as pa from sift_client.client import SiftClient from sift_client.sift_types.asset import Asset, AssetUpdate -from sift_client.sift_types.calculated_channel import CalculatedChannel, CalculatedChannelUpdate -from sift_client.sift_types.channel import Channel, ChannelReference -from sift_client.sift_types.rule import Rule, RuleAction, RuleUpdate -from sift_client.sift_types.run import Run, RunUpdate +from sift_client.sift_types.calculated_channel import ( + CalculatedChannel, + CalculatedChannelCreate, + CalculatedChannelUpdate, +) +from sift_client.sift_types.channel import Channel +from sift_client.sift_types.rule import Rule, RuleCreate, RuleUpdate +from sift_client.sift_types.run import Run, RunCreate, RunUpdate class AssetsAPI: """Sync counterpart to `AssetsAPIAsync`. High-level API for interacting with assets. - This class provides a Pythonic, notebook-friendly interface for interacting with the AssetsAPI. - It handles automatic handling of gRPC services, seamless type conversion, and clear error handling. - - All methods in this class use the Asset class from the low-level wrapper, which is a user-friendly - representation of an asset using standard Python data structures and types. + This class provides a Pythonic, notebook-friendly interface for interacting with the AssetsAPI. + It handles automatic handling of gRPC services, seamless type conversion, and clear error handling. + All methods in this class use the Asset class from the low-level wrapper, which is a user-friendly + representation of an asset using standard Python data structures and types. """ def __init__(self, sift_client: SiftClient): @@ -85,11 +88,12 @@ class AssetsAPI: created_before: datetime | None = None, modified_after: datetime | None = None, modified_before: datetime | None = None, - created_by: Any | None = None, - modified_by: Any | None = None, - tags: list[str] | None = None, - tag_ids: list[str] | None = None, + created_by: Any | str | None = None, + modified_by: Any | str | None = None, + tags: list[Any] | list[str] | None = None, + _tag_ids: list[str] | None = None, metadata: list[Any] | None = None, + description_contains: str | None = None, include_archived: bool = False, filter_query: str | None = None, order_by: str | None = None, @@ -98,27 +102,37 @@ class AssetsAPI: """List assets with optional filtering. Args: - asset_ids: List of asset IDs to filter by. name: Exact name of the asset. name_contains: Partial name of the asset. - name_regex: Regular expression string to filter assets by name. - asset_ids: List of asset IDs to filter by. - created_after: Created after this date. - created_before: Created before this date. - modified_after: Modified after this date. - modified_before: Modified before this date. - created_by: Assets created by this user. - modified_by: Assets last modified by this user. - tags: Assets with these tags. - tag_ids: List of asset tag IDs to filter by. - metadata: metadata filter - include_archived: Include archived assets. + name_regex: Regular expression to filter assets by name. + asset_ids: Filter to assets with any of these Ids. + created_after: Filter assets created after this datetime. + created_before: Filter assets created before this datetime. + modified_after: Filter assets modified after this datetime. + modified_before: Filter assets modified before this datetime. + created_by: Filter assets created by this User or user ID. + modified_by: Filter assets last modified by this User or user ID. + tags: Filter assets with any of these Tags or tag names. + metadata: Filter assets by metadata criteria. + description_contains: Partial description of the asset. + include_archived: If True, include archived assets in results. filter_query: Explicit CEL query to filter assets. - order_by: How to order the retrieved assets. # TODO: tooling for this? - limit: How many assets to retrieve. If None, retrieves all matches. + order_by: Field and direction to order results by. + limit: Maximum number of assets to return. If None, returns all matches. + + Returns: + A list of Asset objects that match the filter criteria. + """ + ... + + def unarchive(self, asset: str | Asset) -> Asset: + """Unarchive an asset. + + Args: + asset: The Asset or asset ID to unarchive. Returns: - A list of Assets that matches the filter. + The unarchived Asset. """ ... @@ -139,12 +153,11 @@ class CalculatedChannelsAPI: High-level API for interacting with calculated channels. - This class provides a Pythonic, notebook-friendly interface for interacting with the CalculatedChannelsAPI. - It handles automatic handling of gRPC services, seamless type conversion, and clear error handling. - - All methods in this class use the CalculatedChannel class from the low-level wrapper, which is a user-friendly - representation of a calculated channel using standard Python data structures and types. + This class provides a Pythonic, notebook-friendly interface for interacting with the CalculatedChannelsAPI. + It handles automatic handling of gRPC services, seamless type conversion, and clear error handling. + All methods in this class use the CalculatedChannel class from the low-level wrapper, which is a user-friendly + representation of a calculated channel using standard Python data structures and types. """ def __init__(self, sift_client: SiftClient): @@ -156,43 +169,26 @@ class CalculatedChannelsAPI: ... def _run(self, coro): ... - def archive(self, *, calculated_channel: str | CalculatedChannel) -> None: - """Archive a Calculated Channel.""" + def archive(self, calculated_channel: str | CalculatedChannel) -> CalculatedChannel: + """Archive a calculated channel. + + Args: + calculated_channel: The id or CalculatedChannel object of the calculated channel to archive. + + Returns: + The archived CalculatedChannel. + """ ... - def create( - self, - *, - name: str, - expression: str, - channel_references: list[ChannelReference], - description: str = "", - units: str | None = None, - client_key: str | None = None, - asset_ids: list[str] | None = None, - tag_ids: list[str] | None = None, - all_assets: bool = False, - user_notes: str = "", - ) -> CalculatedChannel: + def create(self, create: CalculatedChannelCreate | dict) -> CalculatedChannel: """Create a calculated channel. Args: - name: The name of the calculated channel. - expression: The expression to calculate the value of the calculated channel. - channel_references: A list of channel references that are used in the expression. - description: The description of the calculated channel. - units: The units of the calculated channel. - client_key: A user-defined unique identifier for the calculated channel. - asset_ids: A list of asset IDs to make the calculation available for. - tag_ids: A list of tag IDs to make the calculation available for. - all_assets: A flag that, when set to True, associates the calculated channel with all assets. - user_notes: User notes for the calculated channel. + create: A CalculatedChannelCreate object or dictionary with configuration for the new calculated channel. + This should include properties like name, expression, channel_references, etc. Returns: The created CalculatedChannel. - - Raises: - ValueError: If asset configuration is invalid. """ ... @@ -201,7 +197,7 @@ class CalculatedChannelsAPI: Will raise an error if multiple calculated channels are found. Args: - **kwargs: Keyword arguments to pass to `list`. + **kwargs: Keyword arguments to pass to `list_`. Returns: The CalculatedChannel found or None. @@ -209,18 +205,13 @@ class CalculatedChannelsAPI: ... def get( - self, - *, - calculated_channel_id: str | None = None, - client_key: str | None = None, - organization_id: str | None = None, + self, *, calculated_channel_id: str | None = None, client_key: str | None = None ) -> CalculatedChannel: """Get a Calculated Channel. Args: calculated_channel_id: The ID of the calculated channel. client_key: The client key of the calculated channel. - organization_id: The organization ID (required if using client_key and user belongs to multiple organizations). Returns: The CalculatedChannel. @@ -236,47 +227,49 @@ class CalculatedChannelsAPI: name: str | None = None, name_contains: str | None = None, name_regex: str | re.Pattern | None = None, + calculated_channel_ids: list[str] | None = None, + client_keys: list[str] | None = None, created_after: datetime | None = None, created_before: datetime | None = None, modified_after: datetime | None = None, modified_before: datetime | None = None, - created_by: Any | None = None, - modified_by: Any | None = None, - client_key: str | None = None, - asset_id: str | None = None, - asset_name: str | None = None, - tag_id: str | None = None, - tag_name: str | None = None, + created_by: Any | str | None = None, + modified_by: Any | str | None = None, + tags: list[Any] | list[str] | None = None, + metadata: list[Any] | None = None, + asset: Asset | str | None = None, + run: Run | str | None = None, version: int | None = None, + description_contains: str | None = None, include_archived: bool = False, filter_query: str | None = None, order_by: str | None = None, limit: int | None = None, - organization_id: str | None = None, ) -> list[CalculatedChannel]: - """List calculated channels with optional filtering. + """List calculated channels with optional filtering. This will return the latest version. To find all versions, use `list_versions`. Args: name: Exact name of the calculated channel. name_contains: Partial name of the calculated channel. name_regex: Regular expression string to filter calculated channels by name. + calculated_channel_ids: Filter to calculated channels with any of these IDs. + client_keys: Filter to calculated channels with any of these client keys. created_after: Created after this date. created_before: Created before this date. modified_after: Modified after this date. modified_before: Modified before this date. created_by: Calculated channels created by this user. modified_by: Calculated channels last modified by this user. - client_key: The client key of the calculated channel. - asset_id: The asset ID associated with the calculated channel. - asset_name: The asset name associated with the calculated channel. - tag_id: The tag ID associated with the calculated channel. - tag_name: The tag name associated with the calculated channel. + tags: Filter calculated channels with any of these Tags or tag names. + metadata: Filter calculated channels by metadata criteria. + asset: Filter calculated channels associated with this Asset or asset ID. + run: Filter calculated channels associated with this Run or run ID. version: The version of the calculated channel. + description_contains: Partial description of the calculated channel. include_archived: Include archived calculated channels. filter_query: Explicit CEL query to filter calculated channels. order_by: How to order the retrieved calculated channels. limit: How many calculated channels to retrieve. If None, retrieves all matches. - organization_id: The organization ID (required if user belongs to multiple organizations). Returns: A list of CalculatedChannels that matches the filter. @@ -286,52 +279,68 @@ class CalculatedChannelsAPI: def list_versions( self, *, - calculated_channel_id: str | None = None, + calculated_channel: CalculatedChannel | str | None = None, client_key: str | None = None, - organization_id: str | None = None, name: str | None = None, name_contains: str | None = None, name_regex: str | re.Pattern | None = None, - asset_id: str | None = None, - asset_name: str | None = None, - tag_id: str | None = None, - tag_name: str | None = None, - version: int | None = None, + created_after: datetime | None = None, + created_before: datetime | None = None, + modified_after: datetime | None = None, + modified_before: datetime | None = None, + created_by: Any | str | None = None, + modified_by: Any | str | None = None, + tags: list[Any] | list[str] | None = None, + metadata: list[Any] | None = None, + description_contains: str | None = None, include_archived: bool = False, + filter_query: str | None = None, order_by: str | None = None, limit: int | None = None, ) -> list[CalculatedChannel]: """List versions of a calculated channel. Args: - calculated_channel_id: The ID of the calculated channel. + calculated_channel: The CalculatedChannel or ID of the calculated channel to get versions for. client_key: The client key of the calculated channel. - name: The name of the calculated channel. - name_contains: The name of the calculated channel. - name_regex: The name of the calculated channel. - asset_id: The asset ID of the calculated channel. - asset_name: The asset name of the calculated channel. - tag_id: The tag ID of the calculated channel. - tag_name: The tag name of the calculated channel. - version: The version of the calculated channel. - include_archived: Whether to include archived calculated channels. - organization_id: The organization ID. Required if your user belongs to multiple organizations. - order_by: The field to order by. - limit: How many versions to retrieve. If None, retrieves all matches. + name: Exact name of the calculated channel. + name_contains: Partial name of the calculated channel. + name_regex: Regular expression string to filter calculated channels by name. + created_after: Filter versions created after this datetime. + created_before: Filter versions created before this datetime. + modified_after: Filter versions modified after this datetime. + modified_before: Filter versions modified before this datetime. + created_by: Filter versions created by this user or user ID. + modified_by: Filter versions modified by this user or user ID. + tags: Filter versions with any of these Tags or tag names. + metadata: Filter versions by metadata criteria. + description_contains: Partial description of the calculated channel. + include_archived: Include archived versions. + filter_query: Explicit CEL query to filter versions. + order_by: How to order the retrieved versions. + limit: Maximum number of versions to return. If None, returns all matches. Returns: - A list of CalculatedChannel versions. + A list of CalculatedChannel versions that match the filter criteria. + """ + ... - Raises: - ValueError: If neither calculated_channel_id nor client_key is provided. + def unarchive(self, calculated_channel: str | CalculatedChannel) -> CalculatedChannel: + """Unarchive a calculated channel. + + Args: + calculated_channel: The id or CalculatedChannel object of the calculated channel to unarchive. + + Returns: + The unarchived CalculatedChannel. """ ... def update( self, - *, - calculated_channel: str | CalculatedChannel, + calculated_channel: CalculatedChannel | str, update: CalculatedChannelUpdate | dict, + *, user_notes: str | None = None, ) -> CalculatedChannel: """Update a Calculated Channel. @@ -351,12 +360,11 @@ class ChannelsAPI: High-level API for interacting with channels. - This class provides a Pythonic, notebook-friendly interface for interacting with the ChannelsAPI. - It handles automatic handling of gRPC services, seamless type conversion, and clear error handling. - - All methods in this class use the Channel class from the low-level wrapper, which is a user-friendly - representation of a channel using standard Python data structures and types. + This class provides a Pythonic, notebook-friendly interface for interacting with the ChannelsAPI. + It handles automatic handling of gRPC services, seamless type conversion, and clear error handling. + All methods in this class use the Channel class from the low-level wrapper, which is a user-friendly + representation of a channel using standard Python data structures and types. """ def __init__(self, sift_client: SiftClient): @@ -373,7 +381,7 @@ class ChannelsAPI: raises an error. Args: - **kwargs: Keyword arguments to pass to `list`. + **kwargs: Keyword arguments to pass to `list_`. Returns: The Channel found or None. @@ -395,7 +403,7 @@ class ChannelsAPI: self, *, channels: list[Channel], - run_id: str | None = None, + run: Run | str | None = None, start_time: datetime | None = None, end_time: datetime | None = None, limit: int | None = None, @@ -404,10 +412,13 @@ class ChannelsAPI: Args: channels: The channels to get data for. - run_id: The run to get data for. + run: The Run or run_id to get data for. start_time: The start time to get data for. end_time: The end time to get data for. limit: The maximum number of data points to return. Will be in increments of page_size or default page size defined by the call if no page_size is provided. + + Returns: + A dictionary mapping channel names to pandas DataFrames containing the channel data. """ ... @@ -415,7 +426,7 @@ class ChannelsAPI: self, *, channels: list[Channel], - run_id: str | None = None, + run: Run | str | None = None, start_time: datetime | None = None, end_time: datetime | None = None, limit: int | None = None, @@ -426,45 +437,43 @@ class ChannelsAPI: def list_( self, *, - asset_id: str | None = None, name: str | None = None, name_contains: str | None = None, name_regex: str | re.Pattern | None = None, - description: str | None = None, - description_contains: str | None = None, - active: bool | None = None, - run_id: str | None = None, - run_name: str | None = None, - client_key: str | None = None, - created_before: datetime | None = None, + channel_ids: list[str] | None = None, created_after: datetime | None = None, - modified_before: datetime | None = None, + created_before: datetime | None = None, modified_after: datetime | None = None, + modified_before: datetime | None = None, + asset: Asset | str | None = None, + run: Run | str | None = None, + description_contains: str | None = None, + include_archived: bool | None = None, + filter_query: str | None = None, order_by: str | None = None, limit: int | None = None, ) -> list[Channel]: """List channels with optional filtering. Args: - asset_id: The asset ID to get. - name: The name of the channel to get. - name_contains: The partial name of the channel to get. - name_regex: The regex name of the channel to get. - description: The description of the channel to get. - description_contains: The partial description of the channel to get. - active: Whether the channel is active. - run_id: The run ID to get. - run_name: The name of the run to get. - client_key: The client key of the run to get. - created_before: The created date of the channel to get. - created_after: The created date of the channel to get. - modified_before: The modified date of the channel to get. - modified_after: The modified date of the channel to get. - order_by: How to order the retrieved channels. - limit: How many channels to retrieve. If None, retrieves all matches. + name: Exact name of the channel. + name_contains: Partial name of the channel. + name_regex: Regular expression to filter channels by name. + channel_ids: Filter to channels with any of these IDs. + created_after: Filter channels created after this datetime. + created_before: Filter channels created before this datetime. + modified_after: Filter channels modified after this datetime. + modified_before: Filter channels modified before this datetime. + asset: Filter channels associated with this Asset or asset ID. + run: Filter channels associated with this Run or run ID. + description_contains: Partial description of the channel. + include_archived: If True, include archived channels in results. + filter_query: Explicit CEL query to filter channels. + order_by: Field and direction to order results by. + limit: Maximum number of channels to return. If None, returns all matches. Returns: - A list of Channels that matches the filter. + A list of Channels that matches the filter criteria. """ ... @@ -496,12 +505,11 @@ class RulesAPI: High-level API for interacting with rules. - This class provides a Pythonic, notebook-friendly interface for interacting with the RulesAPI. - It handles automatic handling of gRPC services, seamless type conversion, and clear error handling. - - All methods in this class use the Rule class from the low-level wrapper, which is a user-friendly - representation of a rule using standard Python data structures and types. + This class provides a Pythonic, notebook-friendly interface for interacting with the RulesAPI. + It handles automatic handling of gRPC services, seamless type conversion, and clear error handling. + All methods in this class use the Rule class from the low-level wrapper, which is a user-friendly + representation of a rule using standard Python data structures and types. """ def __init__(self, sift_client: SiftClient): @@ -513,63 +521,26 @@ class RulesAPI: ... def _run(self, coro): ... - def archive( - self, - *, - rule: str | Rule | None = None, - rules: list[Rule] | None = None, - rule_ids: list[str] | None = None, - client_keys: list[str] | None = None, - ) -> None: - """Archive a rule or multiple. + def archive(self, rule: str | Rule) -> Rule: + """Archive a rule. Args: - rule: The Rule to archive. - rules: The Rules to archive. - rule_ids: The rule IDs to archive. - client_keys: The client keys to archive. - """ - ... - - def batch_get( - self, *, rule_ids: list[str] | None = None, client_keys: list[str] | None = None - ) -> list[Rule]: - """Get multiple rules by rule IDs or client keys. - - Args: - rule_ids: List of rule IDs to get. - client_keys: List of client keys to get. + rule: The id or Rule object of the rule to archive. Returns: - List of Rules. + The archived Rule. """ ... - def batch_restore( - self, *, rule_ids: list[str] | None = None, client_keys: list[str] | None = None - ) -> None: - """Batch restore rules. + def create(self, create: RuleCreate | dict) -> Rule: + """Create a new rule. Args: - rule_ids: List of rule IDs to restore. - client_keys: List of client keys to undelete. - """ - ... + create: A RuleCreate object or dictionary with configuration for the new rule. - def create( - self, - name: str, - description: str, - expression: str, - channel_references: list[ChannelReference], - action: RuleAction, - organization_id: str | None = None, - client_key: str | None = None, - asset_ids: list[str] | None = None, - contextual_channels: list[str] | None = None, - is_external: bool = False, - ) -> Rule: - """Create a new rule.""" + Returns: + The created Rule. + """ ... def find(self, **kwargs) -> Rule | None: @@ -602,9 +573,22 @@ class RulesAPI: name: str | None = None, name_contains: str | None = None, name_regex: str | re.Pattern | None = None, + rule_ids: list[str] | None = None, + client_keys: list[str] | None = None, + created_after: datetime | None = None, + created_before: datetime | None = None, + modified_after: datetime | None = None, + modified_before: datetime | None = None, + created_by: Any | str | None = None, + modified_by: Any | str | None = None, + metadata: list[Any] | None = None, + asset_ids: list[str] | None = None, + asset_tag_ids: list[str] | None = None, + description_contains: str | None = None, + include_archived: bool = False, + filter_query: str | None = None, order_by: str | None = None, limit: int | None = None, - include_deleted: bool = False, ) -> list[Rule]: """List rules with optional filtering. @@ -612,32 +596,41 @@ class RulesAPI: name: Exact name of the rule. name_contains: Partial name of the rule. name_regex: Regular expression string to filter rules by name. - order_by: How to order the retrieved rules. - limit: How many rules to retrieve. If None, retrieves all matches. - include_deleted: Include deleted rules. + rule_ids: IDs of rules to filter to. + client_keys: Client keys of rules to filter to. + created_after: Rules created after this datetime. + created_before: Rules created before this datetime. + modified_after: Rules modified after this datetime. + modified_before: Rules modified before this datetime. + created_by: Filter rules created by this User or user ID. + modified_by: Filter rules last modified by this User or user ID. + metadata: Filter rules by metadata criteria. + asset_ids: Filter rules associated with any of these Asset IDs. + asset_tag_ids: Filter rules associated with any of these Asset Tag IDs. + description_contains: Partial description of the rule. + include_archived: If True, include archived rules in results. + filter_query: Explicit CEL query to filter rules. + order_by: Field and direction to order results by. + limit: Maximum number of rules to return. If None, returns all matches. Returns: A list of Rules that matches the filter. """ ... - def restore( - self, *, rule: str | Rule, rule_id: str | None = None, client_key: str | None = None - ) -> Rule: - """Restore a rule. + def unarchive(self, rule: str | Rule) -> Rule: + """Unarchive a rule. Args: - rule: The Rule or rule ID to restore. - rule_id: The rule ID to restore (alternative to rule parameter). - client_key: The client key to restore (alternative to rule parameter). + rule: The id or Rule object of the rule to unarchive. Returns: - The restored Rule. + The unarchived Rule. """ ... def update( - self, rule: str | Rule, update: RuleUpdate | dict, version_notes: str | None = None + self, rule: Rule | str, update: RuleUpdate | dict, *, version_notes: str | None = None ) -> Rule: """Update a Rule. @@ -656,12 +649,11 @@ class RunsAPI: High-level API for interacting with runs. - This class provides a Pythonic, notebook-friendly interface for interacting with the RunsAPI. - It handles automatic handling of gRPC services, seamless type conversion, and clear error handling. - - All methods in this class use the Run class from the low-level wrapper, which is a user-friendly - representation of a run using standard Python data structures and types. + This class provides a Pythonic, notebook-friendly interface for interacting with the RunsAPI. + It handles automatic handling of gRPC services, seamless type conversion, and clear error handling. + All methods in this class use the Run class from the low-level wrapper, which is a user-friendly + representation of a run using standard Python data structures and types. """ def __init__(self, sift_client: SiftClient): @@ -673,7 +665,7 @@ class RunsAPI: ... def _run(self, coro): ... - def archive(self, *, run: str | Run) -> None: + def archive(self, run: str | Run) -> Run: """Archive a run. Args: @@ -681,28 +673,11 @@ class RunsAPI: """ ... - def create( - self, - name: str, - description: str, - tags: list[str] | None = None, - start_time: datetime | None = None, - stop_time: datetime | None = None, - organization_id: str | None = None, - client_key: str | None = None, - metadata: dict[str, str | float | bool] | None = None, - ) -> Run: + def create(self, create: RunCreate | dict) -> Run: """Create a new run. Args: - name: The name of the run. - description: The description of the run. - tags: Tags to associate with the run. - start_time: The start time of the run. - stop_time: The stop time of the run. - organization_id: The organization ID. - client_key: A unique client key for the run. - metadata: Metadata values for the run. + create: The Run definition to create. Returns: The created Run. @@ -710,7 +685,7 @@ class RunsAPI: ... def create_automatic_association_for_assets( - self, run: str | Run, asset_names: list[str] + self, run: str | Run, *, asset_names: list[str] ) -> None: """Associate assets with a run for automatic data ingestion. @@ -721,22 +696,23 @@ class RunsAPI: ... def find(self, **kwargs) -> Run | None: - """Find a single run matching the given query. Takes the same arguments as `list`. If more than one run is found, + """Find a single run matching the given query. Takes the same arguments as `list_`. If more than one run is found, raises an error. Args: - **kwargs: Keyword arguments to pass to `list`. + **kwargs: Keyword arguments to pass to `list_`. Returns: The Run found or None. """ ... - def get(self, *, run_id: str) -> Run: + def get(self, *, run_id: str | None = None, client_key: str | None = None) -> Run: """Get a Run. Args: run_id: The ID of the run. + client_key: The client key of the run. Returns: The Run. @@ -749,15 +725,26 @@ class RunsAPI: name: str | None = None, name_contains: str | None = None, name_regex: str | re.Pattern | None = None, - description: str | None = None, - description_contains: str | None = None, - duration_seconds: int | None = None, - client_key: str | None = None, - asset_id: str | None = None, - asset_name: str | None = None, - created_by_user_id: str | None = None, + run_ids: list[str] | None = None, + client_keys: list[str] | None = None, + created_after: datetime | None = None, + created_before: datetime | None = None, + modified_after: datetime | None = None, + modified_before: datetime | None = None, + created_by: Any | str | None = None, + modified_by: Any | str | None = None, + metadata: list[Any] | None = None, + assets: list[Asset] | list[str] | None = None, + duration_less_than: timedelta | None = None, + duration_greater_than: timedelta | None = None, + start_time_after: datetime | None = None, + start_time_before: datetime | None = None, + stop_time_after: datetime | None = None, + stop_time_before: datetime | None = None, is_stopped: bool | None = None, + description_contains: str | None = None, include_archived: bool = False, + filter_query: str | None = None, order_by: str | None = None, limit: int | None = None, ) -> list[Run]: @@ -766,25 +753,36 @@ class RunsAPI: Args: name: Exact name of the run. name_contains: Partial name of the run. - name_regex: Regular expression string to filter runs by name. - description: Exact description of the run. - description_contains: Partial description of the run. - duration_seconds: Duration of the run in seconds. - client_key: Client key to filter by. - asset_id: Asset ID to filter by. - asset_name: Asset name to filter by. - created_by_user_id: User ID who created the run. + name_regex: Regular expression to filter runs by name. + run_ids: Filter to runs with any of these IDs. + client_keys: Filter to runs with any of these client keys. + created_after: Filter runs created after this datetime. + created_before: Filter runs created before this datetime. + modified_after: Filter runs modified after this datetime. + modified_before: Filter runs modified before this datetime. + created_by: Filter runs created by this User or user ID. + modified_by: Filter runs last modified by this User or user ID. + metadata: Filter runs by metadata criteria. + assets: Filter runs associated with any of these Assets or asset IDs. + duration_less_than: Filter runs with duration less than this time. + duration_greater_than: Filter runs with duration greater than this time. + start_time_after: Filter runs that started after this datetime. + start_time_before: Filter runs that started before this datetime. + stop_time_after: Filter runs that stopped after this datetime. + stop_time_before: Filter runs that stopped before this datetime. is_stopped: Whether the run is stopped. - include_archived: Whether to include archived runs. - order_by: How to order the retrieved runs. - limit: How many runs to retrieve. If None, retrieves all matches. + description_contains: Partial description of the run. + include_archived: If True, include archived runs in results. + filter_query: Explicit CEL query to filter runs. + order_by: Field and direction to order results by. + limit: Maximum number of runs to return. If None, returns all matches. Returns: - A list of Runs that matches the filter. + A list of Run objects that match the filter criteria. """ ... - def stop(self, *, run: str | Run) -> None: + def stop(self, run: str | Run) -> None: """Stop a run by setting its stop time to the current time. Args: @@ -792,11 +790,11 @@ class RunsAPI: """ ... - def stop_run(self, run: str | Run) -> None: - """Stop a run by setting its stop time to the current time. + def unarchive(self, run: str | Run) -> Run: + """Unarchive a run. Args: - run: The Run or run ID to stop. + run: The Run or run ID to unarchive. """ ... diff --git a/python/lib/sift_client/sift_types/__init__.py b/python/lib/sift_client/sift_types/__init__.py index 6a389fa51..1af4eb80d 100644 --- a/python/lib/sift_client/sift_types/__init__.py +++ b/python/lib/sift_client/sift_types/__init__.py @@ -1,6 +1,7 @@ from sift_client.sift_types.asset import Asset, AssetUpdate from sift_client.sift_types.calculated_channel import ( CalculatedChannel, + CalculatedChannelCreate, CalculatedChannelUpdate, ) from sift_client.sift_types.channel import ( @@ -9,33 +10,39 @@ ChannelDataType, ChannelReference, ) -from sift_client.sift_types.ingestion import IngestionConfig +from sift_client.sift_types.ingestion import ChannelConfig, Flow, IngestionConfig from sift_client.sift_types.rule import ( Rule, RuleAction, RuleActionType, RuleAnnotationType, + RuleCreate, RuleUpdate, RuleVersion, ) -from sift_client.sift_types.run import Run, RunUpdate +from sift_client.sift_types.run import Run, RunCreate, RunUpdate __all__ = [ "Asset", "AssetUpdate", "CalculatedChannel", + "CalculatedChannelCreate", "CalculatedChannelUpdate", "Channel", "ChannelBitFieldElement", + "ChannelConfig", "ChannelDataType", "ChannelReference", + "Flow", "IngestionConfig", "Rule", "RuleAction", "RuleActionType", "RuleAnnotationType", + "RuleCreate", "RuleUpdate", "RuleVersion", "Run", + "RunCreate", "RunUpdate", ] diff --git a/python/lib/sift_client/sift_types/_base.py b/python/lib/sift_client/sift_types/_base.py index 9254992cf..adbdf37e7 100644 --- a/python/lib/sift_client/sift_types/_base.py +++ b/python/lib/sift_client/sift_types/_base.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, TypeVar from google.protobuf import field_mask_pb2, message -from pydantic import BaseModel, ConfigDict, PrivateAttr +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr if TYPE_CHECKING: from sift_client.client import SiftClient @@ -18,6 +18,7 @@ class BaseType(BaseModel, Generic[ProtoT, SelfT], ABC): model_config = ConfigDict(frozen=True) id_: str | None = None + proto: Any | None = Field(default=None, exclude=True) # For user reference only _client: SiftClient | None = None @property @@ -28,6 +29,13 @@ def client(self) -> SiftClient: ) return self._client + @property + def _id_or_error(self) -> str: + """Get the ID of this instance or raise an error if it's not set for type safe usage.""" + if self.id_ is None: + raise ValueError("ID is not set") + return self.id_ + @classmethod @abstractmethod def _from_proto(cls, proto: ProtoT, sift_client: SiftClient | None = None) -> SelfT: ... @@ -42,6 +50,9 @@ def _update(self, other: BaseType[ProtoT, SelfT]) -> BaseType[ProtoT, SelfT]: for key in other.__class__.model_fields.keys(): if key in self.model_fields: self.__dict__.update({key: getattr(other, key)}) + + # Make sure we also update the proto since it is excluded + self.__dict__["proto"] = other.proto return self @@ -58,13 +69,10 @@ class MappingHelper(BaseModel): converter: type[Any] | Callable[[Any], Any] | None = None -# TODO: how to handle nulling fields, needs to be default value for the type -class ModelUpdate(BaseModel, Generic[ProtoT], ABC): - """Base class for Pydantic models that generate proto patches with field masks.""" +class ModelCreateUpdateBase(BaseModel, ABC): + """Base class for Pydantic models that generate proto messages.""" model_config = ConfigDict(frozen=False) - - _resource_id: Any | None = PrivateAttr(default=None) _to_proto_helpers: ClassVar[dict[str, MappingHelper]] = PrivateAttr(default={}) def __init__(self, **data: Any): @@ -77,28 +85,6 @@ def __init__(self, **data: Any): f"MappingHelper created for {expected_field} but {self.__class__.__name__} has no matching variable names." ) - @property - def resource_id(self): - return self._resource_id - - @resource_id.setter - def resource_id(self, value): - self._resource_id = value - - def to_proto_with_mask(self) -> tuple[ProtoT, field_mask_pb2.FieldMask]: - """Convert to proto with field mask.""" - # Get the corresponding proto class - proto_cls: type[ProtoT] = self._get_proto_class() - proto_msg = proto_cls() - - # Get only explicitly set fields, including those set to None - data = self.model_dump(exclude_unset=True, exclude_none=False) - paths = self._build_proto_and_paths(proto_msg, data) - - self._add_resource_id_to_proto(proto_msg) - mask = field_mask_pb2.FieldMask(paths=paths) - return proto_msg, mask - def _build_proto_and_paths( self, proto_msg, data, prefix="", already_setting_path_override=False ) -> list[str]: @@ -176,10 +162,61 @@ def _build_proto_and_paths( return paths + +class ModelCreate(ModelCreateUpdateBase, Generic[ProtoT], ABC): + """Base class for Pydantic models that generate proto messages for creation.""" + + @abstractmethod + def _get_proto_class(self) -> type[ProtoT]: + """Get the corresponding proto class - override in subclasses since typing is not strict.""" + raise NotImplementedError("Subclasses must implement this") + + def to_proto(self) -> ProtoT: + """Convert to proto.""" + # Get the corresponding proto class + proto_cls: type[ProtoT] = self._get_proto_class() + proto_msg = proto_cls() + + # Get all fields + data = self.model_dump(exclude_none=False) + self._build_proto_and_paths(proto_msg, data) + + return proto_msg + + +class ModelUpdate(ModelCreateUpdateBase, Generic[ProtoT], ABC): + """Base class for Pydantic models that generate proto patches with field masks.""" + + _resource_id: str | None = PrivateAttr(default=None) + + @property + def resource_id(self): + return self._resource_id + + @resource_id.setter + def resource_id(self, value): + self._resource_id = value + + def to_proto_with_mask(self) -> tuple[ProtoT, field_mask_pb2.FieldMask]: + """Convert to proto with field mask.""" + # Get the corresponding proto class + proto_cls: type[ProtoT] = self._get_proto_class() + proto_msg = proto_cls() + + # Get only explicitly set fields, including those set to None + data = self.model_dump(exclude_unset=True, exclude_none=False) + paths = self._build_proto_and_paths(proto_msg, data) + + self._add_resource_id_to_proto(proto_msg) + mask = field_mask_pb2.FieldMask(paths=paths) + return proto_msg, mask + + @abstractmethod def _get_proto_class(self) -> type[ProtoT]: """Get the corresponding proto class - override in subclasses since typing is not strict.""" raise NotImplementedError("Subclasses must implement this") + @abstractmethod def _add_resource_id_to_proto(self, proto_msg: ProtoT): """Assigns a resource ID (such as Asset ID) to the proto message.""" raise NotImplementedError("Subclasses must implement this") diff --git a/python/lib/sift_client/sift_types/asset.py b/python/lib/sift_client/sift_types/asset.py index ad078a9ee..7b02a3dce 100644 --- a/python/lib/sift_client/sift_types/asset.py +++ b/python/lib/sift_client/sift_types/asset.py @@ -17,6 +17,7 @@ class Asset(BaseType[AssetProto, "Asset"]): """Model of the Sift Asset.""" + # Required fields name: str organization_id: str created_date: datetime @@ -25,15 +26,10 @@ class Asset(BaseType[AssetProto, "Asset"]): modified_by_user_id: str tags: list[str] metadata: dict[str, str | float | bool] - archived_date: datetime | None + is_archived: bool - @property - def is_archived(self): - """Whether the asset is archived.""" - # TODO: clean up this logic when gRPC returns a null. - return self.archived_date is not None and self.archived_date > datetime( - 1970, 1, 1, tzinfo=timezone.utc - ) + # Optional fields + archived_date: datetime | None @property def created_by(self): @@ -48,11 +44,11 @@ def modified_by(self): @property def runs(self) -> list[Run]: """Get the runs associated with this asset.""" - return self.client.runs.list_(asset_id=self.id_) + return self.client.runs.list_(assets=[self]) - def channels(self, run_id: str | None = None, limit: int | None = None) -> list[Channel]: + def channels(self, run: Run | str | None = None, limit: int | None = None) -> list[Channel]: """Get the channels for this asset.""" - return self.client.channels.list_(asset_id=self.id_, run_id=run_id, limit=limit) + return self.client.channels.list_(asset=self, run=run, limit=limit) @property def rules(self): @@ -74,6 +70,12 @@ def archive(self, *, archive_runs: bool = False) -> Asset: self._update(updated_asset) return self + def unarchive(self) -> Asset: + """Unarchive the asset.""" + updated_asset = self.client.assets.unarchive(asset=self) + self._update(updated_asset) + return self + def update(self, update: AssetUpdate | dict) -> Asset: """Update the Asset. @@ -88,6 +90,7 @@ def update(self, update: AssetUpdate | dict) -> Asset: @classmethod def _from_proto(cls, proto: AssetProto, sift_client: SiftClient | None = None) -> Asset: return cls( + proto=proto, id_=proto.asset_id, name=proto.name, organization_id=proto.organization_id, @@ -97,6 +100,7 @@ def _from_proto(cls, proto: AssetProto, sift_client: SiftClient | None = None) - modified_by_user_id=proto.modified_by_user_id, tags=list(proto.tags) if proto.tags else [], archived_date=proto.archived_date.ToDatetime(tzinfo=timezone.utc), + is_archived=proto.is_archived, metadata=metadata_proto_to_dict(proto.metadata), # type: ignore _client=sift_client, ) @@ -106,12 +110,14 @@ class AssetUpdate(ModelUpdate[AssetProto]): """Model of the Asset Fields that can be updated.""" tags: list[str] | None = None - archived_date: datetime | str | None = None metadata: dict[str, str | float | bool] | None = None + is_archived: bool | None = None _to_proto_helpers: ClassVar = { "metadata": MappingHelper( - proto_attr_path="metadata", update_field="metadata", converter=metadata_dict_to_proto + proto_attr_path="metadata", + update_field="metadata", + converter=metadata_dict_to_proto, ), } diff --git a/python/lib/sift_client/sift_types/calculated_channel.py b/python/lib/sift_client/sift_types/calculated_channel.py index 4f2bd8b71..987bb386a 100644 --- a/python/lib/sift_client/sift_types/calculated_channel.py +++ b/python/lib/sift_client/sift_types/calculated_channel.py @@ -1,17 +1,26 @@ from __future__ import annotations from datetime import datetime, timezone -from typing import TYPE_CHECKING, Any, ClassVar +from typing import TYPE_CHECKING, ClassVar +from pydantic import model_validator from sift.calculated_channels.v2.calculated_channels_pb2 import ( CalculatedChannel as CalculatedChannelProto, ) from sift.calculated_channels.v2.calculated_channels_pb2 import ( CalculatedChannelAbstractChannelReference, + CreateCalculatedChannelRequest, ) -from sift_client.sift_types._base import BaseType, MappingHelper, ModelUpdate +from sift_client.sift_types._base import ( + BaseType, + MappingHelper, + ModelCreate, + ModelCreateUpdateBase, + ModelUpdate, +) from sift_client.sift_types.channel import ChannelReference +from sift_client.util.metadata import metadata_dict_to_proto if TYPE_CHECKING: from sift_client.client import SiftClient @@ -24,6 +33,7 @@ class CalculatedChannel(BaseType[CalculatedChannelProto, "CalculatedChannel"]): description: str expression: str channel_references: list[ChannelReference] + is_archived: bool units: str | None asset_ids: list[str] | None @@ -41,13 +51,6 @@ class CalculatedChannel(BaseType[CalculatedChannelProto, "CalculatedChannel"]): created_by_user_id: str | None modified_by_user_id: str | None - @property - def is_archived(self): - """Whether the calculated channel is archived.""" - return self.archived_date is not None and self.archived_date > datetime( - 1970, 1, 1, tzinfo=timezone.utc - ) - @property def created_by(self): """Get the user that created this calculated channel.""" @@ -63,6 +66,11 @@ def archive(self) -> CalculatedChannel: self.client.calculated_channels.archive(calculated_channel=self) return self + def unarchive(self) -> CalculatedChannel: + """Unarchive the calculated channel.""" + self.client.calculated_channels.unarchive(calculated_channel=self) + return self + def update( self, update: CalculatedChannelUpdate | dict, @@ -106,6 +114,7 @@ def _from_proto( if proto.HasField("archived_date") else None ), + is_archived=proto.is_archived, version_id=proto.version_id, version=proto.version, change_message=proto.change_message, @@ -122,17 +131,23 @@ def _from_proto( ) -class CalculatedChannelUpdate(ModelUpdate[CalculatedChannelProto]): - """Model of the Calculated Channel Fields that can be updated.""" +class CalculatedChannelBase(ModelCreateUpdateBase): + """Base class for CalculatedChannel create and update models with shared fields and validation.""" - name: str | None = None description: str | None = None + user_notes: str | None = None units: str | None = None + expression: str | None = None # This is named expression_channel_references to match the protobuf field name for easier deserialization. expression_channel_references: list[ChannelReference] | None = None + + # Scoping of the calculated channel. tag_ids: list[str] | None = None - archived_date: datetime | None = None + asset_ids: list[str] | None = None + all_assets: bool | None = None + + metadata: dict[str, str | float | bool] | None = None _to_proto_helpers: ClassVar = { "expression": MappingHelper( @@ -148,23 +163,61 @@ class CalculatedChannelUpdate(ModelUpdate[CalculatedChannelProto]): proto_attr_path="calculated_channel_configuration.asset_configuration.selection.tag_ids", update_field="asset_configuration", ), + "asset_ids": MappingHelper( + proto_attr_path="calculated_channel_configuration.asset_configuration.selection.asset_ids", + update_field="asset_configuration", + ), + "all_assets": MappingHelper( + proto_attr_path="calculated_channel_configuration.asset_configuration.all_assets", + ), + "metadata": MappingHelper( + proto_attr_path="metadata", + update_field="metadata", + converter=metadata_dict_to_proto, + ), } - def __init__(self, **data: Any): - """Initialize a CalculatedChannelUpdate instance. - - Args: - **data: Keyword arguments for the update fields. + @model_validator(mode="after") + def _validate_asset_configuration(self): + """Validate that either all_assets is True or at least one of tag_ids or asset_ids is provided, but not both.""" + if self.all_assets is not None and self.all_assets and (self.asset_ids or self.tag_ids): + raise ValueError("Cannot specify both all_assets=True and asset_ids/tag_ids") + return self - Raises: - ValueError: If only one of expression or expression_channel_references is provided. - Both must be provided together or neither should be provided. - """ - super().__init__(**data) + @model_validator(mode="after") + def _validate_expression_and_channel_references(self): + """Validate that expression and expression_channel_references are set together.""" if any([self.expression, self.expression_channel_references]) and not all( [self.expression, self.expression_channel_references] ): raise ValueError("Expression and channel references must be set together") + return self + + +class CalculatedChannelCreate(CalculatedChannelBase, ModelCreate[CreateCalculatedChannelRequest]): + """Create model for a Calculated Channel.""" + + name: str + client_key: str | None = None + + def _get_proto_class(self) -> type[CreateCalculatedChannelRequest]: + return CreateCalculatedChannelRequest + + +class CalculatedChannelUpdate(CalculatedChannelBase, ModelUpdate[CalculatedChannelProto]): + """Update model for a Calculated Channel.""" + + name: str | None = None + is_archived: bool | None = None + + @model_validator(mode="after") + def _validate_non_updatable_fields(self): + """Validate that the fields that cannot be updated are not set.""" + if self.user_notes is not None: + raise ValueError("Cannot update user notes") + if self.client_key is not None: + raise ValueError("Cannot update client key") + return self def _get_proto_class(self) -> type[CalculatedChannelProto]: return CalculatedChannelProto diff --git a/python/lib/sift_client/sift_types/channel.py b/python/lib/sift_client/sift_types/channel.py index 9d91afe0c..2c18dbb83 100644 --- a/python/lib/sift_client/sift_types/channel.py +++ b/python/lib/sift_client/sift_types/channel.py @@ -24,7 +24,6 @@ Uint32Values, Uint64Values, ) -from sift.ingestion_configs.v2.ingestion_configs_pb2 import ChannelConfig from sift_client.sift_types._base import BaseType @@ -207,17 +206,21 @@ def _to_proto(self) -> ChannelBitFieldElementPb: class Channel(BaseType[ChannelProto, "Channel"]): """Model representing a Sift Channel.""" + # Required fields name: str data_type: ChannelDataType - description: str | None = None - unit: str | None = None + description: str + unit: str bit_field_elements: list[ChannelBitFieldElement] = Field(default_factory=list) enum_types: dict[str, int] = Field(default_factory=dict) - asset_id: str | None = None - created_date: datetime | None = None - modified_date: datetime | None = None - created_by_user_id: str | None = None - modified_by_user_id: str | None = None + asset_id: str + created_date: datetime + modified_date: datetime + created_by_user_id: str + modified_by_user_id: str + + # Optional fields + ... @staticmethod def _enum_types_to_proto_list(enum_types: dict[str, int] | None) -> list[ChannelEnumTypePb]: @@ -231,45 +234,24 @@ def _enum_types_from_proto_list(enum_types: list[ChannelEnumTypePb]) -> dict[str return {enum.name: enum.key for enum in enum_types} @classmethod - def _from_proto( - cls, proto: ChannelProto | ChannelConfig, sift_client: SiftClient | None = None - ) -> Channel: - if isinstance(proto, ChannelProto): - return cls( - id_=proto.channel_id, - name=proto.name, - data_type=ChannelDataType(proto.data_type), - description=proto.description, - unit=proto.unit_id, - bit_field_elements=[ - ChannelBitFieldElement._from_proto(el) for el in proto.bit_field_elements - ], - enum_types=cls._enum_types_from_proto_list(proto.enum_types), # type: ignore - asset_id=proto.asset_id, - created_date=proto.created_date.ToDatetime(tzinfo=timezone.utc), - modified_date=proto.modified_date.ToDatetime(tzinfo=timezone.utc), - created_by_user_id=proto.created_by_user_id, - modified_by_user_id=proto.modified_by_user_id, - _client=sift_client, - ) - elif isinstance(proto, ChannelConfig): - return cls( - id_=proto.name, - name=proto.name, - data_type=ChannelDataType(proto.data_type), - _client=sift_client, - ) - - def _to_config_proto(self) -> ChannelConfig: - return ChannelConfig( - name=self.name, - data_type=self.data_type.value, - description=self.description, # type: ignore - unit=self.unit, # type: ignore - bit_field_elements=[el._to_proto() for el in self.bit_field_elements] - if self.bit_field_elements - else None, - enum_types=self._enum_types_to_proto_list(self.enum_types), + def _from_proto(cls, proto: ChannelProto, sift_client: SiftClient | None = None) -> Channel: + return cls( + proto=proto, + id_=proto.channel_id, + name=proto.name, + data_type=ChannelDataType(proto.data_type), + description=proto.description, + unit=proto.unit_id, + bit_field_elements=[ + ChannelBitFieldElement._from_proto(el) for el in proto.bit_field_elements + ], + enum_types=cls._enum_types_from_proto_list(proto.enum_types), # type: ignore + asset_id=proto.asset_id, + created_date=proto.created_date.ToDatetime(tzinfo=timezone.utc), + modified_date=proto.modified_date.ToDatetime(tzinfo=timezone.utc), + created_by_user_id=proto.created_by_user_id, + modified_by_user_id=proto.modified_by_user_id, + _client=sift_client, ) def data( @@ -296,7 +278,7 @@ def data( if as_arrow: data = self.client.channels.get_data_as_arrow( channels=[self], - run_id=run_id, + run=run_id, start_time=start_time, end_time=end_time, limit=limit, # type: ignore @@ -304,7 +286,7 @@ def data( else: data = self.client.channels.get_data( channels=[self], - run_id=run_id, + run=run_id, start_time=start_time, end_time=end_time, limit=limit, # type: ignore diff --git a/python/lib/sift_client/sift_types/ingestion.py b/python/lib/sift_client/sift_types/ingestion.py index b449bf228..2d6d22693 100644 --- a/python/lib/sift_client/sift_types/ingestion.py +++ b/python/lib/sift_client/sift_types/ingestion.py @@ -4,8 +4,11 @@ from typing import TYPE_CHECKING, Any from google.protobuf.empty_pb2 import Empty -from pydantic import ConfigDict +from pydantic import ConfigDict, model_validator from sift.ingest.v1.ingest_pb2 import IngestWithConfigDataChannelValue +from sift.ingestion_configs.v2.ingestion_configs_pb2 import ( + ChannelConfig as ChannelConfigProto, +) from sift.ingestion_configs.v2.ingestion_configs_pb2 import ( FlowConfig, ) @@ -22,12 +25,13 @@ ) from sift_client.sift_types._base import BaseType -from sift_client.sift_types.channel import Channel, ChannelDataType +from sift_client.sift_types.channel import ChannelBitFieldElement, ChannelDataType if TYPE_CHECKING: from datetime import datetime from sift_client.client import SiftClient + from sift_client.sift_types.channel import Channel class IngestionConfig(BaseType[IngestionConfigProto, "IngestionConfig"]): @@ -41,6 +45,7 @@ def _from_proto( cls, proto: IngestionConfigProto, sift_client: SiftClient | None = None ) -> IngestionConfig: return cls( + proto=proto, id_=proto.ingestion_config_id, asset_id=proto.asset_id, client_key=proto.client_key, @@ -48,6 +53,102 @@ def _from_proto( ) +class ChannelConfig(BaseType[ChannelConfigProto, "ChannelConfig"]): + """Channel configuration model for ingestion purposes. + + This model contains only the fields needed for ingestion configuration, + without the full metadata from the Channels API. + """ + + model_config = ConfigDict(frozen=False) + name: str + data_type: ChannelDataType + description: str | None = None + unit: str | None = None + bit_field_elements: list[ChannelBitFieldElement] | None = None + enum_types: dict[str, int] | None = None + + @model_validator(mode="after") + def _validate_enum_types(self): + """Validate that enum_types is provided when data_type is ENUM.""" + if self.data_type == ChannelDataType.ENUM and not self.enum_types: + raise ValueError( + f"Channel '{self.name}' has data_type ENUM but enum_types is not provided" + ) + elif self.data_type == ChannelDataType.BIT_FIELD and not self.bit_field_elements: + raise ValueError( + f"Channel '{self.name}' has data_type BIT_FIELD but bit_field_elements is not provided" + ) + return self + + @classmethod + def _from_proto( + cls, proto: ChannelConfigProto, sift_client: SiftClient | None = None + ) -> ChannelConfig: + """Create ChannelConfig from ChannelConfigProto.""" + return cls( + proto=proto, + name=proto.name, + data_type=ChannelDataType(proto.data_type), + description=proto.description if proto.description else None, + unit=proto.unit if proto.unit else None, + bit_field_elements=[ + ChannelBitFieldElement._from_proto(el) for el in proto.bit_field_elements + ] + if proto.bit_field_elements + else None, + enum_types={enum.name: enum.key for enum in proto.enum_types} + if proto.enum_types + else None, + _client=sift_client, + ) + + @classmethod + def from_channel(cls, channel: Channel) -> ChannelConfig: + """Create ChannelConfig from a Channel. + + Args: + channel: The Channel to convert. + + Returns: + A ChannelConfig with the channel's configuration data. + """ + return cls( + name=channel.name, + data_type=channel.data_type, + description=channel.description, + unit=channel.unit, + bit_field_elements=channel.bit_field_elements if channel.bit_field_elements else None, + enum_types=channel.enum_types, + ) + + def _to_config_proto(self) -> ChannelConfigProto: + """Convert to ChannelConfigProto for ingestion.""" + from sift.common.type.v1.channel_bit_field_element_pb2 import ( + ChannelBitFieldElement as ChannelBitFieldElementPb, + ) + from sift.common.type.v1.channel_enum_type_pb2 import ChannelEnumType as ChannelEnumTypePb + + return ChannelConfigProto( + name=self.name, + data_type=self.data_type.value, + description=self.description or "", + unit=self.unit or "", + bit_field_elements=[ + ChannelBitFieldElementPb( + name=bfe.name, + index=bfe.index, + bit_count=bfe.bit_count, + ) + for bfe in self.bit_field_elements or [] + ], + enum_types=[ + ChannelEnumTypePb(name=name, key=key) + for name, key in (self.enum_types or {}).items() + ], + ) + + class Flow(BaseType[FlowConfig, "Flow"]): """Model representing a data flow for ingestion. @@ -56,15 +157,16 @@ class Flow(BaseType[FlowConfig, "Flow"]): model_config = ConfigDict(frozen=False) name: str - channels: list[Channel] + channels: list[ChannelConfig] ingestion_config_id: str | None = None run_id: str | None = None @classmethod def _from_proto(cls, proto: FlowConfig, sift_client: SiftClient | None = None) -> Flow: return cls( + proto=proto, name=proto.name, - channels=[Channel._from_proto(channel) for channel in proto.channels], + channels=[ChannelConfig._from_proto(channel) for channel in proto.channels], _client=sift_client, ) @@ -80,11 +182,11 @@ def _to_rust_config(self) -> FlowConfigPy: channels=[_channel_to_rust_config(channel) for channel in self.channels], ) - def add_channel(self, channel: Channel): - """Add a Channel to this Flow. + def add_channel(self, channel: ChannelConfig): + """Add a ChannelConfig to this Flow. Args: - channel: The Channel to add. + channel: The ChannelConfig to add. Raises: ValueError: If the flow has already been created with an ingestion config. @@ -113,7 +215,7 @@ def ingest(self, *, timestamp: datetime, channel_values: dict[str, Any]): # Converter functions. -def _channel_to_rust_config(channel: Channel) -> ChannelConfigPy: +def _channel_to_rust_config(channel: ChannelConfig) -> ChannelConfigPy: return ChannelConfigPy( name=channel.name, data_type=_to_rust_type(channel.data_type), @@ -133,7 +235,7 @@ def _channel_to_rust_config(channel: Channel) -> ChannelConfigPy: def _rust_channel_value_from_bitfield( - channel: Channel, value: Any + channel: ChannelConfig, value: Any ) -> IngestWithConfigDataChannelValuePy: """Helper function to convert a bitfield value to a ChannelValuePy object. @@ -169,10 +271,10 @@ def _rust_channel_value_from_bitfield( return IngestWithConfigDataChannelValuePy.bitfield(byte_array) -def _to_rust_value(channel: Channel, value: Any) -> IngestWithConfigDataChannelValuePy: +def _to_rust_value(channel: ChannelConfig, value: Any) -> IngestWithConfigDataChannelValuePy: if value is None: return IngestWithConfigDataChannelValuePy.empty() - if channel.data_type == ChannelDataType.ENUM: + if channel.data_type == ChannelDataType.ENUM and channel.enum_types is not None: enum_name = value enum_val = channel.enum_types.get(enum_name) if enum_val is None: diff --git a/python/lib/sift_client/sift_types/rule.py b/python/lib/sift_client/sift_types/rule.py index 90c40e6ab..ac4a4e3d6 100644 --- a/python/lib/sift_client/sift_types/rule.py +++ b/python/lib/sift_client/sift_types/rule.py @@ -1,6 +1,5 @@ from __future__ import annotations -from datetime import datetime, timezone from enum import Enum from typing import TYPE_CHECKING @@ -16,6 +15,9 @@ ChannelReferencesEntry = CalculatedChannelConfig.ChannelReferencesEntry del CalculatedChannelConfig +from sift.rules.v1.rules_pb2 import ( + CreateRuleRequest, +) from sift.rules.v1.rules_pb2 import ( Rule as RuleProto, ) @@ -26,10 +28,12 @@ RuleVersion as RuleVersionProto, ) -from sift_client.sift_types._base import BaseType, ModelUpdate +from sift_client.sift_types._base import BaseType, ModelCreate, ModelUpdate from sift_client.sift_types.channel import ChannelReference if TYPE_CHECKING: + from datetime import datetime + from sift_client.client import SiftClient from sift_client.sift_types.asset import Asset @@ -37,38 +41,33 @@ class Rule(BaseType[RuleProto, "Rule"]): """Model of the Sift Rule.""" + # Required fields name: str description: str - is_enabled: bool = True - expression: str | None = None - channel_references: list[ChannelReference] | None = None - action: RuleAction | None = None - asset_ids: list[str] | None = None - asset_tag_ids: list[str] | None = None - contextual_channels: list[str] | None = None - client_key: str | None = None - - # Fields from proto - created_date: datetime | None = None - modified_date: datetime | None = None - created_by_user_id: str | None = None - modified_by_user_id: str | None = None - organization_id: str | None = None - rule_version: RuleVersion | None = None - archived_date: datetime | None = None - is_external: bool | None = None - - @property - def is_archived(self) -> bool: - """Whether the rule is archived.""" - return self.archived_date is not None and self.archived_date > datetime( - 1970, 1, 1, tzinfo=timezone.utc - ) + is_enabled: bool + created_date: datetime + modified_date: datetime + created_by_user_id: str + modified_by_user_id: str + organization_id: str + is_archived: bool + is_external: bool + + # Optional fields + expression: str | None + channel_references: list[ChannelReference] | None + action: RuleAction | None + asset_ids: list[str] | None + asset_tag_ids: list[str] | None + contextual_channels: list[str] | None + client_key: str | None + rule_version: RuleVersion | None + archived_date: datetime | None @property def assets(self) -> list[Asset]: """Get the assets that this rule applies to.""" - return self.client.assets.list_(asset_ids=self.asset_ids, tag_ids=self.asset_tag_ids) + return self.client.assets.list_(asset_ids=self.asset_ids, _tag_ids=self.asset_tag_ids) @property def organization(self): @@ -103,9 +102,17 @@ def update(self, update: RuleUpdate | dict, version_notes: str | None = None) -> self._update(updated_rule) return self - def archive(self) -> None: + def archive(self) -> Rule: """Archive the rule.""" - self.client.rules.archive(rule=self) + updated_rule = self.client.rules.archive(rule=self) + self._update(updated_rule) + return self + + def unarchive(self) -> Rule: + """Unarchive the rule.""" + updated_rule = self.client.rules.unarchive(rule=self) + self._update(updated_rule) + return self @classmethod def _from_proto(cls, proto: RuleProto, sift_client: SiftClient | None = None) -> Rule: @@ -115,6 +122,7 @@ def _from_proto(cls, proto: RuleProto, sift_client: SiftClient | None = None) -> else None ) return cls( + proto=proto, id_=proto.rule_id, name=proto.name, description=proto.description, @@ -139,12 +147,37 @@ def _from_proto(cls, proto: RuleProto, sift_client: SiftClient | None = None) -> asset_ids=proto.asset_configuration.asset_ids, # type: ignore asset_tag_ids=proto.asset_configuration.tag_ids, # type: ignore contextual_channels=[c.name for c in proto.contextual_channels.channels], - archived_date=proto.deleted_date.ToDatetime() if proto.deleted_date else None, + archived_date=(proto.archived_date.ToDatetime() if proto.archived_date else None), + is_archived=proto.is_archived, is_external=proto.is_external, _client=sift_client, ) +class RuleCreate(ModelCreate[CreateRuleRequest]): + """Model for creating a new Rule. + + Note: + - asset_ids applies this rule to those assets. + - asset_tag_ids applies this rule to assets with those tags. + """ + + name: str + description: str + expression: str + channel_references: list[ChannelReference] + action: RuleAction + organization_id: str | None = None + client_key: str | None = None + asset_ids: list[str] | None = None + asset_tag_ids: list[str] | None = None + contextual_channels: list[str] | None = None + is_external: bool = False + + def _get_proto_class(self) -> type[CreateRuleRequest]: + return CreateRuleRequest + + class RuleUpdate(ModelUpdate[RuleProto]): """Model of the Rule fields that can be updated. @@ -161,6 +194,7 @@ class RuleUpdate(ModelUpdate[RuleProto]): asset_ids: list[str] | None = None asset_tag_ids: list[str] | None = None contextual_channels: list[str] | None = None + is_archived: bool | None = None def _get_proto_class(self) -> type[RuleProto]: return RuleProto @@ -279,11 +313,13 @@ def _from_proto( else None ), action_type=action_type, - annotation_type=RuleAnnotationType.from_str( - proto.configuration.annotation.annotation_type # type: ignore - ) - if action_type == RuleActionType.ANNOTATION - else None, + annotation_type=( + RuleAnnotationType.from_str( + proto.configuration.annotation.annotation_type # type: ignore + ) + if action_type == RuleActionType.ANNOTATION + else None + ), _client=sift_client, ) @@ -314,7 +350,8 @@ class RuleVersion(BaseType[RuleVersionProto, "RuleVersion"]): created_by_user_id: str version_notes: str generated_change_message: str - deleted_date: datetime | None = None + archived_date: datetime | None + is_archived: bool @classmethod def _from_proto( @@ -328,6 +365,7 @@ def _from_proto( created_by_user_id=proto.created_by_user_id, version_notes=proto.version_notes, generated_change_message=proto.generated_change_message, - deleted_date=proto.deleted_date.ToDatetime() if proto.deleted_date else None, + archived_date=(proto.archived_date.ToDatetime() if proto.archived_date else None), + is_archived=proto.is_archived, _client=sift_client, ) diff --git a/python/lib/sift_client/sift_types/run.py b/python/lib/sift_client/sift_types/run.py index 3549e75be..e8292b8d3 100644 --- a/python/lib/sift_client/sift_types/run.py +++ b/python/lib/sift_client/sift_types/run.py @@ -1,12 +1,19 @@ from __future__ import annotations -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, ClassVar -from pydantic import ConfigDict +from pydantic import model_validator +from sift.runs.v2.runs_pb2 import CreateRunRequest as CreateRunRequestProto from sift.runs.v2.runs_pb2 import Run as RunProto -from sift_client.sift_types._base import BaseType, MappingHelper, ModelUpdate +from sift_client.sift_types._base import ( + BaseType, + MappingHelper, + ModelCreate, + ModelCreateUpdateBase, + ModelUpdate, +) from sift_client.util.metadata import metadata_dict_to_proto, metadata_proto_to_dict if TYPE_CHECKING: @@ -14,40 +21,10 @@ from sift_client.sift_types.asset import Asset -class RunUpdate(ModelUpdate[RunProto]): - """Update model for Run.""" - - model_config = ConfigDict(arbitrary_types_allowed=True) - - name: str | None = None - description: str | None = None - start_time: datetime | None = None - stop_time: datetime | None = None - is_pinned: bool | None = None - client_key: str | None = None - tags: list[str] | None = None - metadata: dict[str, str | float | bool] | None = None - - _to_proto_helpers: ClassVar = { - "metadata": MappingHelper( - proto_attr_path="metadata", update_field="metadata", converter=metadata_dict_to_proto - ), - } - - def _get_proto_class(self) -> type[RunProto]: - return RunProto - - def _add_resource_id_to_proto(self, proto_msg: RunProto): - if self._resource_id is None: - raise ValueError("Resource ID must be set before adding to proto") - proto_msg.run_id = self._resource_id - - class Run(BaseType[RunProto, "Run"]): """Run model representing a data collection run.""" - model_config = ConfigDict(arbitrary_types_allowed=True) - + # Required fields name: str description: str created_date: datetime @@ -55,18 +32,24 @@ class Run(BaseType[RunProto, "Run"]): created_by_user_id: str modified_by_user_id: str organization_id: str - start_time: datetime | None = None - stop_time: datetime | None = None - tags: list[str] | None = None - default_report_id: str | None = None - client_key: str | None = None metadata: dict[str, str | float | bool] - asset_ids: list[str] | None = None - archived_date: datetime | None = None + tags: list[str] + asset_ids: list[str] + is_adhoc: bool + is_archived: bool + + # Optional fields + start_time: datetime | None + stop_time: datetime | None + duration: timedelta | None + default_report_id: str | None + client_key: str | None + archived_date: datetime | None @classmethod def _from_proto(cls, proto: RunProto, sift_client: SiftClient | None = None) -> Run: return cls( + proto=proto, id_=proto.run_id, created_date=proto.created_date.ToDatetime(tzinfo=timezone.utc), modified_date=proto.modified_date.ToDatetime(tzinfo=timezone.utc), @@ -79,6 +62,7 @@ def _from_proto(cls, proto: RunProto, sift_client: SiftClient | None = None) -> stop_time=proto.stop_time.ToDatetime(tzinfo=timezone.utc) if proto.HasField("stop_time") else None, + duration=proto.duration.ToTimedelta() if proto.HasField("duration") else None, name=proto.name, description=proto.description, tags=list(proto.tags), @@ -89,48 +73,113 @@ def _from_proto(cls, proto: RunProto, sift_client: SiftClient | None = None) -> archived_date=proto.archived_date.ToDatetime() if proto.HasField("archived_date") else None, + is_archived=proto.is_archived, + is_adhoc=proto.is_adhoc, _client=sift_client, ) - def _to_proto(self) -> RunProto: - """Convert to protobuf message.""" - proto = RunProto( - run_id=self.id_ or "", - created_date=self.created_date, # type: ignore - modified_date=self.modified_date, # type: ignore - created_by_user_id=self.created_by_user_id, - modified_by_user_id=self.modified_by_user_id, - organization_id=self.organization_id, - is_pinned=False, - name=self.name, - description=self.description, - tags=self.tags, - metadata=metadata_dict_to_proto(self.metadata), - asset_ids=self.asset_ids, - ) + @property + def assets(self) -> list[Asset]: + """Return all assets associated with this run.""" + if not self.asset_ids: + return [] + return self.client.assets.list_(asset_ids=self.asset_ids) - if self.start_time is not None: - proto.start_time.FromDatetime(self.start_time) + def archive(self) -> Run: + """Archive the run.""" + updated_run = self.client.runs.archive(run=self) + self._update(updated_run) + return self - if self.stop_time is not None: - proto.stop_time.FromDatetime(self.stop_time) + def unarchive(self) -> Run: + """Unarchive the run.""" + updated_run = self.client.runs.unarchive(run=self) + self._update(updated_run) + return self - if self.default_report_id is not None: - proto.default_report_id = self.default_report_id + def update(self, update: RunUpdate | dict) -> Run: + """Update the Run. - if self.client_key is not None: - proto.client_key = self.client_key + Args: + update: The update to apply to the run. See RunUpdate for more updatable fields. - if self.archived_date is not None: - proto.archived_date.FromDatetime(self.archived_date) + Returns: + The updated run. + """ + updated_run = self.client.runs.update(run=self, update=update) + self._update(updated_run) + return self - return proto - @property - def assets(self) -> list[Asset]: - """Return all assets associated with this run.""" - if not hasattr(self, "client") or self.client is None: - raise RuntimeError("Run is not bound to a client instance.") - if not self.asset_ids: - return [] - return self.client.assets.list_(asset_ids=self.asset_ids) +class RunBase(ModelCreateUpdateBase): + """Base class for Run create and update models with shared fields and validation.""" + + description: str | None = None + start_time: datetime | None = None + stop_time: datetime | None = None + tags: list[str] | None = None + metadata: dict[str, str | float | bool] | None = None + + _to_proto_helpers: ClassVar = { + "metadata": MappingHelper( + proto_attr_path="metadata", update_field="metadata", converter=metadata_dict_to_proto + ), + } + + @model_validator(mode="after") + def _validate_time_fields(self): + """Validate time-related fields after initialization.""" + if self.stop_time is not None and self.start_time is None: + raise ValueError("start_time must be provided if stop_time is provided") + + if self.start_time is not None and self.stop_time is not None: + if self.start_time >= self.stop_time: + raise ValueError("start_time must be before stop_time") + + return self + + +class RunCreate(RunBase, ModelCreate[CreateRunRequestProto]): + """Create model for Run.""" + + name: str + client_key: str | None = None + tags: list[str] | None = None + metadata: dict[str, str | float | bool] | None = None + start_time: datetime | None = None + stop_time: datetime | None = None + organization_id: str | None = None + + _to_proto_helpers: ClassVar = { + "metadata": MappingHelper( + proto_attr_path="metadata", update_field="metadata", converter=metadata_dict_to_proto + ), + } + + def _get_proto_class(self) -> type[CreateRunRequestProto]: + return CreateRunRequestProto + + +class RunUpdate(RunBase, ModelUpdate[RunProto]): + """Update model for Run.""" + + name: str | None = None + tags: list[str] | None = None + metadata: dict[str, str | float | bool] | None = None + start_time: datetime | None = None + stop_time: datetime | None = None + is_archived: bool | None = None + + _to_proto_helpers: ClassVar = { + "metadata": MappingHelper( + proto_attr_path="metadata", update_field="metadata", converter=metadata_dict_to_proto + ), + } + + def _get_proto_class(self) -> type[RunProto]: + return RunProto + + def _add_resource_id_to_proto(self, proto_msg: RunProto): + if self._resource_id is None: + raise ValueError("Resource ID must be set before adding to proto") + proto_msg.run_id = self._resource_id diff --git a/python/lib/sift_client/util/cel_utils.py b/python/lib/sift_client/util/cel_utils.py index 219f6fe19..ea74f2171 100644 --- a/python/lib/sift_client/util/cel_utils.py +++ b/python/lib/sift_client/util/cel_utils.py @@ -6,7 +6,7 @@ from __future__ import annotations import re -from datetime import datetime +from datetime import datetime, timedelta from typing import Any @@ -50,9 +50,11 @@ def equals(key: str, value: Any) -> str: A CEL expression string """ if value is None: - return f"{key} == null" + return equals_null(key) elif isinstance(value, str): return f"{key} == '{value}'" + elif isinstance(value, bool): + return f"{key} == {str(value).lower()}" else: return f"{key} == {value}" @@ -184,7 +186,7 @@ def match(field: str, query: str | re.Pattern) -> str: return f"{field}.matches('{escaped_regex}')" -def greater_than(field: str, value: int | float | datetime) -> str: +def greater_than(field: str, value: int | float | datetime | timedelta) -> str: """Generates a CEL expression that checks whether a numeric or datetime field is greater than a given value. Args: @@ -195,13 +197,15 @@ def greater_than(field: str, value: int | float | datetime) -> str: A CEL expression string """ if isinstance(value, datetime): - as_string = value.isoformat() + as_string = f"timestamp('{value.isoformat()}')" + elif isinstance(value, timedelta): + as_string = f"duration('{value.total_seconds()}s')" else: as_string = str(value) return f"{field} > {as_string}" -def less_than(field: str, value: int | float | datetime) -> str: +def less_than(field: str, value: int | float | datetime | timedelta) -> str: """Generates a CEL expression that checks whether a numeric or datetime field is less than a given value. Args: @@ -212,7 +216,9 @@ def less_than(field: str, value: int | float | datetime) -> str: A CEL expression string """ if isinstance(value, datetime): - as_string = value.isoformat() + as_string = f"timestamp('{value.isoformat()}')" + elif isinstance(value, timedelta): + as_string = f"duration('{value.total_seconds()}s')" else: as_string = str(value) return f"{field} < {as_string}" diff --git a/python/mkdocs.yml b/python/mkdocs.yml index e598f2d8d..00ae85827 100644 --- a/python/mkdocs.yml +++ b/python/mkdocs.yml @@ -64,6 +64,7 @@ plugins: python: import: - https://docs.python.org/3/objects.inv + - https://docs.pydantic.dev/latest/objects.inv options: show_docstring_examples: false load_external_modules: true @@ -90,6 +91,8 @@ plugins: - griffe_extensions/sync_stubs_inspector.py:InspectSpecificObjects: paths: - sift_client.resources.sync_stubs + - griffe_pydantic: + schema: false - api-autonav: nav_section_title: Sift Py API diff --git a/python/pyproject.toml b/python/pyproject.toml index 82483cc42..cc833a58c 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -57,7 +57,13 @@ development = [ "ruff~=0.12.10", ] build = ["pdoc==14.5.0", "build==1.2.1"] -docs = ["mkdocs", "mkdocs-material", "mkdocstrings[python]", "mkdocs-include-markdown-plugin", "mkdocs-api-autonav", "mike"] +docs = ["mkdocs", + "mkdocs-material", + "mkdocstrings[python]", + "mkdocs-include-markdown-plugin", + "mkdocs-api-autonav", + "mike", + "griffe-pydantic"] # May be required for certain library functionality openssl = ["pyOpenSSL<24.0.0", "types-pyOpenSSL<24.0.0", "cffi~=1.14"] diff --git a/python/scripts/dev b/python/scripts/dev index 8565c84e9..1ecf01ba9 100755 --- a/python/scripts/dev +++ b/python/scripts/dev @@ -18,6 +18,7 @@ Subcommands: lint Runs 'ruff check' to lint the lib directory mypy Runs 'mypy lib' for static analysis pyright Runs 'pyright lib' for type checking + check Runs all python linting checks gen-stubs Generates pyi stubs for sift_client synchronous wrappers mypy-stubs Runs stubtest (mypy) on the generated pyi stubs pip-install Install project dependencies