From 1817bd5761154fc3277fb53431d8713d98cea8b0 Mon Sep 17 00:00:00 2001 From: Jacob Fu <141651335+FuJacob@users.noreply.github.com> Date: Mon, 1 Jun 2026 23:14:16 -0700 Subject: [PATCH] Fix constrained-decode latency: bounded top-K, skip unused logSumExp Making the constrained decoder the only llama decode path (#532) moved every suggestion onto two O(vocab) operations per generated token that the deleted native sampler never paid: - ConstrainedSampler.candidatePool sorted the full vocabulary (150k-256k tokens) every step just to take the top topK. Replace the full sort with a single-pass bounded top-K selection that keeps the same membership and the same lower-id tie-break. - runConstrainedDecode scored every token with a full-vocab logSumExp to feed the confidence floor, which shouldSuppress treats as a no-op at the default floor of -infinity. Skip it unless a caller raises the floor. Token selection dropped from ~8.0s to ~0.55s per suggestion in a debug build (200k vocab, 25-token budget) with identical selected tokens. A 4000-trial randomized equivalence test pins the fast path to the old full-sort behavior bit-for-bit. --- .../Services/Runtime/LlamaRuntimeCore.swift | 7 +- Cotabby/Support/ConstrainedSampler.swift | 54 +++++++-- CotabbyTests/ConstrainedSamplerTests.swift | 104 ++++++++++++++++++ 3 files changed, 156 insertions(+), 9 deletions(-) diff --git a/Cotabby/Services/Runtime/LlamaRuntimeCore.swift b/Cotabby/Services/Runtime/LlamaRuntimeCore.swift index 92f02dd..a62bdc1 100644 --- a/Cotabby/Services/Runtime/LlamaRuntimeCore.swift +++ b/Cotabby/Services/Runtime/LlamaRuntimeCore.swift @@ -335,7 +335,12 @@ nonisolated final class LlamaRuntimeCore: @unchecked Sendable { // Accumulate raw bytes and decode once at the end: a single token may carry only part of // a multi-byte UTF-8 scalar, so per-token String decoding would corrupt CJK / emoji. let tokenBytes = profile.bytes(for: tokenID) - if let logProb = ConstrainedSampler.logProb(ofTokenAt: tokenID, in: logits) { + // Scoring each step calls `logSumExp` over the full vocabulary — an O(vocab) exp pass per + // token. It only feeds the confidence-floor suppression check, which `shouldSuppress` + // short-circuits to a no-op when the floor is -infinity (the shipped default). Skip the + // work entirely in that case and only pay it when a caller has actually raised the floor. + if options.confidenceFloor > -.infinity, + let logProb = ConstrainedSampler.logProb(ofTokenAt: tokenID, in: logits) { sumLogprob += logProb } generatedBytes.append(contentsOf: tokenBytes) diff --git a/Cotabby/Support/ConstrainedSampler.swift b/Cotabby/Support/ConstrainedSampler.swift index 5a83013..174f4e1 100644 --- a/Cotabby/Support/ConstrainedSampler.swift +++ b/Cotabby/Support/ConstrainedSampler.swift @@ -140,20 +140,58 @@ enum ConstrainedSampler { } /// The token ids to consider this step, ordered by id. When `limit` is at least `count` every id - /// is returned (still id-ordered). Otherwise the ids are ranked by descending logit, the top - /// `limit` are kept, and that subset is re-sorted by id so downstream tie-breaking stays stable. + /// is returned (still id-ordered). Otherwise the `limit` highest-logit ids are kept and returned + /// re-sorted by id so downstream tie-breaking stays stable. + /// + /// Performance invariant: this runs once per generated token and `count` is the full vocabulary + /// (~150k-256k for the shipped base models), so it must not sort the whole vocabulary. The earlier + /// `(0.. [Int] { guard limit < count else { return Array(0.. logits[rhs] + var keptIDs = [Int](repeating: 0, count: limit) + var filled = 0 + // Index into `keptIDs` of the candidate to evict first (lowest logit; ties toward the larger + // id). Only meaningful once the buffer is full; recomputed after every displacement. + var worstIndex = 0 + for id in 0 ..< count { + if filled < limit { + keptIDs[filled] = id + filled += 1 + if filled == limit { + worstIndex = worstCandidateIndex(in: keptIDs, count: limit, logits: logits) + } + continue + } + if logits[id] > logits[keptIDs[worstIndex]] { + keptIDs[worstIndex] = id + worstIndex = worstCandidateIndex(in: keptIDs, count: limit, logits: logits) + } + } + return keptIDs.sorted() + } + + /// Index into `keptIDs` of the candidate that should leave the kept set first: the lowest logit, + /// breaking ties toward the larger id so the smaller id is retained. This matches the top-`limit` + /// cut line of a full `(logit desc, id asc)` sort, which is what `candidatePool` reproduces without + /// sorting. `count` candidates is at most `limit` (small), so this O(limit) scan is cheap. + private static func worstCandidateIndex(in keptIDs: [Int], count: Int, logits: [Float]) -> Int { + var worst = 0 + for index in 1 ..< count { + let candidate = keptIDs[index] + let current = keptIDs[worst] + if logits[candidate] < logits[current] + || (logits[candidate] == logits[current] && candidate > current) { + worst = index } - // Stable id ordering for equal logits so the kept set is deterministic at the cut line. - return lhs < rhs } - return ranked.prefix(limit).sorted() + return worst } /// Numerically stable log(sum(exp(row))): subtract the max before exponentiating so large logits diff --git a/CotabbyTests/ConstrainedSamplerTests.swift b/CotabbyTests/ConstrainedSamplerTests.swift index 4af68a8..a1b5be1 100644 --- a/CotabbyTests/ConstrainedSamplerTests.swift +++ b/CotabbyTests/ConstrainedSamplerTests.swift @@ -188,6 +188,110 @@ final class ConstrainedSamplerTests: XCTestCase { XCTAssertNil(id) } + // MARK: - candidatePool equivalence (top-K selection without a full sort) + + /// Deterministic, seedable RNG so the randomized equivalence sweep is reproducible across runs and + /// machines (no dependence on `SystemRandomNumberGenerator`). + private struct SplitMix64: RandomNumberGenerator { + private var state: UInt64 + init(seed: UInt64) { state = seed } + mutating func next() -> UInt64 { + state &+= 0x9E37_79B9_7F4A_7C15 + var mixed = state + mixed = (mixed ^ (mixed >> 30)) &* 0xBF58_476D_1CE4_E5B9 + mixed = (mixed ^ (mixed >> 27)) &* 0x94D0_49BB_1331_11EB + return mixed ^ (mixed >> 31) + } + } + + /// Reference selection that mirrors the *old* implementation exactly: rank the whole vocabulary by + /// (logit desc, id asc), keep the top `topK`, then argmax the survivors. `selectToken` now skips + /// the full sort, so this is the oracle the fast path must reproduce bit-for-bit. + private func referenceSelect( + logits: [Float], + control: Set, + admissible: Set?, + topK: Int, + blocked: Set + ) -> Int? { + guard topK > 0, !logits.isEmpty else { return nil } + if let admissible, admissible.isEmpty { return nil } + let ranked = (0 ..< logits.count).sorted { lhs, rhs in + if logits[lhs] != logits[rhs] { return logits[lhs] > logits[rhs] } + return lhs < rhs + } + let pool = Array(ranked.prefix(topK)).sorted() + var best: Int? + var bestLogit: Float = -.infinity + for id in pool { + if control.contains(id) || blocked.contains(id) { continue } + if let admissible, !admissible.contains(id) { continue } + if best == nil || logits[id] > bestLogit { + best = id + bestLogit = logits[id] + } + } + return best + } + + /// The fast top-K selection must match the old full-sort behavior for every combination of vocab + /// size, tie structure, topK cut, exclusions, blocks, and admissibility. Logits are quantized to a + /// few distinct values on many trials so exact-tie cut-line behavior (lower id wins) is exercised + /// heavily, which is where a hand-rolled top-K is most likely to diverge from a full sort. + func test_select_matchesFullSortReferenceAcrossRandomInputs() { + var rng = SplitMix64(seed: 0xC0FFEE_D00D) + for trial in 0 ..< 4000 { + let count = Int.random(in: 1 ... 120, using: &rng) + // Alternate between coarse (tie-heavy) and fine logit granularity. + let distinctValues = trial.isMultiple(of: 2) ? 4 : 64 + let logits = (0 ..< count).map { _ in + Float(Int.random(in: 0 ..< distinctValues, using: &rng)) + } + let control = Set((0 ..< count).filter { _ in Int.random(in: 0 ..< 5, using: &rng) == 0 }) + let blocked = Set((0 ..< count).filter { _ in Int.random(in: 0 ..< 6, using: &rng) == 0 }) + let admissible: Set? = Int.random(in: 0 ..< 3, using: &rng) == 0 + ? nil + : Set((0 ..< count).filter { _ in Int.random(in: 0 ..< 2, using: &rng) == 0 }) + let topK = Int.random(in: 0 ... (count + 2), using: &rng) + + let actual = ConstrainedSampler.selectToken( + logits: logits, + profile: plainProfile(count: count, control: control), + admissibleTokenIDs: admissible, + topK: topK, + blockedTokenIDs: blocked + ) + let expected = referenceSelect( + logits: logits, + control: control, + admissible: admissible, + topK: topK, + blocked: blocked + ) + XCTAssertEqual( + actual, + expected, + "trial \(trial): count=\(count) topK=\(topK) diverged from full-sort reference" + ) + } + } + + /// Cut-line tie-break: when the top-`topK` boundary falls in a run of equal logits, the lower ids + /// must be the ones kept (so the selected token is the lowest id in the tied run), exactly as the + /// previous full sort guaranteed. A large vocab makes a regression in the bounded selection obvious. + func test_select_largeVocabEqualLogits_keepsLowestIDsAtCut() { + let count = 5000 + let logits = [Float](repeating: 1.0, count: count) + let id = ConstrainedSampler.selectToken( + logits: logits, + profile: plainProfile(count: count), + admissibleTokenIDs: nil, + topK: 20 + ) + // All logits equal -> the kept pool is ids 0...19 and argmax breaks to the lowest id. + XCTAssertEqual(id, 0) + } + // MARK: - averageLogProb func test_averageLogProb_uniformRow_matchesNegativeLogVocab() {