diff --git a/README.md b/README.md index e3523d3..41ca9f5 100644 --- a/README.md +++ b/README.md @@ -141,6 +141,14 @@ Adrian supports entirely offline, data sovereign deployments using just a handfu Use the same `adrian.init` snippet as in the [Quickstart](#quickstart) above. The SDK defaults to `ws://localhost:8080/ws`, so a self-hosted setup needs nothing more than the API key - drop the `ws_url=` line. +### Classifier error policy + +Adrian records classifier outages, malformed classifier responses, and unparseable classifier output as `verdict_status=error` with `mad_code=""`. These are operational classifier errors, not benign `M0` findings and not synthetic malicious activity. + +The default policy remains availability-first: classifier errors fail open. In **Settings -> Policy**, enable **Fail closed on classifier error** to make BLOCK-mode tool calls return blocked responses when the classifier cannot produce a verdict. In HITL mode, actionable classifier errors are sent to the review queue and held until an operator approves or rejects them. + +Fail-closed classifier-error enforcement requires the Python SDK version shipped with this repository update. Older SDKs ignore the additive protobuf `status` and policy fields, see an empty MAD code, and continue fail-open even when the dashboard toggle is enabled. + To [reset the admin password](https://docs.adrian.secureagentics.ai/reference/backend#reset-the-admin-password), [change the model](https://docs.adrian.secureagentics.ai/reference/backend#switch-the-local-gguf) and much more check out the dedicated [Docs site](https://docs.adrian.secureagentics.ai/). ## Why Adrian is different diff --git a/backend/internal/api/handlers_events.go b/backend/internal/api/handlers_events.go index eb7cfc4..87676ee 100644 --- a/backend/internal/api/handlers_events.go +++ b/backend/internal/api/handlers_events.go @@ -84,13 +84,18 @@ func (s *Server) handleListEvents(w http.ResponseWriter, r *http.Request) { since = t } } + if status := q.Get("verdict_status"); status != "" && !validVerdictStatus(status) { + writeError(w, http.StatusBadRequest, "invalid verdict_status") + return + } filters := store.EventFilters{ - Since: since, - AgentID: q.Get("agent_id"), - SessionID: q.Get("session_id"), - EventType: q.Get("event_type"), - MinMAD: q.Get("min_mad"), + Since: since, + AgentID: q.Get("agent_id"), + SessionID: q.Get("session_id"), + EventType: q.Get("event_type"), + MinMAD: q.Get("min_mad"), + VerdictStatus: q.Get("verdict_status"), } rows, total, err := s.store.ListEvents(r.Context(), filters, pg.PerPage, pg.Offset) @@ -203,6 +208,7 @@ func eventToListItemResponse(r *store.EventListRow) eventListItemResponse { ID: r.VerdictID, MADCode: r.MADCode, Classification: r.Classification, + VerdictStatus: r.VerdictStatus, } } return resp diff --git a/backend/internal/api/handlers_reviews.go b/backend/internal/api/handlers_reviews.go index 2ecd782..e8d8847 100644 --- a/backend/internal/api/handlers_reviews.go +++ b/backend/internal/api/handlers_reviews.go @@ -39,6 +39,7 @@ type reviewDetail struct { reviewSummary EventPayload json.RawMessage `json:"event_payload,omitempty"` Classification string `json:"classification,omitempty"` + Reasoning string `json:"reasoning,omitempty"` } type reviewResolveResponse struct { @@ -49,8 +50,14 @@ type reviewResolveResponse struct { func (s *Server) handleListReviews(w http.ResponseWriter, r *http.Request) { pg := parsePagination(r) - status := r.URL.Query().Get("status") - rows, total, err := s.store.ListHitlQueue(r.Context(), status, pg.PerPage, pg.Offset) + q := r.URL.Query() + status := q.Get("status") + verdictStatus := q.Get("verdict_status") + if verdictStatus != "" && !validVerdictStatus(verdictStatus) { + writeError(w, http.StatusBadRequest, "invalid verdict_status") + return + } + rows, total, err := s.store.ListHitlQueue(r.Context(), status, verdictStatus, pg.PerPage, pg.Offset) if err != nil { writeError(w, http.StatusInternalServerError, "query failed") return @@ -81,6 +88,7 @@ func (s *Server) handleGetReview(w http.ResponseWriter, r *http.Request) { resp := reviewDetail{ reviewSummary: reviewToSummary(&row.HitlReview), Classification: row.Classification, + Reasoning: row.Reasoning, } if row.EventPayloadJSON != "" { resp.EventPayload = json.RawMessage(row.EventPayloadJSON) @@ -128,7 +136,7 @@ func (s *Server) resolveReview(w http.ResponseWriter, r *http.Request, status st EventId: row.EventID, SessionId: row.SessionID, MadCode: row.MADCode, - Status: pb.VerdictStatus_VERDICT_STATUS_OK, + Status: reviewVerdictStatusProto(row.VerdictStatus), Policy: s.policySnapshotProto(pol), Hitl: &pb.HitlResponse{ContinueExecution: continueExec}, }}, @@ -165,3 +173,14 @@ func reviewToSummary(r *store.HitlReview) reviewSummary { } return out } + +func reviewVerdictStatusProto(status string) pb.VerdictStatus { + switch status { + case "error": + return pb.VerdictStatus_VERDICT_STATUS_ERROR + case "ok": + return pb.VerdictStatus_VERDICT_STATUS_OK + default: + return pb.VerdictStatus_VERDICT_STATUS_UNSPECIFIED + } +} diff --git a/backend/internal/api/handlers_stats.go b/backend/internal/api/handlers_stats.go index 8787369..d6ff5e8 100644 --- a/backend/internal/api/handlers_stats.go +++ b/backend/internal/api/handlers_stats.go @@ -6,12 +6,13 @@ package api import "net/http" type overviewResponse struct { - TotalEvents int `json:"total_events"` - FlaggedVerdicts int `json:"flagged_verdicts"` - PendingReviews int `json:"pending_reviews"` - ActiveAgents int `json:"active_agents"` - VerdictsByMAD map[string]int `json:"verdicts_by_mad"` - Window string `json:"window"` + TotalEvents int `json:"total_events"` + FlaggedVerdicts int `json:"flagged_verdicts"` + ClassifierErrors int `json:"classifier_errors"` + PendingReviews int `json:"pending_reviews"` + ActiveAgents int `json:"active_agents"` + VerdictsByMAD map[string]int `json:"verdicts_by_mad"` + Window string `json:"window"` } type activityBucketEntry struct { @@ -31,12 +32,13 @@ func (s *Server) handleStatsOverview(w http.ResponseWriter, r *http.Request) { return } writeJSON(w, http.StatusOK, overviewResponse{ - TotalEvents: o.TotalEvents, - FlaggedVerdicts: o.FlaggedVerdicts, - PendingReviews: o.PendingReviews, - ActiveAgents: o.ActiveAgents, - VerdictsByMAD: o.VerdictsByMAD, - Window: "24h", + TotalEvents: o.TotalEvents, + FlaggedVerdicts: o.FlaggedVerdicts, + ClassifierErrors: o.ClassifierErrors, + PendingReviews: o.PendingReviews, + ActiveAgents: o.ActiveAgents, + VerdictsByMAD: o.VerdictsByMAD, + Window: "24h", }) } diff --git a/backend/internal/api/handlers_test.go b/backend/internal/api/handlers_test.go index 304f1d1..bcd091e 100644 --- a/backend/internal/api/handlers_test.go +++ b/backend/internal/api/handlers_test.go @@ -14,6 +14,7 @@ import ( "time" "github.com/google/uuid" + "google.golang.org/protobuf/proto" _ "modernc.org/sqlite" "github.com/secureagentics/Adrian/backend/internal/api" @@ -415,7 +416,7 @@ func TestProfileNameValidation(t *testing.T) { func TestStatsOverview(t *testing.T) { srv, db, _, cookie := newTestServerLoggedIn(t) - // Seed: 3 events on 2 agents, 2 verdicts (one M0, one M3), + // Seed: 3 events on 2 agents, 3 verdicts (M0, M3, classifier error), // 1 pending review, 1 agents row with last_seen recent. if _, err := db.Exec( `INSERT INTO agents (id, agent_id, last_seen) VALUES (?, 'a1', datetime('now'))`, @@ -441,6 +442,13 @@ func TestStatsOverview(t *testing.T) { t.Fatalf("seed verdict: %v", err) } } + if _, err := db.Exec( + `INSERT INTO verdicts (id, event_id, session_id, mad_code, classification, verdict_status) + VALUES (?, ?, 'sess-stats', '', 'error', 'error')`, + uuid.NewString(), uuid.NewString(), + ); err != nil { + t.Fatalf("seed error verdict: %v", err) + } if _, err := db.Exec( `INSERT INTO hitl_queue (id, event_id, session_id, mad_code) VALUES (?, ?, 'sess-stats', 'M3')`, uuid.NewString(), uuid.NewString(), @@ -459,6 +467,9 @@ func TestStatsOverview(t *testing.T) { if int(data["flagged_verdicts"].(float64)) != 1 { t.Errorf("flagged_verdicts = %v, want 1 (only M3.b counts)", data["flagged_verdicts"]) } + if int(data["classifier_errors"].(float64)) != 1 { + t.Errorf("classifier_errors = %v, want 1", data["classifier_errors"]) + } if int(data["pending_reviews"].(float64)) != 1 { t.Errorf("pending_reviews = %v, want 1", data["pending_reviews"]) } @@ -466,8 +477,10 @@ func TestStatsOverview(t *testing.T) { t.Errorf("active_agents = %v, want 1", data["active_agents"]) } dist := data["verdicts_by_mad"].(map[string]any) - if int(dist["M0"].(float64)) != 1 || int(dist["M3"].(float64)) != 1 { - t.Errorf("verdicts_by_mad = %v, want M0=1 M3=1", dist) + if int(dist["M0"].(float64)) != 1 || + int(dist["M3"].(float64)) != 1 || + int(dist["error"].(float64)) != 1 { + t.Errorf("verdicts_by_mad = %v, want M0=1 M3=1 error=1", dist) } } @@ -523,6 +536,9 @@ func TestListVerdictsIncludesStatusAndFiltersError(t *testing.T) { if row["classification"] != "error" || row["verdict_status"] != "error" { t.Errorf("verdict row = %v, want classification/status error", row) } + if row["reasoning"] != "classifier failed" { + t.Errorf("reasoning = %v, want classifier failed", row["reasoning"]) + } } // ----------------------------------------------------------------- @@ -632,6 +648,139 @@ func TestApproveReviewPublishesToSubscriber(t *testing.T) { } } +func TestApproveErrorReviewPublishesErrorStatus(t *testing.T) { + srv, db, hub, cookie := newTestServerWithHub(t) + + const sessID = "sess-hitl-error" + eventID := uuid.NewString() + verdictID := uuid.NewString() + queueID := uuid.NewString() + + if _, err := db.Exec( + `INSERT INTO events (id, session_id, agent_id, event_type, run_id, payload) + VALUES (?, ?, 'agent-h', 'llm', 'r1', '{}')`, + eventID, sessID, + ); err != nil { + t.Fatalf("seed event: %v", err) + } + if _, err := db.Exec( + `INSERT INTO verdicts (id, event_id, session_id, mad_code, classification, verdict_status, reasoning) + VALUES (?, ?, ?, '', 'error', 'error', 'classifier failure: boom')`, + verdictID, eventID, sessID, + ); err != nil { + t.Fatalf("seed verdict: %v", err) + } + if _, err := db.Exec( + `INSERT INTO hitl_queue (id, event_id, verdict_id, session_id, mad_code) + VALUES (?, ?, ?, ?, '')`, + queueID, eventID, verdictID, sessID, + ); err != nil { + t.Fatalf("seed hitl_queue: %v", err) + } + + detailResp := getReq(t, srv, cookie, "/api/reviews/"+queueID) + if detailResp.StatusCode != http.StatusOK { + t.Fatalf("detail status = %d, want 200", detailResp.StatusCode) + } + detail := decodeBody(t, detailResp)["data"].(map[string]any) + if detail["reasoning"] != "classifier failure: boom" { + t.Errorf("detail.reasoning = %v, want classifier failure cause", detail["reasoning"]) + } + + ch, dereg, err := hub.Register(sessID, "test-owner") + if err != nil { + t.Fatalf("Register: %v", err) + } + defer dereg() + + resp := postJSON(t, srv, cookie, "/api/reviews/"+queueID+"/approve", map[string]any{}) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + + select { + case buf := <-ch: + var frame pb.ServerFrame + if err := proto.Unmarshal(buf, &frame); err != nil { + t.Fatalf("unmarshal frame: %v", err) + } + verdict := frame.GetVerdict() + if verdict == nil { + t.Fatalf("expected Verdict, got %T", frame.Frame) + } + if verdict.GetStatus() != pb.VerdictStatus_VERDICT_STATUS_ERROR { + t.Fatalf("status = %v, want ERROR", verdict.GetStatus()) + } + if verdict.GetMadCode() != "" { + t.Fatalf("mad_code = %q, want empty", verdict.GetMadCode()) + } + if verdict.GetHitl() == nil || !verdict.GetHitl().GetContinueExecution() { + t.Fatalf("expected approve to continue execution") + } + case <-time.After(time.Second): + t.Fatal("subscriber never received the resolution frame") + } +} + +func TestListReviewsFiltersByVerdictStatus(t *testing.T) { + srv, db, _, cookie := newTestServerWithHub(t) + + const sessID = "sess-review-filter" + okEventID := uuid.NewString() + errorEventID := uuid.NewString() + okVerdictID := uuid.NewString() + errorVerdictID := uuid.NewString() + okQueueID := uuid.NewString() + errorQueueID := uuid.NewString() + + if _, err := db.Exec( + `INSERT INTO events (id, session_id, agent_id, event_type, run_id, payload) + VALUES (?, ?, 'agent-h', 'llm', 'r-ok', '{}'), + (?, ?, 'agent-h', 'llm', 'r-error', '{}')`, + okEventID, sessID, + errorEventID, sessID, + ); err != nil { + t.Fatalf("seed events: %v", err) + } + if _, err := db.Exec( + `INSERT INTO verdicts (id, event_id, session_id, mad_code, classification, verdict_status) + VALUES (?, ?, ?, 'M3', 'block', 'ok'), + (?, ?, ?, '', 'error', 'error')`, + okVerdictID, okEventID, sessID, + errorVerdictID, errorEventID, sessID, + ); err != nil { + t.Fatalf("seed verdicts: %v", err) + } + if _, err := db.Exec( + `INSERT INTO hitl_queue (id, event_id, verdict_id, session_id, mad_code) + VALUES (?, ?, ?, ?, 'M3'), + (?, ?, ?, ?, '')`, + okQueueID, okEventID, okVerdictID, sessID, + errorQueueID, errorEventID, errorVerdictID, sessID, + ); err != nil { + t.Fatalf("seed hitl_queue: %v", err) + } + + resp := getReq(t, srv, cookie, "/api/reviews?status=pending&verdict_status=error") + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + data := decodeBody(t, resp)["data"].(map[string]any) + if int(data["total"].(float64)) != 1 { + t.Fatalf("total = %v, want 1", data["total"]) + } + reviews := data["reviews"].([]any) + row := reviews[0].(map[string]any) + if row["id"] != errorQueueID || row["verdict_status"] != "error" { + t.Fatalf("filtered review = %v, want only classifier-error review %q", row, errorQueueID) + } + + resp = getReq(t, srv, cookie, "/api/reviews?verdict_status=bogus") + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("invalid verdict_status status = %d, want 400", resp.StatusCode) + } +} + func TestApproveReviewNoSubscriberStillResolves(t *testing.T) { srv, db, _, cookie := newTestServerWithHub(t) @@ -1072,6 +1221,66 @@ func TestEventsMinMADFilterUsesLatestVerdict(t *testing.T) { } } +func TestEventsVerdictStatusFilterUsesLatestVerdict(t *testing.T) { + srv, db, _, cookie := newTestServerLoggedIn(t) + + const sid = "sess-verdict-status" + + eOK := uuid.NewString() + if _, err := db.Exec( + `INSERT INTO events (id, session_id, agent_id, event_type, run_id, payload) + VALUES (?, ?, 'agent-ok', 'tool', 'r1', '{}')`, + eOK, sid, + ); err != nil { + t.Fatalf("seed ok event: %v", err) + } + if _, err := db.Exec( + `INSERT INTO verdicts (id, event_id, session_id, mad_code, classification, verdict_status, created_at) + VALUES (?, ?, ?, '', 'error', 'error', datetime('now', '-2 seconds')), + (?, ?, ?, 'M0', 'benign', 'ok', datetime('now', '-1 seconds'))`, + uuid.NewString(), eOK, sid, + uuid.NewString(), eOK, sid, + ); err != nil { + t.Fatalf("seed ok verdicts: %v", err) + } + + eError := uuid.NewString() + if _, err := db.Exec( + `INSERT INTO events (id, session_id, agent_id, event_type, run_id, payload) + VALUES (?, ?, 'agent-error', 'llm', 'r2', '{}')`, + eError, sid, + ); err != nil { + t.Fatalf("seed error event: %v", err) + } + if _, err := db.Exec( + `INSERT INTO verdicts (id, event_id, session_id, mad_code, classification, verdict_status, created_at) + VALUES (?, ?, ?, 'M0', 'benign', 'ok', datetime('now', '-2 seconds')), + (?, ?, ?, '', 'error', 'error', datetime('now', '-1 seconds'))`, + uuid.NewString(), eError, sid, + uuid.NewString(), eError, sid, + ); err != nil { + t.Fatalf("seed error verdicts: %v", err) + } + + resp := getReq(t, srv, cookie, "/api/events?session_id="+sid+"&verdict_status=error") + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + data := decodeBody(t, resp)["data"].(map[string]any) + if int(data["total"].(float64)) != 1 { + t.Errorf("verdict_status=error total = %v, want 1", data["total"]) + } + events := data["events"].([]any) + if len(events) != 1 || events[0].(map[string]any)["id"] != eError { + t.Errorf("verdict_status=error events = %v, want only event %q", events, eError) + } + + resp = getReq(t, srv, cookie, "/api/events?session_id="+sid+"&verdict_status=bogus") + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("invalid verdict_status status = %d, want 400", resp.StatusCode) + } +} + // ----------------------------------------------------------------- // MCP servers // ----------------------------------------------------------------- diff --git a/backend/internal/api/handlers_verdicts.go b/backend/internal/api/handlers_verdicts.go index 3c75e57..273517c 100644 --- a/backend/internal/api/handlers_verdicts.go +++ b/backend/internal/api/handlers_verdicts.go @@ -17,6 +17,7 @@ type verdictResponse struct { MADCode string `json:"mad_code"` Classification string `json:"classification"` VerdictStatus string `json:"verdict_status"` + Reasoning string `json:"reasoning,omitempty"` LatencyMS *int64 `json:"latency_ms,omitempty"` TokensUsed int32 `json:"tokens_used"` CreatedAt string `json:"created_at"` @@ -78,6 +79,7 @@ func verdictRowToResponse(r *store.VerdictListRow) verdictResponse { MADCode: r.MADCode, Classification: r.Classification, VerdictStatus: r.VerdictStatus, + Reasoning: r.Reasoning, LatencyMS: r.LatencyMS, TokensUsed: r.TokensUsed, CreatedAt: r.CreatedAt.UTC().Format("2006-01-02T15:04:05.000Z"), diff --git a/backend/internal/notifications/discord_test.go b/backend/internal/notifications/discord_test.go index c05a9ab..5530f39 100644 --- a/backend/internal/notifications/discord_test.go +++ b/backend/internal/notifications/discord_test.go @@ -5,13 +5,19 @@ package notifications import ( "context" + "database/sql" "encoding/json" "io" "net/http" "net/http/httptest" "strings" + "sync/atomic" "testing" "time" + + "github.com/google/uuid" + "github.com/secureagentics/Adrian/backend/internal/store" + _ "modernc.org/sqlite" ) func TestValidateDiscordWebhookURL(t *testing.T) { @@ -120,6 +126,54 @@ func TestSendNonDiscordURLRejected(t *testing.T) { } } +func TestDispatcherSkipsEmptyMADCode(t *testing.T) { + var posts int32 + mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&posts, 1) + w.WriteHeader(http.StatusNoContent) + })) + defer mock.Close() + + origHosts := allowedHosts + allowedHosts = []string{mock.URL + "/"} + defer func() { allowedHosts = origHosts }() + + db, err := sql.Open("sqlite", "file:notifications?mode=memory&cache=shared") + if err != nil { + t.Fatalf("open sqlite: %v", err) + } + defer db.Close() + if _, err := db.Exec(` +CREATE TABLE webhooks ( + id TEXT PRIMARY KEY, + platform TEXT NOT NULL DEFAULT 'discord', + webhook_url TEXT NOT NULL, + alert_type TEXT NOT NULL, + enabled INTEGER NOT NULL DEFAULT 1, + installed_by_user_id TEXT, + created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ','now')), + updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ','now')) +); +`); err != nil { + t.Fatalf("create webhooks: %v", err) + } + st := store.New(db) + if err := st.CreateWebhook(context.Background(), uuid.NewString(), mock.URL+"/api/webhooks/1/tok", "all", ""); err != nil { + t.Fatalf("create webhook: %v", err) + } + + d := NewDispatcher(st, "https://dash.example") + d.fanout(context.Background(), VerdictNotification{ + EventID: "ev-error", + SessionID: "sess-error", + MADCode: "", + Classification: "error", + }) + if got := atomic.LoadInt32(&posts); got != 0 { + t.Fatalf("webhook posts = %d, want 0 for empty MAD code", got) + } +} + func TestSendRespectsContextDeadline(t *testing.T) { // Server that holds the response open longer than the client's // context allows. The handler exits when r.Context() is cancelled diff --git a/backend/internal/notifications/dispatcher.go b/backend/internal/notifications/dispatcher.go index 99b56cc..692e4f1 100644 --- a/backend/internal/notifications/dispatcher.go +++ b/backend/internal/notifications/dispatcher.go @@ -69,7 +69,9 @@ func (d *Dispatcher) Run(ctx context.Context) { // would mean state outside SQLite). func (d *Dispatcher) fanout(ctx context.Context, vn VerdictNotification) { if vn.MADCode == "" || strings.HasPrefix(vn.MADCode, "M0") { - // Benign verdicts don't fan out; webhooks are for flagged events. + // Empty MAD codes (classifier errors) and M0 benign verdicts do + // not fan out; these webhooks are for real flagged MAD findings. + // Operational outage alerts should be a separate alert type. return } hooks, err := d.store.ListWebhooks(ctx, true) diff --git a/backend/internal/store/events.go b/backend/internal/store/events.go index 3549725..07014ca 100644 --- a/backend/internal/store/events.go +++ b/backend/internal/store/events.go @@ -76,6 +76,9 @@ type EventFilters struct { // Lets the dashboard surface flagged events that didn't trigger a // HITL hold (post-execution tool pairs, tool_call-less LLM pairs). MinMAD string + // VerdictStatus restricts to events whose latest verdict has this status. + // Accepts "ok" or "error"; empty = no filter. + VerdictStatus string } // InsertEvent persists one paired event and reports whether a new row @@ -265,6 +268,13 @@ func eventsWhere(f EventFilters) (string, []any) { args = append(args, t) } } + if f.VerdictStatus != "" { + parts = append(parts, "EXISTS (SELECT 1 FROM verdicts v "+ + "WHERE v.event_id = e.id "+ + "AND v.created_at = (SELECT max(v2.created_at) FROM verdicts v2 WHERE v2.event_id = e.id) "+ + "AND v.verdict_status = ?)") + args = append(args, f.VerdictStatus) + } return strings.Join(parts, " AND "), args } diff --git a/backend/internal/store/hitl.go b/backend/internal/store/hitl.go index 7aa3019..abae609 100644 --- a/backend/internal/store/hitl.go +++ b/backend/internal/store/hitl.go @@ -7,6 +7,7 @@ import ( "context" "database/sql" "errors" + "strings" "time" "github.com/google/uuid" @@ -47,28 +48,40 @@ func (s *Store) InsertHitlQueue(ctx context.Context, eventID, verdictID, session return err } -// ListHitlQueue returns rows in the requested status (default 'pending'), -// newest first, paginated. -func (s *Store) ListHitlQueue(ctx context.Context, status string, perPage, offset int) ([]*HitlReview, int, error) { +// ListHitlQueue returns rows in the requested review status (default +// 'pending') and optional verdict status, newest first, paginated. +func (s *Store) ListHitlQueue(ctx context.Context, status, verdictStatus string, perPage, offset int) ([]*HitlReview, int, error) { if status == "" { status = "pending" } + where := []string{"q.status = ?"} + args := []any{status} + if verdictStatus != "" { + where = append(where, "COALESCE(v.verdict_status, 'ok') = ?") + args = append(args, verdictStatus) + } + whereSQL := strings.Join(where, " AND ") + var total int if err := s.db.QueryRowContext(ctx, - `SELECT count(*) FROM hitl_queue WHERE status = ?`, status, + `SELECT count(*) + FROM hitl_queue q + LEFT JOIN verdicts v ON v.id = q.verdict_id + WHERE `+whereSQL, args..., ).Scan(&total); err != nil { return nil, 0, err } + queryArgs := append(append([]any{}, args...), perPage, offset) rows, err := s.db.QueryContext(ctx, `SELECT q.id, q.event_id, COALESCE(q.verdict_id, ''), COALESCE(q.session_id, ''), q.mad_code, COALESCE(v.verdict_status, 'ok'), q.status, COALESCE(q.reviewed_by, ''), COALESCE(q.reviewed_at, ''), q.created_at FROM hitl_queue q LEFT JOIN verdicts v ON v.id = q.verdict_id - WHERE q.status = ? + WHERE `+whereSQL+` ORDER BY q.created_at DESC LIMIT ? OFFSET ?`, - status, perPage, offset) + queryArgs...) if err != nil { return nil, 0, err } diff --git a/backend/internal/store/stats.go b/backend/internal/store/stats.go index 5d3ab47..abe40b3 100644 --- a/backend/internal/store/stats.go +++ b/backend/internal/store/stats.go @@ -10,11 +10,12 @@ import ( // Overview is the 24h summary the dashboard home renders. type Overview struct { - TotalEvents int - FlaggedVerdicts int - PendingReviews int - ActiveAgents int - VerdictsByMAD map[string]int + TotalEvents int + FlaggedVerdicts int + ClassifierErrors int + PendingReviews int + ActiveAgents int + VerdictsByMAD map[string]int } // ActivityBucket is one bin in the time-series response. @@ -35,15 +36,25 @@ func (s *Store) StatsOverview(ctx context.Context) (*Overview, error) { return nil, err } - // Flagged = anything other than M0/empty, i.e. an actual MAD code. + // Flagged = real non-M0 MAD findings. Classifier errors are tracked + // separately below so outages do not inflate security-finding totals. if err := s.db.QueryRowContext(ctx, `SELECT count(*) FROM verdicts WHERE created_at >= datetime('now', ?) + AND verdict_status = 'ok' AND mad_code != '' AND mad_code NOT LIKE 'M0%'`, window, ).Scan(&o.FlaggedVerdicts); err != nil { return nil, err } + if err := s.db.QueryRowContext(ctx, + `SELECT count(*) FROM verdicts + WHERE created_at >= datetime('now', ?) + AND verdict_status = 'error'`, window, + ).Scan(&o.ClassifierErrors); err != nil { + return nil, err + } + if err := s.db.QueryRowContext(ctx, `SELECT count(*) FROM hitl_queue WHERE status = 'pending'`, ).Scan(&o.PendingReviews); err != nil { @@ -59,7 +70,8 @@ func (s *Store) StatsOverview(ctx context.Context) (*Overview, error) { rows, err := s.db.QueryContext(ctx, `SELECT CASE - WHEN mad_code LIKE 'M0%' OR mad_code = '' THEN 'M0' + WHEN verdict_status = 'error' THEN 'error' + WHEN mad_code LIKE 'M0%' THEN 'M0' WHEN mad_code LIKE 'M2%' THEN 'M2' WHEN mad_code LIKE 'M3%' THEN 'M3' WHEN mad_code LIKE 'M4%' THEN 'M4' diff --git a/backend/internal/store/verdicts.go b/backend/internal/store/verdicts.go index 027490a..beaa21e 100644 --- a/backend/internal/store/verdicts.go +++ b/backend/internal/store/verdicts.go @@ -33,6 +33,7 @@ type VerdictListRow struct { MADCode string Classification string VerdictStatus string + Reasoning string LatencyMS *int64 TokensUsed int32 CreatedAt time.Time @@ -76,7 +77,7 @@ func (s *Store) ListVerdicts(ctx context.Context, f VerdictFilters, perPage, off args = append(args, perPage, offset) rows, err := s.db.QueryContext(ctx, `SELECT id, event_id, session_id, mad_code, classification, verdict_status, - latency_ms, tokens_used, created_at + COALESCE(reasoning, ''), latency_ms, tokens_used, created_at FROM verdicts WHERE `+where+` ORDER BY created_at DESC @@ -92,7 +93,7 @@ func (s *Store) ListVerdicts(ctx context.Context, f VerdictFilters, perPage, off var latency sql.NullInt64 var createdAt string if err := rows.Scan(&r.ID, &r.EventID, &r.SessionID, &r.MADCode, &r.Classification, &r.VerdictStatus, - &latency, &r.TokensUsed, &createdAt); err != nil { + &r.Reasoning, &latency, &r.TokensUsed, &createdAt); err != nil { return nil, 0, err } if latency.Valid { @@ -109,14 +110,14 @@ func (s *Store) ListVerdicts(ctx context.Context, f VerdictFilters, perPage, off func (s *Store) GetVerdictByEventID(ctx context.Context, eventID string) (*VerdictListRow, error) { row := s.db.QueryRowContext(ctx, `SELECT id, event_id, session_id, mad_code, classification, verdict_status, - latency_ms, tokens_used, created_at + COALESCE(reasoning, ''), latency_ms, tokens_used, created_at FROM verdicts WHERE event_id = ? ORDER BY created_at DESC LIMIT 1`, eventID) r := &VerdictListRow{} var latency sql.NullInt64 var createdAt string if err := row.Scan(&r.ID, &r.EventID, &r.SessionID, &r.MADCode, &r.Classification, &r.VerdictStatus, - &latency, &r.TokensUsed, &createdAt); err != nil { + &r.Reasoning, &latency, &r.TokensUsed, &createdAt); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, ErrNotFound } diff --git a/backend/internal/ws/handler.go b/backend/internal/ws/handler.go index a40706b..4b5f5a2 100644 --- a/backend/internal/ws/handler.go +++ b/backend/internal/ws/handler.go @@ -303,6 +303,11 @@ func persistAndClassify(ctx context.Context, sess *session, st *store.Store, cla } verdict, err := classifier.Classify(ctx, ev, agentProfileID) if err != nil { + if ctx.Err() != nil { + slog.InfoContext(ctx, "ws.classify_cancelled", + "error", err, "event_id", ev.EventId) + return nil + } slog.WarnContext(ctx, "ws.classifier_failure", "error", err, "event_id", ev.EventId) reasoning := "classifier failure: " + err.Error() @@ -354,6 +359,10 @@ func persistAndClassify(ctx context.Context, sess *session, st *store.Store, cla } func dispatchVerdict(ctx context.Context, sess *session, st *store.Store, hub *Hub, ev *pb.PairedEvent, snap *pb.PolicySnapshot, verdictID, madCode, verdictStatus string) error { + if verdictStatus == "error" { + return dispatchErrorVerdict(ctx, sess, st, hub, ev, snap, verdictID, madCode) + } + // Mode-gated dispatch: // alert: persist verdict, do NOT notify the SDK (dashboard-only). // hitl + in-scope + actionable: persist + queue for human review, @@ -361,7 +370,7 @@ func dispatchVerdict(ctx context.Context, sess *session, st *store.Store, hub *H // hitl + in-scope + non-actionable: forward (review would be a // no-op for the operator since the SDK never blocks on it). // hitl + out-of-scope: forward (no review queued for this code). - // block: forward all verdicts; SDK is the enforcement point. + // block: forward all OK verdicts; SDK is the enforcement point. inScope := shouldFanOut(snap, madCode) switch snap.GetMode() { case pb.Mode_MODE_ALERT: @@ -391,6 +400,38 @@ func dispatchVerdict(ctx context.Context, sess *session, st *store.Store, hub *H return nil } + publishVerdict(ctx, sess, hub, ev, snap, madCode, verdictStatus) + return nil +} + +func dispatchErrorVerdict(ctx context.Context, sess *session, st *store.Store, hub *Hub, ev *pb.PairedEvent, snap *pb.PolicySnapshot, verdictID, madCode string) error { + switch snap.GetMode() { + case pb.Mode_MODE_ALERT: + return nil + case pb.Mode_MODE_BLOCK: + publishVerdict(ctx, sess, hub, ev, snap, madCode, "error") + return nil + case pb.Mode_MODE_HITL: + if snap.GetFailClosedOnClassifierError() && isActionable(ev) { + if err := st.InsertHitlQueue(ctx, ev.EventId, verdictID, sess.sessionID, madCode); err != nil { + slog.ErrorContext(ctx, "hitl.insert_failed_fallback_publish", + "error", err, "event_id", ev.EventId, "verdict_id", verdictID) + publishVerdict(ctx, sess, hub, ev, snap, madCode, "error") + } + return nil + } + publishVerdict(ctx, sess, hub, ev, snap, madCode, "error") + return nil + default: + slog.WarnContext(ctx, "ws.unknown_mode_dropping_verdict", + "mode", snap.GetMode().String(), "event_id", ev.EventId) + return nil + } +} + +func publishVerdict(ctx context.Context, sess *session, hub *Hub, ev *pb.PairedEvent, snap *pb.PolicySnapshot, madCode, verdictStatus string) { + warnOldSDKClassifierErrorCompatibility(ctx, sess, ev, snap, verdictStatus) + out := &pb.ServerFrame{ Frame: &pb.ServerFrame_Verdict{ Verdict: &pb.Verdict{ @@ -406,7 +447,28 @@ func dispatchVerdict(ctx context.Context, sess *session, st *store.Store, hub *H slog.WarnContext(ctx, "ws.publish_dropped", "event_id", ev.EventId, "session_id", sess.sessionID) } - return nil +} + +func warnOldSDKClassifierErrorCompatibility(ctx context.Context, sess *session, ev *pb.PairedEvent, snap *pb.PolicySnapshot, verdictStatus string) { + if verdictStatus != "error" || !snap.GetFailClosedOnClassifierError() || sess.warnedClassifierErrorCompatibility { + return + } + sess.warnedClassifierErrorCompatibility = true + slog.WarnContext(ctx, "ws.classifier_error_fail_closed_requires_updated_sdk", + "event_id", ev.EventId, + "session_id", sess.sessionID, + "message", "old SDKs ignore classifier-error status and policy fields, so fail-closed enforcement requires the updated SDK") +} + +func verdictStatusProto(status string) pb.VerdictStatus { + switch status { + case "error": + return pb.VerdictStatus_VERDICT_STATUS_ERROR + case "ok": + return pb.VerdictStatus_VERDICT_STATUS_OK + default: + return pb.VerdictStatus_VERDICT_STATUS_UNSPECIFIED + } } func verdictStatusProto(status string) pb.VerdictStatus { diff --git a/backend/internal/ws/handler_test.go b/backend/internal/ws/handler_test.go index cf70d9c..473600a 100644 --- a/backend/internal/ws/handler_test.go +++ b/backend/internal/ws/handler_test.go @@ -254,6 +254,101 @@ func TestClassifierFailurePersistsAndPublishesErrorVerdict(t *testing.T) { } } +func TestClassifierFailureAlertPersistsWithoutPublish(t *testing.T) { + db, conn := classifierFailureConn(t, "alert", false) + + eventID := uuid.NewString() + if err := sendPairedEvent(conn, classifierFailureToolEvent(eventID, "classifier-failure-alert")); err != nil { + t.Fatalf("send paired_batch: %v", err) + } + + if err := expectNoServerFrame(conn, 250*time.Millisecond); err == nil { + t.Fatal("expected no SDK verdict in alert mode") + } + assertStoredErrorVerdict(t, db, eventID) +} + +func TestClassifierFailureHitlFailClosedQueuesActionable(t *testing.T) { + db, conn := classifierFailureConn(t, "hitl", true) + + eventID := uuid.NewString() + if err := sendPairedEvent(conn, classifierFailureActionableEvent(eventID, "classifier-failure-hitl")); err != nil { + t.Fatalf("send paired_batch: %v", err) + } + + if err := expectNoServerFrame(conn, 250*time.Millisecond); err == nil { + t.Fatal("expected actionable fail-closed ERROR verdict to be held for HITL") + } + assertStoredErrorVerdict(t, db, eventID) + + var queued int + if err := db.QueryRow( + `SELECT count(*) FROM hitl_queue h + JOIN verdicts v ON v.id = h.verdict_id + WHERE h.event_id = ? AND h.mad_code = '' AND v.verdict_status = 'error'`, + eventID, + ).Scan(&queued); err != nil { + t.Fatalf("query hitl_queue: %v", err) + } + if queued != 1 { + t.Fatalf("queued error reviews = %d, want 1", queued) + } +} + +func TestClassifierFailureHitlFailClosedNonActionablePublishes(t *testing.T) { + db, conn := classifierFailureConn(t, "hitl", true) + + eventID := uuid.NewString() + if err := sendPairedEvent(conn, classifierFailureToolEvent(eventID, "classifier-failure-hitl-nonactionable")); err != nil { + t.Fatalf("send paired_batch: %v", err) + } + + frame, err := readServerFrame(conn) + if err != nil { + t.Fatalf("read verdict: %v", err) + } + if got := frame.GetVerdict().GetStatus(); got != bpb.VerdictStatus_VERDICT_STATUS_ERROR { + t.Fatalf("pushed status = %v, want ERROR", got) + } + assertStoredErrorVerdict(t, db, eventID) + + var queued int + if err := db.QueryRow(`SELECT count(*) FROM hitl_queue WHERE event_id = ?`, eventID).Scan(&queued); err != nil { + t.Fatalf("query hitl_queue: %v", err) + } + if queued != 0 { + t.Fatalf("queued reviews = %d, want 0", queued) + } +} + +func TestClassifierFailureHitlQueueFailureFallsBackToPublish(t *testing.T) { + db, conn := classifierFailureConn(t, "hitl", true) + if _, err := db.Exec(` +CREATE TRIGGER fail_hitl_insert +BEFORE INSERT ON hitl_queue +BEGIN + SELECT RAISE(FAIL, 'forced hitl insert failure'); +END; +`); err != nil { + t.Fatalf("create hitl failure trigger: %v", err) + } + + eventID := uuid.NewString() + if err := sendPairedEvent(conn, classifierFailureActionableEvent(eventID, "classifier-failure-hitl-fallback")); err != nil { + t.Fatalf("send paired_batch: %v", err) + } + + frame, err := readServerFrame(conn) + if err != nil { + t.Fatalf("read verdict: %v", err) + } + verdict := frame.GetVerdict() + if verdict.GetStatus() != bpb.VerdictStatus_VERDICT_STATUS_ERROR || verdict.GetMadCode() != "" { + t.Fatalf("pushed verdict = (%q, %v), want ('', ERROR)", verdict.GetMadCode(), verdict.GetStatus()) + } + assertStoredErrorVerdict(t, db, eventID) +} + func TestDuplicateEventRetryKeepsWSOpen(t *testing.T) { db := openInMemoryDB(t) t.Cleanup(func() { _ = db.Close() }) @@ -648,6 +743,120 @@ type fakeClassifier struct { calls *int32 } +func classifierFailureConn(t *testing.T, mode string, failClosed bool) (*sql.DB, *websocket.Conn) { + t.Helper() + + db := openInMemoryDB(t) + t.Cleanup(func() { _ = db.Close() }) + + st := store.New(db) + plaintextKey := "adr_local_test_key_classifier_failure_" + uuid.NewString() + keyHash := sha256Hex(plaintextKey) + insertAPIKey(t, db, keyHash) + + failClosedInt := 0 + if failClosed { + failClosedInt = 1 + } + if _, err := db.Exec( + `UPDATE policies SET mode = ?, fail_closed_on_classifier_error = ? WHERE id = 1`, + mode, failClosedInt, + ); err != nil { + t.Fatalf("set policy: %v", err) + } + + llm := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "classifier exploded", http.StatusInternalServerError) + })) + t.Cleanup(llm.Close) + classifier := engine.NewHTTPClient(llm.URL, "test-key", "test-model", nil, nil) + + mux := http.NewServeMux() + mux.Handle("/ws", ws.AuthMiddleware(st)(ws.NewHandler(st, classifier, ws.NewHub(), nil, nil))) + srv := httptest.NewServer(mux) + t.Cleanup(srv.Close) + + wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + "/ws" + header := http.Header{"Authorization": {"Bearer " + plaintextKey}} + conn, _, err := websocket.DefaultDialer.Dial(wsURL, header) + if err != nil { + t.Fatalf("dial: %v", err) + } + t.Cleanup(func() { _ = conn.Close() }) + + if err := writeProto(conn, &bpb.ClientFrame{ + Frame: &bpb.ClientFrame_Login{Login: &bpb.SessionLogin{ + SessionId: "classifier-failure-sess-" + uuid.NewString(), SchemaVersion: 2, + }}, + }); err != nil { + t.Fatalf("send login: %v", err) + } + if _, err := readServerFrame(conn); err != nil { + t.Fatalf("read login_ack: %v", err) + } + return db, conn +} + +func classifierFailureToolEvent(eventID, sessionID string) *bpb.PairedEvent { + return &bpb.PairedEvent{ + EventId: eventID, SessionId: sessionID, + RunId: "run-classifier-failure", + PairType: bpb.PairType_PAIR_TYPE_TOOL, + Agent: &bpb.AgentContext{AgentId: "failure-agent"}, + Data: &bpb.PairedEvent_Tool{Tool: &bpb.ToolPairData{ + ToolName: "noop", ToolCallId: "tc-classifier-failure", Input: "{}", Output: "ok", + }}, + } +} + +func classifierFailureActionableEvent(eventID, sessionID string) *bpb.PairedEvent { + return &bpb.PairedEvent{ + EventId: eventID, SessionId: sessionID, + RunId: "run-classifier-failure", + PairType: bpb.PairType_PAIR_TYPE_LLM, + Agent: &bpb.AgentContext{AgentId: "failure-agent"}, + Data: &bpb.PairedEvent_Llm{Llm: &bpb.LlmPairData{ + Model: "test-model", + Output: "calling tool", + ToolCalls: []*bpb.ToolCall{{ + Name: "noop", Id: "tc-classifier-failure", Args: "{}", + }}, + }}, + } +} + +func sendPairedEvent(conn *websocket.Conn, ev *bpb.PairedEvent) error { + return writeProto(conn, &bpb.ClientFrame{ + Frame: &bpb.ClientFrame_PairedBatch{PairedBatch: &bpb.PairedEventBatch{ + Events: []*bpb.PairedEvent{ev}, + }}, + }) +} + +func expectNoServerFrame(conn *websocket.Conn, timeout time.Duration) error { + if err := conn.SetReadDeadline(time.Now().Add(timeout)); err != nil { + return err + } + _, _, err := conn.ReadMessage() + _ = conn.SetReadDeadline(time.Time{}) + return err +} + +func assertStoredErrorVerdict(t *testing.T, db *sql.DB, eventID string) { + t.Helper() + var madCode, classification, verdictStatus string + if err := db.QueryRow( + `SELECT mad_code, classification, verdict_status FROM verdicts WHERE event_id = ?`, + eventID, + ).Scan(&madCode, &classification, &verdictStatus); err != nil { + t.Fatalf("query verdict: %v", err) + } + if madCode != "" || classification != "error" || verdictStatus != "error" { + t.Fatalf("stored verdict = (%q, %q, %q), want ('', error, error)", + madCode, classification, verdictStatus) + } +} + func (f *fakeClassifier) Classify(_ context.Context, _ *bpb.PairedEvent, _ string) (*engine.Verdict, error) { if f.calls != nil { atomic.AddInt32(f.calls, 1) @@ -736,7 +945,8 @@ func statusOrZero(r *http.Response) int { } // testSchema is the minimum subset of 001_initial_schema.sql the WS -// handler exercises (api_keys, policies, events, verdicts, mcp_servers). +// handler exercises (api_keys, policies, events, verdicts, mcp_servers, +// hitl_queue). // Embedding the full migration file here would couple the test to the // migration's evolution. const testSchema = ` @@ -798,4 +1008,15 @@ CREATE TABLE agents ( last_seen TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ','now')), metadata TEXT NOT NULL DEFAULT '{}' ); +CREATE TABLE hitl_queue ( + id TEXT PRIMARY KEY, + event_id TEXT NOT NULL UNIQUE, + verdict_id TEXT, + session_id TEXT, + mad_code TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'pending', + reviewed_by TEXT, + reviewed_at TEXT, + created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ','now')) +); ` diff --git a/backend/internal/ws/helpers.go b/backend/internal/ws/helpers.go index 0b7e42b..4f4d283 100644 --- a/backend/internal/ws/helpers.go +++ b/backend/internal/ws/helpers.go @@ -45,8 +45,8 @@ func isActionable(ev *pb.PairedEvent) bool { return llm != nil && len(llm.ToolCalls) > 0 } -// shouldFanOut decides whether a verdict's MAD code is in scope for -// the active policy. False for codes outside the M0/M2/M3/M4 set +// shouldFanOut decides whether an OK verdict's MAD code is in scope +// for the active policy. False for codes outside the M0/M2/M3/M4 set // (defensive: an unrecognised code drops rather than panics) and for // MAD families whose policy_mX flag is unset. // diff --git a/backend/internal/ws/session.go b/backend/internal/ws/session.go index cb251a3..4409e6c 100644 --- a/backend/internal/ws/session.go +++ b/backend/internal/ws/session.go @@ -15,6 +15,8 @@ type session struct { llmProvider string llmModel string loggedIn bool + + warnedClassifierErrorCompatibility bool } // agentProfileID returns the bound agent_profile_id (or nil if the diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md index 5e0228f..2c5ee14 100644 --- a/docs/ARCHITECTURE.md +++ b/docs/ARCHITECTURE.md @@ -22,8 +22,9 @@ | HTTP POST to ADRIAN_LLM_URL (OpenAI | | compatible chat-completions), strip | | blocks, parse M-code. On | -| classifier error, fail-open with synthetic | -| M0 / benign + WARN log. | +| classifier error, return an error to WS | +| ingest. WS persists verdict_status=error, | +| mad_code="", and routes by policy. | | | | | v | | internal/store SQLite (WAL) writes: events, verdicts, | @@ -47,9 +48,9 @@ +--------------------------------------------------------------------+ | adrian-frontend (Next.js container) | | Login and force-change-password, agent profiles, API keys, | -| policy editor (singleton mode + per-MAD-code toggles), HITL | -| review queue, events and verdicts feeds (REST poll), webhook | -| configuration (Discord). | +| policy editor (singleton mode, per-MAD-code toggles, classifier | +| error fail-closed flag), HITL review queue, events and verdicts | +| feeds (REST poll), webhook configuration (Discord). | +--------------------------------------------------------------------+ +--------------------------------------------------------------------+ diff --git a/frontend/app/(dashboard)/events/page.tsx b/frontend/app/(dashboard)/events/page.tsx index 236201c..0356ff9 100644 --- a/frontend/app/(dashboard)/events/page.tsx +++ b/frontend/app/(dashboard)/events/page.tsx @@ -7,7 +7,7 @@ import { AlertExplanation } from '@/components/alert-explanation' import { Badge } from '@/components/badge' import { JsonBlock } from '@/components/json-block' import { Pagination } from '@/components/pagination' -import { madBadgeColor, timeAgo } from '@/lib/utils' +import { isClassifierErrorVerdict, madBadgeColor, timeAgo, verdictBadgeColor, verdictBadgeLabel } from '@/lib/utils' import { TimeRange, sinceForRange, TimeRangeSelect } from '@/components/time-range' type EventRow = { @@ -28,7 +28,10 @@ type EventVerdict = { id: string mad_code: string classification: string + verdict_status?: string + reasoning?: string latency_ms?: number | null + created_at?: string } type EventDetail = EventRow & { @@ -40,7 +43,14 @@ type EventDetail = EventRow & { export default function EventsPage() { const [data, setData] = useState<{ events: EventRow[]; total: number }>({ events: [], total: 0 }) const [page, setPage] = useState(1) - const [filters, setFilters] = useState({ event_type: '', session_id: '', min_mad: '' }) + const [filters, setFilters] = useState(() => ({ + event_type: '', + session_id: '', + min_mad: '', + verdict_status: typeof window === 'undefined' + ? '' + : new URLSearchParams(window.location.search).get('verdict_status') || '', + })) const [range, setRange] = useState('24h') const [customSince, setCustomSince] = useState('') const [expanded, setExpanded] = useState(null) @@ -52,6 +62,7 @@ export default function EventsPage() { if (filters.event_type) params.set('event_type', filters.event_type) if (filters.session_id) params.set('session_id', filters.session_id) if (filters.min_mad) params.set('min_mad', filters.min_mad) + if (filters.verdict_status) params.set('verdict_status', filters.verdict_status) if (since) params.set('since', since) api(`/api/events?${params}`) .then(r => setData(r.data || { events: [], total: 0 })) @@ -96,6 +107,14 @@ export default function EventsPage() { + No verdict recorded for this event yet.

) : (
- - {detail?.verdict && typeof detail.verdict.latency_ms === 'number' && ( + + {typeof verdict.latency_ms === 'number' && ( - Latency: {detail.verdict.latency_ms}ms + Latency: {verdict.latency_ms}ms )}
)} - {verdict && verdict.mad_code !== 'M0' && ( + {verdict && isClassifierErrorVerdict(verdict) && verdict.reasoning && ( +

+ {verdict.reasoning} +

+ )} + {verdict && !isClassifierErrorVerdict(verdict) && verdict.mad_code !== 'M0' && (
diff --git a/frontend/app/(dashboard)/page.tsx b/frontend/app/(dashboard)/page.tsx index 1b2978f..4b1e4dc 100644 --- a/frontend/app/(dashboard)/page.tsx +++ b/frontend/app/(dashboard)/page.tsx @@ -8,6 +8,7 @@ import { madBadgeColor } from '@/lib/utils' type Overview = { total_events: number flagged_verdicts: number + classifier_errors: number pending_reviews: number active_agents: number verdicts_by_mad: Record @@ -49,9 +50,10 @@ export default function OverviewPage() { -
+
+
@@ -99,19 +101,21 @@ export default function OverviewPage() {

Verdict mix

{overview && Object.values(overview.verdicts_by_mad).some(v => v > 0) ? (
    - {(['M0', 'M2', 'M3', 'M4'] as const).map(family => { + {(['M0', 'M2', 'M3', 'M4', 'error'] as const).map(family => { const count = overview.verdicts_by_mad[family] || 0 const total = Object.values(overview.verdicts_by_mad).reduce((a, b) => a + b, 0) const pct = total ? (count / total) * 100 : 0 return (
  • - {family} + + {family === 'error' ? 'Classifier error' : family} + {count}
    diff --git a/frontend/app/(dashboard)/reviews/page.tsx b/frontend/app/(dashboard)/reviews/page.tsx index 0d3f65c..4b4af44 100644 --- a/frontend/app/(dashboard)/reviews/page.tsx +++ b/frontend/app/(dashboard)/reviews/page.tsx @@ -6,7 +6,7 @@ import { api } from '@/lib/api' import { AlertExplanation } from '@/components/alert-explanation' import { Badge } from '@/components/badge' import { JsonBlock } from '@/components/json-block' -import { madBadgeColor, timeAgo } from '@/lib/utils' +import { isClassifierErrorVerdict, timeAgo, verdictBadgeColor, verdictBadgeLabel } from '@/lib/utils' type ReviewSummary = { id: string @@ -14,6 +14,7 @@ type ReviewSummary = { verdict_id: string session_id: string mad_code: string + verdict_status: string status: string created_at: string } @@ -21,6 +22,7 @@ type ReviewSummary = { type ReviewDetail = ReviewSummary & { event_payload?: any classification?: string + reasoning?: string } export default function ReviewsPage() { @@ -93,7 +95,7 @@ export default function ReviewsPage() {

    Nothing waiting on you

    - When policy mode is HITL and a flagged verdict lands in scope, the SDK pauses and the event appears here. + When policy mode is HITL and a flagged verdict or fail-closed classifier error lands in scope, the SDK pauses and the event appears here.

    ) : ( @@ -109,7 +111,7 @@ export default function ReviewsPage() { }`} >
    - + {timeAgo(r.created_at)}
    @@ -126,7 +128,7 @@ export default function ReviewsPage() { ) : (
    - +
    - + {isClassifierErrorVerdict(detail) ? ( +
    +

    Classifier error

    +

    + The classifier did not return a MAD code. Approving resumes the paused SDK action; rejecting returns a blocked tool response. +

    + {detail.reasoning && ( +

    + {detail.reasoning} +

    + )} +
    + ) : ( + + )}

    Event payload

    diff --git a/frontend/app/(dashboard)/sessions/[session_id]/page.tsx b/frontend/app/(dashboard)/sessions/[session_id]/page.tsx index 806f519..0898bfc 100644 --- a/frontend/app/(dashboard)/sessions/[session_id]/page.tsx +++ b/frontend/app/(dashboard)/sessions/[session_id]/page.tsx @@ -5,12 +5,13 @@ import { useParams } from 'next/navigation' import { api } from '@/lib/api' import { Badge } from '@/components/badge' import { JsonBlock } from '@/components/json-block' -import { madBadgeColor, timeAgo } from '@/lib/utils' +import { verdictBadgeColor, verdictBadgeLabel, timeAgo } from '@/lib/utils' type Verdict = { id: string mad_code: string classification: string + verdict_status: string } type Entry = { @@ -92,7 +93,7 @@ export default function SessionTimelinePage() {
    {entry.verdict && ( - + )} {timeAgo(entry.created_at)}
    @@ -111,7 +112,7 @@ export default function SessionTimelinePage() {

    Verdict

    - +
    )} diff --git a/frontend/app/(dashboard)/settings/page.tsx b/frontend/app/(dashboard)/settings/page.tsx index 9f32b4d..5bacd5f 100644 --- a/frontend/app/(dashboard)/settings/page.tsx +++ b/frontend/app/(dashboard)/settings/page.tsx @@ -80,6 +80,7 @@ function PolicyTab() { policy_m2: !!policy.policy_m2, policy_m3: !!policy.policy_m3, policy_m4: !!policy.policy_m4, + fail_closed_on_classifier_error: !!policy.fail_closed_on_classifier_error, }), }) setStatus('saved') @@ -139,12 +140,27 @@ function PolicyTab() {
    )} +
    +

    + Classifier failure handling +

    + setPolicy({ ...policy, fail_closed_on_classifier_error: v })} + /> +

    + Older SDK versions ignore this flag. Update agents to the SDK version + shipped with this dashboard before relying on fail-closed enforcement. +

    +
    +
    NOTE - - Saving disconnects every active SDK session for this org so each one - reconnects with the new policy on its next event. Events already - in-flight at the moment of save are classified against the previous - policy - expect brief discrepancies for a few seconds after a change. + Mode changes apply when an SDK reconnects. The classifier-error + fail-closed flag is included on future verdicts, so BLOCK-mode timeout + decisions refresh after the next verdict snapshot reaches the SDK.
    diff --git a/frontend/lib/utils.ts b/frontend/lib/utils.ts index b538032..9d2f512 100644 --- a/frontend/lib/utils.ts +++ b/frontend/lib/utils.ts @@ -13,6 +13,13 @@ const PILL_NEUTRAL = 'bg-surface-raised text-ink-3 border-surface-border' const PILL_SOFT = 'bg-surface-overlay text-ink-2 border-surface-border' const PILL_MID = 'bg-ink-2 text-surface-raised border-ink-2' const PILL_STRONG = 'bg-ink text-surface-raised border-ink' +const PILL_ERROR = 'bg-danger/20 text-danger border-danger/40' + +export type VerdictDisplayInput = { + mad_code?: string + classification?: string + verdict_status?: string +} | null | undefined export function madBadgeColor(code: string): string { if (!code) return PILL_NEUTRAL @@ -22,6 +29,21 @@ export function madBadgeColor(code: string): string { return PILL_NEUTRAL } +export function isClassifierErrorVerdict(verdict: VerdictDisplayInput): boolean { + if (!verdict) return false + return verdict.verdict_status === 'error' || (verdict.classification === 'error' && !verdict.mad_code) +} + +export function verdictBadgeLabel(verdict: VerdictDisplayInput): string { + if (isClassifierErrorVerdict(verdict)) return 'Classifier error' + return verdict?.mad_code || 'No MAD code' +} + +export function verdictBadgeColor(verdict: VerdictDisplayInput): string { + if (isClassifierErrorVerdict(verdict)) return PILL_ERROR + return madBadgeColor(verdict?.mad_code || '') +} + export function classificationBadgeColor(cls: string): string { switch (cls) { case 'BLOCK': return PILL_STRONG diff --git a/sdk/python/adrian/__init__.py b/sdk/python/adrian/__init__.py index 6f6c024..785878b 100644 --- a/sdk/python/adrian/__init__.py +++ b/sdk/python/adrian/__init__.py @@ -198,10 +198,11 @@ def init( session_id: Session identifier. Falls back to ``ADRIAN_SESSION_ID``, then to a per-cwd persistent UUID. See :mod:`adrian.session_persistence`. - block_timeout: Max seconds to wait for a verdict in ``MODE_BLOCK`` - before fail-open. Ignored in ``MODE_ALERT`` (no wait) and - ``MODE_HITL`` (wait indefinitely). Falls back to - ``ADRIAN_BLOCK_TIMEOUT``. + block_timeout: Max seconds to wait for a verdict in ``MODE_BLOCK``. + Timeout handling follows the server policy's + ``fail_closed_on_classifier_error`` flag. Ignored in + ``MODE_ALERT`` (no wait) and ``MODE_HITL`` (wait indefinitely). + Falls back to ``ADRIAN_BLOCK_TIMEOUT``. on_event: Callback for every paired event. on_verdict: Callback for every verdict. on_block: Callback for BLOCK-tier verdicts (M3 / M4). Notification @@ -844,11 +845,18 @@ def _extract_tool_calls( # pyright: ignore[reportUnusedFunction] def _should_halt(verdict: pb.Verdict) -> bool: """Decide whether a verdict should halt tool execution. - HITL resolutions override per-MAD policy when present. + HITL resolutions override everything: ``continue_execution=False`` + means halt, ``True`` means continue. Classifier ERROR verdicts + follow ``fail_closed_on_classifier_error``. Otherwise the per-MAD + policy bool is the sole scope authority: if the verdict's tier is + in-scope, halt; if not, continue. """ if verdict.HasField("hitl"): return not verdict.hitl.continue_execution + if verdict.status == pb.VERDICT_STATUS_ERROR: + return bool(verdict.policy.fail_closed_on_classifier_error) + mad_prefix = verdict.mad_code[:2] return { "M0": verdict.policy.policy_m0, @@ -861,11 +869,9 @@ def _should_halt(verdict: pb.Verdict) -> bool: def _patch_tool_node() -> None: """Patch ToolNode for callback injection + async verdict gate. - ToolNode dispatches tools via tool.invoke (sync) even within async - Pregel. BaseTool.invoke can't await a verdict from the event loop - thread, so we add the verdict gate here on ToolNode.ainvoke - the - entry point Pregel calls before tool dispatch begins. This is a - complementary gate to BaseTool (which covers direct callers). + ToolNode stays responsible for callback injection. The verdict gate lives + on ``BaseTool`` so async ToolNode dispatch does not consume verdict futures + before individual tools run. """ try: from langgraph.prebuilt import ToolNode @@ -983,6 +989,12 @@ async def _async_gate(tool_call_id: str) -> bool: if not ws.policy_active(): return False + if tool_call_id not in ws._tool_call_id_to_event_id: # pyright: ignore[reportPrivateUsage] + # Unknown / evicted correlation: there is no producing LLM + # event to gate, so this remains fail-open even when + # classifier-error fail-closed is enabled. + return False + cfg = _get_config() timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) verdict = await ws.wait_for_tool_call_verdict(tool_call_id, timeout) diff --git a/sdk/python/adrian/config.py b/sdk/python/adrian/config.py index 238b54c..d4a0895 100644 --- a/sdk/python/adrian/config.py +++ b/sdk/python/adrian/config.py @@ -18,7 +18,8 @@ """Callback invoked for every verdict received. Accepts a ``VerdictContext`` with full event metadata. May be sync or -async. Fires for every MAD code the server forwards (M0 / M2 / M3 / M4). +async. Fires for every forwarded verdict, including classifier-error +verdicts whose ``mad_code`` is empty. """ type OnBlockCallback = ( @@ -97,8 +98,9 @@ class AdrianConfig: ws_url: WebSocket URL for the Adrian server. ``None`` disables the WebSocket handler. block_timeout: Max seconds to wait for a verdict in ``MODE_BLOCK`` - before fail-open. Ignored in ``MODE_ALERT`` (no wait) and - ``MODE_HITL`` (wait indefinitely). + before applying the server's classifier-error timeout policy. + Ignored in ``MODE_ALERT`` (no wait) and ``MODE_HITL`` (wait + indefinitely). on_event: Callback for every paired event. on_verdict: Callback for every verdict. on_block: Callback for BLOCK-tier verdicts (M3 / M4). diff --git a/sdk/python/adrian/handler.py b/sdk/python/adrian/handler.py index 3289041..99aa7be 100644 --- a/sdk/python/adrian/handler.py +++ b/sdk/python/adrian/handler.py @@ -141,6 +141,7 @@ async def handle_verdict(self, verdict: pb.Verdict) -> None: run_id=record.run_id, parent_run_id=record.parent_run_id, policy=verdict.policy, + status=verdict.status, mad_code=verdict.mad_code, hitl=hitl, ) diff --git a/sdk/python/adrian/types.py b/sdk/python/adrian/types.py index 13c1506..082b1c4 100644 --- a/sdk/python/adrian/types.py +++ b/sdk/python/adrian/types.py @@ -219,9 +219,11 @@ class VerdictContext: event_data: Original event payload TypedDict. run_id: LangChain run ID. parent_run_id: Parent run ID if nested, or ``None``. - mad_code: MAD policy code the classifier returned + status: Classifier result status. ``VERDICT_STATUS_ERROR`` means + the classifier did not produce a MAD code; ``mad_code`` is empty. + mad_code: MAD policy code the classifier returned on OK verdicts (e.g. ``"M0"``, ``"M2_C"``, ``"M4_a"``). Empty string - when no code is set (benign). + means no classifier-produced MAD code exists. policy: Org's effective execution-mode policy at the moment this verdict was decided. Carries the mode (alert / block / hitl) and per-MAD-code scope booleans. @@ -238,6 +240,7 @@ class VerdictContext: run_id: str parent_run_id: str | None policy: pb.PolicySnapshot + status: int = pb.VERDICT_STATUS_UNSPECIFIED mad_code: str = "" hitl: pb.HitlResponse | None = None diff --git a/sdk/python/adrian/ws.py b/sdk/python/adrian/ws.py index 9b26d70..9e5e8bc 100644 --- a/sdk/python/adrian/ws.py +++ b/sdk/python/adrian/ws.py @@ -228,9 +228,10 @@ def __init__( self._model = "" # Server-supplied execution-mode policy. Populated when the # first ServerFrame{login_ack} arrives after each (re)connect. - # ``policy_active()`` and ``block_timeout()`` read this state - # to decide whether the patched ToolNode should wait for a - # verdict and how long. + # ``policy_active()``, ``block_timeout()``, and + # ``fail_closed_on_classifier_error()`` read this state to + # decide whether the patched ToolNode should wait for a verdict + # and how to handle classifier failures/timeouts. self._mode: int = pb.MODE_UNSPECIFIED self._policy: pb.PolicySnapshot | None = None # Set the first time a ``ServerFrame{login_ack}`` is applied. @@ -265,7 +266,8 @@ def __init__( # with the matching ``Verdict`` proto. Futures survive a # disconnect: a late verdict after reconnect still resolves # the wait; if none arrives, ``wait_for_verdict``'s timeout - # produces a natural fail-open in BLOCK mode. + # returns None and the patched ToolNode applies the current + # fail-open/fail-closed classifier-error policy. self._pending_verdicts: dict[str, asyncio.Future[pb.Verdict]] = {} # Maps LLM pair run_id → event_id so a subsequent tool call can # look up the verdict by its parent_run_id (the LLM's run_id). @@ -325,8 +327,8 @@ def policy_active(self) -> bool: def block_timeout(self, kwarg_default: float) -> float | None: """Effective per-tool-call wait timeout for the active mode. - - ``MODE_BLOCK``: ``kwarg_default`` (typically 30s), fail-open - if the server doesn't classify in time. + - ``MODE_BLOCK``: ``kwarg_default`` (typically 30s). Timeout + handling follows ``fail_closed_on_classifier_error``. - ``MODE_HITL``: ``None``, wait indefinitely for human review. - ``MODE_ALERT`` / unset: ``0``, caller short-circuits before registering a future. @@ -338,6 +340,12 @@ def block_timeout(self, kwarg_default: float) -> float | None: else: return 0 + def fail_closed_on_classifier_error(self) -> bool: + """Whether classifier errors/timeouts should halt tool execution.""" + return bool( + self._policy is not None and self._policy.fail_closed_on_classifier_error + ) + # -- EventHandler protocol -- async def on_paired_event(self, event: PairedEvent) -> None: @@ -703,12 +711,13 @@ def _on_login_ack(self, ack: pb.LoginAck) -> None: self._login_ack_received.set() logger.info( "LoginAck received: mode=%s policy_m0=%s policy_m2=%s " - "policy_m3=%s policy_m4=%s", + "policy_m3=%s policy_m4=%s fail_closed_on_classifier_error=%s", pb.Mode.Name(ack.policy.mode), ack.policy.policy_m0, ack.policy.policy_m2, ack.policy.policy_m3, ack.policy.policy_m4, + ack.policy.fail_closed_on_classifier_error, ) if self._on_login_ack_cb is not None: @@ -733,10 +742,17 @@ async def _on_verdict_frame(self, verdict: pb.Verdict) -> None: owns the cleanup: its ``finally`` pops the entry after the await returns. """ + if verdict.HasField("policy"): + # Keep the policy snapshot fresh for BLOCK-mode timeout + # decisions. Execution mode remains login-fixed for this + # release; hot-switching mode mid-session is out of scope. + self._policy = verdict.policy + logger.info( - "Verdict received: event_id=%s mad_code=%s mode=%s hitl=%s", + "Verdict received: event_id=%s mad_code=%s status=%s mode=%s hitl=%s", verdict.event_id, verdict.mad_code or "-", + pb.VerdictStatus.Name(verdict.status), pb.Mode.Name(verdict.policy.mode), verdict.HasField("hitl"), ) @@ -960,9 +976,9 @@ async def wait_for_verdict( """Wait for a verdict for ``event_id``. ``timeout`` is mode-derived (see :meth:`block_timeout`): - a positive float for ``MODE_BLOCK`` (fail-open at timeout), + a positive float for ``MODE_BLOCK`` (caller applies policy at timeout), ``None`` for ``MODE_HITL`` (wait indefinitely). Returns the - verdict, or ``None`` on timeout (fail-open). + verdict, or ``None`` on timeout. Resolved futures are kept in ``_pending_verdicts`` so a second waiter on the same event_id (e.g. BaseTool.ainvoke firing after diff --git a/sdk/python/tests/test_block_mode.py b/sdk/python/tests/test_block_mode.py index 742249b..af1ba81 100644 --- a/sdk/python/tests/test_block_mode.py +++ b/sdk/python/tests/test_block_mode.py @@ -35,6 +35,7 @@ def _apply_mode( policy_m2: bool = False, policy_m3: bool = False, policy_m4: bool = False, + fail_closed_on_classifier_error: bool = False, ) -> pb.PolicySnapshot: """Drive the mode/policy state as if a LoginAck had arrived.""" policy = pb.PolicySnapshot( @@ -43,6 +44,7 @@ def _apply_mode( policy_m2=policy_m2, policy_m3=policy_m3, policy_m4=policy_m4, + fail_closed_on_classifier_error=fail_closed_on_classifier_error, ) ws._mode = mode ws._policy = policy @@ -270,6 +272,174 @@ async def _real_tool(x: str) -> str: # Fail-closed: tool should NOT have run. assert captured == [] + async def test_timeout_fail_closed_blocks_tool(self, tmp_path: Path) -> None: + captured: list[str] = [] + + def _real_tool(x: str) -> str: + """Real tool stub for block-mode tests.""" + captured.append(x) + + return x + + adrian.init( + api_key="k", + log_file=str(tmp_path / "events.jsonl"), + auto_instrument=True, + ws_url="ws://x", + block_timeout=0.05, + ) + + ws = adrian._ws_client + assert ws is not None + _apply_mode( + ws, + pb.MODE_BLOCK, + policy_m4=True, + fail_closed_on_classifier_error=True, + ) + ws._connected.set() + ws._tool_call_id_to_event_id["tc-1"] = "llm-evt" + + tool_node = ToolNode([_real_tool]) + ai = AIMessage( + content="", + tool_calls=[{"id": "tc-1", "name": "_real_tool", "args": {"x": "hi"}}], + ) + state: dict[str, Any] = {"messages": [ai]} + + result = await tool_node.ainvoke(state, config=_runtime_config()) # pyright: ignore[reportUnknownMemberType] + + assert captured == [] + assert "BLOCKED" in result["messages"][0].content + + async def test_error_verdict_fail_open_runs_tool(self, tmp_path: Path) -> None: + captured: list[str] = [] + + def _real_tool(x: str) -> str: + """Real tool stub for block-mode tests.""" + captured.append(x) + + return x + + adrian.init( + api_key="k", + log_file=str(tmp_path / "events.jsonl"), + auto_instrument=True, + ws_url="ws://x", + block_timeout=1.0, + ) + + ws = adrian._ws_client + assert ws is not None + policy = _apply_mode(ws, pb.MODE_BLOCK, fail_closed_on_classifier_error=False) + ws._connected.set() + ws._tool_call_id_to_event_id["tc-1"] = "llm-evt" + + fut = ws.register_pending("llm-evt") + fut.set_result( + pb.Verdict( + event_id="llm-evt", + status=pb.VERDICT_STATUS_ERROR, + mad_code="", + policy=policy, + ), + ) + + tool_node = ToolNode([_real_tool]) + ai = AIMessage( + content="", + tool_calls=[{"id": "tc-1", "name": "_real_tool", "args": {"x": "hi"}}], + ) + state: dict[str, Any] = {"messages": [ai]} + + await tool_node.ainvoke(state, config=_runtime_config()) # pyright: ignore[reportUnknownMemberType] + + assert captured == ["hi"] + + async def test_error_verdict_fail_closed_blocks_tool( + self, + tmp_path: Path, + ) -> None: + captured: list[str] = [] + + def _real_tool(x: str) -> str: + """Real tool stub for block-mode tests.""" + captured.append(x) + + return x + + adrian.init( + api_key="k", + log_file=str(tmp_path / "events.jsonl"), + auto_instrument=True, + ws_url="ws://x", + block_timeout=1.0, + ) + + ws = adrian._ws_client + assert ws is not None + policy = _apply_mode(ws, pb.MODE_BLOCK, fail_closed_on_classifier_error=True) + ws._connected.set() + ws._tool_call_id_to_event_id["tc-1"] = "llm-evt" + + fut = ws.register_pending("llm-evt") + fut.set_result( + pb.Verdict( + event_id="llm-evt", + status=pb.VERDICT_STATUS_ERROR, + mad_code="", + policy=policy, + ), + ) + + tool_node = ToolNode([_real_tool]) + ai = AIMessage( + content="", + tool_calls=[{"id": "tc-1", "name": "_real_tool", "args": {"x": "hi"}}], + ) + state: dict[str, Any] = {"messages": [ai]} + + result = await tool_node.ainvoke(state, config=_runtime_config()) # pyright: ignore[reportUnknownMemberType] + + assert captured == [] + assert "BLOCKED" in result["messages"][0].content + + async def test_unknown_tool_call_stays_fail_open_when_fail_closed( + self, + tmp_path: Path, + ) -> None: + captured: list[str] = [] + + def _real_tool(x: str) -> str: + """Real tool stub for block-mode tests.""" + captured.append(x) + + return x + + adrian.init( + api_key="k", + log_file=str(tmp_path / "events.jsonl"), + auto_instrument=True, + ws_url="ws://x", + block_timeout=0.05, + ) + + ws = adrian._ws_client + assert ws is not None + _apply_mode(ws, pb.MODE_BLOCK, fail_closed_on_classifier_error=True) + ws._connected.set() + + tool_node = ToolNode([_real_tool]) + ai = AIMessage( + content="", + tool_calls=[{"id": "tc-1", "name": "_real_tool", "args": {"x": "hi"}}], + ) + state: dict[str, Any] = {"messages": [ai]} + + await tool_node.ainvoke(state, config=_runtime_config()) # pyright: ignore[reportUnknownMemberType] + + assert captured == ["hi"] + class TestModeAlert: async def test_alert_mode_skips_wait(self, tmp_path: Path) -> None: diff --git a/sdk/python/tests/test_exec_modes.py b/sdk/python/tests/test_exec_modes.py index f3f5e42..0678736 100644 --- a/sdk/python/tests/test_exec_modes.py +++ b/sdk/python/tests/test_exec_modes.py @@ -168,6 +168,62 @@ async def test_out_of_scope_continues_without_hitl( assert captured == ["hi"] + async def test_error_review_approve_continues(self, tmp_path: Path) -> None: + """ERROR verdict + HITL approve still resumes the tool.""" + captured: list[str] = [] + ws = _init_with_ws(tmp_path) + policy = pb.PolicySnapshot( + mode=pb.MODE_HITL, + fail_closed_on_classifier_error=True, + ) + _apply_mode(ws, policy) + ws._tool_call_id_to_event_id["tc-1"] = "evt-1" + + verdict = pb.Verdict( + event_id="evt-1", + status=pb.VERDICT_STATUS_ERROR, + mad_code="", + policy=policy, + ) + verdict.hitl.continue_execution = True + ws.register_pending("evt-1").set_result(verdict) + + result = await ToolNode([_stub_tool(captured)]).ainvoke( # pyright: ignore[reportUnknownMemberType] + _ainvoke_state(), + config=_runtime_config(), + ) + + assert captured == ["hi"] + assert "BLOCKED" not in result["messages"][0].content + + async def test_error_review_reject_halts(self, tmp_path: Path) -> None: + """ERROR verdict + HITL reject blocks the tool.""" + captured: list[str] = [] + ws = _init_with_ws(tmp_path) + policy = pb.PolicySnapshot( + mode=pb.MODE_HITL, + fail_closed_on_classifier_error=True, + ) + _apply_mode(ws, policy) + ws._tool_call_id_to_event_id["tc-1"] = "evt-1" + + verdict = pb.Verdict( + event_id="evt-1", + status=pb.VERDICT_STATUS_ERROR, + mad_code="", + policy=policy, + ) + verdict.hitl.continue_execution = False + ws.register_pending("evt-1").set_result(verdict) + + result = await ToolNode([_stub_tool(captured)]).ainvoke( # pyright: ignore[reportUnknownMemberType] + _ainvoke_state(), + config=_runtime_config(), + ) + + assert captured == [] + assert "BLOCKED" in result["messages"][0].content + # ------------------------------------------------------------------ # Stray HITL resolution + protocol error diff --git a/sdk/python/tests/test_handler.py b/sdk/python/tests/test_handler.py index cc421ec..b5fa09e 100644 --- a/sdk/python/tests/test_handler.py +++ b/sdk/python/tests/test_handler.py @@ -12,6 +12,8 @@ from adrian.handler import AdrianCallbackHandler, extract_model_name from adrian.hooks import HookRegistry from adrian.pairing import EventPairBuffer +from adrian.proto import event_pb2 as pb +from adrian.types import EventRecord, VerdictContext from langchain_core.messages import ( AIMessage, BaseMessage, # noqa: TC002 @@ -230,3 +232,47 @@ def test_kwargs_model_name(self) -> None: def test_empty_dict(self) -> None: assert extract_model_name({}) == "unknown" + + +class TestVerdictCallbacks: + async def test_error_verdict_populates_status_without_mad_callbacks(self) -> None: + seen: list[VerdictContext] = [] + blocked: list[VerdictContext] = [] + audited: list[VerdictContext] = [] + + handler = AdrianCallbackHandler( + pair_buffer=EventPairBuffer(), + context_tracker=AgentContextTracker(), + hooks=HookRegistry(), + config=AdrianConfig( + on_verdict=seen.append, + on_block=blocked.append, + on_audit=audited.append, + ), + ) + handler._event_map["evt-error"] = EventRecord( # pyright: ignore[reportPrivateUsage] + event_type="llm", + data={ + "output": "tool call", + "tool_calls": [], + "usage": None, + }, + run_id="run-1", + parent_run_id=None, + ) + + await handler.handle_verdict( + pb.Verdict( + event_id="evt-error", + session_id="sess-1", + status=pb.VERDICT_STATUS_ERROR, + mad_code="", + policy=pb.PolicySnapshot(fail_closed_on_classifier_error=True), + ), + ) + + assert len(seen) == 1 + assert seen[0].status == pb.VERDICT_STATUS_ERROR + assert seen[0].mad_code == "" + assert blocked == [] + assert audited == [] diff --git a/sdk/python/tests/test_ws.py b/sdk/python/tests/test_ws.py index c28762e..6e79ac7 100644 --- a/sdk/python/tests/test_ws.py +++ b/sdk/python/tests/test_ws.py @@ -299,6 +299,26 @@ async def __anext__(self) -> bytes: assert resolved.event_id == "evt-1" assert resolved.mad_code == "M4_a" + async def test_verdict_frame_refreshes_policy_without_switching_mode(self) -> None: + client = WebSocketClient("ws://x", "s", api_key="k") + client._mode = pb.MODE_ALERT + client._policy = pb.PolicySnapshot( + mode=pb.MODE_ALERT, + fail_closed_on_classifier_error=False, + ) + + verdict = pb.Verdict( + event_id="evt-1", + policy=pb.PolicySnapshot( + mode=pb.MODE_BLOCK, + fail_closed_on_classifier_error=True, + ), + ) + await client._on_verdict_frame(verdict) + + assert client._mode == pb.MODE_ALERT + assert client.fail_closed_on_classifier_error() is True + # ------------------------------------------------------------------ # Block-mode primitives