-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdiff_tests_test_admin_rbac.py
More file actions
132 lines (132 loc) · 4.14 KB
/
diff_tests_test_admin_rbac.py
File metadata and controls
132 lines (132 loc) · 4.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
diff --git a/tests/test_admin_rbac.py b/tests/test_admin_rbac.py
new file mode 100644
index 00000000..bee3fac6
--- /dev/null
+++ b/tests/test_admin_rbac.py
@@ -0,0 +1,126 @@
+import types
+import pytest
+from datetime import datetime, timedelta, timezone
+from fastapi import FastAPI, Depends, HTTPException
+from fastapi.responses import JSONResponse
+from fastapi.testclient import TestClient
+from starlette.requests import Request
+from starlette.responses import Response
+from types import SimpleNamespace
+
+from app.middleware.api_key_security import ApiKeySecurityMiddleware
+from app.middleware.tenant_context import require_admin_role
+from app.services.api_key_security import api_key_security_service
+
+@pytest.fixture
+def anyio_backend():
+ return "asyncio"
+
+
+@pytest.fixture
+def api_keys():
+ return {"admin": "admin-key", "client": "client-key"}
+
+
+@pytest.fixture
+def admin_client(api_keys):
+ app = FastAPI()
+
+ @app.exception_handler(HTTPException)
+ async def http_exception_handler(request, exc):
+ return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
+
+ app.add_middleware(ApiKeySecurityMiddleware)
+
+ @app.get("/admin/settings", dependencies=[Depends(require_admin_role)])
+ async def settings():
+ return {"status": "ok"}
+
+ now = datetime.now(timezone.utc)
+ key_map = {
+ api_keys["admin"]: (
+ SimpleNamespace(
+ id="admin-key-id",
+ tenant_id="admin-tenant",
+ name="Admin Key",
+ prefix="sk_admin",
+ active=True,
+ expires_at=now + timedelta(days=90),
+ last_rotated_at=now,
+ ip_whitelist=[]
+ ),
+ SimpleNamespace(id="admin-tenant", name="Admin Tenant", active=True, plan="Bank")
+ ),
+ api_keys["client"]: (
+ SimpleNamespace(
+ id="client-key-id",
+ tenant_id="client-tenant",
+ name="Client Key",
+ prefix="sk_client",
+ active=True,
+ expires_at=now + timedelta(days=90),
+ last_rotated_at=now,
+ ip_whitelist=[]
+ ),
+ SimpleNamespace(id="client-tenant", name="Client Tenant", active=True, plan="SME")
+ )
+ }
+
+ original_validate = ApiKeySecurityMiddleware._validate_api_key
+ original_track_usage = api_key_security_service.track_usage
+
+ async def fake_validate(self, db, api_key):
+ return key_map.get(api_key, (None, None))
+
+ async def fake_track_usage(self, db, api_key_id, tenant_id, endpoint):
+ return None
+
+ ApiKeySecurityMiddleware._validate_api_key = fake_validate
+ api_key_security_service.track_usage = types.MethodType(fake_track_usage, api_key_security_service)
+
+ try:
+ yield TestClient(app, raise_server_exceptions=False)
+ finally:
+ ApiKeySecurityMiddleware._validate_api_key = original_validate
+ api_key_security_service.track_usage = original_track_usage
+
+
+@pytest.mark.anyio
+async def test_api_key_security_missing_key_raises_401():
+ middleware = ApiKeySecurityMiddleware(lambda request: Response())
+
+ scope = {
+ "type": "http",
+ "method": "GET",
+ "path": "/admin/settings",
+ "headers": [],
+ }
+
+ async def receive():
+ return {"type": "http.request"}
+
+ request = Request(scope, receive)
+
+ async def call_next(_request):
+ return Response()
+
+ with pytest.raises(HTTPException) as exc:
+ await middleware.dispatch(request, call_next)
+
+ assert exc.value.status_code == 401
+
+
+def test_admin_endpoint_rejects_non_admin(admin_client, api_keys):
+ response = admin_client.get(
+ "/admin/settings",
+ headers={"Authorization": f"Bearer {api_keys['client']}"}
+ )
+ assert response.status_code == 403, response.text
+
+
+def test_admin_endpoint_allows_admin(admin_client, api_keys):
+ response = admin_client.get(
+ "/admin/settings",
+ headers={"Authorization": f"Bearer {api_keys['admin']}"}
+ )
+ assert response.status_code == 200, response.text