diff --git a/ee/api/quota_limits.py b/ee/api/quota_limits.py new file mode 100644 index 000000000000..718e5a88c556 --- /dev/null +++ b/ee/api/quota_limits.py @@ -0,0 +1,69 @@ +"""Expose a team's quota-limit state. + +Backs the LLM gateway's `QuotaResolver`, which forwards the caller's auth +header here to learn whether a given team is currently over its AI credits +quota. Project-nested so org membership and token `scoped_teams`/ +`scoped_organizations` enforcement come from the standard +`TeamAndOrgViewSetMixin` permission chain — see +`posthog.permissions.APIScopePermission.check_team_and_org_permissions`. +""" + +from __future__ import annotations + +from typing import Any + +from drf_spectacular.utils import extend_schema +from rest_framework import serializers, viewsets +from rest_framework.request import Request +from rest_framework.response import Response + +from posthog.api.routing import TeamAndOrgViewSetMixin + +from ee.billing.quota_limiting import QuotaLimitingCaches, QuotaResource, is_team_limited + + +class QuotaResourceLimitSerializer(serializers.Serializer): + limited = serializers.BooleanField( + help_text="True when the team is currently over its quota for this resource and limits are in effect.", + ) + + +class QuotaLimitsResponseSerializer(serializers.Serializer): + limited = serializers.DictField( + child=QuotaResourceLimitSerializer(), + help_text=( + "Per-resource limit state keyed by `QuotaResource` value. " + "Currently only `ai_credits` is reported; additional resources may be added." + ), + ) + + +@extend_schema(tags=["quota_limits"]) +class QuotaLimitsViewSet(TeamAndOrgViewSetMixin, viewsets.ViewSet): + """Read-only view of a team's quota-limit state.""" + + scope_object = "project" + required_scopes = ["project:read"] + http_method_names = ["get", "head", "options"] + + @extend_schema( + summary="Get a team's quota-limit state", + description=( + "Return the current quota-limit state for the team identified in the URL, " + "keyed by `QuotaResource` value. Used by the LLM gateway to gate billable " + "products on AI credits exhaustion." + ), + responses={200: QuotaLimitsResponseSerializer}, + ) + def list(self, request: Request, *args: Any, **kwargs: Any) -> Response: + limited = { + resource.value: { + "limited": is_team_limited( + self.team.api_token, + resource, + QuotaLimitingCaches.QUOTA_LIMITER_CACHE_KEY, + ), + } + for resource in QuotaResource + } + return Response(QuotaLimitsResponseSerializer({"limited": limited}).data) diff --git a/ee/api/test/test_quota_limits.py b/ee/api/test/test_quota_limits.py new file mode 100644 index 000000000000..2042bcd77a7e --- /dev/null +++ b/ee/api/test/test_quota_limits.py @@ -0,0 +1,164 @@ +from __future__ import annotations + +from posthog.test.base import APIBaseTest + +from rest_framework import status + +from posthog.models.organization import Organization +from posthog.models.personal_api_key import PersonalAPIKey +from posthog.models.team import Team +from posthog.models.utils import generate_random_token_personal, hash_key_value + +from ee.billing.quota_limiting import ( + QuotaLimitingCaches, + QuotaResource, + add_limited_team_tokens, + replace_limited_team_tokens, +) + + +def _clear_ai_credits_limits() -> None: + replace_limited_team_tokens(QuotaResource.AI_CREDITS, {}, QuotaLimitingCaches.QUOTA_LIMITER_CACHE_KEY) + + +class TestQuotaLimitsAPI(APIBaseTest): + def setUp(self) -> None: + super().setUp() + _clear_ai_credits_limits() + + def tearDown(self) -> None: + _clear_ai_credits_limits() + super().tearDown() + + def _url(self, team_id: int | None = None) -> str: + return f"/api/projects/{team_id if team_id is not None else self.team.pk}/quota_limits/" + + def _set_ai_credits_limit(self, team_api_token: str, expires_at: int) -> None: + add_limited_team_tokens( + QuotaResource.AI_CREDITS, + {team_api_token: expires_at}, + QuotaLimitingCaches.QUOTA_LIMITER_CACHE_KEY, + ) + + def test_unauthenticated_returns_401_or_403(self) -> None: + self.client.logout() + response = self.client.get(self._url()) + # DRF returns 401 when no creds are presented and an authenticator that supports + # a WWW-Authenticate challenge is configured; otherwise it returns 403. Either is + # an auth failure — we only care that the endpoint refuses unauthenticated reads. + self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + + def test_session_auth_returns_under_quota_when_team_not_limited(self) -> None: + response = self.client.get(self._url()) + self.assertEqual(response.status_code, status.HTTP_200_OK) + data = response.json() + self.assertEqual(data["limited"]["ai_credits"], {"limited": False}) + + def test_returns_limited_when_team_is_over_quota(self) -> None: + self._set_ai_credits_limit(self.team.api_token, 9_999_999_999) + + response = self.client.get(self._url()) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.json()["limited"]["ai_credits"], {"limited": True}) + + def test_returns_unlimited_when_limit_has_already_expired(self) -> None: + self._set_ai_credits_limit(self.team.api_token, 1) # epoch 1970 + + response = self.client.get(self._url()) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.json()["limited"]["ai_credits"], {"limited": False}) + + def test_personal_api_key_auth_works(self) -> None: + self.client.logout() + raw_key = generate_random_token_personal() + PersonalAPIKey.objects.create( + label="quota_limits-test", + user=self.user, + secure_value=hash_key_value(raw_key), + scopes=["project:read"], + ) + + self._set_ai_credits_limit(self.team.api_token, 9_999_999_999) + + response = self.client.get( + self._url(), + headers={"authorization": f"Bearer {raw_key}"}, + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.json()["limited"]["ai_credits"], {"limited": True}) + + def test_user_not_in_teams_org_is_forbidden(self) -> None: + other_org = Organization.objects.create(name="other-org") + other_team = Team.objects.create(organization=other_org, name="other-team") + + response = self.client.get(self._url(other_team.pk)) + # The caller is logged in to a team in a different org — TeamMemberAccessPermission + # rejects with 403 (or 404 if the queryset can't see the team at all). + self.assertIn(response.status_code, (status.HTTP_403_FORBIDDEN, status.HTTP_404_NOT_FOUND)) + + def test_personal_api_key_scoped_to_a_different_team_is_forbidden(self) -> None: + # Caller has access to both teams via membership, but the token is scoped to + # `other_team` only — the standalone-endpoint design would have missed this and + # leaked the other team's state. + other_team = Team.objects.create(organization=self.organization, name="other-team") + self.client.logout() + raw_key = generate_random_token_personal() + PersonalAPIKey.objects.create( + label="quota_limits-test", + user=self.user, + secure_value=hash_key_value(raw_key), + scopes=["project:read"], + scoped_teams=[other_team.pk], + ) + + response = self.client.get( + self._url(), + headers={"authorization": f"Bearer {raw_key}"}, + ) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_personal_api_key_missing_required_scope_is_forbidden(self) -> None: + # A token with only `feature_flag:read` shouldn't be able to read quota state. + self.client.logout() + raw_key = generate_random_token_personal() + PersonalAPIKey.objects.create( + label="quota_limits-test", + user=self.user, + secure_value=hash_key_value(raw_key), + scopes=["feature_flag:read"], + ) + + response = self.client.get( + self._url(), + headers={"authorization": f"Bearer {raw_key}"}, + ) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_response_includes_every_quota_resource(self) -> None: + # Limiting one resource must not hide the unlimited state of the rest. + self._set_ai_credits_limit(self.team.api_token, 9_999_999_999) + + response = self.client.get(self._url()) + self.assertEqual(response.status_code, status.HTTP_200_OK) + limited = response.json()["limited"] + expected_keys = {resource.value for resource in QuotaResource} + self.assertEqual(set(limited.keys()), expected_keys) + self.assertTrue(limited["ai_credits"]["limited"]) + for resource in QuotaResource: + if resource is QuotaResource.AI_CREDITS: + continue + self.assertFalse(limited[resource.value]["limited"], resource.value) + + def test_multi_team_user_gets_per_team_answers(self) -> None: + # Same user belongs to two teams in their org; each team's quota is independent. + # This is the regression that "me" couldn't model — `user.team` (current team) + # picked one arbitrary answer for users in multiple teams. + other_team = Team.objects.create(organization=self.organization, name="other-team") + self._set_ai_credits_limit(self.team.api_token, 9_999_999_999) + # other_team's token deliberately not limited + + resp_self = self.client.get(self._url()) + resp_other = self.client.get(self._url(other_team.pk)) + + self.assertEqual(resp_self.json()["limited"]["ai_credits"], {"limited": True}) + self.assertEqual(resp_other.json()["limited"]["ai_credits"], {"limited": False}) diff --git a/posthog/api/__init__.py b/posthog/api/__init__.py index 8667f5916ac9..23aa28942dfc 100644 --- a/posthog/api/__init__.py +++ b/posthog/api/__init__.py @@ -142,6 +142,7 @@ ) from products.web_analytics.backend.api.web_analytics_filter_preset import WebAnalyticsFilterPresetViewSet +from ee.api.quota_limits import QuotaLimitsViewSet from ee.api.session_summaries import SessionGroupSummaryViewSet from ee.api.vercel import vercel_installation, vercel_product, vercel_proxy, vercel_resource @@ -371,6 +372,14 @@ def register_grandfathered_environment_nested_viewset( # Seats (proxied to billing service) router.register(r"seats", seats.SeatViewSet, "seats") +# Quota limits (project-scoped — backs the LLM gateway's QuotaResolver) +projects_router.register( + r"quota_limits", + QuotaLimitsViewSet, + "project_quota_limits", + ["team_id"], +) + projects_router.register(r"surveys", survey.SurveyViewSet, "project_surveys", ["project_id"]) projects_router.register(r"product_tours", ProductTourViewSet, "project_product_tours", ["project_id"]) projects_router.register( diff --git a/posthog/llm/gateway_client.py b/posthog/llm/gateway_client.py index 67f63d0ae5c1..c4727016bda8 100644 --- a/posthog/llm/gateway_client.py +++ b/posthog/llm/gateway_client.py @@ -9,6 +9,7 @@ "llm_gateway", "posthog_code", "background_agents", + "slack_app_routing", "wizard", "django", "growth", diff --git a/posthog/temporal/ai/__init__.py b/posthog/temporal/ai/__init__.py index 4a541bb62099..2a234d65c0f3 100644 --- a/posthog/temporal/ai/__init__.py +++ b/posthog/temporal/ai/__init__.py @@ -18,6 +18,7 @@ create_posthog_code_routing_rule_activity, create_posthog_code_task_for_repo_activity, discover_posthog_code_repository_via_agent_activity, + enforce_posthog_code_billing_quota_activity, forward_posthog_code_followup_activity, handle_posthog_code_rules_command_activity, post_posthog_code_internal_error_activity, @@ -68,6 +69,7 @@ process_research_agent_activity, summarize_llm_traces_activity, process_slack_conversation_activity, + enforce_posthog_code_billing_quota_activity, resolve_posthog_code_slack_user_activity, handle_posthog_code_rules_command_activity, collect_posthog_code_thread_messages_activity, diff --git a/posthog/temporal/ai/posthog_code_slack_mention.py b/posthog/temporal/ai/posthog_code_slack_mention.py index b3ab69cb693b..8a4c8071a039 100644 --- a/posthog/temporal/ai/posthog_code_slack_mention.py +++ b/posthog/temporal/ai/posthog_code_slack_mention.py @@ -35,6 +35,45 @@ logger = structlog.get_logger(__name__) +def _block_if_team_over_quota( + *, + integration: Any, + slack: Any, + channel: str, + thread_ts: str, + slack_user_id: str, + context: str, +) -> bool: + """Refuse a Slack-bot turn when the team is over its AI credits quota. + + Tach blocks ``products.slack_app`` from importing ``ee.billing``, so the + quota lookup lives here (where the temporal layer can freely import ee) + while the user-facing denial message lives in ``slack_app.backend.api`` + (where the Slack-posting helpers live). Returns True when the team was + blocked and a denial was posted. + """ + from products.slack_app.backend.api import post_quota_exhausted_denial + + from ee.billing.quota_limiting import QuotaLimitingCaches, QuotaResource, is_team_limited + + if not is_team_limited( + integration.team.api_token, + QuotaResource.AI_CREDITS, + QuotaLimitingCaches.QUOTA_LIMITER_CACHE_KEY, + ): + return False + + post_quota_exhausted_denial( + integration=integration, + slack=slack, + channel=channel, + thread_ts=thread_ts, + slack_user_id=slack_user_id, + context=context, + ) + return True + + def _safe_react(client: Any, channel: str, timestamp: str, name: str) -> None: try: client.reactions_add(channel=channel, timestamp=timestamp, name=name) @@ -205,6 +244,24 @@ async def run(self, inputs: PostHogCodeSlackMentionWorkflowInputs) -> None: return try: + # Gate every workflow entry on the team's AI-credits quota before any + # other activity runs. Webhook-level short-circuit catches the common + # case (see products/slack_app/backend/api.py); this is the defense in + # depth that also covers replays, manual workflow starts, and the race + # where the webhook saw "not limited" but Redis flipped before we got + # here. Wrapped in `workflow.patched` so in-flight workflows from + # before this deploy stay deterministic on replay. + if workflow.patched("posthog-code-slack-billing-gate"): + blocked = await _execute_posthog_code_activity( + enforce_posthog_code_billing_quota_activity, + inputs, + channel, + thread_ts, + slack_user_id, + ) + if blocked: + return + followup_handled = await _execute_posthog_code_activity( forward_posthog_code_followup_activity, inputs, @@ -890,6 +947,38 @@ def classify_posthog_code_task_needs_repo_activity( return classify_task_needs_repo(event_text, thread_messages) +@activity.defn +def enforce_posthog_code_billing_quota_activity( + inputs: PostHogCodeSlackMentionWorkflowInputs, + channel: str, + thread_ts: str, + slack_user_id: str, +) -> bool: + """Block the workflow when the team has exhausted its AI-credits quota. + + Returns True when a denial was posted and the workflow should stop. Called + as the first activity in the mention workflow so the bot never proceeds to + Slack roundtrips, thread fetches, or billable LLM calls (the classifier, + notably) for an over-quota team. + """ + from posthog.models.integration import Integration, SlackIntegration + + integration = Integration.objects.select_related("team").get( + id=inputs.integration_id, + kind="slack-posthog-code", + integration_id=inputs.slack_team_id, + ) + slack = SlackIntegration(integration) + return _block_if_team_over_quota( + integration=integration, + slack=slack, + channel=channel, + thread_ts=thread_ts, + slack_user_id=slack_user_id, + context="task_create", + ) + + @activity.defn def post_posthog_code_no_repos_activity( inputs: PostHogCodeSlackMentionWorkflowInputs, channel: str, thread_ts: str @@ -1036,6 +1125,19 @@ def create_posthog_code_task_for_repo_activity( integration_id=inputs.slack_team_id, ) slack = SlackIntegration(integration) + + # Refuse before the :seedling: reaction or the permalink fetch: a denied + # mention should not first ack-react and then refuse a second later. + if _block_if_team_over_quota( + integration=integration, + slack=slack, + channel=channel, + thread_ts=thread_ts, + slack_user_id=slack_user_id, + context="task_create", + ): + return + user_message_ts = event.get("ts") if user_message_ts: _safe_react(slack.client, channel, user_message_ts, "seedling") @@ -1241,6 +1343,16 @@ def forward_posthog_code_followup_activity( ) return True + if _block_if_team_over_quota( + integration=integration, + slack=slack, + channel=channel, + thread_ts=thread_ts, + slack_user_id=slack_user_id, + context="followup", + ): + return True + if task_run.is_terminal: return _resume_task_with_new_run( mapping, diff --git a/posthog/temporal/tests/ai/test_module_integrity.py b/posthog/temporal/tests/ai/test_module_integrity.py index 0ba3e89b31f4..ec0e15e60a1e 100644 --- a/posthog/temporal/tests/ai/test_module_integrity.py +++ b/posthog/temporal/tests/ai/test_module_integrity.py @@ -54,6 +54,7 @@ def test_activities_remain_unchanged(self): "process_research_agent_activity", "summarize_llm_traces_activity", "process_slack_conversation_activity", + "enforce_posthog_code_billing_quota_activity", "resolve_posthog_code_slack_user_activity", "handle_posthog_code_rules_command_activity", "collect_posthog_code_thread_messages_activity", diff --git a/products/slack_app/backend/api.py b/products/slack_app/backend/api.py index d1552c6f9205..6eeb4b763328 100644 --- a/products/slack_app/backend/api.py +++ b/products/slack_app/backend/api.py @@ -299,6 +299,45 @@ def lookup_slack_user_id_by_email( return slack_user_id +QUOTA_EXHAUSTED_MESSAGE = ( + "Your team has used its monthly PostHog AI credits. " + "Top up at https://us.posthog.com/organization/billing to continue." +) + + +def post_quota_exhausted_denial( + *, + integration: Integration, + slack: SlackIntegration, + channel: str, + thread_ts: str, + slack_user_id: str, + context: str, +) -> None: + """Post the AI-credits denial message into a Slack thread. + + Called by the workflow's quota gate after it determines the team is over + quota. Lives in this module so the Slack-posting helpers and the message + text stay co-located; the quota check itself lives in the temporal layer + (which is the only side allowed to import ``ee.billing``). + """ + logger.info( + "posthog_code_slack_blocked_by_quota", + context=context, + team_id=integration.team_id, + channel=channel, + thread_ts=thread_ts, + ) + _post_slack_user_feedback( + slack, + channel, + slack_user_id, + thread_ts, + QUOTA_EXHAUSTED_MESSAGE, + prefer_thread_message=True, + ) + + def _post_slack_user_feedback( slack: SlackIntegration, channel: str, @@ -1034,7 +1073,7 @@ def classify_task_needs_repo( 'Respond with ONLY a JSON object: {{"needs_repo": true}} or {{"needs_repo": false}}' ) try: - client = get_llm_client("slack-posthog-code") + client = get_llm_client("slack_app_routing") response = client.chat.completions.create( model="claude-haiku-4-5-20251001", messages=[{"role": "user", "content": prompt}], diff --git a/products/slack_app/backend/tests/test_followup_forwarding.py b/products/slack_app/backend/tests/test_followup_forwarding.py index 353e26d609b1..804e375ba7b9 100644 --- a/products/slack_app/backend/tests/test_followup_forwarding.py +++ b/products/slack_app/backend/tests/test_followup_forwarding.py @@ -13,6 +13,7 @@ from posthog.temporal.ai.posthog_code_slack_mention import ( PostHogCodeSlackMentionWorkflowInputs, create_posthog_code_task_for_repo_activity, + enforce_posthog_code_billing_quota_activity, forward_posthog_code_followup_activity, ) @@ -33,6 +34,16 @@ def _command_result(**kwargs): return SimpleNamespace(**defaults) +def _assert_quota_denial_posted(mock_slack_instance: MagicMock, channel: str, thread_ts: str) -> None: + denial_calls = [ + call + for call in mock_slack_instance.client.chat_postMessage.call_args_list + if call.kwargs.get("channel") == channel and call.kwargs.get("thread_ts") == thread_ts + ] + assert denial_calls, "Expected an in-thread denial message when over quota" + assert "PostHog AI credits" in denial_calls[0].kwargs["text"] + + class TestSlackThreadTaskMapping(TestCase): def setUp(self): self.Task = apps.get_model("tasks", "Task") @@ -401,6 +412,36 @@ def test_description_encloses_thread_context_in_a_tag(self, mock_slack_cls, mock ) assert task.description.endswith("do something") + @patch("products.tasks.backend.temporal.client.execute_task_processing_workflow") + @patch("posthog.temporal.ai.posthog_code_slack_mention.SlackIntegration") + @patch("ee.billing.quota_limiting.is_team_limited", return_value=True) + def test_quota_exceeded_blocks_task_creation_with_thread_message( + self, + _mock_is_team_limited, + mock_slack_cls, + mock_execute_workflow, + ): + mock_slack_instance = MagicMock() + mock_slack_cls.return_value = mock_slack_instance + + inputs = _make_inputs(self.integration.id) + create_posthog_code_task_for_repo_activity( + inputs, + "C123", + "1234.5678", + "U_ALICE", + self.user.id, + inputs.event, + [{"user": "U_ALICE", "text": "do something"}], + None, + ) + + # No task created, no workflow started. + assert not self.Task.objects.filter(team=self.team).exists() + mock_execute_workflow.assert_not_called() + + _assert_quota_denial_posted(mock_slack_instance, "C123", "1234.5678") + class TestForwardPostHogCodeFollowupActivity(TestCase): def setUp(self): @@ -446,6 +487,27 @@ def test_no_mapping_returns_false(self): ) assert result is False + @patch("ee.billing.quota_limiting.is_team_limited", return_value=True) + @patch("posthog.temporal.ai.posthog_code_slack_mention.SlackIntegration") + def test_quota_exceeded_blocks_followup_with_thread_message( + self, + mock_slack_cls, + _mock_is_team_limited, + ): + self._create_mapping() + mock_slack_instance = MagicMock() + mock_slack_cls.return_value = mock_slack_instance + + inputs = _make_inputs(self.integration.id) + result = forward_posthog_code_followup_activity( + inputs, "C123", "1234.5678", "U_ALICE", "do something", "1234.5679" + ) + + # The follow-up was handled by refusal, so the workflow shouldn't fall + # through to new-task creation. + assert result is True + _assert_quota_denial_posted(mock_slack_instance, "C123", "1234.5678") + @patch("posthog.temporal.ai.posthog_code_slack_mention.execute_task_processing_workflow") @patch("posthog.temporal.ai.posthog_code_slack_mention.SlackIntegration") def test_terminal_run_resumes_same_task(self, mock_slack_cls, mock_execute_workflow): @@ -717,6 +779,52 @@ def test_connection_error_retries_and_succeeds(self, mock_slack_cls, mock_send, mock_slack_instance.client.chat_postMessage.assert_not_called() +class TestEnforcePostHogCodeBillingQuotaActivity(TestCase): + """The workflow's first activity gate. Returns True (and posts a denial) when + the team is over its AI-credits quota; False otherwise.""" + + def setUp(self): + self.org = Organization.objects.create(name="TestOrg") + self.team = Team.objects.create(organization=self.org, name="TestTeam") + self.integration = Integration.objects.create( + team=self.team, kind="slack-posthog-code", integration_id="T_SLACK", config={} + ) + + @patch("posthog.models.integration.SlackIntegration") + @patch("ee.billing.quota_limiting.is_team_limited", return_value=True) + def test_returns_true_and_posts_denial_when_over_quota(self, _mock_is_team_limited, mock_slack_cls): + mock_slack_instance = MagicMock() + mock_slack_cls.return_value = mock_slack_instance + + inputs = _make_inputs(self.integration.id) + blocked = enforce_posthog_code_billing_quota_activity( + inputs, + "C123", + "1234.5678", + "U_ALICE", + ) + + assert blocked is True + _assert_quota_denial_posted(mock_slack_instance, "C123", "1234.5678") + + @patch("posthog.models.integration.SlackIntegration") + @patch("ee.billing.quota_limiting.is_team_limited", return_value=False) + def test_returns_false_and_posts_nothing_when_under_quota(self, _mock_is_team_limited, mock_slack_cls): + mock_slack_instance = MagicMock() + mock_slack_cls.return_value = mock_slack_instance + + inputs = _make_inputs(self.integration.id) + blocked = enforce_posthog_code_billing_quota_activity( + inputs, + "C123", + "1234.5678", + "U_ALICE", + ) + + assert blocked is False + mock_slack_instance.client.chat_postMessage.assert_not_called() + + class TestEventLevelDedupe(TestCase): """Verify that the workflow ID format supports event-level deduplication.""" diff --git a/products/slack_app/backend/tests/test_posthog_code_event_handler.py b/products/slack_app/backend/tests/test_posthog_code_event_handler.py index cfd149899239..685dc1b7a90f 100644 --- a/products/slack_app/backend/tests/test_posthog_code_event_handler.py +++ b/products/slack_app/backend/tests/test_posthog_code_event_handler.py @@ -501,3 +501,44 @@ def test_app_mention_with_missing_scopes_posts_reauth_and_skips_workflow( assert "app_mentions:read" in feedback_text if scope_value and "chat:write.customize" in scope_value: assert "chat:write.customize" not in feedback_text + + @patch("products.slack_app.backend.api._post_slack_user_feedback") + @patch("ee.billing.quota_limiting.is_team_limited", return_value=True) + @patch("products.slack_app.backend.api._posthog_code_enabled_for_integration", return_value=True) + @patch("products.slack_app.backend.api.SlackIntegration") + @patch("products.slack_app.backend.api.asyncio.run") + @patch("products.slack_app.backend.api.sync_connect") + @override_settings(DEBUG=False) + def test_over_quota_team_still_starts_workflow_for_in_workflow_enforcement( + self, + mock_sync_connect, + mock_asyncio_run, + mock_slack_cls, + _mock_flag, + _mock_is_team_limited, + mock_post_feedback, + ): + mock_slack_instance = mock_slack_cls.return_value + mock_slack_instance.missing_scopes.return_value = set() + + from products.slack_app.backend.api import ROUTE_HANDLED_LOCALLY, route_posthog_code_event_to_relevant_region + + request = self.factory.post("/slack/event-callback/", HTTP_HOST="eu.posthog.com") + event = { + "type": "app_mention", + "channel": "C001", + "user": "U123", + "ts": "1234.5678", + "thread_ts": "1234.5678", + } + + result = route_posthog_code_event_to_relevant_region(request, event, "T12345") + + assert result == ROUTE_HANDLED_LOCALLY + # Quota enforcement now lives entirely inside the workflow — the webhook + # always schedules it, and the activity-level gate posts the denial + # before any billable LLM call. This guarantees a single owner for the + # quota decision instead of dual gating. + mock_sync_connect.assert_called_once() + mock_asyncio_run.assert_called_once() + mock_post_feedback.assert_not_called() diff --git a/services/llm-gateway/src/llm_gateway/api/usage.py b/services/llm-gateway/src/llm_gateway/api/usage.py index 557a61bdfd1b..e94976b571c3 100644 --- a/services/llm-gateway/src/llm_gateway/api/usage.py +++ b/services/llm-gateway/src/llm_gateway/api/usage.py @@ -8,7 +8,7 @@ from pydantic import BaseModel from llm_gateway.auth.models import AuthenticatedUser -from llm_gateway.dependencies import get_authenticated_user +from llm_gateway.dependencies import get_authenticated_user, resolve_plan_and_quota from llm_gateway.rate_limiting.cost_throttles import CostStatus, UserCostBurstThrottle, UserCostSustainedThrottle from llm_gateway.rate_limiting.runner import ThrottleRunner from llm_gateway.rate_limiting.throttles import ThrottleContext @@ -17,7 +17,6 @@ PlanResolver, is_pro_plan, parse_iso_utc, - resolve_plan_info, ) logger = structlog.get_logger(__name__) @@ -35,11 +34,16 @@ class CostLimitStatus(BaseModel): exceeded: bool +class AiCreditsStatus(BaseModel): + exhausted: bool + + class UsageResponse(BaseModel): product: str user_id: int burst: CostLimitStatus sustained: CostLimitStatus + ai_credits: AiCreditsStatus is_rate_limited: bool is_pro: bool billing_period_end: datetime | None = None @@ -65,8 +69,15 @@ async def get_usage( user: Annotated[AuthenticatedUser, Depends(get_authenticated_user)], ) -> UsageResponse: runner: ThrottleRunner = request.app.state.throttle_runner - plan_info = await resolve_plan_info(request, user.user_id, product) + + plan_info, quota_status = await resolve_plan_and_quota( + request, + user_id=user.user_id, + team_id=user.team_id, + product=product, + ) now = datetime.now(tz=UTC) + ai_credits_exhausted = quota_status.limited context = ThrottleContext( user=user, @@ -75,6 +86,7 @@ async def get_usage( plan_key=plan_info.plan_key, seat_created_at=plan_info.seat_created_at, billing_period_start=plan_info.billing_period.current_period_start if plan_info.billing_period else None, + ai_credits_exhausted=ai_credits_exhausted, ) burst_status: CostLimitStatus | None = None @@ -118,7 +130,8 @@ async def get_usage( user_id=user.user_id, burst=burst_status, sustained=sustained_status, - is_rate_limited=burst_status.exceeded or sustained_status.exceeded, + ai_credits=AiCreditsStatus(exhausted=ai_credits_exhausted), + is_rate_limited=burst_status.exceeded or sustained_status.exceeded or ai_credits_exhausted, is_pro=is_pro_plan(plan_info.plan_key), billing_period_end=billing_period_end, ) diff --git a/services/llm-gateway/src/llm_gateway/callbacks/posthog.py b/services/llm-gateway/src/llm_gateway/callbacks/posthog.py index 64ee60acadca..34441a76220f 100644 --- a/services/llm-gateway/src/llm_gateway/callbacks/posthog.py +++ b/services/llm-gateway/src/llm_gateway/callbacks/posthog.py @@ -10,6 +10,7 @@ from llm_gateway.auth.models import resolve_distinct_id from llm_gateway.callbacks.base import InstrumentedCallback +from llm_gateway.products.config import get_product_config from llm_gateway.request_context import ( get_auth_user, get_posthog_flags, @@ -53,6 +54,15 @@ def _replace_binary_content(data: Any) -> Any: _TRUNCATION_MARKER = "[truncated: content too large for capture]" _TRUNCATABLE_FIELDS = ("$ai_output_choices", "$ai_input") + +def _is_product_billable(product: str) -> bool: + """Look up the product's billable flag in the central registry. False for + unknown products so we never accidentally bill calls we can't attribute. + """ + config = get_product_config(product) + return bool(config and config.billable) + + # Stable namespace for hashing non-UUID trace identifiers (e.g. Claude Code's # JSON-encoded session blobs sent via Anthropic's metadata.user_id) into a # deterministic UUID. Generated once and frozen so the same input always maps @@ -147,6 +157,7 @@ async def _on_success( "$ai_trace_id": trace_id, "$ai_span_id": str(uuid4()), "ai_product": product, + "$ai_billable": _is_product_billable(product), } posthog_properties = get_posthog_properties() or {} @@ -226,6 +237,7 @@ async def _on_failure( "$ai_is_error": True, "$ai_error": standard_logging_object.get("error_str", ""), "ai_product": product, + "$ai_billable": _is_product_billable(product), } posthog_properties = get_posthog_properties() or {} diff --git a/services/llm-gateway/src/llm_gateway/config.py b/services/llm-gateway/src/llm_gateway/config.py index 84461bb68856..f279164f9197 100644 --- a/services/llm-gateway/src/llm_gateway/config.py +++ b/services/llm-gateway/src/llm_gateway/config.py @@ -160,6 +160,10 @@ class Settings(BaseSettings): posthog_api_base_url: str = "https://us.posthog.com" plan_cache_ttl: int = 900 # 15 minutes + # Billing recomputes quota state on at most an hourly cadence, so we are + # comfortable letting a team go slightly over their limit in exchange for + # avoiding a Django roundtrip on every billable request. + quota_cache_ttl: int = 300 # 5 minutes billing_period_days: int = 30 # Anthropic -> Bedrock circuit breaker. When the trailing failure rate of the Anthropic diff --git a/services/llm-gateway/src/llm_gateway/dependencies.py b/services/llm-gateway/src/llm_gateway/dependencies.py index ccc3d80a3741..752357f92903 100644 --- a/services/llm-gateway/src/llm_gateway/dependencies.py +++ b/services/llm-gateway/src/llm_gateway/dependencies.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import json from typing import Annotated, Any @@ -10,7 +11,12 @@ from llm_gateway.auth.models import AuthenticatedUser from llm_gateway.auth.service import AuthService, get_auth_service from llm_gateway.circuit_breaker import AnthropicCircuitBreaker -from llm_gateway.products.config import ALLOWED_PRODUCTS, check_product_access, resolve_product_alias +from llm_gateway.products.config import ( + ALLOWED_PRODUCTS, + check_product_access, + get_product_config, + resolve_product_alias, +) from llm_gateway.rate_limiting.cost_refresh import ensure_costs_fresh from llm_gateway.rate_limiting.runner import ThrottleRunner from llm_gateway.rate_limiting.throttles import ThrottleContext @@ -19,7 +25,8 @@ get_request_id, set_throttle_context, ) -from llm_gateway.services.plan_resolver import resolve_plan_info +from llm_gateway.services.plan_resolver import PlanInfo, resolve_plan_info +from llm_gateway.services.quota_resolver import QuotaResourceStatus, resolve_quota_status logger = structlog.get_logger(__name__) @@ -149,6 +156,31 @@ async def _extract_end_user_id_from_body(request: Request) -> str | None: return None +async def resolve_plan_and_quota( + request: Request, + *, + user_id: int, + team_id: int | None, + product: str, +) -> tuple[PlanInfo, QuotaResourceStatus]: + """Fetch plan info and (for billable products) AI credits quota in parallel. + + Both calls are independent Django roundtrips on cache miss, so for billable + products we overlap them. For non-billable products the throttle stack + short-circuits regardless of quota state, so we skip the resolver entirely + rather than paying for the Redis GET (and the HTTP fallback on cache miss). + """ + product_config = get_product_config(product) + if product_config and product_config.billable: + plan_info, quota_status = await asyncio.gather( + resolve_plan_info(request, user_id, product), + resolve_quota_status(request, team_id), + ) + return plan_info, quota_status + plan_info = await resolve_plan_info(request, user_id, product) + return plan_info, QuotaResourceStatus(limited=False) + + async def enforce_throttles( request: Request, user: Annotated[AuthenticatedUser, Depends(enforce_product_access)], @@ -163,7 +195,12 @@ async def enforce_throttles( else: end_user_id = await _extract_end_user_id_from_body(request) - plan_info = await resolve_plan_info(request, user.user_id, product) + plan_info, quota_status = await resolve_plan_and_quota( + request, + user_id=user.user_id, + team_id=user.team_id, + product=product, + ) context = ThrottleContext( user=user, @@ -173,6 +210,7 @@ async def enforce_throttles( plan_key=plan_info.plan_key, seat_created_at=plan_info.seat_created_at, billing_period_start=plan_info.billing_period.current_period_start if plan_info.billing_period else None, + ai_credits_exhausted=quota_status.limited, ) request.state.throttle_context = context set_throttle_context(runner, context) diff --git a/services/llm-gateway/src/llm_gateway/main.py b/services/llm-gateway/src/llm_gateway/main.py index 50278966c7d9..42dec77a8ad5 100644 --- a/services/llm-gateway/src/llm_gateway/main.py +++ b/services/llm-gateway/src/llm_gateway/main.py @@ -26,6 +26,7 @@ from llm_gateway.config import Settings, get_settings from llm_gateway.db.postgres import close_db_pool, init_db_pool from llm_gateway.metrics.prometheus import DB_POOL_SIZE, get_instrumentator +from llm_gateway.rate_limiting.billable_credits_throttle import BillableCreditThrottle from llm_gateway.rate_limiting.cost_gauge_publisher import publish_product_cost_gauges_loop from llm_gateway.rate_limiting.cost_refresh import ensure_costs_fresh from llm_gateway.rate_limiting.cost_throttles import ( @@ -37,6 +38,7 @@ from llm_gateway.rate_limiting.runner import ThrottleRunner from llm_gateway.request_context import RequestContext, set_request_context from llm_gateway.services.plan_resolver import PlanResolver +from llm_gateway.services.quota_resolver import QuotaResolver def configure_logging(debug: bool = False) -> None: @@ -158,6 +160,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: ) app.state.throttle_runner = ThrottleRunner( throttles=[ + BillableCreditThrottle(), product_throttle, UserCostBurstThrottle(redis=app.state.redis), UserCostSustainedThrottle(redis=app.state.redis), @@ -186,6 +189,10 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: redis=app.state.redis, http_client=app.state.http_client, ) + app.state.quota_resolver = QuotaResolver( + redis=app.state.redis, + http_client=app.state.http_client, + ) logger.info("Plan resolver initialized", posthog_api_base_url=settings.posthog_api_base_url or "(not configured)") logger.info( diff --git a/services/llm-gateway/src/llm_gateway/products/config.py b/services/llm-gateway/src/llm_gateway/products/config.py index 95febfb0ba05..f5b7abc1dd3b 100644 --- a/services/llm-gateway/src/llm_gateway/products/config.py +++ b/services/llm-gateway/src/llm_gateway/products/config.py @@ -16,6 +16,10 @@ class ProductConfig: allowed_application_ids: frozenset[str] | None = frozenset() allowed_models: frozenset[str] | None = None # None = all allowed allow_api_keys: bool = True + # Tag emitted $ai_generation events with $ai_billable=true so the PHAI + # daily aggregator (posthog/tasks/usage_report.py) rolls them into the + # customer team's AI credits bucket. + billable: bool = False BEDROCK_MODELS = BEDROCK_MODEL_IDS @@ -28,6 +32,25 @@ class ProductConfig: WIZARD_US_APP_ID = "019a0c79-b69d-0000-f31b-b41345208c9d" WIZARD_EU_APP_ID = "019a12d0-6edd-0000-0458-86616af3a3db" +# Shared by `posthog_code` and `slack_app` — the agent that runs in the sandbox +# is the same code regardless of where the task was initiated, so the model +# allowlist is identical. +_POSTHOG_CODE_AGENT_MODELS: Final[frozenset[str]] = frozenset( + { + "claude-opus-4-5", + "claude-opus-4-6", + "claude-opus-4-7", + "claude-sonnet-4-5", + "claude-sonnet-4-6", + "claude-haiku-4-5", + "gpt-5.5", + "gpt-5.4", + "gpt-5.3-codex", + "gpt-5.2", + "gpt-5-mini", + } +) + PRODUCTS: Final[dict[str, ProductConfig]] = { "llm_gateway": ProductConfig( allowed_application_ids=None, @@ -36,22 +59,7 @@ class ProductConfig: ), "posthog_code": ProductConfig( allowed_application_ids=frozenset({POSTHOG_CODE_US_APP_ID, POSTHOG_CODE_EU_APP_ID}), - allowed_models=frozenset( - { - "claude-opus-4-5", - "claude-opus-4-6", - "claude-opus-4-7", - "claude-sonnet-4-5", - "claude-sonnet-4-6", - "claude-haiku-4-5", - "gpt-5.5", - "gpt-5.4", - "gpt-5.3-codex", - "gpt-5.2", - "gpt-5-mini", - } - | BEDROCK_MODELS - ), + allowed_models=_POSTHOG_CODE_AGENT_MODELS | BEDROCK_MODELS, allow_api_keys=False, ), "background_agents": ProductConfig( @@ -72,6 +80,12 @@ class ProductConfig: ), allow_api_keys=False, ), + "slack_app": ProductConfig( + allowed_application_ids=frozenset({POSTHOG_CODE_US_APP_ID, POSTHOG_CODE_EU_APP_ID}), + allowed_models=_POSTHOG_CODE_AGENT_MODELS | BEDROCK_MODELS, + allow_api_keys=False, + billable=True, + ), "wizard": ProductConfig( allowed_application_ids=frozenset({WIZARD_US_APP_ID, WIZARD_EU_APP_ID}), allowed_models=None, @@ -87,10 +101,11 @@ class ProductConfig: allowed_models=None, allow_api_keys=True, ), - "slack-posthog-code": ProductConfig( + "slack_app_routing": ProductConfig( allowed_application_ids=None, allowed_models=frozenset({"claude-haiku-4-5"}), allow_api_keys=True, + billable=True, ), "growth": ProductConfig( allowed_application_ids=None, @@ -139,7 +154,8 @@ class ProductConfig: PRODUCT_ALIASES: Final[dict[str, str]] = { "array": "posthog_code", "twig": "posthog_code", - "slack-twig": "slack-posthog-code", + "slack-posthog-code": "slack_app_routing", + "slack-twig": "slack_app_routing", } diff --git a/services/llm-gateway/src/llm_gateway/rate_limiting/billable_credits_throttle.py b/services/llm-gateway/src/llm_gateway/rate_limiting/billable_credits_throttle.py new file mode 100644 index 000000000000..4be99145f703 --- /dev/null +++ b/services/llm-gateway/src/llm_gateway/rate_limiting/billable_credits_throttle.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from llm_gateway.products.config import get_product_config +from llm_gateway.rate_limiting.throttles import Throttle, ThrottleContext, ThrottleResult + +# Hint the client to back off for a minute. A precise expiry timestamp is +# available upstream in Redis, but with Django's 30s in-process cache and the +# gateway's 30s resolver cache stacked on top, any computed retry window would +# be 0–60s stale anyway — a fixed minute is no worse and cuts the data +# plumbing. +_RETRY_AFTER_SECONDS = 60 + + +class BillableCreditThrottle(Throttle): + """Gate billable-product LLM calls on the team's AI credits balance. + + Reads ``ai_credits_exhausted`` from :class:`ThrottleContext`, pre-resolved + by the dependency layer (see ``resolve_quota_status``). + """ + + scope = "billable_credits" + + async def allow_request(self, context: ThrottleContext) -> ThrottleResult: + config = get_product_config(context.product) + if not (config and config.billable): + return ThrottleResult.allow() + + if not context.ai_credits_exhausted: + return ThrottleResult.allow() + + return ThrottleResult.deny( + detail=( + "Your team has used its monthly PostHog AI credits. " + "Top up at https://us.posthog.com/organization/billing to continue." + ), + scope=self.scope, + retry_after=_RETRY_AFTER_SECONDS, + ) diff --git a/services/llm-gateway/src/llm_gateway/rate_limiting/throttles.py b/services/llm-gateway/src/llm_gateway/rate_limiting/throttles.py index aff139e2e529..9f47757b1d89 100644 --- a/services/llm-gateway/src/llm_gateway/rate_limiting/throttles.py +++ b/services/llm-gateway/src/llm_gateway/rate_limiting/throttles.py @@ -23,6 +23,7 @@ class ThrottleContext: plan_key: str | None = None seat_created_at: str | None = None billing_period_start: str | None = None + ai_credits_exhausted: bool = False @dataclass diff --git a/services/llm-gateway/src/llm_gateway/services/quota_resolver.py b/services/llm-gateway/src/llm_gateway/services/quota_resolver.py new file mode 100644 index 000000000000..0004f09a8953 --- /dev/null +++ b/services/llm-gateway/src/llm_gateway/services/quota_resolver.py @@ -0,0 +1,171 @@ +"""Resolves a team's quota state via the PostHog API quota_limits endpoint. + +Mirrors :mod:`llm_gateway.services.plan_resolver` — forwards the caller's +``Authorization`` header to ``GET /api/projects/{team_id}/quota_limits/`` and +caches the result per team-and-resource in the gateway's own Redis. + +Transient upstream failures (network errors, 5xx) are retried within the +request with exponential backoff. 4xx responses or exhausted retries fall +open and briefly cache ``limited=False`` so a struggling Django isn't hit on +every subsequent request. +""" + +from __future__ import annotations + +import asyncio +import json +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import httpx +import structlog + +from llm_gateway.config import get_settings + +if TYPE_CHECKING: + from fastapi import Request + from redis.asyncio import Redis + +logger = structlog.get_logger(__name__) + +_AI_CREDITS_RESOURCE = "ai_credits" + +# Cache window for the fail-open path (4xx from Django, or transient failure +# after retries are exhausted). Long enough to keep a misconfigured client off +# Django's neck during an auth-failure storm; short enough that a recovered +# upstream is consulted again within a minute. +_FAIL_OPEN_CACHE_TTL_SECONDS = 60 + +# Exponential backoff between within-request retries on transient failures. +# The first attempt fires immediately; each subsequent retry waits +# ``MULTIPLIER * 2**n`` seconds, doubling the gap each step. Tune the +# multiplier to widen or tighten the overall spacing without touching the +# formula. +_MAX_RETRIES = 3 +_RETRY_BACKOFF_MULTIPLIER_SECONDS = 5 +_RETRY_DELAYS_SECONDS: tuple[float, ...] = ( + 0, + *(_RETRY_BACKOFF_MULTIPLIER_SECONDS * 2**i for i in range(_MAX_RETRIES)), +) + + +@dataclass +class QuotaResourceStatus: + limited: bool + + +class _TransientUpstreamError(Exception): + """Retryable failure: a 5xx response or a network-level error.""" + + +def _redis_key(resource_key: str, team_id: int) -> str: + return f"quota:{resource_key}:team:{team_id}" + + +async def resolve_quota_status(request: Request, team_id: int | None) -> QuotaResourceStatus: + """Resolve the team's AI credits quota state, falling open on errors.""" + if team_id is None: + return QuotaResourceStatus(limited=False) + auth_header = request.headers.get("Authorization", "") + if not auth_header: + return QuotaResourceStatus(limited=False) + + quota_resolver: QuotaResolver = request.app.state.quota_resolver + try: + return await quota_resolver.get_ai_credits_status(team_id=team_id, auth_header=auth_header) + except Exception: + logger.warning("quota_resolve_failed", team_id=team_id) + return QuotaResourceStatus(limited=False) + + +class QuotaResolver: + """Fetches team quota state from Django, caches per team.""" + + def __init__(self, redis: Redis[bytes] | None, http_client: httpx.AsyncClient): + self._redis = redis + self._http = http_client + self._cache_ttl = get_settings().quota_cache_ttl + + async def get_ai_credits_status(self, team_id: int, auth_header: str) -> QuotaResourceStatus: + return await self._get_resource_status(_AI_CREDITS_RESOURCE, team_id, auth_header) + + async def _get_resource_status(self, resource_key: str, team_id: int, auth_header: str) -> QuotaResourceStatus: + cached = await self._get_cached(resource_key, team_id) + if cached is not None: + return cached + + try: + status, ttl = await self._fetch_with_retry(resource_key, team_id, auth_header) + except Exception: + logger.warning("quota_fetch_failed", resource=resource_key, team_id=team_id, exc_info=True) + status, ttl = QuotaResourceStatus(limited=False), _FAIL_OPEN_CACHE_TTL_SECONDS + + await self._set_cached(resource_key, team_id, status, ttl) + return status + + async def _fetch_with_retry( + self, resource_key: str, team_id: int, auth_header: str + ) -> tuple[QuotaResourceStatus, int]: + """Try the upstream up to ``len(_RETRY_DELAYS_SECONDS)`` times. + + Network errors and 5xx responses are retried with growing waits between + attempts. 4xx and successful responses return immediately from + :meth:`_fetch`. If every attempt raises a transient error the last + exception is re-raised for the caller to fail open. + """ + last_exc: Exception | None = None + for delay in _RETRY_DELAYS_SECONDS: + if delay: + await asyncio.sleep(delay) + try: + return await self._fetch(resource_key, team_id, auth_header) + except _TransientUpstreamError as exc: + last_exc = exc + assert last_exc is not None + raise last_exc + + async def _fetch(self, resource_key: str, team_id: int, auth_header: str) -> tuple[QuotaResourceStatus, int]: + """One attempt against Django. Raises :class:`_TransientUpstreamError` on retryable failures.""" + settings = get_settings() + if not settings.posthog_api_base_url: + return QuotaResourceStatus(limited=False), _FAIL_OPEN_CACHE_TTL_SECONDS + + url = f"{settings.posthog_api_base_url.rstrip('/')}/api/projects/{team_id}/quota_limits/" + try: + resp = await self._http.get(url, headers={"Authorization": auth_header}, timeout=2.0) + except httpx.RequestError as exc: + raise _TransientUpstreamError(str(exc)) from exc + + if resp.status_code >= 500: + raise _TransientUpstreamError(f"quota_limits returned {resp.status_code}") + if resp.status_code >= 400: + # 4xx is permanent for the lifetime of this request — bad token, + # missing team, scope mismatch. Fail open briefly so a hot loop + # with a broken token doesn't hammer Django. + return QuotaResourceStatus(limited=False), _FAIL_OPEN_CACHE_TTL_SECONDS + + data = resp.json() + resource = (data.get("limited") or {}).get(resource_key) or {} + return QuotaResourceStatus(limited=bool(resource.get("limited"))), self._cache_ttl + + async def _get_cached(self, resource_key: str, team_id: int) -> QuotaResourceStatus | None: + if not self._redis: + return None + try: + val = await self._redis.get(_redis_key(resource_key, team_id)) + if val is None: + return None + payload = json.loads(val.decode()) + return QuotaResourceStatus(limited=bool(payload.get("limited"))) + except Exception: + logger.debug("quota_cache_read_failed", resource=resource_key, team_id=team_id) + return None + + async def _set_cached(self, resource_key: str, team_id: int, status: QuotaResourceStatus, ttl: int) -> None: + if not self._redis: + return + try: + payload = json.dumps({"limited": status.limited}) + await self._redis.set(_redis_key(resource_key, team_id), payload, ex=ttl) + except Exception: + logger.debug("quota_cache_write_failed", resource=resource_key, team_id=team_id) diff --git a/services/llm-gateway/tests/callbacks/test_posthog.py b/services/llm-gateway/tests/callbacks/test_posthog.py index 9a860cc8ce8e..1d50233ccfaf 100644 --- a/services/llm-gateway/tests/callbacks/test_posthog.py +++ b/services/llm-gateway/tests/callbacks/test_posthog.py @@ -303,6 +303,78 @@ async def test_on_failure_includes_ai_product( props = mock_client.capture.call_args.kwargs["properties"] assert props["ai_product"] == product + @pytest.mark.asyncio + @pytest.mark.parametrize("product", ["slack_app", "slack_app_routing"]) + async def test_on_success_marks_slack_products_billable( + self, + callback: PostHogCallback, + auth_user: AuthenticatedUser, + standard_logging_object: dict, + product: str, + mock_posthog_client: tuple, + ) -> None: + _, mock_client = mock_posthog_client + kwargs = {"standard_logging_object": standard_logging_object, "litellm_params": {}} + + with ( + patch("llm_gateway.callbacks.posthog.get_auth_user", return_value=auth_user), + patch("llm_gateway.callbacks.posthog.get_product", return_value=product), + ): + await callback._on_success(kwargs, None, 0.0, 1.0, end_user_id=None) + + props = mock_client.capture.call_args.kwargs["properties"] + assert props["$ai_billable"] is True + + @pytest.mark.asyncio + @pytest.mark.parametrize("product", ["posthog_code", "background_agents", "wizard", "llm_gateway"]) + async def test_on_success_does_not_mark_other_products_billable( + self, + callback: PostHogCallback, + auth_user: AuthenticatedUser, + standard_logging_object: dict, + product: str, + mock_posthog_client: tuple, + ) -> None: + _, mock_client = mock_posthog_client + kwargs = {"standard_logging_object": standard_logging_object, "litellm_params": {}} + + with ( + patch("llm_gateway.callbacks.posthog.get_auth_user", return_value=auth_user), + patch("llm_gateway.callbacks.posthog.get_product", return_value=product), + ): + await callback._on_success(kwargs, None, 0.0, 1.0, end_user_id=None) + + props = mock_client.capture.call_args.kwargs["properties"] + assert props["$ai_billable"] is False + + @pytest.mark.asyncio + @pytest.mark.parametrize("product", ["slack_app", "slack_app_routing"]) + async def test_on_failure_marks_slack_products_billable( + self, + callback: PostHogCallback, + auth_user: AuthenticatedUser, + mock_posthog_client: tuple, + product: str, + ) -> None: + _, mock_client = mock_posthog_client + kwargs = { + "standard_logging_object": { + "model": "claude-sonnet-4-6", + "custom_llm_provider": "anthropic", + "error_str": "boom", + }, + "litellm_params": {}, + } + + with ( + patch("llm_gateway.callbacks.posthog.get_auth_user", return_value=auth_user), + patch("llm_gateway.callbacks.posthog.get_product", return_value=product), + ): + await callback._on_failure(kwargs, None, 0.0, 1.0, end_user_id=None) + + props = mock_client.capture.call_args.kwargs["properties"] + assert props["$ai_billable"] is True + @pytest.mark.asyncio async def test_on_success_uses_passed_end_user_id( self, diff --git a/services/llm-gateway/tests/conftest.py b/services/llm-gateway/tests/conftest.py index c691b5a8ba24..200c7b7282ba 100644 --- a/services/llm-gateway/tests/conftest.py +++ b/services/llm-gateway/tests/conftest.py @@ -9,6 +9,7 @@ from llm_gateway.auth.models import AuthenticatedUser from llm_gateway.main import http_exception_handler +from llm_gateway.rate_limiting.billable_credits_throttle import BillableCreditThrottle from llm_gateway.rate_limiting.cost_throttles import ( ProductCostThrottle, UserCostBurstThrottle, @@ -17,6 +18,13 @@ from llm_gateway.rate_limiting.runner import ThrottleRunner from llm_gateway.rate_limiting.throttles import Throttle from llm_gateway.services.plan_resolver import PlanInfo +from llm_gateway.services.quota_resolver import QuotaResourceStatus + + +def _make_fake_quota_resolver() -> AsyncMock: + resolver = AsyncMock() + resolver.get_ai_credits_status = AsyncMock(return_value=QuotaResourceStatus(limited=False)) + return resolver def create_test_app( @@ -26,7 +34,9 @@ def create_test_app( from llm_gateway.api.health import health_router from llm_gateway.api.routes import router + quota_resolver = _make_fake_quota_resolver() default_throttles: list[Throttle] = [ + BillableCreditThrottle(), ProductCostThrottle(redis=None), UserCostBurstThrottle(redis=None), UserCostSustainedThrottle(redis=None), @@ -41,6 +51,7 @@ async def test_lifespan(app: FastAPI) -> AsyncGenerator[None, None]: app.state.plan_resolver = AsyncMock() app.state.plan_resolver.get_plan = AsyncMock(return_value=PlanInfo(plan_key=None, seat_created_at=None)) app.state.anthropic_circuit_breaker = None + app.state.quota_resolver = quota_resolver yield app = FastAPI(title="LLM Gateway Test", lifespan=test_lifespan) diff --git a/services/llm-gateway/tests/test_billable_credits_throttle.py b/services/llm-gateway/tests/test_billable_credits_throttle.py new file mode 100644 index 000000000000..50a276af80e7 --- /dev/null +++ b/services/llm-gateway/tests/test_billable_credits_throttle.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +import pytest + +from llm_gateway.auth.models import AuthenticatedUser +from llm_gateway.rate_limiting.billable_credits_throttle import BillableCreditThrottle +from llm_gateway.rate_limiting.throttles import ThrottleContext + + +def _make_user() -> AuthenticatedUser: + return AuthenticatedUser( + user_id=1, + team_id=42, + auth_method="personal_api_key", + distinct_id="distinct-1", + scopes=["llm_gateway:read"], + ) + + +def _make_context(product: str, *, ai_credits_exhausted: bool = False) -> ThrottleContext: + return ThrottleContext(user=_make_user(), product=product, ai_credits_exhausted=ai_credits_exhausted) + + +class TestBillableCreditThrottle: + @pytest.mark.asyncio + async def test_allows_non_billable_product_even_when_exhausted(self) -> None: + # `posthog_code` is non-billable; exhaustion at the context level is + # irrelevant — the throttle short-circuits before checking the flag. + throttle = BillableCreditThrottle() + + result = await throttle.allow_request(_make_context(product="posthog_code", ai_credits_exhausted=True)) + + assert result.allowed is True + + @pytest.mark.asyncio + @pytest.mark.parametrize("product", ["slack_app", "slack_app_routing"]) + async def test_allows_billable_product_when_not_exhausted(self, product: str) -> None: + throttle = BillableCreditThrottle() + + result = await throttle.allow_request(_make_context(product=product, ai_credits_exhausted=False)) + + assert result.allowed is True + + @pytest.mark.asyncio + @pytest.mark.parametrize("product", ["slack_app", "slack_app_routing"]) + async def test_denies_billable_product_when_exhausted(self, product: str) -> None: + throttle = BillableCreditThrottle() + + result = await throttle.allow_request(_make_context(product=product, ai_credits_exhausted=True)) + + assert result.allowed is False + assert result.status_code == 429 + assert result.scope == "billable_credits" + assert "PostHog AI credits" in result.detail + assert result.retry_after == 60 diff --git a/services/llm-gateway/tests/test_product_config.py b/services/llm-gateway/tests/test_product_config.py index d19e8ded2e4b..a1a7beec56fa 100644 --- a/services/llm-gateway/tests/test_product_config.py +++ b/services/llm-gateway/tests/test_product_config.py @@ -255,6 +255,51 @@ def test_slack_twig_rejects_non_haiku_models(self): assert error is not None assert "not allowed" in error + def test_slack_app_routing_allows_claude_haiku_via_api_key(self): + allowed, error = check_product_access("slack_app_routing", "personal_api_key", None, "claude-haiku-4-5") + assert allowed is True + assert error is None + + def test_slack_app_routing_rejects_non_haiku_models(self): + allowed, error = check_product_access("slack_app_routing", "personal_api_key", None, "claude-sonnet-4-5") + assert allowed is False + assert error is not None + assert "not allowed" in error + + def test_slack_posthog_code_alias_still_resolves(self): + # Legacy alias kept for backward compat (old URL paths, Django integration kind). + allowed, error = check_product_access("slack-posthog-code", "personal_api_key", None, "claude-haiku-4-5") + assert allowed is True + assert error is None + + @pytest.mark.parametrize( + "model", + [ + "claude-opus-4-7", + "claude-sonnet-4-6", + "claude-haiku-4-5", + "gpt-5.3-codex", + ], + ) + def test_slack_app_allows_agent_models(self, model: str): + allowed, error = check_product_access("slack_app", "oauth_access_token", POSTHOG_CODE_US_APP_ID, model) + assert allowed is True + assert error is None + + def test_slack_app_rejects_api_keys(self): + allowed, error = check_product_access("slack_app", "personal_api_key", None, "claude-sonnet-4-6") + assert allowed is False + assert error is not None + assert "requires OAuth" in error + + def test_slack_app_rejects_unauthorized_oauth_app(self): + allowed, error = check_product_access( + "slack_app", "oauth_access_token", "00000000-0000-0000-0000-000000000000", "claude-sonnet-4-6" + ) + assert allowed is False + assert error is not None + assert "not authorized" in error + class TestBackwardsCompatibility: def test_twig_app_id_constants_are_aliases(self): @@ -266,10 +311,11 @@ def test_twig_app_id_constants_are_aliases(self): [ ("twig", "posthog_code"), ("array", "posthog_code"), - ("slack-twig", "slack-posthog-code"), + ("slack-twig", "slack_app_routing"), + ("slack-posthog-code", "slack_app_routing"), ], ) - def test_aliases_resolve_to_posthog_code(self, alias: str, target: str): + def test_aliases_resolve_to_canonical_product(self, alias: str, target: str): assert resolve_product_alias(alias) == target def test_twig_alias_returns_same_config_as_posthog_code(self): @@ -284,9 +330,11 @@ def test_twig_alias_validates_to_posthog_code(self): def test_array_alias_validates_to_posthog_code(self): assert validate_product("array") == "posthog_code" - def test_slack_twig_alias_resolves_to_slack_posthog_code(self): - assert get_product_config("slack-twig") is get_product_config("slack-posthog-code") - assert validate_product("slack-twig") == "slack-posthog-code" + def test_slack_aliases_resolve_to_slack_app_routing(self): + assert get_product_config("slack-twig") is get_product_config("slack_app_routing") + assert get_product_config("slack-posthog-code") is get_product_config("slack_app_routing") + assert validate_product("slack-twig") == "slack_app_routing" + assert validate_product("slack-posthog-code") == "slack_app_routing" class TestValidateProduct: @@ -310,7 +358,8 @@ def test_alias_resolves_to_target_product(self, alias: str, target: str): def test_resolve_product_alias_returns_alias_target(self): assert resolve_product_alias("array") == "posthog_code" assert resolve_product_alias("twig") == "posthog_code" - assert resolve_product_alias("slack-twig") == "slack-posthog-code" + assert resolve_product_alias("slack-twig") == "slack_app_routing" + assert resolve_product_alias("slack-posthog-code") == "slack_app_routing" def test_resolve_product_alias_returns_input_if_not_aliased(self): assert resolve_product_alias("wizard") == "wizard" diff --git a/services/llm-gateway/tests/test_quota_resolver.py b/services/llm-gateway/tests/test_quota_resolver.py new file mode 100644 index 000000000000..28254e6c11e5 --- /dev/null +++ b/services/llm-gateway/tests/test_quota_resolver.py @@ -0,0 +1,181 @@ +from __future__ import annotations + +import json +from collections.abc import Iterator +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +from llm_gateway.services.quota_resolver import ( + _FAIL_OPEN_CACHE_TTL_SECONDS, + _RETRY_DELAYS_SECONDS, + QuotaResolver, + QuotaResourceStatus, + _redis_key, +) + + +def _make_response(status_code: int, payload: dict[str, object] | None = None) -> httpx.Response: + content = json.dumps(payload or {}).encode() + return httpx.Response(status_code, content=content, headers={"content-type": "application/json"}) + + +def _make_http_client(response: httpx.Response | Exception) -> MagicMock: + client = MagicMock() + if isinstance(response, Exception): + client.get = AsyncMock(side_effect=response) + else: + client.get = AsyncMock(return_value=response) + return client + + +def _make_http_client_sequence(responses: list[httpx.Response | Exception]) -> MagicMock: + client = MagicMock() + + async def _next(*args: object, **kwargs: object) -> httpx.Response: + item = responses.pop(0) + if isinstance(item, Exception): + raise item + return item + + client.get = AsyncMock(side_effect=_next) + return client + + +@pytest.fixture(autouse=True) +def _no_retry_sleep() -> Iterator[None]: + """Don't actually sleep between retries — keeps tests fast and deterministic.""" + with patch("llm_gateway.services.quota_resolver.asyncio.sleep", new=AsyncMock()): + yield + + +class TestQuotaResolver: + @pytest.mark.asyncio + async def test_fetches_and_parses_limited_response(self) -> None: + http_client = _make_http_client( + _make_response(200, {"team_id": 1, "limited": {"ai_credits": {"limited": True}}}) + ) + resolver = QuotaResolver(redis=None, http_client=http_client) + + status = await resolver.get_ai_credits_status(team_id=42, auth_header="Bearer phx_test") + + assert status == QuotaResourceStatus(limited=True) + http_client.get.assert_awaited_once() + assert http_client.get.await_args.kwargs["headers"]["Authorization"] == "Bearer phx_test" + + @pytest.mark.asyncio + async def test_fetches_and_parses_unlimited_response(self) -> None: + http_client = _make_http_client( + _make_response(200, {"team_id": 1, "limited": {"ai_credits": {"limited": False}}}) + ) + resolver = QuotaResolver(redis=None, http_client=http_client) + + status = await resolver.get_ai_credits_status(team_id=42, auth_header="Bearer phx_test") + + assert status == QuotaResourceStatus(limited=False) + + @pytest.mark.asyncio + async def test_4xx_fails_open_without_retrying(self) -> None: + # 4xx is treated as a permanent failure for the lifetime of this request, + # so we don't burn the retry budget on a token that won't fix itself. + http_client = _make_http_client(_make_response(401, {"detail": "no auth"})) + resolver = QuotaResolver(redis=None, http_client=http_client) + + status = await resolver.get_ai_credits_status(team_id=42, auth_header="Bearer phx_test") + + assert status == QuotaResourceStatus(limited=False) + http_client.get.assert_awaited_once() + + @pytest.mark.asyncio + async def test_retries_on_5xx_and_succeeds(self) -> None: + # A transient 503 is retried; the next attempt succeeds. + http_client = _make_http_client_sequence( + [ + _make_response(503, {}), + _make_response(200, {"limited": {"ai_credits": {"limited": True}}}), + ] + ) + resolver = QuotaResolver(redis=None, http_client=http_client) + + status = await resolver.get_ai_credits_status(team_id=42, auth_header="Bearer phx_test") + + assert status == QuotaResourceStatus(limited=True) + assert http_client.get.await_count == 2 + + @pytest.mark.asyncio + async def test_retries_on_network_error_and_succeeds(self) -> None: + http_client = _make_http_client_sequence( + [ + httpx.ConnectError("boom"), + _make_response(200, {"limited": {"ai_credits": {"limited": False}}}), + ] + ) + resolver = QuotaResolver(redis=None, http_client=http_client) + + status = await resolver.get_ai_credits_status(team_id=42, auth_header="Bearer phx_test") + + assert status == QuotaResourceStatus(limited=False) + assert http_client.get.await_count == 2 + + @pytest.mark.asyncio + async def test_gives_up_after_all_retries_and_fails_open(self) -> None: + # Consecutive network errors exhaust the retry budget; we fall open + # and cache the answer for the fail-open window. + http_client = _make_http_client(httpx.ConnectError("boom")) + redis = AsyncMock() + redis.get = AsyncMock(return_value=None) + redis.set = AsyncMock() + resolver = QuotaResolver(redis=redis, http_client=http_client) + + status = await resolver.get_ai_credits_status(team_id=42, auth_header="Bearer phx_test") + + assert status == QuotaResourceStatus(limited=False) + assert http_client.get.await_count == len(_RETRY_DELAYS_SECONDS) + redis.set.assert_awaited_once() + assert redis.set.await_args.kwargs.get("ex") == _FAIL_OPEN_CACHE_TTL_SECONDS + + @pytest.mark.asyncio + async def test_uses_cached_result_and_skips_http(self) -> None: + redis = AsyncMock() + redis.get = AsyncMock(return_value=json.dumps({"limited": True}).encode()) + http_client = _make_http_client(_make_response(200)) + resolver = QuotaResolver(redis=redis, http_client=http_client) + + status = await resolver.get_ai_credits_status(team_id=42, auth_header="Bearer phx_test") + + assert status == QuotaResourceStatus(limited=True) + redis.get.assert_awaited_once_with(_redis_key("ai_credits", 42)) + http_client.get.assert_not_called() + + @pytest.mark.asyncio + async def test_writes_cache_on_miss(self) -> None: + redis = AsyncMock() + redis.get = AsyncMock(return_value=None) + redis.set = AsyncMock() + http_client = _make_http_client(_make_response(200, {"limited": {"ai_credits": {"limited": True}}})) + resolver = QuotaResolver(redis=redis, http_client=http_client) + + await resolver.get_ai_credits_status(team_id=42, auth_header="Bearer phx_test") + + redis.set.assert_awaited_once() + call = redis.set.await_args + assert call.args[0] == _redis_key("ai_credits", 42) + assert json.loads(call.args[1]) == {"limited": True} + # Successful fetches use the gateway settings default of 5 minutes. + assert call.kwargs.get("ex") == 300 + + @pytest.mark.asyncio + async def test_caches_fail_open_for_full_window_on_4xx(self) -> None: + # 4xx responses (e.g. expired token) cache for the fail-open window so + # a hot loop with a broken token doesn't hammer Django. + redis = AsyncMock() + redis.get = AsyncMock(return_value=None) + redis.set = AsyncMock() + http_client = _make_http_client(_make_response(401, {"detail": "no auth"})) + resolver = QuotaResolver(redis=redis, http_client=http_client) + + await resolver.get_ai_credits_status(team_id=42, auth_header="Bearer phx_test") + + redis.set.assert_awaited_once() + assert redis.set.await_args.kwargs.get("ex") == _FAIL_OPEN_CACHE_TTL_SECONDS diff --git a/services/llm-gateway/tests/test_usage.py b/services/llm-gateway/tests/test_usage.py index e9ec3e6999c7..3b75ada2366a 100644 --- a/services/llm-gateway/tests/test_usage.py +++ b/services/llm-gateway/tests/test_usage.py @@ -358,6 +358,40 @@ def test_ignores_user_id_query_param(self, authenticated_usage_client: TestClien assert response.status_code == 200 assert response.json()["user_id"] == 42 + def test_ai_credits_reported_unlimited_for_non_billable_product( + self, authenticated_usage_client: TestClient + ) -> None: + # posthog_code is not billable; ai_credits should be unlimited and not contribute + # to is_rate_limited even if the resolver thinks the team is over. + from llm_gateway.services.quota_resolver import QuotaResourceStatus + + app = authenticated_usage_client.app + app.state.quota_resolver.get_ai_credits_status = AsyncMock(return_value=QuotaResourceStatus(limited=True)) + + response = authenticated_usage_client.get( + "/v1/usage/posthog_code", + headers={"Authorization": "Bearer phx_test"}, + ) + assert response.status_code == 200 + data = response.json() + assert data["ai_credits"] == {"exhausted": False} + assert data["is_rate_limited"] is False + + def test_ai_credits_reflects_resolver_for_billable_product(self, authenticated_usage_client: TestClient) -> None: + from llm_gateway.services.quota_resolver import QuotaResourceStatus + + app = authenticated_usage_client.app + app.state.quota_resolver.get_ai_credits_status = AsyncMock(return_value=QuotaResourceStatus(limited=True)) + + response = authenticated_usage_client.get( + "/v1/usage/slack_app", + headers={"Authorization": "Bearer phx_test"}, + ) + assert response.status_code == 200 + data = response.json() + assert data["ai_credits"] == {"exhausted": True} + assert data["is_rate_limited"] is True + def test_invalidate_plan_cache_calls_resolver(self, authenticated_usage_client: TestClient) -> None: app = authenticated_usage_client.app app.state.plan_resolver.invalidate = AsyncMock() diff --git a/services/mcp/src/api/generated.ts b/services/mcp/src/api/generated.ts index f3e9e141b277..11020534bd85 100644 --- a/services/mcp/src/api/generated.ts +++ b/services/mcp/src/api/generated.ts @@ -35076,6 +35076,21 @@ export namespace Schemas { query: EventsNode | ActionsNode | PersonsNode | DataWarehouseNode | FunnelsDataWarehouseNode | LifecycleDataWarehouseNode | EventsQuery | SessionsQuery | ActorsQuery | GroupsQuery | InsightActorsQuery | InsightActorsQueryOptions | SessionsTimelineQuery | HogQuery | HogQLQuery | HogQLMetadata | HogQLAutocomplete | SessionAttributionExplorerQuery | RevenueExampleEventsQuery | RevenueExampleDataWarehouseTablesQuery | ErrorTrackingQuery | ErrorTrackingSimilarIssuesQuery | ErrorTrackingBreakdownsQuery | ErrorTrackingIssueCorrelationQuery | ExperimentFunnelsQuery | ExperimentTrendsQuery | ExperimentQuery | ExperimentExposureQuery | DocumentSimilarityQuery | WebOverviewQuery | WebStatsTableQuery | WebExternalClicksTableQuery | WebGoalsQuery | WebVitalsQuery | WebVitalsPathBreakdownQuery | WebPageURLSearchQuery | WebAnalyticsExternalSummaryQuery | WebNotableChangesQuery | RevenueAnalyticsGrossRevenueQuery | RevenueAnalyticsMetricsQuery | RevenueAnalyticsMRRQuery | RevenueAnalyticsOverviewQuery | RevenueAnalyticsTopCustomersQuery | MarketingAnalyticsTableQuery | MarketingAnalyticsAggregatedQuery | NonIntegratedConversionsTableQuery | DataVisualizationNode | DataTableNode | SavedInsightNode | InsightVizNode | TrendsQuery | FunnelsQuery | RetentionQuery | PathsQuery | StickinessQuery | LifecycleQuery | FunnelCorrelationQuery | DatabaseSchemaQuery | RecordingsQuery | LogsQuery | LogAttributesQuery | LogValuesQuery | TraceSpansQuery | TraceSpansAggregationQuery | TraceSpansTreeQuery | SuggestedQuestionsQuery | TeamTaxonomyQuery | EventTaxonomyQuery | ActorsPropertyTaxonomyQuery | TracesQuery | TraceQuery | TraceNeighborsQuery | VectorSearchQuery | UsageMetricsQuery | EndpointsUsageOverviewQuery | EndpointsUsageTableQuery | EndpointsUsageTrendsQuery | PropertyValuesQuery; } + export interface QuotaResourceLimit { + /** True when the team is currently over its quota for this resource and limits are in effect. */ + limited: boolean; + } + + /** + * Per-resource limit state keyed by `QuotaResource` value. Currently only `ai_credits` is reported; additional resources may be added. + */ + export type QuotaLimitsResponseLimited = {[key: string]: QuotaResourceLimit}; + + export interface QuotaLimitsResponse { + /** Per-resource limit state keyed by `QuotaResource` value. Currently only `ai_credits` is reported; additional resources may be added. */ + limited: QuotaLimitsResponseLimited; + } + export interface RecomputeResult { run: Run; counts_changed: boolean;