Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion clients/python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ readme = "README.md"
dependencies = [
"sentry-arroyo>=2.33.1",
"sentry-sdk[http2]>=2.43.0",
"sentry-protos>=0.4.11",
"sentry-protos>=0.8.5",
"confluent_kafka>=2.3.0",
"cronsim>=2.6",
"grpcio>=1.67.0",
Expand Down
13 changes: 12 additions & 1 deletion clients/python/src/examples/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,16 @@ def scheduler() -> None:
help="The number of child processes to start.",
default=2,
)
def worker(rpc_host: str, concurrency: int) -> None:
@click.option(
"--push-mode", help="Whether to run in PUSH or PULL mode.", default=False, is_flag=True
)
Comment thread
cursor[bot] marked this conversation as resolved.
@click.option(
"--grpc-port",
help="Port for the gRPC server to listen on.",
default=50052,
type=int,
)
def worker(rpc_host: str, concurrency: int, push_mode: bool, grpc_port: int) -> None:
from taskbroker_client.worker import TaskWorker

click.echo("Starting worker")
Expand All @@ -87,6 +96,8 @@ def worker(rpc_host: str, concurrency: int) -> None:
rebalance_after=32,
processing_pool_name="examples",
process_type="forkserver",
push_mode=push_mode,
grpc_port=grpc_port,
)
exitcode = worker.start()
raise SystemExit(exitcode)
Expand Down
28 changes: 28 additions & 0 deletions clients/python/src/taskbroker_client/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,22 @@ class MetricsBackend(Protocol):
An abstract class that defines the interface for metrics backends.
"""

@abstractmethod
def gauge(
self,
key: str,
value: float,
instance: str | None = None,
tags: Tags | None = None,
sample_rate: float = 1,
unit: str | None = None,
stacklevel: int = 0,
) -> None:
"""
Records a gauge metric (a point-in-time value).
"""
raise NotImplementedError
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll have to be careful about implementing this when sentry's client library version is updated. I don't think there are any tests in sentry that will fail because of this new abstract method.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you expand on this?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sentry (and launchpad) now have implementations of this MetricsBackend but the won't have this method defined. When the worker runtime calls metrics.gauge() it will raise as the applications won't have implemented this abstract method.

When we update applications to the new client library release, we have to remember to implement this method, or the push worker will be broken.


@abstractmethod
def incr(
self,
Expand Down Expand Up @@ -71,6 +87,18 @@ class NoOpMetricsBackend(MetricsBackend):
Default metrics backend that does not record anything.
"""

def gauge(
self,
key: str,
value: float,
instance: str | None = None,
tags: Tags | None = None,
sample_rate: float = 1,
unit: str | None = None,
stacklevel: int = 0,
) -> None:
pass

