From 221ec4ccbe710e17627f5b7a48101be39321c4a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aykut=20Gu=CC=88ven?= Date: Tue, 24 Feb 2026 20:38:29 +0100 Subject: [PATCH 1/4] Mark CoreML and AVFoundation @preconcurrency --- Sources/WhisperKit/Core/Audio/AudioProcessor.swift | 2 +- Sources/WhisperKit/Utilities/Extensions+Public.swift | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift index 18ae990..55cbd02 100644 --- a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift +++ b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift @@ -2,7 +2,7 @@ // Copyright © 2024 Argmax, Inc. All rights reserved. import Accelerate -import AVFoundation +@preconcurrency import AVFoundation import CoreAudio import CoreML diff --git a/Sources/WhisperKit/Utilities/Extensions+Public.swift b/Sources/WhisperKit/Utilities/Extensions+Public.swift index fd0e8a3..ced782c 100644 --- a/Sources/WhisperKit/Utilities/Extensions+Public.swift +++ b/Sources/WhisperKit/Utilities/Extensions+Public.swift @@ -2,7 +2,8 @@ // Copyright © 2024 Argmax, Inc. All rights reserved. import AVFoundation -import CoreML +// TODO: Should be able to remove `@preconcurrency` once we drop support for iOS 16, macOS 14. +@preconcurrency import CoreML public extension Array where Element == TranscriptionSegment { func contains(segment: TranscriptionSegment) -> Bool { From f67a7e9b6b7b3511de35c81413117946d8a9422b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aykut=20Gu=CC=88ven?= Date: Tue, 24 Feb 2026 20:50:42 +0100 Subject: [PATCH 2/4] Fix concurrency errors in MLTensor extension --- Sources/ArgmaxCore/MLTensorExtensions.swift | 38 ---------- .../WhisperKit/Core/Text/TokenSampler.swift | 16 ++-- Sources/WhisperKit/Core/TextDecoder.swift | 4 +- .../MLTensorExtensionsTests.swift | 74 +++++++++++++++++++ 4 files changed, 85 insertions(+), 47 deletions(-) create mode 100644 Tests/WhisperKitTests/MLTensorExtensionsTests.swift diff --git a/Sources/ArgmaxCore/MLTensorExtensions.swift b/Sources/ArgmaxCore/MLTensorExtensions.swift index 44c0e75..fc7a394 100644 --- a/Sources/ArgmaxCore/MLTensorExtensions.swift +++ b/Sources/ArgmaxCore/MLTensorExtensions.swift @@ -44,43 +44,5 @@ public extension MLTensor { fatalError("Unsupported scalar type: \(scalarType)") } } - - // MARK: Sync (legacy — uses DispatchSemaphore, unsafe in concurrent async contexts) - - @available(*, deprecated, message: "Use await toIntArray() instead.") - func asIntArray() -> [Int] { - let semaphore = DispatchSemaphore(value: 0) - var result: [Int] = [] - Task(priority: .high) { - result = await self.toIntArray() - semaphore.signal() - } - semaphore.wait() - return result - } - - @available(*, deprecated, message: "Use await toFloatArray() instead.") - func asFloatArray() -> [Float] { - let semaphore = DispatchSemaphore(value: 0) - var result: [Float] = [] - Task(priority: .high) { - result = await self.toFloatArray() - semaphore.signal() - } - semaphore.wait() - return result - } - - @available(*, deprecated, message: "Use await toMLMultiArray() instead.") - func asMLMultiArray() -> MLMultiArray { - let semaphore = DispatchSemaphore(value: 0) - var result = try! MLMultiArray(shape: [1], dataType: .float16, initialValue: 0.0) - Task(priority: .high) { - result = await self.toMLMultiArray() - semaphore.signal() - } - semaphore.wait() - return result - } } #endif diff --git a/Sources/WhisperKit/Core/Text/TokenSampler.swift b/Sources/WhisperKit/Core/Text/TokenSampler.swift index 8d71051..b18608a 100644 --- a/Sources/WhisperKit/Core/Text/TokenSampler.swift +++ b/Sources/WhisperKit/Core/Text/TokenSampler.swift @@ -6,7 +6,7 @@ import CoreML import Foundation public protocol TokenSampling { - func update(tokens: [Int], logits: MLMultiArray, logProbs: [Float]) -> SamplingResult + func update(tokens: [Int], logits: MLMultiArray, logProbs: [Float]) async -> SamplingResult func finalize(tokens: [Int], logProbs: [Float]) -> SamplingResult } @@ -39,7 +39,7 @@ open class GreedyTokenSampler: TokenSampling { #if canImport(CoreML.MLState) @available(macOS 15, iOS 18, watchOS 11, visionOS 2, *) - private func sampleWithMLTensor(logits: MLMultiArray) -> (token: Int, logprob: Float) { + private func sampleWithMLTensor(logits: MLMultiArray) async -> (token: Int, logprob: Float) { // Use MLTensor operations if available for sampling // Reference: https://github.com/huggingface/swift-transformers/blob/preview/Sources/Generation/Decoders.swift var logitsTensor = MLTensor(MLShapedArray(logits)).cast(to: Float.self) @@ -76,9 +76,11 @@ open class GreedyTokenSampler: TokenSampling { nextLogprobTensor = softmaxScores.gathering(atIndices: nextTokenTensor, alongAxis: -1).log() } + async let nextTokenArray = nextTokenTensor.toIntArray() + async let nextLogprobArray = nextLogprobTensor.toFloatArray() return ( - token: nextTokenTensor.asIntArray()[0], - logprob: nextLogprobTensor.asFloatArray()[0] + token: await nextTokenArray[0], + logprob: await nextLogprobArray[0] ) } #endif @@ -212,7 +214,7 @@ open class GreedyTokenSampler: TokenSampling { return (token: nextToken!, logprob: nextLogprob) } - public func update(tokens: [Int], logits: MLMultiArray, logProbs: [Float]) -> SamplingResult { + public func update(tokens: [Int], logits: MLMultiArray, logProbs: [Float]) async -> SamplingResult { var nextTokens = tokens var nextLogprobs = logProbs var completed = false @@ -220,7 +222,7 @@ open class GreedyTokenSampler: TokenSampling { var result: (token: Int, logprob: Float) #if canImport(CoreML.MLState) if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) { - result = sampleWithMLTensor(logits: logits) + result = await sampleWithMLTensor(logits: logits) } else { result = sampleWithBNNS(logits: logits) } @@ -278,7 +280,7 @@ open class BeamSearchTokenSampler: TokenSampling { finishedSequences = [] } - public func update(tokens: [Int], logits: MLMultiArray, logProbs: [Float]) -> SamplingResult { + public func update(tokens: [Int], logits: MLMultiArray, logProbs: [Float]) async -> SamplingResult { // TODO: Implement fatalError("Not implemented: \(#function)") } diff --git a/Sources/WhisperKit/Core/TextDecoder.swift b/Sources/WhisperKit/Core/TextDecoder.swift index 84ab2b2..fc73de8 100644 --- a/Sources/WhisperKit/Core/TextDecoder.swift +++ b/Sources/WhisperKit/Core/TextDecoder.swift @@ -687,7 +687,7 @@ open class TextDecoder: TextDecoding, WhisperMLModel { let samplingStartTime = Date() - let sampleResult = tokenSampler.update(tokens: currentTokens, logits: logits, logProbs: logProbs) + let sampleResult = await tokenSampler.update(tokens: currentTokens, logits: logits, logProbs: logProbs) nextToken = sampleResult.tokens.last! logProbs = sampleResult.logProbs @@ -839,7 +839,7 @@ open class TextDecoder: TextDecoding, WhisperMLModel { let samplingStartTime = Date() - let sampleResult = tokenSampler.update(tokens: currentTokens, logits: logits, logProbs: logProbs) + let sampleResult = await tokenSampler.update(tokens: currentTokens, logits: logits, logProbs: logProbs) nextToken = sampleResult.tokens.last! let nextTokenLogProb = sampleResult.logProbs.last! diff --git a/Tests/WhisperKitTests/MLTensorExtensionsTests.swift b/Tests/WhisperKitTests/MLTensorExtensionsTests.swift new file mode 100644 index 0000000..6b766ec --- /dev/null +++ b/Tests/WhisperKitTests/MLTensorExtensionsTests.swift @@ -0,0 +1,74 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2026 Argmax, Inc. All rights reserved. + +#if canImport(CoreML.MLState) +import CoreML +@testable import WhisperKit +import XCTest + +@available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) +final class MLTensorExtensionsTests: XCTestCase { + func testAsIntArrayReturnsExpectedScalars() async { + let tensor = MLTensor(MLShapedArray(scalars: [1, -2, 42], shape: [3])) + + let result = await tensor.toIntArray() + + XCTAssertEqual(result, [1, -2, 42]) + } + + func testAsFloatArraySupportsFloat32Tensor() async { + let tensor = MLTensor(MLShapedArray(scalars: [0.25, -1.5, 2.0], shape: [3])) + + let result = await tensor.toFloatArray() + + assertEqual(result, [0.25, -1.5, 2.0], accuracy: 0.0001) + } + + func testAsFloatArraySupportsFloatTypeTensor() async { + let expected = [FloatType(0.125), FloatType(-0.75), FloatType(3.5)] + let tensor = MLTensor(MLShapedArray(scalars: expected, shape: [3])) + + let result = await tensor.toFloatArray() + + assertEqual(result, expected.map(Float.init), accuracy: 0.0001) + } + + func testAsFloatArraySupportsInt32Tensor() async { + let tensor = MLTensor(MLShapedArray(scalars: [-3, 0, 7], shape: [3])) + + let result = await tensor.toFloatArray() + + assertEqual(result, [-3, 0, 7], accuracy: 0.0001) + } + + func testAsMLMultiArrayRoundTripsFloatTypeTensor() async { + let expected = [FloatType(1.25), FloatType(-0.5), FloatType(3.75)] + let tensor = MLTensor(MLShapedArray(scalars: expected, shape: [3])) + + let result = await tensor.toMLMultiArray() + let shapedArray = MLShapedArray(result) + + XCTAssertEqual(result.shape, [3]) + XCTAssertEqual(shapedArray.scalars.count, expected.count) + assertEqual(shapedArray.scalars.map(Float.init), expected.map(Float.init), accuracy: 0.0001) + } + + func testAsMLMultiArrayRoundTripsInt32Tensor() async { + let expected: [Int32] = [-9, 4, 12] + let tensor = MLTensor(MLShapedArray(scalars: expected, shape: [3])) + + let result = await tensor.toMLMultiArray() + let shapedArray = MLShapedArray(result) + + XCTAssertEqual(result.shape, [3]) + XCTAssertEqual(shapedArray.scalars, expected) + } + + private func assertEqual(_ lhs: [Float], _ rhs: [Float], accuracy: Float) { + XCTAssertEqual(lhs.count, rhs.count) + for (actual, expected) in zip(lhs, rhs) { + XCTAssertEqual(actual, expected, accuracy: accuracy) + } + } +} +#endif From 6cd63e9e0f1cffc51c7d5952a730f494c5f0681b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aykut=20G=C3=BCven?= Date: Tue, 24 Feb 2026 22:52:09 +0100 Subject: [PATCH 3/4] Fix typo Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- Sources/WhisperKit/Utilities/Extensions+Public.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/WhisperKit/Utilities/Extensions+Public.swift b/Sources/WhisperKit/Utilities/Extensions+Public.swift index ced782c..f860996 100644 --- a/Sources/WhisperKit/Utilities/Extensions+Public.swift +++ b/Sources/WhisperKit/Utilities/Extensions+Public.swift @@ -2,7 +2,7 @@ // Copyright © 2024 Argmax, Inc. All rights reserved. import AVFoundation -// TODO: Should be able to remove `@preconcurrency` once we drop support for iOS 16, macOS 14. +// TODO: Should be able to remove `@preconcurrency` once we drop support for iOS 16 and macOS 13. @preconcurrency import CoreML public extension Array where Element == TranscriptionSegment { From 9e10fb7badba62f6a7ee9247b596846aaeb7f1dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aykut=20Gu=CC=88ven?= Date: Wed, 25 Feb 2026 09:58:46 +0100 Subject: [PATCH 4/4] Fix type inference issue --- Tests/WhisperKitTests/MLTensorExtensionsTests.swift | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/Tests/WhisperKitTests/MLTensorExtensionsTests.swift b/Tests/WhisperKitTests/MLTensorExtensionsTests.swift index 6b766ec..eb3fd54 100644 --- a/Tests/WhisperKitTests/MLTensorExtensionsTests.swift +++ b/Tests/WhisperKitTests/MLTensorExtensionsTests.swift @@ -29,8 +29,9 @@ final class MLTensorExtensionsTests: XCTestCase { let tensor = MLTensor(MLShapedArray(scalars: expected, shape: [3])) let result = await tensor.toFloatArray() + let expectedFloats: [Float] = expected.map { Float($0) } - assertEqual(result, expected.map(Float.init), accuracy: 0.0001) + assertEqual(result, expectedFloats, accuracy: 0.0001) } func testAsFloatArraySupportsInt32Tensor() async { @@ -47,10 +48,12 @@ final class MLTensorExtensionsTests: XCTestCase { let result = await tensor.toMLMultiArray() let shapedArray = MLShapedArray(result) + let resultFloats: [Float] = shapedArray.scalars.map { Float($0) } + let expectedFloats: [Float] = expected.map { Float($0) } XCTAssertEqual(result.shape, [3]) XCTAssertEqual(shapedArray.scalars.count, expected.count) - assertEqual(shapedArray.scalars.map(Float.init), expected.map(Float.init), accuracy: 0.0001) + assertEqual(resultFloats, expectedFloats, accuracy: 0.0001) } func testAsMLMultiArrayRoundTripsInt32Tensor() async {