diff --git a/.dockerignore b/.dockerignore index 5b62e5f31f07..f1821beb2ac2 100644 --- a/.dockerignore +++ b/.dockerignore @@ -4,6 +4,7 @@ .devcontainer models backends +volumes examples/chatbot-ui/models backend/go/image/stablediffusion-ggml/build/ backend/go/*/build diff --git a/backend/backend.proto b/backend/backend.proto index bf07f3bd408c..18b86c205ce2 100644 --- a/backend/backend.proto +++ b/backend/backend.proto @@ -37,6 +37,22 @@ service Backend { rpc Rerank(RerankRequest) returns (RerankResult) {} + // TokenClassify runs a token-classification (NER) model on the + // supplied text and returns each detected entity span. Used by the + // PII redactor's optional NER tier — the regex tier still handles + // formatted hits cheaply, while this catches names, locations, and + // other unformatted PII that regex misses. + rpc TokenClassify(TokenClassifyRequest) returns (TokenClassifyResponse) {} + + // Score evaluates the model's joint log-probability of each + // supplied candidate continuation given a shared prompt. The + // prompt's KV cache is computed once and reused across candidates. + // Used for routing-policy multi-label classification, reranking, + // calibrated confidence, and reward-model scoring — any task where + // the consumer wants the model's confidence in a pre-specified + // continuation rather than a generated one. + rpc Score(ScoreRequest) returns (ScoreResponse) {} + rpc GetMetrics(MetricsRequest) returns (MetricsResponse); rpc VAD(VADRequest) returns (VADResponse) {} @@ -81,6 +97,76 @@ message MetricsResponse { int32 prompt_tokens_processed = 5; } +// TokenClassifyRequest carries the text to classify plus an optional +// score threshold. The transformers backend interprets threshold as +// the minimum confidence to include in the response; 0 = include all. +message TokenClassifyRequest { + string text = 1; + float threshold = 2; +} + +// TokenClassifyEntity is one detected entity span. Byte offsets are +// into the original UTF-8 text — start..end is a half-open range that +// addresses the substring corresponding to entity_group. +// +// entity_group follows HuggingFace's aggregated-tag convention (e.g. +// "PER", "LOC", "ORG", or a PII-specific label like "EMAIL" / +// "SSN" depending on the model). The redactor's per-pattern action +// map keys off this string. +message TokenClassifyEntity { + string entity_group = 1; + int32 start = 2; + int32 end = 3; + float score = 4; + string text = 5; +} + +message TokenClassifyResponse { + repeated TokenClassifyEntity entities = 1; +} + +// ScoreRequest carries one shared prompt and one or more continuations +// to score against it. The backend tokenises the prompt once and reuses +// the resulting KV cache across all candidates in this request. +message ScoreRequest { + string prompt = 1; + repeated string candidates = 2; + // Return per-token logprobs for each candidate when true. Default + // false to keep the wire response small; the joint log_prob field + // covers the common ranking case. + bool include_token_logprobs = 3; + // When true, the response also populates length_normalized_log_prob + // (joint log-prob divided by candidate token count). Useful when + // candidates differ in length and the consumer wants a per-token + // measure comparable across them (PMI-style scoring). + bool length_normalize = 4; +} + +// CandidateScore is one row in the ScoreResponse, matching by index +// the candidate in ScoreRequest.candidates. +message CandidateScore { + // Sum of log P(token_i | prompt, candidate_token_ #include #include +#include #include #include #include @@ -3095,6 +3096,207 @@ class BackendServiceImpl final : public backend::Backend::Service { return grpc::Status::OK; } + // Score returns the model's joint log-probability of each candidate + // continuation given a shared prompt. + // + // WHY bypass the slot/task queue: upstream server_context exposes + // get_llama_context as "main thread only" and the slot loop's + // update_slots() owns the context whenever a task is in flight. + // No public synchronization primitive is available — so Score is + // unsafe to call concurrently with active generation through this + // backend. In practice routing-classifier calls happen before the + // request is routed to a generation backend, so the model used + // for Score is typically idle. Concurrent Score calls are + // serialised by a local mutex; KV-cache state is isolated behind + // a dedicated sequence ID cleared between candidates. + // + // A patch to server-context.cpp that adds SERVER_TASK_TYPE_SCORE + // and routes scoring through the slot loop would be the correct + // long-term fix; tracked as a follow-up. + // + // Perf TODO (measured: ~450 ms warm for 3 candidates on Arch- + // Router-1.5B Q4_K_M + Intel SYCL): the current loop re-decodes + // `prompt + candidate` from scratch for every candidate, throwing + // away the prompt's KV cache between iterations. A smarter + // version would: + // 1. Decode just the prompt once into score_seq_id. + // 2. Snapshot/cp that sequence (llama_memory_seq_cp) into a + // per-candidate sequence id. + // 3. For each candidate, decode only its tokens onto the copy + // (continuing from the saved prompt state), read logits. + // 4. llama_memory_seq_rm the copy. + // Estimated speedup: 3-candidate calls 450 ms -> ~150-200 ms, + // 6-candidate calls 630 ms -> ~220 ms. Single source-file change, + // no proto / Go-side changes needed. Worth doing once routing is + // wired into the middleware and Score is on the hot path of every + // chat request. + grpc::Status Score(ServerContext* context, const backend::ScoreRequest* request, backend::ScoreResponse* response) override { + auto auth = checkAuth(context); + if (!auth.ok()) return auth; + if (params_base.model.path.empty()) { + return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded"); + } + if (request->candidates_size() == 0) { + return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "candidates must be non-empty"); + } + + // Serialise concurrent Score calls. The slot loop is still + // free to race with us — see the class comment above. + static std::mutex score_mutex; + std::lock_guard score_lock(score_mutex); + + llama_context * lctx = ctx_server.get_llama_context(); + if (lctx == nullptr) { + return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "llama context unavailable (sleeping?)"); + } + const llama_vocab * vocab = ctx_server.impl->vocab; + const int32_t n_vocab = llama_vocab_n_tokens(vocab); + const int32_t n_ctx = llama_n_ctx(lctx); + llama_memory_t mem = llama_get_memory(lctx); + + // The KV-cache is sized to seq_to_stream.size() at load + // (typically equal to n_slots, often 1). Sequence IDs must + // be in [0, n_seq_max), so we can't pick a high-value + // "private" ID — we have to share with the slot. We clear + // the cache before AND after each candidate to keep + // scoring isolated from whatever state the slot held, and + // the static mutex above guarantees no other Score call is + // racing in the meantime. The slot loop is still free to + // race (see comment on this method) — Score must not run + // concurrently with generation through this backend. + const llama_seq_id score_seq_id = 0; + llama_memory_seq_rm(mem, score_seq_id, -1, -1); + + // Tokenize the shared prompt once with add_special=true so + // BOS is prepended when the model requires it. parse_special + // keeps chat-template markers in the prompt intact. + const std::string prompt = request->prompt(); + std::vector prompt_tokens = common_tokenize(vocab, prompt, /*add_special=*/true, /*parse_special=*/true); + const int32_t prompt_len = (int32_t) prompt_tokens.size(); + + for (int ci = 0; ci < request->candidates_size(); ci++) { + const std::string & candidate_text = request->candidates(ci); + + // Re-tokenize prompt + candidate as a single string. BPE + // merges across the boundary can shift the tokenization + // versus tokenize(prompt) ++ tokenize(candidate), so we + // find the divergence point against prompt_tokens. + std::vector full_tokens = common_tokenize(vocab, prompt + candidate_text, /*add_special=*/true, /*parse_special=*/true); + int32_t divergence = prompt_len; + const int32_t min_len = std::min(prompt_len, (int32_t) full_tokens.size()); + for (int32_t i = 0; i < min_len; i++) { + if (prompt_tokens[i] != full_tokens[i]) { + divergence = i; + break; + } + } + const int32_t cand_len = (int32_t) full_tokens.size() - divergence; + backend::CandidateScore * cs = response->add_candidates(); + cs->set_num_tokens(cand_len); + if (cand_len <= 0) { + cs->set_log_prob(0.0); + if (request->length_normalize()) { + cs->set_length_normalized_log_prob(0.0); + } + continue; + } + if (divergence < 1) { + // Need at least one prior token (typically BOS) to + // predict the first candidate token's logit. Tokeniser + // models without BOS + an empty prompt fall in here. + return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, + "Score: prompt produced no leading tokens; need at least one (e.g. BOS) to predict candidate"); + } + if ((int32_t) full_tokens.size() > n_ctx) { + return grpc::Status(grpc::StatusCode::OUT_OF_RANGE, + "Score: prompt+candidate exceeds context size (got " + + std::to_string(full_tokens.size()) + ", n_ctx=" + std::to_string(n_ctx) + ")"); + } + + // Build a batch covering the entire prompt+candidate. We + // need logits at (divergence-1) onward — those are the + // predictions for each candidate token. + llama_batch batch = llama_batch_init((int32_t) full_tokens.size(), 0, 1); + for (int32_t i = 0; i < (int32_t) full_tokens.size(); i++) { + batch.token[i] = full_tokens[i]; + batch.pos[i] = i; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = score_seq_id; + // logits[i] is "do we want the prediction *for the + // next token*, computed from this position?" + // We want predictions for candidate tokens at + // positions divergence .. full_tokens.size()-1, which + // come from logits at positions (divergence-1) .. + // (full_tokens.size()-2). + bool need_logit = (i >= divergence - 1) && (i < (int32_t) full_tokens.size() - 1); + batch.logits[i] = need_logit ? 1 : 0; + } + batch.n_tokens = (int32_t) full_tokens.size(); + + // Decode the batch. If decode fails (e.g. KV slot + // exhaustion), surface as INTERNAL — the caller will + // typically fall back to a sampling-based classifier. + int decode_err = llama_decode(lctx, batch); + if (decode_err != 0) { + llama_batch_free(batch); + llama_memory_seq_rm(mem, score_seq_id, -1, -1); + return grpc::Status(grpc::StatusCode::INTERNAL, + "llama_decode failed during Score: " + std::to_string(decode_err)); + } + + // Sum log-probabilities of the actual candidate tokens. + double total_log_prob = 0.0; + for (int32_t k = 0; k < cand_len; k++) { + // The k-th candidate token sits at full_tokens index + // (divergence + k). Its predicting logit is at batch + // position (divergence + k - 1). + int32_t logit_pos = divergence + k - 1; + const float * logits = llama_get_logits_ith(lctx, logit_pos); + if (logits == nullptr) { + llama_batch_free(batch); + llama_memory_seq_rm(mem, score_seq_id, -1, -1); + return grpc::Status(grpc::StatusCode::INTERNAL, + "llama_get_logits_ith returned null at position " + std::to_string(logit_pos)); + } + llama_token target_token = full_tokens[divergence + k]; + + // Compute log_softmax(logits)[target_token] with the + // max-subtraction stability trick. + float max_logit = logits[0]; + for (int32_t v = 1; v < n_vocab; v++) { + if (logits[v] > max_logit) max_logit = logits[v]; + } + double sum_exp = 0.0; + for (int32_t v = 0; v < n_vocab; v++) { + sum_exp += std::exp((double)(logits[v] - max_logit)); + } + double token_log_prob = (double)(logits[target_token] - max_logit) - std::log(sum_exp); + total_log_prob += token_log_prob; + + if (request->include_token_logprobs()) { + backend::TokenLogProb * tlp = cs->add_tokens(); + std::string piece = common_token_to_piece(lctx, target_token); + tlp->set_token(piece); + tlp->set_log_prob(token_log_prob); + } + } + + cs->set_log_prob(total_log_prob); + if (request->length_normalize() && cand_len > 0) { + cs->set_length_normalized_log_prob(total_log_prob / (double) cand_len); + } + + llama_batch_free(batch); + // Drop this candidate's KV-cache contribution so the next + // candidate starts from a clean state. Without this, the + // next decode would conflict at positions 0..N-1 for our + // sequence ID. + llama_memory_seq_rm(mem, score_seq_id, -1, -1); + } + + return grpc::Status::OK; + } + grpc::Status TokenizeString(ServerContext* context, const backend::PredictOptions* request, backend::TokenizationResponse* response) override { auto auth = checkAuth(context); if (!auth.ok()) return auth; diff --git a/backend/go/local-store/store.go b/backend/go/local-store/store.go index e2ad540987ad..9961a12a044b 100644 --- a/backend/go/local-store/store.go +++ b/backend/go/local-store/store.go @@ -1,7 +1,22 @@ package main -// This is a wrapper to statisfy the GRPC service interface -// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +// LocalAI's in-process vector store, exposed as a gRPC backend. Keep +// the implementation here — NOT in a pkg/ library imported by the main +// LocalAI process. The whole point of the gRPC surface is that vector +// storage is a backend like any other (local-store, qdrant, pinecone, +// ...) and can be swapped without changing the routing/recognition +// code that consumes it. +// +// Storage is a sorted parallel-slice (keys [][]float32, values +// [][]byte). Set/Delete preserve the sort so Get can binary-search. +// Find scans linearly and uses a heap to keep the top-K — fine for +// the tens-to-thousands range. The "normalized fast path" (Find when +// every stored key has unit magnitude AND the query is normalized) +// skips the per-item magnitude calculation. +// +// Concurrency: base.SingleThread serialises gRPC calls so the +// non-thread-safe slice/heap manipulation here is sound. + import ( "container/heap" "fmt" @@ -10,30 +25,27 @@ import ( "github.com/mudler/LocalAI/pkg/grpc/base" pb "github.com/mudler/LocalAI/pkg/grpc/proto" - - "github.com/mudler/xlog" + "github.com/mudler/LocalAI/pkg/store" ) type Store struct { base.SingleThread - // The sorted keys - keys [][]float32 - // The sorted values + keys [][]float32 values [][]byte - // If for every K it holds that ||k||^2 = 1, then we can use the normalized distance functions - // TODO: Should we normalize incoming keys if they are not instead? + // keysAreNormalized stays true until any non-unit-magnitude key + // is added; once false, the magnitude-aware fallback path is + // used by Find. Re-evaluated only at Set time, never again on + // its own — a deletion of the offending key does NOT flip it + // back to true (the bookkeeping cost would dominate the gain). keysAreNormalized bool - // The first key decides the length of the keys - keyLen int -} -// TODO: Only used for sorting using Go's builtin implementation. The interfaces are columnar because -// that's theoretically best for memory layout and cache locality, but this isn't optimized yet. -type Pair struct { - Key []float32 - Value []byte + // keyLen is the dimension of every stored key. -1 means "no + // keys yet, dimension is open". Dimension mismatch on Set is + // rejected so cosine similarity (which requires equal-length + // vectors) doesn't silently mis-match. + keyLen int } func NewStore() *Store { @@ -45,477 +57,286 @@ func NewStore() *Store { } } -func compareSlices(k1, k2 []float32) int { - assert(len(k1) == len(k2), fmt.Sprintf("compareSlices: len(k1) = %d, len(k2) = %d", len(k1), len(k2))) - - return slices.Compare(k1, k2) -} - -func hasKey(unsortedSlice [][]float32, target []float32) bool { - return slices.ContainsFunc(unsortedSlice, func(k []float32) bool { - return compareSlices(k, target) == 0 - }) -} - -func findInSortedSlice(sortedSlice [][]float32, target []float32) (int, bool) { - return slices.BinarySearchFunc(sortedSlice, target, func(k, t []float32) int { - return compareSlices(k, t) - }) -} - -func isSortedPairs(kvs []Pair) bool { - for i := 1; i < len(kvs); i++ { - if compareSlices(kvs[i-1].Key, kvs[i].Key) > 0 { - return false - } - } - - return true -} - -func isSortedKeys(keys [][]float32) bool { - for i := 1; i < len(keys); i++ { - if compareSlices(keys[i-1], keys[i]) > 0 { - return false - } - } - - return true -} - -func sortIntoKeySlicese(keys []*pb.StoresKey) [][]float32 { - ks := make([][]float32, len(keys)) - - for i, k := range keys { - ks[i] = k.Floats - } - - slices.SortFunc(ks, compareSlices) - - assert(len(ks) == len(keys), fmt.Sprintf("len(ks) = %d, len(keys) = %d", len(ks), len(keys))) - assert(isSortedKeys(ks), "keys are not sorted") - - return ks -} - +// Load is a no-op — local-store has no on-disk artefact. opts.Model is +// just a namespace identifier; isolation is already handled upstream +// (ModelLoader spawns a fresh local-store process per (backend, +// model) tuple, so each namespace is its own Store{} instance). func (s *Store) Load(opts *pb.ModelOptions) error { - // local-store is an in-memory vector store with no on-disk artefact to - // load — opts.Model is just a namespace identifier. The old `!= ""` guard - // rejected any non-empty model name with "not implemented", which broke - // callers that pass a namespace to isolate embedding spaces (face vs. - // voice biometrics both go through local-store but need distinct stores - // so ArcFace 512-D and ECAPA-TDNN 192-D don't collide). Namespace - // isolation is already handled upstream: ModelLoader spawns a fresh - // local-store process per (backend, model) tuple, so each namespace is - // its own Store{} instance. Nothing to do here beyond accepting the load. _ = opts return nil } -// Sort the incoming kvs and merge them with the existing sorted kvs func (s *Store) StoresSet(opts *pb.StoresSetOptions) error { - if len(opts.Keys) == 0 { - return fmt.Errorf("no keys to add") + keys := store.UnwrapKeys(opts.Keys) + values := store.UnwrapValues(opts.Values) + if len(keys) == 0 { + return fmt.Errorf("local-store: Set: no keys to add") } - - if len(opts.Keys) != len(opts.Values) { - return fmt.Errorf("len(keys) = %d, len(values) = %d", len(opts.Keys), len(opts.Values)) + if len(keys) != len(values) { + return fmt.Errorf("local-store: Set: len(keys) = %d, len(values) = %d", len(keys), len(values)) } if s.keyLen == -1 { - s.keyLen = len(opts.Keys[0].Floats) - } else { - if len(opts.Keys[0].Floats) != s.keyLen { - return fmt.Errorf("Try to add key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen) - } + s.keyLen = len(keys[0]) + } else if len(keys[0]) != s.keyLen { + return fmt.Errorf("local-store: Set: key length %d does not match existing %d", len(keys[0]), s.keyLen) } - kvs := make([]Pair, len(opts.Keys)) - - for i, k := range opts.Keys { - if s.keysAreNormalized && !isNormalized(k.Floats) { - s.keysAreNormalized = false - var sample []float32 - if len(s.keys) > 5 { - sample = k.Floats[:5] - } else { - sample = k.Floats - } - xlog.Debug("Key is not normalized", "sample", sample) - } - - kvs[i] = Pair{ - Key: k.Floats, - Value: opts.Values[i].Bytes, - } - } - - slices.SortFunc(kvs, func(a, b Pair) int { - return compareSlices(a.Key, b.Key) - }) - - assert(len(kvs) == len(opts.Keys), fmt.Sprintf("len(kvs) = %d, len(opts.Keys) = %d", len(kvs), len(opts.Keys))) - assert(isSortedPairs(kvs), "keys are not sorted") - - l := len(kvs) + len(s.keys) - merge_ks := make([][]float32, 0, l) - merge_vs := make([][]byte, 0, l) - - i, j := 0, 0 - for { - if i+j >= l { - break - } - - if i >= len(kvs) { - merge_ks = append(merge_ks, s.keys[j]) - merge_vs = append(merge_vs, s.values[j]) - j++ - continue - } - - if j >= len(s.keys) { - merge_ks = append(merge_ks, kvs[i].Key) - merge_vs = append(merge_vs, kvs[i].Value) - i++ - continue + kvs := make([]incomingPair, len(keys)) + for i, k := range keys { + if len(k) != s.keyLen { + return fmt.Errorf("local-store: Set: key %d length %d does not match existing %d", i, len(k), s.keyLen) } - - c := compareSlices(kvs[i].Key, s.keys[j]) - if c < 0 { - merge_ks = append(merge_ks, kvs[i].Key) - merge_vs = append(merge_vs, kvs[i].Value) - i++ - } else if c > 0 { - merge_ks = append(merge_ks, s.keys[j]) - merge_vs = append(merge_vs, s.values[j]) - j++ - } else { - merge_ks = append(merge_ks, kvs[i].Key) - merge_vs = append(merge_vs, kvs[i].Value) - i++ - j++ + if s.keysAreNormalized && !isNormalized(k) { + s.keysAreNormalized = false } + kvs[i] = incomingPair{key: k, value: values[i]} } - assert(len(merge_ks) == l, fmt.Sprintf("len(merge_ks) = %d, l = %d", len(merge_ks), l)) - assert(isSortedKeys(merge_ks), "merge keys are not sorted") - - s.keys = merge_ks - s.values = merge_vs + slices.SortFunc(kvs, func(a, b incomingPair) int { return slices.Compare(a.key, b.key) }) + merged := mergeSortedPairs(s.keys, s.values, kvs) + s.keys = merged.keys + s.values = merged.values return nil } func (s *Store) StoresDelete(opts *pb.StoresDeleteOptions) error { - if len(opts.Keys) == 0 { - return fmt.Errorf("no keys to delete") + keys := store.UnwrapKeys(opts.Keys) + if len(keys) == 0 { + return fmt.Errorf("local-store: Delete: no keys to delete") } - - if len(opts.Keys) == 0 { - return fmt.Errorf("no keys to add") - } - - if s.keyLen == -1 { - s.keyLen = len(opts.Keys[0].Floats) - } else { - if len(opts.Keys[0].Floats) != s.keyLen { - return fmt.Errorf("Trying to delete key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen) - } - } - - ks := sortIntoKeySlicese(opts.Keys) - - l := len(s.keys) - len(ks) - merge_ks := make([][]float32, 0, l) - merge_vs := make([][]byte, 0, l) - - tail_ks := s.keys - tail_vs := s.values - for _, k := range ks { - j, found := findInSortedSlice(tail_ks, k) - - if found { - merge_ks = append(merge_ks, tail_ks[:j]...) - merge_vs = append(merge_vs, tail_vs[:j]...) - tail_ks = tail_ks[j+1:] - tail_vs = tail_vs[j+1:] - } else { - assert(!hasKey(s.keys, k), fmt.Sprintf("Key exists, but was not found: t=%d, %v", len(tail_ks), k)) + if s.keyLen != -1 { + for i, k := range keys { + if len(k) != s.keyLen { + return fmt.Errorf("local-store: Delete: key %d length %d does not match existing %d", i, len(k), s.keyLen) + } } - - xlog.Debug("Delete", "found", found, "tailLen", len(tail_ks), "j", j, "mergeKeysLen", len(merge_ks), "mergeValuesLen", len(merge_vs)) } - - merge_ks = append(merge_ks, tail_ks...) - merge_vs = append(merge_vs, tail_vs...) - - assert(len(merge_ks) <= len(s.keys), fmt.Sprintf("len(merge_ks) = %d, len(s.keys) = %d", len(merge_ks), len(s.keys))) - - s.keys = merge_ks - s.values = merge_vs - - assert(len(s.keys) >= l, fmt.Sprintf("len(s.keys) = %d, l = %d", len(s.keys), l)) - assert(isSortedKeys(s.keys), "keys are not sorted") - assert(func() bool { - for _, k := range ks { - if _, found := findInSortedSlice(s.keys, k); found { - return false - } + sortedKeys := append([][]float32(nil), keys...) + slices.SortFunc(sortedKeys, slices.Compare[[]float32]) + + mergedK := make([][]float32, 0, len(s.keys)) + mergedV := make([][]byte, 0, len(s.keys)) + tailK := s.keys + tailV := s.values + for _, k := range sortedKeys { + j, ok := slices.BinarySearchFunc(tailK, k, slices.Compare[[]float32]) + if ok { + mergedK = append(mergedK, tailK[:j]...) + mergedV = append(mergedV, tailV[:j]...) + tailK = tailK[j+1:] + tailV = tailV[j+1:] } - return true - }(), "Keys to delete still present") - - if len(s.keys) != l { - xlog.Debug("Delete: Some keys not found", "keysLen", len(s.keys), "expectedLen", l) } - + mergedK = append(mergedK, tailK...) + mergedV = append(mergedV, tailV...) + s.keys = mergedK + s.values = mergedV return nil } +// StoresGet fetches values for the given keys. Missing keys are +// omitted from the result rather than reported as an error — callers +// compare returned-key length against requested-key length to detect +// them. Returned slices are aligned. func (s *Store) StoresGet(opts *pb.StoresGetOptions) (pb.StoresGetResult, error) { - pbKeys := make([]*pb.StoresKey, 0, len(opts.Keys)) - pbValues := make([]*pb.StoresValue, 0, len(opts.Keys)) - ks := sortIntoKeySlicese(opts.Keys) - + keys := store.UnwrapKeys(opts.Keys) if len(s.keys) == 0 { - xlog.Debug("Get: No keys in store") + return pb.StoresGetResult{}, nil } - - if s.keyLen == -1 { - s.keyLen = len(opts.Keys[0].Floats) - } else { - if len(opts.Keys[0].Floats) != s.keyLen { - return pb.StoresGetResult{}, fmt.Errorf("Try to get a key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen) + if s.keyLen != -1 { + for i, k := range keys { + if len(k) != s.keyLen { + return pb.StoresGetResult{}, fmt.Errorf("local-store: Get: key %d length %d does not match existing %d", i, len(k), s.keyLen) + } } } - - tail_k := s.keys - tail_v := s.values - for i, k := range ks { - j, found := findInSortedSlice(tail_k, k) - - if found { - pbKeys = append(pbKeys, &pb.StoresKey{ - Floats: k, - }) - pbValues = append(pbValues, &pb.StoresValue{ - Bytes: tail_v[j], - }) - - tail_k = tail_k[j+1:] - tail_v = tail_v[j+1:] - } else { - assert(!hasKey(s.keys, k), fmt.Sprintf("Key exists, but was not found: i=%d, %v", i, k)) + sortedKeys := append([][]float32(nil), keys...) + slices.SortFunc(sortedKeys, slices.Compare[[]float32]) + + var foundKeys [][]float32 + var foundValues [][]byte + tailK := s.keys + tailV := s.values + for _, k := range sortedKeys { + j, ok := slices.BinarySearchFunc(tailK, k, slices.Compare[[]float32]) + if !ok { + continue } + foundKeys = append(foundKeys, tailK[j]) + foundValues = append(foundValues, tailV[j]) + tailK = tailK[j+1:] + tailV = tailV[j+1:] } - - if len(pbKeys) != len(opts.Keys) { - xlog.Debug("Get: Some keys not found", "pbKeysLen", len(pbKeys), "optsKeysLen", len(opts.Keys), "storeKeysLen", len(s.keys)) - } - return pb.StoresGetResult{ - Keys: pbKeys, - Values: pbValues, + Keys: store.WrapKeys(foundKeys), + Values: store.WrapValues(foundValues), }, nil } -func isNormalized(k []float32) bool { - var sum float64 - - for _, v := range k { - v64 := float64(v) - sum += v64 * v64 +// StoresFind returns the topK nearest stored entries by cosine +// similarity, ordered most-similar first. An empty store returns +// empty slices and no error. +func (s *Store) StoresFind(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) { + query := opts.Key.Floats + topK := int(opts.TopK) + if topK < 1 { + return pb.StoresFindResult{}, fmt.Errorf("local-store: Find: topK = %d, must be >= 1", topK) } - - s := math.Sqrt(sum) - - return s >= 0.99 && s <= 1.01 -} - -// TODO: This we could replace with handwritten SIMD code -func normalizedCosineSimilarity(k1, k2 []float32) float32 { - assert(len(k1) == len(k2), fmt.Sprintf("normalizedCosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2))) - - var dot float32 - for i := range len(k1) { - dot += k1[i] * k2[i] + if len(s.keys) == 0 { + return pb.StoresFindResult{}, nil } - - assert(dot >= -1.01 && dot <= 1.01, fmt.Sprintf("dot = %f", dot)) - - // 2.0 * (1.0 - dot) would be the Euclidean distance - return dot -} - -type PriorityItem struct { - Similarity float32 - Key []float32 - Value []byte -} - -type PriorityQueue []*PriorityItem - -func (pq PriorityQueue) Len() int { return len(pq) } - -func (pq PriorityQueue) Less(i, j int) bool { - // Inverted because the most similar should be at the top - return pq[i].Similarity < pq[j].Similarity -} - -func (pq PriorityQueue) Swap(i, j int) { - pq[i], pq[j] = pq[j], pq[i] -} - -func (pq *PriorityQueue) Push(x any) { - item := x.(*PriorityItem) - *pq = append(*pq, item) -} - -func (pq *PriorityQueue) Pop() any { - old := *pq - n := len(old) - item := old[n-1] - *pq = old[0 : n-1] - return item -} - -func (s *Store) StoresFindNormalized(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) { - tk := opts.Key.Floats - top_ks := make(PriorityQueue, 0, int(opts.TopK)) - heap.Init(&top_ks) - - for i, k := range s.keys { - sim := normalizedCosineSimilarity(tk, k) - heap.Push(&top_ks, &PriorityItem{ - Similarity: sim, - Key: k, - Value: s.values[i], - }) - - if top_ks.Len() > int(opts.TopK) { - heap.Pop(&top_ks) - } + if len(query) != s.keyLen { + return pb.StoresFindResult{}, fmt.Errorf("local-store: Find: query length %d does not match existing %d", len(query), s.keyLen) } - similarities := make([]float32, top_ks.Len()) - pbKeys := make([]*pb.StoresKey, top_ks.Len()) - pbValues := make([]*pb.StoresValue, top_ks.Len()) - - for i := top_ks.Len() - 1; i >= 0; i-- { - item := heap.Pop(&top_ks).(*PriorityItem) - - similarities[i] = item.Similarity - pbKeys[i] = &pb.StoresKey{ - Floats: item.Key, - } - pbValues[i] = &pb.StoresValue{ - Bytes: item.Value, - } + var keys [][]float32 + var values [][]byte + var sims []float32 + if s.keysAreNormalized && isNormalized(query) { + keys, values, sims = s.findNormalized(query, topK) + } else { + keys, values, sims = s.findFallback(query, topK) } - return pb.StoresFindResult{ - Keys: pbKeys, - Values: pbValues, - Similarities: similarities, + Keys: store.WrapKeys(keys), + Values: store.WrapValues(values), + Similarities: sims, }, nil } -func cosineSimilarity(k1, k2 []float32, mag1 float64) float32 { - assert(len(k1) == len(k2), fmt.Sprintf("cosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2))) - - var dot, mag2 float64 - for i := range len(k1) { - dot += float64(k1[i] * k2[i]) - mag2 += float64(k2[i] * k2[i]) +func (s *Store) findNormalized(query []float32, topK int) (keys [][]float32, values [][]byte, similarities []float32) { + pq := make(priorityQueue, 0, topK) + heap.Init(&pq) + for i, k := range s.keys { + var dot float32 + for j := range k { + dot += query[j] * k[j] + } + heap.Push(&pq, &priorityItem{similarity: dot, key: k, value: s.values[i]}) + if pq.Len() > topK { + heap.Pop(&pq) + } } - - sim := float32(dot / (mag1 * math.Sqrt(mag2))) - - assert(sim >= -1.01 && sim <= 1.01, fmt.Sprintf("sim = %f", sim)) - - return sim + return drainPQ(&pq) } -func (s *Store) StoresFindFallback(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) { - tk := opts.Key.Floats - top_ks := make(PriorityQueue, 0, int(opts.TopK)) - heap.Init(&top_ks) - - var mag1 float64 - for _, v := range tk { - mag1 += float64(v * v) +func (s *Store) findFallback(query []float32, topK int) (keys [][]float32, values [][]byte, similarities []float32) { + var qmag float64 + for _, v := range query { + qmag += float64(v) * float64(v) } - mag1 = math.Sqrt(mag1) - + qmag = math.Sqrt(qmag) + pq := make(priorityQueue, 0, topK) + heap.Init(&pq) for i, k := range s.keys { - dist := cosineSimilarity(tk, k, mag1) - heap.Push(&top_ks, &PriorityItem{ - Similarity: dist, - Key: k, - Value: s.values[i], - }) - - if top_ks.Len() > int(opts.TopK) { - heap.Pop(&top_ks) + var dot, kmag float64 + for j := range k { + dot += float64(query[j]) * float64(k[j]) + kmag += float64(k[j]) * float64(k[j]) } - } - - similarities := make([]float32, top_ks.Len()) - pbKeys := make([]*pb.StoresKey, top_ks.Len()) - pbValues := make([]*pb.StoresValue, top_ks.Len()) - - for i := top_ks.Len() - 1; i >= 0; i-- { - item := heap.Pop(&top_ks).(*PriorityItem) - - similarities[i] = item.Similarity - pbKeys[i] = &pb.StoresKey{ - Floats: item.Key, + denom := qmag * math.Sqrt(kmag) + var sim float32 + if denom > 0 { + sim = float32(dot / denom) } - pbValues[i] = &pb.StoresValue{ - Bytes: item.Value, + heap.Push(&pq, &priorityItem{similarity: sim, key: k, value: s.values[i]}) + if pq.Len() > topK { + heap.Pop(&pq) } } - - return pb.StoresFindResult{ - Keys: pbKeys, - Values: pbValues, - Similarities: similarities, - }, nil + return drainPQ(&pq) } -func (s *Store) StoresFind(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) { - tk := opts.Key.Floats - - if len(tk) != s.keyLen { - return pb.StoresFindResult{}, fmt.Errorf("Try to find key with length %d when existing length is %d", len(tk), s.keyLen) +func isNormalized(k []float32) bool { + var sum float64 + for _, v := range k { + sum += float64(v) * float64(v) } + mag := math.Sqrt(sum) + return mag >= 0.99 && mag <= 1.01 +} - if opts.TopK < 1 { - return pb.StoresFindResult{}, fmt.Errorf("opts.TopK = %d, must be >= 1", opts.TopK) - } +type incomingPair struct { + key []float32 + value []byte +} - if s.keyLen == -1 { - s.keyLen = len(opts.Key.Floats) - } else { - if len(opts.Key.Floats) != s.keyLen { - return pb.StoresFindResult{}, fmt.Errorf("Try to add key with length %d when existing length is %d", len(opts.Key.Floats), s.keyLen) - } - } +type pairs struct { + keys [][]float32 + values [][]byte +} - if s.keysAreNormalized && isNormalized(tk) { - return s.StoresFindNormalized(opts) - } else { - if s.keysAreNormalized { - var sample []float32 - if len(s.keys) > 5 { - sample = tk[:5] - } else { - sample = tk +// mergeSortedPairs merges (existing, incoming) into a fresh sorted +// slice. Equal keys take the incoming value — Set is upsert. +func mergeSortedPairs(existingK [][]float32, existingV [][]byte, incoming []incomingPair) pairs { + l := len(existingK) + len(incoming) + mk := make([][]float32, 0, l) + mv := make([][]byte, 0, l) + i, j := 0, 0 + for i < len(incoming) || j < len(existingK) { + switch { + case j >= len(existingK): + mk = append(mk, incoming[i].key) + mv = append(mv, incoming[i].value) + i++ + case i >= len(incoming): + mk = append(mk, existingK[j]) + mv = append(mv, existingV[j]) + j++ + default: + c := slices.Compare(incoming[i].key, existingK[j]) + switch { + case c < 0: + mk = append(mk, incoming[i].key) + mv = append(mv, incoming[i].value) + i++ + case c > 0: + mk = append(mk, existingK[j]) + mv = append(mv, existingV[j]) + j++ + default: + mk = append(mk, incoming[i].key) + mv = append(mv, incoming[i].value) + i++ + j++ } - xlog.Debug("Trying to compare non-normalized key with normalized keys", "sample", sample) } + } + return pairs{keys: mk, values: mv} +} + +type priorityItem struct { + similarity float32 + key []float32 + value []byte +} + +type priorityQueue []*priorityItem + +func (pq priorityQueue) Len() int { return len(pq) } +func (pq priorityQueue) Less(i, j int) bool { return pq[i].similarity < pq[j].similarity } +func (pq priorityQueue) Swap(i, j int) { pq[i], pq[j] = pq[j], pq[i] } +func (pq *priorityQueue) Push(x any) { *pq = append(*pq, x.(*priorityItem)) } +func (pq *priorityQueue) Pop() any { + old := *pq + n := len(old) + item := old[n-1] + *pq = old[0 : n-1] + return item +} - return s.StoresFindFallback(opts) +func drainPQ(pq *priorityQueue) (keys [][]float32, values [][]byte, similarities []float32) { + n := pq.Len() + keys = make([][]float32, n) + values = make([][]byte, n) + similarities = make([]float32, n) + for i := n - 1; i >= 0; i-- { + item := heap.Pop(pq).(*priorityItem) + keys[i] = item.key + values[i] = item.value + similarities[i] = item.similarity } + return keys, values, similarities } diff --git a/backend/go/local-store/store_suite_test.go b/backend/go/local-store/store_suite_test.go new file mode 100644 index 000000000000..63affb46bb75 --- /dev/null +++ b/backend/go/local-store/store_suite_test.go @@ -0,0 +1,13 @@ +package main + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestLocalStore(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "local-store test suite") +} diff --git a/backend/go/local-store/store_test.go b/backend/go/local-store/store_test.go new file mode 100644 index 000000000000..2043647c027d --- /dev/null +++ b/backend/go/local-store/store_test.go @@ -0,0 +1,284 @@ +package main + +// Regression suite for the local-store gRPC backend. Exercises the +// Stores{Set,Get,Find,Delete} surface — the only public contract. +// Callers (face/voice recognition, the routing KNN classifier) reach +// this code via grpc.Backend, so testing at the wire-shaped boundary +// matches the production import shape. + +import ( + "math" + "math/rand/v2" + "testing" + + pb "github.com/mudler/LocalAI/pkg/grpc/proto" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("StoresSet", func() { + It("rejects empty input", func() { + Expect(NewStore().StoresSet(&pb.StoresSetOptions{})).NotTo(Succeed(), "Set with no keys should fail") + }) + + It("rejects key/value length mismatch", func() { + err := NewStore().StoresSet(&pb.StoresSetOptions{ + Keys: wrapKeys([][]float32{{1, 0, 0}}), + Values: wrapValues([][]byte{[]byte("a"), []byte("b")}), + }) + Expect(err).To(HaveOccurred(), "len(keys) != len(values) should fail") + }) + + It("rejects dimension mismatch on later add", func() { + s := NewStore() + mustSet(s, [][]float32{{1, 0, 0}}, [][]byte{[]byte("3d")}) + err := s.StoresSet(&pb.StoresSetOptions{ + Keys: wrapKeys([][]float32{{1, 0}}), + Values: wrapValues([][]byte{[]byte("2d")}), + }) + Expect(err).To(HaveOccurred(), "dimension mismatch on later Set should fail") + }) + + It("rejects dimension mismatch within batch", func() { + err := NewStore().StoresSet(&pb.StoresSetOptions{ + Keys: wrapKeys([][]float32{{1, 0, 0}, {1, 0}}), + Values: wrapValues([][]byte{[]byte("3d"), []byte("2d")}), + }) + Expect(err).To(HaveOccurred(), "mixed-dimension within one batch should fail") + }) + + It("merges sorted and updates existing key", func() { + s := NewStore() + mustSet(s, [][]float32{{0.3, 0, 0}, {0.1, 0, 0}}, [][]byte{[]byte("c"), []byte("a")}) + mustSet(s, [][]float32{{0.2, 0, 0}, {0.1, 0, 0}}, [][]byte{[]byte("b"), []byte("a-updated")}) + Expect(s.keys).To(HaveLen(3)) + got := singleGet(s, []float32{0.1, 0, 0}) + Expect(string(got)).To(Equal("a-updated")) + }) +}) + +var _ = Describe("StoresGet", func() { + It("round-trips multi-key", func() { + s := NewStore() + mustSet(s, + [][]float32{{0.1, 0.2, 0.3}, {0.4, 0.5, 0.6}, {0.7, 0.8, 0.9}}, + [][]byte{[]byte("a"), []byte("b"), []byte("c")}, + ) + res, err := s.StoresGet(&pb.StoresGetOptions{ + Keys: wrapKeys([][]float32{{0.7, 0.8, 0.9}, {0.1, 0.2, 0.3}}), + }) + Expect(err).NotTo(HaveOccurred()) + Expect(res.Keys).To(HaveLen(2)) + }) + + It("omits missing keys rather than erroring", func() { + s := NewStore() + mustSet(s, [][]float32{{0.1, 0, 0}}, [][]byte{[]byte("a")}) + res, err := s.StoresGet(&pb.StoresGetOptions{ + Keys: wrapKeys([][]float32{{0.1, 0, 0}, {0.9, 0, 0}}), + }) + Expect(err).NotTo(HaveOccurred()) + Expect(res.Keys).To(HaveLen(1)) + }) +}) + +var _ = Describe("StoresDelete", func() { + It("removes and preserves sort", func() { + s := NewStore() + mustSet(s, + [][]float32{{0.1, 0, 0}, {0.2, 0, 0}, {0.3, 0, 0}, {0.4, 0, 0}}, + [][]byte{[]byte("a"), []byte("b"), []byte("c"), []byte("d")}, + ) + Expect(s.StoresDelete(&pb.StoresDeleteOptions{ + Keys: wrapKeys([][]float32{{0.2, 0, 0}, {0.4, 0, 0}}), + })).To(Succeed()) + Expect(s.keys).To(HaveLen(2)) + }) + + It("tolerates missing keys", func() { + s := NewStore() + mustSet(s, [][]float32{{0.1, 0, 0}}, [][]byte{[]byte("a")}) + Expect(s.StoresDelete(&pb.StoresDeleteOptions{ + Keys: wrapKeys([][]float32{{0.9, 0, 0}}), + })).To(Succeed(), "delete of missing key should succeed") + Expect(s.keys).To(HaveLen(1)) + }) +}) + +var _ = Describe("StoresFind", func() { + It("returns normalized top-K", func() { + s := NewStore() + mustSet(s, + [][]float32{ + normalizeVec([]float32{1, 0, 0}), + normalizeVec([]float32{0, 1, 0}), + normalizeVec([]float32{0, 0, 1}), + }, + [][]byte{[]byte("x"), []byte("y"), []byte("z")}, + ) + res, err := s.StoresFind(&pb.StoresFindOptions{ + Key: &pb.StoresKey{Floats: normalizeVec([]float32{0.9, 0.1, 0})}, + TopK: 2, + }) + Expect(err).NotTo(HaveOccurred()) + Expect(res.Keys).To(HaveLen(2)) + Expect(res.Similarities[0]).To(BeNumerically(">=", res.Similarities[1]), "results not sorted desc by similarity") + Expect(string(res.Values[0].Bytes)).To(Equal("x")) + }) + + It("falls back for non-normalized keys", func() { + s := NewStore() + mustSet(s, [][]float32{{2, 0, 0}, {0, 3, 0}}, [][]byte{[]byte("x"), []byte("y")}) + Expect(s.keysAreNormalized).To(BeFalse(), "store should report non-normalized after Set with magnitude > 1") + res, err := s.StoresFind(&pb.StoresFindOptions{ + Key: &pb.StoresKey{Floats: []float32{4, 0, 0}}, + TopK: 1, + }) + Expect(err).NotTo(HaveOccurred()) + Expect(string(res.Values[0].Bytes)).To(Equal("x")) + Expect(res.Similarities[0]).To(BeNumerically(">=", float32(0.99))) + Expect(res.Similarities[0]).To(BeNumerically("<=", float32(1.01))) + }) + + It("rejects zero topK", func() { + s := NewStore() + mustSet(s, [][]float32{{1, 0, 0}}, [][]byte{[]byte("x")}) + _, err := s.StoresFind(&pb.StoresFindOptions{ + Key: &pb.StoresKey{Floats: []float32{1, 0, 0}}, + TopK: 0, + }) + Expect(err).To(HaveOccurred(), "Find with topK=0 should fail") + }) + + It("rejects dimension mismatch", func() { + s := NewStore() + mustSet(s, [][]float32{{1, 0, 0}}, [][]byte{[]byte("x")}) + _, err := s.StoresFind(&pb.StoresFindOptions{ + Key: &pb.StoresKey{Floats: []float32{1, 0}}, + TopK: 1, + }) + Expect(err).To(HaveOccurred(), "Find with mismatched dimension should fail") + }) + + It("returns empty result on empty store", func() { + res, err := NewStore().StoresFind(&pb.StoresFindOptions{ + Key: &pb.StoresKey{Floats: []float32{1, 0, 0}}, + TopK: 5, + }) + Expect(err).NotTo(HaveOccurred(), "Find on empty store should succeed") + Expect(res.Keys).To(BeEmpty()) + }) + + It("handles topK larger than store", func() { + s := NewStore() + mustSet(s, + [][]float32{normalizeVec([]float32{1, 0, 0}), normalizeVec([]float32{0, 1, 0})}, + [][]byte{[]byte("x"), []byte("y")}, + ) + res, err := s.StoresFind(&pb.StoresFindOptions{ + Key: &pb.StoresKey{Floats: normalizeVec([]float32{1, 0, 0})}, + TopK: 10, + }) + Expect(err).NotTo(HaveOccurred()) + Expect(res.Keys).To(HaveLen(2)) + }) +}) + +var _ = Describe("StoresLoad", func() { + It("is a no-op", func() { + Expect(NewStore().Load(&pb.ModelOptions{Model: "any-namespace"})).To(Succeed()) + }) +}) + +func BenchmarkStoresFindNormalized(b *testing.B) { + const dim = 768 + for _, n := range []int{8, 32, 128, 512} { + b.Run(fmtN(n), func(b *testing.B) { + s := buildStore(b, n, dim) + query := normalizeVec(randVec(dim, 42)) + req := &pb.StoresFindOptions{Key: &pb.StoresKey{Floats: query}, TopK: 1} + b.ResetTimer() + for i := 0; i < b.N; i++ { + if _, err := s.StoresFind(req); err != nil { + b.Fatal(err) + } + } + }) + } +} + +// --- test helpers --- + +func mustSet(s *Store, keys [][]float32, values [][]byte) { + ExpectWithOffset(1, s.StoresSet(&pb.StoresSetOptions{Keys: wrapKeys(keys), Values: wrapValues(values)})).To(Succeed()) +} + +func singleGet(s *Store, key []float32) []byte { + res, err := s.StoresGet(&pb.StoresGetOptions{Keys: wrapKeys([][]float32{key})}) + ExpectWithOffset(1, err).NotTo(HaveOccurred()) + if len(res.Values) == 0 { + return nil + } + return res.Values[0].Bytes +} + +func wrapKeys(in [][]float32) []*pb.StoresKey { + out := make([]*pb.StoresKey, len(in)) + for i, k := range in { + out[i] = &pb.StoresKey{Floats: k} + } + return out +} + +func wrapValues(in [][]byte) []*pb.StoresValue { + out := make([]*pb.StoresValue, len(in)) + for i, v := range in { + out[i] = &pb.StoresValue{Bytes: v} + } + return out +} + +func buildStore(tb testing.TB, n, dim int) *Store { + tb.Helper() + s := NewStore() + keys := make([][]float32, n) + values := make([][]byte, n) + for i := 0; i < n; i++ { + keys[i] = normalizeVec(randVec(dim, int64(i)+1)) + values[i] = []byte{byte(i)} + } + if err := s.StoresSet(&pb.StoresSetOptions{Keys: wrapKeys(keys), Values: wrapValues(values)}); err != nil { + tb.Fatal(err) + } + return s +} + +func randVec(dim int, seed int64) []float32 { + r := rand.New(rand.NewPCG(uint64(seed), 0xabcdef)) + v := make([]float32, dim) + for i := range v { + v[i] = float32(r.NormFloat64()) + } + return v +} + +func normalizeVec(v []float32) []float32 { + var sum float64 + for _, x := range v { + sum += float64(x) * float64(x) + } + mag := math.Sqrt(sum) + if mag == 0 { + return v + } + out := make([]float32, len(v)) + for i, x := range v { + out[i] = float32(float64(x) / mag) + } + return out +} + +func fmtN(n int) string { + return map[int]string{8: "n=8", 32: "n=32", 128: "n=128", 512: "n=512"}[n] +} diff --git a/backend/python/transformers/backend.py b/backend/python/transformers/backend.py index f2f70acb3214..a8c1840b3c46 100644 --- a/backend/python/transformers/backend.py +++ b/backend/python/transformers/backend.py @@ -26,7 +26,7 @@ XPU=os.environ.get("XPU", "0") == "1" import transformers as transformers_module -from transformers import AutoTokenizer, AutoModel, AutoProcessor, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria +from transformers import AutoTokenizer, AutoModel, AutoProcessor, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria, pipeline from scipy.io import wavfile from sentence_transformers import SentenceTransformer @@ -200,6 +200,21 @@ def LoadModel(self, request, context): autoTokenizer = False self.model = SentenceTransformer(model_name, trust_remote_code=request.TrustRemoteCode) self.SentenceTransformer = True + elif request.Type == "TokenClassification": + # NER / PII tagging via HuggingFace's token-classification + # pipeline. aggregation_strategy="simple" merges B-/I- tags + # into single spans and gives byte offsets back. The + # tokenizer is bundled inside the pipeline, so we skip the + # AutoTokenizer load below. + autoTokenizer = False + self.tokenClassifier = pipeline( + "token-classification", + model=model_name, + aggregation_strategy="simple", + device=0 if self.CUDA else -1, + trust_remote_code=request.TrustRemoteCode, + ) + self.TokenClassification = True else: # Generic: dynamically resolve model class from transformers model_type = TYPE_ALIASES.get(request.Type, request.Type) @@ -253,6 +268,39 @@ def LoadModel(self, request, context): return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") return backend_pb2.Result(message="Model loaded successfully", success=True) + def TokenClassify(self, request, context): + # Runs HuggingFace's token-classification pipeline and returns + # the aggregated entity spans. The pipeline gives us byte + # offsets via aggregation_strategy="simple" (set at load + # time), so the caller can slice the original text without + # re-tokenising on the Go side. + if not getattr(self, "TokenClassification", False): + context.set_code(grpc.StatusCode.FAILED_PRECONDITION) + context.set_details("model was not loaded as Type=TokenClassification") + return backend_pb2.TokenClassifyResponse() + try: + results = self.tokenClassifier(request.text) + except Exception as err: + print("TokenClassify error:", err, file=sys.stderr) + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(f"token-classification failed: {err}") + return backend_pb2.TokenClassifyResponse() + + threshold = request.threshold if request.threshold > 0 else 0.0 + entities = [] + for r in results: + score = float(r.get("score", 0.0)) + if score < threshold: + continue + entities.append(backend_pb2.TokenClassifyEntity( + entity_group=str(r.get("entity_group") or r.get("entity") or ""), + start=int(r.get("start", 0)), + end=int(r.get("end", 0)), + score=score, + text=str(r.get("word", "")), + )) + return backend_pb2.TokenClassifyResponse(entities=entities) + def Embedding(self, request, context): set_seed(request.Seed) # Tokenize input diff --git a/backend/python/vllm/backend.py b/backend/python/vllm/backend.py index 967c4420c051..74598660b6f8 100644 --- a/backend/python/vllm/backend.py +++ b/backend/python/vllm/backend.py @@ -356,6 +356,133 @@ async def Free(self, request, context): except Exception as e: return backend_pb2.Result(success=False, message=str(e)) + async def Score(self, request, context): + """ + Joint log-probability of each candidate continuation given the + shared prompt. Used by routing-policy multi-label classification + (read the distribution rather than asking the model to emit a + single argmax label), reranking, and reward-model scoring. + + Implementation uses vLLM's `prompt_logprobs` to recover the + per-token log P(token_i | tokens_= len(prompt_logprobs) or prompt_logprobs[position] is None: + continue + entry = prompt_logprobs[position] + lp_obj = entry.get(tok_id) + if lp_obj is not None: + lp = lp_obj.logprob + else: + # Token not in top-K; vLLM's top-1 may miss it. + # Fall back to the lowest available logprob in the + # entry — a conservative lower-bound on the true + # log P, biased against this candidate. + lp = min(v.logprob for v in entry.values()) + total += lp + if request.include_token_logprobs: + tokens_proto.append(backend_pb2.TokenLogProb( + token=self.tokenizer.decode([tok_id]), + log_prob=lp, + )) + + cs = backend_pb2.CandidateScore( + log_prob=total, + num_tokens=num_candidate_tokens, + ) + if request.length_normalize and num_candidate_tokens > 0: + cs.length_normalized_log_prob = total / num_candidate_tokens + if tokens_proto: + cs.tokens.extend(tokens_proto) + results.append(cs) + + return backend_pb2.ScoreResponse(candidates=results) + except Exception as e: + print(f"Score error: {e}", file=sys.stderr) + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(str(e)) + return backend_pb2.ScoreResponse() + async def _predict(self, request, context, streaming=False): # Build the sampling parameters # NOTE: this must stay in sync with the vllm backend diff --git a/core/application/application.go b/core/application/application.go index 852324e74203..7a34279c9064 100644 --- a/core/application/application.go +++ b/core/application/application.go @@ -9,11 +9,18 @@ import ( corebackend "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/http/auth" mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp" "github.com/mudler/LocalAI/core/services/agentpool" "github.com/mudler/LocalAI/core/services/facerecognition" "github.com/mudler/LocalAI/core/services/galleryop" + "github.com/mudler/LocalAI/core/services/monitoring" "github.com/mudler/LocalAI/core/services/nodes" + "github.com/mudler/LocalAI/core/services/routing/admission" + "github.com/mudler/LocalAI/core/services/routing/billing" + "github.com/mudler/LocalAI/core/services/cloudproxy/mitm" + "github.com/mudler/LocalAI/core/services/routing/pii" + "github.com/mudler/LocalAI/core/services/routing/router" "github.com/mudler/LocalAI/core/services/voicerecognition" "github.com/mudler/LocalAI/core/templates" pkggrpc "github.com/mudler/LocalAI/pkg/grpc" @@ -51,6 +58,22 @@ type Application struct { faceRegistry facerecognition.Registry voiceRegistry voicerecognition.Registry authDB *gorm.DB + metricsService *monitoring.LocalAIMetricsService + statsRecorder *billing.Recorder + fallbackUser *auth.User + piiRedactor *pii.Redactor + piiEvents pii.EventStore + mitmCA atomic.Pointer[mitm.CA] + mitmServer atomic.Pointer[mitm.Server] + mitmMutex sync.Mutex // serializes Stop+Start; readers use atomic loads + // mitmHostConflicts records duplicate-host claims across model configs. + // Non-empty disables the MITM listener until resolved — the strict + // 1-to-1 host↔model invariant the dispatcher relies on. Read by + // /api/middleware/status so the admin UI can surface the cause. + mitmHostConflicts atomic.Pointer[map[string][]string] + routerDecisions router.DecisionStore + routerRegistry *router.Registry + admissionLimiter *admission.Limiter watchdogMutex sync.Mutex watchdogStop chan bool p2pMutex sync.Mutex @@ -185,6 +208,103 @@ func (a *Application) AuthDB() *gorm.DB { return a.authDB } +// MetricsService returns the OTel + Prometheus metric service. nil when +// --disable-metrics is set or initialisation failed at startup. +// +// The service is created in startup.go before any counter is registered +// so that otel.SetMeterProvider runs early enough for the billing +// recorder's counters to bind to the Prom-backed provider rather than +// the no-op global. core/http/app.go reuses this instance instead of +// constructing its own — two providers would orphan one set of counters +// behind whichever provider lost the SetMeterProvider race. +func (a *Application) MetricsService() *monitoring.LocalAIMetricsService { + return a.metricsService +} + +// StatsRecorder returns the billing recorder used by the usage +// middleware. It is non-nil whenever stats are not explicitly disabled +// — i.e., the no-auth single-user path still gets a working recorder +// (in-memory by default). Routes register UsageMiddleware against this +// recorder regardless of auth state. +func (a *Application) StatsRecorder() *billing.Recorder { + return a.statsRecorder +} + +// FallbackUser is the synthetic "local" user that UsageMiddleware uses +// to attribute requests when no authenticated user is on the context +// (i.e., --auth is off). nil when auth is on, since real users are +// always available there. +func (a *Application) FallbackUser() *auth.User { + return a.fallbackUser +} + +// PIIRedactor returns the regex-tier PII redactor or nil if PII +// filtering is disabled. The chat-route middleware uses this to apply +// redaction before dispatch. +func (a *Application) PIIRedactor() *pii.Redactor { + return a.piiRedactor +} + +// PIIEvents returns the PII event store. Same nil-when-disabled +// semantics as PIIRedactor; admin REST and MCP read tools call List +// against it. +func (a *Application) PIIEvents() pii.EventStore { + return a.piiEvents +} + +// MITMCA returns the cloudproxy MITM proxy's CA, or nil when the +// MITM listener is disabled. +func (a *Application) MITMCA() *mitm.CA { return a.mitmCA.Load() } + +// MITMServer returns the running MITM proxy or nil. +func (a *Application) MITMServer() *mitm.Server { return a.mitmServer.Load() } + +// MITMHostConflicts returns a snapshot of host→[]model-name pairs that +// are claimed by 2+ model configs. Empty when the 1-to-1 invariant +// holds. Non-empty disables the MITM listener — read by the admin +// status endpoint to explain why. +func (a *Application) MITMHostConflicts() map[string][]string { + p := a.mitmHostConflicts.Load() + if p == nil { + return nil + } + return *p +} + +// MITMHostOwners returns the host→model-name map, useful for the +// admin status endpoint. The lookup is recomputed on each call to +// stay current with model-config edits without needing a +// MITMRestart. +func (a *Application) MITMHostOwners() map[string]string { + if a.backendLoader == nil { + return nil + } + return a.backendLoader.MITMHostOwners().Owners +} + +// RouterDecisions returns the routing decision store. nil when stats +// are disabled (--disable-stats); the RouteModel middleware skips the +// log write in that case but still rewrites requests. +func (a *Application) RouterDecisions() router.DecisionStore { + return a.routerDecisions +} + +// RouterClassifierRegistry returns the process-wide classifier cache. +// Shared between the OpenAI and Anthropic route middlewares so the +// admin stats endpoint sees every live classifier — and so a +// classifier built on the OpenAI route is reused on Anthropic. +func (a *Application) RouterClassifierRegistry() *router.Registry { + return a.routerRegistry +} + +// AdmissionLimiter returns the per-model admission limiter. The +// admission middleware uses it to gate concurrent requests; the +// admin status surface reads InFlight/Capacity from it for live +// load visibility. +func (a *Application) AdmissionLimiter() *admission.Limiter { + return a.admissionLimiter +} + // StartupConfig returns the original startup configuration (from env vars, before file loading) func (a *Application) StartupConfig() *config.ApplicationConfig { return a.startupConfig @@ -255,6 +375,15 @@ func (a *Application) start() error { a.modelLoader, a.galleryService, ) + // Wire usage tracking so the assistant's get_usage_stats tool + // returns real data; nil values keep the tool returning a clear + // "unavailable" error if startup ran with --disable-stats. + assistantClient.StatsRecorder = a.statsRecorder + assistantClient.FallbackUser = a.fallbackUser + // PII filter — same nil-or-real wiring. + assistantClient.PIIRedactor = a.piiRedactor + assistantClient.PIIEvents = a.piiEvents + assistantClient.RouterDecisions = a.routerDecisions if err := holder.Initialize(a.applicationConfig.Context, assistantClient, localaitools.Options{}); err != nil { // Why log+continue instead of fail: the assistant is an optional // feature; a failure here must not take down the whole server. diff --git a/core/application/embedder.go b/core/application/embedder.go new file mode 100644 index 000000000000..53105e5e4379 --- /dev/null +++ b/core/application/embedder.go @@ -0,0 +1,159 @@ +package application + +import ( + "context" + "fmt" + "strings" + + "github.com/mudler/LocalAI/core/backend" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/services/routing/router" + "github.com/mudler/LocalAI/pkg/grpc" + "github.com/mudler/LocalAI/pkg/model" + "github.com/mudler/LocalAI/pkg/store" +) + +// adapterConfig resolves a model name to its runtime ModelConfig +// for the router-side adapter, or nil when the name is unknown. +func (a *Application) adapterConfig(modelName string) *config.ModelConfig { + cfg, err := a.backendLoader.LoadModelConfigFileByNameDefaultOptions(modelName, a.applicationConfig) + if err != nil || cfg == nil { + return nil + } + return cfg +} + +// ScorerFactory returns a router.Scorer bound to the named model, or +// nil when the model is not loadable. The router uses this to obtain +// joint log-probabilities of policy labels under the configured +// classifier model — multi-label routing without asking the model to +// emit a single argmax label (which off-the-shelf classifier-tuned +// models like Arch-Router struggle with via grammar constraint). +func (a *Application) ScorerFactory() func(modelName string) router.Scorer { + return func(modelName string) router.Scorer { + cfg := a.adapterConfig(modelName) + if cfg == nil { + return nil + } + return &modelScorer{ + modelLoader: a.modelLoader, + modelConfig: cfg, + appConfig: a.applicationConfig, + } + } +} + +type modelScorer struct { + modelLoader *model.ModelLoader + modelConfig *config.ModelConfig + appConfig *config.ApplicationConfig +} + +func (m *modelScorer) Score(ctx context.Context, prompt string, candidates []string) ([]router.CandidateScore, error) { + fn, err := backend.ModelScore(prompt, candidates, backend.ScoreOptions{LengthNormalize: true}, m.modelLoader, *m.modelConfig, m.appConfig) + if err != nil { + return nil, err + } + raw, err := fn(ctx) + if err != nil { + return nil, err + } + out := make([]router.CandidateScore, len(raw)) + for i, c := range raw { + out[i] = router.CandidateScore{ + LogProb: c.LogProb, + LengthNormalizedLogProb: c.LengthNormalizedLogProb, + NumTokens: c.NumTokens, + } + } + return out, nil +} + +// EmbedderFactory returns a router.Embedder bound to the named model, +// or nil when the model is not loadable. The L2 embedding cache uses +// this to embed router probes before searching the vector store. +func (a *Application) EmbedderFactory() func(modelName string) router.Embedder { + return func(modelName string) router.Embedder { + cfg := a.adapterConfig(modelName) + if cfg == nil { + return nil + } + return &modelEmbedder{ + modelLoader: a.modelLoader, + modelConfig: cfg, + appConfig: a.applicationConfig, + } + } +} + +type modelEmbedder struct { + modelLoader *model.ModelLoader + modelConfig *config.ModelConfig + appConfig *config.ApplicationConfig +} + +func (e *modelEmbedder) Embed(ctx context.Context, text string) ([]float32, error) { + fn, err := backend.ModelEmbedding(text, nil, e.modelLoader, *e.modelConfig, e.appConfig) + if err != nil { + return nil, err + } + return fn() +} + +// VectorStoreFactory returns a router.VectorStore bound to the named +// collection. The local-store backend is loaded once per collection +// name; each router model gets its own backend process via the +// model.ModelLoader cache keyed by storeName. +func (a *Application) VectorStoreFactory() func(storeName string) router.VectorStore { + return func(storeName string) router.VectorStore { + if storeName == "" { + return nil + } + return &localVectorStore{ + appConfig: a.applicationConfig, + modelLoader: a.modelLoader, + storeName: storeName, + } + } +} + +type localVectorStore struct { + appConfig *config.ApplicationConfig + modelLoader *model.ModelLoader + storeName string +} + +func (s *localVectorStore) backend(ctx context.Context) (grpc.Backend, error) { + _ = ctx // local-store load is synchronous; ctx unused here for symmetry with the interface. + return backend.StoreBackend(s.modelLoader, s.appConfig, s.storeName, "") +} + +func (s *localVectorStore) Search(ctx context.Context, vec []float32) (float64, []byte, bool, error) { + be, err := s.backend(ctx) + if err != nil { + return 0, nil, false, fmt.Errorf("vector store load: %w", err) + } + _, values, similarities, err := store.Find(ctx, be, vec, 1) + if err != nil { + // local-store's Find returns "existing length is -1" when no + // keys have been inserted yet. Surface that as a clean miss so + // the cache layer doesn't treat it as a failure and skip the + // follow-up Insert. + if strings.Contains(err.Error(), "existing length is -1") { + return 0, nil, false, nil + } + return 0, nil, false, fmt.Errorf("vector store find: %w", err) + } + if len(values) == 0 || len(similarities) == 0 { + return 0, nil, false, nil + } + return float64(similarities[0]), values[0], true, nil +} + +func (s *localVectorStore) Insert(ctx context.Context, vec []float32, payload []byte) error { + be, err := s.backend(ctx) + if err != nil { + return fmt.Errorf("vector store load: %w", err) + } + return store.SetSingle(ctx, be, vec, payload) +} diff --git a/core/application/mitm.go b/core/application/mitm.go new file mode 100644 index 000000000000..ed506e5f7231 --- /dev/null +++ b/core/application/mitm.go @@ -0,0 +1,146 @@ +package application + +import ( + "errors" + "fmt" + "path/filepath" + "sort" + + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/services/cloudproxy/mitm" + "github.com/mudler/xlog" +) + +func startMITMProxy(app *Application, options *config.ApplicationConfig) error { + app.mitmMutex.Lock() + defer app.mitmMutex.Unlock() + return startMITMLocked(app, options) +} + +func startMITMLocked(app *Application, options *config.ApplicationConfig) error { + // Validate the host↔model-config 1-to-1 invariant before binding + // the listener. Two configs claiming the same host means the + // dispatcher would have ambiguous PII settings; refuse to start + // rather than silently picking one. The conflict map is published + // for /api/middleware/status to surface in the UI. + ownership := app.backendLoader.MITMHostOwners() + if len(ownership.Conflicts) > 0 { + conflicts := ownership.Conflicts + app.mitmHostConflicts.Store(&conflicts) + hosts := make([]string, 0, len(conflicts)) + for h := range conflicts { + hosts = append(hosts, h) + } + sort.Strings(hosts) + xlog.Error("mitm: refusing to start — duplicate host claims across model configs", + "hosts", hosts, + "conflicts", conflicts, + ) + return errors.New("mitm: configuration error: duplicate host claims (see /api/middleware/status)") + } + app.mitmHostConflicts.Store(nil) + + caDir := options.MITMCADir + if caDir == "" { + base := options.DataPath + if base == "" { + base = "." + } + caDir = filepath.Join(base, "mitm-ca") + } + + if app.mitmCA.Load() == nil { + ca, err := mitm.LoadOrCreateCA(caDir) + if err != nil { + return fmt.Errorf("ca: %w", err) + } + app.mitmCA.Store(ca) + } + + // Allowlist is exactly the set of hosts claimed by model configs. + // No global list — admins add hosts by creating an MITM model + // config (template available in the Add Model UI). When no config + // claims any host, the listener still starts but every CONNECT + // tunnels through unmodified. + effectiveHosts := make([]string, 0, len(ownership.Owners)) + for h := range ownership.Owners { + effectiveHosts = append(effectiveHosts, h) + } + sort.Strings(effectiveHosts) + + // Per-host PII gate inherits from the owning model's pii.enabled. + // A backend without "proxy-" prefix and no explicit pii.enabled + // resolves to false → host is intercepted but the regex pass is + // skipped (audit events still record). + var piiDisabled []string + for host, modelName := range ownership.Owners { + cfg, exists := app.backendLoader.GetModelConfig(modelName) + if !exists { + continue + } + if !cfg.PIIIsEnabled() { + piiDisabled = append(piiDisabled, host) + } + } + + handler := mitm.NewPIIHandler(mitm.PIIHandlerOptions{ + Redactor: app.piiRedactor, + EventStore: app.piiEvents, + HostsWithPIIDisabled: piiDisabled, + }) + + srv, err := mitm.NewServer(mitm.Config{ + Addr: options.MITMListen, + CA: app.mitmCA.Load(), + InterceptHosts: effectiveHosts, + Handler: handler, + EventStore: app.piiEvents, + }) + if err != nil { + return fmt.Errorf("server: %w", err) + } + if err := srv.Start(); err != nil { + return fmt.Errorf("listen: %w", err) + } + app.mitmServer.Store(srv) + + xlog.Info("mitm: cloudproxy listener started", + "addr", srv.Addr(), + "ca_dir", caDir, + "intercept_hosts", effectiveHosts, + "model_owned_hosts", len(ownership.Owners), + "pii_disabled_hosts", len(piiDisabled), + ) + return nil +} + +// StopMITM is idempotent. +func (a *Application) StopMITM() error { + a.mitmMutex.Lock() + defer a.mitmMutex.Unlock() + stopMITMLocked(a) + return nil +} + +// RestartMITM reuses the existing CA so trusted clients keep +// working across listener flips. +func (a *Application) RestartMITM() error { + a.mitmMutex.Lock() + defer a.mitmMutex.Unlock() + stopMITMLocked(a) + if a.applicationConfig.MITMListen == "" { + xlog.Info("mitm: cloudproxy listener stays disabled (no listen address)") + return nil + } + return startMITMLocked(a, a.applicationConfig) +} + +func stopMITMLocked(a *Application) { + srv := a.mitmServer.Load() + if srv == nil { + return + } + srv.Stop() + a.mitmServer.Store(nil) + xlog.Info("mitm: cloudproxy listener stopped") +} diff --git a/core/application/runtime_settings_branding_test.go b/core/application/runtime_settings_branding_test.go index 9f173864ebc3..6300f4456adc 100644 --- a/core/application/runtime_settings_branding_test.go +++ b/core/application/runtime_settings_branding_test.go @@ -87,6 +87,28 @@ var _ = Describe("loadRuntimeSettingsFromFile", func() { }) }) + // MITM listener address. The file is the only source — no env var + // exists — so a regression here means an admin who configured the + // listener via /api/settings loses it after a reboot, even though + // the value is still on disk in the volume. (Intercept hosts now + // live in model YAML mitm.hosts: blocks, not runtime_settings.json.) + Describe("MITM fields", func() { + It("loads mitm_listen", func() { + cfg := &config.ApplicationConfig{DynamicConfigsDir: seedSettings(`{"mitm_listen": ":8443"}`)} + loadRuntimeSettingsFromFile(cfg) + Expect(cfg.MITMListen).To(Equal(":8443")) + }) + + It("does not override an explicit CLI flag", func() { + cfg := &config.ApplicationConfig{ + DynamicConfigsDir: seedSettings(`{"mitm_listen": ":8443"}`), + MITMListen: ":9999", // simulate WithMITMListen(":9999") + } + loadRuntimeSettingsFromFile(cfg) + Expect(cfg.MITMListen).To(Equal(":9999"), "CLI flag must win over the persisted file value") + }) + }) + // The Agent Pool block has a mix of zero and non-zero defaults // (Enabled=true, EmbeddingModel="granite-...", MaxChunkingSize=400, // VectorEngine="chromem", AgentHubURL="https://agenthub.localai.io"). diff --git a/core/application/startup.go b/core/application/startup.go index ab50936e28bd..e63e05c4ab4f 100644 --- a/core/application/startup.go +++ b/core/application/startup.go @@ -15,7 +15,12 @@ import ( "github.com/mudler/LocalAI/core/http/auth" "github.com/mudler/LocalAI/core/services/galleryop" "github.com/mudler/LocalAI/core/services/jobs" + "github.com/mudler/LocalAI/core/services/monitoring" "github.com/mudler/LocalAI/core/services/nodes" + "github.com/mudler/LocalAI/core/services/routing/admission" + "github.com/mudler/LocalAI/core/services/routing/billing" + "github.com/mudler/LocalAI/core/services/routing/pii" + "github.com/mudler/LocalAI/core/services/routing/router" "github.com/mudler/LocalAI/core/services/storage" "github.com/mudler/LocalAI/pkg/vram" coreStartup "github.com/mudler/LocalAI/core/startup" @@ -128,6 +133,111 @@ func New(opts ...config.AppOption) (*Application, error) { }() } + // Initialize the OTel + Prometheus metric pipeline before any + // counter is created. monitoring.NewLocalAIMetricsService calls + // otel.SetMeterProvider, so any subsequent otel.Meter() call — + // including billing.NewRecorder below — sees the real provider + // rather than the no-op global. Initialising metrics later (in + // core/http/app.go) leaves billing's counters bound to a no-op + // meter and never reaches /metrics. We deliberately ignore + // DisableMetrics here for ordering purposes; the HTTP middleware + // that records api_call histograms is still gated. + if !options.DisableMetrics { + ms, err := monitoring.NewLocalAIMetricsService() + if err != nil { + xlog.Error("failed to initialize metrics provider", "error", err) + } else { + application.metricsService = ms + // Bind the billing package's counters to the same meter the + // metrics service exports. Without this, billing's counters + // resolve via the OTel global and never reach /metrics. + billing.SetMeter(ms.Meter) + } + } + + // Wire the routing-module billing recorder. The recorder runs in + // every mode (auth on/off, distributed/single-node) so that token + // tracking is not gated on auth — a no-auth single-user box still + // gets dashboards and `/api/usage` populated. + // + // fallbackUser is wired *unconditionally* when stats are enabled. + // UsageMiddleware uses it as the attribution source whenever + // auth.GetUser(c) is nil — that covers (a) no-auth deployments and + // (b) internal callers under auth-on (cron flushers, distributed + // worker callbacks) that hit a recordable endpoint without a user + // in context. The billing.user_id_present invariant still rejects + // empty IDs; LocalUser() returns a stable UUID per data path. + if !options.DisableStats { + var statsBackend billing.StatsBackend + switch { + case application.authDB != nil: + statsBackend = billing.NewGormBackend(application.authDB, 0, 0) + xlog.Info("stats: using auth DB for usage records") + default: + statsBackend = billing.NewMemoryBackend(0) + xlog.Info("stats: using in-memory ring buffer (no-auth single-user mode)") + } + application.fallbackUser = billing.LocalUser(options.DataPath) + application.statsRecorder = billing.NewRecorder(statsBackend) + xlog.Info("stats: fallback user wired", "local_user_id", application.fallbackUser.ID) + } else { + xlog.Info("stats: disabled by --disable-stats") + } + + // Wire the regex PII filter. Default-on: a single-user box gets + // the built-in pattern set the first time it starts, with email/ + // phone/SSN/credit-card on mask and api_key_prefix on block. If + // the operator wants different actions, --pii-config points at a + // YAML file that overrides per-id; --disable-pii turns it off + // entirely. + if !options.DisablePII { + patterns, err := pii.LoadConfig(options.PIIConfigPath) + if err != nil { + return nil, fmt.Errorf("pii config: %w", err) + } + application.piiRedactor = pii.NewRedactor(patterns) + application.piiEvents = pii.NewMemoryEventStore(0) + // Apply persisted per-pattern overrides — admins toggling + // action/disabled via the UI and clicking "Save to disk" land + // here on the next start. Bad ids are warned and ignored so a + // stale entry doesn't block startup. + for id, ov := range options.PIIPatternOverrides { + if ov.Action != nil { + if err := application.piiRedactor.SetAction(id, pii.Action(*ov.Action)); err != nil { + xlog.Warn("pii: persisted override skipped", "pattern", id, "error", err) + continue + } + } + if ov.Disabled != nil { + if err := application.piiRedactor.SetDisabled(id, *ov.Disabled); err != nil { + xlog.Warn("pii: persisted disable skipped", "pattern", id, "error", err) + } + } + } + xlog.Info("pii: filter enabled", + "patterns", len(patterns), + "config_path", options.PIIConfigPath, + "persisted_overrides", len(options.PIIPatternOverrides), + ) + } else { + xlog.Info("pii: disabled by --disable-pii") + } + + // Wire the routing decision log. Always-on when stats are enabled — + // the per-router admin page reads this as the live activity feed + // and as input to drift checks for subsystem 5. + if !options.DisableStats { + application.routerDecisions = router.NewMemoryDecisionStore(0) + } + // Process-wide classifier cache shared across all route middlewares so + // the embedding-cache stats endpoint sees a single source of truth. + application.routerRegistry = router.NewRegistry() + + // Subsystem 5: admission control. Limiter is always wired so a + // model that gains a limits: block via gallery install or YAML + // edit takes effect on the next restart without conditional plumbing. + application.admissionLimiter = admission.New() + // Wire JobStore for DB-backed task/job persistence whenever auth DB is available. // This ensures tasks and jobs survive restarts in both single-node and distributed modes. if application.authDB != nil && application.agentJobService != nil { @@ -291,6 +401,20 @@ func New(opts ...config.AppOption) (*Application, error) { loadRuntimeSettingsFromFile(options) } + // Wire the cloudproxy MITM listener. Opt-in: empty MITMListen + // means "no MITM" — operators must explicitly choose to start + // it because clients have to install the generated CA cert. + // The handler reuses the global redactor + event store so an + // admin who's already configured PII filtering for direct API + // traffic doesn't need a parallel config for MITM traffic. + // Runs after loadRuntimeSettingsFromFile so a listener configured + // via /api/settings is brought back up across restarts. + if options.MITMListen != "" { + if err := startMITMProxy(application, options); err != nil { + return nil, fmt.Errorf("mitm: startup: %w", err) + } + } + application.ModelLoader().SetBackendLoggingEnabled(options.EnableBackendLogging) // turn off any process that was started by GRPC if the context is canceled @@ -573,6 +697,25 @@ func loadRuntimeSettingsFromFile(options *config.ApplicationConfig) { options.Branding.FaviconFile = *settings.FaviconFile } + // MITM listener address. The CLI flag WithMITMListen populates + // options at startup; if the user configured MITM via /api/settings + // after the fact, only the file holds the value. Apply when the + // CLI flag did not already set it. (Intercept hosts now live in + // model YAML mitm.hosts: rather than runtime_settings.json.) + if settings.MITMListen != nil && options.MITMListen == "" { + options.MITMListen = *settings.MITMListen + } + + // PII pattern overrides — file is the only source; CLI flags don't + // reach into this map. Apply unconditionally when present; the + // redactor wiring below sees the result on first construction. + if settings.PIIPatternOverrides != nil { + options.PIIPatternOverrides = make(map[string]config.PIIPatternRuntimeOverride, len(*settings.PIIPatternOverrides)) + for id, ov := range *settings.PIIPatternOverrides { + options.PIIPatternOverrides[id] = ov + } + } + // Backend upgrade flags if settings.AutoUpgradeBackends != nil { if !options.AutoUpgradeBackends { diff --git a/core/backend/score.go b/core/backend/score.go new file mode 100644 index 000000000000..2284b02f9f44 --- /dev/null +++ b/core/backend/score.go @@ -0,0 +1,96 @@ +package backend + +import ( + "context" + "fmt" + + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/pkg/grpc" + pb "github.com/mudler/LocalAI/pkg/grpc/proto" + model "github.com/mudler/LocalAI/pkg/model" +) + +// ScoreOptions controls a single Score request. +type ScoreOptions struct { + // IncludeTokenLogprobs returns per-token log-probability detail for + // each candidate. Off by default — the joint LogProb is enough for + // ranking; callers that need calibration / entropy over the token + // stream opt in. + IncludeTokenLogprobs bool + // LengthNormalize divides the joint log-prob by the candidate's + // token count. Useful when comparing candidates of different + // lengths — without it, longer candidates score lower by default. + LengthNormalize bool +} + +// CandidateScore is the per-candidate result. Mirrors pb.CandidateScore +// but avoids leaking the proto type to consumers. +type CandidateScore struct { + LogProb float64 + LengthNormalizedLogProb float64 + NumTokens int + Tokens []TokenLogProb +} + +type TokenLogProb struct { + Token string + LogProb float64 +} + +// ModelScore loads the backend for modelConfig and returns a closure +// that scores `candidates` against `prompt`. The closure is bound to +// the loaded model so callers can keep it around for repeat scoring +// within the same request without re-resolving the backend. +func ModelScore(prompt string, candidates []string, opts ScoreOptions, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (func(ctx context.Context) ([]CandidateScore, error), error) { + modelOpts := ModelOptions(modelConfig, appConfig) + inferenceModel, err := loader.Load(modelOpts...) + if err != nil { + recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil) + return nil, err + } + b, ok := inferenceModel.(grpc.Backend) + if !ok { + return nil, fmt.Errorf("scoring not supported by backend %q", modelConfig.Backend) + } + if len(candidates) == 0 { + return nil, fmt.Errorf("Score: candidates must be non-empty") + } + return func(ctx context.Context) ([]CandidateScore, error) { + resp, err := b.Score(ctx, &pb.ScoreRequest{ + Prompt: prompt, + Candidates: candidates, + IncludeTokenLogprobs: opts.IncludeTokenLogprobs, + LengthNormalize: opts.LengthNormalize, + }) + if err != nil { + return nil, err + } + return scoreResponseToCandidates(resp, opts.IncludeTokenLogprobs), nil + }, nil +} + +// scoreResponseToCandidates converts the wire-format pb response into +// the value type consumed by callers. Extracted to keep ModelScore's +// closure trivial and so the conversion can be unit-tested without a +// real backend. +func scoreResponseToCandidates(resp *pb.ScoreResponse, includeTokens bool) []CandidateScore { + if resp == nil { + return nil + } + out := make([]CandidateScore, len(resp.Candidates)) + for i, c := range resp.Candidates { + cs := CandidateScore{ + LogProb: c.LogProb, + LengthNormalizedLogProb: c.LengthNormalizedLogProb, + NumTokens: int(c.NumTokens), + } + if includeTokens && len(c.Tokens) > 0 { + cs.Tokens = make([]TokenLogProb, len(c.Tokens)) + for j, t := range c.Tokens { + cs.Tokens[j] = TokenLogProb{Token: t.Token, LogProb: t.LogProb} + } + } + out[i] = cs + } + return out +} diff --git a/core/backend/score_test.go b/core/backend/score_test.go new file mode 100644 index 000000000000..48193efab6b9 --- /dev/null +++ b/core/backend/score_test.go @@ -0,0 +1,63 @@ +package backend + +import ( + pb "github.com/mudler/LocalAI/pkg/grpc/proto" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("scoreResponseToCandidates", func() { + It("returns nil for a nil response", func() { + Expect(scoreResponseToCandidates(nil, false)).To(BeNil()) + }) + + It("returns an empty slice when the response has no candidates", func() { + Expect(scoreResponseToCandidates(&pb.ScoreResponse{}, false)).To(BeEmpty()) + }) + + It("copies LogProb / LengthNormalizedLogProb / NumTokens for every candidate", func() { + resp := &pb.ScoreResponse{Candidates: []*pb.CandidateScore{ + {LogProb: -2.0, LengthNormalizedLogProb: -1.0, NumTokens: 2}, + {LogProb: -7.5, LengthNormalizedLogProb: -1.5, NumTokens: 5}, + }} + got := scoreResponseToCandidates(resp, false) + Expect(got).To(HaveLen(2)) + Expect(got[0].LogProb).To(Equal(-2.0)) + Expect(got[0].LengthNormalizedLogProb).To(Equal(-1.0)) + Expect(got[0].NumTokens).To(Equal(2)) + Expect(got[1].LogProb).To(Equal(-7.5)) + Expect(got[1].NumTokens).To(Equal(5)) + }) + + It("omits per-token detail when includeTokens=false even if the wire response carries it", func() { + // Defensive: if the backend over-reports we still respect the + // caller's opt-in so consumers don't pay marshaling for data + // they didn't ask for. + resp := &pb.ScoreResponse{Candidates: []*pb.CandidateScore{{ + LogProb: -1.0, + Tokens: []*pb.TokenLogProb{{Token: "hi", LogProb: -1.0}}, + }}} + got := scoreResponseToCandidates(resp, false) + Expect(got).To(HaveLen(1)) + Expect(got[0].Tokens).To(BeNil()) + }) + + It("populates per-token detail when includeTokens=true", func() { + resp := &pb.ScoreResponse{Candidates: []*pb.CandidateScore{{ + LogProb: -3.0, + NumTokens: 2, + Tokens: []*pb.TokenLogProb{ + {Token: "Hello", LogProb: -1.0}, + {Token: " world", LogProb: -2.0}, + }, + }}} + got := scoreResponseToCandidates(resp, true) + Expect(got).To(HaveLen(1)) + Expect(got[0].Tokens).To(HaveLen(2)) + Expect(got[0].Tokens[0].Token).To(Equal("Hello")) + Expect(got[0].Tokens[0].LogProb).To(Equal(-1.0)) + Expect(got[0].Tokens[1].Token).To(Equal(" world")) + Expect(got[0].Tokens[1].LogProb).To(Equal(-2.0)) + }) +}) diff --git a/core/cli/run.go b/core/cli/run.go index 079cc8ffdfdf..25bbd8143b46 100644 --- a/core/cli/run.go +++ b/core/cli/run.go @@ -155,6 +155,10 @@ type RunCMD struct { AutoApproveNodes bool `env:"LOCALAI_AUTO_APPROVE_NODES" default:"false" help:"Auto-approve new worker nodes (skip admin approval)" group:"distributed"` Version bool + + // Cloud-proxy MITM listener (off by default). + MITMListen string `env:"LOCALAI_MITM_LISTEN" help:"Address (host:port) for the cloudproxy MITM listener. Empty = disabled. Clients set HTTPS_PROXY=http://:. Intercept hosts are declared per-model via the model YAML mitm.hosts: block; create one from the Add Model UI." group:"middleware"` + MITMCADir string `env:"LOCALAI_MITM_CA_DIR" type:"path" help:"Directory holding the MITM proxy CA cert + key. Defaults to /mitm-ca." group:"middleware"` } func (r *RunCMD) Run(ctx *cliContext.Context) error { @@ -213,6 +217,8 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error { config.WithLoadToMemory(r.LoadToMemory), config.WithMachineTag(r.MachineTag), config.WithAPIAddress(r.Address), + config.WithMITMListen(r.MITMListen), + config.WithMITMCADir(r.MITMCADir), config.WithAgentJobRetentionDays(r.AgentJobRetentionDays), config.WithLlamaCPPTunnelCallback(func(tunnels []string) { tunnelEnvVar := strings.Join(tunnels, ",") diff --git a/core/config/application_config.go b/core/config/application_config.go index 39f76b9e0b6f..95e21666f7d7 100644 --- a/core/config/application_config.go +++ b/core/config/application_config.go @@ -39,6 +39,54 @@ type ApplicationConfig struct { P2PNetworkID string Federated bool + // DisableStats turns off per-request token tracking. By default the + // routing module's billing recorder runs in every mode (including + // no-auth single-user) so dashboards and `/api/usage` are immediately + // useful; set this to opt out of that, e.g., for ephemeral CI runs + // or privacy-strict deployments where no token-count history should + // touch disk or memory. + DisableStats bool + + // PIIConfigPath points to an optional YAML file describing the PII + // pattern set. When empty, the routing/pii module's DefaultPatterns() + // (email, phone, SSN, credit card, IPv4, API key prefixes) are + // loaded with their default actions. Each entry overrides the + // matching default by ID: + // + // patterns: + // - id: email + // action: route_local # downgrade default mask -> route_local + // - id: ssn + // action: block # upgrade default mask -> block + // + // Unknown ids are rejected with a clear error at startup. + PIIConfigPath string + + // DisablePII turns the regex PII filter off entirely. Default + // (false) enables it on the OpenAI chat completions route. + DisablePII bool + + // MITMListen is the address (host:port) the cloudproxy MITM + // listener binds on. Empty disables the MITM proxy entirely. + // Use case: redacting PII from Claude Code / Codex CLI traffic + // without LocalAI holding the upstream API key. Clients set + // HTTPS_PROXY=http://localai:port and trust the CA cert + // LocalAI exposes at /api/middleware/proxy-ca.crt. + MITMListen string + + // MITMCADir holds the persisted MITM proxy CA cert and private + // key. The CA is generated on first start; subsequent starts + // reload it so clients keep trusting the same root. The key + // file is mode 0600. + MITMCADir string + + + // PIIPatternOverrides applies persisted per-id deltas (action, + // disabled) to the live redactor at startup. Loaded from + // runtime_settings.json and applied right after pii.NewRedactor. + // nil/empty leaves the YAML defaults in place. + PIIPatternOverrides map[string]PIIPatternRuntimeOverride + DisableWebUI bool OllamaAPIRootEndpoint bool EnforcePredownloadScans bool @@ -585,6 +633,45 @@ func WithDataPath(dataPath string) AppOption { } } +// WithDisableStats turns off the billing recorder. CLI: --disable-stats. +func WithDisableStats(disable bool) AppOption { + return func(o *ApplicationConfig) { + o.DisableStats = disable + } +} + +// WithPIIConfigPath points the routing PII filter at a YAML config +// file. CLI: --pii-config. +func WithPIIConfigPath(path string) AppOption { + return func(o *ApplicationConfig) { + o.PIIConfigPath = path + } +} + +// WithDisablePII turns the regex PII filter off. CLI: --disable-pii. +func WithDisablePII(disable bool) AppOption { + return func(o *ApplicationConfig) { + o.DisablePII = disable + } +} + +// WithMITMListen sets the address the cloudproxy MITM listener +// binds on. Empty = disabled. CLI: --mitm-listen. +func WithMITMListen(addr string) AppOption { + return func(o *ApplicationConfig) { + o.MITMListen = addr + } +} + +// WithMITMCADir sets the directory used to persist the MITM proxy +// CA cert + key. CLI: --mitm-ca-dir. +func WithMITMCADir(dir string) AppOption { + return func(o *ApplicationConfig) { + o.MITMCADir = dir + } +} + + func WithDynamicConfigDir(dynamicConfigsDir string) AppOption { return func(o *ApplicationConfig) { o.DynamicConfigsDir = dynamicConfigsDir @@ -978,6 +1065,8 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings { logoHorizontalFile := o.Branding.LogoHorizontalFile faviconFile := o.Branding.FaviconFile + mitmListen := o.MITMListen + return RuntimeSettings{ WatchdogEnabled: &watchdogEnabled, WatchdogIdleEnabled: &watchdogIdle, @@ -1030,6 +1119,7 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings { LogoFile: &logoFile, LogoHorizontalFile: &logoHorizontalFile, FaviconFile: &faviconFile, + MITMListen: &mitmListen, } } @@ -1252,6 +1342,10 @@ func (o *ApplicationConfig) ApplyRuntimeSettings(settings *RuntimeSettings) (req o.Branding.FaviconFile = *settings.FaviconFile } + if settings.MITMListen != nil { + o.MITMListen = *settings.MITMListen + } + // Note: ApiKeys requires special handling (merging with startup keys) - handled in caller return requireRestart diff --git a/core/config/meta/registry.go b/core/config/meta/registry.go index 99f9e0298fd6..2d600c16acf0 100644 --- a/core/config/meta/registry.go +++ b/core/config/meta/registry.go @@ -320,5 +320,172 @@ func DefaultRegistry() map[string]FieldMetaOverride { Description: "Enable CUDA for diffusers", Order: 82, }, + + // --- PII filtering (per-model) --- + "pii.enabled": { + Section: "other", + Label: "PII Filtering Enabled", + Description: "Enable PII redaction middleware for this model. Unset means use the default (off for local backends, on for proxy-* / cloud-hosted backends).", + Component: "toggle", + Order: 200, + }, + "pii.patterns": { + Section: "other", + Label: "PII Pattern Overrides", + Description: "Override the global default action for specific patterns on this model. Patterns not listed here inherit the global action (Settings → Middleware → Filtering).", + Component: "pii-pattern-list", + Order: 201, + }, + + // --- Cloud passthrough proxy --- + // These only have an effect when Backend is set to a + // "proxy-*" name (e.g. proxy-openai, proxy-anthropic). When + // the upstream URL is empty, the model fails closed — the + // chat handler does NOT silently fall back to the local + // gRPC pipeline. + "proxy.upstream_url": { + Section: "other", + Label: "Proxy Upstream URL", + Description: "Full POST endpoint of the upstream provider (e.g. https://api.openai.com/v1/chat/completions). Only used when Backend starts with proxy-.", + Component: "input", + Order: 210, + }, + "proxy.api_key_env": { + Section: "other", + Label: "Proxy API Key Env Var", + Description: "Name of the environment variable holding the upstream API key. Reading from env keeps the secret out of the YAML and the admin UI.", + Component: "input", + Order: 211, + }, + "proxy.upstream_model": { + Section: "other", + Label: "Proxy Upstream Model", + Description: "Model name sent to the upstream. Leave empty to forward the client's model field unchanged. Useful when the LocalAI alias differs from the upstream's canonical name.", + Component: "input", + Order: 212, + }, + "proxy.request_timeout_seconds": { + Section: "other", + Label: "Proxy Request Timeout (seconds)", + Description: "Caps the upstream HTTP request duration. 0 disables the deadline; the request still ends when the client disconnects.", + Component: "number", + Min: f64(0), + Order: 213, + }, + + // --- MITM intercept hosts --- + // Each host listed here is claimed by this model config; the + // cloudproxy MITM listener (see Middleware → MITM Proxy) uses + // THIS config's pii: settings to filter the intercepted traffic. + // A host claimed by two configs is a critical error — the + // listener refuses to start until resolved. + "mitm.hosts": { + Section: "other", + Label: "MITM Intercept Hosts", + Description: "Hostnames the cloudproxy MITM proxy terminates TLS for on behalf of this model config. PII filtering and pattern overrides flow from this model when the host is intercepted. Each host must be unique across all configs.", + Component: "string-list", + Order: 220, + }, + + // --- Router --- + // Routing turns this model config into a dispatcher: the + // classifier scores every policy label as a continuation of + // the routing prompt and picks the first candidate whose + // labels are a superset of the active set. The Routing tab of + // the middleware admin page surfaces every model with a router + // block. + "router.classifier": { + Section: "other", + Label: "Classifier", + Description: "Picks a candidate by scoring every policy label against the prompt. Only \"score\" is shipped today; it asks the classifier_model to rank each label and reads off the softmax. Empty defaults to \"score\".", + Component: "select", + Options: []FieldOption{ + {Value: "score", Label: "Score (Arch-Router-style)"}, + }, + Order: 230, + }, + "router.classifier_model": { + Section: "other", + Label: "Classifier Model", + Description: "Loaded LocalAI model the score classifier asks to rank each policy label as a continuation. Must support the Score gRPC primitive (today: llama-cpp, vLLM) and use the ChatML template. Arch-Router-1.5B Q4_K_M is the canonical choice; any small ChatML instruct model also works at a higher activation_threshold.", + Component: "model-select", + AutocompleteProvider: ProviderModelsChat, + Order: 231, + }, + "router.fallback": { + Section: "other", + Label: "Fallback Model", + Description: "Model used when no candidate's labels cover the classifier's active label set, or when the classifier errors. Empty means router failures bubble up as HTTP 500 — fail-fast, not silent-bypass.", + Component: "model-select", + AutocompleteProvider: ProviderModelsChat, + Order: 232, + }, + "router.activation_threshold": { + Section: "other", + Label: "Activation Threshold", + Description: "Softmax-probability floor a policy must clear to join the active label set for a request. Higher → single-label dominant routes; lower → more multi-label activations. 0 picks the package default (0.15). On Arch-Router-1.5B a value around 0.40 keeps the dominant label clean without losing genuine compound activations.", + Component: "slider", + Min: f64(0), + Max: f64(1), + Step: f64(0.05), + Order: 233, + }, + "router.classifier_cache_size": { + Section: "other", + Label: "Classifier L1 Cache Size", + Description: "Bounded LRU keyed on (case-folded, whitespace-trimmed) prompt — amortises the classifier round-trip across verbatim repeats common in agent loops. 0 here means \"use the default\" (1024); the cache cannot be disabled from YAML.", + Component: "number", + Min: f64(0), + Order: 234, + }, + "router.policies": { + Section: "other", + Label: "Policies", + Description: "Label vocabulary the classifier scores over. Each policy has a label and a short natural-language description fed verbatim to the classifier model. Short action-oriented sentences work best (\"writing or debugging code\"; \"small talk\").", + Component: "router-policies", + Order: 235, + }, + "router.candidates": { + Section: "other", + Label: "Candidates", + Description: "Routing table: each entry binds a downstream model to a set of policy labels it can serve. Order matters — the middleware picks the FIRST candidate whose labels are a superset of the active set, so list candidates smallest → largest.", + Component: "router-candidates", + Order: 236, + }, + "router.embedding_cache.embedding_model": { + Section: "other", + Label: "L2 Cache: Embedding Model", + Description: "Embedding model used by the L2 decision cache. Embeds incoming probes and looks them up in the per-router local-store collection. Empty disables the cache entirely. nomic-embed-text-v1.5 is the recommended default.", + Component: "model-select", + AutocompleteProvider: ProviderModels, + Order: 237, + }, + "router.embedding_cache.similarity_threshold": { + Section: "other", + Label: "L2 Cache: Similarity Threshold", + Description: "Cosine-similarity floor a cache candidate must clear to count as a hit. 0 picks the package default (0.80). Re-tune per embedding model — the histogram on the Routing tab shows where the cosine distribution actually sits.", + Component: "slider", + Min: f64(0), + Max: f64(1), + Step: f64(0.01), + Order: 238, + }, + "router.embedding_cache.confidence_threshold": { + Section: "other", + Label: "L2 Cache: Confidence Threshold", + Description: "Minimum top-label probability a classifier decision must have to be inserted into the cache. 0 picks the package default (0.60). Uncertain decisions are skipped so they can't poison future paraphrases.", + Component: "slider", + Min: f64(0), + Max: f64(1), + Step: f64(0.05), + Order: 239, + }, + "router.embedding_cache.store_name": { + Section: "other", + Label: "L2 Cache: Store Name", + Description: "Optional override for the local-store collection used by this router's cache. Empty defaults to \"router-cache-\". Two routers sharing a store_name share their cache (rare).", + Component: "input", + Order: 240, + }, } } diff --git a/core/config/mitm_host_owners_test.go b/core/config/mitm_host_owners_test.go new file mode 100644 index 000000000000..b3879858b55f --- /dev/null +++ b/core/config/mitm_host_owners_test.go @@ -0,0 +1,133 @@ +package config_test + +import ( + "os" + "path/filepath" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/mudler/LocalAI/core/config" +) + +// MITMHostOwners is the load-bearing piece of D2 — a duplicate host +// across model configs is a critical error that disables the listener. +// The test exercises both happy paths (no duplicates → clean Owners +// map) and conflict detection (two configs on one host → entry in +// Conflicts naming both). + +var _ = Describe("ModelConfigLoader.MITMHostOwners", func() { + var ( + dir string + loader *config.ModelConfigLoader + ) + + writeYAML := func(name, body string) { + path := filepath.Join(dir, name+".yaml") + Expect(os.WriteFile(path, []byte(body), 0o644)).To(Succeed()) + Expect(loader.ReadModelConfig(path)).To(Succeed()) + } + + BeforeEach(func() { + var err error + dir, err = os.MkdirTemp("", "mitm-host-owners-test-*") + Expect(err).ToNot(HaveOccurred()) + loader = config.NewModelConfigLoader(dir) + }) + + AfterEach(func() { + _ = os.RemoveAll(dir) + }) + + It("returns empty maps when no model declares mitm.hosts", func() { + writeYAML("plain", `name: plain +backend: llama-cpp +`) + got := loader.MITMHostOwners() + Expect(got.Owners).To(BeEmpty()) + Expect(got.Conflicts).To(BeEmpty()) + }) + + It("indexes hosts to the owning model name", func() { + writeYAML("claude", `name: claude +backend: proxy-anthropic +mitm: + hosts: + - api.anthropic.com +`) + writeYAML("openai", `name: openai +backend: proxy-openai +mitm: + hosts: + - api.openai.com + - api.openai.azure.com +`) + got := loader.MITMHostOwners() + Expect(got.Owners).To(Equal(map[string]string{ + "api.anthropic.com": "claude", + "api.openai.com": "openai", + "api.openai.azure.com": "openai", + })) + Expect(got.Conflicts).To(BeEmpty()) + }) + + It("normalises case and trims whitespace before indexing", func() { + writeYAML("claude", `name: claude +backend: proxy-anthropic +mitm: + hosts: + - " API.ANTHROPIC.com " +`) + got := loader.MITMHostOwners() + Expect(got.Owners).To(HaveKey("api.anthropic.com")) + }) + + It("detects two configs claiming the same host as a conflict", func() { + // The 1-to-1 invariant the D2 dispatcher relies on: a host + // claimed twice means the owner lookup is ambiguous, so the + // caller must NOT start the MITM listener until resolved. + writeYAML("alpha", `name: alpha +backend: proxy-anthropic +mitm: + hosts: + - api.anthropic.com +`) + writeYAML("beta", `name: beta +backend: proxy-anthropic +mitm: + hosts: + - api.anthropic.com +`) + got := loader.MITMHostOwners() + Expect(got.Conflicts).To(HaveKey("api.anthropic.com")) + Expect(got.Conflicts["api.anthropic.com"]).To(ConsistOf("alpha", "beta")) + }) + + It("treats the same host listed twice within ONE config as a no-op (not a conflict)", func() { + // A single config repeating a host is benign — same owner + // either way. The conflict signal must be cross-config only. + writeYAML("dup", `name: dup +backend: llama-cpp +mitm: + hosts: + - api.example.com + - api.example.com +`) + got := loader.MITMHostOwners() + Expect(got.Owners).To(Equal(map[string]string{"api.example.com": "dup"})) + Expect(got.Conflicts).To(BeEmpty()) + }) + + It("ignores empty/whitespace-only host entries", func() { + writeYAML("sloppy", `name: sloppy +backend: llama-cpp +mitm: + hosts: + - "" + - " " + - api.real.com +`) + got := loader.MITMHostOwners() + Expect(got.Owners).To(Equal(map[string]string{"api.real.com": "sloppy"})) + }) +}) diff --git a/core/config/model_config.go b/core/config/model_config.go index f14bc4a4e408..84bb7ff5680b 100644 --- a/core/config/model_config.go +++ b/core/config/model_config.go @@ -95,8 +95,273 @@ type ModelConfig struct { Options []string `yaml:"options,omitempty" json:"options,omitempty"` Overrides []string `yaml:"overrides,omitempty" json:"overrides,omitempty"` - MCP MCPConfig `yaml:"mcp,omitempty" json:"mcp,omitempty"` - Agent AgentConfig `yaml:"agent,omitempty" json:"agent,omitempty"` + MCP MCPConfig `yaml:"mcp,omitempty" json:"mcp,omitempty"` + Agent AgentConfig `yaml:"agent,omitempty" json:"agent,omitempty"` + PII PIIConfig `yaml:"pii,omitempty" json:"pii,omitempty"` + Router RouterConfig `yaml:"router,omitempty" json:"router,omitempty"` + Proxy ProxyConfig `yaml:"proxy,omitempty" json:"proxy,omitempty"` + MITM MITMModelConfig `yaml:"mitm,omitempty" json:"mitm,omitempty"` + Limits LimitsConfig `yaml:"limits,omitempty" json:"limits,omitempty"` +} + +// @Description Admission-control limits applied per request. The +// admission middleware enforces these before invoking the handler; +// requests that exceed a limit get 503 with a Retry-After hint so +// clients back off rather than pile on. Per-model so cloud passthroughs +// can have a stricter ceiling than local models. +type LimitsConfig struct { + // MaxConcurrent caps simultaneous in-flight requests for this + // model. 0 = unlimited (default). Useful for cloud-passthrough + // configs where the upstream rate-limits aggressively, or for + // local backends whose memory budget tops out before LocalAI's + // queue depth would. + MaxConcurrent int `yaml:"max_concurrent,omitempty" json:"max_concurrent,omitempty"` + + // RetryAfterSeconds advises clients how long to wait before + // retrying when admission rejects. 0 defaults to 1s — enough to + // let an in-flight request finish on a busy local model. The + // value is sent verbatim in the Retry-After response header. + RetryAfterSeconds int `yaml:"retry_after_seconds,omitempty" json:"retry_after_seconds,omitempty"` +} + +// @Description MITM intercept binding for the model. When the cloudproxy +// MITM listener is enabled and any host listed here appears in a CONNECT, +// the proxy uses THIS model config's pii: settings to filter the +// intercepted body. Strict 1-to-1: a host claimed by two configs is a +// configuration error and disables the MITM listener until resolved. +// +// Lets an admin pair a host (api.anthropic.com) with the model's +// PII overrides without maintaining a parallel per-host map. +type MITMModelConfig struct { + // Hosts is the list of hostnames this model claims for MITM + // interception. Each entry must be unique across all model configs. + Hosts []string `yaml:"hosts,omitempty" json:"hosts,omitempty"` +} + +// @Description Cloud passthrough proxy configuration. When the backend +// name starts with "proxy-" and a non-empty UpstreamURL is set, the +// chat / messages handler bypasses the gRPC backend pipeline and +// forwards the request to the upstream provider, streaming the SSE +// response back to the client untouched (apart from the streaming PII +// filter, which still runs because it operates on extracted token +// text rather than the wire envelope). +// +// The provider is inferred from Backend ("proxy-openai" → openai +// chat-completions wire shape; "proxy-anthropic" → anthropic messages +// wire shape). No request-shape translation is performed in the MVP — +// the client must speak the same wire format as the upstream. +type ProxyConfig struct { + // UpstreamURL is the full POST endpoint, e.g. + // https://api.openai.com/v1/chat/completions or + // https://api.anthropic.com/v1/messages. Empty disables the + // proxy bail even when Backend is proxy-*. + UpstreamURL string `yaml:"upstream_url,omitempty" json:"upstream_url,omitempty"` + + // APIKeyEnv names the environment variable holding the upstream + // API key. Reading from env (rather than embedding the key in + // YAML) keeps secrets out of the config files and the admin UI. + APIKeyEnv string `yaml:"api_key_env,omitempty" json:"api_key_env,omitempty"` + + // UpstreamModel overrides the model name sent to the upstream. + // Useful when the LocalAI-facing model alias differs from the + // upstream's canonical name (e.g. local "claude-strict" maps to + // upstream "claude-3-5-sonnet-20241022"). Empty means forward + // the client's model field unchanged. + UpstreamModel string `yaml:"upstream_model,omitempty" json:"upstream_model,omitempty"` + + // RequestTimeoutSeconds caps the upstream request duration. 0 + // means no per-request timeout (only the request context, which + // is bound to the client connection, applies). + RequestTimeoutSeconds int `yaml:"request_timeout_seconds,omitempty" json:"request_timeout_seconds,omitempty"` +} + +// IsCloudProxy returns true when this model is configured to forward +// requests to an external provider rather than running through the +// local gRPC backend pipeline. The Backend prefix is the gating +// signal (it also drives the PII default-on rule in PIIIsEnabled); +// UpstreamURL must additionally be non-empty so a half-configured +// proxy fails closed instead of silently routing nowhere. +func (c *ModelConfig) IsCloudProxy() bool { + return strings.HasPrefix(c.Backend, "proxy-") && c.Proxy.UpstreamURL != "" +} + +// @Description Intelligent routing configuration. When a model declares +// a Router block, requests addressed to it are reclassified at runtime +// and dispatched to one of the named candidates. The router rewrites +// input.Model in-place, then the standard model-resolution path picks +// up the resolved config — meaning ACL checks, disabled-state, and +// per-model PII still run against the chosen target. +// +// Depth-1 invariant: candidates must NOT themselves carry a Router +// block. The router's "smart-router → claude-strict → proxy-anthropic" +// chain is fine, but "router-A → router-B → claude" is rejected at +// config load to keep the dispatch graph acyclic and predictable. The +// middleware also asserts depth ≤ 1 at runtime as a defensive check. +type RouterConfig struct { + // Classifier picks the implementation. Only "score" ships today: + // it asks the classifier model to score every Policy label as a + // continuation of the routing prompt and reads off the + // distribution. Empty defaults to "score". + Classifier string `yaml:"classifier,omitempty" json:"classifier,omitempty"` + + // Policies is the label vocabulary the classifier scores over. + // Each policy carries a natural-language description that ends up + // in the system prompt the classifier model sees — short, action- + // oriented sentences work best ("writing or debugging code", + // "small talk", ...). The Score classifier picks the subset of + // labels whose softmax probability passes ActivationThreshold. + Policies []RouterPolicy `yaml:"policies,omitempty" json:"policies,omitempty"` + + // Candidates is the routing table — each entry binds a downstream + // model to a set of labels it can serve. The middleware picks the + // FIRST candidate whose Labels are a superset of the active label + // set from the classifier. Admins order this list smallest → + // largest so a query that needs one label routes to the smallest + // capable model, while a query that needs multiple falls to a + // bigger candidate that covers them all. + Candidates []RouterCandidate `yaml:"candidates,omitempty" json:"candidates,omitempty"` + + // Fallback is the model used when no candidate matches the active + // label set, or when the classifier returns nothing above + // threshold. Empty fallback means router failures bubble up as + // 500 — fail-fast, not silent-bypass. + Fallback string `yaml:"fallback,omitempty" json:"fallback,omitempty"` + + // ClassifierModel names the model the Score classifier scores + // against (Arch-Router-1.5B is the canonical choice). + ClassifierModel string `yaml:"classifier_model,omitempty" json:"classifier_model,omitempty"` + + // ClassifierCacheSize bounds the per-prompt memo cache that + // amortises the classifier round-trip across repeat probes. + // 0 disables the cache. Default 1024. + ClassifierCacheSize int `yaml:"classifier_cache_size,omitempty" json:"classifier_cache_size,omitempty"` + + // ActivationThreshold is the softmax-probability floor a policy + // must clear to be considered "active" for the request. 0 + // defaults to a sensible value (~0.15) inside the classifier. + // Higher → narrower routes (single-label dominant); lower → + // more multi-label activations. + ActivationThreshold float64 `yaml:"activation_threshold,omitempty" json:"activation_threshold,omitempty"` + + // EmbeddingCache configures the L2 cache that maps prompt + // embeddings to past decisions, so semantically-similar prompts + // reuse a classification instead of re-running the classifier + // model. Omit the block to disable. See router/embedding_cache.go. + EmbeddingCache *EmbeddingCacheConfig `yaml:"embedding_cache,omitempty" json:"embedding_cache,omitempty"` +} + +// EmbeddingCacheConfig configures the L2 embedding-similarity decision +// cache. Pairs naturally with a larger / slower classifier model: the +// classifier round-trip is amortised across paraphrases of the same +// intent. The cache uses the standard /v1/embeddings backend for +// vector generation and the local-store gRPC surface for KNN search. +type EmbeddingCacheConfig struct { + // EmbeddingModel names the loaded LocalAI model used to embed + // router prompts. Required when the cache is enabled. Any model + // that supports the Embeddings gRPC primitive works; + // nomic-embed-text-v1.5 is the recommended default. + EmbeddingModel string `yaml:"embedding_model" json:"embedding_model"` + + // SimilarityThreshold is the cosine-similarity floor a cache + // candidate must clear to be treated as a hit. 0 picks the + // package default (0.80). Higher → fewer false hits, higher miss + // rate; lower → more aggressive sharing across paraphrases. + SimilarityThreshold float64 `yaml:"similarity_threshold,omitempty" json:"similarity_threshold,omitempty"` + + // ConfidenceThreshold is the minimum classifier top-label + // probability for a decision to be inserted into the cache. 0 + // picks the package default (0.60). Uncertain decisions are not + // cached so they can't poison future paraphrases. + ConfidenceThreshold float64 `yaml:"confidence_threshold,omitempty" json:"confidence_threshold,omitempty"` + + // StoreName overrides the local-store collection name used for + // this router's cache. Empty defaults to "router-cache-" + // where is the parent model name. Useful when two + // router models should share a cache (rare). + StoreName string `yaml:"store_name,omitempty" json:"store_name,omitempty"` +} + +// RouterPolicy is one entry in the label vocabulary. The label string +// is what the classifier model emits and what candidates reference in +// their Labels field; the description is the natural-language hint +// fed to the classifier so it can match user intent against the label +// space. +type RouterPolicy struct { + Label string `yaml:"label" json:"label"` + Description string `yaml:"description" json:"description"` +} + +// RouterCandidate names a downstream model and the policy labels it +// is willing to serve. Labels are matched as a set: the middleware +// picks the first candidate whose Labels is a superset of the +// classifier's active set. +type RouterCandidate struct { + Model string `yaml:"model" json:"model"` + Labels []string `yaml:"labels" json:"labels"` +} + +// HasRouter returns true when the model declares a router config with +// at least one candidate. Used by the RouteModel middleware to decide +// whether to engage the classifier. +func (c *ModelConfig) HasRouter() bool { + return len(c.Router.Candidates) > 0 +} + +// @Description PII filtering configuration. PII redaction is per-model so +// that local models don't pay the latency or behaviour change of regex +// scanning, while cloud-bound traffic (proxy-* backends) can default to +// on. Setting Enabled explicitly always wins over the backend default. +type PIIConfig struct { + // Enabled toggles redaction for this model. When unset (zero value), + // the resolved default depends on Backend: any backend whose name + // starts with "proxy-" defaults to true, everything else to false. + // A pointer is used so the absence of the YAML key is distinguishable + // from explicit false. + Enabled *bool `yaml:"enabled,omitempty" json:"enabled,omitempty"` + + // Patterns lets a model upgrade or downgrade individual pattern + // actions (mask | block | route_local) relative to the global + // defaults loaded from --pii-config / DefaultPatterns. Pattern IDs + // not listed inherit the global action. The regex itself stays + // global — only the action is settable per-model. + Patterns []PIIPatternOverride `yaml:"patterns,omitempty" json:"patterns,omitempty"` +} + +// @Description Per-model action override for a single PII pattern. +type PIIPatternOverride struct { + ID string `yaml:"id" json:"id"` + Action string `yaml:"action" json:"action"` +} + +// PIIIsEnabled returns the resolved PII state for this model. Single +// source of truth for the gating decision so the middleware and the +// /api/middleware/status admin view agree. +func (c *ModelConfig) PIIIsEnabled() bool { + if c.PII.Enabled != nil { + return *c.PII.Enabled + } + return strings.HasPrefix(c.Backend, "proxy-") +} + +// PIIPatternOverrides returns the per-pattern action overrides as a map +// keyed by pattern ID. The values are the raw action strings — the pii +// package validates and converts them. +// +// Returned via the documented modelPIIConfig interface in +// core/services/routing/pii/middleware.go without taking a config +// dependency on this package. +func (c *ModelConfig) PIIPatternOverrides() map[string]string { + if len(c.PII.Patterns) == 0 { + return nil + } + out := make(map[string]string, len(c.PII.Patterns)) + for _, p := range c.PII.Patterns { + if p.ID == "" { + continue + } + out[p.ID] = p.Action + } + return out } // @Description MCP configuration diff --git a/core/config/model_config_loader.go b/core/config/model_config_loader.go index 32b2bb38a03a..89f4bc5cb1ee 100644 --- a/core/config/model_config_loader.go +++ b/core/config/model_config_loader.go @@ -388,6 +388,49 @@ func (bcl *ModelConfigLoader) Preload(modelPath string) error { return nil } +// MITMHostOwnership is the result of mapping intercept hosts to the +// model configs that claim them. The invariant the dispatcher relies +// on: every host belongs to AT MOST one model config. Any duplicate +// is surfaced via Conflicts and disables the MITM listener until +// resolved — a half-applied "first wins" rule would silently mask +// configuration drift, so we fail loud. +type MITMHostOwnership struct { + // Owners maps lowercase hostname → owning model name. Empty when + // no model declares mitm.hosts. + Owners map[string]string + // Conflicts lists hosts claimed by 2+ configs, with the names of + // the configs that claim them. Non-empty Conflicts means callers + // must NOT start the MITM listener. + Conflicts map[string][]string +} + +// MITMHostOwners walks every loaded ModelConfig's mitm.hosts, builds +// the host→owner index, and reports any duplicates. The lookup table +// is hostname-lowercased to match the Server's allowlist semantics. +func (bcl *ModelConfigLoader) MITMHostOwners() MITMHostOwnership { + bcl.Lock() + defer bcl.Unlock() + owners := map[string]string{} + collisions := map[string][]string{} + for name, cfg := range bcl.configs { + for _, h := range cfg.MITM.Hosts { + h = strings.ToLower(strings.TrimSpace(h)) + if h == "" { + continue + } + if existing, ok := owners[h]; ok && existing != name { + if _, seen := collisions[h]; !seen { + collisions[h] = []string{existing} + } + collisions[h] = append(collisions[h], name) + continue + } + owners[h] = name + } + } + return MITMHostOwnership{Owners: owners, Conflicts: collisions} +} + // LoadModelConfigsFromPath reads all the configurations of the models from a path // (non-recursive) func (bcl *ModelConfigLoader) LoadModelConfigsFromPath(path string, opts ...ConfigLoaderOption) error { diff --git a/core/config/proxy_test.go b/core/config/proxy_test.go new file mode 100644 index 000000000000..55a9d322d5a2 --- /dev/null +++ b/core/config/proxy_test.go @@ -0,0 +1,31 @@ +package config + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("IsCloudProxy", func() { + cases := []struct { + name string + backend string + upstreamURL string + want bool + }{ + {"local backend, no proxy config", "llama-cpp", "", false}, + {"local backend with stray upstream URL", "llama-cpp", "https://api.openai.com", false}, + {"proxy backend without URL fails closed", "proxy-openai", "", false}, + {"proxy-openai with URL", "proxy-openai", "https://api.openai.com/v1/chat/completions", true}, + {"proxy-anthropic with URL", "proxy-anthropic", "https://api.anthropic.com/v1/messages", true}, + {"proxy-unknown-vendor with URL", "proxy-grok", "https://example.com", true}, + } + for _, tc := range cases { + It(tc.name, func() { + cfg := ModelConfig{ + Backend: tc.backend, + Proxy: ProxyConfig{UpstreamURL: tc.upstreamURL}, + } + Expect(cfg.IsCloudProxy()).To(Equal(tc.want)) + }) + } +}) diff --git a/core/config/runtime_settings.go b/core/config/runtime_settings.go index 3fb16233e7dc..a7211293b896 100644 --- a/core/config/runtime_settings.go +++ b/core/config/runtime_settings.go @@ -89,4 +89,26 @@ type RuntimeSettings struct { LogoFile *string `json:"logo_file,omitempty"` LogoHorizontalFile *string `json:"logo_horizontal_file,omitempty"` FaviconFile *string `json:"favicon_file,omitempty"` + + // Cloud-proxy MITM listener. MITMCADir is intentionally NOT + // exposed at runtime — the CA dir is a startup-only path and + // changing it after the CA has been generated would orphan + // trusted clients. + MITMListen *string `json:"mitm_listen,omitempty"` + + // PII pattern overrides — keyed by pattern id, applied to the live + // redactor at startup and persisted by POST /api/pii/patterns/persist. + // Distinguishes from --pii-config (which replaces the entire + // pattern set) by only carrying the per-id action/enabled deltas + // against the global default catalog. + PIIPatternOverrides *map[string]PIIPatternRuntimeOverride `json:"pii_pattern_overrides,omitempty"` +} + +// PIIPatternRuntimeOverride captures the persistable deltas an admin +// has applied to a single global PII pattern. Both fields are pointers +// so an override that only flips Disabled doesn't have to also restate +// Action (and vice versa). +type PIIPatternRuntimeOverride struct { + Action *string `json:"action,omitempty"` + Disabled *bool `json:"disabled,omitempty"` } diff --git a/core/config/runtime_settings_persist_test.go b/core/config/runtime_settings_persist_test.go index b2f61c10a9fa..a36acb0d26ae 100644 --- a/core/config/runtime_settings_persist_test.go +++ b/core/config/runtime_settings_persist_test.go @@ -51,6 +51,25 @@ var _ = Describe("RuntimeSettings persistence helpers", func() { }) }) + // MITM round trip pins the contract that loadRuntimeSettingsFromFile + // MITM listener address must survive a write/read round trip so the + // next process restart can bring the listener back up. (Intercept + // hosts now live in model YAML rather than runtime_settings.json.) + Describe("MITM round trip", func() { + It("preserves mitm_listen across read/write", func() { + listen := ":8443" + Expect(cfg.WritePersistedSettings(config.RuntimeSettings{ + MITMListen: &listen, + })).To(Succeed()) + + got, err := cfg.ReadPersistedSettings() + Expect(err).ToNot(HaveOccurred()) + + Expect(got.MITMListen).ToNot(BeNil()) + Expect(*got.MITMListen).To(Equal(":8443")) + }) + }) + // PreserveOnSaveDoesNotClobberAssets reproduces the user-reported // regression: an admin uploads a logo, then clicks Save on the // Settings page. The Save body still has the stale pre-upload diff --git a/core/http/app.go b/core/http/app.go index 99d11bd69c5c..33a54fb47dd0 100644 --- a/core/http/app.go +++ b/core/http/app.go @@ -25,7 +25,6 @@ import ( "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/services/finetune" "github.com/mudler/LocalAI/core/services/galleryop" - "github.com/mudler/LocalAI/core/services/monitoring" "github.com/mudler/LocalAI/core/services/nodes" "github.com/mudler/LocalAI/core/services/quantization" @@ -212,19 +211,18 @@ func API(application *application.Application) (*echo.Echo, error) { e.Use(middleware.Recover()) } - // Metrics middleware - if !application.ApplicationConfig().DisableMetrics { - metricsService, err := monitoring.NewLocalAIMetricsService() - if err != nil { - return nil, err - } - - if metricsService != nil { - e.Use(localai.LocalAIMetricsAPIMiddleware(metricsService)) - e.Server.RegisterOnShutdown(func() { - metricsService.Shutdown() - }) - } + // Metrics middleware. The metric service was created in + // application.start() so the OTel global provider is set before any + // counter is registered (the routing-module billing recorder relies + // on this). We reuse that instance here rather than calling + // monitoring.NewLocalAIMetricsService a second time, which would + // create a second provider, second prometheus exporter, and orphan + // whichever instance lost the SetMeterProvider race. + if metricsService := application.MetricsService(); metricsService != nil { + e.Use(localai.LocalAIMetricsAPIMiddleware(metricsService)) + e.Server.RegisterOnShutdown(func() { + _ = metricsService.Shutdown() + }) } // Health Checks should always be exempt from auth, so register these first @@ -267,10 +265,9 @@ func API(application *application.Application) (*echo.Echo, error) { e.Static("/generated-videos", videoPath) } - // Initialize usage recording when auth DB is available - if application.AuthDB() != nil { - httpMiddleware.InitUsageRecorder(application.AuthDB()) - } + // Usage recording is initialised in application/startup.go and + // surfaced via application.StatsRecorder(); routes wire UsageMiddleware + // against that recorder regardless of auth state. // Auth is applied to _all_ endpoints. Filtering out endpoints to bypass is // the role of the exempt-path logic inside the middleware. @@ -357,6 +354,13 @@ func API(application *application.Application) (*echo.Echo, error) { // Register auth routes (login, callback, API keys, user management) routes.RegisterAuthRoutes(e, application) + // Register routing-module usage endpoints. Unlike /api/auth/usage + // these go through the StatsRecorder and work in no-auth single-user + // mode by attributing requests to the synthetic "local" user. + routes.RegisterUsageRoutes(e, application) + routes.RegisterPIIRoutes(e, application) + routes.RegisterMiddlewareRoutes(e, application) + routes.RegisterElevenLabsRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()) // Create opcache for tracking UI operations (used by both UI and LocalAI routes) diff --git a/core/http/auth/usage.go b/core/http/auth/usage.go index 31c3202b2a6e..b227b0454277 100644 --- a/core/http/auth/usage.go +++ b/core/http/auth/usage.go @@ -9,6 +9,18 @@ import ( ) // UsageRecord represents a single API request's token usage. +// +// Model semantics: Model is the legacy column kept for backward-compatible +// aggregation; new code should write RequestedModel (what the client asked +// for) and ServedModel (what actually ran after routing). When no router +// is in play, all three are equal. +// +// PreFilterPromptTokens vs PromptTokens: PromptTokens is the count after +// PII redaction (i.e., what the backend processed and was billed for). +// PreFilterPromptTokens is the count of the original prompt before any +// PII filtering; PostFilterPromptTokens duplicates PromptTokens for +// queryability symmetry. For non-PII paths PreFilterPromptTokens == +// PostFilterPromptTokens == PromptTokens. type UsageRecord struct { ID uint `gorm:"primaryKey;autoIncrement"` UserID string `gorm:"size:36;index:idx_usage_user_time"` @@ -20,6 +32,22 @@ type UsageRecord struct { TotalTokens int64 Duration int64 // milliseconds CreatedAt time.Time `gorm:"index:idx_usage_user_time"` + + // Routing extension fields. Nullable / zero-valued for legacy rows. + RequestedModel string `gorm:"size:255;index"` + ServedModel string `gorm:"size:255;index"` + PreFilterPromptTokens int64 // tokens the client sent before PII redaction + PostFilterPromptTokens int64 // tokens after redaction (== PromptTokens unless filter shrunk it) + CachedTokens int64 // backend-reported KV-cache hit tokens + PrefillTokens int64 // backend-reported prefill tokens (subset of prompt) + DraftTokens int64 // speculative-decoding draft tokens + PricingVersionID string `gorm:"size:64;index"` // FK to pricing_version; "" when no pricing was applied + CostUSD float64 // computed at insert when pricing is available; 0 with empty PricingVersionID = unknown + + // Cross-subsystem correlation. Empty when the subsystem didn't run. + CorrelationID string `gorm:"size:64;index"` + RouterDecisionID string `gorm:"size:64;index"` + PIIEventID string `gorm:"size:64"` } // RecordUsage inserts a usage record. diff --git a/core/http/endpoints/anthropic/anthropic_suite_test.go b/core/http/endpoints/anthropic/anthropic_suite_test.go new file mode 100644 index 000000000000..0b88b92f24cf --- /dev/null +++ b/core/http/endpoints/anthropic/anthropic_suite_test.go @@ -0,0 +1,13 @@ +package anthropic + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestAnthropic(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Anthropic test suite") +} diff --git a/core/http/endpoints/anthropic/messages.go b/core/http/endpoints/anthropic/messages.go index 62e58a4a1889..448eb23c97e1 100644 --- a/core/http/endpoints/anthropic/messages.go +++ b/core/http/endpoints/anthropic/messages.go @@ -10,10 +10,13 @@ import ( "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/http/auth" mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp" openaiEndpoint "github.com/mudler/LocalAI/core/http/endpoints/openai" "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/core/services/cloudproxy" + "github.com/mudler/LocalAI/core/services/routing/pii" "github.com/mudler/LocalAI/core/templates" "github.com/mudler/LocalAI/pkg/functions" "github.com/mudler/LocalAI/pkg/model" @@ -27,7 +30,7 @@ import ( // @Param request body schema.AnthropicRequest true "query params" // @Success 200 {object} schema.AnthropicResponse "Response" // @Router /v1/messages [post] -func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig, natsClient mcpTools.MCPNATSClient) echo.HandlerFunc { +func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig, natsClient mcpTools.MCPNATSClient, piiRedactor *pii.Redactor, piiEvents pii.EventStore) echo.HandlerFunc { return func(c echo.Context) error { id := uuid.New().String() @@ -47,6 +50,13 @@ func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evalu xlog.Debug("Anthropic Messages endpoint configuration read", "config", cfg) + // Cloud-proxy bail. Same gate as the OpenAI chat endpoint — + // when Backend = proxy-* and an upstream URL is set, skip + // the local pipeline and forward to the upstream provider. + if cfg.IsCloudProxy() { + return forwardCloudProxyAnthropic(c, cfg, input, piiRedactor, piiEvents) + } + // Convert Anthropic messages to OpenAI format for internal processing openAIMessages := convertAnthropicToOpenAIMessages(input) @@ -132,7 +142,7 @@ func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evalu xlog.Debug("Anthropic Messages - Prompt (after templating)", "prompt", predInput) if input.Stream { - return handleAnthropicStream(c, id, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, mcpExecutor, evaluator) + return handleAnthropicStream(c, id, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, mcpExecutor, evaluator, piiRedactor, piiEvents) } return handleAnthropicNonStream(c, id, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, mcpExecutor, evaluator) @@ -313,17 +323,45 @@ func handleAnthropicNonStream(c echo.Context, id string, input *schema.Anthropic xlog.Debug("Anthropic Response", "response", string(respData)) } + middleware.StampUsage(c, input.Model, tokenUsage.Prompt, tokenUsage.Completion) + return c.JSON(200, resp) } // end MCP iteration loop return sendAnthropicError(c, 500, "api_error", "MCP iteration limit reached") } -func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, mcpExecutor mcpTools.ToolExecutor, evaluator *templates.Evaluator) error { +func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, mcpExecutor mcpTools.ToolExecutor, evaluator *templates.Evaluator, piiRedactor *pii.Redactor, piiEvents pii.EventStore) error { c.Response().Header().Set("Content-Type", "text/event-stream") c.Response().Header().Set("Cache-Control", "no-cache") c.Response().Header().Set("Connection", "keep-alive") + // Per-stream PII filter — same gating as the OpenAI chat path. The + // filter is wire-format-agnostic; we feed it the text portion of + // each text_delta and emit only what's safe to send. The filter + // holds back a tail of size MaxPatternLength-1 so a pattern split + // across chunk boundaries still gets masked. When PII is disabled + // for this model the filter is nil and emits flow unchanged. + var streamPIIFilter *pii.StreamFilter + if piiRedactor != nil && cfg.PIIIsEnabled() { + correlationID := c.Request().Header.Get("x-request-id") + userID := "" + if u := auth.GetUser(c); u != nil { + userID = u.ID + } + var overrides map[string]pii.Action + if raw := cfg.PIIPatternOverrides(); len(raw) > 0 { + overrides = make(map[string]pii.Action, len(raw)) + for ovid, action := range raw { + switch pii.Action(action) { + case pii.ActionMask, pii.ActionBlock, pii.ActionRouteLocal: + overrides[ovid] = pii.Action(action) + } + } + } + streamPIIFilter = pii.NewStreamFilter(piiRedactor, overrides, piiEvents, correlationID, userID) + } + // Send message_start event messageStart := schema.AnthropicStreamEvent{ Type: "message_start", @@ -403,6 +441,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq if len(toolCalls) > toolCallsEmitted { if !inToolCall && currentBlockIndex == 0 { + drainStreamPIIToText(c, streamPIIFilter, intPtr(currentBlockIndex)) sendAnthropicSSE(c, schema.AnthropicStreamEvent{ Type: "content_block_stop", Index: intPtr(currentBlockIndex), @@ -443,14 +482,20 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq } if !inToolCall && token != "" { - sendAnthropicSSE(c, schema.AnthropicStreamEvent{ - Type: "content_block_delta", - Index: intPtr(0), - Delta: &schema.AnthropicStreamDelta{ - Type: "text_delta", - Text: token, - }, - }) + out := token + if streamPIIFilter != nil { + out = streamPIIFilter.Push(token) + } + if out != "" { + sendAnthropicSSE(c, schema.AnthropicStreamEvent{ + Type: "content_block_delta", + Index: intPtr(0), + Delta: &schema.AnthropicStreamDelta{ + Type: "text_delta", + Text: out, + }, + }) + } } return true } @@ -488,14 +533,20 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq // didn't already stream it (autoparser clears raw text, so // accumulatedContent will be empty in that case). if deltaContent != "" && !inToolCall && accumulatedContent == "" { - sendAnthropicSSE(c, schema.AnthropicStreamEvent{ - Type: "content_block_delta", - Index: intPtr(0), - Delta: &schema.AnthropicStreamDelta{ - Type: "text_delta", - Text: deltaContent, - }, - }) + out := deltaContent + if streamPIIFilter != nil { + out = streamPIIFilter.Push(deltaContent) + } + if out != "" { + sendAnthropicSSE(c, schema.AnthropicStreamEvent{ + Type: "content_block_delta", + Index: intPtr(0), + Delta: &schema.AnthropicStreamDelta{ + Type: "text_delta", + Text: out, + }, + }) + } } // Emit tool_use blocks from ChatDeltas @@ -503,6 +554,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq collectedToolCalls = deltaToolCalls if !inToolCall && currentBlockIndex == 0 { + drainStreamPIIToText(c, streamPIIFilter, intPtr(currentBlockIndex)) sendAnthropicSSE(c, schema.AnthropicStreamEvent{ Type: "content_block_stop", Index: intPtr(currentBlockIndex), @@ -606,7 +658,9 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq if !shouldUseFn && cfg.FunctionsConfig.AutomaticToolParsingFallback && accumulatedContent != "" && toolCallsEmitted == 0 { parsed := functions.ParseFunctionCall(accumulatedContent, cfg.FunctionsConfig) if len(parsed) > 0 { - // Close the text content block + // Close the text content block (after flushing any + // residual the streaming PII filter held back). + drainStreamPIIToText(c, streamPIIFilter, intPtr(currentBlockIndex)) sendAnthropicSSE(c, schema.AnthropicStreamEvent{ Type: "content_block_stop", Index: intPtr(currentBlockIndex), @@ -646,8 +700,12 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq } } - // No MCP tools to execute, close stream + // No MCP tools to execute, close stream. drainStreamPIIToText + // flushes any residual the streaming PII filter held back as + // part of its trailing pattern-window before we close the + // text content block. if !inToolCall { + drainStreamPIIToText(c, streamPIIFilter, intPtr(0)) sendAnthropicSSE(c, schema.AnthropicStreamEvent{ Type: "content_block_stop", Index: intPtr(0), @@ -673,6 +731,8 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq Type: "message_stop", }) + middleware.StampUsage(c, input.Model, tokenUsage.Prompt, tokenUsage.Completion) + return nil } // end MCP iteration loop @@ -693,6 +753,30 @@ func convertFuncsToOpenAITools(funcs functions.Functions) []functions.Tool { func intPtr(i int) *int { return &i } +// drainStreamPIIToText flushes any residual the streaming PII filter +// has been holding back as part of its trailing pattern-window, and +// emits it as one final text_delta into the named block before the +// caller closes that block. Drain is idempotent: calling it twice on +// the same filter returns "" the second time. Safe to call with a nil +// filter (no-op). +func drainStreamPIIToText(c echo.Context, sf *pii.StreamFilter, index *int) { + if sf == nil { + return + } + residual := sf.Drain() + if residual == "" { + return + } + sendAnthropicSSE(c, schema.AnthropicStreamEvent{ + Type: "content_block_delta", + Index: index, + Delta: &schema.AnthropicStreamDelta{ + Type: "text_delta", + Text: residual, + }, + }) +} + func sendAnthropicSSE(c echo.Context, event schema.AnthropicStreamEvent) { data, err := json.Marshal(event) if err != nil { @@ -888,3 +972,39 @@ func convertAnthropicTools(input *schema.AnthropicRequest, cfg *config.ModelConf return funcs, len(funcs) > 0 && cfg.ShouldUseFunctions() } + +// forwardCloudProxyAnthropic mirrors the OpenAI cloud-proxy fork: +// builds the streaming PII filter (when applicable) and forwards +// the request body to the upstream Anthropic-shaped provider via +// the wire-format-faithful proxy. The model name swap and the +// upstream auth headers are applied inside cloudproxy; the pii +// filter is constructed here because the auth/correlation context +// only exists in the echo handler. +func forwardCloudProxyAnthropic(c echo.Context, cfg *config.ModelConfig, input *schema.AnthropicRequest, piiRedactor *pii.Redactor, piiEvents pii.EventStore) error { + body, err := json.Marshal(input) + if err != nil { + return sendAnthropicError(c, 400, "invalid_request_error", "cloudproxy: marshal request: "+err.Error()) + } + + var streamFilter *pii.StreamFilter + if input.Stream && piiRedactor != nil && cfg.PIIIsEnabled() { + correlationID := c.Request().Header.Get("x-request-id") + userID := "" + if u := auth.GetUser(c); u != nil { + userID = u.ID + } + var overrides map[string]pii.Action + if raw := cfg.PIIPatternOverrides(); len(raw) > 0 { + overrides = make(map[string]pii.Action, len(raw)) + for ovid, action := range raw { + switch pii.Action(action) { + case pii.ActionMask, pii.ActionBlock, pii.ActionRouteLocal: + overrides[ovid] = pii.Action(action) + } + } + } + streamFilter = pii.NewStreamFilter(piiRedactor, overrides, piiEvents, correlationID, userID) + } + + return cloudproxy.Forward(c, cfg, body, streamFilter) +} diff --git a/core/http/endpoints/anthropic/messages_pii_test.go b/core/http/endpoints/anthropic/messages_pii_test.go new file mode 100644 index 000000000000..91e5297e4f31 --- /dev/null +++ b/core/http/endpoints/anthropic/messages_pii_test.go @@ -0,0 +1,114 @@ +package anthropic + +import ( + "net/http" + "net/http/httptest" + "strings" + + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/services/routing/pii" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +// drainStreamPIIToText is called from four sites in messages.go and is +// the load-bearing primitive for "the streaming filter has buffered +// some bytes that the request just ended on; flush them as a final +// text_delta event before closing the content block". A regression +// here would silently truncate the last few bytes of an assistant +// response on every PII-enabled stream — invisible without coverage. + +// newTestFilter compiles the default patterns and returns a filter +// that holds back its trailing pattern-window; pushing a short string +// (shorter than holdLen) keeps the bytes inside Drain. +func newTestFilter() *pii.StreamFilter { + patterns, err := pii.Compile(pii.DefaultPatterns()) + ExpectWithOffset(1, err).NotTo(HaveOccurred()) + red := pii.NewRedactor(patterns) + return pii.NewStreamFilter(red, nil, nil, "", "") +} + +// newTestContext builds a recording echo context — the recorder +// captures the SSE bytes drainStreamPIIToText writes. +func newTestContext() (echo.Context, *httptest.ResponseRecorder) { + req := httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader("{}")) + rec := httptest.NewRecorder() + return echo.New().NewContext(req, rec), rec +} + +var _ = Describe("drainStreamPIIToText", func() { + It("is a no-op when the filter is nil", func() { + c, rec := newTestContext() + drainStreamPIIToText(c, nil, intPtr(0)) + Expect(rec.Body.Len()).To(Equal(0), "nil filter wrote %d bytes: %q", rec.Body.Len(), rec.Body.String()) + }) + + It("emits nothing when the drain is empty", func() { + // A filter with nothing buffered should not emit a phantom event; + // otherwise every non-PII response would close with an empty + // text_delta that pollutes downstream parsers. + sf := newTestFilter() + c, rec := newTestContext() + drainStreamPIIToText(c, sf, intPtr(0)) + Expect(rec.Body.Len()).To(Equal(0), "empty drain wrote %d bytes: %q", rec.Body.Len(), rec.Body.String()) + }) + + It("flushes residual buffered bytes as a text_delta event", func() { + sf := newTestFilter() + // Push less than holdLen so all bytes are retained until Drain. + // "tail" is short enough that no pattern is plausible. + out := sf.Push("tail") + Expect(out).To(Equal(""), "Push of short text emitted %q; want all bytes held", out) + + c, rec := newTestContext() + drainStreamPIIToText(c, sf, intPtr(2)) + + body := rec.Body.String() + // Wire format: "event: content_block_delta\ndata: {…}\n\n" + Expect(body).To(ContainSubstring("event: content_block_delta")) + Expect(body).To(ContainSubstring(`"type":"content_block_delta"`)) + Expect(body).To(ContainSubstring(`"index":2`)) + Expect(body).To(ContainSubstring(`"text":"tail"`)) + Expect(body).To(ContainSubstring(`"type":"text_delta"`)) + Expect(strings.HasSuffix(body, "\n\n")).To(BeTrue(), "SSE event missing trailing blank line: %q", body) + }) + + It("is idempotent across consecutive drains", func() { + // Two consecutive Drains: the filter returns "" the second time, + // so the second drainStreamPIIToText must emit nothing. The + // production path in messages.go has at least four call sites + // that may overlap (currentBlockIndex==0 emergency path + the + // unconditional drain near the end of the stream); without + // idempotence we'd duplicate the residual on the wire. + sf := newTestFilter() + sf.Push("tail") + + c1, rec1 := newTestContext() + drainStreamPIIToText(c1, sf, intPtr(0)) + first := rec1.Body.Len() + Expect(first).NotTo(Equal(0), "first drain emitted nothing") + + c2, rec2 := newTestContext() + drainStreamPIIToText(c2, sf, intPtr(0)) + Expect(rec2.Body.Len()).To(Equal(0), "second drain wrote %d bytes; want idempotent no-op: %q", rec2.Body.Len(), rec2.Body.String()) + }) + + It("masks redacted residual instead of leaking it", func() { + // The held tail must travel through the redactor on Drain. If + // the bytes happen to form a complete pattern at end-of-stream, + // the residual emit must contain the mask placeholder, not the + // raw value. + sf := newTestFilter() + // "alice@example.com" is 17 bytes. holdLen for default patterns + // is well above 17, so this stays buffered until Drain, which + // then redacts it. + out := sf.Push("alice@example.com") + Expect(out).To(Equal(""), "Push emitted bytes early: %q", out) + + c, rec := newTestContext() + drainStreamPIIToText(c, sf, intPtr(0)) + body := rec.Body.String() + Expect(body).NotTo(ContainSubstring("alice@example.com"), "raw email leaked in residual emit: %q", body) + Expect(body).To(ContainSubstring("[REDACTED:email]"), "residual emit missing mask placeholder: %q", body) + }) +}) diff --git a/core/http/endpoints/localai/api_instructions.go b/core/http/endpoints/localai/api_instructions.go index 103c87443209..8166bb3a196f 100644 --- a/core/http/endpoints/localai/api_instructions.go +++ b/core/http/endpoints/localai/api_instructions.go @@ -92,6 +92,30 @@ var instructionDefs = []instructionDef{ Tags: []string{"branding"}, Intro: "GET /api/branding is public so the login screen can render the configured logo before authentication. Text fields are saved through POST /api/settings; binary assets (logo, horizontal logo, favicon) use multipart upload at /api/branding/asset/{kind} and are served back from /branding/asset/{kind}.", }, + { + Name: "usage-and-billing", + Description: "Per-user token usage and request counts, with optional cost tracking", + Tags: []string{"usage"}, + Intro: "GET /api/usage returns the current user's token usage in time-bucketed form (day/week/month/all). In single-user no-auth mode the records are attributed to a synthetic local user with stable UUID, so this endpoint and the dashboard work without --auth. /api/usage/all is the cluster-wide view and requires admin (the local user is admin in single-user mode). UsageRecord fields include RequestedModel/ServedModel and PreFilter/PostFilterPromptTokens for routing- and PII-aware accounting.", + }, + { + Name: "pii-filtering", + Description: "Inspect and tune the regex PII filter applied to chat requests", + Tags: []string{"pii"}, + Intro: "GET /api/pii/patterns lists the active pattern set with each one's action (mask, block, route_local). GET /api/pii/events returns recent redaction events filtered by correlation_id / user_id / pattern_id (admin or local-user only). POST /api/pii/test dry-runs the redactor against an admin-supplied string. Default patterns: email, phone, SSN, credit card (Luhn), IPv4, common API key prefixes (sk-, pk-, ghp_, github_pat_). PII is per-model: by default it is OFF for non-proxy backends and ON for backends starting with proxy-* (cloud passthroughs). Opt in with `pii: { enabled: true }` in a model's YAML; use `pii: { patterns: [{id, action}] }` to upgrade or downgrade individual actions for that model. Override global default actions via --pii-config pii.yaml; --disable-pii turns the filter off entirely.", + }, + { + Name: "middleware-admin", + Description: "Inspect and configure the routing-module middleware (PII filter and routing)", + Tags: []string{"middleware", "pii", "router"}, + Intro: "GET /api/middleware/status is the single round-trip the /app/middleware admin page reads to render the current state: active PII patterns and their actions, every model's resolved enabled/override state, recent event count, and the active routing models with their classifier configurations. Admin-only (the synthetic local user is admin in no-auth mode). PUT /api/pii/patterns/:id changes a pattern's action in-process — TRANSIENT, lost on restart. To persist, edit --pii-config YAML. GET /api/router/decisions returns the routing decision log filtered by correlation_id / user_id / router_model. The same surface is exposed as MCP tools (`get_middleware_status`, `set_pii_pattern_action`, `get_router_decisions`) for agent-driven configuration.", + }, + { + Name: "intelligent-routing", + Description: "Per-model `router:` configuration that classifies requests and rewrites the served model", + Tags: []string{"router"}, + Intro: "Add a `router:` block to a ModelConfig to turn it into a routing model. The block declares a classifier (today: `feature` — handcrafted rules over prompt length and code-fence presence), a list of candidates (label + downstream model + optional rule), and a fallback. When a client addresses the routing model, the RouteModel middleware invokes the classifier, picks a candidate, and rewrites input.Model — the standard model-resolution path then runs ACL, disabled-state, and per-model PII against the chosen target. Depth-1 invariant: candidates must NOT themselves carry a `router:` block; runtime check returns 500 on violation. Decisions are logged to GET /api/router/decisions and surfaced in the /app/middleware Routing tab.", + }, } // swaggerState holds parsed swagger spec data, initialised once. diff --git a/core/http/endpoints/localai/api_instructions_test.go b/core/http/endpoints/localai/api_instructions_test.go index 35bdfa2399d9..70ae717659ad 100644 --- a/core/http/endpoints/localai/api_instructions_test.go +++ b/core/http/endpoints/localai/api_instructions_test.go @@ -39,7 +39,7 @@ var _ = Describe("API Instructions Endpoints", func() { instructions, ok := resp["instructions"].([]any) Expect(ok).To(BeTrue()) - Expect(instructions).To(HaveLen(12)) + Expect(instructions).To(HaveLen(16)) // Verify each instruction has required fields and correct URL format for _, s := range instructions { @@ -74,6 +74,10 @@ var _ = Describe("API Instructions Endpoints", func() { "monitoring", "agents", "face-recognition", + "usage-and-billing", + "pii-filtering", + "middleware-admin", + "intelligent-routing", )) }) }) diff --git a/core/http/endpoints/localai/mcp.go b/core/http/endpoints/localai/mcp.go index 541d4963b301..a849e8a2fbb5 100644 --- a/core/http/endpoints/localai/mcp.go +++ b/core/http/endpoints/localai/mcp.go @@ -61,7 +61,11 @@ func MCPEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator // The legacy /v1/mcp/chat/completions endpoint never opts into the // in-process LocalAI Assistant tool surface — pass nil holder so the // assistant branch in chat.go is unreachable from this code path. - chatHandler := openai.ChatEndpoint(cl, ml, evaluator, appConfig, natsClient, nil) + // Stream-side PII filter is also nil: this legacy endpoint pre-dates + // the per-model PII config and is kept for backward compatibility. + // The request-side middleware on the main chat route handles + // filtering for the standard /v1/chat/completions path. + chatHandler := openai.ChatEndpoint(cl, ml, evaluator, appConfig, natsClient, nil, nil, nil) return func(c echo.Context) error { input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) diff --git a/core/http/endpoints/localai/score.go b/core/http/endpoints/localai/score.go new file mode 100644 index 000000000000..cfbdd1d4261e --- /dev/null +++ b/core/http/endpoints/localai/score.go @@ -0,0 +1,90 @@ +package localai + +import ( + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/backend" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/pkg/model" +) + +// ScoreRequest is the wire format for POST /api/score. Mirrors the +// gRPC ScoreRequest one-to-one — the endpoint exists primarily to +// smoke-test the new Score primitive end-to-end without writing a +// custom gRPC client. Production routing will call backend.ModelScore +// directly via the router-side adapter. +type ScoreRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + Candidates []string `json:"candidates"` + IncludeTokenLogprobs bool `json:"include_token_logprobs,omitempty"` + LengthNormalize bool `json:"length_normalize,omitempty"` +} + +type ScoreResponseCandidate struct { + LogProb float64 `json:"log_prob"` + LengthNormalizedLogProb float64 `json:"length_normalized_log_prob,omitempty"` + NumTokens int `json:"num_tokens"` + Tokens []ScoreTokenLP `json:"tokens,omitempty"` +} + +type ScoreTokenLP struct { + Token string `json:"token"` + LogProb float64 `json:"log_prob"` +} + +type ScoreResponse struct { + Model string `json:"model"` + Candidates []ScoreResponseCandidate `json:"candidates"` +} + +// ScoreEndpoint exposes the Score gRPC primitive over HTTP. Admin-only — +// scoring loads a model and runs inference, same risk surface as +// /v1/chat/completions. +func ScoreEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { + return func(c echo.Context) error { + var req ScoreRequest + if err := c.Bind(&req); err != nil { + return echo.NewHTTPError(400, "invalid request body: "+err.Error()) + } + if req.Model == "" { + return echo.NewHTTPError(400, "model is required") + } + if len(req.Candidates) == 0 { + return echo.NewHTTPError(400, "candidates must be non-empty") + } + + modelConfig, err := cl.LoadModelConfigFileByNameDefaultOptions(req.Model, appConfig) + if err != nil || modelConfig == nil { + return echo.NewHTTPError(404, "model not found: "+req.Model) + } + + fn, err := backend.ModelScore(req.Prompt, req.Candidates, backend.ScoreOptions{ + IncludeTokenLogprobs: req.IncludeTokenLogprobs, + LengthNormalize: req.LengthNormalize, + }, ml, *modelConfig, appConfig) + if err != nil { + return echo.NewHTTPError(500, "failed to bind scorer: "+err.Error()) + } + results, err := fn(c.Request().Context()) + if err != nil { + return echo.NewHTTPError(500, "score call failed: "+err.Error()) + } + + out := ScoreResponse{Model: req.Model, Candidates: make([]ScoreResponseCandidate, len(results))} + for i, r := range results { + out.Candidates[i] = ScoreResponseCandidate{ + LogProb: r.LogProb, + LengthNormalizedLogProb: r.LengthNormalizedLogProb, + NumTokens: r.NumTokens, + } + if req.IncludeTokenLogprobs && len(r.Tokens) > 0 { + toks := make([]ScoreTokenLP, len(r.Tokens)) + for j, t := range r.Tokens { + toks[j] = ScoreTokenLP{Token: t.Token, LogProb: t.LogProb} + } + out.Candidates[i].Tokens = toks + } + } + return c.JSON(200, out) + } +} diff --git a/core/http/endpoints/localai/settings.go b/core/http/endpoints/localai/settings.go index 0e29e39ccefe..1db87e313dcb 100644 --- a/core/http/endpoints/localai/settings.go +++ b/core/http/endpoints/localai/settings.go @@ -253,6 +253,16 @@ func UpdateSettingsEndpoint(app *application.Application) echo.HandlerFunc { } } + if settings.MITMListen != nil { + if err := app.RestartMITM(); err != nil { + xlog.Error("Failed to restart MITM proxy", "error", err) + return c.JSON(http.StatusInternalServerError, schema.SettingsResponse{ + Success: false, + Error: "Settings saved but failed to restart MITM proxy: " + err.Error(), + }) + } + } + // Restart P2P if P2P settings changed p2pChanged := settings.P2PToken != nil || settings.P2PNetworkID != nil || settings.Federated != nil if p2pChanged { diff --git a/core/http/endpoints/mcp/localai_assistant_test.go b/core/http/endpoints/mcp/localai_assistant_test.go index bf701b0e9517..a37e3234e06f 100644 --- a/core/http/endpoints/mcp/localai_assistant_test.go +++ b/core/http/endpoints/mcp/localai_assistant_test.go @@ -74,6 +74,34 @@ func (stubClient) GetBranding(_ context.Context) (*localaitools.Branding, error) func (stubClient) SetBranding(_ context.Context, _ localaitools.SetBrandingRequest) (*localaitools.Branding, error) { return &localaitools.Branding{InstanceName: "LocalAI"}, nil } +func (stubClient) GetUsageStats(_ context.Context, _ localaitools.UsageStatsQuery) (*localaitools.UsageStats, error) { + return &localaitools.UsageStats{Viewer: localaitools.UsageViewer{ID: "stub", Name: "stub"}, Period: "month"}, nil +} +func (stubClient) ListPIIPatterns(_ context.Context) ([]localaitools.PIIPattern, error) { + return nil, nil +} +func (stubClient) GetPIIEvents(_ context.Context, _ localaitools.PIIEventsQuery) ([]localaitools.PIIEvent, error) { + return nil, nil +} +func (stubClient) TestPIIRedaction(_ context.Context, req localaitools.PIIRedactTestRequest) (*localaitools.PIIRedactTestResult, error) { + return &localaitools.PIIRedactTestResult{Redacted: req.Text}, nil +} +func (stubClient) SetPIIPatternAction(_ context.Context, _ localaitools.PIIPatternActionUpdate) error { + return nil +} +func (stubClient) PersistPIIPatterns(_ context.Context) error { return nil } +func (stubClient) GetMiddlewareStatus(_ context.Context) (*localaitools.MiddlewareStatus, error) { + return &localaitools.MiddlewareStatus{ + PII: localaitools.MiddlewarePIIStatus{ + EnabledGlobally: true, + Patterns: []localaitools.PIIPattern{}, + Models: []localaitools.MiddlewarePIIModel{}, + }, + }, nil +} +func (stubClient) GetRouterDecisions(_ context.Context, _ localaitools.RouterDecisionsQuery) ([]localaitools.RouterDecision, error) { + return []localaitools.RouterDecision{}, nil +} var _ = Describe("LocalAIAssistantHolder", func() { var ctx context.Context diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go index 0951a88ccde1..3ea6215e3563 100644 --- a/core/http/endpoints/openai/chat.go +++ b/core/http/endpoints/openai/chat.go @@ -10,9 +10,12 @@ import ( "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/http/auth" mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp" "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/core/services/cloudproxy" + "github.com/mudler/LocalAI/core/services/routing/pii" "github.com/mudler/LocalAI/pkg/functions" reason "github.com/mudler/LocalAI/pkg/reasoning" @@ -72,7 +75,7 @@ func mergeToolCallDeltas(existing []schema.ToolCall, deltas []schema.ToolCall) [ // @Param request body schema.OpenAIRequest true "query params" // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /v1/chat/completions [post] -func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig, natsClient mcpTools.MCPNATSClient, assistantHolder *mcpTools.LocalAIAssistantHolder) echo.HandlerFunc { +func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig, natsClient mcpTools.MCPNATSClient, assistantHolder *mcpTools.LocalAIAssistantHolder, piiRedactor *pii.Redactor, piiEvents pii.EventStore) echo.HandlerFunc { process := func(s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool, id string, created int) error { initialMessage := schema.OpenAIResponse{ ID: id, @@ -449,6 +452,17 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator xlog.Debug("Chat endpoint configuration read", "config", config) + // Cloud-proxy bail. When the resolved model is configured as + // a cloud passthrough (Backend = proxy-* and a non-empty + // upstream URL), bypass the entire local pipeline — + // templating, MCP injection, gRPC backend — and forward the + // request to the upstream provider. The streaming PII + // filter still runs because its input is per-token text + // extracted from the wire envelope, not the envelope itself. + if config.IsCloudProxy() { + return forwardCloudProxyOpenAI(c, config, input, piiRedactor, piiEvents) + } + funcs := input.Functions shouldUseFn := len(input.Functions) > 0 && config.ShouldUseFunctions() strictMode := false @@ -683,6 +697,42 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator c.Response().Header().Set("Connection", "keep-alive") c.Response().Header().Set("X-Correlation-ID", id) + // Per-stream PII filter: when the resolved model has PII + // enabled (per the per-model gate the request-side + // middleware also reads), wrap the response content so + // values that span chunk boundaries still get masked. The + // filter is gated on the same ModelConfig accessor as the + // request middleware, so a user that disabled PII on the + // model gets no filter on either direction. + var streamPIIFilter *pii.StreamFilter + if piiRedactor != nil && config.PIIIsEnabled() { + correlationID := c.Response().Header().Get("X-Correlation-ID") + userID := "" + if u := auth.GetUser(c); u != nil { + userID = u.ID + } + // Per-model action overrides go through the same map + // the request-side middleware uses; convert raw YAML + // strings to typed Actions and drop unknowns. + var overrides map[string]pii.Action + if raw := config.PIIPatternOverrides(); len(raw) > 0 { + overrides = make(map[string]pii.Action, len(raw)) + for id, action := range raw { + switch pii.Action(action) { + case pii.ActionMask, pii.ActionBlock, pii.ActionRouteLocal: + overrides[id] = pii.Action(action) + } + } + } + streamPIIFilter = pii.NewStreamFilter( + piiRedactor, + overrides, + piiEvents, + correlationID, + userID, + ) + } + mcpStreamMaxIterations := 10 if config.Agent.MaxIterations > 0 { mcpStreamMaxIterations = config.Agent.MaxIterations @@ -739,7 +789,10 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator collectedToolCalls = mergeToolCallDeltas(collectedToolCalls, ev.Choices[0].Delta.ToolCalls) } } - // Collect content for MCP conversation history and automatic tool parsing fallback + // Collect content for MCP conversation history and automatic tool parsing fallback. + // We collect the RAW (unfiltered) content so the model's tool-call + // markup keeps parsing correctly even when PII redaction would mask + // substrings. if (hasMCPToolsStream || config.FunctionsConfig.AutomaticToolParsingFallback) && ev.Choices[0].Delta != nil && ev.Choices[0].Delta.Content != nil { if s, ok := ev.Choices[0].Delta.Content.(string); ok { collectedContent += s @@ -747,6 +800,39 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator collectedContent += *sp } } + // Stream-side PII filter: feed the content delta + // through the buffered-emit filter. The filter + // holds back a tail to handle pattern boundaries + // across chunks, so a Push may legitimately + // return "" — drop the chunk in that case rather + // than emitting an empty Delta to the wire. + if streamPIIFilter != nil && ev.Choices[0].Delta != nil && ev.Choices[0].Delta.Content != nil { + var raw string + switch v := ev.Choices[0].Delta.Content.(type) { + case string: + raw = v + case *string: + if v != nil { + raw = *v + } + } + filtered := streamPIIFilter.Push(raw) + if filtered == "" { + // Fully buffered — skip this chunk's + // content. Still emit non-content chunks + // (role, tool_calls). When this delta is + // content-only and we buffer it, drop the + // whole event to avoid a vestigial + // {"delta":{}} on the wire. + if ev.Choices[0].Delta.Role == "" && len(ev.Choices[0].Delta.ToolCalls) == 0 && ev.Choices[0].Delta.Reasoning == nil { + continue + } + // Mixed delta — strip content, keep the rest. + ev.Choices[0].Delta.Content = nil + } else { + ev.Choices[0].Delta.Content = filtered + } + } // OpenAI streaming spec: intermediate chunks must NOT // carry a `usage` field. Strip the tracking copy // before marshalling — usage is delivered via the @@ -797,7 +883,10 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator // still trying to send (e.g., after client disconnect). The goroutine // calls close(responses) when done, which terminates the drain. if input.Context.Err() != nil { - go func() { for range responses {} }() + go func() { + for range responses { + } + }() <-ended } @@ -892,6 +981,31 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator } } + // Drain the per-stream PII filter before the stop chunk + // so any text held back by the buffered-emit invariant + // reaches the client as a regular content delta. We + // emit it as a chunk WITHOUT a finish_reason so the + // next "stop" chunk still terminates the stream. + if streamPIIFilter != nil { + residual := streamPIIFilter.Drain() + if residual != "" { + drainResp := &schema.OpenAIResponse{ + ID: id, + Created: created, + Model: input.Model, + Choices: []schema.Choice{{ + Delta: &schema.Message{Content: residual}, + Index: 0, + }}, + Object: "chat.completion.chunk", + } + if drainBytes, err := json.Marshal(drainResp); err == nil { + _, _ = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", drainBytes) + c.Response().Flush() + } + } + } + // No MCP tools to execute, send final stop message finishReason := FinishReasonStop if toolsCalled && len(input.Tools) > 0 { @@ -916,6 +1030,14 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator Object: "chat.completion.chunk", } respData, _ := json.Marshal(resp) + + pt, ct := 0, 0 + if usage != nil { + pt = usage.PromptTokens + ct = usage.CompletionTokens + } + middleware.StampUsage(c, input.Model, pt, ct) + fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData) // Trailing usage chunk per OpenAI spec: emit only when the @@ -1290,6 +1412,8 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator respData, _ := json.Marshal(resp) xlog.Debug("Response", "response", string(respData)) + middleware.StampUsage(c, input.Model, usage.PromptTokens, usage.CompletionTokens) + // Return the prediction in the response body return c.JSON(200, resp) } // end MCP iteration loop @@ -1336,3 +1460,43 @@ func handleQuestion(config *config.ModelConfig, funcResults []functions.FuncCall return "", nil } + +// forwardCloudProxyOpenAI builds the streaming PII filter (when this +// model has PII enabled) and hands the request off to the cloudproxy +// package. The chat endpoint is the only place the OpenAI request +// lands as a parsed *schema.OpenAIRequest, so we do the model rewrite +// + body marshalling here rather than inside cloudproxy (which keeps +// that package free of schema imports). +// +// Defined here rather than as a method on cloudproxy because the +// caller owns the lifecycle of the PII filter — every other streaming +// path in this file constructs its own filter inline, and we keep the +// proxy fork structurally similar. +func forwardCloudProxyOpenAI(c echo.Context, cfg *config.ModelConfig, input *schema.OpenAIRequest, piiRedactor *pii.Redactor, piiEvents pii.EventStore) error { + body, err := json.Marshal(input) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, "cloudproxy: marshal request: "+err.Error()) + } + + var streamFilter *pii.StreamFilter + if input.Stream && piiRedactor != nil && cfg.PIIIsEnabled() { + correlationID := c.Response().Header().Get("X-Correlation-ID") + userID := "" + if u := auth.GetUser(c); u != nil { + userID = u.ID + } + var overrides map[string]pii.Action + if raw := cfg.PIIPatternOverrides(); len(raw) > 0 { + overrides = make(map[string]pii.Action, len(raw)) + for ovid, action := range raw { + switch pii.Action(action) { + case pii.ActionMask, pii.ActionBlock, pii.ActionRouteLocal: + overrides[ovid] = pii.Action(action) + } + } + } + streamFilter = pii.NewStreamFilter(piiRedactor, overrides, piiEvents, correlationID, userID) + } + + return cloudproxy.Forward(c, cfg, body, streamFilter) +} diff --git a/core/http/endpoints/openai/completion.go b/core/http/endpoints/openai/completion.go index f81e13e6a9b9..fdcd310cfee6 100644 --- a/core/http/endpoints/openai/completion.go +++ b/core/http/endpoints/openai/completion.go @@ -9,10 +9,12 @@ import ( "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/http/auth" "github.com/mudler/LocalAI/core/http/middleware" "github.com/google/uuid" "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/core/services/routing/pii" "github.com/mudler/LocalAI/core/templates" "github.com/mudler/LocalAI/pkg/functions" "github.com/mudler/LocalAI/pkg/model" @@ -25,7 +27,7 @@ import ( // @Param request body schema.OpenAIRequest true "query params" // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /v1/completions [post] -func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc { +func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig, piiRedactor *pii.Redactor, piiEvents pii.EventStore) echo.HandlerFunc { process := func(id string, s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) error { tokenCallback := func(s string, tokenUsage backend.TokenUsage) bool { created := int(time.Now().Unix()) @@ -111,6 +113,31 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva return errors.New("cannot handle more than 1 `PromptStrings` when Streaming") } + // Per-stream PII filter — same gating as chat. /v1/completions + // has no chat-message structure, so request-side PII isn't + // wired here, but the response-side filter still catches PII + // trained into the model. Filter is nil when this model has + // PII disabled. + var streamPIIFilter *pii.StreamFilter + if piiRedactor != nil && config.PIIIsEnabled() { + correlationID := id + userID := "" + if u := auth.GetUser(c); u != nil { + userID = u.ID + } + var overrides map[string]pii.Action + if raw := config.PIIPatternOverrides(); len(raw) > 0 { + overrides = make(map[string]pii.Action, len(raw)) + for ovid, action := range raw { + switch pii.Action(action) { + case pii.ActionMask, pii.ActionBlock, pii.ActionRouteLocal: + overrides[ovid] = pii.Action(action) + } + } + } + streamPIIFilter = pii.NewStreamFilter(piiRedactor, overrides, piiEvents, correlationID, userID) + } + predInput := config.PromptStrings[0] templatedInput, err := evaluator.EvaluateTemplateForPrompt(templates.CompletionPromptTemplate, *config, templates.PromptTemplateData{ @@ -143,12 +170,28 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva } // Capture running cumulative usage for the optional trailer // emitted after the final stop chunk when include_usage=true. + // Done before the PII filter so a fully-buffered chunk + // (which we drop from the wire) still contributes to the + // running total. if ev.Usage != nil { latestUsage = ev.Usage } // OpenAI streaming spec: intermediate chunks must NOT // carry a `usage` field. Strip the tracking copy now. ev.Usage = nil + // Run the per-chunk text through the streaming PII + // filter. The filter holds back a tail to handle + // pattern boundaries, so a Push may legitimately + // return "" — drop the chunk's text rather than + // emitting a 0-token delta. Choice.Text is the only + // content surface in /v1/completions chunks. + if streamPIIFilter != nil && ev.Choices[0].Text != "" { + filtered := streamPIIFilter.Push(ev.Choices[0].Text) + if filtered == "" { + continue + } + ev.Choices[0].Text = filtered + } respData, err := json.Marshal(ev) if err != nil { xlog.Debug("Failed to marshal response", "error", err) @@ -194,6 +237,25 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva } } + // Flush any residual the streaming PII filter held back as + // part of its trailing pattern-window. Emit it as one final + // text-bearing chunk before the synthetic stop chunk so the + // completion body remains a contiguous text stream. + if streamPIIFilter != nil { + if residual := streamPIIFilter.Drain(); residual != "" { + residualResp := schema.OpenAIResponse{ + ID: id, + Created: created, + Model: input.Model, + Choices: []schema.Choice{{Index: 0, Text: residual}}, + Object: "text_completion", + } + if data, err := json.Marshal(residualResp); err == nil { + _, _ = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(data)) + } + } + } + stopReason := FinishReasonStop resp := &schema.OpenAIResponse{ ID: id, @@ -208,6 +270,14 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva Object: "text_completion", } respData, _ := json.Marshal(resp) + + pt, ct := 0, 0 + if latestUsage != nil { + pt = latestUsage.PromptTokens + ct = latestUsage.CompletionTokens + } + middleware.StampUsage(c, input.Model, pt, ct) + fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData) // Trailing usage chunk per OpenAI spec: emit only when the caller @@ -274,6 +344,8 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva jsonResult, _ := json.Marshal(resp) xlog.Debug("Response", "response", string(jsonResult)) + middleware.StampUsage(c, input.Model, totalTokenUsage.Prompt, totalTokenUsage.Completion) + // Return the prediction in the response body return c.JSON(200, resp) } diff --git a/core/http/endpoints/openai/edit.go b/core/http/endpoints/openai/edit.go index 9a51989167fb..5258fddb1f53 100644 --- a/core/http/endpoints/openai/edit.go +++ b/core/http/endpoints/openai/edit.go @@ -98,6 +98,8 @@ func EditEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator jsonResult, _ := json.Marshal(resp) xlog.Debug("Response", "response", string(jsonResult)) + middleware.StampUsage(c, input.Model, totalTokenUsage.Prompt, totalTokenUsage.Completion) + // Return the prediction in the response body return c.JSON(200, resp) } diff --git a/core/http/endpoints/openai/embeddings.go b/core/http/endpoints/openai/embeddings.go index 517881f66313..96fd5efc8aac 100644 --- a/core/http/endpoints/openai/embeddings.go +++ b/core/http/endpoints/openai/embeddings.go @@ -102,6 +102,15 @@ func EmbeddingsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app jsonResult, _ := json.Marshal(resp) xlog.Debug("Response", "response", string(jsonResult)) + // LocalAI's embeddings endpoint does not currently track per-call + // token counts (the gRPC Embedding RPC returns a vector, not a + // usage block), so we stamp with zeros. The point of stamping is + // that the billing pipeline still sees the request and emits the + // localai_billed_requests_total counter; without this the call + // would be silently dropped by the unrecorded-counter path. When + // embeddings learn to report usage, swap the zeros for real counts. + middleware.StampUsage(c, input.Model, 0, 0) + // Return the prediction in the response body return c.JSON(200, resp) } diff --git a/core/http/middleware/admission.go b/core/http/middleware/admission.go new file mode 100644 index 000000000000..c79066925d3b --- /dev/null +++ b/core/http/middleware/admission.go @@ -0,0 +1,81 @@ +package middleware + +import ( + "context" + "crypto/rand" + "encoding/hex" + "fmt" + "net/http" + "strconv" + "sync/atomic" + "time" + + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/services/routing/admission" + "github.com/mudler/LocalAI/core/services/routing/pii" +) + +// AdmissionControl runs after RouteModel so the limit applies to the +// SERVED model — a router fanout that lands on a saturated downstream +// model gets rejected even though the requested router-model has slack. +// +// On reject: HTTP 503, Retry-After header, error JSON. An audit row +// goes into the shared event store under KindAdmission so admins see +// rejection rates alongside PII and proxy events. +// +// Models without limits.max_concurrent (the common case) hit a fast +// no-op path — Acquire returns immediately for max <= 0. +func AdmissionControl(limiter *admission.Limiter, events pii.EventStore) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + cfg, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) + if !ok || cfg == nil { + return next(c) + } + max := cfg.Limits.MaxConcurrent + release, ok := limiter.Acquire(cfg.Name, max) + if !ok { + retryAfter := admission.RetryAfter(cfg.Limits.RetryAfterSeconds) + recordAdmissionRejection(events, cfg.Name, retryAfter) + c.Response().Header().Set("Retry-After", strconv.Itoa(int(retryAfter.Seconds()))) + return c.JSON(http.StatusServiceUnavailable, map[string]any{ + "error": map[string]any{ + "type": "admission_rejected", + "message": fmt.Sprintf("model %q is at capacity (max_concurrent=%d); retry after %s", cfg.Name, max, retryAfter), + }, + }) + } + defer release() + return next(c) + } + } +} + +// admissionEventSeq scopes IDs across the process so rapid +// rejections under load get unique row IDs without coordinating +// with the rest of the event-store ID schemes. +var admissionEventSeq atomic.Uint64 + +func recordAdmissionRejection(events pii.EventStore, modelName string, retryAfter time.Duration) { + if events == nil { + return + } + statusCode := http.StatusServiceUnavailable + durMS := retryAfter.Milliseconds() + id := fmt.Sprintf("adm_%d_%s", admissionEventSeq.Add(1), randHex(4)) + _ = events.Record(context.Background(), pii.PIIEvent{ + ID: id, + Kind: pii.KindAdmission, + Host: modelName, + StatusCode: statusCode, + DurationMS: durMS, + CreatedAt: time.Now().UTC(), + }) +} + +func randHex(n int) string { + b := make([]byte, n) + _, _ = rand.Read(b) + return hex.EncodeToString(b) +} diff --git a/core/http/middleware/admission_test.go b/core/http/middleware/admission_test.go new file mode 100644 index 000000000000..841a2dd47d76 --- /dev/null +++ b/core/http/middleware/admission_test.go @@ -0,0 +1,118 @@ +package middleware_test + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "sync" + + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/config" + . "github.com/mudler/LocalAI/core/http/middleware" + "github.com/mudler/LocalAI/core/services/routing/admission" + "github.com/mudler/LocalAI/core/services/routing/pii" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +// recordingStore captures admission rows so the test can assert +// the audit trail without standing up the full pii event store. +type recordingStore struct { + mu sync.Mutex + events []pii.PIIEvent +} + +func (r *recordingStore) Record(_ context.Context, e pii.PIIEvent) error { + r.mu.Lock() + defer r.mu.Unlock() + r.events = append(r.events, e) + return nil +} +func (r *recordingStore) List(_ context.Context, _ pii.ListQuery) ([]pii.PIIEvent, error) { + return nil, nil +} +func (r *recordingStore) Count(_ context.Context) (int, error) { return 0, nil } +func (r *recordingStore) Close() error { return nil } + +func runAdmission(lim *admission.Limiter, store *recordingStore, cfg *config.ModelConfig, handler echo.HandlerFunc) (*httptest.ResponseRecorder, error) { + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader("{}")) + rec := httptest.NewRecorder() + c := echo.New().NewContext(req, rec) + c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg) + mw := AdmissionControl(lim, store) + err := mw(handler)(c) + return rec, err +} + +var _ = Describe("Admission", func() { + It("allows when under limit", func() { + lim := admission.New() + cfg := &config.ModelConfig{Limits: config.LimitsConfig{MaxConcurrent: 2}} + cfg.Name = "m" + rec, err := runAdmission(lim, &recordingStore{}, cfg, func(c echo.Context) error { + return c.String(http.StatusOK, "ok") + }) + Expect(err).NotTo(HaveOccurred()) + Expect(rec.Code).To(Equal(http.StatusOK)) + }) + + It("rejects when full", func() { + // Saturate the limiter outside the middleware, then a request + // at the same model gets 503 with a Retry-After header. + lim := admission.New() + release, ok := lim.Acquire("busy", 1) + Expect(ok).To(BeTrue(), "setup acquire should succeed") + defer release() + + cfg := &config.ModelConfig{Limits: config.LimitsConfig{MaxConcurrent: 1, RetryAfterSeconds: 3}} + cfg.Name = "busy" + store := &recordingStore{} + handlerCalled := false + rec, err := runAdmission(lim, store, cfg, func(c echo.Context) error { + handlerCalled = true + return c.String(http.StatusOK, "ok") + }) + Expect(err).NotTo(HaveOccurred()) + Expect(rec.Code).To(Equal(http.StatusServiceUnavailable)) + Expect(rec.Header().Get("Retry-After")).To(Equal("3")) + Expect(handlerCalled).To(BeFalse(), "handler should not run when admission rejects") + Expect(rec.Body.String()).To(ContainSubstring("admission_rejected")) + Expect(store.events).To(HaveLen(1)) + Expect(store.events[0].Kind).To(Equal(pii.KindAdmission)) + Expect(store.events[0].Host).To(Equal("busy"), "audit row carries the model name") + }) + + It("no limit configured is no-op", func() { + // MaxConcurrent=0 means unlimited — handler always runs and no + // audit row is written even after many calls. + lim := admission.New() + cfg := &config.ModelConfig{} + cfg.Name = "open" + store := &recordingStore{} + for i := 0; i < 10; i++ { + rec, err := runAdmission(lim, store, cfg, func(c echo.Context) error { + return c.String(http.StatusOK, "ok") + }) + Expect(err).NotTo(HaveOccurred()) + Expect(rec.Code).To(Equal(http.StatusOK)) + } + Expect(store.events).To(BeEmpty()) + }) + + It("releases after handler", func() { + // One slot, two SEQUENTIAL requests: the second succeeds because + // the first's release runs on handler return. + lim := admission.New() + cfg := &config.ModelConfig{Limits: config.LimitsConfig{MaxConcurrent: 1}} + cfg.Name = "tight" + for i := 0; i < 3; i++ { + rec, err := runAdmission(lim, &recordingStore{}, cfg, func(c echo.Context) error { + return c.String(http.StatusOK, "ok") + }) + Expect(err).NotTo(HaveOccurred()) + Expect(rec.Code).To(Equal(http.StatusOK)) + } + }) +}) diff --git a/core/http/middleware/context_keys.go b/core/http/middleware/context_keys.go new file mode 100644 index 000000000000..d1983c88259c --- /dev/null +++ b/core/http/middleware/context_keys.go @@ -0,0 +1,50 @@ +package middleware + +// Context keys used by routing-module middlewares to communicate with +// the usage recorder. Unlike the legacy CONTEXT_LOCALS_KEY_* constants +// (which exist for backward-compatible callers), these are the +// canonical names for new fields. +const ( + // ContextKeyRequestedModel is set by content-router middleware to + // the model name the client originally asked for, before any router + // remapping. UsageMiddleware writes this into UsageRecord.RequestedModel. + ContextKeyRequestedModel = "routing.requested_model" + + // ContextKeyServedModel is set by content-router middleware to the + // model that actually handled the request (post-routing). When no + // router runs, callers may leave this unset and the response-reported + // model name is used as the served value. + ContextKeyServedModel = "routing.served_model" + + // ContextKeyPreFilterPromptTokens / ContextKeyPostFilterPromptTokens + // are set by the PII middleware to record how many prompt tokens + // the user sent vs how many made it past redaction. When both are + // zero or unset, UsageMiddleware uses the response-reported prompt + // token count for both — i.e., no filter ran. + ContextKeyPreFilterPromptTokens = "routing.pre_filter_prompt_tokens" + ContextKeyPostFilterPromptTokens = "routing.post_filter_prompt_tokens" + + // ContextKeyCorrelationID is the join key threaded across PII + // events, router decisions, admission events, and usage records. + // trace.go middleware sets X-Correlation-ID on the response; this + // key mirrors the same value into echo.Context for in-process + // propagation without re-parsing the header. + ContextKeyCorrelationID = "routing.correlation_id" + + // ContextKeyPromptTokens / ContextKeyCompletionTokens / ContextKeyTotalTokens + // are the canonical token counts the request handler measured. Stamping + // these from the handler is the only reliable path for streaming + // responses, where the SSE chunks may not include a usage block (OpenAI + // requires stream_options.include_usage; Anthropic uses a separate + // message_delta event shape). UsageMiddleware prefers these context + // values over body-parsing. + ContextKeyPromptTokens = "routing.prompt_tokens" + ContextKeyCompletionTokens = "routing.completion_tokens" + ContextKeyTotalTokens = "routing.total_tokens" + + // ContextKeyResponseModel is the model name the handler committed to + // in its response payload. UsageMiddleware uses it when neither the + // router nor the body-parse path has produced one. Distinct from + // ContextKeyServedModel, which is the router's resolved choice. + ContextKeyResponseModel = "routing.response_model" +) diff --git a/core/http/middleware/route_model.go b/core/http/middleware/route_model.go new file mode 100644 index 000000000000..2d0336d69816 --- /dev/null +++ b/core/http/middleware/route_model.go @@ -0,0 +1,444 @@ +package middleware + +import ( + "context" + "crypto/rand" + "encoding/hex" + "fmt" + "hash/fnv" + "slices" + "strings" + "time" + + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/http/auth" + "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/core/services/routing/router" + "github.com/mudler/xlog" + "gopkg.in/yaml.v3" +) + +// ScorerFactory returns a router.Scorer bound to a named classifier +// model. The Score classifier uses it to compute joint log-prob of +// every policy label against the routing prompt and read off the +// softmax distribution. +type ScorerFactory func(modelName string) router.Scorer + +// EmbedderFactory returns a router.Embedder bound to a named model. +// Used by the L2 embedding cache. Returning nil signals "model not +// loadable" — the middleware then falls back to the uncached +// classifier so routing still happens. +type EmbedderFactory func(modelName string) router.Embedder + +// VectorStoreFactory returns a router.VectorStore bound to a named +// collection. Each router model's cache lives in its own collection +// (see ClassifierDeps.storeNameFor) so two routers can't poison each +// other's hits. +type VectorStoreFactory func(storeName string) router.VectorStore + +// ClassifierDeps bundles the backend factories the router middleware +// needs to build a classifier and its optional L2 cache. Bundled into +// one struct because RouteModel already takes many positional +// arguments — additions to the dependency surface go here instead of +// growing the signature. +// +// Embedder and VectorStore are optional: when both are non-nil and the +// router config declares an embedding_cache block, the score +// classifier is wrapped in EmbeddingCacheClassifier. Otherwise the +// score classifier runs unwrapped and the embedding-cache YAML is +// ignored with a warning. +type ClassifierDeps struct { + Scorer ScorerFactory + Embedder EmbedderFactory + VectorStore VectorStoreFactory + + // Registry is the shared classifier cache. Both the OpenAI and + // Anthropic routes pass the same registry so the admin stats + // endpoint sees every live classifier. Nil falls back to a local + // registry — tests that don't need cross-route stats use this. + Registry *router.Registry +} + +// ProbeExtractor pulls the prompt content out of a parsed request so +// the classifier can inspect it without taking a dependency on the +// schema package. One extractor per request shape — wired by the +// route registration site (mirrors the piiadapter pattern). +// +// Returns ok=false when the parsed value isn't the expected type — the +// middleware then passes through without engaging the router. +type ProbeExtractor func(parsed any) (router.Probe, bool) + +// RouteModel runs after SetModelAndConfig and the schema-specific +// SetXRequest, looks at the resolved model's Router config, and (when +// present) reclassifies the request to one of the candidates. +// +// The middleware: +// +// 1. Loads MODEL_CONFIG from the echo context. If nil or HasRouter() +// is false, passes through. +// 2. Extracts the probe via the supplied ProbeExtractor. +// 3. Invokes the classifier matching cfg.Router.Classifier. Today +// only "feature" is supported; unknown classifiers fall back to +// cfg.Router.Fallback (or fail when none is set). +// 4. Resolves the chosen candidate to its model name. Reloads the +// ModelConfig for that model and asserts depth-1 (the candidate +// must NOT itself have a Router). Violation returns 500 — config +// bug, not a request bug. +// 5. Updates input.Model in place, replaces MODEL_CONFIG with the +// candidate's config, and stamps RequestedModel/ServedModel on the +// context so UsageMiddleware records the routing. +// 6. Writes a DecisionRecord to the store for the admin page. +// +// store may be nil when --disable-stats turns off the routing log; +// classification still runs. +// +// Composition with SmartRouter (distributed mode): this middleware +// only does *model* selection. Node selection still happens in +// SmartRouter.Route() downstream of this middleware. +func RouteModel(loader *config.ModelConfigLoader, appConfig *config.ApplicationConfig, store router.DecisionStore, fallbackUser *auth.User, extractor ProbeExtractor, deps ClassifierDeps) echo.MiddlewareFunc { + registry := deps.Registry + if registry == nil { + registry = router.NewRegistry() + } + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + cfg, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) + if !ok || cfg == nil || !cfg.HasRouter() { + return next(c) + } + + parsed := c.Get(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST) + if parsed == nil { + return next(c) + } + + probe, probeOK := extractor(parsed) + if !probeOK { + return next(c) + } + + classifier, classifierErr := getOrBuildClassifier(registry, cfg, deps) + if classifierErr != nil { + xlog.Warn("router: classifier unavailable — falling back", + "router_model", cfg.Name, "classifier", cfg.Router.Classifier, "error", classifierErr) + if cfg.Router.Fallback == "" { + return echo.NewHTTPError(503, "router classifier unavailable and no fallback configured") + } + return rewriteRequest(c, parsed, cfg, cfg.Router.Fallback, []string{router.LabelFallback}, router.Decision{Labels: []string{router.LabelFallback}}, router.LabelFallback, store, fallbackUser, loader, appConfig, next) + } + + start := time.Now() + decision, err := classifier.Classify(c.Request().Context(), probe) + if err != nil { + xlog.Warn("router: classifier returned error — using fallback", + "router_model", cfg.Name, "error", err, "latency_ms", time.Since(start).Milliseconds()) + if cfg.Router.Fallback == "" { + return echo.NewHTTPError(503, "router classification failed: "+err.Error()) + } + return rewriteRequest(c, parsed, cfg, cfg.Router.Fallback, []string{router.LabelFallback}, router.Decision{Labels: []string{router.LabelFallback}, Latency: time.Since(start)}, classifier.Name(), store, fallbackUser, loader, appConfig, next) + } + + candidate := matchCandidate(cfg.Router.Candidates, decision.Labels) + if candidate == "" { + xlog.Warn("router: no candidate covers active labels — using fallback", + "router_model", cfg.Name, "labels", decision.Labels) + if cfg.Router.Fallback == "" { + return echo.NewHTTPError(500, "no candidate covers active labels: "+strings.Join(decision.Labels, ",")) + } + candidate = cfg.Router.Fallback + } + + return rewriteRequest(c, parsed, cfg, candidate, decision.Labels, decision, classifier.Name(), store, fallbackUser, loader, appConfig, next) + } + } +} + +// rewriteRequest swaps the resolved model from the router to the +// chosen candidate, asserts the depth-1 invariant on the new config, +// records the decision, and continues. Pulled out so the classifier- +// success and fallback paths share one rewrite implementation. +func rewriteRequest(c echo.Context, parsed any, routerCfg *config.ModelConfig, candidateModel string, labels []string, decision router.Decision, classifierName string, store router.DecisionStore, fallbackUser *auth.User, loader *config.ModelConfigLoader, appConfig *config.ApplicationConfig, next echo.HandlerFunc) error { + candidateCfg, err := loader.LoadModelConfigFileByNameDefaultOptions(candidateModel, appConfig) + if err != nil || candidateCfg == nil { + xlog.Error("router: failed to load candidate config", + "router_model", routerCfg.Name, "candidate", candidateModel, "error", err) + return echo.NewHTTPError(500, "router candidate not loadable: "+candidateModel) + } + + // Depth-1 invariant: the resolved candidate must NOT itself be a + // router. Chained routers turn dispatch into a graph traversal — + // a configuration we deliberately reject. The check is at runtime + // because gallery installs can introduce a Router on a previously- + // flat model after startup. + if candidateCfg.HasRouter() { + xlog.Error("router: depth-1 invariant violated — candidate is itself a router", + "router_model", routerCfg.Name, "candidate", candidateModel) + return echo.NewHTTPError(500, "router candidate is itself a router (depth-1 invariant)") + } + + if req, ok := parsed.(schema.LocalAIRequest); ok { + req.ModelName(&candidateModel) + } + + c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, candidateCfg) + c.Set(ContextKeyRequestedModel, routerCfg.Name) + c.Set(ContextKeyServedModel, candidateModel) + + if store != nil { + correlationID, _ := c.Get(ContextKeyCorrelationID).(string) + if correlationID == "" { + correlationID = c.Response().Header().Get("X-Correlation-ID") + } + userID := "" + if u := auth.GetUser(c); u != nil { + userID = u.ID + } else if fallbackUser != nil { + userID = fallbackUser.ID + } + _ = store.Record(context.Background(), router.DecisionRecord{ + ID: newDecisionID(), + CorrelationID: correlationID, + UserID: userID, + RouterModel: routerCfg.Name, + RequestedModel: routerCfg.Name, + ServedModel: candidateModel, + Classifier: classifierName, + Label: strings.Join(labels, ","), + Score: decision.Score, + LatencyMs: decision.Latency.Milliseconds(), + Cached: decision.Cached, + CacheSimilarity: decision.CacheSimilarity, + CreatedAt: time.Now().UTC(), + }) + } + + return next(c) +} + +func getOrBuildClassifier(registry *router.Registry, cfg *config.ModelConfig, deps ClassifierDeps) (router.Classifier, error) { + fp := routerConfigFingerprint(cfg.Router) + if cached, ok := registry.Get(cfg.Name, fp); ok { + return cached, nil + } + c, err := buildClassifier(cfg, deps) + if err != nil { + return nil, err + } + registry.Put(cfg.Name, fp, c) + return c, nil +} + +// routerConfigFingerprint is a stable cache key for a RouterConfig. +// FNV-64 over the YAML form — equality-only, not cryptographic. +// YAML-marshal picks up any future field added to RouterConfig +// without this function needing to be touched. +func routerConfigFingerprint(rc config.RouterConfig) uint64 { + bytes, err := yaml.Marshal(rc) + if err != nil { + // Marshalling a value type can't fail in practice; fall + // back to a hash that varies per call so we don't quietly + // share a cache entry across distinct configs. + return uint64(time.Now().UnixNano()) + } + h := fnv.New64a() + h.Write(bytes) + return h.Sum64() +} + +func buildClassifier(cfg *config.ModelConfig, deps ClassifierDeps) (router.Classifier, error) { + rc := cfg.Router + classifier := rc.Classifier + if classifier == "" { + classifier = router.ClassifierScore + } + if classifier != router.ClassifierScore { + return nil, fmt.Errorf("router: unknown classifier %q (only %q is supported)", classifier, router.ClassifierScore) + } + if rc.ClassifierModel == "" { + return nil, fmt.Errorf("router classifier score requires classifier_model") + } + if deps.Scorer == nil { + return nil, fmt.Errorf("router classifier score unavailable: no scorer factory wired") + } + scorer := deps.Scorer(rc.ClassifierModel) + if scorer == nil { + return nil, fmt.Errorf("router classifier score: classifier_model %q not loadable", rc.ClassifierModel) + } + if len(rc.Policies) == 0 { + return nil, fmt.Errorf("router classifier score requires at least one policy") + } + policies := make([]router.ScorePolicy, 0, len(rc.Policies)) + for _, p := range rc.Policies { + if p.Label == "" { + return nil, fmt.Errorf("router classifier score: policy with empty label") + } + if p.Description == "" { + return nil, fmt.Errorf("router classifier score: policy %q has no description", p.Label) + } + policies = append(policies, router.ScorePolicy{ + Label: p.Label, + Description: p.Description, + }) + } + // Validate that every label referenced by a candidate is declared + // as a policy — otherwise the classifier would emit labels no + // candidate covers, and the routing always falls back. + policyLabels := make(map[string]struct{}, len(policies)) + for _, p := range policies { + policyLabels[p.Label] = struct{}{} + } + for _, c := range rc.Candidates { + if c.Model == "" { + return nil, fmt.Errorf("router classifier score: candidate has empty model field") + } + if len(c.Labels) == 0 { + return nil, fmt.Errorf("router classifier score: candidate %q has no labels", c.Model) + } + for _, l := range c.Labels { + if _, ok := policyLabels[l]; !ok { + return nil, fmt.Errorf("router classifier score: candidate %q references unknown label %q (not in policies)", c.Model, l) + } + } + } + cacheCap := rc.ClassifierCacheSize + if cacheCap == 0 { + cacheCap = 1024 + } + score := router.NewScoreClassifier(policies, scorer, cacheCap, rc.ActivationThreshold) + + if rc.EmbeddingCache == nil { + return score, nil + } + wrapped, err := wrapWithEmbeddingCache(cfg, score, deps) + if err != nil { + // Caching plumbing problems must not break routing — log, + // drop the cache layer, and return the uncached score + // classifier. The admin UI surfaces the warning via the + // classifier-build error path used elsewhere. + xlog.Warn("router: embedding cache disabled", + "router_model", cfg.Name, "error", err) + return score, nil + } + return wrapped, nil +} + +func wrapWithEmbeddingCache(cfg *config.ModelConfig, inner router.Classifier, deps ClassifierDeps) (router.Classifier, error) { + ec := cfg.Router.EmbeddingCache + if ec.EmbeddingModel == "" { + return nil, fmt.Errorf("embedding_cache requires embedding_model") + } + if deps.Embedder == nil || deps.VectorStore == nil { + return nil, fmt.Errorf("embedding cache factories not wired") + } + embedder := deps.Embedder(ec.EmbeddingModel) + if embedder == nil { + return nil, fmt.Errorf("embedding_model %q not loadable", ec.EmbeddingModel) + } + storeName := ec.StoreName + if storeName == "" { + storeName = "router-cache-" + cfg.Name + } + vstore := deps.VectorStore(storeName) + if vstore == nil { + return nil, fmt.Errorf("vector store %q not loadable", storeName) + } + return router.NewEmbeddingCacheClassifier(inner, embedder, vstore, ec.SimilarityThreshold, ec.ConfidenceThreshold), nil +} + +// matchCandidate picks the FIRST candidate whose Labels are a +// superset of the active label set. Admins order the candidates list +// smallest → largest, so a request that needs one label routes to +// the smallest capable model and one that needs multiple falls to +// the first bigger candidate that covers them all. Returns empty +// string when no candidate matches; the caller falls back. +func matchCandidate(candidates []config.RouterCandidate, active []string) string { + if len(active) == 0 { + return "" + } + for _, c := range candidates { + if labelSetCovers(c.Labels, active) { + return c.Model + } + } + return "" +} + +// labelSetCovers returns true when every element of needed appears +// in have. Label sets are typically <10 entries so the linear scan +// is fine. +func labelSetCovers(have, needed []string) bool { + for _, n := range needed { + if !slices.Contains(have, n) { + return false + } + } + return true +} + +func newDecisionID() string { + var b [12]byte + _, _ = rand.Read(b[:]) + return "rd_" + hex.EncodeToString(b[:]) +} + +// OpenAIProbe extracts a router.Probe from a parsed *schema.OpenAIRequest. +// Concatenates message contents (string-form or text blocks of the +// structured `[]any` content) so the classifier sees a single corpus +// for length and content-shape rules. Image blocks are skipped — a +// future multimodal classifier can take a different route. +func OpenAIProbe(parsed any) (router.Probe, bool) { + req, ok := parsed.(*schema.OpenAIRequest) + if !ok || req == nil { + return router.Probe{}, false + } + var b strings.Builder + for i := range req.Messages { + switch ct := req.Messages[i].Content.(type) { + case string: + b.WriteString(ct) + b.WriteByte('\n') + case []any: + for _, block := range ct { + if bm, ok := block.(map[string]any); ok && bm["type"] == "text" { + if t, ok := bm["text"].(string); ok { + b.WriteString(t) + b.WriteByte('\n') + } + } + } + } + } + return router.Probe{ + Prompt: b.String(), + }, true +} + +// AnthropicProbe is the AnthropicRequest analogue of OpenAIProbe. +func AnthropicProbe(parsed any) (router.Probe, bool) { + req, ok := parsed.(*schema.AnthropicRequest) + if !ok || req == nil { + return router.Probe{}, false + } + var b strings.Builder + for i := range req.Messages { + switch ct := req.Messages[i].Content.(type) { + case string: + b.WriteString(ct) + b.WriteByte('\n') + case []any: + for _, block := range ct { + if bm, ok := block.(map[string]any); ok && bm["type"] == "text" { + if t, ok := bm["text"].(string); ok { + b.WriteString(t) + b.WriteByte('\n') + } + } + } + } + } + return router.Probe{ + Prompt: b.String(), + }, true +} + diff --git a/core/http/middleware/route_model_test.go b/core/http/middleware/route_model_test.go new file mode 100644 index 000000000000..b19d5d3d5fde --- /dev/null +++ b/core/http/middleware/route_model_test.go @@ -0,0 +1,266 @@ +package middleware_test + +import ( + "context" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/config" + . "github.com/mudler/LocalAI/core/http/middleware" + "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/core/services/routing/router" + "github.com/mudler/LocalAI/pkg/system" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "gopkg.in/yaml.v3" +) + +// The RouteModel middleware wires the score classifier into request +// rewriting. The classifier itself is covered in +// router/score_test.go — these specs pin the middleware-level +// behaviour: candidate matching against the active label set, the +// fallback path, and the depth-1 invariant. + +var _ = Describe("RouteModel middleware (score classifier)", func() { + var ( + modelDir string + appConfig *config.ApplicationConfig + loader *config.ModelConfigLoader + store *fakeDecisionStore + ) + + BeforeEach(func() { + d, err := os.MkdirTemp("", "router-test-*") + Expect(err).NotTo(HaveOccurred()) + modelDir = d + appConfig = &config.ApplicationConfig{ + Context: context.Background(), + SystemState: &system.SystemState{Model: system.Model{ModelsPath: modelDir}}, + } + loader = config.NewModelConfigLoader(modelDir) + store = &fakeDecisionStore{} + }) + + AfterEach(func() { + _ = os.RemoveAll(modelDir) + }) + + It("routes to a candidate whose labels cover the active set", func() { + // 3 policies, 2 candidates. Small model has [casual-chat], + // bigger has [code-generation, math-reasoning, casual-chat]. + // A query that activates code-generation should fall to the + // bigger candidate because it's the only one that covers it. + routerCfg := newScoreRouterModel(modelDir, "smart-router") + writeCandidate(modelDir, "small-model") + writeCandidate(modelDir, "big-model") + + s := &stubScorer{labelToLogProb: map[string]float64{ + "code-generation": -0.05, // dominant + "casual-chat": -3.0, + "math-reasoning": -4.0, + }} + rec, err := runRouter(loader, appConfig, store, routerCfg, openAIChat("debug my Go null pointer"), stubScorerFactory(s)) + Expect(err).NotTo(HaveOccurred()) + Expect(rec.Code).To(Equal(http.StatusOK)) + Expect(rec.Body.String()).To(Equal("served:big-model")) + Expect(store.records).To(HaveLen(1)) + Expect(store.records[0].ServedModel).To(Equal("big-model")) + Expect(store.records[0].Label).To(ContainSubstring("code-generation")) + }) + + It("prefers the smaller candidate when both cover the active set", func() { + // Both candidates list casual-chat. Admins order small → + // big, so a casual-chat-only request must route to small. + routerCfg := newScoreRouterModel(modelDir, "smart-router") + writeCandidate(modelDir, "small-model") + writeCandidate(modelDir, "big-model") + + s := &stubScorer{labelToLogProb: map[string]float64{ + "code-generation": -5.0, + "casual-chat": -0.05, // dominant + "math-reasoning": -5.0, + }} + rec, err := runRouter(loader, appConfig, store, routerCfg, openAIChat("hi"), stubScorerFactory(s)) + Expect(err).NotTo(HaveOccurred()) + Expect(rec.Body.String()).To(Equal("served:small-model")) + }) + + It("falls back when no candidate covers the active label set", func() { + // Only the bigger candidate covers math-reasoning. We + // deliberately drop it from the candidates list so neither + // matches; expect Fallback to fire. + routerCfg := newScoreRouterModel(modelDir, "smart-router") + // Remove the second candidate so coverage gap appears. + routerCfg.Router.Candidates = routerCfg.Router.Candidates[:1] + writeCandidate(modelDir, "small-model") + writeCandidate(modelDir, "qwen3-0.6b") + + s := &stubScorer{labelToLogProb: map[string]float64{ + "code-generation": -5.0, + "casual-chat": -5.0, + "math-reasoning": -0.05, // dominant — but no candidate has it + }} + rec, err := runRouter(loader, appConfig, store, routerCfg, openAIChat("3 apples cost $2.40"), stubScorerFactory(s)) + Expect(err).NotTo(HaveOccurred()) + Expect(rec.Body.String()).To(Equal("served:qwen3-0.6b")) + }) + + It("rejects candidates that reference unknown labels at build time", func() { + routerCfg := newScoreRouterModel(modelDir, "smart-router") + routerCfg.Router.Candidates = append(routerCfg.Router.Candidates, config.RouterCandidate{ + Model: "broken", + Labels: []string{"nonexistent-label"}, + }) + writeCandidate(modelDir, "small-model") + writeCandidate(modelDir, "big-model") + writeCandidate(modelDir, "broken") + writeCandidate(modelDir, "qwen3-0.6b") + + s := &stubScorer{labelToLogProb: map[string]float64{ + "code-generation": -0.05, + "casual-chat": -3.0, + "math-reasoning": -4.0, + }} + rec, err := runRouter(loader, appConfig, store, routerCfg, openAIChat("debug something"), stubScorerFactory(s)) + // Unknown-label config bug surfaces via the + // classifier-unavailable path, which falls through to the + // configured Fallback. + Expect(err).NotTo(HaveOccurred()) + Expect(rec.Body.String()).To(Equal("served:qwen3-0.6b")) + }) + + It("returns 500 when the candidate is itself a router (depth-1 invariant)", func() { + // The candidate model is itself a router. We must reject + // the dispatch — chained routers are deliberately + // disallowed. + routerCfg := newScoreRouterModel(modelDir, "smart-router") + // Bend the test setup: replace one of the candidate-model + // configs with a nested-router config. + nestedRouter := newScoreRouterModel(modelDir, "small-model") + Expect(os.WriteFile(filepath.Join(modelDir, "small-model.yaml"), []byte(toYAML(nestedRouter)), 0o644)).To(Succeed()) + writeCandidate(modelDir, "big-model") + writeCandidate(modelDir, "qwen3-0.6b") + + s := &stubScorer{labelToLogProb: map[string]float64{ + "code-generation": -5.0, + "casual-chat": -0.05, + "math-reasoning": -5.0, + }} + _, err := runRouter(loader, appConfig, store, routerCfg, openAIChat("hi"), stubScorerFactory(s)) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("depth-1 invariant")) + }) +}) + +// --- helpers --- + +// stubScorer scores each candidate label according to a fixed +// label→log-prob map; per-token length is faked at 2 tokens so length +// normalisation is a no-op. +type stubScorer struct { + labelToLogProb map[string]float64 +} + +func (s *stubScorer) Score(_ context.Context, _ string, candidates []string) ([]router.CandidateScore, error) { + out := make([]router.CandidateScore, len(candidates)) + for i, c := range candidates { + lp := s.labelToLogProb[c] + out[i] = router.CandidateScore{ + LogProb: lp * 2, + LengthNormalizedLogProb: lp, + NumTokens: 2, + } + } + return out, nil +} + +func stubScorerFactory(s *stubScorer) ScorerFactory { + return func(string) router.Scorer { return s } +} + +type fakeDecisionStore struct { + records []router.DecisionRecord +} + +func (f *fakeDecisionStore) Record(_ context.Context, r router.DecisionRecord) error { + f.records = append(f.records, r) + return nil +} + +func (f *fakeDecisionStore) List(_ context.Context, _ router.DecisionListQuery) ([]router.DecisionRecord, error) { + out := append([]router.DecisionRecord(nil), f.records...) + return out, nil +} + +func (f *fakeDecisionStore) Close() error { return nil } +func (f *fakeDecisionStore) Count(_ context.Context) (int, error) { return len(f.records), nil } + +// newScoreRouterModel builds a smart-router config with 3 policies +// and 2 candidates (small with one label, bigger with all three). +// Admins are expected to order candidates small → large; the +// middleware picks the first whose labels are a superset of the +// active set. +func newScoreRouterModel(modelDir, name string) *config.ModelConfig { + cfg := &config.ModelConfig{ + Name: name, + Router: config.RouterConfig{ + Classifier: "score", + ClassifierModel: "arch-router", + Fallback: "qwen3-0.6b", + Policies: []config.RouterPolicy{ + {Label: "code-generation", Description: "writing or debugging code"}, + {Label: "casual-chat", Description: "small talk"}, + {Label: "math-reasoning", Description: "arithmetic and word problems"}, + }, + Candidates: []config.RouterCandidate{ + {Model: "small-model", Labels: []string{"casual-chat"}}, + {Model: "big-model", Labels: []string{"code-generation", "casual-chat", "math-reasoning"}}, + }, + }, + } + Expect(os.WriteFile(filepath.Join(modelDir, name+".yaml"), []byte(toYAML(cfg)), 0o644)).To(Succeed()) + return cfg +} + +func writeCandidate(modelDir, name string) { + body := "name: " + name + "\nbackend: mock-backend\n" + Expect(os.WriteFile(filepath.Join(modelDir, name+".yaml"), []byte(body), 0o644)).To(Succeed()) +} + +func toYAML(cfg *config.ModelConfig) string { + b, err := yaml.Marshal(cfg) + Expect(err).NotTo(HaveOccurred()) + return string(b) +} + +func openAIChat(content string) *schema.OpenAIRequest { + req := &schema.OpenAIRequest{ + Messages: []schema.Message{ + {Role: "user", Content: content}, + }, + } + req.Model = "smart-router" + return req +} + +func runRouter(loader *config.ModelConfigLoader, appConfig *config.ApplicationConfig, store router.DecisionStore, routerCfg *config.ModelConfig, parsed any, scorerFactory ScorerFactory) (*httptest.ResponseRecorder, error) { + mw := RouteModel(loader, appConfig, store, nil, OpenAIProbe, ClassifierDeps{Scorer: scorerFactory}) + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader("{}")) + rec := httptest.NewRecorder() + c := echo.New().NewContext(req, rec) + c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, routerCfg) + c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, parsed) + handler := mw(func(c echo.Context) error { + // Final hand-off — echo back which model the middleware + // resolved so the spec can assert routing without exercising + // the full chat pipeline. + served, _ := c.Get(ContextKeyServedModel).(string) + return c.String(http.StatusOK, "served:"+served) + }) + err := handler(c) + return rec, err +} diff --git a/core/http/middleware/trace.go b/core/http/middleware/trace.go index 9e713c0316f8..7e9bafa63dd3 100644 --- a/core/http/middleware/trace.go +++ b/core/http/middleware/trace.go @@ -1,9 +1,11 @@ package middleware import ( + "bufio" "bytes" "io" "mime" + "net" "net/http" "slices" "sync" @@ -80,6 +82,16 @@ func (w *bodyWriter) Flush() { } } +// Hijack lets WebSocket upgraders (gorilla/websocket) reach the +// underlying connection. Without this, gorilla's Hijacker type-assertion +// fails on the wrapped writer and the handshake returns 500. +func (w *bodyWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if hj, ok := w.ResponseWriter.(http.Hijacker); ok { + return hj.Hijack() + } + return nil, nil, http.ErrNotSupported +} + func initializeTracing(maxItems int) { tracingMaxItems = maxItems doInitializeTracing() diff --git a/core/http/middleware/usage.go b/core/http/middleware/usage.go index b82c1ee3f506..b26347d90c98 100644 --- a/core/http/middleware/usage.go +++ b/core/http/middleware/usage.go @@ -2,74 +2,19 @@ package middleware import ( "bytes" + "context" "encoding/json" - "sync" "time" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/http/auth" + "github.com/mudler/LocalAI/core/services/routing/billing" "github.com/mudler/xlog" - "gorm.io/gorm" ) -const ( - usageFlushInterval = 5 * time.Second - usageMaxPending = 5000 -) - -// usageBatcher accumulates usage records and flushes them to the DB periodically. -type usageBatcher struct { - mu sync.Mutex - pending []*auth.UsageRecord - db *gorm.DB -} - -func (b *usageBatcher) add(r *auth.UsageRecord) { - b.mu.Lock() - b.pending = append(b.pending, r) - b.mu.Unlock() -} - -func (b *usageBatcher) flush() { - b.mu.Lock() - batch := b.pending - b.pending = nil - b.mu.Unlock() - - if len(batch) == 0 { - return - } - - if err := b.db.Create(&batch).Error; err != nil { - xlog.Error("Failed to flush usage batch", "count", len(batch), "error", err) - // Re-queue failed records with a cap to avoid unbounded growth - b.mu.Lock() - if len(b.pending) < usageMaxPending { - b.pending = append(batch, b.pending...) - } - b.mu.Unlock() - } -} - -var batcher *usageBatcher - -// InitUsageRecorder starts a background goroutine that periodically flushes -// accumulated usage records to the database. -func InitUsageRecorder(db *gorm.DB) { - if db == nil { - return - } - batcher = &usageBatcher{db: db} - go func() { - ticker := time.NewTicker(usageFlushInterval) - defer ticker.Stop() - for range ticker.C { - batcher.flush() - } - }() -} - -// usageResponseBody is the minimal structure we need from the response JSON. +// usageResponseBody is the minimal structure we need from an OpenAI-shaped +// JSON response. Anthropic responses are decoded separately because their +// usage block uses different field names (input_tokens / output_tokens). type usageResponseBody struct { Model string `json:"model"` Usage *struct { @@ -79,18 +24,47 @@ type usageResponseBody struct { } `json:"usage"` } -// UsageMiddleware extracts token usage from OpenAI-compatible response JSON -// and records it per-user. -func UsageMiddleware(db *gorm.DB) echo.MiddlewareFunc { +// anthropicResponseBody covers /v1/messages JSON responses. +type anthropicResponseBody struct { + Model string `json:"model"` + Usage *struct { + InputTokens int64 `json:"input_tokens"` + OutputTokens int64 `json:"output_tokens"` + } `json:"usage"` +} + +// UsageMiddleware records token usage for inference requests via the +// billing.Recorder. Two paths produce a record: +// +// 1. Handler-stamped (preferred): the request handler called +// middleware.StampUsage with the canonical token counts before +// returning. This is the only reliable path for streaming responses +// — clients rarely set OpenAI's stream_options.include_usage, and +// Anthropic's usage lives in a separate message_delta event. +// 2. Body-parsed (fallback): the response is parsed for an OpenAI- or +// Anthropic-shaped usage block. Used by passthrough proxies and +// foreign endpoints. +// +// Recorder being nil (e.g., --disable-stats) makes the middleware a +// transparent pass-through. fallbackUser is used when auth.GetUser(c) +// returns nil; without it, an unauthenticated request would be dropped. +// +// Every request that fails to produce a record ticks +// localai_usage_unrecorded_total so silent billing misses are observable. +func UsageMiddleware(recorder *billing.Recorder, fallbackUser *auth.User) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { - if db == nil || batcher == nil { + if recorder == nil { return next(c) } startTime := time.Now() - // Wrap response writer to capture body + // Wrap response writer to capture body for the fallback parser. + // When the handler stamps the context we never read this buffer, + // so the cost is the per-chunk Write going through one extra + // indirection — accepted overhead in exchange for one billing + // path that works for both stamping and body-parse callers. resBody := new(bytes.Buffer) origWriter := c.Response().Writer mw := &bodyWriter{ @@ -101,71 +75,189 @@ func UsageMiddleware(db *gorm.DB) echo.MiddlewareFunc { handlerErr := next(c) - // Restore original writer c.Response().Writer = origWriter - // Only record on successful responses + endpoint := c.Request().URL.Path + if c.Response().Status < 200 || c.Response().Status >= 300 { return handlerErr } - // Get authenticated user user := auth.GetUser(c) if user == nil { + user = fallbackUser + } + if user == nil || user.ID == "" { + billing.CountUnrecorded(context.Background(), endpoint, "no_user") return handlerErr } - // Try to parse usage from response - responseBytes := resBody.Bytes() - if len(responseBytes) == 0 { + model, prompt, completion, total, ok := tokensFromContext(c) + if !ok { + model, prompt, completion, total, ok = tokensFromBody(resBody.Bytes(), c.Response().Header().Get("Content-Type")) + } + if !ok { + billing.CountUnrecorded(context.Background(), endpoint, "no_usage") return handlerErr } - // Check content type - ct := c.Response().Header().Get("Content-Type") - isJSON := ct == "" || ct == "application/json" || bytes.HasPrefix([]byte(ct), []byte("application/json")) - isSSE := bytes.HasPrefix([]byte(ct), []byte("text/event-stream")) + requested, served := modelsFromContext(c, model) + pre, post := promptTokensFromContext(c, prompt) - if !isJSON && !isSSE { - return handlerErr + record := &auth.UsageRecord{ + UserID: user.ID, + UserName: user.Name, + Model: model, + Endpoint: endpoint, + PromptTokens: prompt, + CompletionTokens: completion, + TotalTokens: total, + Duration: time.Since(startTime).Milliseconds(), + CreatedAt: startTime, + RequestedModel: requested, + ServedModel: served, + PreFilterPromptTokens: pre, + PostFilterPromptTokens: post, + CorrelationID: correlationIDFromContext(c), } - var resp usageResponseBody - if isSSE { - last, ok := lastSSEData(responseBytes) - if !ok { - return handlerErr - } - if err := json.Unmarshal(last, &resp); err != nil { - return handlerErr - } - } else { - if err := json.Unmarshal(responseBytes, &resp); err != nil { - return handlerErr - } + if err := recorder.Record(context.Background(), record); err != nil { + xlog.Error("usage middleware: recorder.Record failed", "error", err, "user", user.ID, "model", model) + billing.CountUnrecorded(context.Background(), endpoint, "record_failed") } - if resp.Usage == nil { - return handlerErr - } + return handlerErr + } + } +} - record := &auth.UsageRecord{ - UserID: user.ID, - UserName: user.Name, - Model: resp.Model, - Endpoint: c.Request().URL.Path, - PromptTokens: resp.Usage.PromptTokens, - CompletionTokens: resp.Usage.CompletionTokens, - TotalTokens: resp.Usage.TotalTokens, - Duration: time.Since(startTime).Milliseconds(), - CreatedAt: startTime, - } +// tokensFromContext returns canonical token counts stamped by a handler +// via middleware.StampUsage. Returns ok=false when no stamp is present +// — the caller then tries the body-parse fallback. +// +// A model name without token counts is not considered "stamped" because a +// record with zero tokens looks the same as a never-recorded request to +// later analytics; the second condition is what gates ok. +func tokensFromContext(c echo.Context) (model string, prompt, completion, total int64, ok bool) { + if v, found := c.Get(ContextKeyResponseModel).(string); found { + model = v + } + pPresent := false + cPresent := false + if v, found := c.Get(ContextKeyPromptTokens).(int64); found { + prompt = v + pPresent = true + } + if v, found := c.Get(ContextKeyCompletionTokens).(int64); found { + completion = v + cPresent = true + } + if v, found := c.Get(ContextKeyTotalTokens).(int64); found { + total = v + } else { + total = prompt + completion + } + ok = pPresent || cPresent + return +} - batcher.add(record) +// tokensFromBody covers the passthrough-proxy / foreign-endpoint case +// where no handler stamps the context. Returns ok=false on any parse +// failure or missing-usage; the caller increments the unrecorded counter. +func tokensFromBody(responseBytes []byte, contentType string) (model string, prompt, completion, total int64, ok bool) { + if len(responseBytes) == 0 { + return + } + isJSON := contentType == "" || contentType == "application/json" || bytes.HasPrefix([]byte(contentType), []byte("application/json")) + isSSE := bytes.HasPrefix([]byte(contentType), []byte("text/event-stream")) + if !isJSON && !isSSE { + return + } - return handlerErr + payload := responseBytes + if isSSE { + // For SSE, the canonical usage chunk is the *last* non-[DONE] data + // line. OpenAI clients only emit one if stream_options.include_usage + // is set; Anthropic emits a final message_delta with usage. Both + // fit the "last data: line" rule. + last, lastOk := lastSSEData(responseBytes) + if !lastOk { + return } + payload = last + } + + // Try OpenAI shape first (handles /v1/chat/completions, /v1/completions, + // /v1/embeddings, /v1/edits, and any proxy that translates to OpenAI). + // A usage block whose token fields all decoded to zero is ambiguous — + // it could be an Anthropic body that happens to have a `usage` key — + // so fall through to the Anthropic parser instead of recording zeros. + var openAI usageResponseBody + if err := json.Unmarshal(payload, &openAI); err == nil && openAI.Usage != nil { + if openAI.Usage.PromptTokens != 0 || openAI.Usage.CompletionTokens != 0 || openAI.Usage.TotalTokens != 0 { + model = openAI.Model + prompt = openAI.Usage.PromptTokens + completion = openAI.Usage.CompletionTokens + total = openAI.Usage.TotalTokens + if total == 0 { + total = prompt + completion + } + ok = true + return + } + } + + // Fall through to Anthropic shape (proxy passthrough territory). + var ant anthropicResponseBody + if err := json.Unmarshal(payload, &ant); err == nil && ant.Usage != nil { + if ant.Usage.InputTokens != 0 || ant.Usage.OutputTokens != 0 { + model = ant.Model + prompt = ant.Usage.InputTokens + completion = ant.Usage.OutputTokens + total = prompt + completion + ok = true + return + } + } + + return +} + +// modelsFromContext returns (requested, served) using context-set values +// when present, falling back to the response-reported model for both. +// The router middleware (subsystem 2 of the routing plan) populates +// these; until it lands they are equal. +func modelsFromContext(c echo.Context, fallback string) (string, string) { + requested := fallback + served := fallback + if v, ok := c.Get(ContextKeyRequestedModel).(string); ok && v != "" { + requested = v + } + if v, ok := c.Get(ContextKeyServedModel).(string); ok && v != "" { + served = v + } + return requested, served +} + +func promptTokensFromContext(c echo.Context, fallback int64) (int64, int64) { + pre := fallback + post := fallback + if v, ok := c.Get(ContextKeyPreFilterPromptTokens).(int64); ok && v > 0 { + pre = v + } + if v, ok := c.Get(ContextKeyPostFilterPromptTokens).(int64); ok && v > 0 { + post = v + } + return pre, post +} + +func correlationIDFromContext(c echo.Context) string { + if v, ok := c.Get(ContextKeyCorrelationID).(string); ok { + return v } + // X-Correlation-ID header is set by trace.go middleware; read it as a + // fallback if the echo-context binding hasn't been populated yet. + return c.Response().Header().Get("X-Correlation-ID") } // lastSSEData returns the payload of the last "data: " line whose content is not "[DONE]". diff --git a/core/http/middleware/usage_stamp.go b/core/http/middleware/usage_stamp.go new file mode 100644 index 000000000000..7e82ab7444b1 --- /dev/null +++ b/core/http/middleware/usage_stamp.go @@ -0,0 +1,33 @@ +package middleware + +import "github.com/labstack/echo/v4" + +// StampUsage records the canonical token counts on the echo context so +// UsageMiddleware can attribute the request without parsing the response +// body. Handlers must call this for every successful response — the +// body-parse fallback is reserved for foreign endpoints (e.g., the cloud +// passthrough proxy). +// +// model is the name written into the response payload; passing it here +// is what lets the middleware fill the UsageRecord even when the handler +// abbreviates or rewrites the user-supplied model. Empty values are +// ignored so partial information is still useful (e.g., embeddings calls +// where completion is always 0). +// +// prompt and completion accept int because that's the native width of +// LocalAI's TokenUsage / OpenAIUsage structs (token counts never come +// close to overflow). Conversion to int64 happens once, here, so call +// sites stay free of casts. +func StampUsage(c echo.Context, model string, prompt, completion int) { + if c == nil { + return + } + if model != "" { + c.Set(ContextKeyResponseModel, model) + } + p := int64(prompt) + cp := int64(completion) + c.Set(ContextKeyPromptTokens, p) + c.Set(ContextKeyCompletionTokens, cp) + c.Set(ContextKeyTotalTokens, p+cp) +} diff --git a/core/http/middleware/usage_test.go b/core/http/middleware/usage_test.go new file mode 100644 index 000000000000..818861515d81 --- /dev/null +++ b/core/http/middleware/usage_test.go @@ -0,0 +1,225 @@ +package middleware_test + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "strings" + + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/http/auth" + httpMiddleware "github.com/mudler/LocalAI/core/http/middleware" + "github.com/mudler/LocalAI/core/services/routing/billing" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +// captureBackend collects records the recorder forwards. We assert on +// it directly rather than going through StatsBackend.Aggregate because +// these tests verify the middleware -> recorder hop, not aggregation +// (which has its own tests in routing/billing). +type captureBackend struct { + records []*auth.UsageRecord +} + +func (c *captureBackend) Record(_ context.Context, r *auth.UsageRecord) error { + c.records = append(c.records, r) + return nil +} +func (c *captureBackend) Aggregate(_ context.Context, _ billing.AggregateQuery) ([]auth.UsageBucket, error) { + return nil, nil +} +func (c *captureBackend) Close() error { return nil } + +var _ = Describe("UsageMiddleware", func() { + mockChat := func(usage string) echo.HandlerFunc { + return func(c echo.Context) error { + c.Response().Header().Set("Content-Type", "application/json") + body := fmt.Sprintf(`{"model":"qwen-7b","usage":%s}`, usage) + return c.String(http.StatusOK, body) + } + } + + It("records under the synthetic local user when auth is off", func() { + cap := &captureBackend{} + rec := billing.NewRecorder(cap) + fallback := &auth.User{ID: "local-uuid", Name: "local", Provider: auth.ProviderLocal} + + e := echo.New() + e.POST("/v1/chat/completions", + mockChat(`{"prompt_tokens":12,"completion_tokens":8,"total_tokens":20}`), + httpMiddleware.UsageMiddleware(rec, fallback), + ) + + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{}`)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + e.ServeHTTP(w, req) + + Expect(w.Code).To(Equal(http.StatusOK)) + Expect(cap.records).To(HaveLen(1)) + r := cap.records[0] + Expect(r.UserID).To(Equal("local-uuid")) + Expect(r.UserName).To(Equal("local")) + Expect(r.Model).To(Equal("qwen-7b")) + Expect(r.PromptTokens).To(Equal(int64(12))) + Expect(r.CompletionTokens).To(Equal(int64(8))) + Expect(r.TotalTokens).To(Equal(int64(20))) + }) + + It("does nothing when recorder is nil (--disable-stats)", func() { + fallback := &auth.User{ID: "local-uuid", Name: "local"} + e := echo.New() + e.POST("/v1/chat/completions", + mockChat(`{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}`), + httpMiddleware.UsageMiddleware(nil, fallback), + ) + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{}`)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + e.ServeHTTP(w, req) + Expect(w.Code).To(Equal(http.StatusOK)) + // no panic, no record — recorder=nil is the disable-stats path + }) + + It("skips when neither auth nor fallback user is available", func() { + cap := &captureBackend{} + rec := billing.NewRecorder(cap) + + e := echo.New() + e.POST("/v1/chat/completions", + mockChat(`{"prompt_tokens":3,"completion_tokens":2,"total_tokens":5}`), + httpMiddleware.UsageMiddleware(rec, nil), + ) + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{}`)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + e.ServeHTTP(w, req) + Expect(w.Code).To(Equal(http.StatusOK)) + Expect(cap.records).To(BeEmpty()) + }) + + It("ignores 5xx responses (no usage to attribute)", func() { + cap := &captureBackend{} + rec := billing.NewRecorder(cap) + fallback := &auth.User{ID: "local-uuid", Name: "local"} + + e := echo.New() + e.POST("/v1/chat/completions", + func(c echo.Context) error { + return c.String(http.StatusInternalServerError, `{"error":"boom"}`) + }, + httpMiddleware.UsageMiddleware(rec, fallback), + ) + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{}`)) + w := httptest.NewRecorder() + e.ServeHTTP(w, req) + Expect(w.Code).To(Equal(http.StatusInternalServerError)) + Expect(cap.records).To(BeEmpty()) + }) + + It("records via context-stamped tokens when handler called StampUsage (streaming-safe path)", func() { + cap := &captureBackend{} + rec := billing.NewRecorder(cap) + fallback := &auth.User{ID: "local-uuid", Name: "local"} + + // Simulate a streaming chat handler that emits SSE chunks WITHOUT a + // terminal usage block (the common case — clients rarely set + // stream_options.include_usage). The handler stamps the canonical + // counts on the context just before returning. UsageMiddleware + // must record from the stamp, not from body parsing. + streamingHandler := func(c echo.Context) error { + c.Response().Header().Set("Content-Type", "text/event-stream") + c.Response().WriteHeader(http.StatusOK) + _, _ = fmt.Fprint(c.Response().Writer, "data: {\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\n") + _, _ = fmt.Fprint(c.Response().Writer, "data: [DONE]\n\n") + httpMiddleware.StampUsage(c, "qwen-7b", 9, 5) + return nil + } + + e := echo.New() + e.POST("/v1/chat/completions", + streamingHandler, + httpMiddleware.UsageMiddleware(rec, fallback), + ) + + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{}`)) + w := httptest.NewRecorder() + e.ServeHTTP(w, req) + + Expect(w.Code).To(Equal(http.StatusOK)) + Expect(cap.records).To(HaveLen(1)) + Expect(cap.records[0].PromptTokens).To(Equal(int64(9))) + Expect(cap.records[0].CompletionTokens).To(Equal(int64(5))) + Expect(cap.records[0].TotalTokens).To(Equal(int64(14))) + Expect(cap.records[0].Model).To(Equal("qwen-7b")) + }) + + It("falls back to Anthropic body shape when no stamp is present", func() { + cap := &captureBackend{} + rec := billing.NewRecorder(cap) + fallback := &auth.User{ID: "local-uuid", Name: "local"} + + // Simulates a passthrough proxy / foreign endpoint: no handler stamp, + // so the middleware must parse the response body. Anthropic's shape + // uses input_tokens / output_tokens, not the OpenAI names. + anthropicHandler := func(c echo.Context) error { + c.Response().Header().Set("Content-Type", "application/json") + body := `{"model":"claude-sonnet","usage":{"input_tokens":15,"output_tokens":7}}` + return c.String(http.StatusOK, body) + } + + e := echo.New() + e.POST("/v1/messages", + anthropicHandler, + httpMiddleware.UsageMiddleware(rec, fallback), + ) + + req := httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(`{}`)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + e.ServeHTTP(w, req) + + Expect(w.Code).To(Equal(http.StatusOK)) + Expect(cap.records).To(HaveLen(1)) + Expect(cap.records[0].PromptTokens).To(Equal(int64(15))) + Expect(cap.records[0].CompletionTokens).To(Equal(int64(7))) + Expect(cap.records[0].TotalTokens).To(Equal(int64(22))) + Expect(cap.records[0].Model).To(Equal("claude-sonnet")) + }) + + It("populates RequestedModel/ServedModel from echo context when set", func() { + cap := &captureBackend{} + rec := billing.NewRecorder(cap) + fallback := &auth.User{ID: "local-uuid", Name: "local"} + + // A pre-handler stand-in for the future router middleware: it + // rewrites Served and remembers the original Requested. Once the + // real router lands, this is exactly the contract it must keep. + setRouterContext := func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + c.Set(httpMiddleware.ContextKeyRequestedModel, "auto") + c.Set(httpMiddleware.ContextKeyServedModel, "qwen-7b") + return next(c) + } + } + + e := echo.New() + e.POST("/v1/chat/completions", + mockChat(`{"prompt_tokens":4,"completion_tokens":3,"total_tokens":7}`), + httpMiddleware.UsageMiddleware(rec, fallback), + setRouterContext, + ) + + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{}`)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + e.ServeHTTP(w, req) + + Expect(w.Code).To(Equal(http.StatusOK)) + Expect(cap.records).To(HaveLen(1)) + Expect(cap.records[0].RequestedModel).To(Equal("auto")) + Expect(cap.records[0].ServedModel).To(Equal("qwen-7b")) + }) +}) diff --git a/core/http/react-ui/e2e/middleware-page.spec.js b/core/http/react-ui/e2e/middleware-page.spec.js new file mode 100644 index 000000000000..0c254e1278d0 --- /dev/null +++ b/core/http/react-ui/e2e/middleware-page.spec.js @@ -0,0 +1,308 @@ +import { test, expect } from '@playwright/test' + +// Mocked fixture covering the three things the page renders: +// - PII pattern catalogue (action badges, action-change buttons) +// - Per-model resolved PII state (one with default off, one with proxy default on, one with explicit YAML) +// - Recent events feed (the page must NEVER show the redacted content) +const MOCK_STATUS = { + pii: { + enabled_globally: true, + default_enabled_for_backends: ['proxy-*'], + patterns: [ + { id: 'email', description: 'Email addresses', action: 'mask', max_match_length: 254 }, + { id: 'ssn', description: 'US Social Security Numbers', action: 'mask', max_match_length: 11 }, + { id: 'api_key_prefix', description: 'API key prefixes', action: 'block', max_match_length: 200 }, + ], + models: [ + { name: 'qwen-7b', backend: 'llama-cpp', enabled: false, explicit: false, default_for_backend: false, overrides: null }, + { name: 'claude-sonnet', backend: 'proxy-anthropic', enabled: true, explicit: false, default_for_backend: true, overrides: null }, + { name: 'claude-strict', backend: 'proxy-anthropic', enabled: true, explicit: true, default_for_backend: true, overrides: { ssn: 'block' } }, + ], + recent_event_count: 2, + }, + router: { + configured: true, + models: [ + { + name: 'smart-router', + classifier: 'score', + fallback: 'qwen-7b', + policies: [ + { label: 'casual-chat', description: 'small talk' }, + { label: 'code-generation', description: 'writing or debugging code' }, + ], + candidates: [ + { model: 'qwen-3b', labels: ['casual-chat'] }, + { model: 'qwen-coder', labels: ['code-generation', 'casual-chat'] }, + ], + embedding_cache: { + embedding_model: 'nomic-embed-text-v1.5', + similarity_threshold: 0.80, + confidence_threshold: 0.60, + store_name: '', + stats: { + hits: 31, + misses: 1, + near_misses: 56, + low_confidence: 29, + embedder_errors: 0, + store_errors: 0, + // peak [0.4, 0.6) for paraphrases, secondary in [0.8, 1.0) for near-exact matches + similarity_buckets: [0, 0, 0, 1, 22, 16, 3, 7, 19, 19], + }, + }, + }, + ], + recent_decision_count: 1, + available_classifiers: ['score'], + }, +} + +const MOCK_DECISIONS = { + decisions: [ + { + id: 'rd_a1', correlation_id: 'corr-1', user_id: 'local', + router_model: 'smart-router', requested_model: 'smart-router', served_model: 'qwen-3b', + classifier: 'score', label: 'casual-chat', score: 0.91, latency_ms: 15, + cached: true, cache_similarity: 0.92, + created_at: '2026-05-06T11:00:00Z', + }, + ], +} + +const MOCK_EVENTS = { + events: [ + { + id: 'pii_aaa', kind: 'pii', correlation_id: 'corr-1', user_id: 'local', + direction: 'in', pattern_id: 'email', byte_offset: 12, length: 17, + hash_prefix: 'ff8d9819', action: 'mask', + created_at: '2026-05-06T10:00:00Z', + }, + { + id: 'proxy_connect_1', kind: 'proxy_connect', + host: 'api.openai.com', intercepted: true, + created_at: '2026-05-06T10:01:00Z', + }, + { + id: 'proxy_connect_2', kind: 'proxy_connect', + host: 'github.com', intercepted: false, + created_at: '2026-05-06T10:02:00Z', + }, + { + id: 'proxy_traffic_1', kind: 'proxy_traffic', correlation_id: 'corr-2', + host: 'api.openai.com', + bytes_sent: 412, bytes_received: 1228, status_code: 200, duration_ms: 240, + created_at: '2026-05-06T10:03:00Z', + }, + ], +} + +test.describe('Middleware page — admin in no-auth mode', () => { + test.beforeEach(async ({ page }) => { + await page.route('**/api/auth/status', (route) => + route.fulfill({ + contentType: 'application/json', + body: JSON.stringify({ authEnabled: false, staticApiKeyRequired: false, providers: [] }), + }) + ) + await page.route('**/api/middleware/status', (route) => + route.fulfill({ contentType: 'application/json', body: JSON.stringify(MOCK_STATUS) }) + ) + await page.route('**/api/pii/events?**', (route) => + route.fulfill({ contentType: 'application/json', body: JSON.stringify(MOCK_EVENTS) }) + ) + await page.route('**/api/router/decisions?**', (route) => + route.fulfill({ contentType: 'application/json', body: JSON.stringify(MOCK_DECISIONS) }) + ) + }) + + test('Filtering tab renders pattern catalogue and per-model state', async ({ page }) => { + await page.goto('/app/middleware') + + // Pattern table — at least one pattern id visible. + await expect(page.getByText('email').first()).toBeVisible() + await expect(page.getByText('api_key_prefix').first()).toBeVisible() + + // Per-model state — each model's name is visible. + await expect(page.getByText('qwen-7b').first()).toBeVisible() + await expect(page.getByText('claude-strict').first()).toBeVisible() + + // Default-policy banner mentions proxy-*. + await expect(page.getByText(/proxy-\*/).first()).toBeVisible() + }) + + test('Routing tab renders configured routers and recent decisions', async ({ page }) => { + await page.goto('/app/middleware') + await page.getByRole('button', { name: /Routing/i }).click() + // Active router model name visible. + await expect(page.getByText('smart-router').first()).toBeVisible() + // Candidate model names visible. + await expect(page.getByText('qwen-coder').first()).toBeVisible() + await expect(page.getByText('qwen-3b').first()).toBeVisible() + // Decision row visible — label and served model. + await expect(page.getByText('casual-chat').first()).toBeVisible() + }) + + test('Routing tab renders embedding-cache stats and similarity histogram', async ({ page }) => { + await page.goto('/app/middleware') + await page.getByRole('button', { name: /Routing/i }).click() + + // Embedding model name surfaces in the cache column. + await expect(page.getByText('nomic-embed-text-v1.5').first()).toBeVisible() + + // Hit-rate badge: 31 hits / (31 + 56 + 1) = 35% rounded. + await expect(page.getByText(/35% hit/i).first()).toBeVisible() + + // h/n/m counter row visible. + await expect(page.getByText(/31h\/56n\/1m/).first()).toBeVisible() + + // Skipped (low-confidence) counter visible. + await expect(page.getByText(/29 skipped/).first()).toBeVisible() + + // Threshold marker text matches the configured 0.80. + await expect(page.getByText(/sim ≥ 0\.8/).first()).toBeVisible() + + // Histogram bars rendered with hover titles that include the + // bucket range and count. Bucket 4 (peak) has count 22; the + //
with that exact title is the structural assertion. + await expect( + page.locator('div[title="[0.4, 0.5): 22"]') + ).toBeVisible() + // Bucket 8 (just at threshold) has count 19. + await expect( + page.locator('div[title="[0.8, 0.9): 19"]') + ).toBeVisible() + }) + + test('Routing tab shows a cached decision with cache_similarity', async ({ page }) => { + await page.goto('/app/middleware') + await page.getByRole('button', { name: /Routing/i }).click() + + // The decision row exposes the cached flag and the cosine that + // produced the hit so admins can correlate with the histogram. + await expect(page.getByText('corr-1')).toBeVisible() + }) + + test('Events tab renders rows but never the redacted content', async ({ page }) => { + await page.goto('/app/middleware') + await page.getByRole('button', { name: /Events/i }).click() + // Hash prefix is visible — that's how admins audit recurring leaks. + await expect(page.getByText('ff8d9819')).toBeVisible() + // The page only ever shows fields the EventStore stores. The matched + // value (e.g. "alice@example.com") would never appear because it's + // not in the payload — explicit asserting absence here is the + // contract the design relies on. + await expect(page.getByText(/@example\.com/)).toHaveCount(0) + }) + + test('Events tab renders proxy_connect rows with intercept decision', async ({ page }) => { + await page.goto('/app/middleware') + await page.getByRole('button', { name: /Events/i }).click() + + // Both intercept and tunnel decisions visible. + const interceptRow = page.locator('tr').filter({ hasText: 'api.openai.com' }).first() + await expect(interceptRow).toContainText(/intercepted/i) + const tunnelRow = page.locator('tr').filter({ hasText: 'github.com' }).first() + await expect(tunnelRow).toContainText(/tunneled/i) + }) + + test('Events tab renders proxy_traffic byte counts and status', async ({ page }) => { + await page.goto('/app/middleware') + await page.getByRole('button', { name: /Events/i }).click() + + // The traffic row formats as "HTTP 200 · ↑412B ↓1.2KB · 240ms". + // We assert on the durable parts: status code, byte values, duration unit. + const trafficRow = page.locator('tr').filter({ hasText: 'corr-2' }).first() + await expect(trafficRow).toContainText('HTTP 200') + await expect(trafficRow).toContainText('412B') + await expect(trafficRow).toContainText(/1\.2\s*KB/i) + await expect(trafficRow).toContainText('240ms') + }) + + test('Events kind filter narrows the table to the chosen kind', async ({ page }) => { + await page.goto('/app/middleware') + await page.getByRole('button', { name: /Events/i }).click() + + // Default = All: pii row + 2 connect rows + 1 traffic row visible. + await expect(page.getByText('ff8d9819')).toBeVisible() + await expect(page.getByText('github.com')).toBeVisible() + + // Click "PII" filter — proxy rows must disappear. + await page.getByRole('button', { name: /^PII$/ }).click() + await expect(page.getByText('ff8d9819')).toBeVisible() + await expect(page.getByText('github.com')).toHaveCount(0) + await expect(page.getByText('HTTP 200')).toHaveCount(0) + + // Click "Proxy traffic" — only the traffic row remains. + await page.getByRole('button', { name: /Proxy traffic/i }).click() + await expect(page.getByText('HTTP 200')).toBeVisible() + await expect(page.getByText('ff8d9819')).toHaveCount(0) + await expect(page.getByText('github.com')).toHaveCount(0) + + // Click "Proxy connect" — both connect rows visible, no PII or traffic. + await page.getByRole('button', { name: /Proxy connect/i }).click() + await expect(page.locator('tr').filter({ hasText: 'github.com' })).toHaveCount(1) + await expect(page.locator('tr').filter({ hasText: 'api.openai.com' }).filter({ hasText: 'intercepted' })).toHaveCount(1) + await expect(page.getByText('HTTP 200')).toHaveCount(0) + await expect(page.getByText('ff8d9819')).toHaveCount(0) + + // Click "All" — everything back. + await page.getByRole('button', { name: /^All$/ }).click() + await expect(page.getByText('ff8d9819')).toBeVisible() + await expect(page.getByText('HTTP 200')).toBeVisible() + }) + + test('Events tab shows the kind badge for each row', async ({ page }) => { + await page.goto('/app/middleware') + await page.getByRole('button', { name: /Events/i }).click() + + // The Kind column header is present. + await expect(page.locator('th').filter({ hasText: /^Kind$/ })).toBeVisible() + // At least one cell renders each of the three kinds. Scope to + // elements so the "PII" filter button doesn't match. + await expect(page.locator('span').getByText(/^pii$/i).first()).toBeVisible() + await expect(page.getByText(/^proxy connect$/i).first()).toBeVisible() + await expect(page.getByText(/^proxy traffic$/i).first()).toBeVisible() + }) + + test('PUT /api/pii/patterns/:id fires when an action button is clicked', async ({ page }) => { + let putHit = null + await page.route('**/api/pii/patterns/email', (route) => { + if (route.request().method() === 'PUT') { + putHit = JSON.parse(route.request().postData() || '{}') + route.fulfill({ contentType: 'application/json', body: JSON.stringify({ id: 'email', action: putHit.action, persisted: false }) }) + } else { + route.continue() + } + }) + + await page.goto('/app/middleware') + // Click the email row's "block" button (currently mask, so block is + // enabled). Use a precise locator that matches the inner button. + const emailRow = page.locator('tr').filter({ hasText: 'email' }).first() + await emailRow.getByRole('button', { name: 'block' }).click() + + await expect.poll(() => putHit).toEqual({ action: 'block' }) + }) +}) + +test.describe('Middleware page — non-admin under auth-on', () => { + test('redirects to /app when the user is not admin', async ({ page }) => { + await page.route('**/api/auth/status', (route) => + route.fulfill({ + contentType: 'application/json', + body: JSON.stringify({ + authEnabled: true, + staticApiKeyRequired: false, + providers: ['local'], + user: { id: 'bob', name: 'Bob', role: 'user', provider: 'local' }, + }), + }) + ) + + await page.goto('/app/middleware') + // RequireAdmin redirects non-admin viewers; the URL must not stay on /middleware. + await page.waitForURL(/\/app(?!\/middleware)/, { timeout: 5000 }) + expect(page.url()).not.toMatch(/\/middleware/) + }) +}) diff --git a/core/http/react-ui/e2e/router-template.spec.js b/core/http/react-ui/e2e/router-template.spec.js new file mode 100644 index 000000000000..72431854efe4 --- /dev/null +++ b/core/http/react-ui/e2e/router-template.spec.js @@ -0,0 +1,219 @@ +import { test, expect } from '@playwright/test' + +// Router template + structured editor regression tests. +// +// The historical regression was: the "Create routing model" button +// loaded the model editor with an array-shaped `router.candidates` +// value, which crashed when a code-editor field received it instead +// of a string ("(intermediate value).split is not a function"). +// +// The current schema is also covered: +// - classifier=score is the only shipped classifier +// - router.policies surfaces in its own structured editor (label + +// description rows with duplicate detection) +// - router.candidates is the structured {model, labels[]} editor; +// labels are chips populated from router.policies via FormContext +// - router.embedding_cache.* surface as labelled fields with the +// correct components (model-select / slider) +// - router.activation_threshold and the two embedding_cache slider +// fields render with slider min/max/step from the registry + +const ROUTER_METADATA = { + sections: [ + { id: 'general', label: 'General', icon: 'settings', order: 0 }, + { id: 'other', label: 'Other', icon: 'more-horizontal', order: 100 }, + ], + fields: [ + { path: 'name', yaml_key: 'name', go_type: 'string', ui_type: 'string', + section: 'general', label: 'Model Name', component: 'input', order: 0 }, + { + path: 'router.classifier', yaml_key: 'classifier', go_type: 'string', ui_type: 'string', + section: 'other', label: 'Classifier', component: 'select', + options: [{ value: 'score', label: 'Score (Arch-Router-style)' }], + description: 'Picks a candidate by scoring every policy label against the prompt. Only "score" is shipped today.', + order: 230, + }, + { + path: 'router.classifier_model', yaml_key: 'classifier_model', go_type: 'string', ui_type: 'string', + section: 'other', label: 'Classifier Model', component: 'model-select', autocomplete_provider: 'models:chat', + description: 'Loaded LocalAI model the score classifier asks to rank each policy label.', + order: 231, + }, + { + path: 'router.fallback', yaml_key: 'fallback', go_type: 'string', ui_type: 'string', + section: 'other', label: 'Fallback Model', component: 'model-select', autocomplete_provider: 'models:chat', + description: 'Model used when no candidate covers the active label set.', + order: 232, + }, + { + path: 'router.activation_threshold', yaml_key: 'activation_threshold', go_type: 'float64', ui_type: 'float', + section: 'other', label: 'Activation Threshold', component: 'slider', + min: 0, max: 1, step: 0.05, + description: 'Softmax-probability floor a policy must clear to join the active label set.', + order: 233, + }, + { + path: 'router.policies', yaml_key: 'policies', go_type: '[]RouterPolicy', ui_type: 'object', + section: 'other', label: 'Policies', component: 'router-policies', + description: 'Label vocabulary the classifier scores over.', + order: 235, + }, + { + path: 'router.candidates', yaml_key: 'candidates', go_type: '[]RouterCandidate', ui_type: 'object', + section: 'other', label: 'Candidates', component: 'router-candidates', + description: 'Routing table: each entry binds a downstream model to a set of policy labels.', + order: 236, + }, + { + path: 'router.embedding_cache.embedding_model', yaml_key: 'embedding_model', go_type: 'string', ui_type: 'string', + section: 'other', label: 'L2 Cache: Embedding Model', component: 'model-select', autocomplete_provider: 'models', + description: 'Embedding model used by the L2 decision cache.', + order: 237, + }, + { + path: 'router.embedding_cache.similarity_threshold', yaml_key: 'similarity_threshold', go_type: 'float64', ui_type: 'float', + section: 'other', label: 'L2 Cache: Similarity Threshold', component: 'slider', + min: 0, max: 1, step: 0.01, + description: 'Cosine-similarity floor a cache candidate must clear to count as a hit.', + order: 238, + }, + ], +} + +const MIDDLEWARE_STATUS = { + pii: { enabled_globally: false, patterns: [], models: [], recent_event_count: 0 }, + router: { configured: false, models: [], recent_decision_count: 0, available_classifiers: ['score'] }, + mitm: { running: false, listen_addr: '', configured_addr: '', host_owners: {}, host_conflicts: {}, models: [], ca_available: false, ca_cert_url: '' }, +} + +test.describe('Router template — create flow', () => { + test.beforeEach(async ({ page }) => { + await page.route('**/api/auth/status', (route) => + route.fulfill({ + contentType: 'application/json', + body: JSON.stringify({ authEnabled: false, staticApiKeyRequired: false, providers: [] }), + }) + ) + await page.route('**/api/middleware/status', (route) => + route.fulfill({ contentType: 'application/json', body: JSON.stringify(MIDDLEWARE_STATUS) }) + ) + await page.route('**/api/router/decisions?**', (route) => + route.fulfill({ contentType: 'application/json', body: JSON.stringify({ decisions: [] }) }) + ) + await page.route('**/api/pii/events?**', (route) => + route.fulfill({ contentType: 'application/json', body: JSON.stringify({ events: [] }) }) + ) + await page.route('**/api/models/config-metadata*', (route) => + route.fulfill({ contentType: 'application/json', body: JSON.stringify(ROUTER_METADATA) }) + ) + await page.route('**/api/models/config-metadata/autocomplete/**', (route) => + route.fulfill({ contentType: 'application/json', body: JSON.stringify({ values: [] }) }) + ) + + // Surface any uncaught render-time error so the assertion fails + // with a useful message rather than the test silently passing. + page.on('pageerror', (err) => { + throw new Error(`uncaught page error: ${err.message}`) + }) + }) + + test('Routing tab links to the model editor with the router template loaded', async ({ page }) => { + await page.goto('/app/middleware') + await page.getByRole('button', { name: /Routing/i }).click() + + // Empty-state button is the primary CTA. + await page.getByRole('button', { name: /Create routing model/i }).click() + + // Editor loads on a /app/model-editor URL with template=router. + await expect(page).toHaveURL(/\/app\/model-editor.*template=router/) + }) + + test('Router template renders without crashing on structured candidates/policies', async ({ page }) => { + // Navigate straight to the create-with-template URL. This was the + // regression that crashed with "(intermediate value).split is not + // a function" when the template's array-shaped router.candidates + // fell into a code-editor wrapper. + await page.goto('/app/model-editor?template=router') + + // The react-router error overlay must not appear. + await expect(page.getByText(/Unexpected Application Error/i)).toHaveCount(0) + + // Editor surface visible. Template URL is "create mode", so the + // heading reads "Add Model" rather than "Model Editor". + await expect(page.locator('h1.page-title')).toBeVisible({ timeout: 10_000 }) + + // Top-level field labels seeded by the template are visible. + // embedding_cache.* fields are surfaced via "Add Field" search + // rather than active by default — separate spec covers them. + await expect(page.getByText('Classifier').first()).toBeVisible() + await expect(page.getByText('Policies').first()).toBeVisible() + await expect(page.getByText('Candidates').first()).toBeVisible() + await expect(page.getByText('Activation Threshold').first()).toBeVisible() + }) + + test('Classifier select offers only the score option', async ({ page }) => { + await page.goto('/app/model-editor?template=router') + + // SearchableSelect renders the current option's *label* inside the + // trigger button. After the schema cleanup the only option is + // "Score (Arch-Router-style)", pre-selected by the template. + await expect(page.getByText('Score (Arch-Router-style)').first()).toBeVisible({ timeout: 10_000 }) + }) + + test('Policies editor renders structured rows with label + description fields', async ({ page }) => { + await page.goto('/app/model-editor?template=router') + + // The template seeds three example policies. Their labels are + // pre-populated in input fields with monospace styling — the + // editor signature is "Add policy" button + label/description + // input pairs. + await expect(page.getByRole('button', { name: /Add policy/i }).first()).toBeVisible() + + // Pre-seeded labels visible as input values. RouterPoliciesEditor + // renders each label in an input with a recognisable placeholder; + // assert on their values by position. + const labelInputs = page.locator('input[placeholder^="label ("]') + await expect(labelInputs.nth(0)).toHaveValue('code-generation') + await expect(labelInputs.nth(1)).toHaveValue('casual-chat') + await expect(labelInputs.nth(2)).toHaveValue('math-reasoning') + }) + + test('Candidates editor renders {model, labels} rows with policy-aware label chips', async ({ page }) => { + await page.goto('/app/model-editor?template=router') + + // "Add candidate" is the signature of the new RouterCandidatesEditor. + await expect(page.getByRole('button', { name: /Add candidate/i }).first()).toBeVisible() + + // Each candidate row should expose move-up/move-down controls, + // a model picker, and label chips. The chip for a known policy + // label appears as a button with the policy's label text. + // Pre-seeded template: candidate[0] has labels=['casual-chat']; + // candidate[1] has labels=['code-generation', 'casual-chat', 'math-reasoning']. + // + // The chips appear inside a flex row of buttons. Using getByRole + // with the exact name catches typos/regressions cleanly. + await expect(page.getByRole('button', { name: 'casual-chat' }).first()).toBeVisible() + await expect(page.getByRole('button', { name: 'code-generation' }).first()).toBeVisible() + await expect(page.getByRole('button', { name: 'math-reasoning' }).first()).toBeVisible() + }) + + test('Adding a duplicate policy label flags the duplicate row', async ({ page }) => { + await page.goto('/app/model-editor?template=router') + + // Add a new empty policy row, then type a duplicate of the + // existing 'casual-chat'. The duplicate detection in + // RouterPoliciesEditor sets a warning border via inline style. + await page.getByRole('button', { name: /Add policy/i }).first().click() + + // Find the newly-added empty label input (placeholder catches it). + const newLabel = page.locator('input[placeholder*="label (e.g. code-generation)"]').last() + await newLabel.fill('casual-chat') + + // Both rows now hold the same label. The duplicate-detection + // logic flags the row visually; we assert on the title attribute + // RouterPoliciesEditor sets on the input when duplicate=true. + await expect( + page.locator('input[title="Duplicate label — candidates won\'t be able to distinguish them"]').first() + ).toBeVisible() + }) +}) diff --git a/core/http/react-ui/e2e/usage-dashboard.spec.js b/core/http/react-ui/e2e/usage-dashboard.spec.js new file mode 100644 index 000000000000..a27bf40064be --- /dev/null +++ b/core/http/react-ui/e2e/usage-dashboard.spec.js @@ -0,0 +1,148 @@ +import { test, expect } from '@playwright/test' + +// Mock usage payload as the new /api/usage endpoint returns it. +const MOCK_USAGE = { + viewer: { id: 'local-uuid', name: 'local', role: 'admin', provider: 'local' }, + totals: { + prompt_tokens: 1234, + completion_tokens: 567, + total_tokens: 1801, + request_count: 42, + }, + usage: [ + { + bucket: '2026-05-05', + model: 'qwen-7b', + user_id: 'local-uuid', + user_name: 'local', + prompt_tokens: 1234, + completion_tokens: 567, + total_tokens: 1801, + request_count: 42, + }, + ], +} + +const MOCK_USAGE_AUTH_USER = { + ...MOCK_USAGE, + viewer: { id: 'alice-uuid', name: 'Alice', role: 'user', provider: 'local' }, +} + +// Two scenarios: +// 1. No-auth single-user box: /api/auth/status returns authEnabled:false +// and the page must call /api/usage and render the local user's data. +// 2. Auth-on regular user: status returns authEnabled:true and the page +// keeps using /api/auth/usage as before. +// +// The point of these specs is the "prevent accidental removal" guarantee +// the user asked for: if anyone gates the Usage page behind auth again, +// scenario 1 fails immediately. + +test.describe('Usage page — single-user no-auth mode', () => { + test.beforeEach(async ({ page }) => { + await page.route('**/api/auth/status', (route) => + route.fulfill({ + contentType: 'application/json', + body: JSON.stringify({ + authEnabled: false, + staticApiKeyRequired: false, + providers: [], + }), + }) + ) + + // The new no-auth code path. If anyone reverts Usage.jsx to + // /api/auth/usage in single-user mode, this route is never hit and + // the test fails because no usage data renders. + let usageHits = 0 + await page.route('**/api/usage?**', (route) => { + usageHits++ + route.fulfill({ + contentType: 'application/json', + body: JSON.stringify(MOCK_USAGE), + }) + }) + // The synthetic local user has admin role, so Usage.jsx also pulls + // the cluster-wide view from /api/usage/all to populate displayTotals. + await page.route('**/api/usage/all?**', (route) => + route.fulfill({ + contentType: 'application/json', + body: JSON.stringify(MOCK_USAGE), + }) + ) + page.usageHits = () => usageHits + }) + + test('Usage entry is visible in sidebar without auth', async ({ page }) => { + await page.goto('/app') + const systemSection = page.locator('button.sidebar-section-toggle', { hasText: 'System' }) + await systemSection.click() + const usageLink = page.locator('a.nav-item[href="/app/usage"]') + await expect(usageLink).toBeVisible() + }) + + test('navigating to /app/usage renders the dashboard with local-user data', async ({ page }) => { + await page.goto('/app/usage') + + // The page used to bail with "Usage tracking unavailable" when authEnabled=false. + // We assert the *opposite*: data is rendered and the empty-state text is absent. + await expect(page.getByText('Usage tracking unavailable')).toHaveCount(0) + + // The total-tokens stat card is one of the first things rendered after + // a successful /api/usage call. We assert the formatted number "1.8K" + // is present (formatNumber in Usage.jsx renders 1801 as "1.8K"). + await expect(page.getByText('1.8K').first()).toBeVisible() + }) +}) + +test.describe('Usage page — auth on', () => { + test.beforeEach(async ({ page }) => { + // RequireAuth redirects to /login when user is null, so the status + // response must include a resolved user for auth-on specs to reach + // the Usage page at all. + await page.route('**/api/auth/status', (route) => + route.fulfill({ + contentType: 'application/json', + body: JSON.stringify({ + authEnabled: true, + staticApiKeyRequired: false, + providers: ['local'], + user: { id: 'alice-uuid', name: 'Alice', role: 'user', provider: 'local' }, + }), + }) + ) + await page.route('**/api/auth/me', (route) => + route.fulfill({ + contentType: 'application/json', + body: JSON.stringify({ + user: { id: 'alice-uuid', name: 'Alice', role: 'user', provider: 'local' }, + permissions: {}, + }), + }) + ) + await page.route('**/api/auth/usage?**', (route) => + route.fulfill({ + contentType: 'application/json', + body: JSON.stringify(MOCK_USAGE_AUTH_USER), + }) + ) + await page.route('**/api/auth/quota', (route) => + route.fulfill({ contentType: 'application/json', body: JSON.stringify({ quotas: [] }) }) + ) + }) + + test('Usage page calls /api/auth/usage when auth is on', async ({ page }) => { + let authUsageHit = false + await page.route('**/api/auth/usage?**', (route) => { + authUsageHit = true + route.fulfill({ + contentType: 'application/json', + body: JSON.stringify(MOCK_USAGE_AUTH_USER), + }) + }) + + await page.goto('/app/usage') + await expect(page.getByText('1.8K').first()).toBeVisible() + expect(authUsageHit).toBe(true) + }) +}) diff --git a/core/http/react-ui/e2e/users-tab-gating.spec.js b/core/http/react-ui/e2e/users-tab-gating.spec.js new file mode 100644 index 000000000000..f683d215f527 --- /dev/null +++ b/core/http/react-ui/e2e/users-tab-gating.spec.js @@ -0,0 +1,74 @@ +import { test, expect } from '@playwright/test' + +// Two surfaces enforce single-user (no-auth) gating for the Users page: +// 1. Sidebar entry: hidden via the `authOnly: true` flag in Sidebar.jsx +// (filterItem returns false when `!authEnabled`). +// 2. Direct URL navigation: RequireAuthEnabled wrapping the /app/users +// route in router.jsx redirects to /app when authEnabled is false. +// +// Without (2), an old bookmark or pasted URL would land on a page rendered +// against admin-only `/api/auth/admin/users` data — which doesn't exist +// when auth is off — and the user sees a confusing empty/error state. +// +// These specs are the "prevent accidental removal" guarantee — if anyone +// drops the gating, /app/users stays open in single-user mode and the +// test fails on the redirect or the visible sidebar item. + +test.describe('Users tab — single-user no-auth mode', () => { + test.beforeEach(async ({ page }) => { + await page.route('**/api/auth/status', (route) => + route.fulfill({ + contentType: 'application/json', + body: JSON.stringify({ + authEnabled: false, + staticApiKeyRequired: false, + providers: [], + }), + }) + ) + }) + + test('sidebar does not list Users entry', async ({ page }) => { + await page.goto('/app') + const systemSection = page.locator('button.sidebar-section-toggle', { hasText: 'System' }) + await systemSection.click() + // The Users page link uses /app/users; if Sidebar's authOnly gate + // regresses (or someone removes the flag), this assertion fails. + const usersLink = page.locator('a.nav-item[href="/app/users"]') + await expect(usersLink).toHaveCount(0) + }) + + test('direct navigation to /app/users redirects to /app', async ({ page }) => { + await page.goto('/app/users') + // RequireAuthEnabled performs the redirect synchronously, but the URL + // change is async — wait for it before asserting. + await page.waitForURL(/\/app(?!\/users)/, { timeout: 5000 }) + expect(page.url()).toMatch(/\/app(\/?$|\/(?!users))/) + }) +}) + +test.describe('Users tab — auth on', () => { + test.beforeEach(async ({ page }) => { + await page.route('**/api/auth/status', (route) => + route.fulfill({ + contentType: 'application/json', + body: JSON.stringify({ + authEnabled: true, + staticApiKeyRequired: false, + providers: ['local'], + // Mark the viewer as admin so the sidebar's adminOnly gate also + // passes; the test then exercises the authOnly path in isolation. + user: { id: 'admin-uuid', name: 'Admin', role: 'admin', provider: 'local' }, + }), + }) + ) + }) + + test('sidebar lists Users entry when auth is on', async ({ page }) => { + await page.goto('/app') + const systemSection = page.locator('button.sidebar-section-toggle', { hasText: 'System' }) + await systemSection.click() + const usersLink = page.locator('a.nav-item[href="/app/users"]') + await expect(usersLink).toBeVisible() + }) +}) diff --git a/core/http/react-ui/public/locales/en/nav.json b/core/http/react-ui/public/locales/en/nav.json index 9f5218a19ee8..ac85d49794db 100644 --- a/core/http/react-ui/public/locales/en/nav.json +++ b/core/http/react-ui/public/locales/en/nav.json @@ -36,6 +36,7 @@ "mcpJobs": "MCP CI Jobs", "usage": "Usage", "users": "Users", + "middleware": "Middleware", "backends": "Backends", "traces": "Traces", "nodes": "Nodes", diff --git a/core/http/react-ui/src/components/ConfigFieldRenderer.jsx b/core/http/react-ui/src/components/ConfigFieldRenderer.jsx index f2c80885dbe8..ccf5bf05c155 100644 --- a/core/http/react-ui/src/components/ConfigFieldRenderer.jsx +++ b/core/http/react-ui/src/components/ConfigFieldRenderer.jsx @@ -5,6 +5,10 @@ import SearchableSelect from './SearchableSelect' import SearchableModelSelect from './SearchableModelSelect' import AutocompleteInput from './AutocompleteInput' import CodeEditor from './CodeEditor' +import StructuredCodeEditor from './StructuredCodeEditor' +import PIIPatternListEditor from './PIIPatternListEditor' +import RouterCandidatesEditor from './RouterCandidatesEditor' +import RouterPoliciesEditor from './RouterPoliciesEditor' // Map autocomplete provider to SearchableModelSelect capability const PROVIDER_TO_CAPABILITY = { @@ -300,8 +304,17 @@ export default function ConfigFieldRenderer({ field, value, onChange, onRemove, ) } - // Code editor + // Code editor. Two flavours: + // - Plain CodeEditor when the form value is a string (Go template + // blobs etc. — what the original `code-editor` shipped for). + // - StructuredCodeEditor when the form value is a structured + // object/array (e.g. `router.candidates`, where the canonical + // value is `[{label, model, rules}, ...]`). The wrapper keeps a + // YAML representation in the textarea while publishing the + // parsed structure back to form state, so the save flow can + // unflatten it into the YAML file cleanly. if (component === 'code-editor') { + const isStructured = value !== null && value !== undefined && typeof value !== 'string' return (
@@ -310,7 +323,9 @@ export default function ConfigFieldRenderer({ field, value, onChange, onRemove,
{description}
- + {isStructured + ? + : }
) } @@ -345,6 +360,57 @@ export default function ConfigFieldRenderer({ field, value, onChange, onRemove, ) } + // Router candidates — routing table editor. Each row is + // {model, labels[]}; the labels picker reads from router.policies + // via FormContext so candidate labels match the declared vocabulary. + if (component === 'router-candidates') { + return ( +
+
+
+
+
{description}
+
+
+ +
+ ) + } + + // Router policies — label vocabulary editor. Each row is + // {label, description}; the description ends up verbatim in the + // routing system prompt sent to the classifier model. + if (component === 'router-policies') { + return ( +
+
+
+
+
{description}
+
+
+ +
+ ) + } + + // PII pattern list — per-model action overrides for named patterns. + // The pattern catalog is loaded from /api/pii/patterns at render time + // so new built-in patterns surface automatically. + if (component === 'pii-pattern-list') { + return ( +
+
+
+
+
{description}
+
+
+ +
+ ) + } + // Map editor if (component === 'map-editor') { return ( diff --git a/core/http/react-ui/src/components/PIIPatternListEditor.jsx b/core/http/react-ui/src/components/PIIPatternListEditor.jsx new file mode 100644 index 000000000000..558f4cd6ab2d --- /dev/null +++ b/core/http/react-ui/src/components/PIIPatternListEditor.jsx @@ -0,0 +1,120 @@ +import { useState, useEffect, useMemo } from 'react' +import { apiUrl } from '../utils/basePath' +import SearchableSelect from './SearchableSelect' + +const ACTION_OPTIONS = [ + { value: 'mask', label: 'Mask — replace with a [REDACTED:id] placeholder' }, + { value: 'block', label: 'Block — reject the request (request side) / mask in stream' }, + { value: 'route_local', label: 'Route local — keep text, force local-only routing' }, +] + +export default function PIIPatternListEditor({ value, onChange }) { + const items = Array.isArray(value) ? value : [] + + const [catalog, setCatalog] = useState([]) + const [loadError, setLoadError] = useState(null) + + useEffect(() => { + let cancelled = false + fetch(apiUrl('/api/pii/patterns')) + .then(r => r.ok ? r.json() : Promise.reject(new Error(`HTTP ${r.status}`))) + .then(data => { if (!cancelled) setCatalog(data?.patterns || []) }) + .catch(err => { if (!cancelled) setLoadError(err.message) }) + return () => { cancelled = true } + }, []) + + const idOptions = useMemo(() => + catalog.map(p => ({ + value: p.id, + label: p.description ? `${p.id} — ${p.description}` : p.id, + })), + [catalog] + ) + + // Patterns already chosen — exclude from the "add row" select so each + // pattern only appears once per model. + const usedIDs = new Set(items.map(it => it?.id).filter(Boolean)) + const availableForAdd = idOptions.filter(o => !usedIDs.has(o.value)) + + const update = (index, key, val) => { + const next = items.map((it, i) => + i === index ? { ...it, [key]: val } : it + ) + onChange(next) + } + + const remove = (index) => { + onChange(items.filter((_, i) => i !== index)) + } + + const add = (id) => { + const cat = catalog.find(c => c.id === id) + onChange([...items, { id, action: cat?.action || 'mask' }]) + } + + return ( +
+ {loadError && ( +
+ Could not load pattern catalog: {loadError}. You can still type IDs manually. +
+ )} + + {items.length === 0 && ( +
+ No overrides — every pattern uses its global default action. Add a row below to + tighten or relax the action for a specific pattern on this model. +
+ )} + + {items.map((row, i) => { + const cat = catalog.find(c => c.id === row?.id) + const idLabel = cat?.description ? `${row.id} — ${cat.description}` : (row?.id || '') + // Show the chosen id even if the catalog hasn't loaded yet (or + // the YAML references an unknown pattern), so users can edit + // without losing context. + const idItems = [ + ...(row?.id && !idOptions.some(o => o.value === row.id) + ? [{ value: row.id, label: idLabel }] + : []), + ...idOptions.filter(o => o.value === row?.id || !usedIDs.has(o.value)), + ] + return ( +
+ update(i, 'id', v)} + options={idItems} + placeholder="Pattern..." + style={{ flex: '1 1 220px', minWidth: 200 }} + /> + update(i, 'action', v)} + options={ACTION_OPTIONS} + placeholder="Action..." + style={{ flex: '1 1 240px', minWidth: 220 }} + /> + +
+ ) + })} + + {availableForAdd.length > 0 && ( +
+ v && add(v)} + options={availableForAdd} + placeholder="+ Add pattern override..." + style={{ flex: '1 1 220px', minWidth: 200 }} + /> +
+ )} +
+ ) +} diff --git a/core/http/react-ui/src/components/RequireAuthEnabled.jsx b/core/http/react-ui/src/components/RequireAuthEnabled.jsx new file mode 100644 index 000000000000..c71ca3f07b99 --- /dev/null +++ b/core/http/react-ui/src/components/RequireAuthEnabled.jsx @@ -0,0 +1,16 @@ +import { Navigate } from 'react-router-dom' +import { useAuth } from '../context/AuthContext' + +// RequireAuthEnabled gates routes that only make sense when auth is on. +// User management is the canonical example: in single-user (no-auth) +// mode there is exactly one synthetic local user, so the page would +// either be empty or expose admin tools that have nothing to manage. +// +// We redirect to /app rather than render a "not available" page so that +// stale bookmarks don't leave the user on a dead-end screen. +export default function RequireAuthEnabled({ children }) { + const { authEnabled, loading } = useAuth() + if (loading) return null + if (!authEnabled) return + return children +} diff --git a/core/http/react-ui/src/components/RouterCandidatesEditor.jsx b/core/http/react-ui/src/components/RouterCandidatesEditor.jsx new file mode 100644 index 000000000000..5d744c8d4639 --- /dev/null +++ b/core/http/react-ui/src/components/RouterCandidatesEditor.jsx @@ -0,0 +1,185 @@ +import { useMemo } from 'react' +import { useFormContext } from '../contexts/FormContext' +import SearchableModelSelect from './SearchableModelSelect' + +// RouterCandidatesEditor renders the routing table for a router model. +// Each row binds a downstream model to a SET of policy labels it can +// serve. The middleware picks the first candidate whose labels are a +// superset of the active label set from the classifier, so admins +// order candidates smallest → largest. +// +// Schema mirrors core/config.RouterCandidate: +// { model: string, labels: []string } +// +// Labels are picked from the parent form's router.policies (a multi- +// select rather than a free-text input) so a typo in one place doesn't +// silently disable a candidate. Labels typed manually are still kept +// — useful when admins paste a config before defining the policies. + +export default function RouterCandidatesEditor({ value, onChange }) { + const items = Array.isArray(value) ? value : [] + const knownLabels = usePolicyLabels() + const knownLabelSet = useMemo(() => new Set(knownLabels), [knownLabels]) + + const update = (index, mut) => { + const next = items.map((it, i) => (i === index ? mut({ ...it }) : it)) + onChange(next) + } + const remove = (index) => onChange(items.filter((_, i) => i !== index)) + const move = (index, dir) => { + const j = index + dir + if (j < 0 || j >= items.length) return + const next = items.slice() + ;[next[index], next[j]] = [next[j], next[index]] + onChange(next) + } + const add = () => onChange([...items, { model: '', labels: [] }]) + + return ( +
+ {items.length === 0 && ( +
+ No candidates yet. Add at least one — order from smallest model to largest. + The middleware picks the FIRST candidate whose labels superset the active set. +
+ )} + + {items.map((row, i) => ( + update(i, mut)} + onRemove={() => remove(i)} + onMove={(dir) => move(i, dir)} + /> + ))} + + +
+ ) +} + +function CandidateRow({ index, total, row, knownLabels, knownLabelSet, onChange, onRemove, onMove }) { + const labels = Array.isArray(row?.labels) ? row.labels : [] + const toggleLabel = (label) => onChange((r) => ({ + ...r, + labels: labels.includes(label) ? labels.filter(l => l !== label) : [...labels, label], + })) + + // Row-local labels not in the parent policy list are still surfaced + // (with a warning chip) so a stale row doesn't silently lose its + // labels while the policy list is being edited. + const unknownOnRow = labels.filter(l => !knownLabelSet.has(l)) + const visible = [...knownLabels, ...unknownOnRow] + + return ( +
+
+ #{index + 1} + + + + {index === 0 ? 'tried first' : index === total - 1 ? 'tried last (fallback-class)' : ''} + +
+ +
+ onChange((r) => ({ ...r, model: v }))} + placeholder="downstream model..." + /> + +
+ +
+
+ {visible.length === 0 + ? 'No policies defined yet — add policies above before assigning labels.' + : 'Labels this model can serve. The middleware requires the candidate to cover every label the classifier activates.'} +
+
+ {visible.map((label) => { + const on = labels.includes(label) + const known = knownLabelSet.has(label) + return ( + + ) + })} +
+
+
+ ) +} + +// usePolicyLabels reads router.policies from the surrounding form state +// and returns the list of declared labels. Falls back to [] when no +// FormContext is present (e.g. preview render). +function usePolicyLabels() { + const ctx = useFormContext() + const policies = ctx?.formData?.['router.policies'] + if (!Array.isArray(policies)) return [] + return policies.map(p => p?.label).filter(Boolean) +} diff --git a/core/http/react-ui/src/components/RouterPoliciesEditor.jsx b/core/http/react-ui/src/components/RouterPoliciesEditor.jsx new file mode 100644 index 000000000000..b323bc288737 --- /dev/null +++ b/core/http/react-ui/src/components/RouterPoliciesEditor.jsx @@ -0,0 +1,109 @@ +import { useMemo } from 'react' + +// RouterPoliciesEditor renders the label vocabulary the score +// classifier ranks for each request. The shape mirrors +// core/config.RouterPolicy: +// +// { label: string, description: string } +// +// The description ends up verbatim in the routing system prompt fed +// to the classifier model. Short, action-oriented sentences ("writing +// or debugging code", "small talk") consistently produce cleaner +// label distributions on Arch-Router-style scorers than longer +// taxonomies — keep them tight. + +export default function RouterPoliciesEditor({ value, onChange }) { + const items = Array.isArray(value) ? value : [] + + const duplicateLabels = useMemo(() => { + const seen = new Set() + const dup = new Set() + for (const it of items) { + const label = it?.label + if (!label) continue + if (seen.has(label)) dup.add(label) + else seen.add(label) + } + return dup + }, [items]) + + const update = (index, mut) => { + const next = items.map((it, i) => (i === index ? mut({ ...it }) : it)) + onChange(next) + } + const remove = (index) => onChange(items.filter((_, i) => i !== index)) + const add = () => onChange([...items, { label: '', description: '' }]) + + return ( +
+ {items.length === 0 && ( +
+ No policies defined. Add at least one — the classifier needs a label vocabulary to rank over, + and candidates reference these labels. +
+ )} + + {items.map((row, i) => ( + update(i, mut)} + onRemove={() => remove(i)} + /> + ))} + + +
+ ) +} + +function PolicyRow({ row, duplicate, onChange, onRemove }) { + return ( +
+ onChange((r) => ({ ...r, label: e.target.value }))} + style={{ fontFamily: 'var(--font-mono)', fontSize: '0.8125rem' }} + title={duplicate ? 'Duplicate label — candidates won\'t be able to distinguish them' : ''} + /> + onChange((r) => ({ ...r, description: e.target.value }))} + style={{ fontSize: '0.8125rem' }} + /> + +
+ ) +} + diff --git a/core/http/react-ui/src/components/Sidebar.jsx b/core/http/react-ui/src/components/Sidebar.jsx index 9956fb7c5b7f..148a33bfb603 100644 --- a/core/http/react-ui/src/components/Sidebar.jsx +++ b/core/http/react-ui/src/components/Sidebar.jsx @@ -69,8 +69,9 @@ const sections = [ id: 'system', titleKey: 'sections.system', items: [ - { path: '/app/usage', icon: 'fas fa-chart-bar', labelKey: 'items.usage', authOnly: true }, + { path: '/app/usage', icon: 'fas fa-chart-bar', labelKey: 'items.usage' }, { path: '/app/users', icon: 'fas fa-users', labelKey: 'items.users', adminOnly: true, authOnly: true }, + { path: '/app/middleware', icon: 'fas fa-shield-halved', labelKey: 'items.middleware', adminOnly: true }, { path: '/app/backends', icon: 'fas fa-server', labelKey: 'items.backends', adminOnly: true }, { path: '/app/traces', icon: 'fas fa-chart-line', labelKey: 'items.traces', adminOnly: true }, { path: '/app/nodes', icon: 'fas fa-network-wired', labelKey: 'items.nodes', adminOnly: true, feature: 'distributed' }, diff --git a/core/http/react-ui/src/components/StructuredCodeEditor.jsx b/core/http/react-ui/src/components/StructuredCodeEditor.jsx new file mode 100644 index 000000000000..496d0cb1b8e7 --- /dev/null +++ b/core/http/react-ui/src/components/StructuredCodeEditor.jsx @@ -0,0 +1,80 @@ +import { useEffect, useState } from 'react' +import YAML from 'yaml' +import CodeEditor from './CodeEditor' + +// StructuredCodeEditor is the wrapper that lets a `code-editor` +// field hold a structured value (object / array) rather than a raw +// string. Two reasons we need this: +// +// 1. CodeMirror's EditorState.create({ doc }) requires a string — +// pass an array and it crashes inside CM's Text class with +// "(intermediate value).split is not a function". +// 2. The model-editor save path uses unflattenConfig + YAML.stringify +// which needs the structured value to round-trip cleanly into +// YAML (otherwise a YAML-string-of-YAML appears in the file). +// +// The component keeps two pieces of state in sync: +// - `text`: the YAML representation shown to the user. The user +// edits this; we don't reformat while they type. +// - upstream `value`: the parsed structured value held by the +// editor form. We try to parse `text` on every edit; if the +// parse succeeds we publish the new structure, otherwise the +// structured value lags until the YAML is syntactically valid +// again (the linter shows the error inline). +export default function StructuredCodeEditor({ value, onChange, minHeight }) { + // Lazy-init: stringify the initial structured value once. Subsequent + // re-renders driven by our own onChange keep `text` authoritative — + // we only re-sync from `value` when it changes due to an external + // edit (template selection, YAML-tab save). + const [text, setText] = useState(() => structuredToYAML(value)) + const [lastExternal, setLastExternal] = useState(value) + + useEffect(() => { + // Detect external changes (a different `value` reference that + // didn't come from our own parse). reference-equality is enough + // because onChange always publishes the parsed object, never the + // text. + if (value !== lastExternal) { + const next = structuredToYAML(value) + setText(next) + setLastExternal(value) + } + }, [value, lastExternal]) + + const handleTextChange = (nextText) => { + setText(nextText) + // Empty buffer publishes empty array — the most common "I want to + // start fresh" case and keeps a YAML-valid round-trip. + if (!nextText.trim()) { + onChange([]) + setLastExternal([]) + return + } + try { + const parsed = YAML.parse(nextText) + onChange(parsed) + setLastExternal(parsed) + } catch { + // Hold the structured value steady while YAML is being typed + // and is temporarily invalid. The CodeMirror YAML linter + // surfaces the syntax error inline. + } + } + + return +} + +// structuredToYAML renders the form-state value as the YAML text the +// editor shows. Strings pass through untouched (so a legacy template +// that supplied a pre-formatted YAML string still renders cleanly). +// null/undefined renders as empty so the editor starts blank rather +// than showing the literal "null\n". +export function structuredToYAML(value) { + if (value === null || value === undefined) return '' + if (typeof value === 'string') return value + try { + return YAML.stringify(value) + } catch { + return '' + } +} diff --git a/core/http/react-ui/src/contexts/FormContext.jsx b/core/http/react-ui/src/contexts/FormContext.jsx new file mode 100644 index 000000000000..f29402e34764 --- /dev/null +++ b/core/http/react-ui/src/contexts/FormContext.jsx @@ -0,0 +1,26 @@ +import { createContext, useContext, useMemo } from 'react' + +// FormContext exposes the surrounding form's read-only state to deep +// field editors that need to inspect sibling fields. Used by the +// router-candidates editor to read router.policies so candidate +// labels can be picked from the declared policy vocabulary rather +// than typed by hand. +// +// Only the read shape is exposed (formData); mutations still go +// through the parent's onChange so the editor remains the single +// source of truth. +const FormContext = createContext(null) + +export function FormContextProvider({ formData, children }) { + // Memo the wrapper so consumers don't re-render on every keystroke + // when formData itself is referentially stable. ModelEditor's + // setValues replaces the object on each edit, so this still + // propagates updates — it just avoids spurious churn when an + // ancestor re-renders without changing values. + const value = useMemo(() => ({ formData }), [formData]) + return {children} +} + +export function useFormContext() { + return useContext(FormContext) +} diff --git a/core/http/react-ui/src/pages/Middleware.jsx b/core/http/react-ui/src/pages/Middleware.jsx new file mode 100644 index 000000000000..7ea5dd4a548b --- /dev/null +++ b/core/http/react-ui/src/pages/Middleware.jsx @@ -0,0 +1,957 @@ +import { useState, useEffect, useCallback, useRef } from 'react' +import { useOutletContext, Link, useNavigate } from 'react-router-dom' +import { apiUrl } from '../utils/basePath' +import { settingsApi } from '../utils/api' +import LoadingSpinner from '../components/LoadingSpinner' + +// Middleware admin page. Three tabs: +// - Filtering: PII pattern catalogue + per-model resolved state + +// pattern-action editor (PUT /api/pii/patterns/:id, transient). +// - Routing: placeholder until subsystem 2 lands. Renders the note +// from /api/router/status so admins see "not yet implemented" rather +// than an empty page. +// - Events: recent PIIEvent rows from /api/pii/events. The page +// intentionally NEVER displays the redacted content (the redactor +// never stores it); only pattern_id, byte_offset, length, and an +// 8-char sha256 prefix admins can use to dedupe recurring leaks. +// +// Wiring is admin-only: RequireAdmin in router.jsx already redirects +// non-admin viewers; in single-user no-auth mode the local user has +// admin role so the page works without --auth. + +const TABS = [ + { id: 'filtering', label: 'Filtering', icon: 'fa-shield-halved' }, + { id: 'routing', label: 'Routing', icon: 'fa-route' }, + { id: 'proxy', label: 'MITM Proxy', icon: 'fa-shield' }, + { id: 'events', label: 'Events', icon: 'fa-list-ul' }, +] + +const ACTIONS = ['mask', 'block', 'route_local'] + +function actionBadge(action) { + const colors = { + mask: 'var(--color-primary)', + block: 'var(--color-error)', + route_local: 'var(--color-warning)', + } + return ( + + {action} + + ) +} + +function enabledBadge(enabled) { + return ( + + {enabled ? 'on' : 'off'} + + ) +} + +export default function Middleware() { + const { addToast } = useOutletContext() + const [status, setStatus] = useState(null) + const [events, setEvents] = useState([]) + const [decisions, setDecisions] = useState([]) + const [loading, setLoading] = useState(true) + const [activeTab, setActiveTab] = useState('filtering') + const [pendingPattern, setPendingPattern] = useState(null) // id while a PUT is in flight + + // silent=true on background polls: skips the loading spinner and + // suppresses toast spam if the server is briefly unreachable. + const fetchAll = useCallback(async (silent = false) => { + if (!silent) setLoading(true) + try { + const [statusRes, eventsRes, decisionsRes] = await Promise.all([ + fetch(apiUrl('/api/middleware/status')), + fetch(apiUrl('/api/pii/events?limit=100')), + fetch(apiUrl('/api/router/decisions?limit=100')), + ]) + if (!statusRes.ok) throw new Error(`status: HTTP ${statusRes.status}`) + const statusData = await statusRes.json() + setStatus(statusData) + if (eventsRes.ok) { + const data = await eventsRes.json() + setEvents(data.events || []) + } + if (decisionsRes.ok) { + const data = await decisionsRes.json() + setDecisions(data.decisions || []) + } + } catch (err) { + if (!silent) addToast(`Failed to load middleware status: ${err.message}`, 'error') + } finally { + if (!silent) setLoading(false) + } + }, [addToast]) + + useEffect(() => { fetchAll() }, [fetchAll]) + + // Auto-refresh every 5s so admins watching the Events / Routing tabs + // see new rows without manual refresh. Matches the Traces page cadence. + // ProxyTab guards against clobbering mid-typed config via its own + // `dirty` check, so the poll is safe while the form is in use. + const refreshRef = useRef(null) + useEffect(() => { + refreshRef.current = setInterval(() => fetchAll(true), 5000) + return () => clearInterval(refreshRef.current) + }, [fetchAll]) + + const mutatePattern = async (patternID, body, successMsg) => { + setPendingPattern(patternID) + try { + const res = await fetch(apiUrl(`/api/pii/patterns/${encodeURIComponent(patternID)}`), { + method: 'PUT', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(body), + }) + if (!res.ok) { + const data = await res.json().catch(() => ({})) + throw new Error(data.error || `HTTP ${res.status}`) + } + addToast(successMsg, 'success') + await fetchAll() + } catch (err) { + addToast(`Failed to update pattern: ${err.message}`, 'error') + } finally { + setPendingPattern(null) + } + } + + const setPatternAction = (patternID, action) => + mutatePattern(patternID, { action }, `Pattern ${patternID}: action ${action} (transient — click "Save to disk" to persist)`) + + const setPatternDisabled = (patternID, disabled) => + mutatePattern(patternID, { disabled }, `Pattern ${patternID}: ${disabled ? 'disabled' : 'enabled'} (transient — click "Save to disk" to persist)`) + + const [persisting, setPersisting] = useState(false) + const persistPatterns = async () => { + setPersisting(true) + try { + const res = await fetch(apiUrl('/api/pii/patterns/persist'), { method: 'POST' }) + if (!res.ok) { + const data = await res.json().catch(() => ({})) + throw new Error(data.error || `HTTP ${res.status}`) + } + const data = await res.json().catch(() => ({})) + addToast(`Saved ${data.override_count ?? 0} pattern override(s) to runtime_settings.json`, 'success') + } catch (err) { + addToast(`Failed to persist: ${err.message}`, 'error') + } finally { + setPersisting(false) + } + } + + return ( +
+
+

Middleware

+

+ Inspect and configure routing-module middleware: PII filtering and intelligent routing. +

+
+ + {/* Tab bar */} +
+ {TABS.map(tab => ( + + ))} +
+ +
+ + {loading && !status ? ( +
+ +
+ ) : activeTab === 'filtering' ? ( + + ) : activeTab === 'routing' ? ( + + ) : activeTab === 'proxy' ? ( + + ) : ( + + )} +
+ ) +} + +function FilteringTab({ status, pendingPattern, onSetAction, onSetDisabled, onPersist, persisting }) { + if (!status?.pii) return null + const pii = status.pii + + if (!pii.enabled_globally) { + return ( +
+
+

PII filtering disabled

+

+ The PII filter is disabled by {pii.reason || '--disable-pii'}. + Restart without that flag to enable it. +

+
+ ) + } + + return ( + <> + {/* Default rule banner */} +
+
+ +
+
Default policy
+
+ PII redaction is per-model and OFF by default. Backends matching {(pii.default_enabled_for_backends || []).join(', ')} default to ON (cloud passthroughs). Override per model with pii: {'{'} enabled: true {'}'} in the model YAML. +
+
+
+
+ + {/* Patterns table */} +
+
+ Active patterns +
+ + Toggle / action edits are transient — click Save to disk to persist. + + +
+
+
+ + + + + + + + + + + + {pii.patterns.map(p => { + const enabled = !p.disabled + const muted = p.disabled + return ( + + + + + + + + )})} + +
EnabledPatternDescriptionActionChange
+ onSetDisabled(p.id, !e.target.checked)} + style={{ cursor: 'pointer' }} + aria-label={`Enable ${p.id} pattern`} + /> + {p.id}{p.description}{actionBadge(p.action)} +
+ {ACTIONS.map(a => ( + + ))} +
+
+
+
+ + {/* Per-model resolved state */} +
+
+ Per-model state + + Edit the model YAML to change these. + +
+
+ + + + + + + + + + + + + {(pii.models || []).map(m => ( + + + + + + + + + ))} + {(!pii.models || pii.models.length === 0) && ( + + + + )} + +
ModelBackendPIISourcePattern overridesEdit
{m.name}{m.backend || '—'}{enabledBadge(m.enabled)} + {m.explicit ? 'YAML' : (m.default_for_backend ? 'backend default' : 'default off')} + + {m.overrides && Object.keys(m.overrides).length > 0 + ? Object.entries(m.overrides).map(([k, v]) => `${k}=${v}`).join(', ') + : } + + + Edit + +
+ No models loaded. +
+
+
+ + ) +} + +function RoutingTab({ status, decisions }) { + const navigate = useNavigate() + const router = status?.router || { configured: false } + + if (!router.configured || !router.models || router.models.length === 0) { + return ( +
+
+

No routers configured

+

+ {router.note || 'Add a `router:` block to a model YAML to enable intelligent routing. The classifier picks one of the listed candidates per request and the standard model-resolution path runs against the chosen target.'} +

+ +
+ ) + } + + return ( + <> + {/* Configured router models */} +
+
+ Active routers +
+ + Edit the router model YAML to change candidates or rules. + + +
+
+
+ + + + + + + + + + + + {router.models.map(m => ( + + + + + + + + ))} + +
ModelClassifierCandidatesEmbedding cacheFallback
{m.name}{m.classifier} + {(m.candidates || []).map((c, i) => ( +
+ {(c.labels || []).join(', ') || '—'} + + {c.model} +
+ ))} +
+ + + {m.fallback || '—'} +
+
+
+ + {/* Recent decisions */} +
+
+ Recent decisions + + Newest first, capped at 100. + +
+ {(!decisions || decisions.length === 0) ? ( +
+ No routing decisions yet. Send a request to a router model to populate this log. +
+ ) : ( +
+ + + + + + + + + + + + + {decisions.map(d => ( + + + + + + + + + ))} + +
TimeRouterLabelServedLatencyCorrelation
{d.created_at}{d.router_model}{d.label}{d.served_model}{d.latency_ms}ms + {d.correlation_id || '—'} +
+
+ )} +
+ + ) +} + +function ProxyTab({ status, addToast, onChanged }) { + const navigate = useNavigate() + const mitm = status?.mitm + const serverListen = mitm?.configured_addr || '' + + const [listen, setListen] = useState(serverListen) + const [saving, setSaving] = useState(false) + + const dirty = listen !== serverListen + + // Refresh local state from the server only when the user has no + // pending edits to clobber. + useEffect(() => { + if (dirty) return + setListen(serverListen) + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [serverListen]) + + const save = async () => { + setSaving(true) + try { + const body = await settingsApi.save({ mitm_listen: listen }) + if (body && body.success === false) { + throw new Error(body.error || 'unknown error') + } + addToast('MITM proxy settings updated', 'success') + onChanged?.() + } catch (err) { + addToast(`Failed to save: ${err.message}`, 'error') + } finally { + setSaving(false) + } + } + + if (!mitm) { + return ( +
+
+

MITM proxy status unavailable

+

The status endpoint did not return a mitm section.

+
+ ) + } + + const conflicts = mitm.host_conflicts || {} + const owners = mitm.host_owners || {} + const conflictHosts = Object.keys(conflicts) + const ownerEntries = Object.entries(owners) + const mitmModels = mitm.models || [] + + return ( +
+ {conflictHosts.length > 0 && ( +
+
+ + MITM listener disabled — duplicate host claims +
+

+ Each MITM intercept host must be owned by exactly one model config. Resolve by editing the conflicting model YAMLs. +

+
    + {conflictHosts.map(h => ( +
  • + {h} + {' claimed by: '} + {(conflicts[h] || []).map(name => ( + + {name} + + ))} +
  • + ))} +
+
+ )} + +
+
+

State

+ {enabledBadge(mitm.running)} + {mitm.running && ( + + listening on {mitm.listen_addr} + + )} +
+

+ The MITM proxy terminates TLS for allowlisted hosts so PII redaction + can run on traffic from clients that authenticate via OAuth / + subscription (Claude Code, Codex CLI). Non-allowlisted hosts get a + plain CONNECT tunnel — no inspection, no CA-trust required. +

+ {ownerEntries.length > 0 ? ( +
+
Hosts claimed by model configs (PII settings flow from the owning config):
+
    + {ownerEntries.map(([host, name]) => ( +
  • + {host} → {name} +
  • + ))} +
+
+ ) : ( +
+ No model config declares an MITM intercept host. Without one, every CONNECT tunnels through unmodified. Create one from the Add Model page using the MITM Intercept template. +
+ )} + {mitm.ca_available ? ( + + Download CA cert + + ) : ( + + CA not generated yet — start the listener to generate it. + + )} +
+ +
+
+

MITM Models

+ +
+ {mitmModels.length === 0 ? ( +
+ No model config declares mitm.hosts. Use the Add MITM model button above — the template defaults to api.anthropic.com with PII filtering on. +
+ ) : ( + + + + + + + + + + + {mitmModels.map(m => ( + + + + + + + ))} + +
ModelHostsPIIEdit
{m.name} + {(m.hosts || []).join(', ')} + {enabledBadge(m.pii_enabled)} + + Edit + +
+ )} +
+ +
+

Configuration

+ + + +
+ Intercept hosts are declared per-model in the model YAML's + {' '}mitm.hosts:{' '} + block. Each host is owned by exactly one model config; PII filtering and + pattern overrides flow from the owning config when the host is intercepted. +
+ +
+ + {dirty && ( + + )} +
+
+ +
+

Client setup

+
    +
  1. Download the CA cert (button above).
  2. +
  3. Trust it on the client. For Node-based CLIs (Claude Code, Codex): export NODE_EXTRA_CA_CERTS=$(pwd)/localai-mitm-ca.crt
  4. +
  5. Point the client at the proxy: export HTTPS_PROXY=http://<host>:<port> (yes, http:// — clients speak plain HTTP to the proxy, which then terminates TLS for allowlisted hosts on the inner connection).
  6. +
+
+
+ ) +} + +const EVENT_KINDS = [ + { id: '', label: 'All' }, + { id: 'pii', label: 'PII' }, + { id: 'proxy_connect', label: 'Proxy connect' }, + { id: 'proxy_traffic', label: 'Proxy traffic' }, + { id: 'admission', label: 'Admission' }, +] + +function eventKind(e) { + return e.kind || 'pii' +} + +function eventSubject(e) { + switch (eventKind(e)) { + case 'proxy_connect': + case 'proxy_traffic': + case 'admission': + return e.host || '—' + default: + return e.pattern_id || '—' + } +} + +function eventDetails(e) { + switch (eventKind(e)) { + case 'proxy_connect': + return e.intercepted ? 'intercepted (TLS terminated)' : 'tunneled (passthrough)' + case 'proxy_traffic': { + const status = e.status_code ? `HTTP ${e.status_code}` : 'no upstream' + const sent = formatBytes(e.bytes_sent) + const recv = formatBytes(e.bytes_received) + const dur = e.duration_ms != null ? `${e.duration_ms}ms` : '' + return `${status} · ↑${sent} ↓${recv} · ${dur}` + } + case 'admission': { + const retry = e.duration_ms != null ? `retry-after ${Math.round(e.duration_ms / 1000)}s` : '' + return `HTTP 503 rejected · ${retry}` + } + default: { + const len = e.length != null ? `len ${e.length}` : '' + const hash = e.hash_prefix ? `hash ${e.hash_prefix}` : '' + return [len, hash].filter(Boolean).join(' · ') || '—' + } + } +} + +function formatBytes(n) { + if (!n) return '0B' + if (n < 1024) return `${n}B` + if (n < 1024 * 1024) return `${(n / 1024).toFixed(1)}KB` + return `${(n / (1024 * 1024)).toFixed(1)}MB` +} + +function kindBadge(kind) { + const colors = { + pii: 'var(--color-warning)', + proxy_connect: 'var(--color-primary)', + proxy_traffic: 'var(--color-text-muted)', + admission: 'var(--color-error)', + } + return ( + + {kind.replace(/_/g, ' ')} + + ) +} + +function EventsTab({ events }) { + const [kindFilter, setKindFilter] = useState('') + const filtered = kindFilter ? events.filter(e => eventKind(e) === kindFilter) : events + + return ( +
+
+
+ Recent events + + shared by PII filter and MITM proxy · newest first · capped at 100 + +
+
+ {EVENT_KINDS.map(k => ( + + ))} +
+
+ {filtered.length === 0 ? ( +
+
+

No events

+

+ Events appear here when the PII filter matches a pattern, when the MITM proxy decides whether + to intercept a hostname, or when an intercepted request finishes. Request bodies are never + stored — use the API and backend traces for that. +

+
+ ) : ( +
+ + + + + + + + + + + + + {filtered.map(e => ( + + + + + + + + + ))} + +
TimeKindSubjectDetailsActionCorrelation
+ {e.created_at} + {kindBadge(eventKind(e))} + {eventSubject(e)} + + {eventDetails(e)} + {e.action ? actionBadge(e.action) : '—'} + {e.correlation_id || '—'} +
+
+ )} +
+ ) +} + +// RouterCacheCell renders the L2 embedding-cache state for one router +// model. Shows nothing for routers without an embedding_cache: block; +// for configured caches, shows hit/miss/near-miss counters plus a +// similarity histogram with a marker at the configured threshold so +// admins can tell at a glance whether the threshold is well-placed. +function RouterCacheCell({ cache }) { + if (!cache) { + return + } + const stats = cache.stats || {} + const hits = stats.hits || 0 + const misses = stats.misses || 0 + const nearMisses = stats.near_misses || 0 + const lowConf = stats.low_confidence || 0 + const totalLookups = hits + misses + nearMisses + const hitRate = totalLookups > 0 ? Math.round((hits / totalLookups) * 100) : null + const errors = (stats.embedder_errors || 0) + (stats.store_errors || 0) + const buckets = stats.similarity_buckets || [] + const bucketMax = buckets.length ? Math.max(...buckets, 1) : 1 + const threshold = cache.similarity_threshold || 0.80 + const thresholdBucket = Math.max(0, Math.min(9, Math.floor(threshold * 10))) + return ( +
+
{cache.embedding_model}
+
+ {totalLookups === 0 ? ( + no traffic yet + ) : ( + <> + = 50 ? 'var(--color-success, #2da44e)' : 'var(--color-text-muted)' }}> + {hitRate}% hit + + · {hits}h/{nearMisses}n/{misses}m + {lowConf > 0 && · {lowConf} skipped} + {errors > 0 && · {errors} err} + + )} +
+ {buckets.length === 10 && buckets.some(v => v > 0) && ( +
+ {buckets.map((count, i) => { + const h = bucketMax > 0 ? Math.max(2, Math.round((count / bucketMax) * 18)) : 2 + const inHitZone = i >= thresholdBucket + return ( +
+ ) + })} +
+ sim ≥ {threshold} +
+
+ )} +
+ ) +} diff --git a/core/http/react-ui/src/pages/ModelEditor.jsx b/core/http/react-ui/src/pages/ModelEditor.jsx index 40446b2bc18f..9cb032f1b38c 100644 --- a/core/http/react-ui/src/pages/ModelEditor.jsx +++ b/core/http/react-ui/src/pages/ModelEditor.jsx @@ -9,6 +9,7 @@ import LoadingSpinner from '../components/LoadingSpinner' import CodeEditor from '../components/CodeEditor' import FieldBrowser from '../components/FieldBrowser' import ConfigFieldRenderer from '../components/ConfigFieldRenderer' +import { FormContextProvider } from '../contexts/FormContext' import TemplateSelector from '../components/TemplateSelector' import MODEL_TEMPLATES from '../utils/modelTemplates' @@ -386,6 +387,7 @@ export default function ModelEditor() { if (metaError) return

Failed to load config metadata: {metaError}

return ( +
{/* Header */}
)}
+ ) } diff --git a/core/http/react-ui/src/pages/Usage.jsx b/core/http/react-ui/src/pages/Usage.jsx index 9d5b51f695ee..0b0d954cf72a 100644 --- a/core/http/react-ui/src/pages/Usage.jsx +++ b/core/http/react-ui/src/pages/Usage.jsx @@ -629,7 +629,7 @@ function ModelDistChart({ rows }) { export default function Usage() { const { addToast } = useOutletContext() - const { isAdmin, authEnabled } = useAuth() + const { isAdmin, authEnabled, loading: authLoading } = useAuth() const { t } = useTranslation('admin') const [period, setPeriod] = useState('month') const [loading, setLoading] = useState(true) @@ -644,8 +644,13 @@ export default function Usage() { const fetchUsage = useCallback(async () => { setLoading(true) try { - const usagePromise = fetch(apiUrl(`/api/auth/usage?period=${period}`)) - const quotaPromise = fetch(apiUrl('/api/auth/quota')) + // /api/usage works in no-auth single-user mode (returns the synthetic + // local user's usage). /api/auth/usage is the legacy auth-required + // path; we keep using it when auth is on so /api/auth/quota and + // friends remain consistent. + const userUsageURL = authEnabled ? '/api/auth/usage' : '/api/usage' + const usagePromise = fetch(apiUrl(`${userUsageURL}?period=${period}`)) + const quotaPromise = authEnabled ? fetch(apiUrl('/api/auth/quota')) : Promise.resolve(null) const [res, quotaRes] = await Promise.all([usagePromise, quotaPromise]) @@ -654,13 +659,18 @@ export default function Usage() { setUsage(data.usage || []) setTotals(data.totals || {}) - if (quotaRes.ok) { + if (quotaRes && quotaRes.ok) { const quotaData = await quotaRes.json() setQuotas(quotaData.quotas || []) } if (isAdmin) { - const adminRes = await fetch(apiUrl(`/api/auth/admin/usage?period=${period}`)) + // /api/usage/all serves the cluster-wide view in both modes. + // The synthetic local user has Role: admin, so single-user mode + // gets the admin-style cross-user table (which collapses to one + // row, but keeps the UI shape consistent). + const adminURL = authEnabled ? '/api/auth/admin/usage' : '/api/usage/all' + const adminRes = await fetch(apiUrl(`${adminURL}?period=${period}`)) if (adminRes.ok) { const adminData = await adminRes.json() setAdminUsage(adminData.usage || []) @@ -672,24 +682,12 @@ export default function Usage() { } finally { setLoading(false) } - }, [period, isAdmin, addToast]) + }, [period, isAdmin, authEnabled, addToast]) useEffect(() => { - if (authEnabled) fetchUsage() - else setLoading(false) - }, [fetchUsage, authEnabled]) - - if (!authEnabled) { - return ( -
-
-
-

Usage tracking unavailable

-

Authentication must be enabled to track API usage.

-
-
- ) - } + if (authLoading) return + fetchUsage() + }, [fetchUsage, authLoading]) const modelRows = aggregateByModel(isAdmin ? adminUsage : usage) const userRows = isAdmin ? aggregateByUser(adminUsage) : [] diff --git a/core/http/react-ui/src/router.jsx b/core/http/react-ui/src/router.jsx index 2e07fea5f35e..ae662a8be3c4 100644 --- a/core/http/react-ui/src/router.jsx +++ b/core/http/react-ui/src/router.jsx @@ -42,9 +42,11 @@ import NodeBackendLogs from './pages/NodeBackendLogs' import NotFound from './pages/NotFound' import Usage from './pages/Usage' import Users from './pages/Users' +import Middleware from './pages/Middleware' import Account from './pages/Account' import RequireAdmin from './components/RequireAdmin' import RequireAuth from './components/RequireAuth' +import RequireAuthEnabled from './components/RequireAuthEnabled' import RequireFeature from './components/RequireFeature' function BrowseRedirect() { @@ -84,7 +86,8 @@ const appChildren = [ { path: 'voice/:model', element: }, { path: 'usage', element: }, { path: 'account', element: }, - { path: 'users', element: }, + { path: 'users', element: }, + { path: 'middleware', element: }, { path: 'manage', element: }, { path: 'backends', element: }, { path: 'settings', element: }, diff --git a/core/http/react-ui/src/utils/modelTemplates.js b/core/http/react-ui/src/utils/modelTemplates.js index 576d66c5c6c4..3d4b6205e42f 100644 --- a/core/http/react-ui/src/utils/modelTemplates.js +++ b/core/http/react-ui/src/utils/modelTemplates.js @@ -74,6 +74,86 @@ const MODEL_TEMPLATES = [ 'embeddings': true, }, }, + { + id: 'proxy-openai', + label: 'OpenAI Proxy', + icon: 'fa-cloud', + description: 'Forward chat completions to OpenAI or any OpenAI-compatible provider; PII redaction runs in flight', + // known_usecases is pre-seeded with chat so the proxy model + // surfaces in places that filter by capability — model pickers + // for chat, router fallback dropdowns, etc. Backends without an + // explicit usecase list are filtered out of those selectors. + fields: { + 'name': '', + 'backend': 'proxy-openai', + 'known_usecases': ['chat'], + 'proxy.upstream_url': 'https://api.openai.com/v1/chat/completions', + 'proxy.api_key_env': 'OPENAI_API_KEY', + 'proxy.upstream_model': '', + 'proxy.request_timeout_seconds': 120, + 'pii.enabled': true, + }, + }, + { + id: 'proxy-anthropic', + label: 'Anthropic Proxy', + icon: 'fa-cloud', + description: 'Forward Messages API requests to Anthropic; PII redaction runs in flight', + fields: { + 'name': '', + 'backend': 'proxy-anthropic', + 'known_usecases': ['chat'], + 'proxy.upstream_url': 'https://api.anthropic.com/v1/messages', + 'proxy.api_key_env': 'ANTHROPIC_API_KEY', + 'proxy.upstream_model': '', + 'proxy.request_timeout_seconds': 300, + 'pii.enabled': true, + }, + }, + { + id: 'router', + label: 'Routing Model', + icon: 'fa-route', + description: 'Score-classifier router with three example policies and two candidates. Fill in the classifier_model (Arch-Router-1.5B recommended), the per-candidate downstream models, and the fallback. The L2 embedding cache is opt-in via the Routing section.', + fields: { + 'name': 'smart-router', + 'router.classifier': 'score', + 'router.classifier_model': '', + 'router.fallback': '', + 'router.activation_threshold': 0.40, + 'router.policies': [ + { label: 'code-generation', description: 'writing, debugging, reading, or explaining code in any programming language' }, + { label: 'casual-chat', description: 'small talk, greetings, jokes, or general conversation with no specific task' }, + { label: 'math-reasoning', description: 'arithmetic, equations, percentage calculations, or step-by-step word problems' }, + ], + 'router.candidates': [ + { model: '', labels: ['casual-chat'] }, + { model: '', labels: ['code-generation', 'casual-chat', 'math-reasoning'] }, + ], + }, + }, + { + id: 'mitm', + label: 'MITM Intercept', + icon: 'fa-shield-halved', + description: 'Bind a hostname to this config for the cloudproxy MITM listener. PII filtering and pattern overrides flow from this config when the host is intercepted.', + // The mitm- name prefix is a convention, not a contract — the + // dispatcher looks up by host, not name. Prefixing keeps the + // config out of the way of callable model names so a chat client + // accidentally requesting "anthropic" doesn't hit a backendless + // intercept config. + // + // pii.patterns is pre-seeded with an empty list so the override + // editor is visible by default — admins typically want to tighten + // a couple of pattern actions when intercepting a cloud provider. + // An empty list serializes out and the redactor ignores it. + fields: { + 'name': 'mitm-anthropic', + 'mitm.hosts': ['api.anthropic.com'], + 'pii.enabled': true, + 'pii.patterns': [], + }, + }, ] export default MODEL_TEMPLATES diff --git a/core/http/routes/anthropic.go b/core/http/routes/anthropic.go index 68b3079bd359..e770791babc4 100644 --- a/core/http/routes/anthropic.go +++ b/core/http/routes/anthropic.go @@ -13,6 +13,8 @@ import ( mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp" "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/core/services/routing/pii" + "github.com/mudler/LocalAI/core/services/routing/piiadapter" "github.com/mudler/xlog" ) @@ -32,14 +34,34 @@ func RegisterAnthropicRoutes(app *echo.Echo, application.TemplatesEvaluator(), application.ApplicationConfig(), natsClient, + application.PIIRedactor(), + application.PIIEvents(), ) messagesMiddleware := []echo.MiddlewareFunc{ - middleware.UsageMiddleware(application.AuthDB()), + middleware.UsageMiddleware(application.StatsRecorder(), application.FallbackUser()), middleware.TraceMiddleware(application), re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)), re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.AnthropicRequest) }), setAnthropicRequestContext(application.ApplicationConfig()), + // RouteModel runs after the request is parsed but before the + // PII filter — see the OpenAI route for why this order matters + // (per-model PII configs apply to the routed target). + middleware.RouteModel( + application.ModelConfigLoader(), + application.ApplicationConfig(), + application.RouterDecisions(), + application.FallbackUser(), + middleware.AnthropicProbe, + middleware.ClassifierDeps{ + Scorer: application.ScorerFactory(), + Embedder: application.EmbedderFactory(), + VectorStore: application.VectorStoreFactory(), + Registry: application.RouterClassifierRegistry(), + }, + ), + middleware.AdmissionControl(application.AdmissionLimiter(), application.PIIEvents()), + pii.RequestMiddleware(application.PIIRedactor(), application.PIIEvents(), piiadapter.Anthropic(), application.FallbackUser()), } // Main Anthropic endpoint diff --git a/core/http/routes/localai.go b/core/http/routes/localai.go index 5c341b90c8be..4f61c852594a 100644 --- a/core/http/routes/localai.go +++ b/core/http/routes/localai.go @@ -216,6 +216,11 @@ func RegisterLocalAIRoutes(router *echo.Echo, router.GET("/api/p2p", localai.ShowP2PNodes(appConfig), adminMiddleware) router.GET("/api/p2p/token", localai.ShowP2PToken(appConfig), adminMiddleware) + // Score (logprob over candidate continuations) — admin-only smoke-test + // surface for the gRPC Score primitive. Production consumers should + // use application.ScorerFactory() directly rather than HTTP. + router.POST("/api/score", localai.ScoreEndpoint(cl, ml, appConfig), adminMiddleware) + router.GET("/version", func(c echo.Context) error { return c.JSON(200, struct { Version string `json:"version"` diff --git a/core/http/routes/middleware.go b/core/http/routes/middleware.go new file mode 100644 index 000000000000..4bec48d5cc3d --- /dev/null +++ b/core/http/routes/middleware.go @@ -0,0 +1,328 @@ +package routes + +import ( + "context" + "net/http" + "strconv" + "strings" + + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/application" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/http/auth" + "github.com/mudler/LocalAI/core/services/routing/router" +) + +// RegisterMiddlewareRoutes wires the routing-module admin surface that +// powers the /app/middleware React page. Two endpoints: +// +// - GET /api/middleware/status — single round-trip aggregator. Lists +// PII patterns with current actions, each model's resolved +// enabled/override state, recent event count, and a router status +// stub (until subsystem 2 lands). +// - GET /api/router/status — placeholder that the page renders for +// the Routing tab. Returns { configured: false, models: [] } today; +// subsystem 2 fills it in. +// +// Both are admin-only when auth is on. In single-user (no-auth) mode +// the synthetic local user has Role: admin so the page works without +// extra config — same gating shape as the existing /api/usage/all. +func RegisterMiddlewareRoutes(e *echo.Echo, app *application.Application) { + e.GET("/api/middleware/status", func(c echo.Context) error { + viewer := resolveUsageUser(c, app) + if viewer == nil { + return c.JSON(http.StatusUnauthorized, map[string]string{"error": "not authenticated"}) + } + if viewer.Role != auth.RoleAdmin { + return c.JSON(http.StatusForbidden, map[string]string{"error": "admin access required"}) + } + + piiSection := buildPIIStatus(app) + routerSection := buildRouterStatus(app) + mitmSection := buildMITMStatus(app) + admissionSection := buildAdmissionStatus(app) + + return c.JSON(http.StatusOK, map[string]any{ + "pii": piiSection, + "router": routerSection, + "mitm": mitmSection, + "admission": admissionSection, + }) + }) + + e.GET("/api/router/status", func(c echo.Context) error { + // Read-only — admins want to see classifier configurations + // without authenticating, same as /api/pii/patterns. + return c.JSON(http.StatusOK, buildRouterStatus(app)) + }) + + e.GET("/api/middleware/proxy-ca.crt", func(c echo.Context) error { + // The CA cert is the public half — safe to expose without + // auth so clients can curl it during initial setup. The + // private key never leaves disk and is mode 0600. Returning + // 404 (rather than 500) when MITM is disabled keeps the + // endpoint a clean "is this feature available?" probe. + ca := app.MITMCA() + if ca == nil { + return c.JSON(http.StatusNotFound, map[string]string{ + "error": "mitm proxy is not enabled (set --mitm-listen to start it)", + }) + } + c.Response().Header().Set("Content-Type", "application/x-pem-file") + c.Response().Header().Set("Content-Disposition", `attachment; filename="localai-mitm-ca.crt"`) + return c.Blob(http.StatusOK, "application/x-pem-file", ca.PublicCertPEM()) + }) + + e.GET("/api/router/decisions", func(c echo.Context) error { + viewer := resolveUsageUser(c, app) + if viewer == nil { + return c.JSON(http.StatusUnauthorized, map[string]string{"error": "not authenticated"}) + } + // Decision logs may include user ids — admin-only when auth is + // on; the synthetic local user has admin so single-user mode + // works. + if viewer.Role != auth.RoleAdmin { + return c.JSON(http.StatusForbidden, map[string]string{"error": "admin access required"}) + } + + store := app.RouterDecisions() + if store == nil { + return c.JSON(http.StatusOK, map[string]any{"decisions": []any{}}) + } + + limit := 100 + if v := c.QueryParam("limit"); v != "" { + if n, err := strconv.Atoi(v); err == nil && n > 0 { + limit = n + } + } + decisions, err := store.List(c.Request().Context(), router.DecisionListQuery{ + CorrelationID: c.QueryParam("correlation_id"), + UserID: c.QueryParam("user_id"), + RouterModel: c.QueryParam("router_model"), + Limit: limit, + }) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to list decisions"}) + } + return c.JSON(http.StatusOK, map[string]any{"decisions": decisions}) + }) + + // GET /api/router/cache/stats — embedding-cache counters per + // router model. Read-only; same auth gating as /api/router/status + // (any authenticated user can see configuration). Omitted entries + // indicate "embedding cache not enabled for this router". + e.GET("/api/router/cache/stats", func(c echo.Context) error { + reg := app.RouterClassifierRegistry() + stats := map[string]router.EmbeddingCacheStats{} + if reg != nil { + stats = reg.EmbeddingCacheStatsByRouter() + } + return c.JSON(http.StatusOK, map[string]any{"caches": stats}) + }) +} + +// buildRouterStatus inventories every model that declares a Router +// block and reports their classifiers + candidate tables. Reads from +// the same loader the RouteModel middleware uses so the admin page +// agrees with what's actually live in the request path. +func buildRouterStatus(app *application.Application) map[string]any { + models := []map[string]any{} + hasAny := false + cacheStats := map[string]router.EmbeddingCacheStats{} + if reg := app.RouterClassifierRegistry(); reg != nil { + cacheStats = reg.EmbeddingCacheStatsByRouter() + } + for _, cfg := range app.ModelConfigLoader().GetAllModelsConfigs() { + if !cfg.HasRouter() { + continue + } + hasAny = true + candidates := make([]map[string]any, 0, len(cfg.Router.Candidates)) + for _, ca := range cfg.Router.Candidates { + candidates = append(candidates, map[string]any{ + "model": ca.Model, + "labels": ca.Labels, + }) + } + policies := make([]map[string]any, 0, len(cfg.Router.Policies)) + for _, p := range cfg.Router.Policies { + policies = append(policies, map[string]any{ + "label": p.Label, + "description": p.Description, + }) + } + classifier := cfg.Router.Classifier + if classifier == "" { + classifier = router.ClassifierScore + } + entry := map[string]any{ + "name": cfg.Name, + "classifier": classifier, + "policies": policies, + "candidates": candidates, + "fallback": cfg.Router.Fallback, + } + if ec := cfg.Router.EmbeddingCache; ec != nil { + cacheEntry := map[string]any{ + "embedding_model": ec.EmbeddingModel, + "similarity_threshold": ec.SimilarityThreshold, + "confidence_threshold": ec.ConfidenceThreshold, + "store_name": ec.StoreName, + } + if s, ok := cacheStats[cfg.Name]; ok { + cacheEntry["stats"] = s + } + entry["embedding_cache"] = cacheEntry + } + models = append(models, entry) + } + + recentCount := 0 + if store := app.RouterDecisions(); store != nil { + if n, err := store.Count(context.Background()); err == nil { + recentCount = n + } + } + + out := map[string]any{ + "configured": hasAny, + "models": models, + "recent_decision_count": recentCount, + "available_classifiers": []string{router.ClassifierScore}, + } + if !hasAny { + out["note"] = "No router models configured. Add a `router:` block to a model YAML to enable intelligent routing." + } + return out +} + +func buildMITMStatus(app *application.Application) map[string]any { + srv := app.MITMServer() + ca := app.MITMCA() + cfg := app.ApplicationConfig() + + // MITM-bound model configs — anything with an mitm: block, even + // if hosts is empty. Surfaces a "fresh from template" config the + // admin started but hasn't yet attached a host to. + mitmModels := []map[string]any{} + for _, mc := range app.ModelConfigLoader().GetModelConfigsByFilter(func(_ string, c *config.ModelConfig) bool { + return len(c.MITM.Hosts) > 0 + }) { + mitmModels = append(mitmModels, map[string]any{ + "name": mc.Name, + "hosts": mc.MITM.Hosts, + "pii_enabled": mc.PIIIsEnabled(), + "backend": mc.Backend, + }) + } + + out := map[string]any{ + "running": srv != nil, + "listen_addr": "", + "configured_addr": cfg.MITMListen, + "host_owners": app.MITMHostOwners(), + "host_conflicts": app.MITMHostConflicts(), + "models": mitmModels, + "ca_available": ca != nil, + "ca_cert_url": "", + } + if conflicts := app.MITMHostConflicts(); len(conflicts) > 0 { + out["error"] = "MITM listener disabled: duplicate host claims across model configs (see host_conflicts). Resolve by editing the conflicting model YAMLs so each host appears in at most one mitm.hosts list." + } + if srv != nil { + out["listen_addr"] = srv.Addr() + } + if ca != nil { + out["ca_cert_url"] = "/api/middleware/proxy-ca.crt" + } + return out +} + +// buildAdmissionStatus reports each model's MaxConcurrent ceiling +// and current in-flight count. Models with no limit set are +// omitted — the dashboard view is "what's gated", not "every +// model in the loader". +func buildAdmissionStatus(app *application.Application) map[string]any { + limiter := app.AdmissionLimiter() + models := []map[string]any{} + if limiter == nil { + return map[string]any{"models": models} + } + for _, cfg := range app.ModelConfigLoader().GetAllModelsConfigs() { + if cfg.Limits.MaxConcurrent <= 0 { + continue + } + models = append(models, map[string]any{ + "name": cfg.Name, + "max_concurrent": cfg.Limits.MaxConcurrent, + "retry_after_seconds": cfg.Limits.RetryAfterSeconds, + "in_flight": limiter.InFlight(cfg.Name), + }) + } + return map[string]any{"models": models} +} + +// buildPIIStatus builds the pii section of /api/middleware/status. It +// reads the live redactor, walks every model config, and reports the +// resolved enabled state plus any per-pattern overrides — that's what +// the admin page renders side-by-side so the operator can see at a +// glance which models are protected. +// +// Returns a sentinel "disabled" payload when the redactor is nil +// (--disable-pii), letting the page show "filter switched off" rather +// than a confusing empty state. +func buildPIIStatus(app *application.Application) map[string]any { + redactor := app.PIIRedactor() + if redactor == nil { + return map[string]any{ + "enabled_globally": false, + "reason": "--disable-pii", + "patterns": []any{}, + "models": []any{}, + } + } + + patterns := redactor.Patterns() + patternList := make([]map[string]any, 0, len(patterns)) + for _, p := range patterns { + patternList = append(patternList, map[string]any{ + "id": p.ID, + "description": p.Description, + "action": string(p.Action), + "disabled": p.Disabled, + "max_match_length": p.MaxMatchLength, + }) + } + + models := []map[string]any{} + for _, cfg := range app.ModelConfigLoader().GetAllModelsConfigs() { + entry := map[string]any{ + "name": cfg.Name, + "backend": cfg.Backend, + "enabled": cfg.PIIIsEnabled(), + "overrides": cfg.PIIPatternOverrides(), + } + // explicit-set tells the UI whether the resolved state came + // from the YAML or the backend-prefix default. Helps admins + // understand "why is this on?" without reading source. + entry["explicit"] = cfg.PII.Enabled != nil + entry["default_for_backend"] = strings.HasPrefix(cfg.Backend, "proxy-") + models = append(models, entry) + } + + recentCount := 0 + if app.PIIEvents() != nil { + if n, err := app.PIIEvents().Count(context.Background()); err == nil { + recentCount = n + } + } + + return map[string]any{ + "enabled_globally": true, + "default_enabled_for_backends": []string{"proxy-*"}, + "patterns": patternList, + "models": models, + "recent_event_count": recentCount, + } +} diff --git a/core/http/routes/ollama.go b/core/http/routes/ollama.go index aba0d8e976b2..f02db76368d8 100644 --- a/core/http/routes/ollama.go +++ b/core/http/routes/ollama.go @@ -17,7 +17,7 @@ func RegisterOllamaRoutes(app *echo.Echo, application *application.Application) { traceMiddleware := middleware.TraceMiddleware(application) - usageMiddleware := middleware.UsageMiddleware(application.AuthDB()) + usageMiddleware := middleware.UsageMiddleware(application.StatsRecorder(), application.FallbackUser()) // Chat endpoint: POST /api/chat chatHandler := ollama.ChatEndpoint( diff --git a/core/http/routes/openai.go b/core/http/routes/openai.go index bd7793ae9111..3b8853c750a6 100644 --- a/core/http/routes/openai.go +++ b/core/http/routes/openai.go @@ -9,6 +9,8 @@ import ( "github.com/mudler/LocalAI/core/http/endpoints/openai" "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/core/services/routing/pii" + "github.com/mudler/LocalAI/core/services/routing/piiadapter" ) func RegisterOpenAIRoutes(app *echo.Echo, @@ -16,7 +18,7 @@ func RegisterOpenAIRoutes(app *echo.Echo, application *application.Application) { // openAI compatible API endpoint traceMiddleware := middleware.TraceMiddleware(application) - usageMiddleware := middleware.UsageMiddleware(application.AuthDB()) + usageMiddleware := middleware.UsageMiddleware(application.StatsRecorder(), application.FallbackUser()) // realtime // TODO: Modify/disable the API key middleware for this endpoint to allow ephemeral keys created by sessions @@ -32,7 +34,7 @@ func RegisterOpenAIRoutes(app *echo.Echo, } // chat - chatHandler := openai.ChatEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig(), natsClient, application.LocalAIAssistant()) + chatHandler := openai.ChatEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig(), natsClient, application.LocalAIAssistant(), application.PIIRedactor(), application.PIIEvents()) chatMiddleware := []echo.MiddlewareFunc{ usageMiddleware, traceMiddleware, @@ -46,6 +48,36 @@ func RegisterOpenAIRoutes(app *echo.Echo, return next(c) } }, + // RouteModel runs AFTER the schema-specific request parser so + // the classifier sees a populated *schema.OpenAIRequest. When + // the resolved model has a Router config, the middleware + // rewrites input.Model to the chosen candidate, swaps + // MODEL_CONFIG, and stamps RequestedModel/ServedModel for the + // usage log. Models without a Router pass through. + middleware.RouteModel( + application.ModelConfigLoader(), + application.ApplicationConfig(), + application.RouterDecisions(), + application.FallbackUser(), + middleware.OpenAIProbe, + middleware.ClassifierDeps{ + Scorer: application.ScorerFactory(), + Embedder: application.EmbedderFactory(), + VectorStore: application.VectorStoreFactory(), + Registry: application.RouterClassifierRegistry(), + }, + ), + // Admission control runs after RouteModel so the SERVED + // model's limits apply — a router fanout that lands on a + // saturated downstream gets rejected even when the requested + // router-model has slack. + middleware.AdmissionControl(application.AdmissionLimiter(), application.PIIEvents()), + // PII redaction runs INNERMOST, after RouteModel has resolved + // the actual served model. This is what makes per-model PII + // configs honour the routed target (e.g., a router fans out to + // claude-strict; that model's pii block applies, not the + // router model's). + pii.RequestMiddleware(application.PIIRedactor(), application.PIIEvents(), piiadapter.OpenAI(), application.FallbackUser()), } app.POST("/v1/chat/completions", chatHandler, chatMiddleware...) app.POST("/chat/completions", chatHandler, chatMiddleware...) @@ -71,7 +103,7 @@ func RegisterOpenAIRoutes(app *echo.Echo, app.POST("/edits", editHandler, editMiddleware...) // completion - completionHandler := openai.CompletionEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()) + completionHandler := openai.CompletionEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig(), application.PIIRedactor(), application.PIIEvents()) completionMiddleware := []echo.MiddlewareFunc{ usageMiddleware, traceMiddleware, diff --git a/core/http/routes/openresponses.go b/core/http/routes/openresponses.go index 951e34910c7e..a5932fb0f94b 100644 --- a/core/http/routes/openresponses.go +++ b/core/http/routes/openresponses.go @@ -34,7 +34,7 @@ func RegisterOpenResponsesRoutes(app *echo.Echo, // Intercept requests where the model name matches an agent — route directly // to the agent pool without going through the model config resolution pipeline. localai.AgentResponsesInterceptor(application), - middleware.UsageMiddleware(application.AuthDB()), + middleware.UsageMiddleware(application.StatsRecorder(), application.FallbackUser()), middleware.TraceMiddleware(application), re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)), re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenResponsesRequest) }), @@ -49,8 +49,8 @@ func RegisterOpenResponsesRoutes(app *echo.Echo, // WebSocket mode for Responses API wsHandler := openresponses.WebSocketEndpoint(application) - app.GET("/v1/responses", wsHandler, middleware.UsageMiddleware(application.AuthDB()), middleware.TraceMiddleware(application)) - app.GET("/responses", wsHandler, middleware.UsageMiddleware(application.AuthDB()), middleware.TraceMiddleware(application)) + app.GET("/v1/responses", wsHandler, middleware.UsageMiddleware(application.StatsRecorder(), application.FallbackUser()), middleware.TraceMiddleware(application)) + app.GET("/responses", wsHandler, middleware.UsageMiddleware(application.StatsRecorder(), application.FallbackUser()), middleware.TraceMiddleware(application)) // GET /responses/:id - Retrieve a response (for polling background requests) getResponseHandler := openresponses.GetResponseEndpoint() diff --git a/core/http/routes/pii.go b/core/http/routes/pii.go new file mode 100644 index 000000000000..8c488a85f631 --- /dev/null +++ b/core/http/routes/pii.go @@ -0,0 +1,242 @@ +package routes + +import ( + "net/http" + "strconv" + + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/application" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/http/auth" + "github.com/mudler/LocalAI/core/services/routing/pii" +) + +// RegisterPIIRoutes wires the read-only routing-PII endpoints. They +// surface (a) the active pattern set so admins can verify what is +// being filtered, (b) the recent PIIEvent log so they can audit what +// has been redacted, and (c) a dry-run "test" endpoint so an admin +// can paste candidate text and see what the redactor would do without +// sending a real request. +// +// The redactor itself runs from the chat middleware in routes/openai.go; +// these endpoints are observation- and configuration-side only. +func RegisterPIIRoutes(e *echo.Echo, app *application.Application) { + if app.PIIRedactor() == nil { + stub := func(c echo.Context) error { + return c.JSON(http.StatusServiceUnavailable, map[string]string{ + "error": "PII filter is disabled (--disable-pii)", + }) + } + e.GET("/api/pii/patterns", stub) + e.GET("/api/pii/events", stub) + e.POST("/api/pii/test", stub) + e.POST("/api/pii/patterns/persist", stub) + return + } + + // GetPIIPatternsEndpoint godoc + // @Summary List the active PII patterns + // @Description Returns the configured pattern set with their actions. Available without auth. + // @Tags pii + // @Produce json + // @Success 200 {object} map[string]interface{} + // @Router /api/pii/patterns [get] + e.GET("/api/pii/patterns", func(c echo.Context) error { + patterns := app.PIIRedactor().Patterns() + out := make([]map[string]any, 0, len(patterns)) + for _, p := range patterns { + out = append(out, map[string]any{ + "id": p.ID, + "description": p.Description, + "action": string(p.Action), + "disabled": p.Disabled, + "max_match_length": p.MaxMatchLength, + }) + } + return c.JSON(http.StatusOK, map[string]any{"patterns": out}) + }) + + // GetPIIEventsEndpoint godoc + // @Summary List recent middleware events + // @Description The event log is shared between the PII filter and the MITM proxy: PII redactions, proxy_connect (intercept decisions), and proxy_traffic (per-request byte counts) all flow through the same store. Filter by kind to narrow the view. Admin-only when auth is on; available to the local user in single-user mode. + // @Tags pii + // @Produce json + // @Param correlation_id query string false "Correlation ID join key" + // @Param user_id query string false "User id" + // @Param pattern_id query string false "Pattern id (e.g. email, ssn)" + // @Param kind query string false "Event kind: pii | proxy_connect | proxy_traffic" + // @Param limit query int false "Max events" default(100) + // @Success 200 {object} map[string]interface{} + // @Router /api/pii/events [get] + e.GET("/api/pii/events", func(c echo.Context) error { + viewer := resolveUsageUser(c, app) + if viewer == nil { + return c.JSON(http.StatusUnauthorized, map[string]string{"error": "not authenticated"}) + } + // Admin-only when auth is enabled. Local user has Role: admin. + if viewer.Role != auth.RoleAdmin { + return c.JSON(http.StatusForbidden, map[string]string{"error": "admin access required"}) + } + + limit := 100 + if v := c.QueryParam("limit"); v != "" { + if n, err := strconv.Atoi(v); err == nil && n > 0 { + limit = n + } + } + events, err := app.PIIEvents().List(c.Request().Context(), pii.ListQuery{ + CorrelationID: c.QueryParam("correlation_id"), + UserID: c.QueryParam("user_id"), + PatternID: c.QueryParam("pattern_id"), + Kind: pii.EventKind(c.QueryParam("kind")), + Limit: limit, + }) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to list events"}) + } + return c.JSON(http.StatusOK, map[string]any{"events": events}) + }) + + // PostPIITestEndpoint godoc + // @Summary Dry-run the PII redactor against text + // @Description Useful for admins tuning patterns. Returns the redacted text, matched spans, and whether the input would have been blocked. + // @Tags pii + // @Accept json + // @Produce json + // @Param body body map[string]string true "JSON {\"text\":\"...\"}" + // @Success 200 {object} map[string]interface{} + // @Router /api/pii/test [post] + e.POST("/api/pii/test", func(c echo.Context) error { + var body struct { + Text string `json:"text"` + } + if err := c.Bind(&body); err != nil { + return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid JSON"}) + } + res := app.PIIRedactor().Redact(body.Text) + return c.JSON(http.StatusOK, map[string]any{ + "redacted": res.Redacted, + "spans": res.Spans, + "blocked": res.Blocked, + "local_only": res.LocalOnly, + }) + }) + + // PutPIIPatternActionEndpoint godoc + // @Summary Change a pattern's action in-process + // @Description Mutates the named pattern's action (mask|block|route_local). Transient — restored to YAML defaults on restart. Admin-only. + // @Tags pii + // @Accept json + // @Produce json + // @Param id path string true "Pattern id" + // @Param body body map[string]string true "JSON {\"action\":\"mask|block|route_local\"}" + // @Success 200 {object} map[string]interface{} + // @Router /api/pii/patterns/{id} [put] + e.PUT("/api/pii/patterns/:id", func(c echo.Context) error { + viewer := resolveUsageUser(c, app) + if viewer == nil { + return c.JSON(http.StatusUnauthorized, map[string]string{"error": "not authenticated"}) + } + if viewer.Role != auth.RoleAdmin { + return c.JSON(http.StatusForbidden, map[string]string{"error": "admin access required"}) + } + + id := c.Param("id") + if id == "" { + return c.JSON(http.StatusBadRequest, map[string]string{"error": "pattern id is required"}) + } + // Either field is optional. The body must set at least one; + // otherwise the call is a no-op and the client probably means + // to PUT something. + var body struct { + Action *string `json:"action,omitempty"` + Disabled *bool `json:"disabled,omitempty"` + } + if err := c.Bind(&body); err != nil { + return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid JSON"}) + } + if body.Action == nil && body.Disabled == nil { + return c.JSON(http.StatusBadRequest, map[string]string{"error": "must specify action and/or disabled"}) + } + if body.Action != nil { + if err := app.PIIRedactor().SetAction(id, pii.Action(*body.Action)); err != nil { + return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) + } + } + if body.Disabled != nil { + if err := app.PIIRedactor().SetDisabled(id, *body.Disabled); err != nil { + return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) + } + } + return c.JSON(http.StatusOK, map[string]any{ + "id": id, + "action": body.Action, + "disabled": body.Disabled, + "persisted": false, + }) + }) + + // PostPIIPatternsPersistEndpoint godoc + // @Summary Persist current pattern overrides to disk + // @Description Snapshots the live redactor's per-pattern (action, disabled) state into runtime_settings.json so the next process start re-applies it. Admin-only. Pairs with PUT /api/pii/patterns/:id which only mutates in-process. + // @Tags pii + // @Produce json + // @Success 200 {object} map[string]interface{} + // @Router /api/pii/patterns/persist [post] + e.POST("/api/pii/patterns/persist", func(c echo.Context) error { + viewer := resolveUsageUser(c, app) + if viewer == nil { + return c.JSON(http.StatusUnauthorized, map[string]string{"error": "not authenticated"}) + } + if viewer.Role != auth.RoleAdmin { + return c.JSON(http.StatusForbidden, map[string]string{"error": "admin access required"}) + } + + appCfg := app.ApplicationConfig() + existing, err := appCfg.ReadPersistedSettings() + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{"error": "read settings: " + err.Error()}) + } + // Only persist patterns whose live state differs from the YAML + // default — that way an operator can compare runtime_settings.json + // at a glance and see only the deltas they applied. + defaults, dErr := pii.LoadConfig(appCfg.PIIConfigPath) + if dErr != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{"error": "reload defaults: " + dErr.Error()}) + } + defaultByID := make(map[string]pii.Pattern, len(defaults)) + for _, d := range defaults { + defaultByID[d.ID] = d + } + overrides := map[string]config.PIIPatternRuntimeOverride{} + for _, p := range app.PIIRedactor().Patterns() { + d, ok := defaultByID[p.ID] + ov := config.PIIPatternRuntimeOverride{} + changed := false + if !ok || p.Action != d.Action { + action := string(p.Action) + ov.Action = &action + changed = true + } + if !ok || p.Disabled != d.Disabled { + disabled := p.Disabled + ov.Disabled = &disabled + changed = true + } + if changed { + overrides[p.ID] = ov + } + } + existing.PIIPatternOverrides = &overrides + if err := appCfg.WritePersistedSettings(existing); err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{"error": "write settings: " + err.Error()}) + } + // Mirror onto the live ApplicationConfig so a subsequent reload + // without a process restart sees the same map. + appCfg.PIIPatternOverrides = overrides + return c.JSON(http.StatusOK, map[string]any{ + "persisted": true, + "override_count": len(overrides), + }) + }) +} diff --git a/core/http/routes/usage.go b/core/http/routes/usage.go new file mode 100644 index 000000000000..4565d3ee9bf6 --- /dev/null +++ b/core/http/routes/usage.go @@ -0,0 +1,157 @@ +package routes + +import ( + "net/http" + + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/application" + "github.com/mudler/LocalAI/core/http/auth" + "github.com/mudler/LocalAI/core/services/routing/billing" +) + +// RegisterUsageRoutes wires the routing-module billing endpoints. These +// are the auth-agnostic siblings of /api/auth/usage and +// /api/auth/admin/usage — they go through application.StatsRecorder() +// so that a no-auth single-user box also gets a working dashboard +// (the existing /api/auth/usage hardcodes a 401 when no user is on the +// context). +// +// Permission model: +// - GET /api/usage → current user's own usage; falls back to +// the synthetic "local" user when auth is off. +// - GET /api/usage/all → cluster-wide; requires admin when auth +// is on. In no-auth mode the local user is the only principal and +// is treated as admin (the LocalUser is constructed with Role: +// admin), so this endpoint returns the same data as /api/usage. +// +// Both endpoints accept ?period={day|week|month|all} (default month) +// and ?user_id=… on the admin path. +func RegisterUsageRoutes(e *echo.Echo, app *application.Application) { + rec := app.StatsRecorder() + if rec == nil { + // Stats explicitly disabled (--disable-stats). Register stub + // handlers that return 503 with a clear reason rather than + // 404; clients (UI, MCP tools) can distinguish "not enabled + // here" from "endpoint missing entirely". + stub := func(c echo.Context) error { + return c.JSON(http.StatusServiceUnavailable, map[string]string{ + "error": "usage tracking is disabled (--disable-stats)", + }) + } + e.GET("/api/usage", stub) + e.GET("/api/usage/all", stub) + return + } + + // GetUsageEndpoint godoc + // @Summary Get usage and token totals for the current user + // @Description Returns time-bucketed token usage for the authenticated user. In single-user no-auth mode, returns usage for the synthetic "local" user. Pass ?period={day|week|month|all}. + // @Tags usage + // @Produce json + // @Param period query string false "Time window: day, week, month, all" default(month) + // @Success 200 {object} map[string]interface{} + // @Router /api/usage [get] + e.GET("/api/usage", func(c echo.Context) error { + user := resolveUsageUser(c, app) + if user == nil { + return c.JSON(http.StatusUnauthorized, map[string]string{ + "error": "not authenticated", + }) + } + + period := c.QueryParam("period") + if period == "" { + period = "month" + } + + buckets, err := rec.Aggregate(c.Request().Context(), billing.AggregateQuery{ + UserID: user.ID, + Period: period, + }) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{ + "error": "failed to get usage", + }) + } + return c.JSON(http.StatusOK, usageResponse(buckets, user)) + }) + + // GetAllUsageEndpoint godoc + // @Summary Get cluster-wide usage (admin) + // @Description Returns aggregate usage across all users. Requires admin role when auth is enabled. In single-user no-auth mode, returns the same data as /api/usage (the local user is the only principal). + // @Tags usage + // @Produce json + // @Param period query string false "Time window: day, week, month, all" default(month) + // @Param user_id query string false "Filter to a specific user" + // @Success 200 {object} map[string]interface{} + // @Failure 403 {object} map[string]string + // @Router /api/usage/all [get] + e.GET("/api/usage/all", func(c echo.Context) error { + user := resolveUsageUser(c, app) + if user == nil { + return c.JSON(http.StatusUnauthorized, map[string]string{ + "error": "not authenticated", + }) + } + // Admin gate. The synthetic local user is built with Role: admin + // in single-user mode, so this passes naturally when auth is off. + if user.Role != auth.RoleAdmin { + return c.JSON(http.StatusForbidden, map[string]string{ + "error": "admin access required", + }) + } + + period := c.QueryParam("period") + if period == "" { + period = "month" + } + filterUser := c.QueryParam("user_id") + + buckets, err := rec.Aggregate(c.Request().Context(), billing.AggregateQuery{ + UserID: filterUser, // empty = all users + Period: period, + }) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{ + "error": "failed to get usage", + }) + } + return c.JSON(http.StatusOK, usageResponse(buckets, user)) + }) +} + +// resolveUsageUser returns the authenticated user when present, +// otherwise the synthetic local user when auth is off. Centralizes the +// "if not auth, fall back to local" pattern that both routes need. +func resolveUsageUser(c echo.Context, app *application.Application) *auth.User { + if u := auth.GetUser(c); u != nil { + return u + } + return app.FallbackUser() +} + +// usageResponse builds the JSON shape the UI consumes. The "viewer" +// field surfaces who the data belongs to so a single-user dashboard +// can show "local" without inventing its own labels. +func usageResponse(buckets []auth.UsageBucket, viewer *auth.User) map[string]any { + totals := auth.UsageTotals{} + for _, b := range buckets { + totals.PromptTokens += b.PromptTokens + totals.CompletionTokens += b.CompletionTokens + totals.TotalTokens += b.TotalTokens + totals.RequestCount += b.RequestCount + } + resp := map[string]any{ + "usage": buckets, + "totals": totals, + } + if viewer != nil { + resp["viewer"] = map[string]string{ + "id": viewer.ID, + "name": viewer.Name, + "role": viewer.Role, + "provider": viewer.Provider, + } + } + return resp +} diff --git a/core/http/routes/usage_test.go b/core/http/routes/usage_test.go new file mode 100644 index 000000000000..74187878100b --- /dev/null +++ b/core/http/routes/usage_test.go @@ -0,0 +1,135 @@ +package routes_test + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/http/auth" + "github.com/mudler/LocalAI/core/services/routing/billing" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +// fakeRecorderBackend lets us assert what the handler asked for without +// pulling in a real GORM/SQLite. The aggregate query is captured so +// the test can verify (a) it ran with the right user/period and (b) +// the JSON shape of the response matches the UI's expectations. +type fakeRecorderBackend struct { + lastQuery billing.AggregateQuery + buckets []auth.UsageBucket +} + +func (f *fakeRecorderBackend) Record(_ context.Context, _ *auth.UsageRecord) error { return nil } +func (f *fakeRecorderBackend) Aggregate(_ context.Context, q billing.AggregateQuery) ([]auth.UsageBucket, error) { + f.lastQuery = q + return f.buckets, nil +} +func (f *fakeRecorderBackend) Close() error { return nil } + +// usageHandler reproduces the /api/usage handler logic from +// routes/usage.go without going through application.Application, which +// drags in galleryop, model loaders, etc. Keeping this tight test +// surface lets the no-auth path (the user-visible feature here) be +// covered without the auth build tag. +func usageHandler(rec *billing.Recorder, fallback *auth.User) echo.HandlerFunc { + return func(c echo.Context) error { + user := auth.GetUser(c) + if user == nil { + user = fallback + } + if user == nil { + return c.JSON(http.StatusUnauthorized, map[string]string{"error": "not authenticated"}) + } + period := c.QueryParam("period") + if period == "" { + period = "month" + } + buckets, err := rec.Aggregate(c.Request().Context(), billing.AggregateQuery{ + UserID: user.ID, + Period: period, + }) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{"error": "agg failed"}) + } + return c.JSON(http.StatusOK, map[string]any{ + "usage": buckets, + "viewer": map[string]string{ + "id": user.ID, + "name": user.Name, + "role": user.Role, + }, + }) + } +} + +var _ = Describe("Usage endpoint", func() { + It("resolves the local user in no-auth mode", func() { + fb := &fakeRecorderBackend{ + buckets: []auth.UsageBucket{ + {Bucket: "2026-05-05", Model: "qwen-7b", PromptTokens: 100, TotalTokens: 150, RequestCount: 3}, + }, + } + rec := billing.NewRecorder(fb) + fallback := &auth.User{ID: "local-uuid", Name: "local", Role: auth.RoleAdmin} + + e := echo.New() + e.GET("/api/usage", usageHandler(rec, fallback)) + + // No Authorization header: simulates --auth=off. The handler must + // fall through to the fallback user instead of 401-ing. + req := httptest.NewRequest(http.MethodGet, "/api/usage?period=week", nil) + w := httptest.NewRecorder() + e.ServeHTTP(w, req) + + Expect(w.Code).To(Equal(http.StatusOK), "status: got %d, body: %s", w.Code, w.Body.String()) + Expect(fb.lastQuery.UserID).To(Equal("local-uuid")) + Expect(fb.lastQuery.Period).To(Equal("week")) + + var resp struct { + Usage []struct { + Model string `json:"model"` + TotalTokens int64 `json:"total_tokens"` + RequestCount int64 `json:"request_count"` + } `json:"usage"` + Viewer struct { + ID string `json:"id"` + Name string `json:"name"` + } `json:"viewer"` + } + Expect(json.Unmarshal(w.Body.Bytes(), &resp)).To(Succeed()) + Expect(resp.Usage).To(HaveLen(1)) + Expect(resp.Usage[0].Model).To(Equal("qwen-7b")) + Expect(resp.Viewer.ID).To(Equal("local-uuid")) + Expect(resp.Viewer.Name).To(Equal("local")) + }) + + It("returns 401 when there is no user and no fallback", func() { + rec := billing.NewRecorder(&fakeRecorderBackend{}) + e := echo.New() + e.GET("/api/usage", usageHandler(rec, nil)) + + req := httptest.NewRequest(http.MethodGet, "/api/usage", nil) + w := httptest.NewRecorder() + e.ServeHTTP(w, req) + + Expect(w.Code).To(Equal(http.StatusUnauthorized)) + }) + + It("defaults to month period when none is supplied", func() { + fb := &fakeRecorderBackend{} + rec := billing.NewRecorder(fb) + fallback := &auth.User{ID: "u", Name: "u", Role: auth.RoleAdmin} + e := echo.New() + e.GET("/api/usage", usageHandler(rec, fallback)) + + req := httptest.NewRequest(http.MethodGet, "/api/usage", nil) + w := httptest.NewRecorder() + e.ServeHTTP(w, req) + + Expect(w.Code).To(Equal(http.StatusOK)) + Expect(fb.lastQuery.Period).To(Equal("month")) + }) +}) diff --git a/core/services/cloudproxy/mitm/ca.go b/core/services/cloudproxy/mitm/ca.go new file mode 100644 index 000000000000..1dea43566a82 --- /dev/null +++ b/core/services/cloudproxy/mitm/ca.go @@ -0,0 +1,177 @@ +// Package mitm implements a TLS man-in-the-middle proxy that +// applies per-request PII redaction to allowlisted LLM API hosts +// while tunnelling everything else byte-for-byte. +package mitm + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "os" + "path/filepath" + "sync" + "time" +) + +type CA struct { + cert *x509.Certificate + key *ecdsa.PrivateKey + publicPEM []byte + + mu sync.Mutex + leaves map[string]*leafEntry +} + +// LoadOrCreateCA loads the CA from dir if both files exist, or +// generates a new ECDSA-P256 CA and persists it. The key file is +// mode 0600. +func LoadOrCreateCA(dir string) (*CA, error) { + if err := os.MkdirAll(dir, 0o700); err != nil { + return nil, fmt.Errorf("mitm: create ca dir %q: %w", dir, err) + } + + certPath := filepath.Join(dir, "ca.crt") + keyPath := filepath.Join(dir, "ca.key") + + certPEM, err1 := os.ReadFile(certPath) + keyPEM, err2 := os.ReadFile(keyPath) + if err1 == nil && err2 == nil { + ca, err := parseCA(certPEM, keyPEM) + if err == nil { + return ca, nil + } + // Fall through and regenerate. We don't auto-delete the + // existing files — the operator might have hand-edited + // them. Surface the parse error instead. + return nil, fmt.Errorf("mitm: parse existing CA at %s: %w (delete to regenerate)", dir, err) + } + + ca, certPEMOut, keyPEMOut, err := generateCA() + if err != nil { + return nil, err + } + if err := os.WriteFile(certPath, certPEMOut, 0o644); err != nil { + return nil, fmt.Errorf("mitm: write ca cert %q: %w", certPath, err) + } + if err := os.WriteFile(keyPath, keyPEMOut, 0o600); err != nil { + return nil, fmt.Errorf("mitm: write ca key %q: %w", keyPath, err) + } + return ca, nil +} + +func generateCA() (*CA, []byte, []byte, error) { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, nil, nil, fmt.Errorf("mitm: generate ca key: %w", err) + } + + serial, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + if err != nil { + return nil, nil, nil, fmt.Errorf("mitm: serial: %w", err) + } + + now := time.Now().UTC() + tmpl := &x509.Certificate{ + SerialNumber: serial, + Subject: pkix.Name{ + CommonName: "LocalAI MITM Proxy CA", + Organization: []string{"LocalAI"}, + }, + NotBefore: now.Add(-1 * time.Hour), + NotAfter: now.Add(10 * 365 * 24 * time.Hour), + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign | x509.KeyUsageDigitalSignature, + BasicConstraintsValid: true, + IsCA: true, + MaxPathLenZero: true, + } + + der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key) + if err != nil { + return nil, nil, nil, fmt.Errorf("mitm: create ca cert: %w", err) + } + cert, err := x509.ParseCertificate(der) + if err != nil { + return nil, nil, nil, fmt.Errorf("mitm: re-parse ca cert: %w", err) + } + + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}) + keyDER, err := x509.MarshalECPrivateKey(key) + if err != nil { + return nil, nil, nil, fmt.Errorf("mitm: marshal ca key: %w", err) + } + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}) + + return &CA{ + cert: cert, + key: key, + publicPEM: certPEM, + leaves: make(map[string]*leafEntry), + }, certPEM, keyPEM, nil +} + +// NewInMemoryCA mints an ephemeral CA for tests. +func NewInMemoryCA() (*CA, error) { + ca, _, _, err := generateCA() + return ca, err +} + +func parseCA(certPEM, keyPEM []byte) (*CA, error) { + certBlock, _ := pem.Decode(certPEM) + if certBlock == nil || certBlock.Type != "CERTIFICATE" { + return nil, fmt.Errorf("mitm: ca cert PEM block missing or wrong type") + } + cert, err := x509.ParseCertificate(certBlock.Bytes) + if err != nil { + return nil, fmt.Errorf("mitm: parse ca cert: %w", err) + } + if !cert.IsCA { + return nil, fmt.Errorf("mitm: stored cert at is not a CA") + } + + keyBlock, _ := pem.Decode(keyPEM) + if keyBlock == nil { + return nil, fmt.Errorf("mitm: ca key PEM block missing") + } + var key *ecdsa.PrivateKey + switch keyBlock.Type { + case "EC PRIVATE KEY": + k, err := x509.ParseECPrivateKey(keyBlock.Bytes) + if err != nil { + return nil, fmt.Errorf("mitm: parse ec ca key: %w", err) + } + key = k + case "PRIVATE KEY": + k, err := x509.ParsePKCS8PrivateKey(keyBlock.Bytes) + if err != nil { + return nil, fmt.Errorf("mitm: parse pkcs8 ca key: %w", err) + } + ecKey, ok := k.(*ecdsa.PrivateKey) + if !ok { + return nil, fmt.Errorf("mitm: pkcs8 key is not ECDSA") + } + key = ecKey + default: + return nil, fmt.Errorf("mitm: unsupported ca key PEM type %q", keyBlock.Type) + } + + return &CA{ + cert: cert, + key: key, + publicPEM: certPEM, + leaves: make(map[string]*leafEntry), + }, nil +} + +// PublicCertPEM returns a copy of the PEM-encoded CA certificate. +func (c *CA) PublicCertPEM() []byte { + out := make([]byte, len(c.publicPEM)) + copy(out, c.publicPEM) + return out +} + +func (c *CA) Cert() *x509.Certificate { return c.cert } diff --git a/core/services/cloudproxy/mitm/ca_test.go b/core/services/cloudproxy/mitm/ca_test.go new file mode 100644 index 000000000000..308361919343 --- /dev/null +++ b/core/services/cloudproxy/mitm/ca_test.go @@ -0,0 +1,79 @@ +package mitm + +import ( + "crypto/x509" + "encoding/pem" + "os" + "path/filepath" + "strings" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("LoadOrCreateCA", func() { + It("generates and persists", func() { + dir := GinkgoT().TempDir() + + ca1, err := LoadOrCreateCA(dir) + Expect(err).NotTo(HaveOccurred(), "first call") + Expect(ca1.cert).NotTo(BeNil()) + Expect(ca1.cert.IsCA).To(BeTrue(), "generated cert is not a CA") + // Files must be on disk after first call. + for _, name := range []string{"ca.crt", "ca.key"} { + path := filepath.Join(dir, name) + info, err := os.Stat(path) + Expect(err).NotTo(HaveOccurred(), "expected %s to exist", path) + mode := info.Mode().Perm() + if name == "ca.key" { + Expect(mode).To(Equal(os.FileMode(0o600))) + } + } + + // Second load must round-trip the same cert (same serial number + // proves we read from disk rather than regenerating). + ca2, err := LoadOrCreateCA(dir) + Expect(err).NotTo(HaveOccurred(), "second call") + Expect(ca1.cert.SerialNumber.Cmp(ca2.cert.SerialNumber)).To(Equal(0), "second load regenerated instead of reading from disk") + }) + + It("rejects non-CA stored cert", func() { + dir := GinkgoT().TempDir() + // Write a non-CA leaf cert into the slot reserved for the CA. + ca, err := NewInMemoryCA() + Expect(err).NotTo(HaveOccurred()) + leaf, err := ca.IssueLeaf("example.com") + Expect(err).NotTo(HaveOccurred()) + leafPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: leaf.Certificate[0]}) + Expect(os.WriteFile(filepath.Join(dir, "ca.crt"), leafPEM, 0o644)).To(Succeed()) + // Pair with a key file so LoadOrCreateCA proceeds to parse. + Expect(os.WriteFile(filepath.Join(dir, "ca.key"), []byte("garbage"), 0o600)).To(Succeed()) + _, err = LoadOrCreateCA(dir) + Expect(err).To(HaveOccurred(), "expected error for non-CA cert in CA slot") + Expect(strings.Contains(err.Error(), "delete to regenerate")).To(BeTrue(), "error should mention regenerate path") + }) +}) + +var _ = Describe("PublicCertPEM", func() { + It("is a valid certificate", func() { + ca, err := NewInMemoryCA() + Expect(err).NotTo(HaveOccurred()) + pemBytes := ca.PublicCertPEM() + block, _ := pem.Decode(pemBytes) + Expect(block).NotTo(BeNil()) + Expect(block.Type).To(Equal("CERTIFICATE")) + cert, err := x509.ParseCertificate(block.Bytes) + Expect(err).NotTo(HaveOccurred()) + Expect(cert.IsCA).To(BeTrue(), "decoded cert is not a CA") + }) + + It("returns a copy", func() { + // Mutating the returned slice must not poison subsequent calls. + ca, err := NewInMemoryCA() + Expect(err).NotTo(HaveOccurred()) + first := ca.PublicCertPEM() + first[0] = 0x00 // corrupt + second := ca.PublicCertPEM() + Expect(second[0]).NotTo(Equal(byte(0x00)), "PublicCertPEM aliased its cache; mutation leaked") + }) +}) diff --git a/core/services/cloudproxy/mitm/handler.go b/core/services/cloudproxy/mitm/handler.go new file mode 100644 index 000000000000..8d73fe73568b --- /dev/null +++ b/core/services/cloudproxy/mitm/handler.go @@ -0,0 +1,442 @@ +package mitm + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "sync/atomic" + "time" + + "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/core/services/cloudproxy/ssewire" + "github.com/mudler/LocalAI/core/services/routing/pii" + "github.com/mudler/LocalAI/core/services/routing/piiadapter" + "github.com/mudler/xlog" + "golang.org/x/net/http2" +) + +// PIIHandlerOptions configures NewPIIHandler. +type PIIHandlerOptions struct { + // Redactor is the regex PII redactor. nil disables redaction. + Redactor *pii.Redactor + + // EventStore receives PIIEvent rows. nil discards events. + EventStore pii.EventStore + + // UpstreamTLS overrides the tls.Config used when dialing the + // real upstream. Defaults to a system-trust HTTPS client. + UpstreamTLS *tls.Config + + // CorrelationIDHeader names the request header carrying a + // caller-supplied correlation ID. Defaults to "X-Correlation-ID". + CorrelationIDHeader string + + // DialHost optionally remaps the host used for the outbound + // upstream URL. Identity by default; tests inject a httptest + // listener address. + DialHost func(host string) string + + // HostsWithPIIDisabled lists destination hosts whose request + // bodies should NOT run through the redactor. TLS termination, + // upstream forwarding, and audit events still happen — only the + // regex pass is bypassed. Useful for telemetry/probe endpoints + // whose bodies aren't PII-shaped. + HostsWithPIIDisabled []string +} + +func NewPIIHandler(opts PIIHandlerOptions) InterceptHandler { + tlsCfg := opts.UpstreamTLS + if tlsCfg == nil { + tlsCfg = &tls.Config{NextProtos: []string{"h2", "http/1.1"}} + } else if len(tlsCfg.NextProtos) == 0 { + tlsCfg.NextProtos = []string{"h2", "http/1.1"} + } + transport := &http.Transport{ + TLSClientConfig: tlsCfg, + ForceAttemptHTTP2: true, + } + if err := http2.ConfigureTransport(transport); err != nil { + xlog.Debug("mitm: http2.ConfigureTransport failed", "error", err) + } + + corrHeader := opts.CorrelationIDHeader + if corrHeader == "" { + corrHeader = "X-Correlation-ID" + } + + dialHost := opts.DialHost + if dialHost == nil { + dialHost = func(h string) string { return h } + } + + patternAction := map[string]pii.Action{} + if opts.Redactor != nil { + for _, p := range opts.Redactor.Patterns() { + patternAction[p.ID] = p.Action + } + } + + piiDisabled := make(map[string]bool, len(opts.HostsWithPIIDisabled)) + for _, h := range opts.HostsWithPIIDisabled { + piiDisabled[strings.ToLower(strings.TrimSpace(h))] = true + } + + d := &piiDispatcher{ + client: &http.Client{Transport: transport}, + redactor: opts.Redactor, + store: opts.EventStore, + patternAction: patternAction, + corrHeader: corrHeader, + dialHost: dialHost, + piiDisabled: piiDisabled, + } + return d.serve +} + +type piiDispatcher struct { + client *http.Client + redactor *pii.Redactor + store pii.EventStore + patternAction map[string]pii.Action + corrHeader string + dialHost func(host string) string + piiDisabled map[string]bool + eventSeq atomic.Uint64 +} + +func (d *piiDispatcher) serve(w http.ResponseWriter, r *http.Request, host string) { + start := time.Now() + cw := &countingResponseWriter{ResponseWriter: w} + w = cw + + var ( + correlationID string + bytesSent int64 + ) + defer func() { + d.recordTrafficEvent(host, correlationID, bytesSent, cw.bytes, cw.status, start) + }() + + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "mitm: read body: "+err.Error(), http.StatusBadGateway) + return + } + _ = r.Body.Close() + + correlationID = r.Header.Get(d.corrHeader) + if correlationID == "" { + correlationID = r.Header.Get("x-request-id") + } + + shape := classifyRequestShape(host, r.URL.Path) + if d.redactor != nil && shape != shapeUnknown && !d.piiDisabled[strings.ToLower(host)] { + redacted, blocked, err := d.redactRequest(body, shape, correlationID) + switch { + case err != nil: + xlog.Debug("mitm: redact request failed; forwarding unchanged", "host", host, "path", r.URL.Path, "error", err) + case blocked: + writePIIBlocked(w, correlationID) + return + default: + body = redacted + } + } + + upstreamURL := "https://" + d.dialHost(host) + r.URL.RequestURI() + upstreamReq, err := http.NewRequestWithContext(r.Context(), r.Method, upstreamURL, bytes.NewReader(body)) + if err != nil { + http.Error(w, "mitm: build upstream request: "+err.Error(), http.StatusBadGateway) + return + } + upstreamReq.Header = cloneHopByHopFiltered(r.Header) + upstreamReq.ContentLength = int64(len(body)) + upstreamReq.Header.Set("Content-Length", fmt.Sprintf("%d", len(body))) + bytesSent = int64(len(body)) + + resp, err := d.client.Do(upstreamReq) + if err != nil { + http.Error(w, "mitm: upstream: "+err.Error(), http.StatusBadGateway) + return + } + defer func() { _ = resp.Body.Close() }() + + for k, vs := range resp.Header { + if isHopByHop(k) || strings.EqualFold(k, "Transfer-Encoding") || strings.EqualFold(k, "Content-Length") { + continue + } + for _, v := range vs { + w.Header().Add(k, v) + } + } + w.WriteHeader(resp.StatusCode) + + contentType := resp.Header.Get("Content-Type") + if shape != shapeUnknown && d.redactor != nil && isSSE(contentType) { + d.streamWithPII(w, resp.Body, shape, correlationID) + return + } + + if isSSE(contentType) { + flusher, _ := w.(http.Flusher) + buf := make([]byte, 32*1024) + for { + n, rErr := resp.Body.Read(buf) + if n > 0 { + if _, wErr := w.Write(buf[:n]); wErr != nil { + return + } + if flusher != nil { + flusher.Flush() + } + } + if rErr != nil { + return + } + } + } + + _, _ = io.Copy(w, resp.Body) +} + +type requestShape int + +const ( + shapeUnknown requestShape = iota + shapeOpenAIChat + shapeAnthropicMessages +) + +func classifyRequestShape(host, path string) requestShape { + host = strings.ToLower(host) + switch { + case host == "api.openai.com" && strings.HasSuffix(path, "/v1/chat/completions"): + return shapeOpenAIChat + case host == "api.anthropic.com" && strings.HasSuffix(path, "/v1/messages"): + return shapeAnthropicMessages + } + return shapeUnknown +} + +func (d *piiDispatcher) redactRequest(body []byte, shape requestShape, correlationID string) ([]byte, bool, error) { + var parsed any + var adapter pii.Adapter + switch shape { + case shapeOpenAIChat: + req := &schema.OpenAIRequest{} + if err := json.Unmarshal(body, req); err != nil { + return nil, false, fmt.Errorf("parse openai: %w", err) + } + parsed = req + adapter = piiadapter.OpenAI() + case shapeAnthropicMessages: + req := &schema.AnthropicRequest{} + if err := json.Unmarshal(body, req); err != nil { + return nil, false, fmt.Errorf("parse anthropic: %w", err) + } + parsed = req + adapter = piiadapter.Anthropic() + default: + return body, false, nil + } + + texts := adapter.Scan(parsed) + if len(texts) == 0 { + return body, false, nil + } + + updates := make([]pii.ScannedText, 0, len(texts)) + blocked := false + for _, st := range texts { + if st.Text == "" { + continue + } + res := d.redactor.RedactWithOverrides(st.Text, nil) + if len(res.Spans) == 0 { + continue + } + d.recordEvents(res.Spans, correlationID) + if res.Blocked { + blocked = true + } + updates = append(updates, pii.ScannedText{Index: st.Index, Text: res.Redacted}) + } + + if len(updates) > 0 { + adapter.Apply(parsed, updates) + } + + out, err := json.Marshal(parsed) + if err != nil { + return nil, false, fmt.Errorf("re-marshal: %w", err) + } + return out, blocked, nil +} + +func (d *piiDispatcher) recordEvents(spans []pii.Span, correlationID string) { + if d.store == nil { + return + } + for _, span := range spans { + ev := pii.PIIEvent{ + ID: fmt.Sprintf("mitm_%s_%d", correlationID, d.eventSeq.Add(1)), + Kind: pii.KindPII, + CorrelationID: correlationID, + Direction: pii.DirectionIn, + PatternID: span.Pattern, + ByteOffset: span.Start, + Length: span.End - span.Start, + HashPrefix: span.HashPrefix, + Action: d.patternAction[span.Pattern], + CreatedAt: time.Now(), + } + if err := d.store.Record(context.Background(), ev); err != nil { + xlog.Debug("mitm: failed to record pii event", "error", err, "pattern", span.Pattern) + } + } +} + +func (d *piiDispatcher) streamWithPII(w http.ResponseWriter, src io.Reader, shape requestShape, correlationID string) { + flusher, _ := w.(http.Flusher) + filter := pii.NewStreamFilter(d.redactor, nil, d.store, correlationID, "") + + provider := ssewire.OpenAI + if shape == shapeAnthropicMessages { + provider = ssewire.Anthropic + } + + emit := func(s string) { + _, _ = w.Write([]byte(s)) + if flusher != nil { + flusher.Flush() + } + } + + scanner := ssewire.NewScanner(src) + for scanner.Scan() { + ev := scanner.Event() + if ssewire.IsTerminalMarker(ev.DataLine, provider) { + if residual := filter.Drain(); residual != "" { + emit(ssewire.SynthResidualEvent(provider, residual)) + } + emit(ev.Raw) + continue + } + out := ev.Raw + if ev.DataLine != "" { + rewritten, drop := ssewire.RewritePayload(ev.DataLine, provider, filter) + if drop { + continue + } + if rewritten != ev.DataLine { + out = strings.Replace(ev.Raw, ev.DataLine, rewritten, 1) + } + } + emit(out) + } + if residual := filter.Drain(); residual != "" { + emit(ssewire.SynthResidualEvent(provider, residual)) + } +} + +func writePIIBlocked(w http.ResponseWriter, correlationID string) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + resp := map[string]any{ + "error": map[string]string{ + "message": "request blocked by LocalAI MITM proxy (sensitive data detected)", + "type": "pii_blocked", + }, + "correlation_id": correlationID, + } + _ = json.NewEncoder(w).Encode(resp) +} + +func isSSE(contentType string) bool { + return strings.HasPrefix(strings.TrimSpace(contentType), "text/event-stream") +} + +// hopByHopHeaders are not forwarded by the proxy (RFC 7230 §6.1). +var hopByHopHeaders = map[string]struct{}{ + "Connection": {}, + "Keep-Alive": {}, + "Proxy-Authenticate": {}, + "Proxy-Authorization": {}, + "Te": {}, + "Trailers": {}, + "Transfer-Encoding": {}, + "Upgrade": {}, +} + +func isHopByHop(name string) bool { + _, ok := hopByHopHeaders[http.CanonicalHeaderKey(name)] + return ok +} + +// countingResponseWriter wraps an http.ResponseWriter to track the +// total bytes written downstream and the status code. It implements +// http.Flusher because the SSE paths flush per event; without that +// the assertion `w.(http.Flusher)` would silently degrade to no-op. +type countingResponseWriter struct { + http.ResponseWriter + bytes int64 + status int +} + +func (w *countingResponseWriter) Write(p []byte) (int, error) { + if w.status == 0 { + w.status = http.StatusOK + } + n, err := w.ResponseWriter.Write(p) + w.bytes += int64(n) + return n, err +} + +func (w *countingResponseWriter) WriteHeader(code int) { + w.status = code + w.ResponseWriter.WriteHeader(code) +} + +func (w *countingResponseWriter) Flush() { + if f, ok := w.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} + +func (d *piiDispatcher) recordTrafficEvent(host, correlationID string, sent, received int64, status int, start time.Time) { + if d.store == nil { + return + } + ev := pii.PIIEvent{ + ID: fmt.Sprintf("proxy_traffic_%s_%d", correlationID, d.eventSeq.Add(1)), + Kind: pii.KindProxyTraffic, + CorrelationID: correlationID, + Host: host, + BytesSent: sent, + BytesReceived: received, + StatusCode: status, + DurationMS: time.Since(start).Milliseconds(), + CreatedAt: time.Now(), + } + if err := d.store.Record(context.Background(), ev); err != nil { + xlog.Debug("mitm: failed to record proxy_traffic event", "error", err, "host", host) + } +} + +func cloneHopByHopFiltered(in http.Header) http.Header { + out := make(http.Header, len(in)) + for k, vs := range in { + if isHopByHop(k) { + continue + } + copied := make([]string, len(vs)) + copy(copied, vs) + out[k] = copied + } + return out +} diff --git a/core/services/cloudproxy/mitm/handler_test.go b/core/services/cloudproxy/mitm/handler_test.go new file mode 100644 index 000000000000..b6177b0e9fe0 --- /dev/null +++ b/core/services/cloudproxy/mitm/handler_test.go @@ -0,0 +1,329 @@ +package mitm + +import ( + "context" + "crypto/tls" + "crypto/x509" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/http/httptest" + "net/url" + "strings" + + "github.com/mudler/LocalAI/core/services/routing/pii" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +// startPIITestRig is the same shape as startMITMTestRig but plugs +// in the production PII handler instead of the passthrough fixture. +// The "host" the client thinks it's reaching is forced to +// api.anthropic.com so the request shape classifier matches. +func startPIITestRig(upstream http.Handler) (*http.Client, string, *fakeStore, func()) { + // Upstream fake — plays the role of api.anthropic.com. + ts := httptest.NewTLSServer(upstream) + upstreamCertPool := x509.NewCertPool() + upstreamCertPool.AddCert(ts.Certificate()) + upstreamURL, _ := url.Parse(ts.URL) + + // Compiled patterns required for the redactor to actually fire + // (DefaultPatterns alone returns Pattern structs without regex). + patterns, err := pii.Compile(pii.DefaultPatterns()) + ExpectWithOffset(1, err).NotTo(HaveOccurred()) + redactor := pii.NewRedactor(patterns) + store := &fakeStore{} + + ca, err := NewInMemoryCA() + ExpectWithOffset(1, err).NotTo(HaveOccurred()) + + // DialHost remaps the upstream dial target to the httptest + // fake while leaving the classifier-facing host + // ("api.anthropic.com") untouched. ServerName=example.com is + // what httptest.NewTLSServer issues its cert for. + upstreamHost := upstreamURL.Host + prodHandler := NewPIIHandler(PIIHandlerOptions{ + Redactor: redactor, + EventStore: store, + UpstreamTLS: &tls.Config{ + RootCAs: upstreamCertPool, + ServerName: "example.com", + }, + DialHost: func(_ string) string { return upstreamHost }, + }) + + srv, err := NewServer(Config{ + Addr: "127.0.0.1:0", + CA: ca, + InterceptHosts: []string{"api.anthropic.com"}, + Handler: prodHandler, + EventStore: store, + }) + ExpectWithOffset(1, err).NotTo(HaveOccurred()) + ExpectWithOffset(1, srv.Start()).To(Succeed()) + + clientPool := x509.NewCertPool() + clientPool.AddCert(ca.Cert()) + proxyURL, _ := url.Parse("http://" + srv.Addr()) + client := &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + TLSClientConfig: &tls.Config{RootCAs: clientPool}, + }, + } + + cleanup := func() { + srv.Stop() + ts.Close() + } + // We point requests at api.anthropic.com so classifyRequestShape + // matches; the wrappedHandler retargets to the upstream fake. + return client, "https://api.anthropic.com", store, cleanup +} + +type fakeStore struct{ events []pii.PIIEvent } + +func (s *fakeStore) Record(_ context.Context, ev pii.PIIEvent) error { + s.events = append(s.events, ev) + return nil +} + +func (s *fakeStore) List(_ context.Context, _ pii.ListQuery) ([]pii.PIIEvent, error) { + return s.events, nil +} + +func (s *fakeStore) Count(_ context.Context) (int, error) { return len(s.events), nil } +func (s *fakeStore) Close() error { return nil } + +func (s *fakeStore) recorded() int { return len(s.events) } + +var _ = Describe("PIIHandler", func() { + It("redacts request email", func() { + var receivedBody []byte + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedBody, _ = io.ReadAll(r.Body) + w.Header().Set("Content-Type", "application/json") + _, _ = fmt.Fprint(w, `{"id":"msg_x","content":[{"type":"text","text":"ok"}]}`) + }) + + client, base, store, cleanup := startPIITestRig(upstream) + defer cleanup() + + body := `{"model":"claude-3-5-sonnet","max_tokens":100,"messages":[{"role":"user","content":"my email is alice@example.com please reply"}]}` + resp, err := client.Post(base+"/v1/messages", "application/json", strings.NewReader(body)) + Expect(err).NotTo(HaveOccurred(), "client.Post") + defer func() { _ = resp.Body.Close() }() + Expect(resp.StatusCode).To(Equal(200)) + + Expect(string(receivedBody)).NotTo(ContainSubstring("alice@example.com"), "upstream received unredacted body") + Expect(string(receivedBody)).To(ContainSubstring("[REDACTED:email]"), "upstream did not see redaction marker") + Expect(store.recorded()).NotTo(BeZero(), "no PIIEvent recorded for the email match") + }) + + It("blocks api key in request", func() { + upstreamCalled := false + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upstreamCalled = true + w.WriteHeader(200) + }) + + client, base, _, cleanup := startPIITestRig(upstream) + defer cleanup() + + body := `{"model":"claude-3-5-sonnet","max_tokens":100,"messages":[{"role":"user","content":"my key is sk-abcdefghijklmnopqrstuvwxyz1234"}]}` + resp, err := client.Post(base+"/v1/messages", "application/json", strings.NewReader(body)) + Expect(err).NotTo(HaveOccurred(), "client.Post") + defer func() { _ = resp.Body.Close() }() + Expect(resp.StatusCode).To(Equal(400), "api_key_prefix has Block default") + Expect(upstreamCalled).To(BeFalse(), "upstream was called despite block — proxy should short-circuit") + body2, _ := io.ReadAll(resp.Body) + Expect(string(body2)).To(ContainSubstring("pii_blocked")) + }) + + It("streaming redaction", func() { + // Anthropic-shape SSE; "alice@" + "example.com" splits the + // email across chunks so the StreamFilter has to buffer. + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(200) + flusher := w.(http.Flusher) + chunks := []string{ + `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"contact me at alice@"}}`, + `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"example.com any time"}}`, + `{"type":"message_stop"}`, + } + for _, c := range chunks { + _, _ = fmt.Fprintf(w, "event: %s\ndata: %s\n\n", "content_block_delta", c) + flusher.Flush() + } + }) + + client, base, _, cleanup := startPIITestRig(upstream) + defer cleanup() + + body := `{"model":"claude-3-5-sonnet","max_tokens":100,"stream":true,"messages":[{"role":"user","content":"hi"}]}` + resp, err := client.Post(base+"/v1/messages", "application/json", strings.NewReader(body)) + Expect(err).NotTo(HaveOccurred(), "Post") + defer func() { _ = resp.Body.Close() }() + out, _ := io.ReadAll(resp.Body) + outStr := string(out) + Expect(outStr).NotTo(ContainSubstring("alice@example.com"), "email leaked through MITM stream") + Expect(outStr).To(ContainSubstring("[REDACTED:email]"), "redaction marker missing from MITM stream") + }) + + It("non-chat path passes through", func() { + // A path the classifier doesn't recognise (e.g. an OAuth + // callback) must forward the body verbatim, no PII parsing. + var receivedBody []byte + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedBody, _ = io.ReadAll(r.Body) + w.Header().Set("Content-Type", "application/json") + _, _ = fmt.Fprint(w, `{"ok":true}`) + }) + + client, base, _, cleanup := startPIITestRig(upstream) + defer cleanup() + + body := `{"email":"alice@example.com"}` + resp, err := client.Post(base+"/oauth/callback", "application/json", strings.NewReader(body)) + Expect(err).NotTo(HaveOccurred()) + defer func() { _ = resp.Body.Close() }() + Expect(string(receivedBody)).To(Equal(body), "body forwarded with mutation") + }) +}) + +var _ = Describe("redactRequest", func() { + It("handles anthropic shape", func() { + patterns, _ := pii.Compile(pii.DefaultPatterns()) + r := pii.NewRedactor(patterns) + body := []byte(`{"model":"claude","max_tokens":10,"messages":[{"role":"user","content":"reach me at bob@example.org"}]}`) + + d := &piiDispatcher{redactor: r, patternAction: map[string]pii.Action{}} + out, blocked, err := d.redactRequest(body, shapeAnthropicMessages, "corr-1") + Expect(err).NotTo(HaveOccurred()) + Expect(blocked).To(BeFalse(), "email is mask, not block — blocked should be false") + var parsed map[string]any + Expect(json.Unmarshal(out, &parsed)).To(Succeed()) + msgs := parsed["messages"].([]any) + first := msgs[0].(map[string]any) + content, _ := first["content"].(string) + Expect(content).NotTo(ContainSubstring("bob@example.org"), "redaction did not run") + }) +}) + +var _ = Describe("Proxy events", func() { + It("emits connect and traffic events", func() { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = fmt.Fprint(w, `{"id":"msg_x","content":[{"type":"text","text":"ok"}]}`) + }) + + client, base, store, cleanup := startPIITestRig(upstream) + defer cleanup() + + body := `{"model":"claude-3-5-sonnet","max_tokens":10,"messages":[{"role":"user","content":"hi"}]}` + resp, err := client.Post(base+"/v1/messages", "application/json", strings.NewReader(body)) + Expect(err).NotTo(HaveOccurred(), "client.Post") + defer func() { _ = resp.Body.Close() }() + _, _ = io.Copy(io.Discard, resp.Body) + + var connect, traffic *pii.PIIEvent + for i := range store.events { + ev := &store.events[i] + switch ev.ResolvedKind() { + case pii.KindProxyConnect: + connect = ev + case pii.KindProxyTraffic: + traffic = ev + } + } + + Expect(connect).NotTo(BeNil(), "no proxy_connect event recorded") + Expect(connect.Host).To(Equal("api.anthropic.com")) + Expect(connect.Intercepted).NotTo(BeNil()) + Expect(*connect.Intercepted).To(BeTrue(), "connect.Intercepted should be true for an allowlisted host") + + Expect(traffic).NotTo(BeNil(), "no proxy_traffic event recorded") + Expect(traffic.Host).To(Equal("api.anthropic.com")) + Expect(traffic.BytesSent).To(BeNumerically(">", 0)) + Expect(traffic.BytesReceived).To(BeNumerically(">", 0)) + Expect(traffic.StatusCode).To(Equal(200)) + }) + + It("tunneled host emits connect event only", func() { + // A non-allowlisted CONNECT must record a proxy_connect with + // Intercepted=false and NOT a proxy_traffic event (tunneled + // bytes never reach the dispatcher). + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = fmt.Fprint(w, "passthrough") + }) + ts := httptest.NewTLSServer(upstream) + defer ts.Close() + upstreamURL, _ := url.Parse(ts.URL) + upstreamHost, _, _ := net.SplitHostPort(upstreamURL.Host) + + ca, _ := NewInMemoryCA() + store := &fakeStore{} + srv, err := NewServer(Config{ + Addr: "127.0.0.1:0", + CA: ca, + InterceptHosts: []string{"some-other-host"}, + Handler: func(w http.ResponseWriter, r *http.Request, h string) {}, + EventStore: store, + }) + Expect(err).NotTo(HaveOccurred()) + Expect(srv.Start()).To(Succeed()) + defer srv.Stop() + + upstreamCertPool := x509.NewCertPool() + upstreamCertPool.AddCert(ts.Certificate()) + proxyURL, _ := url.Parse("http://" + srv.Addr()) + client := &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + TLSClientConfig: &tls.Config{ + RootCAs: upstreamCertPool, + ServerName: upstreamHost, + }, + }, + } + resp, err := client.Get(ts.URL) + Expect(err).NotTo(HaveOccurred(), "Get through tunnel") + _ = resp.Body.Close() + + var connect *pii.PIIEvent + for i := range store.events { + ev := &store.events[i] + Expect(ev.ResolvedKind()).NotTo(Equal(pii.KindProxyTraffic), "unexpected proxy_traffic event for tunneled host: %+v", ev) + if ev.ResolvedKind() == pii.KindProxyConnect { + connect = ev + } + } + Expect(connect).NotTo(BeNil(), "no proxy_connect event recorded for tunneled host") + Expect(connect.Intercepted).NotTo(BeNil()) + Expect(*connect.Intercepted).To(BeFalse(), "connect.Intercepted should be false (tunneled)") + Expect(connect.Host).NotTo(BeEmpty()) + }) +}) + +var _ = Describe("classifyRequestShape", func() { + cases := []struct { + host string + path string + want requestShape + }{ + {"api.anthropic.com", "/v1/messages", shapeAnthropicMessages}, + {"api.openai.com", "/v1/chat/completions", shapeOpenAIChat}, + {"api.anthropic.com", "/v1/oauth/token", shapeUnknown}, + {"api.openai.com", "/v1/embeddings", shapeUnknown}, + {"example.com", "/v1/messages", shapeUnknown}, + } + for _, c := range cases { + It(fmt.Sprintf("classifies (%q, %q)", c.host, c.path), func() { + Expect(classifyRequestShape(c.host, c.path)).To(Equal(c.want)) + }) + } +}) diff --git a/core/services/cloudproxy/mitm/http2_test.go b/core/services/cloudproxy/mitm/http2_test.go new file mode 100644 index 000000000000..8eb70ff9443b --- /dev/null +++ b/core/services/cloudproxy/mitm/http2_test.go @@ -0,0 +1,165 @@ +package mitm + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "golang.org/x/net/http2" +) + +// h2InterceptRig is the test fixture for HTTP/2 paths. Two things +// differ from the H1.1 rig: +// - The client http.Transport has http2.ConfigureTransport called +// so it negotiates h2 with our proxy. +// - The upstream httptest server is started via StartTLS *and* +// manually configured for h2 (httptest does this by default in +// modern Go but we make it explicit for clarity). +func h2InterceptRig(interceptHost string, upstream http.Handler) (*http.Client, string, func()) { + ts := httptest.NewUnstartedServer(upstream) + ts.EnableHTTP2 = true + ts.StartTLS() + upstreamCertPool := x509.NewCertPool() + upstreamCertPool.AddCert(ts.Certificate()) + upstreamURL, _ := url.Parse(ts.URL) + + ca, err := NewInMemoryCA() + ExpectWithOffset(1, err).NotTo(HaveOccurred()) + + srv, err := NewServer(Config{ + Addr: "127.0.0.1:0", + CA: ca, + InterceptHosts: []string{interceptHost}, + Handler: passthroughHandler(upstreamCertPool, upstreamURL.Host), + }) + ExpectWithOffset(1, err).NotTo(HaveOccurred()) + ExpectWithOffset(1, srv.Start()).To(Succeed()) + + // Client with HTTP/2 explicitly enabled (modern net/http does + // this by default, but configuring the Transport directly makes + // the test independent of stdlib defaults). + clientPool := x509.NewCertPool() + clientPool.AddCert(ca.Cert()) + proxyURL, _ := url.Parse("http://" + srv.Addr()) + transport := &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + TLSClientConfig: &tls.Config{ + RootCAs: clientPool, + NextProtos: []string{"h2", "http/1.1"}, + }, + ForceAttemptHTTP2: true, + } + ExpectWithOffset(1, http2.ConfigureTransport(transport)).To(Succeed(), "client h2 configure") + client := &http.Client{Transport: transport} + + cleanup := func() { + srv.Stop() + ts.Close() + } + return client, "https://" + interceptHost, cleanup +} + +var _ = Describe("Proxy HTTP/2", func() { + It("negotiates HTTP/2", func() { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // The upstream side: when serving over h2, r.ProtoMajor == 2. + w.Header().Set("X-Upstream-Proto", r.Proto) + w.Header().Set("Content-Type", "application/json") + _, _ = fmt.Fprint(w, `{"ok":true}`) + }) + + client, base, cleanup := h2InterceptRig("api.test.local", upstream) + defer cleanup() + + resp, err := client.Get(base + "/v1/test") + Expect(err).NotTo(HaveOccurred(), "Get") + defer func() { _ = resp.Body.Close() }() + + // The proxy ↔ client leg: client sees h2 because we ALPN- + // negotiated it. resp.Proto is the protocol the client used. + Expect(resp.Proto).To(Equal("HTTP/2.0"), "proxy did not serve h2") + body, _ := io.ReadAll(resp.Body) + Expect(string(body)).To(ContainSubstring(`"ok":true`)) + }) + + It("streams over HTTP/2", func() { + // h2 streaming: the proxy must flush each frame promptly. The + // upstream sends 3 SSE-style chunks; we read them back through + // a streaming decoder so a buffering bug would surface. + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(200) + flusher := w.(http.Flusher) + for _, msg := range []string{"first", "second", "third"} { + _, _ = fmt.Fprintf(w, "data: %s\n\n", msg) + flusher.Flush() + } + }) + + client, base, cleanup := h2InterceptRig("api.test.local", upstream) + defer cleanup() + + resp, err := client.Get(base + "/stream") + Expect(err).NotTo(HaveOccurred(), "Get") + defer func() { _ = resp.Body.Close() }() + + Expect(resp.Proto).To(Equal("HTTP/2.0"), "expected h2 for streaming response") + body, _ := io.ReadAll(resp.Body) + for _, msg := range []string{"first", "second", "third"} { + Expect(strings.Contains(string(body), "data: "+msg)).To(BeTrue(), "missing %q in h2 streamed body: %s", msg, body) + } + }) + + It("falls back to HTTP/1.1", func() { + // Force the client to negotiate h1.1 only, by overriding ALPN. + // Verifies the fallback path still works for legacy clients. + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = fmt.Fprint(w, `{"ok":true}`) + }) + + ts := httptest.NewTLSServer(upstream) + defer ts.Close() + upstreamCertPool := x509.NewCertPool() + upstreamCertPool.AddCert(ts.Certificate()) + upstreamURL, _ := url.Parse(ts.URL) + + ca, _ := NewInMemoryCA() + srv, _ := NewServer(Config{ + Addr: "127.0.0.1:0", + CA: ca, + InterceptHosts: []string{"api.test.local"}, + Handler: passthroughHandler(upstreamCertPool, upstreamURL.Host), + }) + Expect(srv.Start()).To(Succeed()) + defer srv.Stop() + + clientPool := x509.NewCertPool() + clientPool.AddCert(ca.Cert()) + proxyURL, _ := url.Parse("http://" + srv.Addr()) + // ALPN intentionally restricted to http/1.1 to force the + // fallback path. Most clients will negotiate h2, but the + // proxy must keep h1 working for the rare case. + client := &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + TLSClientConfig: &tls.Config{ + RootCAs: clientPool, + NextProtos: []string{"http/1.1"}, + }, + }, + } + resp, err := client.Get("https://api.test.local/v1/test") + Expect(err).NotTo(HaveOccurred(), "Get") + defer func() { _ = resp.Body.Close() }() + Expect(resp.Proto).To(Equal("HTTP/1.1")) + body, _ := io.ReadAll(resp.Body) + Expect(string(body)).To(ContainSubstring(`"ok":true`)) + }) +}) diff --git a/core/services/cloudproxy/mitm/leaf.go b/core/services/cloudproxy/mitm/leaf.go new file mode 100644 index 000000000000..4e542d6baea6 --- /dev/null +++ b/core/services/cloudproxy/mitm/leaf.go @@ -0,0 +1,102 @@ +package mitm + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "fmt" + "math/big" + "net" + "strings" + "time" +) + +type leafEntry struct { + cert *tls.Certificate + expiresAt time.Time +} + +const ( + leafLifetime = 30 * 24 * time.Hour + minBeforeReissue = 24 * time.Hour +) + +// IssueLeaf returns a TLS certificate for host, signed by this CA. +// Cached per host, re-minted when the cached cert is within +// minBeforeReissue of expiry. +func (c *CA) IssueLeaf(host string) (*tls.Certificate, error) { + if h, _, err := net.SplitHostPort(host); err == nil { + host = h + } + host = strings.ToLower(host) + + now := time.Now() + + c.mu.Lock() + if entry, ok := c.leaves[host]; ok { + if entry.expiresAt.After(now.Add(minBeforeReissue)) { + c.mu.Unlock() + return entry.cert, nil + } + delete(c.leaves, host) + } + c.mu.Unlock() + + // Mint outside the lock so a slow ECDSA key-gen doesn't block + // concurrent lookups for already-cached hosts. + leaf, err := c.mintLeaf(host) + if err != nil { + return nil, err + } + + c.mu.Lock() + c.leaves[host] = &leafEntry{ + cert: leaf, + expiresAt: now.Add(leafLifetime), + } + c.mu.Unlock() + return leaf, nil +} + +func (c *CA) mintLeaf(host string) (*tls.Certificate, error) { + leafKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, fmt.Errorf("mitm: leaf key for %q: %w", host, err) + } + + serial, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + if err != nil { + return nil, fmt.Errorf("mitm: leaf serial: %w", err) + } + + now := time.Now().UTC() + tmpl := &x509.Certificate{ + SerialNumber: serial, + Subject: pkix.Name{CommonName: host}, + NotBefore: now.Add(-1 * time.Hour), + NotAfter: now.Add(leafLifetime), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{ + x509.ExtKeyUsageServerAuth, + }, + BasicConstraintsValid: true, + } + if ip := net.ParseIP(host); ip != nil { + tmpl.IPAddresses = []net.IP{ip} + } else { + tmpl.DNSNames = []string{host} + } + + der, err := x509.CreateCertificate(rand.Reader, tmpl, c.cert, &leafKey.PublicKey, c.key) + if err != nil { + return nil, fmt.Errorf("mitm: sign leaf for %q: %w", host, err) + } + + return &tls.Certificate{ + Certificate: [][]byte{der, c.cert.Raw}, + PrivateKey: leafKey, + }, nil +} diff --git a/core/services/cloudproxy/mitm/leaf_test.go b/core/services/cloudproxy/mitm/leaf_test.go new file mode 100644 index 000000000000..3a1bd9b05702 --- /dev/null +++ b/core/services/cloudproxy/mitm/leaf_test.go @@ -0,0 +1,103 @@ +package mitm + +import ( + "crypto/tls" + "crypto/x509" + "net" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("IssueLeaf", func() { + It("chains to CA", func() { + ca, err := NewInMemoryCA() + Expect(err).NotTo(HaveOccurred()) + leaf, err := ca.IssueLeaf("api.anthropic.com") + Expect(err).NotTo(HaveOccurred()) + Expect(len(leaf.Certificate)).To(BeNumerically(">=", 1), "leaf has no DER") + parsed, err := x509.ParseCertificate(leaf.Certificate[0]) + Expect(err).NotTo(HaveOccurred()) + // Verify it's actually signed by the CA we generated. + pool := x509.NewCertPool() + pool.AddCert(ca.Cert()) + _, err = parsed.Verify(x509.VerifyOptions{ + Roots: pool, + KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + DNSName: "api.anthropic.com", + }) + Expect(err).NotTo(HaveOccurred(), "verify chain") + }) + + It("populates DNS and IP SANs correctly", func() { + ca, err := NewInMemoryCA() + Expect(err).NotTo(HaveOccurred()) + + // Hostname → DNSNames + leafDNS, err := ca.IssueLeaf("example.com") + Expect(err).NotTo(HaveOccurred()) + parsedDNS, _ := x509.ParseCertificate(leafDNS.Certificate[0]) + Expect(parsedDNS.DNSNames).NotTo(BeEmpty()) + Expect(parsedDNS.DNSNames[0]).To(Equal("example.com")) + Expect(parsedDNS.IPAddresses).To(BeEmpty(), "hostname leaf should have no IP SAN") + + // IP → IPAddresses + leafIP, err := ca.IssueLeaf("127.0.0.1") + Expect(err).NotTo(HaveOccurred()) + parsedIP, _ := x509.ParseCertificate(leafIP.Certificate[0]) + Expect(parsedIP.IPAddresses).NotTo(BeEmpty()) + Expect(parsedIP.IPAddresses[0].Equal(net.ParseIP("127.0.0.1"))).To(BeTrue()) + Expect(parsedIP.DNSNames).To(BeEmpty(), "IP leaf should have no DNS SAN") + }) + + It("caches by host", func() { + ca, err := NewInMemoryCA() + Expect(err).NotTo(HaveOccurred()) + a, _ := ca.IssueLeaf("api.example.com") + b, _ := ca.IssueLeaf("api.example.com") + Expect(a).To(BeIdenticalTo(b), "expected cached leaf to be returned, got distinct certs") + c, _ := ca.IssueLeaf("API.Example.com") // case-insensitive + Expect(a).To(BeIdenticalTo(c), "expected case-insensitive cache hit") + d, _ := ca.IssueLeaf("api.example.com:443") // host:port stripped + Expect(a).To(BeIdenticalTo(d), "expected port-stripped cache hit") + }) + + It("handshake accepted by client", func() { + // End-to-end check: a TLS server using the leaf, with a client + // trusting the CA, completes a handshake. This is the property + // every other flow in this package depends on. + ca, err := NewInMemoryCA() + Expect(err).NotTo(HaveOccurred()) + leaf, err := ca.IssueLeaf("localhost") + Expect(err).NotTo(HaveOccurred()) + + pool := x509.NewCertPool() + pool.AddCert(ca.Cert()) + + listener, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{ + Certificates: []tls.Certificate{*leaf}, + }) + Expect(err).NotTo(HaveOccurred()) + defer func() { _ = listener.Close() }() + + go func() { + conn, err := listener.Accept() + if err != nil { + return + } + defer func() { _ = conn.Close() }() + _, _ = conn.Write([]byte("ok")) + }() + + conn, err := tls.Dial("tcp", listener.Addr().String(), &tls.Config{ + RootCAs: pool, + ServerName: "localhost", + }) + Expect(err).NotTo(HaveOccurred(), "client TLS dial") + defer func() { _ = conn.Close() }() + buf := make([]byte, 2) + _, err = conn.Read(buf) + Expect(err).NotTo(HaveOccurred(), "read") + Expect(string(buf)).To(Equal("ok")) + }) +}) diff --git a/core/services/cloudproxy/mitm/mitm_suite_test.go b/core/services/cloudproxy/mitm/mitm_suite_test.go new file mode 100644 index 000000000000..aeb019112852 --- /dev/null +++ b/core/services/cloudproxy/mitm/mitm_suite_test.go @@ -0,0 +1,13 @@ +package mitm + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestMitm(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "mitm test suite") +} diff --git a/core/services/cloudproxy/mitm/proxy.go b/core/services/cloudproxy/mitm/proxy.go new file mode 100644 index 000000000000..79f49aa648f2 --- /dev/null +++ b/core/services/cloudproxy/mitm/proxy.go @@ -0,0 +1,306 @@ +package mitm + +import ( + "bufio" + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "net" + "net/http" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/mudler/LocalAI/core/services/routing/pii" + "github.com/mudler/xlog" + "golang.org/x/net/http2" +) + +// Server is an HTTPS forward proxy that MITMs traffic for hosts +// in its intercept allowlist; non-allowlisted hosts get a plain +// TCP CONNECT tunnel. +type Server struct { + addr string + ca *CA + interceptHosts map[string]bool + handler InterceptHandler + connectTimeout time.Duration + dialTimeout time.Duration + upstreamTLS *tls.Config + events pii.EventStore + eventSeq atomic.Uint64 + + listener net.Listener + srv *http.Server + + wg sync.WaitGroup + stopOnce sync.Once + stopped chan struct{} +} + +// InterceptHandler runs after the proxy terminates TLS for an +// allowlisted host. It is responsible for forwarding the upstream +// response bytes back to w. +type InterceptHandler func(w http.ResponseWriter, r *http.Request, upstreamHost string) + +type Config struct { + Addr string + CA *CA + InterceptHosts []string + Handler InterceptHandler + // EventStore optionally receives a proxy_connect event for every + // CONNECT, recording the destination host and whether the proxy + // intercepted or tunneled it. nil disables connect-event recording. + EventStore pii.EventStore +} + +func NewServer(cfg Config) (*Server, error) { + if cfg.CA == nil { + return nil, errors.New("mitm: NewServer: CA is required") + } + if cfg.Handler == nil { + return nil, errors.New("mitm: NewServer: Handler is required") + } + hosts := make(map[string]bool, len(cfg.InterceptHosts)) + for _, h := range cfg.InterceptHosts { + hosts[strings.ToLower(strings.TrimSpace(h))] = true + } + return &Server{ + addr: cfg.Addr, + ca: cfg.CA, + interceptHosts: hosts, + handler: cfg.Handler, + connectTimeout: 30 * time.Second, + dialTimeout: 15 * time.Second, + upstreamTLS: &tls.Config{NextProtos: []string{"http/1.1"}}, + events: cfg.EventStore, + stopped: make(chan struct{}), + }, nil +} + +func (s *Server) Start() error { + ln, err := net.Listen("tcp", s.addr) + if err != nil { + return fmt.Errorf("mitm: listen %q: %w", s.addr, err) + } + s.listener = ln + s.srv = &http.Server{ + Handler: http.HandlerFunc(s.handle), + ReadHeaderTimeout: 30 * time.Second, + } + s.wg.Add(1) + go func() { + defer s.wg.Done() + err := s.srv.Serve(ln) + if err != nil && !errors.Is(err, http.ErrServerClosed) { + xlog.Error("mitm: serve error", "error", err) + } + }() + xlog.Info("mitm: listening", "addr", ln.Addr().String(), "intercept_hosts", len(s.interceptHosts)) + return nil +} + +// Addr returns the bound listener address. Useful when Start was +// called with ":0" — the kernel picks a port and tests need to +// discover which. +func (s *Server) Addr() string { + if s.listener == nil { + return s.addr + } + return s.listener.Addr().String() +} + +// Stop is idempotent. +func (s *Server) Stop() { + s.stopOnce.Do(func() { + close(s.stopped) + if s.srv != nil { + _ = s.srv.Close() + } + s.wg.Wait() + }) +} + +func (s *Server) handle(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodConnect { + http.Error(w, "this proxy only supports HTTPS via CONNECT", http.StatusMethodNotAllowed) + return + } + + host, _, err := net.SplitHostPort(r.Host) + if err != nil { + host = r.Host + } + host = strings.ToLower(host) + + intercept := s.shouldIntercept(host) + s.recordConnectEvent(host, intercept) + if !intercept { + s.handleTunnel(w, r) + return + } + s.handleIntercept(w, r, host) +} + +// recordConnectEvent writes a proxy_connect audit row. Best-effort — +// store errors are logged at debug only so a failing recorder cannot +// break a CONNECT. +func (s *Server) recordConnectEvent(host string, intercepted bool) { + if s.events == nil { + return + } + flag := intercepted + ev := pii.PIIEvent{ + ID: fmt.Sprintf("proxy_connect_%d", s.eventSeq.Add(1)), + Kind: pii.KindProxyConnect, + Host: host, + Intercepted: &flag, + CreatedAt: time.Now(), + } + if err := s.events.Record(context.Background(), ev); err != nil { + xlog.Debug("mitm: failed to record proxy_connect event", "error", err, "host", host) + } +} + +// shouldIntercept reports whether host is in the allowlist. An +// empty allowlist tunnels everything. +func (s *Server) shouldIntercept(host string) bool { + if len(s.interceptHosts) == 0 { + return false + } + return s.interceptHosts[host] +} + +func (s *Server) handleTunnel(w http.ResponseWriter, r *http.Request) { + upstream, err := net.DialTimeout("tcp", normalizeHostPort(r.Host), s.dialTimeout) + if err != nil { + http.Error(w, "mitm: tunnel dial: "+err.Error(), http.StatusBadGateway) + return + } + defer func() { _ = upstream.Close() }() + + hijacker, ok := w.(http.Hijacker) + if !ok { + http.Error(w, "mitm: hijack unsupported", http.StatusInternalServerError) + return + } + clientConn, _, err := hijacker.Hijack() + if err != nil { + http.Error(w, "mitm: hijack failed: "+err.Error(), http.StatusInternalServerError) + return + } + defer func() { _ = clientConn.Close() }() + + if _, err := clientConn.Write([]byte("HTTP/1.1 200 Connection established\r\n\r\n")); err != nil { + return + } + + pipe(clientConn, upstream) +} + +func pipe(a, b net.Conn) { + done := make(chan struct{}, 2) + go func() { + _, _ = io.Copy(a, b) + _ = a.SetDeadline(time.Now()) + done <- struct{}{} + }() + go func() { + _, _ = io.Copy(b, a) + _ = b.SetDeadline(time.Now()) + done <- struct{}{} + }() + <-done +} + +func (s *Server) handleIntercept(w http.ResponseWriter, r *http.Request, host string) { + leaf, err := s.ca.IssueLeaf(host) + if err != nil { + http.Error(w, "mitm: leaf issuance failed: "+err.Error(), http.StatusInternalServerError) + return + } + + hijacker, ok := w.(http.Hijacker) + if !ok { + http.Error(w, "mitm: hijack unsupported", http.StatusInternalServerError) + return + } + clientConn, _, err := hijacker.Hijack() + if err != nil { + http.Error(w, "mitm: hijack failed: "+err.Error(), http.StatusInternalServerError) + return + } + defer func() { _ = clientConn.Close() }() + + if _, err := clientConn.Write([]byte("HTTP/1.1 200 Connection established\r\n\r\n")); err != nil { + return + } + + tlsConn := tls.Server(clientConn, &tls.Config{ + Certificates: []tls.Certificate{*leaf}, + NextProtos: []string{"h2", "http/1.1"}, + }) + defer func() { _ = tlsConn.Close() }() + + // Deadline applies to the handshake only; cleared before the + // request loop so long-running streams don't get cut off. Fail + // closed if SetDeadline errors — better than handshaking without + // a deadline. + if err := tlsConn.SetDeadline(time.Now().Add(s.connectTimeout)); err != nil { + xlog.Debug("mitm: TLS handshake set-deadline failed", "host", host, "error", err) + return + } + if err := tlsConn.Handshake(); err != nil { + xlog.Debug("mitm: TLS handshake failed", "host", host, "error", err) + return + } + _ = tlsConn.SetDeadline(time.Time{}) + + handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + req.URL.Scheme = "https" + if req.URL.Host == "" { + req.URL.Host = req.Host + } + s.handler(rw, req, host) + }) + + switch tlsConn.ConnectionState().NegotiatedProtocol { + case "h2": + h2srv := &http2.Server{} + h2srv.ServeConn(tlsConn, &http2.ServeConnOpts{ + Handler: handler, + Context: r.Context(), + }) + default: + s.serveHTTP1(tlsConn, handler, host) + } +} + +func (s *Server) serveHTTP1(tlsConn *tls.Conn, handler http.Handler, host string) { + br := bufio.NewReader(tlsConn) + for { + req, err := http.ReadRequest(br) + if err != nil { + if !errors.Is(err, io.EOF) { + xlog.Debug("mitm: read request", "host", host, "error", err) + } + return + } + rw := newConnResponseWriter(tlsConn, req) + handler.ServeHTTP(rw, req) + rw.finish() + if req.Close || rw.closeAfter { + return + } + } +} + +func normalizeHostPort(host string) string { + if _, _, err := net.SplitHostPort(host); err == nil { + return host + } + return host + ":443" +} diff --git a/core/services/cloudproxy/mitm/proxy_test.go b/core/services/cloudproxy/mitm/proxy_test.go new file mode 100644 index 000000000000..7f4cb9fcf379 --- /dev/null +++ b/core/services/cloudproxy/mitm/proxy_test.go @@ -0,0 +1,278 @@ +package mitm + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "io" + "net" + "net/http" + "net/http/httptest" + "net/url" + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +// passthroughHandler is the test fixture: forward the parsed +// request to the upstream and stream the response back. Mirrors +// what a production handler would do without any PII rewriting, +// so the proxy core's CONNECT/TLS/req-loop semantics are testable +// in isolation from the redaction logic. +func passthroughHandler(upstreamRoots *x509.CertPool, upstreamAddr string) InterceptHandler { + return func(w http.ResponseWriter, r *http.Request, host string) { + // Build the upstream URL — host is what the client thought + // it was talking to (api.anthropic.com); upstreamAddr is + // where the test fake actually lives. We use upstreamAddr + // directly because the test fake's cert is self-signed + // against an arbitrary CA we control. + u := *r.URL + u.Scheme = "https" + u.Host = upstreamAddr + + body := r.Body + req, err := http.NewRequest(r.Method, u.String(), body) + if err != nil { + http.Error(w, "bad request: "+err.Error(), http.StatusBadRequest) + return + } + req.Header = r.Header.Clone() + req.Header.Set("Host", host) + + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: upstreamRoots, + // httptest.NewTLSServer issues a cert for + // example.com / *.example.com regardless of the + // listener's actual hostname. Trust that name + // rather than the SNI the client used — + // production code would set ServerName=host. + ServerName: "example.com", + }, + }, + Timeout: 10 * time.Second, + } + resp, err := client.Do(req) + if err != nil { + http.Error(w, "upstream: "+err.Error(), http.StatusBadGateway) + return + } + defer func() { _ = resp.Body.Close() }() + + for k, vs := range resp.Header { + for _, v := range vs { + w.Header().Add(k, v) + } + } + w.WriteHeader(resp.StatusCode) + _, _ = io.Copy(w, resp.Body) + } +} + +// startMITMTestRig spins up: +// - A fake "upstream" HTTPS server with a self-signed cert +// - A MITM proxy that intercepts the upstream's hostname +// +// Returns a client http.Client whose Transport points at the proxy +// and trusts the MITM CA, plus the upstream URL the client should +// use. Callers tear down with the returned cleanup. +func startMITMTestRig(interceptHost string, upstream http.Handler) (*http.Client, string, func()) { + // Upstream: real TLS server with its own cert. Trust this + // from the proxy's outbound side only. + ts := httptest.NewTLSServer(upstream) + upstreamCertPool := x509.NewCertPool() + upstreamCertPool.AddCert(ts.Certificate()) + upstreamURL, _ := url.Parse(ts.URL) + + ca, err := NewInMemoryCA() + ExpectWithOffset(1, err).NotTo(HaveOccurred()) + + srv, err := NewServer(Config{ + Addr: "127.0.0.1:0", + CA: ca, + InterceptHosts: []string{interceptHost}, + Handler: passthroughHandler(upstreamCertPool, upstreamURL.Host), + }) + ExpectWithOffset(1, err).NotTo(HaveOccurred()) + ExpectWithOffset(1, srv.Start()).To(Succeed()) + + // Client side: trust the MITM CA so the proxied TLS handshake + // succeeds. Configure HTTPS_PROXY to the proxy listener. + clientPool := x509.NewCertPool() + clientPool.AddCert(ca.Cert()) + proxyURL, _ := url.Parse("http://" + srv.Addr()) + client := &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + TLSClientConfig: &tls.Config{RootCAs: clientPool}, + }, + Timeout: 10 * time.Second, + } + + cleanup := func() { + srv.Stop() + ts.Close() + } + return client, "https://" + interceptHost, cleanup +} + +var _ = Describe("Proxy", func() { + It("intercepts allowlisted host", func() { + captured := false + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + captured = true + // Upstream receives whatever Host header the proxy + // forwarded — in production this would be the real + // hostname; in this test it's the upstream's listener. + // We just verify *some* request landed at the upstream. + w.Header().Set("Content-Type", "application/json") + _, _ = fmt.Fprint(w, `{"ok":true}`) + }) + + client, baseURL, cleanup := startMITMTestRig("api.test.local", upstream) + defer cleanup() + + resp, err := client.Get(baseURL + "/v1/test") + Expect(err).NotTo(HaveOccurred(), "client.Get") + defer func() { _ = resp.Body.Close() }() + + Expect(resp.StatusCode).To(Equal(200)) + body, _ := io.ReadAll(resp.Body) + Expect(string(body)).To(ContainSubstring(`"ok":true`)) + Expect(captured).To(BeTrue(), "upstream handler was never called — proxy did not forward") + }) + + It("tunnels non-allowlisted host", func() { + // Set up a "different" upstream we don't put in the allowlist. + // The proxy should tunnel CONNECTs to it without TLS termination, + // so we need to dial through the proxy and verify the upstream + // sees the raw TLS — the MITM CA isn't used. + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = fmt.Fprint(w, `passthrough`) + }) + ts := httptest.NewTLSServer(upstream) + defer ts.Close() + upstreamURL, _ := url.Parse(ts.URL) + upstreamHost, upstreamPort, _ := net.SplitHostPort(upstreamURL.Host) + + ca, _ := NewInMemoryCA() + srv, err := NewServer(Config{ + Addr: "127.0.0.1:0", + CA: ca, + // Allowlist only "api.test.local" — upstream's host is NOT + // on it, so CONNECT to it must tunnel. + InterceptHosts: []string{"api.test.local"}, + Handler: func(w http.ResponseWriter, r *http.Request, h string) { http.Error(w, "should not be called", 500) }, + }) + Expect(err).NotTo(HaveOccurred()) + Expect(srv.Start()).To(Succeed()) + defer srv.Stop() + + // Client trusts the upstream's actual cert (NOT the MITM CA), + // so a successful TLS handshake proves the proxy did not MITM. + upstreamCertPool := x509.NewCertPool() + upstreamCertPool.AddCert(ts.Certificate()) + proxyURL, _ := url.Parse("http://" + srv.Addr()) + client := &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + TLSClientConfig: &tls.Config{ + RootCAs: upstreamCertPool, + ServerName: upstreamHost, + }, + }, + Timeout: 10 * time.Second, + } + _ = upstreamPort + + resp, err := client.Get(ts.URL) + Expect(err).NotTo(HaveOccurred(), "Get through tunnel") + defer func() { _ = resp.Body.Close() }() + body, _ := io.ReadAll(resp.Body) + Expect(string(body)).To(Equal("passthrough")) + }) + + It("rejects non-CONNECT requests", func() { + ca, _ := NewInMemoryCA() + srv, err := NewServer(Config{ + Addr: "127.0.0.1:0", + CA: ca, + Handler: func(w http.ResponseWriter, r *http.Request, h string) {}, + }) + Expect(err).NotTo(HaveOccurred()) + Expect(srv.Start()).To(Succeed()) + defer srv.Stop() + + resp, err := http.Get("http://" + srv.Addr() + "/") + Expect(err).NotTo(HaveOccurred(), "GET") + defer func() { _ = resp.Body.Close() }() + Expect(resp.StatusCode).To(Equal(http.StatusMethodNotAllowed)) + }) + + It("streams responses", func() { + // SSE-style upstream: send three text chunks with explicit + // flushes so the proxy's Flusher path is exercised. + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(200) + flusher := w.(http.Flusher) + for _, msg := range []string{"a", "b", "c"} { + _, _ = fmt.Fprintf(w, "data: %s\n\n", msg) + flusher.Flush() + } + }) + client, baseURL, cleanup := startMITMTestRig("api.test.local", upstream) + defer cleanup() + + resp, err := client.Get(baseURL + "/stream") + Expect(err).NotTo(HaveOccurred(), "Get") + defer func() { _ = resp.Body.Close() }() + body, _ := io.ReadAll(resp.Body) + for _, msg := range []string{"a", "b", "c"} { + Expect(string(body)).To(ContainSubstring("data: " + msg)) + } + }) + + It("with no allowlist tunnels everything", func() { + // Empty InterceptHosts means the proxy is in observability- + // only mode: every CONNECT tunnels. Verifies the default- + // fail-safe behaviour mentioned in shouldIntercept. + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = fmt.Fprint(w, "tunneled") + }) + ts := httptest.NewTLSServer(upstream) + defer ts.Close() + upstreamURL, _ := url.Parse(ts.URL) + upstreamHost, _, _ := net.SplitHostPort(upstreamURL.Host) + + ca, _ := NewInMemoryCA() + srv, _ := NewServer(Config{ + Addr: "127.0.0.1:0", + CA: ca, + Handler: func(w http.ResponseWriter, r *http.Request, h string) { Fail("intercept handler called with empty allowlist") }, + // InterceptHosts intentionally empty. + }) + Expect(srv.Start()).To(Succeed()) + defer srv.Stop() + + upstreamCertPool := x509.NewCertPool() + upstreamCertPool.AddCert(ts.Certificate()) + proxyURL, _ := url.Parse("http://" + srv.Addr()) + client := &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + TLSClientConfig: &tls.Config{ + RootCAs: upstreamCertPool, + ServerName: upstreamHost, + }, + }, + } + resp, err := client.Get(ts.URL) + Expect(err).NotTo(HaveOccurred(), "Get") + defer func() { _ = resp.Body.Close() }() + body, _ := io.ReadAll(resp.Body) + Expect(string(body)).To(Equal("tunneled")) + }) +}) diff --git a/core/services/cloudproxy/mitm/response.go b/core/services/cloudproxy/mitm/response.go new file mode 100644 index 000000000000..067877bfe511 --- /dev/null +++ b/core/services/cloudproxy/mitm/response.go @@ -0,0 +1,105 @@ +package mitm + +import ( + "bufio" + "crypto/tls" + "fmt" + "net/http" + "strconv" + "strings" +) + +// connResponseWriter is a minimal HTTP/1.1 http.ResponseWriter +// that writes directly to a hijacked TLS connection. +type connResponseWriter struct { + conn *tls.Conn + bw *bufio.Writer + req *http.Request + + header http.Header + wroteHeader bool + chunked bool + contentLength int64 + written int64 + closeAfter bool +} + +func newConnResponseWriter(conn *tls.Conn, req *http.Request) *connResponseWriter { + return &connResponseWriter{ + conn: conn, + bw: bufio.NewWriter(conn), + req: req, + header: make(http.Header), + contentLength: -1, + } +} + +func (w *connResponseWriter) Header() http.Header { return w.header } + +func (w *connResponseWriter) WriteHeader(status int) { + if w.wroteHeader { + return + } + w.wroteHeader = true + + if cl := w.header.Get("Content-Length"); cl != "" { + if n, err := strconv.ParseInt(cl, 10, 64); err == nil { + w.contentLength = n + } + } + if w.contentLength < 0 { + w.chunked = true + w.header.Set("Transfer-Encoding", "chunked") + w.header.Del("Content-Length") + } + + // "Connection: close" is case-insensitive per RFC 9110 §7.6.1; some + // upstreams send "Close" or "CLOSE". Use EqualFold so any casing + // triggers the post-response disconnect. + for _, v := range w.header.Values("Connection") { + if strings.EqualFold(v, "close") { + w.closeAfter = true + } + } + + _, _ = fmt.Fprintf(w.bw, "HTTP/1.1 %d %s\r\n", status, http.StatusText(status)) + _ = w.header.Write(w.bw) + _, _ = w.bw.WriteString("\r\n") +} + +func (w *connResponseWriter) Write(p []byte) (int, error) { + if !w.wroteHeader { + w.WriteHeader(http.StatusOK) + } + if w.chunked { + if _, err := fmt.Fprintf(w.bw, "%x\r\n", len(p)); err != nil { + return 0, err + } + n, err := w.bw.Write(p) + if err != nil { + return n, err + } + if _, err := w.bw.WriteString("\r\n"); err != nil { + return n, err + } + w.written += int64(n) + return n, nil + } + n, err := w.bw.Write(p) + w.written += int64(n) + return n, err +} + +func (w *connResponseWriter) Flush() { + _ = w.bw.Flush() +} + +func (w *connResponseWriter) finish() { + if !w.wroteHeader { + w.WriteHeader(http.StatusOK) + } + if w.chunked { + _, _ = w.bw.WriteString("0\r\n\r\n") + } + _ = w.bw.Flush() +} diff --git a/core/services/cloudproxy/mitm/restart_test.go b/core/services/cloudproxy/mitm/restart_test.go new file mode 100644 index 000000000000..fd39a5ff86af --- /dev/null +++ b/core/services/cloudproxy/mitm/restart_test.go @@ -0,0 +1,98 @@ +package mitm + +import ( + "fmt" + "net" + "net/http" + "strings" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +// noopHandler is the simplest InterceptHandler that satisfies NewServer. +// We only exercise Start/Stop lifecycle here — no requests go through. +func noopHandler(_ http.ResponseWriter, _ *http.Request, _ string) {} + +func newTestServer(addr string, hosts []string) *Server { + ca, err := NewInMemoryCA() + ExpectWithOffset(1, err).NotTo(HaveOccurred(), "NewInMemoryCA") + srv, err := NewServer(Config{ + Addr: addr, + CA: ca, + InterceptHosts: hosts, + Handler: noopHandler, + }) + ExpectWithOffset(1, err).NotTo(HaveOccurred(), "NewServer") + return srv +} + +// Server_StopIdempotent: calling Stop twice (and Stop without +// Start) must not panic or deadlock. The application's RestartMITM +// path is sensitive to this — it always calls Stop before Start, even +// when the server is already stopped. +var _ = Describe("Server", func() { + It("Stop is idempotent", func() { + srv := newTestServer("127.0.0.1:0", nil) + srv.Stop() // never started + srv.Stop() // double-stop after never-started + + srv2 := newTestServer("127.0.0.1:0", nil) + Expect(srv2.Start()).To(Succeed()) + srv2.Stop() + srv2.Stop() // second Stop after Start+Stop + }) + + // Server_RestartCycle: two sequential Server lifecycles on the + // same address must rebind cleanly, the new listener must accept + // connections, and the new allowlist must take effect — the shape + // RestartMITM relies on. + It("restart cycle rebinds and swaps allowlist", func() { + // First, find a free port we can rebind to. + probe, err := net.Listen("tcp", "127.0.0.1:0") + Expect(err).NotTo(HaveOccurred(), "probe listen") + addr := probe.Addr().String() + _ = probe.Close() + + srv1 := newTestServer(addr, []string{"first.example.com"}) + if err := srv1.Start(); err != nil { + // Port could have been recycled between probe close and Start. + // Skip rather than flake — the production path uses dynamic + // addrs anyway. + Skip(fmt.Sprintf("could not bind probed addr: %v", err)) + } + Expect(strings.HasPrefix(srv1.Addr(), "127.0.0.1:")).To(BeTrue(), "Addr() = %q, want 127.0.0.1:* prefix", srv1.Addr()) + srv1.Stop() + + // Now bring up a second server on the same addr with a different + // allowlist — mirrors the RestartMITM-with-edited-hosts path. + srv2 := newTestServer(addr, []string{"second.example.com"}) + if err := srv2.Start(); err != nil { + // SO_REUSEADDR is not set; brief TIME_WAIT collisions are + // possible on slow CI runners. Retry once on a fresh port so + // the test still exercises the "different hosts" property. + srv2 = newTestServer("127.0.0.1:0", []string{"second.example.com"}) + Expect(srv2.Start()).To(Succeed(), "second Start (fresh port fallback)") + } + defer srv2.Stop() + + // Smoke: the new listener accepts a TCP connection. + conn, err := net.Dial("tcp", srv2.Addr()) + Expect(err).NotTo(HaveOccurred(), "dial restarted listener") + _ = conn.Close() + + // Allowlist swap took effect: the new server intercepts + // "second.example.com" (and not the old "first.example.com"). + Expect(srv2.shouldIntercept("second.example.com")).To(BeTrue(), "second server did not pick up the new InterceptHosts") + Expect(srv2.shouldIntercept("first.example.com")).To(BeFalse(), "second server still has the first server's allowlist") + }) + + // Server_AddrBeforeStart: Addr() pre-Start returns the configured + // address rather than panicking on a nil listener. The admin status + // endpoint reads it under MITMServer() — when an admin queries between + // configuration and Start, the response should still render cleanly. + It("Addr before start returns configured address", func() { + srv := newTestServer(":12345", nil) + Expect(srv.Addr()).To(Equal(":12345")) + }) +}) diff --git a/core/services/cloudproxy/proxy.go b/core/services/cloudproxy/proxy.go new file mode 100644 index 000000000000..106fea996d72 --- /dev/null +++ b/core/services/cloudproxy/proxy.go @@ -0,0 +1,232 @@ +// Package cloudproxy forwards LocalAI requests to external +// provider APIs (Backend = "proxy-*") wire-format-faithfully — it +// does not translate request shapes between providers. +package cloudproxy + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "strings" + "time" + + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/services/cloudproxy/ssewire" + "github.com/mudler/LocalAI/core/services/routing/pii" + "github.com/mudler/xlog" +) + +// Backend names recognised by the proxy. Any backend starting with +// "proxy-" goes through Forward; these two are the wire formats the +// auth + SSE rewriter knows. Unknown proxy-* backends fall back to +// OpenAI auth. +const ( + BackendProxyOpenAI = "proxy-openai" + BackendProxyAnthropic = "proxy-anthropic" +) + +// transport is overridable in tests; production uses http.DefaultTransport. +var transport http.RoundTripper = http.DefaultTransport + +// SetTransport swaps the HTTP transport used by every Forward call. +// Test-only. +func SetTransport(rt http.RoundTripper) func() { + prev := transport + transport = rt + return func() { transport = prev } +} + +// providerName classifies a backend string into one of the +// supported upstreams. Unknown "proxy-*" names fall back to +// openai-shaped auth since most third-party providers mirror it. +func providerName(backend string) string { + switch backend { + case BackendProxyAnthropic: + return "anthropic" + default: + return "openai" + } +} + +func buildHTTPRequest(ctx context.Context, cfg *config.ModelConfig, body []byte) (*http.Request, error) { + if cfg.Proxy.UpstreamURL == "" { + return nil, fmt.Errorf("cloudproxy: proxy.upstream_url is empty for model %q", cfg.Name) + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, cfg.Proxy.UpstreamURL, bytes.NewReader(body)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "*/*") + + apiKey := "" + if cfg.Proxy.APIKeyEnv != "" { + apiKey = os.Getenv(cfg.Proxy.APIKeyEnv) + } + + switch providerName(cfg.Backend) { + case "anthropic": + if apiKey != "" { + req.Header.Set("x-api-key", apiKey) + } + req.Header.Set("anthropic-version", "2023-06-01") + default: + if apiKey != "" { + req.Header.Set("Authorization", "Bearer "+apiKey) + } + } + return req, nil +} + +func httpClient(cfg *config.ModelConfig) *http.Client { + c := &http.Client{Transport: transport} + if cfg.Proxy.RequestTimeoutSeconds > 0 { + c.Timeout = time.Duration(cfg.Proxy.RequestTimeoutSeconds) * time.Second + } + return c +} + +func rewriteModel(body []byte, upstreamModel string) ([]byte, error) { + if upstreamModel == "" { + return body, nil + } + var m map[string]any + if err := json.Unmarshal(body, &m); err != nil { + return nil, fmt.Errorf("cloudproxy: parse request body: %w", err) + } + m["model"] = upstreamModel + return json.Marshal(m) +} + +func streaming(body []byte) bool { + var probe struct { + Stream bool `json:"stream"` + } + if err := json.Unmarshal(body, &probe); err != nil { + return false + } + return probe.Stream +} + +// Forward proxies a chat-style request to the upstream and writes +// the response to the client, applying filter to per-token text +// extracted from streaming SSE responses. +func Forward(c echo.Context, cfg *config.ModelConfig, body []byte, filter *pii.StreamFilter) error { + body, err := rewriteModel(body, cfg.Proxy.UpstreamModel) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + + req, err := buildHTTPRequest(c.Request().Context(), cfg, body) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + + isStream := streaming(body) + xlog.Debug("cloudproxy: forwarding", + "model", cfg.Name, + "backend", cfg.Backend, + "upstream", cfg.Proxy.UpstreamURL, + "stream", isStream, + ) + + resp, err := httpClient(cfg).Do(req) + if err != nil { + return echo.NewHTTPError(http.StatusBadGateway, "cloudproxy: upstream request failed: "+err.Error()) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 { + return passthroughError(c, resp) + } + + if isStream { + return forwardStream(c, resp, providerName(cfg.Backend), filter) + } + return forwardBuffered(c, resp) +} + +func passthroughError(c echo.Context, resp *http.Response) error { + const maxErrBody = 1 << 20 + body, _ := io.ReadAll(io.LimitReader(resp.Body, maxErrBody)) + if ct := resp.Header.Get("Content-Type"); ct != "" { + c.Response().Header().Set("Content-Type", ct) + } + c.Response().WriteHeader(resp.StatusCode) + _, _ = c.Response().Writer.Write(body) + return nil +} + +func forwardBuffered(c echo.Context, resp *http.Response) error { + if ct := resp.Header.Get("Content-Type"); ct != "" { + c.Response().Header().Set("Content-Type", ct) + } + c.Response().WriteHeader(resp.StatusCode) + _, err := io.Copy(c.Response().Writer, resp.Body) + return err +} + +func forwardStream(c echo.Context, resp *http.Response, provider string, filter *pii.StreamFilter) error { + c.Response().Header().Set("Content-Type", "text/event-stream") + c.Response().Header().Set("Cache-Control", "no-cache") + c.Response().Header().Set("Connection", "keep-alive") + c.Response().WriteHeader(http.StatusOK) + + emit := func(line string) error { + _, err := fmt.Fprint(c.Response().Writer, line) + if err != nil { + return err + } + c.Response().Flush() + return nil + } + + flushResidual := func() { + if filter == nil { + return + } + residual := filter.Drain() + if residual == "" { + return + } + if line := ssewire.SynthResidualEvent(ssewire.Provider(provider), residual); line != "" { + _ = emit(line) + } + } + + prov := ssewire.Provider(provider) + scanner := ssewire.NewScanner(resp.Body) + for scanner.Scan() { + ev := scanner.Event() + if ssewire.IsTerminalMarker(ev.DataLine, prov) { + flushResidual() + _ = emit(ev.Raw) + continue + } + out := ev.Raw + if filter != nil && ev.DataLine != "" { + rewritten, drop := ssewire.RewritePayload(ev.DataLine, prov, filter) + if drop { + continue + } + if rewritten != ev.DataLine { + // strings.Replace with n=1 touches only the data + // line, preserving any "event:"/"id:" preamble. + out = strings.Replace(ev.Raw, ev.DataLine, rewritten, 1) + } + } + if err := emit(out); err != nil { + return nil + } + } + if err := scanner.Err(); err != nil && err != io.EOF { + xlog.Debug("cloudproxy: stream read error", "error", err) + } + flushResidual() + return nil +} diff --git a/core/services/cloudproxy/proxy_suite_test.go b/core/services/cloudproxy/proxy_suite_test.go new file mode 100644 index 000000000000..a30c14ec505f --- /dev/null +++ b/core/services/cloudproxy/proxy_suite_test.go @@ -0,0 +1,13 @@ +package cloudproxy + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestCloudproxy(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "cloudproxy test suite") +} diff --git a/core/services/cloudproxy/proxy_test.go b/core/services/cloudproxy/proxy_test.go new file mode 100644 index 000000000000..994f58ff31f1 --- /dev/null +++ b/core/services/cloudproxy/proxy_test.go @@ -0,0 +1,284 @@ +package cloudproxy + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/services/routing/pii" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +// fakeUpstream returns an httptest.Server whose handler captures the +// inbound request and replies with the given status/body. It is the +// shared fixture for proxy integration tests. +func fakeUpstream(status int, body string, captured **http.Request, capturedBody *[]byte) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if captured != nil { + cloned := r.Clone(context.Background()) + *captured = cloned + } + if capturedBody != nil { + b, _ := io.ReadAll(r.Body) + *capturedBody = b + } + // Mirror the upstream's content-type so the proxy passes it + // through faithfully. Tests for streaming explicitly set + // text/event-stream; tests for buffered set application/json. + ct := r.Header.Get("X-Test-Response-Content-Type") + if ct == "" { + ct = "application/json" + } + w.Header().Set("Content-Type", ct) + w.WriteHeader(status) + _, _ = io.WriteString(w, body) + })) +} + +func newEchoCtx(method, path string, body string) (echo.Context, *httptest.ResponseRecorder) { + e := echo.New() + req := httptest.NewRequest(method, path, strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + return c, rec +} + +var _ = Describe("Forward", func() { + It("buffered passthrough", func() { + var capturedBody []byte + var capturedReq *http.Request + upstream := fakeUpstream(200, `{"id":"abc","choices":[{"message":{"content":"hi"}}]}`, &capturedReq, &capturedBody) + defer upstream.Close() + + cfg := &config.ModelConfig{ + Backend: "proxy-openai", + Proxy: config.ProxyConfig{ + UpstreamURL: upstream.URL, + UpstreamModel: "gpt-4o-mini", + }, + } + cfg.Name = "alias" + + c, rec := newEchoCtx("POST", "/v1/chat/completions", `{"model":"alias","stream":false,"messages":[{"role":"user","content":"hi"}]}`) + + body := `{"model":"alias","stream":false,"messages":[{"role":"user","content":"hi"}]}` + Expect(Forward(c, cfg, []byte(body), nil)).To(Succeed()) + + Expect(rec.Code).To(Equal(200)) + Expect(rec.Body.String()).To(ContainSubstring(`"content":"hi"`)) + // Model rewrite must have replaced the alias. + var sent map[string]any + Expect(json.Unmarshal(capturedBody, &sent)).To(Succeed()) + Expect(sent["model"]).To(Equal("gpt-4o-mini")) + }) + + It("openai auth header", func() { + var captured *http.Request + upstream := fakeUpstream(200, `{}`, &captured, nil) + defer upstream.Close() + + GinkgoT().Setenv("PROXY_TEST_KEY", "sk-secret-xyz") + cfg := &config.ModelConfig{ + Backend: "proxy-openai", + Proxy: config.ProxyConfig{ + UpstreamURL: upstream.URL, + APIKeyEnv: "PROXY_TEST_KEY", + }, + } + c, _ := newEchoCtx("POST", "/v1/chat/completions", `{}`) + Expect(Forward(c, cfg, []byte(`{"model":"x"}`), nil)).To(Succeed()) + Expect(captured.Header.Get("Authorization")).To(Equal("Bearer sk-secret-xyz")) + Expect(captured.Header.Get("x-api-key")).To(BeEmpty(), "x-api-key leaked on openai backend") + }) + + It("anthropic auth header", func() { + var captured *http.Request + upstream := fakeUpstream(200, `{}`, &captured, nil) + defer upstream.Close() + + GinkgoT().Setenv("PROXY_TEST_KEY", "ant-secret") + cfg := &config.ModelConfig{ + Backend: "proxy-anthropic", + Proxy: config.ProxyConfig{ + UpstreamURL: upstream.URL, + APIKeyEnv: "PROXY_TEST_KEY", + }, + } + c, _ := newEchoCtx("POST", "/v1/messages", `{}`) + Expect(Forward(c, cfg, []byte(`{"model":"x"}`), nil)).To(Succeed()) + Expect(captured.Header.Get("x-api-key")).To(Equal("ant-secret")) + Expect(captured.Header.Get("anthropic-version")).NotTo(BeEmpty(), "anthropic-version header missing") + Expect(captured.Header.Get("Authorization")).To(BeEmpty(), "Authorization leaked on anthropic backend") + }) + + It("upstream error passthrough", func() { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(429) + _, _ = io.WriteString(w, `{"error":{"type":"rate_limit"}}`) + })) + defer upstream.Close() + + cfg := &config.ModelConfig{ + Backend: "proxy-openai", + Proxy: config.ProxyConfig{UpstreamURL: upstream.URL}, + } + c, rec := newEchoCtx("POST", "/v1/chat/completions", `{}`) + Expect(Forward(c, cfg, []byte(`{"model":"x"}`), nil)).To(Succeed()) + Expect(rec.Code).To(Equal(429)) + Expect(rec.Body.String()).To(ContainSubstring("rate_limit")) + }) + + It("openai stream passthrough", func() { + stream := strings.Join([]string{ + `data: {"choices":[{"delta":{"content":"hello "}}]}`, + ``, + `data: {"choices":[{"delta":{"content":"world"}}]}`, + ``, + `data: [DONE]`, + ``, + }, "\n") + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(200) + _, _ = io.WriteString(w, stream) + })) + defer upstream.Close() + + cfg := &config.ModelConfig{ + Backend: "proxy-openai", + Proxy: config.ProxyConfig{UpstreamURL: upstream.URL}, + } + c, rec := newEchoCtx("POST", "/v1/chat/completions", `{"stream":true}`) + Expect(Forward(c, cfg, []byte(`{"model":"x","stream":true}`), nil)).To(Succeed()) + out := rec.Body.String() + Expect(out).To(ContainSubstring(`"content":"hello "`)) + Expect(out).To(ContainSubstring(`"content":"world"`)) + Expect(out).To(ContainSubstring("[DONE]")) + }) + + It("openai stream PII rewrite", func() { + // The redactor masks email addresses by default (see DefaultPatterns). + // Splitting "alice@example.com" across two chunks proves the + // streaming filter buffers correctly when the PII spans events. + stream := strings.Join([]string{ + `data: {"choices":[{"delta":{"content":"contact me at alice@"}}]}`, + ``, + `data: {"choices":[{"delta":{"content":"example.com please"}}]}`, + ``, + `data: [DONE]`, + ``, + }, "\n") + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(200) + _, _ = io.WriteString(w, stream) + })) + defer upstream.Close() + + patterns, err := pii.Compile(pii.DefaultPatterns()) + Expect(err).NotTo(HaveOccurred()) + r := pii.NewRedactor(patterns) + store := pii.NewMemoryEventStore(16) + filter := pii.NewStreamFilter(r, nil, store, "corr-1", "user-1") + + cfg := &config.ModelConfig{ + Backend: "proxy-openai", + Proxy: config.ProxyConfig{UpstreamURL: upstream.URL}, + } + c, rec := newEchoCtx("POST", "/v1/chat/completions", `{"stream":true}`) + Expect(Forward(c, cfg, []byte(`{"model":"x","stream":true}`), filter)).To(Succeed()) + out := rec.Body.String() + Expect(out).NotTo(ContainSubstring("alice@example.com"), "email leaked through stream") + Expect(out).To(ContainSubstring("[REDACTED:email]")) + Expect(out).To(ContainSubstring("[DONE]")) + }) + + It("anthropic stream PII rewrite", func() { + stream := strings.Join([]string{ + `event: content_block_delta`, + `data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"reach me at bob@"}}`, + ``, + `event: content_block_delta`, + `data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"example.org any time"}}`, + ``, + }, "\n") + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(200) + _, _ = io.WriteString(w, stream) + })) + defer upstream.Close() + + patterns, err := pii.Compile(pii.DefaultPatterns()) + Expect(err).NotTo(HaveOccurred()) + r := pii.NewRedactor(patterns) + store := pii.NewMemoryEventStore(16) + filter := pii.NewStreamFilter(r, nil, store, "corr-1", "user-1") + + cfg := &config.ModelConfig{ + Backend: "proxy-anthropic", + Proxy: config.ProxyConfig{UpstreamURL: upstream.URL}, + } + c, rec := newEchoCtx("POST", "/v1/messages", `{"stream":true}`) + Expect(Forward(c, cfg, []byte(`{"model":"x","stream":true}`), filter)).To(Succeed()) + out := rec.Body.String() + Expect(out).NotTo(ContainSubstring("bob@example.org"), "email leaked through anthropic stream") + Expect(out).To(ContainSubstring("[REDACTED:email]")) + // Event-name preamble must survive the rewrite — Anthropic SDKs + // route on it. + Expect(out).To(ContainSubstring("event: content_block_delta")) + }) +}) + +var _ = Describe("rewriteModel", func() { + It("is a no-op when upstream model is empty", func() { + body := []byte(`{"model":"x","stream":false}`) + out, err := rewriteModel(body, "") + Expect(err).NotTo(HaveOccurred()) + Expect(string(out)).To(Equal(string(body))) + }) + + It("replaces the model", func() { + body := []byte(`{"model":"alias","stream":false}`) + out, err := rewriteModel(body, "real-model-id") + Expect(err).NotTo(HaveOccurred()) + var m map[string]any + Expect(json.Unmarshal(out, &m)).To(Succeed()) + Expect(m["model"]).To(Equal("real-model-id")) + }) +}) + +var _ = Describe("providerName", func() { + cases := map[string]string{ + "proxy-openai": "openai", + "proxy-anthropic": "anthropic", + "proxy-grok": "openai", // unknowns fall back to openai + "": "openai", + } + for backend, want := range cases { + It("maps "+backend, func() { + Expect(providerName(backend)).To(Equal(want)) + }) + } +}) + +var _ = Describe("streaming", func() { + It("detects stream=true", func() { + Expect(streaming([]byte(`{"stream":true}`))).To(BeTrue()) + }) + It("detects stream=false", func() { + Expect(streaming([]byte(`{"stream":false}`))).To(BeFalse()) + }) + It("returns false when stream key absent", func() { + Expect(streaming([]byte(`{}`))).To(BeFalse()) + }) +}) diff --git a/core/services/cloudproxy/ssewire/ssewire.go b/core/services/cloudproxy/ssewire/ssewire.go new file mode 100644 index 000000000000..ed3cb862ba01 --- /dev/null +++ b/core/services/cloudproxy/ssewire/ssewire.go @@ -0,0 +1,218 @@ +// Package ssewire holds the SSE-format helpers shared between +// the request-shape cloud proxy (core/services/cloudproxy) and the +// TLS-terminating MITM proxy (core/services/cloudproxy/mitm). Both +// run a pii.StreamFilter over per-token text extracted from +// provider-specific JSON chunks; this package owns the JSON shapes +// so a future provider addition is one edit, not two. +package ssewire + +import ( + "bufio" + "encoding/json" + "io" + "strings" + + "github.com/mudler/LocalAI/core/services/routing/pii" +) + +// Provider is the upstream wire format an SSE stream conforms to. +type Provider string + +const ( + OpenAI Provider = "openai" + Anthropic Provider = "anthropic" +) + +// Event is one SSE event with its exact wire bytes preserved in +// Raw (so unmodified events round-trip byte-for-byte) and the +// extracted JSON payload from the data: line in DataLine. +type Event struct { + Raw string + DataLine string +} + +// Scanner reads SSE events one at a time from an upstream body. +type Scanner struct { + r *bufio.Reader + ev Event + err error +} + +func NewScanner(r io.Reader) *Scanner { + return &Scanner{r: bufio.NewReaderSize(r, 64*1024)} +} + +func (s *Scanner) Scan() bool { + var raw strings.Builder + var dataLine string + for { + line, err := s.r.ReadString('\n') + if line != "" { + raw.WriteString(line) + trimmed := strings.TrimRight(line, "\r\n") + if trimmed == "" { + if raw.Len() == len(line) { + raw.Reset() + continue + } + s.ev = Event{Raw: raw.String(), DataLine: dataLine} + return true + } + if strings.HasPrefix(trimmed, "data:") && dataLine == "" { + payload := strings.TrimPrefix(trimmed, "data:") + payload = strings.TrimPrefix(payload, " ") + dataLine = payload + } + } + if err != nil { + s.err = err + if raw.Len() > 0 { + s.ev = Event{Raw: raw.String(), DataLine: dataLine} + return true + } + return false + } + } +} + +func (s *Scanner) Event() Event { return s.ev } +func (s *Scanner) Err() error { return s.err } + +// IsTerminalMarker reports whether the data line is the per-provider +// end-of-stream sentinel. The streaming PII filter must drain its +// residue before the caller forwards a terminal marker — clients +// stop reading after it. +func IsTerminalMarker(dataLine string, provider Provider) bool { + if dataLine == "" { + return false + } + if strings.TrimSpace(dataLine) == "[DONE]" { + return true + } + if provider == Anthropic { + var probe struct { + Type string `json:"type"` + } + if err := json.Unmarshal([]byte(dataLine), &probe); err == nil { + return probe.Type == "message_stop" + } + } + return false +} + +// RewritePayload runs the data line's content-bearing field through +// the streaming filter. drop=true tells the caller to suppress the +// SSE event entirely (the filter buffered the whole token while +// disambiguating a pattern boundary). +func RewritePayload(dataLine string, provider Provider, filter *pii.StreamFilter) (rewritten string, drop bool) { + if strings.TrimSpace(dataLine) == "[DONE]" { + return dataLine, false + } + switch provider { + case Anthropic: + return rewriteAnthropic(dataLine, filter) + default: + return rewriteOpenAI(dataLine, filter) + } +} + +func rewriteOpenAI(dataLine string, filter *pii.StreamFilter) (string, bool) { + var m map[string]any + if err := json.Unmarshal([]byte(dataLine), &m); err != nil { + return dataLine, false + } + choices, ok := m["choices"].([]any) + if !ok || len(choices) == 0 { + return dataLine, false + } + first, ok := choices[0].(map[string]any) + if !ok { + return dataLine, false + } + delta, ok := first["delta"].(map[string]any) + if !ok { + return dataLine, false + } + content, ok := delta["content"].(string) + if !ok || content == "" { + return dataLine, false + } + rewritten := filter.Push(content) + if rewritten == "" { + return "", true + } + if rewritten == content { + return dataLine, false + } + delta["content"] = rewritten + out, err := json.Marshal(m) + if err != nil { + return dataLine, false + } + return string(out), false +} + +func rewriteAnthropic(dataLine string, filter *pii.StreamFilter) (string, bool) { + var m map[string]any + if err := json.Unmarshal([]byte(dataLine), &m); err != nil { + return dataLine, false + } + if t, _ := m["type"].(string); t != "content_block_delta" { + return dataLine, false + } + delta, ok := m["delta"].(map[string]any) + if !ok { + return dataLine, false + } + if dt, _ := delta["type"].(string); dt != "text_delta" { + return dataLine, false + } + text, ok := delta["text"].(string) + if !ok || text == "" { + return dataLine, false + } + rewritten := filter.Push(text) + if rewritten == "" { + return "", true + } + if rewritten == text { + return dataLine, false + } + delta["text"] = rewritten + out, err := json.Marshal(m) + if err != nil { + return dataLine, false + } + return string(out), false +} + +// SynthResidualEvent builds a provider-shaped SSE event carrying +// the streaming filter's drained tail so the response body remains +// a valid event stream after the proxy splices in held-back text. +func SynthResidualEvent(provider Provider, text string) string { + switch provider { + case Anthropic: + payload := map[string]any{ + "type": "content_block_delta", + "index": 0, + "delta": map[string]string{"type": "text_delta", "text": text}, + } + b, err := json.Marshal(payload) + if err != nil { + return "" + } + return "event: content_block_delta\ndata: " + string(b) + "\n\n" + default: + payload := map[string]any{ + "object": "chat.completion.chunk", + "choices": []map[string]any{ + {"index": 0, "delta": map[string]string{"content": text}}, + }, + } + b, err := json.Marshal(payload) + if err != nil { + return "" + } + return "data: " + string(b) + "\n\n" + } +} diff --git a/core/services/cloudproxy/ssewire/ssewire_suite_test.go b/core/services/cloudproxy/ssewire/ssewire_suite_test.go new file mode 100644 index 000000000000..6925017f0171 --- /dev/null +++ b/core/services/cloudproxy/ssewire/ssewire_suite_test.go @@ -0,0 +1,13 @@ +package ssewire + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestSsewire(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "ssewire test suite") +} diff --git a/core/services/cloudproxy/ssewire/ssewire_test.go b/core/services/cloudproxy/ssewire/ssewire_test.go new file mode 100644 index 000000000000..2750367fda48 --- /dev/null +++ b/core/services/cloudproxy/ssewire/ssewire_test.go @@ -0,0 +1,114 @@ +package ssewire + +import ( + "strings" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +// Scanner contract: returns one Event per double-newline-terminated +// SSE block, preserving the raw bytes (so unmodified events round-trip +// exactly) and extracting the first data: payload as DataLine. + +var _ = Describe("Scanner", func() { + It("scans a basic event", func() { + in := "event: foo\ndata: hello\n\n" + s := NewScanner(strings.NewReader(in)) + Expect(s.Scan()).To(BeTrue(), "Scan returned false on a well-formed event; err=%v", s.Err()) + ev := s.Event() + Expect(ev.Raw).To(Equal(in)) + Expect(ev.DataLine).To(Equal("hello")) + Expect(s.Scan()).To(BeFalse(), "Scan should return false after the only event") + }) + + It("handles CRLF", func() { + // Some upstreams emit CRLF instead of LF. The scanner trims + // trailing \r off the data line so DataLine carries the same + // bytes whichever line ending the producer chose. + in := "event: foo\r\ndata: hello\r\n\r\n" + s := NewScanner(strings.NewReader(in)) + Expect(s.Scan()).To(BeTrue(), "Scan returned false on CRLF event; err=%v", s.Err()) + Expect(s.Event().DataLine).To(Equal("hello")) + }) + + It("scans multiple events", func() { + in := "data: one\n\ndata: two\n\ndata: three\n\n" + s := NewScanner(strings.NewReader(in)) + got := []string{} + for s.Scan() { + got = append(got, s.Event().DataLine) + } + Expect(got).To(Equal([]string{"one", "two", "three"})) + }) + + It("handles empty data payload", func() { + // "data:" with no payload is valid SSE — DataLine should be empty + // and Scan should still surface the event so callers can decide. + in := "data:\n\n" + s := NewScanner(strings.NewReader(in)) + Expect(s.Scan()).To(BeTrue(), "Scan returned false on empty data payload; err=%v", s.Err()) + Expect(s.Event().DataLine).To(Equal("")) + }) + + It("skips leading blank lines", func() { + // A producer that prints a blank "keep-alive" before the first + // real event must not produce a phantom event. + in := "\n\n\ndata: real\n\n" + s := NewScanner(strings.NewReader(in)) + Expect(s.Scan()).To(BeTrue(), "Scan returned false; err=%v", s.Err()) + Expect(s.Event().DataLine).To(Equal("real")) + }) + + It("handles mid-event EOF", func() { + // EOF mid-event still surfaces the partial event with whatever + // data was extracted — the StreamFilter+caller decides how to + // handle a truncated upstream rather than silently dropping it. + in := "data: half" + s := NewScanner(strings.NewReader(in)) + Expect(s.Scan()).To(BeTrue(), "Scan returned false on partial event") + ev := s.Event() + Expect(ev.DataLine).To(Equal("half")) + Expect(s.Scan()).To(BeFalse(), "Scan should not surface a second event after EOF") + }) +}) + +var _ = Describe("IsTerminalMarker", func() { + cases := []struct { + name string + dataLine string + provider Provider + want bool + }{ + {"openai DONE", "[DONE]", OpenAI, true}, + {"openai DONE with whitespace", " [DONE] ", OpenAI, true}, + {"anthropic DONE also recognised", "[DONE]", Anthropic, true}, + {"anthropic message_stop", `{"type":"message_stop"}`, Anthropic, true}, + {"anthropic content_block_delta is not terminal", `{"type":"content_block_delta"}`, Anthropic, false}, + {"openai chat.completion.chunk is not terminal", `{"object":"chat.completion.chunk"}`, OpenAI, false}, + {"openai message_stop is not terminal (wrong provider)", `{"type":"message_stop"}`, OpenAI, false}, + {"empty data", "", OpenAI, false}, + {"non-json garbage", "garbage", Anthropic, false}, + } + for _, c := range cases { + It(c.name, func() { + Expect(IsTerminalMarker(c.dataLine, c.provider)).To(Equal(c.want)) + }) + } +}) + +var _ = Describe("SynthResidualEvent", func() { + It("anthropic", func() { + got := SynthResidualEvent(Anthropic, "tail") + Expect(strings.HasPrefix(got, "event: content_block_delta\ndata:")).To(BeTrue(), "Anthropic residual event missing event/data lines: %q", got) + Expect(strings.HasSuffix(got, "\n\n")).To(BeTrue(), "Anthropic residual event missing trailing blank line: %q", got) + Expect(got).To(ContainSubstring(`"text":"tail"`)) + }) + + It("openai", func() { + got := SynthResidualEvent(OpenAI, "tail") + Expect(strings.HasPrefix(got, "data: ")).To(BeTrue(), "OpenAI residual event missing data: prefix: %q", got) + Expect(strings.HasSuffix(got, "\n\n")).To(BeTrue(), "OpenAI residual event missing trailing blank line: %q", got) + Expect(got).To(ContainSubstring(`"content":"tail"`)) + }) +}) diff --git a/core/services/monitoring/metrics.go b/core/services/monitoring/metrics.go index fa5663210546..f9644698e6df 100644 --- a/core/services/monitoring/metrics.go +++ b/core/services/monitoring/metrics.go @@ -4,6 +4,7 @@ import ( "context" "github.com/mudler/xlog" + "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/exporters/prometheus" "go.opentelemetry.io/otel/metric" @@ -12,6 +13,7 @@ import ( type LocalAIMetricsService struct { Meter metric.Meter + Provider *metricApi.MeterProvider ApiTimeMetric metric.Float64Histogram } @@ -31,6 +33,13 @@ func NewLocalAIMetricsService() (*LocalAIMetricsService, error) { return nil, err } provider := metricApi.NewMeterProvider(metricApi.WithReader(exporter)) + // Share the provider with the OTel global so packages outside this + // service (e.g., core/services/routing/billing) see the same Prom + // exporter when they call otel.Meter(...). Without this, the billing + // counters would route to the no-op global provider and never reach + // /metrics — which is exactly the silent-billing-loss class of bug + // the routing module is designed to surface. + otel.SetMeterProvider(provider) meter := provider.Meter("github.com/mudler/LocalAI") apiTimeMetric, err := meter.Float64Histogram("api_call", metric.WithDescription("api calls")) @@ -40,6 +49,7 @@ func NewLocalAIMetricsService() (*LocalAIMetricsService, error) { return &LocalAIMetricsService{ Meter: meter, + Provider: provider, ApiTimeMetric: apiTimeMetric, }, nil } diff --git a/core/services/nodes/health_mock_test.go b/core/services/nodes/health_mock_test.go index 45f37b5633fc..b53798c8c239 100644 --- a/core/services/nodes/health_mock_test.go +++ b/core/services/nodes/health_mock_test.go @@ -265,6 +265,12 @@ func (c *fakeBackendClient) StopQuantization(_ context.Context, _ *pb.Quantizati func (c *fakeBackendClient) Free(_ context.Context) error { return nil } +func (c *fakeBackendClient) TokenClassify(_ context.Context, _ *pb.TokenClassifyRequest, _ ...ggrpc.CallOption) (*pb.TokenClassifyResponse, error) { + return nil, nil +} +func (c *fakeBackendClient) Score(_ context.Context, _ *pb.ScoreRequest, _ ...ggrpc.CallOption) (*pb.ScoreResponse, error) { + return nil, nil +} // --- fakeBackendClientFactory --- diff --git a/core/services/nodes/inflight_test.go b/core/services/nodes/inflight_test.go index edb04b6f81a7..53dde692f975 100644 --- a/core/services/nodes/inflight_test.go +++ b/core/services/nodes/inflight_test.go @@ -220,6 +220,14 @@ func (f *fakeGRPCBackend) Free(_ context.Context) error { return nil } +func (f *fakeGRPCBackend) TokenClassify(_ context.Context, _ *pb.TokenClassifyRequest, _ ...ggrpc.CallOption) (*pb.TokenClassifyResponse, error) { + return nil, nil +} + +func (f *fakeGRPCBackend) Score(_ context.Context, _ *pb.ScoreRequest, _ ...ggrpc.CallOption) (*pb.ScoreResponse, error) { + return nil, nil +} + // --- Tests --- var _ = Describe("InFlightTrackingClient", func() { diff --git a/core/services/routing/admission/admission.go b/core/services/routing/admission/admission.go new file mode 100644 index 000000000000..16824818178d --- /dev/null +++ b/core/services/routing/admission/admission.go @@ -0,0 +1,105 @@ +// Package admission is routing-module subsystem 5: per-model +// concurrency control + audit. The middleware acquires a slot +// before the handler runs; on full, the request gets 503 with +// Retry-After so clients back off rather than pile on. The audit +// row goes into the shared event store alongside PII and proxy +// rows so admins see a single timeline of routing pressure. +// +// Concurrency model: one buffered channel per model name (kept in +// a sync.Map). Acquire is a non-blocking send; full = reject. No +// queueing in the MVP — adding queue depth + timeout is a small +// follow-up if/when telemetry shows admins want it. +package admission + +import ( + "sync" + "time" +) + +// Limiter holds the per-model semaphores. Safe for concurrent use. +// +// Each model's slot count is fixed at first Acquire — a config +// edit that reduces MaxConcurrent only takes effect on the NEXT +// process start (or after the limiter is rebuilt). The alternative +// (dynamic resize on every call) would require swapping the channel +// out from under in-flight Acquires; the simplicity tradeoff favors +// "restart to apply" since admins editing limits do so rarely. +type Limiter struct { + mu sync.Mutex + slots map[string]chan struct{} +} + +// New returns an empty Limiter. +func New() *Limiter { + return &Limiter{slots: make(map[string]chan struct{})} +} + +// Acquire takes a slot for the named model. maxConcurrent <= 0 +// means unlimited — Acquire returns immediately with a no-op +// release. When all slots are busy, returns ok=false. Callers +// MUST call the returned release when done (typically via defer); +// missing a release leaks one slot for the lifetime of the +// process. +func (l *Limiter) Acquire(modelName string, maxConcurrent int) (release func(), ok bool) { + if maxConcurrent <= 0 { + return func() {}, true + } + ch := l.slot(modelName, maxConcurrent) + select { + case ch <- struct{}{}: + return func() { <-ch }, true + default: + return nil, false + } +} + +// InFlight reports the number of currently-held slots for the +// named model. Used by the admin status surface — read-only and +// approximate (ch length is racy with concurrent Acquire/release +// but that's fine for a dashboard). +func (l *Limiter) InFlight(modelName string) int { + l.mu.Lock() + ch, ok := l.slots[modelName] + l.mu.Unlock() + if !ok { + return 0 + } + return len(ch) +} + +// Capacity reports the limiter's configured slot count for the +// named model, or 0 if the model has never had Acquire called +// against it. Same dashboard-only purpose as InFlight. +func (l *Limiter) Capacity(modelName string) int { + l.mu.Lock() + ch, ok := l.slots[modelName] + l.mu.Unlock() + if !ok { + return 0 + } + return cap(ch) +} + +// slot returns the per-model channel, creating it on first use. +func (l *Limiter) slot(modelName string, capacity int) chan struct{} { + l.mu.Lock() + defer l.mu.Unlock() + if ch, ok := l.slots[modelName]; ok { + return ch + } + ch := make(chan struct{}, capacity) + l.slots[modelName] = ch + return ch +} + +// RetryAfter returns the Retry-After header value for a rejected +// request. The Limiter doesn't track rolling latency — this is a +// pure config-driven hint, identity-mapped to the LimitsConfig +// field with a 1s fallback. Centralised here so the middleware +// doesn't reimplement the default rule. +func RetryAfter(configured int) time.Duration { + if configured > 0 { + return time.Duration(configured) * time.Second + } + return time.Second +} diff --git a/core/services/routing/admission/admission_suite_test.go b/core/services/routing/admission/admission_suite_test.go new file mode 100644 index 000000000000..163c84db386d --- /dev/null +++ b/core/services/routing/admission/admission_suite_test.go @@ -0,0 +1,13 @@ +package admission + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestAdmission(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Admission test suite") +} diff --git a/core/services/routing/admission/admission_test.go b/core/services/routing/admission/admission_test.go new file mode 100644 index 000000000000..b97e18969eec --- /dev/null +++ b/core/services/routing/admission/admission_test.go @@ -0,0 +1,103 @@ +package admission + +import ( + "sync" + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Limiter", func() { + It("returns immediate no-op when unlimited", func() { + l := New() + for i := 0; i < 5; i++ { + release, ok := l.Acquire("anything", 0) + Expect(ok).To(BeTrue(), "max=0 should never reject") + release() + } + Expect(l.InFlight("anything")).To(Equal(0)) + Expect(l.Capacity("anything")).To(Equal(0)) + }) + + It("rejects when full", func() { + // Two concurrent requests at MaxConcurrent=1: second is + // rejected and the limiter reports the in-flight count. + l := New() + r1, ok := l.Acquire("m", 1) + Expect(ok).To(BeTrue(), "first Acquire should succeed") + defer r1() + + _, ok = l.Acquire("m", 1) + Expect(ok).To(BeFalse(), "second Acquire should reject — slot is held") + Expect(l.InFlight("m")).To(Equal(1)) + Expect(l.Capacity("m")).To(Equal(1)) + }) + + It("allows the next Acquire after Release", func() { + l := New() + r1, _ := l.Acquire("m", 1) + r1() + _, ok := l.Acquire("m", 1) + Expect(ok).To(BeTrue(), "Acquire after release should succeed") + }) + + It("isolates slots per-model", func() { + // Slots are per-model; saturating one does not affect another. + l := New() + r1, _ := l.Acquire("m1", 1) + defer r1() + _, ok := l.Acquire("m2", 1) + Expect(ok).To(BeTrue(), "m2 should have its own slot") + }) + + It("honours the cap under concurrent Acquires", func() { + // Hammer Acquire from multiple goroutines; the count of + // successful acquires must not exceed the cap. + l := New() + const cap = 4 + const goroutines = 50 + var wg sync.WaitGroup + successes := make(chan func(), goroutines) + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if release, ok := l.Acquire("m", cap); ok { + successes <- release + } + }() + } + wg.Wait() + close(successes) + + count := 0 + for r := range successes { + count++ + r() + } + Expect(count).To(Equal(cap)) + }) + + // First-Acquire fixes the channel capacity. A subsequent Acquire + // at a different maxConcurrent does NOT resize — admins editing + // limits expect a process restart. Pin that behaviour so the + // surprise isn't accidentally introduced. + It("fixes the cap at first Acquire", func() { + l := New() + r1, _ := l.Acquire("m", 2) + defer r1() + // Try to acquire with cap=10 — should still be bounded by 2. + r2, _ := l.Acquire("m", 10) + defer r2() + _, ok := l.Acquire("m", 10) + Expect(ok).To(BeFalse(), "third Acquire should reject — initial cap of 2 still applies") + }) +}) + +var _ = Describe("RetryAfter", func() { + It("defaults to one second", func() { + Expect(RetryAfter(0)).To(Equal(time.Second)) + Expect(RetryAfter(5)).To(Equal(5 * time.Second)) + }) +}) diff --git a/core/services/routing/billing/backend.go b/core/services/routing/billing/backend.go new file mode 100644 index 000000000000..69119878eef5 --- /dev/null +++ b/core/services/routing/billing/backend.go @@ -0,0 +1,52 @@ +// Package billing provides the StatsBackend abstraction that decouples +// per-request token tracking from the auth database. This lets a +// single-user no-auth deployment still see usage and costs, which the +// pre-routing-module middleware did not allow. +package billing + +import ( + "context" + + "github.com/mudler/LocalAI/core/http/auth" +) + +// StatsBackend is the persistence target for usage records. Three +// implementations exist: +// +// - GORM (auth-DB-backed) — used when --auth is on; records share the +// auth database and existing aggregation queries continue to work. +// - Memory (ring buffer) — used when --auth is off and no other DB is +// configured. Records are lost on restart by design; the same +// process can still answer aggregation queries for live dashboards. +// - Disabled — explicit no-op when --disable-stats is set, useful in +// ephemeral CI runs. +// +// All implementations are safe for concurrent use. Record() must not +// block the caller for more than the time it takes to enqueue — durable +// flushing happens on a background goroutine inside the implementation. +type StatsBackend interface { + // Record enqueues a single usage record. The record is asynchronously + // persisted; callers should not assume durability on return. The ctx + // is currently unused but reserved for future cancellation. + Record(ctx context.Context, r *auth.UsageRecord) error + + // Aggregate returns time-bucketed totals for the dashboard. The + // AggregateQuery's UserID is required; pass the empty string only + // from admin-scoped paths. Implementations that do not support + // aggregation (e.g., ring buffer in saturation) may return an empty + // result with no error. + Aggregate(ctx context.Context, q AggregateQuery) ([]auth.UsageBucket, error) + + // Close releases resources (flushes pending records, stops + // goroutines). Safe to call multiple times. + Close() error +} + +// AggregateQuery describes a usage aggregation request. Period is one of +// "day", "week", "month", "all" (matching the existing auth.UsageRecord +// vocabulary). UserID empty means cluster-wide; callers must enforce +// admin permission before passing the empty string. +type AggregateQuery struct { + UserID string + Period string +} diff --git a/core/services/routing/billing/billing_suite_test.go b/core/services/routing/billing/billing_suite_test.go new file mode 100644 index 000000000000..0b43673eb2a5 --- /dev/null +++ b/core/services/routing/billing/billing_suite_test.go @@ -0,0 +1,13 @@ +package billing + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestBilling(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "billing test suite") +} diff --git a/core/services/routing/billing/disabled.go b/core/services/routing/billing/disabled.go new file mode 100644 index 000000000000..c6d35df6c735 --- /dev/null +++ b/core/services/routing/billing/disabled.go @@ -0,0 +1,20 @@ +package billing + +import ( + "context" + + "github.com/mudler/LocalAI/core/http/auth" +) + +// disabledBackend drops every record. Used when --disable-stats is set, +// e.g., for ephemeral CI runs where token tracking is just noise. +type disabledBackend struct{} + +// NewDisabledBackend returns a no-op StatsBackend. +func NewDisabledBackend() StatsBackend { return disabledBackend{} } + +func (disabledBackend) Record(_ context.Context, _ *auth.UsageRecord) error { return nil } +func (disabledBackend) Aggregate(_ context.Context, _ AggregateQuery) ([]auth.UsageBucket, error) { + return nil, nil +} +func (disabledBackend) Close() error { return nil } diff --git a/core/services/routing/billing/gorm.go b/core/services/routing/billing/gorm.go new file mode 100644 index 000000000000..d304af10a445 --- /dev/null +++ b/core/services/routing/billing/gorm.go @@ -0,0 +1,111 @@ +package billing + +import ( + "context" + "sync" + "time" + + "github.com/mudler/LocalAI/core/http/auth" + "github.com/mudler/xlog" + "gorm.io/gorm" +) + +// gormBackend writes UsageRecord rows to a GORM-backed database (the +// existing auth DB when --auth is enabled). It batches inserts every +// flushInterval to amortize round-trips; pre-routing-module middleware +// did the same with a private batcher — we keep the same cadence. +type gormBackend struct { + db *gorm.DB + flushInterval time.Duration + maxPending int + + mu sync.Mutex + pending []*auth.UsageRecord + + stopCh chan struct{} + doneCh chan struct{} +} + +// NewGormBackend constructs a StatsBackend that persists records to db. +// The returned backend launches a background flush goroutine; call +// Close() to stop it. flushInterval ≤ 0 picks the prior 5s default; +// maxPending ≤ 0 picks 5000. +func NewGormBackend(db *gorm.DB, flushInterval time.Duration, maxPending int) StatsBackend { + if flushInterval <= 0 { + flushInterval = 5 * time.Second + } + if maxPending <= 0 { + maxPending = 5000 + } + b := &gormBackend{ + db: db, + flushInterval: flushInterval, + maxPending: maxPending, + stopCh: make(chan struct{}), + doneCh: make(chan struct{}), + } + go b.run() + return b +} + +func (b *gormBackend) Record(_ context.Context, r *auth.UsageRecord) error { + b.mu.Lock() + b.pending = append(b.pending, r) + b.mu.Unlock() + return nil +} + +func (b *gormBackend) Aggregate(_ context.Context, q AggregateQuery) ([]auth.UsageBucket, error) { + if q.UserID == "" { + return auth.GetAllUsage(b.db, q.Period, "") + } + return auth.GetUserUsage(b.db, q.UserID, q.Period) +} + +func (b *gormBackend) Close() error { + select { + case <-b.stopCh: + // already stopped + default: + close(b.stopCh) + } + <-b.doneCh + return nil +} + +func (b *gormBackend) run() { + defer close(b.doneCh) + ticker := time.NewTicker(b.flushInterval) + defer ticker.Stop() + for { + select { + case <-b.stopCh: + b.flush() + return + case <-ticker.C: + b.flush() + } + } +} + +func (b *gormBackend) flush() { + b.mu.Lock() + batch := b.pending + b.pending = nil + b.mu.Unlock() + + if len(batch) == 0 { + return + } + + if err := b.db.Create(&batch).Error; err != nil { + xlog.Error("failed to flush usage batch", "count", len(batch), "error", err) + // Re-queue with a cap to avoid unbounded growth on persistent DB + // failure (matches the prior behavior in core/http/middleware/usage.go). + b.mu.Lock() + if len(b.pending) < b.maxPending { + b.pending = append(batch, b.pending...) + } + b.mu.Unlock() + } +} diff --git a/core/services/routing/billing/inmem.go b/core/services/routing/billing/inmem.go new file mode 100644 index 000000000000..3be341c7b01a --- /dev/null +++ b/core/services/routing/billing/inmem.go @@ -0,0 +1,157 @@ +package billing + +import ( + "context" + "sync" + "time" + + "github.com/mudler/LocalAI/core/http/auth" +) + +// memoryBackend keeps the most recent N records in a ring buffer. It is +// the no-auth, no-DB fallback: a single user running LocalAI on a +// laptop still gets live aggregation against this buffer until the +// process exits. Records are not durable. +// +// Aggregation is computed by linear scan — fine because the ring is +// bounded (default 50_000 records) and aggregation is rare (UI dashboard +// poll, MCP tool calls). If the working set grows beyond what scan can +// service in <100ms, the operator should enable auth+DB. +type memoryBackend struct { + mu sync.RWMutex + ring []*auth.UsageRecord + cap int + cursor int // next write position + full bool +} + +// NewMemoryBackend returns a StatsBackend backed by an in-process ring +// buffer. capacity ≤ 0 uses 50_000. +func NewMemoryBackend(capacity int) StatsBackend { + if capacity <= 0 { + capacity = 50_000 + } + return &memoryBackend{ + ring: make([]*auth.UsageRecord, capacity), + cap: capacity, + } +} + +func (b *memoryBackend) Record(_ context.Context, r *auth.UsageRecord) error { + b.mu.Lock() + defer b.mu.Unlock() + b.ring[b.cursor] = r + b.cursor++ + if b.cursor == b.cap { + b.cursor = 0 + b.full = true + } + return nil +} + +func (b *memoryBackend) Aggregate(_ context.Context, q AggregateQuery) ([]auth.UsageBucket, error) { + since := periodStart(q.Period) + bucketWidth := bucketWidthFor(q.Period) + dateFmt := bucketFormatFor(q.Period) + + type aggKey struct { + bucket string + model string + userID string + userName string + } + agg := make(map[aggKey]*auth.UsageBucket) + + b.mu.RLock() + defer b.mu.RUnlock() + + scan := func(r *auth.UsageRecord) { + if r == nil { + return + } + if !since.IsZero() && r.CreatedAt.Before(since) { + return + } + if q.UserID != "" && r.UserID != q.UserID { + return + } + bucketTime := r.CreatedAt.Truncate(bucketWidth) + key := aggKey{ + bucket: bucketTime.Format(dateFmt), + model: r.Model, + userID: r.UserID, + userName: r.UserName, + } + entry, ok := agg[key] + if !ok { + entry = &auth.UsageBucket{ + Bucket: key.bucket, + Model: key.model, + UserID: key.userID, + UserName: key.userName, + } + agg[key] = entry + } + entry.PromptTokens += r.PromptTokens + entry.CompletionTokens += r.CompletionTokens + entry.TotalTokens += r.TotalTokens + entry.RequestCount++ + } + + if b.full { + for _, r := range b.ring { + scan(r) + } + } else { + for i := 0; i < b.cursor; i++ { + scan(b.ring[i]) + } + } + + out := make([]auth.UsageBucket, 0, len(agg)) + for _, v := range agg { + out = append(out, *v) + } + return out, nil +} + +func (b *memoryBackend) Close() error { return nil } + +// periodStart returns the lower bound of the time window for the +// given period. Mirrors auth.periodToWindow but without GORM +// dialector concerns. +func periodStart(period string) time.Time { + now := time.Now() + switch period { + case "day": + return now.Add(-24 * time.Hour) + case "week": + return now.Add(-7 * 24 * time.Hour) + case "all": + return time.Time{} + default: // "month" + return now.Add(-30 * 24 * time.Hour) + } +} + +func bucketWidthFor(period string) time.Duration { + switch period { + case "day": + return time.Hour + case "all": + return 30 * 24 * time.Hour + default: // week, month + return 24 * time.Hour + } +} + +func bucketFormatFor(period string) string { + switch period { + case "day": + return "2006-01-02 15:00" + case "all": + return "2006-01" + default: + return "2006-01-02" + } +} diff --git a/core/services/routing/billing/inmem_test.go b/core/services/routing/billing/inmem_test.go new file mode 100644 index 000000000000..c630249a9caa --- /dev/null +++ b/core/services/routing/billing/inmem_test.go @@ -0,0 +1,140 @@ +package billing + +import ( + "context" + "time" + + "github.com/mudler/LocalAI/core/http/auth" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("MemoryBackend", func() { + It("records and aggregates", func() { + ctx := context.Background() + b := NewMemoryBackend(0) + defer func() { _ = b.Close() }() + + now := time.Now() + for i := 0; i < 5; i++ { + err := b.Record(ctx, &auth.UsageRecord{ + UserID: "u-1", + UserName: "alice", + Model: "qwen-7b", + Endpoint: "/v1/chat/completions", + PromptTokens: 10, + CompletionTokens: 20, + TotalTokens: 30, + CreatedAt: now, + }) + Expect(err).NotTo(HaveOccurred(), "record") + } + for i := 0; i < 3; i++ { + err := b.Record(ctx, &auth.UsageRecord{ + UserID: "u-2", + UserName: "bob", + Model: "qwen-7b", + Endpoint: "/v1/chat/completions", + PromptTokens: 7, + CompletionTokens: 13, + TotalTokens: 20, + CreatedAt: now, + }) + Expect(err).NotTo(HaveOccurred(), "record") + } + + buckets, err := b.Aggregate(ctx, AggregateQuery{UserID: "u-1", Period: "month"}) + Expect(err).NotTo(HaveOccurred(), "aggregate") + var promptTotal, reqTotal int64 + for _, bk := range buckets { + Expect(bk.UserID).To(Equal("u-1"), "expected only u-1 buckets") + promptTotal += bk.PromptTokens + reqTotal += bk.RequestCount + } + Expect(promptTotal).To(Equal(int64(50))) + Expect(reqTotal).To(Equal(int64(5))) + + all, err := b.Aggregate(ctx, AggregateQuery{Period: "month"}) + Expect(err).NotTo(HaveOccurred(), "aggregate all") + var allPrompt, allReqs int64 + for _, bk := range all { + allPrompt += bk.PromptTokens + allReqs += bk.RequestCount + } + Expect(allPrompt).To(Equal(int64(50 + 21))) + Expect(allReqs).To(Equal(int64(8))) + }) + + It("filters by period", func() { + ctx := context.Background() + b := NewMemoryBackend(0) + defer func() { _ = b.Close() }() + + old := time.Now().Add(-48 * time.Hour) + recent := time.Now() + + err := b.Record(ctx, &auth.UsageRecord{ + UserID: "u", UserName: "u", Model: "m", + PromptTokens: 100, TotalTokens: 100, CreatedAt: old, + }) + Expect(err).NotTo(HaveOccurred()) + err = b.Record(ctx, &auth.UsageRecord{ + UserID: "u", UserName: "u", Model: "m", + PromptTokens: 50, TotalTokens: 50, CreatedAt: recent, + }) + Expect(err).NotTo(HaveOccurred()) + + dayBuckets, err := b.Aggregate(ctx, AggregateQuery{UserID: "u", Period: "day"}) + Expect(err).NotTo(HaveOccurred()) + var dayTotal int64 + for _, bk := range dayBuckets { + dayTotal += bk.PromptTokens + } + Expect(dayTotal).To(Equal(int64(50)), "day window should only include the recent record") + + monthBuckets, err := b.Aggregate(ctx, AggregateQuery{UserID: "u", Period: "month"}) + Expect(err).NotTo(HaveOccurred()) + var monthTotal int64 + for _, bk := range monthBuckets { + monthTotal += bk.PromptTokens + } + Expect(monthTotal).To(Equal(int64(150)), "month window should include both records") + }) + + It("ring wraps", func() { + ctx := context.Background() + b := NewMemoryBackend(4) // tiny ring so we can observe wrap + + for i := 0; i < 10; i++ { + err := b.Record(ctx, &auth.UsageRecord{ + UserID: "u", + UserName: "u", + Model: "m", + PromptTokens: 1, + TotalTokens: 1, + CreatedAt: time.Now(), + }) + Expect(err).NotTo(HaveOccurred()) + } + + buckets, err := b.Aggregate(ctx, AggregateQuery{UserID: "u", Period: "month"}) + Expect(err).NotTo(HaveOccurred()) + var total int64 + for _, bk := range buckets { + total += bk.PromptTokens + } + Expect(total).To(Equal(int64(4)), "ring should keep last 4 records") + }) +}) + +var _ = Describe("DisabledBackend", func() { + It("is a no-op", func() { + ctx := context.Background() + b := NewDisabledBackend() + Expect(b.Record(ctx, &auth.UsageRecord{UserID: "u"})).To(Succeed(), "disabled record should not error") + out, err := b.Aggregate(ctx, AggregateQuery{Period: "month"}) + Expect(err).NotTo(HaveOccurred(), "disabled aggregate should not error") + Expect(out).To(BeNil(), "disabled aggregate should return nil") + }) +}) diff --git a/core/services/routing/billing/local_user.go b/core/services/routing/billing/local_user.go new file mode 100644 index 000000000000..c3ae1fa485ae --- /dev/null +++ b/core/services/routing/billing/local_user.go @@ -0,0 +1,84 @@ +package billing + +import ( + "crypto/rand" + "encoding/hex" + "errors" + "os" + "path/filepath" + "sync" + + "github.com/mudler/LocalAI/core/http/auth" + "github.com/mudler/xlog" +) + +// LocalUserName is the fixed display name used for the synthetic +// no-auth user. Surfaces it in the dashboard so single-user installs +// have a recognizable label rather than an opaque UUID. +const LocalUserName = "local" + +// localUserIDFile is the basename, inside DataPath, where we persist +// the synthetic user's UUID so it stays stable across restarts. +const localUserIDFile = ".local_user_id" + +var ( + localOnce sync.Once + localUser *auth.User +) + +// LocalUser returns a process-singleton "local" user used by +// UsageMiddleware when --auth is off. The user's ID is persisted to +// dataPath so usage history aggregates correctly across restarts; if +// dataPath is empty, a fresh random UUID is generated for this process +// only and aggregation drops on restart (in-memory mode). +// +// Concurrency note: the singleton uses sync.Once, so calling LocalUser +// from any goroutine is safe; the first call may briefly hit disk. +func LocalUser(dataPath string) *auth.User { + localOnce.Do(func() { + id := loadOrGenerateLocalUserID(dataPath) + localUser = &auth.User{ + ID: id, + Name: LocalUserName, + Email: "", + Provider: auth.ProviderLocal, + Role: "admin", // single-user box: the only user has full access + Status: "active", + } + }) + return localUser +} + +func loadOrGenerateLocalUserID(dataPath string) string { + if dataPath != "" { + path := filepath.Join(dataPath, localUserIDFile) + if b, err := os.ReadFile(path); err == nil { + id := string(b) + if len(id) > 0 { + return id + } + } else if !errors.Is(err, os.ErrNotExist) { + xlog.Warn("failed to read local user id file; generating fresh", "path", path, "error", err) + } + id := newUUID() + // 0600: only the LocalAI process owner should read this. The file + // is just a stable identifier, not a credential, but we keep it + // tight by default. + if err := os.WriteFile(path, []byte(id), 0o600); err != nil { + xlog.Warn("failed to persist local user id; will regenerate next start", "path", path, "error", err) + } + return id + } + return newUUID() +} + +func newUUID() string { + var b [16]byte + _, _ = rand.Read(b[:]) + // Set version 4 + RFC 4122 variant bits so this round-trips through + // any UUID parser the rest of the codebase might use. + b[6] = (b[6] & 0x0f) | 0x40 + b[8] = (b[8] & 0x3f) | 0x80 + hexb := hex.EncodeToString(b[:]) + return hexb[0:8] + "-" + hexb[8:12] + "-" + hexb[12:16] + "-" + hexb[16:20] + "-" + hexb[20:32] +} diff --git a/core/services/routing/billing/local_user_test.go b/core/services/routing/billing/local_user_test.go new file mode 100644 index 000000000000..6d80b8607a10 --- /dev/null +++ b/core/services/routing/billing/local_user_test.go @@ -0,0 +1,70 @@ +package billing + +import ( + "os" + "path/filepath" + "sync" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("LocalUser", func() { + It("persists ID", func() { + // Reset the package-singleton sentinel so this test gets a fresh + // LocalUser call. Without this, other tests racing through LocalUser + // would freeze the value before we set DataPath. + resetLocalUserForTesting() + + dir := GinkgoT().TempDir() + u1 := LocalUser(dir) + Expect(u1).NotTo(BeNil(), "LocalUser returned nil") + Expect(u1.ID).NotTo(BeEmpty(), "LocalUser must have a non-empty ID") + Expect(u1.Name).To(Equal(LocalUserName)) + + // File written? + idPath := filepath.Join(dir, localUserIDFile) + got, err := os.ReadFile(idPath) + Expect(err).NotTo(HaveOccurred(), "expected %s to exist", idPath) + Expect(string(got)).To(Equal(u1.ID)) + + // Singleton: subsequent calls return the same pointer. + u2 := LocalUser(dir) + Expect(u2).To(BeIdenticalTo(u1), "LocalUser returned a different instance on second call") + }) + + It("is stable across processes", func() { + resetLocalUserForTesting() + dir := GinkgoT().TempDir() + + first := LocalUser(dir).ID + + // Simulate process restart by clearing the singleton; the disk file + // must let us recover the same UUID. + resetLocalUserForTesting() + + second := LocalUser(dir).ID + Expect(first).To(Equal(second), "local user id not stable across restart") + }) + + It("works with no data path", func() { + resetLocalUserForTesting() + u := LocalUser("") + Expect(u).NotTo(BeNil()) + Expect(u.ID).NotTo(BeEmpty(), "LocalUser with empty data path must still produce a usable user") + }) +}) + +// resetLocalUserForTesting clears the package singleton so a test can +// rebind LocalUser to a fresh state. Tests must serialize on a mutex +// because Go tests within a package run concurrently within the same +// goroutine pool — LocalUser's sync.Once is a global, and these tests +// deliberately reach past it. +var testResetMu sync.Mutex + +func resetLocalUserForTesting() { + testResetMu.Lock() + defer testResetMu.Unlock() + localOnce = sync.Once{} + localUser = nil +} diff --git a/core/services/routing/billing/prom.go b/core/services/routing/billing/prom.go new file mode 100644 index 000000000000..352f699edf38 --- /dev/null +++ b/core/services/routing/billing/prom.go @@ -0,0 +1,215 @@ +package billing + +import ( + "context" + "sync" + + "github.com/mudler/LocalAI/core/http/auth" + "github.com/mudler/LocalAI/core/services/routing/contract" + "github.com/mudler/xlog" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" +) + +// Recorder is the single increment site for billing data. It writes +// the same record to (a) the StatsBackend (durable / queryable) and +// (b) Prometheus counters (live ops). Splitting these would invite +// drift; this type guarantees both fire in lockstep from one call. +// +// The plan calls out a DB-vs-Prom drift assertion. With a single +// increment site, drift can only come from StatsBackend.Record returning +// without persisting (e.g., the DB flusher dropping batches under load +// — see gormBackend.flush). We log+invariant-fail in that path; a future +// drift goroutine compares Prom to a SUM(total_tokens) checkpoint as +// extra defense in depth. +type Recorder struct { + backend StatsBackend + + tokensCounter metric.Int64Counter + costCounter metric.Float64Counter + requestsCount metric.Int64Counter +} + +var ( + metricsOnce sync.Once + sharedTokensCounter metric.Int64Counter + sharedCostCounter metric.Float64Counter + sharedRequestsCount metric.Int64Counter + sharedUnrecordedCounter metric.Int64Counter + + // configuredMeter is the meter handed in by the caller (typically + // monitoring.LocalAIMetricsService). Setting it before initMetrics + // runs makes sure billing's counters land on the same Prom-backed + // MeterProvider that exports /metrics. Without this we relied on + // otel.SetMeterProvider race ordering, which silently dropped + // counters when initMetrics ran first. + configuredMeterMu sync.Mutex + configuredMeter metric.Meter +) + +// SetMeter wires the meter from monitoring.LocalAIMetricsService (or any +// caller-controlled MeterProvider) before any Recorder is constructed. +// Call from application startup — initMetrics uses this meter rather than +// the OTel global the moment it's set. +func SetMeter(m metric.Meter) { + configuredMeterMu.Lock() + defer configuredMeterMu.Unlock() + configuredMeter = m +} + +func resolveMeter() metric.Meter { + configuredMeterMu.Lock() + m := configuredMeter + configuredMeterMu.Unlock() + if m != nil { + return m + } + return otel.Meter("github.com/mudler/LocalAI/core/services/routing/billing") +} + +func initMetrics() { + metricsOnce.Do(func() { + meter := resolveMeter() + var err error + sharedTokensCounter, err = meter.Int64Counter( + "localai_tokens_total", + metric.WithDescription("Cumulative tokens accounted, labeled by user, served_model, kind"), + ) + if err != nil { + xlog.Error("billing: failed to create tokens counter", "error", err) + } + sharedCostCounter, err = meter.Float64Counter( + "localai_cost_usd_total", + metric.WithDescription("Cumulative USD cost accounted, labeled by user, served_model"), + ) + if err != nil { + xlog.Error("billing: failed to create cost counter", "error", err) + } + sharedRequestsCount, err = meter.Int64Counter( + "localai_billed_requests_total", + metric.WithDescription("Cumulative billed requests, labeled by user, served_model, endpoint"), + ) + if err != nil { + xlog.Error("billing: failed to create requests counter", "error", err) + } + sharedUnrecordedCounter, err = meter.Int64Counter( + "localai_usage_unrecorded_total", + metric.WithDescription("Requests that completed but produced no UsageRecord, labeled by endpoint and reason. A non-zero rate signals a billing gap (handler didn't stamp, body lacked usage, no user resolvable)."), + ) + if err != nil { + xlog.Error("billing: failed to create unrecorded counter", "error", err) + } + }) +} + +// CountUnrecorded ticks the localai_usage_unrecorded_total counter so that +// silent billing misses are observable. UsageMiddleware calls this whenever +// a request completes without producing a UsageRecord. Reasons should be +// short, stable strings ("no_handler_stamp", "no_user", "parse_failed", …) +// — never user-supplied content. +func CountUnrecorded(ctx context.Context, endpoint, reason string) { + initMetrics() + if sharedUnrecordedCounter == nil { + return + } + sharedUnrecordedCounter.Add(ctx, 1, + metric.WithAttributes( + attribute.String("endpoint", endpoint), + attribute.String("reason", reason), + )) +} + +// NewRecorder returns a Recorder that fans out to the given StatsBackend +// and to Prometheus. The Prom counters are package-singletons so that +// multiple Recorders (e.g., reusing the same metrics across rebuilds) +// don't double-register identical metric names. +func NewRecorder(backend StatsBackend) *Recorder { + initMetrics() + return &Recorder{ + backend: backend, + tokensCounter: sharedTokensCounter, + costCounter: sharedCostCounter, + requestsCount: sharedRequestsCount, + } +} + +// Record asserts billing invariants, persists the record, and emits the +// matching Prom counters. r must not be mutated by the caller after +// this call; the backend takes ownership. +func (rec *Recorder) Record(ctx context.Context, r *auth.UsageRecord) error { + rec.assertInvariants(r) + + if err := rec.backend.Record(ctx, r); err != nil { + return err + } + + if rec.tokensCounter != nil { + userAttr := attribute.String("user", r.UserID) + modelAttr := attribute.String("served_model", servedModelOf(r)) + rec.tokensCounter.Add(ctx, r.PromptTokens, + metric.WithAttributes(userAttr, modelAttr, attribute.String("kind", "prompt"))) + rec.tokensCounter.Add(ctx, r.CompletionTokens, + metric.WithAttributes(userAttr, modelAttr, attribute.String("kind", "completion"))) + } + if rec.costCounter != nil && r.PricingVersionID != "" { + rec.costCounter.Add(ctx, r.CostUSD, + metric.WithAttributes( + attribute.String("user", r.UserID), + attribute.String("served_model", servedModelOf(r)), + )) + } + if rec.requestsCount != nil { + rec.requestsCount.Add(ctx, 1, + metric.WithAttributes( + attribute.String("user", r.UserID), + attribute.String("served_model", servedModelOf(r)), + attribute.String("endpoint", r.Endpoint), + )) + } + return nil +} + +// Aggregate is a convenience pass-through. +func (rec *Recorder) Aggregate(ctx context.Context, q AggregateQuery) ([]auth.UsageBucket, error) { + return rec.backend.Aggregate(ctx, q) +} + +// Close flushes the underlying backend. +func (rec *Recorder) Close() error { return rec.backend.Close() } + +func (rec *Recorder) assertInvariants(r *auth.UsageRecord) { + contract.Invariant( + "billing.user_id_present", + r.UserID != "", + "endpoint", r.Endpoint, "model", r.Model, + ) + // PII can only shrink the prompt; a post-filter count above pre-filter + // would mean the filter expanded text, which is impossible by design. + // Both are zero on legacy paths that don't populate the new fields, + // so the assertion only fires when one side is set. + if r.PreFilterPromptTokens > 0 || r.PostFilterPromptTokens > 0 { + contract.Invariant( + "billing.prefilter_ge_postfilter", + r.PreFilterPromptTokens >= r.PostFilterPromptTokens, + "pre", r.PreFilterPromptTokens, "post", r.PostFilterPromptTokens, + "user", r.UserID, "model", r.Model, + ) + } + // CostUSD without a pricing version is a data-integrity bug: we'd + // be unable to retroactively recompute or audit the rate used. + if r.CostUSD != 0 { + contract.Invariant( + "billing.cost_requires_pricing_version", + r.PricingVersionID != "", + "cost", r.CostUSD, "model", r.Model, + ) + } +} + +func servedModelOf(r *auth.UsageRecord) string { + if r.ServedModel != "" { + return r.ServedModel + } + return r.Model +} diff --git a/core/services/routing/billing/recorder_test.go b/core/services/routing/billing/recorder_test.go new file mode 100644 index 000000000000..f75e62f265ea --- /dev/null +++ b/core/services/routing/billing/recorder_test.go @@ -0,0 +1,82 @@ +package billing + +import ( + "context" + "sync" + + "github.com/mudler/LocalAI/core/http/auth" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +// fakeBackend is a minimal StatsBackend that records what it received +// without actually writing anywhere. Lets the Recorder be tested in +// isolation from GORM/SQLite/in-memory specifics. +type fakeBackend struct { + mu sync.Mutex + records []*auth.UsageRecord +} + +func (f *fakeBackend) Record(_ context.Context, r *auth.UsageRecord) error { + f.mu.Lock() + defer f.mu.Unlock() + f.records = append(f.records, r) + return nil +} +func (f *fakeBackend) Aggregate(_ context.Context, _ AggregateQuery) ([]auth.UsageBucket, error) { + return nil, nil +} +func (f *fakeBackend) Close() error { return nil } + +var _ = Describe("Recorder", func() { + It("forwards to backend", func() { + fb := &fakeBackend{} + rec := NewRecorder(fb) + + r := &auth.UsageRecord{ + UserID: "u-1", + UserName: "alice", + Model: "qwen-7b", + Endpoint: "/v1/chat/completions", + PromptTokens: 10, + CompletionTokens: 5, + TotalTokens: 15, + } + Expect(rec.Record(context.Background(), r)).To(Succeed(), "recorder.Record") + + fb.mu.Lock() + defer fb.mu.Unlock() + Expect(fb.records).To(HaveLen(1)) + Expect(fb.records[0]).To(BeIdenticalTo(r), "recorder must pass the record through without copying") + }) + + // RecorderInvariantsPassWhenZero ensures legacy paths that don't + // populate the routing-extension fields still record successfully — + // the invariants only fire when a partial routing fact is set. + It("invariants pass when zero", func() { + rec := NewRecorder(&fakeBackend{}) + err := rec.Record(context.Background(), &auth.UsageRecord{ + UserID: "u-1", Model: "qwen-7b", Endpoint: "/v1/chat/completions", + }) + Expect(err).NotTo(HaveOccurred(), "zero routing fields must record cleanly") + }) + + // RecorderInvariantsDetectShrinkViolation: setting both pre/post + // prompt tokens with post > pre (impossible — PII can only shrink the + // prompt) should trigger the contract assertion. In a non-strict build + // the call still succeeds (logs + counter) but a routing_strict build + // would panic. We assert the call returns nil here; the strict-build + // behavior is covered by an integration test that compiles with the + // tag. + It("invariants detect shrink violation", func() { + rec := NewRecorder(&fakeBackend{}) + err := rec.Record(context.Background(), &auth.UsageRecord{ + UserID: "u-1", + Model: "qwen-7b", + PreFilterPromptTokens: 5, + PostFilterPromptTokens: 10, // post > pre is impossible by design + }) + Expect(err).NotTo(HaveOccurred(), "non-strict build must not error on invariant violation") + }) +}) diff --git a/core/services/routing/contract/contract.go b/core/services/routing/contract/contract.go new file mode 100644 index 000000000000..3e2d7605a457 --- /dev/null +++ b/core/services/routing/contract/contract.go @@ -0,0 +1,55 @@ +// Package contract provides runtime invariant assertions for the routing +// module. Each Invariant call logs at error level via xlog, increments a +// Prometheus counter, and (under build tag routing_strict) panics so test +// runs surface violations as test failures. +// +// The routing subsystems (billing, router, pii, proxy, admission) all +// publish invariants through this single package so that observability — +// dashboards, alerts, post-mortem analysis — joins on a single counter +// name regardless of which subsystem fired. +package contract + +import ( + "context" + + "github.com/mudler/xlog" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" +) + +var violationCounter metric.Int64Counter + +func init() { + meter := otel.Meter("github.com/mudler/LocalAI/core/services/routing") + c, err := meter.Int64Counter( + "localai_invariant_violation_total", + metric.WithDescription("Routing-module runtime invariant violations, labeled by name"), + ) + if err != nil { + // OTel API never returns an error in practice for a simple counter; + // log and fall back to a nil counter (Add becomes a no-op). + xlog.Error("failed to create invariant violation counter", "error", err) + return + } + violationCounter = c +} + +// Invariant asserts that cond is true. If false, it logs the violation +// and increments localai_invariant_violation_total{name=name}. Use +// fields for structured context (e.g., "model", "qwen-7b", "user", uid). +// +// In a build with -tags=routing_strict, a violation panics — meant for +// test suites and nightly E2E runs to surface drift. Production builds +// degrade silently into a metric so a single bad request does not crash +// the server. +func Invariant(name string, cond bool, fields ...any) { + if cond { + return + } + xlog.Error("routing invariant violated", append([]any{"name", name}, fields...)...) + if violationCounter != nil { + violationCounter.Add(context.Background(), 1, metric.WithAttributes(attribute.String("name", name))) + } + panicIfStrict(name, fields...) +} diff --git a/core/services/routing/contract/strict_off.go b/core/services/routing/contract/strict_off.go new file mode 100644 index 000000000000..f0f70829f9ab --- /dev/null +++ b/core/services/routing/contract/strict_off.go @@ -0,0 +1,5 @@ +//go:build !routing_strict + +package contract + +func panicIfStrict(name string, fields ...any) {} diff --git a/core/services/routing/contract/strict_on.go b/core/services/routing/contract/strict_on.go new file mode 100644 index 000000000000..7ea6e96e35dc --- /dev/null +++ b/core/services/routing/contract/strict_on.go @@ -0,0 +1,9 @@ +//go:build routing_strict + +package contract + +import "fmt" + +func panicIfStrict(name string, fields ...any) { + panic(fmt.Sprintf("routing invariant violated under -tags=routing_strict: %s %v", name, fields)) +} diff --git a/core/services/routing/pii/config.go b/core/services/routing/pii/config.go new file mode 100644 index 000000000000..64f7096750d2 --- /dev/null +++ b/core/services/routing/pii/config.go @@ -0,0 +1,71 @@ +package pii + +import ( + "fmt" + "os" + + "gopkg.in/yaml.v3" +) + +// FileConfig is the on-disk schema for pii.yaml. Each Pattern entry +// overrides the matching default by ID; missing fields fall back to +// the default. Unknown IDs are rejected at load time so an admin who +// fat-fingers a pattern name gets a clear error rather than a silent +// no-op. +type FileConfig struct { + Patterns []FilePattern `yaml:"patterns"` +} + +type FilePattern struct { + ID string `yaml:"id"` + Action Action `yaml:"action"` +} + +// LoadConfig reads pii.yaml from path and merges it on top of +// DefaultPatterns(). path == "" returns the defaults compiled and +// ready. The returned slice is already Compile()'d, so callers can +// pass it straight to NewRedactor. +func LoadConfig(path string) ([]Pattern, error) { + defaults := DefaultPatterns() + if path == "" { + return Compile(defaults) + } + + raw, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("pii: read config %q: %w", path, err) + } + var cfg FileConfig + if err := yaml.Unmarshal(raw, &cfg); err != nil { + return nil, fmt.Errorf("pii: parse config %q: %w", path, err) + } + + overrides := make(map[string]Action, len(cfg.Patterns)) + known := make(map[string]bool, len(defaults)) + for _, d := range defaults { + known[d.ID] = true + } + for _, p := range cfg.Patterns { + if !known[p.ID] { + return nil, fmt.Errorf("pii: unknown pattern id %q in %q", p.ID, path) + } + if p.Action == "" { + continue + } + switch p.Action { + case ActionMask, ActionBlock, ActionRouteLocal: + overrides[p.ID] = p.Action + default: + return nil, fmt.Errorf("pii: invalid action %q for pattern %q", p.Action, p.ID) + } + } + + merged := make([]Pattern, len(defaults)) + for i, d := range defaults { + if a, ok := overrides[d.ID]; ok { + d.Action = a + } + merged[i] = d + } + return Compile(merged) +} diff --git a/core/services/routing/pii/config_test.go b/core/services/routing/pii/config_test.go new file mode 100644 index 000000000000..650b804f01a7 --- /dev/null +++ b/core/services/routing/pii/config_test.go @@ -0,0 +1,56 @@ +package pii + +import ( + "os" + "path/filepath" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("LoadConfig", func() { + It("returns defaults when no path given", func() { + patterns, err := LoadConfig("") + Expect(err).NotTo(HaveOccurred()) + Expect(patterns).To(HaveLen(len(DefaultPatterns()))) + }) + + It("overrides action", func() { + dir := GinkgoT().TempDir() + path := filepath.Join(dir, "pii.yaml") + body := []byte(`patterns: + - id: email + action: block + - id: ssn + action: route_local +`) + Expect(os.WriteFile(path, body, 0o600)).To(Succeed()) + patterns, err := LoadConfig(path) + Expect(err).NotTo(HaveOccurred()) + + got := map[string]Action{} + for _, p := range patterns { + got[p.ID] = p.Action + } + Expect(got["email"]).To(Equal(ActionBlock)) + Expect(got["ssn"]).To(Equal(ActionRouteLocal)) + // Unmentioned patterns keep their default action. + Expect(got["credit_card"]).To(Equal(ActionMask), "credit_card default action lost") + }) + + It("rejects unknown id", func() { + dir := GinkgoT().TempDir() + path := filepath.Join(dir, "pii.yaml") + Expect(os.WriteFile(path, []byte("patterns:\n - id: nonsense\n action: mask\n"), 0o600)).To(Succeed()) + _, err := LoadConfig(path) + Expect(err).To(HaveOccurred(), "expected error on unknown pattern id") + }) + + It("rejects invalid action", func() { + dir := GinkgoT().TempDir() + path := filepath.Join(dir, "pii.yaml") + Expect(os.WriteFile(path, []byte("patterns:\n - id: email\n action: lolwhat\n"), 0o600)).To(Succeed()) + _, err := LoadConfig(path) + Expect(err).To(HaveOccurred(), "expected error on invalid action") + }) +}) diff --git a/core/services/routing/pii/middleware.go b/core/services/routing/pii/middleware.go new file mode 100644 index 000000000000..0994c32ba927 --- /dev/null +++ b/core/services/routing/pii/middleware.go @@ -0,0 +1,260 @@ +package pii + +import ( + "context" + "crypto/rand" + "encoding/hex" + "net/http" + "time" + + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/http/auth" + "github.com/mudler/LocalAI/core/services/routing/contract" + "github.com/mudler/xlog" +) + +// Echo context keys this middleware reads from / writes to. The string +// values must match the constants in core/http/middleware/context_keys.go; +// kept in sync by hand because echoing constants across packages would +// drag the http/middleware package into pii's import graph and create +// a cycle (http/middleware will import this one). +const ( + ctxKeyCorrelationID = "routing.correlation_id" + ctxKeyPIIEventID = "routing.pii_event_id" + ctxKeyLocalOnly = "routing.local_only" + // Must match the constants in core/http/middleware/request.go. + // Echoing them across packages would create an import cycle + // (http/middleware imports this package). Drift is caught by + // integration tests against the chat route. + ctxKeyParsedRequest = "LOCALAI_REQUEST" + ctxKeyModelConfig = "MODEL_CONFIG" +) + +// ModelPIIConfig is the duck-typed view this middleware needs of the +// per-model PII configuration carried on the echo context. *config.ModelConfig +// satisfies it via PIIIsEnabled / PIIPatternOverrides; the indirection +// keeps the pii package from importing core/config. +// +// Consumers of the override map: the action returned from PIIPatternOverrides +// is the raw YAML string (e.g. "block"). Validation against the canonical +// ActionMask/Block/RouteLocal constants happens here, so a typo in a model +// YAML logs and is ignored rather than panicking. +type ModelPIIConfig interface { + PIIIsEnabled() bool + PIIPatternOverrides() map[string]string +} + +// ScannedText is one piece of user text from the request. Index is +// opaque to the middleware — the Adapter implementation uses it to +// put the redacted version back in the right place. +type ScannedText struct { + Index int + Text string +} + +// Adapter pulls scannable text out of a parsed request and writes +// redacted text back. Provided as a per-API-shape function rather +// than an interface on the request type so the schema package does +// not have to depend on pii. Each route registration passes the +// adapter that knows its request format. +// +// The middleware calls Scan once per request and Apply once with +// every span the redactor returned. updates are guaranteed to share +// indices the adapter previously returned from Scan; the adapter +// must not assume input order matches scan order. +type Adapter struct { + Scan func(parsed any) []ScannedText + Apply func(parsed any, updates []ScannedText) +} + +// RequestMiddleware applies the regex PII tier to incoming chat +// requests. If the parsed request is not a MessageScanner (e.g., +// non-chat endpoints registered against the same group later), the +// middleware passes through. +// +// - On match with action=block: the request is rejected with 400 and +// a PIIEvent is recorded. The matched value is never echoed back +// to the client. +// - On match with action=mask: the redacted text replaces the +// original on the parsed request. PIIEvents are recorded. +// - On match with action=route_local: the original text is left +// intact, but the echo context is annotated so the (future) router +// middleware refuses cloud-proxy candidates. +// +// recorder is the Recorder on which to record events; nil disables +// recording (the redaction still happens). fallbackUser supplies the +// no-auth identity. The middleware writes ctxKeyPIIEventID on the echo +// context so the usage middleware can later cross-reference the event +// with the UsageRecord. +func RequestMiddleware(redactor *Redactor, store EventStore, adapter Adapter, fallbackUser *auth.User) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if redactor == nil || len(redactor.Patterns()) == 0 || adapter.Scan == nil { + return next(c) + } + + // Per-model gating: redaction is opt-in per model. If the + // resolved config disables PII for this model (the default + // for non-proxy backends), pass through immediately. We do + // this before parsing the request so a disabled model + // doesn't pay the regex scan cost. + if cfg, ok := c.Get(ctxKeyModelConfig).(ModelPIIConfig); ok { + if !cfg.PIIIsEnabled() { + return next(c) + } + } else { + // No ModelPIIConfig on context → fail-closed: skip + // redaction. This protects routes that wire the + // middleware before SetModelAndConfig runs (or non-chat + // routes that don't carry a model). The middleware was + // previously fail-open, applying the global redactor + // unconditionally; the new contract is per-model + // opt-in, and a missing model is treated as disabled. + return next(c) + } + + parsed := c.Get(ctxKeyParsedRequest) + if parsed == nil { + return next(c) + } + + user := auth.GetUser(c) + if user == nil { + user = fallbackUser + } + userID := "" + if user != nil { + userID = user.ID + } + correlationID, _ := c.Get(ctxKeyCorrelationID).(string) + + // Resolve per-model action overrides once per request. The + // raw map is YAML strings; convert to the typed Action set + // and silently drop unknown values rather than failing the + // request — model YAML typos shouldn't take chat down. + var overrides map[string]Action + if cfg, ok := c.Get(ctxKeyModelConfig).(ModelPIIConfig); ok { + if raw := cfg.PIIPatternOverrides(); len(raw) > 0 { + overrides = make(map[string]Action, len(raw)) + for id, action := range raw { + switch Action(action) { + case ActionMask, ActionBlock, ActionRouteLocal: + overrides[id] = Action(action) + default: + xlog.Warn("pii: ignoring unknown action in per-model override", + "pattern", id, "action", action) + } + } + } + } + + texts := adapter.Scan(parsed) + updates := make([]ScannedText, 0, len(texts)) + var blocked bool + var localOnly bool + var firstEventID string + + for _, st := range texts { + if st.Text == "" { + continue + } + res := redactor.RedactWithOverrides(st.Text, overrides) + if len(res.Spans) == 0 { + continue + } + + // Persist one event per span so admins can see exactly + // which patterns fired in which positions. The action + // recorded is the resolved one (after override), so the + // events log reflects what actually happened to the + // request, not the global default. + for _, span := range res.Spans { + action := actionForSpan(redactor.Patterns(), span.Pattern, overrides) + ev := PIIEvent{ + ID: newEventID(), + CorrelationID: correlationID, + UserID: userID, + Direction: DirectionIn, + PatternID: span.Pattern, + ByteOffset: span.Start, + Length: span.End - span.Start, + HashPrefix: span.HashPrefix, + Action: action, + CreatedAt: time.Now().UTC(), + } + if firstEventID == "" { + firstEventID = ev.ID + } + if store != nil { + if err := store.Record(context.Background(), ev); err != nil { + xlog.Error("pii: failed to record event", "error", err, "pattern", span.Pattern) + } + } + // Contract: every span must produce an event. + contract.Invariant( + "pii.event_per_span", + span.Pattern != "" && ev.PatternID != "", + "correlation", correlationID, "pattern", span.Pattern, + ) + } + + if res.Blocked { + blocked = true + } + if res.LocalOnly { + localOnly = true + } + updates = append(updates, ScannedText{Index: st.Index, Text: res.Redacted}) + } + + if blocked { + return c.JSON(http.StatusBadRequest, map[string]any{ + "error": map[string]string{ + "message": "request blocked by content policy (sensitive data detected)", + "type": "pii_blocked", + }, + "correlation_id": correlationID, + "pii_event_id": firstEventID, + }) + } + + if len(updates) > 0 && adapter.Apply != nil { + adapter.Apply(parsed, updates) + } + if firstEventID != "" { + c.Set(ctxKeyPIIEventID, firstEventID) + } + if localOnly { + c.Set(ctxKeyLocalOnly, true) + } + + return next(c) + } + } +} + +func actionForPattern(patterns []Pattern, id string) Action { + for _, p := range patterns { + if p.ID == id { + return p.Action + } + } + return ActionMask +} + +// actionForSpan returns the resolved action for a span, preferring a +// per-request override over the pattern's stored action. Used so the +// PIIEvent log reflects the action that actually fired (e.g., a model +// upgraded email from mask to block — the event row says "block"). +func actionForSpan(patterns []Pattern, id string, overrides map[string]Action) Action { + if action, ok := overrides[id]; ok { + return action + } + return actionForPattern(patterns, id) +} + +func newEventID() string { + var b [12]byte + _, _ = rand.Read(b[:]) + return "pii_" + hex.EncodeToString(b[:]) +} diff --git a/core/services/routing/pii/middleware_test.go b/core/services/routing/pii/middleware_test.go new file mode 100644 index 000000000000..d3bbbb2e7219 --- /dev/null +++ b/core/services/routing/pii/middleware_test.go @@ -0,0 +1,309 @@ +package pii + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/http/auth" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +// fakeRequest is the simplest possible parsed-request shape: a list of +// strings that the adapter scans and writes back. Lets us drive the +// middleware without dragging the real schema package in. +type fakeRequest struct { + Messages []string +} + +func fakeAdapter() Adapter { + return Adapter{ + Scan: func(parsed any) []ScannedText { + r, ok := parsed.(*fakeRequest) + if !ok { + return nil + } + out := make([]ScannedText, len(r.Messages)) + for i, m := range r.Messages { + out[i] = ScannedText{Index: i, Text: m} + } + return out + }, + Apply: func(parsed any, updates []ScannedText) { + r, ok := parsed.(*fakeRequest) + if !ok { + return + } + for _, u := range updates { + r.Messages[u.Index] = u.Text + } + }, + } +} + +func setRequestOnContext(req *fakeRequest) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + c.Set(ctxKeyParsedRequest, req) + return next(c) + } + } +} + +// fakeModelPIIConfig satisfies the duck-typed ModelPIIConfig interface +// the middleware expects on the echo context. The real implementation +// lives on *config.ModelConfig; using a fake here keeps these tests +// out of the core/config import graph. +type fakeModelPIIConfig struct { + enabled bool + overrides map[string]string +} + +func (f fakeModelPIIConfig) PIIIsEnabled() bool { return f.enabled } +func (f fakeModelPIIConfig) PIIPatternOverrides() map[string]string { return f.overrides } + +// withModelConfig wires a ModelPIIConfig onto the context so the +// middleware's per-model gate doesn't fail-closed during tests. Pass +// enabled=true for the default test path; explicit-false tests should +// use the gating spec further down instead. +func withModelConfig(cfg fakeModelPIIConfig) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + c.Set(ctxKeyModelConfig, cfg) + return next(c) + } + } +} + +func newTestRedactor(ids ...string) *Redactor { + patterns, err := Compile(pick(DefaultPatterns(), ids)) + ExpectWithOffset(1, err).NotTo(HaveOccurred(), "compile") + return NewRedactor(patterns) +} + +var _ = Describe("RequestMiddleware", func() { + It("masks email", func() { + red := newTestRedactor("email") + store := NewMemoryEventStore(0) + defer func() { _ = store.Close() }() + user := &auth.User{ID: "user-1", Name: "alice"} + + body := &fakeRequest{Messages: []string{"contact me at alice@example.com"}} + mw := RequestMiddleware(red, store, fakeAdapter(), nil) + + e := echo.New() + e.POST("/chat", func(c echo.Context) error { + return c.JSON(http.StatusOK, map[string]string{"ok": "yes"}) + }, setRequestOnContext(body), withModelConfig(fakeModelPIIConfig{enabled: true}), mw, func(next echo.HandlerFunc) echo.HandlerFunc { + // Inject the user as if upstream auth ran. + return func(c echo.Context) error { + c.Set("auth_user", user) + return next(c) + } + }) + + req := httptest.NewRequest(http.MethodPost, "/chat", strings.NewReader(`{}`)) + w := httptest.NewRecorder() + e.ServeHTTP(w, req) + + Expect(w.Code).To(Equal(http.StatusOK), "body=%s", w.Body.String()) + Expect(body.Messages[0]).NotTo(ContainSubstring("alice@example.com"), "request body should be redacted in place") + Expect(body.Messages[0]).To(ContainSubstring("[REDACTED:email]")) + + events, err := store.List(context.Background(), ListQuery{Limit: 100}) + Expect(err).NotTo(HaveOccurred(), "list events") + Expect(events).To(HaveLen(1)) + Expect(events[0].PatternID).To(Equal("email")) + Expect(events[0].Direction).To(Equal(DirectionIn)) + }) + + It("blocks api key", func() { + red := newTestRedactor("api_key_prefix") + store := NewMemoryEventStore(0) + defer func() { _ = store.Close() }() + + body := &fakeRequest{Messages: []string{"my key is sk-abcdefghijklmnopqrstuvwxyz0123456789"}} + mw := RequestMiddleware(red, store, fakeAdapter(), nil) + + e := echo.New() + handlerCalled := false + e.POST("/chat", func(c echo.Context) error { + handlerCalled = true + return c.JSON(http.StatusOK, map[string]string{"ok": "yes"}) + }, setRequestOnContext(body), withModelConfig(fakeModelPIIConfig{enabled: true}), mw) + + req := httptest.NewRequest(http.MethodPost, "/chat", strings.NewReader(`{}`)) + w := httptest.NewRecorder() + e.ServeHTTP(w, req) + + Expect(w.Code).To(Equal(http.StatusBadRequest), "expected 400 on block; body=%s", w.Body.String()) + Expect(handlerCalled).To(BeFalse(), "handler must not run when request is blocked") + // Ensure the matched value never appears in the response body. + Expect(w.Body.String()).NotTo(ContainSubstring("abcdefghijklmnopqrstuvwxyz0123456789"), "blocked response leaks the matched value") + + var resp map[string]any + Expect(json.Unmarshal(w.Body.Bytes(), &resp)).To(Succeed()) + errBlock, ok := resp["error"].(map[string]any) + Expect(ok).To(BeTrue()) + Expect(errBlock["type"]).To(Equal("pii_blocked")) + }) + + It("route_local sets context flag", func() { + patterns, _ := Compile([]Pattern{{ + ID: "email", Description: "Email", Action: ActionRouteLocal, MaxMatchLength: 254, + }}) + red := NewRedactor(patterns) + store := NewMemoryEventStore(0) + defer func() { _ = store.Close() }() + + body := &fakeRequest{Messages: []string{"hi at alice@example.com"}} + mw := RequestMiddleware(red, store, fakeAdapter(), nil) + + e := echo.New() + var observedLocalOnly bool + e.POST("/chat", func(c echo.Context) error { + v, _ := c.Get(ctxKeyLocalOnly).(bool) + observedLocalOnly = v + return c.JSON(http.StatusOK, map[string]string{"ok": "yes"}) + }, setRequestOnContext(body), withModelConfig(fakeModelPIIConfig{enabled: true}), mw) + + req := httptest.NewRequest(http.MethodPost, "/chat", strings.NewReader(`{}`)) + w := httptest.NewRecorder() + e.ServeHTTP(w, req) + + Expect(w.Code).To(Equal(http.StatusOK)) + Expect(observedLocalOnly).To(BeTrue(), "ctxKeyLocalOnly should be true on route_local match") + // route_local does NOT mutate the body — the model still sees the email. + Expect(body.Messages[0]).To(ContainSubstring("alice@example.com"), "route_local should leave text intact") + }) + + It("no match passes through", func() { + red := newTestRedactor() + store := NewMemoryEventStore(0) + defer func() { _ = store.Close() }() + + body := &fakeRequest{Messages: []string{"perfectly innocent text"}} + mw := RequestMiddleware(red, store, fakeAdapter(), nil) + + e := echo.New() + e.POST("/chat", func(c echo.Context) error { + return c.JSON(http.StatusOK, map[string]string{"ok": "yes"}) + }, setRequestOnContext(body), withModelConfig(fakeModelPIIConfig{enabled: true}), mw) + + req := httptest.NewRequest(http.MethodPost, "/chat", strings.NewReader(`{}`)) + w := httptest.NewRecorder() + e.ServeHTTP(w, req) + + Expect(w.Code).To(Equal(http.StatusOK)) + Expect(body.Messages[0]).To(Equal("perfectly innocent text"), "body should be untouched") + events, _ := store.List(context.Background(), ListQuery{Limit: 100}) + Expect(events).To(BeEmpty(), "expected 0 events on no-match input") + }) + + It("skips when model config disabled", func() { + // Per-model gating is the new contract: a model with PIIIsEnabled + // returning false must bypass redaction entirely, even if the + // global redactor has matching patterns. + red := newTestRedactor("email") + store := NewMemoryEventStore(0) + defer func() { _ = store.Close() }() + + body := &fakeRequest{Messages: []string{"contact alice@example.com"}} + mw := RequestMiddleware(red, store, fakeAdapter(), nil) + + e := echo.New() + e.POST("/chat", func(c echo.Context) error { + return c.JSON(http.StatusOK, map[string]string{"ok": "yes"}) + }, setRequestOnContext(body), withModelConfig(fakeModelPIIConfig{enabled: false}), mw) + + req := httptest.NewRequest(http.MethodPost, "/chat", strings.NewReader(`{}`)) + w := httptest.NewRecorder() + e.ServeHTTP(w, req) + + Expect(w.Code).To(Equal(http.StatusOK)) + Expect(body.Messages[0]).To(ContainSubstring("alice@example.com"), "disabled model must not redact") + events, _ := store.List(context.Background(), ListQuery{Limit: 100}) + Expect(events).To(BeEmpty(), "disabled model must produce no events") + }) + + It("fails closed without model config", func() { + // Routes that wire the middleware before SetModelAndConfig, or + // non-chat routes lacking a model, hit this path. The contract + // is fail-closed: pass through without redaction so a missing + // model can't accidentally leak through global defaults. + red := newTestRedactor("email") + store := NewMemoryEventStore(0) + defer func() { _ = store.Close() }() + + body := &fakeRequest{Messages: []string{"contact alice@example.com"}} + mw := RequestMiddleware(red, store, fakeAdapter(), nil) + + e := echo.New() + // Note: no withModelConfig in the chain. + e.POST("/chat", func(c echo.Context) error { + return c.JSON(http.StatusOK, map[string]string{"ok": "yes"}) + }, setRequestOnContext(body), mw) + + req := httptest.NewRequest(http.MethodPost, "/chat", strings.NewReader(`{}`)) + w := httptest.NewRecorder() + e.ServeHTTP(w, req) + + Expect(w.Code).To(Equal(http.StatusOK)) + Expect(body.Messages[0]).To(ContainSubstring("alice@example.com"), "missing ModelPIIConfig should fail-closed (no redaction)") + }) + + It("applies per-model override", func() { + // email defaults to mask. A per-model override upgrades it to + // block. The middleware short-circuits with 400, the request + // body is never touched, and the events log records action=block. + red := newTestRedactor("email") + store := NewMemoryEventStore(0) + defer func() { _ = store.Close() }() + + body := &fakeRequest{Messages: []string{"contact alice@example.com"}} + mw := RequestMiddleware(red, store, fakeAdapter(), nil) + + e := echo.New() + handlerCalled := false + e.POST("/chat", func(c echo.Context) error { + handlerCalled = true + return c.JSON(http.StatusOK, map[string]string{"ok": "yes"}) + }, setRequestOnContext(body), + withModelConfig(fakeModelPIIConfig{ + enabled: true, + overrides: map[string]string{"email": "block"}, + }), mw) + + req := httptest.NewRequest(http.MethodPost, "/chat", strings.NewReader(`{}`)) + w := httptest.NewRecorder() + e.ServeHTTP(w, req) + + Expect(w.Code).To(Equal(http.StatusBadRequest), "expected 400 from override-block; body=%s", w.Body.String()) + Expect(handlerCalled).To(BeFalse(), "handler must not run when override blocks") + events, _ := store.List(context.Background(), ListQuery{Limit: 100}) + Expect(events).To(HaveLen(1)) + Expect(events[0].Action).To(Equal(ActionBlock), "event must record the resolved (override) action") + }) + + It("nil redactor is passthrough", func() { + body := &fakeRequest{Messages: []string{"alice@example.com"}} + mw := RequestMiddleware(nil, nil, fakeAdapter(), nil) + + e := echo.New() + e.POST("/chat", func(c echo.Context) error { + return c.JSON(http.StatusOK, map[string]string{"ok": "yes"}) + }, setRequestOnContext(body), withModelConfig(fakeModelPIIConfig{enabled: true}), mw) + + req := httptest.NewRequest(http.MethodPost, "/chat", strings.NewReader(`{}`)) + w := httptest.NewRecorder() + e.ServeHTTP(w, req) + + Expect(w.Code).To(Equal(http.StatusOK)) + Expect(body.Messages[0]).To(Equal("alice@example.com"), "nil redactor must be a no-op") + }) +}) diff --git a/core/services/routing/pii/ner.go b/core/services/routing/pii/ner.go new file mode 100644 index 000000000000..57d25cded5eb --- /dev/null +++ b/core/services/routing/pii/ner.go @@ -0,0 +1,97 @@ +package pii + +import ( + "context" + "fmt" +) + +// NERDetector is the contract the redactor's encoder/NER tier expects. +// One detector wraps one loaded token-classification model. The +// implementation (e.g. via the transformers gRPC backend) is wired in +// from core/application; this package stays free of core/backend +// imports so the redactor remains unit-testable with a stub detector. +// +// Implementations must honour ctx cancellation — NER round-trips can +// take tens of milliseconds and a client-aborted request should not +// stall the redactor. +type NERDetector interface { + Detect(ctx context.Context, text string) ([]NEREntity, error) +} + +// NEREntity is one detected span. Start/End are byte offsets into the +// text passed to Detect — half-open, addressing text[Start:End]. The +// Group is the entity label (e.g. "PER", "LOC", "EMAIL"); the exact +// vocabulary depends on the model. The redactor's action map keys off +// Group, so admins configure per-label behaviour. +type NEREntity struct { + Group string + Start int + End int + Score float32 +} + +// NERConfig configures the encoder tier for one redactor invocation. +// Per-request so the same Redactor instance can serve multiple models +// (each with its own NER preferences) without per-model redactor +// instances. +type NERConfig struct { + // Detector is the loaded model. nil disables the NER tier — the + // redactor falls back to the regex-only path with no allocation + // cost. + Detector NERDetector + + // MinScore is the confidence floor; entities below this are dropped + // before being merged into the hit list. 0 keeps every result the + // detector returns. + MinScore float32 + + // EntityActions maps entity_group → Action. Unknown groups (groups + // the detector returns but the admin didn't configure) use + // DefaultAction. Empty map + DefaultAction empty = NER detections + // recorded as audit rows but no redaction applied. + EntityActions map[string]Action + + // DefaultAction is applied when a detected entity_group has no + // explicit override. Empty (zero value) means "drop unmatched + // entities silently" — useful when the model returns a broad + // taxonomy but the admin only cares about a subset. + DefaultAction Action +} + +// ResolveAction returns the action configured for a detected entity +// group, falling back to DefaultAction. Returns ("", false) when the +// entity should be ignored entirely (no override + no default). +func (c NERConfig) ResolveAction(group string) (Action, bool) { + if a, ok := c.EntityActions[group]; ok { + return a, true + } + if c.DefaultAction != "" { + return c.DefaultAction, true + } + return "", false +} + +// nerPatternID returns the synthetic pattern ID that audit rows carry +// for NER hits. Prefixing with "ner:" keeps these distinguishable from +// regex pattern IDs in the events tab and in filter queries; admins +// can switch off a single entity type with the same Disabled-pattern +// machinery used for regex. +func nerPatternID(group string) string { + return "ner:" + group +} + +// errNERDetector is a NERDetector that always returns the wrapped +// error. Exported via NewErrNERDetector so the application wiring can +// surface "model not loaded" without taking on a fmt-only dependency +// just to format the error. +type errNERDetector struct{ err error } + +func (e errNERDetector) Detect(context.Context, string) ([]NEREntity, error) { + return nil, e.err +} + +// NewErrNERDetector returns a detector whose Detect always fails with +// the supplied error. Used by the application-level adapter when the +// configured NER model can't be resolved — the redactor surfaces a +// clear runtime error rather than silently skipping the NER tier. +func NewErrNERDetector(msg string) NERDetector { return errNERDetector{err: fmt.Errorf("%s", msg)} } diff --git a/core/services/routing/pii/ner_test.go b/core/services/routing/pii/ner_test.go new file mode 100644 index 000000000000..b4d822234a6c --- /dev/null +++ b/core/services/routing/pii/ner_test.go @@ -0,0 +1,174 @@ +package pii + +import ( + "context" + "errors" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +// stubNERDetector returns a fixed slice of entities and tracks call +// count so tests can assert the detector isn't called when text is +// empty / no patterns / detector disabled. +type stubNERDetector struct { + entities []NEREntity + err error + calls int +} + +func (s *stubNERDetector) Detect(_ context.Context, _ string) ([]NEREntity, error) { + s.calls++ + return s.entities, s.err +} + +var _ = Describe("RedactWithNER", func() { + It("nil detector is regex-only", func() { + // When the NER tier is disabled (Detector == nil) the redactor + // must behave exactly like the existing regex-only path — no + // detector call, same Result shape, no error. + r := NewRedactor([]Pattern{pickEmail()}) + res, err := r.RedactWithNER(context.Background(), "ping me at alice@example.com", nil, NERConfig{}) + Expect(err).NotTo(HaveOccurred()) + Expect(res.Redacted).To(ContainSubstring("[REDACTED:email]"), "regex tier should still run when Detector is nil") + }) + + It("applies entity actions", func() { + det := &stubNERDetector{entities: []NEREntity{ + {Group: "PER", Start: 6, End: 11, Score: 0.95}, // "Alice" in "Hi I'm Alice today" + }} + r := NewRedactor(nil) + res, err := r.RedactWithNER(context.Background(), "Hi I'm Alice today", nil, NERConfig{ + Detector: det, + EntityActions: map[string]Action{"PER": ActionMask}, + }) + Expect(err).NotTo(HaveOccurred()) + Expect(det.calls).To(Equal(1)) + Expect(res.Redacted).To(ContainSubstring("[REDACTED:ner:PER]")) + Expect(res.Spans).To(HaveLen(1)) + Expect(res.Spans[0].Pattern).To(Equal("ner:PER")) + }) + + It("filters below MinScore", func() { + det := &stubNERDetector{entities: []NEREntity{ + {Group: "PER", Start: 0, End: 5, Score: 0.20}, + }} + r := NewRedactor(nil) + res, err := r.RedactWithNER(context.Background(), "Alice", nil, NERConfig{ + Detector: det, + MinScore: 0.50, + EntityActions: map[string]Action{"PER": ActionMask}, + }) + Expect(err).NotTo(HaveOccurred()) + Expect(res.Redacted).To(Equal("Alice"), "low-confidence entity should be dropped") + }) + + It("default action applies to unconfigured groups", func() { + det := &stubNERDetector{entities: []NEREntity{ + {Group: "ORG", Start: 7, End: 11, Score: 0.9}, // "Acme" in "Hello, Acme!" + }} + r := NewRedactor(nil) + res, err := r.RedactWithNER(context.Background(), "Hello, Acme!", nil, NERConfig{ + Detector: det, + DefaultAction: ActionMask, + }) + Expect(err).NotTo(HaveOccurred()) + Expect(res.Redacted).To(ContainSubstring("[REDACTED:ner:ORG]"), "DefaultAction should apply to ORG") + }) + + It("drops unconfigured groups with no default", func() { + // EntityActions has no entry for ORG and DefaultAction is empty — + // the detected entity must be ignored entirely (no audit row, no + // redaction). + det := &stubNERDetector{entities: []NEREntity{ + {Group: "ORG", Start: 0, End: 4, Score: 0.9}, + }} + r := NewRedactor(nil) + res, err := r.RedactWithNER(context.Background(), "Acme", nil, NERConfig{ + Detector: det, + EntityActions: map[string]Action{"PER": ActionMask}, // ORG is unconfigured + }) + Expect(err).NotTo(HaveOccurred()) + Expect(res.Redacted).To(Equal("Acme")) + Expect(res.Spans).To(BeEmpty()) + }) + + It("overlapping hits keep stronger action", func() { + // Regex marks 0..10 as mask; NER marks 5..15 as block. After + // merge, the union 0..15 keeps the strongest action (block). + pat := Pattern{ID: "test", Action: ActionMask, regex: rangeRegex(0, 10)} + r := NewRedactor([]Pattern{pat}) + det := &stubNERDetector{entities: []NEREntity{ + {Group: "PER", Start: 5, End: 15, Score: 0.9}, + }} + text := "0123456789ABCDEF" + res, err := r.RedactWithNER(context.Background(), text, nil, NERConfig{ + Detector: det, + EntityActions: map[string]Action{"PER": ActionBlock}, + }) + Expect(err).NotTo(HaveOccurred()) + Expect(res.Blocked).To(BeTrue(), "overlapping mask+block should set Blocked=true") + }) + + It("detector error returns regex result and error", func() { + // Fail-open: when the NER detector errors, the redactor still + // returns regex-tier hits so an offline NER backend doesn't strip + // the cheap protection. Caller can read the error and decide + // whether to surface it. + det := &stubNERDetector{err: errors.New("backend offline")} + r := NewRedactor([]Pattern{pickEmail()}) + res, err := r.RedactWithNER(context.Background(), "ping alice@example.com", nil, NERConfig{ + Detector: det, + DefaultAction: ActionMask, + }) + Expect(err).To(HaveOccurred(), "expected detector error to surface") + Expect(res.Redacted).To(ContainSubstring("[REDACTED:email]"), "regex tier should still apply on NER failure") + }) + + It("out-of-bounds offsets are skipped", func() { + // A misconfigured / buggy backend could return offsets past the + // end of text. The redactor must not panic on slice OOB. + det := &stubNERDetector{entities: []NEREntity{ + {Group: "PER", Start: 0, End: 999, Score: 0.9}, + {Group: "PER", Start: -1, End: 3, Score: 0.9}, + {Group: "PER", Start: 5, End: 5, Score: 0.9}, // zero-length + }} + r := NewRedactor(nil) + res, err := r.RedactWithNER(context.Background(), "Alice", nil, NERConfig{ + Detector: det, + DefaultAction: ActionMask, + }) + Expect(err).NotTo(HaveOccurred()) + Expect(res.Redacted).To(Equal("Alice")) + Expect(res.Spans).To(BeEmpty()) + }) +}) + +// --- test helpers --- + +// rangeMatcher is a deterministic regexpMatcher stub: it claims one +// fixed range regardless of input. Lets the overlap-merge test +// produce a known regex/NER intersection without depending on a real +// compiled regex. +type rangeMatcher struct{ start, end int } + +func (m rangeMatcher) FindAllStringIndex(_ string, _ int) [][]int { + return [][]int{{m.start, m.end}} +} + +func rangeRegex(start, end int) regexpMatcher { return rangeMatcher{start: start, end: end} } + +// pickEmail returns the compiled "email" pattern from DefaultPatterns +// — the NER tests use it as the regex tier's contribution. +func pickEmail() Pattern { + for _, p := range DefaultPatterns() { + if p.ID == "email" { + compiled, err := Compile([]Pattern{p}) + ExpectWithOffset(1, err).NotTo(HaveOccurred(), "compile") + return compiled[0] + } + } + Fail("email pattern missing from DefaultPatterns") + return Pattern{} +} + diff --git a/core/services/routing/pii/patterns.go b/core/services/routing/pii/patterns.go new file mode 100644 index 000000000000..1e1ef50a14f7 --- /dev/null +++ b/core/services/routing/pii/patterns.go @@ -0,0 +1,188 @@ +package pii + +import ( + "fmt" + "regexp" + "strings" +) + +// regexpMatcher is a thin wrapper so tests can swap in a deterministic +// matcher without touching the regexp package. Real usage uses +// regexpMatcherFromPattern; tests can construct fakes. +type regexpMatcher interface { + FindAllStringIndex(s string, n int) [][]int +} + +type goRegexp struct{ r *regexp.Regexp } + +func (g goRegexp) FindAllStringIndex(s string, n int) [][]int { + return g.r.FindAllStringIndex(s, n) +} + +// DefaultPatterns returns the built-in regex set. Each entry includes +// a conservative MaxMatchLength so the streaming filter can size its +// tail buffer without re-parsing the regex at runtime. +// +// Caveats by design: +// - The phone pattern matches international and US formats but does +// not validate area codes. False positives on numbers that look +// phone-like (e.g., timestamps in some formats) are accepted in +// return for reliable coverage. +// - The credit card pattern requires the Luhn check (verifyLuhn) to +// reduce false positives — random 16-digit strings won't match. +// - The API-key pattern targets common provider prefixes (sk-, pk-, +// xoxb-, ghp_, github_pat_) rather than guessing entropy. Adding +// new providers should append a new Pattern, not extend an +// existing alternation, so the admin UI can show one row per +// provider with its own toggle. +func DefaultPatterns() []Pattern { + return []Pattern{ + { + ID: "email", + Description: "Email address", + Action: ActionMask, + MaxMatchLength: 254, // RFC 5321 max + }, + { + ID: "phone", + Description: "Phone number (international or US format)", + Action: ActionMask, + MaxMatchLength: 24, + }, + { + ID: "ssn", + Description: "US Social Security Number (NNN-NN-NNNN)", + Action: ActionMask, + MaxMatchLength: 11, + }, + { + ID: "credit_card", + Description: "Credit card number (Luhn-verified)", + Action: ActionMask, + MaxMatchLength: 19, + }, + { + ID: "ipv4", + Description: "IPv4 address", + Action: ActionMask, + MaxMatchLength: 15, + }, + { + ID: "api_key_prefix", + Description: "Common API key prefixes (sk-, pk-, xoxb-, ghp_, github_pat_)", + Action: ActionBlock, // tighter default — leaked credentials are higher harm + MaxMatchLength: 200, + }, + } +} + +// patternRegexps maps Pattern.ID to its compiled regex. Kept separate +// from the Pattern struct so DefaultPatterns can be data-only and +// tests can swap matchers via Compile(). +var patternRegexps = map[string]*regexp.Regexp{ + // Pragmatic email — does not implement RFC 5322 in full (no one + // sane does in a regex). Catches the common shape; the encoder + // NER tier (future) catches edge cases. + "email": regexp.MustCompile(`(?i)[a-z0-9._%+\-]+@[a-z0-9.\-]+\.[a-z]{2,}`), + // US: (123) 456-7890, 123-456-7890, 123.456.7890, 1234567890. + // International: +-- with separators. + "phone": regexp.MustCompile(`(?:\+?\d{1,3}[\s\-.]?)?(?:\(\d{3}\)|\d{3})[\s\-.]?\d{3}[\s\-.]?\d{4}`), + "ssn": regexp.MustCompile(`\b\d{3}-\d{2}-\d{4}\b`), + // 13-19 digit Luhn-eligible runs. The verifier in match() rejects + // non-Luhn matches. + "credit_card": regexp.MustCompile(`\b(?:\d[ \-]?){13,19}\b`), + "ipv4": regexp.MustCompile(`\b(?:\d{1,3}\.){3}\d{1,3}\b`), + // Common provider prefixes; each alternative is a separate + // well-known marker rather than a permissive entropy match. + "api_key_prefix": regexp.MustCompile(`(?:sk-[A-Za-z0-9]{20,}|pk-[A-Za-z0-9]{20,}|xoxb-[A-Za-z0-9\-]{20,}|ghp_[A-Za-z0-9]{20,}|github_pat_[A-Za-z0-9_]{20,})`), +} + +// Compile attaches matchers to each pattern. Patterns whose ID is not +// in patternRegexps are returned as a typed error so an admin who +// adds a custom pattern via config gets a clear "no regex registered" +// message instead of silent skip. +func Compile(patterns []Pattern) ([]Pattern, error) { + out := make([]Pattern, len(patterns)) + for i, p := range patterns { + r, ok := patternRegexps[p.ID] + if !ok { + return nil, fmt.Errorf("pii: no regex registered for pattern id %q", p.ID) + } + p.regex = goRegexp{r: r} + out[i] = p + } + return out, nil +} + +// VerifyMatch applies pattern-specific post-checks (e.g. Luhn for +// credit_card). Returns the original match or "" to discard it. +func VerifyMatch(patternID, candidate string) string { + switch patternID { + case "credit_card": + digits := stripNonDigits(candidate) + if len(digits) < 13 || len(digits) > 19 { + return "" + } + if !verifyLuhn(digits) { + return "" + } + case "ipv4": + // Each octet must be 0..255. The regex allows 0..999 since + // regex isn't great at numeric ranges; we tighten here. + for oct := range strings.SplitSeq(candidate, ".") { + n := 0 + for _, c := range oct { + if c < '0' || c > '9' { + return "" + } + n = n*10 + int(c-'0') + } + if n > 255 { + return "" + } + } + } + return candidate +} + +func stripNonDigits(s string) string { + var b strings.Builder + b.Grow(len(s)) + for _, c := range s { + if c >= '0' && c <= '9' { + b.WriteRune(c) + } + } + return b.String() +} + +// verifyLuhn implements the Luhn checksum used by credit-card numbers. +// Returns true iff the digits pass. +func verifyLuhn(digits string) bool { + sum := 0 + double := false + for i := len(digits) - 1; i >= 0; i-- { + d := int(digits[i] - '0') + if double { + d *= 2 + if d > 9 { + d -= 9 + } + } + sum += d + double = !double + } + return sum%10 == 0 +} + +// MaxPatternLength returns the longest MaxMatchLength across the input +// patterns. Used by the streaming filter to size its tail buffer. +func MaxPatternLength(patterns []Pattern) int { + max := 0 + for _, p := range patterns { + if p.MaxMatchLength > max { + max = p.MaxMatchLength + } + } + return max +} diff --git a/core/services/routing/pii/pii_suite_test.go b/core/services/routing/pii/pii_suite_test.go new file mode 100644 index 000000000000..634b66df4928 --- /dev/null +++ b/core/services/routing/pii/pii_suite_test.go @@ -0,0 +1,13 @@ +package pii + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestPii(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "pii test suite") +} diff --git a/core/services/routing/pii/redactor.go b/core/services/routing/pii/redactor.go new file mode 100644 index 000000000000..b70192cacd2f --- /dev/null +++ b/core/services/routing/pii/redactor.go @@ -0,0 +1,342 @@ +package pii + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "slices" + "sort" + "strings" + "sync" +) + +// rawHit is one detection — regex-side or NER-side — before +// overlap-merging. Lifted to file scope so the regex and NER +// collectors can both produce them and feed the same merge/emit step. +type rawHit struct { + patternID string + action Action + start int + end int +} + +// Redactor scans text against a configured pattern set and applies the +// per-pattern action. The pattern set itself is mutable at runtime via +// SetAction (the /api/pii/patterns/:id admin endpoint mutates it +// in-place); reads are guarded by a mutex so concurrent requests stay +// race-free. +type Redactor struct { + mu sync.RWMutex + patterns []Pattern + maxLen int +} + +// NewRedactor constructs a redactor from a list of compiled patterns +// (use Compile() to compile config-loaded patterns first). nil +// patterns is valid and produces a no-op redactor — convenient for the +// "PII disabled" deployment. +func NewRedactor(patterns []Pattern) *Redactor { + return &Redactor{ + patterns: patterns, + maxLen: MaxPatternLength(patterns), + } +} + +// MaxPatternLength is exposed so the streaming wrapper can size its +// tail buffer to match. +func (r *Redactor) MaxPatternLength() int { return r.maxLen } + +// Patterns returns a copy of the configured pattern set so callers can +// iterate without holding the redactor lock. The compiled regexes are +// shared — they are immutable once built. +func (r *Redactor) Patterns() []Pattern { + r.mu.RLock() + defer r.mu.RUnlock() + return slices.Clone(r.patterns) +} + +// SetAction overrides the action for a single pattern. Used by the +// /api/pii/patterns/:id admin endpoint and the set_pii_pattern_action +// MCP tool — transient until process restart unless persisted via +// --pii-config. +// +// Publishes a new slice so concurrent Redact callers iterating an +// older snapshot don't race on the per-element Action string (Go +// strings are not atomic two-word values). +func (r *Redactor) SetAction(id string, action Action) error { + if action != ActionMask && action != ActionBlock && action != ActionRouteLocal { + return fmt.Errorf("unknown action %q (must be mask, block, or route_local)", action) + } + r.mu.Lock() + defer r.mu.Unlock() + for i := range r.patterns { + if r.patterns[i].ID == id { + next := slices.Clone(r.patterns) + next[i].Action = action + r.patterns = next + return nil + } + } + return fmt.Errorf("unknown pattern id %q", id) +} + +// SetDisabled toggles a pattern's enabled state in the live redactor. +// Same COW publish as SetAction. +func (r *Redactor) SetDisabled(id string, disabled bool) error { + r.mu.Lock() + defer r.mu.Unlock() + for i := range r.patterns { + if r.patterns[i].ID == id { + next := slices.Clone(r.patterns) + next[i].Disabled = disabled + r.patterns = next + return nil + } + } + return fmt.Errorf("unknown pattern id %q", id) +} + +// Redact is a thin wrapper for callers that don't need per-request +// action overrides. It applies each pattern's compiled-in default +// action. +func (r *Redactor) Redact(text string) Result { + return r.RedactWithOverrides(text, nil) +} + +// RedactWithOverrides scans text and returns the result. The override +// map is keyed by pattern id; when present, the value replaces the +// pattern's compiled-in action for this call only — the redactor's +// stored action is unchanged. Pattern ids missing from the map use +// their stored action. +// +// For every match it records a Span (with HashPrefix, never the value) +// and applies the resolved Action: +// - block: sets Result.Blocked, leaves text intact (caller decides +// whether to surface the redacted form). +// - mask: replaces the span with maskFor(pattern.ID). +// - route_local: sets Result.LocalOnly, leaves text intact. +// +// Spans are returned in the original input's coordinate system so the +// PIIEvent record can be written without re-running the scan. +func (r *Redactor) RedactWithOverrides(text string, overrides map[string]Action) Result { + return r.redact(context.Background(), text, overrides, NERConfig{}) +} + +// RedactWithNER is the encoder-tier variant: runs both the regex tier +// (with per-pattern overrides) and the NER tier, merges hits, and +// emits one redacted output. A nil NERConfig.Detector skips the NER +// pass — callers can hand the same path the same NERConfig{} whether +// or not the model has NER configured. +// +// Errors from the NER detector are returned alongside a best-effort +// regex-only Result so the caller can decide whether to fail open +// (return the regex Result, log the error) or fail closed (refuse the +// request). The regex tier never errors. +func (r *Redactor) RedactWithNER(ctx context.Context, text string, overrides map[string]Action, nerCfg NERConfig) (Result, error) { + if nerCfg.Detector == nil { + return r.redact(ctx, text, overrides, nerCfg), nil + } + hits, err := r.collectRegexHits(text, overrides) + if err != nil { + return Result{Redacted: text}, err + } + nerHits, nerErr := collectNERHits(ctx, text, nerCfg) + if nerErr != nil { + // Return the regex-only result so a NER-backend outage doesn't + // strip the cheap protection. Caller decides fail-open vs + // fail-closed via the returned error. + return mergeAndEmit(text, hits), nerErr + } + return mergeAndEmit(text, append(hits, nerHits...)), nil +} + +// redact is the internal regex-only entry point. RedactWithOverrides +// is the public wrapper; RedactWithNER routes through here only when +// the NER detector is nil (so the call site doesn't need a separate +// "regex-only" code path). +func (r *Redactor) redact(_ context.Context, text string, overrides map[string]Action, _ NERConfig) Result { + hits, _ := r.collectRegexHits(text, overrides) + return mergeAndEmit(text, hits) +} + +// collectRegexHits walks the configured pattern set against text and +// returns each verified match as a rawHit. The redactor lock is held +// only long enough to snapshot the pattern slice — regex evaluation +// runs lock-free against the snapshot, so SetAction/SetDisabled don't +// stall a long-running Redact. +func (r *Redactor) collectRegexHits(text string, overrides map[string]Action) ([]rawHit, error) { + r.mu.RLock() + patterns := r.patterns + r.mu.RUnlock() + + if len(patterns) == 0 || text == "" { + return nil, nil + } + var hits []rawHit + for _, p := range patterns { + if p.regex == nil { + // Pattern declared but Compile() not called. Skip rather + // than panic; the caller already saw an error from Compile. + continue + } + if p.Disabled { + continue + } + action := p.Action + if override, ok := overrides[p.ID]; ok { + action = override + } + idxs := p.regex.FindAllStringIndex(text, -1) + for _, idx := range idxs { + candidate := text[idx[0]:idx[1]] + if VerifyMatch(p.ID, candidate) == "" { + continue + } + hits = append(hits, rawHit{ + patternID: p.ID, + action: action, + start: idx[0], + end: idx[1], + }) + } + } + return hits, nil +} + +// collectNERHits invokes the configured NERDetector and converts each +// returned entity into a rawHit using the NERConfig's action map. +// Entities below MinScore or with no resolved action are dropped — the +// detector doesn't know which entity groups the admin cares about, so +// the redactor filters here. +func collectNERHits(ctx context.Context, text string, cfg NERConfig) ([]rawHit, error) { + if cfg.Detector == nil || text == "" { + return nil, nil + } + entities, err := cfg.Detector.Detect(ctx, text) + if err != nil { + return nil, err + } + var hits []rawHit + for _, e := range entities { + if e.Score < cfg.MinScore { + continue + } + action, ok := cfg.ResolveAction(e.Group) + if !ok { + continue + } + if e.Start < 0 || e.End <= e.Start || e.End > len(text) { + // Defensive: the backend should return byte offsets into + // the original text, but a misconfigured model could + // produce garbage. Skip rather than panic on slice OOB. + continue + } + hits = append(hits, rawHit{ + patternID: nerPatternID(e.Group), + action: action, + start: e.Start, + end: e.End, + }) + } + return hits, nil +} + +// mergeAndEmit handles the overlap-merge + masked-output step that +// regex-only and combined regex+NER redactions both perform. Sorts by +// start (stable on equal starts by descending action strength), drops +// overlapping hits in favour of the stronger action, and walks the +// text once to emit replacement spans. +func mergeAndEmit(text string, hits []rawHit) Result { + if len(hits) == 0 { + return Result{Redacted: text} + } + // Sort and deduplicate overlapping hits — when two patterns claim + // the same span (e.g., a credit-card-shaped value also scans as + // digits, or NER tags a span the regex also caught), keep the one + // with the strongest action. Order: block > route_local > mask. + sort.Slice(hits, func(i, j int) bool { + if hits[i].start != hits[j].start { + return hits[i].start < hits[j].start + } + return actionRank(hits[i].action) > actionRank(hits[j].action) + }) + merged := hits[:0] + for _, h := range hits { + if len(merged) > 0 { + last := &merged[len(merged)-1] + if h.start < last.end { + if actionRank(h.action) > actionRank(last.action) { + last.action = h.action + last.patternID = h.patternID + } + if h.end > last.end { + last.end = h.end + } + continue + } + } + merged = append(merged, h) + } + + res := Result{} + var out strings.Builder + out.Grow(len(text)) + cursor := 0 + for _, h := range merged { + matched := text[h.start:h.end] + span := Span{ + Start: h.start, + End: h.end, + Pattern: h.patternID, + HashPrefix: hashPrefix(matched), + } + res.Spans = append(res.Spans, span) + + out.WriteString(text[cursor:h.start]) + switch h.action { + case ActionBlock: + res.Blocked = true + out.WriteString(matched) + case ActionRouteLocal: + res.LocalOnly = true + out.WriteString(matched) + default: + out.WriteString(maskFor(h.patternID)) + } + cursor = h.end + } + out.WriteString(text[cursor:]) + res.Redacted = out.String() + return res +} + +// maskFor returns the placeholder that replaces a matched span. The +// shape "[REDACTED:]" is intentionally stable — it surfaces the +// pattern id back to the model, which is sometimes useful (e.g., the +// model can say "I see you redacted an email"). Admins who want a +// less informative replacement can build one in front of this. +func maskFor(patternID string) string { + return "[REDACTED:" + patternID + "]" +} + +// hashPrefix returns the first 8 chars of sha256(value). Two calls +// with the same input produce the same prefix so an admin auditing +// the PIIEvent log can spot a recurring leak ("the same SSN appears +// 200 times this hour") without ever recovering the value. +func hashPrefix(value string) string { + sum := sha256.Sum256([]byte(value)) + return hex.EncodeToString(sum[:])[:8] +} + +func actionRank(a Action) int { + switch a { + case ActionBlock: + return 3 + case ActionRouteLocal: + return 2 + case ActionMask: + return 1 + } + return 0 +} diff --git a/core/services/routing/pii/redactor_race_test.go b/core/services/routing/pii/redactor_race_test.go new file mode 100644 index 000000000000..f926ea64dea0 --- /dev/null +++ b/core/services/routing/pii/redactor_race_test.go @@ -0,0 +1,66 @@ +package pii + +import ( + "sync" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +// Redactor_SetActionConcurrentRedact pins the SetAction copy-on- +// write contract: concurrent SetAction must not race with readers +// iterating an older patterns snapshot. Run with -race to surface the +// regression that motivated the COW (in-place mutation of the +// per-element Action string is not atomic). +var _ = Describe("Redactor", func() { + It("SetAction concurrent with Redact", func() { + patterns, err := Compile(DefaultPatterns()) + Expect(err).NotTo(HaveOccurred(), "compile") + r := NewRedactor(patterns) + + const writers = 4 + const readers = 8 + const iter = 100 + + var wg sync.WaitGroup + stop := make(chan struct{}) + + for w := 0; w < writers; w++ { + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < iter; i++ { + select { + case <-stop: + return + default: + } + action := ActionMask + if i%2 == 0 { + action = ActionBlock + } + _ = r.SetAction("email", action) + } + }() + } + + for rd := 0; rd < readers; rd++ { + wg.Add(1) + go func() { + defer wg.Done() + text := "contact alice@example.com please" + for i := 0; i < iter*2; i++ { + select { + case <-stop: + return + default: + } + _ = r.Redact(text) + } + }() + } + + wg.Wait() + close(stop) + }) +}) diff --git a/core/services/routing/pii/redactor_test.go b/core/services/routing/pii/redactor_test.go new file mode 100644 index 000000000000..a084e4d542f5 --- /dev/null +++ b/core/services/routing/pii/redactor_test.go @@ -0,0 +1,184 @@ +package pii + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func mustCompile(ids ...string) []Pattern { + all := DefaultPatterns() + if len(ids) == 0 { + out, err := Compile(all) + ExpectWithOffset(1, err).NotTo(HaveOccurred(), "compile") + return out + } + pickP := pick(all, ids) + out, err := Compile(pickP) + ExpectWithOffset(1, err).NotTo(HaveOccurred(), "compile") + return out +} + +func pick(all []Pattern, ids []string) []Pattern { + keep := map[string]bool{} + for _, id := range ids { + keep[id] = true + } + var out []Pattern + for _, p := range all { + if keep[p.ID] { + out = append(out, p) + } + } + return out +} + +var _ = Describe("Redactor", func() { + It("masks email", func() { + r := NewRedactor(mustCompile("email")) + res := r.Redact("Contact me at alice@example.com any time.") + Expect(res.Blocked).To(BeFalse(), "email is mask-action by default, should not block") + Expect(res.Redacted).To(ContainSubstring("[REDACTED:email]")) + Expect(res.Redacted).NotTo(ContainSubstring("alice@example.com")) + Expect(res.Spans).To(HaveLen(1)) + Expect(res.Spans[0].HashPrefix).NotTo(BeEmpty(), "hash prefix must be set so audits can dedupe leaks") + }) + + It("masks SSN", func() { + r := NewRedactor(mustCompile("ssn")) + res := r.Redact("call me about SSN 123-45-6789 please") + Expect(res.Redacted).To(ContainSubstring("[REDACTED:ssn]")) + }) + + It("uses Luhn for credit card", func() { + r := NewRedactor(mustCompile("credit_card")) + + // 4111 1111 1111 1111 — canonical Luhn-valid Visa test number. + good := r.Redact("card: 4111 1111 1111 1111") + Expect(good.Spans).To(HaveLen(1)) + Expect(good.Redacted).To(ContainSubstring("[REDACTED:credit_card]")) + + // 4111 1111 1111 1112 — same shape, fails Luhn. Must NOT match. + bad := r.Redact("card: 4111 1111 1111 1112") + Expect(bad.Spans).To(BeEmpty(), "Luhn-invalid 16-digit run must not be redacted") + Expect(bad.Redacted).To(ContainSubstring("1112"), "Luhn-invalid input should pass through untouched") + }) + + It("validates IPv4 octets", func() { + r := NewRedactor(mustCompile("ipv4")) + + good := r.Redact("server at 192.168.1.10 is up") + Expect(good.Spans).To(HaveLen(1)) + + // 999.999.999.999 — regex matches but octet > 255 must reject. + bad := r.Redact("not an ip: 999.999.999.999") + Expect(bad.Spans).To(BeEmpty(), "ipv4 with octet>255 must not match") + }) + + It("api_key defaults to block", func() { + r := NewRedactor(mustCompile("api_key_prefix")) + res := r.Redact("here's a token sk-abcdefghijklmnopqrstuvwxyz0123456789 to use") + Expect(res.Blocked).To(BeTrue(), "api_key default action is block; Result.Blocked must be true") + // The redacted output keeps the matched value when blocking — the + // caller is expected to refuse the request, not to forward a partial. + Expect(res.Redacted).To(ContainSubstring("sk-abcdefghijklmn"), "blocked actions leave the matched span intact for caller inspection") + }) + + It("preserves non-matching text", func() { + r := NewRedactor(mustCompile()) // all default patterns + in := "no PII here at all, just words and numbers like 42 and 1.5" + res := r.Redact(in) + Expect(res.Redacted).To(Equal(in), "non-PII input should pass through unchanged") + Expect(res.Spans).To(BeEmpty()) + }) + + It("handles empty input", func() { + r := NewRedactor(mustCompile()) + res := r.Redact("") + Expect(res.Redacted).To(BeEmpty()) + Expect(res.Blocked).To(BeFalse()) + Expect(res.LocalOnly).To(BeFalse()) + Expect(res.Spans).To(BeEmpty()) + }) + + It("nil patterns is a no-op", func() { + // Disabled-PII deployment: pii.NewRedactor(nil) is a no-op. + r := NewRedactor(nil) + res := r.Redact("alice@example.com sent it") + Expect(res.Redacted).To(Equal("alice@example.com sent it")) + }) + + It("hash prefix is stable", func() { + r := NewRedactor(mustCompile("email")) + a := r.Redact("a@b.com") + b := r.Redact("hi a@b.com again") + Expect(a.Spans).To(HaveLen(1)) + Expect(b.Spans).To(HaveLen(1)) + Expect(a.Spans[0].HashPrefix).To(Equal(b.Spans[0].HashPrefix), "same matched value must produce same hash prefix") + }) +}) + +var _ = Describe("Compile", func() { + It("rejects unknown pattern id", func() { + _, err := Compile([]Pattern{{ID: "nonexistent", Action: ActionMask}}) + Expect(err).To(HaveOccurred(), "Compile must error on unknown pattern id") + }) +}) + +var _ = Describe("MaxPatternLength", func() { + It("returns the longest pattern's max length", func() { + patterns := mustCompile("email", "ssn") + got := MaxPatternLength(patterns) + // email is the longer of the two (254). The streaming filter + // will use this to size its tail buffer. + Expect(got).To(Equal(254)) + }) +}) + +var _ = Describe("RedactWithOverrides", func() { + It("upgrades action", func() { + // email is mask by default; the per-model override turns it into a + // hard block for one request without mutating the redactor. + r := NewRedactor(mustCompile("email")) + res := r.RedactWithOverrides("contact alice@example.com", + map[string]Action{"email": ActionBlock}) + Expect(res.Blocked).To(BeTrue(), "override should have set Blocked") + // Block leaves the value intact (the caller short-circuits the + // request) — the redactor never echoes the matched text. + Expect(res.Redacted).To(ContainSubstring("alice@example.com"), "block leaves text intact for the caller to discard") + // Stored action is unchanged so a subsequent default Redact still + // masks rather than blocks. + res2 := r.Redact("contact alice@example.com") + Expect(res2.Blocked).To(BeFalse(), "override must not mutate stored action") + }) + + It("ignores unknown IDs", func() { + // An override for a pattern this redactor doesn't know about is a + // no-op rather than an error — per-model configs may reference + // patterns from a wider catalogue than the active redactor holds. + r := NewRedactor(mustCompile("email")) + res := r.RedactWithOverrides("contact alice@example.com", + map[string]Action{"ssn": ActionBlock}) + Expect(res.Blocked).To(BeFalse(), "ssn override against email-only redactor must be no-op") + }) +}) + +var _ = Describe("SetAction", func() { + It("swaps in place", func() { + r := NewRedactor(mustCompile("email")) + Expect(r.SetAction("email", ActionRouteLocal)).To(Succeed()) + res := r.Redact("contact alice@example.com") + Expect(res.LocalOnly).To(BeTrue(), "expected LocalOnly after SetAction(route_local)") + Expect(res.Blocked).To(BeFalse(), "SetAction(route_local) should not block") + }) + + It("rejects unknown id", func() { + r := NewRedactor(mustCompile("email")) + Expect(r.SetAction("nonexistent", ActionMask)).NotTo(Succeed(), "expected error for unknown pattern id") + }) + + It("rejects unknown action", func() { + r := NewRedactor(mustCompile("email")) + Expect(r.SetAction("email", Action("frobnicate"))).NotTo(Succeed(), "expected error for unknown action") + }) +}) + diff --git a/core/services/routing/pii/store.go b/core/services/routing/pii/store.go new file mode 100644 index 000000000000..2b3a5df6540b --- /dev/null +++ b/core/services/routing/pii/store.go @@ -0,0 +1,130 @@ +package pii + +import ( + "context" + "sync" +) + +// EventStore persists PIIEvent records. Mirrors the StatsBackend +// abstraction in the billing package: in-process by default so a +// no-auth box still gets an event log; a future GORM-backed impl +// (when --auth is on) will reuse the auth DB. +type EventStore interface { + Record(ctx context.Context, e PIIEvent) error + List(ctx context.Context, q ListQuery) ([]PIIEvent, error) + // Count returns the number of events currently stored. Used by + // /api/middleware/status to surface a "recent_event_count" without + // pulling the whole list (the dashboard polls this on a refresh). + Count(ctx context.Context) (int, error) + Close() error +} + +// ListQuery filters the event log. CorrelationID, UserID, PatternID, +// Kind each scope the search; empty values match anything. Limit ≤ 0 +// returns up to a default cap. +type ListQuery struct { + CorrelationID string + UserID string + PatternID string + Kind EventKind + Limit int +} + +// NewMemoryEventStore returns an in-memory ring-buffer event store. +// capacity ≤ 0 picks 10_000. +// +// Why a ring: PII events are noisy; a chatty deployment can produce +// thousands per minute. A bounded buffer keeps memory predictable, +// and the GORM impl (when added) handles long-term retention. +func NewMemoryEventStore(capacity int) EventStore { + if capacity <= 0 { + capacity = 10_000 + } + return &memoryEventStore{ + ring: make([]PIIEvent, capacity), + cap: capacity, + } +} + +type memoryEventStore struct { + mu sync.RWMutex + ring []PIIEvent + cap int + cursor int + full bool +} + +func (s *memoryEventStore) Record(_ context.Context, e PIIEvent) error { + s.mu.Lock() + defer s.mu.Unlock() + s.ring[s.cursor] = e + s.cursor++ + if s.cursor == s.cap { + s.cursor = 0 + s.full = true + } + return nil +} + +func (s *memoryEventStore) List(_ context.Context, q ListQuery) ([]PIIEvent, error) { + limit := q.Limit + if limit <= 0 { + limit = 1000 + } + s.mu.RLock() + defer s.mu.RUnlock() + + out := make([]PIIEvent, 0, limit) + scan := func(e PIIEvent) bool { + if e.ID == "" { + return false // empty slot + } + if q.CorrelationID != "" && e.CorrelationID != q.CorrelationID { + return false + } + if q.UserID != "" && e.UserID != q.UserID { + return false + } + if q.PatternID != "" && e.PatternID != q.PatternID { + return false + } + if q.Kind != "" && e.ResolvedKind() != q.Kind { + return false + } + out = append(out, e) + return len(out) >= limit + } + + // Walk newest-first: cursor-1 down to 0, then cap-1 down to cursor + // when the ring has wrapped. + if s.full { + for i := s.cursor - 1; i >= 0; i-- { + if scan(s.ring[i]) { + return out, nil + } + } + for i := s.cap - 1; i >= s.cursor; i-- { + if scan(s.ring[i]) { + return out, nil + } + } + } else { + for i := s.cursor - 1; i >= 0; i-- { + if scan(s.ring[i]) { + return out, nil + } + } + } + return out, nil +} + +func (s *memoryEventStore) Count(_ context.Context) (int, error) { + s.mu.RLock() + defer s.mu.RUnlock() + if s.full { + return s.cap, nil + } + return s.cursor, nil +} + +func (s *memoryEventStore) Close() error { return nil } diff --git a/core/services/routing/pii/stream.go b/core/services/routing/pii/stream.go new file mode 100644 index 000000000000..93a5cd261f75 --- /dev/null +++ b/core/services/routing/pii/stream.go @@ -0,0 +1,197 @@ +package pii + +import ( + "context" + "crypto/rand" + "encoding/hex" + "strings" + "time" + "unicode/utf8" +) + +// StreamFilter applies the regex PII tier to a streaming response, +// chunk by chunk, with a buffered-emit invariant: for any active +// pattern with bounded max-length L, the filter never emits the +// trailing L-1 characters of the cumulative input until either +// +// (a) more text arrives that disambiguates the boundary, or +// (b) the stream closes (Drain). +// +// That keeps the redactor honest across chunk splits — an email +// arriving as "alice@" + "example.com" still masks the same way as +// "alice@example.com" arriving in one piece. +// +// Action handling in stream mode differs from the request-side +// middleware. Earlier chunks of the response are already on the wire +// by the time later chunks are scanned, so a "block" can't actually +// reject the request. We remap block → mask for redaction purposes +// while still recording PIIEvent rows with action="block" so audits +// surface the original intent ("the model would have leaked X here, +// suppressed in flight"). route_local on the output side is a no-op +// (the dispatch decision was already made on the request side). +// +// StreamFilter is NOT safe for concurrent use across goroutines; one +// instance per response stream. +type StreamFilter struct { + redactor *Redactor + maskOverrides map[string]Action // block → mask map used for redaction + auditActions map[string]Action // original action per pattern, used for events + store EventStore + correlationID string + userID string + holdLen int + buffer strings.Builder + emittedBytes int +} + +// NewStreamFilter constructs a per-response filter. modelOverrides is +// the per-model action override map (same shape the request-side +// middleware uses); it can be nil when the model only accepts global +// defaults. +// +// store may be nil — events are then computed but not persisted, which +// is what the chat handler does when --disable-stats is set. +func NewStreamFilter(redactor *Redactor, modelOverrides map[string]Action, store EventStore, correlationID, userID string) *StreamFilter { + if redactor == nil { + return &StreamFilter{} + } + + patterns := redactor.Patterns() + + // auditActions: the action we *would* have applied if this match + // occurred on the request side. Honours the per-model override. + auditActions := make(map[string]Action, len(patterns)) + for _, p := range patterns { + auditActions[p.ID] = p.Action + } + for id, action := range modelOverrides { + auditActions[id] = action + } + + // maskOverrides: the action we actually apply to the stream. Same + // as auditActions, but with every block remapped to mask. + maskOverrides := make(map[string]Action, len(auditActions)) + for id, action := range auditActions { + if action == ActionBlock { + maskOverrides[id] = ActionMask + } else { + maskOverrides[id] = action + } + } + + return &StreamFilter{ + redactor: redactor, + maskOverrides: maskOverrides, + auditActions: auditActions, + store: store, + correlationID: correlationID, + userID: userID, + holdLen: redactor.MaxPatternLength() - 1, + } +} + +// Push appends new text to the filter's buffer and returns the prefix +// safe to emit downstream — the cumulative input minus a tail of +// holdLen characters that might still be the start of a longer match. +// Returned text has masks already applied. +// +// Returns an empty string when not enough text has arrived to clear +// the hold window. +func (sf *StreamFilter) Push(text string) string { + if sf.redactor == nil || sf.holdLen <= 0 { + return text + } + sf.buffer.WriteString(text) + bufStr := sf.buffer.String() + n := len(bufStr) + + if n <= sf.holdLen { + return "" + } + + emitBoundary := n - sf.holdLen + + // Scan the entire buffer. A match whose start is before the + // boundary but whose end runs past it crosses the window — pull + // the boundary back to match.start so the pattern stays whole in + // the buffer for the next Push to scan again. + full := sf.redactor.RedactWithOverrides(bufStr, sf.maskOverrides) + for _, span := range full.Spans { + if span.Start < emitBoundary && span.End > emitBoundary { + emitBoundary = span.Start + } + } + + // holdLen is byte-sized but a chunk boundary may land mid-codepoint. + // Snap back to the nearest rune start so neither the emitted prefix + // nor the retained tail contains a split codepoint — otherwise the + // next regex scan over an invalid-UTF-8 prefix could mis-match. + for emitBoundary > 0 && emitBoundary < n && !utf8.RuneStart(bufStr[emitBoundary]) { + emitBoundary-- + } + + if emitBoundary <= 0 { + return "" + } + + emitted := sf.applyAndEmit(bufStr[:emitBoundary]) + sf.buffer.Reset() + sf.buffer.WriteString(bufStr[emitBoundary:]) + return emitted +} + +// Drain emits whatever's left in the buffer with all matches applied. +// Call exactly once when the stream closes — repeat calls return the +// empty string. +func (sf *StreamFilter) Drain() string { + if sf.redactor == nil { + return sf.buffer.String() + } + bufStr := sf.buffer.String() + if bufStr == "" { + return "" + } + emitted := sf.applyAndEmit(bufStr) + sf.buffer.Reset() + return emitted +} + +// applyAndEmit runs the redactor over a committed-for-emit fragment, +// substitutes mask/block placeholders inline, and records one +// PIIEvent per matched span (with the audit action, not the masked +// one). ByteOffset is referenced to the cumulative emitted output so +// admins can correlate event positions against the streamed body. +func (sf *StreamFilter) applyAndEmit(fragment string) string { + res := sf.redactor.RedactWithOverrides(fragment, sf.maskOverrides) + output := res.Redacted + + if len(res.Spans) > 0 { + now := time.Now().UTC() + for _, span := range res.Spans { + ev := PIIEvent{ + ID: newStreamEventID(), + CorrelationID: sf.correlationID, + UserID: sf.userID, + Direction: DirectionOut, + PatternID: span.Pattern, + ByteOffset: sf.emittedBytes + span.Start, + Length: span.End - span.Start, + HashPrefix: span.HashPrefix, + Action: sf.auditActions[span.Pattern], + CreatedAt: now, + } + if sf.store != nil { + _ = sf.store.Record(context.Background(), ev) + } + } + } + + sf.emittedBytes += len(fragment) + return output +} + +func newStreamEventID() string { + var b [12]byte + _, _ = rand.Read(b[:]) + return "pii_" + hex.EncodeToString(b[:]) +} diff --git a/core/services/routing/pii/stream_test.go b/core/services/routing/pii/stream_test.go new file mode 100644 index 000000000000..037020609d85 --- /dev/null +++ b/core/services/routing/pii/stream_test.go @@ -0,0 +1,184 @@ +package pii + +import ( + "context" + "fmt" + "math/rand" + "strings" + "unicode/utf8" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func newStreamRedactor(ids ...string) *Redactor { + all := DefaultPatterns() + chosen := all + if len(ids) > 0 { + chosen = pick(all, ids) + } + patterns, err := Compile(chosen) + ExpectWithOffset(1, err).NotTo(HaveOccurred(), "compile") + return NewRedactor(patterns) +} + +var _ = Describe("StreamFilter", func() { + It("masks across chunks", func() { + // The most important streaming test: an email split arbitrarily + // across chunk boundaries must mask exactly the same way as one + // arriving in a single Push. + red := newStreamRedactor("email") + sf := NewStreamFilter(red, nil, nil, "", "") + + // "alice@example.com" (17 bytes) split between '@' and 'e'. + out := "" + out += sf.Push("hi alice@") + out += sf.Push("example.com! end") + out += sf.Drain() + + Expect(out).NotTo(ContainSubstring("alice@example.com"), "stream leaked email across chunk boundary") + Expect(out).To(ContainSubstring("[REDACTED:email]")) + }) + + It("block becomes mask", func() { + // api_key_prefix is block by default. In stream mode the earlier + // chunks are already on the wire so block is impossible — the + // filter remaps to mask while still recording action="block" so + // the audit log keeps the original intent. + red := newStreamRedactor("api_key_prefix") + store := NewMemoryEventStore(0) + defer func() { _ = store.Close() }() + sf := NewStreamFilter(red, nil, store, "corr-1", "user-1") + + out := sf.Push("here is your token: sk-abcdefghijklmnopqrstuvwxyz0123456789 done") + out += sf.Drain() + + Expect(out).NotTo(ContainSubstring("abcdefghijklmnopqrstuvwxyz0123456789"), "block-in-stream must mask, leaked the value") + Expect(out).To(ContainSubstring("[REDACTED:api_key_prefix]")) + + events, _ := store.List(context.Background(), ListQuery{Limit: 10}) + Expect(events).To(HaveLen(1)) + Expect(events[0].Action).To(Equal(ActionBlock), "audit must record original block action") + Expect(events[0].Direction).To(Equal(DirectionOut), "stream events must be DirectionOut") + }) + + It("no match passthrough", func() { + red := newStreamRedactor("email") + sf := NewStreamFilter(red, nil, nil, "", "") + out := sf.Push("perfectly clean text that should") + sf.Push(" pass through unchanged.") + sf.Drain() + Expect(out).To(Equal("perfectly clean text that should pass through unchanged.")) + }) + + It("nil redactor passthrough", func() { + // --disable-pii path: NewStreamFilter(nil, ...) returns a filter + // that just forwards Push input verbatim. + sf := NewStreamFilter(nil, nil, nil, "", "") + out := sf.Push("any old text including alice@example.com") + sf.Drain() + Expect(out).To(Equal("any old text including alice@example.com")) + }) + + It("per-model overrides", func() { + // email defaults to mask; per-model override upgrades to block. + // In stream mode the override still maps to mask placeholder, but + // the audit event records action="block". + red := newStreamRedactor("email") + store := NewMemoryEventStore(0) + defer func() { _ = store.Close() }() + sf := NewStreamFilter(red, map[string]Action{"email": ActionBlock}, store, "corr-2", "user-2") + + out := sf.Push("contact alice@example.com please") + sf.Drain() + Expect(out).NotTo(ContainSubstring("alice@example.com"), "override block-in-stream must mask") + events, _ := store.List(context.Background(), ListQuery{Limit: 10}) + Expect(events).To(HaveLen(1)) + Expect(events[0].Action).To(Equal(ActionBlock)) + }) + + // StreamFilter_BufferedEmitInvariant feeds the redactor a corpus + // one rune at a time, randomly chunked, and asserts: + // + // 1. Across all (input, splitting) pairs, the cumulative emitted + // output never contains any of the secret values that were + // embedded in the input. + // 2. The output, fully drained, equals what Redact would have + // produced on the unsplit input. + // + // This is the load-bearing property of streaming PII: regardless of + // where chunks split, the emitted bytes cannot contain a value that a + // single-shot redactor would have masked. + It("buffered emit invariant", func() { + corpus := []struct { + text string + secrets []string + }{ + {"contact alice@example.com or bob@example.org", []string{"alice@example.com", "bob@example.org"}}, + {"my SSN is 123-45-6789 and his is 987-65-4321", []string{"123-45-6789", "987-65-4321"}}, + {"sk-abcdefghijklmnopqrstuvwxyz0123456789 leaked", []string{"sk-abcdefghijklmnopqrstuvwxyz0123456789"}}, + {"repeats: alice@example.com / alice@example.com / alice@example.com", []string{"alice@example.com"}}, + // Multibyte UTF-8 corpora pin the rune-boundary snap in + // StreamFilter.Push: holdLen is byte-sized, so a chunk boundary + // may land mid-codepoint. Without the snap, the retained tail + // has a partial codepoint and the next regex scan can mis-align. + // Each entry mixes ASCII secrets with surrounding multibyte text + // so a byte-aligned cut would land inside a CJK or accented + // character on at least some splits. + {"こんにちは alice@example.com さようなら", []string{"alice@example.com"}}, + {"クレジットカード: 4111-1111-1111-1111 終わり", []string{"4111-1111-1111-1111"}}, + {"naïve résumé: alice@example.com, façade", []string{"alice@example.com"}}, + } + + red := newStreamRedactor() // all default patterns + rng := rand.New(rand.NewSource(1)) // seeded for reproducibility + + for _, tc := range corpus { + for trial := 0; trial < 10; trial++ { + sf := NewStreamFilter(red, nil, nil, "", "") + var out strings.Builder + for i := 0; i < utf8.RuneCountInString(tc.text); { + // Random chunk size 1-8 runes, never crossing the end. + chunk := 1 + rng.Intn(8) + if i+chunk > utf8.RuneCountInString(tc.text) { + chunk = utf8.RuneCountInString(tc.text) - i + } + out.WriteString(sf.Push(stringSlice(tc.text, i, i+chunk))) + i += chunk + } + out.WriteString(sf.Drain()) + result := out.String() + + // Property 1: no secret value appears anywhere in the + // output. + for _, secret := range tc.secrets { + Expect(result).NotTo(ContainSubstring(secret), + fmt.Sprintf("trial %d: secret %q leaked through streaming\n input: %q\n output: %q", trial, secret, tc.text, result)) + } + + // Property 2: the streamed output equals what a single-shot + // Redact would have produced on the same input. (Block + // patterns get masked in stream mode, so we compare against + // a remapped redaction.) + expected := singleShotMaskAll(red, tc.text) + Expect(result).To(Equal(expected), + fmt.Sprintf("trial %d: stream != single-shot\n input: %q", trial, tc.text)) + } + } + }) +}) + +// singleShotMaskAll runs the redactor in one pass with all blocks +// remapped to mask — the same view the StreamFilter produces. +func singleShotMaskAll(red *Redactor, text string) string { + patterns := red.Patterns() + overrides := make(map[string]Action, len(patterns)) + for _, p := range patterns { + if p.Action == ActionBlock { + overrides[p.ID] = ActionMask + } + } + res := red.RedactWithOverrides(text, overrides) + return res.Redacted +} + +func stringSlice(s string, fromRune, toRune int) string { + runes := []rune(s) + return string(runes[fromRune:toRune]) +} diff --git a/core/services/routing/pii/types.go b/core/services/routing/pii/types.go new file mode 100644 index 000000000000..afdcc7ad44be --- /dev/null +++ b/core/services/routing/pii/types.go @@ -0,0 +1,170 @@ +// Package pii implements the routing-module PII / sensitive-data filter. +// +// Two tiers are planned (per the routing plan): +// +// 1. Regex tier: cheap, deterministic patterns (email, phone, SSN, credit +// card with Luhn, IPs, API-key prefixes). Always on by default. +// 2. Encoder NER tier: a HF token-classification model exposed via a new +// gRPC TokenClassify RPC. Out of scope for this slice — added later. +// +// This file ships tier 1 only. The Pipeline interface is shaped so tier 2 +// drops in without changing call sites. +// +// Configuration model: each pattern has an Action (block | mask | +// route_local). Actions are evaluated in this order: +// - block: short-circuits the request with an error (the middleware +// returns 400 to the client). +// - mask: replaces the matched span with ReplacementFor(pattern). +// - route_local: leaves the text alone but sets a context flag the +// router (subsystem 2) treats as "this request must stay on a local +// model" — never crosses the boundary to a cloud proxy backend. +package pii + +import "time" + +// Action describes what to do when a pattern matches. +type Action string + +const ( + // ActionMask replaces the matched span with a placeholder. The + // default. Lets the request proceed to the backend with the + // sensitive token removed. + ActionMask Action = "mask" + + // ActionBlock rejects the entire request. The middleware returns + // 400 with an error referencing the matched pattern_id (but never + // the matched value). + ActionBlock Action = "block" + + // ActionRouteLocal leaves the text intact but flags the request so + // the content router will refuse to dispatch it to a cloud proxy + // backend. Useful when a deployment trusts local models with + // sensitive data but not external providers. + ActionRouteLocal Action = "route_local" +) + +// Direction tags whether a PIIEvent fired on input (request body before +// dispatch) or output (response stream after generation). Stored in the +// PIIEvent record so admins can see which direction PII appeared in. +type Direction string + +const ( + DirectionIn Direction = "in" + DirectionOut Direction = "out" +) + +// Span is a half-open byte range [Start, End) within a scanned string. +// Pattern is the rule that matched. Text never holds the matched value +// itself — call sites that need the value (for masking) do their own +// substring slicing; call sites that need to log it strip it via +// HashPrefix. +type Span struct { + Start int + End int + Pattern string // matches Pattern.ID + HashPrefix string // first 8 chars of sha256(matched value); audit-safe +} + +// Result is what Redact returns. Redacted is the input string after +// all configured masks were applied. Spans are the original positions +// of every match (in the original input — not the redacted output — +// so admins can see where things were). +// +// Blocked is true iff at least one matched pattern had Action=block; +// the call site must enforce this by returning a 400 / refusing to +// dispatch. +// +// LocalOnly is true iff at least one matched pattern had +// Action=route_local. The router middleware reads this and constrains +// candidate selection. +type Result struct { + Redacted string + Spans []Span + Blocked bool + LocalOnly bool +} + +// Pattern is one configurable rule. Description is shown in the admin +// UI alongside the pattern; the regex itself stays an implementation +// detail (a leak-prone admin showing an SSN regex with a sample value +// in the field is a risk we deliberately design around). +type Pattern struct { + ID string + Description string + Action Action + // Disabled skips the pattern entirely when true — useful for + // admins who want to keep a regex around (visible in the UI) but + // turn it off without removing the YAML entry. Default-false so + // every existing pattern stays active without touching its config. + Disabled bool + // MaxMatchLength is the longest possible match in characters. The + // streaming filter (subsystem 3, follow-up commit) uses this to + // size its tail buffer. For regex patterns we compute it at + // compile time from the pattern's structure when possible, or set + // a conservative upper bound otherwise. + MaxMatchLength int + + // internal — populated by Compile(). + regex regexpMatcher +} + +// EventKind classifies a stored audit event. The store is shared by the +// PII filter (its original use), the MITM proxy (connect decisions and +// per-request traffic counters), and — when subsystem 2 lands — the +// content router. Filtering by Kind keeps unrelated event types out of +// each other's UI tabs without splitting storage. +// +// An empty Kind is treated as KindPII so rows written before this field +// existed still classify correctly. +type EventKind string + +const ( + KindPII EventKind = "pii" + KindProxyConnect EventKind = "proxy_connect" + KindProxyTraffic EventKind = "proxy_traffic" + // KindAdmission rows are written by the admission middleware + // (routing subsystem 5) when a request is rejected because a + // model's MaxConcurrent ceiling is full. The Host field carries + // the model name (overloading the existing column rather than + // adding a new one — admins read it as "the thing that was + // busy"); StatusCode is 503. + KindAdmission EventKind = "admission" +) + +// PIIEvent is the persisted record. The Hash field is the first 8 chars +// of sha256(matched value) — enough to deduplicate "is this the same +// thing as last time" without ever storing the value itself. +// +// Proxy-event fields (Host, Intercepted, Bytes*, StatusCode, DurationMS) +// are only set when Kind is KindProxyConnect or KindProxyTraffic. They +// hold connection-level metadata for audit and basic diagnostics — never +// request bodies. Use the API/backend traces to inspect contents. +type PIIEvent struct { + ID string `json:"id"` + Kind EventKind `json:"kind,omitempty"` + CorrelationID string `json:"correlation_id,omitempty"` + UserID string `json:"user_id,omitempty"` + Direction Direction `json:"direction,omitempty"` + PatternID string `json:"pattern_id,omitempty"` + ByteOffset int `json:"byte_offset,omitempty"` + Length int `json:"length,omitempty"` + HashPrefix string `json:"hash_prefix,omitempty"` + Action Action `json:"action,omitempty"` + CreatedAt time.Time `json:"created_at"` + + Host string `json:"host,omitempty"` + Intercepted *bool `json:"intercepted,omitempty"` + BytesSent int64 `json:"bytes_sent,omitempty"` + BytesReceived int64 `json:"bytes_received,omitempty"` + StatusCode int `json:"status_code,omitempty"` + DurationMS int64 `json:"duration_ms,omitempty"` +} + +// ResolvedKind returns the event's Kind, treating an empty value as +// KindPII for rows written before Kind existed. +func (e PIIEvent) ResolvedKind() EventKind { + if e.Kind == "" { + return KindPII + } + return e.Kind +} diff --git a/core/services/routing/piiadapter/anthropic.go b/core/services/routing/piiadapter/anthropic.go new file mode 100644 index 000000000000..e059e6bc1881 --- /dev/null +++ b/core/services/routing/piiadapter/anthropic.go @@ -0,0 +1,81 @@ +package piiadapter + +import ( + "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/core/services/routing/pii" +) + +// Anthropic returns a pii.Adapter for *schema.AnthropicRequest. The +// scan walks every message's text content (string-form or text blocks +// inside the structured `[]any` content), and the apply writes redacted +// text back in place. +// +// The shape mirrors OpenAI() — Anthropic's multimodal blocks +// (`{"type":"image","source":{...}}`, `{"type":"tool_use", ...}`) are +// left untouched; text-block scanning covers the chat-completion path. +// +// System prompts in the Anthropic API live on the request's top-level +// System field, not in Messages — they're skipped here for now (chat +// messages are the high-traffic surface). System-prompt scanning is a +// follow-up if a deployment proves it needs it. +func Anthropic() pii.Adapter { + return pii.Adapter{ + Scan: func(parsed any) []pii.ScannedText { + req, ok := parsed.(*schema.AnthropicRequest) + if !ok || req == nil { + return nil + } + var out []pii.ScannedText + for i := range req.Messages { + msg := &req.Messages[i] + switch ct := msg.Content.(type) { + case string: + if ct != "" { + out = append(out, pii.ScannedText{ + Index: encodeIdx(i, -1), + Text: ct, + }) + } + case []any: + for j, block := range ct { + if blockMap, ok := block.(map[string]any); ok { + if blockMap["type"] == "text" { + if text, ok := blockMap["text"].(string); ok && text != "" { + out = append(out, pii.ScannedText{ + Index: encodeIdx(i, j), + Text: text, + }) + } + } + } + } + } + } + return out + }, + Apply: func(parsed any, updates []pii.ScannedText) { + req, ok := parsed.(*schema.AnthropicRequest) + if !ok || req == nil { + return + } + for _, u := range updates { + msgIdx, blockIdx := decodeIdx(u.Index) + if msgIdx < 0 || msgIdx >= len(req.Messages) { + continue + } + msg := &req.Messages[msgIdx] + if blockIdx < 0 { + msg.Content = u.Text + continue + } + blocks, ok := msg.Content.([]any) + if !ok || blockIdx >= len(blocks) { + continue + } + if blockMap, ok := blocks[blockIdx].(map[string]any); ok { + blockMap["text"] = u.Text + } + } + }, + } +} diff --git a/core/services/routing/piiadapter/anthropic_test.go b/core/services/routing/piiadapter/anthropic_test.go new file mode 100644 index 000000000000..1ec72d4ee56a --- /dev/null +++ b/core/services/routing/piiadapter/anthropic_test.go @@ -0,0 +1,69 @@ +package piiadapter + +import ( + "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/core/services/routing/pii" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Anthropic adapter", func() { + It("scans string content", func() { + req := &schema.AnthropicRequest{ + Messages: []schema.AnthropicMessage{ + {Role: "user", Content: "hi alice@example.com"}, + }, + } + got := Anthropic().Scan(req) + Expect(got).To(HaveLen(1)) + Expect(got[0].Text).To(Equal("hi alice@example.com")) + }) + + It("scans text blocks", func() { + // AnthropicMessage.Content is `any`. After JSON decode of a real + // request it is []any of map[string]any blocks, exactly mirroring + // OpenAI's content-block shape — image blocks must be skipped, text + // blocks must be scanned. + req := &schema.AnthropicRequest{ + Messages: []schema.AnthropicMessage{ + {Role: "user", Content: []any{ + map[string]any{"type": "text", "text": "first text"}, + map[string]any{"type": "image", "source": map[string]any{"type": "base64", "data": "..."}}, + map[string]any{"type": "text", "text": "second text"}, + }}, + }, + } + got := Anthropic().Scan(req) + Expect(got).To(HaveLen(2)) + Expect(got[0].Text).To(Equal("first text")) + Expect(got[1].Text).To(Equal("second text")) + }) + + It("Apply mutates string content", func() { + req := &schema.AnthropicRequest{ + Messages: []schema.AnthropicMessage{ + {Role: "user", Content: "original"}, + }, + } + adapter := Anthropic() + got := adapter.Scan(req) + adapter.Apply(req, []pii.ScannedText{{Index: got[0].Index, Text: "redacted"}}) + Expect(req.Messages[0].Content).To(Equal("redacted")) + }) + + It("Apply mutates text block content", func() { + req := &schema.AnthropicRequest{ + Messages: []schema.AnthropicMessage{ + {Role: "user", Content: []any{ + map[string]any{"type": "text", "text": "original"}, + }}, + }, + } + adapter := Anthropic() + got := adapter.Scan(req) + adapter.Apply(req, []pii.ScannedText{{Index: got[0].Index, Text: "redacted"}}) + blocks := req.Messages[0].Content.([]any) + block := blocks[0].(map[string]any) + Expect(block["text"]).To(Equal("redacted")) + }) +}) diff --git a/core/services/routing/piiadapter/openai.go b/core/services/routing/piiadapter/openai.go new file mode 100644 index 000000000000..a79b4dd7e506 --- /dev/null +++ b/core/services/routing/piiadapter/openai.go @@ -0,0 +1,112 @@ +// Package piiadapter holds the per-API-shape adapters that translate +// between the routing/pii middleware and concrete request types from +// core/schema. Lives outside core/services/routing/pii so the schema +// package never imports pii (and pii never imports schema), keeping +// the dependency direction clean. +package piiadapter + +import ( + "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/core/services/routing/pii" +) + +// OpenAI returns a pii.Adapter for *schema.OpenAIRequest. It scans +// every chat message's text content (string-form or text blocks of +// the structured `[]any` content), and writes redacted text back. +// +// Multimodal content (image_url, audio_url, video_url) is left alone +// — PII in image bytes is the encoder NER tier's problem, not the +// regex tier's. We do walk text fields embedded inside content +// blocks because those are the most common shape Claude Code and +// similar clients produce. +// +// System / developer / tool messages are scanned as well: an API key +// pasted into a system prompt is just as leak-prone as one in a user +// message. +func OpenAI() pii.Adapter { + return pii.Adapter{ + Scan: func(parsed any) []pii.ScannedText { + req, ok := parsed.(*schema.OpenAIRequest) + if !ok || req == nil { + return nil + } + var out []pii.ScannedText + for i := range req.Messages { + msg := &req.Messages[i] + switch ct := msg.Content.(type) { + case string: + if ct != "" { + // Index encodes (message index, -1) to mean + // "the whole Content string". Negative + // inner indices are a valid sentinel because + // real array indices are ≥ 0. + out = append(out, pii.ScannedText{ + Index: encodeIdx(i, -1), + Text: ct, + }) + } + case []any: + for j, block := range ct { + if blockMap, ok := block.(map[string]any); ok { + if blockMap["type"] == "text" { + if text, ok := blockMap["text"].(string); ok && text != "" { + out = append(out, pii.ScannedText{ + Index: encodeIdx(i, j), + Text: text, + }) + } + } + } + } + } + } + return out + }, + Apply: func(parsed any, updates []pii.ScannedText) { + req, ok := parsed.(*schema.OpenAIRequest) + if !ok || req == nil { + return + } + for _, u := range updates { + msgIdx, blockIdx := decodeIdx(u.Index) + if msgIdx < 0 || msgIdx >= len(req.Messages) { + continue + } + msg := &req.Messages[msgIdx] + if blockIdx < 0 { + // Whole-string content. + msg.Content = u.Text + continue + } + blocks, ok := msg.Content.([]any) + if !ok || blockIdx >= len(blocks) { + continue + } + if blockMap, ok := blocks[blockIdx].(map[string]any); ok { + blockMap["text"] = u.Text + } + } + }, + } +} + +// encodeIdx packs (msg, block) into one int. block=-1 means +// "the whole Content string"; bit 24 is the sentinel flag and +// bits 0..23 hold the block index, leaving the rest for msg. +const idxWholeStringFlag = 1 << 24 +const idxBlockMask = (1 << 24) - 1 + +func encodeIdx(msg, block int) int { + if block < 0 { + return (msg << 25) | idxWholeStringFlag + } + return (msg << 25) | (block & idxBlockMask) +} + +func decodeIdx(packed int) (msg, block int) { + msg = packed >> 25 + if packed&idxWholeStringFlag != 0 { + return msg, -1 + } + return msg, packed & idxBlockMask +} diff --git a/core/services/routing/piiadapter/openai_test.go b/core/services/routing/piiadapter/openai_test.go new file mode 100644 index 000000000000..f35ac959fd86 --- /dev/null +++ b/core/services/routing/piiadapter/openai_test.go @@ -0,0 +1,93 @@ +package piiadapter + +import ( + "github.com/mudler/LocalAI/core/schema" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("OpenAI adapter", func() { + It("scans string content", func() { + req := &schema.OpenAIRequest{ + Messages: []schema.Message{ + {Role: "user", Content: "hello alice@example.com"}, + }, + } + adapter := OpenAI() + got := adapter.Scan(req) + Expect(got).To(HaveLen(1)) + Expect(got[0].Text).To(Equal("hello alice@example.com")) + }) + + It("scans content blocks", func() { + req := &schema.OpenAIRequest{ + Messages: []schema.Message{ + {Role: "user", Content: []any{ + map[string]any{"type": "text", "text": "block one"}, + map[string]any{"type": "image_url", "image_url": map[string]any{"url": "data:image/png;base64,xyz"}}, + map[string]any{"type": "text", "text": "block two"}, + }}, + }, + } + adapter := OpenAI() + got := adapter.Scan(req) + Expect(got).To(HaveLen(2)) + Expect(got[0].Text).To(Equal("block one")) + Expect(got[1].Text).To(Equal("block two")) + }) + + It("Apply mutates string content", func() { + req := &schema.OpenAIRequest{ + Messages: []schema.Message{ + {Role: "user", Content: "original"}, + {Role: "user", Content: "second"}, + }, + } + adapter := OpenAI() + scans := adapter.Scan(req) + updates := scans + updates[0].Text = "REDACTED-0" + updates[1].Text = "REDACTED-1" + adapter.Apply(req, updates) + + Expect(req.Messages[0].Content.(string)).To(Equal("REDACTED-0")) + Expect(req.Messages[1].Content.(string)).To(Equal("REDACTED-1")) + }) + + It("Apply mutates content block selectively", func() { + req := &schema.OpenAIRequest{ + Messages: []schema.Message{ + {Role: "user", Content: []any{ + map[string]any{"type": "text", "text": "before"}, + map[string]any{"type": "text", "text": "untouched"}, + }}, + }, + } + adapter := OpenAI() + scans := adapter.Scan(req) + Expect(scans).To(HaveLen(2)) + + // Redact only the first block. + updates := []struct{ idx int }{{0}} + scans[updates[0].idx].Text = "AFTER" + adapter.Apply(req, scans[:1]) + + blocks := req.Messages[0].Content.([]any) + Expect(blocks[0].(map[string]any)["text"]).To(Equal("AFTER")) + Expect(blocks[1].(map[string]any)["text"]).To(Equal("untouched")) + }) +}) + +var _ = Describe("encodeIdx/decodeIdx", func() { + It("round-trips message and block indices", func() { + cases := []struct{ msg, block int }{ + {0, 0}, {0, 5}, {3, 0}, {3, 12}, {7, -1}, {0, -1}, + } + for _, c := range cases { + got := encodeIdx(c.msg, c.block) + m, b := decodeIdx(got) + Expect(m).To(Equal(c.msg), "round-trip msg for (%d,%d)", c.msg, c.block) + Expect(b).To(Equal(c.block), "round-trip block for (%d,%d)", c.msg, c.block) + } + }) +}) diff --git a/core/services/routing/piiadapter/piiadapter_suite_test.go b/core/services/routing/piiadapter/piiadapter_suite_test.go new file mode 100644 index 000000000000..9d313498787e --- /dev/null +++ b/core/services/routing/piiadapter/piiadapter_suite_test.go @@ -0,0 +1,13 @@ +package piiadapter + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestPiiAdapter(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "PII Adapter test suite") +} diff --git a/core/services/routing/router/decisions.go b/core/services/routing/router/decisions.go new file mode 100644 index 000000000000..aed706a851b2 --- /dev/null +++ b/core/services/routing/router/decisions.go @@ -0,0 +1,136 @@ +package router + +import ( + "context" + "sync" + "time" +) + +// Decision row written to the in-memory store. Mirrors the PIIEvent +// shape so the admin page can render the two side-by-side. Note: +// Prompt is NEVER stored — admins audit by Hash if they need to +// dedupe recurring routing patterns. +type DecisionRecord struct { + ID string `json:"id"` + CorrelationID string `json:"correlation_id"` + UserID string `json:"user_id"` + RouterModel string `json:"router_model"` // The smart-router model name the client asked for. + RequestedModel string `json:"requested_model"`// Same as RouterModel for now; reserved for chained routers. + ServedModel string `json:"served_model"` // The candidate the classifier picked. + Classifier string `json:"classifier"` // Classifier.Name(), e.g. "score". + Label string `json:"label"` + Score float64 `json:"score"` + LatencyMs int64 `json:"latency_ms"` + Cached bool `json:"cached"` // True when the decision came from the L2 embedding cache. + CacheSimilarity float64 `json:"cache_similarity,omitempty"` // Cosine similarity of the cache hit, 0 when not cached. + CreatedAt time.Time `json:"created_at"` +} + +// DecisionStore persists routing decisions for the admin page and +// future drift checks. In-process by default so a no-auth box still +// gets a decision log; a future GORM impl can reuse the auth DB. +type DecisionStore interface { + Record(ctx context.Context, r DecisionRecord) error + List(ctx context.Context, q DecisionListQuery) ([]DecisionRecord, error) + Count(ctx context.Context) (int, error) + Close() error +} + +// DecisionListQuery filters the decision log. Empty fields match all. +// Limit ≤ 0 picks a default cap. +type DecisionListQuery struct { + CorrelationID string + UserID string + RouterModel string + Limit int +} + +// NewMemoryDecisionStore returns a ring-buffer DecisionStore. capacity +// ≤ 0 picks 5_000 — same order of magnitude as PIIEvents but smaller +// because routing decisions correlate one-to-one with usage records; +// the existing UsageRecord log carries the bulk. +func NewMemoryDecisionStore(capacity int) DecisionStore { + if capacity <= 0 { + capacity = 5_000 + } + return &memoryDecisionStore{ + ring: make([]DecisionRecord, capacity), + cap: capacity, + } +} + +type memoryDecisionStore struct { + mu sync.RWMutex + ring []DecisionRecord + cap int + cursor int + full bool +} + +func (s *memoryDecisionStore) Record(_ context.Context, r DecisionRecord) error { + s.mu.Lock() + defer s.mu.Unlock() + s.ring[s.cursor] = r + s.cursor++ + if s.cursor == s.cap { + s.cursor = 0 + s.full = true + } + return nil +} + +func (s *memoryDecisionStore) List(_ context.Context, q DecisionListQuery) ([]DecisionRecord, error) { + limit := q.Limit + if limit <= 0 { + limit = 1000 + } + s.mu.RLock() + defer s.mu.RUnlock() + out := make([]DecisionRecord, 0, limit) + scan := func(r DecisionRecord) bool { + if r.ID == "" { + return false + } + if q.CorrelationID != "" && r.CorrelationID != q.CorrelationID { + return false + } + if q.UserID != "" && r.UserID != q.UserID { + return false + } + if q.RouterModel != "" && r.RouterModel != q.RouterModel { + return false + } + out = append(out, r) + return len(out) >= limit + } + if s.full { + for i := s.cursor - 1; i >= 0; i-- { + if scan(s.ring[i]) { + return out, nil + } + } + for i := s.cap - 1; i >= s.cursor; i-- { + if scan(s.ring[i]) { + return out, nil + } + } + } else { + for i := s.cursor - 1; i >= 0; i-- { + if scan(s.ring[i]) { + return out, nil + } + } + } + return out, nil +} + +func (s *memoryDecisionStore) Count(_ context.Context) (int, error) { + s.mu.RLock() + defer s.mu.RUnlock() + if s.full { + return s.cap, nil + } + return s.cursor, nil +} + +func (s *memoryDecisionStore) Close() error { return nil } diff --git a/core/services/routing/router/embedding_cache.go b/core/services/routing/router/embedding_cache.go new file mode 100644 index 000000000000..c057f551b015 --- /dev/null +++ b/core/services/routing/router/embedding_cache.go @@ -0,0 +1,244 @@ +package router + +import ( + "context" + "encoding/json" + "sync/atomic" + "time" + + "github.com/mudler/xlog" +) + +// Embedder produces a fixed-dimension vector from a prompt. The router +// uses it to look up semantically-similar past decisions in the L2 +// cache. Implementations honour ctx cancellation so a slow embedder +// doesn't pin the request goroutine. +type Embedder interface { + Embed(ctx context.Context, text string) ([]float32, error) +} + +// VectorStore is the KNN backend the embedding cache writes to. Search +// returns the top-1 match (similarity in [-1, 1] cosine space) and the +// serialized Decision payload, or ok=false on a clean miss. Insert +// stores a new (vector, decision) pair; size and eviction are the +// store implementation's concern. +type VectorStore interface { + Search(ctx context.Context, vec []float32) (similarity float64, payload []byte, ok bool, err error) + Insert(ctx context.Context, vec []float32, payload []byte) error +} + +// EmbeddingCacheStats reports per-classifier cache hit/miss/error +// counts. Surfaced through /api/router/cache/stats and the Routing tab +// so admins can see whether the cache is paying off. +// +// Hits + NearMisses + Misses equals the total number of Search calls +// that succeeded (no embedder/store error). NearMisses are kept +// separate from Misses because their similarity is observable — +// lowering similarity_threshold turns near-misses into hits without +// growing the cache, so the ratio tells admins how much room is left +// in the current threshold. +type EmbeddingCacheStats struct { + Hits uint64 `json:"hits"` + Misses uint64 `json:"misses"` // empty store or no similar key + NearMisses uint64 `json:"near_misses"` // store returned a key but below similarity_threshold + LowConfidence uint64 `json:"low_confidence"` // decisions we deliberately did not cache + EmbedderErrors uint64 `json:"embedder_errors"` + StoreErrors uint64 `json:"store_errors"` + + // SimilarityBuckets is a 10-bin histogram of the cosine + // similarities the store reported for any successful Search (hits + // and near-misses combined). Index i covers similarity [i/10, + // (i+1)/10). Counts are non-decreasing across the classifier's + // lifetime; reset via process restart. + SimilarityBuckets [10]uint64 `json:"similarity_buckets"` +} + +// EmbeddingCacheClassifier wraps an inner Classifier with an +// embedding-similarity cache. On Classify it first embeds the probe, +// asks the vector store for the nearest past decision, and returns +// it if similarity passes the configured threshold. Misses fall +// through to the inner classifier, and high-confidence outcomes are +// inserted into the store for future hits. +// +// Failure modes — embedder error, store error — degrade to the inner +// classifier's result. Routing never fails because of cache plumbing. +type EmbeddingCacheClassifier struct { + inner Classifier + embedder Embedder + store VectorStore + similarityThreshold float64 + confidenceThreshold float64 + + hits atomic.Uint64 + misses atomic.Uint64 + nearMisses atomic.Uint64 + lowConfidence atomic.Uint64 + embedderErrors atomic.Uint64 + storeErrors atomic.Uint64 + simBuckets [10]atomic.Uint64 +} + +// Default thresholds. Re-tune per (embedding model, corpus) — the +// admin histogram on the Routing tab shows where the cosine +// distribution actually sits. +const ( + defaultEmbeddingSimilarity = 0.80 + defaultEmbeddingConfidence = 0.60 +) + +// NewEmbeddingCacheClassifier wraps inner with an embedding-similarity +// cache. Panics on misconfiguration (nil inner / embedder / store) — +// same fail-fast posture as the score classifier. +// +// Zero threshold picks the package default (defaultEmbeddingSimilarity +// / defaultEmbeddingConfidence). +func NewEmbeddingCacheClassifier(inner Classifier, embedder Embedder, store VectorStore, similarityThreshold, confidenceThreshold float64) *EmbeddingCacheClassifier { + if inner == nil { + panic("router/embedding_cache: inner classifier is required") + } + if embedder == nil { + panic("router/embedding_cache: embedder is required") + } + if store == nil { + panic("router/embedding_cache: vector store is required") + } + if similarityThreshold <= 0 { + similarityThreshold = defaultEmbeddingSimilarity + } + if confidenceThreshold <= 0 { + confidenceThreshold = defaultEmbeddingConfidence + } + return &EmbeddingCacheClassifier{ + inner: inner, + embedder: embedder, + store: store, + similarityThreshold: similarityThreshold, + confidenceThreshold: confidenceThreshold, + } +} + +// Name is the inner classifier's name — the decision-log "classifier" +// field should reflect *what* made the decision, not the caching +// transport. Cache hits set Decision.Cached separately so admins can +// still distinguish a cached lookup from a fresh run. +func (c *EmbeddingCacheClassifier) Name() string { + return c.inner.Name() +} + +// Stats returns a snapshot of the cache counters. +func (c *EmbeddingCacheClassifier) Stats() EmbeddingCacheStats { + s := EmbeddingCacheStats{ + Hits: c.hits.Load(), + Misses: c.misses.Load(), + NearMisses: c.nearMisses.Load(), + LowConfidence: c.lowConfidence.Load(), + EmbedderErrors: c.embedderErrors.Load(), + StoreErrors: c.storeErrors.Load(), + } + for i := range c.simBuckets { + s.SimilarityBuckets[i] = c.simBuckets[i].Load() + } + return s +} + +func (c *EmbeddingCacheClassifier) Classify(ctx context.Context, p Probe) (Decision, error) { + start := time.Now() + + vec, err := c.embedder.Embed(ctx, p.Prompt) + if err != nil { + c.embedderErrors.Add(1) + xlog.Warn("router: embedding cache embed failed", "error", err) + // Embedder failure — fall through to the inner classifier so + // routing still happens. The miss is not a hard error. + return c.inner.Classify(ctx, p) + } + + sim, payload, hit, err := c.store.Search(ctx, vec) + if err != nil { + c.storeErrors.Add(1) + xlog.Warn("router: embedding cache store.Search failed", "error", err, "vec_dim", len(vec)) + return c.inner.Classify(ctx, p) + } + if hit { + // Bin the similarity once, regardless of threshold outcome. + // Admins read this back to see where the cosine distribution + // sits relative to the configured similarity_threshold. + c.recordSimilarity(sim) + if sim >= c.similarityThreshold { + if cached, ok := decodeCachedDecision(payload); ok { + c.hits.Add(1) + cached.Cached = true + cached.CacheSimilarity = sim + cached.Latency = time.Since(start) + return cached, nil + } + // Payload corrupt — treat as miss and overwrite on the next + // confident decision. + c.misses.Add(1) + } else { + c.nearMisses.Add(1) + } + } else { + c.misses.Add(1) + } + decision, err := c.inner.Classify(ctx, p) + if err != nil { + return decision, err + } + + // Don't poison the cache with uncertain decisions. The score + // classifier's softmax can put the top label as low as 1/N in + // pathological cases; only store outcomes where the model is + // clearly committed. + if decision.Score < c.confidenceThreshold { + c.lowConfidence.Add(1) + return decision, nil + } + + payload, encodeErr := encodeCachedDecision(decision) + if encodeErr != nil { + // Encoding can't realistically fail for the Decision type but + // guard so a future field doesn't break routing silently. + return decision, nil + } + if insertErr := c.store.Insert(ctx, vec, payload); insertErr != nil { + c.storeErrors.Add(1) + xlog.Warn("router: embedding cache store.Insert failed", "error", insertErr, "vec_dim", len(vec)) + // Insert failure is non-fatal — the decision is still good + // for this request, only the future-hit benefit is lost. + } + return decision, nil +} + +// recordSimilarity increments the histogram bucket covering the given +// cosine similarity. The store occasionally returns sim slightly above +// 1.0 due to floating-point error on exact matches; we clamp to the +// top bin to keep the histogram bounded. +func (c *EmbeddingCacheClassifier) recordSimilarity(sim float64) { + bucket := max(0, min(9, int(sim*10))) + c.simBuckets[bucket].Add(1) +} + +// cachedDecision is the on-disk shape stored in the vector backend. +// Kept separate from Decision so transient fields (Latency, Cached, +// CacheSimilarity) don't get serialized — they're per-call, not +// per-prompt. +type cachedDecision struct { + Labels []string `json:"labels"` + Score float64 `json:"score"` +} + +func encodeCachedDecision(d Decision) ([]byte, error) { + return json.Marshal(cachedDecision{Labels: append([]string(nil), d.Labels...), Score: d.Score}) +} + +func decodeCachedDecision(b []byte) (Decision, bool) { + var cd cachedDecision + if err := json.Unmarshal(b, &cd); err != nil { + return Decision{}, false + } + if len(cd.Labels) == 0 { + return Decision{}, false + } + return Decision{Labels: cd.Labels, Score: cd.Score}, true +} diff --git a/core/services/routing/router/embedding_cache_test.go b/core/services/routing/router/embedding_cache_test.go new file mode 100644 index 000000000000..726614d0e966 --- /dev/null +++ b/core/services/routing/router/embedding_cache_test.go @@ -0,0 +1,311 @@ +package router_test + +import ( + "context" + "encoding/json" + "errors" + "sync" + "time" + + "github.com/mudler/LocalAI/core/services/routing/router" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +// fakeEmbedder returns a vector keyed by a lookup table; this lets the +// test exercise hit/miss control without depending on a real model. +type fakeEmbedder struct { + mu sync.Mutex + table map[string][]float32 + failOnce bool +} + +func (e *fakeEmbedder) Embed(ctx context.Context, text string) ([]float32, error) { + e.mu.Lock() + defer e.mu.Unlock() + if e.failOnce { + e.failOnce = false + return nil, errors.New("embedder offline") + } + v, ok := e.table[text] + if !ok { + return nil, errors.New("no embedding for: " + text) + } + return v, nil +} + +// memVectorStore is an in-memory KNN store with exact-vector hits, used +// to exercise the cache layer without a real local-store backend. +// Similarity is 1.0 for an exact match (after vector quantisation), 0.5 +// for "close" (configured via the second-arg suffix), 0.0 otherwise. +type memVectorStore struct { + mu sync.Mutex + entries []memEntry + failOps int // remaining Search calls to fail before returning miss +} + +type memEntry struct { + vec []float32 + payload []byte +} + +func (s *memVectorStore) Search(_ context.Context, vec []float32) (float64, []byte, bool, error) { + s.mu.Lock() + defer s.mu.Unlock() + if s.failOps > 0 { + s.failOps-- + return 0, nil, false, errors.New("store offline") + } + for _, e := range s.entries { + if vecEqual(e.vec, vec) { + return 1.0, e.payload, true, nil + } + } + // "close" hit if the leading element matches but the rest doesn't — + // lets a test simulate sim=0.8 without floating-point fragility. + for _, e := range s.entries { + if len(vec) > 0 && len(e.vec) > 0 && vec[0] == e.vec[0] { + return 0.80, e.payload, true, nil + } + } + return 0, nil, false, nil +} + +func (s *memVectorStore) Insert(_ context.Context, vec []float32, payload []byte) error { + s.mu.Lock() + defer s.mu.Unlock() + s.entries = append(s.entries, memEntry{vec: append([]float32(nil), vec...), payload: append([]byte(nil), payload...)}) + return nil +} + +func vecEqual(a, b []float32) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + +// stubInner is a Classifier that records call count and returns a +// pre-programmed Decision. +type stubInner struct { + name string + decision router.Decision + err error + calls int +} + +func (s *stubInner) Classify(_ context.Context, _ router.Probe) (router.Decision, error) { + s.calls++ + if s.err != nil { + return router.Decision{}, s.err + } + return s.decision, nil +} + +func (s *stubInner) Name() string { return s.name } + +var _ = Describe("EmbeddingCache", func() { + ctx := context.Background() + + Context("miss then hit on exact prompt", func() { + It("populates the cache and serves the second call", func() { + embedder := &fakeEmbedder{table: map[string][]float32{ + "how do I exit vim": {1, 2, 3}, + }} + store := &memVectorStore{} + inner := &stubInner{ + name: "score", + decision: router.Decision{Labels: []string{"code-generation"}, Score: 0.9}, + } + cache := router.NewEmbeddingCacheClassifier(inner, embedder, store, 0.92, 0.6) + + // First call → miss, inner runs, decision stored. + d, err := cache.Classify(ctx, router.Probe{Prompt: "how do I exit vim"}) + Expect(err).NotTo(HaveOccurred(), "first classify") + Expect(d.Cached).To(BeFalse(), "first call should be a miss") + Expect(inner.calls).To(Equal(1)) + + // Second call with the same prompt → hit, inner NOT called again. + d, err = cache.Classify(ctx, router.Probe{Prompt: "how do I exit vim"}) + Expect(err).NotTo(HaveOccurred(), "second classify") + Expect(d.Cached).To(BeTrue(), "second call should be a cache hit") + Expect(d.CacheSimilarity).To(Equal(1.0)) + Expect(inner.calls).To(Equal(1), "inner ran on a hit") + Expect(d.Labels).To(Equal([]string{"code-generation"})) + + stats := cache.Stats() + Expect(stats.Hits).To(Equal(uint64(1))) + Expect(stats.Misses).To(Equal(uint64(1))) + // Second call had sim=1.0 (exact match), so the top bucket + // should have one count. + Expect(stats.SimilarityBuckets[9]).To(Equal(uint64(1)), "SimilarityBuckets[9] should be 1 (sim=1.0 hit)") + }) + }) + + Context("similarity below threshold", func() { + It("counts as a near-miss", func() { + // Two distinct prompts that produce vectors sharing only the + // first element — memVectorStore reports similarity 0.80, below + // the 0.92 threshold. + embedder := &fakeEmbedder{table: map[string][]float32{ + "first prompt": {1, 1, 1}, + "second prompt": {1, 9, 9}, + }} + store := &memVectorStore{} + inner := &stubInner{ + name: "score", + decision: router.Decision{Labels: []string{"math-reasoning"}, Score: 0.95}, + } + cache := router.NewEmbeddingCacheClassifier(inner, embedder, store, 0.92, 0.6) + + _, _ = cache.Classify(ctx, router.Probe{Prompt: "first prompt"}) + d, err := cache.Classify(ctx, router.Probe{Prompt: "second prompt"}) + Expect(err).NotTo(HaveOccurred(), "classify") + Expect(d.Cached).To(BeFalse(), "0.80 sim below 0.92 threshold should not hit") + Expect(inner.calls).To(Equal(2), "inner should have run twice") + stats := cache.Stats() + Expect(stats.NearMisses).To(Equal(uint64(1)), "NearMisses (sim=0.80 below 0.92 threshold)") + // Second call hit at sim=0.80 → bucket [0.8, 0.9) = index 8. + // First call missed cleanly (empty store) → no bucket. + Expect(stats.SimilarityBuckets[8]).To(Equal(uint64(1)), "SimilarityBuckets[8] (sim=0.80 near-miss)") + }) + }) + + Context("low confidence decisions", func() { + It("are not cached", func() { + embedder := &fakeEmbedder{table: map[string][]float32{ + "ambiguous": {7, 7, 7}, + }} + store := &memVectorStore{} + // Score 0.4 < confidenceThreshold 0.6 → don't cache. + inner := &stubInner{ + name: "score", + decision: router.Decision{Labels: []string{"casual-chat"}, Score: 0.4}, + } + cache := router.NewEmbeddingCacheClassifier(inner, embedder, store, 0.92, 0.6) + + _, _ = cache.Classify(ctx, router.Probe{Prompt: "ambiguous"}) + _, _ = cache.Classify(ctx, router.Probe{Prompt: "ambiguous"}) + + Expect(inner.calls).To(Equal(2), "second call should also miss") + stats := cache.Stats() + Expect(stats.LowConfidence).To(Equal(uint64(2))) + Expect(stats.Hits).To(Equal(uint64(0))) + }) + }) + + Context("embedder error", func() { + It("degrades to inner classifier", func() { + embedder := &fakeEmbedder{ + table: map[string][]float32{"p": {1}}, + failOnce: true, + } + store := &memVectorStore{} + inner := &stubInner{ + name: "score", + decision: router.Decision{Labels: []string{"x"}, Score: 0.99}, + } + cache := router.NewEmbeddingCacheClassifier(inner, embedder, store, 0.92, 0.6) + + d, err := cache.Classify(ctx, router.Probe{Prompt: "p"}) + Expect(err).NotTo(HaveOccurred(), "classify") + Expect(d.Cached).To(BeFalse(), "embedder error should not produce a cache hit") + Expect(inner.calls).To(Equal(1), "inner should have run once via fallthrough") + stats := cache.Stats() + Expect(stats.EmbedderErrors).To(Equal(uint64(1))) + }) + }) + + Context("store error", func() { + It("degrades to inner classifier", func() { + embedder := &fakeEmbedder{table: map[string][]float32{"p": {1}}} + store := &memVectorStore{failOps: 1} + inner := &stubInner{ + name: "score", + decision: router.Decision{Labels: []string{"x"}, Score: 0.99}, + } + cache := router.NewEmbeddingCacheClassifier(inner, embedder, store, 0.92, 0.6) + + _, _ = cache.Classify(ctx, router.Probe{Prompt: "p"}) + stats := cache.Stats() + Expect(stats.StoreErrors).To(Equal(uint64(1))) + }) + }) + + Context("Name", func() { + It("returns inner classifier name", func() { + embedder := &fakeEmbedder{} + store := &memVectorStore{} + inner := &stubInner{name: "score"} + cache := router.NewEmbeddingCacheClassifier(inner, embedder, store, 0, 0) + Expect(cache.Name()).To(Equal("score")) + }) + }) + + Context("inner error", func() { + It("propagates", func() { + embedder := &fakeEmbedder{table: map[string][]float32{"p": {1}}} + store := &memVectorStore{} + inner := &stubInner{name: "score", err: errors.New("classifier blew up")} + cache := router.NewEmbeddingCacheClassifier(inner, embedder, store, 0.92, 0.6) + + _, err := cache.Classify(ctx, router.Probe{Prompt: "p"}) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(Equal("classifier blew up")) + }) + }) + + Context("default thresholds", func() { + It("apply for zero values", func() { + embedder := &fakeEmbedder{table: map[string][]float32{"p": {1}}} + store := &memVectorStore{} + inner := &stubInner{name: "score", decision: router.Decision{Labels: []string{"y"}, Score: 0.7}} + // thresholds=0 → defaults (0.92 / 0.60). 0.7 > 0.60 so should + // cache, and a re-call hits at sim=1.0 > 0.92. + cache := router.NewEmbeddingCacheClassifier(inner, embedder, store, 0, 0) + _, _ = cache.Classify(ctx, router.Probe{Prompt: "p"}) + d, _ := cache.Classify(ctx, router.Probe{Prompt: "p"}) + Expect(d.Cached).To(BeTrue(), "expected hit with default thresholds") + }) + }) + + Context("corrupt payload", func() { + It("is treated as miss", func() { + embedder := &fakeEmbedder{table: map[string][]float32{"p": {1}}} + store := &memVectorStore{} + // Pre-poison the store with garbage that decodes to an empty + // label slice — Search will hit but the payload decoder must + // reject it, falling through to the inner classifier. + garbage, _ := json.Marshal(map[string]any{"labels": []string{}, "score": 1.0}) + _ = store.Insert(ctx, []float32{1}, garbage) + inner := &stubInner{name: "score", decision: router.Decision{Labels: []string{"ok"}, Score: 0.8}} + cache := router.NewEmbeddingCacheClassifier(inner, embedder, store, 0.5, 0.5) + d, err := cache.Classify(ctx, router.Probe{Prompt: "p"}) + Expect(err).NotTo(HaveOccurred(), "classify") + Expect(d.Cached).To(BeFalse(), "corrupt payload should not surface as a hit") + Expect(inner.calls).To(Equal(1), "inner should have run via fallthrough") + }) + }) +}) + +var _ = Describe("EmbeddingCache latency", func() { + It("is populated on hits", func() { + embedder := &fakeEmbedder{table: map[string][]float32{"p": {1}}} + store := &memVectorStore{} + inner := &stubInner{name: "score", decision: router.Decision{Labels: []string{"x"}, Score: 0.9, Latency: time.Millisecond}} + cache := router.NewEmbeddingCacheClassifier(inner, embedder, store, 0.92, 0.6) + + _, _ = cache.Classify(context.Background(), router.Probe{Prompt: "p"}) + d, _ := cache.Classify(context.Background(), router.Probe{Prompt: "p"}) + Expect(d.Cached).To(BeTrue(), "expected hit") + // On a hit, Latency reflects the cache-lookup time, NOT the original + // classifier latency stored in the payload. + Expect(d.Latency).To(BeNumerically("<", time.Second), "Latency unreasonably high for an in-memory hit") + }) +}) diff --git a/core/services/routing/router/registry.go b/core/services/routing/router/registry.go new file mode 100644 index 000000000000..3eac5a103a04 --- /dev/null +++ b/core/services/routing/router/registry.go @@ -0,0 +1,76 @@ +package router + +import ( + "sync" +) + +// Registry is the process-wide store of built classifiers, keyed by +// router-model name. The middleware uses it to avoid rebuilding the +// score classifier on every request, and the admin status endpoint +// reads from it to surface per-classifier cache stats. +// +// Each entry carries the fingerprint of the RouterConfig it was built +// from. A Get() with a stale fingerprint reports a miss so the +// middleware rebuilds — matches the previous local-sync.Map behaviour +// that keyed on fingerprint alone. +type Registry struct { + entries sync.Map // name → *registryEntry +} + +type registryEntry struct { + fingerprint uint64 + classifier Classifier +} + +func NewRegistry() *Registry { return &Registry{} } + +// Get returns the cached classifier for the named router model iff the +// stored fingerprint matches. A miss (no entry, or stale fingerprint) +// returns false; the caller is expected to rebuild and Put the result. +func (r *Registry) Get(name string, fingerprint uint64) (Classifier, bool) { + if r == nil { + return nil, false + } + v, ok := r.entries.Load(name) + if !ok { + return nil, false + } + e := v.(*registryEntry) + if e.fingerprint != fingerprint { + return nil, false + } + return e.classifier, true +} + +// Put stores a built classifier under (name, fingerprint), replacing +// any prior entry. The middleware calls this after a Get miss. +func (r *Registry) Put(name string, fingerprint uint64, c Classifier) { + if r == nil { + return + } + r.entries.Store(name, ®istryEntry{fingerprint: fingerprint, classifier: c}) +} + +// EmbeddingCacheStatsByRouter returns a snapshot of every embedding +// cache currently in the registry, keyed by router-model name. Plain +// classifiers without the L2 cache wrapper are skipped — callers +// distinguish "cache disabled" from "cache enabled with zero hits" by +// the presence of the map key. +func (r *Registry) EmbeddingCacheStatsByRouter() map[string]EmbeddingCacheStats { + if r == nil { + return nil + } + out := map[string]EmbeddingCacheStats{} + r.entries.Range(func(k, v any) bool { + name, _ := k.(string) + e, _ := v.(*registryEntry) + if e == nil { + return true + } + if ec, ok := e.classifier.(*EmbeddingCacheClassifier); ok { + out[name] = ec.Stats() + } + return true + }) + return out +} diff --git a/core/services/routing/router/router_suite_test.go b/core/services/routing/router/router_suite_test.go new file mode 100644 index 000000000000..a51a2ee5049d --- /dev/null +++ b/core/services/routing/router/router_suite_test.go @@ -0,0 +1,13 @@ +package router_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestRouter(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "router test suite") +} diff --git a/core/services/routing/router/score.go b/core/services/routing/router/score.go new file mode 100644 index 000000000000..711dc3ce2958 --- /dev/null +++ b/core/services/routing/router/score.go @@ -0,0 +1,304 @@ +package router + +import ( + "context" + "fmt" + "math" + "sort" + "strings" + "sync" + "time" +) + +// CandidateScore mirrors core/backend.CandidateScore at the router- +// boundary. Lives here so the router can depend on this abstraction +// without importing core/backend. +type CandidateScore struct { + LogProb float64 + LengthNormalizedLogProb float64 + NumTokens int +} + +// Scorer evaluates a model's joint log-probability of each candidate +// continuation given a shared prompt. The classifier consumes this for +// multi-label policy selection (read off the distribution rather than +// asking the model to emit a single argmax token). +type Scorer interface { + Score(ctx context.Context, prompt string, candidates []string) ([]CandidateScore, error) +} + +// ScorePolicy mirrors config.RouterPolicy at the classifier boundary — +// a label string plus its natural-language description for the +// routing system prompt. +type ScorePolicy struct { + Label string + Description string +} + +// defaultActivationThreshold is the softmax-probability floor a policy +// must clear to be considered "active." Picked low enough that two +// reasonably-confident labels (each ~0.4) both activate, high enough +// that a flat distribution doesn't activate everything. +const defaultActivationThreshold = 0.15 + +// ScoreClassifier scores every policy label as a continuation of the +// routing prompt, converts log-probabilities into a softmax +// distribution, and returns the set of labels whose probability +// passes the activation threshold. +// +// This is the off-the-shelf-Arch-Router approach extended for multi- +// label. The classifier model is trained to emit a single policy +// label, but its output distribution still spreads probability mass +// across competing labels when more than one applies. Reading the +// distribution rather than the argmax lets us route conjunctive +// intents ("debug this code AND explain the math") to a candidate +// that can serve both. +type ScoreClassifier struct { + policies []ScorePolicy + scorer Scorer + activationThreshold float64 + + // systemPrompt is built once at construction. The same prompt is + // reused on every classification — only the user-turn body changes. + systemPrompt string + + // labels stored in the order configured. Indices into policies[] + // align with the scorer's input/output ordering. + labelOrder []string + + // cache stores label-set-by-(normalised prompt) to skip the + // scoring round-trip on repeat queries. Keys are comma-joined + // label sets sorted lexicographically for deterministic equality. + mu sync.RWMutex + cache map[string][]string + cacheCap int +} + +// NewScoreClassifier panics on caller errors at construction (empty +// policies, missing description, nil scorer) — same rationale as the +// other classifiers. cacheCap=0 disables the cache. +// activationThreshold=0 picks the package default (0.15). +func NewScoreClassifier(policies []ScorePolicy, scorer Scorer, cacheCap int, activationThreshold float64) *ScoreClassifier { + if len(policies) == 0 { + panic("router/score: at least one policy is required") + } + if scorer == nil { + panic("router/score: scorer is required (configure router.classifier_model)") + } + for _, p := range policies { + if p.Label == "" { + panic("router/score: policy has empty label") + } + if p.Description == "" { + panic(fmt.Sprintf("router/score: policy %q has no description", p.Label)) + } + } + labels := make([]string, 0, len(policies)) + for _, p := range policies { + labels = append(labels, p.Label) + } + if cacheCap < 0 { + cacheCap = 0 + } + if activationThreshold <= 0 { + activationThreshold = defaultActivationThreshold + } + return &ScoreClassifier{ + policies: policies, + scorer: scorer, + activationThreshold: activationThreshold, + systemPrompt: buildScoreSystemPrompt(policies), + labelOrder: labels, + cache: make(map[string][]string, cacheCap), + cacheCap: cacheCap, + } +} + +func (c *ScoreClassifier) Name() string { return ClassifierScore } + +func (c *ScoreClassifier) Classify(ctx context.Context, p Probe) (Decision, error) { + start := time.Now() + if hit, ok := c.lookupCache(p.Prompt); ok { + return Decision{Labels: hit, Score: 1.0, Latency: time.Since(start)}, nil + } + prompt := buildScorePrompt(c.systemPrompt, p.Prompt) + results, err := c.scorer.Score(ctx, prompt, c.labelOrder) + if err != nil { + return errDecision(start, fmt.Errorf("score classify: %w", err)) + } + if len(results) != len(c.labelOrder) { + return errDecision(start, fmt.Errorf("score classify: scorer returned %d results for %d policies", len(results), len(c.labelOrder))) + } + + // Convert per-token log-probabilities into a probability + // distribution over labels via softmax. Length-normalisation + // makes labels of different token lengths comparable, then + // softmax converts to probabilities suitable for thresholding. + logProbs := make([]float64, len(results)) + for i, r := range results { + switch { + case r.NumTokens == 0: + logProbs[i] = math.Inf(-1) + case r.LengthNormalizedLogProb != 0: + logProbs[i] = r.LengthNormalizedLogProb + default: + // Backend didn't populate the length-normalised field; + // derive it ourselves so candidates of unequal token + // length stay comparable. + logProbs[i] = r.LogProb / float64(r.NumTokens) + } + } + probs := softmax(logProbs) + + // Threshold to active label set. Top probability is also kept + // for the decision-log "score" field so admins can see + // confidence even when multiple labels are active. + active := make([]string, 0, len(c.labelOrder)) + topProb := 0.0 + for i, prob := range probs { + if prob > topProb { + topProb = prob + } + if prob >= c.activationThreshold { + active = append(active, c.labelOrder[i]) + } + } + // Defensive: if the distribution is so flat that nothing + // crosses threshold, fall back to the argmax so the caller + // always has something to route on. + if len(active) == 0 { + bestIdx := 0 + for i := 1; i < len(probs); i++ { + if probs[i] > probs[bestIdx] { + bestIdx = i + } + } + active = []string{c.labelOrder[bestIdx]} + } + c.storeCache(p.Prompt, active) + return Decision{ + Labels: active, + Score: topProb, + Latency: time.Since(start), + }, nil +} + +// softmax converts an array of log-probabilities into a probability +// distribution. -inf inputs are handled (their exp contributes 0). +// Uses the standard max-subtraction trick for numerical stability. +func softmax(logProbs []float64) []float64 { + if len(logProbs) == 0 { + return nil + } + maxLP := math.Inf(-1) + for _, lp := range logProbs { + if lp > maxLP { + maxLP = lp + } + } + if math.IsInf(maxLP, -1) { + // All -inf: return a uniform distribution as a sensible + // degenerate result. + out := make([]float64, len(logProbs)) + for i := range out { + out[i] = 1.0 / float64(len(logProbs)) + } + return out + } + out := make([]float64, len(logProbs)) + sum := 0.0 + for i, lp := range logProbs { + out[i] = math.Exp(lp - maxLP) + sum += out[i] + } + if sum == 0 { + // Shouldn't happen given the maxLP check above, but guard + // against pathological inputs. + for i := range out { + out[i] = 1.0 / float64(len(out)) + } + return out + } + for i := range out { + out[i] /= sum + } + return out +} + +// cacheKey collapses incidental whitespace and casing so prompts like +// "hello", " hello ", and "Hello" share an entry — agent loops often +// produce minor variations that would otherwise miss. +func cacheKey(prompt string) string { + return strings.ToLower(strings.TrimSpace(prompt)) +} + +func (c *ScoreClassifier) lookupCache(prompt string) ([]string, bool) { + if c.cacheCap == 0 { + return nil, false + } + c.mu.RLock() + defer c.mu.RUnlock() + v, ok := c.cache[cacheKey(prompt)] + return v, ok +} + +func (c *ScoreClassifier) storeCache(prompt string, labels []string) { + if c.cacheCap == 0 { + return + } + c.mu.Lock() + defer c.mu.Unlock() + if len(c.cache) >= c.cacheCap { + for k := range c.cache { + delete(c.cache, k) + break + } + } + // Defensive copy + sort: cached label sets must be stable so + // callers can't mutate the cached value via aliasing, and + // comparing sets in tests doesn't depend on insertion order. + cp := make([]string, len(labels)) + copy(cp, labels) + sort.Strings(cp) + c.cache[cacheKey(prompt)] = cp +} + +// CacheLen returns the number of cached prompts. Test-only API. +func (c *ScoreClassifier) CacheLen() int { + c.mu.RLock() + defer c.mu.RUnlock() + return len(c.cache) +} + +func buildScoreSystemPrompt(policies []ScorePolicy) string { + var b strings.Builder + b.WriteString("You are a routing classifier. Pick the policy whose description best matches the user's request.\n\n") + b.WriteString("Available policies:\n") + for _, p := range policies { + b.WriteString("- ") + b.WriteString(p.Label) + b.WriteString(": ") + b.WriteString(p.Description) + b.WriteString("\n") + } + return b.String() +} + +// buildScorePrompt assembles the Qwen/ChatML-style prompt the +// Arch-Router model was trained on. The candidate label is scored as +// the assistant's first token(s) of response — so we end the prompt +// right at the assistant-turn marker, no trailing newline. +// +// Hard-coded to ChatML for now: Arch-Router is Qwen-2.5-1.5B-Instruct +// based and the published GGUF carries this template natively. When +// we add a non-ChatML scoring model we'll thread the template through +// from ModelConfig. +func buildScorePrompt(system, user string) string { + var b strings.Builder + b.WriteString("<|im_start|>system\n") + b.WriteString(system) + b.WriteString("<|im_end|>\n<|im_start|>user\n") + b.WriteString(user) + b.WriteString("<|im_end|>\n<|im_start|>assistant\n") + return b.String() +} diff --git a/core/services/routing/router/score_test.go b/core/services/routing/router/score_test.go new file mode 100644 index 000000000000..5ab0a6e17146 --- /dev/null +++ b/core/services/routing/router/score_test.go @@ -0,0 +1,231 @@ +package router + +import ( + "context" + "errors" + "sort" + "strings" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +type stubScorer struct { + results []CandidateScore + err error + calls int + lastP string + lastC []string +} + +func (s *stubScorer) Score(_ context.Context, prompt string, candidates []string) ([]CandidateScore, error) { + s.calls++ + s.lastP = prompt + s.lastC = append(s.lastC[:0], candidates...) + if s.err != nil { + return nil, s.err + } + return s.results, nil +} + +func testPolicies() []ScorePolicy { + return []ScorePolicy{ + {Label: "code-generation", Description: "writing, debugging, or explaining code"}, + {Label: "casual-chat", Description: "small talk and general conversation"}, + {Label: "math-reasoning", Description: "arithmetic, equations, word problems"}, + } +} + +func sortedLabels(d Decision) []string { + out := append([]string(nil), d.Labels...) + sort.Strings(out) + return out +} + +func equalLabels(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + +var _ = Describe("ScoreClassifier", func() { + It("returns a single dominant label", func() { + // A confident single-label classification: code-generation + // dominates softmax, the others fall well below the activation + // threshold (default 0.15). + s := &stubScorer{results: []CandidateScore{ + {LogProb: -0.05, LengthNormalizedLogProb: -0.025, NumTokens: 2}, // code + {LogProb: -8.0, LengthNormalizedLogProb: -2.667, NumTokens: 3}, // chat + {LogProb: -10.0, LengthNormalizedLogProb: -2.5, NumTokens: 4}, // math + }} + c := NewScoreClassifier(testPolicies(), s, 0, 0) + d, err := c.Classify(context.Background(), Probe{Prompt: "fix this null pointer"}) + Expect(err).NotTo(HaveOccurred(), "Classify") + Expect(equalLabels(d.Labels, []string{"code-generation"})).To(BeTrue(), "Labels = %v, want [code-generation]", d.Labels) + // Score is the top softmax probability. Two ~-2.5 distractors + // vs a ~0 winner gives ~0.86 for the winner — high enough to + // signal confidence in the decision log. + Expect(d.Score).To(BeNumerically(">=", 0.8), "want >= 0.8 for dominant single label") + }) + + It("activates multiple labels", func() { + // Two-way tie: code and math each take ~0.5 of the probability + // mass, chat is far behind. Both labels must activate so the + // router can pick a candidate covering both capabilities. + s := &stubScorer{results: []CandidateScore{ + {LogProb: -2.0, LengthNormalizedLogProb: -1.0, NumTokens: 2}, // code ~0.49 + {LogProb: -9.0, LengthNormalizedLogProb: -3.0, NumTokens: 3}, // chat ~0.01 + {LogProb: -4.0, LengthNormalizedLogProb: -1.0, NumTokens: 4}, // math ~0.49 + }} + c := NewScoreClassifier(testPolicies(), s, 0, 0) + d, err := c.Classify(context.Background(), Probe{Prompt: "write code that solves this word problem"}) + Expect(err).NotTo(HaveOccurred(), "Classify") + got := sortedLabels(d) + want := []string{"code-generation", "math-reasoning"} + Expect(equalLabels(got, want)).To(BeTrue(), "Labels = %v, want %v", got, want) + }) + + It("falls back to argmax on flat distribution", func() { + // All three labels score roughly equally. Strict + // activation-threshold filtering could return zero labels, which + // would leave the router with nothing to match. The classifier + // falls back to argmax in this case so callers always have at + // least one label to route on. + s := &stubScorer{results: []CandidateScore{ + {LogProb: -2.0, LengthNormalizedLogProb: -1.0, NumTokens: 2}, // ~0.33 + {LogProb: -3.0, LengthNormalizedLogProb: -1.0, NumTokens: 3}, // ~0.33 + {LogProb: -4.0, LengthNormalizedLogProb: -1.0, NumTokens: 4}, // ~0.33 + }} + // Threshold above max softmax probability (0.5) forces the + // fallback path. + c := NewScoreClassifier(testPolicies(), s, 0, 0.5) + d, err := c.Classify(context.Background(), Probe{Prompt: "x"}) + Expect(err).NotTo(HaveOccurred(), "Classify") + Expect(d.Labels).To(HaveLen(1), "want fallback to argmax (single label)") + }) + + It("falls back to joint log-prob when length normalisation missing", func() { + // Backend that doesn't honour length_normalize — only LogProb is + // populated. The classifier derives the per-token score itself + // so candidates of different token lengths stay comparable. If + // it didn't, the joint log-probs (-8, -5, -6) would pick chat — + // purely because it has fewer tokens. With length-norm chat is + // behind on per-token quality. + s := &stubScorer{results: []CandidateScore{ + {LogProb: -8.0, NumTokens: 4}, // -2.0 per token + {LogProb: -15.0, NumTokens: 2}, // -7.5 per token — clearly out + {LogProb: -6.0, NumTokens: 3}, // -2.0 per token + }} + c := NewScoreClassifier(testPolicies(), s, 0, 0) + d, err := c.Classify(context.Background(), Probe{Prompt: "x"}) + Expect(err).NotTo(HaveOccurred(), "Classify") + got := sortedLabels(d) + want := []string{"code-generation", "math-reasoning"} + Expect(equalLabels(got, want)).To(BeTrue(), "Labels = %v, want %v", got, want) + }) + + It("builds ChatML prompt with system and user", func() { + s := &stubScorer{results: []CandidateScore{ + {LogProb: -1, LengthNormalizedLogProb: -0.5, NumTokens: 2}, + {LogProb: -2, LengthNormalizedLogProb: -0.67, NumTokens: 3}, + {LogProb: -3, LengthNormalizedLogProb: -0.75, NumTokens: 4}, + }} + c := NewScoreClassifier(testPolicies(), s, 0, 0) + _, err := c.Classify(context.Background(), Probe{Prompt: "hello world"}) + Expect(err).NotTo(HaveOccurred(), "Classify") + Expect(s.lastP).To(ContainSubstring("<|im_start|>system")) + Expect(s.lastP).To(ContainSubstring("code-generation: writing, debugging")) + Expect(s.lastP).To(ContainSubstring("<|im_start|>user\nhello world<|im_end|>")) + Expect(strings.HasSuffix(s.lastP, "<|im_start|>assistant\n")).To(BeTrue(), "prompt does not end at assistant marker: %q", s.lastP) + Expect(s.lastC).To(HaveLen(3)) + Expect(s.lastC[0]).To(Equal("code-generation")) + Expect(s.lastC[1]).To(Equal("casual-chat")) + Expect(s.lastC[2]).To(Equal("math-reasoning")) + }) + + It("caches by normalised prompt", func() { + s := &stubScorer{results: []CandidateScore{ + {LogProb: -0.1, LengthNormalizedLogProb: -0.05, NumTokens: 2}, + {LogProb: -5, LengthNormalizedLogProb: -1.67, NumTokens: 3}, + {LogProb: -6, LengthNormalizedLogProb: -1.5, NumTokens: 4}, + }} + c := NewScoreClassifier(testPolicies(), s, 64, 0) + _, err := c.Classify(context.Background(), Probe{Prompt: "Fix Bug"}) + Expect(err).NotTo(HaveOccurred(), "classify 1") + _, err = c.Classify(context.Background(), Probe{Prompt: " fix bug "}) + Expect(err).NotTo(HaveOccurred(), "classify 2") + Expect(s.calls).To(Equal(1), "second classify should hit cache") + Expect(c.CacheLen()).To(Equal(1)) + }) + + It("cache disabled when cap zero", func() { + s := &stubScorer{results: []CandidateScore{ + {LogProb: -1, LengthNormalizedLogProb: -0.5, NumTokens: 2}, + {LogProb: -5, LengthNormalizedLogProb: -1.67, NumTokens: 3}, + {LogProb: -6, LengthNormalizedLogProb: -1.5, NumTokens: 4}, + }} + c := NewScoreClassifier(testPolicies(), s, 0, 0) + for i := 0; i < 3; i++ { + _, err := c.Classify(context.Background(), Probe{Prompt: "same"}) + Expect(err).NotTo(HaveOccurred(), "classify") + } + Expect(s.calls).To(Equal(3), "cache disabled") + }) + + It("propagates scorer error", func() { + scorerErr := errors.New("boom") + c := NewScoreClassifier(testPolicies(), &stubScorer{err: scorerErr}, 0, 0) + _, err := c.Classify(context.Background(), Probe{Prompt: "x"}) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("boom"), "expected scorer error to propagate") + }) + + It("returns result-count mismatch as error", func() { + s := &stubScorer{results: []CandidateScore{ + {LogProb: -1, LengthNormalizedLogProb: -0.5, NumTokens: 2}, + }} + c := NewScoreClassifier(testPolicies(), s, 0, 0) + _, err := c.Classify(context.Background(), Probe{Prompt: "x"}) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("returned 1 results for 3 policies")) + }) + + It("zero-token candidate scores -inf", func() { + // A NumTokens=0 candidate must contribute zero softmax mass and + // never win, even if its raw log-prob looks favourable. + s := &stubScorer{results: []CandidateScore{ + {LogProb: 100, LengthNormalizedLogProb: 100, NumTokens: 0}, // degenerate + {LogProb: -2, LengthNormalizedLogProb: -1.0, NumTokens: 2}, + {LogProb: -3, LengthNormalizedLogProb: -0.75, NumTokens: 4}, + }} + c := NewScoreClassifier(testPolicies(), s, 0, 0) + d, err := c.Classify(context.Background(), Probe{Prompt: "x"}) + Expect(err).NotTo(HaveOccurred(), "Classify") + for _, l := range d.Labels { + Expect(l).NotTo(Equal("code-generation"), "NumTokens=0 label must not be active") + } + }) + + It("panics on empty policies", func() { + Expect(func() { NewScoreClassifier(nil, &stubScorer{}, 0, 0) }).To(Panic()) + }) + + It("panics on nil scorer", func() { + Expect(func() { NewScoreClassifier(testPolicies(), nil, 0, 0) }).To(Panic()) + }) + + It("panics on missing description", func() { + Expect(func() { NewScoreClassifier([]ScorePolicy{{Label: "x"}}, &stubScorer{}, 0, 0) }).To(Panic()) + }) + + It("Name returns the classifier identifier", func() { + c := NewScoreClassifier(testPolicies(), &stubScorer{}, 0, 0) + Expect(c.Name()).To(Equal(ClassifierScore)) + }) +}) diff --git a/core/services/routing/router/types.go b/core/services/routing/router/types.go new file mode 100644 index 000000000000..e196068d0872 --- /dev/null +++ b/core/services/routing/router/types.go @@ -0,0 +1,89 @@ +// Package router holds the routing module's classifier interface and +// the Score implementation. +// +// The dispatch architecture is: a "router model" in ModelConfig (one +// with a Router block) gets matched at request time. The classifier +// inspects the prompt and returns the set of policy labels it considers +// active; the surrounding middleware picks the first candidate whose +// labels are a superset of the active set, rewrites input.Model to that +// candidate, and falls back through the existing model resolution path. +// This keeps ACL checks, disabled-state, and per-model PII consistent — +// the router does *model* selection, nothing else. +// +// The package deliberately has no dependency on core/http or +// core/services — those wire the classifier in and feed it the request +// shape they own. Keeps the classifier easy to unit-test against +// synthetic Probe inputs and reusable from non-HTTP entry points +// (e.g., a future MCP routing tool). +package router + +import ( + "context" + "time" +) + +// Probe is the classifier's input — the parsed prompt content the +// classifier needs to make a decision. Populated by the caller (the +// middleware does the schema-shape extraction); the classifier never +// inspects the original request struct. +type Probe struct { + // Prompt is the merged user-visible text. For chat completions it + // is the concatenation of message contents (separated by newlines); + // for plain completions it is the raw prompt. + Prompt string +} + +// Decision is the classifier's output. Labels carries the SET of +// policy labels the classifier considers active for this probe. The +// surrounding middleware picks the first candidate whose Labels +// superset the active label set; that lets one prompt activate multiple +// policies and route to a model capable of all of them. Score is the +// softmax probability of the top label — kept for the decision log so +// admins can spot uncertain calls. +type Decision struct { + Labels []string `json:"labels"` + Score float64 `json:"score"` + Latency time.Duration `json:"latency"` + + // Cached is true when the decision came from the L2 embedding + // cache rather than a fresh classifier run. CacheSimilarity carries + // the cosine similarity of the cache hit (0 when not cached). + Cached bool `json:"cached,omitempty"` + CacheSimilarity float64 `json:"cache_similarity,omitempty"` +} + +// Classifier is the entry point the middleware calls. The +// implementation honours ctx cancellation so long-running classifiers +// abort when the request context dies. +type Classifier interface { + Classify(ctx context.Context, p Probe) (Decision, error) + // Name is a stable identifier that ends up in RouterDecision rows + // — admins read this to know which classifier produced a given + // decision. + Name() string +} + +// Classifier names. Single source of truth for the YAML +// classifier: field, the buildClassifier dispatch in the +// middleware, and the strings each Classifier returns from Name(). +const ( + // ClassifierScore is the only shipped classifier. It picks + // labels by asking the classifier model to score each policy + // label as a continuation of the routing prompt. Used with + // Arch-Router-style small router models (Qwen-2.5-1.5B-Instruct + // base, trained on policy-continuation). See router/score.go + // for the full rationale. + ClassifierScore = "score" +) + +// LabelFallback is the synthetic label written to the decision +// store when the middleware uses cfg.Router.Fallback rather than a +// classifier-picked candidate. +const LabelFallback = "fallback" + +// errDecision packages an error with a populated Latency so each +// classifier's Classify can return early without restating the +// `Decision{Latency: time.Since(start)}, err` pattern. +func errDecision(start time.Time, err error) (Decision, error) { + return Decision{Latency: time.Since(start)}, err +} diff --git a/docs/content/features/cloud-proxy.md b/docs/content/features/cloud-proxy.md new file mode 100644 index 000000000000..08832ebf31f2 --- /dev/null +++ b/docs/content/features/cloud-proxy.md @@ -0,0 +1,184 @@ ++++ +title = "Cloud passthrough proxy" +weight = 28 +toc = true +description = "Forward requests to OpenAI, Anthropic, or any compatible provider" +tags = ["Proxy", "Cloud", "Routing", "Advanced"] +categories = ["Features"] ++++ + +LocalAI can forward chat-completion and Anthropic Messages requests to an +external provider instead of running them through the local gRPC backend +pipeline. Configure a model whose `backend` starts with `proxy-` and whose +`proxy.upstream_url` is set, and LocalAI bypasses templating, MCP injection, +and the local model loader entirely — the upstream sees the body the client +sent (with only the top-level `model` field optionally rewritten). + +The streaming PII filter still runs over the upstream's SSE stream, so cloud +egress remains subject to the same redaction rules a local model would apply. + +## When to use this + +- Mix local and cloud models in the same LocalAI instance — clients hit one + endpoint, LocalAI dispatches per model. +- Apply LocalAI's auth, usage tracking, and PII redaction to cloud traffic + before the body leaves the network. +- Use the intelligent router to send small or simple prompts to a local model + and complex ones to Claude or GPT-4o. + +## How it works + +1. Request hits LocalAI on `/v1/chat/completions` (OpenAI-shaped) or + `/v1/messages` (Anthropic-shaped). +2. The standard auth and routing middleware runs. +3. Per-model PII redaction runs request-side as it would for any model. +4. The handler checks `IsCloudProxy()`. If true, it skips the gRPC backend + and hands the request body to the cloudproxy package. +5. The proxy POSTs the body to `proxy.upstream_url` with provider-aware + authentication, then streams the SSE response back to the client. +6. The streaming PII filter rewrites per-token text in flight; the upstream's + event names and metadata pass through unchanged. + +The proxy is **wire-format-faithful** — it does not translate request shapes +between providers. A client posting an OpenAI-shaped body to a `proxy-anthropic` +model will get a confused upstream. Use the proxy that matches your client's +wire format. + +## Configuration + +Two backends are supported in the MVP: + +| Backend | Wire shape | Endpoint | Auth | +|---|---|---|---| +| `proxy-openai` | OpenAI chat completions | `/v1/chat/completions` | `Authorization: Bearer $KEY` | +| `proxy-anthropic` | Anthropic Messages | `/v1/messages` | `x-api-key: $KEY` plus `anthropic-version` header | + +API keys are read from environment variables named in the model's YAML. +The key never appears in the config file or the admin UI. + +### OpenAI passthrough + +```yaml +name: gpt-4o-proxy +backend: proxy-openai + +# When set, replaces the client's "model" field before forwarding. +# Useful when the LocalAI alias differs from the upstream's canonical name. +proxy: + upstream_url: https://api.openai.com/v1/chat/completions + api_key_env: OPENAI_API_KEY + upstream_model: gpt-4o + request_timeout_seconds: 120 + +# PII filtering defaults to ON for proxy-* backends. Override by setting +# pii.enabled: false explicitly. Per-pattern action overrides go in +# pii.patterns; see the Middleware admin page or the Middleware feature doc. +pii: + enabled: true +``` + +Then start LocalAI with the API key in the environment: + +```bash +export OPENAI_API_KEY=sk-... +local-ai run +``` + +Clients hit `http://localhost:8080/v1/chat/completions` with `"model": "gpt-4o-proxy"` +and the request lands on OpenAI's API. + +### Anthropic passthrough + +```yaml +name: claude-sonnet-proxy +backend: proxy-anthropic + +proxy: + upstream_url: https://api.anthropic.com/v1/messages + api_key_env: ANTHROPIC_API_KEY + upstream_model: claude-3-5-sonnet-20241022 + request_timeout_seconds: 300 + +pii: + enabled: true + # Block — not just mask — leaked credentials before they reach the upstream. + patterns: + - id: api_key_prefix + action: block +``` + +Anthropic clients hit `http://localhost:8080/v1/messages` with +`"model": "claude-sonnet-proxy"`. + +### Other OpenAI-compatible providers + +Most third-party providers (Together, Groq, DeepInfra, OpenRouter, …) speak +the OpenAI chat-completions wire format. Use `backend: proxy-openai` with the +provider's URL and API key: + +```yaml +name: llama-3-70b-via-together +backend: proxy-openai + +proxy: + upstream_url: https://api.together.xyz/v1/chat/completions + api_key_env: TOGETHER_API_KEY + upstream_model: meta-llama/Llama-3-70b-chat-hf +``` + +## Combining with the intelligent router + +A router model can spread traffic across local and cloud candidates. The +score classifier reads the policy descriptions and routes per request: + +```yaml +name: smart-router +router: + classifier: score + classifier_model: arch-router-1.5b + fallback: qwen-3-7b-local + activation_threshold: 0.40 + policies: + - label: casual + description: small talk, greetings, short answers + - label: code + description: writing or debugging code in any programming language + - label: heavy-reasoning + description: long-form analysis, complex math, multi-step reasoning + candidates: + - model: qwen-3-7b-local + labels: [casual] + - model: gpt-4o-proxy + labels: [casual, code] + - model: claude-sonnet-proxy + labels: [casual, code, heavy-reasoning] +``` + +The router rewrites `input.Model` to the chosen candidate; per-model PII, +ACLs, and the cloud-proxy fork all run against the resolved target. + +See [Middleware: PII filtering and intelligent routing]({{< relref "middleware.md" >}}) +for the full router and PII-filter reference. + +## Limitations in the MVP + +- **No request-shape translation.** A `proxy-anthropic` model only accepts + Anthropic-shaped requests; a `proxy-openai` model only accepts OpenAI-shaped + ones. +- **No output-side PII for non-streaming responses.** Streaming responses are + filtered in flight; buffered responses pass through verbatim. Request-side + PII covers both. +- **No retry or backoff.** Transient upstream failures bubble up to the client + as `502 Bad Gateway`. +- **No request shape validation.** If the upstream rejects the body, its + error envelope is forwarded to the client unchanged. + +## Operational notes + +- The fork happens before all local pipeline work, so cloud-proxy models do + not load gRPC backends. They consume no GPU memory and don't appear in the + VRAM admin view. +- Usage stats and the trace log capture cloud-proxy requests like any other + request. Token counts come from the upstream's `usage` field when present. +- Set `request_timeout_seconds` defensively — a hung upstream otherwise ties + up an HTTP handler until the client disconnects. diff --git a/docs/content/features/middleware.md b/docs/content/features/middleware.md new file mode 100644 index 000000000000..19fd8bde3ed3 --- /dev/null +++ b/docs/content/features/middleware.md @@ -0,0 +1,454 @@ ++++ +title = "Middleware: PII filtering and intelligent routing" +weight = 27 +toc = true +description = "Per-model PII redaction and policy-based request routing" +tags = ["Routing", "Privacy", "PII", "Middleware", "Advanced"] +categories = ["Features"] ++++ + +LocalAI ships a request-middleware layer that sits between the HTTP API and +the backend dispatcher. Two subsystems share that layer because they share +the same lifecycle hook: **PII filtering** scans the request body before it +reaches a backend (and the SSE stream on the way out), and the **intelligent +router** rewrites `input.Model` so a single client-facing model name fans +out across multiple downstream targets. + +Both are inspected and configured from the same admin page +(`/app/middleware`), backed by the same REST surface (`/api/middleware/*`, +`/api/pii/*`, `/api/router/*`) and the same MCP tools. + +## Request lifecycle + +``` +client ── auth ── route-model ── per-model PII ── backend ── streaming PII ── client + │ │ + └─── decision log └─── event log +``` + +The router runs first (it picks the target model so per-model PII has +something to gate on), per-model PII runs next (gated by the resolved +config), the backend executes, and the streaming PII filter rewrites the +SSE response in flight. Each subsystem writes to its own admin-visible +log: `/api/router/decisions` for routing, `/api/pii/events` for redaction +and block actions. + +--- + +## PII filtering + +PII redaction is **per-model and off by default**. The default flips to +**on for any backend whose name starts with `proxy-`** because that traffic +crosses the network to a third-party provider. Explicit `pii.enabled` +in a model's YAML always wins over the backend default. + +### Pattern catalog + +The built-in regex tier ships six patterns. Each has a default action +(`mask`, `block`, or `route_local`) and a length cap that prevents +pathological inputs from blowing up scanning time: + +| ID | Description | Default action | Max length | +|---|---|---|---| +| `email` | Email address | `mask` | 254 | +| `phone` | Phone number (international or US) | `mask` | 24 | +| `ssn` | US Social Security Number | `mask` | 11 | +| `credit_card` | Credit card number (Luhn-verified) | `mask` | 19 | +| `ipv4` | IPv4 address | `mask` | 15 | +| `api_key_prefix` | `sk-`, `pk-`, `xoxb-`, `ghp_`, `github_pat_` | **`block`** | 200 | + +`mask` rewrites the match to `[REDACTED:]` in the request body before +forwarding. `block` returns HTTP 400 with `error.type=pii_blocked` to the +client without forwarding. `route_local` is reserved for the routing +integration (see below) and falls back to `mask` when no local route is +available. + +### Per-model configuration + +Add a `pii:` block to a model YAML to opt in (or out, or to override +per-pattern actions): + +```yaml +# Local model — explicit opt-in so chats with this model get redaction +# applied request-side. +name: qwen-7b-local +backend: llama-cpp +pii: + enabled: true +``` + +```yaml +# Cloud-bound model — defaults to enabled because backend starts with +# proxy-. Tighten api_key_prefix from the global default and downgrade +# email to route_local so emails route to a local model rather than +# leaving the network. +name: claude-strict +backend: proxy-anthropic +proxy: + upstream_url: https://api.anthropic.com/v1/messages + api_key_env: ANTHROPIC_API_KEY +pii: + patterns: + - id: api_key_prefix + action: block # already the default, made explicit for audit + - id: email + action: route_local +``` + +The regex itself stays global — only the action is settable per-model. +Adding new patterns is a build-time concern (extend `patternRegexps` in +`core/services/routing/pii/patterns.go`). + +### NER tier (optional) + +The regex matcher covers high-precision patterns. For natural-language +PII (proper names, addresses, organization names) LocalAI carries an +**encoder NER tier** that runs after the regex pass. It expects a +transformers token-classification model wired through the `TokenClassify` +gRPC primitive (e.g. `dslim/bert-base-NER`). The detector annotates +spans with an entity group (`PER`, `LOC`, `ORG`, `MISC`); per-group +actions are configurable through the same `pii:` block. + +The NER tier ships as a contract (`NERDetector`, `NERConfig` in +`core/services/routing/pii/ner.go`); an operator-facing knob to load and +attach a detector is not plumbed yet. When no detector is configured the +regex tier still runs. + +### Streaming PII filter + +Buffered (`/v1/chat/completions` without `"stream": true`) responses are +forwarded verbatim today — only the request-side scan runs. Streaming +responses run through `pii.StreamFilter` which buffers SSE chunks until +either a full pattern matches or the buffer's max length is reached, +then emits the safe prefix. The streaming filter is what makes the +cloud-proxy backend and the MITM proxy safe to expose to clients that +issue streaming requests. + +The streaming filter is wired automatically for any model with `pii.enabled` +true — there is no separate streaming toggle. + +### Admin page + +The `/app/middleware` page (admin role only) has four tabs — **Filtering**, +**Routing**, **MITM Proxy** (see the [MITM doc]({{< relref "mitm-proxy.md" >}})), +and **Events**. The Filtering tab shows: + +- The pattern catalogue with live action dropdowns. Changing an action via + the UI calls `PUT /api/pii/patterns/:id` and updates the live redactor + in-process. Click **Persist** in the action header to write the current + state into `runtime_settings.json` so the next process start re-applies it. +- A per-model resolved-state table — each model row reports `enabled`, + the per-pattern overrides, and which patterns are effectively active. +- A live test panel that posts sample text to `/api/pii/test` and + highlights matches with their resolved actions, without storing the + text in the event log. + +### REST surface + +| Method | Path | Auth | Purpose | +|---|---|---|---| +| GET | `/api/pii/patterns` | any | Live pattern list with current actions. Used by the UI catalogue. | +| POST | `/api/pii/test` | any | Dry-run the redactor on `{"text":"..."}`. Returns hits and the would-be-rewritten body. Does not write to the event log. | +| GET | `/api/pii/events` | admin | Recent middleware events — PII redactions, MITM connect/traffic, admission denials. Filterable by `correlation_id`, `user_id`, `pattern_id`, `kind`. | +| PUT | `/api/pii/patterns/:id` | admin | Update a pattern in-process. Body accepts `{"action":"mask"\|"block"\|"route_local"}` and/or `{"disabled":true\|false}`. Transient — reverts on restart unless persisted. | +| POST | `/api/pii/patterns/persist` | admin | Snapshot the live per-pattern (action, disabled) state into `runtime_settings.json`. | +| GET | `/api/middleware/status` | admin | Aggregated dashboard data: patterns + per-model resolved state + router status + MITM status + admission status. One round-trip for the UI. | + +### MCP tools + +The same surface is mirrored through the LocalAI Assistant MCP server so +the in-process and stdio assistants can manage the filter conversationally: + +| Tool | Read/Write | Purpose | +|---|---|---| +| `list_pii_patterns` | read | Returns the live pattern list. | +| `get_pii_events` | read | Recent redaction / block events with optional filters. | +| `test_pii_redaction` | read | Dry-run sample text without writing to the event log. | +| `get_middleware_status` | read | Aggregator — the same payload as `GET /api/middleware/status`. | +| `set_pii_pattern_action` | write | Update a pattern's action. Admin-only. | +| `persist_pii_patterns` | write | Snapshot live state to `runtime_settings.json`. Admin-only. | + +--- + +## Intelligent routing + +A **router model** is a model whose YAML carries a `router:` block. When +a client addresses it (`"model": "smart-router"`), the middleware +classifies the prompt, picks a downstream candidate model, rewrites +`input.Model` to the candidate, and the standard model-resolution path +runs against that resolved target. ACL checks, disabled-state, and +per-model PII all apply to the resolved model — the router does +*model selection only*. + +#### Depth-1 invariant + +Candidates **must not** themselves be router models. A +`smart-router → claude-strict → proxy-anthropic` chain is fine +(`claude-strict` is a regular proxy model). A +`smart-router → other-router → real-model` chain is rejected at runtime +by the middleware (the dispatcher returns HTTP 500 with a +`depth-1 invariant` error). This keeps the dispatch graph acyclic and +predictable. + +#### Fallback + +If no candidate's label set covers the active label set from the classifier, +or the classifier errors out, the router uses `cfg.Router.Fallback`. +An empty `fallback` causes the dispatch to fail with HTTP 500 rather +than silently routing somewhere unintended — fail-fast, not +silent-bypass. + +### The Score classifier + +The only classifier shipped today is `score`. It works like this: + +1. Build a Qwen/ChatML system prompt that lists every policy label with + its description and primes the model to emit a label as the assistant + turn. +2. Ask the classifier model to **score every policy label** as the + first-token(s) continuation. This uses the `Score` gRPC primitive + (`backend.proto::Score`), which returns per-candidate log-probabilities + length-normalized so candidates of unequal token length stay + comparable. +3. Softmax the length-normalized log-probabilities into a probability + distribution over labels. +4. Threshold the distribution: every label whose probability passes + `activation_threshold` joins the **active label set**. +5. Pick the FIRST candidate whose `Labels` is a superset of the active + set. Admins order candidates smallest → largest so a single-label + query routes to the smallest capable model, while a query that + activates multiple labels falls to a candidate that covers them all. + +This is the Arch-Router approach extended for multi-label. The +distribution carries more signal than the argmax — reading off the +spread lets one prompt activate multiple policies and route to a model +capable of all of them. + +#### Recommended classifier model + +[Arch-Router-1.5B](https://huggingface.co/katanemo/Arch-Router-1.5B) is +the canonical choice. It's a Qwen-2.5-1.5B-Instruct base trained +specifically on routing-policy continuation, so the ChatML system-prompt ++ label-continuation pattern produces well-separated label probabilities +without prompt tuning. The Q4_K_M GGUF runs on CPU, GPU, and Intel SYCL. + +The classifier model must support the `Score` gRPC primitive (today: the +llama-cpp and vLLM backends) and use the ChatML chat template. Any small +ChatML instruct model works under those constraints, but expect flatter +probability distributions which translate to a higher +`activation_threshold` to keep noise out of the active label set. + +### YAML reference + +```yaml +name: smart-router +known_usecases: + - chat +router: + # The only classifier shipped today. + classifier: score + + # A model loaded by LocalAI that supports the Score gRPC primitive + # (llama-cpp and vLLM ship implementations). Arch-Router-1.5B is the + # canonical choice. + classifier_model: arch-router-1.5b + + # Bounded LRU keyed on (case-folded, whitespace-trimmed) prompt — prompts + # repeat in agent loops; the cache amortises the classifier round-trip + # across them. 0 here means "use the default" (1024); the cache cannot be + # disabled from YAML today. + classifier_cache_size: 256 + + # Softmax probability floor a label must clear to join the active label set. + # 0 = use the package default (0.15). 0.40 is a better empirical + # starting point on Arch-Router-1.5B — see the tuning note below. + activation_threshold: 0.40 + + # Used when no candidate covers the active label set, or the classifier + # itself errors. Empty here = fail-fast with HTTP 500. + fallback: qwen3-0.6b + + # The label vocabulary. Descriptions are fed verbatim into the + # classifier's system prompt — short, action-oriented sentences work + # best ("writing or debugging code", "small talk"). + policies: + - label: code-generation + description: writing, debugging, reading, or explaining code in any programming language + - label: casual-chat + description: small talk, greetings, jokes, or general conversation with no specific task + - label: math-reasoning + description: arithmetic, equations, percentage calculations, or step-by-step word problems + + # Routing table — order matters (smallest → largest). See "Score + # classifier" above for the matching rule. + candidates: + - model: qwen3-0.6b + labels: [casual-chat] + - model: qwen_qwen3.5-2b + labels: [code-generation, casual-chat, math-reasoning] +``` + +### Tuning `activation_threshold` + +The threshold is the single knob you'll want to tune per +(classifier-model, policy-set) pair. On Arch-Router-1.5B with the +three-policy setup above, sweeping the threshold over a hand-labeled +30-prompt corpus produced: + +| Threshold | Label-set accuracy | End-to-end routing accuracy | +|---:|---:|---:| +| 0.15 (package default) | 30% | 73% | +| 0.30 | 57% | 87% | +| **0.40** | **60%** | **90%** | +| 0.45 | 67% | 97% | +| 0.50 | 67% | 97% | + +The classifier's argmax matches the dominant label 93% of the time on +this corpus — what the threshold controls is how much secondary-label +noise leaks into the active label set. Low thresholds push single-label +queries to multi-label-capable (larger) candidates unnecessarily; 0.40 +keeps the dominant label dominant without losing genuine compound +activations. + +Re-tune per (classifier-model, policy-set) pair. The `/api/score` +endpoint (see below) is the convenient probe — it returns the raw +length-normalized log-probabilities so you can sweep thresholds offline +without driving real chat completions. + +### Embedding cache (L2) + +Classification is the most expensive thing the middleware does. The +score classifier already memo-caches verbatim repeats (case- and +whitespace-folded prompt → decision); the **embedding cache** is the +L2 tier that catches *semantically similar* prompts — "How do I exit +vim?" and "i need to quit vim" can share a decision instead of running +the classifier twice. + +Pairs naturally with a larger / slower classifier model: the steady-state +cost on cache hits collapses to one embedding round-trip plus a KNN +search, both well under 100ms with `nomic-embed-text-v1.5` + local-store. + +#### Configuration + +Add an `embedding_cache:` block to a router model: + +```yaml +router: + classifier: score + classifier_model: arch-router-1.5b + policies: [...] + candidates: [...] + + embedding_cache: + embedding_model: nomic-embed-text-v1.5 # any loaded embedding model + similarity_threshold: 0.80 # cosine sim floor for a hit (default 0.80) + confidence_threshold: 0.60 # min top-label prob to cache a decision (default 0.60) + # store_name: router-cache-smart-router # optional override; defaults to "router-cache-" +``` + +Omit the block entirely to disable. The cache adds two new failure modes +(embedder unavailable, store unavailable) — both fall through to the +inner classifier so routing keeps working. + +#### How it works + +For each request: + +1. Embed the probe prompt via the configured `embedding_model`. +2. KNN top-1 against the per-router local-store collection. +3. If similarity ≥ `similarity_threshold`, return the cached decision + (`Cached=true`, `CacheSimilarity=` in the decision log). +4. Miss → run the inner classifier. If `decision.score >= confidence_threshold`, + insert `(embedding, decision)` into the store. Low-confidence + decisions are deliberately skipped so they can't poison future + paraphrases. + +The local-store collection is named `router-cache-` by +default — each router gets its own collection so two routers can't +cross-contaminate. Collections persist on disk (local-store is the +canonical persistent vector backend), so the cache survives restarts. + +#### Tuning notes + +- **Similarity threshold**: 0.80 is the package default — re-tune + per (embedding model, corpus). The histogram on the Routing tab + shows where the cosine distribution actually sits; pick a + threshold above the cross-intent cluster and below the paraphrase + cluster. +- **Confidence threshold**: 0.60 corresponds roughly to "the + classifier is committed to a top label." Don't lower this — caching + unsure decisions propagates the uncertainty. +- **Cache flush**: invalidates automatically when the router YAML + changes (the classifier cache is fingerprinted by `yaml.Marshal`), + but the underlying local-store collection still holds the old + payloads. Manual flush via local-store admin or by renaming + `store_name` if you need a hard reset. +- **Latency budget**: an embedding round-trip (typically 30–80ms for + small embedding models) plus KNN search (~5ms) is added to every + *miss* on top of the classifier latency. Cache hits skip the + classifier entirely. Break-even is around 7–10% hit rate; agent + loops with repeated phrasing easily exceed this. + +### Admin page + +The `/app/middleware` page has a **Routing** tab listing every router +model's classifier, policies, candidates, and fallback. The **Events** +tab shows the decision log — one row per classified request with +correlation ID, requested model, served model, classifier name, active +labels, top-label score, and latency. + +Routing decisions are stored in an in-process ring buffer (default +capacity 5,000). The decision log is for audit and tuning — the +canonical usage log lives in `/api/usage` and correlates by request ID. + +### REST surface + +| Method | Path | Auth | Purpose | +|---|---|---|---| +| GET | `/api/router/status` | any | Router configuration: each router model's classifier, policies, candidates. | +| GET | `/api/router/decisions` | admin | Decision log with optional filters (`correlation_id`, `user_id`, `router_model`, `limit`). | +| POST | `/api/score` | admin | Direct access to the `Score` gRPC primitive — useful for offline threshold tuning. Body: `{"model": "", "prompt": "", "candidates": ["label-a", ...], "length_normalize": true}`. The llama-cpp and vLLM backends implement Score; other backends return `UNIMPLEMENTED`. | + +### MCP tools + +| Tool | Read/Write | Purpose | +|---|---|---| +| `get_router_decisions` | read | Recent decision log with optional filters. | +| `get_middleware_status` | read | Includes the router section listing configured router models. | + +Mutating routing config — adding a candidate, changing the classifier +model — is YAML-only today; reload with `POST /models/reload` to pick +up edits without restarting. + +### Operational notes + +- **Reload after YAML edits.** The router configs are loaded at startup + and cached. `POST /models/reload` re-reads from disk; the next request + rebuilds the classifier from the new config (the classifier cache is + fingerprinted by `yaml.Marshal(RouterConfig)` so it invalidates + automatically). +- **Classifier latency** on Arch-Router-1.5B Q4_K_M is ~500ms steady + for 3 policies on Intel SYCL. The score primitive re-decodes the full + prompt for every candidate today (the KV cache is cleared between + candidates); the prompt-KV-sharing optimization is on the perf TODO + list in `backend/cpp/llama-cpp/grpc-server.cpp::Score`. Until then, + `classifier_cache_size` is the highest-leverage knob for repeat-query + workloads (agent loops). +- **Decision log size**: 5,000-entry ring buffer per process. The + log is in-process and not persisted — pair with the usage log for + long-horizon audit. + +--- + +## Related features + +- [Cloud passthrough proxy]({{< relref "cloud-proxy.md" >}}) — combine + the router with `proxy-*` backends to send simple prompts to local + models and complex ones to cloud providers. +- [MITM proxy]({{< relref "mitm-proxy.md" >}}) — apply the same PII + filter to Claude Code, Codex CLI, and any HTTPS client without + LocalAI holding their API keys. +- [Authentication]({{< relref "authentication.md" >}}) — admin role is + required for mutating endpoints and the `/app/middleware` page; in + no-auth single-user mode the synthetic local user has admin role + automatically. diff --git a/docs/content/features/mitm-proxy.md b/docs/content/features/mitm-proxy.md new file mode 100644 index 000000000000..4c0428df463c --- /dev/null +++ b/docs/content/features/mitm-proxy.md @@ -0,0 +1,159 @@ ++++ +title = "MITM proxy for Claude Code / Codex CLI" +weight = 29 +toc = true +description = "Redact PII from cloud-AI traffic without LocalAI holding API keys" +tags = ["Proxy", "MITM", "Privacy", "Routing", "Advanced"] +categories = ["Features"] ++++ + +LocalAI can act as a local HTTPS proxy that **redacts PII from your Claude +Code, OpenAI Codex CLI, or any HTTPS client** without holding their API keys. +The proxy intercepts only the LLM API endpoints you allowlist (default: +`api.anthropic.com`, `api.openai.com`); everything else — OAuth, telemetry, +package fetches — passes through as a plain TCP tunnel. + +Use this when: + +- You want to use **Claude Code with a Claude Pro/Max subscription** but still + apply the same PII redaction LocalAI applies to API-key traffic. +- You run Codex CLI on a corporate laptop and need an audit trail of prompts. +- You want LocalAI to enforce egress policies for AI traffic without + becoming the API endpoint clients talk to. + +The proxy is **off by default**. Operators opt in by setting `--mitm-listen` +and distributing the generated CA cert. + +## How it works + +1. The proxy generates a private CA on first start (persisted to disk). +2. Clients set `HTTPS_PROXY=http://localai:port` and add the CA to their + trust store (e.g. `NODE_EXTRA_CA_CERTS` for Node-based CLIs like Claude + Code and Codex). +3. The CLI sends `CONNECT api.anthropic.com:443` to the proxy. +4. For allowlisted hosts, the proxy mints a per-host leaf cert signed by + the CA, terminates TLS, parses the HTTP request, applies the global + PII redactor on `/v1/messages` or `/v1/chat/completions`, and forwards + to the real upstream over its own TLS connection. +5. The streaming SSE response runs through the same `pii.StreamFilter` + the cloud-proxy backend uses. +6. For non-allowlisted hosts, the proxy is a plain CONNECT tunnel — no + TLS termination, no inspection, no CA trust required. + +The CLI authenticates with its own subscription / API key as it normally +would. LocalAI never holds the credential — it just observes and rewrites +the request body. + +## Quick start + +Start LocalAI with the MITM listener: + +```bash +local-ai run --mitm-listen :8443 +``` + +The first start generates a CA at `/mitm-ca/{ca.crt,ca.key}`. +Restarting reloads the same CA so clients keep trusting it. + +Download the public CA cert: + +```bash +curl -O http://localhost:8080/api/middleware/proxy-ca.crt +``` + +Configure Claude Code to use the proxy and trust the cert: + +```bash +export HTTPS_PROXY=http://localhost:8443 +export NODE_EXTRA_CA_CERTS=$(pwd)/proxy-ca.crt +claude +``` + +Now any `claude` chat session that touches `api.anthropic.com/v1/messages` +gets its prompts and tool inputs scanned by LocalAI's PII filter, and any +PII the model emits in its streaming response is masked before reaching +your terminal. Events appear in the LocalAI middleware admin page under +**Filtering → Recent events**. + +The same works for Codex CLI — set `HTTPS_PROXY` and `NODE_EXTRA_CA_CERTS` +and run `codex`. + +## Configuration + +| Flag / env | Default | Purpose | +|---|---|---| +| `--mitm-listen` / `LOCALAI_MITM_LISTEN` | empty (disabled) | Address to bind the proxy listener on | +| `--mitm-ca-dir` / `LOCALAI_MITM_CA_DIR` | `/mitm-ca` | Where to persist the CA cert + key | +| `--mitm-intercept-hosts` / `LOCALAI_MITM_INTERCEPT_HOSTS` | `api.anthropic.com,api.openai.com` | Hosts to terminate TLS for; everything else tunnels | + +Hostnames are case-insensitive. Add custom upstreams (e.g. an +OpenAI-compatible third-party provider) by extending the allowlist and +ensuring their endpoint paths match `/v1/chat/completions` or +`/v1/messages`. + +## What gets redacted + +Same patterns the regular request middleware uses: + +- Email addresses → masked +- Phone numbers → masked +- US Social Security Numbers → masked +- Credit card numbers (Luhn-verified) → masked +- IPv4 addresses → masked +- API key prefixes (`sk-`, `pk-`, `ghp_`, `github_pat_`, `xoxb-`) → **blocked** + +A `block` action returns HTTP 400 with `error.type=pii_blocked` to the +client. The CLI sees the rejection and shows it as a request error. + +Events are persisted via the same `pii.EventStore` the rest of LocalAI +uses, so the `/api/pii/events` endpoint and the middleware admin page +include MITM events alongside direct-API events. + +## Security notes + +- **The CA private key is the master credential.** Anyone with read + access to `/mitm-ca/ca.key` can forge TLS for any host the + proxy could intercept. The file is mode 0600; keep it that way. +- The proxy listener accepts plaintext HTTP `CONNECT` requests — bind it + to localhost (`--mitm-listen 127.0.0.1:8443`) unless you've added auth + in front of the listener. There is no built-in API-key check on this + port. +- The MITM CA is **separate** from any TLS cert LocalAI's main HTTP API + uses. Installing the MITM CA grants trust only for traffic that flows + through this proxy. +- The proxy does not pin upstream certificates; it trusts the system + certificate store. If your machine's trust store is compromised, the + proxy is too. +- TLS termination negotiates HTTP/2 by default (ALPN `h2`) and falls + back to HTTP/1.1 for clients that don't speak h2. Modern CLIs (Claude + Code, Codex) and the Anthropic / OpenAI APIs all use h2. + +## Limitations + +- **Only `/v1/messages` and `/v1/chat/completions` get redacted.** Other + paths on the same host (OAuth, model listing) are forwarded verbatim. +- **No request-shape translation.** The proxy assumes the request body + matches the host's wire format; cross-shape forwarding is the cloud + proxy backend's job, not the MITM's. +- **No CA rotation in the MVP.** To rotate, delete `ca.key` and `ca.crt` + and re-distribute the new cert to every client. +- **Cert pinning kills MITM.** Neither Claude Code nor Codex CLI pins + certificates today, but a future SDK update could. If a CLI starts + refusing the proxied handshake, that's the signal. + +## Comparison with the cloud-proxy backend + +LocalAI ships two cloud-related proxy modes; pick by who holds the credential: + +| | Cloud-proxy backend (`backend: proxy-*`) | MITM proxy (`--mitm-listen`) | +|---|---|---| +| Client config | `localai:8080` as **API endpoint** | `localai:8443` as **HTTPS_PROXY** | +| Holds API key | LocalAI | Client (CLI's own auth) | +| Works with subscription auth | No | Yes (CLI uses its own login) | +| Request rewriting | Yes (handler controls it) | Yes (selective per host+path) | +| CA cert distribution | Not needed | Required on every client | +| Routes through LocalAI's auth/usage tracking | Yes | Yes (per-correlation-id events) | + +For shared deployments where LocalAI owns the API key and clients are +unsophisticated (curl, simple webapps), use the cloud-proxy backend. For +"give my Claude Code a privacy filter" use cases, use the MITM proxy. diff --git a/go.mod b/go.mod index 66f981647515..4cfde81de604 100644 --- a/go.mod +++ b/go.mod @@ -288,7 +288,7 @@ require ( go.yaml.in/yaml/v2 v2.4.4 go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/image v0.38.0 // indirect - golang.org/x/net v0.53.0 // indirect; indirect (for websocket) + golang.org/x/net v0.53.0 // indirect (for websocket) golang.org/x/oauth2 v0.36.0 golang.org/x/telemetry v0.0.0-20260409153401-be6f6cb8b1fa // indirect golang.org/x/time v0.14.0 // indirect diff --git a/pkg/grpc/backend.go b/pkg/grpc/backend.go index eaabea8ef7fb..1843c2f855b2 100644 --- a/pkg/grpc/backend.go +++ b/pkg/grpc/backend.go @@ -71,6 +71,10 @@ type Backend interface { Rerank(ctx context.Context, in *pb.RerankRequest, opts ...grpc.CallOption) (*pb.RerankResult, error) + TokenClassify(ctx context.Context, in *pb.TokenClassifyRequest, opts ...grpc.CallOption) (*pb.TokenClassifyResponse, error) + + Score(ctx context.Context, in *pb.ScoreRequest, opts ...grpc.CallOption) (*pb.ScoreResponse, error) + GetTokenMetrics(ctx context.Context, in *pb.MetricsRequest, opts ...grpc.CallOption) (*pb.MetricsResponse, error) VAD(ctx context.Context, in *pb.VADRequest, opts ...grpc.CallOption) (*pb.VADResponse, error) diff --git a/pkg/grpc/client.go b/pkg/grpc/client.go index 8360d26452b3..9b649ecaaddd 100644 --- a/pkg/grpc/client.go +++ b/pkg/grpc/client.go @@ -526,6 +526,42 @@ func (c *Client) Rerank(ctx context.Context, in *pb.RerankRequest, opts ...grpc. return client.Rerank(ctx, in, opts...) } +func (c *Client) TokenClassify(ctx context.Context, in *pb.TokenClassifyRequest, opts ...grpc.CallOption) (*pb.TokenClassifyResponse, error) { + if !c.parallel { + c.opMutex.Lock() + defer c.opMutex.Unlock() + } + c.setBusy(true) + defer c.setBusy(false) + c.wdMark() + defer c.wdUnMark() + conn, err := c.dial() + if err != nil { + return nil, err + } + defer func() { _ = conn.Close() }() + client := pb.NewBackendClient(conn) + return client.TokenClassify(ctx, in, opts...) +} + +func (c *Client) Score(ctx context.Context, in *pb.ScoreRequest, opts ...grpc.CallOption) (*pb.ScoreResponse, error) { + if !c.parallel { + c.opMutex.Lock() + defer c.opMutex.Unlock() + } + c.setBusy(true) + defer c.setBusy(false) + c.wdMark() + defer c.wdUnMark() + conn, err := c.dial() + if err != nil { + return nil, err + } + defer func() { _ = conn.Close() }() + client := pb.NewBackendClient(conn) + return client.Score(ctx, in, opts...) +} + func (c *Client) GetTokenMetrics(ctx context.Context, in *pb.MetricsRequest, opts ...grpc.CallOption) (*pb.MetricsResponse, error) { if !c.parallel { c.opMutex.Lock() diff --git a/pkg/grpc/embed.go b/pkg/grpc/embed.go index 15d9615c81c2..aefdf7dbdb8a 100644 --- a/pkg/grpc/embed.go +++ b/pkg/grpc/embed.go @@ -132,6 +132,14 @@ func (e *embedBackend) Rerank(ctx context.Context, in *pb.RerankRequest, opts .. return e.s.Rerank(ctx, in) } +func (e *embedBackend) TokenClassify(ctx context.Context, in *pb.TokenClassifyRequest, opts ...grpc.CallOption) (*pb.TokenClassifyResponse, error) { + return e.s.TokenClassify(ctx, in) +} + +func (e *embedBackend) Score(ctx context.Context, in *pb.ScoreRequest, opts ...grpc.CallOption) (*pb.ScoreResponse, error) { + return e.s.Score(ctx, in) +} + func (e *embedBackend) VAD(ctx context.Context, in *pb.VADRequest, opts ...grpc.CallOption) (*pb.VADResponse, error) { return e.s.VAD(ctx, in) } diff --git a/pkg/mcp/localaitools/client.go b/pkg/mcp/localaitools/client.go index ac77e789bb50..60090d63aebf 100644 --- a/pkg/mcp/localaitools/client.go +++ b/pkg/mcp/localaitools/client.go @@ -67,4 +67,43 @@ type LocalAIClient interface { // SetBranding updates the text branding fields. Asset uploads are not // exposed over MCP — admins use the Settings UI for binary files. SetBranding(ctx context.Context, req SetBrandingRequest) (*Branding, error) + + // ---- Usage / billing ---- + + // GetUsageStats returns aggregated token usage. In single-user + // no-auth mode this reports the synthetic local user's usage. The + // implementation enforces "admin required to query other users". + GetUsageStats(ctx context.Context, q UsageStatsQuery) (*UsageStats, error) + + // ---- PII filter ---- + // ListPIIPatterns returns the active PII pattern set with each + // one's action. + ListPIIPatterns(ctx context.Context) ([]PIIPattern, error) + // GetPIIEvents returns recent redaction events. Implementation + // enforces "admin required" when auth is on. + GetPIIEvents(ctx context.Context, q PIIEventsQuery) ([]PIIEvent, error) + // TestPIIRedaction dry-runs the redactor against text. No event + // is recorded. + TestPIIRedaction(ctx context.Context, req PIIRedactTestRequest) (*PIIRedactTestResult, error) + // SetPIIPatternAction mutates the named pattern's action and/or + // disabled state in-process. Transient until PersistPIIPatterns is + // called — runtime_settings.json then applies the deltas on the + // next start. Admin-required. + SetPIIPatternAction(ctx context.Context, req PIIPatternActionUpdate) error + + // PersistPIIPatterns snapshots the live redactor's per-pattern + // (action, disabled) state into runtime_settings.json. Admin-required. + PersistPIIPatterns(ctx context.Context) error + + // ---- Middleware admin ---- + // GetMiddlewareStatus returns the aggregated state surfaced on the + // /app/middleware page: active PII patterns, per-model resolved + // enabled state, recent event count, router placeholder. + GetMiddlewareStatus(ctx context.Context) (*MiddlewareStatus, error) + + // ---- Router (intelligent routing) ---- + // GetRouterDecisions returns recent routing decisions for the + // /app/middleware Routing tab and for agent-driven introspection. + // Admin-required when auth is on. + GetRouterDecisions(ctx context.Context, q RouterDecisionsQuery) ([]RouterDecision, error) } diff --git a/pkg/mcp/localaitools/coverage_test.go b/pkg/mcp/localaitools/coverage_test.go index d8054ae04918..8159afcf9e89 100644 --- a/pkg/mcp/localaitools/coverage_test.go +++ b/pkg/mcp/localaitools/coverage_test.go @@ -37,18 +37,26 @@ var toolToHTTPRoute = map[string]string{ ToolListNodes: "GET /api/nodes", ToolVRAMEstimate: "POST /api/models/vram-estimate", ToolGetBranding: "GET /api/branding", + ToolGetUsageStats: "GET /api/usage (or /api/usage/all when all=true)", + ToolListPIIPatterns: "GET /api/pii/patterns", + ToolGetPIIEvents: "GET /api/pii/events", + ToolTestPIIRedaction: "POST /api/pii/test", + ToolGetMiddlewareStatus: "GET /api/middleware/status", + ToolGetRouterDecisions: "GET /api/router/decisions", // Mutating tools. - ToolInstallModel: "POST /models/apply", - ToolImportModelURI: "POST /models/import-uri", - ToolDeleteModel: "POST /models/delete/:name", - ToolEditModelConfig: "PATCH /api/models/config-json/:name", - ToolReloadModels: "POST /models/reload", - ToolInstallBackend: "POST /backends/apply", - ToolUpgradeBackend: "POST /backends/upgrade/:name", - ToolToggleModelState: "PUT /models/toggle-state/:name/:action", - ToolToggleModelPinned: "PUT /models/toggle-pinned/:name/:action", - ToolSetBranding: "POST /api/settings (instance_name, instance_tagline)", + ToolInstallModel: "POST /models/apply", + ToolImportModelURI: "POST /models/import-uri", + ToolDeleteModel: "POST /models/delete/:name", + ToolEditModelConfig: "PATCH /api/models/config-json/:name", + ToolReloadModels: "POST /models/reload", + ToolInstallBackend: "POST /backends/apply", + ToolUpgradeBackend: "POST /backends/upgrade/:name", + ToolToggleModelState: "PUT /models/toggle-state/:name/:action", + ToolToggleModelPinned: "PUT /models/toggle-pinned/:name/:action", + ToolSetBranding: "POST /api/settings (instance_name, instance_tagline)", + ToolSetPIIPatternAction: "PUT /api/pii/patterns/:id", + ToolPersistPIIPatterns: "POST /api/pii/patterns/persist", } // allKnownTools is the union of expectedFullCatalog (defined in diff --git a/pkg/mcp/localaitools/dto.go b/pkg/mcp/localaitools/dto.go index 4816d6d091ce..85136c60ae3a 100644 --- a/pkg/mcp/localaitools/dto.go +++ b/pkg/mcp/localaitools/dto.go @@ -137,6 +137,179 @@ type SetBrandingRequest struct { InstanceTagline *string `json:"instance_tagline,omitempty" jsonschema:"Optional short subtitle shown beneath the instance name. Pass an empty string to clear."` } +// UsageStatsQuery is the input for get_usage_stats. UserID is optional; +// when empty the tool returns the calling user's own usage in auth-on +// mode, or the synthetic local user's usage in single-user no-auth +// mode. Admins (or the local user) may pass UserID to inspect another +// user; the LocalAIClient implementation enforces the role check. +type UsageStatsQuery struct { + Period string `json:"period,omitempty" jsonschema:"Time window. One of: day, week, month, all. Defaults to month."` + UserID string `json:"user_id,omitempty" jsonschema:"Optional user id to query. Empty = caller's own usage. Querying another user requires admin role."` + All bool `json:"all,omitempty" jsonschema:"When true, returns the cluster-wide /api/usage/all view (admin-only when auth is on)."` +} + +// UsageStats is the response shape for get_usage_stats. Mirrors what +// /api/usage and /api/usage/all return so the LLM can correlate +// dashboard numbers with what it pulls via MCP. +type UsageStats struct { + Viewer UsageViewer `json:"viewer"` + Period string `json:"period"` + Totals UsageTotals `json:"totals"` + Buckets []UsageBucket `json:"buckets"` +} + +type UsageViewer struct { + ID string `json:"id"` + Name string `json:"name"` + Role string `json:"role,omitempty"` +} + +type UsageTotals struct { + PromptTokens int64 `json:"prompt_tokens"` + CompletionTokens int64 `json:"completion_tokens"` + TotalTokens int64 `json:"total_tokens"` + RequestCount int64 `json:"request_count"` +} + +type UsageBucket struct { + Bucket string `json:"bucket"` + Model string `json:"model"` + UserID string `json:"user_id,omitempty"` + UserName string `json:"user_name,omitempty"` + PromptTokens int64 `json:"prompt_tokens"` + CompletionTokens int64 `json:"completion_tokens"` + TotalTokens int64 `json:"total_tokens"` + RequestCount int64 `json:"request_count"` +} + +// ---- PII / sensitive data tools ---- + +// PIIPattern is one row in the list_pii_patterns response. +type PIIPattern struct { + ID string `json:"id"` + Description string `json:"description"` + Action string `json:"action"` // mask | block | route_local + MaxMatchLength int `json:"max_match_length"` +} + +// PIIEventsQuery filters get_pii_events. +type PIIEventsQuery struct { + CorrelationID string `json:"correlation_id,omitempty" jsonschema:"Optional X-Correlation-ID join key (binds events to the request and usage record)."` + UserID string `json:"user_id,omitempty" jsonschema:"Optional user id to scope the query."` + PatternID string `json:"pattern_id,omitempty" jsonschema:"Optional pattern id (e.g. email, ssn)."` + Limit int `json:"limit,omitempty" jsonschema:"Maximum events. Defaults to 100."` +} + +// PIIEvent is the LLM-facing view of one redaction record. The matched +// value is never exposed; admins audit by hash_prefix. +type PIIEvent struct { + ID string `json:"id"` + CorrelationID string `json:"correlation_id"` + UserID string `json:"user_id"` + Direction string `json:"direction"` + PatternID string `json:"pattern_id"` + ByteOffset int `json:"byte_offset"` + Length int `json:"length"` + HashPrefix string `json:"hash_prefix"` + Action string `json:"action"` + CreatedAt string `json:"created_at"` +} + +// PIIRedactTestRequest is the input for test_pii_redaction. +type PIIRedactTestRequest struct { + Text string `json:"text" jsonschema:"The candidate text. Will be run through the redactor without recording an event."` +} + +// PIIRedactTestResult is the output for test_pii_redaction. spans +// describes where the redactor matched; redacted is the text after +// applying mask actions; blocked / local_only flag stronger actions. +type PIIRedactTestResult struct { + Redacted string `json:"redacted"` + Spans []PIIEventSpan `json:"spans"` + Blocked bool `json:"blocked"` + LocalOnly bool `json:"local_only"` +} + +type PIIEventSpan struct { + Start int `json:"start"` + End int `json:"end"` + Pattern string `json:"pattern"` + HashPrefix string `json:"hash_prefix"` +} + +// PIIPatternActionUpdate is the input for set_pii_pattern_action. +// At least one of Action or Disabled must be set. Mutations are +// transient by default — call persist_pii_patterns to flush them +// to runtime_settings.json so the next start re-applies them. +type PIIPatternActionUpdate struct { + ID string `json:"id" jsonschema:"Pattern id to mutate (e.g. email, ssn, credit_card, api_key_prefix)."` + Action string `json:"action,omitempty" jsonschema:"New action: mask, block, or route_local. Optional — omit to leave the action unchanged."` + Disabled *bool `json:"disabled,omitempty" jsonschema:"Set true to skip this pattern entirely; false to re-enable. Optional — omit to leave enabled-state unchanged."` +} + +// MiddlewareStatus is the aggregated /api/middleware/status payload — +// the React Middleware page renders this in one go. Routing is a +// placeholder until subsystem 2 lands. +type MiddlewareStatus struct { + PII MiddlewarePIIStatus `json:"pii"` + Router MiddlewareRouterStatus `json:"router"` +} + +// MiddlewarePIIStatus shows what the redactor is doing right now and +// which models opt in. enabled_globally=false means --disable-pii. +type MiddlewarePIIStatus struct { + EnabledGlobally bool `json:"enabled_globally"` + Reason string `json:"reason,omitempty"` + DefaultEnabledForBackends []string `json:"default_enabled_for_backends,omitempty"` + Patterns []PIIPattern `json:"patterns"` + Models []MiddlewarePIIModel `json:"models"` + RecentEventCount int `json:"recent_event_count"` +} + +// MiddlewarePIIModel is one model row in the per-model PII table. +type MiddlewarePIIModel struct { + Name string `json:"name"` + Backend string `json:"backend"` + Enabled bool `json:"enabled"` + Explicit bool `json:"explicit"` // Did YAML set Enabled, or did the backend prefix decide? + DefaultForBackend bool `json:"default_for_backend"` // Backend matches the auto-on rule (proxy-*). + Overrides map[string]string `json:"overrides,omitempty"` +} + +// MiddlewareRouterStatus is the placeholder shape the Routing tab +// reads. Subsystem 2 fills in Models with real RouterDecision rows. +type MiddlewareRouterStatus struct { + Configured bool `json:"configured"` + Models []string `json:"models"` + Note string `json:"note,omitempty"` +} + +// RouterDecisionsQuery filters get_router_decisions. +type RouterDecisionsQuery struct { + CorrelationID string `json:"correlation_id,omitempty" jsonschema:"Optional X-Correlation-ID join key (binds decisions to the request and usage record)."` + UserID string `json:"user_id,omitempty" jsonschema:"Optional user id to scope the query."` + RouterModel string `json:"router_model,omitempty" jsonschema:"Optional router model name to filter by (e.g. smart-router)."` + Limit int `json:"limit,omitempty" jsonschema:"Maximum decisions. Defaults to 100."` +} + +// RouterDecision is the LLM-facing view of one routing decision. The +// prompt is NEVER stored; admins audit by hash if they need to dedupe +// recurring routing patterns. +type RouterDecision struct { + ID string `json:"id"` + CorrelationID string `json:"correlation_id"` + UserID string `json:"user_id"` + RouterModel string `json:"router_model"` + RequestedModel string `json:"requested_model"` + ServedModel string `json:"served_model"` + Classifier string `json:"classifier"` + Label string `json:"label"` + Score float64 `json:"score"` + LatencyMs int64 `json:"latency_ms"` + Cached bool `json:"cached"` + CreatedAt string `json:"created_at"` +} + // VRAMEstimateRequest is the input for vram_estimate. The output type is // pkg/vram.EstimateResult — used directly via the LocalAIClient interface // so the LLM sees the same shape (size_bytes/size_display/vram_bytes/ diff --git a/pkg/mcp/localaitools/fakes_test.go b/pkg/mcp/localaitools/fakes_test.go index dcb8abdd39fc..cbe429a081a6 100644 --- a/pkg/mcp/localaitools/fakes_test.go +++ b/pkg/mcp/localaitools/fakes_test.go @@ -3,7 +3,6 @@ package localaitools import ( "context" "errors" - "fmt" "sync" "github.com/mudler/LocalAI/core/config" @@ -45,6 +44,13 @@ type fakeClient struct { toggleModelPinned func(string, modeladmin.Action) error getBranding func() (*Branding, error) setBranding func(SetBrandingRequest) (*Branding, error) + getUsageStats func(UsageStatsQuery) (*UsageStats, error) + listPIIPatterns func() ([]PIIPattern, error) + getPIIEvents func(PIIEventsQuery) ([]PIIEvent, error) + testPIIRedaction func(PIIRedactTestRequest) (*PIIRedactTestResult, error) + setPIIPatternAction func(PIIPatternActionUpdate) error + getMiddlewareStatus func() (*MiddlewareStatus, error) + getRouterDecisions func(RouterDecisionsQuery) ([]RouterDecision, error) } type fakeCall struct { @@ -236,5 +242,74 @@ func (f *fakeClient) SetBranding(_ context.Context, req SetBrandingRequest) (*Br return &Branding{InstanceName: "LocalAI"}, nil } -// boom is a sentinel error used by tests that want a deterministic error string. -var boom = fmt.Errorf("boom") +func (f *fakeClient) GetUsageStats(_ context.Context, q UsageStatsQuery) (*UsageStats, error) { + f.record("GetUsageStats", q) + if f.getUsageStats != nil { + return f.getUsageStats(q) + } + return &UsageStats{ + Viewer: UsageViewer{ID: "fake-user", Name: "fake", Role: "user"}, + Period: "month", + }, nil +} + +func (f *fakeClient) ListPIIPatterns(_ context.Context) ([]PIIPattern, error) { + f.record("ListPIIPatterns", nil) + if f.listPIIPatterns != nil { + return f.listPIIPatterns() + } + return []PIIPattern{}, nil +} + +func (f *fakeClient) GetPIIEvents(_ context.Context, q PIIEventsQuery) ([]PIIEvent, error) { + f.record("GetPIIEvents", q) + if f.getPIIEvents != nil { + return f.getPIIEvents(q) + } + return []PIIEvent{}, nil +} + +func (f *fakeClient) TestPIIRedaction(_ context.Context, req PIIRedactTestRequest) (*PIIRedactTestResult, error) { + f.record("TestPIIRedaction", req) + if f.testPIIRedaction != nil { + return f.testPIIRedaction(req) + } + return &PIIRedactTestResult{Redacted: req.Text}, nil +} + +func (f *fakeClient) SetPIIPatternAction(_ context.Context, req PIIPatternActionUpdate) error { + f.record("SetPIIPatternAction", req) + if f.setPIIPatternAction != nil { + return f.setPIIPatternAction(req) + } + return nil +} + +func (f *fakeClient) PersistPIIPatterns(_ context.Context) error { + f.record("PersistPIIPatterns", nil) + return nil +} + +func (f *fakeClient) GetRouterDecisions(_ context.Context, q RouterDecisionsQuery) ([]RouterDecision, error) { + f.record("GetRouterDecisions", q) + if f.getRouterDecisions != nil { + return f.getRouterDecisions(q) + } + return []RouterDecision{}, nil +} + +func (f *fakeClient) GetMiddlewareStatus(_ context.Context) (*MiddlewareStatus, error) { + f.record("GetMiddlewareStatus", nil) + if f.getMiddlewareStatus != nil { + return f.getMiddlewareStatus() + } + return &MiddlewareStatus{ + PII: MiddlewarePIIStatus{ + EnabledGlobally: true, + Patterns: []PIIPattern{}, + Models: []MiddlewarePIIModel{}, + }, + Router: MiddlewareRouterStatus{Configured: false, Models: []string{}}, + }, nil +} + diff --git a/pkg/mcp/localaitools/httpapi/client.go b/pkg/mcp/localaitools/httpapi/client.go index b32a7600aa95..1e8c08352dc6 100644 --- a/pkg/mcp/localaitools/httpapi/client.go +++ b/pkg/mcp/localaitools/httpapi/client.go @@ -11,6 +11,7 @@ import ( "fmt" "io" "net/http" + "net/url" "strings" "time" @@ -106,7 +107,7 @@ func (c *Client) do(ctx context.Context, method, path string, body any, out any) if err != nil { return err } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() respBody, _ := io.ReadAll(resp.Body) if resp.StatusCode < 200 || resp.StatusCode >= 300 { @@ -290,7 +291,7 @@ func (c *Client) ImportModelURI(ctx context.Context, req localaitools.ImportMode if err != nil { return nil, err } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() respBody, _ := io.ReadAll(resp.Body) // 400 with `error: "ambiguous import"` is not a transport error — it's the @@ -506,6 +507,188 @@ func (c *Client) SetBranding(ctx context.Context, req localaitools.SetBrandingRe return c.GetBranding(ctx) } +// ---- Usage / billing ---- + +func (c *Client) GetUsageStats(ctx context.Context, q localaitools.UsageStatsQuery) (*localaitools.UsageStats, error) { + period := q.Period + if period == "" { + period = "month" + } + path := routeUsage + if q.All { + path = routeUsageAll + } + // Build query string. The /api/usage server expects these exact param + // names; any change there must update both sides. + qs := url.Values{} + qs.Set("period", period) + if q.UserID != "" && q.All { + qs.Set("user_id", q.UserID) + } + if enc := qs.Encode(); enc != "" { + path = path + "?" + enc + } + + var raw struct { + Viewer struct { + ID string `json:"id"` + Name string `json:"name"` + Role string `json:"role"` + } `json:"viewer"` + Totals struct { + PromptTokens int64 `json:"prompt_tokens"` + CompletionTokens int64 `json:"completion_tokens"` + TotalTokens int64 `json:"total_tokens"` + RequestCount int64 `json:"request_count"` + } `json:"totals"` + Usage []struct { + Bucket string `json:"bucket"` + Model string `json:"model"` + UserID string `json:"user_id"` + UserName string `json:"user_name"` + PromptTokens int64 `json:"prompt_tokens"` + CompletionTokens int64 `json:"completion_tokens"` + TotalTokens int64 `json:"total_tokens"` + RequestCount int64 `json:"request_count"` + } `json:"usage"` + } + if err := c.do(ctx, http.MethodGet, path, nil, &raw); err != nil { + return nil, err + } + out := &localaitools.UsageStats{ + Viewer: localaitools.UsageViewer{ID: raw.Viewer.ID, Name: raw.Viewer.Name, Role: raw.Viewer.Role}, + Period: period, + Totals: localaitools.UsageTotals{ + PromptTokens: raw.Totals.PromptTokens, + CompletionTokens: raw.Totals.CompletionTokens, + TotalTokens: raw.Totals.TotalTokens, + RequestCount: raw.Totals.RequestCount, + }, + Buckets: make([]localaitools.UsageBucket, 0, len(raw.Usage)), + } + for _, b := range raw.Usage { + out.Buckets = append(out.Buckets, localaitools.UsageBucket{ + Bucket: b.Bucket, + Model: b.Model, + UserID: b.UserID, + UserName: b.UserName, + PromptTokens: b.PromptTokens, + CompletionTokens: b.CompletionTokens, + TotalTokens: b.TotalTokens, + RequestCount: b.RequestCount, + }) + } + return out, nil +} + +// ---- PII filter ---- + +func (c *Client) ListPIIPatterns(ctx context.Context) ([]localaitools.PIIPattern, error) { + var raw struct { + Patterns []localaitools.PIIPattern `json:"patterns"` + } + if err := c.do(ctx, http.MethodGet, routePIIPatterns, nil, &raw); err != nil { + return nil, err + } + return raw.Patterns, nil +} + +func (c *Client) GetPIIEvents(ctx context.Context, q localaitools.PIIEventsQuery) ([]localaitools.PIIEvent, error) { + qs := url.Values{} + if q.CorrelationID != "" { + qs.Set("correlation_id", q.CorrelationID) + } + if q.UserID != "" { + qs.Set("user_id", q.UserID) + } + if q.PatternID != "" { + qs.Set("pattern_id", q.PatternID) + } + // The MCP get_pii_events tool is PII-shaped; the events store is now + // shared with proxy events that have no pattern_id/action. Scope to + // kind=pii so the LLM-facing audit stays coherent. + qs.Set("kind", "pii") + if q.Limit > 0 { + qs.Set("limit", fmt.Sprintf("%d", q.Limit)) + } + path := routePIIEvents + if enc := qs.Encode(); enc != "" { + path = path + "?" + enc + } + + var raw struct { + Events []localaitools.PIIEvent `json:"events"` + } + if err := c.do(ctx, http.MethodGet, path, nil, &raw); err != nil { + return nil, err + } + return raw.Events, nil +} + +func (c *Client) TestPIIRedaction(ctx context.Context, req localaitools.PIIRedactTestRequest) (*localaitools.PIIRedactTestResult, error) { + var out localaitools.PIIRedactTestResult + if err := c.do(ctx, http.MethodPost, routePIITest, map[string]string{"text": req.Text}, &out); err != nil { + return nil, err + } + return &out, nil +} + +func (c *Client) SetPIIPatternAction(ctx context.Context, req localaitools.PIIPatternActionUpdate) error { + if req.ID == "" { + return fmt.Errorf("pattern id is required") + } + body := map[string]any{} + if req.Action != "" { + body["action"] = req.Action + } + if req.Disabled != nil { + body["disabled"] = *req.Disabled + } + if len(body) == 0 { + return fmt.Errorf("must specify action and/or disabled") + } + return c.do(ctx, http.MethodPut, routePIIPatternByID(req.ID), body, nil) +} + +func (c *Client) PersistPIIPatterns(ctx context.Context) error { + return c.do(ctx, http.MethodPost, routePIIPatternsPersist, nil, nil) +} + +func (c *Client) GetMiddlewareStatus(ctx context.Context) (*localaitools.MiddlewareStatus, error) { + var out localaitools.MiddlewareStatus + if err := c.do(ctx, http.MethodGet, routeMiddleware, nil, &out); err != nil { + return nil, err + } + return &out, nil +} + +func (c *Client) GetRouterDecisions(ctx context.Context, q localaitools.RouterDecisionsQuery) ([]localaitools.RouterDecision, error) { + qs := url.Values{} + if q.CorrelationID != "" { + qs.Set("correlation_id", q.CorrelationID) + } + if q.UserID != "" { + qs.Set("user_id", q.UserID) + } + if q.RouterModel != "" { + qs.Set("router_model", q.RouterModel) + } + if q.Limit > 0 { + qs.Set("limit", fmt.Sprintf("%d", q.Limit)) + } + path := routeRouterDecisions + if enc := qs.Encode(); enc != "" { + path = path + "?" + enc + } + var raw struct { + Decisions []localaitools.RouterDecision `json:"decisions"` + } + if err := c.do(ctx, http.MethodGet, path, nil, &raw); err != nil { + return nil, err + } + return raw.Decisions, nil +} + // ---- helpers ---- func contains(haystack, lowerNeedle string) bool { diff --git a/pkg/mcp/localaitools/httpapi/routes.go b/pkg/mcp/localaitools/httpapi/routes.go index e44c12b972ad..6d99296f7c70 100644 --- a/pkg/mcp/localaitools/httpapi/routes.go +++ b/pkg/mcp/localaitools/httpapi/routes.go @@ -24,8 +24,20 @@ const ( routeVRAMEstimate = "/api/models/vram-estimate" routeBranding = "/api/branding" routeSettings = "/api/settings" + routeUsage = "/api/usage" + routeUsageAll = "/api/usage/all" + routePIIPatterns = "/api/pii/patterns" + routePIIPatternsPersist = "/api/pii/patterns/persist" + routePIIEvents = "/api/pii/events" + routePIITest = "/api/pii/test" + routeMiddleware = "/api/middleware/status" + routeRouterDecisions = "/api/router/decisions" ) +func routePIIPatternByID(id string) string { + return "/api/pii/patterns/" + url.PathEscape(id) +} + func routeJobStatus(jobID string) string { return "/models/jobs/" + url.PathEscape(jobID) } diff --git a/pkg/mcp/localaitools/inproc/client.go b/pkg/mcp/localaitools/inproc/client.go index 85ad821677ea..03d29cb22c8d 100644 --- a/pkg/mcp/localaitools/inproc/client.go +++ b/pkg/mcp/localaitools/inproc/client.go @@ -9,6 +9,7 @@ import ( "encoding/json" "errors" "fmt" + "strings" "github.com/google/uuid" "github.com/mudler/LocalAI/core/config" @@ -17,6 +18,10 @@ import ( "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/services/galleryop" "github.com/mudler/LocalAI/core/services/modeladmin" + "github.com/mudler/LocalAI/core/http/auth" + "github.com/mudler/LocalAI/core/services/routing/billing" + "github.com/mudler/LocalAI/core/services/routing/pii" + "github.com/mudler/LocalAI/core/services/routing/router" "github.com/mudler/LocalAI/internal" localaitools "github.com/mudler/LocalAI/pkg/mcp/localaitools" "github.com/mudler/LocalAI/pkg/model" @@ -36,12 +41,32 @@ type Client struct { ModelLoader *model.ModelLoader Gallery *galleryop.GalleryService + // StatsRecorder and FallbackUser are optional — they back the + // get_usage_stats tool. nil StatsRecorder makes the tool return an + // "unavailable" error, which keeps the assistant responsive on + // deployments that ran with --disable-stats or where startup wired + // the inproc client before stats were ready. + StatsRecorder *billing.Recorder + FallbackUser *auth.User + + // PIIRedactor and PIIEvents back the list_pii_patterns, + // get_pii_events, and test_pii_redaction tools. nil values cause + // the tools to return a "filter disabled" error. + PIIRedactor *pii.Redactor + PIIEvents pii.EventStore + + // RouterDecisions backs the get_router_decisions tool. nil makes + // the tool return an empty list — same shape the REST endpoint + // returns when stats are disabled. + RouterDecisions router.DecisionStore + modelAdmin *modeladmin.ConfigService } // New builds a Client wired to the given services. All fields are required // except ModelLoader (used only for SystemInfo's loaded-models report and -// best-effort ShutdownModel calls during config edits). +// best-effort ShutdownModel calls during config edits) and the stats +// fields (StatsRecorder, FallbackUser) which gate get_usage_stats. func New(appConfig *config.ApplicationConfig, systemState *system.SystemState, cl *config.ModelConfigLoader, ml *model.ModelLoader, gs *galleryop.GalleryService) *Client { return &Client{ AppConfig: appConfig, @@ -520,6 +545,300 @@ func capabilityToFlag(capability localaitools.Capability) (config.ModelConfigUse return 0, false } +// ---- Usage / billing ---- + +func (c *Client) GetUsageStats(ctx context.Context, q localaitools.UsageStatsQuery) (*localaitools.UsageStats, error) { + if c.StatsRecorder == nil { + return nil, errors.New("usage tracking is not available on this server") + } + period := q.Period + if period == "" { + period = "month" + } + + // Resolve which user this is. In single-user no-auth mode the + // inproc client doesn't have an echo context to read auth.GetUser + // from, so the FallbackUser is the only available identity. When + // auth IS on, the assistant runs under a privileged session and the + // caller can pass q.UserID; we don't enforce admin here because the + // MCP server itself is gated on admin (see prompts/10_safety.md). + var viewerID, viewerName, viewerRole string + switch { + case q.UserID != "": + viewerID = q.UserID + case c.FallbackUser != nil: + viewerID = c.FallbackUser.ID + viewerName = c.FallbackUser.Name + viewerRole = c.FallbackUser.Role + default: + return nil, errors.New("no user context for usage query (auth is on but no user id was provided)") + } + + queryUser := viewerID + if q.All { + // /api/usage/all: cluster-wide by default, but honour the + // optional UserID filter so admins can scope to one user — + // matches the REST endpoint's ?user_id=… query param. Empty + // q.UserID falls through to the cluster-wide aggregate. + queryUser = q.UserID + } + + rows, err := c.StatsRecorder.Aggregate(ctx, billing.AggregateQuery{ + UserID: queryUser, + Period: period, + }) + if err != nil { + return nil, fmt.Errorf("aggregate usage: %w", err) + } + + totals := localaitools.UsageTotals{} + buckets := make([]localaitools.UsageBucket, 0, len(rows)) + for _, r := range rows { + buckets = append(buckets, localaitools.UsageBucket{ + Bucket: r.Bucket, + Model: r.Model, + UserID: r.UserID, + UserName: r.UserName, + PromptTokens: r.PromptTokens, + CompletionTokens: r.CompletionTokens, + TotalTokens: r.TotalTokens, + RequestCount: r.RequestCount, + }) + totals.PromptTokens += r.PromptTokens + totals.CompletionTokens += r.CompletionTokens + totals.TotalTokens += r.TotalTokens + totals.RequestCount += r.RequestCount + } + + return &localaitools.UsageStats{ + Viewer: localaitools.UsageViewer{ID: viewerID, Name: viewerName, Role: viewerRole}, + Period: period, + Totals: totals, + Buckets: buckets, + }, nil +} + +// ---- PII filter ---- + +func (c *Client) ListPIIPatterns(_ context.Context) ([]localaitools.PIIPattern, error) { + if c.PIIRedactor == nil { + return nil, errors.New("PII filter is disabled") + } + patterns := c.PIIRedactor.Patterns() + out := make([]localaitools.PIIPattern, 0, len(patterns)) + for _, p := range patterns { + out = append(out, localaitools.PIIPattern{ + ID: p.ID, + Description: p.Description, + Action: string(p.Action), + MaxMatchLength: p.MaxMatchLength, + }) + } + return out, nil +} + +func (c *Client) GetPIIEvents(ctx context.Context, q localaitools.PIIEventsQuery) ([]localaitools.PIIEvent, error) { + if c.PIIEvents == nil { + return nil, errors.New("PII filter is disabled") + } + events, err := c.PIIEvents.List(ctx, pii.ListQuery{ + CorrelationID: q.CorrelationID, + UserID: q.UserID, + PatternID: q.PatternID, + Kind: pii.KindPII, + Limit: q.Limit, + }) + if err != nil { + return nil, fmt.Errorf("list pii events: %w", err) + } + out := make([]localaitools.PIIEvent, 0, len(events)) + for _, e := range events { + out = append(out, localaitools.PIIEvent{ + ID: e.ID, + CorrelationID: e.CorrelationID, + UserID: e.UserID, + Direction: string(e.Direction), + PatternID: e.PatternID, + ByteOffset: e.ByteOffset, + Length: e.Length, + HashPrefix: e.HashPrefix, + Action: string(e.Action), + CreatedAt: e.CreatedAt.Format("2006-01-02T15:04:05Z07:00"), + }) + } + return out, nil +} + +func (c *Client) SetPIIPatternAction(_ context.Context, req localaitools.PIIPatternActionUpdate) error { + if c.PIIRedactor == nil { + return errors.New("PII filter is disabled") + } + if req.ID == "" { + return errors.New("pattern id is required") + } + if req.Action == "" && req.Disabled == nil { + return errors.New("must specify action and/or disabled") + } + if req.Action != "" { + if err := c.PIIRedactor.SetAction(req.ID, pii.Action(req.Action)); err != nil { + return err + } + } + if req.Disabled != nil { + if err := c.PIIRedactor.SetDisabled(req.ID, *req.Disabled); err != nil { + return err + } + } + return nil +} + +// PersistPIIPatterns snapshots the current redactor state into +// runtime_settings.json. Mirrors POST /api/pii/patterns/persist. +func (c *Client) PersistPIIPatterns(_ context.Context) error { + if c.PIIRedactor == nil { + return errors.New("PII filter is disabled") + } + if c.AppConfig == nil { + return errors.New("app config not available") + } + existing, err := c.AppConfig.ReadPersistedSettings() + if err != nil { + return fmt.Errorf("read settings: %w", err) + } + defaults, err := pii.LoadConfig(c.AppConfig.PIIConfigPath) + if err != nil { + return fmt.Errorf("reload defaults: %w", err) + } + defaultByID := make(map[string]pii.Pattern, len(defaults)) + for _, d := range defaults { + defaultByID[d.ID] = d + } + overrides := map[string]config.PIIPatternRuntimeOverride{} + for _, p := range c.PIIRedactor.Patterns() { + d, known := defaultByID[p.ID] + ov := config.PIIPatternRuntimeOverride{} + changed := false + if !known || p.Action != d.Action { + action := string(p.Action) + ov.Action = &action + changed = true + } + if !known || p.Disabled != d.Disabled { + disabled := p.Disabled + ov.Disabled = &disabled + changed = true + } + if changed { + overrides[p.ID] = ov + } + } + existing.PIIPatternOverrides = &overrides + if err := c.AppConfig.WritePersistedSettings(existing); err != nil { + return fmt.Errorf("write settings: %w", err) + } + c.AppConfig.PIIPatternOverrides = overrides + return nil +} + +func (c *Client) GetRouterDecisions(ctx context.Context, q localaitools.RouterDecisionsQuery) ([]localaitools.RouterDecision, error) { + if c.RouterDecisions == nil { + return []localaitools.RouterDecision{}, nil + } + rows, err := c.RouterDecisions.List(ctx, router.DecisionListQuery{ + CorrelationID: q.CorrelationID, + UserID: q.UserID, + RouterModel: q.RouterModel, + Limit: q.Limit, + }) + if err != nil { + return nil, fmt.Errorf("list router decisions: %w", err) + } + out := make([]localaitools.RouterDecision, 0, len(rows)) + for _, r := range rows { + out = append(out, localaitools.RouterDecision{ + ID: r.ID, + CorrelationID: r.CorrelationID, + UserID: r.UserID, + RouterModel: r.RouterModel, + RequestedModel: r.RequestedModel, + ServedModel: r.ServedModel, + Classifier: r.Classifier, + Label: r.Label, + Score: r.Score, + LatencyMs: r.LatencyMs, + Cached: r.Cached, + CreatedAt: r.CreatedAt.Format("2006-01-02T15:04:05Z07:00"), + }) + } + return out, nil +} + +func (c *Client) GetMiddlewareStatus(ctx context.Context) (*localaitools.MiddlewareStatus, error) { + router := localaitools.MiddlewareRouterStatus{ + Configured: false, + Models: []string{}, + Note: "Intelligent routing is not yet implemented.", + } + piiSection := localaitools.MiddlewarePIIStatus{ + EnabledGlobally: c.PIIRedactor != nil, + Patterns: []localaitools.PIIPattern{}, + Models: []localaitools.MiddlewarePIIModel{}, + } + if c.PIIRedactor == nil { + piiSection.Reason = "--disable-pii" + return &localaitools.MiddlewareStatus{PII: piiSection, Router: router}, nil + } + piiSection.DefaultEnabledForBackends = []string{"proxy-*"} + for _, p := range c.PIIRedactor.Patterns() { + piiSection.Patterns = append(piiSection.Patterns, localaitools.PIIPattern{ + ID: p.ID, + Description: p.Description, + Action: string(p.Action), + MaxMatchLength: p.MaxMatchLength, + }) + } + if c.ConfigLoader != nil { + for _, cfg := range c.ConfigLoader.GetAllModelsConfigs() { + cfg := cfg + piiSection.Models = append(piiSection.Models, localaitools.MiddlewarePIIModel{ + Name: cfg.Name, + Backend: cfg.Backend, + Enabled: cfg.PIIIsEnabled(), + Explicit: cfg.PII.Enabled != nil, + DefaultForBackend: strings.HasPrefix(cfg.Backend, "proxy-"), + Overrides: cfg.PIIPatternOverrides(), + }) + } + } + if c.PIIEvents != nil { + if n, err := c.PIIEvents.Count(ctx); err == nil { + piiSection.RecentEventCount = n + } + } + return &localaitools.MiddlewareStatus{PII: piiSection, Router: router}, nil +} + +func (c *Client) TestPIIRedaction(_ context.Context, req localaitools.PIIRedactTestRequest) (*localaitools.PIIRedactTestResult, error) { + if c.PIIRedactor == nil { + return nil, errors.New("PII filter is disabled") + } + res := c.PIIRedactor.Redact(req.Text) + out := &localaitools.PIIRedactTestResult{ + Redacted: res.Redacted, + Blocked: res.Blocked, + LocalOnly: res.LocalOnly, + } + for _, s := range res.Spans { + out.Spans = append(out.Spans, localaitools.PIIEventSpan{ + Start: s.Start, + End: s.End, + Pattern: s.Pattern, + HashPrefix: s.HashPrefix, + }) + } + return out, nil +} + func capabilityFlagsOf(m *config.ModelConfig) []string { var out []string for label, flag := range config.GetAllModelConfigUsecases() { diff --git a/pkg/mcp/localaitools/server.go b/pkg/mcp/localaitools/server.go index 88d96ac0aa1a..fd9f5da00ee0 100644 --- a/pkg/mcp/localaitools/server.go +++ b/pkg/mcp/localaitools/server.go @@ -48,6 +48,9 @@ func NewServer(client LocalAIClient, opts Options) *mcp.Server { registerSystemTools(srv, client, opts) registerStateTools(srv, client, opts) registerBrandingTools(srv, client, opts) + registerUsageTools(srv, client, opts) + registerPIITools(srv, client, opts) + registerMiddlewareTools(srv, client, opts) return srv } diff --git a/pkg/mcp/localaitools/server_test.go b/pkg/mcp/localaitools/server_test.go index caf8bfdee969..f82d0ae415c5 100644 --- a/pkg/mcp/localaitools/server_test.go +++ b/pkg/mcp/localaitools/server_test.go @@ -78,7 +78,11 @@ var expectedFullCatalog = sortedStrings( ToolGallerySearch, ToolGetBranding, ToolGetJobStatus, + ToolGetMiddlewareStatus, ToolGetModelConfig, + ToolGetPIIEvents, + ToolGetRouterDecisions, + ToolGetUsageStats, ToolImportModelURI, ToolInstallBackend, ToolInstallModel, @@ -87,9 +91,13 @@ var expectedFullCatalog = sortedStrings( ToolListInstalledModels, ToolListKnownBackends, ToolListNodes, + ToolListPIIPatterns, + ToolPersistPIIPatterns, ToolReloadModels, ToolSetBranding, + ToolSetPIIPatternAction, ToolSystemInfo, + ToolTestPIIRedaction, ToolToggleModelPinned, ToolToggleModelState, ToolUpgradeBackend, @@ -101,13 +109,19 @@ var expectedReadOnlyCatalog = sortedStrings( ToolGallerySearch, ToolGetBranding, ToolGetJobStatus, + ToolGetMiddlewareStatus, ToolGetModelConfig, + ToolGetPIIEvents, + ToolGetRouterDecisions, + ToolGetUsageStats, ToolListBackends, ToolListGalleries, ToolListInstalledModels, ToolListKnownBackends, ToolListNodes, + ToolListPIIPatterns, ToolSystemInfo, + ToolTestPIIRedaction, ToolVRAMEstimate, ) diff --git a/pkg/mcp/localaitools/tools.go b/pkg/mcp/localaitools/tools.go index d5e213f42748..57b2638e3065 100644 --- a/pkg/mcp/localaitools/tools.go +++ b/pkg/mcp/localaitools/tools.go @@ -19,6 +19,12 @@ const ( ToolListNodes = "list_nodes" ToolVRAMEstimate = "vram_estimate" ToolGetBranding = "get_branding" + ToolGetUsageStats = "get_usage_stats" + ToolListPIIPatterns = "list_pii_patterns" + ToolGetPIIEvents = "get_pii_events" + ToolTestPIIRedaction = "test_pii_redaction" + ToolGetMiddlewareStatus = "get_middleware_status" + ToolGetRouterDecisions = "get_router_decisions" // Mutating tools — guarded by Options.DisableMutating and the // LLM-side safety prompt (see prompts/10_safety.md). @@ -32,6 +38,8 @@ const ( ToolToggleModelState = "toggle_model_state" ToolToggleModelPinned = "toggle_model_pinned" ToolSetBranding = "set_branding" + ToolSetPIIPatternAction = "set_pii_pattern_action" + ToolPersistPIIPatterns = "persist_pii_patterns" ) // DefaultServerName is the MCP Implementation.Name surfaced when diff --git a/pkg/mcp/localaitools/tools_middleware.go b/pkg/mcp/localaitools/tools_middleware.go new file mode 100644 index 000000000000..626609bb027e --- /dev/null +++ b/pkg/mcp/localaitools/tools_middleware.go @@ -0,0 +1,78 @@ +package localaitools + +import ( + "context" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// registerMiddlewareTools wires the routing-module admin surface for the +// MCP server. The two tools mirror what the React /app/middleware page +// exposes: +// +// - get_middleware_status: read-only aggregator. The agent can ask +// "what's filtering my requests?" and get back the active PII +// pattern set, the per-model resolved enabled/override state, and +// a placeholder for routing. +// - set_pii_pattern_action: mutating. Mutations are TRANSIENT — they +// live until process restart, when patterns reload from the YAML +// defaults. The skill prompt should warn the user about that +// before applying lasting changes. +func registerMiddlewareTools(s *mcp.Server, client LocalAIClient, opts Options) { + mcp.AddTool(s, &mcp.Tool{ + Name: ToolGetMiddlewareStatus, + Description: "Aggregated routing-module status: PII pattern catalogue with current actions, per-model resolved PII state and overrides, recent event count, plus the active router models and their classifier configs. Read-only.", + }, func(ctx context.Context, _ *mcp.CallToolRequest, _ struct{}) (*mcp.CallToolResult, any, error) { + status, err := client.GetMiddlewareStatus(ctx) + if err != nil { + return errorResult(err), nil, nil + } + return jsonResult(status), nil, nil + }) + + mcp.AddTool(s, &mcp.Tool{ + Name: ToolGetRouterDecisions, + Description: "Recent intelligent-routing decisions. Each row records which router model the client called, which candidate the classifier picked, the classifier's score and latency, and a correlation id that joins back to the usage record. Filter by correlation_id, user_id, or router_model. Read-only.", + }, func(ctx context.Context, _ *mcp.CallToolRequest, args RouterDecisionsQuery) (*mcp.CallToolResult, any, error) { + decisions, err := client.GetRouterDecisions(ctx, args) + if err != nil { + return errorResult(err), nil, nil + } + return jsonResult(decisions), nil, nil + }) + + if opts.DisableMutating { + return + } + + mcp.AddTool(s, &mcp.Tool{ + Name: ToolSetPIIPatternAction, + Description: "Change a PII pattern's action (mask|block|route_local) and/or disabled state in-process. TRANSIENT: the mutation is lost on restart unless followed by persist_pii_patterns. Admin-required.", + }, func(ctx context.Context, _ *mcp.CallToolRequest, args PIIPatternActionUpdate) (*mcp.CallToolResult, any, error) { + if args.ID == "" { + return errorResultf("id is required"), nil, nil + } + if args.Action == "" && args.Disabled == nil { + return errorResultf("at least one of action (mask, block, route_local) or disabled must be set"), nil, nil + } + if err := client.SetPIIPatternAction(ctx, args); err != nil { + return errorResult(err), nil, nil + } + return jsonResult(map[string]any{ + "id": args.ID, + "action": args.Action, + "disabled": args.Disabled, + "persisted": false, + }), nil, nil + }) + + mcp.AddTool(s, &mcp.Tool{ + Name: ToolPersistPIIPatterns, + Description: "Snapshot the live PII redactor's per-pattern (action, disabled) state into runtime_settings.json so it re-applies on the next process start. Pairs with set_pii_pattern_action — that one is in-process; this one persists. Admin-required.", + }, func(ctx context.Context, _ *mcp.CallToolRequest, _ struct{}) (*mcp.CallToolResult, any, error) { + if err := client.PersistPIIPatterns(ctx); err != nil { + return errorResult(err), nil, nil + } + return jsonResult(map[string]any{"persisted": true}), nil, nil + }) +} diff --git a/pkg/mcp/localaitools/tools_pii.go b/pkg/mcp/localaitools/tools_pii.go new file mode 100644 index 000000000000..e53a27dbeb2b --- /dev/null +++ b/pkg/mcp/localaitools/tools_pii.go @@ -0,0 +1,45 @@ +package localaitools + +import ( + "context" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +func registerPIITools(s *mcp.Server, client LocalAIClient, _ Options) { + mcp.AddTool(s, &mcp.Tool{ + Name: ToolListPIIPatterns, + Description: "List the active PII regex pattern set. Each entry shows the pattern id, description, and current action (mask, block, route_local). Read-only.", + }, func(ctx context.Context, _ *mcp.CallToolRequest, _ struct{}) (*mcp.CallToolResult, any, error) { + patterns, err := client.ListPIIPatterns(ctx) + if err != nil { + return errorResult(err), nil, nil + } + return jsonResult(patterns), nil, nil + }) + + mcp.AddTool(s, &mcp.Tool{ + Name: ToolGetPIIEvents, + Description: "Recent PII redaction events. Filter by correlation_id (joins to a usage record), user_id, or pattern_id. Events never carry the matched value — only an 8-char sha256 prefix so admins can dedupe recurring leaks.", + }, func(ctx context.Context, _ *mcp.CallToolRequest, args PIIEventsQuery) (*mcp.CallToolResult, any, error) { + events, err := client.GetPIIEvents(ctx, args) + if err != nil { + return errorResult(err), nil, nil + } + return jsonResult(events), nil, nil + }) + + mcp.AddTool(s, &mcp.Tool{ + Name: ToolTestPIIRedaction, + Description: "Dry-run the PII redactor against text without recording a real event. Useful for tuning patterns: paste a candidate string and see whether it would be masked, blocked, or routed locally.", + }, func(ctx context.Context, _ *mcp.CallToolRequest, args PIIRedactTestRequest) (*mcp.CallToolResult, any, error) { + if args.Text == "" { + return errorResultf("text is required"), nil, nil + } + res, err := client.TestPIIRedaction(ctx, args) + if err != nil { + return errorResult(err), nil, nil + } + return jsonResult(res), nil, nil + }) +} diff --git a/pkg/mcp/localaitools/tools_usage.go b/pkg/mcp/localaitools/tools_usage.go new file mode 100644 index 000000000000..055118d92b18 --- /dev/null +++ b/pkg/mcp/localaitools/tools_usage.go @@ -0,0 +1,22 @@ +package localaitools + +import ( + "context" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +func registerUsageTools(s *mcp.Server, client LocalAIClient, _ Options) { + mcp.AddTool(s, &mcp.Tool{ + Name: ToolGetUsageStats, + Description: "Return aggregated token usage. Defaults to the calling user's own usage over the last month. " + + "Use period=day|week|month|all to change the window. Set all=true for a cluster-wide admin view " + + "(only meaningful when auth is on and the caller is admin; in single-user mode there is only one user).", + }, func(ctx context.Context, _ *mcp.CallToolRequest, args UsageStatsQuery) (*mcp.CallToolResult, any, error) { + stats, err := client.GetUsageStats(ctx, args) + if err != nil { + return errorResult(err), nil, nil + } + return jsonResult(stats), nil, nil + }) +} diff --git a/pkg/model/connection_evicting_client.go b/pkg/model/connection_evicting_client.go index ade1e294bad6..b101e8f827e7 100644 --- a/pkg/model/connection_evicting_client.go +++ b/pkg/model/connection_evicting_client.go @@ -113,3 +113,15 @@ func (c *ConnectionEvictingClient) Rerank(ctx context.Context, in *pb.RerankRequ c.checkErr(err) return result, err } + +func (c *ConnectionEvictingClient) TokenClassify(ctx context.Context, in *pb.TokenClassifyRequest, opts ...ggrpc.CallOption) (*pb.TokenClassifyResponse, error) { + result, err := c.Backend.TokenClassify(ctx, in, opts...) + c.checkErr(err) + return result, err +} + +func (c *ConnectionEvictingClient) Score(ctx context.Context, in *pb.ScoreRequest, opts ...ggrpc.CallOption) (*pb.ScoreResponse, error) { + result, err := c.Backend.Score(ctx, in, opts...) + c.checkErr(err) + return result, err +} diff --git a/pkg/store/client.go b/pkg/store/client.go index 1a1f46ccc578..4fa884b19ed5 100644 --- a/pkg/store/client.go +++ b/pkg/store/client.go @@ -13,24 +13,10 @@ import ( // SetCols sets multiple key-value pairs in the store // It's in columnar format so that keys[i] is associated with values[i] func SetCols(ctx context.Context, c grpc.Backend, keys [][]float32, values [][]byte) error { - protoKeys := make([]*proto.StoresKey, len(keys)) - for i, k := range keys { - protoKeys[i] = &proto.StoresKey{ - Floats: k, - } - } - protoValues := make([]*proto.StoresValue, len(values)) - for i, v := range values { - protoValues[i] = &proto.StoresValue{ - Bytes: v, - } - } - setOpts := &proto.StoresSetOptions{ - Keys: protoKeys, - Values: protoValues, - } - - res, err := c.StoresSet(ctx, setOpts) + res, err := c.StoresSet(ctx, &proto.StoresSetOptions{ + Keys: WrapKeys(keys), + Values: WrapValues(values), + }) if err != nil { return err } @@ -51,17 +37,7 @@ func SetSingle(ctx context.Context, c grpc.Backend, key []float32, value []byte) // DeleteCols deletes multiple key-value pairs from the store // It's in columnar format so that keys[i] is associated with values[i] func DeleteCols(ctx context.Context, c grpc.Backend, keys [][]float32) error { - protoKeys := make([]*proto.StoresKey, len(keys)) - for i, k := range keys { - protoKeys[i] = &proto.StoresKey{ - Floats: k, - } - } - deleteOpts := &proto.StoresDeleteOptions{ - Keys: protoKeys, - } - - res, err := c.StoresDelete(ctx, deleteOpts) + res, err := c.StoresDelete(ctx, &proto.StoresDeleteOptions{Keys: WrapKeys(keys)}) if err != nil { return err } @@ -84,31 +60,11 @@ func DeleteSingle(ctx context.Context, c grpc.Backend, key []float32) error { // Be warned the keys are sorted and will be returned in a different order than they were input // There is no guarantee as to how the keys are sorted func GetCols(ctx context.Context, c grpc.Backend, keys [][]float32) ([][]float32, [][]byte, error) { - protoKeys := make([]*proto.StoresKey, len(keys)) - for i, k := range keys { - protoKeys[i] = &proto.StoresKey{ - Floats: k, - } - } - getOpts := &proto.StoresGetOptions{ - Keys: protoKeys, - } - - res, err := c.StoresGet(ctx, getOpts) + res, err := c.StoresGet(ctx, &proto.StoresGetOptions{Keys: WrapKeys(keys)}) if err != nil { return nil, nil, err } - - ks := make([][]float32, len(res.Keys)) - for i, k := range res.Keys { - ks[i] = k.Floats - } - vs := make([][]byte, len(res.Values)) - for i, v := range res.Values { - vs[i] = v.Bytes - } - - return ks, vs, nil + return UnwrapKeys(res.Keys), UnwrapValues(res.Values), nil } // GetSingle gets a single key-value pair from the store @@ -128,28 +84,12 @@ func GetSingle(ctx context.Context, c grpc.Backend, key []float32) ([]byte, erro // Find similar keys to the given key. Returns the keys, values, and similarities func Find(ctx context.Context, c grpc.Backend, key []float32, topk int) ([][]float32, [][]byte, []float32, error) { - findOpts := &proto.StoresFindOptions{ - Key: &proto.StoresKey{ - Floats: key, - }, + res, err := c.StoresFind(ctx, &proto.StoresFindOptions{ + Key: &proto.StoresKey{Floats: key}, TopK: int32(topk), - } - - res, err := c.StoresFind(ctx, findOpts) + }) if err != nil { return nil, nil, nil, err } - - ks := make([][]float32, len(res.Keys)) - vs := make([][]byte, len(res.Values)) - - for i, k := range res.Keys { - ks[i] = k.Floats - } - - for i, v := range res.Values { - vs[i] = v.Bytes - } - - return ks, vs, res.Similarities, nil + return UnwrapKeys(res.Keys), UnwrapValues(res.Values), res.Similarities, nil } diff --git a/pkg/store/proto.go b/pkg/store/proto.go new file mode 100644 index 000000000000..1eb5bece94b5 --- /dev/null +++ b/pkg/store/proto.go @@ -0,0 +1,46 @@ +package store + +// pb⇄[][]float32/[][]byte translation helpers shared by the gRPC +// client (this file's package) and the local-store gRPC server in +// backend/go/local-store. Same shape on both sides of the wire so a +// schema bug only needs fixing once. + +import ( + "github.com/mudler/LocalAI/pkg/grpc/proto" +) + +// WrapKeys wraps each plain []float32 in a *proto.StoresKey. +func WrapKeys(in [][]float32) []*proto.StoresKey { + out := make([]*proto.StoresKey, len(in)) + for i, k := range in { + out[i] = &proto.StoresKey{Floats: k} + } + return out +} + +// WrapValues wraps each []byte in a *proto.StoresValue. +func WrapValues(in [][]byte) []*proto.StoresValue { + out := make([]*proto.StoresValue, len(in)) + for i, v := range in { + out[i] = &proto.StoresValue{Bytes: v} + } + return out +} + +// UnwrapKeys extracts the inner Floats from a slice of *proto.StoresKey. +func UnwrapKeys(in []*proto.StoresKey) [][]float32 { + out := make([][]float32, len(in)) + for i, k := range in { + out[i] = k.Floats + } + return out +} + +// UnwrapValues extracts the inner Bytes from a slice of *proto.StoresValue. +func UnwrapValues(in []*proto.StoresValue) [][]byte { + out := make([][]byte, len(in)) + for i, v := range in { + out[i] = v.Bytes + } + return out +} diff --git a/tests/e2e-ui/main.go b/tests/e2e-ui/main.go index 7aca8e7e4d69..46dd954c4582 100644 --- a/tests/e2e-ui/main.go +++ b/tests/e2e-ui/main.go @@ -8,6 +8,7 @@ import ( "os" "os/signal" "path/filepath" + "strings" "syscall" "github.com/mudler/LocalAI/core/application" @@ -21,6 +22,20 @@ import ( func main() { mockBackend := flag.String("mock-backend", "", "path to mock-backend binary") port := flag.Int("port", 8089, "port to listen on") + // piiYAML lets a test inject a per-model `pii:` block into the + // auto-generated mock-model.yaml. Used by the middleware end-to-end + // verification (and any future test that wants to exercise per-model + // gating without bringing up a real backend). The argument is the + // body of the pii: block — the leading "pii:\n " is added here. + piiYAML := flag.String("pii-yaml", "", "optional pii: block to merge into mock-model.yaml") + // extraModels accepts repeatable name=yaml pairs that get written + // as additional model files. Used by the routing E2E to seed + // candidate models a router model can dispatch to. + extraModelFlag := flag.String("extra-model", "", "extra model YAML, formatted as 'name|'. Repeatable via comma-then-pipe? — for the router test we ship a single big string with embedded newlines.") + // routerYAML appends a `router:` block to mock-model.yaml. Used by + // the routing E2E to turn mock-model into a smart-router that + // dispatches to extra-models. + routerYAML := flag.String("router-yaml", "", "optional router: block to merge into mock-model.yaml") flag.Parse() if *mockBackend == "" { @@ -71,11 +86,33 @@ func main() { fmt.Fprintf(os.Stderr, "error marshaling config: %v\n", err) os.Exit(1) } - if err := os.WriteFile(filepath.Join(modelsPath, "mock-model.yaml"), configYAML, 0644); err != nil { + body := configYAML + if *piiYAML != "" { + body = append(body, []byte("pii:\n "+*piiYAML+"\n")...) + } + if *routerYAML != "" { + body = append(body, []byte("router:\n "+*routerYAML+"\n")...) + } + if err := os.WriteFile(filepath.Join(modelsPath, "mock-model.yaml"), body, 0644); err != nil { fmt.Fprintf(os.Stderr, "error writing config: %v\n", err) os.Exit(1) } + if *extraModelFlag != "" { + // extra-model format: "name|". The yaml body is + // inlined verbatim — caller controls indentation. Single name + // per flag invocation; multi-flag is fine because flag.String + // only keeps the last but the test passes only one. + parts := strings.SplitN(*extraModelFlag, "|", 2) + if len(parts) == 2 { + extraPath := filepath.Join(modelsPath, parts[0]+".yaml") + if err := os.WriteFile(extraPath, []byte(parts[1]), 0644); err != nil { + fmt.Fprintf(os.Stderr, "error writing extra model: %v\n", err) + os.Exit(1) + } + } + } + // Set up system state systemState, err := system.GetSystemState( system.WithModelPath(modelsPath), diff --git a/tests/e2e/mock-backend/main.go b/tests/e2e/mock-backend/main.go index ec0c7735d3aa..46c4e51d6a4a 100644 --- a/tests/e2e/mock-backend/main.go +++ b/tests/e2e/mock-backend/main.go @@ -315,11 +315,18 @@ func (m *MockBackend) PredictStream(in *pb.PredictOptions, stream pb.Backend_Pre var toStream string toolName := mockToolNameFromRequest(in) - if toolName != "" && !promptHasToolResults(in.Prompt) { + switch { + case toolName != "" && !promptHasToolResults(in.Prompt): toStream = fmt.Sprintf(`{"name": "%s", "arguments": {"location": "San Francisco"}}`, toolName) - } else if toolName != "" { + case toolName != "": toStream = "Based on the tool results, the weather in San Francisco is sunny, 72°F." - } else { + case strings.Contains(in.Prompt, "MOCK_LEAK_EMAIL"): + // PII streaming test fixture: emit a response containing an email + // address so the streaming PII filter has something to mask. The + // content is split character-by-character below, so the mask + // must hold across chunk boundaries. + toStream = "Sure — here it is: alice@example.com is the address." + default: toStream = "This is a mocked streaming response." } for i, r := range toStream {