diff --git a/chart/templates/configmap.yaml b/chart/templates/configmap.yaml index cfe5086..f87f64e 100644 --- a/chart/templates/configmap.yaml +++ b/chart/templates/configmap.yaml @@ -17,4 +17,6 @@ data: S3PROXY_REDIS_URL: {{ .Values.externalRedis.url | quote }} {{- end }} S3PROXY_REDIS_UPLOAD_TTL_HOURS: {{ .Values.externalRedis.uploadTtlHours | quote }} + S3PROXY_ADMIN_UI: {{ .Values.admin.enabled | quote }} + S3PROXY_ADMIN_PATH: {{ .Values.admin.path | quote }} S3PROXY_LOG_LEVEL: {{ .Values.logLevel | quote }} diff --git a/chart/values.yaml b/chart/values.yaml index 584dd80..f2f0940 100644 --- a/chart/values.yaml +++ b/chart/values.yaml @@ -95,6 +95,10 @@ secrets: awsAccessKeyId: "" awsSecretAccessKey: "" +admin: + enabled: false + path: "/admin" + logLevel: "DEBUG" resources: diff --git a/e2e/cluster.sh b/e2e/cluster.sh index 3633964..fa712ad 100755 --- a/e2e/cluster.sh +++ b/e2e/cluster.sh @@ -23,6 +23,9 @@ case "${1:-help}" in echo "==========================================" echo "Cluster is running in background." echo "" + echo "Admin dashboard: http://localhost:4433/admin/" + echo "Login: minioadmin / minioadmin" + echo "" echo "Run tests:" echo " ./cluster.sh postgres" echo " ./cluster.sh elasticsearch" diff --git a/e2e/docker-compose.yml b/e2e/docker-compose.yml index 7047e8e..579ee85 100644 --- a/e2e/docker-compose.yml +++ b/e2e/docker-compose.yml @@ -18,6 +18,8 @@ services: privileged: true depends_on: - registry + ports: + - "4433:4433" volumes: - /var/run/docker.sock:/var/run/docker.sock - ..:/repo @@ -437,7 +439,8 @@ services: --set redis-ha.haproxy.checkInterval=5s \ --set redis-ha.haproxy.timeout.check=10s \ --set redis-ha.haproxy.timeout.server=60s \ - --set redis-ha.haproxy.timeout.client=60s & + --set redis-ha.haproxy.timeout.client=60s \ + --set admin.enabled=true & S3PROXY_PID=$$! # 8. Wait for ALL parallel tasks @@ -464,11 +467,17 @@ services: echo "S3 Proxy endpoint for databases: http://s3-gateway.s3proxy.svc.cluster.local:80" echo "Direct MinIO (unencrypted): http://minio.minio.svc.cluster.local:9000" + # Start admin dashboard port-forward in background + kubectl port-forward --address 0.0.0.0 svc/s3proxy-python 4433:4433 -n s3proxy & + echo "" echo "==========================================" echo "Cluster is ready" echo "==========================================" echo "" + echo "Admin dashboard: http://localhost:4433/admin/" + echo "Login: minioadmin / minioadmin" + echo "" echo "Run database tests with:" echo " ./cluster.sh postgres" echo " ./cluster.sh elasticsearch" diff --git a/pyproject.toml b/pyproject.toml index b02e3c0..373ba3c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,9 @@ target-version = "py314" [tool.ruff.lint] select = ["E", "F", "I", "N", "W", "UP", "B", "C4", "SIM"] +[tool.ruff.lint.per-file-ignores] +"s3proxy/admin/templates.py" = ["E501"] + [tool.pytest.ini_options] asyncio_mode = "auto" testpaths = ["tests"] diff --git a/s3proxy/admin/__init__.py b/s3proxy/admin/__init__.py new file mode 100644 index 0000000..db5b707 --- /dev/null +++ b/s3proxy/admin/__init__.py @@ -0,0 +1,5 @@ +"""Admin dashboard for S3Proxy.""" + +from .router import create_admin_router + +__all__ = ["create_admin_router"] diff --git a/s3proxy/admin/auth.py b/s3proxy/admin/auth.py new file mode 100644 index 0000000..0348094 --- /dev/null +++ b/s3proxy/admin/auth.py @@ -0,0 +1,35 @@ +"""Basic auth for admin dashboard.""" + +import secrets + +from fastapi import Depends, HTTPException, status +from fastapi.security import HTTPBasic, HTTPBasicCredentials + +security = HTTPBasic(realm="S3Proxy Admin") + +_security_dep = Depends(security) + + +def create_auth_dependency(settings, credentials_store: dict[str, str]): + """Create a Basic Auth dependency for the admin router.""" + if settings.admin_username and settings.admin_password: + valid_username = settings.admin_username + valid_password = settings.admin_password + else: + if not credentials_store: + raise RuntimeError("No credentials configured for admin auth") + valid_username = next(iter(credentials_store.keys())) + valid_password = credentials_store[valid_username] + + async def verify(credentials: HTTPBasicCredentials = _security_dep): + username_ok = secrets.compare_digest(credentials.username.encode(), valid_username.encode()) + password_ok = secrets.compare_digest(credentials.password.encode(), valid_password.encode()) + if not (username_ok and password_ok): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid credentials", + headers={"WWW-Authenticate": 'Basic realm="S3Proxy Admin"'}, + ) + return credentials + + return verify diff --git a/s3proxy/admin/collectors.py b/s3proxy/admin/collectors.py new file mode 100644 index 0000000..b561847 --- /dev/null +++ b/s3proxy/admin/collectors.py @@ -0,0 +1,393 @@ +"""Data collectors for admin dashboard.""" + +from __future__ import annotations + +import hashlib +import json +import os +import time +from collections import deque +from dataclasses import asdict, dataclass +from typing import TYPE_CHECKING + +import structlog + +from .. import metrics +from ..state.redis import get_redis, is_using_redis + +if TYPE_CHECKING: + from ..config import Settings + from ..handlers import S3ProxyHandler + +logger = structlog.get_logger(__name__) + +ADMIN_KEY_PREFIX = "s3proxy:admin:" +ADMIN_TTL_SECONDS = 30 + + +# --------------------------------------------------------------------------- +# Rate tracker — sliding window over Prometheus counters +# --------------------------------------------------------------------------- + + +class RateTracker: + """Tracks counter snapshots over a sliding window to compute per-minute rates.""" + + def __init__(self, window_seconds: int = 600): + self._window = window_seconds + self._snapshots: deque[tuple[float, dict[str, float]]] = deque() + + def record(self, counters: dict[str, float]) -> None: + now = time.monotonic() + self._snapshots.append((now, counters)) + cutoff = now - self._window - 10 + while len(self._snapshots) > 2 and self._snapshots[0][0] < cutoff: + self._snapshots.popleft() + + def rate_per_minute(self, key: str) -> float: + if len(self._snapshots) < 2: + return 0.0 + oldest_ts, oldest_vals = self._snapshots[0] + newest_ts, newest_vals = self._snapshots[-1] + elapsed = newest_ts - oldest_ts + if elapsed < 1: + return 0.0 + delta = newest_vals.get(key, 0) - oldest_vals.get(key, 0) + return max(0.0, delta / elapsed * 60) + + def history(self, key: str, max_points: int = 60) -> list[float]: + """Return per-minute rate history as a list of floats for sparklines.""" + if len(self._snapshots) < 2: + return [] + rates: list[float] = [] + for i in range(1, len(self._snapshots)): + prev_ts, prev_vals = self._snapshots[i - 1] + curr_ts, curr_vals = self._snapshots[i] + elapsed = curr_ts - prev_ts + if elapsed < 0.1: + continue + delta = curr_vals.get(key, 0) - prev_vals.get(key, 0) + rates.append(round(max(0.0, delta / elapsed * 60), 1)) + if len(rates) > max_points: + step = len(rates) / max_points + rates = [rates[int(i * step)] for i in range(max_points)] + return rates + + +_rate_tracker = RateTracker(window_seconds=600) + + +# --------------------------------------------------------------------------- +# Request log — ring buffer for live feed +# --------------------------------------------------------------------------- + + +@dataclass(slots=True, frozen=True) +class RequestEntry: + """Single request log entry for the live feed.""" + + timestamp: float + method: str + path: str + operation: str + status: int + duration_ms: float + size: int + crypto: str + + +class RequestLog: + """Fixed-size ring buffer of recent requests for the live feed.""" + + ENCRYPT_OPS = frozenset({ + "PutObject", "UploadPart", "UploadPartCopy", + "CompleteMultipartUpload", "CopyObject", + }) + DECRYPT_OPS = frozenset({"GetObject"}) + + def __init__(self, maxlen: int = 200): + self._entries: deque[RequestEntry] = deque(maxlen=maxlen) + + def record( + self, + method: str, + path: str, + operation: str, + status: int, + duration: float, + size: int, + ) -> None: + crypto = "" + if operation in self.ENCRYPT_OPS: + crypto = "encrypt" + elif operation in self.DECRYPT_OPS: + crypto = "decrypt" + self._entries.append(RequestEntry( + timestamp=time.time(), + method=method, + path=path[:120], + operation=operation, + status=status, + duration_ms=round(duration * 1000, 1), + size=size, + crypto=crypto, + )) + + def recent(self, limit: int = 50) -> list[dict]: + """Return most recent entries as dicts, newest first.""" + entries = list(self._entries) + entries.reverse() + return [asdict(e) for e in entries[:limit]] + + +_request_log = RequestLog(maxlen=200) + + +def record_request( + method: str, + path: str, + operation: str, + status: int, + duration: float, + size: int, +) -> None: + """Record a completed request to the live feed log.""" + _request_log.record(method, path, operation, status, duration, size) + + +# --------------------------------------------------------------------------- +# Prometheus helpers +# --------------------------------------------------------------------------- + + +def _read_gauge(gauge) -> float: + return gauge._value.get() + + +def _read_counter(counter) -> float: + return counter._value.get() + + +def _read_labeled_counter_sum(counter) -> float: + total = 0.0 + for sample in counter.collect()[0].samples: + if sample.name.endswith("_total"): + total += sample.value + return total + + +def _read_labeled_gauge_sum(gauge) -> float: + total = 0.0 + for sample in gauge.collect()[0].samples: + total += sample.value + return total + + +def _read_errors_by_class() -> tuple[float, float, float]: + """Read 4xx, 5xx, 503 counts from REQUEST_COUNT labels.""" + errors_4xx = 0.0 + errors_5xx = 0.0 + errors_503 = 0.0 + for sample in metrics.REQUEST_COUNT.collect()[0].samples: + if not sample.name.endswith("_total"): + continue + status = str(sample.labels.get("status", "")) + if status.startswith("4"): + errors_4xx += sample.value + elif status == "503": + errors_503 += sample.value + errors_5xx += sample.value + elif status.startswith("5"): + errors_5xx += sample.value + return errors_4xx, errors_5xx, errors_503 + + +# --------------------------------------------------------------------------- +# Collectors +# --------------------------------------------------------------------------- + + +def collect_pod_identity(settings: Settings, start_time: float) -> dict: + """Collect pod identity for the header banner.""" + return { + "pod_name": os.environ.get("HOSTNAME", "unknown"), + "uptime_seconds": int(time.monotonic() - start_time), + "storage_backend": "Redis (HA)" if is_using_redis() else "In-memory", + "kek_fingerprint": hashlib.sha256(settings.kek).hexdigest()[:16], + } + + +def collect_health() -> dict: + """Collect health metrics with error counts.""" + memory_reserved = _read_gauge(metrics.MEMORY_RESERVED_BYTES) + memory_limit = _read_gauge(metrics.MEMORY_LIMIT_BYTES) + usage_pct = round(memory_reserved / memory_limit * 100, 1) if memory_limit > 0 else 0 + errors_4xx, errors_5xx, errors_503 = _read_errors_by_class() + + return { + "memory_reserved_bytes": int(memory_reserved), + "memory_limit_bytes": int(memory_limit), + "memory_usage_pct": usage_pct, + "requests_in_flight": int(_read_labeled_gauge_sum(metrics.REQUESTS_IN_FLIGHT)), + "errors_4xx": int(errors_4xx), + "errors_5xx": int(errors_5xx), + "errors_503": int(errors_503), + } + + +def collect_throughput() -> dict: + """Collect throughput counters and compute per-minute rates.""" + encrypt_ops = 0.0 + decrypt_ops = 0.0 + for sample in metrics.ENCRYPTION_OPERATIONS.collect()[0].samples: + if sample.name.endswith("_total"): + if sample.labels.get("operation") == "encrypt": + encrypt_ops = sample.value + elif sample.labels.get("operation") == "decrypt": + decrypt_ops = sample.value + + total_requests = _read_labeled_counter_sum(metrics.REQUEST_COUNT) + bytes_encrypted = _read_counter(metrics.BYTES_ENCRYPTED) + bytes_decrypted = _read_counter(metrics.BYTES_DECRYPTED) + errors_4xx, errors_5xx, errors_503 = _read_errors_by_class() + + counters = { + "requests": total_requests, + "encrypt_ops": encrypt_ops, + "decrypt_ops": decrypt_ops, + "bytes_encrypted": bytes_encrypted, + "bytes_decrypted": bytes_decrypted, + "errors_4xx": errors_4xx, + "errors_5xx": errors_5xx, + "errors_503": errors_503, + } + _rate_tracker.record(counters) + + return { + "rates": { + "requests_per_min": round(_rate_tracker.rate_per_minute("requests"), 1), + "encrypt_per_min": round(_rate_tracker.rate_per_minute("encrypt_ops"), 1), + "decrypt_per_min": round(_rate_tracker.rate_per_minute("decrypt_ops"), 1), + "bytes_encrypted_per_min": int(_rate_tracker.rate_per_minute("bytes_encrypted")), + "bytes_decrypted_per_min": int(_rate_tracker.rate_per_minute("bytes_decrypted")), + "errors_4xx_per_min": round(_rate_tracker.rate_per_minute("errors_4xx"), 1), + "errors_5xx_per_min": round(_rate_tracker.rate_per_minute("errors_5xx"), 1), + "errors_503_per_min": round(_rate_tracker.rate_per_minute("errors_503"), 1), + }, + "history": { + "requests_per_min": _rate_tracker.history("requests"), + "encrypt_per_min": _rate_tracker.history("encrypt_ops"), + "decrypt_per_min": _rate_tracker.history("decrypt_ops"), + "bytes_encrypted_per_min": _rate_tracker.history("bytes_encrypted"), + "bytes_decrypted_per_min": _rate_tracker.history("bytes_decrypted"), + }, + } + + +# --------------------------------------------------------------------------- +# Redis pod metrics publishing (multi-pod view) +# --------------------------------------------------------------------------- + + +async def publish_pod_metrics(pod_data: dict) -> None: + """Publish this pod's metrics to Redis so other pods can read them.""" + if not is_using_redis(): + return + try: + client = get_redis() + pod_name = pod_data["pod"]["pod_name"] + key = f"{ADMIN_KEY_PREFIX}{pod_name}" + await client.set(key, json.dumps(pod_data).encode(), ex=ADMIN_TTL_SECONDS) + except Exception: + logger.debug("Failed to publish pod metrics to Redis", exc_info=True) + + +async def read_all_pod_metrics() -> list[dict]: + """Read all pods' metrics from Redis. Returns empty list if not using Redis.""" + if not is_using_redis(): + return [] + try: + client = get_redis() + pods = [] + async for key in client.scan_iter(match=f"{ADMIN_KEY_PREFIX}*", count=100): + data = await client.get(key) + if data: + pods.append(json.loads(data)) + pods.sort(key=lambda p: p.get("pod", {}).get("pod_name", "")) + return pods + except Exception: + logger.debug("Failed to read pod metrics from Redis", exc_info=True) + return [] + + +# --------------------------------------------------------------------------- +# Formatters +# --------------------------------------------------------------------------- + + +def _format_bytes(n: int) -> str: + """Format bytes to human-readable string.""" + for unit in ("B", "KB", "MB", "GB", "TB"): + if abs(n) < 1024: + return f"{n:.1f} {unit}" if unit != "B" else f"{n} {unit}" + n /= 1024 + return f"{n:.1f} PB" + + +def _format_uptime(seconds: int) -> str: + """Format seconds to human-readable uptime string.""" + days, remainder = divmod(seconds, 86400) + hours, remainder = divmod(remainder, 3600) + minutes, _ = divmod(remainder, 60) + parts = [] + if days: + parts.append(f"{days}d") + if hours: + parts.append(f"{hours}h") + parts.append(f"{minutes}m") + return " ".join(parts) + + +# --------------------------------------------------------------------------- +# Aggregate +# --------------------------------------------------------------------------- + + +async def collect_all( + settings: Settings, + handler: S3ProxyHandler, + start_time: float, +) -> dict: + """Collect all dashboard data and publish to Redis for multi-pod view.""" + pod = collect_pod_identity(settings, start_time) + health = collect_health() + throughput = collect_throughput() + + local_data = { + "pod": pod, + "health": health, + "throughput": throughput, + "formatted": { + "memory_reserved": _format_bytes(health["memory_reserved_bytes"]), + "memory_limit": _format_bytes(health["memory_limit_bytes"]), + "uptime": _format_uptime(pod["uptime_seconds"]), + "bytes_encrypted_per_min": _format_bytes( + throughput["rates"]["bytes_encrypted_per_min"] + ), + "bytes_decrypted_per_min": _format_bytes( + throughput["rates"]["bytes_decrypted_per_min"] + ), + }, + } + + # Publish this pod's data to Redis (fire-and-forget for other pods to see) + await publish_pod_metrics(local_data) + + # Read all pods from Redis (includes this pod's just-published data) + all_pods = await read_all_pod_metrics() + + return { + **local_data, + "request_log": _request_log.recent(50), + "all_pods": all_pods, + } diff --git a/s3proxy/admin/router.py b/s3proxy/admin/router.py new file mode 100644 index 0000000..8f5fc9c --- /dev/null +++ b/s3proxy/admin/router.py @@ -0,0 +1,40 @@ +"""Admin dashboard router.""" + +from __future__ import annotations + +import os +from typing import TYPE_CHECKING + +from fastapi import APIRouter, Depends, Request +from fastapi.responses import HTMLResponse, JSONResponse + +from .auth import create_auth_dependency +from .collectors import collect_all +from .templates import DASHBOARD_HTML + +if TYPE_CHECKING: + from ..config import Settings + + +def create_admin_router(settings: Settings, credentials_store: dict[str, str]) -> APIRouter: + """Create the admin dashboard router with auth.""" + verify_admin = create_auth_dependency(settings, credentials_store) + router = APIRouter(dependencies=[Depends(verify_admin)]) + + @router.get("/", response_class=HTMLResponse) + async def dashboard(): + return HTMLResponse(DASHBOARD_HTML) + + @router.get("/api/status") + async def status(request: Request): + data = await collect_all( + request.app.state.settings, + request.app.state.handler, + request.app.state.start_time, + ) + return JSONResponse( + data, + headers={"X-Served-By": os.environ.get("HOSTNAME", "unknown")}, + ) + + return router diff --git a/s3proxy/admin/templates.py b/s3proxy/admin/templates.py new file mode 100644 index 0000000..de11086 --- /dev/null +++ b/s3proxy/admin/templates.py @@ -0,0 +1,292 @@ +"""HTML template for admin dashboard.""" + +DASHBOARD_HTML = """\ + + + + + +S3Proxy Admin + + + + +
+
+
loading...
+
3s
+
+
+uptime - +KEK - +- +
+
+
+ +
+
Health
+
Memory-
+
In-Flight-
+
Errors
+
+4xx 0/min +5xx 0/min +503 0/min +
+
+ +
+
Throughput
+
+
0
/min
requests
+
0
/min · 0 B/min
encrypt
+
0
/min · 0 B/min
decrypt
+
+
+ +
+
Live Feed
+
+ + + +
TimeMethodPathOpStatusLatencySize
+
Waiting for requests...
+
+
+ + + + +""" diff --git a/s3proxy/app.py b/s3proxy/app.py index ec755e6..e73f2b7 100644 --- a/s3proxy/app.py +++ b/s3proxy/app.py @@ -5,6 +5,7 @@ import logging import os import sys +import time import uuid from collections.abc import AsyncIterator from contextlib import asynccontextmanager @@ -78,8 +79,10 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]: handler = S3ProxyHandler(settings, credentials_store, multipart_manager) # Store in app.state for route access + app.state.settings = settings app.state.handler = handler app.state.verifier = verifier + app.state.start_time = time.monotonic() yield @@ -106,7 +109,15 @@ def create_app(settings: Settings | None = None) -> FastAPI: app = FastAPI(title="S3Proxy", lifespan=lifespan, docs_url=None, redoc_url=None) _register_exception_handlers(app) - _register_routes(app) + _register_health_routes(app) + + if settings.admin_ui: + from .admin import create_admin_router + + admin_router = create_admin_router(settings, credentials_store) + app.include_router(admin_router, prefix=settings.admin_path) + + _register_catch_all(app) return app @@ -116,7 +127,21 @@ def _register_exception_handlers(app: FastAPI) -> None: @app.exception_handler(HTTPException) async def s3_exception_handler(request: Request, exc: HTTPException): - """Return S3-compatible error response with request ID.""" + """Return S3-compatible error response with request ID. + + Exceptions with custom headers (e.g. WWW-Authenticate from admin auth) + are returned as-is with their original headers preserved. + """ + # Pass through non-S3 exceptions that carry custom headers (e.g. admin auth 401) + if not isinstance(exc, S3Error) and getattr(exc, "headers", None): + from fastapi.responses import JSONResponse + + return JSONResponse( + status_code=exc.status_code, + content={"detail": exc.detail}, + headers=exc.headers, + ) + request_id = str(uuid.uuid4()).replace("-", "").upper()[:16] if isinstance(exc, S3Error): @@ -143,8 +168,8 @@ async def s3_exception_handler(request: Request, exc: HTTPException): ) -def _register_routes(app: FastAPI) -> None: - """Register health check and proxy routes.""" +def _register_health_routes(app: FastAPI) -> None: + """Register health check and metrics routes.""" @app.get("/healthz") @app.get("/readyz") @@ -155,6 +180,10 @@ async def health(): async def metrics(): return Response(content=generate_latest(), media_type=CONTENT_TYPE_LATEST) + +def _register_catch_all(app: FastAPI) -> None: + """Register S3 proxy catch-all route. Must be called last.""" + @app.api_route( "/{path:path}", methods=["GET", "PUT", "POST", "DELETE", "HEAD"], diff --git a/s3proxy/config.py b/s3proxy/config.py index 5d6b5d1..8135894 100644 --- a/s3proxy/config.py +++ b/s3proxy/config.py @@ -49,6 +49,16 @@ class Settings(BaseSettings): # Logging log_level: str = Field(default="INFO", description="Log level (DEBUG, INFO, WARNING, ERROR)") + # Admin dashboard settings + admin_ui: bool = Field(default=False, description="Enable admin dashboard") + admin_path: str = Field(default="/admin", description="URL path prefix for admin dashboard") + admin_username: str = Field( + default="", description="Admin dashboard username (default: AWS access key)" + ) + admin_password: str = Field( + default="", description="Admin dashboard password (default: AWS secret key)" + ) + # Cached KEK derived from encrypt_key (computed once in model_post_init) _kek: bytes = PrivateAttr() diff --git a/s3proxy/request_handler.py b/s3proxy/request_handler.py index 763bc53..407cb5e 100644 --- a/s3proxy/request_handler.py +++ b/s3proxy/request_handler.py @@ -13,6 +13,7 @@ from structlog.stdlib import BoundLogger from . import concurrency, crypto +from .admin.collectors import record_request from .errors import S3Error, raise_for_client_error, raise_for_exception from .handlers import S3ProxyHandler from .metrics import ( @@ -133,6 +134,10 @@ async def handle_proxy_request( REQUESTS_IN_FLIGHT.labels(method=method).dec() REQUEST_COUNT.labels(method=method, operation=operation, status=status_code).inc() REQUEST_DURATION.labels(method=method, operation=operation).observe(duration) + record_request( + method, path, operation, status_code, duration, + int(request.headers.get("content-length", "0") or "0"), + ) if reserved_memory > 0: await concurrency.release_memory(reserved_memory) diff --git a/s3proxy/state/manager.py b/s3proxy/state/manager.py index 7c6159a..cbdf796 100644 --- a/s3proxy/state/manager.py +++ b/s3proxy/state/manager.py @@ -44,6 +44,29 @@ def _storage_key(self, bucket: str, key: str, upload_id: str) -> str: """Generate storage key for upload state.""" return f"{bucket}:{key}:{upload_id}" + async def list_active_uploads(self) -> list[dict]: + """List active uploads for admin dashboard. DEKs are never exposed.""" + keys = await self._store.list_keys() + uploads = [] + for key in keys: + data = await self._store.get(key) + if data is None: + continue + state = deserialize_upload_state(data) + if state is None: + continue + uploads.append( + { + "bucket": state.bucket, + "key": state.key, + "upload_id": self._truncate_id(state.upload_id), + "parts_count": len(state.parts), + "created_at": state.created_at.isoformat(), + "total_plaintext_size": state.total_plaintext_size, + } + ) + return uploads + async def create_upload( self, bucket: str, diff --git a/s3proxy/state/storage.py b/s3proxy/state/storage.py index 3380c20..c2f7394 100644 --- a/s3proxy/state/storage.py +++ b/s3proxy/state/storage.py @@ -48,6 +48,11 @@ async def get_and_delete(self, key: str) -> bytes | None: """Atomically get and delete value. Returns None if not found.""" ... + @abstractmethod + async def list_keys(self) -> list[str]: + """List all stored keys.""" + ... + @abstractmethod async def update(self, key: str, updater: Updater, ttl_seconds: int) -> bytes | None: """Atomically update value using updater function. @@ -69,6 +74,9 @@ class MemoryStateStore(StateStore): def __init__(self) -> None: self._store: dict[str, bytes] = {} + async def list_keys(self) -> list[str]: + return list(self._store.keys()) + async def get(self, key: str) -> bytes | None: return self._store.get(key) @@ -108,6 +116,13 @@ def _key(self, key: str) -> str: """Get prefixed key.""" return f"{self._prefix}{key}" + async def list_keys(self) -> list[str]: + keys: list[str] = [] + async for key in self._client.scan_iter(match=f"{self._prefix}*", count=100): + k = key.decode() if isinstance(key, bytes) else key + keys.append(k.removeprefix(self._prefix)) + return keys + async def get(self, key: str) -> bytes | None: return await self._client.get(self._key(key)) diff --git a/tests/unit/test_admin.py b/tests/unit/test_admin.py new file mode 100644 index 0000000..3d6b928 --- /dev/null +++ b/tests/unit/test_admin.py @@ -0,0 +1,436 @@ +"""Tests for admin dashboard.""" + +import base64 +import hashlib +import os +import time +from unittest.mock import patch + +import pytest +from fastapi.testclient import TestClient + +from s3proxy.admin.collectors import ( + RateTracker, + RequestLog, + _format_bytes, + _format_uptime, + collect_health, + collect_pod_identity, + collect_throughput, + record_request, +) +from s3proxy.config import Settings + +# ============================================================================ +# Fixtures +# ============================================================================ + + +@pytest.fixture +def admin_settings(): + return Settings( + host="http://localhost:9000", + encrypt_key="test-key-for-admin", + admin_ui=True, + admin_path="/admin", + ) + + +@pytest.fixture +def admin_disabled_settings(): + return Settings( + host="http://localhost:9000", + encrypt_key="test-key-for-admin", + admin_ui=False, + ) + + +@pytest.fixture +def admin_credentials(): + return ("AKIAIOSFODNN7EXAMPLE", "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY") + + +@pytest.fixture +def admin_app(admin_settings, admin_credentials): + with patch.dict( + os.environ, + { + "AWS_ACCESS_KEY_ID": admin_credentials[0], + "AWS_SECRET_ACCESS_KEY": admin_credentials[1], + }, + ): + from s3proxy.app import create_app + + return create_app(admin_settings) + + +@pytest.fixture +def admin_disabled_app(admin_disabled_settings, admin_credentials): + with patch.dict( + os.environ, + { + "AWS_ACCESS_KEY_ID": admin_credentials[0], + "AWS_SECRET_ACCESS_KEY": admin_credentials[1], + }, + ): + from s3proxy.app import create_app + + return create_app(admin_disabled_settings) + + +@pytest.fixture +def client(admin_app): + with TestClient(admin_app) as c: + yield c + + +@pytest.fixture +def disabled_client(admin_disabled_app): + with TestClient(admin_disabled_app) as c: + yield c + + +def _basic_auth_header(username: str, password: str) -> dict: + token = base64.b64encode(f"{username}:{password}".encode()).decode() + return {"Authorization": f"Basic {token}"} + + +# ============================================================================ +# Auth Tests +# ============================================================================ + + +class TestAdminAuth: + def test_no_credentials_returns_401(self, client): + response = client.get("/admin/") + assert response.status_code == 401 + assert "WWW-Authenticate" in response.headers + + def test_wrong_credentials_returns_401(self, client): + headers = _basic_auth_header("wrong", "wrong") + response = client.get("/admin/", headers=headers) + assert response.status_code == 401 + + def test_valid_credentials_returns_200(self, client, admin_credentials): + headers = _basic_auth_header(admin_credentials[0], admin_credentials[1]) + response = client.get("/admin/", headers=headers) + assert response.status_code == 200 + + def test_custom_admin_credentials(self, admin_credentials): + with patch.dict( + os.environ, + { + "AWS_ACCESS_KEY_ID": admin_credentials[0], + "AWS_SECRET_ACCESS_KEY": admin_credentials[1], + }, + ): + from s3proxy.app import create_app + + settings = Settings( + host="http://localhost:9000", + encrypt_key="test-key", + admin_ui=True, + admin_username="myadmin", + admin_password="mysecret", + ) + app = create_app(settings) + with TestClient(app) as c: + # AWS creds should NOT work + headers = _basic_auth_header(admin_credentials[0], admin_credentials[1]) + assert c.get("/admin/", headers=headers).status_code == 401 + + # Custom creds should work + headers = _basic_auth_header("myadmin", "mysecret") + assert c.get("/admin/", headers=headers).status_code == 200 + + +# ============================================================================ +# Dashboard HTML Tests +# ============================================================================ + + +class TestDashboardHTML: + def test_returns_html(self, client, admin_credentials): + headers = _basic_auth_header(admin_credentials[0], admin_credentials[1]) + response = client.get("/admin/", headers=headers) + assert response.status_code == 200 + assert "text/html" in response.headers["content-type"] + + def test_contains_expected_sections(self, client, admin_credentials): + headers = _basic_auth_header(admin_credentials[0], admin_credentials[1]) + html = client.get("/admin/", headers=headers).text + assert "S3Proxy Admin" in html + assert "Health" in html + assert "Throughput" in html + assert "Live Feed" in html + + def test_no_sensitive_data_in_html(self, client, admin_credentials, admin_settings): + headers = _basic_auth_header(admin_credentials[0], admin_credentials[1]) + html = client.get("/admin/", headers=headers).text + # Raw key should never appear + assert admin_settings.encrypt_key not in html + # KEK bytes should never appear + kek_hex = admin_settings.kek.hex() + assert kek_hex not in html + # AWS secret key should never appear + assert admin_credentials[1] not in html + + +# ============================================================================ +# API Status Endpoint Tests +# ============================================================================ + + +class TestApiStatus: + def test_returns_json(self, client, admin_credentials): + headers = _basic_auth_header(admin_credentials[0], admin_credentials[1]) + response = client.get("/admin/api/status", headers=headers) + assert response.status_code == 200 + assert "application/json" in response.headers["content-type"] + + def test_contains_expected_keys(self, client, admin_credentials): + headers = _basic_auth_header(admin_credentials[0], admin_credentials[1]) + data = client.get("/admin/api/status", headers=headers).json() + assert "pod" in data + assert "health" in data + assert "throughput" in data + assert "request_log" in data + assert "formatted" in data + assert "all_pods" in data + + def test_pod_has_identity_fields(self, client, admin_credentials, admin_settings): + headers = _basic_auth_header(admin_credentials[0], admin_credentials[1]) + data = client.get("/admin/api/status", headers=headers).json() + pod = data["pod"] + expected_fp = hashlib.sha256(admin_settings.kek).hexdigest()[:16] + assert pod["kek_fingerprint"] == expected_fp + assert "pod_name" in pod + assert "uptime_seconds" in pod + assert "storage_backend" in pod + + def test_no_sensitive_data_in_json(self, client, admin_credentials, admin_settings): + headers = _basic_auth_header(admin_credentials[0], admin_credentials[1]) + response_text = client.get("/admin/api/status", headers=headers).text + # Raw encryption key + assert admin_settings.encrypt_key not in response_text + # Full KEK hex + assert admin_settings.kek.hex() not in response_text + # AWS secret key + assert admin_credentials[1] not in response_text + + def test_requires_auth(self, client): + response = client.get("/admin/api/status") + assert response.status_code == 401 + + def test_x_served_by_header(self, client, admin_credentials): + headers = _basic_auth_header(admin_credentials[0], admin_credentials[1]) + response = client.get("/admin/api/status", headers=headers) + assert "x-served-by" in response.headers + + +# ============================================================================ +# Route Priority Tests +# ============================================================================ + + +class TestRoutePriority: + def test_admin_not_caught_by_s3_catchall(self, client, admin_credentials): + """Admin routes should return HTML/JSON, not S3 XML.""" + headers = _basic_auth_header(admin_credentials[0], admin_credentials[1]) + response = client.get("/admin/", headers=headers) + assert response.status_code == 200 + assert "application/xml" not in response.headers.get("content-type", "") + + def test_admin_disabled_falls_through(self, disabled_client): + """When admin is disabled, /admin should be caught by S3 catch-all.""" + response = disabled_client.get("/admin/") + # Will be caught by the S3 proxy catch-all (may error, but should be XML) + assert response.status_code != 200 or "application/xml" in response.headers.get( + "content-type", "" + ) + + +# ============================================================================ +# Collector Tests +# ============================================================================ + + +class TestCollectors: + def test_pod_identity(self, admin_settings): + start = time.monotonic() + result = collect_pod_identity(admin_settings, start) + assert "pod_name" in result + assert "uptime_seconds" in result + assert "storage_backend" in result + expected_fp = hashlib.sha256(admin_settings.kek).hexdigest()[:16] + assert result["kek_fingerprint"] == expected_fp + assert len(result["kek_fingerprint"]) == 16 + # Must not contain the actual key + assert admin_settings.kek.hex() not in str(result) + + def test_health_keys(self): + result = collect_health() + assert "memory_reserved_bytes" in result + assert "memory_limit_bytes" in result + assert "memory_usage_pct" in result + assert "requests_in_flight" in result + assert "errors_4xx" in result + assert "errors_5xx" in result + assert "errors_503" in result + + def test_throughput_keys(self): + result = collect_throughput() + rates = result["rates"] + assert "requests_per_min" in rates + assert "encrypt_per_min" in rates + assert "decrypt_per_min" in rates + assert "bytes_encrypted_per_min" in rates + assert "bytes_decrypted_per_min" in rates + assert "errors_4xx_per_min" in rates + assert "errors_5xx_per_min" in rates + assert "errors_503_per_min" in rates + history = result["history"] + assert "requests_per_min" in history + assert "bytes_encrypted_per_min" in history + assert "bytes_decrypted_per_min" in history + + def test_format_bytes(self): + assert _format_bytes(0) == "0 B" + assert _format_bytes(1023) == "1023 B" + assert _format_bytes(1024) == "1.0 KB" + assert _format_bytes(1048576) == "1.0 MB" + assert _format_bytes(1073741824) == "1.0 GB" + + def test_format_uptime(self): + assert _format_uptime(30) == "0m" + assert _format_uptime(60) == "1m" + assert _format_uptime(3661) == "1h 1m" + assert _format_uptime(90061) == "1d 1h 1m" + + +# ============================================================================ +# Rate Tracker Tests +# ============================================================================ + + +class TestRateTracker: + def test_empty_tracker_returns_zero(self): + tracker = RateTracker() + assert tracker.rate_per_minute("requests") == 0.0 + + def test_single_snapshot_returns_zero(self): + tracker = RateTracker() + tracker.record({"requests": 100}) + assert tracker.rate_per_minute("requests") == 0.0 + + def test_rate_computation(self): + tracker = RateTracker(window_seconds=300) + # Simulate two snapshots 60 seconds apart + tracker._snapshots.clear() + tracker._snapshots.append((1000.0, {"requests": 100})) + tracker._snapshots.append((1060.0, {"requests": 200})) + # 100 requests in 60 seconds = 100/min + assert tracker.rate_per_minute("requests") == 100.0 + + def test_rate_unknown_key_returns_zero(self): + tracker = RateTracker() + tracker._snapshots.clear() + tracker._snapshots.append((1000.0, {"requests": 100})) + tracker._snapshots.append((1060.0, {"requests": 200})) + assert tracker.rate_per_minute("nonexistent") == 0.0 + + def test_pruning(self): + tracker = RateTracker(window_seconds=10) + now = time.monotonic() + # Add old snapshots well before the window + for i in range(50): + tracker._snapshots.append((now - 100 + i, {"x": float(i)})) + tracker.record({"x": 100.0}) + # Old entries beyond window + 10s buffer should be pruned + assert len(tracker._snapshots) < 50 + + def test_history_empty(self): + tracker = RateTracker() + assert tracker.history("requests") == [] + + def test_history_single_snapshot(self): + tracker = RateTracker() + tracker.record({"requests": 100}) + assert tracker.history("requests") == [] + + def test_history_computation(self): + tracker = RateTracker() + tracker._snapshots.clear() + # 3 snapshots 60s apart: 100→200→400 + tracker._snapshots.append((1000.0, {"requests": 100})) + tracker._snapshots.append((1060.0, {"requests": 200})) + tracker._snapshots.append((1120.0, {"requests": 400})) + hist = tracker.history("requests") + assert len(hist) == 2 + assert hist[0] == 100.0 # (200-100)/60*60 + assert hist[1] == 200.0 # (400-200)/60*60 + + def test_history_downsampling(self): + tracker = RateTracker() + tracker._snapshots.clear() + for i in range(101): + tracker._snapshots.append((1000.0 + i * 3, {"x": float(i * 10)})) + hist = tracker.history("x", max_points=20) + assert len(hist) == 20 + + +# ============================================================================ +# Request Log Tests +# ============================================================================ + + +class TestRequestLog: + def test_empty_log(self): + log = RequestLog(maxlen=10) + assert log.recent() == [] + + def test_record_and_recent(self): + log = RequestLog(maxlen=10) + log.record("GET", "/bucket/key", "GetObject", 200, 0.05, 1024) + entries = log.recent(10) + assert len(entries) == 1 + assert entries[0]["method"] == "GET" + assert entries[0]["operation"] == "GetObject" + assert entries[0]["status"] == 200 + assert entries[0]["crypto"] == "decrypt" + assert entries[0]["duration_ms"] == 50.0 + assert entries[0]["size"] == 1024 + + def test_encrypt_crypto_tag(self): + log = RequestLog(maxlen=10) + log.record("PUT", "/bucket/key", "PutObject", 200, 0.1, 2048) + assert log.recent()[0]["crypto"] == "encrypt" + + def test_no_crypto_for_list(self): + log = RequestLog(maxlen=10) + log.record("GET", "/bucket/", "ListObjects", 200, 0.02, 0) + assert log.recent()[0]["crypto"] == "" + + def test_maxlen_eviction(self): + log = RequestLog(maxlen=5) + for i in range(10): + log.record("GET", f"/b/k{i}", "GetObject", 200, 0.01, 0) + entries = log.recent(10) + assert len(entries) == 5 + # Newest first + assert entries[0]["path"] == "/b/k9" + + def test_newest_first(self): + log = RequestLog(maxlen=10) + log.record("GET", "/first", "GetObject", 200, 0.01, 0) + log.record("PUT", "/second", "PutObject", 200, 0.01, 0) + entries = log.recent() + assert entries[0]["path"] == "/second" + assert entries[1]["path"] == "/first" + + def test_record_request_function(self): + from s3proxy.admin.collectors import _request_log + + initial = len(_request_log.recent(200)) + record_request("HEAD", "/bucket/key", "HeadObject", 200, 0.003, 0) + assert len(_request_log.recent(200)) == initial + 1