def incr(
self,
name: str,
Expand Down
158 changes: 146 additions & 12 deletions clients/python/src/taskbroker_client/worker/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from collections.abc import Callable
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Optional, Sequence, Tuple, Union

import grpc
import orjson
Expand All @@ -27,8 +27,17 @@
from taskbroker_client.metrics import MetricsBackend
from taskbroker_client.types import InflightTaskActivation, ProcessingResult

if TYPE_CHECKING:
ServerInterceptor = grpc.ServerInterceptor[Any, Any]
else:
ServerInterceptor = grpc.ServerInterceptor

logger = logging.getLogger(__name__)

# gRPC runs the unary request deserializer and the servicer on the same thread for a given call
# If HMAC verification fails inside a deserializer wrapper, raising turns into INTERNAL, so we set this flag and abort UNAUTHENTICATED from the servicer
_RPC_SIGNATURE_AUTH_TLS = threading.local()

MAX_ACTIVATION_SIZE = 1024 * 1024 * 10
"""Max payload size we will process."""

Expand Down Expand Up @@ -68,25 +77,27 @@ def __init__(
self.credentials = credentials


# Type alias based on grpc-stubs
ContinuationType = Callable[[ClientCallDetails, Message], Any]


if TYPE_CHECKING:
RpcMethodHandler = grpc.RpcMethodHandler[Any, Any]
InterceptorBase = grpc.UnaryUnaryClientInterceptor[Message, Message]
CallFuture = grpc.CallFuture[Message]
else:
RpcMethodHandler = grpc.RpcMethodHandler
InterceptorBase = grpc.UnaryUnaryClientInterceptor
CallFuture = Any

ClientContinuation = Callable[[ClientCallDetails, Message], Any]
ServerContinuation = Callable[[grpc.HandlerCallDetails], Optional[RpcMethodHandler]]
Metadata = Sequence[Tuple[str, Union[str, bytes]]]


class RequestSignatureInterceptor(InterceptorBase):
def __init__(self, shared_secret: list[str]):
self._secret = shared_secret[0].encode("utf-8")

def intercept_unary_unary(
self,
continuation: ContinuationType,
continuation: ClientContinuation,
client_call_details: grpc.ClientCallDetails,
request: Message,
) -> CallFuture:
Expand All @@ -108,6 +119,122 @@ def intercept_unary_unary(
return continuation(call_details_with_meta, request)


def parse_rpc_secret_list(rpc_secret: str | None) -> list[str] | None:
"""
Parse the task app `rpc_secret` JSON array string into a list of secrets.
Returns `None` when unset, invalid, or empty (no authentication).
"""
if not rpc_secret:
return None

# Try to parse the provided secret
parsed = orjson.loads(rpc_secret)

if not isinstance(parsed, list) or len(parsed) == 0:
# If the secret string is not a list with at least one element, it is invalid
return None

return [str(x) for x in parsed]


def grpc_metadata_get(metadata: Metadata, key: str) -> str | None:
"""
First matching gRPC metadata value for `key` or `None` if not present.
"""
# gRPC metadata keys are ASCII and compared case-insensitively
key = key.lower()

for k, v in metadata:
if k.lower() == key:
return v.decode("utf-8") if isinstance(v, bytes) else v

return None


def verify_rpc_request_signature_hmac(
secrets: list[str],
method: str,
request_body: bytes,
signature_hex: str | None,
) -> bool:
"""
Verify the 'sentry-signature' metadata for a unary RPC body.
Uses the same signing contract as `RequestSignatureInterceptor` and the taskbroker.
"""
if not secrets:
return True

if not signature_hex:
return False

try:
sig_bytes = bytes.fromhex(signature_hex)
except ValueError:
return False

signing_payload = method.encode("utf-8") + b":" + request_body

for secret in secrets:
expected = hmac.new(secret.encode("utf-8"), signing_payload, hashlib.sha256).digest()
if hmac.compare_digest(expected, sig_bytes):
return True

return False


class RequestSignatureServerInterceptor(ServerInterceptor):
"""
Enforces HMAC request signing on unary-unary RPCs like `WorkerService.PushTask`.
Verification uses the raw request bytes from the wire (via a wrapped deserializer)
so it stays consistent across languages and map encodings.
"""

def __init__(self, secrets: list[str]) -> None:
self._secrets = secrets

def intercept_service(
self, continuation: ServerContinuation, handler_call_details: grpc.HandlerCallDetails
) -> Any:
handler = continuation(handler_call_details)
if handler is None or not self._secrets:
return handler

if handler.request_streaming or handler.response_streaming or handler.unary_unary is None:
return handler

inner_deserializer = handler.request_deserializer
if inner_deserializer is None:
return handler

method = handler_call_details.method
metadata = handler_call_details.invocation_metadata
signature = grpc_metadata_get(metadata, "sentry-signature")
original = handler.unary_unary

def request_deserializer(serialized_request: bytes) -> Any:
_RPC_SIGNATURE_AUTH_TLS.failed = False

if not verify_rpc_request_signature_hmac(
self._secrets, method, serialized_request, signature
):
_RPC_SIGNATURE_AUTH_TLS.failed = True
return inner_deserializer(b"")

return inner_deserializer(serialized_request)

def unary_unary(request: Any, context: grpc.ServicerContext) -> Any:
if getattr(_RPC_SIGNATURE_AUTH_TLS, "failed", False):
context.abort(grpc.StatusCode.UNAUTHENTICATED, "Authentication failed")

return original(request, context)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Auth bypass: context.abort() doesn't stop handler execution

High Severity

In gRPC Python's synchronous server, context.abort() does not raise an exception — it sends the error to the client but execution continues (this is a known, unresolved gRPC Python bug: grpc/grpc#30306, #37518). In RequestSignatureServerInterceptor, when HMAC verification fails, context.abort(UNAUTHENTICATED) is called but original(request, context) still executes on the next line, passing the (empty-deserialized) request to the real PushTask handler. This effectively bypasses authentication, allowing unauthenticated requests to enqueue tasks. The same pattern in WorkerServicer.PushTask is less severe but similarly relies on abort() halting execution. A return statement is needed after each context.abort() call.

Additional Locations (1)
Fix in Cursor Fix in Web

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty sure this is not a problem. But something to remember for the future.


return grpc.unary_unary_rpc_method_handler(
unary_unary,
request_deserializer=request_deserializer,
response_serializer=handler.response_serializer,
)


class HostTemporarilyUnavailable(Exception):
"""Raised when a host is temporarily unavailable and should be retried later."""

Expand Down Expand Up @@ -157,6 +284,7 @@ def __init__(
)

self._cur_host = random.choice(self._hosts)
self._host_to_stubs_lock = threading.Lock()
self._host_to_stubs: dict[str, ConsumerServiceStub] = {
self._cur_host: self._connect_to_host(self._cur_host)
}
Expand Down Expand Up @@ -195,11 +323,17 @@ def _emit_health_check(self) -> None:
def _connect_to_host(self, host: str) -> ConsumerServiceStub:
logger.info("taskworker.client.connect", extra={"host": host})
channel = grpc.insecure_channel(host, options=self._grpc_options)
if self._rpc_secret:
secrets = orjson.loads(self._rpc_secret)
secrets = parse_rpc_secret_list(self._rpc_secret)
if secrets:
channel = grpc.intercept_channel(channel, RequestSignatureInterceptor(secrets))
return ConsumerServiceStub(channel)

def _get_stub(self, host: str) -> ConsumerServiceStub:
Comment thread
george-sentry marked this conversation as resolved.
with self._host_to_stubs_lock:
if host not in self._host_to_stubs:
self._host_to_stubs[host] = self._connect_to_host(host)
return self._host_to_stubs[host]
Comment thread
cursor[bot] marked this conversation as resolved.

def _check_consecutive_unavailable_errors(self) -> None:
if self._num_consecutive_unavailable_errors >= self._max_consecutive_unavailable_errors:
self._temporary_unavailable_hosts[self._cur_host] = (
Expand Down Expand Up @@ -246,11 +380,10 @@ def _get_cur_stub(self) -> tuple[str, ConsumerServiceStub]:
tags={"reason": "max_tasks_reached"},
)

if self._cur_host not in self._host_to_stubs:
self._host_to_stubs[self._cur_host] = self._connect_to_host(self._cur_host)
stub = self._get_stub(self._cur_host)

self._num_tasks_before_rebalance -= 1
return self._cur_host, self._host_to_stubs[self._cur_host]
return self._cur_host, stub

def get_task(self, namespace: str | None = None) -> InflightTaskActivation | None:
"""
Expand Down Expand Up @@ -324,10 +457,11 @@ def update_task(
f"Host: {processing_result.host} is temporarily unavailable"
)

stub = self._get_stub(processing_result.host)
Comment thread
george-sentry marked this conversation as resolved.
with self._metrics.timer(
"taskworker.update_task.rpc", tags={"host": processing_result.host}
):
response = self._host_to_stubs[processing_result.host].SetTaskStatus(request)
response = stub.SetTaskStatus(request)
except grpc.RpcError as err:
self._metrics.incr(
"taskworker.client.rpc_error",
Expand Down
Loading
Loading