From 0cb7501c49c48ca5656550dd5afb4bad9e8ef057 Mon Sep 17 00:00:00 2001 From: Vojta Bartos Date: Tue, 19 May 2026 17:23:33 +0200 Subject: [PATCH 01/14] feat(llm-gateway): register slack_app product as billable --- .../src/llm_gateway/callbacks/posthog.py | 12 ++++ .../src/llm_gateway/products/config.py | 46 ++++++++----- .../tests/callbacks/test_posthog.py | 68 +++++++++++++++++++ .../llm-gateway/tests/test_product_config.py | 28 ++++++++ 4 files changed, 138 insertions(+), 16 deletions(-) diff --git a/services/llm-gateway/src/llm_gateway/callbacks/posthog.py b/services/llm-gateway/src/llm_gateway/callbacks/posthog.py index 64ee60acadca..34441a76220f 100644 --- a/services/llm-gateway/src/llm_gateway/callbacks/posthog.py +++ b/services/llm-gateway/src/llm_gateway/callbacks/posthog.py @@ -10,6 +10,7 @@ from llm_gateway.auth.models import resolve_distinct_id from llm_gateway.callbacks.base import InstrumentedCallback +from llm_gateway.products.config import get_product_config from llm_gateway.request_context import ( get_auth_user, get_posthog_flags, @@ -53,6 +54,15 @@ def _replace_binary_content(data: Any) -> Any: _TRUNCATION_MARKER = "[truncated: content too large for capture]" _TRUNCATABLE_FIELDS = ("$ai_output_choices", "$ai_input") + +def _is_product_billable(product: str) -> bool: + """Look up the product's billable flag in the central registry. False for + unknown products so we never accidentally bill calls we can't attribute. + """ + config = get_product_config(product) + return bool(config and config.billable) + + # Stable namespace for hashing non-UUID trace identifiers (e.g. Claude Code's # JSON-encoded session blobs sent via Anthropic's metadata.user_id) into a # deterministic UUID. Generated once and frozen so the same input always maps @@ -147,6 +157,7 @@ async def _on_success( "$ai_trace_id": trace_id, "$ai_span_id": str(uuid4()), "ai_product": product, + "$ai_billable": _is_product_billable(product), } posthog_properties = get_posthog_properties() or {} @@ -226,6 +237,7 @@ async def _on_failure( "$ai_is_error": True, "$ai_error": standard_logging_object.get("error_str", ""), "ai_product": product, + "$ai_billable": _is_product_billable(product), } posthog_properties = get_posthog_properties() or {} diff --git a/services/llm-gateway/src/llm_gateway/products/config.py b/services/llm-gateway/src/llm_gateway/products/config.py index 95febfb0ba05..2bda40439fef 100644 --- a/services/llm-gateway/src/llm_gateway/products/config.py +++ b/services/llm-gateway/src/llm_gateway/products/config.py @@ -16,6 +16,10 @@ class ProductConfig: allowed_application_ids: frozenset[str] | None = frozenset() allowed_models: frozenset[str] | None = None # None = all allowed allow_api_keys: bool = True + # Tag emitted $ai_generation events with $ai_billable=true so the PHAI + # daily aggregator (posthog/tasks/usage_report.py) rolls them into the + # customer team's AI credits bucket. + billable: bool = False BEDROCK_MODELS = BEDROCK_MODEL_IDS @@ -28,6 +32,25 @@ class ProductConfig: WIZARD_US_APP_ID = "019a0c79-b69d-0000-f31b-b41345208c9d" WIZARD_EU_APP_ID = "019a12d0-6edd-0000-0458-86616af3a3db" +# Shared by `posthog_code` and `slack_app` — the agent that runs in the sandbox +# is the same code regardless of where the task was initiated, so the model +# allowlist is identical. +_POSTHOG_CODE_AGENT_MODELS: Final[frozenset[str]] = frozenset( + { + "claude-opus-4-5", + "claude-opus-4-6", + "claude-opus-4-7", + "claude-sonnet-4-5", + "claude-sonnet-4-6", + "claude-haiku-4-5", + "gpt-5.5", + "gpt-5.4", + "gpt-5.3-codex", + "gpt-5.2", + "gpt-5-mini", + } +) + PRODUCTS: Final[dict[str, ProductConfig]] = { "llm_gateway": ProductConfig( allowed_application_ids=None, @@ -36,22 +59,7 @@ class ProductConfig: ), "posthog_code": ProductConfig( allowed_application_ids=frozenset({POSTHOG_CODE_US_APP_ID, POSTHOG_CODE_EU_APP_ID}), - allowed_models=frozenset( - { - "claude-opus-4-5", - "claude-opus-4-6", - "claude-opus-4-7", - "claude-sonnet-4-5", - "claude-sonnet-4-6", - "claude-haiku-4-5", - "gpt-5.5", - "gpt-5.4", - "gpt-5.3-codex", - "gpt-5.2", - "gpt-5-mini", - } - | BEDROCK_MODELS - ), + allowed_models=_POSTHOG_CODE_AGENT_MODELS | BEDROCK_MODELS, allow_api_keys=False, ), "background_agents": ProductConfig( @@ -72,6 +80,12 @@ class ProductConfig: ), allow_api_keys=False, ), + "slack_app": ProductConfig( + allowed_application_ids=frozenset({POSTHOG_CODE_US_APP_ID, POSTHOG_CODE_EU_APP_ID}), + allowed_models=_POSTHOG_CODE_AGENT_MODELS | BEDROCK_MODELS, + allow_api_keys=False, + billable=True, + ), "wizard": ProductConfig( allowed_application_ids=frozenset({WIZARD_US_APP_ID, WIZARD_EU_APP_ID}), allowed_models=None, diff --git a/services/llm-gateway/tests/callbacks/test_posthog.py b/services/llm-gateway/tests/callbacks/test_posthog.py index 9a860cc8ce8e..2d2f165cb82e 100644 --- a/services/llm-gateway/tests/callbacks/test_posthog.py +++ b/services/llm-gateway/tests/callbacks/test_posthog.py @@ -303,6 +303,74 @@ async def test_on_failure_includes_ai_product( props = mock_client.capture.call_args.kwargs["properties"] assert props["ai_product"] == product + @pytest.mark.asyncio + async def test_on_success_marks_slack_app_billable( + self, + callback: PostHogCallback, + auth_user: AuthenticatedUser, + standard_logging_object: dict, + mock_posthog_client: tuple, + ) -> None: + _, mock_client = mock_posthog_client + kwargs = {"standard_logging_object": standard_logging_object, "litellm_params": {}} + + with ( + patch("llm_gateway.callbacks.posthog.get_auth_user", return_value=auth_user), + patch("llm_gateway.callbacks.posthog.get_product", return_value="slack_app"), + ): + await callback._on_success(kwargs, None, 0.0, 1.0, end_user_id=None) + + props = mock_client.capture.call_args.kwargs["properties"] + assert props["$ai_billable"] is True + + @pytest.mark.asyncio + @pytest.mark.parametrize("product", ["posthog_code", "background_agents", "wizard", "llm_gateway"]) + async def test_on_success_does_not_mark_other_products_billable( + self, + callback: PostHogCallback, + auth_user: AuthenticatedUser, + standard_logging_object: dict, + product: str, + mock_posthog_client: tuple, + ) -> None: + _, mock_client = mock_posthog_client + kwargs = {"standard_logging_object": standard_logging_object, "litellm_params": {}} + + with ( + patch("llm_gateway.callbacks.posthog.get_auth_user", return_value=auth_user), + patch("llm_gateway.callbacks.posthog.get_product", return_value=product), + ): + await callback._on_success(kwargs, None, 0.0, 1.0, end_user_id=None) + + props = mock_client.capture.call_args.kwargs["properties"] + assert props["$ai_billable"] is False + + @pytest.mark.asyncio + async def test_on_failure_marks_slack_app_billable( + self, + callback: PostHogCallback, + auth_user: AuthenticatedUser, + mock_posthog_client: tuple, + ) -> None: + _, mock_client = mock_posthog_client + kwargs = { + "standard_logging_object": { + "model": "claude-sonnet-4-6", + "custom_llm_provider": "anthropic", + "error_str": "boom", + }, + "litellm_params": {}, + } + + with ( + patch("llm_gateway.callbacks.posthog.get_auth_user", return_value=auth_user), + patch("llm_gateway.callbacks.posthog.get_product", return_value="slack_app"), + ): + await callback._on_failure(kwargs, None, 0.0, 1.0, end_user_id=None) + + props = mock_client.capture.call_args.kwargs["properties"] + assert props["$ai_billable"] is True + @pytest.mark.asyncio async def test_on_success_uses_passed_end_user_id( self, diff --git a/services/llm-gateway/tests/test_product_config.py b/services/llm-gateway/tests/test_product_config.py index d19e8ded2e4b..7a68d45a5257 100644 --- a/services/llm-gateway/tests/test_product_config.py +++ b/services/llm-gateway/tests/test_product_config.py @@ -255,6 +255,34 @@ def test_slack_twig_rejects_non_haiku_models(self): assert error is not None assert "not allowed" in error + @pytest.mark.parametrize( + "model", + [ + "claude-opus-4-7", + "claude-sonnet-4-6", + "claude-haiku-4-5", + "gpt-5.3-codex", + ], + ) + def test_slack_app_allows_agent_models(self, model: str): + allowed, error = check_product_access("slack_app", "oauth_access_token", POSTHOG_CODE_US_APP_ID, model) + assert allowed is True + assert error is None + + def test_slack_app_rejects_api_keys(self): + allowed, error = check_product_access("slack_app", "personal_api_key", None, "claude-sonnet-4-6") + assert allowed is False + assert error is not None + assert "requires OAuth" in error + + def test_slack_app_rejects_unauthorized_oauth_app(self): + allowed, error = check_product_access( + "slack_app", "oauth_access_token", "00000000-0000-0000-0000-000000000000", "claude-sonnet-4-6" + ) + assert allowed is False + assert error is not None + assert "not authorized" in error + class TestBackwardsCompatibility: def test_twig_app_id_constants_are_aliases(self): From b685a8fa9f70d866f25e8a875676200e1029eea3 Mon Sep 17 00:00:00 2001 From: Vojta Bartos Date: Tue, 19 May 2026 17:23:52 +0200 Subject: [PATCH 02/14] feat(slack-bot): gate task creation and follow-ups on AI credits quota --- .../temporal/ai/posthog_code_slack_mention.py | 72 +++++++++++++++++++ .../backend/tests/test_followup_forwarding.py | 61 ++++++++++++++++ 2 files changed, 133 insertions(+) diff --git a/posthog/temporal/ai/posthog_code_slack_mention.py b/posthog/temporal/ai/posthog_code_slack_mention.py index b3ab69cb693b..0cc43bd9f3d3 100644 --- a/posthog/temporal/ai/posthog_code_slack_mention.py +++ b/posthog/temporal/ai/posthog_code_slack_mention.py @@ -61,6 +61,55 @@ def _safe_react(client: Any, channel: str, timestamp: str, name: str) -> None: _INITIATOR_PLACEHOLDER = "" +_QUOTA_EXHAUSTED_MESSAGE = ( + "Your team has used its monthly PostHog AI credits. " + "Top up at https://us.posthog.com/organization/billing to continue." +) + + +def _block_if_team_over_quota( + *, + integration: Any, + slack: Any, + channel: str, + thread_ts: str, + slack_user_id: str, + context: Literal["task_create", "followup"], +) -> bool: + """Reject a Slack-bot turn when the team is over its AI credits quota. + + Mirrors PHAI's enforcement model (ee/api/conversation.py): every + user-initiated turn — new mention or follow-up reply — is gated against the + same Redis-backed `QuotaResource.AI_CREDITS` set. Returns True when the + team is blocked, posts a friendly in-thread denial as a side effect. + """ + from products.slack_app.backend.api import _post_slack_user_feedback + + from ee.billing.quota_limiting import QuotaLimitingCaches, QuotaResource, is_team_limited + + if not is_team_limited( + integration.team.api_token, QuotaResource.AI_CREDITS, QuotaLimitingCaches.QUOTA_LIMITER_CACHE_KEY + ): + return False + + logger.info( + "posthog_code_slack_blocked_by_quota", + context=context, + team_id=integration.team_id, + channel=channel, + thread_ts=thread_ts, + ) + _post_slack_user_feedback( + slack, + channel, + slack_user_id, + thread_ts, + _QUOTA_EXHAUSTED_MESSAGE, + prefer_thread_message=True, + ) + return True + + def _strip_context_tag(text: str) -> str: return re.sub(rf"", "", text, flags=re.IGNORECASE) @@ -1036,6 +1085,19 @@ def create_posthog_code_task_for_repo_activity( integration_id=inputs.slack_team_id, ) slack = SlackIntegration(integration) + + # Refuse before the :seedling: reaction or the permalink fetch: a denied + # mention should not first ack-react and then refuse a second later. + if _block_if_team_over_quota( + integration=integration, + slack=slack, + channel=channel, + thread_ts=thread_ts, + slack_user_id=slack_user_id, + context="task_create", + ): + return + user_message_ts = event.get("ts") if user_message_ts: _safe_react(slack.client, channel, user_message_ts, "seedling") @@ -1241,6 +1303,16 @@ def forward_posthog_code_followup_activity( ) return True + if _block_if_team_over_quota( + integration=integration, + slack=slack, + channel=channel, + thread_ts=thread_ts, + slack_user_id=slack_user_id, + context="followup", + ): + return True + if task_run.is_terminal: return _resume_task_with_new_run( mapping, diff --git a/products/slack_app/backend/tests/test_followup_forwarding.py b/products/slack_app/backend/tests/test_followup_forwarding.py index 353e26d609b1..6e1fcb672aaf 100644 --- a/products/slack_app/backend/tests/test_followup_forwarding.py +++ b/products/slack_app/backend/tests/test_followup_forwarding.py @@ -33,6 +33,16 @@ def _command_result(**kwargs): return SimpleNamespace(**defaults) +def _assert_quota_denial_posted(mock_slack_instance: MagicMock, channel: str, thread_ts: str) -> None: + denial_calls = [ + call + for call in mock_slack_instance.client.chat_postMessage.call_args_list + if call.kwargs.get("channel") == channel and call.kwargs.get("thread_ts") == thread_ts + ] + assert denial_calls, "Expected an in-thread denial message when over quota" + assert "PostHog AI credits" in denial_calls[0].kwargs["text"] + + class TestSlackThreadTaskMapping(TestCase): def setUp(self): self.Task = apps.get_model("tasks", "Task") @@ -401,6 +411,36 @@ def test_description_encloses_thread_context_in_a_tag(self, mock_slack_cls, mock ) assert task.description.endswith("do something") + @patch("products.tasks.backend.temporal.client.execute_task_processing_workflow") + @patch("posthog.temporal.ai.posthog_code_slack_mention.SlackIntegration") + @patch("ee.billing.quota_limiting.is_team_limited", return_value=True) + def test_quota_exceeded_blocks_task_creation_with_thread_message( + self, + _mock_is_team_limited, + mock_slack_cls, + mock_execute_workflow, + ): + mock_slack_instance = MagicMock() + mock_slack_cls.return_value = mock_slack_instance + + inputs = _make_inputs(self.integration.id) + create_posthog_code_task_for_repo_activity( + inputs, + "C123", + "1234.5678", + "U_ALICE", + self.user.id, + inputs.event, + [{"user": "U_ALICE", "text": "do something"}], + None, + ) + + # No task created, no workflow started. + assert not self.Task.objects.filter(team=self.team).exists() + mock_execute_workflow.assert_not_called() + + _assert_quota_denial_posted(mock_slack_instance, "C123", "1234.5678") + class TestForwardPostHogCodeFollowupActivity(TestCase): def setUp(self): @@ -446,6 +486,27 @@ def test_no_mapping_returns_false(self): ) assert result is False + @patch("ee.billing.quota_limiting.is_team_limited", return_value=True) + @patch("posthog.temporal.ai.posthog_code_slack_mention.SlackIntegration") + def test_quota_exceeded_blocks_followup_with_thread_message( + self, + mock_slack_cls, + _mock_is_team_limited, + ): + self._create_mapping() + mock_slack_instance = MagicMock() + mock_slack_cls.return_value = mock_slack_instance + + inputs = _make_inputs(self.integration.id) + result = forward_posthog_code_followup_activity( + inputs, "C123", "1234.5678", "U_ALICE", "do something", "1234.5679" + ) + + # The follow-up was handled by refusal, so the workflow shouldn't fall + # through to new-task creation. + assert result is True + _assert_quota_denial_posted(mock_slack_instance, "C123", "1234.5678") + @patch("posthog.temporal.ai.posthog_code_slack_mention.execute_task_processing_workflow") @patch("posthog.temporal.ai.posthog_code_slack_mention.SlackIntegration") def test_terminal_run_resumes_same_task(self, mock_slack_cls, mock_execute_workflow): From c008271b4e547fc19d1df90eded63b4ae585ab62 Mon Sep 17 00:00:00 2001 From: Vojta Bartos Date: Thu, 21 May 2026 12:12:02 +0200 Subject: [PATCH 03/14] feat(llm-gateway): rename slack-posthog-code product to slack_app_routing and bill it --- posthog/llm/gateway_client.py | 1 + products/slack_app/backend/api.py | 2 +- .../src/llm_gateway/products/config.py | 6 ++-- .../tests/callbacks/test_posthog.py | 6 ++-- .../llm-gateway/tests/test_product_config.py | 33 +++++++++++++++---- 5 files changed, 37 insertions(+), 11 deletions(-) diff --git a/posthog/llm/gateway_client.py b/posthog/llm/gateway_client.py index 67f63d0ae5c1..c4727016bda8 100644 --- a/posthog/llm/gateway_client.py +++ b/posthog/llm/gateway_client.py @@ -9,6 +9,7 @@ "llm_gateway", "posthog_code", "background_agents", + "slack_app_routing", "wizard", "django", "growth", diff --git a/products/slack_app/backend/api.py b/products/slack_app/backend/api.py index d1552c6f9205..289ca4901b15 100644 --- a/products/slack_app/backend/api.py +++ b/products/slack_app/backend/api.py @@ -1034,7 +1034,7 @@ def classify_task_needs_repo( 'Respond with ONLY a JSON object: {{"needs_repo": true}} or {{"needs_repo": false}}' ) try: - client = get_llm_client("slack-posthog-code") + client = get_llm_client("slack_app_routing") response = client.chat.completions.create( model="claude-haiku-4-5-20251001", messages=[{"role": "user", "content": prompt}], diff --git a/services/llm-gateway/src/llm_gateway/products/config.py b/services/llm-gateway/src/llm_gateway/products/config.py index 2bda40439fef..f5b7abc1dd3b 100644 --- a/services/llm-gateway/src/llm_gateway/products/config.py +++ b/services/llm-gateway/src/llm_gateway/products/config.py @@ -101,10 +101,11 @@ class ProductConfig: allowed_models=None, allow_api_keys=True, ), - "slack-posthog-code": ProductConfig( + "slack_app_routing": ProductConfig( allowed_application_ids=None, allowed_models=frozenset({"claude-haiku-4-5"}), allow_api_keys=True, + billable=True, ), "growth": ProductConfig( allowed_application_ids=None, @@ -153,7 +154,8 @@ class ProductConfig: PRODUCT_ALIASES: Final[dict[str, str]] = { "array": "posthog_code", "twig": "posthog_code", - "slack-twig": "slack-posthog-code", + "slack-posthog-code": "slack_app_routing", + "slack-twig": "slack_app_routing", } diff --git a/services/llm-gateway/tests/callbacks/test_posthog.py b/services/llm-gateway/tests/callbacks/test_posthog.py index 2d2f165cb82e..fe1e2897f7f7 100644 --- a/services/llm-gateway/tests/callbacks/test_posthog.py +++ b/services/llm-gateway/tests/callbacks/test_posthog.py @@ -304,11 +304,13 @@ async def test_on_failure_includes_ai_product( assert props["ai_product"] == product @pytest.mark.asyncio - async def test_on_success_marks_slack_app_billable( + @pytest.mark.parametrize("product", ["slack_app", "slack_app_routing"]) + async def test_on_success_marks_slack_products_billable( self, callback: PostHogCallback, auth_user: AuthenticatedUser, standard_logging_object: dict, + product: str, mock_posthog_client: tuple, ) -> None: _, mock_client = mock_posthog_client @@ -316,7 +318,7 @@ async def test_on_success_marks_slack_app_billable( with ( patch("llm_gateway.callbacks.posthog.get_auth_user", return_value=auth_user), - patch("llm_gateway.callbacks.posthog.get_product", return_value="slack_app"), + patch("llm_gateway.callbacks.posthog.get_product", return_value=product), ): await callback._on_success(kwargs, None, 0.0, 1.0, end_user_id=None) diff --git a/services/llm-gateway/tests/test_product_config.py b/services/llm-gateway/tests/test_product_config.py index 7a68d45a5257..a1a7beec56fa 100644 --- a/services/llm-gateway/tests/test_product_config.py +++ b/services/llm-gateway/tests/test_product_config.py @@ -255,6 +255,23 @@ def test_slack_twig_rejects_non_haiku_models(self): assert error is not None assert "not allowed" in error + def test_slack_app_routing_allows_claude_haiku_via_api_key(self): + allowed, error = check_product_access("slack_app_routing", "personal_api_key", None, "claude-haiku-4-5") + assert allowed is True + assert error is None + + def test_slack_app_routing_rejects_non_haiku_models(self): + allowed, error = check_product_access("slack_app_routing", "personal_api_key", None, "claude-sonnet-4-5") + assert allowed is False + assert error is not None + assert "not allowed" in error + + def test_slack_posthog_code_alias_still_resolves(self): + # Legacy alias kept for backward compat (old URL paths, Django integration kind). + allowed, error = check_product_access("slack-posthog-code", "personal_api_key", None, "claude-haiku-4-5") + assert allowed is True + assert error is None + @pytest.mark.parametrize( "model", [ @@ -294,10 +311,11 @@ def test_twig_app_id_constants_are_aliases(self): [ ("twig", "posthog_code"), ("array", "posthog_code"), - ("slack-twig", "slack-posthog-code"), + ("slack-twig", "slack_app_routing"), + ("slack-posthog-code", "slack_app_routing"), ], ) - def test_aliases_resolve_to_posthog_code(self, alias: str, target: str): + def test_aliases_resolve_to_canonical_product(self, alias: str, target: str): assert resolve_product_alias(alias) == target def test_twig_alias_returns_same_config_as_posthog_code(self): @@ -312,9 +330,11 @@ def test_twig_alias_validates_to_posthog_code(self): def test_array_alias_validates_to_posthog_code(self): assert validate_product("array") == "posthog_code" - def test_slack_twig_alias_resolves_to_slack_posthog_code(self): - assert get_product_config("slack-twig") is get_product_config("slack-posthog-code") - assert validate_product("slack-twig") == "slack-posthog-code" + def test_slack_aliases_resolve_to_slack_app_routing(self): + assert get_product_config("slack-twig") is get_product_config("slack_app_routing") + assert get_product_config("slack-posthog-code") is get_product_config("slack_app_routing") + assert validate_product("slack-twig") == "slack_app_routing" + assert validate_product("slack-posthog-code") == "slack_app_routing" class TestValidateProduct: @@ -338,7 +358,8 @@ def test_alias_resolves_to_target_product(self, alias: str, target: str): def test_resolve_product_alias_returns_alias_target(self): assert resolve_product_alias("array") == "posthog_code" assert resolve_product_alias("twig") == "posthog_code" - assert resolve_product_alias("slack-twig") == "slack-posthog-code" + assert resolve_product_alias("slack-twig") == "slack_app_routing" + assert resolve_product_alias("slack-posthog-code") == "slack_app_routing" def test_resolve_product_alias_returns_input_if_not_aliased(self): assert resolve_product_alias("wizard") == "wizard" From 87c5c91b2f040206266f1d01064d97de91ce82bb Mon Sep 17 00:00:00 2001 From: Vojta Bartos Date: Thu, 21 May 2026 12:40:00 +0200 Subject: [PATCH 04/14] feat(llm-gateway): gate billable products on AI credits quota at the edge MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds BillableCreditThrottle that, for any product with ProductConfig.billable=True, ZSCOREs the team's API token against @posthog/quota-limits/ai_credits — the same Redis set Django's daily aggregator writes to. Over-quota teams get a 429 with a friendly denial before the request reaches the upstream LLM provider. Closes the cross-surface gap for slack-origin tasks: PostHog Code desktop replies bypass the Temporal pre-flight, but cannot bypass the gateway. Plumbs team_api_token through AuthenticatedUser (LEFT JOIN posthog_team in both authenticators) so the throttle has the key it needs to look up Redis. Fails open when Redis or team_api_token is missing — matches the rest of the throttle chain. A startup warning surfaces when Redis isn't configured so we notice silent disablement in production. --- .../src/llm_gateway/auth/authenticators.py | 10 +- .../src/llm_gateway/auth/models.py | 4 + services/llm-gateway/src/llm_gateway/main.py | 2 + .../billable_credits_throttle.py | 66 +++++++++++++ services/llm-gateway/tests/conftest.py | 2 + .../tests/test_billable_credits_throttle.py | 95 +++++++++++++++++++ 6 files changed, 177 insertions(+), 2 deletions(-) create mode 100644 services/llm-gateway/src/llm_gateway/rate_limiting/billable_credits_throttle.py create mode 100644 services/llm-gateway/tests/test_billable_credits_throttle.py diff --git a/services/llm-gateway/src/llm_gateway/auth/authenticators.py b/services/llm-gateway/src/llm_gateway/auth/authenticators.py index 88685f50bf7c..9a67467d759f 100644 --- a/services/llm-gateway/src/llm_gateway/auth/authenticators.py +++ b/services/llm-gateway/src/llm_gateway/auth/authenticators.py @@ -62,9 +62,11 @@ async def authenticate(self, token_hash: str, pool: asyncpg.Pool) -> Authenticat async with acquire_connection(pool) as conn: row = await conn.fetchrow( """ - SELECT pak.id, pak.user_id, pak.scopes, u.current_team_id, u.distinct_id + SELECT pak.id, pak.user_id, pak.scopes, u.current_team_id, u.distinct_id, + t.api_token AS team_api_token FROM posthog_personalapikey pak JOIN posthog_user u ON pak.user_id = u.id + LEFT JOIN posthog_team t ON u.current_team_id = t.id WHERE pak.secure_value = $1 AND u.is_active = true """, token_hash, @@ -83,6 +85,7 @@ async def authenticate(self, token_hash: str, pool: asyncpg.Pool) -> Authenticat auth_method=self.auth_type, distinct_id=row["distinct_id"], scopes=scopes, + team_api_token=row["team_api_token"], ) @@ -108,9 +111,11 @@ async def authenticate(self, token_hash: str, pool: asyncpg.Pool) -> Authenticat row = await conn.fetchrow( """ SELECT oat.id, oat.user_id, oat.scope, oat.expires, - oat.application_id, u.current_team_id, u.distinct_id + oat.application_id, u.current_team_id, u.distinct_id, + t.api_token AS team_api_token FROM posthog_oauthaccesstoken oat JOIN posthog_user u ON oat.user_id = u.id + LEFT JOIN posthog_team t ON u.current_team_id = t.id WHERE oat.token_checksum = $1 AND u.is_active = true """, token_hash, @@ -138,4 +143,5 @@ async def authenticate(self, token_hash: str, pool: asyncpg.Pool) -> Authenticat scopes=scopes, token_expires_at=expires, application_id=str(row["application_id"]), + team_api_token=row["team_api_token"], ) diff --git a/services/llm-gateway/src/llm_gateway/auth/models.py b/services/llm-gateway/src/llm_gateway/auth/models.py index 5bcad3e287cd..9c8f407e0f0b 100644 --- a/services/llm-gateway/src/llm_gateway/auth/models.py +++ b/services/llm-gateway/src/llm_gateway/auth/models.py @@ -11,6 +11,10 @@ class AuthenticatedUser: scopes: list[str] | None = None token_expires_at: datetime | None = None application_id: str | None = None + # The team's `posthog_team.api_token` — used by quota-limit throttles that + # read Django's `@posthog/quota-limits/...` Redis sets, which are keyed by + # team API token rather than team_id. + team_api_token: str | None = None def resolve_distinct_id(auth_user: AuthenticatedUser, end_user_id: str | None) -> str: diff --git a/services/llm-gateway/src/llm_gateway/main.py b/services/llm-gateway/src/llm_gateway/main.py index 50278966c7d9..fc483566131a 100644 --- a/services/llm-gateway/src/llm_gateway/main.py +++ b/services/llm-gateway/src/llm_gateway/main.py @@ -26,6 +26,7 @@ from llm_gateway.config import Settings, get_settings from llm_gateway.db.postgres import close_db_pool, init_db_pool from llm_gateway.metrics.prometheus import DB_POOL_SIZE, get_instrumentator +from llm_gateway.rate_limiting.billable_credits_throttle import BillableCreditThrottle from llm_gateway.rate_limiting.cost_gauge_publisher import publish_product_cost_gauges_loop from llm_gateway.rate_limiting.cost_refresh import ensure_costs_fresh from llm_gateway.rate_limiting.cost_throttles import ( @@ -158,6 +159,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: ) app.state.throttle_runner = ThrottleRunner( throttles=[ + BillableCreditThrottle(redis=app.state.redis), product_throttle, UserCostBurstThrottle(redis=app.state.redis), UserCostSustainedThrottle(redis=app.state.redis), diff --git a/services/llm-gateway/src/llm_gateway/rate_limiting/billable_credits_throttle.py b/services/llm-gateway/src/llm_gateway/rate_limiting/billable_credits_throttle.py new file mode 100644 index 000000000000..0c7184234508 --- /dev/null +++ b/services/llm-gateway/src/llm_gateway/rate_limiting/billable_credits_throttle.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +import time +from collections.abc import Callable +from typing import TYPE_CHECKING + +import structlog + +from llm_gateway.products.config import get_product_config +from llm_gateway.rate_limiting.throttles import Throttle, ThrottleContext, ThrottleResult + +if TYPE_CHECKING: + from redis.asyncio import Redis + +logger = structlog.get_logger(__name__) + + +# Mirror of ee/billing/quota_limiting.py: +# QuotaLimitingCaches.QUOTA_LIMITER_CACHE_KEY = "@posthog/quota-limits/" +# QuotaResource.AI_CREDITS = "ai_credits" +_AI_CREDITS_LIMIT_KEY = "@posthog/quota-limits/ai_credits" + + +class BillableCreditThrottle(Throttle): + """Gate billable-product LLM calls on the team's AI credits balance. + + Reads the same Redis sorted set Django populates via + ee/billing/quota_limiting.add_limited_team_tokens. Members are team API + tokens; scores are Unix timestamps marking when the limit expires. + + Fail-open when Redis is unavailable or the user's team API token isn't + known — matches the rest of the throttle chain. Without this we'd close + requests on infrastructure incidents that have nothing to do with billing. + """ + + scope = "billable_credits" + + def __init__(self, redis: Redis[bytes] | None, clock: Callable[[], float] | None = None): + self._redis = redis + self._now = clock or time.time + if redis is None: + logger.warning( + "billable_credits_throttle_disabled_no_redis", + reason="Redis client not configured; throttle is fail-open and will allow all billable calls.", + ) + + async def allow_request(self, context: ThrottleContext) -> ThrottleResult: + config = get_product_config(context.product) + if not (config and config.billable): + return ThrottleResult.allow() + + if self._redis is None or context.user.team_api_token is None: + return ThrottleResult.allow() + + score = await self._redis.zscore(_AI_CREDITS_LIMIT_KEY, context.user.team_api_token) + if score is None or score <= self._now(): + return ThrottleResult.allow() + + return ThrottleResult.deny( + detail=( + "Your team has used its monthly PostHog AI credits. " + "Top up at https://us.posthog.com/organization/billing to continue." + ), + scope=self.scope, + retry_after=max(int(score - self._now()), 1), + ) diff --git a/services/llm-gateway/tests/conftest.py b/services/llm-gateway/tests/conftest.py index c691b5a8ba24..2f6f24b7d0be 100644 --- a/services/llm-gateway/tests/conftest.py +++ b/services/llm-gateway/tests/conftest.py @@ -9,6 +9,7 @@ from llm_gateway.auth.models import AuthenticatedUser from llm_gateway.main import http_exception_handler +from llm_gateway.rate_limiting.billable_credits_throttle import BillableCreditThrottle from llm_gateway.rate_limiting.cost_throttles import ( ProductCostThrottle, UserCostBurstThrottle, @@ -27,6 +28,7 @@ def create_test_app( from llm_gateway.api.routes import router default_throttles: list[Throttle] = [ + BillableCreditThrottle(redis=None), ProductCostThrottle(redis=None), UserCostBurstThrottle(redis=None), UserCostSustainedThrottle(redis=None), diff --git a/services/llm-gateway/tests/test_billable_credits_throttle.py b/services/llm-gateway/tests/test_billable_credits_throttle.py new file mode 100644 index 000000000000..c561e4fdf4e3 --- /dev/null +++ b/services/llm-gateway/tests/test_billable_credits_throttle.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +from unittest.mock import AsyncMock + +import pytest + +from llm_gateway.auth.models import AuthenticatedUser +from llm_gateway.rate_limiting.billable_credits_throttle import ( + _AI_CREDITS_LIMIT_KEY, + BillableCreditThrottle, +) +from llm_gateway.rate_limiting.throttles import ThrottleContext + +_TEAM_TOKEN = "phc_team_under_test" + + +def _make_user(team_api_token: str | None = _TEAM_TOKEN) -> AuthenticatedUser: + return AuthenticatedUser( + user_id=1, + team_id=42, + auth_method="personal_api_key", + distinct_id="distinct-1", + scopes=["llm_gateway:read"], + team_api_token=team_api_token, + ) + + +def _make_context(product: str, user: AuthenticatedUser | None = None) -> ThrottleContext: + return ThrottleContext(user=user or _make_user(), product=product) + + +class TestBillableCreditThrottle: + @pytest.mark.asyncio + async def test_allows_non_billable_product_without_redis_lookup(self) -> None: + redis = AsyncMock() + throttle = BillableCreditThrottle(redis=redis, clock=lambda: 1_000_000) + + result = await throttle.allow_request(_make_context(product="posthog_code")) + + assert result.allowed is True + redis.zscore.assert_not_called() + + @pytest.mark.asyncio + async def test_allows_billable_product_when_team_not_limited(self) -> None: + redis = AsyncMock() + redis.zscore = AsyncMock(return_value=None) + throttle = BillableCreditThrottle(redis=redis, clock=lambda: 1_000_000) + + result = await throttle.allow_request(_make_context(product="slack_app")) + + assert result.allowed is True + redis.zscore.assert_awaited_once_with(_AI_CREDITS_LIMIT_KEY, _TEAM_TOKEN) + + @pytest.mark.asyncio + async def test_allows_billable_product_when_limit_expired(self) -> None: + redis = AsyncMock() + redis.zscore = AsyncMock(return_value=999_999.0) + throttle = BillableCreditThrottle(redis=redis, clock=lambda: 1_000_000) + + result = await throttle.allow_request(_make_context(product="slack_app")) + + assert result.allowed is True + + @pytest.mark.asyncio + async def test_denies_billable_product_when_team_currently_limited(self) -> None: + redis = AsyncMock() + redis.zscore = AsyncMock(return_value=1_003_600.0) + throttle = BillableCreditThrottle(redis=redis, clock=lambda: 1_000_000) + + result = await throttle.allow_request(_make_context(product="slack_app")) + + assert result.allowed is False + assert result.status_code == 429 + assert result.scope == "billable_credits" + assert "PostHog AI credits" in result.detail + assert result.retry_after == 3600 + + @pytest.mark.asyncio + async def test_allows_when_redis_is_not_configured(self) -> None: + throttle = BillableCreditThrottle(redis=None, clock=lambda: 1_000_000) + + result = await throttle.allow_request(_make_context(product="slack_app")) + + assert result.allowed is True + + @pytest.mark.asyncio + async def test_allows_when_team_api_token_is_missing(self) -> None: + redis = AsyncMock() + throttle = BillableCreditThrottle(redis=redis, clock=lambda: 1_000_000) + user_without_token = _make_user(team_api_token=None) + + result = await throttle.allow_request(_make_context(product="slack_app", user=user_without_token)) + + assert result.allowed is True + redis.zscore.assert_not_called() From 1b686e6707f3da1b45866623f06d0017653d33f1 Mon Sep 17 00:00:00 2001 From: Vojta Bartos Date: Mon, 25 May 2026 12:02:54 +0200 Subject: [PATCH 05/14] chore(slackbot): llm gtw using quata limits django endpoint for usage limiting --- ee/api/quota_limits.py | 67 ++++++++ ee/api/test/test_quota_limits.py | 150 ++++++++++++++++++ posthog/api/__init__.py | 9 ++ .../llm-gateway/src/llm_gateway/api/usage.py | 25 ++- .../src/llm_gateway/auth/authenticators.py | 10 +- .../src/llm_gateway/auth/models.py | 4 - .../llm-gateway/src/llm_gateway/config.py | 4 + .../src/llm_gateway/dependencies.py | 9 +- services/llm-gateway/src/llm_gateway/main.py | 7 +- .../billable_credits_throttle.py | 48 ++---- .../llm_gateway/rate_limiting/throttles.py | 1 + .../llm_gateway/services/quota_resolver.py | 126 +++++++++++++++ .../tests/callbacks/test_posthog.py | 6 +- services/llm-gateway/tests/conftest.py | 11 +- .../tests/test_billable_credits_throttle.py | 76 +++------ .../llm-gateway/tests/test_quota_resolver.py | 116 ++++++++++++++ services/llm-gateway/tests/test_usage.py | 34 ++++ 17 files changed, 588 insertions(+), 115 deletions(-) create mode 100644 ee/api/quota_limits.py create mode 100644 ee/api/test/test_quota_limits.py create mode 100644 services/llm-gateway/src/llm_gateway/services/quota_resolver.py create mode 100644 services/llm-gateway/tests/test_quota_resolver.py diff --git a/ee/api/quota_limits.py b/ee/api/quota_limits.py new file mode 100644 index 000000000000..eb7dfe66d7d1 --- /dev/null +++ b/ee/api/quota_limits.py @@ -0,0 +1,67 @@ +"""Expose a team's quota-limit state. + +Backs the LLM gateway's `QuotaResolver`, which forwards the caller's auth +header here to learn whether a given team is currently over its AI credits +quota. Project-nested so org membership and token `scoped_teams`/ +`scoped_organizations` enforcement come from the standard +`TeamAndOrgViewSetMixin` permission chain — see +`posthog.permissions.APIScopePermission.check_team_and_org_permissions`. +""" + +from __future__ import annotations + +from typing import Any + +from drf_spectacular.utils import extend_schema +from rest_framework import serializers, viewsets +from rest_framework.request import Request +from rest_framework.response import Response + +from posthog.api.routing import TeamAndOrgViewSetMixin + +from ee.billing.quota_limiting import QuotaLimitingCaches, QuotaResource, is_team_limited + + +class QuotaResourceLimitSerializer(serializers.Serializer): + limited = serializers.BooleanField( + help_text="True when the team is currently over its quota for this resource and limits are in effect.", + ) + + +class QuotaLimitsResponseSerializer(serializers.Serializer): + limited = serializers.DictField( + child=QuotaResourceLimitSerializer(), + help_text=( + "Per-resource limit state keyed by `QuotaResource` value. " + "Currently only `ai_credits` is reported; additional resources may be added." + ), + ) + + +@extend_schema(tags=["quota_limits"]) +class QuotaLimitsViewSet(TeamAndOrgViewSetMixin, viewsets.ViewSet): + """Read-only view of a team's quota-limit state.""" + + scope_object = "project" + required_scopes = ["project:read"] + http_method_names = ["get", "head", "options"] + + @extend_schema( + summary="Get a team's quota-limit state", + description=( + "Return the current quota-limit state for the team identified in the URL. " + "Used by the LLM gateway to gate billable products on AI credits exhaustion." + ), + responses={200: QuotaLimitsResponseSerializer}, + ) + def list(self, request: Request, *args: Any, **kwargs: Any) -> Response: + limited = { + QuotaResource.AI_CREDITS.value: { + "limited": is_team_limited( + self.team.api_token, + QuotaResource.AI_CREDITS, + QuotaLimitingCaches.QUOTA_LIMITER_CACHE_KEY, + ), + }, + } + return Response(QuotaLimitsResponseSerializer({"limited": limited}).data) diff --git a/ee/api/test/test_quota_limits.py b/ee/api/test/test_quota_limits.py new file mode 100644 index 000000000000..58c2258ae6aa --- /dev/null +++ b/ee/api/test/test_quota_limits.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +from posthog.test.base import APIBaseTest + +from rest_framework import status + +from posthog.models.organization import Organization +from posthog.models.personal_api_key import PersonalAPIKey +from posthog.models.team import Team +from posthog.models.utils import generate_random_token_personal, hash_key_value + +from ee.billing.quota_limiting import ( + QuotaLimitingCaches, + QuotaResource, + add_limited_team_tokens, + replace_limited_team_tokens, +) + + +def _clear_ai_credits_limits() -> None: + replace_limited_team_tokens(QuotaResource.AI_CREDITS, {}, QuotaLimitingCaches.QUOTA_LIMITER_CACHE_KEY) + + +class TestQuotaLimitsAPI(APIBaseTest): + def setUp(self) -> None: + super().setUp() + _clear_ai_credits_limits() + + def tearDown(self) -> None: + _clear_ai_credits_limits() + super().tearDown() + + def _url(self, team_id: int | None = None) -> str: + return f"/api/projects/{team_id if team_id is not None else self.team.pk}/quota_limits/" + + def _set_ai_credits_limit(self, team_api_token: str, expires_at: int) -> None: + add_limited_team_tokens( + QuotaResource.AI_CREDITS, + {team_api_token: expires_at}, + QuotaLimitingCaches.QUOTA_LIMITER_CACHE_KEY, + ) + + def test_unauthenticated_returns_401_or_403(self) -> None: + self.client.logout() + response = self.client.get(self._url()) + # DRF returns 401 when no creds are presented and an authenticator that supports + # a WWW-Authenticate challenge is configured; otherwise it returns 403. Either is + # an auth failure — we only care that the endpoint refuses unauthenticated reads. + self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + + def test_session_auth_returns_under_quota_when_team_not_limited(self) -> None: + response = self.client.get(self._url()) + self.assertEqual(response.status_code, status.HTTP_200_OK) + data = response.json() + self.assertEqual(data["limited"]["ai_credits"], {"limited": False}) + + def test_returns_limited_when_team_is_over_quota(self) -> None: + self._set_ai_credits_limit(self.team.api_token, 9_999_999_999) + + response = self.client.get(self._url()) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.json()["limited"]["ai_credits"], {"limited": True}) + + def test_returns_unlimited_when_limit_has_already_expired(self) -> None: + self._set_ai_credits_limit(self.team.api_token, 1) # epoch 1970 + + response = self.client.get(self._url()) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.json()["limited"]["ai_credits"], {"limited": False}) + + def test_personal_api_key_auth_works(self) -> None: + self.client.logout() + raw_key = generate_random_token_personal() + PersonalAPIKey.objects.create( + label="quota_limits-test", + user=self.user, + secure_value=hash_key_value(raw_key), + scopes=["project:read"], + ) + + self._set_ai_credits_limit(self.team.api_token, 9_999_999_999) + + response = self.client.get( + self._url(), + headers={"authorization": f"Bearer {raw_key}"}, + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.json()["limited"]["ai_credits"], {"limited": True}) + + def test_user_not_in_teams_org_is_forbidden(self) -> None: + other_org = Organization.objects.create(name="other-org") + other_team = Team.objects.create(organization=other_org, name="other-team") + + response = self.client.get(self._url(other_team.pk)) + # The caller is logged in to a team in a different org — TeamMemberAccessPermission + # rejects with 403 (or 404 if the queryset can't see the team at all). + self.assertIn(response.status_code, (status.HTTP_403_FORBIDDEN, status.HTTP_404_NOT_FOUND)) + + def test_personal_api_key_scoped_to_a_different_team_is_forbidden(self) -> None: + # Caller has access to both teams via membership, but the token is scoped to + # `other_team` only — the standalone-endpoint design would have missed this and + # leaked the other team's state. + other_team = Team.objects.create(organization=self.organization, name="other-team") + self.client.logout() + raw_key = generate_random_token_personal() + PersonalAPIKey.objects.create( + label="quota_limits-test", + user=self.user, + secure_value=hash_key_value(raw_key), + scopes=["project:read"], + scoped_teams=[other_team.pk], + ) + + response = self.client.get( + self._url(), + headers={"authorization": f"Bearer {raw_key}"}, + ) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_personal_api_key_missing_required_scope_is_forbidden(self) -> None: + # A token with only `feature_flag:read` shouldn't be able to read quota state. + self.client.logout() + raw_key = generate_random_token_personal() + PersonalAPIKey.objects.create( + label="quota_limits-test", + user=self.user, + secure_value=hash_key_value(raw_key), + scopes=["feature_flag:read"], + ) + + response = self.client.get( + self._url(), + headers={"authorization": f"Bearer {raw_key}"}, + ) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_multi_team_user_gets_per_team_answers(self) -> None: + # Same user belongs to two teams in their org; each team's quota is independent. + # This is the regression that "me" couldn't model — `user.team` (current team) + # picked one arbitrary answer for users in multiple teams. + other_team = Team.objects.create(organization=self.organization, name="other-team") + self._set_ai_credits_limit(self.team.api_token, 9_999_999_999) + # other_team's token deliberately not limited + + resp_self = self.client.get(self._url()) + resp_other = self.client.get(self._url(other_team.pk)) + + self.assertEqual(resp_self.json()["limited"]["ai_credits"], {"limited": True}) + self.assertEqual(resp_other.json()["limited"]["ai_credits"], {"limited": False}) + diff --git a/posthog/api/__init__.py b/posthog/api/__init__.py index 8667f5916ac9..23aa28942dfc 100644 --- a/posthog/api/__init__.py +++ b/posthog/api/__init__.py @@ -142,6 +142,7 @@ ) from products.web_analytics.backend.api.web_analytics_filter_preset import WebAnalyticsFilterPresetViewSet +from ee.api.quota_limits import QuotaLimitsViewSet from ee.api.session_summaries import SessionGroupSummaryViewSet from ee.api.vercel import vercel_installation, vercel_product, vercel_proxy, vercel_resource @@ -371,6 +372,14 @@ def register_grandfathered_environment_nested_viewset( # Seats (proxied to billing service) router.register(r"seats", seats.SeatViewSet, "seats") +# Quota limits (project-scoped — backs the LLM gateway's QuotaResolver) +projects_router.register( + r"quota_limits", + QuotaLimitsViewSet, + "project_quota_limits", + ["team_id"], +) + projects_router.register(r"surveys", survey.SurveyViewSet, "project_surveys", ["project_id"]) projects_router.register(r"product_tours", ProductTourViewSet, "project_product_tours", ["project_id"]) projects_router.register( diff --git a/services/llm-gateway/src/llm_gateway/api/usage.py b/services/llm-gateway/src/llm_gateway/api/usage.py index 557a61bdfd1b..bdc3facde6b4 100644 --- a/services/llm-gateway/src/llm_gateway/api/usage.py +++ b/services/llm-gateway/src/llm_gateway/api/usage.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from datetime import UTC, datetime, timedelta from typing import Annotated @@ -9,6 +10,7 @@ from llm_gateway.auth.models import AuthenticatedUser from llm_gateway.dependencies import get_authenticated_user +from llm_gateway.products.config import get_product_config from llm_gateway.rate_limiting.cost_throttles import CostStatus, UserCostBurstThrottle, UserCostSustainedThrottle from llm_gateway.rate_limiting.runner import ThrottleRunner from llm_gateway.rate_limiting.throttles import ThrottleContext @@ -19,6 +21,7 @@ parse_iso_utc, resolve_plan_info, ) +from llm_gateway.services.quota_resolver import resolve_quota_status logger = structlog.get_logger(__name__) @@ -35,11 +38,16 @@ class CostLimitStatus(BaseModel): exceeded: bool +class AiCreditsStatus(BaseModel): + exhausted: bool + + class UsageResponse(BaseModel): product: str user_id: int burst: CostLimitStatus sustained: CostLimitStatus + ai_credits: AiCreditsStatus is_rate_limited: bool is_pro: bool billing_period_end: datetime | None = None @@ -65,9 +73,20 @@ async def get_usage( user: Annotated[AuthenticatedUser, Depends(get_authenticated_user)], ) -> UsageResponse: runner: ThrottleRunner = request.app.state.throttle_runner - plan_info = await resolve_plan_info(request, user.user_id, product) + + plan_info, quota_status = await asyncio.gather( + resolve_plan_info(request, user.user_id, product), + resolve_quota_status(request, user.team_id), + ) now = datetime.now(tz=UTC) + # Non-billable products don't actually get denied by BillableCreditThrottle, + # so reporting "exhausted" here would be misleading regardless of the team's + # AI credits state. + product_config = get_product_config(product) + is_billable = bool(product_config and product_config.billable) + ai_credits_exhausted = quota_status.limited if is_billable else False + context = ThrottleContext( user=user, product=product, @@ -75,6 +94,7 @@ async def get_usage( plan_key=plan_info.plan_key, seat_created_at=plan_info.seat_created_at, billing_period_start=plan_info.billing_period.current_period_start if plan_info.billing_period else None, + ai_credits_exhausted=ai_credits_exhausted, ) burst_status: CostLimitStatus | None = None @@ -118,7 +138,8 @@ async def get_usage( user_id=user.user_id, burst=burst_status, sustained=sustained_status, - is_rate_limited=burst_status.exceeded or sustained_status.exceeded, + ai_credits=AiCreditsStatus(exhausted=ai_credits_exhausted), + is_rate_limited=burst_status.exceeded or sustained_status.exceeded or ai_credits_exhausted, is_pro=is_pro_plan(plan_info.plan_key), billing_period_end=billing_period_end, ) diff --git a/services/llm-gateway/src/llm_gateway/auth/authenticators.py b/services/llm-gateway/src/llm_gateway/auth/authenticators.py index 9a67467d759f..88685f50bf7c 100644 --- a/services/llm-gateway/src/llm_gateway/auth/authenticators.py +++ b/services/llm-gateway/src/llm_gateway/auth/authenticators.py @@ -62,11 +62,9 @@ async def authenticate(self, token_hash: str, pool: asyncpg.Pool) -> Authenticat async with acquire_connection(pool) as conn: row = await conn.fetchrow( """ - SELECT pak.id, pak.user_id, pak.scopes, u.current_team_id, u.distinct_id, - t.api_token AS team_api_token + SELECT pak.id, pak.user_id, pak.scopes, u.current_team_id, u.distinct_id FROM posthog_personalapikey pak JOIN posthog_user u ON pak.user_id = u.id - LEFT JOIN posthog_team t ON u.current_team_id = t.id WHERE pak.secure_value = $1 AND u.is_active = true """, token_hash, @@ -85,7 +83,6 @@ async def authenticate(self, token_hash: str, pool: asyncpg.Pool) -> Authenticat auth_method=self.auth_type, distinct_id=row["distinct_id"], scopes=scopes, - team_api_token=row["team_api_token"], ) @@ -111,11 +108,9 @@ async def authenticate(self, token_hash: str, pool: asyncpg.Pool) -> Authenticat row = await conn.fetchrow( """ SELECT oat.id, oat.user_id, oat.scope, oat.expires, - oat.application_id, u.current_team_id, u.distinct_id, - t.api_token AS team_api_token + oat.application_id, u.current_team_id, u.distinct_id FROM posthog_oauthaccesstoken oat JOIN posthog_user u ON oat.user_id = u.id - LEFT JOIN posthog_team t ON u.current_team_id = t.id WHERE oat.token_checksum = $1 AND u.is_active = true """, token_hash, @@ -143,5 +138,4 @@ async def authenticate(self, token_hash: str, pool: asyncpg.Pool) -> Authenticat scopes=scopes, token_expires_at=expires, application_id=str(row["application_id"]), - team_api_token=row["team_api_token"], ) diff --git a/services/llm-gateway/src/llm_gateway/auth/models.py b/services/llm-gateway/src/llm_gateway/auth/models.py index 9c8f407e0f0b..5bcad3e287cd 100644 --- a/services/llm-gateway/src/llm_gateway/auth/models.py +++ b/services/llm-gateway/src/llm_gateway/auth/models.py @@ -11,10 +11,6 @@ class AuthenticatedUser: scopes: list[str] | None = None token_expires_at: datetime | None = None application_id: str | None = None - # The team's `posthog_team.api_token` — used by quota-limit throttles that - # read Django's `@posthog/quota-limits/...` Redis sets, which are keyed by - # team API token rather than team_id. - team_api_token: str | None = None def resolve_distinct_id(auth_user: AuthenticatedUser, end_user_id: str | None) -> str: diff --git a/services/llm-gateway/src/llm_gateway/config.py b/services/llm-gateway/src/llm_gateway/config.py index 84461bb68856..843072a6f68c 100644 --- a/services/llm-gateway/src/llm_gateway/config.py +++ b/services/llm-gateway/src/llm_gateway/config.py @@ -160,6 +160,10 @@ class Settings(BaseSettings): posthog_api_base_url: str = "https://us.posthog.com" plan_cache_ttl: int = 900 # 15 minutes + # AI credits quota state is fast-moving — Django itself caches the underlying + # Redis set for 30 seconds, so caching the resolver result for any longer + # would just stack staleness without saving lookups. + quota_cache_ttl: int = 30 billing_period_days: int = 30 # Anthropic -> Bedrock circuit breaker. When the trailing failure rate of the Anthropic diff --git a/services/llm-gateway/src/llm_gateway/dependencies.py b/services/llm-gateway/src/llm_gateway/dependencies.py index ccc3d80a3741..01f948a9de98 100644 --- a/services/llm-gateway/src/llm_gateway/dependencies.py +++ b/services/llm-gateway/src/llm_gateway/dependencies.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import json from typing import Annotated, Any @@ -20,6 +21,7 @@ set_throttle_context, ) from llm_gateway.services.plan_resolver import resolve_plan_info +from llm_gateway.services.quota_resolver import resolve_quota_status logger = structlog.get_logger(__name__) @@ -163,7 +165,11 @@ async def enforce_throttles( else: end_user_id = await _extract_end_user_id_from_body(request) - plan_info = await resolve_plan_info(request, user.user_id, product) + # Plan + quota are independent Django roundtrips on cache miss — overlap them. + plan_info, quota_status = await asyncio.gather( + resolve_plan_info(request, user.user_id, product), + resolve_quota_status(request, user.team_id), + ) context = ThrottleContext( user=user, @@ -173,6 +179,7 @@ async def enforce_throttles( plan_key=plan_info.plan_key, seat_created_at=plan_info.seat_created_at, billing_period_start=plan_info.billing_period.current_period_start if plan_info.billing_period else None, + ai_credits_exhausted=quota_status.limited, ) request.state.throttle_context = context set_throttle_context(runner, context) diff --git a/services/llm-gateway/src/llm_gateway/main.py b/services/llm-gateway/src/llm_gateway/main.py index fc483566131a..42dec77a8ad5 100644 --- a/services/llm-gateway/src/llm_gateway/main.py +++ b/services/llm-gateway/src/llm_gateway/main.py @@ -38,6 +38,7 @@ from llm_gateway.rate_limiting.runner import ThrottleRunner from llm_gateway.request_context import RequestContext, set_request_context from llm_gateway.services.plan_resolver import PlanResolver +from llm_gateway.services.quota_resolver import QuotaResolver def configure_logging(debug: bool = False) -> None: @@ -159,7 +160,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: ) app.state.throttle_runner = ThrottleRunner( throttles=[ - BillableCreditThrottle(redis=app.state.redis), + BillableCreditThrottle(), product_throttle, UserCostBurstThrottle(redis=app.state.redis), UserCostSustainedThrottle(redis=app.state.redis), @@ -188,6 +189,10 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: redis=app.state.redis, http_client=app.state.http_client, ) + app.state.quota_resolver = QuotaResolver( + redis=app.state.redis, + http_client=app.state.http_client, + ) logger.info("Plan resolver initialized", posthog_api_base_url=settings.posthog_api_base_url or "(not configured)") logger.info( diff --git a/services/llm-gateway/src/llm_gateway/rate_limiting/billable_credits_throttle.py b/services/llm-gateway/src/llm_gateway/rate_limiting/billable_credits_throttle.py index 0c7184234508..4be99145f703 100644 --- a/services/llm-gateway/src/llm_gateway/rate_limiting/billable_credits_throttle.py +++ b/services/llm-gateway/src/llm_gateway/rate_limiting/billable_credits_throttle.py @@ -1,59 +1,31 @@ from __future__ import annotations -import time -from collections.abc import Callable -from typing import TYPE_CHECKING - -import structlog - from llm_gateway.products.config import get_product_config from llm_gateway.rate_limiting.throttles import Throttle, ThrottleContext, ThrottleResult -if TYPE_CHECKING: - from redis.asyncio import Redis - -logger = structlog.get_logger(__name__) - - -# Mirror of ee/billing/quota_limiting.py: -# QuotaLimitingCaches.QUOTA_LIMITER_CACHE_KEY = "@posthog/quota-limits/" -# QuotaResource.AI_CREDITS = "ai_credits" -_AI_CREDITS_LIMIT_KEY = "@posthog/quota-limits/ai_credits" +# Hint the client to back off for a minute. A precise expiry timestamp is +# available upstream in Redis, but with Django's 30s in-process cache and the +# gateway's 30s resolver cache stacked on top, any computed retry window would +# be 0–60s stale anyway — a fixed minute is no worse and cuts the data +# plumbing. +_RETRY_AFTER_SECONDS = 60 class BillableCreditThrottle(Throttle): """Gate billable-product LLM calls on the team's AI credits balance. - Reads the same Redis sorted set Django populates via - ee/billing/quota_limiting.add_limited_team_tokens. Members are team API - tokens; scores are Unix timestamps marking when the limit expires. - - Fail-open when Redis is unavailable or the user's team API token isn't - known — matches the rest of the throttle chain. Without this we'd close - requests on infrastructure incidents that have nothing to do with billing. + Reads ``ai_credits_exhausted`` from :class:`ThrottleContext`, pre-resolved + by the dependency layer (see ``resolve_quota_status``). """ scope = "billable_credits" - def __init__(self, redis: Redis[bytes] | None, clock: Callable[[], float] | None = None): - self._redis = redis - self._now = clock or time.time - if redis is None: - logger.warning( - "billable_credits_throttle_disabled_no_redis", - reason="Redis client not configured; throttle is fail-open and will allow all billable calls.", - ) - async def allow_request(self, context: ThrottleContext) -> ThrottleResult: config = get_product_config(context.product) if not (config and config.billable): return ThrottleResult.allow() - if self._redis is None or context.user.team_api_token is None: - return ThrottleResult.allow() - - score = await self._redis.zscore(_AI_CREDITS_LIMIT_KEY, context.user.team_api_token) - if score is None or score <= self._now(): + if not context.ai_credits_exhausted: return ThrottleResult.allow() return ThrottleResult.deny( @@ -62,5 +34,5 @@ async def allow_request(self, context: ThrottleContext) -> ThrottleResult: "Top up at https://us.posthog.com/organization/billing to continue." ), scope=self.scope, - retry_after=max(int(score - self._now()), 1), + retry_after=_RETRY_AFTER_SECONDS, ) diff --git a/services/llm-gateway/src/llm_gateway/rate_limiting/throttles.py b/services/llm-gateway/src/llm_gateway/rate_limiting/throttles.py index aff139e2e529..9f47757b1d89 100644 --- a/services/llm-gateway/src/llm_gateway/rate_limiting/throttles.py +++ b/services/llm-gateway/src/llm_gateway/rate_limiting/throttles.py @@ -23,6 +23,7 @@ class ThrottleContext: plan_key: str | None = None seat_created_at: str | None = None billing_period_start: str | None = None + ai_credits_exhausted: bool = False @dataclass diff --git a/services/llm-gateway/src/llm_gateway/services/quota_resolver.py b/services/llm-gateway/src/llm_gateway/services/quota_resolver.py new file mode 100644 index 000000000000..3fab65fc30d2 --- /dev/null +++ b/services/llm-gateway/src/llm_gateway/services/quota_resolver.py @@ -0,0 +1,126 @@ +"""Resolves a team's quota state via the PostHog API quota_limits endpoint. + +Mirrors :mod:`llm_gateway.services.plan_resolver` — forwards the caller's +``Authorization`` header to ``GET /api/projects/{team_id}/quota_limits/`` and +caches the result per team-and-resource in the gateway's own Redis. +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import structlog + +from llm_gateway.config import get_settings + +if TYPE_CHECKING: + import httpx + from fastapi import Request + from redis.asyncio import Redis + +logger = structlog.get_logger(__name__) + +_AI_CREDITS_RESOURCE = "ai_credits" + +# Cache window for the fail-open path (4xx from Django, e.g. expired token). +# Short enough that a fixed-up token recovers quickly; long enough to keep a +# misconfigured client off Django's neck during an auth-failure storm. +_FAIL_OPEN_CACHE_TTL_SECONDS = 5 + + +@dataclass +class QuotaResourceStatus: + limited: bool + + +def _redis_key(resource_key: str, team_id: int) -> str: + return f"quota:{resource_key}:team:{team_id}" + + +async def resolve_quota_status(request: Request, team_id: int | None) -> QuotaResourceStatus: + """Resolve the team's AI credits quota state, falling open on errors.""" + if team_id is None: + return QuotaResourceStatus(limited=False) + auth_header = request.headers.get("Authorization", "") + if not auth_header: + return QuotaResourceStatus(limited=False) + + quota_resolver: QuotaResolver = request.app.state.quota_resolver + try: + return await quota_resolver.get_ai_credits_status(team_id=team_id, auth_header=auth_header) + except Exception: + logger.warning("quota_resolve_failed", team_id=team_id) + return QuotaResourceStatus(limited=False) + + +class QuotaResolver: + """Fetches team quota state from Django, caches per team.""" + + def __init__(self, redis: Redis[bytes] | None, http_client: httpx.AsyncClient): + self._redis = redis + self._http = http_client + self._cache_ttl = get_settings().quota_cache_ttl + + async def get_ai_credits_status(self, team_id: int, auth_header: str) -> QuotaResourceStatus: + return await self._get_resource_status(_AI_CREDITS_RESOURCE, team_id, auth_header) + + async def _get_resource_status(self, resource_key: str, team_id: int, auth_header: str) -> QuotaResourceStatus: + cached = await self._get_cached(resource_key, team_id) + if cached is not None: + return cached + + try: + status, ttl = await self._fetch(resource_key, team_id, auth_header) + except Exception: + logger.warning("quota_fetch_failed", resource=resource_key, team_id=team_id, exc_info=True) + return QuotaResourceStatus(limited=False) + + await self._set_cached(resource_key, team_id, status, ttl) + return status + + async def _fetch(self, resource_key: str, team_id: int, auth_header: str) -> tuple[QuotaResourceStatus, int]: + """Return the resource status and the TTL the caller should cache it for. + + 4xx responses are treated as "not limited" and cached briefly so a hot + loop with a broken token can't hammer Django. + """ + settings = get_settings() + if not settings.posthog_api_base_url: + return QuotaResourceStatus(limited=False), _FAIL_OPEN_CACHE_TTL_SECONDS + + url = f"{settings.posthog_api_base_url.rstrip('/')}/api/projects/{team_id}/quota_limits/" + resp = await self._http.get( + url, + headers={"Authorization": auth_header}, + timeout=2.0, + ) + if resp.status_code >= 400: + return QuotaResourceStatus(limited=False), _FAIL_OPEN_CACHE_TTL_SECONDS + + data = resp.json() + resource = (data.get("limited") or {}).get(resource_key) or {} + return QuotaResourceStatus(limited=bool(resource.get("limited"))), self._cache_ttl + + async def _get_cached(self, resource_key: str, team_id: int) -> QuotaResourceStatus | None: + if not self._redis: + return None + try: + val = await self._redis.get(_redis_key(resource_key, team_id)) + if val is None: + return None + payload = json.loads(val.decode()) + return QuotaResourceStatus(limited=bool(payload.get("limited"))) + except Exception: + logger.debug("quota_cache_read_failed", resource=resource_key, team_id=team_id) + return None + + async def _set_cached(self, resource_key: str, team_id: int, status: QuotaResourceStatus, ttl: int) -> None: + if not self._redis: + return + try: + payload = json.dumps({"limited": status.limited}) + await self._redis.set(_redis_key(resource_key, team_id), payload, ex=ttl) + except Exception: + logger.debug("quota_cache_write_failed", resource=resource_key, team_id=team_id) diff --git a/services/llm-gateway/tests/callbacks/test_posthog.py b/services/llm-gateway/tests/callbacks/test_posthog.py index fe1e2897f7f7..1d50233ccfaf 100644 --- a/services/llm-gateway/tests/callbacks/test_posthog.py +++ b/services/llm-gateway/tests/callbacks/test_posthog.py @@ -348,11 +348,13 @@ async def test_on_success_does_not_mark_other_products_billable( assert props["$ai_billable"] is False @pytest.mark.asyncio - async def test_on_failure_marks_slack_app_billable( + @pytest.mark.parametrize("product", ["slack_app", "slack_app_routing"]) + async def test_on_failure_marks_slack_products_billable( self, callback: PostHogCallback, auth_user: AuthenticatedUser, mock_posthog_client: tuple, + product: str, ) -> None: _, mock_client = mock_posthog_client kwargs = { @@ -366,7 +368,7 @@ async def test_on_failure_marks_slack_app_billable( with ( patch("llm_gateway.callbacks.posthog.get_auth_user", return_value=auth_user), - patch("llm_gateway.callbacks.posthog.get_product", return_value="slack_app"), + patch("llm_gateway.callbacks.posthog.get_product", return_value=product), ): await callback._on_failure(kwargs, None, 0.0, 1.0, end_user_id=None) diff --git a/services/llm-gateway/tests/conftest.py b/services/llm-gateway/tests/conftest.py index 2f6f24b7d0be..200c7b7282ba 100644 --- a/services/llm-gateway/tests/conftest.py +++ b/services/llm-gateway/tests/conftest.py @@ -18,6 +18,13 @@ from llm_gateway.rate_limiting.runner import ThrottleRunner from llm_gateway.rate_limiting.throttles import Throttle from llm_gateway.services.plan_resolver import PlanInfo +from llm_gateway.services.quota_resolver import QuotaResourceStatus + + +def _make_fake_quota_resolver() -> AsyncMock: + resolver = AsyncMock() + resolver.get_ai_credits_status = AsyncMock(return_value=QuotaResourceStatus(limited=False)) + return resolver def create_test_app( @@ -27,8 +34,9 @@ def create_test_app( from llm_gateway.api.health import health_router from llm_gateway.api.routes import router + quota_resolver = _make_fake_quota_resolver() default_throttles: list[Throttle] = [ - BillableCreditThrottle(redis=None), + BillableCreditThrottle(), ProductCostThrottle(redis=None), UserCostBurstThrottle(redis=None), UserCostSustainedThrottle(redis=None), @@ -43,6 +51,7 @@ async def test_lifespan(app: FastAPI) -> AsyncGenerator[None, None]: app.state.plan_resolver = AsyncMock() app.state.plan_resolver.get_plan = AsyncMock(return_value=PlanInfo(plan_key=None, seat_created_at=None)) app.state.anthropic_circuit_breaker = None + app.state.quota_resolver = quota_resolver yield app = FastAPI(title="LLM Gateway Test", lifespan=test_lifespan) diff --git a/services/llm-gateway/tests/test_billable_credits_throttle.py b/services/llm-gateway/tests/test_billable_credits_throttle.py index c561e4fdf4e3..50a276af80e7 100644 --- a/services/llm-gateway/tests/test_billable_credits_throttle.py +++ b/services/llm-gateway/tests/test_billable_credits_throttle.py @@ -1,95 +1,55 @@ from __future__ import annotations -from unittest.mock import AsyncMock - import pytest from llm_gateway.auth.models import AuthenticatedUser -from llm_gateway.rate_limiting.billable_credits_throttle import ( - _AI_CREDITS_LIMIT_KEY, - BillableCreditThrottle, -) +from llm_gateway.rate_limiting.billable_credits_throttle import BillableCreditThrottle from llm_gateway.rate_limiting.throttles import ThrottleContext -_TEAM_TOKEN = "phc_team_under_test" - -def _make_user(team_api_token: str | None = _TEAM_TOKEN) -> AuthenticatedUser: +def _make_user() -> AuthenticatedUser: return AuthenticatedUser( user_id=1, team_id=42, auth_method="personal_api_key", distinct_id="distinct-1", scopes=["llm_gateway:read"], - team_api_token=team_api_token, ) -def _make_context(product: str, user: AuthenticatedUser | None = None) -> ThrottleContext: - return ThrottleContext(user=user or _make_user(), product=product) +def _make_context(product: str, *, ai_credits_exhausted: bool = False) -> ThrottleContext: + return ThrottleContext(user=_make_user(), product=product, ai_credits_exhausted=ai_credits_exhausted) class TestBillableCreditThrottle: @pytest.mark.asyncio - async def test_allows_non_billable_product_without_redis_lookup(self) -> None: - redis = AsyncMock() - throttle = BillableCreditThrottle(redis=redis, clock=lambda: 1_000_000) + async def test_allows_non_billable_product_even_when_exhausted(self) -> None: + # `posthog_code` is non-billable; exhaustion at the context level is + # irrelevant — the throttle short-circuits before checking the flag. + throttle = BillableCreditThrottle() - result = await throttle.allow_request(_make_context(product="posthog_code")) + result = await throttle.allow_request(_make_context(product="posthog_code", ai_credits_exhausted=True)) assert result.allowed is True - redis.zscore.assert_not_called() @pytest.mark.asyncio - async def test_allows_billable_product_when_team_not_limited(self) -> None: - redis = AsyncMock() - redis.zscore = AsyncMock(return_value=None) - throttle = BillableCreditThrottle(redis=redis, clock=lambda: 1_000_000) + @pytest.mark.parametrize("product", ["slack_app", "slack_app_routing"]) + async def test_allows_billable_product_when_not_exhausted(self, product: str) -> None: + throttle = BillableCreditThrottle() - result = await throttle.allow_request(_make_context(product="slack_app")) + result = await throttle.allow_request(_make_context(product=product, ai_credits_exhausted=False)) assert result.allowed is True - redis.zscore.assert_awaited_once_with(_AI_CREDITS_LIMIT_KEY, _TEAM_TOKEN) @pytest.mark.asyncio - async def test_allows_billable_product_when_limit_expired(self) -> None: - redis = AsyncMock() - redis.zscore = AsyncMock(return_value=999_999.0) - throttle = BillableCreditThrottle(redis=redis, clock=lambda: 1_000_000) + @pytest.mark.parametrize("product", ["slack_app", "slack_app_routing"]) + async def test_denies_billable_product_when_exhausted(self, product: str) -> None: + throttle = BillableCreditThrottle() - result = await throttle.allow_request(_make_context(product="slack_app")) - - assert result.allowed is True - - @pytest.mark.asyncio - async def test_denies_billable_product_when_team_currently_limited(self) -> None: - redis = AsyncMock() - redis.zscore = AsyncMock(return_value=1_003_600.0) - throttle = BillableCreditThrottle(redis=redis, clock=lambda: 1_000_000) - - result = await throttle.allow_request(_make_context(product="slack_app")) + result = await throttle.allow_request(_make_context(product=product, ai_credits_exhausted=True)) assert result.allowed is False assert result.status_code == 429 assert result.scope == "billable_credits" assert "PostHog AI credits" in result.detail - assert result.retry_after == 3600 - - @pytest.mark.asyncio - async def test_allows_when_redis_is_not_configured(self) -> None: - throttle = BillableCreditThrottle(redis=None, clock=lambda: 1_000_000) - - result = await throttle.allow_request(_make_context(product="slack_app")) - - assert result.allowed is True - - @pytest.mark.asyncio - async def test_allows_when_team_api_token_is_missing(self) -> None: - redis = AsyncMock() - throttle = BillableCreditThrottle(redis=redis, clock=lambda: 1_000_000) - user_without_token = _make_user(team_api_token=None) - - result = await throttle.allow_request(_make_context(product="slack_app", user=user_without_token)) - - assert result.allowed is True - redis.zscore.assert_not_called() + assert result.retry_after == 60 diff --git a/services/llm-gateway/tests/test_quota_resolver.py b/services/llm-gateway/tests/test_quota_resolver.py new file mode 100644 index 000000000000..d7bae046fe00 --- /dev/null +++ b/services/llm-gateway/tests/test_quota_resolver.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +import json +from unittest.mock import AsyncMock, MagicMock + +import httpx +import pytest + +from llm_gateway.services.quota_resolver import ( + QuotaResolver, + QuotaResourceStatus, + _redis_key, +) + + +def _make_response(status_code: int, payload: dict[str, object] | None = None) -> httpx.Response: + content = json.dumps(payload or {}).encode() + return httpx.Response(status_code, content=content, headers={"content-type": "application/json"}) + + +def _make_http_client(response: httpx.Response | Exception) -> MagicMock: + client = MagicMock() + if isinstance(response, Exception): + client.get = AsyncMock(side_effect=response) + else: + client.get = AsyncMock(return_value=response) + return client + + +class TestQuotaResolver: + @pytest.mark.asyncio + async def test_fetches_and_parses_limited_response(self) -> None: + http_client = _make_http_client( + _make_response(200, {"team_id": 1, "limited": {"ai_credits": {"limited": True}}}) + ) + resolver = QuotaResolver(redis=None, http_client=http_client) + + status = await resolver.get_ai_credits_status(team_id=42, auth_header="Bearer phx_test") + + assert status == QuotaResourceStatus(limited=True) + http_client.get.assert_awaited_once() + assert http_client.get.await_args.kwargs["headers"]["Authorization"] == "Bearer phx_test" + + @pytest.mark.asyncio + async def test_fetches_and_parses_unlimited_response(self) -> None: + http_client = _make_http_client( + _make_response(200, {"team_id": 1, "limited": {"ai_credits": {"limited": False}}}) + ) + resolver = QuotaResolver(redis=None, http_client=http_client) + + status = await resolver.get_ai_credits_status(team_id=42, auth_header="Bearer phx_test") + + assert status == QuotaResourceStatus(limited=False) + + @pytest.mark.asyncio + async def test_fail_open_on_http_error(self) -> None: + http_client = _make_http_client(httpx.ConnectError("boom")) + resolver = QuotaResolver(redis=None, http_client=http_client) + + status = await resolver.get_ai_credits_status(team_id=42, auth_header="Bearer phx_test") + + assert status == QuotaResourceStatus(limited=False) + + @pytest.mark.asyncio + async def test_fail_open_on_4xx(self) -> None: + http_client = _make_http_client(_make_response(401, {"detail": "no auth"})) + resolver = QuotaResolver(redis=None, http_client=http_client) + + status = await resolver.get_ai_credits_status(team_id=42, auth_header="Bearer phx_test") + + assert status == QuotaResourceStatus(limited=False) + + @pytest.mark.asyncio + async def test_uses_cached_result_and_skips_http(self) -> None: + redis = AsyncMock() + redis.get = AsyncMock(return_value=json.dumps({"limited": True}).encode()) + http_client = _make_http_client(_make_response(200)) + resolver = QuotaResolver(redis=redis, http_client=http_client) + + status = await resolver.get_ai_credits_status(team_id=42, auth_header="Bearer phx_test") + + assert status == QuotaResourceStatus(limited=True) + redis.get.assert_awaited_once_with(_redis_key("ai_credits", 42)) + http_client.get.assert_not_called() + + @pytest.mark.asyncio + async def test_writes_cache_on_miss(self) -> None: + redis = AsyncMock() + redis.get = AsyncMock(return_value=None) + redis.set = AsyncMock() + http_client = _make_http_client(_make_response(200, {"limited": {"ai_credits": {"limited": True}}})) + resolver = QuotaResolver(redis=redis, http_client=http_client) + + await resolver.get_ai_credits_status(team_id=42, auth_header="Bearer phx_test") + + redis.set.assert_awaited_once() + call = redis.set.await_args + assert call.args[0] == _redis_key("ai_credits", 42) + assert json.loads(call.args[1]) == {"limited": True} + # TTL matches the gateway settings default of 30s. + assert call.kwargs.get("ex") == 30 + + @pytest.mark.asyncio + async def test_caches_fail_open_briefly_on_4xx(self) -> None: + # 4xx responses (e.g. expired token) cache for a short window so a hot + # loop with a broken token doesn't hammer Django. + redis = AsyncMock() + redis.get = AsyncMock(return_value=None) + redis.set = AsyncMock() + http_client = _make_http_client(_make_response(401, {"detail": "no auth"})) + resolver = QuotaResolver(redis=redis, http_client=http_client) + + await resolver.get_ai_credits_status(team_id=42, auth_header="Bearer phx_test") + + redis.set.assert_awaited_once() + assert redis.set.await_args.kwargs.get("ex") == 5 diff --git a/services/llm-gateway/tests/test_usage.py b/services/llm-gateway/tests/test_usage.py index e9ec3e6999c7..3b75ada2366a 100644 --- a/services/llm-gateway/tests/test_usage.py +++ b/services/llm-gateway/tests/test_usage.py @@ -358,6 +358,40 @@ def test_ignores_user_id_query_param(self, authenticated_usage_client: TestClien assert response.status_code == 200 assert response.json()["user_id"] == 42 + def test_ai_credits_reported_unlimited_for_non_billable_product( + self, authenticated_usage_client: TestClient + ) -> None: + # posthog_code is not billable; ai_credits should be unlimited and not contribute + # to is_rate_limited even if the resolver thinks the team is over. + from llm_gateway.services.quota_resolver import QuotaResourceStatus + + app = authenticated_usage_client.app + app.state.quota_resolver.get_ai_credits_status = AsyncMock(return_value=QuotaResourceStatus(limited=True)) + + response = authenticated_usage_client.get( + "/v1/usage/posthog_code", + headers={"Authorization": "Bearer phx_test"}, + ) + assert response.status_code == 200 + data = response.json() + assert data["ai_credits"] == {"exhausted": False} + assert data["is_rate_limited"] is False + + def test_ai_credits_reflects_resolver_for_billable_product(self, authenticated_usage_client: TestClient) -> None: + from llm_gateway.services.quota_resolver import QuotaResourceStatus + + app = authenticated_usage_client.app + app.state.quota_resolver.get_ai_credits_status = AsyncMock(return_value=QuotaResourceStatus(limited=True)) + + response = authenticated_usage_client.get( + "/v1/usage/slack_app", + headers={"Authorization": "Bearer phx_test"}, + ) + assert response.status_code == 200 + data = response.json() + assert data["ai_credits"] == {"exhausted": True} + assert data["is_rate_limited"] is True + def test_invalidate_plan_cache_calls_resolver(self, authenticated_usage_client: TestClient) -> None: app = authenticated_usage_client.app app.state.plan_resolver.invalidate = AsyncMock() From fc31b557b495bcfd765f3350733dc29ac176e87f Mon Sep 17 00:00:00 2001 From: Vojta Bartos Date: Mon, 25 May 2026 12:44:05 +0200 Subject: [PATCH 06/14] feat(slack-bot): gate AI credits at webhook and workflow entry The existing gate fires deep inside `create_posthog_code_task_for_repo_activity`, after the classifier has already been called against the llm-gateway. For over-quota teams the classifier 429s and the bot's catch-all error handler steers it into the repo-picker, posting in the wrong thread. Two new entry points: - Webhook: refuse `app_mention` before `start_workflow` so no Temporal execution is created. Saves Slack roundtrips and a billable classifier call. - Workflow: new `enforce_posthog_code_billing_quota_activity`, first call in `PostHogCodeSlackMentionWorkflow.run()`, guarded by `workflow.patched( "posthog-code-slack-billing-gate")` so in-flight workflows stay deterministic. Existing in-activity gates stay as defense in depth (and to keep the direct-activity unit tests passing). --- posthog/temporal/ai/__init__.py | 2 + .../temporal/ai/posthog_code_slack_mention.py | 50 ++++++++++++++++++ .../tests/ai/test_module_integrity.py | 1 + products/slack_app/backend/api.py | 51 +++++++++++++++++++ .../backend/tests/test_followup_forwarding.py | 47 +++++++++++++++++ .../tests/test_posthog_code_event_handler.py | 41 +++++++++++++++ 6 files changed, 192 insertions(+) diff --git a/posthog/temporal/ai/__init__.py b/posthog/temporal/ai/__init__.py index 4a541bb62099..2a234d65c0f3 100644 --- a/posthog/temporal/ai/__init__.py +++ b/posthog/temporal/ai/__init__.py @@ -18,6 +18,7 @@ create_posthog_code_routing_rule_activity, create_posthog_code_task_for_repo_activity, discover_posthog_code_repository_via_agent_activity, + enforce_posthog_code_billing_quota_activity, forward_posthog_code_followup_activity, handle_posthog_code_rules_command_activity, post_posthog_code_internal_error_activity, @@ -68,6 +69,7 @@ process_research_agent_activity, summarize_llm_traces_activity, process_slack_conversation_activity, + enforce_posthog_code_billing_quota_activity, resolve_posthog_code_slack_user_activity, handle_posthog_code_rules_command_activity, collect_posthog_code_thread_messages_activity, diff --git a/posthog/temporal/ai/posthog_code_slack_mention.py b/posthog/temporal/ai/posthog_code_slack_mention.py index 0cc43bd9f3d3..683a998f85ea 100644 --- a/posthog/temporal/ai/posthog_code_slack_mention.py +++ b/posthog/temporal/ai/posthog_code_slack_mention.py @@ -254,6 +254,24 @@ async def run(self, inputs: PostHogCodeSlackMentionWorkflowInputs) -> None: return try: + # Gate every workflow entry on the team's AI-credits quota before any + # other activity runs. Webhook-level short-circuit catches the common + # case (see products/slack_app/backend/api.py); this is the defense in + # depth that also covers replays, manual workflow starts, and the race + # where the webhook saw "not limited" but Redis flipped before we got + # here. Wrapped in `workflow.patched` so in-flight workflows from + # before this deploy stay deterministic on replay. + if workflow.patched("posthog-code-slack-billing-gate"): + blocked = await _execute_posthog_code_activity( + enforce_posthog_code_billing_quota_activity, + inputs, + channel, + thread_ts, + slack_user_id, + ) + if blocked: + return + followup_handled = await _execute_posthog_code_activity( forward_posthog_code_followup_activity, inputs, @@ -939,6 +957,38 @@ def classify_posthog_code_task_needs_repo_activity( return classify_task_needs_repo(event_text, thread_messages) +@activity.defn +def enforce_posthog_code_billing_quota_activity( + inputs: PostHogCodeSlackMentionWorkflowInputs, + channel: str, + thread_ts: str, + slack_user_id: str, +) -> bool: + """Block the workflow when the team has exhausted its AI-credits quota. + + Returns True when a denial was posted and the workflow should stop. Called + as the first activity in the mention workflow so the bot never proceeds to + Slack roundtrips, thread fetches, or billable LLM calls (the classifier, + notably) for an over-quota team. + """ + from posthog.models.integration import Integration, SlackIntegration + + integration = Integration.objects.select_related("team", "team__organization").get( + id=inputs.integration_id, + kind="slack-posthog-code", + integration_id=inputs.slack_team_id, + ) + slack = SlackIntegration(integration) + return _block_if_team_over_quota( + integration=integration, + slack=slack, + channel=channel, + thread_ts=thread_ts, + slack_user_id=slack_user_id, + context="task_create", + ) + + @activity.defn def post_posthog_code_no_repos_activity( inputs: PostHogCodeSlackMentionWorkflowInputs, channel: str, thread_ts: str diff --git a/posthog/temporal/tests/ai/test_module_integrity.py b/posthog/temporal/tests/ai/test_module_integrity.py index 0ba3e89b31f4..ec0e15e60a1e 100644 --- a/posthog/temporal/tests/ai/test_module_integrity.py +++ b/posthog/temporal/tests/ai/test_module_integrity.py @@ -54,6 +54,7 @@ def test_activities_remain_unchanged(self): "process_research_agent_activity", "summarize_llm_traces_activity", "process_slack_conversation_activity", + "enforce_posthog_code_billing_quota_activity", "resolve_posthog_code_slack_user_activity", "handle_posthog_code_rules_command_activity", "collect_posthog_code_thread_messages_activity", diff --git a/products/slack_app/backend/api.py b/products/slack_app/backend/api.py index 289ca4901b15..a3bb72da41e1 100644 --- a/products/slack_app/backend/api.py +++ b/products/slack_app/backend/api.py @@ -299,6 +299,55 @@ def lookup_slack_user_id_by_email( return slack_user_id +def _refuse_mention_if_team_over_quota( + slack: SlackIntegration, + event: dict[str, Any], + integration: Integration, +) -> bool: + """Refuse an app_mention at the webhook layer when the team is over quota. + + Returns True when the mention was refused and a denial was posted, so the + caller should bail without starting the workflow. The activity-level gate + inside ``PostHogCodeSlackMentionWorkflow`` is the defense in depth; doing + the same check here saves a Slack roundtrip, a Temporal execution, and a + billable classifier call for the common "team is exhausted" case. + """ + from posthog.temporal.ai.posthog_code_slack_mention import _QUOTA_EXHAUSTED_MESSAGE + + from ee.billing.quota_limiting import QuotaLimitingCaches, QuotaResource, is_team_limited + + if not is_team_limited( + integration.team.api_token, + QuotaResource.AI_CREDITS, + QuotaLimitingCaches.QUOTA_LIMITER_CACHE_KEY, + ): + return False + + channel = event.get("channel") + thread_ts = event.get("thread_ts") or event.get("ts") + slack_user_id = event.get("user") + if not channel or not thread_ts or not slack_user_id: + # The denial needs somewhere to land; without these the workflow + # would have bailed at the top of `run()` anyway. + return True + + logger.info( + "posthog_code_slack_mention_blocked_by_quota_at_webhook", + team_id=integration.team_id, + channel=channel, + thread_ts=thread_ts, + ) + _post_slack_user_feedback( + slack, + channel, + slack_user_id, + thread_ts, + _QUOTA_EXHAUSTED_MESSAGE, + prefer_thread_message=True, + ) + return True + + def _post_slack_user_feedback( slack: SlackIntegration, channel: str, @@ -1221,6 +1270,8 @@ def route_posthog_code_event_to_relevant_region( return ROUTE_HANDLED_LOCALLY if _resolve_pending_repo_picker_from_followup(event, local_match): return ROUTE_HANDLED_LOCALLY + if _refuse_mention_if_team_over_quota(slack, event, local_match): + return ROUTE_HANDLED_LOCALLY workflow_inputs = PostHogCodeSlackMentionWorkflowInputs( event=event, integration_id=local_match.id, diff --git a/products/slack_app/backend/tests/test_followup_forwarding.py b/products/slack_app/backend/tests/test_followup_forwarding.py index 6e1fcb672aaf..804e375ba7b9 100644 --- a/products/slack_app/backend/tests/test_followup_forwarding.py +++ b/products/slack_app/backend/tests/test_followup_forwarding.py @@ -13,6 +13,7 @@ from posthog.temporal.ai.posthog_code_slack_mention import ( PostHogCodeSlackMentionWorkflowInputs, create_posthog_code_task_for_repo_activity, + enforce_posthog_code_billing_quota_activity, forward_posthog_code_followup_activity, ) @@ -778,6 +779,52 @@ def test_connection_error_retries_and_succeeds(self, mock_slack_cls, mock_send, mock_slack_instance.client.chat_postMessage.assert_not_called() +class TestEnforcePostHogCodeBillingQuotaActivity(TestCase): + """The workflow's first activity gate. Returns True (and posts a denial) when + the team is over its AI-credits quota; False otherwise.""" + + def setUp(self): + self.org = Organization.objects.create(name="TestOrg") + self.team = Team.objects.create(organization=self.org, name="TestTeam") + self.integration = Integration.objects.create( + team=self.team, kind="slack-posthog-code", integration_id="T_SLACK", config={} + ) + + @patch("posthog.models.integration.SlackIntegration") + @patch("ee.billing.quota_limiting.is_team_limited", return_value=True) + def test_returns_true_and_posts_denial_when_over_quota(self, _mock_is_team_limited, mock_slack_cls): + mock_slack_instance = MagicMock() + mock_slack_cls.return_value = mock_slack_instance + + inputs = _make_inputs(self.integration.id) + blocked = enforce_posthog_code_billing_quota_activity( + inputs, + "C123", + "1234.5678", + "U_ALICE", + ) + + assert blocked is True + _assert_quota_denial_posted(mock_slack_instance, "C123", "1234.5678") + + @patch("posthog.models.integration.SlackIntegration") + @patch("ee.billing.quota_limiting.is_team_limited", return_value=False) + def test_returns_false_and_posts_nothing_when_under_quota(self, _mock_is_team_limited, mock_slack_cls): + mock_slack_instance = MagicMock() + mock_slack_cls.return_value = mock_slack_instance + + inputs = _make_inputs(self.integration.id) + blocked = enforce_posthog_code_billing_quota_activity( + inputs, + "C123", + "1234.5678", + "U_ALICE", + ) + + assert blocked is False + mock_slack_instance.client.chat_postMessage.assert_not_called() + + class TestEventLevelDedupe(TestCase): """Verify that the workflow ID format supports event-level deduplication.""" diff --git a/products/slack_app/backend/tests/test_posthog_code_event_handler.py b/products/slack_app/backend/tests/test_posthog_code_event_handler.py index cfd149899239..000e158a66b4 100644 --- a/products/slack_app/backend/tests/test_posthog_code_event_handler.py +++ b/products/slack_app/backend/tests/test_posthog_code_event_handler.py @@ -501,3 +501,44 @@ def test_app_mention_with_missing_scopes_posts_reauth_and_skips_workflow( assert "app_mentions:read" in feedback_text if scope_value and "chat:write.customize" in scope_value: assert "chat:write.customize" not in feedback_text + + @patch("products.slack_app.backend.api._post_slack_user_feedback") + @patch("ee.billing.quota_limiting.is_team_limited", return_value=True) + @patch("products.slack_app.backend.api._posthog_code_enabled_for_integration", return_value=True) + @patch("products.slack_app.backend.api.SlackIntegration") + @patch("products.slack_app.backend.api.asyncio.run") + @patch("products.slack_app.backend.api.sync_connect") + @override_settings(DEBUG=False) + def test_over_quota_team_refuses_at_webhook_without_starting_workflow( + self, + mock_sync_connect, + mock_asyncio_run, + mock_slack_cls, + _mock_flag, + _mock_is_team_limited, + mock_post_feedback, + ): + mock_slack_instance = mock_slack_cls.return_value + mock_slack_instance.missing_scopes.return_value = set() + + from products.slack_app.backend.api import ROUTE_HANDLED_LOCALLY, route_posthog_code_event_to_relevant_region + + request = self.factory.post("/slack/event-callback/", HTTP_HOST="eu.posthog.com") + event = { + "type": "app_mention", + "channel": "C001", + "user": "U123", + "ts": "1234.5678", + "thread_ts": "1234.5678", + } + + result = route_posthog_code_event_to_relevant_region(request, event, "T12345") + + assert result == ROUTE_HANDLED_LOCALLY + # The whole point of the webhook gate: no Temporal workflow gets created + # for a team that's already over its AI-credits quota. + mock_sync_connect.assert_not_called() + mock_asyncio_run.assert_not_called() + mock_post_feedback.assert_called_once() + feedback_text = mock_post_feedback.call_args.args[4] + assert "PostHog AI credits" in feedback_text From 5fafefbc35ac35b04b6eab67463ea8ed604ab419 Mon Sep 17 00:00:00 2001 From: Vojta Bartos Date: Mon, 25 May 2026 12:51:03 +0200 Subject: [PATCH 07/14] chore(ee): drop trailing newline in test_quota_limits.py --- ee/api/test/test_quota_limits.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ee/api/test/test_quota_limits.py b/ee/api/test/test_quota_limits.py index 58c2258ae6aa..dd58d18f7737 100644 --- a/ee/api/test/test_quota_limits.py +++ b/ee/api/test/test_quota_limits.py @@ -147,4 +147,3 @@ def test_multi_team_user_gets_per_team_answers(self) -> None: self.assertEqual(resp_self.json()["limited"]["ai_credits"], {"limited": True}) self.assertEqual(resp_other.json()["limited"]["ai_credits"], {"limited": False}) - From 8b9eb78fd35ca5cea7313d49a11c9b5f046f8cc7 Mon Sep 17 00:00:00 2001 From: "tests-posthog[bot]" <250237707+tests-posthog[bot]@users.noreply.github.com> Date: Mon, 25 May 2026 11:19:03 +0000 Subject: [PATCH 08/14] chore: update OpenAPI generated types --- services/mcp/src/api/generated.ts | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/services/mcp/src/api/generated.ts b/services/mcp/src/api/generated.ts index f3e9e141b277..11020534bd85 100644 --- a/services/mcp/src/api/generated.ts +++ b/services/mcp/src/api/generated.ts @@ -35076,6 +35076,21 @@ export namespace Schemas { query: EventsNode | ActionsNode | PersonsNode | DataWarehouseNode | FunnelsDataWarehouseNode | LifecycleDataWarehouseNode | EventsQuery | SessionsQuery | ActorsQuery | GroupsQuery | InsightActorsQuery | InsightActorsQueryOptions | SessionsTimelineQuery | HogQuery | HogQLQuery | HogQLMetadata | HogQLAutocomplete | SessionAttributionExplorerQuery | RevenueExampleEventsQuery | RevenueExampleDataWarehouseTablesQuery | ErrorTrackingQuery | ErrorTrackingSimilarIssuesQuery | ErrorTrackingBreakdownsQuery | ErrorTrackingIssueCorrelationQuery | ExperimentFunnelsQuery | ExperimentTrendsQuery | ExperimentQuery | ExperimentExposureQuery | DocumentSimilarityQuery | WebOverviewQuery | WebStatsTableQuery | WebExternalClicksTableQuery | WebGoalsQuery | WebVitalsQuery | WebVitalsPathBreakdownQuery | WebPageURLSearchQuery | WebAnalyticsExternalSummaryQuery | WebNotableChangesQuery | RevenueAnalyticsGrossRevenueQuery | RevenueAnalyticsMetricsQuery | RevenueAnalyticsMRRQuery | RevenueAnalyticsOverviewQuery | RevenueAnalyticsTopCustomersQuery | MarketingAnalyticsTableQuery | MarketingAnalyticsAggregatedQuery | NonIntegratedConversionsTableQuery | DataVisualizationNode | DataTableNode | SavedInsightNode | InsightVizNode | TrendsQuery | FunnelsQuery | RetentionQuery | PathsQuery | StickinessQuery | LifecycleQuery | FunnelCorrelationQuery | DatabaseSchemaQuery | RecordingsQuery | LogsQuery | LogAttributesQuery | LogValuesQuery | TraceSpansQuery | TraceSpansAggregationQuery | TraceSpansTreeQuery | SuggestedQuestionsQuery | TeamTaxonomyQuery | EventTaxonomyQuery | ActorsPropertyTaxonomyQuery | TracesQuery | TraceQuery | TraceNeighborsQuery | VectorSearchQuery | UsageMetricsQuery | EndpointsUsageOverviewQuery | EndpointsUsageTableQuery | EndpointsUsageTrendsQuery | PropertyValuesQuery; } + export interface QuotaResourceLimit { + /** True when the team is currently over its quota for this resource and limits are in effect. */ + limited: boolean; + } + + /** + * Per-resource limit state keyed by `QuotaResource` value. Currently only `ai_credits` is reported; additional resources may be added. + */ + export type QuotaLimitsResponseLimited = {[key: string]: QuotaResourceLimit}; + + export interface QuotaLimitsResponse { + /** Per-resource limit state keyed by `QuotaResource` value. Currently only `ai_credits` is reported; additional resources may be added. */ + limited: QuotaLimitsResponseLimited; + } + export interface RecomputeResult { run: Run; counts_changed: boolean; From 9a125bbe82f9dee718f2ee74132d451e8442909e Mon Sep 17 00:00:00 2001 From: Vojta Bartos Date: Wed, 27 May 2026 10:14:51 +0200 Subject: [PATCH 09/14] feat(llm-gateway): report quota state for every resource --- ee/api/quota_limits.py | 12 +++++++----- ee/api/test/test_quota_limits.py | 15 +++++++++++++++ 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/ee/api/quota_limits.py b/ee/api/quota_limits.py index eb7dfe66d7d1..718e5a88c556 100644 --- a/ee/api/quota_limits.py +++ b/ee/api/quota_limits.py @@ -49,19 +49,21 @@ class QuotaLimitsViewSet(TeamAndOrgViewSetMixin, viewsets.ViewSet): @extend_schema( summary="Get a team's quota-limit state", description=( - "Return the current quota-limit state for the team identified in the URL. " - "Used by the LLM gateway to gate billable products on AI credits exhaustion." + "Return the current quota-limit state for the team identified in the URL, " + "keyed by `QuotaResource` value. Used by the LLM gateway to gate billable " + "products on AI credits exhaustion." ), responses={200: QuotaLimitsResponseSerializer}, ) def list(self, request: Request, *args: Any, **kwargs: Any) -> Response: limited = { - QuotaResource.AI_CREDITS.value: { + resource.value: { "limited": is_team_limited( self.team.api_token, - QuotaResource.AI_CREDITS, + resource, QuotaLimitingCaches.QUOTA_LIMITER_CACHE_KEY, ), - }, + } + for resource in QuotaResource } return Response(QuotaLimitsResponseSerializer({"limited": limited}).data) diff --git a/ee/api/test/test_quota_limits.py b/ee/api/test/test_quota_limits.py index dd58d18f7737..2042bcd77a7e 100644 --- a/ee/api/test/test_quota_limits.py +++ b/ee/api/test/test_quota_limits.py @@ -134,6 +134,21 @@ def test_personal_api_key_missing_required_scope_is_forbidden(self) -> None: ) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + def test_response_includes_every_quota_resource(self) -> None: + # Limiting one resource must not hide the unlimited state of the rest. + self._set_ai_credits_limit(self.team.api_token, 9_999_999_999) + + response = self.client.get(self._url()) + self.assertEqual(response.status_code, status.HTTP_200_OK) + limited = response.json()["limited"] + expected_keys = {resource.value for resource in QuotaResource} + self.assertEqual(set(limited.keys()), expected_keys) + self.assertTrue(limited["ai_credits"]["limited"]) + for resource in QuotaResource: + if resource is QuotaResource.AI_CREDITS: + continue + self.assertFalse(limited[resource.value]["limited"], resource.value) + def test_multi_team_user_gets_per_team_answers(self) -> None: # Same user belongs to two teams in their org; each team's quota is independent. # This is the regression that "me" couldn't model — `user.team` (current team) From c445434c072534f66eac435da71b3938a7387015 Mon Sep 17 00:00:00 2001 From: Vojta Bartos Date: Wed, 27 May 2026 10:16:27 +0200 Subject: [PATCH 10/14] refactor(llm-gateway): extract resolve_plan_and_quota helper --- .../llm-gateway/src/llm_gateway/api/usage.py | 22 +++------ .../src/llm_gateway/dependencies.py | 45 ++++++++++++++++--- 2 files changed, 45 insertions(+), 22 deletions(-) diff --git a/services/llm-gateway/src/llm_gateway/api/usage.py b/services/llm-gateway/src/llm_gateway/api/usage.py index bdc3facde6b4..e94976b571c3 100644 --- a/services/llm-gateway/src/llm_gateway/api/usage.py +++ b/services/llm-gateway/src/llm_gateway/api/usage.py @@ -1,6 +1,5 @@ from __future__ import annotations -import asyncio from datetime import UTC, datetime, timedelta from typing import Annotated @@ -9,8 +8,7 @@ from pydantic import BaseModel from llm_gateway.auth.models import AuthenticatedUser -from llm_gateway.dependencies import get_authenticated_user -from llm_gateway.products.config import get_product_config +from llm_gateway.dependencies import get_authenticated_user, resolve_plan_and_quota from llm_gateway.rate_limiting.cost_throttles import CostStatus, UserCostBurstThrottle, UserCostSustainedThrottle from llm_gateway.rate_limiting.runner import ThrottleRunner from llm_gateway.rate_limiting.throttles import ThrottleContext @@ -19,9 +17,7 @@ PlanResolver, is_pro_plan, parse_iso_utc, - resolve_plan_info, ) -from llm_gateway.services.quota_resolver import resolve_quota_status logger = structlog.get_logger(__name__) @@ -74,18 +70,14 @@ async def get_usage( ) -> UsageResponse: runner: ThrottleRunner = request.app.state.throttle_runner - plan_info, quota_status = await asyncio.gather( - resolve_plan_info(request, user.user_id, product), - resolve_quota_status(request, user.team_id), + plan_info, quota_status = await resolve_plan_and_quota( + request, + user_id=user.user_id, + team_id=user.team_id, + product=product, ) now = datetime.now(tz=UTC) - - # Non-billable products don't actually get denied by BillableCreditThrottle, - # so reporting "exhausted" here would be misleading regardless of the team's - # AI credits state. - product_config = get_product_config(product) - is_billable = bool(product_config and product_config.billable) - ai_credits_exhausted = quota_status.limited if is_billable else False + ai_credits_exhausted = quota_status.limited context = ThrottleContext( user=user, diff --git a/services/llm-gateway/src/llm_gateway/dependencies.py b/services/llm-gateway/src/llm_gateway/dependencies.py index 01f948a9de98..752357f92903 100644 --- a/services/llm-gateway/src/llm_gateway/dependencies.py +++ b/services/llm-gateway/src/llm_gateway/dependencies.py @@ -11,7 +11,12 @@ from llm_gateway.auth.models import AuthenticatedUser from llm_gateway.auth.service import AuthService, get_auth_service from llm_gateway.circuit_breaker import AnthropicCircuitBreaker -from llm_gateway.products.config import ALLOWED_PRODUCTS, check_product_access, resolve_product_alias +from llm_gateway.products.config import ( + ALLOWED_PRODUCTS, + check_product_access, + get_product_config, + resolve_product_alias, +) from llm_gateway.rate_limiting.cost_refresh import ensure_costs_fresh from llm_gateway.rate_limiting.runner import ThrottleRunner from llm_gateway.rate_limiting.throttles import ThrottleContext @@ -20,8 +25,8 @@ get_request_id, set_throttle_context, ) -from llm_gateway.services.plan_resolver import resolve_plan_info -from llm_gateway.services.quota_resolver import resolve_quota_status +from llm_gateway.services.plan_resolver import PlanInfo, resolve_plan_info +from llm_gateway.services.quota_resolver import QuotaResourceStatus, resolve_quota_status logger = structlog.get_logger(__name__) @@ -151,6 +156,31 @@ async def _extract_end_user_id_from_body(request: Request) -> str | None: return None +async def resolve_plan_and_quota( + request: Request, + *, + user_id: int, + team_id: int | None, + product: str, +) -> tuple[PlanInfo, QuotaResourceStatus]: + """Fetch plan info and (for billable products) AI credits quota in parallel. + + Both calls are independent Django roundtrips on cache miss, so for billable + products we overlap them. For non-billable products the throttle stack + short-circuits regardless of quota state, so we skip the resolver entirely + rather than paying for the Redis GET (and the HTTP fallback on cache miss). + """ + product_config = get_product_config(product) + if product_config and product_config.billable: + plan_info, quota_status = await asyncio.gather( + resolve_plan_info(request, user_id, product), + resolve_quota_status(request, team_id), + ) + return plan_info, quota_status + plan_info = await resolve_plan_info(request, user_id, product) + return plan_info, QuotaResourceStatus(limited=False) + + async def enforce_throttles( request: Request, user: Annotated[AuthenticatedUser, Depends(enforce_product_access)], @@ -165,10 +195,11 @@ async def enforce_throttles( else: end_user_id = await _extract_end_user_id_from_body(request) - # Plan + quota are independent Django roundtrips on cache miss — overlap them. - plan_info, quota_status = await asyncio.gather( - resolve_plan_info(request, user.user_id, product), - resolve_quota_status(request, user.team_id), + plan_info, quota_status = await resolve_plan_and_quota( + request, + user_id=user.user_id, + team_id=user.team_id, + product=product, ) context = ThrottleContext( From 9159625375a8b658bf326e4e490483b1bac57475 Mon Sep 17 00:00:00 2001 From: Vojta Bartos Date: Wed, 27 May 2026 10:18:01 +0200 Subject: [PATCH 11/14] feat(llm-gateway): bump quota cache TTL to 5 minutes --- services/llm-gateway/src/llm_gateway/config.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/services/llm-gateway/src/llm_gateway/config.py b/services/llm-gateway/src/llm_gateway/config.py index 843072a6f68c..f279164f9197 100644 --- a/services/llm-gateway/src/llm_gateway/config.py +++ b/services/llm-gateway/src/llm_gateway/config.py @@ -160,10 +160,10 @@ class Settings(BaseSettings): posthog_api_base_url: str = "https://us.posthog.com" plan_cache_ttl: int = 900 # 15 minutes - # AI credits quota state is fast-moving — Django itself caches the underlying - # Redis set for 30 seconds, so caching the resolver result for any longer - # would just stack staleness without saving lookups. - quota_cache_ttl: int = 30 + # Billing recomputes quota state on at most an hourly cadence, so we are + # comfortable letting a team go slightly over their limit in exchange for + # avoiding a Django roundtrip on every billable request. + quota_cache_ttl: int = 300 # 5 minutes billing_period_days: int = 30 # Anthropic -> Bedrock circuit breaker. When the trailing failure rate of the Anthropic From 0baee26248253e8a5cfeade476dcd77895452abd Mon Sep 17 00:00:00 2001 From: Vojta Bartos Date: Wed, 27 May 2026 10:27:18 +0200 Subject: [PATCH 12/14] refactor(slack-bot): keep AI credits gate inside the workflow only --- .../temporal/ai/posthog_code_slack_mention.py | 63 +++---------------- products/slack_app/backend/api.py | 48 +++++++------- .../tests/test_posthog_code_event_handler.py | 16 ++--- 3 files changed, 42 insertions(+), 85 deletions(-) diff --git a/posthog/temporal/ai/posthog_code_slack_mention.py b/posthog/temporal/ai/posthog_code_slack_mention.py index 683a998f85ea..734288f73733 100644 --- a/posthog/temporal/ai/posthog_code_slack_mention.py +++ b/posthog/temporal/ai/posthog_code_slack_mention.py @@ -61,55 +61,6 @@ def _safe_react(client: Any, channel: str, timestamp: str, name: str) -> None: _INITIATOR_PLACEHOLDER = "" -_QUOTA_EXHAUSTED_MESSAGE = ( - "Your team has used its monthly PostHog AI credits. " - "Top up at https://us.posthog.com/organization/billing to continue." -) - - -def _block_if_team_over_quota( - *, - integration: Any, - slack: Any, - channel: str, - thread_ts: str, - slack_user_id: str, - context: Literal["task_create", "followup"], -) -> bool: - """Reject a Slack-bot turn when the team is over its AI credits quota. - - Mirrors PHAI's enforcement model (ee/api/conversation.py): every - user-initiated turn — new mention or follow-up reply — is gated against the - same Redis-backed `QuotaResource.AI_CREDITS` set. Returns True when the - team is blocked, posts a friendly in-thread denial as a side effect. - """ - from products.slack_app.backend.api import _post_slack_user_feedback - - from ee.billing.quota_limiting import QuotaLimitingCaches, QuotaResource, is_team_limited - - if not is_team_limited( - integration.team.api_token, QuotaResource.AI_CREDITS, QuotaLimitingCaches.QUOTA_LIMITER_CACHE_KEY - ): - return False - - logger.info( - "posthog_code_slack_blocked_by_quota", - context=context, - team_id=integration.team_id, - channel=channel, - thread_ts=thread_ts, - ) - _post_slack_user_feedback( - slack, - channel, - slack_user_id, - thread_ts, - _QUOTA_EXHAUSTED_MESSAGE, - prefer_thread_message=True, - ) - return True - - def _strip_context_tag(text: str) -> str: return re.sub(rf"", "", text, flags=re.IGNORECASE) @@ -973,13 +924,15 @@ def enforce_posthog_code_billing_quota_activity( """ from posthog.models.integration import Integration, SlackIntegration - integration = Integration.objects.select_related("team", "team__organization").get( + from products.slack_app.backend.api import block_if_team_over_quota + + integration = Integration.objects.select_related("team").get( id=inputs.integration_id, kind="slack-posthog-code", integration_id=inputs.slack_team_id, ) slack = SlackIntegration(integration) - return _block_if_team_over_quota( + return block_if_team_over_quota( integration=integration, slack=slack, channel=channel, @@ -1129,6 +1082,8 @@ def create_posthog_code_task_for_repo_activity( thread_messages: list[dict[str, str]], repository: str | None, ) -> None: + from products.slack_app.backend.api import block_if_team_over_quota + integration = Integration.objects.select_related("team", "team__organization").get( id=inputs.integration_id, kind="slack-posthog-code", @@ -1138,7 +1093,7 @@ def create_posthog_code_task_for_repo_activity( # Refuse before the :seedling: reaction or the permalink fetch: a denied # mention should not first ack-react and then refuse a second later. - if _block_if_team_over_quota( + if block_if_team_over_quota( integration=integration, slack=slack, channel=channel, @@ -1319,6 +1274,8 @@ def forward_posthog_code_followup_activity( if _parse_rules_command(event_text): return False + from products.slack_app.backend.api import block_if_team_over_quota + try: mapping = SlackThreadTaskMapping.objects.select_related("task_run", "task__created_by").get( integration_id=inputs.integration_id, @@ -1353,7 +1310,7 @@ def forward_posthog_code_followup_activity( ) return True - if _block_if_team_over_quota( + if block_if_team_over_quota( integration=integration, slack=slack, channel=channel, diff --git a/products/slack_app/backend/api.py b/products/slack_app/backend/api.py index a3bb72da41e1..8eaa2ea42828 100644 --- a/products/slack_app/backend/api.py +++ b/products/slack_app/backend/api.py @@ -299,21 +299,30 @@ def lookup_slack_user_id_by_email( return slack_user_id -def _refuse_mention_if_team_over_quota( - slack: SlackIntegration, - event: dict[str, Any], +QUOTA_EXHAUSTED_MESSAGE = ( + "Your team has used its monthly PostHog AI credits. " + "Top up at https://us.posthog.com/organization/billing to continue." +) + + +def block_if_team_over_quota( + *, integration: Integration, + slack: SlackIntegration, + channel: str, + thread_ts: str, + slack_user_id: str, + context: str, ) -> bool: - """Refuse an app_mention at the webhook layer when the team is over quota. - - Returns True when the mention was refused and a denial was posted, so the - caller should bail without starting the workflow. The activity-level gate - inside ``PostHogCodeSlackMentionWorkflow`` is the defense in depth; doing - the same check here saves a Slack roundtrip, a Temporal execution, and a - billable classifier call for the common "team is exhausted" case. + """Refuse a Slack-bot turn when the team is over its AI credits quota. + + Mirrors PHAI's enforcement (ee/api/conversation.py): every user-initiated + turn — webhook mention, task-creation activity, or follow-up reply — is + gated against the same Redis-backed ``QuotaResource.AI_CREDITS`` set. + Returns True when the team is blocked and posts a friendly in-thread + denial as a side effect. ``context`` is a free-form label that the call + site uses to disambiguate itself in structured logs. """ - from posthog.temporal.ai.posthog_code_slack_mention import _QUOTA_EXHAUSTED_MESSAGE - from ee.billing.quota_limiting import QuotaLimitingCaches, QuotaResource, is_team_limited if not is_team_limited( @@ -323,16 +332,9 @@ def _refuse_mention_if_team_over_quota( ): return False - channel = event.get("channel") - thread_ts = event.get("thread_ts") or event.get("ts") - slack_user_id = event.get("user") - if not channel or not thread_ts or not slack_user_id: - # The denial needs somewhere to land; without these the workflow - # would have bailed at the top of `run()` anyway. - return True - logger.info( - "posthog_code_slack_mention_blocked_by_quota_at_webhook", + "posthog_code_slack_blocked_by_quota", + context=context, team_id=integration.team_id, channel=channel, thread_ts=thread_ts, @@ -342,7 +344,7 @@ def _refuse_mention_if_team_over_quota( channel, slack_user_id, thread_ts, - _QUOTA_EXHAUSTED_MESSAGE, + QUOTA_EXHAUSTED_MESSAGE, prefer_thread_message=True, ) return True @@ -1270,8 +1272,6 @@ def route_posthog_code_event_to_relevant_region( return ROUTE_HANDLED_LOCALLY if _resolve_pending_repo_picker_from_followup(event, local_match): return ROUTE_HANDLED_LOCALLY - if _refuse_mention_if_team_over_quota(slack, event, local_match): - return ROUTE_HANDLED_LOCALLY workflow_inputs = PostHogCodeSlackMentionWorkflowInputs( event=event, integration_id=local_match.id, diff --git a/products/slack_app/backend/tests/test_posthog_code_event_handler.py b/products/slack_app/backend/tests/test_posthog_code_event_handler.py index 000e158a66b4..685dc1b7a90f 100644 --- a/products/slack_app/backend/tests/test_posthog_code_event_handler.py +++ b/products/slack_app/backend/tests/test_posthog_code_event_handler.py @@ -509,7 +509,7 @@ def test_app_mention_with_missing_scopes_posts_reauth_and_skips_workflow( @patch("products.slack_app.backend.api.asyncio.run") @patch("products.slack_app.backend.api.sync_connect") @override_settings(DEBUG=False) - def test_over_quota_team_refuses_at_webhook_without_starting_workflow( + def test_over_quota_team_still_starts_workflow_for_in_workflow_enforcement( self, mock_sync_connect, mock_asyncio_run, @@ -535,10 +535,10 @@ def test_over_quota_team_refuses_at_webhook_without_starting_workflow( result = route_posthog_code_event_to_relevant_region(request, event, "T12345") assert result == ROUTE_HANDLED_LOCALLY - # The whole point of the webhook gate: no Temporal workflow gets created - # for a team that's already over its AI-credits quota. - mock_sync_connect.assert_not_called() - mock_asyncio_run.assert_not_called() - mock_post_feedback.assert_called_once() - feedback_text = mock_post_feedback.call_args.args[4] - assert "PostHog AI credits" in feedback_text + # Quota enforcement now lives entirely inside the workflow — the webhook + # always schedules it, and the activity-level gate posts the denial + # before any billable LLM call. This guarantees a single owner for the + # quota decision instead of dual gating. + mock_sync_connect.assert_called_once() + mock_asyncio_run.assert_called_once() + mock_post_feedback.assert_not_called() From 6cb735f1c8721d2991ec00f9cd8ee076a993d88f Mon Sep 17 00:00:00 2001 From: Vojta Bartos Date: Wed, 27 May 2026 11:05:16 +0200 Subject: [PATCH 13/14] feat(llm-gateway): retry transient quota fetches with exponential backoff --- .../llm_gateway/services/quota_resolver.py | 77 ++++++++++++---- .../llm-gateway/tests/test_quota_resolver.py | 87 ++++++++++++++++--- 2 files changed, 137 insertions(+), 27 deletions(-) diff --git a/services/llm-gateway/src/llm_gateway/services/quota_resolver.py b/services/llm-gateway/src/llm_gateway/services/quota_resolver.py index 3fab65fc30d2..0004f09a8953 100644 --- a/services/llm-gateway/src/llm_gateway/services/quota_resolver.py +++ b/services/llm-gateway/src/llm_gateway/services/quota_resolver.py @@ -3,20 +3,26 @@ Mirrors :mod:`llm_gateway.services.plan_resolver` — forwards the caller's ``Authorization`` header to ``GET /api/projects/{team_id}/quota_limits/`` and caches the result per team-and-resource in the gateway's own Redis. + +Transient upstream failures (network errors, 5xx) are retried within the +request with exponential backoff. 4xx responses or exhausted retries fall +open and briefly cache ``limited=False`` so a struggling Django isn't hit on +every subsequent request. """ from __future__ import annotations +import asyncio import json from dataclasses import dataclass from typing import TYPE_CHECKING +import httpx import structlog from llm_gateway.config import get_settings if TYPE_CHECKING: - import httpx from fastapi import Request from redis.asyncio import Redis @@ -24,10 +30,23 @@ _AI_CREDITS_RESOURCE = "ai_credits" -# Cache window for the fail-open path (4xx from Django, e.g. expired token). -# Short enough that a fixed-up token recovers quickly; long enough to keep a -# misconfigured client off Django's neck during an auth-failure storm. -_FAIL_OPEN_CACHE_TTL_SECONDS = 5 +# Cache window for the fail-open path (4xx from Django, or transient failure +# after retries are exhausted). Long enough to keep a misconfigured client off +# Django's neck during an auth-failure storm; short enough that a recovered +# upstream is consulted again within a minute. +_FAIL_OPEN_CACHE_TTL_SECONDS = 60 + +# Exponential backoff between within-request retries on transient failures. +# The first attempt fires immediately; each subsequent retry waits +# ``MULTIPLIER * 2**n`` seconds, doubling the gap each step. Tune the +# multiplier to widen or tighten the overall spacing without touching the +# formula. +_MAX_RETRIES = 3 +_RETRY_BACKOFF_MULTIPLIER_SECONDS = 5 +_RETRY_DELAYS_SECONDS: tuple[float, ...] = ( + 0, + *(_RETRY_BACKOFF_MULTIPLIER_SECONDS * 2**i for i in range(_MAX_RETRIES)), +) @dataclass @@ -35,6 +54,10 @@ class QuotaResourceStatus: limited: bool +class _TransientUpstreamError(Exception): + """Retryable failure: a 5xx response or a network-level error.""" + + def _redis_key(resource_key: str, team_id: int) -> str: return f"quota:{resource_key}:team:{team_id}" @@ -72,31 +95,53 @@ async def _get_resource_status(self, resource_key: str, team_id: int, auth_heade return cached try: - status, ttl = await self._fetch(resource_key, team_id, auth_header) + status, ttl = await self._fetch_with_retry(resource_key, team_id, auth_header) except Exception: logger.warning("quota_fetch_failed", resource=resource_key, team_id=team_id, exc_info=True) - return QuotaResourceStatus(limited=False) + status, ttl = QuotaResourceStatus(limited=False), _FAIL_OPEN_CACHE_TTL_SECONDS await self._set_cached(resource_key, team_id, status, ttl) return status - async def _fetch(self, resource_key: str, team_id: int, auth_header: str) -> tuple[QuotaResourceStatus, int]: - """Return the resource status and the TTL the caller should cache it for. + async def _fetch_with_retry( + self, resource_key: str, team_id: int, auth_header: str + ) -> tuple[QuotaResourceStatus, int]: + """Try the upstream up to ``len(_RETRY_DELAYS_SECONDS)`` times. - 4xx responses are treated as "not limited" and cached briefly so a hot - loop with a broken token can't hammer Django. + Network errors and 5xx responses are retried with growing waits between + attempts. 4xx and successful responses return immediately from + :meth:`_fetch`. If every attempt raises a transient error the last + exception is re-raised for the caller to fail open. """ + last_exc: Exception | None = None + for delay in _RETRY_DELAYS_SECONDS: + if delay: + await asyncio.sleep(delay) + try: + return await self._fetch(resource_key, team_id, auth_header) + except _TransientUpstreamError as exc: + last_exc = exc + assert last_exc is not None + raise last_exc + + async def _fetch(self, resource_key: str, team_id: int, auth_header: str) -> tuple[QuotaResourceStatus, int]: + """One attempt against Django. Raises :class:`_TransientUpstreamError` on retryable failures.""" settings = get_settings() if not settings.posthog_api_base_url: return QuotaResourceStatus(limited=False), _FAIL_OPEN_CACHE_TTL_SECONDS url = f"{settings.posthog_api_base_url.rstrip('/')}/api/projects/{team_id}/quota_limits/" - resp = await self._http.get( - url, - headers={"Authorization": auth_header}, - timeout=2.0, - ) + try: + resp = await self._http.get(url, headers={"Authorization": auth_header}, timeout=2.0) + except httpx.RequestError as exc: + raise _TransientUpstreamError(str(exc)) from exc + + if resp.status_code >= 500: + raise _TransientUpstreamError(f"quota_limits returned {resp.status_code}") if resp.status_code >= 400: + # 4xx is permanent for the lifetime of this request — bad token, + # missing team, scope mismatch. Fail open briefly so a hot loop + # with a broken token doesn't hammer Django. return QuotaResourceStatus(limited=False), _FAIL_OPEN_CACHE_TTL_SECONDS data = resp.json() diff --git a/services/llm-gateway/tests/test_quota_resolver.py b/services/llm-gateway/tests/test_quota_resolver.py index d7bae046fe00..28254e6c11e5 100644 --- a/services/llm-gateway/tests/test_quota_resolver.py +++ b/services/llm-gateway/tests/test_quota_resolver.py @@ -1,12 +1,15 @@ from __future__ import annotations import json -from unittest.mock import AsyncMock, MagicMock +from collections.abc import Iterator +from unittest.mock import AsyncMock, MagicMock, patch import httpx import pytest from llm_gateway.services.quota_resolver import ( + _FAIL_OPEN_CACHE_TTL_SECONDS, + _RETRY_DELAYS_SECONDS, QuotaResolver, QuotaResourceStatus, _redis_key, @@ -27,6 +30,26 @@ def _make_http_client(response: httpx.Response | Exception) -> MagicMock: return client +def _make_http_client_sequence(responses: list[httpx.Response | Exception]) -> MagicMock: + client = MagicMock() + + async def _next(*args: object, **kwargs: object) -> httpx.Response: + item = responses.pop(0) + if isinstance(item, Exception): + raise item + return item + + client.get = AsyncMock(side_effect=_next) + return client + + +@pytest.fixture(autouse=True) +def _no_retry_sleep() -> Iterator[None]: + """Don't actually sleep between retries — keeps tests fast and deterministic.""" + with patch("llm_gateway.services.quota_resolver.asyncio.sleep", new=AsyncMock()): + yield + + class TestQuotaResolver: @pytest.mark.asyncio async def test_fetches_and_parses_limited_response(self) -> None: @@ -53,22 +76,64 @@ async def test_fetches_and_parses_unlimited_response(self) -> None: assert status == QuotaResourceStatus(limited=False) @pytest.mark.asyncio - async def test_fail_open_on_http_error(self) -> None: - http_client = _make_http_client(httpx.ConnectError("boom")) + async def test_4xx_fails_open_without_retrying(self) -> None: + # 4xx is treated as a permanent failure for the lifetime of this request, + # so we don't burn the retry budget on a token that won't fix itself. + http_client = _make_http_client(_make_response(401, {"detail": "no auth"})) resolver = QuotaResolver(redis=None, http_client=http_client) status = await resolver.get_ai_credits_status(team_id=42, auth_header="Bearer phx_test") assert status == QuotaResourceStatus(limited=False) + http_client.get.assert_awaited_once() @pytest.mark.asyncio - async def test_fail_open_on_4xx(self) -> None: - http_client = _make_http_client(_make_response(401, {"detail": "no auth"})) + async def test_retries_on_5xx_and_succeeds(self) -> None: + # A transient 503 is retried; the next attempt succeeds. + http_client = _make_http_client_sequence( + [ + _make_response(503, {}), + _make_response(200, {"limited": {"ai_credits": {"limited": True}}}), + ] + ) resolver = QuotaResolver(redis=None, http_client=http_client) status = await resolver.get_ai_credits_status(team_id=42, auth_header="Bearer phx_test") + assert status == QuotaResourceStatus(limited=True) + assert http_client.get.await_count == 2 + + @pytest.mark.asyncio + async def test_retries_on_network_error_and_succeeds(self) -> None: + http_client = _make_http_client_sequence( + [ + httpx.ConnectError("boom"), + _make_response(200, {"limited": {"ai_credits": {"limited": False}}}), + ] + ) + resolver = QuotaResolver(redis=None, http_client=http_client) + + status = await resolver.get_ai_credits_status(team_id=42, auth_header="Bearer phx_test") + + assert status == QuotaResourceStatus(limited=False) + assert http_client.get.await_count == 2 + + @pytest.mark.asyncio + async def test_gives_up_after_all_retries_and_fails_open(self) -> None: + # Consecutive network errors exhaust the retry budget; we fall open + # and cache the answer for the fail-open window. + http_client = _make_http_client(httpx.ConnectError("boom")) + redis = AsyncMock() + redis.get = AsyncMock(return_value=None) + redis.set = AsyncMock() + resolver = QuotaResolver(redis=redis, http_client=http_client) + + status = await resolver.get_ai_credits_status(team_id=42, auth_header="Bearer phx_test") + assert status == QuotaResourceStatus(limited=False) + assert http_client.get.await_count == len(_RETRY_DELAYS_SECONDS) + redis.set.assert_awaited_once() + assert redis.set.await_args.kwargs.get("ex") == _FAIL_OPEN_CACHE_TTL_SECONDS @pytest.mark.asyncio async def test_uses_cached_result_and_skips_http(self) -> None: @@ -97,13 +162,13 @@ async def test_writes_cache_on_miss(self) -> None: call = redis.set.await_args assert call.args[0] == _redis_key("ai_credits", 42) assert json.loads(call.args[1]) == {"limited": True} - # TTL matches the gateway settings default of 30s. - assert call.kwargs.get("ex") == 30 + # Successful fetches use the gateway settings default of 5 minutes. + assert call.kwargs.get("ex") == 300 @pytest.mark.asyncio - async def test_caches_fail_open_briefly_on_4xx(self) -> None: - # 4xx responses (e.g. expired token) cache for a short window so a hot - # loop with a broken token doesn't hammer Django. + async def test_caches_fail_open_for_full_window_on_4xx(self) -> None: + # 4xx responses (e.g. expired token) cache for the fail-open window so + # a hot loop with a broken token doesn't hammer Django. redis = AsyncMock() redis.get = AsyncMock(return_value=None) redis.set = AsyncMock() @@ -113,4 +178,4 @@ async def test_caches_fail_open_briefly_on_4xx(self) -> None: await resolver.get_ai_credits_status(team_id=42, auth_header="Bearer phx_test") redis.set.assert_awaited_once() - assert redis.set.await_args.kwargs.get("ex") == 5 + assert redis.set.await_args.kwargs.get("ex") == _FAIL_OPEN_CACHE_TTL_SECONDS From 9d08641b0ecf4f05ee552af34585a2558eccc356 Mon Sep 17 00:00:00 2001 From: Vojta Bartos Date: Wed, 27 May 2026 13:43:44 +0200 Subject: [PATCH 14/14] fix(slack-bot): move quota check to temporal layer to satisfy tach --- .../temporal/ai/posthog_code_slack_mention.py | 51 +++++++++++++++---- products/slack_app/backend/api.py | 28 +++------- 2 files changed, 50 insertions(+), 29 deletions(-) diff --git a/posthog/temporal/ai/posthog_code_slack_mention.py b/posthog/temporal/ai/posthog_code_slack_mention.py index 734288f73733..8a4c8071a039 100644 --- a/posthog/temporal/ai/posthog_code_slack_mention.py +++ b/posthog/temporal/ai/posthog_code_slack_mention.py @@ -35,6 +35,45 @@ logger = structlog.get_logger(__name__) +def _block_if_team_over_quota( + *, + integration: Any, + slack: Any, + channel: str, + thread_ts: str, + slack_user_id: str, + context: str, +) -> bool: + """Refuse a Slack-bot turn when the team is over its AI credits quota. + + Tach blocks ``products.slack_app`` from importing ``ee.billing``, so the + quota lookup lives here (where the temporal layer can freely import ee) + while the user-facing denial message lives in ``slack_app.backend.api`` + (where the Slack-posting helpers live). Returns True when the team was + blocked and a denial was posted. + """ + from products.slack_app.backend.api import post_quota_exhausted_denial + + from ee.billing.quota_limiting import QuotaLimitingCaches, QuotaResource, is_team_limited + + if not is_team_limited( + integration.team.api_token, + QuotaResource.AI_CREDITS, + QuotaLimitingCaches.QUOTA_LIMITER_CACHE_KEY, + ): + return False + + post_quota_exhausted_denial( + integration=integration, + slack=slack, + channel=channel, + thread_ts=thread_ts, + slack_user_id=slack_user_id, + context=context, + ) + return True + + def _safe_react(client: Any, channel: str, timestamp: str, name: str) -> None: try: client.reactions_add(channel=channel, timestamp=timestamp, name=name) @@ -924,15 +963,13 @@ def enforce_posthog_code_billing_quota_activity( """ from posthog.models.integration import Integration, SlackIntegration - from products.slack_app.backend.api import block_if_team_over_quota - integration = Integration.objects.select_related("team").get( id=inputs.integration_id, kind="slack-posthog-code", integration_id=inputs.slack_team_id, ) slack = SlackIntegration(integration) - return block_if_team_over_quota( + return _block_if_team_over_quota( integration=integration, slack=slack, channel=channel, @@ -1082,8 +1119,6 @@ def create_posthog_code_task_for_repo_activity( thread_messages: list[dict[str, str]], repository: str | None, ) -> None: - from products.slack_app.backend.api import block_if_team_over_quota - integration = Integration.objects.select_related("team", "team__organization").get( id=inputs.integration_id, kind="slack-posthog-code", @@ -1093,7 +1128,7 @@ def create_posthog_code_task_for_repo_activity( # Refuse before the :seedling: reaction or the permalink fetch: a denied # mention should not first ack-react and then refuse a second later. - if block_if_team_over_quota( + if _block_if_team_over_quota( integration=integration, slack=slack, channel=channel, @@ -1274,8 +1309,6 @@ def forward_posthog_code_followup_activity( if _parse_rules_command(event_text): return False - from products.slack_app.backend.api import block_if_team_over_quota - try: mapping = SlackThreadTaskMapping.objects.select_related("task_run", "task__created_by").get( integration_id=inputs.integration_id, @@ -1310,7 +1343,7 @@ def forward_posthog_code_followup_activity( ) return True - if block_if_team_over_quota( + if _block_if_team_over_quota( integration=integration, slack=slack, channel=channel, diff --git a/products/slack_app/backend/api.py b/products/slack_app/backend/api.py index 8eaa2ea42828..6eeb4b763328 100644 --- a/products/slack_app/backend/api.py +++ b/products/slack_app/backend/api.py @@ -305,7 +305,7 @@ def lookup_slack_user_id_by_email( ) -def block_if_team_over_quota( +def post_quota_exhausted_denial( *, integration: Integration, slack: SlackIntegration, @@ -313,25 +313,14 @@ def block_if_team_over_quota( thread_ts: str, slack_user_id: str, context: str, -) -> bool: - """Refuse a Slack-bot turn when the team is over its AI credits quota. - - Mirrors PHAI's enforcement (ee/api/conversation.py): every user-initiated - turn — webhook mention, task-creation activity, or follow-up reply — is - gated against the same Redis-backed ``QuotaResource.AI_CREDITS`` set. - Returns True when the team is blocked and posts a friendly in-thread - denial as a side effect. ``context`` is a free-form label that the call - site uses to disambiguate itself in structured logs. - """ - from ee.billing.quota_limiting import QuotaLimitingCaches, QuotaResource, is_team_limited - - if not is_team_limited( - integration.team.api_token, - QuotaResource.AI_CREDITS, - QuotaLimitingCaches.QUOTA_LIMITER_CACHE_KEY, - ): - return False +) -> None: + """Post the AI-credits denial message into a Slack thread. + Called by the workflow's quota gate after it determines the team is over + quota. Lives in this module so the Slack-posting helpers and the message + text stay co-located; the quota check itself lives in the temporal layer + (which is the only side allowed to import ``ee.billing``). + """ logger.info( "posthog_code_slack_blocked_by_quota", context=context, @@ -347,7 +336,6 @@ def block_if_team_over_quota( QUOTA_EXHAUSTED_MESSAGE, prefer_thread_message=True, ) - return True def _post_slack_user_feedback(