diff --git a/api/app.py b/api/app.py index f5f5f9f..9ce6269 100644 --- a/api/app.py +++ b/api/app.py @@ -55,11 +55,18 @@ validation_router, ws_router, streaming_router, + llm_usage_router, + llm_cache_metrics_router, llm_router, reports_router, alerts_router, ) + + + from api.routers.monitoring import record_latency + + from api.routers.ws import poll_and_broadcast_transactions # Setup distributed tracing (issue #336) @@ -167,12 +174,16 @@ async def _latency_middleware(request: Request, call_next): app.include_router(chat_router) app.include_router(ws_router) app.include_router(streaming_router) +app.include_router(llm_usage_router) +app.include_router(llm_cache_metrics_router) + app.include_router(llm_router) app.include_router(reports_router) app.include_router(alerts_router) @app.get("/health", tags=["ops"]) + async def health(): return {"status": "ok"} diff --git a/api/routers/__init__.py b/api/routers/__init__.py index 2f3c75b..538729e 100644 --- a/api/routers/__init__.py +++ b/api/routers/__init__.py @@ -21,11 +21,15 @@ from api.routers.validation import router as validation_router from api.routers.ws import router as ws_router from api.routers.streaming import router as streaming_router +from api.routers.llm_usage import router as llm_usage_router +from api.routers.llm_cache_metrics import router as llm_cache_metrics_router + from api.routers.llm import router as llm_router from api.routers.reports import router as reports_router from api.routers.alerts import router as alerts_router __all__ = [ + "accounts_router", "audit_router", "backup_router", @@ -48,7 +52,9 @@ "validation_router", "ws_router", "streaming_router", + "llm_usage_router", "llm_router", "reports_router", "alerts_router", ] + diff --git a/api/routers/llm_cache_metrics.py b/api/routers/llm_cache_metrics.py new file mode 100644 index 0000000..b648693 --- /dev/null +++ b/api/routers/llm_cache_metrics.py @@ -0,0 +1,19 @@ +"""LLM semantic cache metrics endpoints.""" + +from __future__ import annotations + +from typing import Any, Dict + +from fastapi import APIRouter + +from astroml.cache.redis_cache import RedisCache +from astroml.llm.llm_cached_client import get_semantic_cache_metrics + +router = APIRouter(prefix="/api/v1/llm", tags=["llm"]) + + +@router.get("/cache/semantic/metrics", response_model=Dict[str, Any]) +def semantic_cache_metrics(): + """Return semantic cache hit/miss and avg lookup latency.""" + return get_semantic_cache_metrics(redis_cache=RedisCache()) + diff --git a/api/routers/llm_usage.py b/api/routers/llm_usage.py new file mode 100644 index 0000000..8f7cbc3 --- /dev/null +++ b/api/routers/llm_usage.py @@ -0,0 +1,41 @@ +"""LLM usage and cost monitoring endpoints. + +These endpoints expose: +- recent LLM call events (all calls logged) +- rolling cost summaries + +Prometheus metrics are emitted by ``LLMUsageTracker``. +""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +from fastapi import APIRouter, Query + +from astroml.tracking.llm_usage_tracker import default_llm_usage_tracker + +router = APIRouter(prefix="/api/v1/llm", tags=["llm"]) + + +@router.get("/usage/recent", response_model=List[Dict[str, Any]]) +def recent_llm_usage(limit: int = Query(100, ge=1, le=1000)): + """Return the most recent recorded LLM calls.""" + return default_llm_usage_tracker.recent_calls(limit=limit) + + +@router.get("/usage/summary", response_model=Dict[str, Any]) +def usage_summary(): + """Return a lightweight summary based on recent in-memory buffer.""" + events = default_llm_usage_tracker.recent_calls(limit=5000) + total_calls = len(events) + total_cost_usd = sum(float(e.get("cost_usd", 0.0) or 0.0) for e in events) + total_tokens = sum(int(e.get("total_tokens", 0) or 0) for e in events) + + return { + "total_calls": total_calls, + "total_cost_usd": round(total_cost_usd, 6), + "total_tokens": total_tokens, + "window": "in-memory-recent (up to last 5000 events)", + } + diff --git a/astroml/cache/llm_semantic_cache.py b/astroml/cache/llm_semantic_cache.py new file mode 100644 index 0000000..dc3e1bf --- /dev/null +++ b/astroml/cache/llm_semantic_cache.py @@ -0,0 +1,313 @@ +"""Semantic similarity cache for LLM responses. + +This module implements a similarity-based lookup layer: +- Compute an embedding for the incoming prompt +- Find cached prompts with cosine similarity >= threshold +- Return cached response (cache hit) +- Otherwise caller is expected to compute the LLM response and store it + +Design goals (for acceptance): +- Redis-backed storage for cached responses + prompt embeddings +- Lookup fast: keep the number of candidates bounded + +Implementation notes: +- To avoid heavy vector dependencies, we store embeddings as floats and do an + in-Python scan over a limited candidate set. +- Candidate selection strategy: time-ordered buckets via Redis sorted sets. + We keep a ZSET per model+namespace and fetch the most recent K items. + +If you already have a vector index setup (RedisVector/pgvector), this module +can be swapped without changing the wrapper interface. +""" + +from __future__ import annotations + +import os +import time +import math +import json +import logging +from dataclasses import dataclass +from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple + +import redis + +from astroml.cache.redis_cache import RedisCache + +logger = logging.getLogger(__name__) + + +def _cosine_similarity(a: Sequence[float], b: Sequence[float]) -> float: + # Defensive: handle mismatched dimensions. + if not a or not b or len(a) != len(b): + return -1.0 + + dot = 0.0 + na = 0.0 + nb = 0.0 + for x, y in zip(a, b): + dot += x * y + na += x * x + nb += y * y + if na <= 0.0 or nb <= 0.0: + return -1.0 + return dot / math.sqrt(na * nb) + + +@dataclass(frozen=True) +class SemanticCacheConfig: + """Configuration for semantic caching.""" + + namespace: str = "llm:semantic" + similarity_threshold: float = float(os.environ.get("LLM_CACHE_SIMILARITY_THRESHOLD", "0.88")) + ttl_seconds: int = int(os.environ.get("LLM_CACHE_TTL_SECONDS", "600")) + candidate_top_k: int = int(os.environ.get("LLM_CACHE_LOOKBACK_K", "200")) + + +class LLMEmbeddingProvider: + """Embedding provider interface. + + Provide an implementation that returns a dense embedding vector for the input text. + """ + + def embed(self, *, text: str, model: str) -> List[float]: # pragma: no cover + raise NotImplementedError + + +class DefaultNoopEmbeddingProvider(LLMEmbeddingProvider): + """Fallback provider. + + This is a placeholder so the semantic cache module is importable even when no + embedding provider is configured. The wrapper should inject a real provider. + """ + + def embed(self, *, text: str, model: str) -> List[float]: + raise RuntimeError( + "No embedding provider configured. " + "Provide an embedding provider to LLMSemanticCacheWrapper." + ) + + +@dataclass(frozen=True) +class SemanticCacheHit: + response: Any + similarity: float + cache_key: str + cached_at: Optional[float] = None + + +class LLMSemanticCache: + """Redis-backed semantic cache. + + Storage layout (keys): + - response payload: {namespace}:resp:{model}:{cache_id} + - embedding vector: {namespace}:emb:{model}:{cache_id} + - metadata: {namespace}:meta:{model}:{cache_id} + - index sorted set: {namespace}:idx:{model}:{bucket} + + For speed, we keep one index bucket (current time slice) and fetch top K recent. + """ + + def __init__( + self, + *, + redis_cache: Optional[RedisCache] = None, + config: Optional[SemanticCacheConfig] = None, + embedding_provider: Optional[LLMEmbeddingProvider] = None, + ): + self._redis_cache = redis_cache or RedisCache() + self._config = config or SemanticCacheConfig() + self._embedding_provider = embedding_provider or DefaultNoopEmbeddingProvider() + + # Use the underlying Redis client from the existing RedisCache. + # RedisCache.client is typed as Redis. + self._redis: redis.Redis = self._redis_cache.client + + @property + def config(self) -> SemanticCacheConfig: + return self._config + + def _idx_key(self, *, model: str) -> str: + # Single bucket. If needed, extend to hour/day buckets. + return f"{self._config.namespace}:idx:{model}:all" + + def _resp_key(self, *, model: str, cache_id: str) -> str: + return f"{self._config.namespace}:resp:{model}:{cache_id}" + + def _emb_key(self, *, model: str, cache_id: str) -> str: + return f"{self._config.namespace}:emb:{model}:{cache_id}" + + def _meta_key(self, *, model: str, cache_id: str) -> str: + return f"{self._config.namespace}:meta:{model}:{cache_id}" + + def _now_bucket_score(self) -> float: + # Higher score = more recent. + return time.time() + + def lookup( + self, + *, + prompt: str, + model: str, + embedding_model: str, + ttl_seconds: Optional[int] = None, + ) -> Tuple[Optional[SemanticCacheHit], float]: + """Lookup semantic cache. + + Returns: + (hit, lookup_ms) + """ + start = time.perf_counter() + ttl_seconds = ttl_seconds or self._config.ttl_seconds + + # Compute embedding. + query_emb = self._embedding_provider.embed(text=prompt, model=embedding_model) + + idx_key = self._idx_key(model=model) + + # Fetch bounded recent candidates. + # ZREVRANGE gives highest scores first. + # Candidate cache_id list is expected to be strings. + try: + candidate_ids = self._redis.zrevrange(idx_key, 0, self._config.candidate_top_k - 1) + except Exception as e: # pragma: no cover + logger.warning("Semantic cache ZREVRANGE failed: %s", e) + candidate_ids = [] + + best: Optional[Tuple[float, str]] = None + + # Pipeline get embeddings for candidates. + pipe = self._redis.pipeline() + for cid_b in candidate_ids: + cid = cid_b.decode("utf-8") if isinstance(cid_b, (bytes, bytearray)) else str(cid_b) + pipe.get(self._emb_key(model=model, cache_id=cid)) + emb_blobs = [] + try: + emb_blobs = pipe.execute() + except Exception: # pragma: no cover + emb_blobs = [] + + for cid_b, emb_blob in zip(candidate_ids, emb_blobs): + cid = cid_b.decode("utf-8") if isinstance(cid_b, (bytes, bytearray)) else str(cid_b) + if not emb_blob: + continue + try: + emb_vec = json.loads(emb_blob) + except Exception: + continue + if not isinstance(emb_vec, list): + continue + + sim = _cosine_similarity(query_emb, emb_vec) + if best is None or sim > best[0]: + best = (sim, cid) + + if best is None: + return None, (time.perf_counter() - start) * 1000.0 + + best_sim, best_cid = best + if best_sim < self._config.similarity_threshold: + return None, (time.perf_counter() - start) * 1000.0 + + # Fetch response + metadata. + resp_key = self._resp_key(model=model, cache_id=best_cid) + meta_key = self._meta_key(model=model, cache_id=best_cid) + pipe = self._redis.pipeline() + pipe.get(resp_key) + pipe.get(meta_key) + resp_blob, meta_blob = pipe.execute() + + if resp_blob is None: + return None, (time.perf_counter() - start) * 1000.0 + + # resp_cache uses pickle in RedisCache.get/set; but here we directly store + # the same pickled payload so we can unpickle via RedisCache.get. + # However we don't have the full key-format contract here; simplest is to + # use RedisCache.get on resp_key. + try: + response_obj = self._redis_cache.get(resp_key) + except Exception: + response_obj = None + + cached_at: Optional[float] = None + if meta_blob: + try: + meta = json.loads(meta_blob) + cached_at = meta.get("cached_at") + except Exception: + cached_at = None + + hit = SemanticCacheHit( + response=response_obj, + similarity=float(best_sim), + cache_key=f"{model}:{best_cid}", + cached_at=cached_at, + ) + return hit, (time.perf_counter() - start) * 1000.0 + + def store( + self, + *, + prompt: str, + response: Any, + model: str, + embedding_model: str, + cache_id: Optional[str] = None, + ttl_seconds: Optional[int] = None, + ) -> str: + """Store response in semantic cache.""" + ttl_seconds = ttl_seconds or self._config.ttl_seconds + cache_id = cache_id or str(int(time.time() * 1000)) + + emb_vec = self._embedding_provider.embed(text=prompt, model=embedding_model) + + resp_key = self._resp_key(model=model, cache_id=cache_id) + emb_key = self._emb_key(model=model, cache_id=cache_id) + meta_key = self._meta_key(model=model, cache_id=cache_id) + idx_key = self._idx_key(model=model) + + # Store embedding + metadata as JSON strings (fast to read). + # Store response using existing RedisCache.set (pickle). + try: + self._redis_cache.set(resp_key, response, ttl_seconds=ttl_seconds) + except Exception as e: # pragma: no cover + logger.warning("Semantic cache response set failed: %s", e) + + emb_json = json.dumps(list(map(float, emb_vec))) + meta_json = json.dumps({"cached_at": time.time()}) + + # Store embedding/metadata with TTL too. + try: + self._redis.setex(emb_key, ttl_seconds, emb_json) + self._redis.setex(meta_key, ttl_seconds, meta_json) + # Index: add candidate id; keep index roughly bounded by trimming. + self._redis.zadd(idx_key, {cache_id: self._now_bucket_score()}) + # Soft cap: keep only recent 10x candidate_top_k + cap = max(self._config.candidate_top_k * 10, 100) + self._redis.zremrangebyrank(idx_key, 0, -(cap + 1)) + except Exception as e: # pragma: no cover + logger.warning("Semantic cache embedding/index set failed: %s", e) + + return cache_id + + +class SimpleDeterministicEmbeddingProvider(LLMEmbeddingProvider): + """Developer-friendly embedding provider (NOT semantic-quality). + + Intended for unit tests / local usage when no embedding model exists. + Produces a fixed-length vector derived from hash of text. + """ + + def __init__(self, dim: int = 64): + self.dim = dim + + def embed(self, *, text: str, model: str) -> List[float]: + import hashlib + h = hashlib.md5((model + ":" + text).encode("utf-8")).digest() + # Expand digest to dim floats. + out = [] + for i in range(self.dim): + b = h[i % len(h)] + out.append((b / 255.0) * 2.0 - 1.0) + return out + diff --git a/astroml/llm/llm_cached_client.py b/astroml/llm/llm_cached_client.py new file mode 100644 index 0000000..3124210 --- /dev/null +++ b/astroml/llm/llm_cached_client.py @@ -0,0 +1,172 @@ +"""Cached LLM client wrapper. + +This wrapper is provider-agnostic. It expects an injected underlying client +that can execute the actual LLM call. + +The wrapper performs: +1) Semantic similarity cache lookup +2) If hit: return cached response +3) If miss: call underlying LLM, then store semantic cache + +It also tracks basic cache metrics in Redis (separate counters) so that the +API can expose hit rate and average lookup latency. +""" + +from __future__ import annotations + +import os +import time +import uuid +import logging +from dataclasses import dataclass +from typing import Any, Callable, Dict, Optional, Tuple, Protocol + +from astroml.cache.redis_cache import RedisCache +from astroml.cache.llm_semantic_cache import ( + LLMSemanticCache, + LLMEmbeddingProvider, + SemanticCacheConfig, + SemanticCacheHit, +) + +logger = logging.getLogger(__name__) + + +class LLMProvider(Protocol): + def complete(self, *, model: str, prompt: str, **kwargs: Any) -> Any: # pragma: no cover + ... + + +@dataclass(frozen=True) +class LLMCachedClientConfig: + redis_url: Optional[str] = os.environ.get("LLM_CACHE_REDIS_URL") + model: Optional[str] = None + embedding_model: str = os.environ.get("LLM_CACHE_EMBEDDING_MODEL", "text-embedding-placeholder") + + similarity_threshold: float = float(os.environ.get("LLM_CACHE_SIMILARITY_THRESHOLD", "0.88")) + ttl_seconds: int = int(os.environ.get("LLM_CACHE_TTL_SECONDS", "600")) + candidate_top_k: int = int(os.environ.get("LLM_CACHE_LOOKBACK_K", "200")) + + metrics_prefix: str = "llm:semantic:metrics" + + +class LLMCachedClient: + """Semantic caching wrapper for an LLM provider.""" + + def __init__( + self, + provider: LLMProvider, + *, + embedding_provider: LLMEmbeddingProvider, + config: Optional[LLMCachedClientConfig] = None, + redis_cache: Optional[RedisCache] = None, + ): + self._provider = provider + self._redis_cache = redis_cache or RedisCache() + self._config = config or LLMCachedClientConfig() + + self._semantic_cache = LLMSemanticCache( + redis_cache=self._redis_cache, + config=SemanticCacheConfig( + similarity_threshold=self._config.similarity_threshold, + ttl_seconds=self._config.ttl_seconds, + candidate_top_k=self._config.candidate_top_k, + ), + embedding_provider=embedding_provider, + ) + + self._redis = self._redis_cache.client + + def _metric_key(self, suffix: str) -> str: + return f"{self._config.metrics_prefix}:{suffix}" + + def _incr(self, key: str, amount: int = 1) -> None: + try: + self._redis.incrby(key, amount) + except Exception: + pass + + def _observe_ms(self, key: str, value_ms: float) -> None: + # Keep sum + count for avg. + try: + pipe = self._redis.pipeline() + pipe.incrbyfloat(self._metric_key(key), value_ms) + pipe.incrby(self._metric_key(key) + ":n", 1) + pipe.execute() + except Exception: + pass + + def complete(self, *, model: str, prompt: str, request_id: Optional[str] = None, **kwargs: Any) -> Any: + request_id = request_id or str(uuid.uuid4()) + + # Lookup + hit, lookup_ms = self._semantic_cache.lookup( + prompt=prompt, + model=model, + embedding_model=self._config.embedding_model, + ) + + self._observe_ms("lookup_ms_sum", lookup_ms) + self._incr(self._metric_key("lookup_ms_n"), 1) + + if hit is not None and hit.response is not None: + self._incr(self._metric_key("hits"), 1) + return hit.response + + self._incr(self._metric_key("misses"), 1) + + # Miss -> call provider + start = time.perf_counter() + response = self._provider.complete(model=model, prompt=prompt, request_id=request_id, **kwargs) + _ = time.perf_counter() - start + + # Store + try: + self._semantic_cache.store( + prompt=prompt, + response=response, + model=model, + embedding_model=self._config.embedding_model, + ttl_seconds=self._config.ttl_seconds, + ) + except Exception as e: + logger.warning("Semantic cache store failed: %s", e) + + return response + + +def get_semantic_cache_metrics(*, redis_cache: Optional[RedisCache] = None, metrics_prefix: str = "llm:semantic:metrics") -> Dict[str, Any]: + rc = redis_cache or RedisCache() + redis_client = rc.client + + def _get_int(k: str) -> int: + try: + v = redis_client.get(k) + if v is None: + return 0 + return int(v) + except Exception: + return 0 + + hits = _get_int(f"{metrics_prefix}:hits") + misses = _get_int(f"{metrics_prefix}:misses") + + # Lookup ms: sum+count + try: + sum_ms = float(redis_client.get(f"{metrics_prefix}:lookup_ms_sum") or 0.0) + n = _get_int(f"{metrics_prefix}:lookup_ms_n") + except Exception: + sum_ms, n = 0.0, 0 + + total = hits + misses + hit_rate = (hits / total) if total else 0.0 + avg_lookup_ms = (sum_ms / n) if n else 0.0 + + return { + "hits": hits, + "misses": misses, + "hit_rate": hit_rate, + "avg_lookup_ms": avg_lookup_ms, + "total_lookups": total, + } + diff --git a/astroml/tracking/__init__.py b/astroml/tracking/__init__.py index cc8a2f8..85f36b6 100644 --- a/astroml/tracking/__init__.py +++ b/astroml/tracking/__init__.py @@ -1,4 +1,21 @@ +"""Tracking utilities (metrics, usage, experiment tracking, etc).""" + from .mlflow_tracker import MLflowTracker +from .llm_usage_tracker import ( + LLMUsage, + LLMPrices, + LLMUsageTracker, + default_llm_usage_tracker, +) + +__all__ = [ + "MLflowTracker", + "LLMUsage", + "LLMPrices", + "LLMUsageTracker", + "default_llm_usage_tracker", +] + from .model_registry import ModelRegistry __all__ = ["MLflowTracker", "ModelRegistry"] diff --git a/astroml/tracking/llm_usage_tracker.py b/astroml/tracking/llm_usage_tracker.py new file mode 100644 index 0000000..902a8ae --- /dev/null +++ b/astroml/tracking/llm_usage_tracker.py @@ -0,0 +1,319 @@ +"""Token usage + cost tracking utilities for LLM calls. + +This repo currently doesn't include a concrete LLM provider integration. +To keep the feature testable and useful, this module is provider-agnostic: +callers should construct an ``LLMUsage`` object from provider responses and +pass it to ``LLMUsageTracker``. + +Integration points: +- Wrap your LLM provider call and record usage (tokens, latency, cost). +- Optionally register cost alerts (callbacks). +- Expose Prometheus metrics (if prometheus_client is installed). + +""" + +from __future__ import annotations + +import json +import logging +import os +import threading +import time +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import Callable, Dict, List, Optional + +try: + from prometheus_client import Counter, Gauge, Histogram +except Exception: # pragma: no cover + Counter = Gauge = Histogram = None # type: ignore + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class LLMUsage: + """Usage details from an LLM provider response.""" + + provider: str + model: str + prompt_tokens: int + completion_tokens: int + total_tokens: int + # provider-calculated cost in USD (preferred) + cost_usd: float + # request latency seconds + latency_s: float + # request correlation ids (optional) + request_id: Optional[str] = None + user_id: Optional[str] = None + session_id: Optional[str] = None + + +@dataclass(frozen=True) +class LLMPrices: + """Static token prices (USD per 1K tokens) for cost estimation.""" + + prompt_usd_per_1k: float + completion_usd_per_1k: float + + def estimate_cost_usd(self, prompt_tokens: int, completion_tokens: int) -> float: + return (prompt_tokens / 1000.0) * self.prompt_usd_per_1k + ( + completion_tokens / 1000.0 + ) * self.completion_usd_per_1k + + +class LLMUsageTracker: + """Tracks all LLM calls for cost/latency monitoring. + + Responsibilities: + - Record each call (in-memory ring buffer) + - Maintain rolling totals for cost + - Emit Prometheus metrics (if available) + - Invoke cost alert callbacks when thresholds are crossed + """ + + def __init__( + self, + *, + enabled: Optional[bool] = None, + alert_budget_usd_per_window: Optional[float] = None, + alert_window_s: Optional[int] = None, + ring_buffer_size: int = 5000, + prices: Optional[Dict[str, LLMPrices]] = None, + log_path: Optional[str] = None, + ): + self.enabled = ( + bool(os.environ.get("LLM_USAGE_TRACKING_ENABLED", "1")) + if enabled is None + else enabled + ) + self.alert_budget_usd_per_window = ( + float(os.environ.get("LLM_COST_ALERT_BUDGET_USD", "0")) + if alert_budget_usd_per_window is None + else alert_budget_usd_per_window + ) + self.alert_window_s = int( + os.environ.get("LLM_COST_ALERT_WINDOW_S", "3600") + if alert_window_s is None + else alert_window_s + ) + self.ring_buffer_size = int(ring_buffer_size) + self.prices = prices or {} + + self._lock = threading.Lock() + self._events: List[dict] = [] + self._events_start_idx = 0 + + self._window_start_ts = time.time() + self._window_cost_usd = 0.0 + + self._alert_callbacks: List[Callable[[dict], None]] = [] + + self._prom = {} + self._init_prometheus() + + self._log_path = log_path or os.environ.get( + "LLM_USAGE_LOG_PATH", "./llm_usage_events.jsonl" + ) + + def _init_prometheus(self) -> None: + if Counter is None: + return + + self._prom["llm_calls_total"] = Counter( + "astroml_llm_calls_total", + "Total number of LLM calls", + ["provider", "model"], + ) + self._prom["llm_tokens_total"] = Counter( + "astroml_llm_tokens_total", + "Total tokens used by LLM calls", + ["provider", "model", "token_type"], + ) + self._prom["llm_latency_seconds"] = Histogram( + "astroml_llm_latency_seconds", + "Latency of LLM calls in seconds", + ["provider", "model"], + ) + self._prom["llm_cost_usd_total"] = Counter( + "astroml_llm_cost_usd_total", + "Cumulative cost in USD for LLM calls", + ["provider", "model"], + ) + self._prom["llm_cost_budget_usd_gauge"] = Gauge( + "astroml_llm_cost_budget_usd_gauge", + "Configured LLM cost budget per alert window (USD)", + ) + try: + if self.alert_budget_usd_per_window: + self._prom["llm_cost_budget_usd_gauge"].set( + float(self.alert_budget_usd_per_window) + ) + except Exception: + pass + + def register_cost_alert_callback(self, cb: Callable[[dict], None]) -> None: + """Register a callback invoked when budget is exceeded.""" + with self._lock: + self._alert_callbacks.append(cb) + + def _push_event(self, event: dict) -> None: + if len(self._events) < self.ring_buffer_size: + self._events.append(event) + else: + # ring buffer: drop oldest + self._events[self._events_start_idx % self.ring_buffer_size] = event + self._events_start_idx += 1 + + def _get_cost_from_prices_or_pass_through(self, usage: LLMUsage) -> float: + # Cost_usd is preferred from provider. + if usage.cost_usd is not None: + return float(usage.cost_usd) + + key = f"{usage.provider}:{usage.model}" + prices = self.prices.get(key) or self.prices.get(usage.model) + if not prices: + raise ValueError( + "cost_usd missing and no prices configured for provider/model" + ) + return prices.estimate_cost_usd( + prompt_tokens=usage.prompt_tokens, + completion_tokens=usage.completion_tokens, + ) + + def record_call( + self, + *, + provider: str, + model: str, + prompt_tokens: int, + completion_tokens: int, + latency_s: float, + cost_usd: Optional[float] = None, + request_id: Optional[str] = None, + user_id: Optional[str] = None, + session_id: Optional[str] = None, + ) -> LLMUsage: + """Record an LLM call. + + All LLM calls should pass token counts and latency. + If ``cost_usd`` is not provided, you must configure ``prices``. + """ + + total_tokens = int(prompt_tokens) + int(completion_tokens) + + usage = LLMUsage( + provider=provider, + model=model, + prompt_tokens=int(prompt_tokens), + completion_tokens=int(completion_tokens), + total_tokens=total_tokens, + cost_usd=float(cost_usd) if cost_usd is not None else None, # type: ignore[arg-type] + latency_s=float(latency_s), + request_id=request_id, + user_id=user_id, + session_id=session_id, + ) + + # Resolve cost if needed + resolved_cost_usd = ( + float(cost_usd) + if cost_usd is not None + else self._get_cost_from_prices_or_pass_through(usage) + ) + + usage_dict = { + "timestamp": datetime.now(timezone.utc).isoformat(), + "provider": provider, + "model": model, + "prompt_tokens": usage.prompt_tokens, + "completion_tokens": usage.completion_tokens, + "total_tokens": usage.total_tokens, + "latency_s": usage.latency_s, + "cost_usd": resolved_cost_usd, + "request_id": request_id, + "user_id": user_id, + "session_id": session_id, + } + + with self._lock: + if not self.enabled: + return usage + + self._push_event(usage_dict) + self._window_cost_usd += resolved_cost_usd + + now = time.time() + if now - self._window_start_ts >= self.alert_window_s: + self._window_start_ts = now + self._window_cost_usd = 0.0 + + if self.alert_budget_usd_per_window and resolved_cost_usd is not None: + # Trigger on window cost exceed + if self._window_cost_usd >= float(self.alert_budget_usd_per_window): + alert = { + "type": "llm_cost_budget_exceeded", + "timestamp": usage_dict["timestamp"], + "budget_usd": float(self.alert_budget_usd_per_window), + "window_s": int(self.alert_window_s), + "window_cost_usd": float(self._window_cost_usd), + "last_call": usage_dict, + } + for cb in list(self._alert_callbacks): + try: + cb(alert) + except Exception as exc: # pragma: no cover + logger.warning("LLM cost alert callback failed: %s", exc) + + # Emit Prometheus metrics + prom = self._prom + if prom: + try: + prom["llm_calls_total"].labels(provider=provider, model=model).inc() + prom["llm_tokens_total"].labels( + provider=provider, model=model, token_type="prompt" + ).inc(usage.prompt_tokens) + prom["llm_tokens_total"].labels( + provider=provider, model=model, token_type="completion" + ).inc(usage.completion_tokens) + prom["llm_latency_seconds"].labels( + provider=provider, model=model + ).observe(usage.latency_s) + prom["llm_cost_usd_total"].labels( + provider=provider, model=model + ).inc(resolved_cost_usd) + except Exception: + pass + + # Append JSONL log (all calls logged) + try: + with open(self._log_path, "a", encoding="utf-8") as f: + f.write(json.dumps(usage_dict) + "\n") + except Exception as exc: # pragma: no cover + logger.warning("Failed to write LLM usage log: %s", exc) + + return LLMUsage( + provider=provider, + model=model, + prompt_tokens=usage.prompt_tokens, + completion_tokens=usage.completion_tokens, + total_tokens=usage.total_tokens, + cost_usd=resolved_cost_usd, + latency_s=usage.latency_s, + request_id=request_id, + user_id=user_id, + session_id=session_id, + ) + + def recent_calls(self, limit: int = 100) -> List[dict]: + """Return most recent recorded LLM call events.""" + with self._lock: + if limit <= 0: + return [] + return list(self._events[-limit:]) + + +# Default process-wide tracker instance +default_llm_usage_tracker = LLMUsageTracker() + diff --git a/monitoring/grafana/api_llm_cost_dashboard.json b/monitoring/grafana/api_llm_cost_dashboard.json new file mode 100644 index 0000000..babb0d4 --- /dev/null +++ b/monitoring/grafana/api_llm_cost_dashboard.json @@ -0,0 +1,117 @@ +{ + "annotations": { + "list": [ + { + "builtIn": 1, + "datasource": "-- Grafana --", + "enable": true, + "hide": true, + "iconColor": "rgba(0, 211, 255, 1)", + "name": "Annotations & Alerts", + "type": "dashboard" + } + ] + }, + "editable": true, + "gnetId": null, + "graphTooltip": 0, + "id": null, + "links": [], + "panels": [ + { + "title": "LLM Calls (rate)", + "type": "graph", + "gridPos": { "h": 8, "w": 12, "x": 0, "y": 0 }, + "targets": [ + { + "expr": "rate(astroml_llm_calls_total[5m])", + "legendFormat": "{{provider}}/{{model}}", + "refId": "A" + } + ], + "yaxes": [ + { "format": "short", "label": "calls/s" }, + { "format": "short" } + ] + }, + { + "title": "LLM Cost USD (rate)", + "type": "graph", + "gridPos": { "h": 8, "w": 12, "x": 12, "y": 0 }, + "targets": [ + { + "expr": "rate(astroml_llm_cost_usd_total[5m])", + "legendFormat": "{{provider}}/{{model}}", + "refId": "B" + } + ], + "yaxes": [ + { "format": "currencyUSD", "label": "USD/s" }, + { "format": "short" } + ] + }, + { + "title": "LLM Tokens Prompt (rate)", + "type": "graph", + "gridPos": { "h": 8, "w": 12, "x": 0, "y": 8 }, + "targets": [ + { + "expr": "rate(astroml_llm_tokens_total{token_type=\"prompt\"}[5m])", + "legendFormat": "{{provider}}/{{model}}", + "refId": "C" + } + ], + "yaxes": [ + { "format": "short", "label": "tokens/s" }, + { "format": "short" } + ] + }, + { + "title": "LLM Tokens Completion (rate)", + "type": "graph", + "gridPos": { "h": 8, "w": 12, "x": 12, "y": 8 }, + "targets": [ + { + "expr": "rate(astroml_llm_tokens_total{token_type=\"completion\"}[5m])", + "legendFormat": "{{provider}}/{{model}}", + "refId": "D" + } + ], + "yaxes": [ + { "format": "short", "label": "tokens/s" }, + { "format": "short" } + ] + }, + { + "title": "LLM Latency p95", + "type": "graph", + "gridPos": { "h": 8, "w": 12, "x": 0, "y": 16 }, + "targets": [ + { + "expr": "histogram_quantile(0.95, rate(astroml_llm_latency_seconds_bucket[5m]))", + "legendFormat": "{{provider}}/{{model}}", + "refId": "E" + } + ], + "yaxes": [{ "format": "s", "label": "seconds" }, { "format": "short" }] + }, + { + "title": "LLM Cost USD (cumulative)", + "type": "stat", + "gridPos": { "h": 4, "w": 6, "x": 12, "y": 16 }, + "targets": [ + { "expr": "sum(rate(astroml_llm_cost_usd_total[5m]))", "refId": "F" } + ] + } + ], + "schemaVersion": 26, + "style": "dark", + "tags": ["astroml", "llm", "cost"], + "templating": { "list": [] }, + "time": { "from": "now-1h", "to": "now" }, + "timepicker": {}, + "timezone": "", + "title": "AstroML LLM Cost Dashboard", + "uid": "astroml_llm_cost", + "version": 1 +} diff --git a/monitoring/prometheus/alert_rules_llm_cost.yml b/monitoring/prometheus/alert_rules_llm_cost.yml new file mode 100644 index 0000000..28a3d4f --- /dev/null +++ b/monitoring/prometheus/alert_rules_llm_cost.yml @@ -0,0 +1,11 @@ +groups: + - name: astroml_llm_cost_alerts + rules: + - alert: LLMBudgetExceeded + expr: sum(rate(astroml_llm_cost_usd_total[5m])) > 0 + for: 1m + labels: + severity: warning + annotations: + summary: "LLM cost budget exceeded (rate-based placeholder)" + description: "LLM cost is non-zero; configure budgets via LLMUsageTracker env vars for callback-based alerts." diff --git a/monitoring/prometheus/prometheus.yml b/monitoring/prometheus/prometheus.yml index 6b73ec3..7ce4682 100644 --- a/monitoring/prometheus/prometheus.yml +++ b/monitoring/prometheus/prometheus.yml @@ -3,8 +3,8 @@ global: scrape_interval: 15s evaluation_interval: 15s external_labels: - monitor: 'astroml-monitor' - environment: 'docker' + monitor: "astroml-monitor" + environment: "docker" # Alertmanager configuration alerting: @@ -14,79 +14,80 @@ alerting: # Alert rules files rule_files: - - 'alert_rules.yml' + - "alert_rules.yml" + - "alert_rules_llm_cost.yml" # Scrape configurations scrape_configs: # Prometheus self-monitoring - - job_name: 'prometheus' + - job_name: "prometheus" static_configs: - - targets: ['localhost:9090'] + - targets: ["localhost:9090"] # PostgreSQL exporter (requires postgres_exporter container) - - job_name: 'postgres' - metrics_path: '/metrics' + - job_name: "postgres" + metrics_path: "/metrics" static_configs: - - targets: ['postgres-exporter:9187'] + - targets: ["postgres-exporter:9187"] relabel_configs: - source_labels: [__address__] target_label: instance - replacement: 'postgres' + replacement: "postgres" # Redis exporter (requires redis_exporter container) - - job_name: 'redis' + - job_name: "redis" static_configs: - - targets: ['redis-exporter:9121'] + - targets: ["redis-exporter:9121"] relabel_configs: - source_labels: [__address__] target_label: instance - replacement: 'redis' + replacement: "redis" # Python application metrics (astroml services) - - job_name: 'astroml-ingestion' - metrics_path: '/metrics' + - job_name: "astroml-ingestion" + metrics_path: "/metrics" static_configs: - - targets: ['ingestion:8080'] + - targets: ["ingestion:8080"] relabel_configs: - source_labels: [__address__] target_label: instance - replacement: 'ingestion' + replacement: "ingestion" - - job_name: 'astroml-streaming' - metrics_path: '/metrics' + - job_name: "astroml-streaming" + metrics_path: "/metrics" static_configs: - - targets: ['streaming:8001'] + - targets: ["streaming:8001"] relabel_configs: - source_labels: [__address__] target_label: instance - replacement: 'streaming' + replacement: "streaming" # Training service metrics - - job_name: 'astroml-training' - metrics_path: '/metrics' + - job_name: "astroml-training" + metrics_path: "/metrics" static_configs: - - targets: ['training-cpu:6007', 'training-gpu:6006'] + - targets: ["training-cpu:6007", "training-gpu:6006"] relabel_configs: - source_labels: [__address__] target_label: instance - replacement: 'training' + replacement: "training" # Development service metrics - - job_name: 'astroml-dev' - metrics_path: '/metrics' + - job_name: "astroml-dev" + metrics_path: "/metrics" static_configs: - - targets: ['dev:8002'] + - targets: ["dev:8002"] relabel_configs: - source_labels: [__address__] target_label: instance - replacement: 'dev' + replacement: "dev" # Production service metrics - - job_name: 'astroml-production' - metrics_path: '/metrics' + - job_name: "astroml-production" + metrics_path: "/metrics" static_configs: - - targets: ['production:8000'] + - targets: ["production:8000"] relabel_configs: - source_labels: [__address__] target_label: instance - replacement: 'production' + replacement: "production"