From 2724c1d272a372416d8cdfaba372cc2c874b5db9 Mon Sep 17 00:00:00 2001 From: Francis6-git Date: Sat, 27 Jun 2026 00:42:20 +0100 Subject: [PATCH] feat(llm): implement token usage tracking and cost monitoring (#361) --- api/app.py | 4 + api/routers/__init__.py | 4 + api/routers/llm_usage.py | 41 +++ astroml/tracking/__init__.py | 17 +- astroml/tracking/llm_usage_tracker.py | 319 ++++++++++++++++++ .../grafana/api_llm_cost_dashboard.json | 117 +++++++ .../prometheus/alert_rules_llm_cost.yml | 11 + monitoring/prometheus/prometheus.yml | 65 ++-- 8 files changed, 545 insertions(+), 33 deletions(-) create mode 100644 api/routers/llm_usage.py create mode 100644 astroml/tracking/llm_usage_tracker.py create mode 100644 monitoring/grafana/api_llm_cost_dashboard.json create mode 100644 monitoring/prometheus/alert_rules_llm_cost.yml diff --git a/api/app.py b/api/app.py index 045e9c8..e4651c7 100644 --- a/api/app.py +++ b/api/app.py @@ -54,8 +54,10 @@ validation_router, ws_router, streaming_router, + llm_usage_router, ) from api.routers.monitoring import record_latency + from api.routers.ws import poll_and_broadcast_transactions # Setup distributed tracing (issue #336) @@ -162,9 +164,11 @@ async def _latency_middleware(request: Request, call_next): app.include_router(chat_router) app.include_router(ws_router) app.include_router(streaming_router) +app.include_router(llm_usage_router) @app.get("/health", tags=["ops"]) + async def health(): return {"status": "ok"} diff --git a/api/routers/__init__.py b/api/routers/__init__.py index 4d896f9..ee5c2c8 100644 --- a/api/routers/__init__.py +++ b/api/routers/__init__.py @@ -20,8 +20,10 @@ from api.routers.validation import router as validation_router from api.routers.ws import router as ws_router from api.routers.streaming import router as streaming_router +from api.routers.llm_usage import router as llm_usage_router __all__ = [ + "accounts_router", "audit_router", "backup_router", @@ -43,4 +45,6 @@ "validation_router", "ws_router", "streaming_router", + "llm_usage_router", ] + diff --git a/api/routers/llm_usage.py b/api/routers/llm_usage.py new file mode 100644 index 0000000..8f7cbc3 --- /dev/null +++ b/api/routers/llm_usage.py @@ -0,0 +1,41 @@ +"""LLM usage and cost monitoring endpoints. + +These endpoints expose: +- recent LLM call events (all calls logged) +- rolling cost summaries + +Prometheus metrics are emitted by ``LLMUsageTracker``. +""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +from fastapi import APIRouter, Query + +from astroml.tracking.llm_usage_tracker import default_llm_usage_tracker + +router = APIRouter(prefix="/api/v1/llm", tags=["llm"]) + + +@router.get("/usage/recent", response_model=List[Dict[str, Any]]) +def recent_llm_usage(limit: int = Query(100, ge=1, le=1000)): + """Return the most recent recorded LLM calls.""" + return default_llm_usage_tracker.recent_calls(limit=limit) + + +@router.get("/usage/summary", response_model=Dict[str, Any]) +def usage_summary(): + """Return a lightweight summary based on recent in-memory buffer.""" + events = default_llm_usage_tracker.recent_calls(limit=5000) + total_calls = len(events) + total_cost_usd = sum(float(e.get("cost_usd", 0.0) or 0.0) for e in events) + total_tokens = sum(int(e.get("total_tokens", 0) or 0) for e in events) + + return { + "total_calls": total_calls, + "total_cost_usd": round(total_cost_usd, 6), + "total_tokens": total_tokens, + "window": "in-memory-recent (up to last 5000 events)", + } + diff --git a/astroml/tracking/__init__.py b/astroml/tracking/__init__.py index e2b03d5..59cdb2e 100644 --- a/astroml/tracking/__init__.py +++ b/astroml/tracking/__init__.py @@ -1,3 +1,18 @@ +"""Tracking utilities (metrics, usage, experiment tracking, etc).""" + from .mlflow_tracker import MLflowTracker +from .llm_usage_tracker import ( + LLMUsage, + LLMPrices, + LLMUsageTracker, + default_llm_usage_tracker, +) + +__all__ = [ + "MLflowTracker", + "LLMUsage", + "LLMPrices", + "LLMUsageTracker", + "default_llm_usage_tracker", +] -__all__ = ["MLflowTracker"] diff --git a/astroml/tracking/llm_usage_tracker.py b/astroml/tracking/llm_usage_tracker.py new file mode 100644 index 0000000..902a8ae --- /dev/null +++ b/astroml/tracking/llm_usage_tracker.py @@ -0,0 +1,319 @@ +"""Token usage + cost tracking utilities for LLM calls. + +This repo currently doesn't include a concrete LLM provider integration. +To keep the feature testable and useful, this module is provider-agnostic: +callers should construct an ``LLMUsage`` object from provider responses and +pass it to ``LLMUsageTracker``. + +Integration points: +- Wrap your LLM provider call and record usage (tokens, latency, cost). +- Optionally register cost alerts (callbacks). +- Expose Prometheus metrics (if prometheus_client is installed). + +""" + +from __future__ import annotations + +import json +import logging +import os +import threading +import time +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import Callable, Dict, List, Optional + +try: + from prometheus_client import Counter, Gauge, Histogram +except Exception: # pragma: no cover + Counter = Gauge = Histogram = None # type: ignore + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class LLMUsage: + """Usage details from an LLM provider response.""" + + provider: str + model: str + prompt_tokens: int + completion_tokens: int + total_tokens: int + # provider-calculated cost in USD (preferred) + cost_usd: float + # request latency seconds + latency_s: float + # request correlation ids (optional) + request_id: Optional[str] = None + user_id: Optional[str] = None + session_id: Optional[str] = None + + +@dataclass(frozen=True) +class LLMPrices: + """Static token prices (USD per 1K tokens) for cost estimation.""" + + prompt_usd_per_1k: float + completion_usd_per_1k: float + + def estimate_cost_usd(self, prompt_tokens: int, completion_tokens: int) -> float: + return (prompt_tokens / 1000.0) * self.prompt_usd_per_1k + ( + completion_tokens / 1000.0 + ) * self.completion_usd_per_1k + + +class LLMUsageTracker: + """Tracks all LLM calls for cost/latency monitoring. + + Responsibilities: + - Record each call (in-memory ring buffer) + - Maintain rolling totals for cost + - Emit Prometheus metrics (if available) + - Invoke cost alert callbacks when thresholds are crossed + """ + + def __init__( + self, + *, + enabled: Optional[bool] = None, + alert_budget_usd_per_window: Optional[float] = None, + alert_window_s: Optional[int] = None, + ring_buffer_size: int = 5000, + prices: Optional[Dict[str, LLMPrices]] = None, + log_path: Optional[str] = None, + ): + self.enabled = ( + bool(os.environ.get("LLM_USAGE_TRACKING_ENABLED", "1")) + if enabled is None + else enabled + ) + self.alert_budget_usd_per_window = ( + float(os.environ.get("LLM_COST_ALERT_BUDGET_USD", "0")) + if alert_budget_usd_per_window is None + else alert_budget_usd_per_window + ) + self.alert_window_s = int( + os.environ.get("LLM_COST_ALERT_WINDOW_S", "3600") + if alert_window_s is None + else alert_window_s + ) + self.ring_buffer_size = int(ring_buffer_size) + self.prices = prices or {} + + self._lock = threading.Lock() + self._events: List[dict] = [] + self._events_start_idx = 0 + + self._window_start_ts = time.time() + self._window_cost_usd = 0.0 + + self._alert_callbacks: List[Callable[[dict], None]] = [] + + self._prom = {} + self._init_prometheus() + + self._log_path = log_path or os.environ.get( + "LLM_USAGE_LOG_PATH", "./llm_usage_events.jsonl" + ) + + def _init_prometheus(self) -> None: + if Counter is None: + return + + self._prom["llm_calls_total"] = Counter( + "astroml_llm_calls_total", + "Total number of LLM calls", + ["provider", "model"], + ) + self._prom["llm_tokens_total"] = Counter( + "astroml_llm_tokens_total", + "Total tokens used by LLM calls", + ["provider", "model", "token_type"], + ) + self._prom["llm_latency_seconds"] = Histogram( + "astroml_llm_latency_seconds", + "Latency of LLM calls in seconds", + ["provider", "model"], + ) + self._prom["llm_cost_usd_total"] = Counter( + "astroml_llm_cost_usd_total", + "Cumulative cost in USD for LLM calls", + ["provider", "model"], + ) + self._prom["llm_cost_budget_usd_gauge"] = Gauge( + "astroml_llm_cost_budget_usd_gauge", + "Configured LLM cost budget per alert window (USD)", + ) + try: + if self.alert_budget_usd_per_window: + self._prom["llm_cost_budget_usd_gauge"].set( + float(self.alert_budget_usd_per_window) + ) + except Exception: + pass + + def register_cost_alert_callback(self, cb: Callable[[dict], None]) -> None: + """Register a callback invoked when budget is exceeded.""" + with self._lock: + self._alert_callbacks.append(cb) + + def _push_event(self, event: dict) -> None: + if len(self._events) < self.ring_buffer_size: + self._events.append(event) + else: + # ring buffer: drop oldest + self._events[self._events_start_idx % self.ring_buffer_size] = event + self._events_start_idx += 1 + + def _get_cost_from_prices_or_pass_through(self, usage: LLMUsage) -> float: + # Cost_usd is preferred from provider. + if usage.cost_usd is not None: + return float(usage.cost_usd) + + key = f"{usage.provider}:{usage.model}" + prices = self.prices.get(key) or self.prices.get(usage.model) + if not prices: + raise ValueError( + "cost_usd missing and no prices configured for provider/model" + ) + return prices.estimate_cost_usd( + prompt_tokens=usage.prompt_tokens, + completion_tokens=usage.completion_tokens, + ) + + def record_call( + self, + *, + provider: str, + model: str, + prompt_tokens: int, + completion_tokens: int, + latency_s: float, + cost_usd: Optional[float] = None, + request_id: Optional[str] = None, + user_id: Optional[str] = None, + session_id: Optional[str] = None, + ) -> LLMUsage: + """Record an LLM call. + + All LLM calls should pass token counts and latency. + If ``cost_usd`` is not provided, you must configure ``prices``. + """ + + total_tokens = int(prompt_tokens) + int(completion_tokens) + + usage = LLMUsage( + provider=provider, + model=model, + prompt_tokens=int(prompt_tokens), + completion_tokens=int(completion_tokens), + total_tokens=total_tokens, + cost_usd=float(cost_usd) if cost_usd is not None else None, # type: ignore[arg-type] + latency_s=float(latency_s), + request_id=request_id, + user_id=user_id, + session_id=session_id, + ) + + # Resolve cost if needed + resolved_cost_usd = ( + float(cost_usd) + if cost_usd is not None + else self._get_cost_from_prices_or_pass_through(usage) + ) + + usage_dict = { + "timestamp": datetime.now(timezone.utc).isoformat(), + "provider": provider, + "model": model, + "prompt_tokens": usage.prompt_tokens, + "completion_tokens": usage.completion_tokens, + "total_tokens": usage.total_tokens, + "latency_s": usage.latency_s, + "cost_usd": resolved_cost_usd, + "request_id": request_id, + "user_id": user_id, + "session_id": session_id, + } + + with self._lock: + if not self.enabled: + return usage + + self._push_event(usage_dict) + self._window_cost_usd += resolved_cost_usd + + now = time.time() + if now - self._window_start_ts >= self.alert_window_s: + self._window_start_ts = now + self._window_cost_usd = 0.0 + + if self.alert_budget_usd_per_window and resolved_cost_usd is not None: + # Trigger on window cost exceed + if self._window_cost_usd >= float(self.alert_budget_usd_per_window): + alert = { + "type": "llm_cost_budget_exceeded", + "timestamp": usage_dict["timestamp"], + "budget_usd": float(self.alert_budget_usd_per_window), + "window_s": int(self.alert_window_s), + "window_cost_usd": float(self._window_cost_usd), + "last_call": usage_dict, + } + for cb in list(self._alert_callbacks): + try: + cb(alert) + except Exception as exc: # pragma: no cover + logger.warning("LLM cost alert callback failed: %s", exc) + + # Emit Prometheus metrics + prom = self._prom + if prom: + try: + prom["llm_calls_total"].labels(provider=provider, model=model).inc() + prom["llm_tokens_total"].labels( + provider=provider, model=model, token_type="prompt" + ).inc(usage.prompt_tokens) + prom["llm_tokens_total"].labels( + provider=provider, model=model, token_type="completion" + ).inc(usage.completion_tokens) + prom["llm_latency_seconds"].labels( + provider=provider, model=model + ).observe(usage.latency_s) + prom["llm_cost_usd_total"].labels( + provider=provider, model=model + ).inc(resolved_cost_usd) + except Exception: + pass + + # Append JSONL log (all calls logged) + try: + with open(self._log_path, "a", encoding="utf-8") as f: + f.write(json.dumps(usage_dict) + "\n") + except Exception as exc: # pragma: no cover + logger.warning("Failed to write LLM usage log: %s", exc) + + return LLMUsage( + provider=provider, + model=model, + prompt_tokens=usage.prompt_tokens, + completion_tokens=usage.completion_tokens, + total_tokens=usage.total_tokens, + cost_usd=resolved_cost_usd, + latency_s=usage.latency_s, + request_id=request_id, + user_id=user_id, + session_id=session_id, + ) + + def recent_calls(self, limit: int = 100) -> List[dict]: + """Return most recent recorded LLM call events.""" + with self._lock: + if limit <= 0: + return [] + return list(self._events[-limit:]) + + +# Default process-wide tracker instance +default_llm_usage_tracker = LLMUsageTracker() + diff --git a/monitoring/grafana/api_llm_cost_dashboard.json b/monitoring/grafana/api_llm_cost_dashboard.json new file mode 100644 index 0000000..babb0d4 --- /dev/null +++ b/monitoring/grafana/api_llm_cost_dashboard.json @@ -0,0 +1,117 @@ +{ + "annotations": { + "list": [ + { + "builtIn": 1, + "datasource": "-- Grafana --", + "enable": true, + "hide": true, + "iconColor": "rgba(0, 211, 255, 1)", + "name": "Annotations & Alerts", + "type": "dashboard" + } + ] + }, + "editable": true, + "gnetId": null, + "graphTooltip": 0, + "id": null, + "links": [], + "panels": [ + { + "title": "LLM Calls (rate)", + "type": "graph", + "gridPos": { "h": 8, "w": 12, "x": 0, "y": 0 }, + "targets": [ + { + "expr": "rate(astroml_llm_calls_total[5m])", + "legendFormat": "{{provider}}/{{model}}", + "refId": "A" + } + ], + "yaxes": [ + { "format": "short", "label": "calls/s" }, + { "format": "short" } + ] + }, + { + "title": "LLM Cost USD (rate)", + "type": "graph", + "gridPos": { "h": 8, "w": 12, "x": 12, "y": 0 }, + "targets": [ + { + "expr": "rate(astroml_llm_cost_usd_total[5m])", + "legendFormat": "{{provider}}/{{model}}", + "refId": "B" + } + ], + "yaxes": [ + { "format": "currencyUSD", "label": "USD/s" }, + { "format": "short" } + ] + }, + { + "title": "LLM Tokens Prompt (rate)", + "type": "graph", + "gridPos": { "h": 8, "w": 12, "x": 0, "y": 8 }, + "targets": [ + { + "expr": "rate(astroml_llm_tokens_total{token_type=\"prompt\"}[5m])", + "legendFormat": "{{provider}}/{{model}}", + "refId": "C" + } + ], + "yaxes": [ + { "format": "short", "label": "tokens/s" }, + { "format": "short" } + ] + }, + { + "title": "LLM Tokens Completion (rate)", + "type": "graph", + "gridPos": { "h": 8, "w": 12, "x": 12, "y": 8 }, + "targets": [ + { + "expr": "rate(astroml_llm_tokens_total{token_type=\"completion\"}[5m])", + "legendFormat": "{{provider}}/{{model}}", + "refId": "D" + } + ], + "yaxes": [ + { "format": "short", "label": "tokens/s" }, + { "format": "short" } + ] + }, + { + "title": "LLM Latency p95", + "type": "graph", + "gridPos": { "h": 8, "w": 12, "x": 0, "y": 16 }, + "targets": [ + { + "expr": "histogram_quantile(0.95, rate(astroml_llm_latency_seconds_bucket[5m]))", + "legendFormat": "{{provider}}/{{model}}", + "refId": "E" + } + ], + "yaxes": [{ "format": "s", "label": "seconds" }, { "format": "short" }] + }, + { + "title": "LLM Cost USD (cumulative)", + "type": "stat", + "gridPos": { "h": 4, "w": 6, "x": 12, "y": 16 }, + "targets": [ + { "expr": "sum(rate(astroml_llm_cost_usd_total[5m]))", "refId": "F" } + ] + } + ], + "schemaVersion": 26, + "style": "dark", + "tags": ["astroml", "llm", "cost"], + "templating": { "list": [] }, + "time": { "from": "now-1h", "to": "now" }, + "timepicker": {}, + "timezone": "", + "title": "AstroML LLM Cost Dashboard", + "uid": "astroml_llm_cost", + "version": 1 +} diff --git a/monitoring/prometheus/alert_rules_llm_cost.yml b/monitoring/prometheus/alert_rules_llm_cost.yml new file mode 100644 index 0000000..28a3d4f --- /dev/null +++ b/monitoring/prometheus/alert_rules_llm_cost.yml @@ -0,0 +1,11 @@ +groups: + - name: astroml_llm_cost_alerts + rules: + - alert: LLMBudgetExceeded + expr: sum(rate(astroml_llm_cost_usd_total[5m])) > 0 + for: 1m + labels: + severity: warning + annotations: + summary: "LLM cost budget exceeded (rate-based placeholder)" + description: "LLM cost is non-zero; configure budgets via LLMUsageTracker env vars for callback-based alerts." diff --git a/monitoring/prometheus/prometheus.yml b/monitoring/prometheus/prometheus.yml index 6b73ec3..7ce4682 100644 --- a/monitoring/prometheus/prometheus.yml +++ b/monitoring/prometheus/prometheus.yml @@ -3,8 +3,8 @@ global: scrape_interval: 15s evaluation_interval: 15s external_labels: - monitor: 'astroml-monitor' - environment: 'docker' + monitor: "astroml-monitor" + environment: "docker" # Alertmanager configuration alerting: @@ -14,79 +14,80 @@ alerting: # Alert rules files rule_files: - - 'alert_rules.yml' + - "alert_rules.yml" + - "alert_rules_llm_cost.yml" # Scrape configurations scrape_configs: # Prometheus self-monitoring - - job_name: 'prometheus' + - job_name: "prometheus" static_configs: - - targets: ['localhost:9090'] + - targets: ["localhost:9090"] # PostgreSQL exporter (requires postgres_exporter container) - - job_name: 'postgres' - metrics_path: '/metrics' + - job_name: "postgres" + metrics_path: "/metrics" static_configs: - - targets: ['postgres-exporter:9187'] + - targets: ["postgres-exporter:9187"] relabel_configs: - source_labels: [__address__] target_label: instance - replacement: 'postgres' + replacement: "postgres" # Redis exporter (requires redis_exporter container) - - job_name: 'redis' + - job_name: "redis" static_configs: - - targets: ['redis-exporter:9121'] + - targets: ["redis-exporter:9121"] relabel_configs: - source_labels: [__address__] target_label: instance - replacement: 'redis' + replacement: "redis" # Python application metrics (astroml services) - - job_name: 'astroml-ingestion' - metrics_path: '/metrics' + - job_name: "astroml-ingestion" + metrics_path: "/metrics" static_configs: - - targets: ['ingestion:8080'] + - targets: ["ingestion:8080"] relabel_configs: - source_labels: [__address__] target_label: instance - replacement: 'ingestion' + replacement: "ingestion" - - job_name: 'astroml-streaming' - metrics_path: '/metrics' + - job_name: "astroml-streaming" + metrics_path: "/metrics" static_configs: - - targets: ['streaming:8001'] + - targets: ["streaming:8001"] relabel_configs: - source_labels: [__address__] target_label: instance - replacement: 'streaming' + replacement: "streaming" # Training service metrics - - job_name: 'astroml-training' - metrics_path: '/metrics' + - job_name: "astroml-training" + metrics_path: "/metrics" static_configs: - - targets: ['training-cpu:6007', 'training-gpu:6006'] + - targets: ["training-cpu:6007", "training-gpu:6006"] relabel_configs: - source_labels: [__address__] target_label: instance - replacement: 'training' + replacement: "training" # Development service metrics - - job_name: 'astroml-dev' - metrics_path: '/metrics' + - job_name: "astroml-dev" + metrics_path: "/metrics" static_configs: - - targets: ['dev:8002'] + - targets: ["dev:8002"] relabel_configs: - source_labels: [__address__] target_label: instance - replacement: 'dev' + replacement: "dev" # Production service metrics - - job_name: 'astroml-production' - metrics_path: '/metrics' + - job_name: "astroml-production" + metrics_path: "/metrics" static_configs: - - targets: ['production:8000'] + - targets: ["production:8000"] relabel_configs: - source_labels: [__address__] target_label: instance - replacement: 'production' + replacement: "production"