diff --git a/README.md b/README.md
index 79cd7ea..2a971fd 100644
--- a/README.md
+++ b/README.md
@@ -8,15 +8,19 @@
-# WhisperKit Android (Beta)
+# WhisperKit Android
+
+[](https://github.com/argmaxinc/whisperkitandroid/actions/workflows/pr-checks.yml)
+[](LICENSE.md)
+[](https://central.sonatype.com/artifact/com.argmaxinc/whisperkit)
+[](https://discord.gg/G5F5GZGecC)
-[](https://central.sonatype.com/artifact/com.argmaxinc/whisperkit)
WhisperKit Android brings Foundation Models On Device for Automatic Speech Recognition. It extends the performance and feature set of [WhisperKit](https://github.com/argmaxinc/WhisperKit) from Apple platforms to Android and Linux. The current feature set is a subset of the iOS counterpart,
but we are continuing to invest in Android and now welcome contributions from the community.
-[Example App (Coming Soon)] [[Blog Post]](https://takeargmax.com/blog/android) [[Python Tools Repo]](https://github.com/argmaxinc/whisperkittools)
+[[Example App]](https://play.google.com/store/apps/details?id=com.argmaxinc.whisperax) [[Blog Post]](https://takeargmax.com/blog/android) [[Python Tools Repo]](https://github.com/argmaxinc/whisperkittools)
## Table of Contents
@@ -37,7 +41,7 @@ To use WhisperKit in your Android app, you need to:
```kotlin
dependencies {
// 1. WhisperKit SDK
- implementation("com.argmaxinc:whisperkit:0.3.0") // Check badge above for latest version
+ implementation("com.argmaxinc:whisperkit:0.3.2") // Check badge above for latest version
// 2. QNN dependencies for hardware acceleration
implementation("com.qualcomm.qnn:qnn-runtime:2.34.0")
@@ -73,7 +77,7 @@ class YourActivity : AppCompatActivity() {
whisperKit = WhisperKit.Builder()
.setModel(WhisperKit.OPENAI_TINY_EN)
.setApplicationContext(applicationContext)
- .setCallback { what, timestamp, msg ->
+ .setCallback { what, result ->
// Handle transcription output
when (what) {
WhisperKit.TextOutputCallback.MSG_INIT -> {
@@ -81,10 +85,14 @@ class YourActivity : AppCompatActivity() {
}
WhisperKit.TextOutputCallback.MSG_TEXT_OUT -> {
// New transcription available
- val text = msg
- val time = timestamp
+ val fullText = result.text
+ val segments = result.segments
// Process the transcribed text as it becomes available
// This callback will be called multiple times as more audio is processed
+ segments.forEach { segment ->
+ // Process each segment
+ val segmentText = segment.text
+ }
}
WhisperKit.TextOutputCallback.MSG_CLOSE -> {
// Cleanup complete
diff --git a/android/config/detekt.yml b/android/config/detekt.yml
index 9b823e6..11bfe92 100644
--- a/android/config/detekt.yml
+++ b/android/config/detekt.yml
@@ -14,7 +14,7 @@ complexity:
thresholdInObjects: 10
LongParameterList:
functionThreshold: 8
- constructorThreshold: 7
+ constructorThreshold: 8
CyclomaticComplexMethod:
threshold: 20
NestedBlockDepth:
diff --git a/android/examples/WhisperAX/build.gradle.kts b/android/examples/WhisperAX/build.gradle.kts
index 794608d..d83199b 100644
--- a/android/examples/WhisperAX/build.gradle.kts
+++ b/android/examples/WhisperAX/build.gradle.kts
@@ -12,7 +12,7 @@ android {
applicationId = "com.argmaxinc.whisperax"
minSdk = 26
targetSdk = 35
- versionCode = 3
+ versionCode = 6
versionName = "0.1.0"
testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner"
diff --git a/android/examples/WhisperAX/src/main/java/com/argmaxinc/whisperax/ComputeUnitsView.kt b/android/examples/WhisperAX/src/main/java/com/argmaxinc/whisperax/ComputeUnitsView.kt
index 3fb1717..95c4a7f 100644
--- a/android/examples/WhisperAX/src/main/java/com/argmaxinc/whisperax/ComputeUnitsView.kt
+++ b/android/examples/WhisperAX/src/main/java/com/argmaxinc/whisperax/ComputeUnitsView.kt
@@ -31,6 +31,7 @@ import androidx.compose.material3.Surface
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.collectAsState
+import androidx.compose.runtime.derivedStateOf
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
@@ -40,6 +41,7 @@ import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha
import androidx.compose.ui.draw.rotate
import androidx.compose.ui.unit.dp
+import com.argmaxinc.whisperax.WhisperViewModel.Companion.MODELS_SUPPORTING_NPU
import com.argmaxinc.whisperkit.ExperimentalWhisperKit
import com.argmaxinc.whisperkit.WhisperKit
@@ -50,11 +52,18 @@ enum class ComputeUnits(val displayName: String, val backendValue: Int) {
CPU_AND_NPU("NPU", WhisperKit.Builder.CPU_AND_NPU),
}
+@OptIn(ExperimentalWhisperKit::class)
@Composable
fun ComputeUnitsView(viewModel: WhisperViewModel) {
val modelState by viewModel.modelState.collectAsState()
val encoderState by viewModel.encoderState.collectAsState()
val decoderState by viewModel.decoderState.collectAsState()
+ val selectedModel by viewModel.selectedModel.collectAsState()
+ val shouldEnableNPUForEncoderDecoder by remember {
+ derivedStateOf {
+ selectedModel in MODELS_SUPPORTING_NPU
+ }
+ }
val isEnabled = modelState == ModelState.LOADED || modelState == ModelState.UNLOADED
var whisperKitExpanded by remember { mutableStateOf(true) }
@@ -75,6 +84,7 @@ fun ComputeUnitsView(viewModel: WhisperViewModel) {
currentState = encoderState,
currentUnit = viewModel.encoderComputeUnits.collectAsState().value,
onUnitSelected = { viewModel.setEncoderComputeUnits(it) },
+ shouldEnableNPU = shouldEnableNPUForEncoderDecoder,
enabled = isEnabled,
)
@@ -85,6 +95,7 @@ fun ComputeUnitsView(viewModel: WhisperViewModel) {
currentState = decoderState,
currentUnit = viewModel.decoderComputeUnits.collectAsState().value,
onUnitSelected = { viewModel.setDecoderComputeUnits(it) },
+ shouldEnableNPU = shouldEnableNPUForEncoderDecoder,
enabled = isEnabled,
)
}
@@ -185,6 +196,7 @@ fun ComputeUnitRow(
currentState: ModelState,
currentUnit: ComputeUnits,
onUnitSelected: (ComputeUnits) -> Unit,
+ shouldEnableNPU: Boolean = true,
enabled: Boolean = true,
) {
val infiniteTransition = rememberInfiniteTransition(label = "loading animation")
@@ -248,7 +260,11 @@ fun ComputeUnitRow(
expanded = expanded,
onDismissRequest = { expanded = false },
) {
- ComputeUnits.values().forEach { unit ->
+ if (shouldEnableNPU) {
+ listOf(ComputeUnits.CPU_ONLY, ComputeUnits.CPU_AND_GPU, ComputeUnits.CPU_AND_NPU)
+ } else {
+ listOf(ComputeUnits.CPU_ONLY, ComputeUnits.CPU_AND_GPU)
+ }.forEach { unit ->
DropdownMenuItem(
text = { Text(unit.displayName) },
onClick = {
diff --git a/android/examples/WhisperAX/src/main/java/com/argmaxinc/whisperax/ModelSelectorView.kt b/android/examples/WhisperAX/src/main/java/com/argmaxinc/whisperax/ModelSelectorView.kt
index 331bbfb..791fd9c 100644
--- a/android/examples/WhisperAX/src/main/java/com/argmaxinc/whisperax/ModelSelectorView.kt
+++ b/android/examples/WhisperAX/src/main/java/com/argmaxinc/whisperax/ModelSelectorView.kt
@@ -34,6 +34,7 @@ import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton
import androidx.compose.material3.LinearProgressIndicator
import androidx.compose.material3.MaterialTheme
+import androidx.compose.material3.MenuAnchorType
import androidx.compose.material3.OutlinedTextField
import androidx.compose.material3.Surface
import androidx.compose.material3.Text
@@ -111,7 +112,8 @@ fun ModelSelectorView(viewModel: WhisperViewModel) {
},
modifier = Modifier
.fillMaxWidth()
- .weight(1f),
+ .weight(1f)
+ .menuAnchor(MenuAnchorType.PrimaryNotEditable),
)
ExposedDropdownMenu(
diff --git a/android/examples/WhisperAX/src/main/java/com/argmaxinc/whisperax/WhisperViewModel.kt b/android/examples/WhisperAX/src/main/java/com/argmaxinc/whisperax/WhisperViewModel.kt
index 897f32c..379c424 100644
--- a/android/examples/WhisperAX/src/main/java/com/argmaxinc/whisperax/WhisperViewModel.kt
+++ b/android/examples/WhisperAX/src/main/java/com/argmaxinc/whisperax/WhisperViewModel.kt
@@ -14,6 +14,8 @@ import androidx.compose.runtime.mutableStateListOf
import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import com.argmaxinc.whisperkit.ExperimentalWhisperKit
+import com.argmaxinc.whisperkit.TranscriptionResult
+import com.argmaxinc.whisperkit.TranscriptionSegment
import com.argmaxinc.whisperkit.WhisperKit
import com.argmaxinc.whisperkit.WhisperKit.TextOutputCallback
import com.argmaxinc.whisperkit.WhisperKitException
@@ -33,22 +35,13 @@ import java.text.SimpleDateFormat
import java.util.Date
import java.util.Locale
-data class TranscriptionSegment(
- val text: String,
- val start: Float,
- val end: Float,
- val tokens: List = emptyList(),
-)
-
-data class TranscriptionResult(
- val text: String = "",
- val segments: List = emptyList(),
-)
-
@OptIn(ExperimentalWhisperKit::class)
class WhisperViewModel : ViewModel() {
companion object {
const val TAG = "WhisperViewModel"
+
+ // Models currently supporting NPU backend, don't enable NPU for other models
+ val MODELS_SUPPORTING_NPU = listOf(WhisperKit.Builder.QUALCOMM_TINY_EN, WhisperKit.Builder.QUALCOMM_BASE_EN)
}
private lateinit var appContext: Context
@@ -190,10 +183,11 @@ class WhisperViewModel : ViewModel() {
cacheDir = context.cacheDir.absolutePath
}
- fun onTextOutput(what: Int, timestamp: Float, msg: String) {
+ fun onTextOutput(what: Int, result: TranscriptionResult) {
+ val segments = result.segments
when (what) {
TextOutputCallback.MSG_INIT -> {
- Log.i(MainActivity.TAG, "TFLite initialized: $msg")
+ Log.i(MainActivity.TAG, "TFLite initialized: ${result.text}")
startTime = System.currentTimeMillis()
_pipelineStart.value = startTime.toDouble() / 1000.0
_isInitializing.value = false
@@ -201,14 +195,13 @@ class WhisperViewModel : ViewModel() {
TextOutputCallback.MSG_TEXT_OUT -> {
Log.i(MainActivity.TAG, "TEXT OUT THREAD")
- if (msg.isNotEmpty()) {
+ if (segments.isNotEmpty()) {
if (!firstTokenReceived) {
firstTokenReceived = true
firstTokenTimestamp = System.currentTimeMillis()
_firstTokenTime.value = (firstTokenTimestamp - startTime).toDouble() / 1000.0
}
-
- val newTokens = msg.length / 4
+ val newTokens = segments.joinToString("") { it.text }.length / 4
totalTokens += newTokens
val currentTime = System.currentTimeMillis()
@@ -220,14 +213,14 @@ class WhisperViewModel : ViewModel() {
}
lastTokenTimestamp = currentTime
- updateTranscript(msg)
+ updateTranscript(segments)
}
}
TextOutputCallback.MSG_CLOSE -> {
Log.i(MainActivity.TAG, "Transcription completed.")
- if (msg.isNotEmpty()) {
- val newTokens = msg.length / 4
+ if (segments.isNotEmpty()) {
+ val newTokens = segments.joinToString("") { it.text }.length / 4
totalTokens += newTokens
val totalTime = (System.currentTimeMillis() - startTime).toDouble() / 1000.0
@@ -236,8 +229,7 @@ class WhisperViewModel : ViewModel() {
updateRealtimeMetrics(totalTime)
}
-
- updateTranscript(msg)
+ updateTranscript(segments)
}
}
@@ -247,25 +239,8 @@ class WhisperViewModel : ViewModel() {
}
}
- private fun updateTranscript(chunkText: String, withTimestamps: Boolean = false) {
- var processedText = chunkText
-
- val timestamps = if (withTimestamps) {
- val timestampPattern = "<\\|(\\d+\\.\\d+)\\|>".toRegex()
- val timestampMatches = timestampPattern.findAll(chunkText).toList()
- timestampMatches.map { it.groupValues[1].toFloat() }
- } else {
- emptyList()
- }
-
- if (!withTimestamps) {
- processedText = processedText
- .replace("<\\|[^>]*\\|>".toRegex(), "")
- .trim()
- } else {
- processedText = processedText.trim()
- }
-
+ private fun updateTranscript(segments: List) {
+ val processedText = segments.joinToString("") { it.text }
if (processedText.isNotEmpty()) {
if (allText.isNotEmpty()) {
allText.append("\n")
@@ -284,13 +259,12 @@ class WhisperViewModel : ViewModel() {
fun listModels() {
viewModelScope.launch {
val modelDirs = listOf(
- // TODO: enable when models are ready
- // WhisperKit.Builder.OPENAI_TINY_EN,
- // WhisperKit.Builder.OPENAI_BASE_EN,
- // WhisperKit.Builder.OPENAI_SMALL_EN,
WhisperKit.Builder.QUALCOMM_TINY_EN,
WhisperKit.Builder.QUALCOMM_BASE_EN,
- // WhisperKit.Builder.QUALCOMM_SMALL_EN
+ WhisperKit.Builder.OPENAI_TINY_EN,
+ WhisperKit.Builder.OPENAI_BASE_EN,
+ WhisperKit.Builder.OPENAI_TINY,
+ WhisperKit.Builder.OPENAI_BASE,
)
availableModels.clear()
availableModels.addAll(modelDirs)
@@ -364,6 +338,21 @@ class WhisperViewModel : ViewModel() {
fun selectModel(model: String) {
_selectedModel.value = model
+ if (model in MODELS_SUPPORTING_NPU) {
+ _encoderComputeUnits.update {
+ ComputeUnits.CPU_AND_NPU
+ }
+ _decoderComputeUnits.update {
+ ComputeUnits.CPU_AND_NPU
+ }
+ } else {
+ _encoderComputeUnits.update {
+ ComputeUnits.CPU_ONLY
+ }
+ _decoderComputeUnits.update {
+ ComputeUnits.CPU_ONLY
+ }
+ }
_modelState.value = ModelState.UNLOADED
_encoderState.value = ModelState.UNLOADED
_decoderState.value = ModelState.UNLOADED
diff --git a/android/whisperkit/build.gradle.kts b/android/whisperkit/build.gradle.kts
index f0e66d0..c2394a0 100644
--- a/android/whisperkit/build.gradle.kts
+++ b/android/whisperkit/build.gradle.kts
@@ -66,7 +66,7 @@ dependencies {
mavenPublishing {
- coordinates("com.argmaxinc", "whisperkit", "0.3.0")
+ coordinates("com.argmaxinc", "whisperkit", "0.3.2")
pom {
name.set("WhisperKit")
description.set("On-device Speech Recognition for Android")
diff --git a/android/whisperkit/detekt-baseline.xml b/android/whisperkit/detekt-baseline.xml
index 1df0484..4c4b90c 100644
--- a/android/whisperkit/detekt-baseline.xml
+++ b/android/whisperkit/detekt-baseline.xml
@@ -2,8 +2,10 @@
+ LargeClass:ArgmaxModelDownloaderImplTest.kt$ArgmaxModelDownloaderImplTest
ThrowsCount:WhisperKit.kt$WhisperKit.Builder$@Throws(WhisperKitException::class) fun build(): WhisperKit
TooGenericExceptionCaught:KtorHuggingFaceApiImpl.kt$KtorHuggingFaceApiImpl$e: Exception
TooGenericExceptionCaught:WhisperKitImpl.kt$WhisperKitImpl$e: Exception
+ UnusedParameter:WhisperKitImpl.kt$WhisperKitImpl$timestamp: Float
diff --git a/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/WhisperKit.kt b/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/WhisperKit.kt
index 4e6f7b2..254de56 100644
--- a/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/WhisperKit.kt
+++ b/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/WhisperKit.kt
@@ -7,6 +7,23 @@ import android.content.Context
import com.argmaxinc.whisperkit.huggingface.HuggingFaceApi
import kotlinx.coroutines.flow.Flow
+/**
+ * Contains the complete transcription result.
+ * The result includes both the full text and individual segments of the transcription.
+ */
+data class TranscriptionResult(
+ val text: String,
+ val segments: List,
+)
+
+/**
+ * Represents a single segment of transcribed text.
+ * Each segment contains the text content of a portion of the transcription.
+ */
+data class TranscriptionSegment(
+ val text: String,
+)
+
/**
* WhisperKit is a speech recognition library that provides real-time transcription capabilities.
* It supports both OpenAI and Qualcomm Whisper models, with various size options and compute backend configurations.
@@ -138,16 +155,11 @@ interface WhisperKit {
* - MSG_INIT (0): init() succeeded, model is ready
* - MSG_TEXT_OUT (1): transcription results from previous transcribe() call
* - MSG_CLOSE (2): deinitialize() succeeded, cleanup complete
- * @param timestamp The timestamp of the transcribed segment (only valid for MSG_TEXT_OUT)
- * @param msg The transcribed text or status message:
- * - For MSG_INIT: initialization status
- * - For MSG_TEXT_OUT: transcribed text
- * - For MSG_CLOSE: cleanup status
+ * @param result The transcription result containing with raw text and segments
*/
fun onTextOutput(
what: Int,
- timestamp: Float,
- msg: String,
+ result: TranscriptionResult,
)
}
@@ -160,15 +172,27 @@ interface WhisperKit {
// Model variants
const val OPENAI_TINY_EN = "whisperkit-litert/openai_whisper-tiny.en"
const val OPENAI_BASE_EN = "whisperkit-litert/openai_whisper-base.en"
- const val OPENAI_SMALL_EN = "whisperkit-litert/openai_whisper-small.en"
+ const val OPENAI_TINY = "whisperkit-litert/openai_whisper-tiny"
+ const val OPENAI_BASE = "whisperkit-litert/openai_whisper-base"
const val QUALCOMM_TINY_EN = "qualcomm/Whisper_Tiny_En"
const val QUALCOMM_BASE_EN = "qualcomm/Whisper_Base_En"
- const val QUALCOMM_SMALL_EN = "qualcomm/Whisper_Small_En"
+
+ // Small models are not supported yet
+ internal const val OPENAI_SMALL_EN = "whisperkit-litert/openai_whisper-small.en"
+ internal const val QUALCOMM_SMALL_EN = "qualcomm/Whisper_Small_En"
// Compute units used for encoder/decoder backend
const val CPU_ONLY = 1
const val CPU_AND_GPU = 2
const val CPU_AND_NPU = 3
+ val SUPPORTED_MODELS = listOf(
+ OPENAI_TINY_EN,
+ OPENAI_BASE_EN,
+ OPENAI_TINY,
+ OPENAI_BASE,
+ QUALCOMM_TINY_EN,
+ QUALCOMM_BASE_EN,
+ )
}
private var model: String? = null
@@ -185,17 +209,8 @@ interface WhisperKit {
*/
@Throws(WhisperKitException::class)
fun setModel(model: String): Builder {
- if (model !in
- listOf(
- OPENAI_TINY_EN,
- OPENAI_BASE_EN,
- OPENAI_SMALL_EN,
- QUALCOMM_TINY_EN,
- QUALCOMM_BASE_EN,
- QUALCOMM_SMALL_EN,
- )
- ) {
- throw WhisperKitException("Model must be one of the predefined variants")
+ if (model !in SUPPORTED_MODELS) {
+ throw WhisperKitException("Model must be one of the predefined variants: $SUPPORTED_MODELS")
}
this.model = model
return this
diff --git a/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/WhisperKitImpl.kt b/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/WhisperKitImpl.kt
index 47d0aad..e18c7ca 100644
--- a/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/WhisperKitImpl.kt
+++ b/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/WhisperKitImpl.kt
@@ -9,6 +9,8 @@ import com.argmaxinc.whisperkit.huggingface.HuggingFaceApi
import com.argmaxinc.whisperkit.network.ArgmaxModel
import com.argmaxinc.whisperkit.network.ArgmaxModelDownloader
import com.argmaxinc.whisperkit.network.ArgmaxModelDownloaderImpl
+import com.argmaxinc.whisperkit.util.MessageProcessor
+import com.argmaxinc.whisperkit.util.SegmentTextOnlyMessageProcessor
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.mapLatest
@@ -22,6 +24,7 @@ internal class WhisperKitImpl(
decoderBackend: Int,
private val callback: WhisperKit.TextOutputCallback,
private val argmaxModelDownloader: ArgmaxModelDownloader = ArgmaxModelDownloaderImpl(),
+ private val messageProcessor: MessageProcessor = SegmentTextOnlyMessageProcessor(),
) : WhisperKit {
companion object {
private const val TAG = "WhisperKitImpl"
@@ -144,14 +147,17 @@ internal class WhisperKitImpl(
}
}
- // Callback from JNI native code
+ // Callback from JNI native code, note at the moment timestamp is always 0
private fun onTextOutput(
what: Int,
timestamp: Float,
msg: String,
) {
try {
- callback.onTextOutput(what, timestamp, msg)
+ callback.onTextOutput(
+ what,
+ messageProcessor.process(msg),
+ )
} catch (e: Exception) {
Log.e(TAG, "Callback execution failed: ${e.message}")
throw WhisperKitException("Callback execution failed: ${e.message}", e)
diff --git a/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/network/ArgmaxModelDownloaderImpl.kt b/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/network/ArgmaxModelDownloaderImpl.kt
index b1a4500..59c04ed 100644
--- a/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/network/ArgmaxModelDownloaderImpl.kt
+++ b/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/network/ArgmaxModelDownloaderImpl.kt
@@ -58,7 +58,7 @@ class AndroidHuggingFaceLogger(private val tag: String) : HuggingFaceLogger {
* This class provides functionality to download required model files for a specific variant using
* HuggingFaceApi, supporting automatic retries and progress reporting during downloads.
*/
-class ArgmaxModelDownloaderImpl(
+internal class ArgmaxModelDownloaderImpl(
private val huggingFaceApi: HuggingFaceApi =
KtorHuggingFaceApiImpl(
config =
@@ -69,6 +69,7 @@ class ArgmaxModelDownloaderImpl(
) : ArgmaxModelDownloader {
companion object {
private const val TOKENIZER_REPO = "TOKENIZER_REPO"
+ private const val CONFIG_REPO = "TOKENIZER_REPO"
private const val ENCODER_DECODER_REPO = "ENCODER_DECODER_REPO"
// dir path under argmaxinc/whisperkit-litert to look up for MelSpectrogram.tflite
@@ -79,36 +80,56 @@ class ArgmaxModelDownloaderImpl(
mapOf(
WhisperKit.Builder.OPENAI_TINY_EN to
mapOf(
+ CONFIG_REPO to "openai/whisper-tiny.en",
TOKENIZER_REPO to "openai/whisper-tiny.en",
ENCODER_DECODER_REPO to "openai_whisper-tiny.en",
FEATURE_EXTRACTOR_PATH to "openai_whisper-tiny.en",
),
WhisperKit.Builder.OPENAI_BASE_EN to
mapOf(
+ CONFIG_REPO to "openai/whisper-base.en",
TOKENIZER_REPO to "openai/whisper-base.en",
ENCODER_DECODER_REPO to "openai_whisper-base.en",
FEATURE_EXTRACTOR_PATH to "openai_whisper-base.en",
),
+ WhisperKit.Builder.OPENAI_TINY to
+ mapOf(
+ CONFIG_REPO to "openai/whisper-tiny",
+ TOKENIZER_REPO to "openai/whisper-tiny",
+ ENCODER_DECODER_REPO to "openai_whisper-tiny",
+ FEATURE_EXTRACTOR_PATH to "openai_whisper-tiny",
+ ),
+ WhisperKit.Builder.OPENAI_BASE to
+ mapOf(
+ CONFIG_REPO to "openai/whisper-base",
+ TOKENIZER_REPO to "openai/whisper-base",
+ ENCODER_DECODER_REPO to "openai_whisper-base",
+ FEATURE_EXTRACTOR_PATH to "openai_whisper-base",
+ ),
WhisperKit.Builder.OPENAI_SMALL_EN to
mapOf(
+ CONFIG_REPO to "openai/whisper-small.en",
TOKENIZER_REPO to "openai/whisper-small.en",
ENCODER_DECODER_REPO to "openai_whisper-small.en",
FEATURE_EXTRACTOR_PATH to "openai_whisper-small.en",
),
WhisperKit.Builder.QUALCOMM_TINY_EN to
mapOf(
+ CONFIG_REPO to "openai/whisper-tiny.en",
TOKENIZER_REPO to "openai/whisper-tiny.en",
ENCODER_DECODER_REPO to "qualcomm/Whisper-Tiny-En",
FEATURE_EXTRACTOR_PATH to "quic_openai_whisper-tiny.en",
),
WhisperKit.Builder.QUALCOMM_BASE_EN to
mapOf(
+ CONFIG_REPO to "openai/whisper-base.en",
TOKENIZER_REPO to "openai/whisper-base.en",
ENCODER_DECODER_REPO to "qualcomm/Whisper-Base-En",
FEATURE_EXTRACTOR_PATH to "quic_openai_whisper-base.en",
),
WhisperKit.Builder.QUALCOMM_SMALL_EN to
mapOf(
+ CONFIG_REPO to "openai/whisper-small.en",
TOKENIZER_REPO to "openai/whisper-small.en",
ENCODER_DECODER_REPO to "qualcomm/Whisper-Small-En",
FEATURE_EXTRACTOR_PATH to "quic_openai_whisper-small.en",
@@ -120,12 +141,12 @@ class ArgmaxModelDownloaderImpl(
* Downloads model files for a specific variant and returns a flow of download progress.
*
* For OpenAI models (whisperkit-litert/openai_*):
- * - Downloads tokenizer.json from openai/whisper-*.en
+ * - Downloads config.json and tokenizer.json from openai/whisper-*
* - Downloads AudioEncoder.tflite and TextDecoder.tflite from argmaxinc/whisperkit-litert/openai_whisper-*
* - Downloads MelSpectrogram.tflite from argmaxinc/whisperkit-litert/openai_whisper-*
*
* For Qualcomm models (qualcomm/Whisper_*_En):
- * - Downloads tokenizer.json from openai/whisper-*.en
+ * - Downloads config.json and tokenizer.json from openai/whisper-*.en
* - Downloads WhisperEncoder.tflite and WhisperDecoder.tflite from qualcomm/Whisper-*-En
* and renames them to AudioEncoder.tflite and TextDecoder.tflite respectively
* - Downloads MelSpectrogram.tflite from argmaxinc/whisperkit-litert/quic_openai_whisper-*
@@ -133,6 +154,8 @@ class ArgmaxModelDownloaderImpl(
* @param variant The model variant to download. Must be one of:
* - [WhisperKit.Builder.OPENAI_TINY_EN]
* - [WhisperKit.Builder.OPENAI_BASE_EN]
+ * - [WhisperKit.Builder.OPENAI_TINY]
+ * - [WhisperKit.Builder.OPENAI_BASE]
* - [WhisperKit.Builder.OPENAI_SMALL_EN]
* - [WhisperKit.Builder.QUALCOMM_TINY_EN]
* - [WhisperKit.Builder.QUALCOMM_BASE_EN]
@@ -147,20 +170,53 @@ class ArgmaxModelDownloaderImpl(
): Flow {
val config =
modelConfigs[variant] ?: throw IllegalArgumentException("Invalid variant: $variant")
-
return combine(
+ downloadConfig(config, root),
downloadTokenizer(config, root),
downloadEncoderDecoder(variant, config, root),
downloadFeatureExtractor(config, root),
- ) { tokenizer, encoderDecoder, featureExtractor ->
- // Each flow contributes 1/3 to the total progress
+ ) { config, tokenizer, encoderDecoder, featureExtractor ->
+ // Each flow contributes 1/4 to the total progress
HuggingFaceApi.Progress(
- fractionCompleted =
- (
- tokenizer.fractionCompleted + encoderDecoder.fractionCompleted +
- featureExtractor.fractionCompleted
- ) / 3.0f,
+ fractionCompleted = (
+ config.fractionCompleted + tokenizer.fractionCompleted +
+ encoderDecoder.fractionCompleted + featureExtractor.fractionCompleted
+ ) / 4.0f,
+ )
+ }.onCompletion {
+ // Clean up model directories after all downloads are complete
+ if (!variant.startsWith("qualcomm/")) {
+ // For OpenAI models, clean up the model directory
+ File(root, config[ENCODER_DECODER_REPO]!!).deleteRecursively()
+ }
+ // Clean up feature extractor directory
+ File(root, config[FEATURE_EXTRACTOR_PATH]!!).deleteRecursively()
+ }
+ }
+
+ @OptIn(ExperimentalCoroutinesApi::class)
+ private fun downloadConfig(
+ config: Map,
+ root: File,
+ ): Flow {
+ return flow {
+ emit(
+ huggingFaceApi.getFileMetadata(
+ from = Repo(config[CONFIG_REPO]!!, RepoType.MODELS),
+ filename = "config.json",
+ ),
)
+ }.flatMapLatest { tokenizerMetadata ->
+ val cachedTokenizerFile = File(root, "config.json")
+ if (cachedTokenizerFile.exists() && cachedTokenizerFile.length() == tokenizerMetadata.size) {
+ flowOf(HuggingFaceApi.Progress(1.0f))
+ } else {
+ huggingFaceApi.snapshot(
+ from = Repo(config[CONFIG_REPO]!!, RepoType.MODELS),
+ globFilters = listOf("config.json"),
+ baseDir = root,
+ )
+ }
}
}
@@ -314,7 +370,6 @@ class ArgmaxModelDownloaderImpl(
"MelSpectrogram.tflite",
),
)
- File(root, modelDir).deleteRecursively()
}
}
}
diff --git a/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/util/MessageProcessor.kt b/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/util/MessageProcessor.kt
new file mode 100644
index 0000000..d5e3e2a
--- /dev/null
+++ b/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/util/MessageProcessor.kt
@@ -0,0 +1,10 @@
+package com.argmaxinc.whisperkit.util
+
+import com.argmaxinc.whisperkit.TranscriptionResult
+
+/**
+ * Processor to convert raw model output Strings into [TranscriptionResult]
+ */
+internal interface MessageProcessor {
+ fun process(rawMsg: String): TranscriptionResult
+}
diff --git a/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/util/SegmentTextOnlyMessageProcessor.kt b/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/util/SegmentTextOnlyMessageProcessor.kt
new file mode 100644
index 0000000..8aa2d6a
--- /dev/null
+++ b/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/util/SegmentTextOnlyMessageProcessor.kt
@@ -0,0 +1,46 @@
+package com.argmaxinc.whisperkit.util
+
+import com.argmaxinc.whisperkit.TranscriptionResult
+import com.argmaxinc.whisperkit.TranscriptionSegment
+
+/**
+ * A processor to only extract segment text from raw string, ignoring all timestamps or windows
+ */
+internal class SegmentTextOnlyMessageProcessor : MessageProcessor {
+ private companion object {
+ private val TIMESTAMP_PATTERN = "<\\|(\\d+\\.\\d+)\\|>".toRegex()
+
+ // Pattern to match any <|str|> that's not a timestamp
+ private val NON_TIMESTAMP_PATTERN = "<\\|(?!\\d+\\.\\d+)[^>]*\\|>".toRegex()
+ }
+
+ override fun process(rawMsg: String): TranscriptionResult {
+ // Remove any markers that aren't timestamps
+ val cleanMsg = rawMsg.replace(NON_TIMESTAMP_PATTERN, "")
+
+ val segments = mutableListOf()
+
+ // Find all timestamp markers
+ val matches = TIMESTAMP_PATTERN.findAll(cleanMsg).toList()
+
+ for (i in 0 until matches.size - 1) {
+ val startMatch = matches[i]
+ val endMatch = matches[i + 1]
+
+ // TODD: add start and end to each segment
+ // val start = startMatch.groupValues[1].toFloat()
+ // val end = endMatch.groupValues[1].toFloat()
+
+ // Extract text between timestamps
+ val textStart = startMatch.range.last + 1
+ val textEnd = endMatch.range.first
+ val text = cleanMsg.substring(textStart, textEnd)
+
+ if (text != "\n" && text.isNotEmpty()) {
+ segments.add(TranscriptionSegment(text))
+ }
+ }
+
+ return TranscriptionResult(text = rawMsg, segments = segments)
+ }
+}
diff --git a/android/whisperkit/src/test/java/com/argmaxinc/whisperkit/network/ArgmaxModelDownloaderImplTest.kt b/android/whisperkit/src/test/java/com/argmaxinc/whisperkit/network/ArgmaxModelDownloaderImplTest.kt
index a3f8237..28c1cbc 100644
--- a/android/whisperkit/src/test/java/com/argmaxinc/whisperkit/network/ArgmaxModelDownloaderImplTest.kt
+++ b/android/whisperkit/src/test/java/com/argmaxinc/whisperkit/network/ArgmaxModelDownloaderImplTest.kt
@@ -45,6 +45,14 @@ class ArgmaxModelDownloaderImplTest {
) {
val qualcommModel = modelVariant.startsWith("qualcomm/")
+ // Mock config metadata
+ coEvery {
+ huggingFaceApi.getFileMetadata(
+ from = eq(Repo(expectedTokenizerRepo, RepoType.MODELS)),
+ filename = eq("config.json"),
+ )
+ } returns HuggingFaceApi.FileMetadata(500L, "config.json")
+
// Mock tokenizer metadata
coEvery {
huggingFaceApi.getFileMetadata(
@@ -115,7 +123,16 @@ class ArgmaxModelDownloaderImplTest {
expectedMelSpectrogramPath: String,
) {
// Create cached files upfront
- // 1. Create tokenizer.json
+ // 1. Create config.json
+ val configFile = File(root, "config.json")
+ configFile.createNewFile()
+ configFile.setWritable(true)
+ configFile.setReadable(true)
+ configFile.setExecutable(true)
+ configFile.setLastModified(System.currentTimeMillis())
+ configFile.writeBytes(ByteArray(500))
+
+ // 2. Create tokenizer.json
val tokenizerFile = File(root, "tokenizer.json")
tokenizerFile.createNewFile()
tokenizerFile.setWritable(true)
@@ -124,7 +141,7 @@ class ArgmaxModelDownloaderImplTest {
tokenizerFile.setLastModified(System.currentTimeMillis())
tokenizerFile.writeBytes(ByteArray(1000)) // Set size to match metadata
- // 2. Create encoder/decoder files
+ // 3. Create encoder/decoder files
val audioEncoderFile = File(root, "AudioEncoder.tflite")
val textDecoderFile = File(root, "TextDecoder.tflite")
audioEncoderFile.createNewFile()
@@ -132,7 +149,7 @@ class ArgmaxModelDownloaderImplTest {
audioEncoderFile.writeBytes(ByteArray(2000)) // Set size to match metadata
textDecoderFile.writeBytes(ByteArray(3000)) // Set size to match metadata
- // 3. Create MelSpectrogram.tflite
+ // 4. Create MelSpectrogram.tflite
val melSpectrogramFile = File(root, "MelSpectrogram.tflite")
melSpectrogramFile.createNewFile()
melSpectrogramFile.writeBytes(ByteArray(4000)) // Set size to match metadata
@@ -197,6 +214,13 @@ class ArgmaxModelDownloaderImplTest {
"MelSpectrogram.tflite",
).exists(),
) { "MelSpectrogram.tflite should exist in root directory" }
+
+ // Add directory deletion verification
+ if (!qualcommModel) {
+ assert(!File(root, expectedMelSpectrogramPath).exists()) {
+ "Model directory should be deleted after download"
+ }
+ }
}
private fun testNonCached(
@@ -207,6 +231,15 @@ class ArgmaxModelDownloaderImplTest {
expectedEncoderDecoderRepo: String,
expectedEncoderDecoderGlobFilters: List,
) {
+ // Mock config snapshot
+ every {
+ huggingFaceApi.snapshot(
+ from = eq(Repo(expectedTokenizerRepo, RepoType.MODELS)),
+ globFilters = eq(listOf("config.json")),
+ baseDir = eq(root),
+ )
+ } returns flowOf(HuggingFaceApi.Progress(1.0f))
+
// Mock tokenizer snapshot
every {
huggingFaceApi.snapshot(
@@ -264,7 +297,14 @@ class ArgmaxModelDownloaderImplTest {
downloader.download(ArgmaxModel.WHISPER, modelVariant, root).collect {}
}
- // Verify snapshot was called exactly 3 times for each component
+ // Verify snapshot was called exactly 4 times for each component
+ verify(exactly = 1) {
+ huggingFaceApi.snapshot(
+ from = eq(Repo(expectedTokenizerRepo, RepoType.MODELS)),
+ globFilters = eq(listOf("config.json")),
+ baseDir = eq(root),
+ )
+ }
verify(exactly = 1) {
huggingFaceApi.snapshot(
from = eq(Repo(expectedTokenizerRepo, RepoType.MODELS)),
@@ -346,6 +386,13 @@ class ArgmaxModelDownloaderImplTest {
"MelSpectrogram.tflite",
).exists(),
) { "MelSpectrogram.tflite should exist in root directory" }
+
+ // Add directory deletion verification
+ if (!qualcommModel) {
+ assert(!File(root, expectedMelSpectrogramPath).exists()) {
+ "Model directory should be deleted after download"
+ }
+ }
}
@Test
@@ -362,6 +409,20 @@ class ArgmaxModelDownloaderImplTest {
expectedMelSpectrogramPath = "openai_whisper-tiny.en",
)
+ @Test
+ fun `download creates correct flows for OpenAI tiny multilingual model`() =
+ testDownload(
+ modelVariant = WhisperKit.Builder.OPENAI_TINY,
+ expectedTokenizerRepo = "openai/whisper-tiny",
+ expectedEncoderDecoderRepo = "argmaxinc/whisperkit-litert",
+ expectedEncoderDecoderGlobFilters =
+ listOf(
+ "openai_whisper-tiny/AudioEncoder.tflite",
+ "openai_whisper-tiny/TextDecoder.tflite",
+ ),
+ expectedMelSpectrogramPath = "openai_whisper-tiny",
+ )
+
@Test
fun `download creates correct flows for OpenAI base model`() =
testDownload(
@@ -376,6 +437,20 @@ class ArgmaxModelDownloaderImplTest {
expectedMelSpectrogramPath = "openai_whisper-base.en",
)
+ @Test
+ fun `download creates correct flows for OpenAI base multilingual model`() =
+ testDownload(
+ modelVariant = WhisperKit.Builder.OPENAI_BASE,
+ expectedTokenizerRepo = "openai/whisper-base",
+ expectedEncoderDecoderRepo = "argmaxinc/whisperkit-litert",
+ expectedEncoderDecoderGlobFilters =
+ listOf(
+ "openai_whisper-base/AudioEncoder.tflite",
+ "openai_whisper-base/TextDecoder.tflite",
+ ),
+ expectedMelSpectrogramPath = "openai_whisper-base",
+ )
+
@Test
fun `download creates correct flows for OpenAI small model`() =
testDownload(
@@ -447,6 +522,21 @@ class ArgmaxModelDownloaderImplTest {
cached = true,
)
+ @Test
+ fun `download uses cached files when available for OpenAI tiny multilingual model`() =
+ testDownload(
+ modelVariant = WhisperKit.Builder.OPENAI_TINY,
+ expectedTokenizerRepo = "openai/whisper-tiny",
+ expectedEncoderDecoderRepo = "argmaxinc/whisperkit-litert",
+ expectedEncoderDecoderGlobFilters =
+ listOf(
+ "openai_whisper-tiny/AudioEncoder.tflite",
+ "openai_whisper-tiny/TextDecoder.tflite",
+ ),
+ expectedMelSpectrogramPath = "openai_whisper-tiny",
+ cached = true,
+ )
+
@Test
fun `download uses cached files when available for OpenAI base model`() =
testDownload(
@@ -462,6 +552,21 @@ class ArgmaxModelDownloaderImplTest {
cached = true,
)
+ @Test
+ fun `download uses cached files when available for OpenAI base multilingual model`() =
+ testDownload(
+ modelVariant = WhisperKit.Builder.OPENAI_BASE,
+ expectedTokenizerRepo = "openai/whisper-base",
+ expectedEncoderDecoderRepo = "argmaxinc/whisperkit-litert",
+ expectedEncoderDecoderGlobFilters =
+ listOf(
+ "openai_whisper-base/AudioEncoder.tflite",
+ "openai_whisper-base/TextDecoder.tflite",
+ ),
+ expectedMelSpectrogramPath = "openai_whisper-base",
+ cached = true,
+ )
+
@Test
fun `download uses cached files when available for OpenAI small model`() =
testDownload(
@@ -531,11 +636,20 @@ class ArgmaxModelDownloaderImplTest {
@Test
fun `download combines progress from all flows correctly`() =
runTest {
+ val configProgress = 0.2f
val tokenizerProgress = 0.3f
val encoderDecoderProgress = 0.6f
val melSpectrogramProgress = 0.9f
val expectedProgress =
- (tokenizerProgress + encoderDecoderProgress + melSpectrogramProgress) / 3.0f
+ (configProgress + tokenizerProgress + encoderDecoderProgress + melSpectrogramProgress) / 4.0f
+
+ // Mock config metadata
+ coEvery {
+ huggingFaceApi.getFileMetadata(
+ from = eq(Repo("openai/whisper-tiny.en", RepoType.MODELS)),
+ filename = eq("config.json"),
+ )
+ } returns HuggingFaceApi.FileMetadata(500L, "config.json")
// Mock tokenizer metadata
coEvery {
@@ -570,6 +684,15 @@ class ArgmaxModelDownloaderImplTest {
)
// Mock tokenizer snapshot
+ every {
+ huggingFaceApi.snapshot(
+ from = eq(Repo("openai/whisper-tiny.en", RepoType.MODELS)),
+ globFilters = eq(listOf("config.json")),
+ baseDir = eq(root),
+ )
+ } returns flowOf(HuggingFaceApi.Progress(configProgress))
+
+ // Mock encoder/decoder snapshot
every {
huggingFaceApi.snapshot(
from = eq(Repo("openai/whisper-tiny.en", RepoType.MODELS)),
diff --git a/android/whisperkit/src/test/java/com/argmaxinc/whisperkit/util/SegmentTextOnlyMessageProcessorTest.kt b/android/whisperkit/src/test/java/com/argmaxinc/whisperkit/util/SegmentTextOnlyMessageProcessorTest.kt
new file mode 100644
index 0000000..52c375f
--- /dev/null
+++ b/android/whisperkit/src/test/java/com/argmaxinc/whisperkit/util/SegmentTextOnlyMessageProcessorTest.kt
@@ -0,0 +1,97 @@
+package com.argmaxinc.whisperkit.util
+
+import com.argmaxinc.whisperkit.TranscriptionSegment
+import org.junit.Assert.assertEquals
+import org.junit.Test
+
+class SegmentTextOnlyMessageProcessorTest {
+ private val processor = SegmentTextOnlyMessageProcessor()
+
+ @Test
+ fun `process extracts text segments correctly`() {
+ // Given
+ val rawMsg = """
+ <|startoftranscript|><|0.00|> When I was 27 years old, I left a very demanding job in management consulting.<|18.48|><|18.48|> For a job that was even more demanding, teaching.<|21.84|><|endoftext|>
+ <|startoftranscript|><|0.00|> I went to teach seventh graders, Math, in the New York City Public Schools.<|5.88|><|5.88|> And like any teacher, I made quizzes and tests.<|8.08|><|endoftext|>
+ """.trimIndent()
+
+ // When
+ val result = processor.process(rawMsg)
+
+ // Then
+ val expectedSegments = listOf(
+ TranscriptionSegment(" When I was 27 years old, I left a very demanding job in management consulting."),
+ TranscriptionSegment(" For a job that was even more demanding, teaching."),
+ TranscriptionSegment(" I went to teach seventh graders, Math, in the New York City Public Schools."),
+ TranscriptionSegment(" And like any teacher, I made quizzes and tests."),
+ )
+
+ assertEquals(expectedSegments, result.segments)
+ }
+
+ @Test
+ fun `process handles multiple segments with same timestamps`() {
+ // Given
+ val rawMsg = """
+ <|startoftranscript|><|0.00|> not fixed, that it can change with your effort.<|4.48|><|4.48|> Dr. Dweck has shown that when kids read and learn<|7.44|><|7.44|> about the brain and how it changes and grows<|10.72|><|10.72|> in response to challenge, they're much more likely<|13.64|><|13.64|> to persevere when they fail because they don't believe<|18.80|><|18.80|> that failure is a permanent condition.<|21.56|><|endoftext|>
+ <|startoftranscript|><|0.00|> So, growth mindset is a great idea for building grit, but we need more.<|6.60|><|6.60|> And that's where I'm going to end my remarks, because that's where we are.<|9.28|><|endoftext|>
+ """.trimIndent()
+
+ // When
+ val result = processor.process(rawMsg)
+
+ // Then
+ val expectedSegments = listOf(
+ TranscriptionSegment(" not fixed, that it can change with your effort."),
+ TranscriptionSegment(" Dr. Dweck has shown that when kids read and learn"),
+ TranscriptionSegment(" about the brain and how it changes and grows"),
+ TranscriptionSegment(" in response to challenge, they're much more likely"),
+ TranscriptionSegment(" to persevere when they fail because they don't believe"),
+ TranscriptionSegment(" that failure is a permanent condition."),
+ TranscriptionSegment(" So, growth mindset is a great idea for building grit, but we need more."),
+ TranscriptionSegment(" And that's where I'm going to end my remarks, because that's where we are."),
+ )
+
+ assertEquals(expectedSegments, result.segments)
+ }
+
+ @Test
+ fun `process handles special characters and short segments`() {
+ // Given
+ val rawMsg = "<|startoftranscript|><|0.00|> ♪♪<|2.00|><|endoftext|>"
+
+ // When
+ val result = processor.process(rawMsg)
+
+ // Then
+ val expectedSegments = listOf(
+ TranscriptionSegment(" ♪♪"),
+ )
+
+ assertEquals(expectedSegments, result.segments)
+ }
+
+ @Test
+ fun `process handles empty input`() {
+ // Given
+ val rawMsg = ""
+
+ // When
+ val result = processor.process(rawMsg)
+
+ // Then
+ assertEquals(emptyList(), result.segments)
+ }
+
+ @Test
+ fun `process handles input with no segments`() {
+ // Given
+ val rawMsg = "<|startoftranscript|><|endoftext|>"
+
+ // When
+ val result = processor.process(rawMsg)
+
+ // Then
+ assertEquals(emptyList(), result.segments)
+ }
+}
diff --git a/cpp/src/Models/TextDecoder.cpp b/cpp/src/Models/TextDecoder.cpp
index 297bc94..aaa1a79 100644
--- a/cpp/src/Models/TextDecoder.cpp
+++ b/cpp/src/Models/TextDecoder.cpp
@@ -12,7 +12,154 @@
using namespace WhisperKit;
// 'Monolithic KV Cache' ~ corresponds to the QUIC exported Whisper models
-// TODO: make a metadata class for the model signature and tendor indices,
+namespace WhisperKit {
+constexpr const int kKvFactor = 2;
+constexpr const int kLayersWhisperTiny = 4;
+constexpr const int kLayersWhisperBase = 6;
+constexpr const int kLayersWhisperSmall = 12;
+constexpr const int kLayersWhisperMedium = 24;
+constexpr const int kLayersWhisperLarge = 32;
+
+constexpr const char* kVariantWhisperTiny = "tiny";
+constexpr const char* kVariantWhisperBase = "base";
+constexpr const char* kVariantWhisperSmall = "small";
+constexpr const char* kVariantWhisperMedium = "medium";
+constexpr const char* kVariantWhisperLarge = "large";
+constexpr const char* kVariantNone = "none";
+} // namespace WhisperKit
+
+namespace {
+
+std::string normalize_name(const std::string& name) {
+ // Names were padded with null characters to ensure alignment with original exported names
+ // Remove extra null characters, or naive string matching will fail.
+ auto name_copy = name.c_str();
+ return std::string(name_copy);
+};
+
+} // namespace
+
+class FlatBuffersMetadata {
+ public:
+ FlatBuffersMetadata(const std::string& tflite_model_path) {
+ _model_file_path = tflite_model_path;
+ std::ifstream file(tflite_model_path, std::ios::binary | std::ios::ate);
+ if (!file) throw std::runtime_error("Failed to open file");
+
+ std::streamsize size = file.tellg();
+ file.seekg(0, std::ios::beg);
+
+ _buffer = std::vector(size);
+ if (!file.read(_buffer.data(), size)) throw std::runtime_error("Failed to read file");
+
+ const tflite::Model* model = tflite::GetModel(_buffer.data());
+
+ if (!model) {
+ throw std::runtime_error("Model is null");
+ }
+ if (!model->subgraphs()) {
+ throw std::runtime_error("Model has no subgraphs");
+ }
+ _model = model;
+ parse_model_metadata();
+ }
+
+ ~FlatBuffersMetadata() {
+ _model = nullptr;
+ _input_tensor_indices.clear();
+ _output_tensor_indices.clear();
+ _subgraphs = nullptr;
+ _buffer.clear();
+ _buffer.shrink_to_fit();
+ }
+
+ const std::string& get_model_file_path() const { return _model_file_path; }
+
+ const std::unordered_map> get_input_tensor_indices(int subgraph_index = 0) const;
+ const std::unordered_map> get_output_tensor_indices(int subgraph_index = 0) const;
+ void print_metadata();
+ const tflite::Model* get_model() const { return _model; }
+
+ private:
+ void parse_model_metadata();
+ std::string _model_file_path;
+ const tflite::Model* _model;
+ std::vector _buffer;
+ const ::flatbuffers::Vector<::flatbuffers::Offset>* _subgraphs;
+
+ // name -> (tensor_index, io_index)
+ std::vector>> _input_tensor_indices;
+ std::vector>> _output_tensor_indices;
+};
+
+void FlatBuffersMetadata::parse_model_metadata() {
+ _subgraphs = _model->subgraphs();
+
+ // Primary subgraph is the only one we care about for now.
+ bool only_first_subgraph = true;
+ int max_subgraph_index = only_first_subgraph ? 1 : _subgraphs->size();
+
+ _input_tensor_indices.resize(max_subgraph_index);
+ _output_tensor_indices.resize(max_subgraph_index);
+
+ for (int i = 0; i < max_subgraph_index; i++) {
+ const tflite::SubGraph* subgraph = _subgraphs->Get(i);
+ std::unordered_map> input_tensor_indices;
+ std::unordered_map> output_tensor_indices;
+
+ const auto* inputs = subgraph->inputs();
+ const auto* outputs = subgraph->outputs();
+ const auto* tensors = subgraph->tensors();
+
+ for (int i = 0; i < inputs->size(); ++i) {
+ int tensor_index = inputs->Get(i);
+ auto name = normalize_name(tensors->Get(tensor_index)->name()->str());
+ input_tensor_indices[name] = std::make_pair(tensor_index, i);
+ }
+
+ for (int i = 0; i < outputs->size(); ++i) {
+ int tensor_index = outputs->Get(i);
+ auto name = normalize_name(tensors->Get(tensor_index)->name()->str());
+ output_tensor_indices[name] = std::make_pair(tensor_index, i);
+ }
+ _input_tensor_indices[i] = input_tensor_indices;
+ _output_tensor_indices[i] = output_tensor_indices;
+ }
+}
+
+void FlatBuffersMetadata::print_metadata() {
+ std::cout << "Model file path: " << _model_file_path << std::endl;
+ for (int i = 0; i < _input_tensor_indices.size(); i++) {
+ std::cout << "Subgraph " << i << " input tensor indices:" << std::endl;
+ for (const auto& [name, indices] : _input_tensor_indices[i]) {
+ std::cout << " " << name << ": (" << indices.first << ", " << indices.second << ")" << std::endl;
+ }
+ }
+ for (int i = 0; i < _output_tensor_indices.size(); i++) {
+ std::cout << "Subgraph " << i << " output tensor indices:" << std::endl;
+ for (const auto& [name, indices] : _output_tensor_indices[i]) {
+ std::cout << " " << name << ": (" << indices.first << ", " << indices.second << ")" << std::endl;
+ }
+ }
+}
+
+const std::unordered_map> FlatBuffersMetadata::get_input_tensor_indices(
+ int subgraph_index) const {
+ if (subgraph_index >= _input_tensor_indices.size()) {
+ throw std::runtime_error("Subgraph index out of bounds");
+ }
+ return _input_tensor_indices[subgraph_index];
+}
+
+const std::unordered_map> FlatBuffersMetadata::get_output_tensor_indices(
+ int subgraph_index) const {
+ if (subgraph_index >= _output_tensor_indices.size()) {
+ throw std::runtime_error("Subgraph index out of bounds");
+ }
+ return _output_tensor_indices[subgraph_index];
+}
+
+// TODO: make a metadata class for the model signature and tensor indices,
// pass to the subclasses so we don't have to reopen the file outside of
// loading the actual tflite model. TFLIte APIs require signature runner
// to get this information from interpreter, which is not available if the
@@ -22,7 +169,17 @@ bool is_exact_match_for_monolithic_kv_cache(const tflite::Model* model) {
"x", "index", "k_cache_cross", "v_cache_cross", "k_cache_self", "v_cache_self"};
const std::unordered_set expected_output_names = {"logits", "k_cache", "v_cache"};
- const auto* subgraph = model->subgraphs()->Get(0);
+ if (!model->subgraphs()) {
+ throw std::runtime_error("Model has no subgraphs");
+ }
+
+ const auto* subgraphs = model->subgraphs();
+
+ if (subgraphs->size() == 0) {
+ throw std::runtime_error("Model has no subgraphs");
+ }
+
+ const auto* subgraph = subgraphs->Get(0);
const auto* inputs = subgraph->inputs();
const auto* outputs = subgraph->outputs();
@@ -31,11 +188,11 @@ bool is_exact_match_for_monolithic_kv_cache(const tflite::Model* model) {
}
const auto* tensors = subgraph->tensors();
-
std::unordered_set input_names;
for (int i = 0; i < inputs->size(); ++i) {
int tensor_index = inputs->Get(i);
- input_names.insert(tensors->Get(tensor_index)->name()->str());
+ auto name = tensors->Get(tensor_index)->name()->str();
+ input_names.insert(name);
}
std::unordered_set output_names;
@@ -51,22 +208,6 @@ bool is_exact_match_for_monolithic_kv_cache(const tflite::Model* model) {
return true;
}
-namespace WhisperKit {
-constexpr const int kKvFactor = 2;
-constexpr const int kLayersWhisperTiny = 4;
-constexpr const int kLayersWhisperBase = 6;
-constexpr const int kLayersWhisperSmall = 12;
-constexpr const int kLayersWhisperMedium = 24;
-constexpr const int kLayersWhisperLarge = 32;
-
-constexpr const char* kVariantWhisperTiny = "tiny";
-constexpr const char* kVariantWhisperBase = "base";
-constexpr const char* kVariantWhisperSmall = "small";
-constexpr const char* kVariantWhisperMedium = "medium";
-constexpr const char* kVariantWhisperLarge = "large";
-constexpr const char* kVariantNone = "none";
-} // namespace WhisperKit
-
const int layers_for_variant(const std::string& variant) {
if (variant == kVariantWhisperTiny) {
return kLayersWhisperTiny;
@@ -82,6 +223,25 @@ const int layers_for_variant(const std::string& variant) {
return 0;
}
+std::unordered_set get_expected_input_names_for_variant(const char* variant) {
+ auto input_names_for_variant_with_layers = [](const int num_layers) -> auto{
+ std::unordered_set input_names;
+ input_names.insert(std::string("x"));
+ input_names.insert(std::string("index"));
+ input_names.insert(std::string("k_cache_cross"));
+ input_names.insert(std::string("v_cache_cross"));
+ for (int i = 0; i < num_layers; ++i) {
+ input_names.insert(std::string("k_cache_self_" + std::to_string(i)));
+ input_names.insert(std::string("v_cache_self_" + std::to_string(i)));
+ }
+ return input_names;
+ };
+
+ int num_layers = layers_for_variant(variant);
+
+ return input_names_for_variant_with_layers(num_layers);
+}
+
bool is_exact_match_for_separate_kv_cache_no_alignment_heads(const tflite::Model* model) {
const auto* subgraph = model->subgraphs()->Get(0);
const auto* inputs = subgraph->inputs();
@@ -119,30 +279,17 @@ bool is_exact_match_for_separate_kv_cache_no_alignment_heads(const tflite::Model
return output_names;
};
- auto input_names_for_variant_with_layers = [](const int num_layers) -> auto{
- std::unordered_set input_names;
- input_names.insert(std::string("x"));
- input_names.insert(std::string("index"));
- input_names.insert(std::string("k_cache_cross"));
- input_names.insert(std::string("v_cache_cross"));
- for (int i = 0; i < num_layers; ++i) {
- input_names.insert(std::string("k_cache_self_" + std::to_string(i)));
- input_names.insert(std::string("v_cache_self_" + std::to_string(i)));
- }
- return input_names;
- };
-
char* variant = const_cast(kVariantNone);
- if (num_inputs == calculate_num_inputs_for_variant_with_layers(kLayersWhisperTiny)) {
+ if (num_inputs == get_expected_input_names_for_variant(kVariantWhisperTiny).size()) {
variant = const_cast(kVariantWhisperTiny);
- } else if (num_inputs == calculate_num_inputs_for_variant_with_layers(kLayersWhisperBase)) {
+ } else if (num_inputs == get_expected_input_names_for_variant(kVariantWhisperBase).size()) {
variant = const_cast(kVariantWhisperBase);
- } else if (num_inputs == calculate_num_inputs_for_variant_with_layers(kLayersWhisperSmall)) {
+ } else if (num_inputs == get_expected_input_names_for_variant(kVariantWhisperSmall).size()) {
variant = const_cast(kVariantWhisperSmall);
- } else if (num_inputs == calculate_num_inputs_for_variant_with_layers(kLayersWhisperMedium)) {
+ } else if (num_inputs == get_expected_input_names_for_variant(kVariantWhisperMedium).size()) {
variant = const_cast(kVariantWhisperMedium);
- } else if (num_inputs == calculate_num_inputs_for_variant_with_layers(kLayersWhisperLarge)) {
+ } else if (num_inputs == get_expected_input_names_for_variant(kVariantWhisperLarge).size()) {
variant = const_cast(kVariantWhisperLarge);
}
@@ -173,20 +320,13 @@ bool is_exact_match_for_separate_kv_cache_no_alignment_heads(const tflite::Model
}
}
- auto expected_input_names = input_names_for_variant_with_layers(layers_for_variant(variant));
+ auto expected_input_names = get_expected_input_names_for_variant(variant);
auto expected_output_names = output_names_for_variant_with_layers(layers_for_variant(variant));
std::unordered_set input_names;
std::unordered_set output_names;
const auto* tensors = subgraph->tensors();
- auto normalize_name = [](const std::string& name) -> std::string {
- // Names were padded with null characters to ensure alignment with original exported names
- // Remove extra null characters, or naive string matching will fail.
- auto name_copy = name.c_str();
- return std::string(name_copy);
- };
-
for (int i = 0; i < num_inputs; ++i) {
auto name = normalize_name(tensors->Get(inputs->Get(i))->name()->str());
input_names.insert(name);
@@ -208,27 +348,21 @@ bool is_exact_match_for_separate_kv_cache_no_alignment_heads(const tflite::Model
return true;
}
-std::unique_ptr TextDecoderFactory::CreateFromFile(const std::string& tflite_model_path) {
- std::ifstream file(tflite_model_path, std::ios::binary | std::ios::ate);
- if (!file) throw std::runtime_error("Failed to open file");
-
- std::streamsize size = file.tellg();
- file.seekg(0, std::ios::beg);
-
- std::vector buffer(size);
- if (!file.read(buffer.data(), size)) throw std::runtime_error("Failed to read file");
-
- const tflite::Model* model = tflite::GetModel(buffer.data());
-
- if (!model) throw std::runtime_error("Failed to load model");
-
- auto is_monolithic_kv_cache = is_exact_match_for_monolithic_kv_cache(model);
+TextDecoder::~TextDecoder() {}
+std::unique_ptr TextDecoderFactory::CreateFromFile(const std::string& tflite_model_path) {
+ auto metadata = std::make_unique(tflite_model_path);
+ auto is_monolithic_kv_cache = is_exact_match_for_monolithic_kv_cache(metadata->get_model());
if (is_monolithic_kv_cache) {
return std::make_unique(tflite_model_path);
}
- auto is_separate_kv_cache_no_alignment_heads = is_exact_match_for_separate_kv_cache_no_alignment_heads(model);
+ auto is_separate_kv_cache_no_alignment_heads =
+ is_exact_match_for_separate_kv_cache_no_alignment_heads(metadata->get_model());
+
+ if (is_separate_kv_cache_no_alignment_heads) {
+ return std::make_unique(tflite_model_path);
+ }
throw std::runtime_error("Decoder model signature not recognized");
}
@@ -237,14 +371,22 @@ std::pair MonolithicKVDecoder::get_logits_tensor() { return decoder_
MonolithicKVDecoder::MonolithicKVDecoder(const std::string& tflite_model_path) {
_model_path = tflite_model_path;
+ metadata = std::make_unique(tflite_model_path);
+ // metadata->print_metadata();
// Note that the decoder model is not initialized here, it is initialized in the initialize method
_decoder_model = std::make_unique("TextDecoder");
+
if (!_decoder_model) {
throw std::runtime_error("Decoder model not initialized");
}
}
+MonolithicKVDecoder::~MonolithicKVDecoder() {
+ _decoder_model.reset();
+ metadata.reset();
+}
+
bool MonolithicKVDecoder::initialize(std::string model_path, std::string lib_dir, std::string cache_dir, int backend,
bool debug) {
return _decoder_model->initialize(model_path, lib_dir, cache_dir, backend, debug);
@@ -317,3 +459,293 @@ std::vector> MonolithicKVDecoder::get_output_ptrs() { retu
int MonolithicKVDecoder::get_inference_num() { return _decoder_model->get_inference_num(); }
float MonolithicKVDecoder::get_latency_sum() { return _decoder_model->get_latency_sum(); }
+void MonolithicKVDecoder::dump_input_tensors() {
+ // Not yet implemented
+}
+
+void MonolithicKVDecoder::dump_output_tensors() {
+ // Not yet implemented
+}
+
+std::pair PerLayerKVDecoder::get_logits_tensor() {
+ if (decoder_outputs.empty()) {
+ decoder_outputs = _decoder_model->get_output_ptrs();
+ }
+ auto logits_index = output_tensor_indices.at("logits");
+ return decoder_outputs[logits_index];
+}
+
+PerLayerKVDecoder::PerLayerKVDecoder(const std::string& tflite_model_path) {
+ _model_path = tflite_model_path;
+ metadata = std::make_unique(tflite_model_path);
+ initialize_io_metadata();
+ metadata.reset(); // to close the .tflite file
+
+ // Note that the decoder model is not initialized here, it is initialized in the initialize method
+ _decoder_model = std::make_unique("TextDecoder");
+ if (!_decoder_model) {
+ throw std::runtime_error("Decoder model not initialized");
+ }
+}
+
+PerLayerKVDecoder::~PerLayerKVDecoder() {
+ _decoder_model.reset();
+ metadata.reset();
+}
+
+bool PerLayerKVDecoder::initialize(std::string model_path, std::string lib_dir, std::string cache_dir, int backend,
+ bool debug) {
+ return _decoder_model->initialize(model_path, lib_dir, cache_dir, backend, debug);
+}
+
+void PerLayerKVDecoder::uninitialize() { _decoder_model->uninitialize(); }
+
+void PerLayerKVDecoder::read_input_data(char* input_data, int idx) { _decoder_model->read_input_data(input_data, idx); }
+
+void PerLayerKVDecoder::bind_input_tensor(char* input_data, const std::string& tensor_name) {
+ if (tensor_name == "x") {
+ // get value in int32 and upcast to int64 before passing to read_input_data, which
+ // does a simple memcpy of all the bytes.
+ int32_t* input_data_int32 = reinterpret_cast(input_data);
+ int64_t x = static_cast(*input_data_int32);
+ auto x_tensor_index = input_tensor_indices["x"];
+ _decoder_model->read_input_data(reinterpret_cast(&x), x_tensor_index);
+ return;
+ }
+
+ if (tensor_name == "index") {
+ int* input_data_int32 = reinterpret_cast(input_data);
+ int64_t index = static_cast(*input_data_int32);
+ auto index_tensor_index = input_tensor_indices["index"];
+
+ _decoder_model->read_input_data(reinterpret_cast(&index), index_tensor_index);
+ return;
+ }
+
+ if (tensor_name == "k_cache_cross") {
+ auto k_cache_cross_tensor_index = input_tensor_indices["k_cache_cross"];
+ _decoder_model->read_input_data(input_data, k_cache_cross_tensor_index);
+ return;
+ }
+
+ if (tensor_name == "v_cache_cross") {
+ auto v_cache_cross_tensor_index = input_tensor_indices["v_cache_cross"];
+ _decoder_model->read_input_data(input_data, v_cache_cross_tensor_index);
+ return;
+ }
+
+ else {
+ auto tensor_index = kv_cache_input_tensor_indices[tensor_name];
+ _decoder_model->read_input_data(input_data, tensor_index);
+ return;
+ }
+
+ throw std::runtime_error("Invalid tensor name");
+}
+
+void PerLayerKVDecoder::invoke(bool measure_time) { _decoder_model->invoke(measure_time); }
+
+template
+void save_to_binary_file(const std::string& filename, const std::vector& data) {
+ std::ofstream outfile(filename, std::ios::binary);
+ if (!outfile) {
+ std::cerr << "Error opening file for writing: " << filename << std::endl;
+ return;
+ }
+
+ printf("Saving to binary file: %s\n", filename.c_str());
+ outfile.write(reinterpret_cast(data.data()), data.size() * sizeof(T));
+ outfile.close();
+}
+
+void PerLayerKVDecoder::dump_input_tensors() {
+ printf("Dumping input tensors\n");
+ auto input_ptrs = _decoder_model->get_input_ptrs();
+
+ for (int index = 0; index < _decoder_model->_interpreter->inputs().size(); index++) {
+ auto name = _decoder_model->_interpreter->GetInputName(index);
+ auto safe_name = normalize_name(name);
+ printf("safe name %s index %d\n", safe_name.c_str(), index);
+ if (safe_name == "x" || safe_name == "index") {
+ printf("Dumping input tensor: %s, index=%d\n", name, index);
+ auto tensor_ptr = input_ptrs[index];
+ int tensor_size = tensor_ptr.second;
+ printf("Tensor size: %d\n", tensor_size);
+ int64_t* input_data = reinterpret_cast(tensor_ptr.first);
+ size_t num_elements = tensor_size / sizeof(int64_t);
+ std::vector data(num_elements);
+ printf("==============================================\n");
+ printf("name %s index %d num_elements: %ld\n", name, index, num_elements);
+ for (int i = 0; i < num_elements; i++) {
+ data[i] = input_data[i];
+ printf("name %s index %d data[%d] = %ld\n", name, index, i, data[i]);
+ }
+ printf("==============================================\n");
+
+ std::string _name = std::string(safe_name);
+ std::string filename = "/src/AXIE/debug_inputs/input_" + _name + ".bin";
+ save_to_binary_file(filename, data); // float32
+ } else {
+ std::string _name = std::string(safe_name);
+
+ printf("In float tensor path");
+ printf("Dumping input tensor: %s, index=%d\n", _name.c_str(), index);
+ auto tensor_ptr = input_ptrs[index];
+ auto tensor_size = tensor_ptr.second;
+ printf("Tensor size: %d\n", tensor_size);
+ std::vector data(tensor_size / sizeof(float));
+
+ bool print_for_debug = true;
+ if (data.size() > 100) {
+ print_for_debug = false;
+ }
+
+ for (int i = 0; i < data.size(); i++) {
+ data[i] = *reinterpret_cast(input_ptrs[index].first + i * sizeof(float));
+ if (print_for_debug) {
+ printf("name %s index %d data[%d] = %f\n", _name.c_str(), index, i, data[i]);
+ }
+ }
+ std::string filename = "/src/AXIE/debug_inputs/input_" + _name + ".bin";
+ save_to_binary_file(filename, data);
+ }
+ }
+}
+
+void PerLayerKVDecoder::dump_output_tensors() {
+ printf("Dumping output tensors\n");
+ auto output_ptrs = _decoder_model->get_output_ptrs();
+
+ for (int index = 0; index < _decoder_model->_interpreter->outputs().size(); index++) {
+ auto name = _decoder_model->_interpreter->GetOutputName(index);
+ auto safe_name = normalize_name(name);
+ printf("safe name %s index %d\n", safe_name.c_str(), index);
+
+ if (safe_name == "logits") {
+ std::string _name = std::string(safe_name);
+
+ printf("In float tensor path");
+ printf("Dumping output tensor: %s, index=%d\n", _name.c_str(), index);
+ auto tensor_ptr = output_ptrs[index];
+ auto tensor_size = tensor_ptr.second;
+ printf("Tensor size: %d\n", tensor_size);
+ std::vector data(tensor_size / sizeof(float));
+
+ bool print_for_debug = true;
+ if (data.size() > 100) {
+ print_for_debug = false;
+ }
+ for (int i = 0; i < data.size(); i++) {
+ data[i] = *reinterpret_cast(output_ptrs[index].first + i * sizeof(float));
+ if (print_for_debug) {
+ printf("name %s index %d data[%d] = %f\n", _name.c_str(), index, i, data[i]);
+ }
+ }
+ std::string filename = "/src/AXIE/debug_inputs/output_" + _name + ".bin";
+ save_to_binary_file(filename, data);
+ }
+ }
+}
+
+void PerLayerKVDecoder::initialize_io_metadata() {
+ // self attention kv cache tensors
+ const auto& all_input_tensor_indices = metadata->get_input_tensor_indices(0);
+ const auto& all_output_tensor_indices = metadata->get_output_tensor_indices(0);
+
+ // store only the relative indices within the i/o tensor vectors
+ auto logits_indices = all_output_tensor_indices.at("logits");
+ output_tensor_indices["logits"] = logits_indices.second;
+
+ auto token_indices = all_input_tensor_indices.at("x");
+ input_tensor_indices["x"] = token_indices.second;
+
+ auto index_indices = all_input_tensor_indices.at("index");
+ input_tensor_indices["index"] = index_indices.second;
+
+ auto k_cache_cross_indices = all_input_tensor_indices.at("k_cache_cross");
+ input_tensor_indices["k_cache_cross"] = k_cache_cross_indices.second;
+
+ auto v_cache_cross_indices = all_input_tensor_indices.at("v_cache_cross");
+ input_tensor_indices["v_cache_cross"] = v_cache_cross_indices.second;
+
+ for (const auto& [name, indices] : all_input_tensor_indices) {
+ if (name == "x" || name == "index" || name == "k_cache_cross" || name == "v_cache_cross") {
+ continue;
+ }
+ kv_cache_input_tensor_indices[name] = indices.second;
+ }
+
+ for (const auto& [name, indices] : all_output_tensor_indices) {
+ if (name == "logits") {
+ continue;
+ }
+
+ if (name.find("k_cache_") == 0 || name.find("v_cache_") == 0) {
+ kv_cache_output_tensor_indices[name] = indices.second;
+ }
+ }
+
+ auto extractNumericSuffix = [](const std::string& s) -> int {
+ size_t pos = s.find_last_of('_');
+ if (pos == std::string::npos || pos == s.length() - 1) {
+ throw std::invalid_argument("String does not contain a valid numeric suffix");
+ }
+ return std::stoi(s.substr(pos + 1));
+ };
+
+ for (const auto& [name, index] : kv_cache_output_tensor_indices) {
+ if (name.find("k_cache_") == 0) {
+ auto layer_num = extractNumericSuffix(name);
+ auto input_name = "k_cache_self_" + std::to_string(layer_num);
+
+ kv_cache_io_tensor_names[input_name] = name;
+ } else if (name.find("v_cache_") == 0) {
+ auto layer_num = extractNumericSuffix(name);
+ auto input_name = "v_cache_self_" + std::to_string(layer_num);
+
+ kv_cache_io_tensor_names[input_name] = name;
+ }
+ }
+}
+
+void PerLayerKVDecoder::update_kv_cache() {
+ if (decoder_outputs.empty()) {
+ decoder_outputs = _decoder_model->get_output_ptrs();
+ }
+
+ for (const auto& [input_name, output_name] : kv_cache_io_tensor_names) {
+ auto input_tensor_index = kv_cache_input_tensor_indices[input_name];
+ auto output_tensor_index = kv_cache_output_tensor_indices[output_name];
+ _decoder_model->read_input_data(decoder_outputs[output_tensor_index].first, input_tensor_index);
+ }
+}
+
+void PerLayerKVDecoder::initialize_kv_cache() {
+ if (decoder_outputs.empty()) {
+ decoder_outputs = _decoder_model->get_output_ptrs();
+ }
+
+ auto input_ptrs = _decoder_model->get_input_ptrs();
+
+ for (const auto& [name, index] : kv_cache_input_tensor_indices) {
+ memset(input_ptrs[index].first, 0, input_ptrs[index].second);
+ }
+
+ for (const auto& [name, index] : kv_cache_output_tensor_indices) {
+ memset(decoder_outputs[index].first, 0, decoder_outputs[index].second);
+ }
+}
+
+float PerLayerKVDecoder::get_latency_median() { return _decoder_model->get_latency_median(); }
+
+float PerLayerKVDecoder::get_latency_avg() { return _decoder_model->get_latency_avg(); }
+
+std::unique_ptr PerLayerKVDecoder::get_latency_json() { return _decoder_model->get_latency_json(); }
+
+std::vector> PerLayerKVDecoder::get_input_ptrs() { return _decoder_model->get_input_ptrs(); }
+
+std::vector> PerLayerKVDecoder::get_output_ptrs() { return _decoder_model->get_output_ptrs(); }
+
+int PerLayerKVDecoder::get_inference_num() { return _decoder_model->get_inference_num(); }
+
+float PerLayerKVDecoder::get_latency_sum() { return _decoder_model->get_latency_sum(); }
diff --git a/cpp/src/Models/TextDecoder.hpp b/cpp/src/Models/TextDecoder.hpp
index 59f3590..6d0a31b 100644
--- a/cpp/src/Models/TextDecoder.hpp
+++ b/cpp/src/Models/TextDecoder.hpp
@@ -14,12 +14,14 @@ enum DecoderKVCacheType {
};
}
+class FlatBuffersMetadata;
+
// TODO:
// remove extraneous functions used for passthrough to MODEL_SUPER_CLASS
// to expedite integration
class TextDecoder {
public:
- virtual ~TextDecoder() = default;
+ virtual ~TextDecoder();
virtual void initialize_kv_cache() = 0;
virtual void read_input_data(char* data, int index) = 0;
@@ -38,15 +40,21 @@ class TextDecoder {
virtual float get_latency_median() = 0;
virtual std::unique_ptr get_latency_json() = 0;
+ virtual void dump_input_tensors() = 0;
+ virtual void dump_output_tensors() = 0;
+
protected:
+ std::unique_ptr metadata;
// TODO: modify to hold tflite model from tensorflow & use delegate manager
std::unique_ptr _decoder_model;
std::string _model_path;
+ std::vector> decoder_outputs;
};
class MonolithicKVDecoder : public TextDecoder {
public:
MonolithicKVDecoder(const std::string& tflite_model_path);
+ ~MonolithicKVDecoder();
void initialize_kv_cache() override;
bool initialize(std::string model_path, std::string lib_dir, std::string cache_dir, int backend,
@@ -65,8 +73,47 @@ class MonolithicKVDecoder : public TextDecoder {
float get_latency_median() override;
std::unique_ptr get_latency_json() override;
+ void dump_input_tensors() override;
+ void dump_output_tensors() override;
+};
+
+class PerLayerKVDecoder : public TextDecoder {
+ public:
+ PerLayerKVDecoder(const std::string& tflite_model_path);
+ ~PerLayerKVDecoder();
+ void initialize_kv_cache() override;
+
+ bool initialize(std::string model_path, std::string lib_dir, std::string cache_dir, int backend,
+ bool debug) override;
+ void uninitialize() override;
+ void read_input_data(char* input_data, int idx) override;
+ void invoke(bool measure_time = false) override;
+ void update_kv_cache() override;
+ std::vector> get_input_ptrs() override;
+ std::vector> get_output_ptrs() override;
+ void bind_input_tensor(char* input_data, const std::string& tensor_name) override;
+ std::pair get_logits_tensor() override;
+ int get_inference_num() override;
+ float get_latency_sum() override;
+ float get_latency_avg() override;
+ float get_latency_median() override;
+ std::unique_ptr get_latency_json() override;
+
+ void dump_input_tensors() override;
+ void dump_output_tensors() override;
+
private:
- std::vector> decoder_outputs;
+ void initialize_io_metadata();
+ // self attention kv cache tensors
+ std::unordered_map kv_cache_io_tensor_names; //
+ std::unordered_map kv_cache_input_tensor_indices; //
+ std::unordered_map kv_cache_output_tensor_indices; //
+
+ // non-kv cache tensors
+ std::unordered_map
+ input_tensor_indices; // , non-kv cache tensors
+ std::unordered_map
+ output_tensor_indices; // , non-kv cache tensors
};
class TextDecoderFactory {
diff --git a/cpp/src/Models/tflite_model.cpp b/cpp/src/Models/tflite_model.cpp
index b0eaffc..9c9884e 100644
--- a/cpp/src/Models/tflite_model.cpp
+++ b/cpp/src/Models/tflite_model.cpp
@@ -502,6 +502,9 @@ vector> TFLiteModel::get_input_ptrs() {
case kTfLiteInt32:
input_ptr = _interpreter->typed_input_tensor(idx);
break;
+ case kTfLiteInt64:
+ input_ptr = _interpreter->typed_input_tensor(idx);
+ break;
default:
fprintf(stderr, "Error: unsupported tensor type: %d\n", tensor->type);
exit(-1);
@@ -535,6 +538,16 @@ vector> TFLiteModel::get_output_ptrs() {
return _output_ptrs;
}
+std::pair TFLiteModel::get_output_with_name(const std::string& name) {
+ for (int idx = 0; idx < _interpreter->outputs().size(); idx++) {
+ auto* tensor = _interpreter->tensor(_interpreter->outputs()[idx]);
+ if (strcmp(tensor->name, name.c_str()) == 0) {
+ return make_pair(reinterpret_cast(tensor->data.f), tensor->bytes);
+ }
+ }
+ return make_pair(nullptr, 0);
+}
+
void TFLiteModel::invoke(bool measure_time) {
chrono::time_point before_exec;
if (measure_time) {
diff --git a/cpp/src/Models/tflite_model.hpp b/cpp/src/Models/tflite_model.hpp
index 2914b87..aa13ea0 100644
--- a/cpp/src/Models/tflite_model.hpp
+++ b/cpp/src/Models/tflite_model.hpp
@@ -54,6 +54,7 @@ class TFLiteModel {
void read_input_data(char* input_data, int idx);
std::vector> get_input_ptrs();
std::vector> get_output_ptrs();
+ std::pair get_output_with_name(const std::string& name);
void print_tensor_dims();
std::unique_ptr get_latency_json();
@@ -66,10 +67,12 @@ class TFLiteModel {
std::vector _latencies;
+ std::unique_ptr _interpreter;
+
protected:
std::mutex _mutex;
std::unique_ptr _model;
- std::unique_ptr _interpreter;
+
flatbuffers::FlatBufferBuilder _builder;
TfLiteDelegate* _delegate = nullptr;
std::string _model_name;
diff --git a/cpp/src/Text/Tokenizer.cpp b/cpp/src/Text/Tokenizer.cpp
index 81815ba..6e42172 100644
--- a/cpp/src/Text/Tokenizer.cpp
+++ b/cpp/src/Text/Tokenizer.cpp
@@ -54,105 +54,21 @@ void init_special_tokens(Tokenizer *tokenizer) {
tokenizer->specialTokens = special_tokens;
}
-void init_non_speech_tokens(Tokenizer *tokenizer) {
- std::vector non_speech_tokens{"!",
- "\"",
- "#",
- "(",
- ")",
- "*",
- "+",
- "/",
- ":",
- ";",
- "<",
- "=",
- ">",
- "@",
- "[",
- "\\",
- "]",
- "^",
- "_",
- "`",
- "{",
- "|",
- "}",
- "~",
- " (",
- " \"",
- "--",
- " -",
- " [",
- " '",
- " =",
- " |",
- " :",
- " /",
- " )",
- " <",
- " #",
- " +",
- " --",
- " {",
- " *",
- " }",
- " >",
- " ;",
- " ]",
- " @",
- " \\",
- "))",
- ">>",
- " `",
- " _",
- " ~",
- " (\"",
- "---",
- "(\"",
- " >>",
- " <<",
- " ^",
- "('",
- " ---",
- "}}",
- "]]",
- " >>>",
- "「",
- "」",
- " ((",
- " ))",
- " [[",
- "<<",
- "�",
- " (\'",
- "((",
- " �",
- ")))",
- " {{",
- "{{",
- "[[",
- "-(",
- ">>>",
- " }}",
- " 「",
- "『",
- "』",
- " )))",
- "-[",
- "<|startoftranscript|>",
- "<|translate|>",
- "<|transcribe|>",
- "<|startoflm|>",
- "<|startofprev|>",
- "<|nocaptions|>"};
- tokenizer->numNonSpeechTokens = non_speech_tokens.size();
+void init_non_speech_tokens(Tokenizer *tokenizer, const std::unique_ptr &config) {
+ std::vector non_speech_token_ids = config->at("suppress_tokens");
+
+ tokenizer->numNonSpeechTokens = non_speech_token_ids.size();
tokenizer->nonSpeechTokens = (int *)malloc(sizeof(int) * tokenizer->numNonSpeechTokens);
for (auto i = 0; i < tokenizer->numNonSpeechTokens; i++) {
- tokenizer->nonSpeechTokens[i] = tokenizer_convert_token_to_id(tokenizer, non_speech_tokens[i].c_str());
+ tokenizer->nonSpeechTokens[i] = non_speech_token_ids[i];
}
}
+bool tokenizer_is_multilingual(const Tokenizer *tokenizer) {
+ constexpr const int ENGLISH_VOCAB_SIZE = 51864;
+ return tokenizer->vocabSize != ENGLISH_VOCAB_SIZE;
+}
+
int tokenizer_convert_token_to_id(const Tokenizer *tokenizer, const char *token_string) {
// Encode token
CEncoding *encoding = tokenizer_encode(tokenizer->handle, token_string, false);
@@ -168,7 +84,7 @@ int tokenizer_convert_token_to_id(const Tokenizer *tokenizer, const char *token_
return static_cast(id);
}
-Tokenizer *tokenizer_init_from_file(const char *path) {
+Tokenizer *tokenizer_init_from_file(const char *path, const char *config_path) {
// Dynamically allocate tokenizer memory
Tokenizer *tokenizer = (Tokenizer *)malloc(sizeof(Tokenizer));
if (!tokenizer) {
@@ -178,6 +94,7 @@ Tokenizer *tokenizer_init_from_file(const char *path) {
// Load file to check existence and get vocabulary size.
std::ifstream file(path);
+ std::ifstream config_file(config_path);
if (!file) {
LOGE("Error loading provided tokenizer JSON. File may not exist!\n");
return NULL;
@@ -189,8 +106,13 @@ Tokenizer *tokenizer_init_from_file(const char *path) {
LOGE("Error parsing the provided tokenizer JSON!");
return NULL;
}
- tokenizer->vocabSize = (*json_file)["model"]["vocab"].size();
- LOGI("postproc vocab size: %d\n", tokenizer->vocabSize);
+
+ auto json_config = std::make_unique(json::parse(config_file));
+ if (!json_config) {
+ LOGE("Error parsing the provided tokenizer config JSON!");
+ return NULL;
+ }
+ tokenizer->vocabSize = (*json_config)["vocab_size"];
// Load vocabulary from tokenizer file
tokenizer->handle = tokenizer_from_file(path);
@@ -200,7 +122,7 @@ Tokenizer *tokenizer_init_from_file(const char *path) {
}
init_special_tokens(tokenizer);
- init_non_speech_tokens(tokenizer);
+ init_non_speech_tokens(tokenizer, json_config);
return tokenizer;
}
diff --git a/cpp/src/Text/Tokenizer.h b/cpp/src/Text/Tokenizer.h
index 6aede7e..ccb1c9c 100644
--- a/cpp/src/Text/Tokenizer.h
+++ b/cpp/src/Text/Tokenizer.h
@@ -32,11 +32,13 @@ typedef struct {
} Tokenizer;
// Initialize the tokenizer
-Tokenizer* tokenizer_init_from_file(const char* path);
+Tokenizer* tokenizer_init_from_file(const char* path, const char* config_path);
// Decode token IDs into a string
char* tokenizer_decode(const Tokenizer* tokenizer, const int* tokens, int tokenCount, bool skipSpecialTokens);
+bool tokenizer_is_multilingual(const Tokenizer* tokenizer);
+
// Convert token string to ID
int tokenizer_convert_token_to_id(const Tokenizer* tokenizer, const char* tokenString);
diff --git a/cpp/src/Text/post_proc.cpp b/cpp/src/Text/post_proc.cpp
index fada65b..eee56ca 100644
--- a/cpp/src/Text/post_proc.cpp
+++ b/cpp/src/Text/post_proc.cpp
@@ -70,11 +70,16 @@ int PostProcModel::process(int idx, float* logits, int logits_size, vector&
logits[_tokenizer->specialTokens.endOfTranscriptToken] = -1e9;
logits[_tokenizer->specialTokens.blankToken] = -1e9;
}
- for (int i = 0; i < _tokenizer->numNonSpeechTokens; i++) {
- auto token = _tokenizer->nonSpeechTokens[i];
- logits[token] = -1e9;
+
+ // TODO: unblocking multilingual models
+ if (!tokenizer_is_multilingual(_tokenizer)) {
+ for (int i = 0; i < _tokenizer->numNonSpeechTokens; i++) {
+ auto token = _tokenizer->nonSpeechTokens[i];
+ logits[token] = -1e9;
+ }
+
+ apply_timestamp_rules(logits, logits_size, decoded_tokens);
}
- apply_timestamp_rules(logits, logits_size, decoded_tokens);
// logits
read_input_data(reinterpret_cast(logits), 0);
auto inputs = get_input_ptrs();
diff --git a/cpp/src/TranscribeTask.cpp b/cpp/src/TranscribeTask.cpp
index 3d2dda2..a95ad59 100644
--- a/cpp/src/TranscribeTask.cpp
+++ b/cpp/src/TranscribeTask.cpp
@@ -211,19 +211,31 @@ void Runtime::init() {
LOGI("SoC: \tgeneric CPU (x86, arm64, etc) \n");
#endif
- // TODO: this should be using std::filesystem..
std::string tokenizer_json = config.get_model_path() + "/tokenizer.json";
+ std::string tokenizer_config_json = config.get_model_path() + "/config.json";
std::string melspectro_model = config.get_model_path() + "/MelSpectrogram.tflite";
std::string encoder_model = config.get_model_path() + "/AudioEncoder.tflite";
std::string decoder_model = config.get_model_path() + "/TextDecoder.tflite";
+ std::vector required_files = {tokenizer_json, tokenizer_config_json, melspectro_model, encoder_model,
+ decoder_model};
+ for (const auto& file : required_files) {
+ if (!std::filesystem::exists(file)) {
+ LOGE("File does not exist: %s", file.c_str());
+ std::stringstream ss;
+ ss << file << " : required file not found";
+ throw std::runtime_error(ss.str());
+ }
+ }
+
melspectro = make_unique("mel_spectrogram");
encoder = make_unique("whisper_encoder");
decoder = TextDecoderFactory::CreateFromFile(decoder_model);
// TODO move this to somewhere user accessible.
- tokenizer = tokenizer_init_from_file(tokenizer_json.c_str());
+ tokenizer = tokenizer_init_from_file(tokenizer_json.c_str(), tokenizer_config_json.c_str());
+
postproc = make_unique(tokenizer);
lib_dir = std::string(TRANSCRIBE_TASK_DEFAULT_LIB_DIR);
@@ -344,14 +356,25 @@ void Runtime::encode_decode_postproc(float timestamp) {
encoder->get_mutex()->lock();
encoder->read_input_data(melspectro_outputs[0].first, 0);
encoder->get_mutex()->unlock();
- // Perform encoder inference
encoder->invoke(true);
- const auto& k_cache_cross = encoder_outputs[0].first;
- const auto& v_cache_cross = encoder_outputs[1].first;
+ auto k_cache_cross = encoder->get_output_with_name("k_cache_cross");
+ if (k_cache_cross.first == nullptr) {
+ k_cache_cross = encoder->get_output_with_name("k_cache");
+ }
+ auto v_cache_cross = encoder->get_output_with_name("v_cache_cross");
+ if (v_cache_cross.first == nullptr) {
+ v_cache_cross = encoder->get_output_with_name("v_cache");
+ }
+
+ if (k_cache_cross.first == nullptr || v_cache_cross.first == nullptr) {
+ LOGE("Failed to get k_cache_cross or v_cache_cross");
+ return;
+ }
+
+ decoder->bind_input_tensor(k_cache_cross.first, "k_cache_cross");
+ decoder->bind_input_tensor(v_cache_cross.first, "v_cache_cross");
- decoder->bind_input_tensor(k_cache_cross, "k_cache_cross");
- decoder->bind_input_tensor(v_cache_cross, "v_cache_cross");
decoder->initialize_kv_cache();
constexpr const int MAX_DECODING_STEPS = 224;
diff --git a/scripts/Dockerfile b/scripts/Dockerfile
index fd73198..0bd3f91 100644
--- a/scripts/Dockerfile
+++ b/scripts/Dockerfile
@@ -37,7 +37,7 @@ ENV QNN_RUNTIME_ROOT=/opt/qnn-runtime
ENV AXIE_ROOT=/src/AXIE
ARG ANDROID_NDK_ZIP=$ANDROID_NDK_VERSION-linux.zip
-ARG BAZEL_INSTALLER=bazel-6.5.0-installer-linux-x86_64.sh
+ARG BAZEL_INSTALLER=bazel-7.4.1-installer-linux-x86_64.sh
ARG BAZEL_DIR=/opt/bazel
ARG QNN_RUNTIME=qnn-runtime-2.33.0.aar
ARG QNN_TFLITE_DELEGATE=qnn-litert-delegate-2.33.0.aar
diff --git a/scripts/adb_push.sh b/scripts/adb_push.sh
index 46dd3aa..ed7260f 100755
--- a/scripts/adb_push.sh
+++ b/scripts/adb_push.sh
@@ -2,22 +2,32 @@
# For licensing see accompanying LICENSE file.
# Copyright © 2024 Argmax, Inc. All rights reserved.
+# This script is used to push the files and folders to all connected Android devices
+#
+# whisperkit-cli: CLI app to run transcription on Android device, with files input.
+# libwhisperkit.so: main whisperkit lib
+# models: all models for whisperkit in models/ folder
+# test/jfk_441khz.m4a: sample test audio file for whisperkit
+# scripts/run_on_android.sh: script to run the app on Android device
+
+# Usage: make adb-push [forced]
+# or ./adb_push.sh [forced]
+# If forced is provided, the script will push the files and folders to
+# the Android device even if they already exist. Otherwise, it will
+# skip the push if the files already exist.
+
CURRENT_DIR="$(dirname "$(realpath "$0")")"
SOURCE_DIR="$CURRENT_DIR/.."
WHISPERKIT_CLI="$SOURCE_DIR/build/android/whisperkit-cli"
AXIE_TFLITE_LIB="$SOURCE_DIR/build/android/libwhisperkit.so"
LOCAL_LIBS="$SOURCE_DIR/external/libs/android"
-LOCAL_TINY_DIR="$SOURCE_DIR/models/openai_whisper-tiny"
-LOCAL_BASE_DIR="$SOURCE_DIR/models/openai_whisper-base"
-LOCAL_SMALL_DIR="$SOURCE_DIR/models/openai_whisper-small"
+LOCAL_MODELS_DIR="$SOURCE_DIR/models"
DEVICE_BIN_DIR="/data/local/tmp/bin"
DEVICE_LIB_DIR="/data/local/tmp/lib"
DEVICE_SDROOT_DIR="/sdcard/argmax/tflite"
-DEVICE_TINY_DIR="${DEVICE_SDROOT_DIR}/models/openai_whisper-tiny"
-DEVICE_BASE_DIR="${DEVICE_SDROOT_DIR}/models/openai_whisper-base"
-DEVICE_SMALL_DIR="${DEVICE_SDROOT_DIR}/models/openai_whisper-small"
+DEVICE_MODELS_DIR="${DEVICE_SDROOT_DIR}/models"
DEVICE_INPUTS_DIR="${DEVICE_SDROOT_DIR}/inputs"
EXEC_SCRIPT="$SOURCE_DIR/scripts/run_on_android.sh"
@@ -88,8 +98,6 @@ do
adb -s $DEVICE shell "chmod 777 $DEVICE_SDROOT_DIR/run_on_android.sh"
adb -s $DEVICE shell "chmod 777 $DEVICE_BIN_DIR/whisperkit-cli"
- push_if_not_exists "$LOCAL_TINY_DIR" "$DEVICE_TINY_DIR" $FORCED
- push_if_not_exists "$LOCAL_BASE_DIR" "$DEVICE_BASE_DIR" $FORCED
- push_if_not_exists "$LOCAL_SMALL_DIR" "$DEVICE_SMALL_DIR" $FORCED
+ push_if_not_exists "$LOCAL_MODELS_DIR" "$DEVICE_MODELS_DIR" $FORCED
done
diff --git a/scripts/build_tensorflow.sh b/scripts/build_tensorflow.sh
index a3593b3..6b99966 100755
--- a/scripts/build_tensorflow.sh
+++ b/scripts/build_tensorflow.sh
@@ -22,8 +22,22 @@ export CC_OPT_FLAGS=-Wno-sign-compare
# nightly tf commit needs bazel 7.4.1
USING_NIGHTLY_TF_COMMIT=1
+REQUIRED_BAZEL_VERSION="7.4.1"
+BAZEL_BIN_DIR="/usr/local/lib/bazel/bin"
+BAZEL_FILENAME="bazel-${REQUIRED_BAZEL_VERSION}-linux-x86_64"
+BAZEL_PATH="${BAZEL_BIN_DIR}/${BAZEL_FILENAME}"
+
if [ "$USING_NIGHTLY_TF_COMMIT" = "1" ]; then
- cd "/usr/local/lib/bazel/bin" && curl -fLO https://releases.bazel.build/7.4.1/release/bazel-7.4.1-linux-x86_64 && chmod +x bazel-7.4.1-linux-x86_64
+ if [ -f "$BAZEL_PATH" ]; then
+ echo "Bazel $REQUIRED_BAZEL_VERSION already exists at $BAZEL_PATH. Skipping download."
+ else
+ echo "Downloading Bazel $REQUIRED_BAZEL_VERSION..."
+ mkdir -p "$BAZEL_BIN_DIR"
+ cd "$BAZEL_BIN_DIR" || exit 1
+ curl -fLO "https://releases.bazel.build/${REQUIRED_BAZEL_VERSION}/release/${BAZEL_FILENAME}"
+ chmod +x "$BAZEL_FILENAME"
+ echo "Bazel $REQUIRED_BAZEL_VERSION downloaded to $BAZEL_PATH."
+ fi
fi
if [ "$PLATFORM" = "android" ]; then
diff --git a/scripts/dev_env.sh b/scripts/dev_env.sh
index 0314901..248b23d 100755
--- a/scripts/dev_env.sh
+++ b/scripts/dev_env.sh
@@ -42,7 +42,7 @@ if ! $(docker image inspect $IMAGE_NAME > /dev/null 2>&1) || $FORCE_REBUILD; the
BUILD_DIR="$SOURCE_DIR/.source"
echo "Checking and retrieving dependencies..."
if command -v aria2c &> /dev/null; then
- aria2c $ARIA_OPTIONS -d $BUILD_DIR https://github.com/bazelbuild/bazel/releases/download/6.5.0/bazel-6.5.0-installer-linux-x86_64.sh
+ aria2c $ARIA_OPTIONS -d $BUILD_DIR https://github.com/bazelbuild/bazel/releases/download/7.4.1/bazel-7.4.1-installer-linux-x86_64.sh
aria2c $ARIA_OPTIONS -d $BUILD_DIR https://dl.google.com/android/repository/android-ndk-r25c-linux.zip
aria2c $ARIA_OPTIONS -d $BUILD_DIR https://dl.google.com/android/repository/commandlinetools-linux-11076708_latest.zip
aria2c $ARIA_OPTIONS -d $BUILD_DIR https://repo1.maven.org/maven2/com/qualcomm/qti/qnn-runtime/2.33.0/qnn-runtime-2.33.0.aar
@@ -53,10 +53,11 @@ if ! $(docker image inspect $IMAGE_NAME > /dev/null 2>&1) || $FORCE_REBUILD; the
fi
if [ ! -d "$BUILD_DIR/tensorflow" ]; then
echo "Cloning tensorflow..."
- NIGHTLY_TF_PIN=e2ccfb8b4e0
- git clone --no-checkout --filter=blob:none https://github.com/tensorflow/tensorflow.git "$BUILD_DIR/tensorflow"
+ NIGHTLY_TF_PIN=e2ccfb8b4e03baee112efbcfe824dbac995afb38
+ # Use shallow clone to reduce size
+ git clone --depth 1 --branch=nightly https://github.com/tensorflow/tensorflow.git "$BUILD_DIR/tensorflow"
cd $BUILD_DIR/tensorflow
- git fetch --depth=365 origin nightly
+ git fetch origin $NIGHTLY_TF_PIN --depth=1
git checkout $NIGHTLY_TF_PIN
cd -
fi
diff --git a/scripts/download_models.sh b/scripts/download_models.sh
index 398c429..f73430b 100755
--- a/scripts/download_models.sh
+++ b/scripts/download_models.sh
@@ -11,56 +11,127 @@ MODELS_DIR="$SOURCE_DIR/models"
ARIA_OPTIONS="-x 8 -s 8 --continue --file-allocation=none"
# Set directories
+QCOM_TINY_EN_MODELS_DIR="$MODELS_DIR/quic_openai_whisper-tiny.en"
+QCOM_BASE_EN_MODELS_DIR="$MODELS_DIR/quic_openai_whisper-base.en"
+QCOM_SMALL_EN_MODELS_DIR="$MODELS_DIR/quic_openai_whisper-small.en"
+
TINY_MODELS_DIR="$MODELS_DIR/openai_whisper-tiny"
BASE_MODELS_DIR="$MODELS_DIR/openai_whisper-base"
-SMALL_MODELS_DIR="$MODELS_DIR/openai_whisper-small"
-
-function SAFE_MODEL_DIRECTORY(){
- if [ ! -d "${1}" ]; then
- echo "mkdir ${1} .."
- mkdir -p "${1}"
- fi
-}
-
-SAFE_MODEL_DIRECTORY $TINY_MODELS_DIR
-SAFE_MODEL_DIRECTORY $BASE_MODELS_DIR
-SAFE_MODEL_DIRECTORY $SMALL_MODELS_DIR
+TINY_EN_MODELS_DIR="$MODELS_DIR/openai_whisper-tiny.en"
+BASE_EN_MODELS_DIR="$MODELS_DIR/openai_whisper-base.en"
# Download Whisper auxiliary models
HF_ARGMAX_URL="https://huggingface.co/argmaxinc/whisperkit-litert/resolve/main"
-HF_OPENAI_TINY_URL="https://huggingface.co/openai/whisper-tiny.en/resolve/main"
-
-if [ ! -f $TINY_MODELS_DIR/tokenizer.json ]; then
- aria2c $ARIA_OPTIONS -d "$TINY_MODELS_DIR" -o tokenizer.json $HF_OPENAI_TINY_URL/tokenizer.json
- aria2c $ARIA_OPTIONS -d "$TINY_MODELS_DIR" -o MelSpectrogram.tflite $HF_ARGMAX_URL/quic_openai_whisper-tiny.en/MelSpectrogram.tflite
-fi
-if [ ! -f $BASE_MODELS_DIR/tokenizer.json ]; then
- cp $TINY_MODELS_DIR/tokenizer.json $BASE_MODELS_DIR/.
- cp $TINY_MODELS_DIR/MelSpectrogram.tflite $BASE_MODELS_DIR/.
-fi
-if [ ! -f $SMALL_MODELS_DIR/tokenizer.json ]; then
- cp $TINY_MODELS_DIR/tokenizer.json $SMALL_MODELS_DIR/.
- cp $TINY_MODELS_DIR/MelSpectrogram.tflite $SMALL_MODELS_DIR/.
-fi
-
# Download Qualcomm models
HF_QUALCOMM_URL="https://huggingface.co/qualcomm"
-if [ ! -f $TINY_MODELS_DIR/TextDecoder.tflite ]; then
- aria2c $ARIA_OPTIONS -d "$TINY_MODELS_DIR" -o TextDecoder.tflite $HF_QUALCOMM_URL/Whisper-Tiny-En/resolve/main/WhisperDecoder.tflite
-fi
-if [ ! -f $TINY_MODELS_DIR/AudioEncoder.tflite ]; then
- aria2c $ARIA_OPTIONS -d "$TINY_MODELS_DIR" -o AudioEncoder.tflite $HF_QUALCOMM_URL/Whisper-Tiny-En/resolve/main/WhisperEncoder.tflite
-fi
-if [ ! -f $BASE_MODELS_DIR/TextDecoder.tflite ]; then
- aria2c $ARIA_OPTIONS -d "$BASE_MODELS_DIR" -o TextDecoder.tflite $HF_QUALCOMM_URL/Whisper-Base-En/resolve/main/WhisperDecoder.tflite
-fi
-if [ ! -f $BASE_MODELS_DIR/AudioEncoder.tflite ]; then
- aria2c $ARIA_OPTIONS -d "$BASE_MODELS_DIR" -o AudioEncoder.tflite $HF_QUALCOMM_URL/Whisper-Base-En/resolve/main/WhisperEncoder.tflite
-fi
-if [ ! -f $SMALL_MODELS_DIR/TextDecoder.tflite ]; then
- aria2c $ARIA_OPTIONS -d "$SMALL_MODELS_DIR" -o TextDecoder.tflite $HF_QUALCOMM_URL/Whisper-Small-En/resolve/main/WhisperDecoder.tflite
-fi
-if [ ! -f $SMALL_MODELS_DIR/AudioEncoder.tflite ]; then
- aria2c $ARIA_OPTIONS -d "$SMALL_MODELS_DIR" -o AudioEncoder.tflite $HF_QUALCOMM_URL/Whisper-Small-En/resolve/main/WhisperEncoder.tflite
-fi
+HF_OPENAI_TINY_EN_URL="https://huggingface.co/openai/whisper-tiny.en/resolve/main"
+HF_OPENAI_TINY_URL="https://huggingface.co/openai/whisper-tiny/resolve/main"
+HF_OPENAI_BASE_EN_URL="https://huggingface.co/openai/whisper-base.en/resolve/main"
+HF_OPENAI_BASE_URL="https://huggingface.co/openai/whisper-base/resolve/main"
+HF_OPENAI_SMALL_EN_URL="https://huggingface.co/openai/whisper-small.en/resolve/main"
+HF_OPENAI_SMALL_EN_URL="https://huggingface.co/openai/whisper-small/resolve/main"
+
+HF_QCOM_TINY_EN_URL=$HF_QUALCOMM_URL/Whisper-Tiny-En/resolve/main
+HF_QCOM_BASE_EN_URL=$HF_QUALCOMM_URL/Whisper-Base-En/resolve/main
+HF_QCOM_SMALL_EN_URL=$HF_QUALCOMM_URL/Whisper-Small-En/resolve/main
+
+HF_AX_TINY_EN_URL=$HF_ARGMAX_URL/openai_whisper-tiny.en
+HF_AX_TINY_URL=$HF_ARGMAX_URL/openai_whisper-tiny
+HF_AX_BASE_EN_URL=$HF_ARGMAX_URL/openai_whisper-base.en
+HF_AX_BASE_URL=$HF_ARGMAX_URL/openai_whisper-base
+
+QCOM_MODELS=(
+ "$HF_QCOM_TINY_EN_URL"
+ "$HF_QCOM_BASE_EN_URL"
+ "$HF_QCOM_SMALL_EN_URL"
+)
+
+ARGMAX_MODELS=(
+ "$HF_AX_TINY_EN_URL"
+ "$HF_AX_TINY_URL"
+ "$HF_AX_BASE_EN_URL"
+ "$HF_AX_BASE_URL"
+)
+
+ARGMAX_MODEL_DIRECTORIES=(
+ "$TINY_EN_MODELS_DIR"
+ "$TINY_MODELS_DIR"
+ "$BASE_EN_MODELS_DIR"
+ "$BASE_MODELS_DIR"
+)
+
+
+QCOM_MODEL_DIRECTORIES=(
+ "$QCOM_TINY_EN_MODELS_DIR"
+ "$QCOM_BASE_EN_MODELS_DIR"
+ "$QCOM_SMALL_EN_MODELS_DIR"
+)
+
+MODEL_DIRECTORIES=(
+ "$QCOM_TINY_EN_MODELS_DIR"
+ "$QCOM_BASE_EN_MODELS_DIR"
+ "$QCOM_SMALL_EN_MODELS_DIR"
+ "$TINY_MODELS_DIR"
+ "$BASE_MODELS_DIR"
+ "$TINY_EN_MODELS_DIR"
+ "$BASE_EN_MODELS_DIR"
+)
+
+TOKENIZER_ENDPOINTS=(
+ "$HF_OPENAI_TINY_EN_URL"
+ "$HF_OPENAI_BASE_EN_URL"
+ "$HF_OPENAI_SMALL_EN_URL"
+ "$HF_OPENAI_TINY_URL"
+ "$HF_OPENAI_BASE_URL"
+ "$HF_OPENAI_TINY_EN_URL"
+ "$HF_OPENAI_BASE_EN_URL"
+)
+
+
+ARIA2_OPTS="--quiet=true --summary-interval=0 --download-result=hide --continue=true --max-connection-per-server=4"
+
+melspec_endpoint="${HF_ARGMAX_URL}/openai_whisper-tiny/MelSpectrogram.tflite"
+aria2c $ARIA2_OPTS -d "/tmp/" -o MelSpectrogram.tflite $melspec_endpoint
+
+for i in "${!MODEL_DIRECTORIES[@]}"; do
+ model_dir="${MODEL_DIRECTORIES[$i]}"
+ tokenizer_endpoint="${TOKENIZER_ENDPOINTS[$i]}"
+ echo "Creating directory: $model_dir"
+ mkdir -p "$model_dir"
+
+ echo "Downloading tokenizer.json and config.json from $tokenizer_endpoint to $model_dir"
+
+ aria2c $ARIA2_OPTS -d "$model_dir" -o tokenizer.json "$tokenizer_endpoint/tokenizer.json"
+ aria2c $ARIA2_OPTS -d "$model_dir" -o config.json "$tokenizer_endpoint/config.json"
+ echo "Done with $model_dir"
+done
+
+# Qualcomm models: rename to [AudioEncoder.tflite, TextDecoder.tflite]
+#
+# Argmax models: already named as [AudioEncoder.tflite, TextDecoder.tflite]
+
+for i in "${!QCOM_MODEL_DIRECTORIES[@]}"; do
+ model_dir="${QCOM_MODEL_DIRECTORIES[$i]}"
+ model_endpoint="${QCOM_MODELS[$i]}"
+ echo "Downloading QCOM [AudioEncoder, TextDecoder] from $model_endpoint to $model_dir"
+ echo "Downloading Encoder: ${model_endpoint}/WhisperEncoder.tflite"
+ echo "Downloading Decoder: ${model_endpoint}/WhisperDecoder.tflite"
+ aria2c $ARIA2_OPTS -d "$model_dir" -o TextDecoder.tflite $model_endpoint/WhisperDecoder.tflite
+ aria2c $ARIA2_OPTS -d "$model_dir" -o AudioEncoder.tflite $model_endpoint/WhisperEncoder.tflite
+ cp /tmp/MelSpectrogram.tflite "$model_dir/MelSpectrogram.tflite"
+ echo "Done with $model_dir"
+done
+
+for i in "${!ARGMAX_MODEL_DIRECTORIES[@]}"; do
+ model_dir="${ARGMAX_MODEL_DIRECTORIES[$i]}"
+ model_endpoint="${ARGMAX_MODELS[$i]}"
+ echo "Downloading Argmax [AudioEncoder, TextDecoder] from $model_endpoint to $model_dir"
+ echo "Downloading Encoder: ${model_endpoint}/AudioEncoder.tflite"
+ echo "Downloading Decoder: ${model_endpoint}/TextDecoder.tflite"
+ echo "Downloading MelSpec: ${model_endpoint}/MelSpectrogram.tflite"
+ aria2c $ARIA2_OPTS -d "$model_dir" -o TextDecoder.tflite ${model_endpoint}/TextDecoder.tflite
+ aria2c $ARIA2_OPTS -d "$model_dir" -o AudioEncoder.tflite ${model_endpoint}/AudioEncoder.tflite
+ aria2c $ARIA2_OPTS -d "$model_dir" -o MelSpectrogram.tflite ${model_endpoint}/MelSpectrogram.tflite
+ echo "Done with $model_dir"
+done