From 70ee047a03eee5e9aae9a25cd5aa08e7d205c90c Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Wed, 18 Mar 2026 17:07:56 -0700 Subject: [PATCH 01/12] Add Push Mode to Taskworker + Unit Tests --- clients/python/pyproject.toml | 2 +- clients/python/src/examples/cli.py | 8 +- .../python/src/taskbroker_client/metrics.py | 28 +++ .../src/taskbroker_client/worker/client.py | 14 +- .../src/taskbroker_client/worker/worker.py | 118 ++++++++++- clients/python/tests/worker/test_worker.py | 194 +++++++++++++++++- uv.lock | 8 +- 7 files changed, 353 insertions(+), 19 deletions(-) diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 1dc56621..33d51c77 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -6,7 +6,7 @@ readme = "README.md" dependencies = [ "sentry-arroyo>=2.33.1", "sentry-sdk[http2]>=2.43.0", - "sentry-protos>=0.4.11", + "sentry-protos>=0.8.5", "confluent_kafka>=2.3.0", "cronsim>=2.6", "grpcio>=1.67.0", diff --git a/clients/python/src/examples/cli.py b/clients/python/src/examples/cli.py index 26cd8e01..d58761bc 100644 --- a/clients/python/src/examples/cli.py +++ b/clients/python/src/examples/cli.py @@ -73,7 +73,12 @@ def scheduler() -> None: help="The number of child processes to start.", default=2, ) -def worker(rpc_host: str, concurrency: int) -> None: +@click.option( + "--push-mode", + help="Whether to run in PUSH or PULL mode.", + default=False, +) +def worker(rpc_host: str, concurrency: int, push_mode: bool) -> None: from taskbroker_client.worker import TaskWorker click.echo("Starting worker") @@ -87,6 +92,7 @@ def worker(rpc_host: str, concurrency: int) -> None: rebalance_after=32, processing_pool_name="examples", process_type="forkserver", + push_mode=push_mode, ) exitcode = worker.start() raise SystemExit(exitcode) diff --git a/clients/python/src/taskbroker_client/metrics.py b/clients/python/src/taskbroker_client/metrics.py index f48fdf27..6324029a 100644 --- a/clients/python/src/taskbroker_client/metrics.py +++ b/clients/python/src/taskbroker_client/metrics.py @@ -14,6 +14,22 @@ class MetricsBackend(Protocol): An abstract class that defines the interface for metrics backends. """ + @abstractmethod + def gauge( + self, + key: str, + value: float, + instance: str | None = None, + tags: Tags | None = None, + sample_rate: float = 1, + unit: str | None = None, + stacklevel: int = 0, + ) -> None: + """ + Records a gauge metric (a point-in-time value). + """ + raise NotImplementedError + @abstractmethod def incr( self, @@ -71,6 +87,18 @@ class NoOpMetricsBackend(MetricsBackend): Default metrics backend that does not record anything. """ + def gauge( + self, + key: str, + value: float, + instance: str | None = None, + tags: Tags | None = None, + sample_rate: float = 1, + unit: str | None = None, + stacklevel: int = 0, + ) -> None: + pass + def incr( self, name: str, diff --git a/clients/python/src/taskbroker_client/worker/client.py b/clients/python/src/taskbroker_client/worker/client.py index 0670af9c..f19aeaf9 100644 --- a/clients/python/src/taskbroker_client/worker/client.py +++ b/clients/python/src/taskbroker_client/worker/client.py @@ -139,12 +139,16 @@ def __init__( health_check_settings: HealthCheckSettings | None = None, rpc_secret: str | None = None, grpc_config: str | None = None, + push_mode: bool = False, + grpc_port: int = 50052, ) -> None: assert len(hosts) > 0, "You must provide at least one RPC host to connect to" self._application = application self._hosts = hosts self._rpc_secret = rpc_secret self._metrics = metrics + self._push_mode = push_mode + self._grpc_port = grpc_port self._grpc_options: list[tuple[str, Any]] = [ ("grpc.max_receive_message_length", MAX_ACTIVATION_SIZE) @@ -157,6 +161,7 @@ def __init__( ) self._cur_host = random.choice(self._hosts) + self._host_to_stubs_lock = threading.Lock() self._host_to_stubs: dict[str, ConsumerServiceStub] = { self._cur_host: self._connect_to_host(self._cur_host) } @@ -200,6 +205,12 @@ def _connect_to_host(self, host: str) -> ConsumerServiceStub: channel = grpc.intercept_channel(channel, RequestSignatureInterceptor(secrets)) return ConsumerServiceStub(channel) + def _get_stub(self, host: str) -> ConsumerServiceStub: + with self._host_to_stubs_lock: + if host not in self._host_to_stubs: + self._host_to_stubs[host] = self._connect_to_host(host) + return self._host_to_stubs[host] + def _check_consecutive_unavailable_errors(self) -> None: if self._num_consecutive_unavailable_errors >= self._max_consecutive_unavailable_errors: self._temporary_unavailable_hosts[self._cur_host] = ( @@ -324,10 +335,11 @@ def update_task( f"Host: {processing_result.host} is temporarily unavailable" ) + stub = self._get_stub(processing_result.host) with self._metrics.timer( "taskworker.update_task.rpc", tags={"host": processing_result.host} ): - response = self._host_to_stubs[processing_result.host].SetTaskStatus(request) + response = stub.SetTaskStatus(request) except grpc.RpcError as err: self._metrics.incr( "taskworker.client.rpc_error", diff --git a/clients/python/src/taskbroker_client/worker/worker.py b/clients/python/src/taskbroker_client/worker/worker.py index ca05073f..6855668e 100644 --- a/clients/python/src/taskbroker_client/worker/worker.py +++ b/clients/python/src/taskbroker_client/worker/worker.py @@ -13,7 +13,12 @@ from typing import Any import grpc -from sentry_protos.taskbroker.v1.taskbroker_pb2 import FetchNextTask +from sentry_protos.taskbroker.v1 import taskbroker_pb2_grpc +from sentry_protos.taskbroker.v1.taskbroker_pb2 import ( + FetchNextTask, + PushTaskRequest, + PushTaskResponse, +) from taskbroker_client.app import import_app from taskbroker_client.constants import ( @@ -33,6 +38,34 @@ logger = logging.getLogger(__name__) +class WorkerServicer(taskbroker_pb2_grpc.WorkerServiceServicer): + """ + gRPC servicer that receives task activations pushed from the broker + """ + + def __init__(self, worker: TaskWorker) -> None: + self.worker = worker + + def PushTask( + self, + request: PushTaskRequest, + context: grpc.ServicerContext, + ) -> PushTaskResponse: + """Handle incoming task activation.""" + # Create `InflightTaskActivation` from the pushed task + inflight = InflightTaskActivation( + activation=request.task, + host=request.callback_url, + receive_timestamp=time.monotonic(), + ) + + # Push the task to the worker queue (wait at most 5 seconds) + if not self.worker.push_task(inflight, timeout=5): + context.abort(grpc.StatusCode.RESOURCE_EXHAUSTED, "worker busy") + + return PushTaskResponse() + + class TaskWorker: """ A TaskWorker fetches tasks from a taskworker RPC host and handles executing task activations. @@ -60,6 +93,8 @@ def __init__( process_type: str = "spawn", health_check_file_path: str | None = None, health_check_sec_per_touch: float = DEFAULT_WORKER_HEALTH_CHECK_SEC_PER_TOUCH, + push_mode: bool = False, + grpc_port: int = 50052, **kwargs: dict[str, Any], ) -> None: self.options = kwargs @@ -69,6 +104,11 @@ def __init__( self._concurrency = concurrency app = import_app(app_module) + if push_mode: + logger.info("Running in PUSH mode") + else: + logger.info("Running in PULL mode") + self.client = TaskbrokerClient( hosts=broker_hosts, application=app.name, @@ -81,6 +121,8 @@ def __init__( ), rpc_secret=app.config["rpc_secret"], grpc_config=app.config["grpc_config"], + push_mode=push_mode, + grpc_port=grpc_port, ) self._metrics = app.metrics @@ -110,12 +152,13 @@ def __init__( self._processing_pool_name: str = processing_pool_name or "unknown" + self._push_mode = push_mode + self._grpc_port = grpc_port + def start(self) -> int: """ - Run the worker main loop - - Once started a Worker will loop until it is killed, or - completes its max_task_count when it shuts down. + When in PULL mode, this starts a loop that runs until the worker completes its `max_task_count` or it is killed. + When in PUSH mode, this starts the worker gRPC server. """ self.start_result_thread() self.start_spawn_children_thread() @@ -128,12 +171,32 @@ def signal_handler(*args: Any) -> None: signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) - try: - while True: - self.run_once() - except KeyboardInterrupt: - self.shutdown() - raise + if self._push_mode: + try: + # Start gRPC server + server = grpc.server(ThreadPoolExecutor(max_workers=self._concurrency)) + taskbroker_pb2_grpc.add_WorkerServiceServicer_to_server( + WorkerServicer(self), server + ) + server.add_insecure_port(f"[::]:{self._grpc_port}") + server.start() + logger.info("taskworker.grpc_server.started", extra={"port": self._grpc_port}) + + # Wait for shutdown signal + server.wait_for_termination() + + except KeyboardInterrupt: + server.stop(grace=5) + self.shutdown() + else: + try: + while True: + self.run_once() + except KeyboardInterrupt: + self.shutdown() + raise + + return 0 def run_once(self) -> None: """Access point for tests to run a single worker loop""" @@ -172,6 +235,39 @@ def shutdown(self) -> None: logger.info("taskworker.worker.shutdown.complete") + def push_task(self, inflight: InflightTaskActivation, timeout: float | None = None) -> bool: + """ + Push a task to child tasks queue. + + When timeout is `None`, blocks until the queue has space. When timeout is + set (e.g. 5.0), waits at most that many seconds and returns `False` if the + queue is still full (worker busy). + """ + try: + self._metrics.gauge("taskworker.child_tasks.size", self._child_tasks.qsize()) + except Exception as e: + # `qsize` does not work on all machines + logger.warning(f"Could not report size of child tasks queue - {e}") + + start_time = time.monotonic() + try: + if timeout is None: + self._child_tasks.put(inflight) + else: + self._child_tasks.put(inflight, timeout=timeout) + except queue.Full: + self._metrics.incr( + "taskworker.worker.push_task.busy", + tags={"processing_pool": self._processing_pool_name}, + ) + return False + self._metrics.distribution( + "taskworker.worker.child_task.put.duration", + time.monotonic() - start_time, + tags={"processing_pool": self._processing_pool_name}, + ) + return True + def _add_task(self) -> bool: """ Add a task to child tasks queue. Returns False if no new task was fetched. diff --git a/clients/python/tests/worker/test_worker.py b/clients/python/tests/worker/test_worker.py index 8bfb35b2..84c27398 100644 --- a/clients/python/tests/worker/test_worker.py +++ b/clients/python/tests/worker/test_worker.py @@ -16,6 +16,8 @@ TASK_ACTIVATION_STATUS_COMPLETE, TASK_ACTIVATION_STATUS_FAILURE, TASK_ACTIVATION_STATUS_RETRY, + PushTaskRequest, + PushTaskResponse, RetryState, TaskActivation, ) @@ -25,7 +27,7 @@ from taskbroker_client.retry import NoRetriesRemainingError from taskbroker_client.state import current_task from taskbroker_client.types import InflightTaskActivation, ProcessingResult -from taskbroker_client.worker.worker import TaskWorker +from taskbroker_client.worker.worker import TaskWorker, WorkerServicer from taskbroker_client.worker.workerchild import ProcessingDeadlineExceeded, child_process SIMPLE_TASK = InflightTaskActivation( @@ -300,6 +302,91 @@ def get_task_response(*args: Any, **kwargs: Any) -> InflightTaskActivation | Non assert mock_client.get_task.called assert mock_client.update_task.call_count == 3 + def test_push_task_success_no_timeout(self) -> None: + taskworker = TaskWorker( + app_module="examples.app:app", + broker_hosts=["127.0.0.1:50051"], + max_child_task_count=100, + process_type="fork", + child_tasks_queue_maxsize=2, + ) + mock_metrics = mock.MagicMock() + taskworker._metrics = mock_metrics + mock_queue = mock.MagicMock() + mock_queue.full.return_value = False + taskworker._child_tasks = mock_queue + + result = taskworker.push_task(SIMPLE_TASK, timeout=None) + + self.assertTrue(result) + mock_metrics.gauge.assert_called_once_with("taskworker.child_tasks.size", mock.ANY) + mock_metrics.distribution.assert_called_once_with( + "taskworker.worker.child_task.put.duration", + mock.ANY, + tags={"processing_pool": "unknown"}, + ) + mock_queue.put.assert_called_once_with(SIMPLE_TASK) + + def test_push_task_success_with_timeout(self) -> None: + taskworker = TaskWorker( + app_module="examples.app:app", + broker_hosts=["127.0.0.1:50051"], + max_child_task_count=100, + process_type="fork", + child_tasks_queue_maxsize=2, + ) + taskworker._metrics = mock.MagicMock() + mock_queue = mock.MagicMock() + taskworker._child_tasks = mock_queue + + result = taskworker.push_task(SIMPLE_TASK, timeout=1.0) + + self.assertTrue(result) + mock_queue.put.assert_called_once_with(SIMPLE_TASK, timeout=1.0) + + def test_push_task_queue_full_returns_false_and_incr_busy(self) -> None: + taskworker = TaskWorker( + app_module="examples.app:app", + broker_hosts=["127.0.0.1:50051"], + max_child_task_count=100, + process_type="fork", + child_tasks_queue_maxsize=1, + ) + taskworker._metrics = mock.MagicMock() + mock_queue = mock.MagicMock() + mock_queue.qsize.return_value = 1 + mock_queue.put.side_effect = queue.Full + taskworker._child_tasks = mock_queue + + result = taskworker.push_task(RETRY_TASK, timeout=0.01) + + self.assertFalse(result) + taskworker._metrics.incr.assert_called_once_with( + "taskworker.worker.push_task.busy", + tags={"processing_pool": "unknown"}, + ) + self.assertEqual(mock_queue.put.call_count, 1) + + def test_push_task_gauge_exception_still_enqueues(self) -> None: + taskworker = TaskWorker( + app_module="examples.app:app", + broker_hosts=["127.0.0.1:50051"], + max_child_task_count=100, + process_type="fork", + child_tasks_queue_maxsize=2, + ) + mock_metrics = mock.MagicMock() + mock_metrics.gauge.side_effect = RuntimeError("qsize not supported") + taskworker._metrics = mock_metrics + mock_queue = mock.MagicMock() + taskworker._child_tasks = mock_queue + + result = taskworker.push_task(SIMPLE_TASK, timeout=None) + + self.assertTrue(result) + mock_queue.put.assert_called_once_with(SIMPLE_TASK) + mock_metrics.distribution.assert_called_once() + def test_run_once_current_task_state(self) -> None: # Run a task that uses retry_task() helper # to raise and catch a NoRetriesRemainingError @@ -348,6 +435,111 @@ def update_task_response(*args: Any, **kwargs: Any) -> None: assert redis.get("no-retries-remaining"), "key should exist if except block was hit" redis.delete("no-retries-remaining") + def test_constructor_push_mode_and_grpc_port_stored(self) -> None: + taskworker = TaskWorker( + app_module="examples.app:app", + broker_hosts=["127.0.0.1:50051"], + max_child_task_count=100, + process_type="fork", + push_mode=True, + grpc_port=50099, + ) + self.assertTrue(taskworker._push_mode) + self.assertEqual(taskworker._grpc_port, 50099) + self.assertTrue(taskworker.client._push_mode) + self.assertEqual(taskworker.client._grpc_port, 50099) + + def test_constructor_pull_mode_default(self) -> None: + taskworker = TaskWorker( + app_module="examples.app:app", + broker_hosts=["127.0.0.1:50051"], + max_child_task_count=100, + process_type="fork", + ) + self.assertFalse(taskworker._push_mode) + self.assertEqual(taskworker._grpc_port, 50052) + + def test_start_push_mode_server_creation_and_shutdown(self) -> None: + mock_server = mock.MagicMock() + mock_server.wait_for_termination.side_effect = KeyboardInterrupt() + + taskworker = TaskWorker( + app_module="examples.app:app", + broker_hosts=["127.0.0.1:50051"], + max_child_task_count=100, + process_type="fork", + push_mode=True, + grpc_port=50060, + ) + with mock.patch("taskbroker_client.worker.worker.grpc.server") as mock_grpc_server: + mock_grpc_server.return_value = mock_server + with mock.patch("taskbroker_client.worker.worker.ThreadPoolExecutor") as mock_tpe: + with mock.patch( + "taskbroker_client.worker.worker.taskbroker_pb2_grpc.add_WorkerServiceServicer_to_server" + ) as mock_add_servicer: + exitcode = taskworker.start() + + self.assertEqual(exitcode, 0) + mock_grpc_server.assert_called_once() + self.assertEqual(mock_tpe.call_args[1]["max_workers"], 1) + mock_add_servicer.assert_called_once() + self.assertIsInstance(mock_add_servicer.call_args[0][0], WorkerServicer) + self.assertEqual(mock_add_servicer.call_args[0][1], mock_server) + mock_server.add_insecure_port.assert_called_once_with("[::]:50060") + mock_server.start.assert_called_once() + mock_server.wait_for_termination.assert_called_once() + mock_server.stop.assert_called_once_with(grace=5) + self.assertTrue(taskworker._shutdown_event.is_set()) + + +class TestWorkerServicer(TestCase): + def test_push_task_success(self) -> None: + taskworker = TaskWorker( + app_module="examples.app:app", + broker_hosts=["127.0.0.1:50051"], + max_child_task_count=100, + process_type="fork", + push_mode=True, + ) + with mock.patch.object(taskworker, "push_task", return_value=True) as mock_push_task: + request = PushTaskRequest( + task=SIMPLE_TASK.activation, + callback_url="broker-host:50051", + ) + mock_context = mock.MagicMock() + servicer = WorkerServicer(taskworker) + + response = servicer.PushTask(request, mock_context) + + self.assertIsInstance(response, PushTaskResponse) + mock_context.abort.assert_not_called() + mock_push_task.assert_called_once_with(mock.ANY, timeout=5) + (inflight,) = mock_push_task.call_args[0] + self.assertEqual(inflight.activation.id, SIMPLE_TASK.activation.id) + self.assertEqual(inflight.host, "broker-host:50051") + + def test_push_task_worker_busy(self) -> None: + taskworker = TaskWorker( + app_module="examples.app:app", + broker_hosts=["127.0.0.1:50051"], + max_child_task_count=100, + process_type="fork", + child_tasks_queue_maxsize=1, + ) + with mock.patch.object(taskworker, "push_task", return_value=False): + request = PushTaskRequest( + task=SIMPLE_TASK.activation, + callback_url="broker-host:50051", + ) + mock_context = mock.MagicMock() + servicer = WorkerServicer(taskworker) + + servicer.PushTask(request, mock_context) + + mock_context.abort.assert_called_once_with( + grpc.StatusCode.RESOURCE_EXHAUSTED, "worker busy" + ) + @mock.patch("taskbroker_client.worker.workerchild.capture_checkin") def test_child_process_complete(mock_capture_checkin: mock.MagicMock) -> None: diff --git a/uv.lock b/uv.lock index fdf8f228..941ab117 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.11" resolution-markers = [ "sys_platform == 'darwin' or sys_platform == 'linux'", @@ -546,7 +546,7 @@ wheels = [ [[package]] name = "sentry-protos" -version = "0.4.11" +version = "0.8.6" source = { registry = "https://pypi.devinfra.sentry.io/simple" } dependencies = [ { name = "grpc-stubs", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, @@ -554,7 +554,7 @@ dependencies = [ { name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, ] wheels = [ - { url = "https://pypi.devinfra.sentry.io/wheels/sentry_protos-0.4.11-py3-none-any.whl", hash = "sha256:d60709cd9989679fbe1ca1a9a02393a3af59292da333905d5b28beaa04220352" }, + { url = "https://pypi.devinfra.sentry.io/wheels/sentry_protos-0.8.6-py3-none-any.whl", hash = "sha256:bffd32fae9df928a1d4fc519c1ff02fa3ba8fac7bf8ba0ea6495b1eb353575ef" }, ] [[package]] @@ -698,7 +698,7 @@ requires-dist = [ { name = "redis", specifier = ">=3.4.1" }, { name = "redis-py-cluster", specifier = ">=2.1.0" }, { name = "sentry-arroyo", specifier = ">=2.33.1" }, - { name = "sentry-protos", specifier = ">=0.4.11" }, + { name = "sentry-protos", specifier = ">=0.8.5" }, { name = "sentry-sdk", extras = ["http2"], specifier = ">=2.43.0" }, { name = "setuptools", marker = "extra == 'examples'", specifier = ">=80.0" }, { name = "zstandard", specifier = ">=0.18.0" }, From 357099be684597368a0b5e0c1d0f1270bb53e6cc Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Wed, 18 Mar 2026 17:42:20 -0700 Subject: [PATCH 02/12] Fix Worker CLI Options --- clients/python/src/examples/cli.py | 16 ++++++++++++---- config.yaml | 2 ++ 2 files changed, 14 insertions(+), 4 deletions(-) create mode 100644 config.yaml diff --git a/clients/python/src/examples/cli.py b/clients/python/src/examples/cli.py index d58761bc..28088fb3 100644 --- a/clients/python/src/examples/cli.py +++ b/clients/python/src/examples/cli.py @@ -74,11 +74,18 @@ def scheduler() -> None: default=2, ) @click.option( - "--push-mode", - help="Whether to run in PUSH or PULL mode.", - default=False, + "--push-mode", help="Whether to run in PUSH or PULL mode.", default=False, is_flag=True ) -def worker(rpc_host: str, concurrency: int, push_mode: bool) -> None: +@click.option( + "--grpc-port", help="Whether to run in PUSH or PULL mode.", default=False, is_flag=True +) +@click.option( + "--grpc-port", + help="Port for the gRPC server to listen on.", + default=50052, + type=int, +) +def worker(rpc_host: str, concurrency: int, push_mode: bool, grpc_port: int) -> None: from taskbroker_client.worker import TaskWorker click.echo("Starting worker") @@ -93,6 +100,7 @@ def worker(rpc_host: str, concurrency: int, push_mode: bool) -> None: processing_pool_name="examples", process_type="forkserver", push_mode=push_mode, + grpc_port=grpc_port, ) exitcode = worker.start() raise SystemExit(exitcode) diff --git a/config.yaml b/config.yaml new file mode 100644 index 00000000..26d870f0 --- /dev/null +++ b/config.yaml @@ -0,0 +1,2 @@ +push_mode: true +log_filter: "debug,sqlx=debug,librdkafka=warn,h2=off" From eef475aa2d06f2c4e29e132d50bea386199c6031 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Wed, 18 Mar 2026 18:03:53 -0700 Subject: [PATCH 03/12] Fix gRPC Port CLI Option Duplicate --- clients/python/src/examples/cli.py | 3 --- config.yaml | 2 -- 2 files changed, 5 deletions(-) delete mode 100644 config.yaml diff --git a/clients/python/src/examples/cli.py b/clients/python/src/examples/cli.py index 28088fb3..02ce6e40 100644 --- a/clients/python/src/examples/cli.py +++ b/clients/python/src/examples/cli.py @@ -76,9 +76,6 @@ def scheduler() -> None: @click.option( "--push-mode", help="Whether to run in PUSH or PULL mode.", default=False, is_flag=True ) -@click.option( - "--grpc-port", help="Whether to run in PUSH or PULL mode.", default=False, is_flag=True -) @click.option( "--grpc-port", help="Port for the gRPC server to listen on.", diff --git a/config.yaml b/config.yaml deleted file mode 100644 index 26d870f0..00000000 --- a/config.yaml +++ /dev/null @@ -1,2 +0,0 @@ -push_mode: true -log_filter: "debug,sqlx=debug,librdkafka=warn,h2=off" From 1988ba9a97f2bd0b1c2b6a01b0c4f5ef7913463b Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Wed, 18 Mar 2026 18:49:28 -0700 Subject: [PATCH 04/12] Stop Server Only if Exists --- clients/python/src/taskbroker_client/worker/worker.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/clients/python/src/taskbroker_client/worker/worker.py b/clients/python/src/taskbroker_client/worker/worker.py index 6855668e..9937b1d8 100644 --- a/clients/python/src/taskbroker_client/worker/worker.py +++ b/clients/python/src/taskbroker_client/worker/worker.py @@ -186,7 +186,10 @@ def signal_handler(*args: Any) -> None: server.wait_for_termination() except KeyboardInterrupt: - server.stop(grace=5) + # This may be triggered before the server is initialized + if server: + server.stop(grace=5) + self.shutdown() else: try: From 1df88aa6982f759aa8de3141114288ec94f8a298 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Wed, 18 Mar 2026 19:13:48 -0700 Subject: [PATCH 05/12] Make Sure Server Variable Exists --- clients/python/src/taskbroker_client/worker/worker.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/clients/python/src/taskbroker_client/worker/worker.py b/clients/python/src/taskbroker_client/worker/worker.py index 9937b1d8..50e42cb3 100644 --- a/clients/python/src/taskbroker_client/worker/worker.py +++ b/clients/python/src/taskbroker_client/worker/worker.py @@ -172,6 +172,8 @@ def signal_handler(*args: Any) -> None: signal.signal(signal.SIGTERM, signal_handler) if self._push_mode: + server = None + try: # Start gRPC server server = grpc.server(ThreadPoolExecutor(max_workers=self._concurrency)) From 43ec7275d8a201f407863ebe6bbee52866f284d4 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Thu, 19 Mar 2026 11:21:18 -0700 Subject: [PATCH 06/12] Pass Timeout to Queue Put Directly --- clients/python/src/taskbroker_client/worker/worker.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/clients/python/src/taskbroker_client/worker/worker.py b/clients/python/src/taskbroker_client/worker/worker.py index 50e42cb3..efd75e9b 100644 --- a/clients/python/src/taskbroker_client/worker/worker.py +++ b/clients/python/src/taskbroker_client/worker/worker.py @@ -251,15 +251,12 @@ def push_task(self, inflight: InflightTaskActivation, timeout: float | None = No try: self._metrics.gauge("taskworker.child_tasks.size", self._child_tasks.qsize()) except Exception as e: - # `qsize` does not work on all machines + # 'qsize' does not work on macOS logger.warning(f"Could not report size of child tasks queue - {e}") start_time = time.monotonic() try: - if timeout is None: - self._child_tasks.put(inflight) - else: - self._child_tasks.put(inflight, timeout=timeout) + self._child_tasks.put(inflight, timeout=timeout) except queue.Full: self._metrics.incr( "taskworker.worker.push_task.busy", From b61581ca193e6dd100cccf4b2e4ff392cad3af26 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Thu, 19 Mar 2026 13:47:06 -0700 Subject: [PATCH 07/12] Fix Python Client Tests --- clients/python/tests/worker/test_worker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/clients/python/tests/worker/test_worker.py b/clients/python/tests/worker/test_worker.py index 84c27398..3244c719 100644 --- a/clients/python/tests/worker/test_worker.py +++ b/clients/python/tests/worker/test_worker.py @@ -325,7 +325,7 @@ def test_push_task_success_no_timeout(self) -> None: mock.ANY, tags={"processing_pool": "unknown"}, ) - mock_queue.put.assert_called_once_with(SIMPLE_TASK) + mock_queue.put.assert_called_once_with(SIMPLE_TASK, timeout=None) def test_push_task_success_with_timeout(self) -> None: taskworker = TaskWorker( @@ -384,7 +384,7 @@ def test_push_task_gauge_exception_still_enqueues(self) -> None: result = taskworker.push_task(SIMPLE_TASK, timeout=None) self.assertTrue(result) - mock_queue.put.assert_called_once_with(SIMPLE_TASK) + mock_queue.put.assert_called_once_with(SIMPLE_TASK, timeout=None) mock_metrics.distribution.assert_called_once() def test_run_once_current_task_state(self) -> None: From 9ed3733e9244b613ed35089d637045a4e7ea1cc3 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Sun, 22 Mar 2026 20:37:41 -0700 Subject: [PATCH 08/12] Fix Inconsistent Stub Locking, Shutdown on Termination --- .../python/src/taskbroker_client/worker/client.py | 5 ++--- .../python/src/taskbroker_client/worker/worker.py | 12 +++++++----- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/clients/python/src/taskbroker_client/worker/client.py b/clients/python/src/taskbroker_client/worker/client.py index f19aeaf9..ba8c852e 100644 --- a/clients/python/src/taskbroker_client/worker/client.py +++ b/clients/python/src/taskbroker_client/worker/client.py @@ -257,11 +257,10 @@ def _get_cur_stub(self) -> tuple[str, ConsumerServiceStub]: tags={"reason": "max_tasks_reached"}, ) - if self._cur_host not in self._host_to_stubs: - self._host_to_stubs[self._cur_host] = self._connect_to_host(self._cur_host) + stub = self._get_stub(self._cur_host) self._num_tasks_before_rebalance -= 1 - return self._cur_host, self._host_to_stubs[self._cur_host] + return self._cur_host, stub def get_task(self, namespace: str | None = None) -> InflightTaskActivation | None: """ diff --git a/clients/python/src/taskbroker_client/worker/worker.py b/clients/python/src/taskbroker_client/worker/worker.py index efd75e9b..7213a64f 100644 --- a/clients/python/src/taskbroker_client/worker/worker.py +++ b/clients/python/src/taskbroker_client/worker/worker.py @@ -184,12 +184,14 @@ def signal_handler(*args: Any) -> None: server.start() logger.info("taskworker.grpc_server.started", extra={"port": self._grpc_port}) - # Wait for shutdown signal - server.wait_for_termination() + try: + server.wait_for_termination() + except KeyboardInterrupt: + # Signals are converted to KeyboardInterrupt, swallow for exit code 0 + pass - except KeyboardInterrupt: - # This may be triggered before the server is initialized - if server: + finally: + if server is not None: server.stop(grace=5) self.shutdown() From a0303111192c2f47aa0a7909359f323672da1b90 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Mon, 23 Mar 2026 14:18:32 -0700 Subject: [PATCH 09/12] Remove Unused Fields from TaskbrokerClient --- clients/python/src/taskbroker_client/worker/client.py | 4 ---- clients/python/src/taskbroker_client/worker/worker.py | 2 -- 2 files changed, 6 deletions(-) diff --git a/clients/python/src/taskbroker_client/worker/client.py b/clients/python/src/taskbroker_client/worker/client.py index ba8c852e..12caba15 100644 --- a/clients/python/src/taskbroker_client/worker/client.py +++ b/clients/python/src/taskbroker_client/worker/client.py @@ -139,16 +139,12 @@ def __init__( health_check_settings: HealthCheckSettings | None = None, rpc_secret: str | None = None, grpc_config: str | None = None, - push_mode: bool = False, - grpc_port: int = 50052, ) -> None: assert len(hosts) > 0, "You must provide at least one RPC host to connect to" self._application = application self._hosts = hosts self._rpc_secret = rpc_secret self._metrics = metrics - self._push_mode = push_mode - self._grpc_port = grpc_port self._grpc_options: list[tuple[str, Any]] = [ ("grpc.max_receive_message_length", MAX_ACTIVATION_SIZE) diff --git a/clients/python/src/taskbroker_client/worker/worker.py b/clients/python/src/taskbroker_client/worker/worker.py index 7213a64f..206e42c2 100644 --- a/clients/python/src/taskbroker_client/worker/worker.py +++ b/clients/python/src/taskbroker_client/worker/worker.py @@ -121,8 +121,6 @@ def __init__( ), rpc_secret=app.config["rpc_secret"], grpc_config=app.config["grpc_config"], - push_mode=push_mode, - grpc_port=grpc_port, ) self._metrics = app.metrics From 447ef381d8d8ec9aaffd9a672aaf69b786d2634b Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Mon, 23 Mar 2026 14:23:29 -0700 Subject: [PATCH 10/12] Fix Tests --- clients/python/tests/worker/test_worker.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/clients/python/tests/worker/test_worker.py b/clients/python/tests/worker/test_worker.py index 3244c719..af80386c 100644 --- a/clients/python/tests/worker/test_worker.py +++ b/clients/python/tests/worker/test_worker.py @@ -446,8 +446,6 @@ def test_constructor_push_mode_and_grpc_port_stored(self) -> None: ) self.assertTrue(taskworker._push_mode) self.assertEqual(taskworker._grpc_port, 50099) - self.assertTrue(taskworker.client._push_mode) - self.assertEqual(taskworker.client._grpc_port, 50099) def test_constructor_pull_mode_default(self) -> None: taskworker = TaskWorker( From f0414bfa092e3a1343a0c0c8e7e25520cf9b470d Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Tue, 24 Mar 2026 14:33:35 -0700 Subject: [PATCH 11/12] Make Queue Size Log Debug, Remove Silly Tests --- .../src/taskbroker_client/worker/client.py | 114 ++++++++++++++++- .../src/taskbroker_client/worker/worker.py | 24 +++- clients/python/tests/worker/test_client.py | 92 ++++++++++++++ clients/python/tests/worker/test_worker.py | 116 ++---------------- 4 files changed, 236 insertions(+), 110 deletions(-) diff --git a/clients/python/src/taskbroker_client/worker/client.py b/clients/python/src/taskbroker_client/worker/client.py index 12caba15..cf1462a5 100644 --- a/clients/python/src/taskbroker_client/worker/client.py +++ b/clients/python/src/taskbroker_client/worker/client.py @@ -7,7 +7,7 @@ from collections.abc import Callable from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Sequence, Tuple, Union import grpc import orjson @@ -27,6 +27,11 @@ from taskbroker_client.metrics import MetricsBackend from taskbroker_client.types import InflightTaskActivation, ProcessingResult +if TYPE_CHECKING: + ServerInterceptor = grpc.ServerInterceptor[Any, Any] +else: + ServerInterceptor = grpc.ServerInterceptor + logger = logging.getLogger(__name__) MAX_ACTIVATION_SIZE = 1024 * 1024 * 10 @@ -108,6 +113,109 @@ def intercept_unary_unary( return continuation(call_details_with_meta, request) +def parse_rpc_secret_list(rpc_secret: str | None) -> list[str] | None: + """ + Parse the task app `rpc_secret` JSON array string into a list of secrets. + Returns `None` when unset, invalid, or empty (no authentication). + """ + if not rpc_secret: + return None + + # Try to parse the provided secret + parsed = orjson.loads(rpc_secret) + + if not isinstance(parsed, list) or len(parsed) == 0: + # If the secret string is not a list with at least one element, it is invalid + return None + + return [str(x) for x in parsed] + + +# The `grpc` package defines this type in `grpc._typing` but does not export it for some reason +type Metadata = Sequence[Tuple[str, Union[str, bytes]]] + + +def grpc_metadata_get(metadata: Metadata, key: str) -> str | None: + """ + First matching gRPC metadata value for `key` or `None` if not present. + """ + # gRPC metadata keys are ASCII and compared case-insensitively + key = key.lower() + + for k, v in metadata: + if k.lower() == key: + return v.decode("utf-8") if isinstance(v, bytes) else v + + return None + + +def verify_rpc_request_signature_hmac( + secrets: list[str], + method: str, + request_body: bytes, + signature_hex: str | None, +) -> bool: + """ + Verify the 'sentry-signature' metadata for a unary RPC body. + Uses the same signing contract as `RequestSignatureInterceptor` and the taskbroker. + """ + if not secrets: + return True + + if not signature_hex: + return False + + try: + sig_bytes = bytes.fromhex(signature_hex) + except ValueError: + return False + + signing_payload = method.encode("utf-8") + b":" + request_body + + for secret in secrets: + expected = hmac.new(secret.encode("utf-8"), signing_payload, hashlib.sha256).digest() + if hmac.compare_digest(expected, sig_bytes): + return True + + return False + + +class RequestSignatureServerInterceptor(ServerInterceptor): + """ + Enforces HMAC request signing on unary-unary RPCs like `WorkerService.PushTask`. + """ + + def __init__(self, secrets: list[str]) -> None: + self._secrets = secrets + + def intercept_service(self, continuation: Any, handler_call_details: Any) -> Any: + handler = continuation(handler_call_details) + if handler is None or not self._secrets: + return handler + + if handler.request_streaming or handler.response_streaming or handler.unary_unary is None: + return handler + + method = handler_call_details.method + metadata = handler_call_details.invocation_metadata + original = handler.unary_unary + + def unary_unary(request: Any, context: grpc.ServicerContext) -> Any: + signature = grpc_metadata_get(metadata, "sentry-signature") + body = request.SerializeToString() + + if verify_rpc_request_signature_hmac(self._secrets, method, body, signature): + return original(request, context) + + context.abort(grpc.StatusCode.UNAUTHENTICATED, "Authentication failed") + + return grpc.unary_unary_rpc_method_handler( + unary_unary, + request_deserializer=handler.request_deserializer, + response_serializer=handler.response_serializer, + ) + + class HostTemporarilyUnavailable(Exception): """Raised when a host is temporarily unavailable and should be retried later.""" @@ -196,8 +304,8 @@ def _emit_health_check(self) -> None: def _connect_to_host(self, host: str) -> ConsumerServiceStub: logger.info("taskworker.client.connect", extra={"host": host}) channel = grpc.insecure_channel(host, options=self._grpc_options) - if self._rpc_secret: - secrets = orjson.loads(self._rpc_secret) + secrets = parse_rpc_secret_list(self._rpc_secret) + if secrets: channel = grpc.intercept_channel(channel, RequestSignatureInterceptor(secrets)) return ConsumerServiceStub(channel) diff --git a/clients/python/src/taskbroker_client/worker/worker.py b/clients/python/src/taskbroker_client/worker/worker.py index 206e42c2..81c116ea 100644 --- a/clients/python/src/taskbroker_client/worker/worker.py +++ b/clients/python/src/taskbroker_client/worker/worker.py @@ -10,7 +10,7 @@ from multiprocessing.context import ForkContext, ForkServerContext, SpawnContext from multiprocessing.process import BaseProcess from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any, List import grpc from sentry_protos.taskbroker.v1 import taskbroker_pb2_grpc @@ -31,10 +31,18 @@ from taskbroker_client.worker.client import ( HealthCheckSettings, HostTemporarilyUnavailable, + RequestSignatureServerInterceptor, TaskbrokerClient, + parse_rpc_secret_list, ) from taskbroker_client.worker.workerchild import child_process +if TYPE_CHECKING: + ServerInterceptor = grpc.ServerInterceptor[Any, Any] +else: + ServerInterceptor = grpc.ServerInterceptor + + logger = logging.getLogger(__name__) @@ -152,6 +160,7 @@ def __init__( self._push_mode = push_mode self._grpc_port = grpc_port + self._grpc_secrets = parse_rpc_secret_list(app.config["rpc_secret"]) def start(self) -> int: """ @@ -174,7 +183,16 @@ def signal_handler(*args: Any) -> None: try: # Start gRPC server - server = grpc.server(ThreadPoolExecutor(max_workers=self._concurrency)) + interceptors: List[ServerInterceptor] = [] + + if self._grpc_secrets: + interceptors = [RequestSignatureServerInterceptor(self._grpc_secrets)] + + server = grpc.server( + ThreadPoolExecutor(max_workers=self._concurrency), + interceptors=interceptors, + ) + taskbroker_pb2_grpc.add_WorkerServiceServicer_to_server( WorkerServicer(self), server ) @@ -252,7 +270,7 @@ def push_task(self, inflight: InflightTaskActivation, timeout: float | None = No self._metrics.gauge("taskworker.child_tasks.size", self._child_tasks.qsize()) except Exception as e: # 'qsize' does not work on macOS - logger.warning(f"Could not report size of child tasks queue - {e}") + logger.debug("taskworker.child_tasks.size.error", extra={"error": e}) start_time = time.monotonic() try: diff --git a/clients/python/tests/worker/test_client.py b/clients/python/tests/worker/test_client.py index 4297ccfc..f2509cd2 100644 --- a/clients/python/tests/worker/test_client.py +++ b/clients/python/tests/worker/test_client.py @@ -1,4 +1,6 @@ import dataclasses +import hashlib +import hmac import random import string import time @@ -17,6 +19,7 @@ FetchNextTask, GetTaskRequest, GetTaskResponse, + PushTaskRequest, SetTaskStatusRequest, SetTaskStatusResponse, TaskActivation, @@ -29,7 +32,10 @@ HealthCheckSettings, HostTemporarilyUnavailable, TaskbrokerClient, + grpc_metadata_get, make_broker_hosts, + parse_rpc_secret_list, + verify_rpc_request_signature_hmac, ) @@ -288,6 +294,92 @@ def test_get_task_with_interceptor() -> None: assert result.activation.namespace == "testing" +_PUSH_TASK_METHOD = "/sentry_protos.taskbroker.v1.WorkerService/PushTask" + + +def _push_task_hmac(secret: bytes, body: bytes) -> str: + payload = _PUSH_TASK_METHOD.encode("utf-8") + b":" + body + return hmac.new(secret, payload, hashlib.sha256).hexdigest() + + +def test_verify_rpc_request_signature_hmac_empty_secrets_skips_check() -> None: + body = PushTaskRequest( + task=TaskActivation( + id="abc123", + namespace="testing", + taskname="do_thing", + parameters="", + headers={}, + processing_deadline_duration=1, + ), + callback_url="taskbroker:50051", + ).SerializeToString() + + # If the secrets list is empty, any signature is valid + assert verify_rpc_request_signature_hmac([], _PUSH_TASK_METHOD, body, None) is True + assert verify_rpc_request_signature_hmac([], _PUSH_TASK_METHOD, body, "deadbeef") is True + + +def test_verify_rpc_request_signature_hmac_valid() -> None: + body = PushTaskRequest( + task=TaskActivation( + id="abc123", + namespace="testing", + taskname="do_thing", + parameters="", + headers={}, + processing_deadline_duration=1, + ), + callback_url="taskbroker:50051", + ).SerializeToString() + + # If 'correct' is one of the secrets, this signature should be valid + signature = _push_task_hmac(b"correct", body) + assert ( + verify_rpc_request_signature_hmac(["correct", "other"], _PUSH_TASK_METHOD, body, signature) + is True + ) + + +def test_verify_rpc_request_signature_hmac_wrong_secret() -> None: + body = PushTaskRequest(callback_url="taskbroker:50051").SerializeToString() + + # We should catch an invalid signature + signature = _push_task_hmac(b"expected", body) + assert verify_rpc_request_signature_hmac(["wrong"], _PUSH_TASK_METHOD, body, signature) is False + + +def test_verify_rpc_request_signature_hmac_missing_signature() -> None: + # We should catch a missing signature when one is expected + body = PushTaskRequest(callback_url="taskbroker:50051").SerializeToString() + assert verify_rpc_request_signature_hmac(["s"], _PUSH_TASK_METHOD, body, None) is False + + +def test_verify_rpc_request_signature_hmac_bad_hex() -> None: + # We should reject a signature with the wrong shape + body = PushTaskRequest(callback_url="taskbroker:50051").SerializeToString() + assert verify_rpc_request_signature_hmac(["s"], _PUSH_TASK_METHOD, body, "not-hex") is False + + +def test_grpc_metadata_get() -> None: + # Make sure our 'grpc_metadata_get' helper works + md = (("Sentry-Signature", b"abc"), ("other", "v")) + assert grpc_metadata_get(md, "sentry-signature") == "abc" + + +def test_parse_rpc_secret_list() -> None: + # Handle scenarios in which... + # - No input is provided + # - The input string is not a list + # - The input string is an empty list + assert parse_rpc_secret_list(None) is None + assert parse_rpc_secret_list("") is None + assert parse_rpc_secret_list("[]") is None + + # Correctly parse a valid list of strings + assert parse_rpc_secret_list('["a","b"]') == ["a", "b"] + + def test_get_task_with_namespace() -> None: channel = MockChannel() channel.add_response( diff --git a/clients/python/tests/worker/test_worker.py b/clients/python/tests/worker/test_worker.py index af80386c..5e41361e 100644 --- a/clients/python/tests/worker/test_worker.py +++ b/clients/python/tests/worker/test_worker.py @@ -302,7 +302,7 @@ def get_task_response(*args: Any, **kwargs: Any) -> InflightTaskActivation | Non assert mock_client.get_task.called assert mock_client.update_task.call_count == 3 - def test_push_task_success_no_timeout(self) -> None: + def test_push_task_queue(self) -> None: taskworker = TaskWorker( app_module="examples.app:app", broker_hosts=["127.0.0.1:50051"], @@ -310,82 +310,18 @@ def test_push_task_success_no_timeout(self) -> None: process_type="fork", child_tasks_queue_maxsize=2, ) - mock_metrics = mock.MagicMock() - taskworker._metrics = mock_metrics - mock_queue = mock.MagicMock() - mock_queue.full.return_value = False - taskworker._child_tasks = mock_queue + # We can enqueue the first task result = taskworker.push_task(SIMPLE_TASK, timeout=None) - self.assertTrue(result) - mock_metrics.gauge.assert_called_once_with("taskworker.child_tasks.size", mock.ANY) - mock_metrics.distribution.assert_called_once_with( - "taskworker.worker.child_task.put.duration", - mock.ANY, - tags={"processing_pool": "unknown"}, - ) - mock_queue.put.assert_called_once_with(SIMPLE_TASK, timeout=None) - - def test_push_task_success_with_timeout(self) -> None: - taskworker = TaskWorker( - app_module="examples.app:app", - broker_hosts=["127.0.0.1:50051"], - max_child_task_count=100, - process_type="fork", - child_tasks_queue_maxsize=2, - ) - taskworker._metrics = mock.MagicMock() - mock_queue = mock.MagicMock() - taskworker._child_tasks = mock_queue - - result = taskworker.push_task(SIMPLE_TASK, timeout=1.0) + # We can enqueue the second task + result = taskworker.push_task(SIMPLE_TASK, timeout=1) self.assertTrue(result) - mock_queue.put.assert_called_once_with(SIMPLE_TASK, timeout=1.0) - - def test_push_task_queue_full_returns_false_and_incr_busy(self) -> None: - taskworker = TaskWorker( - app_module="examples.app:app", - broker_hosts=["127.0.0.1:50051"], - max_child_task_count=100, - process_type="fork", - child_tasks_queue_maxsize=1, - ) - taskworker._metrics = mock.MagicMock() - mock_queue = mock.MagicMock() - mock_queue.qsize.return_value = 1 - mock_queue.put.side_effect = queue.Full - taskworker._child_tasks = mock_queue - - result = taskworker.push_task(RETRY_TASK, timeout=0.01) + # We cannot enqueue the third task because the queue is full + result = taskworker.push_task(SIMPLE_TASK, timeout=1) self.assertFalse(result) - taskworker._metrics.incr.assert_called_once_with( - "taskworker.worker.push_task.busy", - tags={"processing_pool": "unknown"}, - ) - self.assertEqual(mock_queue.put.call_count, 1) - - def test_push_task_gauge_exception_still_enqueues(self) -> None: - taskworker = TaskWorker( - app_module="examples.app:app", - broker_hosts=["127.0.0.1:50051"], - max_child_task_count=100, - process_type="fork", - child_tasks_queue_maxsize=2, - ) - mock_metrics = mock.MagicMock() - mock_metrics.gauge.side_effect = RuntimeError("qsize not supported") - taskworker._metrics = mock_metrics - mock_queue = mock.MagicMock() - taskworker._child_tasks = mock_queue - - result = taskworker.push_task(SIMPLE_TASK, timeout=None) - - self.assertTrue(result) - mock_queue.put.assert_called_once_with(SIMPLE_TASK, timeout=None) - mock_metrics.distribution.assert_called_once() def test_run_once_current_task_state(self) -> None: # Run a task that uses retry_task() helper @@ -435,7 +371,7 @@ def update_task_response(*args: Any, **kwargs: Any) -> None: assert redis.get("no-retries-remaining"), "key should exist if except block was hit" redis.delete("no-retries-remaining") - def test_constructor_push_mode_and_grpc_port_stored(self) -> None: + def test_constructor_push_mode(self) -> None: taskworker = TaskWorker( app_module="examples.app:app", broker_hosts=["127.0.0.1:50051"], @@ -444,51 +380,23 @@ def test_constructor_push_mode_and_grpc_port_stored(self) -> None: push_mode=True, grpc_port=50099, ) + + # Make sure delivery mode and gRPC port arguments are stored self.assertTrue(taskworker._push_mode) self.assertEqual(taskworker._grpc_port, 50099) - def test_constructor_pull_mode_default(self) -> None: + def test_constructor_pull_mode(self) -> None: taskworker = TaskWorker( app_module="examples.app:app", broker_hosts=["127.0.0.1:50051"], max_child_task_count=100, process_type="fork", ) + + # Make sure delivery mode and gRPC port are set to their defaults self.assertFalse(taskworker._push_mode) self.assertEqual(taskworker._grpc_port, 50052) - def test_start_push_mode_server_creation_and_shutdown(self) -> None: - mock_server = mock.MagicMock() - mock_server.wait_for_termination.side_effect = KeyboardInterrupt() - - taskworker = TaskWorker( - app_module="examples.app:app", - broker_hosts=["127.0.0.1:50051"], - max_child_task_count=100, - process_type="fork", - push_mode=True, - grpc_port=50060, - ) - with mock.patch("taskbroker_client.worker.worker.grpc.server") as mock_grpc_server: - mock_grpc_server.return_value = mock_server - with mock.patch("taskbroker_client.worker.worker.ThreadPoolExecutor") as mock_tpe: - with mock.patch( - "taskbroker_client.worker.worker.taskbroker_pb2_grpc.add_WorkerServiceServicer_to_server" - ) as mock_add_servicer: - exitcode = taskworker.start() - - self.assertEqual(exitcode, 0) - mock_grpc_server.assert_called_once() - self.assertEqual(mock_tpe.call_args[1]["max_workers"], 1) - mock_add_servicer.assert_called_once() - self.assertIsInstance(mock_add_servicer.call_args[0][0], WorkerServicer) - self.assertEqual(mock_add_servicer.call_args[0][1], mock_server) - mock_server.add_insecure_port.assert_called_once_with("[::]:50060") - mock_server.start.assert_called_once() - mock_server.wait_for_termination.assert_called_once() - mock_server.stop.assert_called_once_with(grace=5) - self.assertTrue(taskworker._shutdown_event.is_set()) - class TestWorkerServicer(TestCase): def test_push_task_success(self) -> None: From f41ba745505dd2252c15fff1ee63c8203c14e862 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Wed, 25 Mar 2026 17:27:46 -0700 Subject: [PATCH 12/12] Fix Deserialization Bug in HMAC Verification --- .../src/taskbroker_client/worker/client.py | 55 +++++++++++++------ 1 file changed, 37 insertions(+), 18 deletions(-) diff --git a/clients/python/src/taskbroker_client/worker/client.py b/clients/python/src/taskbroker_client/worker/client.py index cf1462a5..751b4c6c 100644 --- a/clients/python/src/taskbroker_client/worker/client.py +++ b/clients/python/src/taskbroker_client/worker/client.py @@ -7,7 +7,7 @@ from collections.abc import Callable from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Any, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Any, Optional, Sequence, Tuple, Union import grpc import orjson @@ -34,6 +34,10 @@ logger = logging.getLogger(__name__) +# gRPC runs the unary request deserializer and the servicer on the same thread for a given call +# If HMAC verification fails inside a deserializer wrapper, raising turns into INTERNAL, so we set this flag and abort UNAUTHENTICATED from the servicer +_RPC_SIGNATURE_AUTH_TLS = threading.local() + MAX_ACTIVATION_SIZE = 1024 * 1024 * 10 """Max payload size we will process.""" @@ -73,17 +77,19 @@ def __init__( self.credentials = credentials -# Type alias based on grpc-stubs -ContinuationType = Callable[[ClientCallDetails, Message], Any] - - if TYPE_CHECKING: + RpcMethodHandler = grpc.RpcMethodHandler[Any, Any] InterceptorBase = grpc.UnaryUnaryClientInterceptor[Message, Message] CallFuture = grpc.CallFuture[Message] else: + RpcMethodHandler = grpc.RpcMethodHandler InterceptorBase = grpc.UnaryUnaryClientInterceptor CallFuture = Any +ClientContinuation = Callable[[ClientCallDetails, Message], Any] +ServerContinuation = Callable[[grpc.HandlerCallDetails], Optional[RpcMethodHandler]] +Metadata = Sequence[Tuple[str, Union[str, bytes]]] + class RequestSignatureInterceptor(InterceptorBase): def __init__(self, shared_secret: list[str]): @@ -91,7 +97,7 @@ def __init__(self, shared_secret: list[str]): def intercept_unary_unary( self, - continuation: ContinuationType, + continuation: ClientContinuation, client_call_details: grpc.ClientCallDetails, request: Message, ) -> CallFuture: @@ -131,10 +137,6 @@ def parse_rpc_secret_list(rpc_secret: str | None) -> list[str] | None: return [str(x) for x in parsed] -# The `grpc` package defines this type in `grpc._typing` but does not export it for some reason -type Metadata = Sequence[Tuple[str, Union[str, bytes]]] - - def grpc_metadata_get(metadata: Metadata, key: str) -> str | None: """ First matching gRPC metadata value for `key` or `None` if not present. @@ -183,12 +185,16 @@ def verify_rpc_request_signature_hmac( class RequestSignatureServerInterceptor(ServerInterceptor): """ Enforces HMAC request signing on unary-unary RPCs like `WorkerService.PushTask`. + Verification uses the raw request bytes from the wire (via a wrapped deserializer) + so it stays consistent across languages and map encodings. """ def __init__(self, secrets: list[str]) -> None: self._secrets = secrets - def intercept_service(self, continuation: Any, handler_call_details: Any) -> Any: + def intercept_service( + self, continuation: ServerContinuation, handler_call_details: grpc.HandlerCallDetails + ) -> Any: handler = continuation(handler_call_details) if handler is None or not self._secrets: return handler @@ -196,22 +202,35 @@ def intercept_service(self, continuation: Any, handler_call_details: Any) -> Any if handler.request_streaming or handler.response_streaming or handler.unary_unary is None: return handler + inner_deserializer = handler.request_deserializer + if inner_deserializer is None: + return handler + method = handler_call_details.method metadata = handler_call_details.invocation_metadata + signature = grpc_metadata_get(metadata, "sentry-signature") original = handler.unary_unary - def unary_unary(request: Any, context: grpc.ServicerContext) -> Any: - signature = grpc_metadata_get(metadata, "sentry-signature") - body = request.SerializeToString() + def request_deserializer(serialized_request: bytes) -> Any: + _RPC_SIGNATURE_AUTH_TLS.failed = False + + if not verify_rpc_request_signature_hmac( + self._secrets, method, serialized_request, signature + ): + _RPC_SIGNATURE_AUTH_TLS.failed = True + return inner_deserializer(b"") + + return inner_deserializer(serialized_request) - if verify_rpc_request_signature_hmac(self._secrets, method, body, signature): - return original(request, context) + def unary_unary(request: Any, context: grpc.ServicerContext) -> Any: + if getattr(_RPC_SIGNATURE_AUTH_TLS, "failed", False): + context.abort(grpc.StatusCode.UNAUTHENTICATED, "Authentication failed") - context.abort(grpc.StatusCode.UNAUTHENTICATED, "Authentication failed") + return original(request, context) return grpc.unary_unary_rpc_method_handler( unary_unary, - request_deserializer=handler.request_deserializer, + request_deserializer=request_deserializer, response_serializer=handler.response_serializer, )