diff --git a/swift/Sources/CoreAILanguageModels/InferenceEngines/CoreAIPipelinedEngine.swift b/swift/Sources/CoreAILanguageModels/InferenceEngines/CoreAIPipelinedEngine.swift index 95302b7..bd2e84f 100644 --- a/swift/Sources/CoreAILanguageModels/InferenceEngines/CoreAIPipelinedEngine.swift +++ b/swift/Sources/CoreAILanguageModels/InferenceEngines/CoreAIPipelinedEngine.swift @@ -41,6 +41,12 @@ final class CoreAIPipelinedEngine: InferenceEngine, Sendable { private let engineInUse = Atomic(false) let config: ModelConfig + // Generation lifecycle + private let _activeToken = Mutex(nil) + private let _generationTask = Mutex?>(nil) + + var isBusy: Bool { _activeToken.withLock { $0 != nil } } + init( config: ModelConfig, preparedModel: PreparedModel, @@ -104,9 +110,20 @@ final class CoreAIPipelinedEngine: InferenceEngine, Sendable { let stopReasonStore = StopReasonStore() let (base, outputContinuation) = AsyncThrowingStream.makeStream() - Task { + + let token = GenerationToken() + _activeToken.withLock { $0 = token } + + let task = Task { self.acquireEngine() - defer { self.releaseEngine() } + defer { + self.releaseEngine() + // Only clear if this generation still owns both slots + if self._activeToken.withLock({ $0 === token }) { + self._activeToken.withLock { $0 = nil } + self._generationTask.withLock { $0 = nil } + } + } do { let (tokenStream, tokenContinuation) = AsyncThrowingStream.makeStream() @@ -139,6 +156,7 @@ final class CoreAIPipelinedEngine: InferenceEngine, Sendable { outputContinuation.finish(throwing: error) } } + _generationTask.withLock { $0 = task } return GenerationSequence(base: base, stopReasonStore: stopReasonStore) } @@ -154,8 +172,26 @@ final class CoreAIPipelinedEngine: InferenceEngine, Sendable { } } + func cancel() async throws { + let task: Task? = _generationTask.withLock { task in + task?.cancel() + defer { task = nil } + return task + } + _activeToken.withLock { + $0?.cancel() + $0 = nil + } + await task?.value + } + func reset() { drain() + _activeToken.withLock { + $0?.cancel() + $0 = nil + } + _generationTask.withLock { $0 = nil } guard tryAcquireEngine() else { return } defer { releaseEngine() } engine.reset() diff --git a/swift/Sources/CoreAILanguageModels/InferenceEngines/CoreAISequentialEngine.swift b/swift/Sources/CoreAILanguageModels/InferenceEngines/CoreAISequentialEngine.swift index f20bc05..f62ce35 100644 --- a/swift/Sources/CoreAILanguageModels/InferenceEngines/CoreAISequentialEngine.swift +++ b/swift/Sources/CoreAILanguageModels/InferenceEngines/CoreAISequentialEngine.swift @@ -71,8 +71,16 @@ public final class CoreAISequentialEngine: InferenceEngine, @unchecked Sendable // Track processed tokens for incremental inference private var processedTokenCount: Int = 0 - // Track in-flight generation for drain (same pattern as pipelined engine) - private let generating = Mutex(false) + // Track in-flight generation via token (replaces simple bool lock) + private let _activeToken = Mutex(nil) + + public var isBusy: Bool { _activeToken.withLock { $0 != nil } } + + /// Clear the engine's active token if it matches the given token. + /// Called by the iterator when generation finishes or is cancelled. + func clearTokenIfActive(_ token: GenerationToken) { + _activeToken.withLock { if $0 === token { $0 = nil } } + } // MARK: - Init @@ -337,11 +345,14 @@ public final class CoreAISequentialEngine: InferenceEngine, @unchecked Sendable samplingConfiguration: SamplingConfiguration, inferenceOptions: InferenceOptions ) throws -> GenerationSequence { - GenerationSequence( + let token = GenerationToken() + _activeToken.withLock { $0 = token } + return GenerationSequence( engine: self, input: input, samplingConfiguration: samplingConfiguration, - inferenceOptions: inferenceOptions + inferenceOptions: inferenceOptions, + generationToken: token ) } @@ -350,7 +361,7 @@ public final class CoreAISequentialEngine: InferenceEngine, @unchecked Sendable /// Wait for any in-flight generate() Task to finish. private func drain() { var attempts = 0 - while generating.withLock({ $0 }) { + while _activeToken.withLock({ $0 != nil }) { attempts += 1 if attempts > 5000 { fatalError("Sequential engine drain() timeout — generation Task stuck?") @@ -359,8 +370,18 @@ public final class CoreAISequentialEngine: InferenceEngine, @unchecked Sendable } } + public func cancel() async throws { + _activeToken.withLock { + $0?.cancel() + $0 = nil + } + } + public func reset() { - drain() + _activeToken.withLock { + $0?.cancel() + $0 = nil + } let resetSpan = InstrumentsProfiler.beginReset(engine: "CoreAIClean") processedTokenCount = 0 zeroFill(&keyCache) @@ -464,6 +485,7 @@ extension CoreAISequentialEngine { let input: [CoreAISequentialEngine.TokenId] let samplingConfiguration: SamplingConfiguration let inferenceOptions: InferenceOptions + let generationToken: GenerationToken /// Shared with the iterator so the caller can read why generation ended. let stopReasonStore = StopReasonStore() @@ -480,7 +502,8 @@ extension CoreAISequentialEngine { input: input, samplingConfiguration: samplingConfiguration, inferenceOptions: inferenceOptions, - stopReasonStore: stopReasonStore + stopReasonStore: stopReasonStore, + generationToken: generationToken ) } } @@ -497,10 +520,10 @@ extension CoreAISequentialEngine.GenerationSequence { private let forcedContinuation: [CoreAISequentialEngine.TokenId]? private let maxTokens: Int private let stopReasonStore: StopReasonStore + private let generationToken: GenerationToken private var inputTokens: [CoreAISequentialEngine.TokenId] private var step: Int = 0 - private var didAcquireLock: Bool = false private var finished: Bool = false init( @@ -508,13 +531,15 @@ extension CoreAISequentialEngine.GenerationSequence { input: [CoreAISequentialEngine.TokenId], samplingConfiguration: SamplingConfiguration, inferenceOptions: InferenceOptions, - stopReasonStore: StopReasonStore + stopReasonStore: StopReasonStore, + generationToken: GenerationToken ) { self.engine = engine self.samplingConfiguration = samplingConfiguration self.returnsLogits = inferenceOptions.includeLogits self.forcedContinuation = inferenceOptions.forcedContinuation self.stopReasonStore = stopReasonStore + self.generationToken = generationToken self.inputTokens = input if let forced = inferenceOptions.forcedContinuation { self.maxTokens = forced.count @@ -527,17 +552,16 @@ extension CoreAISequentialEngine.GenerationSequence { } deinit { - if didAcquireLock { - engine.generating.withLock { $0 = false } - } + engine.clearTokenIfActive(generationToken) } public func next() async throws -> InferenceOutput? { if finished { return nil } - if !didAcquireLock { - engine.generating.withLock { $0 = true } - didAcquireLock = true + if generationToken.isCancelled { + stopReasonStore.set(.cancelled) + finishAndRelease() + return nil } guard step < maxTokens else { @@ -572,6 +596,13 @@ extension CoreAISequentialEngine.GenerationSequence { logitBuffer = lastLogits } + // Check cancellation after inference step + if generationToken.isCancelled { + stopReasonStore.set(.cancelled) + finishAndRelease() + return nil + } + let nextToken: Int32 if let forced = forcedContinuation { nextToken = forced[step] @@ -603,10 +634,7 @@ extension CoreAISequentialEngine.GenerationSequence { return } finished = true - if didAcquireLock { - engine.generating.withLock { $0 = false } - didAcquireLock = false - } + engine.clearTokenIfActive(generationToken) } } } diff --git a/swift/Sources/CoreAILanguageModels/InferenceEngines/CoreAIStaticShapeEngine.swift b/swift/Sources/CoreAILanguageModels/InferenceEngines/CoreAIStaticShapeEngine.swift index c6ef5b7..466dfaf 100644 --- a/swift/Sources/CoreAILanguageModels/InferenceEngines/CoreAIStaticShapeEngine.swift +++ b/swift/Sources/CoreAILanguageModels/InferenceEngines/CoreAIStaticShapeEngine.swift @@ -6,6 +6,7 @@ import CoreAI import CoreAIShared import Foundation +import Synchronization /// Static-shape inference engine using Core AI models. public final class StaticShapeEngine: InferenceEngine, @unchecked Sendable { @@ -47,6 +48,16 @@ public final class StaticShapeEngine: InferenceEngine, @unchecked Sendable { // Number of tokens already processed in the current sequence. private var processedTokenCount: Int = 0 + // Track in-flight generation via token + private let _activeToken = Mutex(nil) + + public var isBusy: Bool { _activeToken.withLock { $0 != nil } } + + /// Clear the engine's active token if it matches the given token. + func clearTokenIfActive(_ token: GenerationToken) { + _activeToken.withLock { if $0 === token { $0 = nil } } + } + // MARK: - Initialization public init(configuration: ModelConfig, preparedModel: PreparedModel) async throws { @@ -317,11 +328,14 @@ public final class StaticShapeEngine: InferenceEngine, @unchecked Sendable { samplingConfiguration: SamplingConfiguration, inferenceOptions: InferenceOptions ) throws -> GenerationSequence { - GenerationSequence( + let token = GenerationToken() + _activeToken.withLock { $0 = token } + return GenerationSequence( engine: self, input: input, samplingConfiguration: samplingConfiguration, - inferenceOptions: inferenceOptions + inferenceOptions: inferenceOptions, + generationToken: token ) } @@ -531,7 +545,18 @@ public final class StaticShapeEngine: InferenceEngine, @unchecked Sendable { // MARK: - Lifecycle + public func cancel() async throws { + _activeToken.withLock { + $0?.cancel() + $0 = nil + } + } + public func reset() { + _activeToken.withLock { + $0?.cancel() + $0 = nil + } let resetSpan = InstrumentsProfiler.beginReset(engine: "StaticShape") processedTokenCount = 0 resetSpan.end() @@ -559,6 +584,7 @@ extension StaticShapeEngine { let input: [TokenId] let samplingConfiguration: SamplingConfiguration let inferenceOptions: InferenceOptions + let generationToken: GenerationToken /// Shared with the iterator so the caller can read why generation ended. let stopReasonStore = StopReasonStore() @@ -575,7 +601,8 @@ extension StaticShapeEngine { input: input, samplingConfiguration: samplingConfiguration, inferenceOptions: inferenceOptions, - stopReasonStore: stopReasonStore + stopReasonStore: stopReasonStore, + generationToken: generationToken ) } } @@ -592,22 +619,26 @@ extension StaticShapeEngine.GenerationSequence { private let forcedContinuation: [StaticShapeEngine.TokenId]? private let maxTokens: Int private let stopReasonStore: StopReasonStore + private let generationToken: GenerationToken private var inputTokens: [StaticShapeEngine.TokenId] private var step: Int = 0 + private var finished: Bool = false init( engine: StaticShapeEngine, input: [StaticShapeEngine.TokenId], samplingConfiguration: SamplingConfiguration, inferenceOptions: InferenceOptions, - stopReasonStore: StopReasonStore + stopReasonStore: StopReasonStore, + generationToken: GenerationToken ) { self.engine = engine self.samplingConfiguration = samplingConfiguration self.returnsLogits = inferenceOptions.includeLogits self.forcedContinuation = inferenceOptions.forcedContinuation self.stopReasonStore = stopReasonStore + self.generationToken = generationToken self.inputTokens = input if let forced = inferenceOptions.forcedContinuation { self.maxTokens = forced.count @@ -620,9 +651,18 @@ extension StaticShapeEngine.GenerationSequence { } public mutating func next() async throws -> InferenceOutput? { + if finished { return nil } + + if generationToken.isCancelled { + stopReasonStore.set(.cancelled) + finishAndRelease() + return nil + } + guard step < maxTokens else { // Natural exhaustion. Don't clobber a reason a decoder set (e.g. `.eos`). stopReasonStore.setIfUnset(.maxTokens) + finishAndRelease() return nil } @@ -637,6 +677,13 @@ extension StaticShapeEngine.GenerationSequence { returnsLogits: returnsLogits || forcedContinuation != nil ) + // Check cancellation after inference step + if generationToken.isCancelled { + stopReasonStore.set(.cancelled) + finishAndRelease() + return nil + } + let nextToken = forcedContinuation?[step] ?? sampledToken inputTokens.append(nextToken) step += 1 @@ -647,11 +694,19 @@ extension StaticShapeEngine.GenerationSequence { ) } catch is CancellationError { stopReasonStore.set(.cancelled) + finishAndRelease() throw CancellationError() } catch { stopReasonStore.set(.error) + finishAndRelease() throw error } } + + private mutating func finishAndRelease() { + guard !finished else { return } + finished = true + engine.clearTokenIfActive(generationToken) + } } } diff --git a/swift/Sources/CoreAILanguageModels/InferenceEngines/GenerationToken.swift b/swift/Sources/CoreAILanguageModels/InferenceEngines/GenerationToken.swift new file mode 100644 index 0000000..4f40e14 --- /dev/null +++ b/swift/Sources/CoreAILanguageModels/InferenceEngines/GenerationToken.swift @@ -0,0 +1,19 @@ +// Copyright 2026 Apple Inc. +// +// Use of this source code is governed by a BSD-3-clause license that can +// be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause + +import Synchronization + +/// A token representing an active generation session. +/// +/// Created by `generate()`, held by the iterator. The engine retains a +/// reference to the active token and can cancel it at any time. The iterator +/// checks `isCancelled` on each `next()` call. +public final class GenerationToken: Sendable { + private let _cancelled = Mutex(false) + + public var isCancelled: Bool { _cancelled.withLock { $0 } } + + public func cancel() { _cancelled.withLock { $0 = true } } +} diff --git a/swift/Sources/CoreAILanguageModels/InferenceEngines/InferenceEngine.swift b/swift/Sources/CoreAILanguageModels/InferenceEngines/InferenceEngine.swift index e536596..42b069d 100644 --- a/swift/Sources/CoreAILanguageModels/InferenceEngines/InferenceEngine.swift +++ b/swift/Sources/CoreAILanguageModels/InferenceEngines/InferenceEngine.swift @@ -112,6 +112,16 @@ public protocol InferenceEngine: Sendable { /// Run dummy inference to trigger kernel compilation. func warmup(queryLength: Int, sampling: SamplingConfiguration?) async throws + // MARK: - Cancellation + + /// Whether the engine has an active generation in progress. + var isBusy: Bool { get } + + /// Cancel any in-flight generation. Invalidates the current GenerationToken. + /// For pull-based engines, takes effect on the next `next()` call. + /// For push-based engines (pipelined), also cancels the background Task. + func cancel() async throws + // MARK: - Capabilities /// Whether this engine supports per-step logits extraction. @@ -162,6 +172,14 @@ extension InferenceEngine { public var supportsLogits: Bool { false } } +extension InferenceEngine { + /// Default: engine is not busy. + public var isBusy: Bool { false } + + /// Default: no-op cancel. Engines with active generation override this. + public func cancel() async throws {} +} + extension InferenceEngine { /// Default no-op implementation of warmup. public func warmup(queryLength: Int, sampling: SamplingConfiguration?) async throws { diff --git a/swift/Tests/LanguageModelsTests/CancelAPITests.swift b/swift/Tests/LanguageModelsTests/CancelAPITests.swift new file mode 100644 index 0000000..ca60398 --- /dev/null +++ b/swift/Tests/LanguageModelsTests/CancelAPITests.swift @@ -0,0 +1,120 @@ +// Copyright 2026 Apple Inc. +// +// Use of this source code is governed by a BSD-3-clause license that can +// be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause + +import Testing + +@testable import CoreAILanguageModels + +@Suite("InferenceEngine cancel API") +struct CancelAPITests { + @Test("isBusy is false when idle") + func idleNotBusy() async throws { + let engine = MockEngine(tokens: [10, 20, 30]) + #expect(!engine.isBusy) + } + + @Test("cancel() is safe when idle") + func cancelWhenIdle() async throws { + let engine = MockEngine(tokens: [10, 20, 30]) + try await engine.cancel() + #expect(!engine.isBusy) + } + + @Test("engine is busy during generation") + func busyDuringGeneration() async throws { + let engine = MockEngine(tokens: [10, 20, 30]) + let stream = try engine.generate( + with: [1], + samplingConfiguration: .greedy, + inferenceOptions: InferenceOptions(maxTokens: 100) + ) + #expect(engine.isBusy) + + // Consume to release + for try await _ in stream {} + #expect(!engine.isBusy) + } + + @Test("cancel() stops generation and marks .cancelled") + func cancelStopsGeneration() async throws { + let engine = MockEngine(tokens: [10, 20, 30]) + let stream = try engine.generate( + with: [1], + samplingConfiguration: .greedy, + inferenceOptions: InferenceOptions(maxTokens: 100) + ) + #expect(engine.isBusy) + + try await engine.cancel() + #expect(!engine.isBusy) + + // Stream should yield nil on next iteration + var count = 0 + for try await _ in stream { + count += 1 + } + #expect(count == 0) + #expect(stream.stopReason == .cancelled) + } + + @Test("engine becomes idle after generation completes naturally") + func idleAfterNaturalCompletion() async throws { + let engine = MockEngine(tokens: [10, 20, 30]) + let stream = try engine.generate( + with: [1], + samplingConfiguration: .greedy, + inferenceOptions: InferenceOptions(maxTokens: 3) + ) + #expect(engine.isBusy) + + for try await _ in stream {} + + #expect(!engine.isBusy) + #expect(stream.stopReason == .maxTokens) + } + + @Test("reset() cancels active generation") + func resetCancelsGeneration() async throws { + let engine = MockEngine(tokens: [10, 20, 30]) + let stream = try engine.generate( + with: [1], + samplingConfiguration: .greedy, + inferenceOptions: InferenceOptions(maxTokens: 100) + ) + #expect(engine.isBusy) + + try await engine.reset() + #expect(!engine.isBusy) + #expect(engine.resetCalled) + + // Stream should yield nil + var count = 0 + for try await _ in stream { + count += 1 + } + #expect(count == 0) + } + + @Test("GenerationToken starts not cancelled") + func tokenStartsNotCancelled() { + let token = GenerationToken() + #expect(!token.isCancelled) + } + + @Test("GenerationToken cancel() sets isCancelled") + func tokenCancelSetsFlag() { + let token = GenerationToken() + token.cancel() + #expect(token.isCancelled) + } + + @Test("GenerationToken cancel() is idempotent") + func tokenCancelIdempotent() { + let token = GenerationToken() + token.cancel() + token.cancel() + #expect(token.isCancelled) + } +} diff --git a/swift/Tests/LanguageModelsTests/TestUtilities.swift b/swift/Tests/LanguageModelsTests/TestUtilities.swift index 63c21ef..a77fcbc 100644 --- a/swift/Tests/LanguageModelsTests/TestUtilities.swift +++ b/swift/Tests/LanguageModelsTests/TestUtilities.swift @@ -4,6 +4,7 @@ // be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause import Foundation +import Synchronization import TestUtilities import Tokenizers @@ -37,6 +38,16 @@ class MockEngine: InferenceEngine, @unchecked Sendable { /// Tracks whether reset() was called private(set) var resetCalled: Bool = false + // Generation lifecycle + private let _activeToken = Mutex(nil) + + var isBusy: Bool { _activeToken.withLock { $0 != nil } } + + /// Clear the engine's active token if it matches the given token. + func clearTokenIfActive(_ token: GenerationToken) { + _activeToken.withLock { if $0 === token { $0 = nil } } + } + init( tokens: [Int32] = [10, 20, 30, 40, 50], maxContextLength: Int = 4096, @@ -54,10 +65,13 @@ class MockEngine: InferenceEngine, @unchecked Sendable { samplingConfiguration: SamplingConfiguration, inferenceOptions: InferenceOptions ) throws -> GenerationSequence { - GenerationSequence( + let token = GenerationToken() + _activeToken.withLock { $0 = token } + return GenerationSequence( engine: self, input: input, - inferenceOptions: inferenceOptions + inferenceOptions: inferenceOptions, + generationToken: token ) } @@ -68,6 +82,7 @@ class MockEngine: InferenceEngine, @unchecked Sendable { let engine: MockEngine let input: [TokenId] let inferenceOptions: InferenceOptions + let generationToken: GenerationToken let stopReasonStore = StopReasonStore() @@ -82,7 +97,8 @@ class MockEngine: InferenceEngine, @unchecked Sendable { engine: engine, input: input, inferenceOptions: inferenceOptions, - stopReasonStore: stopReasonStore + stopReasonStore: stopReasonStore, + generationToken: generationToken ) } @@ -95,19 +111,23 @@ class MockEngine: InferenceEngine, @unchecked Sendable { let forcedContinuation: [TokenId]? let maxTokens: Int let stopReasonStore: StopReasonStore + let generationToken: GenerationToken var step: Int = 0 + var finished: Bool = false init( engine: MockEngine, input: [TokenId], inferenceOptions: InferenceOptions, - stopReasonStore: StopReasonStore + stopReasonStore: StopReasonStore, + generationToken: GenerationToken ) { self.engine = engine self.returnsLogits = inferenceOptions.includeLogits self.forcedContinuation = inferenceOptions.forcedContinuation self.stopReasonStore = stopReasonStore + self.generationToken = generationToken if let forced = inferenceOptions.forcedContinuation { self.maxTokens = forced.count } else { @@ -119,8 +139,17 @@ class MockEngine: InferenceEngine, @unchecked Sendable { } mutating func next() async throws -> InferenceOutput? { + if finished { return nil } + + if generationToken.isCancelled { + stopReasonStore.set(.cancelled) + finishAndRelease() + return nil + } + guard step < maxTokens else { stopReasonStore.setIfUnset(.maxTokens) + finishAndRelease() return nil } @@ -149,16 +178,32 @@ class MockEngine: InferenceEngine, @unchecked Sendable { return InferenceOutput(tokenId: nextToken, logits: logits) } catch is CancellationError { stopReasonStore.set(.cancelled) + finishAndRelease() throw CancellationError() } catch { stopReasonStore.set(.error) + finishAndRelease() throw error } } + + private mutating func finishAndRelease() { + guard !finished else { return } + finished = true + engine.clearTokenIfActive(generationToken) + } + } + } + + func cancel() async throws { + _activeToken.withLock { + $0?.cancel() + $0 = nil } } func reset() async throws { + try await cancel() resetCalled = true inferenceCallCount = 0 }