Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/lib/sift_py/rule/_service_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def test_rule_service_load_rules_from_yaml(rule_service):
"assignee": "assignee@abc.com",
"type": "review",
"asset_names": ["asset"],
"tag_names": ["tag1"],
}
with mock.patch.object(RuleService, "create_or_update_rule"):
with mock.patch(
Expand All @@ -115,6 +116,7 @@ def test_rule_service_load_rules_from_yaml(rule_service):
assert rule_config.expression == rule_yaml["expression"]
assert rule_config.action.assignee == rule_yaml["assignee"]
assert rule_config.asset_names == rule_yaml["asset_names"]
assert rule_config.tag_names == rule_yaml["tag_names"]
assert isinstance(rule_config.action, RuleActionCreateDataReviewAnnotation)


Expand Down
6 changes: 5 additions & 1 deletion python/lib/sift_py/rule/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class RuleConfig(AsJson):
- `channel_references`: Reference to channel. If an expression is "$1 < 10", then "$1" is the reference and thus should the key in the dict.
- `rule_client_key`: User defined unique string that uniquely identifies this rule.
- `asset_names`: A list of asset names that this rule should be applied to. ONLY VALID if defining rules outside of a telemetry config.
- `tag_names`: A list of asset names that this rule should be applied to. ONLY VALID if defining rules outside of a telemetry config.
- `tag_names`: A list of asset tags that this rule should be applied to. ONLY VALID if defining rules outside of a telemetry config.
- `contextual_channels`: A list of channel names that provide context but aren't directly used in the expression.
- `is_external`: If this is an external rule.
- `is_live`: If set to True then this rule will be evaluated on live data, otherwise live rule evaluation will be disabled.
Expand All @@ -38,6 +38,7 @@ class RuleConfig(AsJson):
channel_references: List[ExpressionChannelReference]
rule_client_key: Optional[str]
asset_names: List[str]
tag_names: List[str]
contextual_channels: List[str]
is_external: bool
is_live: bool
Expand Down Expand Up @@ -65,6 +66,7 @@ def __init__(

self.name = name
self.asset_names = asset_names or []
self.tag_names = tag_names or []
self.action = action
self.rule_client_key = rule_client_key
self.description = description
Expand Down Expand Up @@ -133,6 +135,8 @@ def interpolate_sub_expressions(


class RuleAction(ABC):
tags: Optional[List[str]]

@abstractmethod
def kind(self) -> RuleActionKind:
pass
Expand Down
56 changes: 53 additions & 3 deletions python/lib/sift_py/rule/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
UpdateRuleRequest,
)
from sift.rules.v1.rules_pb2_grpc import RuleServiceStub
from sift.tags.v2.tags_pb2 import Tag, TagType
from sift.tags.v2.tags_pb2_grpc import TagServiceStub
from sift.users.v2.users_pb2_grpc import UserServiceStub

from sift_py._internal.cel import cel_in
Expand All @@ -55,6 +57,7 @@
RuleActionKind,
RuleConfig,
)
from sift_py.tag._internal.shared import list_tags_impl
from sift_py.yaml.rule import load_rule_modules


Expand All @@ -72,13 +75,15 @@ class RuleService:
_channel_service_stub: ChannelServiceStub
_rule_service_stub: RuleServiceStub
_user_service_stub: UserServiceStub
_tag_service_stub: TagServiceStub
_enable_caching: bool

def __init__(self, channel: SiftChannel, enable_caching=False):
self._asset_service_stub = AssetServiceStub(channel)
self._channel_service_stub = ChannelServiceStub(channel)
self._rule_service_stub = RuleServiceStub(channel)
self._user_service_stub = UserServiceStub(channel)
self._tag_service_stub = TagServiceStub(channel)
self._enable_caching = enable_caching

def load_rules_from_yaml(
Expand Down Expand Up @@ -244,6 +249,7 @@ def _parse_rules_from_yaml(
channel_references=rule_channel_references,
contextual_channels=contextual_channels,
asset_names=rule_yaml.get("asset_names", []),
tag_names=rule_yaml.get("tag_names", []),
sub_expressions=subexpr,
is_external=rule_yaml.get("is_external", False),
is_live=rule_yaml.get("is_live", False),
Expand Down Expand Up @@ -402,8 +408,20 @@ def _update_req_from_rule_config(
"See `sift_py.rule.config.RuleAction` for available actions."
)

# TODO: once we have TagService_ListTags we can do asset-agnostic rules via tags
assets = self._get_assets(names=config.asset_names) if config.asset_names else None
asset_tags = (
self._get_tags(names=config.tag_names, tag_type=TagType.TAG_TYPE_ASSET)
if config.tag_names
else None
)
annotation_tags = (
self._get_tags(
names=[tag for tag in config.action.tags],
tag_type=TagType.TAG_TYPE_ANNOTATION,
)
if config.action.tags
else None
)

actions = []
if config.action.kind() == RuleActionKind.NOTIFICATION:
Expand All @@ -412,6 +430,10 @@ def _update_req_from_rule_config(
"Please contact the Sift team for assistance."
)
elif config.action.kind() == RuleActionKind.ANNOTATION:
annotation_tag_ids = (
[tag.tag_id for tag in annotation_tags] if annotation_tags else None
)

if isinstance(config.action, RuleActionCreateDataReviewAnnotation):
assignee = config.action.assignee
user_id = None
Expand All @@ -431,7 +453,7 @@ def _update_req_from_rule_config(
annotation=AnnotationActionConfiguration(
assigned_to_user_id=user_id,
annotation_type=AnnotationType.ANNOTATION_TYPE_DATA_REVIEW,
# tag_ids=config.action.tags, # TODO: Requires TagService
tag_ids=annotation_tag_ids,
)
),
)
Expand All @@ -442,7 +464,7 @@ def _update_req_from_rule_config(
configuration=RuleActionConfiguration(
annotation=AnnotationActionConfiguration(
annotation_type=AnnotationType.ANNOTATION_TYPE_PHASE,
# tag_ids=config.action.tags, # TODO: Requires TagService
tag_ids=annotation_tag_ids,
)
),
)
Expand Down Expand Up @@ -523,6 +545,7 @@ def _update_req_from_rule_config(
],
asset_configuration=RuleAssetConfiguration(
asset_ids=[asset.asset_id for asset in assets] if assets else None,
tag_ids=[tag.tag_id for tag in asset_tags] if asset_tags else None,
),
contextual_channels=ContextualChannels(channels=contextual_channel_names),
is_external=config.is_external,
Expand Down Expand Up @@ -574,6 +597,12 @@ def get_rule(self, rule: str) -> Optional[RuleConfig]:
)
asset_names = [asset.name for asset in assets]

asset_tags = self._get_tags(
ids=[tag_id for tag_id in rule_pb.asset_configuration.tag_ids],
tag_type=TagType.TAG_TYPE_ASSET,
)
asset_tag_names = [tag.name for tag in asset_tags]

contextual_channels = []
for channel_ref in rule_pb.contextual_channels.channels:
contextual_channels.append(channel_ref.name)
Expand All @@ -585,6 +614,7 @@ def get_rule(self, rule: str) -> Optional[RuleConfig]:
channel_references=channel_references, # type: ignore
contextual_channels=contextual_channels,
asset_names=asset_names,
tag_names=asset_tag_names,
action=action,
expression=expression,
)
Expand Down Expand Up @@ -616,6 +646,17 @@ def _get_assets(self, names: List[str] = [], ids: List[str] = []) -> List[Asset]
else:
return list_assets_impl(self._asset_service_stub, names, ids)

