Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 0 additions & 38 deletions Sources/ArgmaxCore/MLTensorExtensions.swift
Original file line number Diff line number Diff line change
Expand Up @@ -44,43 +44,5 @@ public extension MLTensor {
fatalError("Unsupported scalar type: \(scalarType)")
}
}

// MARK: Sync (legacy — uses DispatchSemaphore, unsafe in concurrent async contexts)

@available(*, deprecated, message: "Use await toIntArray() instead.")
func asIntArray() -> [Int] {
let semaphore = DispatchSemaphore(value: 0)
var result: [Int] = []
Task(priority: .high) {
result = await self.toIntArray()
semaphore.signal()
}
semaphore.wait()
return result
}

@available(*, deprecated, message: "Use await toFloatArray() instead.")
func asFloatArray() -> [Float] {
let semaphore = DispatchSemaphore(value: 0)
var result: [Float] = []
Task(priority: .high) {
result = await self.toFloatArray()
semaphore.signal()
}
semaphore.wait()
return result
}

@available(*, deprecated, message: "Use await toMLMultiArray() instead.")
func asMLMultiArray() -> MLMultiArray {
let semaphore = DispatchSemaphore(value: 0)
var result = try! MLMultiArray(shape: [1], dataType: .float16, initialValue: 0.0)
Task(priority: .high) {
result = await self.toMLMultiArray()
semaphore.signal()
}
semaphore.wait()
return result
}
}
#endif
2 changes: 1 addition & 1 deletion Sources/WhisperKit/Core/Audio/AudioProcessor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Copyright © 2024 Argmax, Inc. All rights reserved.

import Accelerate
import AVFoundation
@preconcurrency import AVFoundation
import CoreAudio
import CoreML

Expand Down
16 changes: 9 additions & 7 deletions Sources/WhisperKit/Core/Text/TokenSampler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import CoreML
import Foundation

public protocol TokenSampling {
func update(tokens: [Int], logits: MLMultiArray, logProbs: [Float]) -> SamplingResult
func update(tokens: [Int], logits: MLMultiArray, logProbs: [Float]) async -> SamplingResult
func finalize(tokens: [Int], logProbs: [Float]) -> SamplingResult
}

