Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
from contextlib import asynccontextmanager
from typing import AsyncGenerator

from fastapi import FastAPI, Request
from fastapi import FastAPI, Request, Response
from fastapi.middleware.cors import CORSMiddleware
from prometheus_client import CONTENT_TYPE_LATEST, generate_latest

from api.auth.middleware import AuthMiddleware
from api.audit_middleware import AuditLoggingMiddleware
Expand All @@ -45,6 +46,7 @@
feedback_router,
fraud_router,
loyalty_router,
llm_health_router,
mentorship_router,
models_router,
monitoring_router,
Expand All @@ -61,6 +63,7 @@
)
from api.routers.monitoring import record_latency
from api.routers.ws import poll_and_broadcast_transactions
from astroml.llm import metrics as _llm_metrics

# Setup distributed tracing (issue #336)
_tracer_provider = setup_tracing()
Expand Down Expand Up @@ -168,6 +171,7 @@ async def _latency_middleware(request: Request, call_next):
app.include_router(ws_router)
app.include_router(streaming_router)
app.include_router(llm_router)
app.include_router(llm_health_router)
app.include_router(reports_router)
app.include_router(alerts_router)

Expand All @@ -177,6 +181,11 @@ async def health():
return {"status": "ok"}


@app.get("/metrics", tags=["ops"])
async def prometheus_metrics():
return Response(generate_latest(), media_type=CONTENT_TYPE_LATEST)


@app.get("/api/v1", tags=["ops"])
async def api_root():
return {"version": settings.api_version, "status": "ok"}
2 changes: 2 additions & 0 deletions api/routers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from api.routers.ws import router as ws_router
from api.routers.streaming import router as streaming_router
from api.routers.llm import router as llm_router
from api.routers.llm_health import router as llm_health_router
from api.routers.reports import router as reports_router
from api.routers.alerts import router as alerts_router

Expand Down Expand Up @@ -49,6 +50,7 @@
"ws_router",
"streaming_router",
"llm_router",
"llm_health_router",
"reports_router",
"alerts_router",
]
20 changes: 20 additions & 0 deletions api/routers/llm_health.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""LLM Health and Provider Status API."""
from __future__ import annotations

from fastapi import APIRouter

from astroml.llm.health import check_all_providers, check_provider_health

router = APIRouter(prefix="/api/v1/llm", tags=["llm-health"])


@router.get("/health")
async def llm_health():
result = await check_all_providers()
return result


@router.get("/health/{provider_name}")
async def llm_provider_health(provider_name: str):
result = await check_provider_health(provider_name)
return result
33 changes: 33 additions & 0 deletions api/tests/test_llm_health.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""Integration tests for LLM health endpoints."""
from __future__ import annotations


class TestLLMHealth:
def test_llm_health_returns_200(self, client):
resp = client.get("/api/v1/llm/health")
assert resp.status_code == 200

def test_llm_health_has_overall_status(self, client):
data = client.get("/api/v1/llm/health").json()
assert "overall_status" in data
assert "providers" in data
assert "checked_at" in data

def test_llm_provider_health_endpoint(self, client):
resp = client.get("/api/v1/llm/health/openai")
assert resp.status_code == 200
data = resp.json()
assert data["provider"] == "openai"
assert "status" in data
assert "latency_ms" in data

def test_llm_health_providers_include_expected(self, client):
data = client.get("/api/v1/llm/health").json()
assert "openai" in data["providers"]
assert "anthropic" in data["providers"]
assert "huggingface" in data["providers"]

def test_prometheus_metrics_endpoint(self, client):
resp = client.get("/metrics")
assert resp.status_code == 200
assert "astroml_llm_provider_health" in resp.text
15 changes: 5 additions & 10 deletions astroml/db/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,7 @@ class ProcessedLedger(Base):
__tablename__ = "processed_ledgers"

id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
ledger_sequence: Mapped[int] = mapped_column(Integer, unique=True, nullable=False)
ledger_sequence: Mapped[int] = mapped_column(Integer, unique=True, nullable=False)
source: Mapped[str] = mapped_column(
String(256),
nullable=False,
Expand All @@ -653,14 +653,10 @@ class ProcessedLedger(Base):
nullable=False,
server_default=func.now(),
)
status: Mapped[
Literal["pending", "processing", "completed", "failed"]
] = mapped_column(
String(16),
nullable=False,
server_default="pending",
)
String(32),
status: Mapped[
Literal["pending", "processing", "completed", "failed"]
] = mapped_column(
String(16),
nullable=False,
server_default="pending",
)
Expand All @@ -679,4 +675,3 @@ class ProcessedLedger(Base):
Index("ix_processed_ledgers_status", "status"),
Index("ix_processed_ledgers_source", "source"),
)
)
3 changes: 2 additions & 1 deletion astroml/llm/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,12 @@ def generate_explanation(self, alert_id: int, account_id: str, pattern: str, sco
latency_ms=latency_ms
)

