Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion Cotabby/Services/Runtime/LlamaRuntimeCore.swift
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,12 @@ nonisolated final class LlamaRuntimeCore: @unchecked Sendable {
// Accumulate raw bytes and decode once at the end: a single token may carry only part of
// a multi-byte UTF-8 scalar, so per-token String decoding would corrupt CJK / emoji.
let tokenBytes = profile.bytes(for: tokenID)
if let logProb = ConstrainedSampler.logProb(ofTokenAt: tokenID, in: logits) {
// Scoring each step calls `logSumExp` over the full vocabulary — an O(vocab) exp pass per
// token. It only feeds the confidence-floor suppression check, which `shouldSuppress`
// short-circuits to a no-op when the floor is -infinity (the shipped default). Skip the
// work entirely in that case and only pay it when a caller has actually raised the floor.
if options.confidenceFloor > -.infinity,
let logProb = ConstrainedSampler.logProb(ofTokenAt: tokenID, in: logits) {
sumLogprob += logProb
}
generatedBytes.append(contentsOf: tokenBytes)
Expand Down
54 changes: 46 additions & 8 deletions Cotabby/Support/ConstrainedSampler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -140,20 +140,58 @@ enum ConstrainedSampler {
}

/// The token ids to consider this step, ordered by id. When `limit` is at least `count` every id
/// is returned (still id-ordered). Otherwise the ids are ranked by descending logit, the top
/// `limit` are kept, and that subset is re-sorted by id so downstream tie-breaking stays stable.
/// is returned (still id-ordered). Otherwise the `limit` highest-logit ids are kept and returned
/// re-sorted by id so downstream tie-breaking stays stable.
///
/// Performance invariant: this runs once per generated token and `count` is the full vocabulary
/// (~150k-256k for the shipped base models), so it must not sort the whole vocabulary. The earlier
/// `(0..<count).sorted` did exactly that — an O(count log count) closure sort plus a count-sized
/// allocation on every decode step — which made generation take seconds once the constrained
/// decoder became the only decode path. We instead select the top `limit` in a single O(count)
/// scan against a fixed-size buffer. Determinism is preserved bit-for-bit: ids are scanned
/// ascending and a candidate only displaces the current worst on a STRICTLY higher logit, so
/// equal-logit ties resolve to the lower id exactly as the full sort's `lhs < rhs` cut did.
private static func candidatePool(count: Int, logits: [Float], limit: Int) -> [Int] {
guard limit < count else {
return Array(0..<count)
}
let ranked = (0..<count).sorted { lhs, rhs in
if logits[lhs] != logits[rhs] {
return logits[lhs] > logits[rhs]
var keptIDs = [Int](repeating: 0, count: limit)
var filled = 0
// Index into `keptIDs` of the candidate to evict first (lowest logit; ties toward the larger
// id). Only meaningful once the buffer is full; recomputed after every displacement.
var worstIndex = 0
for id in 0 ..< count {
if filled < limit {
keptIDs[filled] = id
filled += 1
if filled == limit {
worstIndex = worstCandidateIndex(in: keptIDs, count: limit, logits: logits)
}
continue
}
if logits[id] > logits[keptIDs[worstIndex]] {
keptIDs[worstIndex] = id
worstIndex = worstCandidateIndex(in: keptIDs, count: limit, logits: logits)
}
}
return keptIDs.sorted()
}

/// Index into `keptIDs` of the candidate that should leave the kept set first: the lowest logit,
/// breaking ties toward the larger id so the smaller id is retained. This matches the top-`limit`
/// cut line of a full `(logit desc, id asc)` sort, which is what `candidatePool` reproduces without
/// sorting. `count` candidates is at most `limit` (small), so this O(limit) scan is cheap.
private static func worstCandidateIndex(in keptIDs: [Int], count: Int, logits: [Float]) -> Int {
var worst = 0
for index in 1 ..< count {
let candidate = keptIDs[index]
let current = keptIDs[worst]
if logits[candidate] < logits[current]
|| (logits[candidate] == logits[current] && candidate > current) {
worst = index
}
// Stable id ordering for equal logits so the kept set is deterministic at the cut line.
return lhs < rhs
}
return ranked.prefix(limit).sorted()
return worst
}

/// Numerically stable log(sum(exp(row))): subtract the max before exponentiating so large logits
Expand Down
104 changes: 104 additions & 0 deletions CotabbyTests/ConstrainedSamplerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,110 @@ final class ConstrainedSamplerTests: XCTestCase {
XCTAssertNil(id)
}

// MARK: - candidatePool equivalence (top-K selection without a full sort)

/// Deterministic, seedable RNG so the randomized equivalence sweep is reproducible across runs and
/// machines (no dependence on `SystemRandomNumberGenerator`).
private struct SplitMix64: RandomNumberGenerator {
private var state: UInt64
init(seed: UInt64) { state = seed }
mutating func next() -> UInt64 {
state &+= 0x9E37_79B9_7F4A_7C15
var mixed = state
mixed = (mixed ^ (mixed >> 30)) &* 0xBF58_476D_1CE4_E5B9
mixed = (mixed ^ (mixed >> 27)) &* 0x94D0_49BB_1331_11EB
return mixed ^ (mixed >> 31)
}
}

/// Reference selection that mirrors the *old* implementation exactly: rank the whole vocabulary by
/// (logit desc, id asc), keep the top `topK`, then argmax the survivors. `selectToken` now skips
/// the full sort, so this is the oracle the fast path must reproduce bit-for-bit.
private func referenceSelect(
logits: [Float],
control: Set<Int>,
admissible: Set<Int>?,
topK: Int,
blocked: Set<Int>
) -> Int? {
guard topK > 0, !logits.isEmpty else { return nil }
if let admissible, admissible.isEmpty { return nil }
let ranked = (0 ..< logits.count).sorted { lhs, rhs in
if logits[lhs] != logits[rhs] { return logits[lhs] > logits[rhs] }
return lhs < rhs
}
let pool = Array(ranked.prefix(topK)).sorted()
var best: Int?
var bestLogit: Float = -.infinity
for id in pool {
if control.contains(id) || blocked.contains(id) { continue }
if let admissible, !admissible.contains(id) { continue }
if best == nil || logits[id] > bestLogit {
best = id
bestLogit = logits[id]
}
}
return best
}

/// The fast top-K selection must match the old full-sort behavior for every combination of vocab
/// size, tie structure, topK cut, exclusions, blocks, and admissibility. Logits are quantized to a
/// few distinct values on many trials so exact-tie cut-line behavior (lower id wins) is exercised
/// heavily, which is where a hand-rolled top-K is most likely to diverge from a full sort.
func test_select_matchesFullSortReferenceAcrossRandomInputs() {
var rng = SplitMix64(seed: 0xC0FFEE_D00D)
for trial in 0 ..< 4000 {
let count = Int.random(in: 1 ... 120, using: &rng)
// Alternate between coarse (tie-heavy) and fine logit granularity.
let distinctValues = trial.isMultiple(of: 2) ? 4 : 64
let logits = (0 ..< count).map { _ in
Float(Int.random(in: 0 ..< distinctValues, using: &rng))
}
let control = Set((0 ..< count).filter { _ in Int.random(in: 0 ..< 5, using: &rng) == 0 })
let blocked = Set((0 ..< count).filter { _ in Int.random(in: 0 ..< 6, using: &rng) == 0 })
let admissible: Set<Int>? = Int.random(in: 0 ..< 3, using: &rng) == 0
? nil
: Set((0 ..< count).filter { _ in Int.random(in: 0 ..< 2, using: &rng) == 0 })
let topK = Int.random(in: 0 ... (count + 2), using: &rng)

let actual = ConstrainedSampler.selectToken(
logits: logits,
profile: plainProfile(count: count, control: control),
admissibleTokenIDs: admissible,
topK: topK,
blockedTokenIDs: blocked
)
let expected = referenceSelect(
logits: logits,
control: control,
admissible: admissible,
topK: topK,
blocked: blocked
)
XCTAssertEqual(
actual,
expected,
"trial \(trial): count=\(count) topK=\(topK) diverged from full-sort reference"
)
}
}

/// Cut-line tie-break: when the top-`topK` boundary falls in a run of equal logits, the lower ids
/// must be the ones kept (so the selected token is the lowest id in the tied run), exactly as the
/// previous full sort guaranteed. A large vocab makes a regression in the bounded selection obvious.
func test_select_largeVocabEqualLogits_keepsLowestIDsAtCut() {
let count = 5000
let logits = [Float](repeating: 1.0, count: count)
let id = ConstrainedSampler.selectToken(
logits: logits,
profile: plainProfile(count: count),
admissibleTokenIDs: nil,
topK: 20
)
// All logits equal -> the kept pool is ids 0...19 and argmax breaks to the lowest id.
XCTAssertEqual(id, 0)
}

// MARK: - averageLogProb

func test_averageLogProb_uniformRow_matchesNegativeLogVocab() {
Expand Down