-
-
Notifications
You must be signed in to change notification settings - Fork 6
feat(taskworker): Add Push Mode to Taskworker #576
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
70ee047
357099b
eef475a
1988ba9
1df88aa
43ec727
b61581c
9ed3733
a030311
447ef38
f0414bf
f41ba74
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,6 +14,22 @@ class MetricsBackend(Protocol): | |
| An abstract class that defines the interface for metrics backends. | ||
| """ | ||
|
|
||
| @abstractmethod | ||
| def gauge( | ||
| self, | ||
| key: str, | ||
| value: float, | ||
| instance: str | None = None, | ||
| tags: Tags | None = None, | ||
| sample_rate: float = 1, | ||
| unit: str | None = None, | ||
| stacklevel: int = 0, | ||
| ) -> None: | ||
| """ | ||
| Records a gauge metric (a point-in-time value). | ||
| """ | ||
| raise NotImplementedError | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We'll have to be careful about implementing this when sentry's client library version is updated. I don't think there are any tests in sentry that will fail because of this new abstract method.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you expand on this?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sentry (and launchpad) now have implementations of this When we update applications to the new client library release, we have to remember to implement this method, or the push worker will be broken. |
||
|
|
||
| @abstractmethod | ||
| def incr( | ||
| self, | ||
|
|
@@ -71,6 +87,18 @@ class NoOpMetricsBackend(MetricsBackend): | |
| Default metrics backend that does not record anything. | ||
| """ | ||
|
|
||
| def gauge( | ||
| self, | ||
| key: str, | ||
| value: float, | ||
| instance: str | None = None, | ||
| tags: Tags | None = None, | ||
| sample_rate: float = 1, | ||
| unit: str | None = None, | ||
| stacklevel: int = 0, | ||
| ) -> None: | ||
| pass | ||
|
|
||
| def incr( | ||
| self, | ||
| name: str, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,7 +7,7 @@ | |
| from collections.abc import Callable | ||
| from dataclasses import dataclass | ||
| from pathlib import Path | ||
| from typing import TYPE_CHECKING, Any | ||
| from typing import TYPE_CHECKING, Any, Optional, Sequence, Tuple, Union | ||
|
|
||
| import grpc | ||
| import orjson | ||
|
|
@@ -27,8 +27,17 @@ | |
| from taskbroker_client.metrics import MetricsBackend | ||
| from taskbroker_client.types import InflightTaskActivation, ProcessingResult | ||
|
|
||
| if TYPE_CHECKING: | ||
| ServerInterceptor = grpc.ServerInterceptor[Any, Any] | ||
| else: | ||
| ServerInterceptor = grpc.ServerInterceptor | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| # gRPC runs the unary request deserializer and the servicer on the same thread for a given call | ||
| # If HMAC verification fails inside a deserializer wrapper, raising turns into INTERNAL, so we set this flag and abort UNAUTHENTICATED from the servicer | ||
| _RPC_SIGNATURE_AUTH_TLS = threading.local() | ||
|
|
||
| MAX_ACTIVATION_SIZE = 1024 * 1024 * 10 | ||
| """Max payload size we will process.""" | ||
|
|
||
|
|
@@ -68,25 +77,27 @@ def __init__( | |
| self.credentials = credentials | ||
|
|
||
|
|
||
| # Type alias based on grpc-stubs | ||
| ContinuationType = Callable[[ClientCallDetails, Message], Any] | ||
|
|
||
|
|
||
| if TYPE_CHECKING: | ||
| RpcMethodHandler = grpc.RpcMethodHandler[Any, Any] | ||
| InterceptorBase = grpc.UnaryUnaryClientInterceptor[Message, Message] | ||
| CallFuture = grpc.CallFuture[Message] | ||
| else: | ||
| RpcMethodHandler = grpc.RpcMethodHandler | ||
| InterceptorBase = grpc.UnaryUnaryClientInterceptor | ||
| CallFuture = Any | ||
|
|
||
| ClientContinuation = Callable[[ClientCallDetails, Message], Any] | ||
| ServerContinuation = Callable[[grpc.HandlerCallDetails], Optional[RpcMethodHandler]] | ||
| Metadata = Sequence[Tuple[str, Union[str, bytes]]] | ||
|
|
||
|
|
||
| class RequestSignatureInterceptor(InterceptorBase): | ||
| def __init__(self, shared_secret: list[str]): | ||
| self._secret = shared_secret[0].encode("utf-8") | ||
|
|
||
| def intercept_unary_unary( | ||
| self, | ||
| continuation: ContinuationType, | ||
| continuation: ClientContinuation, | ||
| client_call_details: grpc.ClientCallDetails, | ||
| request: Message, | ||
| ) -> CallFuture: | ||
|
|
@@ -108,6 +119,122 @@ def intercept_unary_unary( | |
| return continuation(call_details_with_meta, request) | ||
|
|
||
|
|
||
| def parse_rpc_secret_list(rpc_secret: str | None) -> list[str] | None: | ||
| """ | ||
| Parse the task app `rpc_secret` JSON array string into a list of secrets. | ||
| Returns `None` when unset, invalid, or empty (no authentication). | ||
| """ | ||
| if not rpc_secret: | ||
| return None | ||
|
|
||
| # Try to parse the provided secret | ||
| parsed = orjson.loads(rpc_secret) | ||
|
|
||
| if not isinstance(parsed, list) or len(parsed) == 0: | ||
| # If the secret string is not a list with at least one element, it is invalid | ||
| return None | ||
|
|
||
| return [str(x) for x in parsed] | ||
|
|
||
|
|
||
| def grpc_metadata_get(metadata: Metadata, key: str) -> str | None: | ||
| """ | ||
| First matching gRPC metadata value for `key` or `None` if not present. | ||
| """ | ||
| # gRPC metadata keys are ASCII and compared case-insensitively | ||
| key = key.lower() | ||
|
|
||
| for k, v in metadata: | ||
| if k.lower() == key: | ||
| return v.decode("utf-8") if isinstance(v, bytes) else v | ||
|
|
||
| return None | ||
|
|
||
|
|
||
| def verify_rpc_request_signature_hmac( | ||
| secrets: list[str], | ||
| method: str, | ||
| request_body: bytes, | ||
| signature_hex: str | None, | ||
| ) -> bool: | ||
| """ | ||
| Verify the 'sentry-signature' metadata for a unary RPC body. | ||
| Uses the same signing contract as `RequestSignatureInterceptor` and the taskbroker. | ||
| """ | ||
| if not secrets: | ||
| return True | ||
|
|
||
| if not signature_hex: | ||
| return False | ||
|
|
||
| try: | ||
| sig_bytes = bytes.fromhex(signature_hex) | ||
| except ValueError: | ||
| return False | ||
|
|
||
| signing_payload = method.encode("utf-8") + b":" + request_body | ||
|
|
||
| for secret in secrets: | ||
| expected = hmac.new(secret.encode("utf-8"), signing_payload, hashlib.sha256).digest() | ||
| if hmac.compare_digest(expected, sig_bytes): | ||
| return True | ||
|
|
||
| return False | ||
|
|
||
|
|
||
| class RequestSignatureServerInterceptor(ServerInterceptor): | ||
| """ | ||
| Enforces HMAC request signing on unary-unary RPCs like `WorkerService.PushTask`. | ||
| Verification uses the raw request bytes from the wire (via a wrapped deserializer) | ||
| so it stays consistent across languages and map encodings. | ||
| """ | ||
|
|
||
| def __init__(self, secrets: list[str]) -> None: | ||
| self._secrets = secrets | ||
|
|
||
| def intercept_service( | ||
| self, continuation: ServerContinuation, handler_call_details: grpc.HandlerCallDetails | ||
| ) -> Any: | ||
| handler = continuation(handler_call_details) | ||
| if handler is None or not self._secrets: | ||
| return handler | ||
|
|
||
| if handler.request_streaming or handler.response_streaming or handler.unary_unary is None: | ||
| return handler | ||
|
|
||
| inner_deserializer = handler.request_deserializer | ||
| if inner_deserializer is None: | ||
| return handler | ||
|
|
||
| method = handler_call_details.method | ||
| metadata = handler_call_details.invocation_metadata | ||
| signature = grpc_metadata_get(metadata, "sentry-signature") | ||
| original = handler.unary_unary | ||
|
|
||
| def request_deserializer(serialized_request: bytes) -> Any: | ||
| _RPC_SIGNATURE_AUTH_TLS.failed = False | ||
|
|
||
| if not verify_rpc_request_signature_hmac( | ||
| self._secrets, method, serialized_request, signature | ||
| ): | ||
| _RPC_SIGNATURE_AUTH_TLS.failed = True | ||
| return inner_deserializer(b"") | ||
|
|
||
| return inner_deserializer(serialized_request) | ||
|
|
||
| def unary_unary(request: Any, context: grpc.ServicerContext) -> Any: | ||
| if getattr(_RPC_SIGNATURE_AUTH_TLS, "failed", False): | ||
| context.abort(grpc.StatusCode.UNAUTHENTICATED, "Authentication failed") | ||
|
|
||
| return original(request, context) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Auth bypass:
|
||
|
|
||
| return grpc.unary_unary_rpc_method_handler( | ||
| unary_unary, | ||
| request_deserializer=request_deserializer, | ||
| response_serializer=handler.response_serializer, | ||
| ) | ||
|
|
||
|
|
||
| class HostTemporarilyUnavailable(Exception): | ||
| """Raised when a host is temporarily unavailable and should be retried later.""" | ||
|
|
||
|
|
@@ -157,6 +284,7 @@ def __init__( | |
| ) | ||
|
|
||
| self._cur_host = random.choice(self._hosts) | ||
| self._host_to_stubs_lock = threading.Lock() | ||
| self._host_to_stubs: dict[str, ConsumerServiceStub] = { | ||
| self._cur_host: self._connect_to_host(self._cur_host) | ||
| } | ||
|
|
@@ -195,11 +323,17 @@ def _emit_health_check(self) -> None: | |
| def _connect_to_host(self, host: str) -> ConsumerServiceStub: | ||
| logger.info("taskworker.client.connect", extra={"host": host}) | ||
| channel = grpc.insecure_channel(host, options=self._grpc_options) | ||
| if self._rpc_secret: | ||
| secrets = orjson.loads(self._rpc_secret) | ||
| secrets = parse_rpc_secret_list(self._rpc_secret) | ||
| if secrets: | ||
| channel = grpc.intercept_channel(channel, RequestSignatureInterceptor(secrets)) | ||
| return ConsumerServiceStub(channel) | ||
|
|
||
| def _get_stub(self, host: str) -> ConsumerServiceStub: | ||
|
george-sentry marked this conversation as resolved.
|
||
| with self._host_to_stubs_lock: | ||
| if host not in self._host_to_stubs: | ||
| self._host_to_stubs[host] = self._connect_to_host(host) | ||
| return self._host_to_stubs[host] | ||
|
cursor[bot] marked this conversation as resolved.
|
||
|
|
||
| def _check_consecutive_unavailable_errors(self) -> None: | ||
| if self._num_consecutive_unavailable_errors >= self._max_consecutive_unavailable_errors: | ||
| self._temporary_unavailable_hosts[self._cur_host] = ( | ||
|
|
@@ -246,11 +380,10 @@ def _get_cur_stub(self) -> tuple[str, ConsumerServiceStub]: | |
| tags={"reason": "max_tasks_reached"}, | ||
| ) | ||
|
|
||
| if self._cur_host not in self._host_to_stubs: | ||
| self._host_to_stubs[self._cur_host] = self._connect_to_host(self._cur_host) | ||
| stub = self._get_stub(self._cur_host) | ||
|
|
||
| self._num_tasks_before_rebalance -= 1 | ||
| return self._cur_host, self._host_to_stubs[self._cur_host] | ||
| return self._cur_host, stub | ||
|
|
||
| def get_task(self, namespace: str | None = None) -> InflightTaskActivation | None: | ||
| """ | ||
|
|
@@ -324,10 +457,11 @@ def update_task( | |
| f"Host: {processing_result.host} is temporarily unavailable" | ||
| ) | ||
|
|
||
| stub = self._get_stub(processing_result.host) | ||
|
george-sentry marked this conversation as resolved.
|
||
| with self._metrics.timer( | ||
| "taskworker.update_task.rpc", tags={"host": processing_result.host} | ||
| ): | ||
| response = self._host_to_stubs[processing_result.host].SetTaskStatus(request) | ||
| response = stub.SetTaskStatus(request) | ||
| except grpc.RpcError as err: | ||
| self._metrics.incr( | ||
| "taskworker.client.rpc_error", | ||
|
|
||


Uh oh!
There was an error while loading. Please reload this page.