Skip to content
This repository was archived by the owner on Jan 24, 2026. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions android/whisperkit/detekt-baseline.xml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
<ManuallySuppressedIssues/>
<CurrentIssues>
<ID>LargeClass:ArgmaxModelDownloaderImplTest.kt$ArgmaxModelDownloaderImplTest</ID>
<ID>LongMethod:KtorHuggingFaceApiImpl.kt$KtorHuggingFaceApiImpl$private suspend fun FlowCollector&lt;Progress&gt;.downloadFilesWithRetry( from: Repo, revision: String, files: List&lt;String&gt;, baseDir: File, )</ID>
<ID>ThrowsCount:WhisperKit.kt$WhisperKit.Builder$@Throws(WhisperKitException::class) fun build(): WhisperKit</ID>
<ID>TooGenericExceptionCaught:KtorHuggingFaceApiImpl.kt$KtorHuggingFaceApiImpl$e: Exception</ID>
<ID>TooGenericExceptionCaught:WhisperKitImpl.kt$WhisperKitImpl$e: Exception</ID>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,33 +41,38 @@ interface HuggingFaceApi {
* Retrieves a list of file names from a HuggingFace repository that match the specified glob patterns.
*
* @param from The repository to search in
* @param revision The revision/branch/commit to use. Defaults to "main"
* @param globFilters List of glob patterns to filter files. If empty, all files are returned
* @return List of file names that match the filters
*/
suspend fun getFileNames(
from: Repo,
revision: String = "main",
globFilters: List<String> = listOf(),
): List<String>

/**
* Retrieves detailed information about a model from a HuggingFace repository.
*
* @param from The repository containing the model, needs to be type [RepoType.MODELS]
* @param revision The revision/branch/commit to use. Defaults to "main"
* @return [ModelInfo] object containing model details
* @throws IllegalArgumentException if the repository type is not [RepoType.MODELS]
*/
suspend fun getModelInfo(from: Repo): ModelInfo
suspend fun getModelInfo(from: Repo, revision: String = "main"): ModelInfo

/**
* Retrieves metadata for a specific file from a HuggingFace repository.
* This is useful for checking file sizes before downloading.
*
* @param from The repository containing the file
* @param revision The revision/branch/commit to use. Defaults to "main"
* @param filename The name of the file to get metadata for
* @return FileMetadata object containing file information
*/
suspend fun getFileMetadata(
from: Repo,
revision: String = "main",
filename: String,
): FileMetadata

Expand All @@ -76,11 +81,13 @@ interface HuggingFaceApi {
* This is useful for checking file sizes before downloading multiple files.
*
* @param from The repository containing the files
* @param revision The revision/branch/commit to use. Defaults to "main"
* @param globFilters List of glob patterns to filter files. If empty, all files are returned
* @return List of FileMetadata objects for files that match the filters
*/
suspend fun getFileMetadata(
from: Repo,
revision: String = "main",
globFilters: List<String> = listOf(),
): List<FileMetadata>

Expand All @@ -90,13 +97,15 @@ interface HuggingFaceApi {
* Progress is reported through a Flow of [Progress] objects.
*
* @param from The repository to download from
* @param revision The revision/branch/commit to use. Defaults to "main"
* @param globFilters List of glob patterns to filter which files to download
* @param baseDir The local directory where files will be downloaded
* @return Flow of [Progress] objects indicating download progress
* @throws IllegalStateException if a file download fails after the maximum number of retry attempts
*/
fun snapshot(
from: Repo,
revision: String = "main",
globFilters: List<String>,
baseDir: File,
): Flow<Progress>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,23 +52,32 @@ internal class KtorHuggingFaceApiImpl(

override suspend fun getFileNames(
from: Repo,
revision: String,
globFilters: List<String>,
): List<String> {
return getModelInfo(from).fileNames(globFilters)
return getModelInfo(from, revision).fileNames(globFilters)
}

override suspend fun getModelInfo(from: Repo): ModelInfo {
override suspend fun getModelInfo(from: Repo, revision: String): ModelInfo {
require(from.type == RepoType.MODELS) {
"$from needs to have type RepoType.MODELS"
}
return getHuggingFaceModel("/api/${from.type.typeName}/${from.id}")
var url = "/api/${from.type.typeName}/${from.id}"
if (revision != "main") {
url += "/revision/$revision"
}
logger.info("Calling HF API at url '$url'")
val result = getHuggingFaceModel(url)
logger.info("Got model info: $result")
return result
}

override suspend fun getFileMetadata(
from: Repo,
revision: String,
filename: String,
): FileMetadata {
val response = client.httpClient.head("/${from.id}/resolve/main/$filename")
val response = client.httpClient.head("/${from.id}/resolve/$revision/$filename")
val size =
response.headers["X-Linked-Size"]?.toLongOrNull()
?: response.headers["Content-Length"]?.toLongOrNull() ?: 0L
Expand All @@ -82,11 +91,12 @@ internal class KtorHuggingFaceApiImpl(

override suspend fun getFileMetadata(
from: Repo,
revision: String,
globFilters: List<String>,
): List<FileMetadata> {
val files = getFileNames(from, globFilters)
val files = getFileNames(from, revision, globFilters)
return files.map { filename ->
getFileMetadata(from, filename)
getFileMetadata(from, revision, filename)
}
}

Expand All @@ -113,24 +123,29 @@ internal class KtorHuggingFaceApiImpl(
*/
override fun snapshot(
from: Repo,
revision: String,
globFilters: List<String>,
baseDir: File,
): Flow<Progress> {
return flow {
baseDir.mkdirs()
getFileNames(from, globFilters).let { filesToDownload ->
getFileNames(from, revision, globFilters).let { filesToDownload ->
if (filesToDownload.isEmpty()) {
logger.info("No files to download, finish immediately")
logger.info(
"No files to download, finish immediately, for Repo(${from.id}, " +
"$revision) and glob filters: $globFilters",
)
emit(Progress(1.0f))
} else {
downloadFilesWithRetry(from, filesToDownload, baseDir)
downloadFilesWithRetry(from, revision, filesToDownload, baseDir)
}
}
}.flowOn(ioDispatcher)
}

private suspend fun FlowCollector<Progress>.downloadFilesWithRetry(
from: Repo,
revision: String,
files: List<String>,
baseDir: File,
) {
Expand All @@ -139,7 +154,7 @@ internal class KtorHuggingFaceApiImpl(
var totalBytes = 0L
val fileSizes = mutableMapOf<String, Long>()
files.forEach { file ->
val metadata = getFileMetadata(from, file)
val metadata = getFileMetadata(from, revision, file)
fileSizes[file] = metadata.size
totalBytes += metadata.size
}
Expand All @@ -154,10 +169,11 @@ internal class KtorHuggingFaceApiImpl(
val targetFile = File(baseDir, file)
targetFile.parentFile?.mkdirs()
var retryCount = 0
val url = "/${from.id}/resolve/$revision/$file"
while (true) {
try {
logger.info("Retry attempt $retryCount for $file")
client.httpClient.prepareGet("/${from.id}/resolve/main/$file")
client.httpClient.prepareGet(url)
.execute { response ->
val channel = response.bodyAsChannel()
targetFile.outputStream().use { output ->
Expand Down
Loading
Loading