feat(provider): Gemma 4 MTP speculative decoding (engine + VLM image/video)#306
feat(provider): Gemma 4 MTP speculative decoding (engine + VLM image/video)#306Gajesh2007 wants to merge 4 commits into
Conversation
… off by default) Bumps the mlx-swift-lm submodule to the engine-MTP integration and turns it on in the provider when opted in: - DARKBLOOM_ENABLE_MTP=1 (default off) enables drafter speculative decoding. - DARKBLOOM_MTP_DRAFTER_PATH=<dir> supplies the Gemma 4 assistant drafter (loaded once per model load, before container.perform). - DARKBLOOM_MTP_MAX_BATCH (default 2) caps the batch size that still speculates; above it the engine uses plain batched decode. makeBatchedEngine attaches a Gemma4MTPEngineRuntime to the scheduler when the target is a Gemma 4 model and a drafter loaded; non-Gemma4 / bind-failure / flag-off all fall back to plain decode. Engine-path parity (engine output == plain batched greedy at B=1,2) is covered in mlx-swift-lm tests. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…on + greedy sampler) Picks up: reusable Gemma4MTPBatchState, engine wiring, the abort/cold-prefill session-invariant fixes, and the greedy-sampler fix that makes MTP actually engage in the live batched engine. Verified live: darkbloom start --local with DARKBLOOM_ENABLE_MTP=1 speculates on text requests (102->114 tok/s B=1) with correct output; image requests route to the untouched VLM path. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Wire drafter-based MTP speculative decoding into the non-batched VLM inference path so multimodal requests decode ~1.9x faster. Bumps mlx-swift-lm to the VLM-MTP engine support. - Load the MTP drafter (gated on DARKBLOOM_ENABLE_MTP + DARKBLOOM_MTP_DRAFTER_PATH) and, when the loaded model is an MLXVLM Gemma 4, bind it to the VLM tower and stash it on the scheduler. - VLMRequestInference: when a drafter is bound, the request is greedy (temperature 0, no repetition penalty), and the model is a Gemma 4 VLM, decode via multimodal MTP prefill + Gemma4MTPTokenIterator; otherwise the existing container.generate path runs unchanged. Default OFF — with no flag the vision path is byte-identical to before. - Stream via NaiveStreamingDetokenizer (correct multi-byte text), the full stop-token set (model EOS + tokenizer EOS + extra-EOS + unknown), EOS not emitted/counted, maxTokens + cancellation honored. - Free the drafter on model unload. - Add an env-gated real-image E2E live test (VLM_MTP_E2E=1): coherent, similar-to-plain output, ~1.9x faster on a real image. Depends on Layr-Labs/mlx-swift-lm#35. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
|
The latest updates on your projects. Learn more about Vercel for GitHub.
|
ethenotethan
left a comment
There was a problem hiding this comment.
Automated Code Review — Layr-Labs/d-inference#
Verdict: REQUEST_CHANGES
Security — 1 finding(s)
- 🔵 [INFO]
provider-swift/Sources/ProviderCore/Inference/BatchScheduler.swift:517-540— Environment variable parsing lacks validation for mtpEnabledFlag() and mtpMaxBatch()- Suggestion: Add bounds checking for mtpMaxBatch() to prevent negative values or unreasonably large batch sizes that could cause resource exhaustion
Performance — 3 finding(s) (2 blocking)
- 🟡 [MEDIUM]
provider-swift/Sources/ProviderCore/Inference/VLMRequestInference.swift:143-168— Synchronous MTP drafter loading blocks request processing- Suggestion: Move drafter loading to initialization time or use async loading with caching to avoid blocking each request
- 🔵 [INFO]
provider-swift/Sources/ProviderCore/Inference/VLMRequestInference.swift:232-355— Stop token set reconstruction on every MTP request- Suggestion: Cache the stop token set construction or move it to model initialization time
- 🟡 [MEDIUM]
provider-swift/Sources/ProviderCore/Inference/BatchScheduler.swift:560-585— Synchronous drafter loading during engine creation- Suggestion: Load drafter asynchronously before engine creation or implement lazy loading with proper error handling
Type_diligence — 1 finding(s) (1 blocking)
- 🟡 [MEDIUM]
provider-swift/Sources/ProviderCore/Inference/VLMRequestInference.swift:143-149— Missing type assertion check for repetitionPenalty comparison- Suggestion: Add explicit type check:
let repPenaltyNeutral = params.repetitionPenalty.map { $0 == 1.0 } ?? trueto avoid potential type confusion when repetitionPenalty is non-nil but not a numeric type
- Suggestion: Add explicit type check:
Additive_complexity — 6 finding(s) (3 blocking)
- 🔵 [INFO]
provider-swift/Sources/ProviderCore/Inference/BatchScheduler.swift:91-105— VLM MTP configuration adds complexity with multiple environment variables and defaults- Suggestion: Consider consolidating MTP configuration into a single struct or reducing the number of configurable parameters
- 🔵 [INFO]
provider-swift/Sources/ProviderCore/Inference/BatchScheduler.swift:523-550— Custom environment variable parsing duplicates standard patterns- Suggestion: Use a configuration library or standardize environment variable parsing across the codebase
- 🟡 [MEDIUM]
provider-swift/Sources/ProviderCore/Inference/VLMRequestInference.swift:140-170— Complex gating logic for MTP path duplicates validation patterns- Suggestion: Extract MTP eligibility check into a dedicated function to reduce cognitive load and enable reuse
- 🟡 [MEDIUM]
provider-swift/Sources/ProviderCore/Inference/VLMRequestInference.swift:248-356— Large MTP streaming implementation embedded in VLMRequestInference adds significant complexity- Suggestion: Extract MTP streaming logic into a dedicated MTPStreamingDecoder class to separate concerns and improve testability
- 🔵 [INFO]
provider-swift/Tests/ProviderCoreTests/VLMRequestInferenceMTPLiveTests.swift:1-300— Entire 300-line test file for single live test case with extensive fixtures- Suggestion: Split into smaller focused tests or move fixture logic to shared test utilities
- 🟡 [MEDIUM]
provider-swift/Sources/ProviderCore/Inference/BatchScheduler.swift:716-762— MTP runtime attachment logic duplicates binding patterns for text vs VLM models- Suggestion: Extract common MTP binding logic into a helper method to reduce duplication and improve maintainability
11 finding(s) total, 6 blocking. Verdict: REQUEST_CHANGES.
🤖 Automated review by Centaur · DAR-186
| /// `Gemma4MTPEngineRuntime` instead) and whenever MTP is off. The | ||
| /// caller stores this on the `BatchScheduler` so the non-batched VLM | ||
| /// vision path can route its decode through MTP speculative decoding. | ||
| let vlmMtpDrafter: Gemma4AssistantDraftModel? | ||
| } | ||
|
|
||
| // MARK: - MTP (Gemma 4 drafter speculative decoding) config | ||
|
|
||
| /// Whether drafter-based MTP is enabled. Off by default; opt in with | ||
| /// `DARKBLOOM_ENABLE_MTP=1` (also true/yes/on). Mirrors the | ||
| /// `DARKBLOOM_PREFIX_CACHE` parsing style but defaults OFF. | ||
| static func mtpEnabledFlag() -> Bool { | ||
| let v = ProcessInfo.processInfo.environment["DARKBLOOM_ENABLE_MTP"]? | ||
| .trimmingCharacters(in: .whitespaces).lowercased() ?? "" | ||
| return v == "1" || v == "true" || v == "yes" || v == "on" | ||
| } | ||
|
|
||
| /// Explicit drafter directory for MTP, from `DARKBLOOM_MTP_DRAFTER_PATH`. | ||
| /// v1 requires an explicit path (operator-controlled); manifest-driven | ||
| /// auto-discovery is a follow-up. Returns nil if unset/missing. | ||
| static func mtpDrafterPath() -> URL? { | ||
| guard let p = ProcessInfo.processInfo.environment["DARKBLOOM_MTP_DRAFTER_PATH"], | ||
| !p.isEmpty else { return nil } | ||
| let url = URL(fileURLWithPath: p) |
There was a problem hiding this comment.
🔵 [INFO] 🔒 Environment variable parsing lacks validation for mtpEnabledFlag() and mtpMaxBatch()
💡 Suggestion: Add bounds checking for mtpMaxBatch() to prevent negative values or unreasonably large batch sizes that could cause resource exhaustion
📊 Score: 2×3 = 6 · Category: missing-input-validation
| // MTP speculative-decode gate. MTP only matches the plain | ||
| // decode byte-for-byte under GREEDY sampling (temperature 0) | ||
| // with no repetition penalty, so it is gated to exactly that | ||
| // case. Any other sampling config (temperature > 0, top-k/p | ||
| // tweaks paired with sampling, repetition penalty) falls back | ||
| // to the existing prepare/generate path below, unchanged. | ||
| let repPenaltyNeutral = | ||
| params.repetitionPenalty == nil || params.repetitionPenalty == 1.0 | ||
| if let drafter = mtpDrafter, | ||
| params.temperature == 0, | ||
| repPenaltyNeutral | ||
| { | ||
| let handled = try await streamMTP( | ||
| container: container, | ||
| lmInput: lmInput, | ||
| params: params, | ||
| drafter: drafter, | ||
| blockSize: mtpBlockSize, | ||
| startedAt: startedAt, | ||
| continuation: continuation) | ||
| // `streamMTP` returns true when it served the request | ||
| // (it finished the stream itself). false ⇒ the model was | ||
| // not an MTP-capable VLM tower; fall through to plain | ||
| // generate so we never drop the request. | ||
| if handled { return } | ||
| mtpLogger.warning( |
There was a problem hiding this comment.
🟡 [MEDIUM] ⚡ Synchronous MTP drafter loading blocks request processing
💡 Suggestion: Move drafter loading to initialization time or use async loading with caching to avoid blocking each request
📊 Score: 3×4 = 12 · Category: blocking_io
| // MARK: - MTP speculative vision decode | ||
|
|
||
| /// Run the multimodal decode through Gemma 4 MTP speculative decoding. | ||
| /// | ||
| /// The whole decode runs inside `container.perform` so model access stays | ||
| /// single-threaded (the container's contract). We seed the drafter from a | ||
| /// multimodal prefill (`forwardForMTPMultimodal`, which merges the image / | ||
| /// video features), then drive `Gemma4MTPTokenIterator` round-by-round. | ||
| /// Each yielded token id is decoded to text exactly like the library's | ||
| /// `generateGemma4MTP` (`tokenizer.decode(tokenIds:[tok])`) and streamed. | ||
| /// | ||
| /// Returns `true` when it served the request (and finished the stream | ||
| /// itself); `false` when the loaded model is not an `MLXVLM.Gemma4` tower, | ||
| /// in which case the caller falls back to the plain generate path. | ||
| /// Throws on prefill / iterator-init failure so the caller's `catch` | ||
| /// finishes the stream with the error. | ||
| private static func streamMTP( | ||
| container: ModelContainer, | ||
| lmInput: LMInput, | ||
| params: GenerateParameters, | ||
| drafter: Gemma4AssistantDraftModel, | ||
| blockSize: Int, | ||
| startedAt: Date, | ||
| continuation: AsyncThrowingStream<MLXServerGenerationEvent, Error>.Continuation | ||
| ) async throws -> Bool { | ||
| // `LMInput` is non-Sendable; hand it across the container's isolation | ||
| // boundary via the `nonSendable` perform overload rather than capturing | ||
| // it in the @Sendable closure. | ||
| try await container.perform(nonSendable: lmInput) { ctx, lmInput in | ||
| guard let vlm = ctx.model as? MLXVLM.Gemma4 else { | ||
| return false | ||
| } | ||
|
|
||
| let promptTokens = lmInput.text.tokens.size | ||
| // Full stop-token set, mirroring the baseline generate loop | ||
| // (`buildStopTokenIds`): model EOS ids ∪ the tokenizer's EOS ∪ the | ||
| // resolved `extraEOSTokens`; the unknown-token id is also treated as | ||
| // a stop. Using only `configuration.eosTokenIds` could fail to | ||
| // terminate on a tokenizer/extra EOS not present in that set. | ||
| var stopTokenIds = ctx.configuration.eosTokenIds | ||
| if let tokenizerEOS = ctx.tokenizer.eosTokenId { | ||
| stopTokenIds.insert(tokenizerEOS) | ||
| } | ||
| for token in ctx.configuration.extraEOSTokens { | ||
| if let id = ctx.tokenizer.convertTokenToId(token) { | ||
| stopTokenIds.insert(id) | ||
| } | ||
| } | ||
| let unknownId = ctx.tokenizer.unknownTokenId | ||
| // Cap output length the same way the library default does when the | ||
| // caller leaves it open-ended. | ||
| var p = params | ||
| if p.maxTokens == nil { p.maxTokens = 1024 } | ||
| let maxTokens = p.maxTokens | ||
|
|
||
| // Seed: multimodal prefill (vision merge + capturing forward), | ||
| // then a prefilled-state MTP iterator over text-only decode rounds. | ||
| let cache = vlm.mtpNewCache(parameters: p) | ||
| let prefill = try vlm.forwardForMTPMultimodal(lmInput, cache: cache) | ||
| var iter = try Gemma4MTPTokenIterator( | ||
| prefill: prefill, | ||
| cache: cache, | ||
| target: vlm, | ||
| drafter: drafter, | ||
| parameters: p, | ||
| blockSize: blockSize) | ||
|
|
||
| // Streaming detokenizer: buffers a token whose decoded segment ends | ||
| // mid-codepoint (U+FFFD) and emits only complete text — so CJK / | ||
| // emoji / accented output is never split into mojibake. Per-token | ||
| // `decode([id])` (what this replaces) breaks multi-token glyphs. | ||
| var detokenizer = NaiveStreamingDetokenizer(tokenizer: ctx.tokenizer) | ||
|
|
||
| var completionTokens = 0 | ||
| var firstTokenAt: Date? | ||
| var lastTokenAt: Date? | ||
| // Default "length"; flipped to "stop" on an EOS token. Mirrors the | ||
| // batched + library paths (maxTokens hit ⇒ "length", EOS ⇒ "stop"). | ||
| var stopReason = "length" | ||
|
|
||
| while let tokenId = iter.next() { | ||
| if Task.isCancelled { | ||
| continuation.finish() | ||
| return true | ||
| } | ||
| // Stop tokens terminate WITHOUT being emitted or counted — | ||
| // parity with the baseline generate loop (which counts only | ||
| // non-stop tokens and never emits the EOS text). | ||
| if tokenId == unknownId || stopTokenIds.contains(tokenId) { | ||
| stopReason = "stop" | ||
| break | ||
| } | ||
| if firstTokenAt == nil { firstTokenAt = Date() } | ||
| lastTokenAt = Date() | ||
| completionTokens += 1 | ||
|
|
||
| detokenizer.append(token: tokenId) | ||
| if let chunkText = detokenizer.next(), !chunkText.isEmpty { | ||
| continuation.yield(.content(chunkText)) | ||
| } | ||
|
|
||
| if let maxTokens, completionTokens >= maxTokens { | ||
| stopReason = "length" | ||
| break | ||
| } | ||
| } | ||
|
|
||
| let promptTime = (firstTokenAt ?? startedAt).timeIntervalSince(startedAt) | ||
| let generationTime = (lastTokenAt ?? firstTokenAt ?? startedAt) | ||
| .timeIntervalSince(firstTokenAt ?? startedAt) | ||
| continuation.yield( | ||
| .info( | ||
| ServerGenerationInfo( | ||
| promptTokens: promptTokens, | ||
| completionTokens: completionTokens, | ||
| promptTime: max(0, promptTime), | ||
| generationTime: max(0, generationTime), | ||
| stopReason: stopReason | ||
| ) | ||
| ) | ||
| ) | ||
| continuation.finish() | ||
| return true | ||
| } |
There was a problem hiding this comment.
🔵 [INFO] ⚡ Stop token set reconstruction on every MTP request
💡 Suggestion: Cache the stop token set construction or move it to model initialization time
📊 Score: 2×3 = 6 · Category: repeated_work
| architecture: ModelArchitecture, | ||
| diskAccountant: GlobalDiskAccountant? = nil | ||
| ) async -> EngineBuild { | ||
| // MTP drafter (Gemma 4): load before `container.perform` (which is a | ||
| // sync, @Sendable closure) so the async weight load can be awaited and | ||
| // captured immutably. Bound to the target and attached to the scheduler | ||
| // inside the perform block. Off unless DARKBLOOM_ENABLE_MTP=1 and a | ||
| // drafter path is provided. | ||
| let mtpDrafter: Gemma4AssistantDraftModel? = await { | ||
| guard mtpEnabledFlag() else { return nil } | ||
| guard let drafterDir = mtpDrafterPath() else { | ||
| mtpLogger.warning( | ||
| "DARKBLOOM_ENABLE_MTP set but DARKBLOOM_MTP_DRAFTER_PATH missing/invalid; using plain decode.") | ||
| return nil | ||
| } | ||
| do { | ||
| let d = try await Gemma4AssistantDraftModel.load(from: drafterDir) | ||
| eval(d) | ||
| mtpLogger.info("MTP enabled — drafter loaded from \(drafterDir.path)") | ||
| return d | ||
| } catch { | ||
| mtpLogger.error( | ||
| "MTP enabled but drafter load failed (\(error)); using plain decode.") | ||
| return nil | ||
| } | ||
| }() |
There was a problem hiding this comment.
🟡 [MEDIUM] ⚡ Synchronous drafter loading during engine creation
💡 Suggestion: Load drafter asynchronously before engine creation or implement lazy loading with proper error handling
📊 Score: 2×4 = 8 · Category: blocking_io
| // MTP speculative-decode gate. MTP only matches the plain | ||
| // decode byte-for-byte under GREEDY sampling (temperature 0) | ||
| // with no repetition penalty, so it is gated to exactly that | ||
| // case. Any other sampling config (temperature > 0, top-k/p | ||
| // tweaks paired with sampling, repetition penalty) falls back | ||
| // to the existing prepare/generate path below, unchanged. | ||
| let repPenaltyNeutral = |
There was a problem hiding this comment.
🟡 [MEDIUM] 🏷️ Missing type assertion check for repetitionPenalty comparison
💡 Suggestion: Add explicit type check: let repPenaltyNeutral = params.repetitionPenalty.map { $0 == 1.0 } ?? true to avoid potential type confusion when repetitionPenalty is non-nil but not a numeric type
📊 Score: 3×3 = 9 · Category: missing-type-assertion
| // MARK: - MTP (Gemma 4 drafter speculative decoding) config | ||
|
|
||
| /// Whether drafter-based MTP is enabled. Off by default; opt in with | ||
| /// `DARKBLOOM_ENABLE_MTP=1` (also true/yes/on). Mirrors the | ||
| /// `DARKBLOOM_PREFIX_CACHE` parsing style but defaults OFF. | ||
| static func mtpEnabledFlag() -> Bool { | ||
| let v = ProcessInfo.processInfo.environment["DARKBLOOM_ENABLE_MTP"]? | ||
| .trimmingCharacters(in: .whitespaces).lowercased() ?? "" | ||
| return v == "1" || v == "true" || v == "yes" || v == "on" | ||
| } | ||
|
|
||
| /// Explicit drafter directory for MTP, from `DARKBLOOM_MTP_DRAFTER_PATH`. | ||
| /// v1 requires an explicit path (operator-controlled); manifest-driven | ||
| /// auto-discovery is a follow-up. Returns nil if unset/missing. | ||
| static func mtpDrafterPath() -> URL? { | ||
| guard let p = ProcessInfo.processInfo.environment["DARKBLOOM_MTP_DRAFTER_PATH"], | ||
| !p.isEmpty else { return nil } | ||
| let url = URL(fileURLWithPath: p) | ||
| return FileManager.default.fileExists(atPath: url.path) ? url : nil | ||
| } | ||
|
|
||
| /// Max active batch size that still speculates (`DARKBLOOM_MTP_MAX_BATCH`, | ||
| /// default 2). Above it the engine falls back to plain batched decode — | ||
| /// the MoE verify tax makes speculation ≈ break-even by B≈4. | ||
| static func mtpMaxBatch() -> Int { | ||
| guard let v = ProcessInfo.processInfo.environment["DARKBLOOM_MTP_MAX_BATCH"], | ||
| let n = Int(v), n >= 1 else { return 2 } | ||
| return n |
There was a problem hiding this comment.
🔵 [INFO] 🧩 Custom environment variable parsing duplicates standard patterns
💡 Suggestion: Use a configuration library or standardize environment variable parsing across the codebase
📊 Score: 2×2 = 4 · Category: reinventing-stdlib
| repetitionPenalty: request.repetitionPenalty | ||
| ) | ||
|
|
||
| // MTP speculative-decode gate. MTP only matches the plain | ||
| // decode byte-for-byte under GREEDY sampling (temperature 0) | ||
| // with no repetition penalty, so it is gated to exactly that | ||
| // case. Any other sampling config (temperature > 0, top-k/p | ||
| // tweaks paired with sampling, repetition penalty) falls back | ||
| // to the existing prepare/generate path below, unchanged. | ||
| let repPenaltyNeutral = | ||
| params.repetitionPenalty == nil || params.repetitionPenalty == 1.0 | ||
| if let drafter = mtpDrafter, | ||
| params.temperature == 0, | ||
| repPenaltyNeutral | ||
| { | ||
| let handled = try await streamMTP( | ||
| container: container, | ||
| lmInput: lmInput, | ||
| params: params, | ||
| drafter: drafter, | ||
| blockSize: mtpBlockSize, | ||
| startedAt: startedAt, | ||
| continuation: continuation) | ||
| // `streamMTP` returns true when it served the request | ||
| // (it finished the stream itself). false ⇒ the model was | ||
| // not an MTP-capable VLM tower; fall through to plain | ||
| // generate so we never drop the request. | ||
| if handled { return } | ||
| mtpLogger.warning( | ||
| "VLM MTP requested but model is not an MLXVLM.Gemma4 tower; using plain generate.") | ||
| } |
There was a problem hiding this comment.
🟡 [MEDIUM] 🧩 Complex gating logic for MTP path duplicates validation patterns
💡 Suggestion: Extract MTP eligibility check into a dedicated function to reduce cognitive load and enable reuse
📊 Score: 3×4 = 12 · Category: duplicate-logic
| private static func streamMTP( | ||
| container: ModelContainer, | ||
| lmInput: LMInput, | ||
| params: GenerateParameters, | ||
| drafter: Gemma4AssistantDraftModel, | ||
| blockSize: Int, | ||
| startedAt: Date, | ||
| continuation: AsyncThrowingStream<MLXServerGenerationEvent, Error>.Continuation | ||
| ) async throws -> Bool { | ||
| // `LMInput` is non-Sendable; hand it across the container's isolation | ||
| // boundary via the `nonSendable` perform overload rather than capturing | ||
| // it in the @Sendable closure. | ||
| try await container.perform(nonSendable: lmInput) { ctx, lmInput in | ||
| guard let vlm = ctx.model as? MLXVLM.Gemma4 else { | ||
| return false | ||
| } | ||
|
|
||
| let promptTokens = lmInput.text.tokens.size | ||
| // Full stop-token set, mirroring the baseline generate loop | ||
| // (`buildStopTokenIds`): model EOS ids ∪ the tokenizer's EOS ∪ the | ||
| // resolved `extraEOSTokens`; the unknown-token id is also treated as | ||
| // a stop. Using only `configuration.eosTokenIds` could fail to | ||
| // terminate on a tokenizer/extra EOS not present in that set. | ||
| var stopTokenIds = ctx.configuration.eosTokenIds | ||
| if let tokenizerEOS = ctx.tokenizer.eosTokenId { | ||
| stopTokenIds.insert(tokenizerEOS) | ||
| } | ||
| for token in ctx.configuration.extraEOSTokens { | ||
| if let id = ctx.tokenizer.convertTokenToId(token) { | ||
| stopTokenIds.insert(id) | ||
| } | ||
| } | ||
| let unknownId = ctx.tokenizer.unknownTokenId | ||
| // Cap output length the same way the library default does when the | ||
| // caller leaves it open-ended. | ||
| var p = params | ||
| if p.maxTokens == nil { p.maxTokens = 1024 } | ||
| let maxTokens = p.maxTokens | ||
|
|
||
| // Seed: multimodal prefill (vision merge + capturing forward), | ||
| // then a prefilled-state MTP iterator over text-only decode rounds. | ||
| let cache = vlm.mtpNewCache(parameters: p) | ||
| let prefill = try vlm.forwardForMTPMultimodal(lmInput, cache: cache) | ||
| var iter = try Gemma4MTPTokenIterator( | ||
| prefill: prefill, | ||
| cache: cache, | ||
| target: vlm, | ||
| drafter: drafter, | ||
| parameters: p, | ||
| blockSize: blockSize) | ||
|
|
||
| // Streaming detokenizer: buffers a token whose decoded segment ends | ||
| // mid-codepoint (U+FFFD) and emits only complete text — so CJK / | ||
| // emoji / accented output is never split into mojibake. Per-token | ||
| // `decode([id])` (what this replaces) breaks multi-token glyphs. | ||
| var detokenizer = NaiveStreamingDetokenizer(tokenizer: ctx.tokenizer) | ||
|
|
||
| var completionTokens = 0 | ||
| var firstTokenAt: Date? | ||
| var lastTokenAt: Date? | ||
| // Default "length"; flipped to "stop" on an EOS token. Mirrors the | ||
| // batched + library paths (maxTokens hit ⇒ "length", EOS ⇒ "stop"). | ||
| var stopReason = "length" | ||
|
|
||
| while let tokenId = iter.next() { | ||
| if Task.isCancelled { | ||
| continuation.finish() | ||
| return true | ||
| } | ||
| // Stop tokens terminate WITHOUT being emitted or counted — | ||
| // parity with the baseline generate loop (which counts only | ||
| // non-stop tokens and never emits the EOS text). | ||
| if tokenId == unknownId || stopTokenIds.contains(tokenId) { | ||
| stopReason = "stop" | ||
| break | ||
| } | ||
| if firstTokenAt == nil { firstTokenAt = Date() } | ||
| lastTokenAt = Date() | ||
| completionTokens += 1 | ||
|
|
||
| detokenizer.append(token: tokenId) | ||
| if let chunkText = detokenizer.next(), !chunkText.isEmpty { | ||
| continuation.yield(.content(chunkText)) | ||
| } | ||
|
|
||
| if let maxTokens, completionTokens >= maxTokens { | ||
| stopReason = "length" | ||
| break | ||
| } | ||
| } | ||
|
|
||
| let promptTime = (firstTokenAt ?? startedAt).timeIntervalSince(startedAt) | ||
| let generationTime = (lastTokenAt ?? firstTokenAt ?? startedAt) | ||
| .timeIntervalSince(firstTokenAt ?? startedAt) | ||
| continuation.yield( | ||
| .info( | ||
| ServerGenerationInfo( | ||
| promptTokens: promptTokens, | ||
| completionTokens: completionTokens, | ||
| promptTime: max(0, promptTime), | ||
| generationTime: max(0, generationTime), | ||
| stopReason: stopReason | ||
| ) | ||
| ) | ||
| ) | ||
| continuation.finish() | ||
| return true | ||
| } | ||
| } |
There was a problem hiding this comment.
🟡 [MEDIUM] 🧩 Large MTP streaming implementation embedded in VLMRequestInference adds significant complexity
💡 Suggestion: Extract MTP streaming logic into a dedicated MTPStreamingDecoder class to separate concerns and improve testability
📊 Score: 4×3 = 12 · Category: misplaced-responsibility
| // Live, real-image end-to-end coverage for the VLM MTP speculative-decode | ||
| // path (`VLMRequestInference.stream` with `mtpDrafter:`). | ||
| // | ||
| // This is the final proof that multimodal Gemma 4 MTP works against a real | ||
| // image and a real ~14 GB qat-4bit VLM tower + its qat-4bit assistant | ||
| // drafter, through the exact `ModelContainer` / `container.perform` path | ||
| // ProviderLoop serves with. It runs TWO streams over the same request — a | ||
| // plain (mtpDrafter: nil) baseline and an MTP stream — and asserts the MTP | ||
| // output is coherent and broadly similar to the baseline. | ||
| // | ||
| // Gated by VLM_MTP_E2E=1 and requires both checkpoints in the local cache. | ||
| // Skips cleanly (recorded as a warning) when any precondition is unmet, so | ||
| // the default suite on a model-less / GPU-less runner stays green. | ||
| // | ||
| // cd provider-swift | ||
| // VLM_MTP_E2E=1 swift test --filter vlmMTPRealImageEndToEnd 2>&1 | tail -60 | ||
| // | ||
| // NOTE on byte-identity: MTP and plain are NOT asserted byte-equal. The | ||
| // VLM tower verifies several drafted tokens per round in a single batched | ||
| // forward, and bf16 rounding makes a handful of near-tie logits resolve | ||
| // differently between the width-1 (plain) and width-N (verify) passes. That | ||
| // divergence is known and accepted; we assert a *loose* similarity (shared | ||
| // prefix or strong word overlap) instead of equality. | ||
|
|
||
| import CoreImage | ||
| import Foundation | ||
| import Testing | ||
| import MLX | ||
| import MLXLLM | ||
| import MLXLMCommon | ||
| import MLXLMServer | ||
| import MLXVLM | ||
|
|
||
| @testable import ProviderCore | ||
|
|
||
| private enum VLMMTPLiveFixtures { | ||
|
|
||
| /// Opt-in env var for this (expensive, ~14 GB) live test. | ||
| static let envVar = "VLM_MTP_E2E" | ||
|
|
||
| static var enabled: Bool { | ||
| ProcessInfo.processInfo.environment[envVar].map { !$0.isEmpty } ?? false | ||
| } | ||
|
|
||
| /// Multimodal Gemma 4 target (config declares `vision_config`). | ||
| /// Overridable via VLM_MTP_TARGET_DIR for a differently-located snapshot. | ||
| static var targetDir: String { | ||
| ProcessInfo.processInfo.environment["VLM_MTP_TARGET_DIR"] | ||
| ?? "/Users/gaj/.cache/huggingface/hub/models--mlx-community--gemma-4-26B-A4B-it-qat-4bit/snapshots/0e3cbab38ce568cf6e23543010d08d03b731910c" | ||
| } | ||
|
|
||
| /// Matching qat-4bit assistant drafter. Overridable via VLM_MTP_DRAFTER_DIR. | ||
| static var drafterDir: String { | ||
| ProcessInfo.processInfo.environment["VLM_MTP_DRAFTER_DIR"] | ||
| ?? "/Users/gaj/.cache/huggingface/hub/models--mlx-community--gemma-4-26B-A4B-it-qat-assistant-4bit/snapshots/bb94eae1b70a80dac16cbf959bb4b7d56bd1fb8c" | ||
| } | ||
|
|
||
| /// First existing candidate image, or nil if none are present. | ||
| static var testImage: URL? { | ||
| let candidates = [ | ||
| ProcessInfo.processInfo.environment["VLM_MTP_IMAGE"], | ||
| "/Users/gaj/Downloads/eigen-labs-logo.png", | ||
| "/Users/gaj/Downloads/exo-logo.png", | ||
| ].compactMap { $0 } | ||
| for path in candidates where FileManager.default.fileExists(atPath: path) { | ||
| return URL(fileURLWithPath: path) | ||
| } | ||
| return nil | ||
| } | ||
| } | ||
|
|
||
| /// Outcome of consuming one `VLMRequestInference.stream`. | ||
| private struct StreamRun { | ||
| var text: String | ||
| var completionTokens: Int | ||
| var promptTokens: Int | ||
| var generationTime: TimeInterval | ||
| /// Wall-clock seconds around stream consumption. | ||
| var wallSeconds: Double | ||
|
|
||
| /// Decode tokens/sec from the engine-reported generation time, falling | ||
| /// back to wall clock when the engine didn't report a positive time. | ||
| var tokensPerSecond: Double { | ||
| if generationTime > 0 { return Double(completionTokens) / generationTime } | ||
| if wallSeconds > 0 { return Double(completionTokens) / wallSeconds } | ||
| return 0 | ||
| } | ||
| } | ||
|
|
||
| @Suite("VLM MTP real-image end-to-end (live)", .serialized) | ||
| struct VLMRequestInferenceMTPLiveTests { | ||
|
|
||
| @Test( | ||
| "VLM MTP describes a real image, coherent and similar to plain decode", | ||
| .enabled(if: VLMMTPLiveFixtures.enabled) | ||
| ) | ||
| func vlmMTPRealImageEndToEnd() async throws { | ||
| // 0. Preconditions — skip cleanly (recorded warning) if unmet so the | ||
| // default suite stays green on a runner without these assets. | ||
| guard LiveInferenceFixtures.ensureMetallibColocated() != nil else { | ||
| withKnownIssue("skipped: \(LiveFixtureSkip.missingMetallib)") { | ||
| Issue.record("\(LiveFixtureSkip.missingMetallib)") | ||
| } | ||
| return | ||
| } | ||
| let targetURL = URL(fileURLWithPath: VLMMTPLiveFixtures.targetDir) | ||
| let drafterURL = URL(fileURLWithPath: VLMMTPLiveFixtures.drafterDir) | ||
| let fm = FileManager.default | ||
| guard fm.fileExists(atPath: targetURL.appendingPathComponent("config.json").path) else { | ||
| withKnownIssue("skipped: VLM target not found at \(targetURL.path)") { | ||
| Issue.record("VLM target not found at \(targetURL.path)") | ||
| } | ||
| return | ||
| } | ||
| guard fm.fileExists(atPath: drafterURL.appendingPathComponent("config.json").path) else { | ||
| withKnownIssue("skipped: drafter not found at \(drafterURL.path)") { | ||
| Issue.record("drafter not found at \(drafterURL.path)") | ||
| } | ||
| return | ||
| } | ||
| guard let imageURL = VLMMTPLiveFixtures.testImage else { | ||
| withKnownIssue("skipped: no test image present") { | ||
| Issue.record("no test image present (set VLM_MTP_IMAGE)") | ||
| } | ||
| return | ||
| } | ||
|
|
||
| // Cap MLX memory the same way ProviderLoop does; this is a big model. | ||
| LiveInferenceFixtures.applyMemoryBudget(maxBytes: 24 * 1024 * 1024 * 1024) | ||
|
|
||
| // 1. Load the real VLM ModelContainer via the SAME factory path | ||
| // ProviderLoop.loadModelContainer uses for a vision_config model. | ||
| #expect( | ||
| ProviderLoop.modelIsVLM(at: targetURL), | ||
| "target must declare vision_config to exercise the VLM path") | ||
| let container = try await VLMModelFactory.shared.loadContainer( | ||
| from: targetURL, | ||
| using: LocalTokenizerLoader() | ||
| ) | ||
|
|
||
| // 2. Load the assistant drafter. | ||
| let drafter = try await Gemma4AssistantDraftModel.load(from: drafterURL) | ||
|
|
||
| // 3. Build the OpenAI request: one text part + one inline base64 image. | ||
| let imageBytes = try Data(contentsOf: imageURL) | ||
| let dataURI = "data:image/png;base64," + imageBytes.base64EncodedString() | ||
| let request = OpenAIChatCompletionRequest( | ||
| model: "gemma-4-vlm", | ||
| messages: [ | ||
| .init( | ||
| role: .user, | ||
| content: .parts([ | ||
| .text("Describe this image in one sentence."), | ||
| .imageURL(dataURI), | ||
| ])) | ||
| ], | ||
| temperature: 0, | ||
| maxTokens: 64 | ||
| ) | ||
| #expect(VLMRequestInference.hasMedia(request), "request must be detected as multimodal") | ||
|
|
||
| // 4. Run TWO streams: plain baseline, then MTP. | ||
| let plain = try await runStream( | ||
| container: container, request: request, drafter: nil) | ||
| let mtp = try await runStream( | ||
| container: container, request: request, drafter: drafter) | ||
|
|
||
| // 5. Coherence + similarity assertions. | ||
| let plainTrim = plain.text.trimmingCharacters(in: .whitespacesAndNewlines) | ||
| let mtpTrim = mtp.text.trimmingCharacters(in: .whitespacesAndNewlines) | ||
|
|
||
| let letterCount = mtpTrim.filter { $0.isLetter }.count | ||
| let coherent = mtpTrim.count > 0 && letterCount >= 10 && !isDegenerate(mtpTrim) | ||
|
|
||
| let prefix = commonPrefixLength(plainTrim, mtpTrim) | ||
| let jaccard = wordJaccard(plainTrim, mtpTrim) | ||
| // "Similar" = a meaningful shared prefix OR strong word overlap. bf16 | ||
| // near-ties in the width-N verify pass can flip a token mid-stream, so | ||
| // we deliberately do NOT require byte-identity. | ||
| let similar = prefix >= 8 || jaccard >= 0.4 | ||
|
|
||
| let speedup = plain.tokensPerSecond > 0 | ||
| ? mtp.tokensPerSecond / plain.tokensPerSecond : 0 | ||
|
|
||
| // 6. Full diagnostic dump for the manual record. | ||
| print("================ VLM MTP real-image E2E ================") | ||
| print("image: \(imageURL.lastPathComponent) (\(imageBytes.count) bytes)") | ||
| print("target: \(VLMMTPLiveFixtures.targetDir)") | ||
| print("drafter: \(VLMMTPLiveFixtures.drafterDir)") | ||
| print("----------------- PLAIN (baseline) --------------------") | ||
| print(plainTrim) | ||
| print("[plain] completion_tokens=\(plain.completionTokens) " | ||
| + "prompt_tokens=\(plain.promptTokens) " | ||
| + String(format: "tok/s=%.2f wall=%.2fs gen=%.2fs", | ||
| plain.tokensPerSecond, plain.wallSeconds, plain.generationTime)) | ||
| print("----------------- MTP ---------------------------------") | ||
| print(mtpTrim) | ||
| print("[mtp] completion_tokens=\(mtp.completionTokens) " | ||
| + "prompt_tokens=\(mtp.promptTokens) " | ||
| + String(format: "tok/s=%.2f wall=%.2fs gen=%.2fs", | ||
| mtp.tokensPerSecond, mtp.wallSeconds, mtp.generationTime)) | ||
| print("----------------- SIMILARITY --------------------------") | ||
| print(String( | ||
| format: "common_prefix_chars=%d word_jaccard=%.3f speedup=%.2fx coherent=%@ similar=%@", | ||
| prefix, jaccard, speedup, | ||
| coherent ? "yes" : "no", similar ? "yes" : "no")) | ||
| print("=======================================================") | ||
|
|
||
| // Assertions. | ||
| let coherentMsg = "MTP output must be coherent (>=10 letters, not degenerate): \(mtpTrim.prefix(120))" | ||
| let similarMsg = "MTP must be loosely similar to plain (prefix=\(prefix) jaccard=\(jaccard)). " | ||
| + "plain=<\(plainTrim.prefix(120))> mtp=<\(mtpTrim.prefix(120))>" | ||
| #expect(mtpTrim.count > 0, "MTP output must be non-empty") | ||
| #expect(coherent, "\(coherentMsg)") | ||
| #expect(plainTrim.count > 0, "plain baseline must be non-empty") | ||
| #expect(similar, "\(similarMsg)") | ||
| } | ||
|
|
||
| // MARK: - Stream consumption | ||
|
|
||
| /// Consume one `VLMRequestInference.stream`, timing the wall clock around | ||
| /// the iteration and collecting content + the final `.info`. | ||
| private func runStream( | ||
| container: ModelContainer, | ||
| request: OpenAIChatCompletionRequest, | ||
| drafter: Gemma4AssistantDraftModel? | ||
| ) async throws -> StreamRun { | ||
| let stream = VLMRequestInference.stream( | ||
| container: container, | ||
| request: request, | ||
| defaultMaxTokens: 64, | ||
| mtpDrafter: drafter, | ||
| mtpBlockSize: 3 | ||
| ) | ||
| var text = "" | ||
| var completionTokens = 0 | ||
| var promptTokens = 0 | ||
| var generationTime: TimeInterval = 0 | ||
| let start = Date() | ||
| for try await event in stream { | ||
| switch event { | ||
| case .content(let chunk): | ||
| text += chunk | ||
| case .info(let info): | ||
| completionTokens = info.completionTokens | ||
| promptTokens = info.promptTokens | ||
| generationTime = info.generationTime | ||
| case .toolCall: | ||
| continue | ||
| } | ||
| } | ||
| let wall = Date().timeIntervalSince(start) | ||
| return StreamRun( | ||
| text: text, | ||
| completionTokens: completionTokens, | ||
| promptTokens: promptTokens, | ||
| generationTime: generationTime, | ||
| wallSeconds: wall) | ||
| } | ||
|
|
||
| // MARK: - Similarity / coherence helpers | ||
|
|
||
| /// Length (in characters) of the common leading prefix of two strings. | ||
| private func commonPrefixLength(_ a: String, _ b: String) -> Int { | ||
| var count = 0 | ||
| var ai = a.startIndex | ||
| var bi = b.startIndex | ||
| while ai < a.endIndex, bi < b.endIndex, a[ai] == b[bi] { | ||
| count += 1 | ||
| ai = a.index(after: ai) | ||
| bi = b.index(after: bi) | ||
| } | ||
| return count | ||
| } | ||
|
|
||
| /// Jaccard similarity over lowercased word sets. | ||
| private func wordJaccard(_ a: String, _ b: String) -> Double { | ||
| let wa = Set(words(a)) | ||
| let wb = Set(words(b)) | ||
| if wa.isEmpty && wb.isEmpty { return 1 } | ||
| let inter = wa.intersection(wb).count | ||
| let union = wa.union(wb).count | ||
| return union == 0 ? 0 : Double(inter) / Double(union) | ||
| } | ||
|
|
||
| private func words(_ s: String) -> [String] { | ||
| s.lowercased() | ||
| .components(separatedBy: CharacterSet.alphanumerics.inverted) | ||
| .filter { !$0.isEmpty } | ||
| } | ||
|
|
||
| /// Detect obviously-broken output: a single token/word repeated many | ||
| /// times with almost no lexical variety. | ||
| private func isDegenerate(_ s: String) -> Bool { | ||
| let ws = words(s) | ||
| guard ws.count >= 6 else { return false } | ||
| let unique = Set(ws).count | ||
| return Double(unique) / Double(ws.count) < 0.2 | ||
| } | ||
| } |
There was a problem hiding this comment.
🔵 [INFO] 🧩 Entire 300-line test file for single live test case with extensive fixtures
💡 Suggestion: Split into smaller focused tests or move fixture logic to shared test utilities
📊 Score: 2×2 = 4 · Category: over-abstraction
| eosTokenIds: eosTokenIds, | ||
| prefixCache: enginePrefixCache // nil unless .engine + flag (TB-007) | ||
| ) | ||
| // Attach the Gemma 4 MTP runtime if a drafter was loaded and the | ||
| // target is a Gemma 4 model. The scheduler threads it into every | ||
| // GenerationBatch it builds; eligibility (size ≤ maxBatch, greedy) | ||
| // is decided per decode step. Non-Gemma-4 or bind failure → plain | ||
| // decode (runtime stays nil). | ||
| // | ||
| // VLM (image/video) models never flow through the batched engine | ||
| // for multimodal requests — those are served by the non-batched | ||
| // `VLMRequestInference.stream` vision path. So the text-target cast | ||
| // below is nil for an `MLXVLM.Gemma4`. In that case we bind the SAME | ||
| // drafter to the VLM tower and hand it out via `EngineBuild` so the | ||
| // vision path can route greedy decode through MTP. Carried out of | ||
| // this static closure on `vlmMtpDrafter`. | ||
| var vlmMtpDrafter: Gemma4AssistantDraftModel? = nil | ||
| if let drafter = mtpDrafter { | ||
| let target: Gemma4TextModel? = | ||
| (ctx.model as? Gemma4TextModel) | ||
| ?? (ctx.model as? Gemma4Model)?.textModel | ||
| if let target { | ||
| do { | ||
| scheduler.mtpRuntime = try Gemma4MTPEngineRuntime( | ||
| target: target, drafter: drafter, maxBatch: mtpMaxBatchCap) | ||
| mtpLogger.info( | ||
| "MTP runtime attached for \(modelId) (maxBatch=\(mtpMaxBatchCap)).") | ||
| } catch { | ||
| mtpLogger.error("MTP bind failed (\(error)); using plain decode.") | ||
| } | ||
| } else if let vlm = ctx.model as? MLXVLM.Gemma4 { | ||
| // VLM tower: bind the drafter to it for the non-batched | ||
| // vision path. `bind` is idempotent on the same target. | ||
| do { | ||
| try drafter.bind(target: vlm) | ||
| vlmMtpDrafter = drafter | ||
| mtpLogger.info( | ||
| "MTP drafter bound to VLM tower for \(modelId); vision decode will speculate.") | ||
| } catch { | ||
| mtpLogger.error( | ||
| "MTP VLM bind failed for \(modelId) (\(error)); vision decode uses plain generate.") | ||
| } | ||
| } else { | ||
| mtpLogger.warning( | ||
| "MTP enabled but \(modelId) is not a Gemma 4 model; using plain decode.") | ||
| } | ||
| } |
There was a problem hiding this comment.
🟡 [MEDIUM] 🧩 MTP runtime attachment logic duplicates binding patterns for text vs VLM models
💡 Suggestion: Extract common MTP binding logic into a helper method to reduce duplication and improve maintainability
📊 Score: 3×3 = 9 · Category: duplicate-logic
|
This PR adds MTP (Multi-Token Prediction) speculative decoding for VLM (Gemma 4) inference and is security-neutral overall — no existing mitigations are weakened and no new trust-boundary crossings are introduced, but two narrow issues in the new code path warrant attention. Trust Boundaries Touched
Per-Threat AnalysisT-007 — Provider serves manipulated model outputs ℹ️ Neutral
T-028 — Residual inference data in GPU memory ℹ️ Neutral
T-041 — Cross-tenant prefix-cache sharing + TTFT timing oracle
New Attack Surface Not Covered by an Existing ThreatOperator-controlled drafter path (
static func mtpDrafterPath() -> URL? {
guard let p = ProcessInfo.processInfo.environment["DARKBLOOM_MTP_DRAFTER_PATH"],
!p.isEmpty else { return nil }
let url = URL(fileURLWithPath: p)
return FileManager.default.fileExists(atPath: url.path) ? url : nil
}There is no validation that the path is within an expected directory (e.g.
Recommendation: Hash the drafter directory with
static func mtpMaxBatch() -> Int {
guard let v = ProcessInfo.processInfo.environment["DARKBLOOM_MTP_MAX_BATCH"],
let n = Int(v), n >= 1 else { return 2 }
return n
}
SEC-* Open Findings ResolvedNone — this PR does not address any open SEC-* finding. 🔐 Threat model: |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: b35956bf06
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| if let drafter = mtpDrafter, | ||
| params.temperature == 0, | ||
| repPenaltyNeutral | ||
| { |
There was a problem hiding this comment.
Preserve tool-call parsing when enabling VLM MTP
When DARKBLOOM_ENABLE_MTP is on, any greedy multimodal request with mtpDrafter takes this branch, but the custom streamMTP loop only yields .content and .info; it never runs the library/tool-call handler that the plain container.generate path can surface as .toolCall and that this function forwards below. For VLM requests that include or produce tool calls, enabling MTP changes the wire response from structured tool-call events to raw text content, so gate MTP off for tool-enabled requests or add the same tool-call parsing to the MTP path.
Useful? React with 👍 / 👎.
| return nil | ||
| } | ||
| do { | ||
| let d = try await Gemma4AssistantDraftModel.load(from: drafterDir) |
There was a problem hiding this comment.
Avoid loading the drafter for non-Gemma models
With DARKBLOOM_ENABLE_MTP=1, this loads the assistant drafter before the code has checked whether the resident model is actually Gemma 4; the non-Gemma case is only detected later in container.perform and then falls back to plain decode. On providers that serve Qwen/Llama/etc. while MTP is globally enabled, every model load still allocates the large drafter weights first, which can slow loads or exhaust memory even though the drafter is unusable for that model. Defer the drafter load until after confirming the target model/config is MTP-capable.
Useful? React with 👍 / 👎.
… load for non-Gemma - Gate VLM MTP off when the request carries tools: `streamMTP` only emits `.content`/`.info`, so taking the MTP branch would drop the structured `.toolCall` events the plain `container.generate` path surfaces — a wire contract change. Tool requests now stay on the plain path. - Skip the (large) drafter weight load when the resident model is not an MTP-capable Gemma 4 tower (text or VLM). Detected by the actual loaded type in `snapshotContainer` (`Gemma4TextModel`/`Gemma4Model`/`MLXVLM.Gemma4`), not by model-id/alias strings. On a multi-model provider with MTP globally enabled, a Qwen/Llama load no longer allocates the unusable drafter. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Review triageCodex (chatgpt-codex-connector) — both fixed in
Automated reviewer (ethenotethan): mostly style/complexity opinions; two are misreads worth noting:
|
|
Found 1 test failure on Blacksmith runners: Failure
|

Provider: Gemma 4 MTP speculative decoding (engine + VLM)
Opt-in drafter-based MTP speculative decoding in the provider. Bumps
mlx-swift-lmto the MTP engine + vision-tower support (Layr-Labs/mlx-swift-lm#35). Off by default — no serving behavior change until an operator setsDARKBLOOM_ENABLE_MTP=1.Text (batched engine)
DARKBLOOM_ENABLE_MTP+DARKBLOOM_MTP_DRAFTER_PATH), binds it to the text target, and attaches the engine MTP runtime. Speculates atB ≤ DARKBLOOM_MTP_MAX_BATCH(default 2), plain decode above. Greedy rows only.VLM (image/video) — this PR's focus
VLMRequestInferencegains an MTP branch: when a drafter is bound, the request is greedy (temperature == 0, no repetition penalty), and the model is a Gemma 4 VLM, it decodes via multimodal MTP prefill +Gemma4MTPTokenIterator. Otherwise the existingcontainer.generatepath runs unchanged.NaiveStreamingDetokenizer(correct multi-byte/CJK/emoji), the full stop-token set (model EOS ∪ tokenizer EOS ∪ extra-EOS ∪ unknown), EOS suppressed + not counted,maxTokens+ cancellation honored. Drafter freed on model unload.Gating (all must hold, else plain decode)
DARKBLOOM_ENABLE_MTP=1AND a drafter loaded+bound AND model is Gemma 4 VLM AND greedy (temp 0, no penalty). Default off; safe fallback always available.Verification
VLM_MTP_E2E=1): plain vs MTP oneigen-labs-logo.png→ identical, coherent output, ~1.9x faster (43.8 → 73.9 tok/s).swift buildclean; 22VLMRequestInferenceunit tests pass; default-off path unchanged.Notes
Test plan
swift build— green.VLM_MTP_E2E=1 swift test --filter vlmMTPRealImageEndToEnd(needs the qat-4bit VLM + qat-assistant-4bit drafter in the local HF cache).Depends on Layr-Labs/mlx-swift-lm#35 (merge that first; this PR pins the submodule to it).
🤖 Generated with Claude Code
Need help on this PR? Tag
/codesmithwith what you need. Autofix is disabled.