diff --git a/android/whisperkit/detekt-baseline.xml b/android/whisperkit/detekt-baseline.xml index 4c4b90c..9ef856f 100644 --- a/android/whisperkit/detekt-baseline.xml +++ b/android/whisperkit/detekt-baseline.xml @@ -3,6 +3,7 @@ LargeClass:ArgmaxModelDownloaderImplTest.kt$ArgmaxModelDownloaderImplTest + LongMethod:KtorHuggingFaceApiImpl.kt$KtorHuggingFaceApiImpl$private suspend fun FlowCollector<Progress>.downloadFilesWithRetry( from: Repo, revision: String, files: List<String>, baseDir: File, ) ThrowsCount:WhisperKit.kt$WhisperKit.Builder$@Throws(WhisperKitException::class) fun build(): WhisperKit TooGenericExceptionCaught:KtorHuggingFaceApiImpl.kt$KtorHuggingFaceApiImpl$e: Exception TooGenericExceptionCaught:WhisperKitImpl.kt$WhisperKitImpl$e: Exception diff --git a/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/huggingface/HuggingFaceApi.kt b/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/huggingface/HuggingFaceApi.kt index f7703bd..514bafe 100644 --- a/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/huggingface/HuggingFaceApi.kt +++ b/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/huggingface/HuggingFaceApi.kt @@ -41,11 +41,13 @@ interface HuggingFaceApi { * Retrieves a list of file names from a HuggingFace repository that match the specified glob patterns. * * @param from The repository to search in + * @param revision The revision/branch/commit to use. Defaults to "main" * @param globFilters List of glob patterns to filter files. If empty, all files are returned * @return List of file names that match the filters */ suspend fun getFileNames( from: Repo, + revision: String = "main", globFilters: List = listOf(), ): List @@ -53,21 +55,24 @@ interface HuggingFaceApi { * Retrieves detailed information about a model from a HuggingFace repository. * * @param from The repository containing the model, needs to be type [RepoType.MODELS] + * @param revision The revision/branch/commit to use. Defaults to "main" * @return [ModelInfo] object containing model details * @throws IllegalArgumentException if the repository type is not [RepoType.MODELS] */ - suspend fun getModelInfo(from: Repo): ModelInfo + suspend fun getModelInfo(from: Repo, revision: String = "main"): ModelInfo /** * Retrieves metadata for a specific file from a HuggingFace repository. * This is useful for checking file sizes before downloading. * * @param from The repository containing the file + * @param revision The revision/branch/commit to use. Defaults to "main" * @param filename The name of the file to get metadata for * @return FileMetadata object containing file information */ suspend fun getFileMetadata( from: Repo, + revision: String = "main", filename: String, ): FileMetadata @@ -76,11 +81,13 @@ interface HuggingFaceApi { * This is useful for checking file sizes before downloading multiple files. * * @param from The repository containing the files + * @param revision The revision/branch/commit to use. Defaults to "main" * @param globFilters List of glob patterns to filter files. If empty, all files are returned * @return List of FileMetadata objects for files that match the filters */ suspend fun getFileMetadata( from: Repo, + revision: String = "main", globFilters: List = listOf(), ): List @@ -90,6 +97,7 @@ interface HuggingFaceApi { * Progress is reported through a Flow of [Progress] objects. * * @param from The repository to download from + * @param revision The revision/branch/commit to use. Defaults to "main" * @param globFilters List of glob patterns to filter which files to download * @param baseDir The local directory where files will be downloaded * @return Flow of [Progress] objects indicating download progress @@ -97,6 +105,7 @@ interface HuggingFaceApi { */ fun snapshot( from: Repo, + revision: String = "main", globFilters: List, baseDir: File, ): Flow diff --git a/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/huggingface/KtorHuggingFaceApiImpl.kt b/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/huggingface/KtorHuggingFaceApiImpl.kt index 5c66e1e..077e219 100644 --- a/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/huggingface/KtorHuggingFaceApiImpl.kt +++ b/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/huggingface/KtorHuggingFaceApiImpl.kt @@ -52,23 +52,32 @@ internal class KtorHuggingFaceApiImpl( override suspend fun getFileNames( from: Repo, + revision: String, globFilters: List, ): List { - return getModelInfo(from).fileNames(globFilters) + return getModelInfo(from, revision).fileNames(globFilters) } - override suspend fun getModelInfo(from: Repo): ModelInfo { + override suspend fun getModelInfo(from: Repo, revision: String): ModelInfo { require(from.type == RepoType.MODELS) { "$from needs to have type RepoType.MODELS" } - return getHuggingFaceModel("/api/${from.type.typeName}/${from.id}") + var url = "/api/${from.type.typeName}/${from.id}" + if (revision != "main") { + url += "/revision/$revision" + } + logger.info("Calling HF API at url '$url'") + val result = getHuggingFaceModel(url) + logger.info("Got model info: $result") + return result } override suspend fun getFileMetadata( from: Repo, + revision: String, filename: String, ): FileMetadata { - val response = client.httpClient.head("/${from.id}/resolve/main/$filename") + val response = client.httpClient.head("/${from.id}/resolve/$revision/$filename") val size = response.headers["X-Linked-Size"]?.toLongOrNull() ?: response.headers["Content-Length"]?.toLongOrNull() ?: 0L @@ -82,11 +91,12 @@ internal class KtorHuggingFaceApiImpl( override suspend fun getFileMetadata( from: Repo, + revision: String, globFilters: List, ): List { - val files = getFileNames(from, globFilters) + val files = getFileNames(from, revision, globFilters) return files.map { filename -> - getFileMetadata(from, filename) + getFileMetadata(from, revision, filename) } } @@ -113,17 +123,21 @@ internal class KtorHuggingFaceApiImpl( */ override fun snapshot( from: Repo, + revision: String, globFilters: List, baseDir: File, ): Flow { return flow { baseDir.mkdirs() - getFileNames(from, globFilters).let { filesToDownload -> + getFileNames(from, revision, globFilters).let { filesToDownload -> if (filesToDownload.isEmpty()) { - logger.info("No files to download, finish immediately") + logger.info( + "No files to download, finish immediately, for Repo(${from.id}, " + + "$revision) and glob filters: $globFilters", + ) emit(Progress(1.0f)) } else { - downloadFilesWithRetry(from, filesToDownload, baseDir) + downloadFilesWithRetry(from, revision, filesToDownload, baseDir) } } }.flowOn(ioDispatcher) @@ -131,6 +145,7 @@ internal class KtorHuggingFaceApiImpl( private suspend fun FlowCollector.downloadFilesWithRetry( from: Repo, + revision: String, files: List, baseDir: File, ) { @@ -139,7 +154,7 @@ internal class KtorHuggingFaceApiImpl( var totalBytes = 0L val fileSizes = mutableMapOf() files.forEach { file -> - val metadata = getFileMetadata(from, file) + val metadata = getFileMetadata(from, revision, file) fileSizes[file] = metadata.size totalBytes += metadata.size } @@ -154,10 +169,11 @@ internal class KtorHuggingFaceApiImpl( val targetFile = File(baseDir, file) targetFile.parentFile?.mkdirs() var retryCount = 0 + val url = "/${from.id}/resolve/$revision/$file" while (true) { try { logger.info("Retry attempt $retryCount for $file") - client.httpClient.prepareGet("/${from.id}/resolve/main/$file") + client.httpClient.prepareGet(url) .execute { response -> val channel = response.bodyAsChannel() targetFile.outputStream().use { output -> diff --git a/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/network/ArgmaxModelDownloaderImpl.kt b/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/network/ArgmaxModelDownloaderImpl.kt index 59c04ed..e83075e 100644 --- a/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/network/ArgmaxModelDownloaderImpl.kt +++ b/android/whisperkit/src/main/java/com/argmaxinc/whisperkit/network/ArgmaxModelDownloaderImpl.kt @@ -68,73 +68,75 @@ internal class ArgmaxModelDownloaderImpl( ), ) : ArgmaxModelDownloader { companion object { - private const val TOKENIZER_REPO = "TOKENIZER_REPO" - private const val CONFIG_REPO = "TOKENIZER_REPO" - private const val ENCODER_DECODER_REPO = "ENCODER_DECODER_REPO" - - // dir path under argmaxinc/whisperkit-litert to look up for MelSpectrogram.tflite - private const val FEATURE_EXTRACTOR_PATH = "FEATURE_EXTRACTOR_PATH" + /** + * Configuration for a model variant containing repository names and revisions. + * Each Pair contains (repository_name, revision) where: + * - repository_name: HuggingFace repo like "openai/whisper-tiny.en" or "qualcomm/Whisper-Tiny-En" + * - revision: branch/tag/commit hash like "main" or "8309cf4d4c30c69132f4f5e83ca8dcb7c17407ae" + * + * @property config Repository and revision for config.json file + * @property tokenizer Repository and revision for tokenizer.json file + * @property encoderDecoder Repository and revision for encoder/decoder model files + * @property featureExtractorPath Path within argmaxinc/whisperkit-litert repo for MelSpectrogram.tflite + */ + private data class ModelConfig( + val config: Pair, + val tokenizer: Pair, + val encoderDecoder: Pair, + val featureExtractorPath: String, + ) @OptIn(ExperimentalWhisperKit::class) - private val modelConfigs = - mapOf( - WhisperKit.Builder.OPENAI_TINY_EN to - mapOf( - CONFIG_REPO to "openai/whisper-tiny.en", - TOKENIZER_REPO to "openai/whisper-tiny.en", - ENCODER_DECODER_REPO to "openai_whisper-tiny.en", - FEATURE_EXTRACTOR_PATH to "openai_whisper-tiny.en", - ), - WhisperKit.Builder.OPENAI_BASE_EN to - mapOf( - CONFIG_REPO to "openai/whisper-base.en", - TOKENIZER_REPO to "openai/whisper-base.en", - ENCODER_DECODER_REPO to "openai_whisper-base.en", - FEATURE_EXTRACTOR_PATH to "openai_whisper-base.en", - ), - WhisperKit.Builder.OPENAI_TINY to - mapOf( - CONFIG_REPO to "openai/whisper-tiny", - TOKENIZER_REPO to "openai/whisper-tiny", - ENCODER_DECODER_REPO to "openai_whisper-tiny", - FEATURE_EXTRACTOR_PATH to "openai_whisper-tiny", - ), - WhisperKit.Builder.OPENAI_BASE to - mapOf( - CONFIG_REPO to "openai/whisper-base", - TOKENIZER_REPO to "openai/whisper-base", - ENCODER_DECODER_REPO to "openai_whisper-base", - FEATURE_EXTRACTOR_PATH to "openai_whisper-base", - ), - WhisperKit.Builder.OPENAI_SMALL_EN to - mapOf( - CONFIG_REPO to "openai/whisper-small.en", - TOKENIZER_REPO to "openai/whisper-small.en", - ENCODER_DECODER_REPO to "openai_whisper-small.en", - FEATURE_EXTRACTOR_PATH to "openai_whisper-small.en", - ), - WhisperKit.Builder.QUALCOMM_TINY_EN to - mapOf( - CONFIG_REPO to "openai/whisper-tiny.en", - TOKENIZER_REPO to "openai/whisper-tiny.en", - ENCODER_DECODER_REPO to "qualcomm/Whisper-Tiny-En", - FEATURE_EXTRACTOR_PATH to "quic_openai_whisper-tiny.en", - ), - WhisperKit.Builder.QUALCOMM_BASE_EN to - mapOf( - CONFIG_REPO to "openai/whisper-base.en", - TOKENIZER_REPO to "openai/whisper-base.en", - ENCODER_DECODER_REPO to "qualcomm/Whisper-Base-En", - FEATURE_EXTRACTOR_PATH to "quic_openai_whisper-base.en", - ), - WhisperKit.Builder.QUALCOMM_SMALL_EN to - mapOf( - CONFIG_REPO to "openai/whisper-small.en", - TOKENIZER_REPO to "openai/whisper-small.en", - ENCODER_DECODER_REPO to "qualcomm/Whisper-Small-En", - FEATURE_EXTRACTOR_PATH to "quic_openai_whisper-small.en", - ), - ) + private val modelConfigs = mapOf( + WhisperKit.Builder.OPENAI_TINY_EN to ModelConfig( + config = "openai/whisper-tiny.en" to "main", + tokenizer = "openai/whisper-tiny.en" to "main", + encoderDecoder = "openai_whisper-tiny.en" to "main", + featureExtractorPath = "openai_whisper-tiny.en", + ), + WhisperKit.Builder.OPENAI_BASE_EN to ModelConfig( + config = "openai/whisper-base.en" to "main", + tokenizer = "openai/whisper-base.en" to "main", + encoderDecoder = "openai_whisper-base.en" to "main", + featureExtractorPath = "openai_whisper-base.en", + ), + WhisperKit.Builder.OPENAI_TINY to ModelConfig( + config = "openai/whisper-tiny" to "main", + tokenizer = "openai/whisper-tiny" to "main", + encoderDecoder = "openai_whisper-tiny" to "main", + featureExtractorPath = "openai_whisper-tiny", + ), + WhisperKit.Builder.OPENAI_BASE to ModelConfig( + config = "openai/whisper-base" to "main", + tokenizer = "openai/whisper-base" to "main", + encoderDecoder = "openai_whisper-base" to "main", + featureExtractorPath = "openai_whisper-base", + ), + WhisperKit.Builder.OPENAI_SMALL_EN to ModelConfig( + config = "openai/whisper-small.en" to "main", + tokenizer = "openai/whisper-small.en" to "main", + encoderDecoder = "openai_whisper-small.en" to "main", + featureExtractorPath = "openai_whisper-small.en", + ), + WhisperKit.Builder.QUALCOMM_TINY_EN to ModelConfig( + config = "openai/whisper-tiny.en" to "main", + tokenizer = "openai/whisper-tiny.en" to "main", + encoderDecoder = "qualcomm/Whisper-Tiny-En" to "8309cf4d4c30c69132f4f5e83ca8dcb7c17407ae", + featureExtractorPath = "quic_openai_whisper-tiny.en", + ), + WhisperKit.Builder.QUALCOMM_BASE_EN to ModelConfig( + config = "openai/whisper-base.en" to "main", + tokenizer = "openai/whisper-base.en" to "main", + encoderDecoder = "qualcomm/Whisper-Base-En" to "4bc89f2f841ee034383a543b954a432febf10ccc", + featureExtractorPath = "quic_openai_whisper-base.en", + ), + WhisperKit.Builder.QUALCOMM_SMALL_EN to ModelConfig( + config = "openai/whisper-small.en" to "main", + tokenizer = "openai/whisper-small.en" to "main", + encoderDecoder = "qualcomm/Whisper-Small-En" to "9a356b7e31999f9141b0c54b4a6514ce2fe27597", + featureExtractorPath = "quic_openai_whisper-small.en", + ), + ) } /** @@ -187,22 +189,23 @@ internal class ArgmaxModelDownloaderImpl( // Clean up model directories after all downloads are complete if (!variant.startsWith("qualcomm/")) { // For OpenAI models, clean up the model directory - File(root, config[ENCODER_DECODER_REPO]!!).deleteRecursively() + File(root, config.encoderDecoder.first).deleteRecursively() } // Clean up feature extractor directory - File(root, config[FEATURE_EXTRACTOR_PATH]!!).deleteRecursively() + File(root, config.featureExtractorPath).deleteRecursively() } } @OptIn(ExperimentalCoroutinesApi::class) private fun downloadConfig( - config: Map, + config: ModelConfig, root: File, ): Flow { return flow { emit( huggingFaceApi.getFileMetadata( - from = Repo(config[CONFIG_REPO]!!, RepoType.MODELS), + from = Repo(config.config.first, RepoType.MODELS), + revision = config.config.second, filename = "config.json", ), ) @@ -212,7 +215,8 @@ internal class ArgmaxModelDownloaderImpl( flowOf(HuggingFaceApi.Progress(1.0f)) } else { huggingFaceApi.snapshot( - from = Repo(config[CONFIG_REPO]!!, RepoType.MODELS), + from = Repo(config.config.first, RepoType.MODELS), + revision = config.config.second, globFilters = listOf("config.json"), baseDir = root, ) @@ -222,13 +226,14 @@ internal class ArgmaxModelDownloaderImpl( @OptIn(ExperimentalCoroutinesApi::class) private fun downloadTokenizer( - config: Map, + config: ModelConfig, root: File, ): Flow { return flow { emit( huggingFaceApi.getFileMetadata( - from = Repo(config[TOKENIZER_REPO]!!, RepoType.MODELS), + from = Repo(config.tokenizer.first, RepoType.MODELS), + revision = config.tokenizer.second, filename = "tokenizer.json", ), ) @@ -238,7 +243,8 @@ internal class ArgmaxModelDownloaderImpl( flowOf(HuggingFaceApi.Progress(1.0f)) } else { huggingFaceApi.snapshot( - from = Repo(config[TOKENIZER_REPO]!!, RepoType.MODELS), + from = Repo(config.tokenizer.first, RepoType.MODELS), + revision = config.tokenizer.second, globFilters = listOf("tokenizer.json"), baseDir = root, ) @@ -249,7 +255,7 @@ internal class ArgmaxModelDownloaderImpl( @OptIn(ExperimentalCoroutinesApi::class) private fun downloadEncoderDecoder( variant: String, - config: Map, + config: ModelConfig, root: File, ): Flow { return if (variant.startsWith("qualcomm/")) { @@ -261,13 +267,14 @@ internal class ArgmaxModelDownloaderImpl( @OptIn(ExperimentalCoroutinesApi::class) private fun downloadQualcommEncoderDecoder( - config: Map, + config: ModelConfig, root: File, ): Flow { return flow { emit( huggingFaceApi.getFileMetadata( - from = Repo(config[ENCODER_DECODER_REPO]!!, RepoType.MODELS), + from = Repo(config.encoderDecoder.first, RepoType.MODELS), + revision = config.encoderDecoder.second, globFilters = listOf("WhisperEncoder.tflite", "WhisperDecoder.tflite"), ), ) @@ -283,7 +290,8 @@ internal class ArgmaxModelDownloaderImpl( flowOf(HuggingFaceApi.Progress(1.0f)) } else { huggingFaceApi.snapshot( - from = Repo(config[ENCODER_DECODER_REPO]!!, RepoType.MODELS), + from = Repo(config.encoderDecoder.first, RepoType.MODELS), + revision = config.encoderDecoder.second, globFilters = listOf("WhisperEncoder.tflite", "WhisperDecoder.tflite"), baseDir = root, ).onCompletion { @@ -296,14 +304,15 @@ internal class ArgmaxModelDownloaderImpl( @OptIn(ExperimentalCoroutinesApi::class) private fun downloadArgmaxEncoderDecoder( - config: Map, + config: ModelConfig, root: File, ): Flow { - val modelDir = config[ENCODER_DECODER_REPO]!! + val modelDir = config.encoderDecoder.first return flow { emit( huggingFaceApi.getFileMetadata( from = Repo("argmaxinc/whisperkit-litert", RepoType.MODELS), + revision = config.encoderDecoder.second, globFilters = listOf( "$modelDir/AudioEncoder.tflite", @@ -343,14 +352,14 @@ internal class ArgmaxModelDownloaderImpl( @OptIn(ExperimentalCoroutinesApi::class) private fun downloadFeatureExtractor( - config: Map, + config: ModelConfig, root: File, ): Flow { return flow { emit( huggingFaceApi.getFileMetadata( from = Repo("argmaxinc/whisperkit-litert", RepoType.MODELS), - filename = "${config[FEATURE_EXTRACTOR_PATH]!!}/MelSpectrogram.tflite", + filename = "${config.featureExtractorPath}/MelSpectrogram.tflite", ), ) }.flatMapLatest { metadata -> @@ -360,10 +369,10 @@ internal class ArgmaxModelDownloaderImpl( } else { huggingFaceApi.snapshot( from = Repo("argmaxinc/whisperkit-litert", RepoType.MODELS), - globFilters = listOf("${config[FEATURE_EXTRACTOR_PATH]!!}/MelSpectrogram.tflite"), + globFilters = listOf("${config.featureExtractorPath}/MelSpectrogram.tflite"), baseDir = root, ).onCompletion { - val modelDir = config[FEATURE_EXTRACTOR_PATH]!! + val modelDir = config.featureExtractorPath File(root, "$modelDir/MelSpectrogram.tflite").renameTo( File( root, diff --git a/android/whisperkit/src/test/java/com/argmaxinc/whisperkit/huggingface/KtorHuggingFaceApiImplTest.kt b/android/whisperkit/src/test/java/com/argmaxinc/whisperkit/huggingface/KtorHuggingFaceApiImplTest.kt index bf4d6d8..e3c748e 100644 --- a/android/whisperkit/src/test/java/com/argmaxinc/whisperkit/huggingface/KtorHuggingFaceApiImplTest.kt +++ b/android/whisperkit/src/test/java/com/argmaxinc/whisperkit/huggingface/KtorHuggingFaceApiImplTest.kt @@ -163,7 +163,7 @@ internal class KtorHuggingFaceApiImplTest { val repo = Repo("test-repo", RepoType.MODELS) // download test1.txt and test2.txt - api.snapshot(repo, listOf("test*"), testDir).test { + api.snapshot(repo, "main", listOf("test*"), testDir).test { // Verify first progress (after test1.txt) val firstProgress = awaitItem() assertTrue(firstProgress.fractionCompleted < 1.0f) @@ -203,7 +203,7 @@ internal class KtorHuggingFaceApiImplTest { UnconfinedTestDispatcher(testScheduler), ) val repo = Repo("test-repo", RepoType.MODELS) - api.snapshot(repo, listOf("test2.txt"), testDir).test(timeout = 10.seconds) { + api.snapshot(repo, "main", listOf("test2.txt"), testDir).test(timeout = 10.seconds) { awaitItem() // first progress val error = awaitError() // exception thrown assertTrue(error is IllegalStateException) @@ -239,7 +239,7 @@ internal class KtorHuggingFaceApiImplTest { val repo = Repo("test-repo", RepoType.MODELS) val progressValues = mutableListOf() - api.snapshot(repo, listOf("test2.txt"), testDir).test { + api.snapshot(repo, "main", listOf("test2.txt"), testDir).test { val firstProgress = awaitItem() progressValues.add(firstProgress.fractionCompleted) verify { mockLogger.info("Retry attempt 0 for test2.txt") } @@ -276,7 +276,7 @@ internal class KtorHuggingFaceApiImplTest { ) val repo = Repo("test-repo", RepoType.MODELS) val globFilters = listOf("test*") - assertEquals(listOf("test1.txt", "test2.txt"), api.getFileNames(repo, globFilters)) + assertEquals(listOf("test1.txt", "test2.txt"), api.getFileNames(repo, "main", globFilters)) } @Test @@ -318,11 +318,16 @@ internal class KtorHuggingFaceApiImplTest { ) val repo = Repo("test-repo", RepoType.MODELS) - api.snapshot(repo, listOf("nonexistent*"), testDir).test { + api.snapshot(repo, "main", listOf("nonexistent*"), testDir).test { // Should emit a single progress with 1.0f val progress = awaitItem() assertTrue(progress.isDone) - verify { mockLogger.info("No files to download, finish immediately") } + verify { + mockLogger.info( + "No files to download, finish immediately, " + + "for Repo(test-repo, main) and glob filters: [nonexistent*]", + ) + } awaitComplete() } } @@ -335,11 +340,16 @@ internal class KtorHuggingFaceApiImplTest { UnconfinedTestDispatcher(testScheduler), ) val repo = Repo("test-repo", RepoType.MODELS) - api.snapshot(repo, emptyList(), testDir).test { + api.snapshot(repo, "main", emptyList(), testDir).test { val progress = awaitItem() awaitComplete() assertTrue(progress.isDone) - verify { mockLogger.info("No files to download, finish immediately") } + verify { + mockLogger.info( + "No files to download, finish immediately, " + + "for Repo(test-repo, main) and glob filters: []", + ) + } } } @@ -352,7 +362,7 @@ internal class KtorHuggingFaceApiImplTest { ) val repo = Repo("test-repo", RepoType.MODELS) val nestedDir = File(testDir, "nested/path") - api.snapshot(repo, listOf("test1.txt"), nestedDir).test { + api.snapshot(repo, "main", listOf("test1.txt"), nestedDir).test { awaitItem() // progress awaitComplete() } diff --git a/android/whisperkit/src/test/java/com/argmaxinc/whisperkit/network/ArgmaxModelDownloaderImplTest.kt b/android/whisperkit/src/test/java/com/argmaxinc/whisperkit/network/ArgmaxModelDownloaderImplTest.kt index 28c1cbc..5259f57 100644 --- a/android/whisperkit/src/test/java/com/argmaxinc/whisperkit/network/ArgmaxModelDownloaderImplTest.kt +++ b/android/whisperkit/src/test/java/com/argmaxinc/whisperkit/network/ArgmaxModelDownloaderImplTest.kt @@ -49,6 +49,7 @@ class ArgmaxModelDownloaderImplTest { coEvery { huggingFaceApi.getFileMetadata( from = eq(Repo(expectedTokenizerRepo, RepoType.MODELS)), + revision = eq("main"), filename = eq("config.json"), ) } returns HuggingFaceApi.FileMetadata(500L, "config.json") @@ -57,6 +58,7 @@ class ArgmaxModelDownloaderImplTest { coEvery { huggingFaceApi.getFileMetadata( from = eq(Repo(expectedTokenizerRepo, RepoType.MODELS)), + revision = eq("main"), filename = eq("tokenizer.json"), ) } returns HuggingFaceApi.FileMetadata(1000L, "tokenizer.json") @@ -74,6 +76,7 @@ class ArgmaxModelDownloaderImplTest { Repo("argmaxinc/whisperkit-litert", RepoType.MODELS) }, ), + revision = any(), globFilters = eq(expectedEncoderDecoderGlobFilters), ) } returns @@ -91,6 +94,7 @@ class ArgmaxModelDownloaderImplTest { coEvery { huggingFaceApi.getFileMetadata( from = eq(Repo("argmaxinc/whisperkit-litert", RepoType.MODELS)), + revision = eq("main"), filename = eq("$expectedMelSpectrogramPath/MelSpectrogram.tflite"), ) } returns @@ -162,6 +166,7 @@ class ArgmaxModelDownloaderImplTest { verify(exactly = 0) { huggingFaceApi.snapshot( from = any(), + revision = any(), globFilters = any(), baseDir = any(), ) @@ -235,6 +240,7 @@ class ArgmaxModelDownloaderImplTest { every { huggingFaceApi.snapshot( from = eq(Repo(expectedTokenizerRepo, RepoType.MODELS)), + revision = eq("main"), globFilters = eq(listOf("config.json")), baseDir = eq(root), ) @@ -244,6 +250,7 @@ class ArgmaxModelDownloaderImplTest { every { huggingFaceApi.snapshot( from = eq(Repo(expectedTokenizerRepo, RepoType.MODELS)), + revision = eq("main"), globFilters = eq(listOf("tokenizer.json")), baseDir = eq(root), ) @@ -253,6 +260,7 @@ class ArgmaxModelDownloaderImplTest { every { huggingFaceApi.snapshot( from = eq(Repo(expectedEncoderDecoderRepo, RepoType.MODELS)), + revision = any(), globFilters = eq(expectedEncoderDecoderGlobFilters), baseDir = eq(root), ) @@ -262,6 +270,7 @@ class ArgmaxModelDownloaderImplTest { every { huggingFaceApi.snapshot( from = eq(Repo("argmaxinc/whisperkit-litert", RepoType.MODELS)), + revision = eq("main"), globFilters = eq(listOf("$expectedMelSpectrogramPath/MelSpectrogram.tflite")), baseDir = eq(root), ) @@ -301,6 +310,7 @@ class ArgmaxModelDownloaderImplTest { verify(exactly = 1) { huggingFaceApi.snapshot( from = eq(Repo(expectedTokenizerRepo, RepoType.MODELS)), + revision = eq("main"), globFilters = eq(listOf("config.json")), baseDir = eq(root), ) @@ -308,6 +318,7 @@ class ArgmaxModelDownloaderImplTest { verify(exactly = 1) { huggingFaceApi.snapshot( from = eq(Repo(expectedTokenizerRepo, RepoType.MODELS)), + revision = eq("main"), globFilters = eq(listOf("tokenizer.json")), baseDir = eq(root), ) @@ -325,6 +336,7 @@ class ArgmaxModelDownloaderImplTest { Repo("argmaxinc/whisperkit-litert", RepoType.MODELS) }, ), + revision = any(), globFilters = eq(expectedEncoderDecoderGlobFilters), baseDir = eq(root), ) @@ -332,6 +344,7 @@ class ArgmaxModelDownloaderImplTest { verify(exactly = 1) { huggingFaceApi.snapshot( from = eq(Repo("argmaxinc/whisperkit-litert", RepoType.MODELS)), + revision = eq("main"), globFilters = eq(listOf("$expectedMelSpectrogramPath/MelSpectrogram.tflite")), baseDir = eq(root), ) @@ -647,6 +660,7 @@ class ArgmaxModelDownloaderImplTest { coEvery { huggingFaceApi.getFileMetadata( from = eq(Repo("openai/whisper-tiny.en", RepoType.MODELS)), + revision = eq("main"), filename = eq("config.json"), ) } returns HuggingFaceApi.FileMetadata(500L, "config.json") @@ -655,6 +669,7 @@ class ArgmaxModelDownloaderImplTest { coEvery { huggingFaceApi.getFileMetadata( from = eq(Repo("openai/whisper-tiny.en", RepoType.MODELS)), + revision = eq("main"), filename = eq("tokenizer.json"), ) } returns HuggingFaceApi.FileMetadata(1000L, "tokenizer.json") @@ -663,6 +678,7 @@ class ArgmaxModelDownloaderImplTest { coEvery { huggingFaceApi.getFileMetadata( from = eq(Repo("qualcomm/Whisper-Tiny-En", RepoType.MODELS)), + revision = any(), globFilters = eq(listOf("WhisperEncoder.tflite", "WhisperDecoder.tflite")), ) } returns @@ -675,6 +691,7 @@ class ArgmaxModelDownloaderImplTest { coEvery { huggingFaceApi.getFileMetadata( from = eq(Repo("argmaxinc/whisperkit-litert", RepoType.MODELS)), + revision = eq("main"), filename = eq("quic_openai_whisper-tiny.en/MelSpectrogram.tflite"), ) } returns @@ -687,6 +704,7 @@ class ArgmaxModelDownloaderImplTest { every { huggingFaceApi.snapshot( from = eq(Repo("openai/whisper-tiny.en", RepoType.MODELS)), + revision = eq("main"), globFilters = eq(listOf("config.json")), baseDir = eq(root), ) @@ -696,6 +714,7 @@ class ArgmaxModelDownloaderImplTest { every { huggingFaceApi.snapshot( from = eq(Repo("openai/whisper-tiny.en", RepoType.MODELS)), + revision = eq("main"), globFilters = eq(listOf("tokenizer.json")), baseDir = eq(root), ) @@ -705,6 +724,7 @@ class ArgmaxModelDownloaderImplTest { every { huggingFaceApi.snapshot( from = eq(Repo("qualcomm/Whisper-Tiny-En", RepoType.MODELS)), + revision = any(), globFilters = eq(listOf("WhisperEncoder.tflite", "WhisperDecoder.tflite")), baseDir = eq(root), ) @@ -714,6 +734,7 @@ class ArgmaxModelDownloaderImplTest { every { huggingFaceApi.snapshot( from = eq(Repo("argmaxinc/whisperkit-litert", RepoType.MODELS)), + revision = eq("main"), globFilters = eq(listOf("quic_openai_whisper-tiny.en/MelSpectrogram.tflite")), baseDir = eq(root), )