Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ final class CoreAIPipelinedEngine: InferenceEngine, Sendable {
private let engineInUse = Atomic<Bool>(false)
let config: ModelConfig

// Generation lifecycle
private let _activeToken = Mutex<GenerationToken?>(nil)
private let _generationTask = Mutex<Task<Void, Never>?>(nil)

var isBusy: Bool { _activeToken.withLock { $0 != nil } }

init(
config: ModelConfig,
preparedModel: PreparedModel,
Expand Down Expand Up @@ -104,9 +110,20 @@ final class CoreAIPipelinedEngine: InferenceEngine, Sendable {
let stopReasonStore = StopReasonStore()
let (base, outputContinuation) =
AsyncThrowingStream<InferenceOutput, any Error>.makeStream()
Task {

let token = GenerationToken()
_activeToken.withLock { $0 = token }

let task = Task {
self.acquireEngine()
defer { self.releaseEngine() }
defer {
self.releaseEngine()
// Only clear if this generation still owns both slots
if self._activeToken.withLock({ $0 === token }) {
self._activeToken.withLock { $0 = nil }
self._generationTask.withLock { $0 = nil }
}
}
do {
let (tokenStream, tokenContinuation) =
AsyncThrowingStream<InferenceEngine.TokenId, any Error>.makeStream()
Expand Down Expand Up @@ -139,6 +156,7 @@ final class CoreAIPipelinedEngine: InferenceEngine, Sendable {
outputContinuation.finish(throwing: error)
}
}
_generationTask.withLock { $0 = task }
return GenerationSequence(base: base, stopReasonStore: stopReasonStore)
}

Expand All @@ -154,8 +172,26 @@ final class CoreAIPipelinedEngine: InferenceEngine, Sendable {
}
}

func cancel() async throws {
let task: Task<Void, Never>? = _generationTask.withLock { task in
task?.cancel()
defer { task = nil }
return task
}
_activeToken.withLock {
$0?.cancel()
$0 = nil
}
await task?.value
}

func reset() {
drain()
_activeToken.withLock {
$0?.cancel()
$0 = nil
}
_generationTask.withLock { $0 = nil }
guard tryAcquireEngine() else { return }
defer { releaseEngine() }
engine.reset()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,16 @@ public final class CoreAISequentialEngine: InferenceEngine, @unchecked Sendable
// Track processed tokens for incremental inference
private var processedTokenCount: Int = 0

// Track in-flight generation for drain (same pattern as pipelined engine)
private let generating = Mutex(false)
// Track in-flight generation via token (replaces simple bool lock)
private let _activeToken = Mutex<GenerationToken?>(nil)

public var isBusy: Bool { _activeToken.withLock { $0 != nil } }

/// Clear the engine's active token if it matches the given token.
/// Called by the iterator when generation finishes or is cancelled.
func clearTokenIfActive(_ token: GenerationToken) {
_activeToken.withLock { if $0 === token { $0 = nil } }
}

// MARK: - Init

Expand Down Expand Up @@ -337,11 +345,14 @@ public final class CoreAISequentialEngine: InferenceEngine, @unchecked Sendable
samplingConfiguration: SamplingConfiguration,
inferenceOptions: InferenceOptions
) throws -> GenerationSequence {
GenerationSequence(
let token = GenerationToken()
_activeToken.withLock { $0 = token }
return GenerationSequence(
engine: self,
input: input,
samplingConfiguration: samplingConfiguration,
inferenceOptions: inferenceOptions
inferenceOptions: inferenceOptions,
generationToken: token
)
}

Expand All @@ -350,7 +361,7 @@ public final class CoreAISequentialEngine: InferenceEngine, @unchecked Sendable
/// Wait for any in-flight generate() Task to finish.
private func drain() {
var attempts = 0
while generating.withLock({ $0 }) {
while _activeToken.withLock({ $0 != nil }) {
attempts += 1
if attempts > 5000 {
fatalError("Sequential engine drain() timeout — generation Task stuck?")
Expand All @@ -359,8 +370,18 @@ public final class CoreAISequentialEngine: InferenceEngine, @unchecked Sendable
}
}

public func cancel() async throws {
_activeToken.withLock {
$0?.cancel()
$0 = nil
}
}

public func reset() {
drain()
_activeToken.withLock {
$0?.cancel()
$0 = nil
}
let resetSpan = InstrumentsProfiler.beginReset(engine: "CoreAIClean")
processedTokenCount = 0
zeroFill(&keyCache)
Expand Down Expand Up @@ -464,6 +485,7 @@ extension CoreAISequentialEngine {
let input: [CoreAISequentialEngine.TokenId]
let samplingConfiguration: SamplingConfiguration
let inferenceOptions: InferenceOptions
let generationToken: GenerationToken

/// Shared with the iterator so the caller can read why generation ended.
let stopReasonStore = StopReasonStore()
Expand All @@ -480,7 +502,8 @@ extension CoreAISequentialEngine {
input: input,
samplingConfiguration: samplingConfiguration,
inferenceOptions: inferenceOptions,
stopReasonStore: stopReasonStore
stopReasonStore: stopReasonStore,
generationToken: generationToken
)
}
}
Expand All @@ -497,24 +520,26 @@ extension CoreAISequentialEngine.GenerationSequence {
private let forcedContinuation: [CoreAISequentialEngine.TokenId]?
private let maxTokens: Int
private let stopReasonStore: StopReasonStore
private let generationToken: GenerationToken

private var inputTokens: [CoreAISequentialEngine.TokenId]
private var step: Int = 0
private var didAcquireLock: Bool = false
private var finished: Bool = false

init(
engine: CoreAISequentialEngine,
input: [CoreAISequentialEngine.TokenId],
samplingConfiguration: SamplingConfiguration,
inferenceOptions: InferenceOptions,
stopReasonStore: StopReasonStore
stopReasonStore: StopReasonStore,
generationToken: GenerationToken
) {
self.engine = engine
self.samplingConfiguration = samplingConfiguration
self.returnsLogits = inferenceOptions.includeLogits
self.forcedContinuation = inferenceOptions.forcedContinuation
self.stopReasonStore = stopReasonStore
self.generationToken = generationToken
self.inputTokens = input
if let forced = inferenceOptions.forcedContinuation {
self.maxTokens = forced.count
Expand All @@ -527,17 +552,16 @@ extension CoreAISequentialEngine.GenerationSequence {
}

deinit {
if didAcquireLock {
engine.generating.withLock { $0 = false }
}
engine.clearTokenIfActive(generationToken)
}

public func next() async throws -> InferenceOutput? {
if finished { return nil }

if !didAcquireLock {
engine.generating.withLock { $0 = true }
didAcquireLock = true
if generationToken.isCancelled {
stopReasonStore.set(.cancelled)
finishAndRelease()
return nil
}

guard step < maxTokens else {
Expand Down Expand Up @@ -572,6 +596,13 @@ extension CoreAISequentialEngine.GenerationSequence {
logitBuffer = lastLogits
}

// Check cancellation after inference step
if generationToken.isCancelled {
stopReasonStore.set(.cancelled)
finishAndRelease()
return nil
}

let nextToken: Int32
if let forced = forcedContinuation {
nextToken = forced[step]
Expand Down Expand Up @@ -603,10 +634,7 @@ extension CoreAISequentialEngine.GenerationSequence {
return
}
finished = true
if didAcquireLock {
engine.generating.withLock { $0 = false }
didAcquireLock = false
}
engine.clearTokenIfActive(generationToken)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import CoreAI
import CoreAIShared
import Foundation
import Synchronization

/// Static-shape inference engine using Core AI models.
public final class StaticShapeEngine: InferenceEngine, @unchecked Sendable {
Expand Down Expand Up @@ -47,6 +48,16 @@ public final class StaticShapeEngine: InferenceEngine, @unchecked Sendable {
// Number of tokens already processed in the current sequence.
private var processedTokenCount: Int = 0

// Track in-flight generation via token
private let _activeToken = Mutex<GenerationToken?>(nil)

public var isBusy: Bool { _activeToken.withLock { $0 != nil } }

/// Clear the engine's active token if it matches the given token.
func clearTokenIfActive(_ token: GenerationToken) {
_activeToken.withLock { if $0 === token { $0 = nil } }
}

// MARK: - Initialization

public init(configuration: ModelConfig, preparedModel: PreparedModel) async throws {
Expand Down Expand Up @@ -317,11 +328,14 @@ public final class StaticShapeEngine: InferenceEngine, @unchecked Sendable {
samplingConfiguration: SamplingConfiguration,
inferenceOptions: InferenceOptions
) throws -> GenerationSequence {
GenerationSequence(
let token = GenerationToken()
_activeToken.withLock { $0 = token }
return GenerationSequence(
engine: self,
input: input,
samplingConfiguration: samplingConfiguration,
inferenceOptions: inferenceOptions
inferenceOptions: inferenceOptions,
generationToken: token
)
}

Expand Down Expand Up @@ -531,7 +545,18 @@ public final class StaticShapeEngine: InferenceEngine, @unchecked Sendable {

// MARK: - Lifecycle

public func cancel() async throws {
_activeToken.withLock {
$0?.cancel()
$0 = nil
}
}

public func reset() {
_activeToken.withLock {
$0?.cancel()
$0 = nil
}
let resetSpan = InstrumentsProfiler.beginReset(engine: "StaticShape")
processedTokenCount = 0
resetSpan.end()
Expand Down Expand Up @@ -559,6 +584,7 @@ extension StaticShapeEngine {
let input: [TokenId]
let samplingConfiguration: SamplingConfiguration
let inferenceOptions: InferenceOptions
let generationToken: GenerationToken

/// Shared with the iterator so the caller can read why generation ended.
let stopReasonStore = StopReasonStore()
Expand All @@ -575,7 +601,8 @@ extension StaticShapeEngine {
input: input,
samplingConfiguration: samplingConfiguration,
inferenceOptions: inferenceOptions,
stopReasonStore: stopReasonStore
stopReasonStore: stopReasonStore,
generationToken: generationToken
)
}
}
Expand All @@ -592,22 +619,26 @@ extension StaticShapeEngine.GenerationSequence {
private let forcedContinuation: [StaticShapeEngine.TokenId]?
private let maxTokens: Int
private let stopReasonStore: StopReasonStore
private let generationToken: GenerationToken

private var inputTokens: [StaticShapeEngine.TokenId]
private var step: Int = 0
private var finished: Bool = false

init(
engine: StaticShapeEngine,
input: [StaticShapeEngine.TokenId],
samplingConfiguration: SamplingConfiguration,
inferenceOptions: InferenceOptions,
stopReasonStore: StopReasonStore
stopReasonStore: StopReasonStore,
generationToken: GenerationToken
) {
self.engine = engine
self.samplingConfiguration = samplingConfiguration
self.returnsLogits = inferenceOptions.includeLogits
self.forcedContinuation = inferenceOptions.forcedContinuation
self.stopReasonStore = stopReasonStore
self.generationToken = generationToken
self.inputTokens = input
if let forced = inferenceOptions.forcedContinuation {
self.maxTokens = forced.count
Expand All @@ -620,9 +651,18 @@ extension StaticShapeEngine.GenerationSequence {
}

public mutating func next() async throws -> InferenceOutput? {
if finished { return nil }

if generationToken.isCancelled {
stopReasonStore.set(.cancelled)
finishAndRelease()
return nil
}

guard step < maxTokens else {
// Natural exhaustion. Don't clobber a reason a decoder set (e.g. `.eos`).
stopReasonStore.setIfUnset(.maxTokens)
finishAndRelease()
return nil
}

Expand All @@ -637,6 +677,13 @@ extension StaticShapeEngine.GenerationSequence {
returnsLogits: returnsLogits || forcedContinuation != nil
)

// Check cancellation after inference step
if generationToken.isCancelled {
stopReasonStore.set(.cancelled)
finishAndRelease()
return nil
}

let nextToken = forcedContinuation?[step] ?? sampledToken
inputTokens.append(nextToken)
step += 1
Expand All @@ -647,11 +694,19 @@ extension StaticShapeEngine.GenerationSequence {
)
} catch is CancellationError {
stopReasonStore.set(.cancelled)
finishAndRelease()
throw CancellationError()
} catch {
stopReasonStore.set(.error)
finishAndRelease()
throw error
}
}

private mutating func finishAndRelease() {
guard !finished else { return }
finished = true
engine.clearTokenIfActive(generationToken)
}
}
}
Loading