diff --git a/src/sentry/taskworker/adapters.py b/src/sentry/taskworker/adapters.py index 6d083f6a02a2c2..ad9315df0dd58c 100644 --- a/src/sentry/taskworker/adapters.py +++ b/src/sentry/taskworker/adapters.py @@ -7,10 +7,14 @@ from __future__ import annotations +import contextlib +import logging import threading +from collections.abc import MutableMapping from contextlib import contextmanager -from typing import Generator +from typing import Any, Generator +import orjson from arroyo.backends.kafka import KafkaProducer from django.conf import settings from django.core.cache.backends.base import BaseCache @@ -26,6 +30,9 @@ from sentry.utils import metrics as sentry_metrics from sentry.utils.arroyo_producer import SingletonProducer, get_arroyo_producer from sentry.utils.memory import track_memory_usage as sentry_track_memory_usage +from sentry.viewer_context import ViewerContext, get_viewer_context, viewer_context_scope + +logger = logging.getLogger(__name__) class DjangoCacheAtMostOnceStore(AtMostOnceStore): @@ -145,6 +152,34 @@ def route_namespace(self, name: str) -> str: return self._default_topic.value +class ViewerContextHook: + """ + ContextHook that propagates ViewerContext through task headers. + + Uses a single JSON header, matching the RPC layer's serialization + via ViewerContext.serialize() / ViewerContext.deserialize(). + """ + + HEADER = "sentry-viewer-context" + + def on_dispatch(self, headers: MutableMapping[str, Any]) -> None: + ctx = get_viewer_context() + if ctx is None: + return + headers[self.HEADER] = orjson.dumps(ctx.serialize()).decode() + + def on_execute(self, headers: dict[str, str]) -> contextlib.AbstractContextManager[None]: + raw = headers.get(self.HEADER) + if not raw: + return contextlib.nullcontext() + try: + ctx = ViewerContext.deserialize(orjson.loads(raw)) + except (orjson.JSONDecodeError, TypeError, KeyError, AttributeError): + logger.exception("Failed to deserialize viewer context header") + return contextlib.nullcontext() + return viewer_context_scope(ctx) + + _producer_local = threading.local() diff --git a/src/sentry/taskworker/runtime.py b/src/sentry/taskworker/runtime.py index 9a46696f07f2bf..2ccba440dc0e46 100644 --- a/src/sentry/taskworker/runtime.py +++ b/src/sentry/taskworker/runtime.py @@ -6,6 +6,7 @@ DjangoCacheAtMostOnceStore, SentryMetricsBackend, SentryRouter, + ViewerContextHook, make_producer, ) @@ -15,6 +16,7 @@ metrics_class=SentryMetricsBackend(), router_class=SentryRouter(), at_most_once_store=DjangoCacheAtMostOnceStore(cache), + context_hooks=[ViewerContextHook()], ) app.set_config( { diff --git a/tests/sentry/taskworker/test_adapters.py b/tests/sentry/taskworker/test_adapters.py index ac8e9d9aa7732f..65c4e16505fc1e 100644 --- a/tests/sentry/taskworker/test_adapters.py +++ b/tests/sentry/taskworker/test_adapters.py @@ -1,10 +1,19 @@ +import contextlib + +import orjson import pytest from django.test.utils import override_settings from taskbroker_client.registry import TaskRegistry from sentry.conf.types.kafka_definition import Topic from sentry.silo.base import SiloMode -from sentry.taskworker.adapters import SentryMetricsBackend, SentryRouter, make_producer +from sentry.taskworker.adapters import ( + SentryMetricsBackend, + SentryRouter, + ViewerContextHook, + make_producer, +) +from sentry.viewer_context import ActorType, ViewerContext, get_viewer_context, viewer_context_scope @pytest.mark.django_db @@ -51,3 +60,99 @@ def test_default_router_topic_control_silo() -> None: router = SentryRouter() topic = router.route_namespace("test.tasks.test_router.control") assert topic == Topic.TASKWORKER_CONTROL.value + + +class TestViewerContextHook: + def test_on_dispatch_with_context(self) -> None: + hook = ViewerContextHook() + headers: dict[str, str] = {} + ctx = ViewerContext(organization_id=42, user_id=7, actor_type=ActorType.USER) + with viewer_context_scope(ctx): + hook.on_dispatch(headers) + + payload = orjson.loads(headers["sentry-viewer-context"]) + assert payload["organization_id"] == 42 + assert payload["user_id"] == 7 + assert payload["actor_type"] == "user" + + def test_on_dispatch_without_context(self) -> None: + hook = ViewerContextHook() + headers: dict[str, str] = {} + hook.on_dispatch(headers) + + assert "sentry-viewer-context" not in headers + + def test_on_dispatch_partial_context(self) -> None: + hook = ViewerContextHook() + headers: dict[str, str] = {} + ctx = ViewerContext(organization_id=42, actor_type=ActorType.SYSTEM) + with viewer_context_scope(ctx): + hook.on_dispatch(headers) + + payload = orjson.loads(headers["sentry-viewer-context"]) + assert payload["organization_id"] == 42 + assert "user_id" not in payload + assert payload["actor_type"] == "system" + + def test_on_execute_restores_context(self) -> None: + hook = ViewerContextHook() + headers = { + "sentry-viewer-context": orjson.dumps( + {"organization_id": 42, "user_id": 7, "actor_type": "user"} + ).decode(), + } + with hook.on_execute(headers): + ctx = get_viewer_context() + assert ctx is not None + assert ctx.organization_id == 42 + assert ctx.user_id == 7 + assert ctx.actor_type == ActorType.USER + + assert get_viewer_context() is None + + def test_on_execute_no_headers(self) -> None: + hook = ViewerContextHook() + cm = hook.on_execute({}) + assert isinstance(cm, contextlib.nullcontext) + + def test_on_execute_partial_headers(self) -> None: + hook = ViewerContextHook() + headers = { + "sentry-viewer-context": orjson.dumps( + {"organization_id": 99, "actor_type": "integration"} + ).decode(), + } + with hook.on_execute(headers): + ctx = get_viewer_context() + assert ctx is not None + assert ctx.organization_id == 99 + assert ctx.user_id is None + assert ctx.actor_type == ActorType.INTEGRATION + + def test_roundtrip(self) -> None: + """Dispatch then execute produces the same ViewerContext.""" + hook = ViewerContextHook() + headers: dict[str, str] = {} + + original = ViewerContext(organization_id=123, user_id=456, actor_type=ActorType.USER) + with viewer_context_scope(original): + hook.on_dispatch(headers) + + with hook.on_execute(headers): + restored = get_viewer_context() + assert restored is not None + assert restored.organization_id == original.organization_id + assert restored.user_id == original.user_id + assert restored.actor_type == original.actor_type + + def test_on_execute_malformed_json(self) -> None: + hook = ViewerContextHook() + headers = {"sentry-viewer-context": "not-valid-json{"} + cm = hook.on_execute(headers) + assert isinstance(cm, contextlib.nullcontext) + + def test_on_execute_non_dict_json(self) -> None: + hook = ViewerContextHook() + headers = {"sentry-viewer-context": "[1, 2, 3]"} + cm = hook.on_execute(headers) + assert isinstance(cm, contextlib.nullcontext)