diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 1dc56621..33d51c77 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -6,7 +6,7 @@ readme = "README.md" dependencies = [ "sentry-arroyo>=2.33.1", "sentry-sdk[http2]>=2.43.0", - "sentry-protos>=0.4.11", + "sentry-protos>=0.8.5", "confluent_kafka>=2.3.0", "cronsim>=2.6", "grpcio>=1.67.0", diff --git a/clients/python/src/examples/cli.py b/clients/python/src/examples/cli.py index 26cd8e01..02ce6e40 100644 --- a/clients/python/src/examples/cli.py +++ b/clients/python/src/examples/cli.py @@ -73,7 +73,16 @@ def scheduler() -> None: help="The number of child processes to start.", default=2, ) -def worker(rpc_host: str, concurrency: int) -> None: +@click.option( + "--push-mode", help="Whether to run in PUSH or PULL mode.", default=False, is_flag=True +) +@click.option( + "--grpc-port", + help="Port for the gRPC server to listen on.", + default=50052, + type=int, +) +def worker(rpc_host: str, concurrency: int, push_mode: bool, grpc_port: int) -> None: from taskbroker_client.worker import TaskWorker click.echo("Starting worker") @@ -87,6 +96,8 @@ def worker(rpc_host: str, concurrency: int) -> None: rebalance_after=32, processing_pool_name="examples", process_type="forkserver", + push_mode=push_mode, + grpc_port=grpc_port, ) exitcode = worker.start() raise SystemExit(exitcode) diff --git a/clients/python/src/taskbroker_client/metrics.py b/clients/python/src/taskbroker_client/metrics.py index f48fdf27..6324029a 100644 --- a/clients/python/src/taskbroker_client/metrics.py +++ b/clients/python/src/taskbroker_client/metrics.py @@ -14,6 +14,22 @@ class MetricsBackend(Protocol): An abstract class that defines the interface for metrics backends. """ + @abstractmethod + def gauge( + self, + key: str, + value: float, + instance: str | None = None, + tags: Tags | None = None, + sample_rate: float = 1, + unit: str | None = None, + stacklevel: int = 0, + ) -> None: + """ + Records a gauge metric (a point-in-time value). + """ + raise NotImplementedError + @abstractmethod def incr( self, @@ -71,6 +87,18 @@ class NoOpMetricsBackend(MetricsBackend): Default metrics backend that does not record anything. """ + def gauge( + self, + key: str, + value: float, + instance: str | None = None, + tags: Tags | None = None, + sample_rate: float = 1, + unit: str | None = None, + stacklevel: int = 0, + ) -> None: + pass + def incr( self, name: str, diff --git a/clients/python/src/taskbroker_client/worker/client.py b/clients/python/src/taskbroker_client/worker/client.py index 0670af9c..751b4c6c 100644 --- a/clients/python/src/taskbroker_client/worker/client.py +++ b/clients/python/src/taskbroker_client/worker/client.py @@ -7,7 +7,7 @@ from collections.abc import Callable from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Optional, Sequence, Tuple, Union import grpc import orjson @@ -27,8 +27,17 @@ from taskbroker_client.metrics import MetricsBackend from taskbroker_client.types import InflightTaskActivation, ProcessingResult +if TYPE_CHECKING: + ServerInterceptor = grpc.ServerInterceptor[Any, Any] +else: + ServerInterceptor = grpc.ServerInterceptor + logger = logging.getLogger(__name__) +# gRPC runs the unary request deserializer and the servicer on the same thread for a given call +# If HMAC verification fails inside a deserializer wrapper, raising turns into INTERNAL, so we set this flag and abort UNAUTHENTICATED from the servicer +_RPC_SIGNATURE_AUTH_TLS = threading.local() + MAX_ACTIVATION_SIZE = 1024 * 1024 * 10 """Max payload size we will process.""" @@ -68,17 +77,19 @@ def __init__( self.credentials = credentials -# Type alias based on grpc-stubs -ContinuationType = Callable[[ClientCallDetails, Message], Any] - - if TYPE_CHECKING: + RpcMethodHandler = grpc.RpcMethodHandler[Any, Any] InterceptorBase = grpc.UnaryUnaryClientInterceptor[Message, Message] CallFuture = grpc.CallFuture[Message] else: + RpcMethodHandler = grpc.RpcMethodHandler InterceptorBase = grpc.UnaryUnaryClientInterceptor CallFuture = Any +ClientContinuation = Callable[[ClientCallDetails, Message], Any] +ServerContinuation = Callable[[grpc.HandlerCallDetails], Optional[RpcMethodHandler]] +Metadata = Sequence[Tuple[str, Union[str, bytes]]] + class RequestSignatureInterceptor(InterceptorBase): def __init__(self, shared_secret: list[str]): @@ -86,7 +97,7 @@ def __init__(self, shared_secret: list[str]): def intercept_unary_unary( self, - continuation: ContinuationType, + continuation: ClientContinuation, client_call_details: grpc.ClientCallDetails, request: Message, ) -> CallFuture: @@ -108,6 +119,122 @@ def intercept_unary_unary( return continuation(call_details_with_meta, request) +def parse_rpc_secret_list(rpc_secret: str | None) -> list[str] | None: + """ + Parse the task app `rpc_secret` JSON array string into a list of secrets. + Returns `None` when unset, invalid, or empty (no authentication). + """ + if not rpc_secret: + return None + + # Try to parse the provided secret + parsed = orjson.loads(rpc_secret) + + if not isinstance(parsed, list) or len(parsed) == 0: + # If the secret string is not a list with at least one element, it is invalid + return None + + return [str(x) for x in parsed] + + +def grpc_metadata_get(metadata: Metadata, key: str) -> str | None: + """ + First matching gRPC metadata value for `key` or `None` if not present. + """ + # gRPC metadata keys are ASCII and compared case-insensitively + key = key.lower() + + for k, v in metadata: + if k.lower() == key: + return v.decode("utf-8") if isinstance(v, bytes) else v + + return None + + +def verify_rpc_request_signature_hmac( + secrets: list[str], + method: str, + request_body: bytes, + signature_hex: str | None, +) -> bool: + """ + Verify the 'sentry-signature' metadata for a unary RPC body. + Uses the same signing contract as `RequestSignatureInterceptor` and the taskbroker. + """ + if not secrets: + return True + + if not signature_hex: + return False + + try: + sig_bytes = bytes.fromhex(signature_hex) + except ValueError: + return False + + signing_payload = method.encode("utf-8") + b":" + request_body + + for secret in secrets: + expected = hmac.new(secret.encode("utf-8"), signing_payload, hashlib.sha256).digest() + if hmac.compare_digest(expected, sig_bytes): + return True + + return False + + +class RequestSignatureServerInterceptor(ServerInterceptor): + """ + Enforces HMAC request signing on unary-unary RPCs like `WorkerService.PushTask`. + Verification uses the raw request bytes from the wire (via a wrapped deserializer) + so it stays consistent across languages and map encodings. + """ + + def __init__(self, secrets: list[str]) -> None: + self._secrets = secrets + + def intercept_service( + self, continuation: ServerContinuation, handler_call_details: grpc.HandlerCallDetails + ) -> Any: + handler = continuation(handler_call_details) + if handler is None or not self._secrets: + return handler + + if handler.request_streaming or handler.response_streaming or handler.unary_unary is None: + return handler + + inner_deserializer = handler.request_deserializer + if inner_deserializer is None: + return handler + + method = handler_call_details.method + metadata = handler_call_details.invocation_metadata + signature = grpc_metadata_get(metadata, "sentry-signature") + original = handler.unary_unary + + def request_deserializer(serialized_request: bytes) -> Any: + _RPC_SIGNATURE_AUTH_TLS.failed = False + + if not verify_rpc_request_signature_hmac( + self._secrets, method, serialized_request, signature + ): + _RPC_SIGNATURE_AUTH_TLS.failed = True + return inner_deserializer(b"") + + return inner_deserializer(serialized_request) + + def unary_unary(request: Any, context: grpc.ServicerContext) -> Any: + if getattr(_RPC_SIGNATURE_AUTH_TLS, "failed", False): + context.abort(grpc.StatusCode.UNAUTHENTICATED, "Authentication failed") + + return original(request, context) + + return grpc.unary_unary_rpc_method_handler( + unary_unary, + request_deserializer=request_deserializer, + response_serializer=handler.response_serializer, + ) + + class HostTemporarilyUnavailable(Exception): """Raised when a host is temporarily unavailable and should be retried later.""" @@ -157,6 +284,7 @@ def __init__( ) self._cur_host = random.choice(self._hosts) + self._host_to_stubs_lock = threading.Lock() self._host_to_stubs: dict[str, ConsumerServiceStub] = { self._cur_host: self._connect_to_host(self._cur_host) } @@ -195,11 +323,17 @@ def _emit_health_check(self) -> None: def _connect_to_host(self, host: str) -> ConsumerServiceStub: logger.info("taskworker.client.connect", extra={"host": host}) channel = grpc.insecure_channel(host, options=self._grpc_options) - if self._rpc_secret: - secrets = orjson.loads(self._rpc_secret) + secrets = parse_rpc_secret_list(self._rpc_secret) + if secrets: channel = grpc.intercept_channel(channel, RequestSignatureInterceptor(secrets)) return ConsumerServiceStub(channel) + def _get_stub(self, host: str) -> ConsumerServiceStub: + with self._host_to_stubs_lock: + if host not in self._host_to_stubs: + self._host_to_stubs[host] = self._connect_to_host(host) + return self._host_to_stubs[host] + def _check_consecutive_unavailable_errors(self) -> None: if self._num_consecutive_unavailable_errors >= self._max_consecutive_unavailable_errors: self._temporary_unavailable_hosts[self._cur_host] = ( @@ -246,11 +380,10 @@ def _get_cur_stub(self) -> tuple[str, ConsumerServiceStub]: tags={"reason": "max_tasks_reached"}, ) - if self._cur_host not in self._host_to_stubs: - self._host_to_stubs[self._cur_host] = self._connect_to_host(self._cur_host) + stub = self._get_stub(self._cur_host) self._num_tasks_before_rebalance -= 1 - return self._cur_host, self._host_to_stubs[self._cur_host] + return self._cur_host, stub def get_task(self, namespace: str | None = None) -> InflightTaskActivation | None: """ @@ -324,10 +457,11 @@ def update_task( f"Host: {processing_result.host} is temporarily unavailable" ) + stub = self._get_stub(processing_result.host) with self._metrics.timer( "taskworker.update_task.rpc", tags={"host": processing_result.host} ): - response = self._host_to_stubs[processing_result.host].SetTaskStatus(request) + response = stub.SetTaskStatus(request) except grpc.RpcError as err: self._metrics.incr( "taskworker.client.rpc_error", diff --git a/clients/python/src/taskbroker_client/worker/worker.py b/clients/python/src/taskbroker_client/worker/worker.py index ca05073f..81c116ea 100644 --- a/clients/python/src/taskbroker_client/worker/worker.py +++ b/clients/python/src/taskbroker_client/worker/worker.py @@ -10,10 +10,15 @@ from multiprocessing.context import ForkContext, ForkServerContext, SpawnContext from multiprocessing.process import BaseProcess from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any, List import grpc -from sentry_protos.taskbroker.v1.taskbroker_pb2 import FetchNextTask +from sentry_protos.taskbroker.v1 import taskbroker_pb2_grpc +from sentry_protos.taskbroker.v1.taskbroker_pb2 import ( + FetchNextTask, + PushTaskRequest, + PushTaskResponse, +) from taskbroker_client.app import import_app from taskbroker_client.constants import ( @@ -26,13 +31,49 @@ from taskbroker_client.worker.client import ( HealthCheckSettings, HostTemporarilyUnavailable, + RequestSignatureServerInterceptor, TaskbrokerClient, + parse_rpc_secret_list, ) from taskbroker_client.worker.workerchild import child_process +if TYPE_CHECKING: + ServerInterceptor = grpc.ServerInterceptor[Any, Any] +else: + ServerInterceptor = grpc.ServerInterceptor + + logger = logging.getLogger(__name__) +class WorkerServicer(taskbroker_pb2_grpc.WorkerServiceServicer): + """ + gRPC servicer that receives task activations pushed from the broker + """ + + def __init__(self, worker: TaskWorker) -> None: + self.worker = worker + + def PushTask( + self, + request: PushTaskRequest, + context: grpc.ServicerContext, + ) -> PushTaskResponse: + """Handle incoming task activation.""" + # Create `InflightTaskActivation` from the pushed task + inflight = InflightTaskActivation( + activation=request.task, + host=request.callback_url, + receive_timestamp=time.monotonic(), + ) + + # Push the task to the worker queue (wait at most 5 seconds) + if not self.worker.push_task(inflight, timeout=5): + context.abort(grpc.StatusCode.RESOURCE_EXHAUSTED, "worker busy") + + return PushTaskResponse() + + class TaskWorker: """ A TaskWorker fetches tasks from a taskworker RPC host and handles executing task activations. @@ -60,6 +101,8 @@ def __init__( process_type: str = "spawn", health_check_file_path: str | None = None, health_check_sec_per_touch: float = DEFAULT_WORKER_HEALTH_CHECK_SEC_PER_TOUCH, + push_mode: bool = False, + grpc_port: int = 50052, **kwargs: dict[str, Any], ) -> None: self.options = kwargs @@ -69,6 +112,11 @@ def __init__( self._concurrency = concurrency app = import_app(app_module) + if push_mode: + logger.info("Running in PUSH mode") + else: + logger.info("Running in PULL mode") + self.client = TaskbrokerClient( hosts=broker_hosts, application=app.name, @@ -110,12 +158,14 @@ def __init__( self._processing_pool_name: str = processing_pool_name or "unknown" + self._push_mode = push_mode + self._grpc_port = grpc_port + self._grpc_secrets = parse_rpc_secret_list(app.config["rpc_secret"]) + def start(self) -> int: """ - Run the worker main loop - - Once started a Worker will loop until it is killed, or - completes its max_task_count when it shuts down. + When in PULL mode, this starts a loop that runs until the worker completes its `max_task_count` or it is killed. + When in PUSH mode, this starts the worker gRPC server. """ self.start_result_thread() self.start_spawn_children_thread() @@ -128,12 +178,48 @@ def signal_handler(*args: Any) -> None: signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) - try: - while True: - self.run_once() - except KeyboardInterrupt: - self.shutdown() - raise + if self._push_mode: + server = None + + try: + # Start gRPC server + interceptors: List[ServerInterceptor] = [] + + if self._grpc_secrets: + interceptors = [RequestSignatureServerInterceptor(self._grpc_secrets)] + + server = grpc.server( + ThreadPoolExecutor(max_workers=self._concurrency), + interceptors=interceptors, + ) + + taskbroker_pb2_grpc.add_WorkerServiceServicer_to_server( + WorkerServicer(self), server + ) + server.add_insecure_port(f"[::]:{self._grpc_port}") + server.start() + logger.info("taskworker.grpc_server.started", extra={"port": self._grpc_port}) + + try: + server.wait_for_termination() + except KeyboardInterrupt: + # Signals are converted to KeyboardInterrupt, swallow for exit code 0 + pass + + finally: + if server is not None: + server.stop(grace=5) + + self.shutdown() + else: + try: + while True: + self.run_once() + except KeyboardInterrupt: + self.shutdown() + raise + + return 0 def run_once(self) -> None: """Access point for tests to run a single worker loop""" @@ -172,6 +258,36 @@ def shutdown(self) -> None: logger.info("taskworker.worker.shutdown.complete") + def push_task(self, inflight: InflightTaskActivation, timeout: float | None = None) -> bool: + """ + Push a task to child tasks queue. + + When timeout is `None`, blocks until the queue has space. When timeout is + set (e.g. 5.0), waits at most that many seconds and returns `False` if the + queue is still full (worker busy). + """ + try: + self._metrics.gauge("taskworker.child_tasks.size", self._child_tasks.qsize()) + except Exception as e: + # 'qsize' does not work on macOS + logger.debug("taskworker.child_tasks.size.error", extra={"error": e}) + + start_time = time.monotonic() + try: + self._child_tasks.put(inflight, timeout=timeout) + except queue.Full: + self._metrics.incr( + "taskworker.worker.push_task.busy", + tags={"processing_pool": self._processing_pool_name}, + ) + return False + self._metrics.distribution( + "taskworker.worker.child_task.put.duration", + time.monotonic() - start_time, + tags={"processing_pool": self._processing_pool_name}, + ) + return True + def _add_task(self) -> bool: """ Add a task to child tasks queue. Returns False if no new task was fetched. diff --git a/clients/python/tests/worker/test_client.py b/clients/python/tests/worker/test_client.py index 4297ccfc..f2509cd2 100644 --- a/clients/python/tests/worker/test_client.py +++ b/clients/python/tests/worker/test_client.py @@ -1,4 +1,6 @@ import dataclasses +import hashlib +import hmac import random import string import time @@ -17,6 +19,7 @@ FetchNextTask, GetTaskRequest, GetTaskResponse, + PushTaskRequest, SetTaskStatusRequest, SetTaskStatusResponse, TaskActivation, @@ -29,7 +32,10 @@ HealthCheckSettings, HostTemporarilyUnavailable, TaskbrokerClient, + grpc_metadata_get, make_broker_hosts, + parse_rpc_secret_list, + verify_rpc_request_signature_hmac, ) @@ -288,6 +294,92 @@ def test_get_task_with_interceptor() -> None: assert result.activation.namespace == "testing" +_PUSH_TASK_METHOD = "/sentry_protos.taskbroker.v1.WorkerService/PushTask" + + +def _push_task_hmac(secret: bytes, body: bytes) -> str: + payload = _PUSH_TASK_METHOD.encode("utf-8") + b":" + body + return hmac.new(secret, payload, hashlib.sha256).hexdigest() + + +def test_verify_rpc_request_signature_hmac_empty_secrets_skips_check() -> None: + body = PushTaskRequest( + task=TaskActivation( + id="abc123", + namespace="testing", + taskname="do_thing", + parameters="", + headers={}, + processing_deadline_duration=1, + ), + callback_url="taskbroker:50051", + ).SerializeToString() + + # If the secrets list is empty, any signature is valid + assert verify_rpc_request_signature_hmac([], _PUSH_TASK_METHOD, body, None) is True + assert verify_rpc_request_signature_hmac([], _PUSH_TASK_METHOD, body, "deadbeef") is True + + +def test_verify_rpc_request_signature_hmac_valid() -> None: + body = PushTaskRequest( + task=TaskActivation( + id="abc123", + namespace="testing", + taskname="do_thing", + parameters="", + headers={}, + processing_deadline_duration=1, + ), + callback_url="taskbroker:50051", + ).SerializeToString() + + # If 'correct' is one of the secrets, this signature should be valid + signature = _push_task_hmac(b"correct", body) + assert ( + verify_rpc_request_signature_hmac(["correct", "other"], _PUSH_TASK_METHOD, body, signature) + is True + ) + + +def test_verify_rpc_request_signature_hmac_wrong_secret() -> None: + body = PushTaskRequest(callback_url="taskbroker:50051").SerializeToString() + + # We should catch an invalid signature + signature = _push_task_hmac(b"expected", body) + assert verify_rpc_request_signature_hmac(["wrong"], _PUSH_TASK_METHOD, body, signature) is False + + +def test_verify_rpc_request_signature_hmac_missing_signature() -> None: + # We should catch a missing signature when one is expected + body = PushTaskRequest(callback_url="taskbroker:50051").SerializeToString() + assert verify_rpc_request_signature_hmac(["s"], _PUSH_TASK_METHOD, body, None) is False + + +def test_verify_rpc_request_signature_hmac_bad_hex() -> None: + # We should reject a signature with the wrong shape + body = PushTaskRequest(callback_url="taskbroker:50051").SerializeToString() + assert verify_rpc_request_signature_hmac(["s"], _PUSH_TASK_METHOD, body, "not-hex") is False + + +def test_grpc_metadata_get() -> None: + # Make sure our 'grpc_metadata_get' helper works + md = (("Sentry-Signature", b"abc"), ("other", "v")) + assert grpc_metadata_get(md, "sentry-signature") == "abc" + + +def test_parse_rpc_secret_list() -> None: + # Handle scenarios in which... + # - No input is provided + # - The input string is not a list + # - The input string is an empty list + assert parse_rpc_secret_list(None) is None + assert parse_rpc_secret_list("") is None + assert parse_rpc_secret_list("[]") is None + + # Correctly parse a valid list of strings + assert parse_rpc_secret_list('["a","b"]') == ["a", "b"] + + def test_get_task_with_namespace() -> None: channel = MockChannel() channel.add_response( diff --git a/clients/python/tests/worker/test_worker.py b/clients/python/tests/worker/test_worker.py index 8bfb35b2..5e41361e 100644 --- a/clients/python/tests/worker/test_worker.py +++ b/clients/python/tests/worker/test_worker.py @@ -16,6 +16,8 @@ TASK_ACTIVATION_STATUS_COMPLETE, TASK_ACTIVATION_STATUS_FAILURE, TASK_ACTIVATION_STATUS_RETRY, + PushTaskRequest, + PushTaskResponse, RetryState, TaskActivation, ) @@ -25,7 +27,7 @@ from taskbroker_client.retry import NoRetriesRemainingError from taskbroker_client.state import current_task from taskbroker_client.types import InflightTaskActivation, ProcessingResult -from taskbroker_client.worker.worker import TaskWorker +from taskbroker_client.worker.worker import TaskWorker, WorkerServicer from taskbroker_client.worker.workerchild import ProcessingDeadlineExceeded, child_process SIMPLE_TASK = InflightTaskActivation( @@ -300,6 +302,27 @@ def get_task_response(*args: Any, **kwargs: Any) -> InflightTaskActivation | Non assert mock_client.get_task.called assert mock_client.update_task.call_count == 3 + def test_push_task_queue(self) -> None: + taskworker = TaskWorker( + app_module="examples.app:app", + broker_hosts=["127.0.0.1:50051"], + max_child_task_count=100, + process_type="fork", + child_tasks_queue_maxsize=2, + ) + + # We can enqueue the first task + result = taskworker.push_task(SIMPLE_TASK, timeout=None) + self.assertTrue(result) + + # We can enqueue the second task + result = taskworker.push_task(SIMPLE_TASK, timeout=1) + self.assertTrue(result) + + # We cannot enqueue the third task because the queue is full + result = taskworker.push_task(SIMPLE_TASK, timeout=1) + self.assertFalse(result) + def test_run_once_current_task_state(self) -> None: # Run a task that uses retry_task() helper # to raise and catch a NoRetriesRemainingError @@ -348,6 +371,81 @@ def update_task_response(*args: Any, **kwargs: Any) -> None: assert redis.get("no-retries-remaining"), "key should exist if except block was hit" redis.delete("no-retries-remaining") + def test_constructor_push_mode(self) -> None: + taskworker = TaskWorker( + app_module="examples.app:app", + broker_hosts=["127.0.0.1:50051"], + max_child_task_count=100, + process_type="fork", + push_mode=True, + grpc_port=50099, + ) + + # Make sure delivery mode and gRPC port arguments are stored + self.assertTrue(taskworker._push_mode) + self.assertEqual(taskworker._grpc_port, 50099) + + def test_constructor_pull_mode(self) -> None: + taskworker = TaskWorker( + app_module="examples.app:app", + broker_hosts=["127.0.0.1:50051"], + max_child_task_count=100, + process_type="fork", + ) + + # Make sure delivery mode and gRPC port are set to their defaults + self.assertFalse(taskworker._push_mode) + self.assertEqual(taskworker._grpc_port, 50052) + + +class TestWorkerServicer(TestCase): + def test_push_task_success(self) -> None: + taskworker = TaskWorker( + app_module="examples.app:app", + broker_hosts=["127.0.0.1:50051"], + max_child_task_count=100, + process_type="fork", + push_mode=True, + ) + with mock.patch.object(taskworker, "push_task", return_value=True) as mock_push_task: + request = PushTaskRequest( + task=SIMPLE_TASK.activation, + callback_url="broker-host:50051", + ) + mock_context = mock.MagicMock() + servicer = WorkerServicer(taskworker) + + response = servicer.PushTask(request, mock_context) + + self.assertIsInstance(response, PushTaskResponse) + mock_context.abort.assert_not_called() + mock_push_task.assert_called_once_with(mock.ANY, timeout=5) + (inflight,) = mock_push_task.call_args[0] + self.assertEqual(inflight.activation.id, SIMPLE_TASK.activation.id) + self.assertEqual(inflight.host, "broker-host:50051") + + def test_push_task_worker_busy(self) -> None: + taskworker = TaskWorker( + app_module="examples.app:app", + broker_hosts=["127.0.0.1:50051"], + max_child_task_count=100, + process_type="fork", + child_tasks_queue_maxsize=1, + ) + with mock.patch.object(taskworker, "push_task", return_value=False): + request = PushTaskRequest( + task=SIMPLE_TASK.activation, + callback_url="broker-host:50051", + ) + mock_context = mock.MagicMock() + servicer = WorkerServicer(taskworker) + + servicer.PushTask(request, mock_context) + + mock_context.abort.assert_called_once_with( + grpc.StatusCode.RESOURCE_EXHAUSTED, "worker busy" + ) + @mock.patch("taskbroker_client.worker.workerchild.capture_checkin") def test_child_process_complete(mock_capture_checkin: mock.MagicMock) -> None: diff --git a/uv.lock b/uv.lock index fdf8f228..941ab117 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.11" resolution-markers = [ "sys_platform == 'darwin' or sys_platform == 'linux'", @@ -546,7 +546,7 @@ wheels = [ [[package]] name = "sentry-protos" -version = "0.4.11" +version = "0.8.6" source = { registry = "https://pypi.devinfra.sentry.io/simple" } dependencies = [ { name = "grpc-stubs", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, @@ -554,7 +554,7 @@ dependencies = [ { name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, ] wheels = [ - { url = "https://pypi.devinfra.sentry.io/wheels/sentry_protos-0.4.11-py3-none-any.whl", hash = "sha256:d60709cd9989679fbe1ca1a9a02393a3af59292da333905d5b28beaa04220352" }, + { url = "https://pypi.devinfra.sentry.io/wheels/sentry_protos-0.8.6-py3-none-any.whl", hash = "sha256:bffd32fae9df928a1d4fc519c1ff02fa3ba8fac7bf8ba0ea6495b1eb353575ef" }, ] [[package]] @@ -698,7 +698,7 @@ requires-dist = [ { name = "redis", specifier = ">=3.4.1" }, { name = "redis-py-cluster", specifier = ">=2.1.0" }, { name = "sentry-arroyo", specifier = ">=2.33.1" }, - { name = "sentry-protos", specifier = ">=0.4.11" }, + { name = "sentry-protos", specifier = ">=0.8.5" }, { name = "sentry-sdk", extras = ["http2"], specifier = ">=2.43.0" }, { name = "setuptools", marker = "extra == 'examples'", specifier = ">=80.0" }, { name = "zstandard", specifier = ">=0.18.0" },