# Cache the response
self.cache.set(prompt, response)

return response
except Exception as e:
provider_name = self.provider.__class__.__name__.replace("Provider", "").lower()
global_tracker.record_error(provider_name)
return f"Error generating explanation: {str(e)}"

def _build_prompt(self, account_id: str, pattern: str, score: float, transactions: List[Dict[str, Any]]) -> str:
Expand Down
113 changes: 113 additions & 0 deletions astroml/llm/health.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""LLM Provider health checks."""
import asyncio
import os
import time
from typing import Any, Dict

import aiohttp

PROVIDER_ENDPOINTS = {
"openai": {
"url": "https://api.openai.com/v1/models",
"method": "GET",
"headers": lambda key: {"Authorization": f"Bearer {key}"},
},
"anthropic": {
"url": "https://api.anthropic.com/v1/messages",
"method": "HEAD",
"headers": lambda key: {
"x-api-key": key,
"anthropic-version": "2023-06-01",
},
},
"huggingface": {
"url": "https://api-inference.huggingface.co/status",
"method": "GET",
"headers": lambda key: {"Authorization": f"Bearer {key}"},
},
}


def _get_api_key(provider_name: str) -> str:
env_key = f"{provider_name.upper()}_API_KEY"
return os.getenv(env_key, "")


async def check_provider_health(
provider_name: str, timeout: float = 5.0
) -> Dict[str, Any]:
start = time.perf_counter()
if provider_name not in PROVIDER_ENDPOINTS:
latency_ms = (time.perf_counter() - start) * 1000
return {
"provider": provider_name,
"status": "unknown",
"latency_ms": round(latency_ms, 2),
"error": "Provider not supported for health checks",
}

api_key = _get_api_key(provider_name)
if not api_key:
latency_ms = (time.perf_counter() - start) * 1000
return {
"provider": provider_name,
"status": "unhealthy",
"latency_ms": round(latency_ms, 2),
"error": "API key not configured",
}

config = PROVIDER_ENDPOINTS[provider_name]

try:
async with aiohttp.ClientSession() as session:
async with session.request(
method=config["method"],
url=config["url"],
headers=config["headers"](api_key),
timeout=aiohttp.ClientTimeout(total=timeout),
) as response:
latency_ms = (time.perf_counter() - start) * 1000
healthy = 200 <= response.status < 300
return {
"provider": provider_name,
"status": "healthy" if healthy else "unhealthy",
"latency_ms": round(latency_ms, 2),
"http_status": response.status,
}
except Exception as e:
latency_ms = (time.perf_counter() - start) * 1000
return {
"provider": provider_name,
"status": "unhealthy",
"latency_ms": round(latency_ms, 2),
"error": str(e),
}


async def check_all_providers() -> Dict[str, Any]:
providers = list(PROVIDER_ENDPOINTS.keys())
results = await asyncio.gather(
*(check_provider_health(p) for p in providers),
return_exceptions=True,
)

provider_statuses = {}
for result in results:
if isinstance(result, Exception):
provider_statuses["unknown"] = {
"provider": "unknown",
"status": "unhealthy",
"latency_ms": 0,
"error": str(result),
}
else:
provider_statuses[result["provider"]] = result

all_healthy = all(
r.get("status") == "healthy" for r in provider_statuses.values()
)
return {
"overall_status": "healthy" if all_healthy else "degraded",
"providers": provider_statuses,
"checked_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
}
32 changes: 32 additions & 0 deletions astroml/llm/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from prometheus_client import Counter, Gauge, Histogram

LLM_REQUESTS_TOTAL = Counter(
"astroml_llm_requests_total",
"Total LLM API requests",
["provider", "status"],
)

LLM_REQUEST_LATENCY_SECONDS = Histogram(
"astroml_llm_request_latency_seconds",
"LLM API request latency in seconds",
["provider"],
buckets=[0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10, 30, 60],
)

LLM_COST_USD_TOTAL = Counter(
"astroml_llm_cost_usd_total",
"Total LLM API cost in USD",
["provider"],
)

LLM_TOKENS_TOTAL = Counter(
"astroml_llm_tokens_total",
"Total LLM tokens processed",
["provider", "token_type"],
)

LLM_PROVIDER_HEALTH = Gauge(
"astroml_llm_provider_health",
"LLM provider health status (1=healthy, 0=unhealthy)",
["provider"],
)
Loading