diff --git a/README.md b/README.md index 79cd7ea..2a971fd 100644 --- a/README.md +++ b/README.md @@ -8,15 +8,19 @@ WhisperKit -# WhisperKit Android (Beta) +# WhisperKit Android + +[![Tests](https://github.com/argmaxinc/whisperkitandroid/actions/workflows/pr-checks.yml/badge.svg)](https://github.com/argmaxinc/whisperkitandroid/actions/workflows/pr-checks.yml) +[![License](https://img.shields.io/github/license/argmaxinc/whisperkitandroid?logo=github&logoColor=969da4&label=License&labelColor=353a41&color=32d058)](LICENSE.md) +[![Maven Central](https://img.shields.io/maven-central/v/com.argmaxinc/whisperkit?logo=sonatype&logoColor=969da4&label=Maven%20Central&labelColor=353a41&color=32d058)](https://central.sonatype.com/artifact/com.argmaxinc/whisperkit) +[![Discord](https://img.shields.io/discord/1171912382512115722?style=flat&logo=discord&logoColor=969da4&label=Discord&labelColor=353a41&color=32d058&link=https%3A%2F%2Fdiscord.gg%2FG5F5GZGecC)](https://discord.gg/G5F5GZGecC) -[![Maven Central](https://img.shields.io/maven-central/v/com.argmaxinc/whisperkit?color=32d058)](https://central.sonatype.com/artifact/com.argmaxinc/whisperkit) WhisperKit Android brings Foundation Models On Device for Automatic Speech Recognition. It extends the performance and feature set of [WhisperKit](https://github.com/argmaxinc/WhisperKit) from Apple platforms to Android and Linux. The current feature set is a subset of the iOS counterpart, but we are continuing to invest in Android and now welcome contributions from the community. -[Example App (Coming Soon)] [[Blog Post]](https://takeargmax.com/blog/android) [[Python Tools Repo]](https://github.com/argmaxinc/whisperkittools) +[[Example App]](https://play.google.com/store/apps/details?id=com.argmaxinc.whisperax) [[Blog Post]](https://takeargmax.com/blog/android) [[Python Tools Repo]](https://github.com/argmaxinc/whisperkittools) ## Table of Contents @@ -37,7 +41,7 @@ To use WhisperKit in your Android app, you need to: ```kotlin dependencies { // 1. WhisperKit SDK - implementation("com.argmaxinc:whisperkit:0.3.0") // Check badge above for latest version + implementation("com.argmaxinc:whisperkit:0.3.2") // Check badge above for latest version // 2. QNN dependencies for hardware acceleration implementation("com.qualcomm.qnn:qnn-runtime:2.34.0") @@ -73,7 +77,7 @@ class YourActivity : AppCompatActivity() { whisperKit = WhisperKit.Builder() .setModel(WhisperKit.OPENAI_TINY_EN) .setApplicationContext(applicationContext) - .setCallback { what, timestamp, msg -> + .setCallback { what, result -> // Handle transcription output when (what) { WhisperKit.TextOutputCallback.MSG_INIT -> { @@ -81,10 +85,14 @@ class YourActivity : AppCompatActivity() { } WhisperKit.TextOutputCallback.MSG_TEXT_OUT -> { // New transcription available - val text = msg - val time = timestamp + val fullText = result.text + val segments = result.segments // Process the transcribed text as it becomes available // This callback will be called multiple times as more audio is processed + segments.forEach { segment -> + // Process each segment + val segmentText = segment.text + } } WhisperKit.TextOutputCallback.MSG_CLOSE -> { // Cleanup complete diff --git a/android/config/detekt.yml b/android/config/detekt.yml index 9b823e6..11bfe92 100644 --- a/android/config/detekt.yml +++ b/android/config/detekt.yml @@ -14,7 +14,7 @@ complexity: thresholdInObjects: 10 LongParameterList: functionThreshold: 8 - constructorThreshold: 7 + constructorThreshold: 8 CyclomaticComplexMethod: threshold: 20 NestedBlockDepth: diff --git a/android/examples/WhisperAX/build.gradle.kts b/android/examples/WhisperAX/build.gradle.kts index 794608d..d83199b 100644 --- a/android/examples/WhisperAX/build.gradle.kts +++ b/android/examples/WhisperAX/build.gradle.kts @@ -12,7 +12,7 @@ android { applicationId = "com.argmaxinc.whisperax" minSdk = 26 targetSdk = 35 - versionCode = 3 + versionCode = 6 versionName = "0.1.0" testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner" diff --git a/android/examples/WhisperAX/src/main/java/com/argmaxinc/whisperax/ComputeUnitsView.kt b/android/examples/WhisperAX/src/main/java/com/argmaxinc/whisperax/ComputeUnitsView.kt index 3fb1717..95c4a7f 100644 --- a/android/examples/WhisperAX/src/main/java/com/argmaxinc/whisperax/ComputeUnitsView.kt +++ b/android/examples/WhisperAX/src/main/java/com/argmaxinc/whisperax/ComputeUnitsView.kt @@ -31,6 +31,7 @@ import androidx.compose.material3.Surface import androidx.compose.material3.Text import androidx.compose.runtime.Composable import androidx.compose.runtime.collectAsState +import androidx.compose.runtime.derivedStateOf import androidx.compose.runtime.getValue import androidx.compose.runtime.mutableStateOf import androidx.compose.runtime.remember @@ -40,6 +41,7 @@ import androidx.compose.ui.Modifier import androidx.compose.ui.draw.alpha import androidx.compose.ui.draw.rotate import androidx.compose.ui.unit.dp +import com.argmaxinc.whisperax.WhisperViewModel.Companion.MODELS_SUPPORTING_NPU import com.argmaxinc.whisperkit.ExperimentalWhisperKit import com.argmaxinc.whisperkit.WhisperKit @@ -50,11 +52,18 @@ enum class ComputeUnits(val displayName: String, val backendValue: Int) { CPU_AND_NPU("NPU", WhisperKit.Builder.CPU_AND_NPU), } +@OptIn(ExperimentalWhisperKit::class) @Composable fun ComputeUnitsView(viewModel: WhisperViewModel) { val modelState by viewModel.modelState.collectAsState() val encoderState by viewModel.encoderState.collectAsState() val decoderState by viewModel.decoderState.collectAsState() + val selectedModel by viewModel.selectedModel.collectAsState() + val shouldEnableNPUForEncoderDecoder by remember { + derivedStateOf { + selectedModel in MODELS_SUPPORTING_NPU + } + } val isEnabled = modelState == ModelState.LOADED || modelState == ModelState.UNLOADED var whisperKitExpanded by remember { mutableStateOf(true) } @@ -75,6 +84,7 @@ fun ComputeUnitsView(viewModel: WhisperViewModel) { currentState = encoderState, currentUnit = viewModel.encoderComputeUnits.collectAsState().value, onUnitSelected = { viewModel.setEncoderComputeUnits(it) }, + shouldEnableNPU = shouldEnableNPUForEncoderDecoder, enabled = isEnabled, ) @@ -85,6 +95,7 @@ fun ComputeUnitsView(viewModel: WhisperViewModel) { currentState = decoderState, currentUnit = viewModel.decoderComputeUnits.collectAsState().value, onUnitSelected = { viewModel.setDecoderComputeUnits(it) }, + shouldEnableNPU = shouldEnableNPUForEncoderDecoder, enabled = isEnabled, ) } @@ -185,6 +196,7 @@ fun ComputeUnitRow( currentState: ModelState, currentUnit: ComputeUnits, onUnitSelected: (ComputeUnits) -> Unit, + shouldEnableNPU: Boolean = true, enabled: Boolean = true, ) { val infiniteTransition = rememberInfiniteTransition(label = "loading animation") @@ -248,7 +260,11 @@ fun ComputeUnitRow( expanded = expanded, onDismissRequest = { expanded = false }, ) { - ComputeUnits.values().forEach { unit -> + if (shouldEnableNPU) { + listOf(ComputeUnits.CPU_ONLY, ComputeUnits.CPU_AND_GPU, ComputeUnits.CPU_AND_NPU) + } else { + listOf(ComputeUnits.CPU_ONLY, ComputeUnits.CPU_AND_GPU) + }.forEach { unit -> DropdownMenuItem( text = { Text(unit.displayName) }, onClick = { diff --git a/android/examples/WhisperAX/src/main/java/com/argmaxinc/whisperax/ModelSelectorView.kt b/android/examples/WhisperAX/src/main/java/com/argmaxinc/whisperax/ModelSelectorView.kt index 331bbfb..791fd9c 100644 --- a/android/examples/WhisperAX/src/main/java/com/argmaxinc/whisperax/ModelSelectorView.kt +++ b/android/examples/WhisperAX/src/main/java/com/argmaxinc/whisperax/ModelSelectorView.kt @@ -34,6 +34,7 @@ import androidx.compose.material3.Icon import androidx.compose.material3.IconButton import androidx.compose.material3.LinearProgressIndicator import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.MenuAnchorType import androidx.compose.material3.OutlinedTextField import androidx.compose.material3.Surface import androidx.compose.material3.Text @@ -111,7 +112,8 @@ fun ModelSelectorView(viewModel: WhisperViewModel) { }, modifier = Modifier .fillMaxWidth() - .weight(1f), + .weight(1f) + .menuAnchor(MenuAnchorType.PrimaryNotEditable), ) ExposedDropdownMenu( diff --git a/android/examples/WhisperAX/src/main/java/com/argmaxinc/whisperax/WhisperViewModel.kt b/android/examples/WhisperAX/src/main/java/com/argmaxinc/whisperax/WhisperViewModel.kt index 897f32c..379c424 100644 --- a/android/examples/WhisperAX/src/main/java/com/argmaxinc/whisperax/WhisperViewModel.kt +++ b/android/examples/WhisperAX/src/main/java/com/argmaxinc/whisperax/WhisperViewModel.kt @@ -14,6 +14,8 @@ import androidx.compose.runtime.mutableStateListOf import androidx.lifecycle.ViewModel import androidx.lifecycle.viewModelScope import com.argmaxinc.whisperkit.ExperimentalWhisperKit +import com.argmaxinc.whisperkit.TranscriptionResult +import com.argmaxinc.whisperkit.TranscriptionSegment import com.argmaxinc.whisperkit.WhisperKit import com.argmaxinc.whisperkit.WhisperKit.TextOutputCallback import com.argmaxinc.whisperkit.WhisperKitException @@ -33,22 +35,13 @@ import java.text.SimpleDateFormat import java.util.Date import java.util.Locale -data class TranscriptionSegment( - val text: String, - val start: Float, - val end: Float, - val tokens: List = emptyList(), -) - -data class TranscriptionResult( - val text: String = "", - val segments: List = emptyList(), -) - @OptIn(ExperimentalWhisperKit::class) class WhisperViewModel : ViewModel() { companion object { const val TAG = "WhisperViewModel" + + // Models currently supporting NPU backend, don't enable NPU for other models + val MODELS_SUPPORTING_NPU = listOf(WhisperKit.Builder.QUALCOMM_TINY_EN, WhisperKit.Builder.QUALCOMM_BASE_EN) } private lateinit var appContext: Context @@ -190,10 +183,11 @@ class WhisperViewModel : ViewModel() { cacheDir = context.cacheDir.absolutePath } - fun onTextOutput(what: Int, timestamp: Float, msg: String) { + fun onTextOutput(what: Int, result: TranscriptionResult) { + val segments = result.segments when (what) { TextOutputCallback.MSG_INIT -> { - Log.i(MainActivity.TAG, "TFLite initialized: $msg") + Log.i(MainActivity.TAG, "TFLite initialized: ${result.text}") startTime = System.currentTimeMillis() _pipelineStart.value = startTime.toDouble() / 1000.0 _isInitializing.value = false @@ -201,14 +195,13 @@ class WhisperViewModel : ViewModel() { TextOutputCallback.MSG_TEXT_OUT -> { Log.i(MainActivity.TAG, "TEXT OUT THREAD") - if (msg.isNotEmpty()) { + if (segments.isNotEmpty()) { if (!firstTokenReceived) { firstTokenReceived = true firstTokenTimestamp = System.currentTimeMillis() _firstTokenTime.value = (firstTokenTimestamp - startTime).toDouble() / 1000.0 } - - val newTokens = msg.length / 4 + val newTokens = segments.joinToString("") { it.text }.length / 4 totalTokens += newTokens val currentTime = System.currentTimeMillis() @@ -220,14 +213,14 @@ class WhisperViewModel : ViewModel() { } lastTokenTimestamp = currentTime - updateTranscript(msg) + updateTranscript(segments) } } TextOutputCallback.MSG_CLOSE -> { Log.i(MainActivity.TAG, "Transcription completed.") - if (msg.isNotEmpty()) { - val newTokens = msg.length / 4 + if (segments.isNotEmpty()) { + val newTokens = segments.joinToString("") { it.text }.length / 4 totalTokens += newTokens val totalTime = (System.currentTimeMillis() - startTime).toDouble() / 1000.0 @@ -236,8 +229,7 @@ class WhisperViewModel : ViewModel() { updateRealtimeMetrics(totalTime) } - - updateTranscript(msg) + updateTranscript(segments) } } @@ -247,25 +239,8 @@ class WhisperViewModel : ViewModel() { } } - private fun updateTranscript(chunkText: String, withTimestamps: Boolean = false) { - var processedText = chunkText - - val timestamps = if (withTimestamps) { - val timestampPattern = "<\\|(\\d+\\.\\d+)\\|>".toRegex() - val timestampMatches = timestampPattern.findAll(chunkText).toList() - timestampMatches.map { it.groupValues[1].toFloat() } - } else { - emptyList() - } - - if (!withTimestamps) { - processedText = processedText - .replace("<\\|[^>]*\\|>".toRegex(), "") - .trim() - } else { - processedText = processedText.trim() - } - + private fun updateTranscript(segments: List) { + val processedText = segments.joinToString("") { it.text } if (processedText.isNotEmpty()) { if (allText.isNotEmpty()) { allText.append("\n") @@ -284,13 +259,12 @@ class WhisperViewModel : ViewModel() { fun listModels() { viewModelScope.launch { val modelDirs = listOf( - // TODO: enable when models are ready - // WhisperKit.Builder.OPENAI_TINY_EN, - // WhisperKit.Builder.OPENAI_BASE_EN, - // WhisperKit.Builder.OPENAI_SMALL_EN, WhisperKit.Builder.QUALCOMM_TINY_EN, WhisperKit.Builder.QUALCOMM_BASE_EN, - // WhisperKit.Builder.QUALCOMM_SMALL_EN + WhisperKit.Builder.OPENAI_TINY_EN, + WhisperKit.Builder.OPENAI_BASE_EN, + WhisperKit.Builder.OPENAI_TINY, + WhisperKit.Builder.OPENAI_BASE, ) availableModels.clear() availableModels.addAll(modelDirs) @@ -364,6 +338,21 @@ class WhisperViewModel : ViewModel() { fun selectModel(model: String) { _selectedModel.value = model + if (model in MODELS_SUPPORTING_NPU) { + _encoderComputeUnits.update { + ComputeUnits.CPU_AND_NPU + } + _decoderComputeUnits.update { + ComputeUnits.CPU_AND_NPU + } + } else { + _encoderComputeUnits.update { + ComputeUnits.CPU_ONLY + } + _decoderComputeUnits.update { + ComputeUnits.CPU_ONLY + } + } _modelState.value = ModelState.UNLOADED _encoderState.value = ModelState.UNLOADED _decoderState.value = ModelState.UNLOADED diff --git a/android/whisperkit/build.gradle.kts b/android/whisperkit/build.gradle.kts index f0e66d0..c2394a0 100644 --- a/android/whisperkit/build.gradle.kts +++ b/android/whisperkit/build.gradle.kts @@ -66,7 +66,7 @@ dependencies { mavenPublishing { - coordinates("com.argmaxinc", "whisperkit", "0.3.0") + coordinates("com.argmaxinc", "whisperkit", "0.3.2") pom { name.set("WhisperKit") description.set("On-device Speech Recognition for Android") diff --git a/android/whisperkit/detekt-baseline.xml b/android/whisperkit/detekt-baseline.xml index 1df0484..4c4b90c 100644 --- a/android/whisperkit/detekt-baseline.xml +++ b/android/whisperkit/detekt-baseline.xml @@ -2,8 +2,10 @@ + LargeClass:ArgmaxModelDownloaderImplTest.kt$ArgmaxModelDownloaderImplTest ThrowsCount:WhisperKit.kt$WhisperKit.Builder$@Throws(WhisperKitException::class) fun build(): WhisperKit TooGenericExceptionCaught:KtorHuggingFaceApiImpl.kt$KtorHuggingFaceApiImpl$e: Exception TooGenericExceptionCaught:WhisperKitImpl.kt$WhisperKitImpl$e: Exception + UnusedParameter:WhisperKitImpl.kt$WhisperKitImpl$timestamp: Float diff --git a/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/WhisperKit.kt b/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/WhisperKit.kt index 4e6f7b2..254de56 100644 --- a/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/WhisperKit.kt +++ b/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/WhisperKit.kt @@ -7,6 +7,23 @@ import android.content.Context import com.argmaxinc.whisperkit.huggingface.HuggingFaceApi import kotlinx.coroutines.flow.Flow +/** + * Contains the complete transcription result. + * The result includes both the full text and individual segments of the transcription. + */ +data class TranscriptionResult( + val text: String, + val segments: List, +) + +/** + * Represents a single segment of transcribed text. + * Each segment contains the text content of a portion of the transcription. + */ +data class TranscriptionSegment( + val text: String, +) + /** * WhisperKit is a speech recognition library that provides real-time transcription capabilities. * It supports both OpenAI and Qualcomm Whisper models, with various size options and compute backend configurations. @@ -138,16 +155,11 @@ interface WhisperKit { * - MSG_INIT (0): init() succeeded, model is ready * - MSG_TEXT_OUT (1): transcription results from previous transcribe() call * - MSG_CLOSE (2): deinitialize() succeeded, cleanup complete - * @param timestamp The timestamp of the transcribed segment (only valid for MSG_TEXT_OUT) - * @param msg The transcribed text or status message: - * - For MSG_INIT: initialization status - * - For MSG_TEXT_OUT: transcribed text - * - For MSG_CLOSE: cleanup status + * @param result The transcription result containing with raw text and segments */ fun onTextOutput( what: Int, - timestamp: Float, - msg: String, + result: TranscriptionResult, ) } @@ -160,15 +172,27 @@ interface WhisperKit { // Model variants const val OPENAI_TINY_EN = "whisperkit-litert/openai_whisper-tiny.en" const val OPENAI_BASE_EN = "whisperkit-litert/openai_whisper-base.en" - const val OPENAI_SMALL_EN = "whisperkit-litert/openai_whisper-small.en" + const val OPENAI_TINY = "whisperkit-litert/openai_whisper-tiny" + const val OPENAI_BASE = "whisperkit-litert/openai_whisper-base" const val QUALCOMM_TINY_EN = "qualcomm/Whisper_Tiny_En" const val QUALCOMM_BASE_EN = "qualcomm/Whisper_Base_En" - const val QUALCOMM_SMALL_EN = "qualcomm/Whisper_Small_En" + + // Small models are not supported yet + internal const val OPENAI_SMALL_EN = "whisperkit-litert/openai_whisper-small.en" + internal const val QUALCOMM_SMALL_EN = "qualcomm/Whisper_Small_En" // Compute units used for encoder/decoder backend const val CPU_ONLY = 1 const val CPU_AND_GPU = 2 const val CPU_AND_NPU = 3 + val SUPPORTED_MODELS = listOf( + OPENAI_TINY_EN, + OPENAI_BASE_EN, + OPENAI_TINY, + OPENAI_BASE, + QUALCOMM_TINY_EN, + QUALCOMM_BASE_EN, + ) } private var model: String? = null @@ -185,17 +209,8 @@ interface WhisperKit { */ @Throws(WhisperKitException::class) fun setModel(model: String): Builder { - if (model !in - listOf( - OPENAI_TINY_EN, - OPENAI_BASE_EN, - OPENAI_SMALL_EN, - QUALCOMM_TINY_EN, - QUALCOMM_BASE_EN, - QUALCOMM_SMALL_EN, - ) - ) { - throw WhisperKitException("Model must be one of the predefined variants") + if (model !in SUPPORTED_MODELS) { + throw WhisperKitException("Model must be one of the predefined variants: $SUPPORTED_MODELS") } this.model = model return this diff --git a/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/WhisperKitImpl.kt b/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/WhisperKitImpl.kt index 47d0aad..e18c7ca 100644 --- a/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/WhisperKitImpl.kt +++ b/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/WhisperKitImpl.kt @@ -9,6 +9,8 @@ import com.argmaxinc.whisperkit.huggingface.HuggingFaceApi import com.argmaxinc.whisperkit.network.ArgmaxModel import com.argmaxinc.whisperkit.network.ArgmaxModelDownloader import com.argmaxinc.whisperkit.network.ArgmaxModelDownloaderImpl +import com.argmaxinc.whisperkit.util.MessageProcessor +import com.argmaxinc.whisperkit.util.SegmentTextOnlyMessageProcessor import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.mapLatest @@ -22,6 +24,7 @@ internal class WhisperKitImpl( decoderBackend: Int, private val callback: WhisperKit.TextOutputCallback, private val argmaxModelDownloader: ArgmaxModelDownloader = ArgmaxModelDownloaderImpl(), + private val messageProcessor: MessageProcessor = SegmentTextOnlyMessageProcessor(), ) : WhisperKit { companion object { private const val TAG = "WhisperKitImpl" @@ -144,14 +147,17 @@ internal class WhisperKitImpl( } } - // Callback from JNI native code + // Callback from JNI native code, note at the moment timestamp is always 0 private fun onTextOutput( what: Int, timestamp: Float, msg: String, ) { try { - callback.onTextOutput(what, timestamp, msg) + callback.onTextOutput( + what, + messageProcessor.process(msg), + ) } catch (e: Exception) { Log.e(TAG, "Callback execution failed: ${e.message}") throw WhisperKitException("Callback execution failed: ${e.message}", e) diff --git a/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/network/ArgmaxModelDownloaderImpl.kt b/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/network/ArgmaxModelDownloaderImpl.kt index b1a4500..59c04ed 100644 --- a/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/network/ArgmaxModelDownloaderImpl.kt +++ b/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/network/ArgmaxModelDownloaderImpl.kt @@ -58,7 +58,7 @@ class AndroidHuggingFaceLogger(private val tag: String) : HuggingFaceLogger { * This class provides functionality to download required model files for a specific variant using * HuggingFaceApi, supporting automatic retries and progress reporting during downloads. */ -class ArgmaxModelDownloaderImpl( +internal class ArgmaxModelDownloaderImpl( private val huggingFaceApi: HuggingFaceApi = KtorHuggingFaceApiImpl( config = @@ -69,6 +69,7 @@ class ArgmaxModelDownloaderImpl( ) : ArgmaxModelDownloader { companion object { private const val TOKENIZER_REPO = "TOKENIZER_REPO" + private const val CONFIG_REPO = "TOKENIZER_REPO" private const val ENCODER_DECODER_REPO = "ENCODER_DECODER_REPO" // dir path under argmaxinc/whisperkit-litert to look up for MelSpectrogram.tflite @@ -79,36 +80,56 @@ class ArgmaxModelDownloaderImpl( mapOf( WhisperKit.Builder.OPENAI_TINY_EN to mapOf( + CONFIG_REPO to "openai/whisper-tiny.en", TOKENIZER_REPO to "openai/whisper-tiny.en", ENCODER_DECODER_REPO to "openai_whisper-tiny.en", FEATURE_EXTRACTOR_PATH to "openai_whisper-tiny.en", ), WhisperKit.Builder.OPENAI_BASE_EN to mapOf( + CONFIG_REPO to "openai/whisper-base.en", TOKENIZER_REPO to "openai/whisper-base.en", ENCODER_DECODER_REPO to "openai_whisper-base.en", FEATURE_EXTRACTOR_PATH to "openai_whisper-base.en", ), + WhisperKit.Builder.OPENAI_TINY to + mapOf( + CONFIG_REPO to "openai/whisper-tiny", + TOKENIZER_REPO to "openai/whisper-tiny", + ENCODER_DECODER_REPO to "openai_whisper-tiny", + FEATURE_EXTRACTOR_PATH to "openai_whisper-tiny", + ), + WhisperKit.Builder.OPENAI_BASE to + mapOf( + CONFIG_REPO to "openai/whisper-base", + TOKENIZER_REPO to "openai/whisper-base", + ENCODER_DECODER_REPO to "openai_whisper-base", + FEATURE_EXTRACTOR_PATH to "openai_whisper-base", + ), WhisperKit.Builder.OPENAI_SMALL_EN to mapOf( + CONFIG_REPO to "openai/whisper-small.en", TOKENIZER_REPO to "openai/whisper-small.en", ENCODER_DECODER_REPO to "openai_whisper-small.en", FEATURE_EXTRACTOR_PATH to "openai_whisper-small.en", ), WhisperKit.Builder.QUALCOMM_TINY_EN to mapOf( + CONFIG_REPO to "openai/whisper-tiny.en", TOKENIZER_REPO to "openai/whisper-tiny.en", ENCODER_DECODER_REPO to "qualcomm/Whisper-Tiny-En", FEATURE_EXTRACTOR_PATH to "quic_openai_whisper-tiny.en", ), WhisperKit.Builder.QUALCOMM_BASE_EN to mapOf( + CONFIG_REPO to "openai/whisper-base.en", TOKENIZER_REPO to "openai/whisper-base.en", ENCODER_DECODER_REPO to "qualcomm/Whisper-Base-En", FEATURE_EXTRACTOR_PATH to "quic_openai_whisper-base.en", ), WhisperKit.Builder.QUALCOMM_SMALL_EN to mapOf( + CONFIG_REPO to "openai/whisper-small.en", TOKENIZER_REPO to "openai/whisper-small.en", ENCODER_DECODER_REPO to "qualcomm/Whisper-Small-En", FEATURE_EXTRACTOR_PATH to "quic_openai_whisper-small.en", @@ -120,12 +141,12 @@ class ArgmaxModelDownloaderImpl( * Downloads model files for a specific variant and returns a flow of download progress. * * For OpenAI models (whisperkit-litert/openai_*): - * - Downloads tokenizer.json from openai/whisper-*.en + * - Downloads config.json and tokenizer.json from openai/whisper-* * - Downloads AudioEncoder.tflite and TextDecoder.tflite from argmaxinc/whisperkit-litert/openai_whisper-* * - Downloads MelSpectrogram.tflite from argmaxinc/whisperkit-litert/openai_whisper-* * * For Qualcomm models (qualcomm/Whisper_*_En): - * - Downloads tokenizer.json from openai/whisper-*.en + * - Downloads config.json and tokenizer.json from openai/whisper-*.en * - Downloads WhisperEncoder.tflite and WhisperDecoder.tflite from qualcomm/Whisper-*-En * and renames them to AudioEncoder.tflite and TextDecoder.tflite respectively * - Downloads MelSpectrogram.tflite from argmaxinc/whisperkit-litert/quic_openai_whisper-* @@ -133,6 +154,8 @@ class ArgmaxModelDownloaderImpl( * @param variant The model variant to download. Must be one of: * - [WhisperKit.Builder.OPENAI_TINY_EN] * - [WhisperKit.Builder.OPENAI_BASE_EN] + * - [WhisperKit.Builder.OPENAI_TINY] + * - [WhisperKit.Builder.OPENAI_BASE] * - [WhisperKit.Builder.OPENAI_SMALL_EN] * - [WhisperKit.Builder.QUALCOMM_TINY_EN] * - [WhisperKit.Builder.QUALCOMM_BASE_EN] @@ -147,20 +170,53 @@ class ArgmaxModelDownloaderImpl( ): Flow { val config = modelConfigs[variant] ?: throw IllegalArgumentException("Invalid variant: $variant") - return combine( + downloadConfig(config, root), downloadTokenizer(config, root), downloadEncoderDecoder(variant, config, root), downloadFeatureExtractor(config, root), - ) { tokenizer, encoderDecoder, featureExtractor -> - // Each flow contributes 1/3 to the total progress + ) { config, tokenizer, encoderDecoder, featureExtractor -> + // Each flow contributes 1/4 to the total progress HuggingFaceApi.Progress( - fractionCompleted = - ( - tokenizer.fractionCompleted + encoderDecoder.fractionCompleted + - featureExtractor.fractionCompleted - ) / 3.0f, + fractionCompleted = ( + config.fractionCompleted + tokenizer.fractionCompleted + + encoderDecoder.fractionCompleted + featureExtractor.fractionCompleted + ) / 4.0f, + ) + }.onCompletion { + // Clean up model directories after all downloads are complete + if (!variant.startsWith("qualcomm/")) { + // For OpenAI models, clean up the model directory + File(root, config[ENCODER_DECODER_REPO]!!).deleteRecursively() + } + // Clean up feature extractor directory + File(root, config[FEATURE_EXTRACTOR_PATH]!!).deleteRecursively() + } + } + + @OptIn(ExperimentalCoroutinesApi::class) + private fun downloadConfig( + config: Map, + root: File, + ): Flow { + return flow { + emit( + huggingFaceApi.getFileMetadata( + from = Repo(config[CONFIG_REPO]!!, RepoType.MODELS), + filename = "config.json", + ), ) + }.flatMapLatest { tokenizerMetadata -> + val cachedTokenizerFile = File(root, "config.json") + if (cachedTokenizerFile.exists() && cachedTokenizerFile.length() == tokenizerMetadata.size) { + flowOf(HuggingFaceApi.Progress(1.0f)) + } else { + huggingFaceApi.snapshot( + from = Repo(config[CONFIG_REPO]!!, RepoType.MODELS), + globFilters = listOf("config.json"), + baseDir = root, + ) + } } } @@ -314,7 +370,6 @@ class ArgmaxModelDownloaderImpl( "MelSpectrogram.tflite", ), ) - File(root, modelDir).deleteRecursively() } } } diff --git a/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/util/MessageProcessor.kt b/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/util/MessageProcessor.kt new file mode 100644 index 0000000..d5e3e2a --- /dev/null +++ b/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/util/MessageProcessor.kt @@ -0,0 +1,10 @@ +package com.argmaxinc.whisperkit.util + +import com.argmaxinc.whisperkit.TranscriptionResult + +/** + * Processor to convert raw model output Strings into [TranscriptionResult] + */ +internal interface MessageProcessor { + fun process(rawMsg: String): TranscriptionResult +} diff --git a/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/util/SegmentTextOnlyMessageProcessor.kt b/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/util/SegmentTextOnlyMessageProcessor.kt new file mode 100644 index 0000000..8aa2d6a --- /dev/null +++ b/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/util/SegmentTextOnlyMessageProcessor.kt @@ -0,0 +1,46 @@ +package com.argmaxinc.whisperkit.util + +import com.argmaxinc.whisperkit.TranscriptionResult +import com.argmaxinc.whisperkit.TranscriptionSegment + +/** + * A processor to only extract segment text from raw string, ignoring all timestamps or windows + */ +internal class SegmentTextOnlyMessageProcessor : MessageProcessor { + private companion object { + private val TIMESTAMP_PATTERN = "<\\|(\\d+\\.\\d+)\\|>".toRegex() + + // Pattern to match any <|str|> that's not a timestamp + private val NON_TIMESTAMP_PATTERN = "<\\|(?!\\d+\\.\\d+)[^>]*\\|>".toRegex() + } + + override fun process(rawMsg: String): TranscriptionResult { + // Remove any markers that aren't timestamps + val cleanMsg = rawMsg.replace(NON_TIMESTAMP_PATTERN, "") + + val segments = mutableListOf() + + // Find all timestamp markers + val matches = TIMESTAMP_PATTERN.findAll(cleanMsg).toList() + + for (i in 0 until matches.size - 1) { + val startMatch = matches[i] + val endMatch = matches[i + 1] + + // TODD: add start and end to each segment + // val start = startMatch.groupValues[1].toFloat() + // val end = endMatch.groupValues[1].toFloat() + + // Extract text between timestamps + val textStart = startMatch.range.last + 1 + val textEnd = endMatch.range.first + val text = cleanMsg.substring(textStart, textEnd) + + if (text != "\n" && text.isNotEmpty()) { + segments.add(TranscriptionSegment(text)) + } + } + + return TranscriptionResult(text = rawMsg, segments = segments) + } +} diff --git a/android/whisperkit/src/test/java/com/argmaxinc/whisperkit/network/ArgmaxModelDownloaderImplTest.kt b/android/whisperkit/src/test/java/com/argmaxinc/whisperkit/network/ArgmaxModelDownloaderImplTest.kt index a3f8237..28c1cbc 100644 --- a/android/whisperkit/src/test/java/com/argmaxinc/whisperkit/network/ArgmaxModelDownloaderImplTest.kt +++ b/android/whisperkit/src/test/java/com/argmaxinc/whisperkit/network/ArgmaxModelDownloaderImplTest.kt @@ -45,6 +45,14 @@ class ArgmaxModelDownloaderImplTest { ) { val qualcommModel = modelVariant.startsWith("qualcomm/") + // Mock config metadata + coEvery { + huggingFaceApi.getFileMetadata( + from = eq(Repo(expectedTokenizerRepo, RepoType.MODELS)), + filename = eq("config.json"), + ) + } returns HuggingFaceApi.FileMetadata(500L, "config.json") + // Mock tokenizer metadata coEvery { huggingFaceApi.getFileMetadata( @@ -115,7 +123,16 @@ class ArgmaxModelDownloaderImplTest { expectedMelSpectrogramPath: String, ) { // Create cached files upfront - // 1. Create tokenizer.json + // 1. Create config.json + val configFile = File(root, "config.json") + configFile.createNewFile() + configFile.setWritable(true) + configFile.setReadable(true) + configFile.setExecutable(true) + configFile.setLastModified(System.currentTimeMillis()) + configFile.writeBytes(ByteArray(500)) + + // 2. Create tokenizer.json val tokenizerFile = File(root, "tokenizer.json") tokenizerFile.createNewFile() tokenizerFile.setWritable(true) @@ -124,7 +141,7 @@ class ArgmaxModelDownloaderImplTest { tokenizerFile.setLastModified(System.currentTimeMillis()) tokenizerFile.writeBytes(ByteArray(1000)) // Set size to match metadata - // 2. Create encoder/decoder files + // 3. Create encoder/decoder files val audioEncoderFile = File(root, "AudioEncoder.tflite") val textDecoderFile = File(root, "TextDecoder.tflite") audioEncoderFile.createNewFile() @@ -132,7 +149,7 @@ class ArgmaxModelDownloaderImplTest { audioEncoderFile.writeBytes(ByteArray(2000)) // Set size to match metadata textDecoderFile.writeBytes(ByteArray(3000)) // Set size to match metadata - // 3. Create MelSpectrogram.tflite + // 4. Create MelSpectrogram.tflite val melSpectrogramFile = File(root, "MelSpectrogram.tflite") melSpectrogramFile.createNewFile() melSpectrogramFile.writeBytes(ByteArray(4000)) // Set size to match metadata @@ -197,6 +214,13 @@ class ArgmaxModelDownloaderImplTest { "MelSpectrogram.tflite", ).exists(), ) { "MelSpectrogram.tflite should exist in root directory" } + + // Add directory deletion verification + if (!qualcommModel) { + assert(!File(root, expectedMelSpectrogramPath).exists()) { + "Model directory should be deleted after download" + } + } } private fun testNonCached( @@ -207,6 +231,15 @@ class ArgmaxModelDownloaderImplTest { expectedEncoderDecoderRepo: String, expectedEncoderDecoderGlobFilters: List, ) { + // Mock config snapshot + every { + huggingFaceApi.snapshot( + from = eq(Repo(expectedTokenizerRepo, RepoType.MODELS)), + globFilters = eq(listOf("config.json")), + baseDir = eq(root), + ) + } returns flowOf(HuggingFaceApi.Progress(1.0f)) + // Mock tokenizer snapshot every { huggingFaceApi.snapshot( @@ -264,7 +297,14 @@ class ArgmaxModelDownloaderImplTest { downloader.download(ArgmaxModel.WHISPER, modelVariant, root).collect {} } - // Verify snapshot was called exactly 3 times for each component + // Verify snapshot was called exactly 4 times for each component + verify(exactly = 1) { + huggingFaceApi.snapshot( + from = eq(Repo(expectedTokenizerRepo, RepoType.MODELS)), + globFilters = eq(listOf("config.json")), + baseDir = eq(root), + ) + } verify(exactly = 1) { huggingFaceApi.snapshot( from = eq(Repo(expectedTokenizerRepo, RepoType.MODELS)), @@ -346,6 +386,13 @@ class ArgmaxModelDownloaderImplTest { "MelSpectrogram.tflite", ).exists(), ) { "MelSpectrogram.tflite should exist in root directory" } + + // Add directory deletion verification + if (!qualcommModel) { + assert(!File(root, expectedMelSpectrogramPath).exists()) { + "Model directory should be deleted after download" + } + } } @Test @@ -362,6 +409,20 @@ class ArgmaxModelDownloaderImplTest { expectedMelSpectrogramPath = "openai_whisper-tiny.en", ) + @Test + fun `download creates correct flows for OpenAI tiny multilingual model`() = + testDownload( + modelVariant = WhisperKit.Builder.OPENAI_TINY, + expectedTokenizerRepo = "openai/whisper-tiny", + expectedEncoderDecoderRepo = "argmaxinc/whisperkit-litert", + expectedEncoderDecoderGlobFilters = + listOf( + "openai_whisper-tiny/AudioEncoder.tflite", + "openai_whisper-tiny/TextDecoder.tflite", + ), + expectedMelSpectrogramPath = "openai_whisper-tiny", + ) + @Test fun `download creates correct flows for OpenAI base model`() = testDownload( @@ -376,6 +437,20 @@ class ArgmaxModelDownloaderImplTest { expectedMelSpectrogramPath = "openai_whisper-base.en", ) + @Test + fun `download creates correct flows for OpenAI base multilingual model`() = + testDownload( + modelVariant = WhisperKit.Builder.OPENAI_BASE, + expectedTokenizerRepo = "openai/whisper-base", + expectedEncoderDecoderRepo = "argmaxinc/whisperkit-litert", + expectedEncoderDecoderGlobFilters = + listOf( + "openai_whisper-base/AudioEncoder.tflite", + "openai_whisper-base/TextDecoder.tflite", + ), + expectedMelSpectrogramPath = "openai_whisper-base", + ) + @Test fun `download creates correct flows for OpenAI small model`() = testDownload( @@ -447,6 +522,21 @@ class ArgmaxModelDownloaderImplTest { cached = true, ) + @Test + fun `download uses cached files when available for OpenAI tiny multilingual model`() = + testDownload( + modelVariant = WhisperKit.Builder.OPENAI_TINY, + expectedTokenizerRepo = "openai/whisper-tiny", + expectedEncoderDecoderRepo = "argmaxinc/whisperkit-litert", + expectedEncoderDecoderGlobFilters = + listOf( + "openai_whisper-tiny/AudioEncoder.tflite", + "openai_whisper-tiny/TextDecoder.tflite", + ), + expectedMelSpectrogramPath = "openai_whisper-tiny", + cached = true, + ) + @Test fun `download uses cached files when available for OpenAI base model`() = testDownload( @@ -462,6 +552,21 @@ class ArgmaxModelDownloaderImplTest { cached = true, ) + @Test + fun `download uses cached files when available for OpenAI base multilingual model`() = + testDownload( + modelVariant = WhisperKit.Builder.OPENAI_BASE, + expectedTokenizerRepo = "openai/whisper-base", + expectedEncoderDecoderRepo = "argmaxinc/whisperkit-litert", + expectedEncoderDecoderGlobFilters = + listOf( + "openai_whisper-base/AudioEncoder.tflite", + "openai_whisper-base/TextDecoder.tflite", + ), + expectedMelSpectrogramPath = "openai_whisper-base", + cached = true, + ) + @Test fun `download uses cached files when available for OpenAI small model`() = testDownload( @@ -531,11 +636,20 @@ class ArgmaxModelDownloaderImplTest { @Test fun `download combines progress from all flows correctly`() = runTest { + val configProgress = 0.2f val tokenizerProgress = 0.3f val encoderDecoderProgress = 0.6f val melSpectrogramProgress = 0.9f val expectedProgress = - (tokenizerProgress + encoderDecoderProgress + melSpectrogramProgress) / 3.0f + (configProgress + tokenizerProgress + encoderDecoderProgress + melSpectrogramProgress) / 4.0f + + // Mock config metadata + coEvery { + huggingFaceApi.getFileMetadata( + from = eq(Repo("openai/whisper-tiny.en", RepoType.MODELS)), + filename = eq("config.json"), + ) + } returns HuggingFaceApi.FileMetadata(500L, "config.json") // Mock tokenizer metadata coEvery { @@ -570,6 +684,15 @@ class ArgmaxModelDownloaderImplTest { ) // Mock tokenizer snapshot + every { + huggingFaceApi.snapshot( + from = eq(Repo("openai/whisper-tiny.en", RepoType.MODELS)), + globFilters = eq(listOf("config.json")), + baseDir = eq(root), + ) + } returns flowOf(HuggingFaceApi.Progress(configProgress)) + + // Mock encoder/decoder snapshot every { huggingFaceApi.snapshot( from = eq(Repo("openai/whisper-tiny.en", RepoType.MODELS)), diff --git a/android/whisperkit/src/test/java/com/argmaxinc/whisperkit/util/SegmentTextOnlyMessageProcessorTest.kt b/android/whisperkit/src/test/java/com/argmaxinc/whisperkit/util/SegmentTextOnlyMessageProcessorTest.kt new file mode 100644 index 0000000..52c375f --- /dev/null +++ b/android/whisperkit/src/test/java/com/argmaxinc/whisperkit/util/SegmentTextOnlyMessageProcessorTest.kt @@ -0,0 +1,97 @@ +package com.argmaxinc.whisperkit.util + +import com.argmaxinc.whisperkit.TranscriptionSegment +import org.junit.Assert.assertEquals +import org.junit.Test + +class SegmentTextOnlyMessageProcessorTest { + private val processor = SegmentTextOnlyMessageProcessor() + + @Test + fun `process extracts text segments correctly`() { + // Given + val rawMsg = """ + <|startoftranscript|><|0.00|> When I was 27 years old, I left a very demanding job in management consulting.<|18.48|><|18.48|> For a job that was even more demanding, teaching.<|21.84|><|endoftext|> + <|startoftranscript|><|0.00|> I went to teach seventh graders, Math, in the New York City Public Schools.<|5.88|><|5.88|> And like any teacher, I made quizzes and tests.<|8.08|><|endoftext|> + """.trimIndent() + + // When + val result = processor.process(rawMsg) + + // Then + val expectedSegments = listOf( + TranscriptionSegment(" When I was 27 years old, I left a very demanding job in management consulting."), + TranscriptionSegment(" For a job that was even more demanding, teaching."), + TranscriptionSegment(" I went to teach seventh graders, Math, in the New York City Public Schools."), + TranscriptionSegment(" And like any teacher, I made quizzes and tests."), + ) + + assertEquals(expectedSegments, result.segments) + } + + @Test + fun `process handles multiple segments with same timestamps`() { + // Given + val rawMsg = """ + <|startoftranscript|><|0.00|> not fixed, that it can change with your effort.<|4.48|><|4.48|> Dr. Dweck has shown that when kids read and learn<|7.44|><|7.44|> about the brain and how it changes and grows<|10.72|><|10.72|> in response to challenge, they're much more likely<|13.64|><|13.64|> to persevere when they fail because they don't believe<|18.80|><|18.80|> that failure is a permanent condition.<|21.56|><|endoftext|> + <|startoftranscript|><|0.00|> So, growth mindset is a great idea for building grit, but we need more.<|6.60|><|6.60|> And that's where I'm going to end my remarks, because that's where we are.<|9.28|><|endoftext|> + """.trimIndent() + + // When + val result = processor.process(rawMsg) + + // Then + val expectedSegments = listOf( + TranscriptionSegment(" not fixed, that it can change with your effort."), + TranscriptionSegment(" Dr. Dweck has shown that when kids read and learn"), + TranscriptionSegment(" about the brain and how it changes and grows"), + TranscriptionSegment(" in response to challenge, they're much more likely"), + TranscriptionSegment(" to persevere when they fail because they don't believe"), + TranscriptionSegment(" that failure is a permanent condition."), + TranscriptionSegment(" So, growth mindset is a great idea for building grit, but we need more."), + TranscriptionSegment(" And that's where I'm going to end my remarks, because that's where we are."), + ) + + assertEquals(expectedSegments, result.segments) + } + + @Test + fun `process handles special characters and short segments`() { + // Given + val rawMsg = "<|startoftranscript|><|0.00|> ♪♪<|2.00|><|endoftext|>" + + // When + val result = processor.process(rawMsg) + + // Then + val expectedSegments = listOf( + TranscriptionSegment(" ♪♪"), + ) + + assertEquals(expectedSegments, result.segments) + } + + @Test + fun `process handles empty input`() { + // Given + val rawMsg = "" + + // When + val result = processor.process(rawMsg) + + // Then + assertEquals(emptyList(), result.segments) + } + + @Test + fun `process handles input with no segments`() { + // Given + val rawMsg = "<|startoftranscript|><|endoftext|>" + + // When + val result = processor.process(rawMsg) + + // Then + assertEquals(emptyList(), result.segments) + } +} diff --git a/cpp/src/Models/TextDecoder.cpp b/cpp/src/Models/TextDecoder.cpp index 297bc94..aaa1a79 100644 --- a/cpp/src/Models/TextDecoder.cpp +++ b/cpp/src/Models/TextDecoder.cpp @@ -12,7 +12,154 @@ using namespace WhisperKit; // 'Monolithic KV Cache' ~ corresponds to the QUIC exported Whisper models -// TODO: make a metadata class for the model signature and tendor indices, +namespace WhisperKit { +constexpr const int kKvFactor = 2; +constexpr const int kLayersWhisperTiny = 4; +constexpr const int kLayersWhisperBase = 6; +constexpr const int kLayersWhisperSmall = 12; +constexpr const int kLayersWhisperMedium = 24; +constexpr const int kLayersWhisperLarge = 32; + +constexpr const char* kVariantWhisperTiny = "tiny"; +constexpr const char* kVariantWhisperBase = "base"; +constexpr const char* kVariantWhisperSmall = "small"; +constexpr const char* kVariantWhisperMedium = "medium"; +constexpr const char* kVariantWhisperLarge = "large"; +constexpr const char* kVariantNone = "none"; +} // namespace WhisperKit + +namespace { + +std::string normalize_name(const std::string& name) { + // Names were padded with null characters to ensure alignment with original exported names + // Remove extra null characters, or naive string matching will fail. + auto name_copy = name.c_str(); + return std::string(name_copy); +}; + +} // namespace + +class FlatBuffersMetadata { + public: + FlatBuffersMetadata(const std::string& tflite_model_path) { + _model_file_path = tflite_model_path; + std::ifstream file(tflite_model_path, std::ios::binary | std::ios::ate); + if (!file) throw std::runtime_error("Failed to open file"); + + std::streamsize size = file.tellg(); + file.seekg(0, std::ios::beg); + + _buffer = std::vector(size); + if (!file.read(_buffer.data(), size)) throw std::runtime_error("Failed to read file"); + + const tflite::Model* model = tflite::GetModel(_buffer.data()); + + if (!model) { + throw std::runtime_error("Model is null"); + } + if (!model->subgraphs()) { + throw std::runtime_error("Model has no subgraphs"); + } + _model = model; + parse_model_metadata(); + } + + ~FlatBuffersMetadata() { + _model = nullptr; + _input_tensor_indices.clear(); + _output_tensor_indices.clear(); + _subgraphs = nullptr; + _buffer.clear(); + _buffer.shrink_to_fit(); + } + + const std::string& get_model_file_path() const { return _model_file_path; } + + const std::unordered_map> get_input_tensor_indices(int subgraph_index = 0) const; + const std::unordered_map> get_output_tensor_indices(int subgraph_index = 0) const; + void print_metadata(); + const tflite::Model* get_model() const { return _model; } + + private: + void parse_model_metadata(); + std::string _model_file_path; + const tflite::Model* _model; + std::vector _buffer; + const ::flatbuffers::Vector<::flatbuffers::Offset>* _subgraphs; + + // name -> (tensor_index, io_index) + std::vector>> _input_tensor_indices; + std::vector>> _output_tensor_indices; +}; + +void FlatBuffersMetadata::parse_model_metadata() { + _subgraphs = _model->subgraphs(); + + // Primary subgraph is the only one we care about for now. + bool only_first_subgraph = true; + int max_subgraph_index = only_first_subgraph ? 1 : _subgraphs->size(); + + _input_tensor_indices.resize(max_subgraph_index); + _output_tensor_indices.resize(max_subgraph_index); + + for (int i = 0; i < max_subgraph_index; i++) { + const tflite::SubGraph* subgraph = _subgraphs->Get(i); + std::unordered_map> input_tensor_indices; + std::unordered_map> output_tensor_indices; + + const auto* inputs = subgraph->inputs(); + const auto* outputs = subgraph->outputs(); + const auto* tensors = subgraph->tensors(); + + for (int i = 0; i < inputs->size(); ++i) { + int tensor_index = inputs->Get(i); + auto name = normalize_name(tensors->Get(tensor_index)->name()->str()); + input_tensor_indices[name] = std::make_pair(tensor_index, i); + } + + for (int i = 0; i < outputs->size(); ++i) { + int tensor_index = outputs->Get(i); + auto name = normalize_name(tensors->Get(tensor_index)->name()->str()); + output_tensor_indices[name] = std::make_pair(tensor_index, i); + } + _input_tensor_indices[i] = input_tensor_indices; + _output_tensor_indices[i] = output_tensor_indices; + } +} + +void FlatBuffersMetadata::print_metadata() { + std::cout << "Model file path: " << _model_file_path << std::endl; + for (int i = 0; i < _input_tensor_indices.size(); i++) { + std::cout << "Subgraph " << i << " input tensor indices:" << std::endl; + for (const auto& [name, indices] : _input_tensor_indices[i]) { + std::cout << " " << name << ": (" << indices.first << ", " << indices.second << ")" << std::endl; + } + } + for (int i = 0; i < _output_tensor_indices.size(); i++) { + std::cout << "Subgraph " << i << " output tensor indices:" << std::endl; + for (const auto& [name, indices] : _output_tensor_indices[i]) { + std::cout << " " << name << ": (" << indices.first << ", " << indices.second << ")" << std::endl; + } + } +} + +const std::unordered_map> FlatBuffersMetadata::get_input_tensor_indices( + int subgraph_index) const { + if (subgraph_index >= _input_tensor_indices.size()) { + throw std::runtime_error("Subgraph index out of bounds"); + } + return _input_tensor_indices[subgraph_index]; +} + +const std::unordered_map> FlatBuffersMetadata::get_output_tensor_indices( + int subgraph_index) const { + if (subgraph_index >= _output_tensor_indices.size()) { + throw std::runtime_error("Subgraph index out of bounds"); + } + return _output_tensor_indices[subgraph_index]; +} + +// TODO: make a metadata class for the model signature and tensor indices, // pass to the subclasses so we don't have to reopen the file outside of // loading the actual tflite model. TFLIte APIs require signature runner // to get this information from interpreter, which is not available if the @@ -22,7 +169,17 @@ bool is_exact_match_for_monolithic_kv_cache(const tflite::Model* model) { "x", "index", "k_cache_cross", "v_cache_cross", "k_cache_self", "v_cache_self"}; const std::unordered_set expected_output_names = {"logits", "k_cache", "v_cache"}; - const auto* subgraph = model->subgraphs()->Get(0); + if (!model->subgraphs()) { + throw std::runtime_error("Model has no subgraphs"); + } + + const auto* subgraphs = model->subgraphs(); + + if (subgraphs->size() == 0) { + throw std::runtime_error("Model has no subgraphs"); + } + + const auto* subgraph = subgraphs->Get(0); const auto* inputs = subgraph->inputs(); const auto* outputs = subgraph->outputs(); @@ -31,11 +188,11 @@ bool is_exact_match_for_monolithic_kv_cache(const tflite::Model* model) { } const auto* tensors = subgraph->tensors(); - std::unordered_set input_names; for (int i = 0; i < inputs->size(); ++i) { int tensor_index = inputs->Get(i); - input_names.insert(tensors->Get(tensor_index)->name()->str()); + auto name = tensors->Get(tensor_index)->name()->str(); + input_names.insert(name); } std::unordered_set output_names; @@ -51,22 +208,6 @@ bool is_exact_match_for_monolithic_kv_cache(const tflite::Model* model) { return true; } -namespace WhisperKit { -constexpr const int kKvFactor = 2; -constexpr const int kLayersWhisperTiny = 4; -constexpr const int kLayersWhisperBase = 6; -constexpr const int kLayersWhisperSmall = 12; -constexpr const int kLayersWhisperMedium = 24; -constexpr const int kLayersWhisperLarge = 32; - -constexpr const char* kVariantWhisperTiny = "tiny"; -constexpr const char* kVariantWhisperBase = "base"; -constexpr const char* kVariantWhisperSmall = "small"; -constexpr const char* kVariantWhisperMedium = "medium"; -constexpr const char* kVariantWhisperLarge = "large"; -constexpr const char* kVariantNone = "none"; -} // namespace WhisperKit - const int layers_for_variant(const std::string& variant) { if (variant == kVariantWhisperTiny) { return kLayersWhisperTiny; @@ -82,6 +223,25 @@ const int layers_for_variant(const std::string& variant) { return 0; } +std::unordered_set get_expected_input_names_for_variant(const char* variant) { + auto input_names_for_variant_with_layers = [](const int num_layers) -> auto{ + std::unordered_set input_names; + input_names.insert(std::string("x")); + input_names.insert(std::string("index")); + input_names.insert(std::string("k_cache_cross")); + input_names.insert(std::string("v_cache_cross")); + for (int i = 0; i < num_layers; ++i) { + input_names.insert(std::string("k_cache_self_" + std::to_string(i))); + input_names.insert(std::string("v_cache_self_" + std::to_string(i))); + } + return input_names; + }; + + int num_layers = layers_for_variant(variant); + + return input_names_for_variant_with_layers(num_layers); +} + bool is_exact_match_for_separate_kv_cache_no_alignment_heads(const tflite::Model* model) { const auto* subgraph = model->subgraphs()->Get(0); const auto* inputs = subgraph->inputs(); @@ -119,30 +279,17 @@ bool is_exact_match_for_separate_kv_cache_no_alignment_heads(const tflite::Model return output_names; }; - auto input_names_for_variant_with_layers = [](const int num_layers) -> auto{ - std::unordered_set input_names; - input_names.insert(std::string("x")); - input_names.insert(std::string("index")); - input_names.insert(std::string("k_cache_cross")); - input_names.insert(std::string("v_cache_cross")); - for (int i = 0; i < num_layers; ++i) { - input_names.insert(std::string("k_cache_self_" + std::to_string(i))); - input_names.insert(std::string("v_cache_self_" + std::to_string(i))); - } - return input_names; - }; - char* variant = const_cast(kVariantNone); - if (num_inputs == calculate_num_inputs_for_variant_with_layers(kLayersWhisperTiny)) { + if (num_inputs == get_expected_input_names_for_variant(kVariantWhisperTiny).size()) { variant = const_cast(kVariantWhisperTiny); - } else if (num_inputs == calculate_num_inputs_for_variant_with_layers(kLayersWhisperBase)) { + } else if (num_inputs == get_expected_input_names_for_variant(kVariantWhisperBase).size()) { variant = const_cast(kVariantWhisperBase); - } else if (num_inputs == calculate_num_inputs_for_variant_with_layers(kLayersWhisperSmall)) { + } else if (num_inputs == get_expected_input_names_for_variant(kVariantWhisperSmall).size()) { variant = const_cast(kVariantWhisperSmall); - } else if (num_inputs == calculate_num_inputs_for_variant_with_layers(kLayersWhisperMedium)) { + } else if (num_inputs == get_expected_input_names_for_variant(kVariantWhisperMedium).size()) { variant = const_cast(kVariantWhisperMedium); - } else if (num_inputs == calculate_num_inputs_for_variant_with_layers(kLayersWhisperLarge)) { + } else if (num_inputs == get_expected_input_names_for_variant(kVariantWhisperLarge).size()) { variant = const_cast(kVariantWhisperLarge); } @@ -173,20 +320,13 @@ bool is_exact_match_for_separate_kv_cache_no_alignment_heads(const tflite::Model } } - auto expected_input_names = input_names_for_variant_with_layers(layers_for_variant(variant)); + auto expected_input_names = get_expected_input_names_for_variant(variant); auto expected_output_names = output_names_for_variant_with_layers(layers_for_variant(variant)); std::unordered_set input_names; std::unordered_set output_names; const auto* tensors = subgraph->tensors(); - auto normalize_name = [](const std::string& name) -> std::string { - // Names were padded with null characters to ensure alignment with original exported names - // Remove extra null characters, or naive string matching will fail. - auto name_copy = name.c_str(); - return std::string(name_copy); - }; - for (int i = 0; i < num_inputs; ++i) { auto name = normalize_name(tensors->Get(inputs->Get(i))->name()->str()); input_names.insert(name); @@ -208,27 +348,21 @@ bool is_exact_match_for_separate_kv_cache_no_alignment_heads(const tflite::Model return true; } -std::unique_ptr TextDecoderFactory::CreateFromFile(const std::string& tflite_model_path) { - std::ifstream file(tflite_model_path, std::ios::binary | std::ios::ate); - if (!file) throw std::runtime_error("Failed to open file"); - - std::streamsize size = file.tellg(); - file.seekg(0, std::ios::beg); - - std::vector buffer(size); - if (!file.read(buffer.data(), size)) throw std::runtime_error("Failed to read file"); - - const tflite::Model* model = tflite::GetModel(buffer.data()); - - if (!model) throw std::runtime_error("Failed to load model"); - - auto is_monolithic_kv_cache = is_exact_match_for_monolithic_kv_cache(model); +TextDecoder::~TextDecoder() {} +std::unique_ptr TextDecoderFactory::CreateFromFile(const std::string& tflite_model_path) { + auto metadata = std::make_unique(tflite_model_path); + auto is_monolithic_kv_cache = is_exact_match_for_monolithic_kv_cache(metadata->get_model()); if (is_monolithic_kv_cache) { return std::make_unique(tflite_model_path); } - auto is_separate_kv_cache_no_alignment_heads = is_exact_match_for_separate_kv_cache_no_alignment_heads(model); + auto is_separate_kv_cache_no_alignment_heads = + is_exact_match_for_separate_kv_cache_no_alignment_heads(metadata->get_model()); + + if (is_separate_kv_cache_no_alignment_heads) { + return std::make_unique(tflite_model_path); + } throw std::runtime_error("Decoder model signature not recognized"); } @@ -237,14 +371,22 @@ std::pair MonolithicKVDecoder::get_logits_tensor() { return decoder_ MonolithicKVDecoder::MonolithicKVDecoder(const std::string& tflite_model_path) { _model_path = tflite_model_path; + metadata = std::make_unique(tflite_model_path); + // metadata->print_metadata(); // Note that the decoder model is not initialized here, it is initialized in the initialize method _decoder_model = std::make_unique("TextDecoder"); + if (!_decoder_model) { throw std::runtime_error("Decoder model not initialized"); } } +MonolithicKVDecoder::~MonolithicKVDecoder() { + _decoder_model.reset(); + metadata.reset(); +} + bool MonolithicKVDecoder::initialize(std::string model_path, std::string lib_dir, std::string cache_dir, int backend, bool debug) { return _decoder_model->initialize(model_path, lib_dir, cache_dir, backend, debug); @@ -317,3 +459,293 @@ std::vector> MonolithicKVDecoder::get_output_ptrs() { retu int MonolithicKVDecoder::get_inference_num() { return _decoder_model->get_inference_num(); } float MonolithicKVDecoder::get_latency_sum() { return _decoder_model->get_latency_sum(); } +void MonolithicKVDecoder::dump_input_tensors() { + // Not yet implemented +} + +void MonolithicKVDecoder::dump_output_tensors() { + // Not yet implemented +} + +std::pair PerLayerKVDecoder::get_logits_tensor() { + if (decoder_outputs.empty()) { + decoder_outputs = _decoder_model->get_output_ptrs(); + } + auto logits_index = output_tensor_indices.at("logits"); + return decoder_outputs[logits_index]; +} + +PerLayerKVDecoder::PerLayerKVDecoder(const std::string& tflite_model_path) { + _model_path = tflite_model_path; + metadata = std::make_unique(tflite_model_path); + initialize_io_metadata(); + metadata.reset(); // to close the .tflite file + + // Note that the decoder model is not initialized here, it is initialized in the initialize method + _decoder_model = std::make_unique("TextDecoder"); + if (!_decoder_model) { + throw std::runtime_error("Decoder model not initialized"); + } +} + +PerLayerKVDecoder::~PerLayerKVDecoder() { + _decoder_model.reset(); + metadata.reset(); +} + +bool PerLayerKVDecoder::initialize(std::string model_path, std::string lib_dir, std::string cache_dir, int backend, + bool debug) { + return _decoder_model->initialize(model_path, lib_dir, cache_dir, backend, debug); +} + +void PerLayerKVDecoder::uninitialize() { _decoder_model->uninitialize(); } + +void PerLayerKVDecoder::read_input_data(char* input_data, int idx) { _decoder_model->read_input_data(input_data, idx); } + +void PerLayerKVDecoder::bind_input_tensor(char* input_data, const std::string& tensor_name) { + if (tensor_name == "x") { + // get value in int32 and upcast to int64 before passing to read_input_data, which + // does a simple memcpy of all the bytes. + int32_t* input_data_int32 = reinterpret_cast(input_data); + int64_t x = static_cast(*input_data_int32); + auto x_tensor_index = input_tensor_indices["x"]; + _decoder_model->read_input_data(reinterpret_cast(&x), x_tensor_index); + return; + } + + if (tensor_name == "index") { + int* input_data_int32 = reinterpret_cast(input_data); + int64_t index = static_cast(*input_data_int32); + auto index_tensor_index = input_tensor_indices["index"]; + + _decoder_model->read_input_data(reinterpret_cast(&index), index_tensor_index); + return; + } + + if (tensor_name == "k_cache_cross") { + auto k_cache_cross_tensor_index = input_tensor_indices["k_cache_cross"]; + _decoder_model->read_input_data(input_data, k_cache_cross_tensor_index); + return; + } + + if (tensor_name == "v_cache_cross") { + auto v_cache_cross_tensor_index = input_tensor_indices["v_cache_cross"]; + _decoder_model->read_input_data(input_data, v_cache_cross_tensor_index); + return; + } + + else { + auto tensor_index = kv_cache_input_tensor_indices[tensor_name]; + _decoder_model->read_input_data(input_data, tensor_index); + return; + } + + throw std::runtime_error("Invalid tensor name"); +} + +void PerLayerKVDecoder::invoke(bool measure_time) { _decoder_model->invoke(measure_time); } + +template +void save_to_binary_file(const std::string& filename, const std::vector& data) { + std::ofstream outfile(filename, std::ios::binary); + if (!outfile) { + std::cerr << "Error opening file for writing: " << filename << std::endl; + return; + } + + printf("Saving to binary file: %s\n", filename.c_str()); + outfile.write(reinterpret_cast(data.data()), data.size() * sizeof(T)); + outfile.close(); +} + +void PerLayerKVDecoder::dump_input_tensors() { + printf("Dumping input tensors\n"); + auto input_ptrs = _decoder_model->get_input_ptrs(); + + for (int index = 0; index < _decoder_model->_interpreter->inputs().size(); index++) { + auto name = _decoder_model->_interpreter->GetInputName(index); + auto safe_name = normalize_name(name); + printf("safe name %s index %d\n", safe_name.c_str(), index); + if (safe_name == "x" || safe_name == "index") { + printf("Dumping input tensor: %s, index=%d\n", name, index); + auto tensor_ptr = input_ptrs[index]; + int tensor_size = tensor_ptr.second; + printf("Tensor size: %d\n", tensor_size); + int64_t* input_data = reinterpret_cast(tensor_ptr.first); + size_t num_elements = tensor_size / sizeof(int64_t); + std::vector data(num_elements); + printf("==============================================\n"); + printf("name %s index %d num_elements: %ld\n", name, index, num_elements); + for (int i = 0; i < num_elements; i++) { + data[i] = input_data[i]; + printf("name %s index %d data[%d] = %ld\n", name, index, i, data[i]); + } + printf("==============================================\n"); + + std::string _name = std::string(safe_name); + std::string filename = "/src/AXIE/debug_inputs/input_" + _name + ".bin"; + save_to_binary_file(filename, data); // float32 + } else { + std::string _name = std::string(safe_name); + + printf("In float tensor path"); + printf("Dumping input tensor: %s, index=%d\n", _name.c_str(), index); + auto tensor_ptr = input_ptrs[index]; + auto tensor_size = tensor_ptr.second; + printf("Tensor size: %d\n", tensor_size); + std::vector data(tensor_size / sizeof(float)); + + bool print_for_debug = true; + if (data.size() > 100) { + print_for_debug = false; + } + + for (int i = 0; i < data.size(); i++) { + data[i] = *reinterpret_cast(input_ptrs[index].first + i * sizeof(float)); + if (print_for_debug) { + printf("name %s index %d data[%d] = %f\n", _name.c_str(), index, i, data[i]); + } + } + std::string filename = "/src/AXIE/debug_inputs/input_" + _name + ".bin"; + save_to_binary_file(filename, data); + } + } +} + +void PerLayerKVDecoder::dump_output_tensors() { + printf("Dumping output tensors\n"); + auto output_ptrs = _decoder_model->get_output_ptrs(); + + for (int index = 0; index < _decoder_model->_interpreter->outputs().size(); index++) { + auto name = _decoder_model->_interpreter->GetOutputName(index); + auto safe_name = normalize_name(name); + printf("safe name %s index %d\n", safe_name.c_str(), index); + + if (safe_name == "logits") { + std::string _name = std::string(safe_name); + + printf("In float tensor path"); + printf("Dumping output tensor: %s, index=%d\n", _name.c_str(), index); + auto tensor_ptr = output_ptrs[index]; + auto tensor_size = tensor_ptr.second; + printf("Tensor size: %d\n", tensor_size); + std::vector data(tensor_size / sizeof(float)); + + bool print_for_debug = true; + if (data.size() > 100) { + print_for_debug = false; + } + for (int i = 0; i < data.size(); i++) { + data[i] = *reinterpret_cast(output_ptrs[index].first + i * sizeof(float)); + if (print_for_debug) { + printf("name %s index %d data[%d] = %f\n", _name.c_str(), index, i, data[i]); + } + } + std::string filename = "/src/AXIE/debug_inputs/output_" + _name + ".bin"; + save_to_binary_file(filename, data); + } + } +} + +void PerLayerKVDecoder::initialize_io_metadata() { + // self attention kv cache tensors + const auto& all_input_tensor_indices = metadata->get_input_tensor_indices(0); + const auto& all_output_tensor_indices = metadata->get_output_tensor_indices(0); + + // store only the relative indices within the i/o tensor vectors + auto logits_indices = all_output_tensor_indices.at("logits"); + output_tensor_indices["logits"] = logits_indices.second; + + auto token_indices = all_input_tensor_indices.at("x"); + input_tensor_indices["x"] = token_indices.second; + + auto index_indices = all_input_tensor_indices.at("index"); + input_tensor_indices["index"] = index_indices.second; + + auto k_cache_cross_indices = all_input_tensor_indices.at("k_cache_cross"); + input_tensor_indices["k_cache_cross"] = k_cache_cross_indices.second; + + auto v_cache_cross_indices = all_input_tensor_indices.at("v_cache_cross"); + input_tensor_indices["v_cache_cross"] = v_cache_cross_indices.second; + + for (const auto& [name, indices] : all_input_tensor_indices) { + if (name == "x" || name == "index" || name == "k_cache_cross" || name == "v_cache_cross") { + continue; + } + kv_cache_input_tensor_indices[name] = indices.second; + } + + for (const auto& [name, indices] : all_output_tensor_indices) { + if (name == "logits") { + continue; + } + + if (name.find("k_cache_") == 0 || name.find("v_cache_") == 0) { + kv_cache_output_tensor_indices[name] = indices.second; + } + } + + auto extractNumericSuffix = [](const std::string& s) -> int { + size_t pos = s.find_last_of('_'); + if (pos == std::string::npos || pos == s.length() - 1) { + throw std::invalid_argument("String does not contain a valid numeric suffix"); + } + return std::stoi(s.substr(pos + 1)); + }; + + for (const auto& [name, index] : kv_cache_output_tensor_indices) { + if (name.find("k_cache_") == 0) { + auto layer_num = extractNumericSuffix(name); + auto input_name = "k_cache_self_" + std::to_string(layer_num); + + kv_cache_io_tensor_names[input_name] = name; + } else if (name.find("v_cache_") == 0) { + auto layer_num = extractNumericSuffix(name); + auto input_name = "v_cache_self_" + std::to_string(layer_num); + + kv_cache_io_tensor_names[input_name] = name; + } + } +} + +void PerLayerKVDecoder::update_kv_cache() { + if (decoder_outputs.empty()) { + decoder_outputs = _decoder_model->get_output_ptrs(); + } + + for (const auto& [input_name, output_name] : kv_cache_io_tensor_names) { + auto input_tensor_index = kv_cache_input_tensor_indices[input_name]; + auto output_tensor_index = kv_cache_output_tensor_indices[output_name]; + _decoder_model->read_input_data(decoder_outputs[output_tensor_index].first, input_tensor_index); + } +} + +void PerLayerKVDecoder::initialize_kv_cache() { + if (decoder_outputs.empty()) { + decoder_outputs = _decoder_model->get_output_ptrs(); + } + + auto input_ptrs = _decoder_model->get_input_ptrs(); + + for (const auto& [name, index] : kv_cache_input_tensor_indices) { + memset(input_ptrs[index].first, 0, input_ptrs[index].second); + } + + for (const auto& [name, index] : kv_cache_output_tensor_indices) { + memset(decoder_outputs[index].first, 0, decoder_outputs[index].second); + } +} + +float PerLayerKVDecoder::get_latency_median() { return _decoder_model->get_latency_median(); } + +float PerLayerKVDecoder::get_latency_avg() { return _decoder_model->get_latency_avg(); } + +std::unique_ptr PerLayerKVDecoder::get_latency_json() { return _decoder_model->get_latency_json(); } + +std::vector> PerLayerKVDecoder::get_input_ptrs() { return _decoder_model->get_input_ptrs(); } + +std::vector> PerLayerKVDecoder::get_output_ptrs() { return _decoder_model->get_output_ptrs(); } + +int PerLayerKVDecoder::get_inference_num() { return _decoder_model->get_inference_num(); } + +float PerLayerKVDecoder::get_latency_sum() { return _decoder_model->get_latency_sum(); } diff --git a/cpp/src/Models/TextDecoder.hpp b/cpp/src/Models/TextDecoder.hpp index 59f3590..6d0a31b 100644 --- a/cpp/src/Models/TextDecoder.hpp +++ b/cpp/src/Models/TextDecoder.hpp @@ -14,12 +14,14 @@ enum DecoderKVCacheType { }; } +class FlatBuffersMetadata; + // TODO: // remove extraneous functions used for passthrough to MODEL_SUPER_CLASS // to expedite integration class TextDecoder { public: - virtual ~TextDecoder() = default; + virtual ~TextDecoder(); virtual void initialize_kv_cache() = 0; virtual void read_input_data(char* data, int index) = 0; @@ -38,15 +40,21 @@ class TextDecoder { virtual float get_latency_median() = 0; virtual std::unique_ptr get_latency_json() = 0; + virtual void dump_input_tensors() = 0; + virtual void dump_output_tensors() = 0; + protected: + std::unique_ptr metadata; // TODO: modify to hold tflite model from tensorflow & use delegate manager std::unique_ptr _decoder_model; std::string _model_path; + std::vector> decoder_outputs; }; class MonolithicKVDecoder : public TextDecoder { public: MonolithicKVDecoder(const std::string& tflite_model_path); + ~MonolithicKVDecoder(); void initialize_kv_cache() override; bool initialize(std::string model_path, std::string lib_dir, std::string cache_dir, int backend, @@ -65,8 +73,47 @@ class MonolithicKVDecoder : public TextDecoder { float get_latency_median() override; std::unique_ptr get_latency_json() override; + void dump_input_tensors() override; + void dump_output_tensors() override; +}; + +class PerLayerKVDecoder : public TextDecoder { + public: + PerLayerKVDecoder(const std::string& tflite_model_path); + ~PerLayerKVDecoder(); + void initialize_kv_cache() override; + + bool initialize(std::string model_path, std::string lib_dir, std::string cache_dir, int backend, + bool debug) override; + void uninitialize() override; + void read_input_data(char* input_data, int idx) override; + void invoke(bool measure_time = false) override; + void update_kv_cache() override; + std::vector> get_input_ptrs() override; + std::vector> get_output_ptrs() override; + void bind_input_tensor(char* input_data, const std::string& tensor_name) override; + std::pair get_logits_tensor() override; + int get_inference_num() override; + float get_latency_sum() override; + float get_latency_avg() override; + float get_latency_median() override; + std::unique_ptr get_latency_json() override; + + void dump_input_tensors() override; + void dump_output_tensors() override; + private: - std::vector> decoder_outputs; + void initialize_io_metadata(); + // self attention kv cache tensors + std::unordered_map kv_cache_io_tensor_names; // + std::unordered_map kv_cache_input_tensor_indices; // + std::unordered_map kv_cache_output_tensor_indices; // + + // non-kv cache tensors + std::unordered_map + input_tensor_indices; // , non-kv cache tensors + std::unordered_map + output_tensor_indices; // , non-kv cache tensors }; class TextDecoderFactory { diff --git a/cpp/src/Models/tflite_model.cpp b/cpp/src/Models/tflite_model.cpp index b0eaffc..9c9884e 100644 --- a/cpp/src/Models/tflite_model.cpp +++ b/cpp/src/Models/tflite_model.cpp @@ -502,6 +502,9 @@ vector> TFLiteModel::get_input_ptrs() { case kTfLiteInt32: input_ptr = _interpreter->typed_input_tensor(idx); break; + case kTfLiteInt64: + input_ptr = _interpreter->typed_input_tensor(idx); + break; default: fprintf(stderr, "Error: unsupported tensor type: %d\n", tensor->type); exit(-1); @@ -535,6 +538,16 @@ vector> TFLiteModel::get_output_ptrs() { return _output_ptrs; } +std::pair TFLiteModel::get_output_with_name(const std::string& name) { + for (int idx = 0; idx < _interpreter->outputs().size(); idx++) { + auto* tensor = _interpreter->tensor(_interpreter->outputs()[idx]); + if (strcmp(tensor->name, name.c_str()) == 0) { + return make_pair(reinterpret_cast(tensor->data.f), tensor->bytes); + } + } + return make_pair(nullptr, 0); +} + void TFLiteModel::invoke(bool measure_time) { chrono::time_point before_exec; if (measure_time) { diff --git a/cpp/src/Models/tflite_model.hpp b/cpp/src/Models/tflite_model.hpp index 2914b87..aa13ea0 100644 --- a/cpp/src/Models/tflite_model.hpp +++ b/cpp/src/Models/tflite_model.hpp @@ -54,6 +54,7 @@ class TFLiteModel { void read_input_data(char* input_data, int idx); std::vector> get_input_ptrs(); std::vector> get_output_ptrs(); + std::pair get_output_with_name(const std::string& name); void print_tensor_dims(); std::unique_ptr get_latency_json(); @@ -66,10 +67,12 @@ class TFLiteModel { std::vector _latencies; + std::unique_ptr _interpreter; + protected: std::mutex _mutex; std::unique_ptr _model; - std::unique_ptr _interpreter; + flatbuffers::FlatBufferBuilder _builder; TfLiteDelegate* _delegate = nullptr; std::string _model_name; diff --git a/cpp/src/Text/Tokenizer.cpp b/cpp/src/Text/Tokenizer.cpp index 81815ba..6e42172 100644 --- a/cpp/src/Text/Tokenizer.cpp +++ b/cpp/src/Text/Tokenizer.cpp @@ -54,105 +54,21 @@ void init_special_tokens(Tokenizer *tokenizer) { tokenizer->specialTokens = special_tokens; } -void init_non_speech_tokens(Tokenizer *tokenizer) { - std::vector non_speech_tokens{"!", - "\"", - "#", - "(", - ")", - "*", - "+", - "/", - ":", - ";", - "<", - "=", - ">", - "@", - "[", - "\\", - "]", - "^", - "_", - "`", - "{", - "|", - "}", - "~", - " (", - " \"", - "--", - " -", - " [", - " '", - " =", - " |", - " :", - " /", - " )", - " <", - " #", - " +", - " --", - " {", - " *", - " }", - " >", - " ;", - " ]", - " @", - " \\", - "))", - ">>", - " `", - " _", - " ~", - " (\"", - "---", - "(\"", - " >>", - " <<", - " ^", - "('", - " ---", - "}}", - "]]", - " >>>", - "「", - "」", - " ((", - " ))", - " [[", - "<<", - "�", - " (\'", - "((", - " �", - ")))", - " {{", - "{{", - "[[", - "-(", - ">>>", - " }}", - " 「", - "『", - "』", - " )))", - "-[", - "<|startoftranscript|>", - "<|translate|>", - "<|transcribe|>", - "<|startoflm|>", - "<|startofprev|>", - "<|nocaptions|>"}; - tokenizer->numNonSpeechTokens = non_speech_tokens.size(); +void init_non_speech_tokens(Tokenizer *tokenizer, const std::unique_ptr &config) { + std::vector non_speech_token_ids = config->at("suppress_tokens"); + + tokenizer->numNonSpeechTokens = non_speech_token_ids.size(); tokenizer->nonSpeechTokens = (int *)malloc(sizeof(int) * tokenizer->numNonSpeechTokens); for (auto i = 0; i < tokenizer->numNonSpeechTokens; i++) { - tokenizer->nonSpeechTokens[i] = tokenizer_convert_token_to_id(tokenizer, non_speech_tokens[i].c_str()); + tokenizer->nonSpeechTokens[i] = non_speech_token_ids[i]; } } +bool tokenizer_is_multilingual(const Tokenizer *tokenizer) { + constexpr const int ENGLISH_VOCAB_SIZE = 51864; + return tokenizer->vocabSize != ENGLISH_VOCAB_SIZE; +} + int tokenizer_convert_token_to_id(const Tokenizer *tokenizer, const char *token_string) { // Encode token CEncoding *encoding = tokenizer_encode(tokenizer->handle, token_string, false); @@ -168,7 +84,7 @@ int tokenizer_convert_token_to_id(const Tokenizer *tokenizer, const char *token_ return static_cast(id); } -Tokenizer *tokenizer_init_from_file(const char *path) { +Tokenizer *tokenizer_init_from_file(const char *path, const char *config_path) { // Dynamically allocate tokenizer memory Tokenizer *tokenizer = (Tokenizer *)malloc(sizeof(Tokenizer)); if (!tokenizer) { @@ -178,6 +94,7 @@ Tokenizer *tokenizer_init_from_file(const char *path) { // Load file to check existence and get vocabulary size. std::ifstream file(path); + std::ifstream config_file(config_path); if (!file) { LOGE("Error loading provided tokenizer JSON. File may not exist!\n"); return NULL; @@ -189,8 +106,13 @@ Tokenizer *tokenizer_init_from_file(const char *path) { LOGE("Error parsing the provided tokenizer JSON!"); return NULL; } - tokenizer->vocabSize = (*json_file)["model"]["vocab"].size(); - LOGI("postproc vocab size: %d\n", tokenizer->vocabSize); + + auto json_config = std::make_unique(json::parse(config_file)); + if (!json_config) { + LOGE("Error parsing the provided tokenizer config JSON!"); + return NULL; + } + tokenizer->vocabSize = (*json_config)["vocab_size"]; // Load vocabulary from tokenizer file tokenizer->handle = tokenizer_from_file(path); @@ -200,7 +122,7 @@ Tokenizer *tokenizer_init_from_file(const char *path) { } init_special_tokens(tokenizer); - init_non_speech_tokens(tokenizer); + init_non_speech_tokens(tokenizer, json_config); return tokenizer; } diff --git a/cpp/src/Text/Tokenizer.h b/cpp/src/Text/Tokenizer.h index 6aede7e..ccb1c9c 100644 --- a/cpp/src/Text/Tokenizer.h +++ b/cpp/src/Text/Tokenizer.h @@ -32,11 +32,13 @@ typedef struct { } Tokenizer; // Initialize the tokenizer -Tokenizer* tokenizer_init_from_file(const char* path); +Tokenizer* tokenizer_init_from_file(const char* path, const char* config_path); // Decode token IDs into a string char* tokenizer_decode(const Tokenizer* tokenizer, const int* tokens, int tokenCount, bool skipSpecialTokens); +bool tokenizer_is_multilingual(const Tokenizer* tokenizer); + // Convert token string to ID int tokenizer_convert_token_to_id(const Tokenizer* tokenizer, const char* tokenString); diff --git a/cpp/src/Text/post_proc.cpp b/cpp/src/Text/post_proc.cpp index fada65b..eee56ca 100644 --- a/cpp/src/Text/post_proc.cpp +++ b/cpp/src/Text/post_proc.cpp @@ -70,11 +70,16 @@ int PostProcModel::process(int idx, float* logits, int logits_size, vector& logits[_tokenizer->specialTokens.endOfTranscriptToken] = -1e9; logits[_tokenizer->specialTokens.blankToken] = -1e9; } - for (int i = 0; i < _tokenizer->numNonSpeechTokens; i++) { - auto token = _tokenizer->nonSpeechTokens[i]; - logits[token] = -1e9; + + // TODO: unblocking multilingual models + if (!tokenizer_is_multilingual(_tokenizer)) { + for (int i = 0; i < _tokenizer->numNonSpeechTokens; i++) { + auto token = _tokenizer->nonSpeechTokens[i]; + logits[token] = -1e9; + } + + apply_timestamp_rules(logits, logits_size, decoded_tokens); } - apply_timestamp_rules(logits, logits_size, decoded_tokens); // logits read_input_data(reinterpret_cast(logits), 0); auto inputs = get_input_ptrs(); diff --git a/cpp/src/TranscribeTask.cpp b/cpp/src/TranscribeTask.cpp index 3d2dda2..a95ad59 100644 --- a/cpp/src/TranscribeTask.cpp +++ b/cpp/src/TranscribeTask.cpp @@ -211,19 +211,31 @@ void Runtime::init() { LOGI("SoC: \tgeneric CPU (x86, arm64, etc) \n"); #endif - // TODO: this should be using std::filesystem.. std::string tokenizer_json = config.get_model_path() + "/tokenizer.json"; + std::string tokenizer_config_json = config.get_model_path() + "/config.json"; std::string melspectro_model = config.get_model_path() + "/MelSpectrogram.tflite"; std::string encoder_model = config.get_model_path() + "/AudioEncoder.tflite"; std::string decoder_model = config.get_model_path() + "/TextDecoder.tflite"; + std::vector required_files = {tokenizer_json, tokenizer_config_json, melspectro_model, encoder_model, + decoder_model}; + for (const auto& file : required_files) { + if (!std::filesystem::exists(file)) { + LOGE("File does not exist: %s", file.c_str()); + std::stringstream ss; + ss << file << " : required file not found"; + throw std::runtime_error(ss.str()); + } + } + melspectro = make_unique("mel_spectrogram"); encoder = make_unique("whisper_encoder"); decoder = TextDecoderFactory::CreateFromFile(decoder_model); // TODO move this to somewhere user accessible. - tokenizer = tokenizer_init_from_file(tokenizer_json.c_str()); + tokenizer = tokenizer_init_from_file(tokenizer_json.c_str(), tokenizer_config_json.c_str()); + postproc = make_unique(tokenizer); lib_dir = std::string(TRANSCRIBE_TASK_DEFAULT_LIB_DIR); @@ -344,14 +356,25 @@ void Runtime::encode_decode_postproc(float timestamp) { encoder->get_mutex()->lock(); encoder->read_input_data(melspectro_outputs[0].first, 0); encoder->get_mutex()->unlock(); - // Perform encoder inference encoder->invoke(true); - const auto& k_cache_cross = encoder_outputs[0].first; - const auto& v_cache_cross = encoder_outputs[1].first; + auto k_cache_cross = encoder->get_output_with_name("k_cache_cross"); + if (k_cache_cross.first == nullptr) { + k_cache_cross = encoder->get_output_with_name("k_cache"); + } + auto v_cache_cross = encoder->get_output_with_name("v_cache_cross"); + if (v_cache_cross.first == nullptr) { + v_cache_cross = encoder->get_output_with_name("v_cache"); + } + + if (k_cache_cross.first == nullptr || v_cache_cross.first == nullptr) { + LOGE("Failed to get k_cache_cross or v_cache_cross"); + return; + } + + decoder->bind_input_tensor(k_cache_cross.first, "k_cache_cross"); + decoder->bind_input_tensor(v_cache_cross.first, "v_cache_cross"); - decoder->bind_input_tensor(k_cache_cross, "k_cache_cross"); - decoder->bind_input_tensor(v_cache_cross, "v_cache_cross"); decoder->initialize_kv_cache(); constexpr const int MAX_DECODING_STEPS = 224; diff --git a/scripts/Dockerfile b/scripts/Dockerfile index fd73198..0bd3f91 100644 --- a/scripts/Dockerfile +++ b/scripts/Dockerfile @@ -37,7 +37,7 @@ ENV QNN_RUNTIME_ROOT=/opt/qnn-runtime ENV AXIE_ROOT=/src/AXIE ARG ANDROID_NDK_ZIP=$ANDROID_NDK_VERSION-linux.zip -ARG BAZEL_INSTALLER=bazel-6.5.0-installer-linux-x86_64.sh +ARG BAZEL_INSTALLER=bazel-7.4.1-installer-linux-x86_64.sh ARG BAZEL_DIR=/opt/bazel ARG QNN_RUNTIME=qnn-runtime-2.33.0.aar ARG QNN_TFLITE_DELEGATE=qnn-litert-delegate-2.33.0.aar diff --git a/scripts/adb_push.sh b/scripts/adb_push.sh index 46dd3aa..ed7260f 100755 --- a/scripts/adb_push.sh +++ b/scripts/adb_push.sh @@ -2,22 +2,32 @@ # For licensing see accompanying LICENSE file. # Copyright © 2024 Argmax, Inc. All rights reserved. +# This script is used to push the files and folders to all connected Android devices +# +# whisperkit-cli: CLI app to run transcription on Android device, with files input. +# libwhisperkit.so: main whisperkit lib +# models: all models for whisperkit in models/ folder +# test/jfk_441khz.m4a: sample test audio file for whisperkit +# scripts/run_on_android.sh: script to run the app on Android device + +# Usage: make adb-push [forced] +# or ./adb_push.sh [forced] +# If forced is provided, the script will push the files and folders to +# the Android device even if they already exist. Otherwise, it will +# skip the push if the files already exist. + CURRENT_DIR="$(dirname "$(realpath "$0")")" SOURCE_DIR="$CURRENT_DIR/.." WHISPERKIT_CLI="$SOURCE_DIR/build/android/whisperkit-cli" AXIE_TFLITE_LIB="$SOURCE_DIR/build/android/libwhisperkit.so" LOCAL_LIBS="$SOURCE_DIR/external/libs/android" -LOCAL_TINY_DIR="$SOURCE_DIR/models/openai_whisper-tiny" -LOCAL_BASE_DIR="$SOURCE_DIR/models/openai_whisper-base" -LOCAL_SMALL_DIR="$SOURCE_DIR/models/openai_whisper-small" +LOCAL_MODELS_DIR="$SOURCE_DIR/models" DEVICE_BIN_DIR="/data/local/tmp/bin" DEVICE_LIB_DIR="/data/local/tmp/lib" DEVICE_SDROOT_DIR="/sdcard/argmax/tflite" -DEVICE_TINY_DIR="${DEVICE_SDROOT_DIR}/models/openai_whisper-tiny" -DEVICE_BASE_DIR="${DEVICE_SDROOT_DIR}/models/openai_whisper-base" -DEVICE_SMALL_DIR="${DEVICE_SDROOT_DIR}/models/openai_whisper-small" +DEVICE_MODELS_DIR="${DEVICE_SDROOT_DIR}/models" DEVICE_INPUTS_DIR="${DEVICE_SDROOT_DIR}/inputs" EXEC_SCRIPT="$SOURCE_DIR/scripts/run_on_android.sh" @@ -88,8 +98,6 @@ do adb -s $DEVICE shell "chmod 777 $DEVICE_SDROOT_DIR/run_on_android.sh" adb -s $DEVICE shell "chmod 777 $DEVICE_BIN_DIR/whisperkit-cli" - push_if_not_exists "$LOCAL_TINY_DIR" "$DEVICE_TINY_DIR" $FORCED - push_if_not_exists "$LOCAL_BASE_DIR" "$DEVICE_BASE_DIR" $FORCED - push_if_not_exists "$LOCAL_SMALL_DIR" "$DEVICE_SMALL_DIR" $FORCED + push_if_not_exists "$LOCAL_MODELS_DIR" "$DEVICE_MODELS_DIR" $FORCED done diff --git a/scripts/build_tensorflow.sh b/scripts/build_tensorflow.sh index a3593b3..6b99966 100755 --- a/scripts/build_tensorflow.sh +++ b/scripts/build_tensorflow.sh @@ -22,8 +22,22 @@ export CC_OPT_FLAGS=-Wno-sign-compare # nightly tf commit needs bazel 7.4.1 USING_NIGHTLY_TF_COMMIT=1 +REQUIRED_BAZEL_VERSION="7.4.1" +BAZEL_BIN_DIR="/usr/local/lib/bazel/bin" +BAZEL_FILENAME="bazel-${REQUIRED_BAZEL_VERSION}-linux-x86_64" +BAZEL_PATH="${BAZEL_BIN_DIR}/${BAZEL_FILENAME}" + if [ "$USING_NIGHTLY_TF_COMMIT" = "1" ]; then - cd "/usr/local/lib/bazel/bin" && curl -fLO https://releases.bazel.build/7.4.1/release/bazel-7.4.1-linux-x86_64 && chmod +x bazel-7.4.1-linux-x86_64 + if [ -f "$BAZEL_PATH" ]; then + echo "Bazel $REQUIRED_BAZEL_VERSION already exists at $BAZEL_PATH. Skipping download." + else + echo "Downloading Bazel $REQUIRED_BAZEL_VERSION..." + mkdir -p "$BAZEL_BIN_DIR" + cd "$BAZEL_BIN_DIR" || exit 1 + curl -fLO "https://releases.bazel.build/${REQUIRED_BAZEL_VERSION}/release/${BAZEL_FILENAME}" + chmod +x "$BAZEL_FILENAME" + echo "Bazel $REQUIRED_BAZEL_VERSION downloaded to $BAZEL_PATH." + fi fi if [ "$PLATFORM" = "android" ]; then diff --git a/scripts/dev_env.sh b/scripts/dev_env.sh index 0314901..248b23d 100755 --- a/scripts/dev_env.sh +++ b/scripts/dev_env.sh @@ -42,7 +42,7 @@ if ! $(docker image inspect $IMAGE_NAME > /dev/null 2>&1) || $FORCE_REBUILD; the BUILD_DIR="$SOURCE_DIR/.source" echo "Checking and retrieving dependencies..." if command -v aria2c &> /dev/null; then - aria2c $ARIA_OPTIONS -d $BUILD_DIR https://github.com/bazelbuild/bazel/releases/download/6.5.0/bazel-6.5.0-installer-linux-x86_64.sh + aria2c $ARIA_OPTIONS -d $BUILD_DIR https://github.com/bazelbuild/bazel/releases/download/7.4.1/bazel-7.4.1-installer-linux-x86_64.sh aria2c $ARIA_OPTIONS -d $BUILD_DIR https://dl.google.com/android/repository/android-ndk-r25c-linux.zip aria2c $ARIA_OPTIONS -d $BUILD_DIR https://dl.google.com/android/repository/commandlinetools-linux-11076708_latest.zip aria2c $ARIA_OPTIONS -d $BUILD_DIR https://repo1.maven.org/maven2/com/qualcomm/qti/qnn-runtime/2.33.0/qnn-runtime-2.33.0.aar @@ -53,10 +53,11 @@ if ! $(docker image inspect $IMAGE_NAME > /dev/null 2>&1) || $FORCE_REBUILD; the fi if [ ! -d "$BUILD_DIR/tensorflow" ]; then echo "Cloning tensorflow..." - NIGHTLY_TF_PIN=e2ccfb8b4e0 - git clone --no-checkout --filter=blob:none https://github.com/tensorflow/tensorflow.git "$BUILD_DIR/tensorflow" + NIGHTLY_TF_PIN=e2ccfb8b4e03baee112efbcfe824dbac995afb38 + # Use shallow clone to reduce size + git clone --depth 1 --branch=nightly https://github.com/tensorflow/tensorflow.git "$BUILD_DIR/tensorflow" cd $BUILD_DIR/tensorflow - git fetch --depth=365 origin nightly + git fetch origin $NIGHTLY_TF_PIN --depth=1 git checkout $NIGHTLY_TF_PIN cd - fi diff --git a/scripts/download_models.sh b/scripts/download_models.sh index 398c429..f73430b 100755 --- a/scripts/download_models.sh +++ b/scripts/download_models.sh @@ -11,56 +11,127 @@ MODELS_DIR="$SOURCE_DIR/models" ARIA_OPTIONS="-x 8 -s 8 --continue --file-allocation=none" # Set directories +QCOM_TINY_EN_MODELS_DIR="$MODELS_DIR/quic_openai_whisper-tiny.en" +QCOM_BASE_EN_MODELS_DIR="$MODELS_DIR/quic_openai_whisper-base.en" +QCOM_SMALL_EN_MODELS_DIR="$MODELS_DIR/quic_openai_whisper-small.en" + TINY_MODELS_DIR="$MODELS_DIR/openai_whisper-tiny" BASE_MODELS_DIR="$MODELS_DIR/openai_whisper-base" -SMALL_MODELS_DIR="$MODELS_DIR/openai_whisper-small" - -function SAFE_MODEL_DIRECTORY(){ - if [ ! -d "${1}" ]; then - echo "mkdir ${1} .." - mkdir -p "${1}" - fi -} - -SAFE_MODEL_DIRECTORY $TINY_MODELS_DIR -SAFE_MODEL_DIRECTORY $BASE_MODELS_DIR -SAFE_MODEL_DIRECTORY $SMALL_MODELS_DIR +TINY_EN_MODELS_DIR="$MODELS_DIR/openai_whisper-tiny.en" +BASE_EN_MODELS_DIR="$MODELS_DIR/openai_whisper-base.en" # Download Whisper auxiliary models HF_ARGMAX_URL="https://huggingface.co/argmaxinc/whisperkit-litert/resolve/main" -HF_OPENAI_TINY_URL="https://huggingface.co/openai/whisper-tiny.en/resolve/main" - -if [ ! -f $TINY_MODELS_DIR/tokenizer.json ]; then - aria2c $ARIA_OPTIONS -d "$TINY_MODELS_DIR" -o tokenizer.json $HF_OPENAI_TINY_URL/tokenizer.json - aria2c $ARIA_OPTIONS -d "$TINY_MODELS_DIR" -o MelSpectrogram.tflite $HF_ARGMAX_URL/quic_openai_whisper-tiny.en/MelSpectrogram.tflite -fi -if [ ! -f $BASE_MODELS_DIR/tokenizer.json ]; then - cp $TINY_MODELS_DIR/tokenizer.json $BASE_MODELS_DIR/. - cp $TINY_MODELS_DIR/MelSpectrogram.tflite $BASE_MODELS_DIR/. -fi -if [ ! -f $SMALL_MODELS_DIR/tokenizer.json ]; then - cp $TINY_MODELS_DIR/tokenizer.json $SMALL_MODELS_DIR/. - cp $TINY_MODELS_DIR/MelSpectrogram.tflite $SMALL_MODELS_DIR/. -fi - # Download Qualcomm models HF_QUALCOMM_URL="https://huggingface.co/qualcomm" -if [ ! -f $TINY_MODELS_DIR/TextDecoder.tflite ]; then - aria2c $ARIA_OPTIONS -d "$TINY_MODELS_DIR" -o TextDecoder.tflite $HF_QUALCOMM_URL/Whisper-Tiny-En/resolve/main/WhisperDecoder.tflite -fi -if [ ! -f $TINY_MODELS_DIR/AudioEncoder.tflite ]; then - aria2c $ARIA_OPTIONS -d "$TINY_MODELS_DIR" -o AudioEncoder.tflite $HF_QUALCOMM_URL/Whisper-Tiny-En/resolve/main/WhisperEncoder.tflite -fi -if [ ! -f $BASE_MODELS_DIR/TextDecoder.tflite ]; then - aria2c $ARIA_OPTIONS -d "$BASE_MODELS_DIR" -o TextDecoder.tflite $HF_QUALCOMM_URL/Whisper-Base-En/resolve/main/WhisperDecoder.tflite -fi -if [ ! -f $BASE_MODELS_DIR/AudioEncoder.tflite ]; then - aria2c $ARIA_OPTIONS -d "$BASE_MODELS_DIR" -o AudioEncoder.tflite $HF_QUALCOMM_URL/Whisper-Base-En/resolve/main/WhisperEncoder.tflite -fi -if [ ! -f $SMALL_MODELS_DIR/TextDecoder.tflite ]; then - aria2c $ARIA_OPTIONS -d "$SMALL_MODELS_DIR" -o TextDecoder.tflite $HF_QUALCOMM_URL/Whisper-Small-En/resolve/main/WhisperDecoder.tflite -fi -if [ ! -f $SMALL_MODELS_DIR/AudioEncoder.tflite ]; then - aria2c $ARIA_OPTIONS -d "$SMALL_MODELS_DIR" -o AudioEncoder.tflite $HF_QUALCOMM_URL/Whisper-Small-En/resolve/main/WhisperEncoder.tflite -fi +HF_OPENAI_TINY_EN_URL="https://huggingface.co/openai/whisper-tiny.en/resolve/main" +HF_OPENAI_TINY_URL="https://huggingface.co/openai/whisper-tiny/resolve/main" +HF_OPENAI_BASE_EN_URL="https://huggingface.co/openai/whisper-base.en/resolve/main" +HF_OPENAI_BASE_URL="https://huggingface.co/openai/whisper-base/resolve/main" +HF_OPENAI_SMALL_EN_URL="https://huggingface.co/openai/whisper-small.en/resolve/main" +HF_OPENAI_SMALL_EN_URL="https://huggingface.co/openai/whisper-small/resolve/main" + +HF_QCOM_TINY_EN_URL=$HF_QUALCOMM_URL/Whisper-Tiny-En/resolve/main +HF_QCOM_BASE_EN_URL=$HF_QUALCOMM_URL/Whisper-Base-En/resolve/main +HF_QCOM_SMALL_EN_URL=$HF_QUALCOMM_URL/Whisper-Small-En/resolve/main + +HF_AX_TINY_EN_URL=$HF_ARGMAX_URL/openai_whisper-tiny.en +HF_AX_TINY_URL=$HF_ARGMAX_URL/openai_whisper-tiny +HF_AX_BASE_EN_URL=$HF_ARGMAX_URL/openai_whisper-base.en +HF_AX_BASE_URL=$HF_ARGMAX_URL/openai_whisper-base + +QCOM_MODELS=( + "$HF_QCOM_TINY_EN_URL" + "$HF_QCOM_BASE_EN_URL" + "$HF_QCOM_SMALL_EN_URL" +) + +ARGMAX_MODELS=( + "$HF_AX_TINY_EN_URL" + "$HF_AX_TINY_URL" + "$HF_AX_BASE_EN_URL" + "$HF_AX_BASE_URL" +) + +ARGMAX_MODEL_DIRECTORIES=( + "$TINY_EN_MODELS_DIR" + "$TINY_MODELS_DIR" + "$BASE_EN_MODELS_DIR" + "$BASE_MODELS_DIR" +) + + +QCOM_MODEL_DIRECTORIES=( + "$QCOM_TINY_EN_MODELS_DIR" + "$QCOM_BASE_EN_MODELS_DIR" + "$QCOM_SMALL_EN_MODELS_DIR" +) + +MODEL_DIRECTORIES=( + "$QCOM_TINY_EN_MODELS_DIR" + "$QCOM_BASE_EN_MODELS_DIR" + "$QCOM_SMALL_EN_MODELS_DIR" + "$TINY_MODELS_DIR" + "$BASE_MODELS_DIR" + "$TINY_EN_MODELS_DIR" + "$BASE_EN_MODELS_DIR" +) + +TOKENIZER_ENDPOINTS=( + "$HF_OPENAI_TINY_EN_URL" + "$HF_OPENAI_BASE_EN_URL" + "$HF_OPENAI_SMALL_EN_URL" + "$HF_OPENAI_TINY_URL" + "$HF_OPENAI_BASE_URL" + "$HF_OPENAI_TINY_EN_URL" + "$HF_OPENAI_BASE_EN_URL" +) + + +ARIA2_OPTS="--quiet=true --summary-interval=0 --download-result=hide --continue=true --max-connection-per-server=4" + +melspec_endpoint="${HF_ARGMAX_URL}/openai_whisper-tiny/MelSpectrogram.tflite" +aria2c $ARIA2_OPTS -d "/tmp/" -o MelSpectrogram.tflite $melspec_endpoint + +for i in "${!MODEL_DIRECTORIES[@]}"; do + model_dir="${MODEL_DIRECTORIES[$i]}" + tokenizer_endpoint="${TOKENIZER_ENDPOINTS[$i]}" + echo "Creating directory: $model_dir" + mkdir -p "$model_dir" + + echo "Downloading tokenizer.json and config.json from $tokenizer_endpoint to $model_dir" + + aria2c $ARIA2_OPTS -d "$model_dir" -o tokenizer.json "$tokenizer_endpoint/tokenizer.json" + aria2c $ARIA2_OPTS -d "$model_dir" -o config.json "$tokenizer_endpoint/config.json" + echo "Done with $model_dir" +done + +# Qualcomm models: rename to [AudioEncoder.tflite, TextDecoder.tflite] +# +# Argmax models: already named as [AudioEncoder.tflite, TextDecoder.tflite] + +for i in "${!QCOM_MODEL_DIRECTORIES[@]}"; do + model_dir="${QCOM_MODEL_DIRECTORIES[$i]}" + model_endpoint="${QCOM_MODELS[$i]}" + echo "Downloading QCOM [AudioEncoder, TextDecoder] from $model_endpoint to $model_dir" + echo "Downloading Encoder: ${model_endpoint}/WhisperEncoder.tflite" + echo "Downloading Decoder: ${model_endpoint}/WhisperDecoder.tflite" + aria2c $ARIA2_OPTS -d "$model_dir" -o TextDecoder.tflite $model_endpoint/WhisperDecoder.tflite + aria2c $ARIA2_OPTS -d "$model_dir" -o AudioEncoder.tflite $model_endpoint/WhisperEncoder.tflite + cp /tmp/MelSpectrogram.tflite "$model_dir/MelSpectrogram.tflite" + echo "Done with $model_dir" +done + +for i in "${!ARGMAX_MODEL_DIRECTORIES[@]}"; do + model_dir="${ARGMAX_MODEL_DIRECTORIES[$i]}" + model_endpoint="${ARGMAX_MODELS[$i]}" + echo "Downloading Argmax [AudioEncoder, TextDecoder] from $model_endpoint to $model_dir" + echo "Downloading Encoder: ${model_endpoint}/AudioEncoder.tflite" + echo "Downloading Decoder: ${model_endpoint}/TextDecoder.tflite" + echo "Downloading MelSpec: ${model_endpoint}/MelSpectrogram.tflite" + aria2c $ARIA2_OPTS -d "$model_dir" -o TextDecoder.tflite ${model_endpoint}/TextDecoder.tflite + aria2c $ARIA2_OPTS -d "$model_dir" -o AudioEncoder.tflite ${model_endpoint}/AudioEncoder.tflite + aria2c $ARIA2_OPTS -d "$model_dir" -o MelSpectrogram.tflite ${model_endpoint}/MelSpectrogram.tflite + echo "Done with $model_dir" +done