Expand Down Expand Up @@ -39,7 +39,7 @@ open class GreedyTokenSampler: TokenSampling {

#if canImport(CoreML.MLState)
@available(macOS 15, iOS 18, watchOS 11, visionOS 2, *)
private func sampleWithMLTensor(logits: MLMultiArray) -> (token: Int, logprob: Float) {
private func sampleWithMLTensor(logits: MLMultiArray) async -> (token: Int, logprob: Float) {
// Use MLTensor operations if available for sampling
// Reference: https://github.com/huggingface/swift-transformers/blob/preview/Sources/Generation/Decoders.swift
var logitsTensor = MLTensor(MLShapedArray<FloatType>(logits)).cast(to: Float.self)
Expand Down Expand Up @@ -76,9 +76,11 @@ open class GreedyTokenSampler: TokenSampling {
nextLogprobTensor = softmaxScores.gathering(atIndices: nextTokenTensor, alongAxis: -1).log()
}

async let nextTokenArray = nextTokenTensor.toIntArray()
async let nextLogprobArray = nextLogprobTensor.toFloatArray()
return (
token: nextTokenTensor.asIntArray()[0],
logprob: nextLogprobTensor.asFloatArray()[0]
token: await nextTokenArray[0],
logprob: await nextLogprobArray[0]
)
}
#endif
Expand Down Expand Up @@ -212,15 +214,15 @@ open class GreedyTokenSampler: TokenSampling {
return (token: nextToken!, logprob: nextLogprob)
}

public func update(tokens: [Int], logits: MLMultiArray, logProbs: [Float]) -> SamplingResult {
public func update(tokens: [Int], logits: MLMultiArray, logProbs: [Float]) async -> SamplingResult {
var nextTokens = tokens
var nextLogprobs = logProbs
var completed = false

var result: (token: Int, logprob: Float)
#if canImport(CoreML.MLState)
if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) {
result = sampleWithMLTensor(logits: logits)
result = await sampleWithMLTensor(logits: logits)
} else {
result = sampleWithBNNS(logits: logits)
}
Expand Down Expand Up @@ -278,7 +280,7 @@ open class BeamSearchTokenSampler: TokenSampling {
finishedSequences = []
}

public func update(tokens: [Int], logits: MLMultiArray, logProbs: [Float]) -> SamplingResult {
public func update(tokens: [Int], logits: MLMultiArray, logProbs: [Float]) async -> SamplingResult {
// TODO: Implement
fatalError("Not implemented: \(#function)")
}
Expand Down
4 changes: 2 additions & 2 deletions Sources/WhisperKit/Core/TextDecoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ open class TextDecoder: TextDecoding, WhisperMLModel {

let samplingStartTime = Date()

let sampleResult = tokenSampler.update(tokens: currentTokens, logits: logits, logProbs: logProbs)
let sampleResult = await tokenSampler.update(tokens: currentTokens, logits: logits, logProbs: logProbs)

nextToken = sampleResult.tokens.last!
logProbs = sampleResult.logProbs
Expand Down Expand Up @@ -839,7 +839,7 @@ open class TextDecoder: TextDecoding, WhisperMLModel {

let samplingStartTime = Date()

let sampleResult = tokenSampler.update(tokens: currentTokens, logits: logits, logProbs: logProbs)
let sampleResult = await tokenSampler.update(tokens: currentTokens, logits: logits, logProbs: logProbs)

nextToken = sampleResult.tokens.last!
let nextTokenLogProb = sampleResult.logProbs.last!
Expand Down
3 changes: 2 additions & 1 deletion Sources/WhisperKit/Utilities/Extensions+Public.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
// Copyright © 2024 Argmax, Inc. All rights reserved.

import AVFoundation
import CoreML
// TODO: Should be able to remove `@preconcurrency` once we drop support for iOS 16 and macOS 13.
@preconcurrency import CoreML

public extension Array where Element == TranscriptionSegment {
func contains(segment: TranscriptionSegment) -> Bool {
Expand Down
77 changes: 77 additions & 0 deletions Tests/WhisperKitTests/MLTensorExtensionsTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// For licensing see accompanying LICENSE.md file.
// Copyright © 2026 Argmax, Inc. All rights reserved.

#if canImport(CoreML.MLState)
import CoreML
@testable import WhisperKit
import XCTest

@available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *)
final class MLTensorExtensionsTests: XCTestCase {
func testAsIntArrayReturnsExpectedScalars() async {
let tensor = MLTensor(MLShapedArray<Int32>(scalars: [1, -2, 42], shape: [3]))

let result = await tensor.toIntArray()

XCTAssertEqual(result, [1, -2, 42])
}

func testAsFloatArraySupportsFloat32Tensor() async {
let tensor = MLTensor(MLShapedArray<Float32>(scalars: [0.25, -1.5, 2.0], shape: [3]))

let result = await tensor.toFloatArray()

assertEqual(result, [0.25, -1.5, 2.0], accuracy: 0.0001)
}

func testAsFloatArraySupportsFloatTypeTensor() async {
let expected = [FloatType(0.125), FloatType(-0.75), FloatType(3.5)]
let tensor = MLTensor(MLShapedArray<FloatType>(scalars: expected, shape: [3]))

let result = await tensor.toFloatArray()
let expectedFloats: [Float] = expected.map { Float($0) }

assertEqual(result, expectedFloats, accuracy: 0.0001)
}

func testAsFloatArraySupportsInt32Tensor() async {
let tensor = MLTensor(MLShapedArray<Int32>(scalars: [-3, 0, 7], shape: [3]))

let result = await tensor.toFloatArray()

assertEqual(result, [-3, 0, 7], accuracy: 0.0001)
}

func testAsMLMultiArrayRoundTripsFloatTypeTensor() async {
let expected = [FloatType(1.25), FloatType(-0.5), FloatType(3.75)]
let tensor = MLTensor(MLShapedArray<FloatType>(scalars: expected, shape: [3]))

let result = await tensor.toMLMultiArray()
let shapedArray = MLShapedArray<FloatType>(result)
let resultFloats: [Float] = shapedArray.scalars.map { Float($0) }
let expectedFloats: [Float] = expected.map { Float($0) }

XCTAssertEqual(result.shape, [3])
XCTAssertEqual(shapedArray.scalars.count, expected.count)
assertEqual(resultFloats, expectedFloats, accuracy: 0.0001)
}

func testAsMLMultiArrayRoundTripsInt32Tensor() async {
let expected: [Int32] = [-9, 4, 12]
let tensor = MLTensor(MLShapedArray<Int32>(scalars: expected, shape: [3]))

let result = await tensor.toMLMultiArray()
let shapedArray = MLShapedArray<Int32>(result)

XCTAssertEqual(result.shape, [3])
XCTAssertEqual(shapedArray.scalars, expected)
}

private func assertEqual(_ lhs: [Float], _ rhs: [Float], accuracy: Float) {
XCTAssertEqual(lhs.count, rhs.count)
for (actual, expected) in zip(lhs, rhs) {
XCTAssertEqual(actual, expected, accuracy: accuracy)
}
}
}
#endif