Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 91 additions & 44 deletions python/lib/sift_client/_internal/low_level_wrappers/rules.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, Sequence, cast

from sift.common.type.v1.resource_identifier_pb2 import ResourceIdentifier, ResourceIdentifiers
from sift.rule_evaluation.v1.rule_evaluation_pb2 import (
Expand Down Expand Up @@ -97,6 +97,48 @@ async def get_rule(self, rule_id: str | None = None, client_key: str | None = No
grpc_rule = cast("GetRuleResponse", response).rule
return Rule._from_proto(grpc_rule)

def _update_rule_request_from_create(self, create: RuleCreate) -> UpdateRuleRequest:
"""Create an UpdateRuleRequest from a RuleCreate object.

Args:
create: The RuleCreate model with the rule configuration.

Returns:
The UpdateRuleRequest proto message.
"""
expression_proto = RuleConditionExpression(
calculated_channel=CalculatedChannelConfig(
expression=create.expression,
channel_references={
c.channel_reference: ChannelReferenceProto(name=c.channel_identifier)
for c in create.channel_references
},
)
)
conditions_request = [
UpdateConditionRequest(
expression=expression_proto,
actions=[create.action._to_update_request()],
)
]
update_request = UpdateRuleRequest(
name=create.name,
description=create.description,
is_enabled=True,
organization_id=create.organization_id or "",
client_key=create.client_key,
is_external=create.is_external,
conditions=conditions_request,
asset_configuration=RuleAssetConfiguration(
asset_ids=create.asset_ids or [],
tag_ids=create.asset_tag_ids or [],
),
contextual_channels=ContextualChannels(
channels=[ChannelReferenceProto(name=c) for c in create.contextual_channels or []]
), # type: ignore
)
return update_request

async def batch_get_rules(
self, rule_ids: list[str] | None = None, client_keys: list[str] | None = None
) -> list[Rule]:
Expand Down Expand Up @@ -138,39 +180,7 @@ async def create_rule(
Returns:
The created Rule.
"""
# Convert rule to UpdateRuleRequest
expression_proto = RuleConditionExpression(
calculated_channel=CalculatedChannelConfig(
expression=create.expression,
channel_references={
c.channel_reference: ChannelReferenceProto(name=c.channel_identifier)
for c in create.channel_references
},
)
)
conditions_request = [
UpdateConditionRequest(
expression=expression_proto,
actions=[create.action._to_update_request()],
)
]
update_request = UpdateRuleRequest(
name=create.name,
description=create.description,
is_enabled=True,
organization_id=create.organization_id or "",
client_key=create.client_key,
is_external=create.is_external,
conditions=conditions_request,
asset_configuration=RuleAssetConfiguration(
asset_ids=create.asset_ids or [],
tag_ids=create.asset_tag_ids or [],
),
contextual_channels=ContextualChannels(
channels=[ChannelReferenceProto(name=c) for c in create.contextual_channels or []]
), # type: ignore
)

update_request = self._update_rule_request_from_create(create)
request = CreateRuleRequest(update=update_request)
created_rule = cast(
"CreateRuleResponse",
Expand Down Expand Up @@ -301,22 +311,59 @@ async def update_rule(
# Get the updated rule
return await self.get_rule(rule_id=rule.id_)

async def batch_update_rules(self, rules: list[RuleUpdate]) -> BatchUpdateRulesResponse:
"""Batch update rules.
async def batch_update_rules(
self,
rules: Sequence[RuleCreate | RuleUpdate],
validate_only: bool = False,
override_expression_validation: bool = False,
) -> BatchUpdateRulesResponse:
"""Batch update or create rules.

Args:
rules: List of rule updates to apply.
rules: List of rule creates or updates to apply. RuleUpdate objects must have
resource_id set.

Returns:
The batch update response.
"""
update_requests = []
for rule_update in rules:
rule = await self.get_rule(rule_id=rule_update.resource_id)
request = self._update_rule_request_from_update(rule, rule_update)
update_requests.append(request)

request = BatchUpdateRulesRequest(rules=update_requests) # type: ignore
Raises:
ValueError: If any RuleUpdate objects are missing resource_id or the rule is not found for updating.
"""
# Collect resource_ids from only RuleUpdate objects
rule_ids: list[str] = []
for rule in rules:
if isinstance(rule, RuleUpdate):
if rule.resource_id is None:
raise ValueError("RuleUpdate objects must have resource_id set")
rule_ids.append(rule.resource_id)

# Fetch existing rules for updates
existing_rules = await self.batch_get_rules(rule_ids=rule_ids) if rule_ids else []
existing_rules_by_id = {rule.id_: rule for rule in existing_rules}

# Build update requests maintaining the input order
update_requests: list[UpdateRuleRequest] = []
for rule in rules:
if isinstance(rule, RuleCreate):
# Convert RuleCreate to UpdateRuleRequest
update_request = self._update_rule_request_from_create(rule)
update_requests.append(update_request)
elif isinstance(rule, RuleUpdate):
# Use existing rule + update to create request
existing_rule = existing_rules_by_id.get(rule.resource_id)
if existing_rule is None:
raise ValueError(
f"Rule with resource_id {rule.resource_id} not found for update"
)
update_request = self._update_rule_request_from_update(existing_rule, rule)
update_requests.append(update_request)

# Call the batch update request
request = BatchUpdateRulesRequest(
rules=update_requests,
validate_only=validate_only,
override_expression_validation=override_expression_validation,
) # type: ignore
response = await self._grpc_client.get_stub(RuleServiceStub).BatchUpdateRules(request)
return cast("BatchUpdateRulesResponse", response)

Expand Down
130 changes: 130 additions & 0 deletions python/lib/sift_client/_tests/resources/test_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,136 @@ async def test_unarchive_rule(self, rules_api_async, new_rule):
finally:
await rules_api_async.archive(new_rule.id_)

class TestBatchUpdate:
"""Tests for the async batch_update_rules method."""

@pytest.mark.asyncio
async def test_batch_update_or_create_rules(self, rules_api_async, nostromo_asset):
"""Test updating multiple rules with different fields."""
from datetime import datetime, timezone

rule1_name = f"test_batch_rule_1_{datetime.now(timezone.utc).isoformat()}"
rule2_name = f"test_batch_rule_2_{datetime.now(timezone.utc).isoformat()}"

rule1 = await rules_api_async.create(
RuleCreate(
name=rule1_name,
client_key=f"test_batch_1_{str(uuid.uuid4())[-8:]}",
description="Test rule 1 for batch update",
expression="$1 > $2",
channel_references=[
ChannelReference(channel_reference="$1", channel_identifier="channel1"),
ChannelReference(channel_reference="$2", channel_identifier="channel2"),
],
action=RuleAction.annotation(
annotation_type=RuleAnnotationType.DATA_REVIEW,
tags=[],
),
asset_ids=[nostromo_asset.id_],
)
)

rule2 = await rules_api_async.create(
RuleCreate(
name=rule2_name,
client_key=f"test_batch_2_{str(uuid.uuid4())[-8:]}",
description="Test rule 2 for batch update",
expression="$1 > 0.5",
channel_references=[
ChannelReference(channel_reference="$1", channel_identifier="channel1"),
],
action=RuleAction.annotation(
annotation_type=RuleAnnotationType.DATA_REVIEW,
tags=[],
),
asset_ids=[nostromo_asset.id_],
)
)

try:
# Batch update both rules
rule1_update = RuleUpdate(description="Updated description 1")
rule1_update.resource_id = rule1.id_

rule2_update = RuleUpdate(description="Updated description 2")
rule2_update.resource_id = rule2.id_

updates = [rule1_update, rule2_update]

updated_rules = await rules_api_async.batch_update_or_create_rules(updates)

assert isinstance(updated_rules, list)
assert len(updated_rules) == 2

# Verify updates were applied
assert updated_rules[0].description == "Updated description 1"
assert updated_rules[1].description == "Updated description 2"
finally:
await rules_api_async.archive(rule1.id_)
await rules_api_async.archive(rule2.id_)

@pytest.mark.asyncio
async def test_batch_update_rules_creates_rules(self, rules_api_async, nostromo_asset):
"""Test batch updating rules that don't already exist."""
from datetime import datetime, timezone

rule1_name = f"test_batch_rule_1_{datetime.now(timezone.utc).isoformat()}"
rule2_name = f"test_batch_rule_2_{datetime.now(timezone.utc).isoformat()}"

rule1 = RuleCreate(
name=rule1_name,
client_key=f"test_batch_1_{str(uuid.uuid4())[-8:]}",
description="Test rule 1 for batch update",
expression="$1 > $2",
channel_references=[
ChannelReference(channel_reference="$1", channel_identifier="channel1"),
ChannelReference(channel_reference="$2", channel_identifier="channel2"),
],
action=RuleAction.annotation(
annotation_type=RuleAnnotationType.DATA_REVIEW,
tags=[],
),
asset_ids=[nostromo_asset.id_],
)

rule2 = RuleCreate(
name=rule2_name,
client_key=f"test_batch_2_{str(uuid.uuid4())[-8:]}",
description="Test rule 2 for batch update",
expression="$1 > 0.5",
channel_references=[
ChannelReference(channel_reference="$1", channel_identifier="channel1"),
],
action=RuleAction.annotation(
annotation_type=RuleAnnotationType.DATA_REVIEW,
tags=[],
),
asset_ids=[nostromo_asset.id_],
)

updated_rules: list[Rule] = []
try:
# Batch update (actually create) both rules
updates = [rule1, rule2]
updated_rules = await rules_api_async.batch_update_or_create_rules(updates)

assert isinstance(updated_rules, list)
assert len(updated_rules) == 2

assert updated_rules[0].client_key == rule1.client_key
assert updated_rules[1].client_key == rule2.client_key
finally:
for rule in updated_rules:
await rules_api_async.archive(rule.id_)

@pytest.mark.asyncio
async def test_batch_update_rules_empty_list(self, rules_api_async):
"""Test handling empty list."""
updated_rules = await rules_api_async.batch_update_or_create_rules([])

assert isinstance(updated_rules, list)
assert len(updated_rules) == 0


class TestRulesAPISync:
"""Test suite for the synchronous Rules API functionality."""
Expand Down
Loading
Loading