def _get_tags(
self,
names: List[str] = [],
ids: List[str] = [],
tag_type: TagType.ValueType = TagType.TAG_TYPE_UNSPECIFIED,
) -> List[Tag]:
if self._enable_caching:
return self._get_tags_cached(tuple(sorted(names)), tuple(sorted(ids)), tag_type)
else:
return list_tags_impl(self._tag_service_stub, names, ids, tag_type)

def _get_channels(self, filter: str) -> List[ChannelPb]:
if self._enable_caching:
return self._get_channels_cached(filter)
Expand All @@ -632,6 +673,15 @@ def _get_active_users(self, filter: str) -> List[UserPb]:
def _get_assets_cached(self, names: Tuple[str], ids: Tuple[str]) -> List[Asset]:
return list_assets_impl(self._asset_service_stub, names, ids)

@cache
def _get_tags_cached(
self,
names: Tuple[str],
ids: Tuple[str],
tag_type: TagType.ValueType = TagType.TAG_TYPE_UNSPECIFIED,
) -> List[Tag]:
return list_tags_impl(self._tag_service_stub, names, ids, tag_type)

@cache
def _get_channels_cached(self, filter: str) -> List[ChannelPb]:
return get_channels(channel_service=self._channel_service_stub, filter=filter)
Expand Down
63 changes: 63 additions & 0 deletions python/lib/sift_py/tag/_internal/shared.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from typing import List, Optional, Tuple, Union, cast

from sift.tags.v2.tags_pb2 import ListTagsRequest, ListTagsResponse, Tag, TagType
from sift.tags.v2.tags_pb2_grpc import TagServiceStub
from sift_py._internal.cel import cel_in


def list_tags_impl(
tag_service_stub: TagServiceStub,
names: Optional[Union[Tuple[str], List[str]]] = None,
ids: Optional[Union[Tuple[str], List[str]]] = None,
tag_type: TagType.ValueType = TagType.TAG_TYPE_UNSPECIFIED,
) -> List[Tag]:
"""
Lists tags in an organization.

Args:
tag_service_stub: The tag service stub to use.
names: Optional collection of names to filter by.
ids: Optional collection of IDs to filter by.
tag_type: Optional tag type to filter by.

Returns:
A list of tags matching the criteria.
"""

def get_tags_with_filter(
tag_service_stub: TagServiceStub,
cel_filter: str,
tag_type: TagType.ValueType,
) -> List[Tag]:
tags: List[Tag] = []
next_page_token = ""
while True:
req = ListTagsRequest(
filter=cel_filter,
page_size=1_000,
page_token=next_page_token,
tag_type=tag_type,
)
res = cast(ListTagsResponse, tag_service_stub.ListTags(req))
tags.extend(res.tags)

if not res.next_page_token:
break
next_page_token = res.next_page_token

return tags

if names is None:
names = []
if ids is None:
ids = []

results: List[Tag] = []
if names:
names_cel = cel_in("name", names)
results.extend(get_tags_with_filter(tag_service_stub, names_cel, tag_type))
if ids:
ids_cel = cel_in("tag_id", ids)
results.extend(get_tags_with_filter(tag_service_stub, ids_cel, tag_type))

return results
Loading