Skip to content

feat(provider): Gemma 4 MTP speculative decoding (engine + VLM image/video)#306

Open
Gajesh2007 wants to merge 4 commits into
masterfrom
feat/gemma4-mtp-engine
Open

feat(provider): Gemma 4 MTP speculative decoding (engine + VLM image/video)#306
Gajesh2007 wants to merge 4 commits into
masterfrom
feat/gemma4-mtp-engine

Conversation

@Gajesh2007

@Gajesh2007 Gajesh2007 commented Jun 10, 2026

Copy link
Copy Markdown
Member

Provider: Gemma 4 MTP speculative decoding (engine + VLM)

Opt-in drafter-based MTP speculative decoding in the provider. Bumps mlx-swift-lm to the MTP engine + vision-tower support (Layr-Labs/mlx-swift-lm#35). Off by default — no serving behavior change until an operator sets DARKBLOOM_ENABLE_MTP=1.

Text (batched engine)

  • Loads a Gemma 4 drafter (gated on DARKBLOOM_ENABLE_MTP + DARKBLOOM_MTP_DRAFTER_PATH), binds it to the text target, and attaches the engine MTP runtime. Speculates at B ≤ DARKBLOOM_MTP_MAX_BATCH (default 2), plain decode above. Greedy rows only.

VLM (image/video) — this PR's focus

  • When the loaded model is an MLXVLM Gemma 4, the same drafter is bound to the vision tower and stashed on the scheduler.
  • VLMRequestInference gains an MTP branch: when a drafter is bound, the request is greedy (temperature == 0, no repetition penalty), and the model is a Gemma 4 VLM, it decodes via multimodal MTP prefill + Gemma4MTPTokenIterator. Otherwise the existing container.generate path runs unchanged.
  • Streams via NaiveStreamingDetokenizer (correct multi-byte/CJK/emoji), the full stop-token set (model EOS ∪ tokenizer EOS ∪ extra-EOS ∪ unknown), EOS suppressed + not counted, maxTokens + cancellation honored. Drafter freed on model unload.

Gating (all must hold, else plain decode)

DARKBLOOM_ENABLE_MTP=1 AND a drafter loaded+bound AND model is Gemma 4 VLM AND greedy (temp 0, no penalty). Default off; safe fallback always available.

Verification

  • Real-image E2E live test (VLM_MTP_E2E=1): plain vs MTP on eigen-labs-logo.pngidentical, coherent output, ~1.9x faster (43.8 → 73.9 tok/s).
  • swift build clean; 22 VLMRequestInference unit tests pass; default-off path unchanged.
  • Two independent reviewers (Codex + Claude) passed; all flagged issues fixed (drafter unload leak, per-token-decode mojibake, incomplete EOS set, EOS over-count).

Notes

  • The VLM tower's multi-token verify is bf16-noisier than the text tower; MTP output diverges from non-MTP decode only at a genuine ~0.14-logit near-tie. Output stays coherent/correct — within the "similar, not bit-identical" tolerance.
  • Concurrent VLM requests serialize their decode under the model lock (the VLM path is already non-batched + uncapped) — acceptable, flagged for awareness.

Test plan

  • swift build — green.
  • VLM_MTP_E2E=1 swift test --filter vlmMTPRealImageEndToEnd (needs the qat-4bit VLM + qat-assistant-4bit drafter in the local HF cache).

Depends on Layr-Labs/mlx-swift-lm#35 (merge that first; this PR pins the submodule to it).

🤖 Generated with Claude Code


View with Codesmith Autofix with Codesmith
Need help on this PR? Tag /codesmith with what you need. Autofix is disabled.

Gajesh2007 and others added 3 commits June 9, 2026 22:31
… off by default)

Bumps the mlx-swift-lm submodule to the engine-MTP integration and turns it
on in the provider when opted in:
- DARKBLOOM_ENABLE_MTP=1 (default off) enables drafter speculative decoding.
- DARKBLOOM_MTP_DRAFTER_PATH=<dir> supplies the Gemma 4 assistant drafter
  (loaded once per model load, before container.perform).
- DARKBLOOM_MTP_MAX_BATCH (default 2) caps the batch size that still
  speculates; above it the engine uses plain batched decode.

makeBatchedEngine attaches a Gemma4MTPEngineRuntime to the scheduler when the
target is a Gemma 4 model and a drafter loaded; non-Gemma4 / bind-failure /
flag-off all fall back to plain decode. Engine-path parity (engine output ==
plain batched greedy at B=1,2) is covered in mlx-swift-lm tests.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…on + greedy sampler)

Picks up: reusable Gemma4MTPBatchState, engine wiring, the abort/cold-prefill
session-invariant fixes, and the greedy-sampler fix that makes MTP actually
engage in the live batched engine. Verified live: darkbloom start --local with
DARKBLOOM_ENABLE_MTP=1 speculates on text requests (102->114 tok/s B=1) with
correct output; image requests route to the untouched VLM path.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Wire drafter-based MTP speculative decoding into the non-batched VLM
inference path so multimodal requests decode ~1.9x faster. Bumps
mlx-swift-lm to the VLM-MTP engine support.

- Load the MTP drafter (gated on DARKBLOOM_ENABLE_MTP +
  DARKBLOOM_MTP_DRAFTER_PATH) and, when the loaded model is an MLXVLM
  Gemma 4, bind it to the VLM tower and stash it on the scheduler.
- VLMRequestInference: when a drafter is bound, the request is greedy
  (temperature 0, no repetition penalty), and the model is a Gemma 4 VLM,
  decode via multimodal MTP prefill + Gemma4MTPTokenIterator; otherwise the
  existing container.generate path runs unchanged. Default OFF — with no
  flag the vision path is byte-identical to before.
- Stream via NaiveStreamingDetokenizer (correct multi-byte text), the full
  stop-token set (model EOS + tokenizer EOS + extra-EOS + unknown), EOS not
  emitted/counted, maxTokens + cancellation honored.
- Free the drafter on model unload.
- Add an env-gated real-image E2E live test (VLM_MTP_E2E=1): coherent,
  similar-to-plain output, ~1.9x faster on a real image.

Depends on Layr-Labs/mlx-swift-lm#35.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@vercel

vercel Bot commented Jun 10, 2026

Copy link
Copy Markdown

The latest updates on your projects. Learn more about Vercel for GitHub.

Project Deployment Actions Updated (UTC)
d-inference Ready Ready Preview Jun 10, 2026 7:00pm
d-inference-console-ui-dev Ready Ready Preview Jun 10, 2026 7:00pm
d-inference-landing Ready Ready Preview Jun 10, 2026 7:00pm

