Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion src/sentry/taskworker/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@

from __future__ import annotations

import contextlib
import logging
import threading
from collections.abc import MutableMapping
from contextlib import contextmanager
from typing import Generator
from typing import Any, Generator

import orjson
from arroyo.backends.kafka import KafkaProducer
from django.conf import settings
from django.core.cache.backends.base import BaseCache
Expand All @@ -26,6 +30,9 @@
from sentry.utils import metrics as sentry_metrics
from sentry.utils.arroyo_producer import SingletonProducer, get_arroyo_producer
from sentry.utils.memory import track_memory_usage as sentry_track_memory_usage
from sentry.viewer_context import ViewerContext, get_viewer_context, viewer_context_scope

logger = logging.getLogger(__name__)


class DjangoCacheAtMostOnceStore(AtMostOnceStore):
Expand Down Expand Up @@ -145,6 +152,34 @@ def route_namespace(self, name: str) -> str:
return self._default_topic.value


class ViewerContextHook:
"""
ContextHook that propagates ViewerContext through task headers.

Uses a single JSON header, matching the RPC layer's serialization
via ViewerContext.serialize() / ViewerContext.deserialize().
"""

HEADER = "sentry-viewer-context"

def on_dispatch(self, headers: MutableMapping[str, Any]) -> None:
ctx = get_viewer_context()
if ctx is None:
return
headers[self.HEADER] = orjson.dumps(ctx.serialize()).decode()

def on_execute(self, headers: dict[str, str]) -> contextlib.AbstractContextManager[None]:
raw = headers.get(self.HEADER)
if not raw:
return contextlib.nullcontext()
try:
ctx = ViewerContext.deserialize(orjson.loads(raw))
except (orjson.JSONDecodeError, TypeError, KeyError, AttributeError):
logger.exception("Failed to deserialize viewer context header")
return contextlib.nullcontext()
Comment thread
sentry[bot] marked this conversation as resolved.
return viewer_context_scope(ctx)


_producer_local = threading.local()


Expand Down
2 changes: 2 additions & 0 deletions src/sentry/taskworker/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
DjangoCacheAtMostOnceStore,
SentryMetricsBackend,
SentryRouter,
ViewerContextHook,
make_producer,
)

Expand All @@ -15,6 +16,7 @@
metrics_class=SentryMetricsBackend(),
router_class=SentryRouter(),
at_most_once_store=DjangoCacheAtMostOnceStore(cache),
context_hooks=[ViewerContextHook()],
Comment thread
sentry-warden[bot] marked this conversation as resolved.
)
app.set_config(
{
Expand Down
107 changes: 106 additions & 1 deletion tests/sentry/taskworker/test_adapters.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
import contextlib

import orjson
import pytest
from django.test.utils import override_settings
from taskbroker_client.registry import TaskRegistry

from sentry.conf.types.kafka_definition import Topic
from sentry.silo.base import SiloMode
from sentry.taskworker.adapters import SentryMetricsBackend, SentryRouter, make_producer
from sentry.taskworker.adapters import (
SentryMetricsBackend,
SentryRouter,
ViewerContextHook,
make_producer,
)
from sentry.viewer_context import ActorType, ViewerContext, get_viewer_context, viewer_context_scope


@pytest.mark.django_db
Expand Down Expand Up @@ -51,3 +60,99 @@ def test_default_router_topic_control_silo() -> None:
router = SentryRouter()
topic = router.route_namespace("test.tasks.test_router.control")
assert topic == Topic.TASKWORKER_CONTROL.value


class TestViewerContextHook:
def test_on_dispatch_with_context(self) -> None:
hook = ViewerContextHook()
headers: dict[str, str] = {}
ctx = ViewerContext(organization_id=42, user_id=7, actor_type=ActorType.USER)
with viewer_context_scope(ctx):
hook.on_dispatch(headers)

payload = orjson.loads(headers["sentry-viewer-context"])
assert payload["organization_id"] == 42
assert payload["user_id"] == 7
assert payload["actor_type"] == "user"

def test_on_dispatch_without_context(self) -> None:
hook = ViewerContextHook()
headers: dict[str, str] = {}
hook.on_dispatch(headers)

assert "sentry-viewer-context" not in headers

def test_on_dispatch_partial_context(self) -> None:
hook = ViewerContextHook()
headers: dict[str, str] = {}
ctx = ViewerContext(organization_id=42, actor_type=ActorType.SYSTEM)
with viewer_context_scope(ctx):
hook.on_dispatch(headers)

payload = orjson.loads(headers["sentry-viewer-context"])
assert payload["organization_id"] == 42
assert "user_id" not in payload
assert payload["actor_type"] == "system"

def test_on_execute_restores_context(self) -> None:
hook = ViewerContextHook()
headers = {
"sentry-viewer-context": orjson.dumps(
{"organization_id": 42, "user_id": 7, "actor_type": "user"}
).decode(),
}
with hook.on_execute(headers):
ctx = get_viewer_context()
assert ctx is not None
assert ctx.organization_id == 42
assert ctx.user_id == 7
assert ctx.actor_type == ActorType.USER

assert get_viewer_context() is None

def test_on_execute_no_headers(self) -> None:
hook = ViewerContextHook()
cm = hook.on_execute({})
assert isinstance(cm, contextlib.nullcontext)

def test_on_execute_partial_headers(self) -> None:
hook = ViewerContextHook()
headers = {
"sentry-viewer-context": orjson.dumps(
{"organization_id": 99, "actor_type": "integration"}
).decode(),
}
with hook.on_execute(headers):
ctx = get_viewer_context()
assert ctx is not None
assert ctx.organization_id == 99
assert ctx.user_id is None
assert ctx.actor_type == ActorType.INTEGRATION

def test_roundtrip(self) -> None:
"""Dispatch then execute produces the same ViewerContext."""
hook = ViewerContextHook()
headers: dict[str, str] = {}

original = ViewerContext(organization_id=123, user_id=456, actor_type=ActorType.USER)
with viewer_context_scope(original):
hook.on_dispatch(headers)

with hook.on_execute(headers):
restored = get_viewer_context()
assert restored is not None
assert restored.organization_id == original.organization_id
assert restored.user_id == original.user_id
assert restored.actor_type == original.actor_type

def test_on_execute_malformed_json(self) -> None:
hook = ViewerContextHook()
headers = {"sentry-viewer-context": "not-valid-json{"}
cm = hook.on_execute(headers)
assert isinstance(cm, contextlib.nullcontext)

def test_on_execute_non_dict_json(self) -> None:
hook = ViewerContextHook()
headers = {"sentry-viewer-context": "[1, 2, 3]"}
cm = hook.on_execute(headers)
assert isinstance(cm, contextlib.nullcontext)
Loading