Request Review

@ethenotethan ethenotethan left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Automated Code Review — Layr-Labs/d-inference#

Verdict: REQUEST_CHANGES

Security — 1 finding(s)

  • 🔵 [INFO] provider-swift/Sources/ProviderCore/Inference/BatchScheduler.swift:517-540 — Environment variable parsing lacks validation for mtpEnabledFlag() and mtpMaxBatch()
    • Suggestion: Add bounds checking for mtpMaxBatch() to prevent negative values or unreasonably large batch sizes that could cause resource exhaustion

Performance — 3 finding(s) (2 blocking)

  • 🟡 [MEDIUM] provider-swift/Sources/ProviderCore/Inference/VLMRequestInference.swift:143-168 — Synchronous MTP drafter loading blocks request processing
    • Suggestion: Move drafter loading to initialization time or use async loading with caching to avoid blocking each request
  • 🔵 [INFO] provider-swift/Sources/ProviderCore/Inference/VLMRequestInference.swift:232-355 — Stop token set reconstruction on every MTP request
    • Suggestion: Cache the stop token set construction or move it to model initialization time
  • 🟡 [MEDIUM] provider-swift/Sources/ProviderCore/Inference/BatchScheduler.swift:560-585 — Synchronous drafter loading during engine creation
    • Suggestion: Load drafter asynchronously before engine creation or implement lazy loading with proper error handling

Type_diligence — 1 finding(s) (1 blocking)

  • 🟡 [MEDIUM] provider-swift/Sources/ProviderCore/Inference/VLMRequestInference.swift:143-149 — Missing type assertion check for repetitionPenalty comparison
    • Suggestion: Add explicit type check: let repPenaltyNeutral = params.repetitionPenalty.map { $0 == 1.0 } ?? true to avoid potential type confusion when repetitionPenalty is non-nil but not a numeric type

Additive_complexity — 6 finding(s) (3 blocking)

  • 🔵 [INFO] provider-swift/Sources/ProviderCore/Inference/BatchScheduler.swift:91-105 — VLM MTP configuration adds complexity with multiple environment variables and defaults
    • Suggestion: Consider consolidating MTP configuration into a single struct or reducing the number of configurable parameters
  • 🔵 [INFO] provider-swift/Sources/ProviderCore/Inference/BatchScheduler.swift:523-550 — Custom environment variable parsing duplicates standard patterns
    • Suggestion: Use a configuration library or standardize environment variable parsing across the codebase
  • 🟡 [MEDIUM] provider-swift/Sources/ProviderCore/Inference/VLMRequestInference.swift:140-170 — Complex gating logic for MTP path duplicates validation patterns
    • Suggestion: Extract MTP eligibility check into a dedicated function to reduce cognitive load and enable reuse
  • 🟡 [MEDIUM] provider-swift/Sources/ProviderCore/Inference/VLMRequestInference.swift:248-356 — Large MTP streaming implementation embedded in VLMRequestInference adds significant complexity
    • Suggestion: Extract MTP streaming logic into a dedicated MTPStreamingDecoder class to separate concerns and improve testability
  • 🔵 [INFO] provider-swift/Tests/ProviderCoreTests/VLMRequestInferenceMTPLiveTests.swift:1-300 — Entire 300-line test file for single live test case with extensive fixtures
    • Suggestion: Split into smaller focused tests or move fixture logic to shared test utilities
  • 🟡 [MEDIUM] provider-swift/Sources/ProviderCore/Inference/BatchScheduler.swift:716-762 — MTP runtime attachment logic duplicates binding patterns for text vs VLM models
    • Suggestion: Extract common MTP binding logic into a helper method to reduce duplication and improve maintainability

11 finding(s) total, 6 blocking. Verdict: REQUEST_CHANGES.

🤖 Automated review by Centaur · DAR-186

Comment on lines +517 to +540
/// `Gemma4MTPEngineRuntime` instead) and whenever MTP is off. The
/// caller stores this on the `BatchScheduler` so the non-batched VLM
/// vision path can route its decode through MTP speculative decoding.
let vlmMtpDrafter: Gemma4AssistantDraftModel?
}

// MARK: - MTP (Gemma 4 drafter speculative decoding) config

/// Whether drafter-based MTP is enabled. Off by default; opt in with
/// `DARKBLOOM_ENABLE_MTP=1` (also true/yes/on). Mirrors the
/// `DARKBLOOM_PREFIX_CACHE` parsing style but defaults OFF.
static func mtpEnabledFlag() -> Bool {
let v = ProcessInfo.processInfo.environment["DARKBLOOM_ENABLE_MTP"]?
.trimmingCharacters(in: .whitespaces).lowercased() ?? ""
return v == "1" || v == "true" || v == "yes" || v == "on"
}

/// Explicit drafter directory for MTP, from `DARKBLOOM_MTP_DRAFTER_PATH`.
/// v1 requires an explicit path (operator-controlled); manifest-driven
/// auto-discovery is a follow-up. Returns nil if unset/missing.
static func mtpDrafterPath() -> URL? {
guard let p = ProcessInfo.processInfo.environment["DARKBLOOM_MTP_DRAFTER_PATH"],
!p.isEmpty else { return nil }
let url = URL(fileURLWithPath: p)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔵 [INFO] 🔒 Environment variable parsing lacks validation for mtpEnabledFlag() and mtpMaxBatch()

💡 Suggestion: Add bounds checking for mtpMaxBatch() to prevent negative values or unreasonably large batch sizes that could cause resource exhaustion

📊 Score: 2×3 = 6 · Category: missing-input-validation

Comment on lines +143 to +168
// MTP speculative-decode gate. MTP only matches the plain
// decode byte-for-byte under GREEDY sampling (temperature 0)
// with no repetition penalty, so it is gated to exactly that
// case. Any other sampling config (temperature > 0, top-k/p
// tweaks paired with sampling, repetition penalty) falls back
// to the existing prepare/generate path below, unchanged.
let repPenaltyNeutral =
params.repetitionPenalty == nil || params.repetitionPenalty == 1.0
if let drafter = mtpDrafter,
params.temperature == 0,
repPenaltyNeutral
{
let handled = try await streamMTP(
container: container,
lmInput: lmInput,
params: params,
drafter: drafter,
blockSize: mtpBlockSize,
startedAt: startedAt,
continuation: continuation)
// `streamMTP` returns true when it served the request
// (it finished the stream itself). false ⇒ the model was
// not an MTP-capable VLM tower; fall through to plain
// generate so we never drop the request.
if handled { return }
mtpLogger.warning(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 [MEDIUM] ⚡ Synchronous MTP drafter loading blocks request processing

💡 Suggestion: Move drafter loading to initialization time or use async loading with caching to avoid blocking each request

📊 Score: 3×4 = 12 · Category: blocking_io

Comment on lines +232 to +355
// MARK: - MTP speculative vision decode

/// Run the multimodal decode through Gemma 4 MTP speculative decoding.
///
/// The whole decode runs inside `container.perform` so model access stays
/// single-threaded (the container's contract). We seed the drafter from a
/// multimodal prefill (`forwardForMTPMultimodal`, which merges the image /
/// video features), then drive `Gemma4MTPTokenIterator` round-by-round.
/// Each yielded token id is decoded to text exactly like the library's
/// `generateGemma4MTP` (`tokenizer.decode(tokenIds:[tok])`) and streamed.
///
/// Returns `true` when it served the request (and finished the stream
/// itself); `false` when the loaded model is not an `MLXVLM.Gemma4` tower,
/// in which case the caller falls back to the plain generate path.
/// Throws on prefill / iterator-init failure so the caller's `catch`
/// finishes the stream with the error.
private static func streamMTP(
container: ModelContainer,
lmInput: LMInput,
params: GenerateParameters,
drafter: Gemma4AssistantDraftModel,
blockSize: Int,
startedAt: Date,
continuation: AsyncThrowingStream<MLXServerGenerationEvent, Error>.Continuation
) async throws -> Bool {
// `LMInput` is non-Sendable; hand it across the container's isolation
// boundary via the `nonSendable` perform overload rather than capturing
// it in the @Sendable closure.
try await container.perform(nonSendable: lmInput) { ctx, lmInput in
guard let vlm = ctx.model as? MLXVLM.Gemma4 else {
return false
}

let promptTokens = lmInput.text.tokens.size
// Full stop-token set, mirroring the baseline generate loop
// (`buildStopTokenIds`): model EOS ids ∪ the tokenizer's EOS ∪ the
// resolved `extraEOSTokens`; the unknown-token id is also treated as
// a stop. Using only `configuration.eosTokenIds` could fail to
// terminate on a tokenizer/extra EOS not present in that set.
var stopTokenIds = ctx.configuration.eosTokenIds
if let tokenizerEOS = ctx.tokenizer.eosTokenId {
stopTokenIds.insert(tokenizerEOS)
}
for token in ctx.configuration.extraEOSTokens {
if let id = ctx.tokenizer.convertTokenToId(token) {
stopTokenIds.insert(id)
}
}
let unknownId = ctx.tokenizer.unknownTokenId
// Cap output length the same way the library default does when the
// caller leaves it open-ended.
var p = params
if p.maxTokens == nil { p.maxTokens = 1024 }
let maxTokens = p.maxTokens

// Seed: multimodal prefill (vision merge + capturing forward),
// then a prefilled-state MTP iterator over text-only decode rounds.
let cache = vlm.mtpNewCache(parameters: p)
let prefill = try vlm.forwardForMTPMultimodal(lmInput, cache: cache)
var iter = try Gemma4MTPTokenIterator(
prefill: prefill,
cache: cache,
target: vlm,
drafter: drafter,
parameters: p,
blockSize: blockSize)

// Streaming detokenizer: buffers a token whose decoded segment ends
// mid-codepoint (U+FFFD) and emits only complete text — so CJK /
// emoji / accented output is never split into mojibake. Per-token
// `decode([id])` (what this replaces) breaks multi-token glyphs.
var detokenizer = NaiveStreamingDetokenizer(tokenizer: ctx.tokenizer)

var completionTokens = 0
var firstTokenAt: Date?
var lastTokenAt: Date?
// Default "length"; flipped to "stop" on an EOS token. Mirrors the
// batched + library paths (maxTokens hit ⇒ "length", EOS ⇒ "stop").
var stopReason = "length"

while let tokenId = iter.next() {
if Task.isCancelled {
continuation.finish()
return true
}
// Stop tokens terminate WITHOUT being emitted or counted —
// parity with the baseline generate loop (which counts only
// non-stop tokens and never emits the EOS text).
if tokenId == unknownId || stopTokenIds.contains(tokenId) {
stopReason = "stop"
break
}
if firstTokenAt == nil { firstTokenAt = Date() }
lastTokenAt = Date()
completionTokens += 1

detokenizer.append(token: tokenId)
if let chunkText = detokenizer.next(), !chunkText.isEmpty {
continuation.yield(.content(chunkText))
}

if let maxTokens, completionTokens >= maxTokens {
stopReason = "length"
break
}
}

let promptTime = (firstTokenAt ?? startedAt).timeIntervalSince(startedAt)
let generationTime = (lastTokenAt ?? firstTokenAt ?? startedAt)
.timeIntervalSince(firstTokenAt ?? startedAt)
continuation.yield(
.info(
ServerGenerationInfo(
promptTokens: promptTokens,
completionTokens: completionTokens,
promptTime: max(0, promptTime),
generationTime: max(0, generationTime),
stopReason: stopReason
)
)
)
continuation.finish()
return true
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔵 [INFO] ⚡ Stop token set reconstruction on every MTP request

💡 Suggestion: Cache the stop token set construction or move it to model initialization time

📊 Score: 2×3 = 6 · Category: repeated_work

Comment on lines 560 to +585
architecture: ModelArchitecture,
diskAccountant: GlobalDiskAccountant? = nil
) async -> EngineBuild {
// MTP drafter (Gemma 4): load before `container.perform` (which is a
// sync, @Sendable closure) so the async weight load can be awaited and
// captured immutably. Bound to the target and attached to the scheduler
// inside the perform block. Off unless DARKBLOOM_ENABLE_MTP=1 and a
// drafter path is provided.
let mtpDrafter: Gemma4AssistantDraftModel? = await {
guard mtpEnabledFlag() else { return nil }
guard let drafterDir = mtpDrafterPath() else {
mtpLogger.warning(
"DARKBLOOM_ENABLE_MTP set but DARKBLOOM_MTP_DRAFTER_PATH missing/invalid; using plain decode.")
return nil
}
do {
let d = try await Gemma4AssistantDraftModel.load(from: drafterDir)
eval(d)
mtpLogger.info("MTP enabled — drafter loaded from \(drafterDir.path)")
return d
} catch {
mtpLogger.error(
"MTP enabled but drafter load failed (\(error)); using plain decode.")
return nil
}
}()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 [MEDIUM] ⚡ Synchronous drafter loading during engine creation

💡 Suggestion: Load drafter asynchronously before engine creation or implement lazy loading with proper error handling

📊 Score: 2×4 = 8 · Category: blocking_io

Comment on lines +143 to +149
// MTP speculative-decode gate. MTP only matches the plain
// decode byte-for-byte under GREEDY sampling (temperature 0)
// with no repetition penalty, so it is gated to exactly that
// case. Any other sampling config (temperature > 0, top-k/p
// tweaks paired with sampling, repetition penalty) falls back
// to the existing prepare/generate path below, unchanged.
let repPenaltyNeutral =

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 [MEDIUM] 🏷️ Missing type assertion check for repetitionPenalty comparison

💡 Suggestion: Add explicit type check: let repPenaltyNeutral = params.repetitionPenalty.map { $0 == 1.0 } ?? true to avoid potential type confusion when repetitionPenalty is non-nil but not a numeric type

📊 Score: 3×3 = 9 · Category: missing-type-assertion

Comment on lines +523 to +550
// MARK: - MTP (Gemma 4 drafter speculative decoding) config

/// Whether drafter-based MTP is enabled. Off by default; opt in with
/// `DARKBLOOM_ENABLE_MTP=1` (also true/yes/on). Mirrors the
/// `DARKBLOOM_PREFIX_CACHE` parsing style but defaults OFF.
static func mtpEnabledFlag() -> Bool {
let v = ProcessInfo.processInfo.environment["DARKBLOOM_ENABLE_MTP"]?
.trimmingCharacters(in: .whitespaces).lowercased() ?? ""
return v == "1" || v == "true" || v == "yes" || v == "on"
}

/// Explicit drafter directory for MTP, from `DARKBLOOM_MTP_DRAFTER_PATH`.
/// v1 requires an explicit path (operator-controlled); manifest-driven
/// auto-discovery is a follow-up. Returns nil if unset/missing.
static func mtpDrafterPath() -> URL? {
guard let p = ProcessInfo.processInfo.environment["DARKBLOOM_MTP_DRAFTER_PATH"],
!p.isEmpty else { return nil }
let url = URL(fileURLWithPath: p)
return FileManager.default.fileExists(atPath: url.path) ? url : nil
}

/// Max active batch size that still speculates (`DARKBLOOM_MTP_MAX_BATCH`,
/// default 2). Above it the engine falls back to plain batched decode —
/// the MoE verify tax makes speculation ≈ break-even by B≈4.
static func mtpMaxBatch() -> Int {
guard let v = ProcessInfo.processInfo.environment["DARKBLOOM_MTP_MAX_BATCH"],
let n = Int(v), n >= 1 else { return 2 }
return n

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔵 [INFO] 🧩 Custom environment variable parsing duplicates standard patterns

💡 Suggestion: Use a configuration library or standardize environment variable parsing across the codebase

📊 Score: 2×2 = 4 · Category: reinventing-stdlib

Comment on lines 140 to +170
repetitionPenalty: request.repetitionPenalty
)

// MTP speculative-decode gate. MTP only matches the plain
// decode byte-for-byte under GREEDY sampling (temperature 0)
// with no repetition penalty, so it is gated to exactly that
// case. Any other sampling config (temperature > 0, top-k/p
// tweaks paired with sampling, repetition penalty) falls back
// to the existing prepare/generate path below, unchanged.
let repPenaltyNeutral =
params.repetitionPenalty == nil || params.repetitionPenalty == 1.0
if let drafter = mtpDrafter,
params.temperature == 0,
repPenaltyNeutral
{
let handled = try await streamMTP(
container: container,
lmInput: lmInput,
params: params,
drafter: drafter,
blockSize: mtpBlockSize,
startedAt: startedAt,
continuation: continuation)
// `streamMTP` returns true when it served the request
// (it finished the stream itself). false ⇒ the model was
// not an MTP-capable VLM tower; fall through to plain
// generate so we never drop the request.
if handled { return }
mtpLogger.warning(
"VLM MTP requested but model is not an MLXVLM.Gemma4 tower; using plain generate.")
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 [MEDIUM] 🧩 Complex gating logic for MTP path duplicates validation patterns

💡 Suggestion: Extract MTP eligibility check into a dedicated function to reduce cognitive load and enable reuse

📊 Score: 3×4 = 12 · Category: duplicate-logic

Comment on lines +248 to +356
private static func streamMTP(
container: ModelContainer,
lmInput: LMInput,
params: GenerateParameters,
drafter: Gemma4AssistantDraftModel,
blockSize: Int,
startedAt: Date,
continuation: AsyncThrowingStream<MLXServerGenerationEvent, Error>.Continuation
) async throws -> Bool {
// `LMInput` is non-Sendable; hand it across the container's isolation
// boundary via the `nonSendable` perform overload rather than capturing
// it in the @Sendable closure.
try await container.perform(nonSendable: lmInput) { ctx, lmInput in
guard let vlm = ctx.model as? MLXVLM.Gemma4 else {
return false
}

let promptTokens = lmInput.text.tokens.size
// Full stop-token set, mirroring the baseline generate loop
// (`buildStopTokenIds`): model EOS ids ∪ the tokenizer's EOS ∪ the
// resolved `extraEOSTokens`; the unknown-token id is also treated as
// a stop. Using only `configuration.eosTokenIds` could fail to
// terminate on a tokenizer/extra EOS not present in that set.
var stopTokenIds = ctx.configuration.eosTokenIds
if let tokenizerEOS = ctx.tokenizer.eosTokenId {
stopTokenIds.insert(tokenizerEOS)
}
for token in ctx.configuration.extraEOSTokens {
if let id = ctx.tokenizer.convertTokenToId(token) {
stopTokenIds.insert(id)
}
}
let unknownId = ctx.tokenizer.unknownTokenId
// Cap output length the same way the library default does when the
// caller leaves it open-ended.
var p = params
if p.maxTokens == nil { p.maxTokens = 1024 }
let maxTokens = p.maxTokens

// Seed: multimodal prefill (vision merge + capturing forward),
// then a prefilled-state MTP iterator over text-only decode rounds.
let cache = vlm.mtpNewCache(parameters: p)
let prefill = try vlm.forwardForMTPMultimodal(lmInput, cache: cache)
var iter = try Gemma4MTPTokenIterator(
prefill: prefill,
cache: cache,
target: vlm,
drafter: drafter,
parameters: p,
blockSize: blockSize)

// Streaming detokenizer: buffers a token whose decoded segment ends
// mid-codepoint (U+FFFD) and emits only complete text — so CJK /
// emoji / accented output is never split into mojibake. Per-token
// `decode([id])` (what this replaces) breaks multi-token glyphs.
var detokenizer = NaiveStreamingDetokenizer(tokenizer: ctx.tokenizer)

var completionTokens = 0
var firstTokenAt: Date?
var lastTokenAt: Date?
// Default "length"; flipped to "stop" on an EOS token. Mirrors the
// batched + library paths (maxTokens hit ⇒ "length", EOS ⇒ "stop").
var stopReason = "length"

while let tokenId = iter.next() {
if Task.isCancelled {
continuation.finish()
return true
}
// Stop tokens terminate WITHOUT being emitted or counted —
// parity with the baseline generate loop (which counts only
// non-stop tokens and never emits the EOS text).
if tokenId == unknownId || stopTokenIds.contains(tokenId) {
stopReason = "stop"
break
}
if firstTokenAt == nil { firstTokenAt = Date() }
lastTokenAt = Date()
completionTokens += 1

detokenizer.append(token: tokenId)
if let chunkText = detokenizer.next(), !chunkText.isEmpty {
continuation.yield(.content(chunkText))
}

if let maxTokens, completionTokens >= maxTokens {
stopReason = "length"
break
}
}

let promptTime = (firstTokenAt ?? startedAt).timeIntervalSince(startedAt)
let generationTime = (lastTokenAt ?? firstTokenAt ?? startedAt)
.timeIntervalSince(firstTokenAt ?? startedAt)
continuation.yield(
.info(
ServerGenerationInfo(
promptTokens: promptTokens,
completionTokens: completionTokens,
promptTime: max(0, promptTime),
generationTime: max(0, generationTime),
stopReason: stopReason
)
)
)
continuation.finish()
return true
}
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 [MEDIUM] 🧩 Large MTP streaming implementation embedded in VLMRequestInference adds significant complexity

💡 Suggestion: Extract MTP streaming logic into a dedicated MTPStreamingDecoder class to separate concerns and improve testability

📊 Score: 4×3 = 12 · Category: misplaced-responsibility

Comment on lines +1 to +300
// Live, real-image end-to-end coverage for the VLM MTP speculative-decode
// path (`VLMRequestInference.stream` with `mtpDrafter:`).
//
// This is the final proof that multimodal Gemma 4 MTP works against a real
// image and a real ~14 GB qat-4bit VLM tower + its qat-4bit assistant
// drafter, through the exact `ModelContainer` / `container.perform` path
// ProviderLoop serves with. It runs TWO streams over the same request — a
// plain (mtpDrafter: nil) baseline and an MTP stream — and asserts the MTP
// output is coherent and broadly similar to the baseline.
//
// Gated by VLM_MTP_E2E=1 and requires both checkpoints in the local cache.
// Skips cleanly (recorded as a warning) when any precondition is unmet, so
// the default suite on a model-less / GPU-less runner stays green.
//
// cd provider-swift
// VLM_MTP_E2E=1 swift test --filter vlmMTPRealImageEndToEnd 2>&1 | tail -60
//
// NOTE on byte-identity: MTP and plain are NOT asserted byte-equal. The
// VLM tower verifies several drafted tokens per round in a single batched
// forward, and bf16 rounding makes a handful of near-tie logits resolve
// differently between the width-1 (plain) and width-N (verify) passes. That
// divergence is known and accepted; we assert a *loose* similarity (shared
// prefix or strong word overlap) instead of equality.

import CoreImage
import Foundation
import Testing
import MLX
import MLXLLM
import MLXLMCommon
import MLXLMServer
import MLXVLM

@testable import ProviderCore

private enum VLMMTPLiveFixtures {

/// Opt-in env var for this (expensive, ~14 GB) live test.
static let envVar = "VLM_MTP_E2E"

static var enabled: Bool {
ProcessInfo.processInfo.environment[envVar].map { !$0.isEmpty } ?? false
}

/// Multimodal Gemma 4 target (config declares `vision_config`).
/// Overridable via VLM_MTP_TARGET_DIR for a differently-located snapshot.
static var targetDir: String {
ProcessInfo.processInfo.environment["VLM_MTP_TARGET_DIR"]
?? "/Users/gaj/.cache/huggingface/hub/models--mlx-community--gemma-4-26B-A4B-it-qat-4bit/snapshots/0e3cbab38ce568cf6e23543010d08d03b731910c"
}

/// Matching qat-4bit assistant drafter. Overridable via VLM_MTP_DRAFTER_DIR.
static var drafterDir: String {
ProcessInfo.processInfo.environment["VLM_MTP_DRAFTER_DIR"]
?? "/Users/gaj/.cache/huggingface/hub/models--mlx-community--gemma-4-26B-A4B-it-qat-assistant-4bit/snapshots/bb94eae1b70a80dac16cbf959bb4b7d56bd1fb8c"
}

/// First existing candidate image, or nil if none are present.
static var testImage: URL? {
let candidates = [
ProcessInfo.processInfo.environment["VLM_MTP_IMAGE"],
"/Users/gaj/Downloads/eigen-labs-logo.png",
"/Users/gaj/Downloads/exo-logo.png",
].compactMap { $0 }
for path in candidates where FileManager.default.fileExists(atPath: path) {
return URL(fileURLWithPath: path)
}
return nil
}
}

/// Outcome of consuming one `VLMRequestInference.stream`.
private struct StreamRun {
var text: String
var completionTokens: Int
var promptTokens: Int
var generationTime: TimeInterval
/// Wall-clock seconds around stream consumption.
var wallSeconds: Double

/// Decode tokens/sec from the engine-reported generation time, falling
/// back to wall clock when the engine didn't report a positive time.
var tokensPerSecond: Double {
if generationTime > 0 { return Double(completionTokens) / generationTime }
if wallSeconds > 0 { return Double(completionTokens) / wallSeconds }
return 0
}
}

@Suite("VLM MTP real-image end-to-end (live)", .serialized)
struct VLMRequestInferenceMTPLiveTests {

@Test(
"VLM MTP describes a real image, coherent and similar to plain decode",
.enabled(if: VLMMTPLiveFixtures.enabled)
)
func vlmMTPRealImageEndToEnd() async throws {
// 0. Preconditions — skip cleanly (recorded warning) if unmet so the
// default suite stays green on a runner without these assets.
guard LiveInferenceFixtures.ensureMetallibColocated() != nil else {
withKnownIssue("skipped: \(LiveFixtureSkip.missingMetallib)") {
Issue.record("\(LiveFixtureSkip.missingMetallib)")
}
return
}
let targetURL = URL(fileURLWithPath: VLMMTPLiveFixtures.targetDir)
let drafterURL = URL(fileURLWithPath: VLMMTPLiveFixtures.drafterDir)
let fm = FileManager.default
guard fm.fileExists(atPath: targetURL.appendingPathComponent("config.json").path) else {
withKnownIssue("skipped: VLM target not found at \(targetURL.path)") {
Issue.record("VLM target not found at \(targetURL.path)")
}
return
}
guard fm.fileExists(atPath: drafterURL.appendingPathComponent("config.json").path) else {
withKnownIssue("skipped: drafter not found at \(drafterURL.path)") {
Issue.record("drafter not found at \(drafterURL.path)")
}
return
}
guard let imageURL = VLMMTPLiveFixtures.testImage else {
withKnownIssue("skipped: no test image present") {
Issue.record("no test image present (set VLM_MTP_IMAGE)")
}
return
}

// Cap MLX memory the same way ProviderLoop does; this is a big model.
LiveInferenceFixtures.applyMemoryBudget(maxBytes: 24 * 1024 * 1024 * 1024)

// 1. Load the real VLM ModelContainer via the SAME factory path
// ProviderLoop.loadModelContainer uses for a vision_config model.
#expect(
ProviderLoop.modelIsVLM(at: targetURL),
"target must declare vision_config to exercise the VLM path")
let container = try await VLMModelFactory.shared.loadContainer(
from: targetURL,
using: LocalTokenizerLoader()
)

// 2. Load the assistant drafter.
let drafter = try await Gemma4AssistantDraftModel.load(from: drafterURL)

// 3. Build the OpenAI request: one text part + one inline base64 image.
let imageBytes = try Data(contentsOf: imageURL)
let dataURI = "data:image/png;base64," + imageBytes.base64EncodedString()
let request = OpenAIChatCompletionRequest(
model: "gemma-4-vlm",
messages: [
.init(
role: .user,
content: .parts([
.text("Describe this image in one sentence."),
.imageURL(dataURI),
]))
],
temperature: 0,
maxTokens: 64
)
#expect(VLMRequestInference.hasMedia(request), "request must be detected as multimodal")

// 4. Run TWO streams: plain baseline, then MTP.
let plain = try await runStream(
container: container, request: request, drafter: nil)
let mtp = try await runStream(
container: container, request: request, drafter: drafter)

// 5. Coherence + similarity assertions.
let plainTrim = plain.text.trimmingCharacters(in: .whitespacesAndNewlines)
let mtpTrim = mtp.text.trimmingCharacters(in: .whitespacesAndNewlines)

let letterCount = mtpTrim.filter { $0.isLetter }.count
let coherent = mtpTrim.count > 0 && letterCount >= 10 && !isDegenerate(mtpTrim)

let prefix = commonPrefixLength(plainTrim, mtpTrim)
let jaccard = wordJaccard(plainTrim, mtpTrim)
// "Similar" = a meaningful shared prefix OR strong word overlap. bf16
// near-ties in the width-N verify pass can flip a token mid-stream, so
// we deliberately do NOT require byte-identity.
let similar = prefix >= 8 || jaccard >= 0.4

let speedup = plain.tokensPerSecond > 0
? mtp.tokensPerSecond / plain.tokensPerSecond : 0

// 6. Full diagnostic dump for the manual record.
print("================ VLM MTP real-image E2E ================")
print("image: \(imageURL.lastPathComponent) (\(imageBytes.count) bytes)")
print("target: \(VLMMTPLiveFixtures.targetDir)")
print("drafter: \(VLMMTPLiveFixtures.drafterDir)")
print("----------------- PLAIN (baseline) --------------------")
print(plainTrim)
print("[plain] completion_tokens=\(plain.completionTokens) "
+ "prompt_tokens=\(plain.promptTokens) "
+ String(format: "tok/s=%.2f wall=%.2fs gen=%.2fs",
plain.tokensPerSecond, plain.wallSeconds, plain.generationTime))
print("----------------- MTP ---------------------------------")
print(mtpTrim)
print("[mtp] completion_tokens=\(mtp.completionTokens) "
+ "prompt_tokens=\(mtp.promptTokens) "
+ String(format: "tok/s=%.2f wall=%.2fs gen=%.2fs",
mtp.tokensPerSecond, mtp.wallSeconds, mtp.generationTime))
print("----------------- SIMILARITY --------------------------")
print(String(
format: "common_prefix_chars=%d word_jaccard=%.3f speedup=%.2fx coherent=%@ similar=%@",
prefix, jaccard, speedup,
coherent ? "yes" : "no", similar ? "yes" : "no"))
print("=======================================================")

// Assertions.
let coherentMsg = "MTP output must be coherent (>=10 letters, not degenerate): \(mtpTrim.prefix(120))"
let similarMsg = "MTP must be loosely similar to plain (prefix=\(prefix) jaccard=\(jaccard)). "
+ "plain=<\(plainTrim.prefix(120))> mtp=<\(mtpTrim.prefix(120))>"
#expect(mtpTrim.count > 0, "MTP output must be non-empty")
#expect(coherent, "\(coherentMsg)")
#expect(plainTrim.count > 0, "plain baseline must be non-empty")
#expect(similar, "\(similarMsg)")
}

// MARK: - Stream consumption

/// Consume one `VLMRequestInference.stream`, timing the wall clock around
/// the iteration and collecting content + the final `.info`.
private func runStream(
container: ModelContainer,
request: OpenAIChatCompletionRequest,
drafter: Gemma4AssistantDraftModel?
) async throws -> StreamRun {
let stream = VLMRequestInference.stream(
container: container,
request: request,
defaultMaxTokens: 64,
mtpDrafter: drafter,
mtpBlockSize: 3
)
var text = ""
var completionTokens = 0
var promptTokens = 0
var generationTime: TimeInterval = 0
let start = Date()
for try await event in stream {
switch event {
case .content(let chunk):
text += chunk
case .info(let info):
completionTokens = info.completionTokens
promptTokens = info.promptTokens
generationTime = info.generationTime
case .toolCall:
continue
}
}
let wall = Date().timeIntervalSince(start)
return StreamRun(
text: text,
completionTokens: completionTokens,
promptTokens: promptTokens,
generationTime: generationTime,
wallSeconds: wall)
}

// MARK: - Similarity / coherence helpers

/// Length (in characters) of the common leading prefix of two strings.
private func commonPrefixLength(_ a: String, _ b: String) -> Int {
var count = 0
var ai = a.startIndex
var bi = b.startIndex
while ai < a.endIndex, bi < b.endIndex, a[ai] == b[bi] {
count += 1
ai = a.index(after: ai)
bi = b.index(after: bi)
}
return count
}

/// Jaccard similarity over lowercased word sets.
private func wordJaccard(_ a: String, _ b: String) -> Double {
let wa = Set(words(a))
let wb = Set(words(b))
if wa.isEmpty && wb.isEmpty { return 1 }
let inter = wa.intersection(wb).count
let union = wa.union(wb).count
return union == 0 ? 0 : Double(inter) / Double(union)
}

private func words(_ s: String) -> [String] {
s.lowercased()
.components(separatedBy: CharacterSet.alphanumerics.inverted)
.filter { !$0.isEmpty }
}

/// Detect obviously-broken output: a single token/word repeated many
/// times with almost no lexical variety.
private func isDegenerate(_ s: String) -> Bool {
let ws = words(s)
guard ws.count >= 6 else { return false }
let unique = Set(ws).count
return Double(unique) / Double(ws.count) < 0.2
}
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔵 [INFO] 🧩 Entire 300-line test file for single live test case with extensive fixtures

💡 Suggestion: Split into smaller focused tests or move fixture logic to shared test utilities

📊 Score: 2×2 = 4 · Category: over-abstraction

Comment on lines 716 to +762
eosTokenIds: eosTokenIds,
prefixCache: enginePrefixCache // nil unless .engine + flag (TB-007)
)
// Attach the Gemma 4 MTP runtime if a drafter was loaded and the
// target is a Gemma 4 model. The scheduler threads it into every
// GenerationBatch it builds; eligibility (size ≤ maxBatch, greedy)
// is decided per decode step. Non-Gemma-4 or bind failure → plain
// decode (runtime stays nil).
//
// VLM (image/video) models never flow through the batched engine
// for multimodal requests — those are served by the non-batched
// `VLMRequestInference.stream` vision path. So the text-target cast
// below is nil for an `MLXVLM.Gemma4`. In that case we bind the SAME
// drafter to the VLM tower and hand it out via `EngineBuild` so the
// vision path can route greedy decode through MTP. Carried out of
// this static closure on `vlmMtpDrafter`.
var vlmMtpDrafter: Gemma4AssistantDraftModel? = nil
if let drafter = mtpDrafter {
let target: Gemma4TextModel? =
(ctx.model as? Gemma4TextModel)
?? (ctx.model as? Gemma4Model)?.textModel
if let target {
do {
scheduler.mtpRuntime = try Gemma4MTPEngineRuntime(
target: target, drafter: drafter, maxBatch: mtpMaxBatchCap)
mtpLogger.info(
"MTP runtime attached for \(modelId) (maxBatch=\(mtpMaxBatchCap)).")
} catch {
mtpLogger.error("MTP bind failed (\(error)); using plain decode.")
}
} else if let vlm = ctx.model as? MLXVLM.Gemma4 {
// VLM tower: bind the drafter to it for the non-batched
// vision path. `bind` is idempotent on the same target.
do {
try drafter.bind(target: vlm)
vlmMtpDrafter = drafter
mtpLogger.info(
"MTP drafter bound to VLM tower for \(modelId); vision decode will speculate.")
} catch {
mtpLogger.error(
"MTP VLM bind failed for \(modelId) (\(error)); vision decode uses plain generate.")
}
} else {
mtpLogger.warning(
"MTP enabled but \(modelId) is not a Gemma 4 model; using plain decode.")
}
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 [MEDIUM] 🧩 MTP runtime attachment logic duplicates binding patterns for text vs VLM models

💡 Suggestion: Extract common MTP binding logic into a helper method to reduce duplication and improve maintainability

📊 Score: 3×3 = 9 · Category: duplicate-logic

@github-actions

github-actions Bot commented Jun 10, 2026

Copy link
Copy Markdown

This PR adds MTP (Multi-Token Prediction) speculative decoding for VLM (Gemma 4) inference and is security-neutral overall — no existing mitigations are weakened and no new trust-boundary crossings are introduced, but two narrow issues in the new code path warrant attention.


Trust Boundaries Touched

Boundary Relevance
TB-003 Provider operator vs. process New env-var flags (DARKBLOOM_ENABLE_MTP, DARKBLOOM_MTP_DRAFTER_PATH, DARKBLOOM_MTP_MAX_BATCH) are parsed from ProcessInfo.processInfo.environment
TB-007 Provider inference engine VLM inference path is modified; weight loading, generation loop, and streaming are all touched

Per-Threat Analysis

T-007 — Provider serves manipulated model outputs ℹ️ Neutral

  • The MTP drafter (Gemma4AssistantDraftModel) is loaded from DARKBLOOM_MTP_DRAFTER_PATH, an operator-supplied filesystem path. The drafter participates in the speculative decode loop (Gemma4MTPTokenIterator) and its outputs influence which tokens are accepted or rejected.
  • The drafter's weight hash is not included in the attestation blob and will not be advertised in registration heartbeats. For the threat model this is tolerable only because drafter weights affect latency (speculation quality), not correctness — the target model vlm (MLXVLM.Gemma4) remains the verifying model. A manipulated drafter that always proposes the correct next token still produces outputs identical to running the target model autoregressively. A pathologically bad drafter degrades TTFT but cannot alter accepted token distributions.
  • However: this assumption holds only as long as the verify step in Gemma4MTPTokenIterator is faithfully implemented in libs/mlx-swift-lm (one of the two uncovered files). If that library's accept/reject logic is ever weakened (e.g. always accepting draft tokens), the drafter becomes a tamper vector. This should be explicitly called out in any future review of libs/mlx-swift-lm.

T-028 — Residual inference data in GPU memory ℹ️ Neutral

  • VLMRequestInference.swift introduces a new MTP streaming path (streamMTP) with a NaiveStreamingDetokenizer and a Gemma4MTPTokenIterator. Neither zeroes intermediate buffers.
  • This does not regress the existing open finding (GPU-side MLX buffers are already not zeroed between requests). No new surfaces are introduced relative to the existing non-MTP VLM path.

T-041 — Cross-tenant prefix-cache sharing + TTFT timing oracle ⚠️ Minor new exposure point

  • BatchScheduler.swift adds vlmMtpDrafter and vlmMtpBlockSize as shared actor state. These are set once per loadModel() call and read by MultiModelBatchSchedulerEngine per request (await scheduler.vlmMtpDrafter / await scheduler.vlmMtpBlockSize).
  • The MTP path produces measurably different TTFT than the non-MTP path. A consumer who can observe TTFT can now infer whether DARKBLOOM_ENABLE_MTP is active (larger first-chunk latency variance is MTP-characteristic). This is a new observability channel for provider configuration, but it does not widen the cross-tenant prompt oracle described in SEC-035 — it only reveals an operator config flag, which is non-sensitive.
  • No new prefix-cache sharing is introduced by this PR.

New Attack Surface Not Covered by an Existing Threat

Operator-controlled drafter path (DARKBLOOM_MTP_DRAFTER_PATH) — unconstrained

BatchScheduler.mtpDrafterPath() (BatchScheduler.swift, new static method) does only a FileManager.default.fileExists check before passing the URL to the drafter loader:

static func mtpDrafterPath() -> URL? {
    guard let p = ProcessInfo.processInfo.environment["DARKBLOOM_MTP_DRAFTER_PATH"],
          !p.isEmpty else { return nil }
    let url = URL(fileURLWithPath: p)
    return FileManager.default.fileExists(atPath: url.path) ? url : nil
}

There is no validation that the path is within an expected directory (e.g. ~/.cache/huggingface). An operator can point DARKBLOOM_MTP_DRAFTER_PATH at any world-readable path on the filesystem. This is consistent with ADV-001's existing capability (the operator already controls model weight paths), but:

  1. The drafter weights are never hashed and never reported to the coordinator. Unlike the primary model weights (WeightHasher covers *.safetensors, *.bin, config, tokenizer), the drafter is loaded silently. An operator can swap the drafter without any heartbeat or hash mismatch being reported.
  2. Although a correct verify step in the MTP iterator prevents an incorrect drafter from altering token distributions, there is no coordinator-observable signal that the drafter was loaded or replaced. This is a gap in the telemetry/auditability posture, even if it is not currently exploitable for output manipulation.

Recommendation: Hash the drafter directory with WeightHasher (or a minimal subset — config.json + *.safetensors) and include it in the heartbeat as a separate drafterWeightHash field. This keeps the coordinator informed and creates a detectable signal if a drafter substitution occurs, consistent with the existing weight-hash telemetry pattern.

DARKBLOOM_MTP_MAX_BATCH integer parse — no upper bound

static func mtpMaxBatch() -> Int {
    guard let v = ProcessInfo.processInfo.environment["DARKBLOOM_MTP_MAX_BATCH"],
          let n = Int(v), n >= 1 else { return 2 }
    return n
}

Int(v) on a 64-bit platform accepts values up to Int.max. An operator setting DARKBLOOM_MTP_MAX_BATCH=9999999 would push all batch sizes through the speculative path, potentially increasing GPU memory pressure in ways that interact with the idle-timeout and buffer-zeroing gaps (T-028, T-029). This is low severity (operator-controlled, ADV-001 already controls the machine) but should have a reasonable cap (e.g. n <= 32) for defensive programming.


SEC-* Open Findings Resolved

None — this PR does not address any open SEC-* finding.


🔐 Threat model: docs/threat-model.yaml · Updates on each push to this PR

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: b35956bf06

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +151 to +154
if let drafter = mtpDrafter,
params.temperature == 0,
repPenaltyNeutral
{

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve tool-call parsing when enabling VLM MTP

When DARKBLOOM_ENABLE_MTP is on, any greedy multimodal request with mtpDrafter takes this branch, but the custom streamMTP loop only yields .content and .info; it never runs the library/tool-call handler that the plain container.generate path can surface as .toolCall and that this function forwards below. For VLM requests that include or produce tool calls, enabling MTP changes the wire response from structured tool-call events to raw text content, so gate MTP off for tool-enabled requests or add the same tool-call parsing to the MTP path.

Useful? React with 👍 / 👎.

return nil
}
do {
let d = try await Gemma4AssistantDraftModel.load(from: drafterDir)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Avoid loading the drafter for non-Gemma models

With DARKBLOOM_ENABLE_MTP=1, this loads the assistant drafter before the code has checked whether the resident model is actually Gemma 4; the non-Gemma case is only detected later in container.perform and then falls back to plain decode. On providers that serve Qwen/Llama/etc. while MTP is globally enabled, every model load still allocates the large drafter weights first, which can slow loads or exhaust memory even though the drafter is unusable for that model. Defer the drafter load until after confirming the target model/config is MTP-capable.

Useful? React with 👍 / 👎.

… load for non-Gemma

- Gate VLM MTP off when the request carries tools: `streamMTP` only emits
  `.content`/`.info`, so taking the MTP branch would drop the structured
  `.toolCall` events the plain `container.generate` path surfaces — a wire
  contract change. Tool requests now stay on the plain path.
- Skip the (large) drafter weight load when the resident model is not an
  MTP-capable Gemma 4 tower (text or VLM). Detected by the actual loaded type
  in `snapshotContainer` (`Gemma4TextModel`/`Gemma4Model`/`MLXVLM.Gemma4`), not
  by model-id/alias strings. On a multi-model provider with MTP globally
  enabled, a Qwen/Llama load no longer allocates the unusable drafter.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@Gajesh2007

Copy link
Copy Markdown
Member Author

Review triage

Codex (chatgpt-codex-connector) — both fixed in 830b7205:

  • P2 VLMRequestInference.swift:154 — preserve tool-call parsing — Fixed. MTP is now gated off when the request carries tools (request.tools non-empty), so tool-enabled VLM requests stay on the container.generate path that surfaces .toolCall. streamMTP is content-only and would otherwise drop those events.
  • P2 BatchScheduler.swift:576 — avoid loading the drafter for non-Gemma models — Fixed. The drafter load is now gated on MTP-capability detected by the actual loaded model type in snapshotContainer (Gemma4TextModel/Gemma4Model/MLXVLM.Gemma4). On a multi-model provider with MTP globally enabled, a Qwen/Llama load no longer allocates the unusable drafter.

Automated reviewer (ethenotethan): mostly style/complexity opinions; two are misreads worth noting:

  • "Synchronous drafter loading blocks request processing" / "during engine creation" — the drafter is loaded once at engine build (already await-ed, off the request path), not per request.
  • "Missing type assertion for repetitionPenalty" — repetitionPenalty is a statically-typed Float?; the == 1.0 comparison has no type-confusion path in Swift.
  • mtpMaxBatch() already clamps to >= 1. The remaining "extract MTPStreamingDecoder / consolidate config" items are reasonable refactors but out of scope for this PR.

@blacksmith-sh

blacksmith-sh Bot commented Jun 10, 2026

Copy link
Copy Markdown
Contributor

Found 1 test failure on Blacksmith runners:

Failure

Test View Logs
github.com/eigeninference/d-inference/e2e/TestIntegration_ConcurrentRequests View Logs

Fix with Codesmith
Need help on this PR? Tag /codesmith with what you need.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants