From 649ced61197a6a0202db301429eb7da98b53b1fd Mon Sep 17 00:00:00 2001 From: Bhagirath Mehta Date: Mon, 4 May 2026 06:35:49 -0500 Subject: [PATCH] Add SDK download cancellation support Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../cpp/live-audio-transcription/README.md | 4 + samples/cpp/live-audio-transcription/main.cpp | 11 +- sdk/cpp/include/foundry_local_manager.h | 11 +- sdk/cpp/include/model.h | 18 +- sdk/cpp/sample/main.cpp | 27 ++- sdk/cpp/src/core.h | 2 +- sdk/cpp/src/core_helpers.h | 103 ++++++++-- sdk/cpp/src/flcore_native.h | 18 +- sdk/cpp/src/foundry_local_internal_core.h | 13 +- sdk/cpp/src/foundry_local_manager.cpp | 70 +++---- sdk/cpp/src/model.cpp | 45 ++--- sdk/cpp/test/ep_test.cpp | 6 +- sdk/cpp/test/model_variant_test.cpp | 82 +++++++- sdk/cs/README.md | 12 ++ sdk/cs/src/Detail/CoreInterop.cs | 29 ++- sdk/cs/src/Detail/ModelVariant.cs | 33 ++- sdk/cs/src/FoundryLocalManager.cs | 16 +- .../DownloadCancellationTests.cs | 122 ++++++++++++ sdk/cs/test/FoundryLocal.Tests/Utils.cs | 3 +- sdk/js/README.md | 15 +- sdk/js/src/detail/coreInterop.ts | 42 +++- sdk/js/src/detail/model.ts | 7 +- sdk/js/src/detail/modelVariant.ts | 24 ++- sdk/js/src/foundryLocalManager.ts | 90 ++++++++- sdk/js/src/imodel.ts | 8 +- sdk/js/test/detail/coreInterop.test.ts | 26 +++ sdk/js/test/foundryLocalManager.test.ts | 188 +++++++++++++----- sdk/js/test/model.test.ts | 79 +++++++- sdk/python/README.md | 17 +- sdk/python/requirements.txt | 9 +- sdk/python/src/detail/core_interop.py | 57 ++++-- sdk/python/src/detail/model.py | 6 +- sdk/python/src/detail/model_variant.py | 24 ++- sdk/python/src/foundry_local_manager.py | 26 ++- sdk/python/src/imodel.py | 6 +- sdk/python/test/test_foundry_local_manager.py | 43 ++++ sdk/python/test/test_model.py | 78 ++++++++ sdk/rust/README.md | 22 ++ sdk/rust/src/detail/core_interop.rs | 148 +++++++++++++- sdk/rust/src/detail/model.rs | 19 +- sdk/rust/src/detail/model_variant.rs | 61 +++++- sdk/rust/src/foundry_local_manager.rs | 70 ++++++- 42 files changed, 1447 insertions(+), 243 deletions(-) create mode 100644 sdk/cs/test/FoundryLocal.Tests/DownloadCancellationTests.cs create mode 100644 sdk/js/test/detail/coreInterop.test.ts diff --git a/samples/cpp/live-audio-transcription/README.md b/samples/cpp/live-audio-transcription/README.md index a9fca9774..3e8b8e8d6 100644 --- a/samples/cpp/live-audio-transcription/README.md +++ b/samples/cpp/live-audio-transcription/README.md @@ -26,3 +26,7 @@ g++ -std=c++20 main.cpp -lfoundry_local -o live-audio-transcription-example # Synthetic 440Hz sine wave (no microphone needed) ./live-audio-transcription-example --synth ``` + +Press `Ctrl+C` to request a graceful stop. The sample passes that signal to +execution-provider and model downloads so long-running downloads can be +cancelled before transcription starts. diff --git a/samples/cpp/live-audio-transcription/main.cpp b/samples/cpp/live-audio-transcription/main.cpp index 1a3341e4c..b50be93ac 100644 --- a/samples/cpp/live-audio-transcription/main.cpp +++ b/samples/cpp/live-audio-transcription/main.cpp @@ -122,7 +122,8 @@ int main(int argc, char* argv[]) { foundry_local::Manager::Create(config); auto& manager = foundry_local::Manager::Instance(); - manager.EnsureEpsDownloaded(); + auto isCancellationRequested = [] { return !g_running.load(); }; + manager.DownloadAndRegisterEps(nullptr, isCancellationRequested); auto& catalog = manager.GetCatalog(); auto* model = catalog.GetModel("nemotron-speech-streaming-en-0.6b"); @@ -131,9 +132,11 @@ int main(int argc, char* argv[]) { } std::cout << "Downloading model (if needed)..." << std::endl; - model->Download([](float pct) { - std::cout << "\rDownloading: " << pct << "% " << std::flush; - }); + model->Download( + [](float pct) { + std::cout << "\rDownloading: " << pct << "% " << std::flush; + }, + isCancellationRequested); std::cout << std::endl; std::cout << "Loading model..." << std::endl; model->Load(); diff --git a/sdk/cpp/include/foundry_local_manager.h b/sdk/cpp/include/foundry_local_manager.h index 51af7f161..6c3d07081 100644 --- a/sdk/cpp/include/foundry_local_manager.h +++ b/sdk/cpp/include/foundry_local_manager.h @@ -83,15 +83,22 @@ namespace foundry_local { /// Download and register all available execution providers. /// @param progressCallback Optional callback invoked with (ep_name, percent) during download. + /// @param isCancellationRequested Optional callback checked on each progress update. + /// Return true to cancel the in-progress download. /// @return Result describing which EPs were registered or failed. - EpDownloadResult DownloadAndRegisterEps(EpProgressCallback progressCallback = nullptr) const; + EpDownloadResult DownloadAndRegisterEps( + EpProgressCallback progressCallback = nullptr, + CancellationCallback isCancellationRequested = nullptr) const; /// Download and register specific execution providers by name. /// @param names EP names to download (as returned by DiscoverEps). /// @param progressCallback Optional callback invoked with (ep_name, percent) during download. + /// @param isCancellationRequested Optional callback checked on each progress update. + /// Return true to cancel the in-progress download. /// @return Result describing which EPs were registered or failed. EpDownloadResult DownloadAndRegisterEps(const std::vector& names, - EpProgressCallback progressCallback = nullptr) const; + EpProgressCallback progressCallback = nullptr, + CancellationCallback isCancellationRequested = nullptr) const; private: explicit Manager(Configuration configuration, ILogger* logger); diff --git a/sdk/cpp/include/model.h b/sdk/cpp/include/model.h index b52fae76c..f136af4ee 100644 --- a/sdk/cpp/include/model.h +++ b/sdk/cpp/include/model.h @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -33,6 +34,7 @@ namespace foundry_local { #endif using DownloadProgressCallback = std::function; + using CancellationCallback = std::function; class IModel { public: @@ -43,7 +45,13 @@ namespace foundry_local { virtual bool IsLoaded() const = 0; virtual bool IsCached() const = 0; virtual const std::filesystem::path& GetPath() const = 0; - virtual void Download(DownloadProgressCallback onProgress = nullptr) = 0; + + /// Download the model to the local cache. + /// @param onProgress Optional callback receiving percentage progress. Return true to continue. + /// @param isCancellationRequested Optional callback checked on each progress update. + /// Return true to cancel the in-progress download. + virtual void Download(DownloadProgressCallback onProgress = nullptr, + CancellationCallback isCancellationRequested = nullptr) = 0; virtual void Load() = 0; virtual void Unload() = 0; virtual void RemoveFromCache() = 0; @@ -123,7 +131,8 @@ namespace foundry_local { const ModelInfo& GetInfo() const; const std::filesystem::path& GetPath() const override; - void Download(DownloadProgressCallback onProgress = nullptr) override; + void Download(DownloadProgressCallback onProgress = nullptr, + CancellationCallback isCancellationRequested = nullptr) override; void Load() override; bool IsLoaded() const override; @@ -158,8 +167,9 @@ namespace foundry_local { bool IsLoaded() const override { return SelectedVariant().IsLoaded(); } bool IsCached() const override { return SelectedVariant().IsCached(); } const std::filesystem::path& GetPath() const override { return SelectedVariant().GetPath(); } - void Download(DownloadProgressCallback onProgress = nullptr) override { - SelectedVariant().Download(std::move(onProgress)); + void Download(DownloadProgressCallback onProgress = nullptr, + CancellationCallback isCancellationRequested = nullptr) override { + SelectedVariant().Download(std::move(onProgress), std::move(isCancellationRequested)); } void Load() override { SelectedVariant().Load(); } void Unload() override { SelectedVariant().Unload(); } diff --git a/sdk/cpp/sample/main.cpp b/sdk/cpp/sample/main.cpp index 7c377da99..b82047800 100644 --- a/sdk/cpp/sample/main.cpp +++ b/sdk/cpp/sample/main.cpp @@ -4,6 +4,8 @@ #include "foundry_local.h" #include +#include +#include #include #include #include @@ -14,6 +16,18 @@ using namespace foundry_local; +namespace { +std::atomic g_cancelRequested{false}; + +void SignalHandler(int /*signum*/) { + g_cancelRequested.store(true); +} + +bool IsCancellationRequested() { + return g_cancelRequested.load(); +} +} // namespace + // --------------------------------------------------------------------------- // Logger // --------------------------------------------------------------------------- @@ -118,7 +132,8 @@ void ChatNonStreaming(Manager& manager, const std::string& alias) { PreferCpuVariant(*concreteModel); } - model->Download([](float pct) { printf("\rDownloading: %5.1f%%", pct); fflush(stdout); return true; }); + model->Download([](float pct) { printf("\rDownloading: %5.1f%%", pct); fflush(stdout); return true; }, + IsCancellationRequested); std::cout << "\n"; model->Load(); @@ -211,7 +226,8 @@ void TranscribeAudio(Manager& manager, const std::string& alias, const std::stri PreferCpuVariant(*concreteModel); } - model->Download([](float pct) { printf("\rDownloading: %5.1f%%", pct); fflush(stdout); return true; }); + model->Download([](float pct) { printf("\rDownloading: %5.1f%%", pct); fflush(stdout); return true; }, + IsCancellationRequested); std::cout << "\n"; model->Load(); @@ -263,7 +279,8 @@ void ChatWithToolCalling(Manager& manager, const std::string& alias) { PreferCpuVariant(*concreteModel); } - model->Download([](float pct) { printf("\rDownloading: %5.1f%%", pct); fflush(stdout); return true; }); + model->Download([](float pct) { printf("\rDownloading: %5.1f%%", pct); fflush(stdout); return true; }, + IsCancellationRequested); std::cout << "\n"; model->Load(); @@ -376,6 +393,8 @@ int main(int argc, char* argv[]) { const std::string audioPath = (argc > 3) ? argv[3] : ""; try { + std::signal(SIGINT, SignalHandler); + StdLogger logger; Manager::Create({"SampleApp"}, &logger); auto& manager = Manager::Instance(); @@ -399,7 +418,7 @@ int main(int argc, char* argv[]) { } printf("\r %-30s %5.1f%%", epName.c_str(), percent); fflush(stdout); - }); + }, IsCancellationRequested); if (!currentEp.empty()) std::cout << "\n"; } else { std::cout << "\nNo execution providers to download.\n"; diff --git a/sdk/cpp/src/core.h b/sdk/cpp/src/core.h index eb598373d..a69f961cc 100644 --- a/sdk/cpp/src/core.h +++ b/sdk/cpp/src/core.h @@ -187,7 +187,7 @@ namespace foundry_local { std::unique_ptr responseGuard(&response, safeDeleter); if (callback != nullptr) { - execCbCmd_(&request, &response, reinterpret_cast(callback), data); + execCbCmd_(&request, &response, callback, data); } else { execCmd_(&request, &response); diff --git a/sdk/cpp/src/core_helpers.h b/sdk/cpp/src/core_helpers.h index c46f294a2..e0cd37092 100644 --- a/sdk/cpp/src/core_helpers.h +++ b/sdk/cpp/src/core_helpers.h @@ -6,12 +6,15 @@ #pragma once +#include #include #include #include #include #include +#include #include +#include #include @@ -47,38 +50,82 @@ namespace foundry_local::detail { return core->call(command, logger, &payload, callback, userData); } + inline bool TryParseFloatToken(std::string_view token, float& value) { + if (token.empty()) { + return false; + } + + const auto* begin = token.data(); + const auto* end = begin + token.size(); + const auto result = std::from_chars(begin, end, value); + return result.ec == std::errc{} && result.ptr == end; + } + + inline bool TryParseDoubleToken(std::string_view token, double& value) { + if (token.empty()) { + return false; + } + + const auto* begin = token.data(); + const auto* end = begin + token.size(); + const auto result = std::from_chars(begin, end, value); + return result.ec == std::errc{} && result.ptr == end; + } + // Serialize + call with a streaming chunk handler. // Wraps the caller-supplied onChunk with the native callback boilerplate - // (null/length checks, exception capture, rethrow after the call). + // (null/length checks, exception capture, cancellation, rethrow after the call). // The errorContext string is used to prefix any core-layer error message. inline CoreResponse CallWithStreamingCallback(Internal::IFoundryLocalCore* core, std::string_view command, - const std::string& payload, ILogger& logger, - const std::function& onChunk, - std::string_view errorContext) { + const std::string* payload, ILogger& logger, + const std::function& onChunk, + std::string_view errorContext, + CancellationCallback isCancellationRequested = nullptr) { struct State { - const std::function* cb; + const std::function* cb; + CancellationCallback isCancellationRequested; + bool cancellationObserved = false; std::exception_ptr exception; - } state{&onChunk, nullptr}; + } state{&onChunk, std::move(isCancellationRequested), false, nullptr}; - auto nativeCallback = [](void* data, int32_t len, void* user) -> int { - if (!data || len <= 0) + auto nativeCallback = [](const void* data, int32_t len, void* user) -> int32_t { + auto* st = static_cast(user); + if (!st) { return 0; + } - auto* st = static_cast(user); - if (st->exception) + if (st->exception || st->cancellationObserved) { + return 1; + } + + if (!data || len <= 0) return 0; try { + if (st->isCancellationRequested && st->isCancellationRequested()) { + st->cancellationObserved = true; + return 1; + } + std::string chunk(static_cast(data), static_cast(len)); - (*(st->cb))(chunk); + if (!(*(st->cb))(chunk)) { + st->cancellationObserved = true; + return 1; + } } catch (...) { st->exception = std::current_exception(); + return 1; } + return 0; }; - auto response = core->call(command, logger, &payload, +nativeCallback, &state); + auto response = core->call(command, logger, payload, +nativeCallback, &state); + if (state.cancellationObserved) { + throw Exception("Operation cancelled", logger); + } + if (response.HasError()) { throw Exception(std::string(errorContext) + response.error, logger); } @@ -90,6 +137,38 @@ namespace foundry_local::detail { return response; } + inline CoreResponse CallWithStreamingCallback(Internal::IFoundryLocalCore* core, std::string_view command, + const std::string* payload, ILogger& logger, + const std::function& onChunk, + std::string_view errorContext, + CancellationCallback isCancellationRequested = nullptr) { + const std::function continuingOnChunk = + [&onChunk](const std::string& chunk) { + onChunk(chunk); + return true; + }; + return CallWithStreamingCallback(core, command, payload, logger, continuingOnChunk, errorContext, + std::move(isCancellationRequested)); + } + + inline CoreResponse CallWithStreamingCallback(Internal::IFoundryLocalCore* core, std::string_view command, + const std::string& payload, ILogger& logger, + const std::function& onChunk, + std::string_view errorContext, + CancellationCallback isCancellationRequested = nullptr) { + return CallWithStreamingCallback(core, command, &payload, logger, onChunk, errorContext, + std::move(isCancellationRequested)); + } + + inline CoreResponse CallWithStreamingCallback(Internal::IFoundryLocalCore* core, std::string_view command, + const std::string& payload, ILogger& logger, + const std::function& onChunk, + std::string_view errorContext, + CancellationCallback isCancellationRequested = nullptr) { + return CallWithStreamingCallback(core, command, &payload, logger, onChunk, errorContext, + std::move(isCancellationRequested)); + } + // Overload: allow Params object directly inline CoreResponse CallWithParams(Internal::IFoundryLocalCore* core, std::string_view command, const nlohmann::json& params, ILogger& logger) { diff --git a/sdk/cpp/src/flcore_native.h b/sdk/cpp/src/flcore_native.h index 2ea792b9e..b4c95ac49 100644 --- a/sdk/cpp/src/flcore_native.h +++ b/sdk/cpp/src/flcore_native.h @@ -5,10 +5,12 @@ #include #include -#ifdef _WIN32 - #define FL_CDECL __cdecl -#else - #define FL_CDECL +#ifndef FL_CDECL + #ifdef _WIN32 + #define FL_CDECL __cdecl + #else + #define FL_CDECL + #endif #endif extern "C" @@ -29,8 +31,9 @@ extern "C" int32_t ErrorLength; }; - // Callback signature: int(*)(void* data, int length, void* userData) — returns 0 to continue, 1 to cancel - using UserCallbackFn = int(__cdecl*)(void*, int32_t, void*); + // Callback signature: int32_t(*)(const void* data, int length, void* userData) + // Return 0 to continue, 1 to cancel. + using UserCallbackFn = int32_t(FL_CDECL*)(const void*, int32_t, void*); struct StreamingRequestBuffer { const void* Command; @@ -43,7 +46,8 @@ extern "C" // Exported function pointer types using execute_command_fn = void(FL_CDECL*)(RequestBuffer*, ResponseBuffer*); - using execute_command_with_callback_fn = void(FL_CDECL*)(RequestBuffer*, ResponseBuffer*, void* /*callback*/, + using execute_command_with_callback_fn = void(FL_CDECL*)(RequestBuffer*, ResponseBuffer*, + UserCallbackFn /*callback*/, void* /*userData*/); using execute_command_with_binary_fn = void(FL_CDECL*)(StreamingRequestBuffer*, ResponseBuffer*); using free_response_fn = void(FL_CDECL*)(ResponseBuffer*); diff --git a/sdk/cpp/src/foundry_local_internal_core.h b/sdk/cpp/src/foundry_local_internal_core.h index 368096dec..3a982b16c 100644 --- a/sdk/cpp/src/foundry_local_internal_core.h +++ b/sdk/cpp/src/foundry_local_internal_core.h @@ -8,11 +8,20 @@ #include #include "logger.h" +#ifndef FL_CDECL + #ifdef _WIN32 + #define FL_CDECL __cdecl + #else + #define FL_CDECL + #endif +#endif + namespace foundry_local { /// Native callback signature used by the core DLL interop. /// Parameters: (data, dataLength, userData). - using NativeCallbackFn = int (*)(void*, int32_t, void*); + /// Return 0 to continue, 1 to cancel the native operation. + using NativeCallbackFn = int32_t(FL_CDECL*)(const void*, int32_t, void*); /// Value returned by IFoundryLocalCore::call(). /// On success, `data` contains the response payload and `error` is empty. @@ -40,4 +49,4 @@ namespace foundry_local { }; } // namespace Internal -} // namespace foundry_local \ No newline at end of file +} // namespace foundry_local diff --git a/sdk/cpp/src/foundry_local_manager.cpp b/sdk/cpp/src/foundry_local_manager.cpp index 2c1e6177c..27f511d05 100644 --- a/sdk/cpp/src/foundry_local_manager.cpp +++ b/sdk/cpp/src/foundry_local_manager.cpp @@ -6,7 +6,7 @@ #include #include #include -#include +#include #include #include @@ -15,6 +15,7 @@ #include "foundry_local_internal_core.h" #include "foundry_local_exception.h" #include "core_interop_request.h" +#include "core_helpers.h" #include "core.h" #include "logger.h" @@ -163,39 +164,14 @@ void Manager::Cleanup() noexcept { return result; } - namespace { - struct EpCallbackContext { - EpProgressCallback* callback; - }; - - int EpProgressNativeCallback(void* data, int32_t dataLength, void* userData) { - auto* ctx = static_cast(userData); - if (!ctx || !ctx->callback || !*ctx->callback) return 0; - if (!data || dataLength <= 0) return 0; - - std::string progressStr(static_cast(data), static_cast(dataLength)); - auto sepIndex = progressStr.find('|'); - if (sepIndex != std::string::npos) { - std::string name = progressStr.substr(0, sepIndex); - // Parse percent using locale-independent std::from_chars - const auto* begin = progressStr.data() + sepIndex + 1; - const auto* end = progressStr.data() + progressStr.size(); - double percent = 0.0; - auto [ptr, ec] = std::from_chars(begin, end, percent); - if (ec == std::errc{}) { - (*ctx->callback)(name, percent); - } - } - return 0; - } - } - - EpDownloadResult Manager::DownloadAndRegisterEps(EpProgressCallback progressCallback) const { - return DownloadAndRegisterEps({}, std::move(progressCallback)); + EpDownloadResult Manager::DownloadAndRegisterEps(EpProgressCallback progressCallback, + CancellationCallback isCancellationRequested) const { + return DownloadAndRegisterEps({}, std::move(progressCallback), std::move(isCancellationRequested)); } EpDownloadResult Manager::DownloadAndRegisterEps(const std::vector& names, - EpProgressCallback progressCallback) const { + EpProgressCallback progressCallback, + CancellationCallback isCancellationRequested) const { std::string requestData; std::string* requestDataPtr = nullptr; @@ -212,16 +188,32 @@ void Manager::Cleanup() noexcept { } CoreResponse response; - if (progressCallback) { - EpCallbackContext ctx{&progressCallback}; - response = core_->call("download_and_register_eps", *logger_, - requestDataPtr, EpProgressNativeCallback, &ctx); + if (progressCallback || isCancellationRequested) { + auto onChunk = [&progressCallback](const std::string& chunk) { + if (!progressCallback) { + return; + } + + const auto sep = chunk.find('|'); + if (sep == std::string::npos) { + return; + } + + double percent = 0.0; + if (detail::TryParseDoubleToken(std::string_view(chunk).substr(sep + 1), percent)) { + progressCallback(chunk.substr(0, sep), percent); + } + }; + + response = detail::CallWithStreamingCallback(core_.get(), "download_and_register_eps", + requestDataPtr, *logger_, onChunk, + "Error downloading execution providers: ", + std::move(isCancellationRequested)); } else { response = core_->call("download_and_register_eps", *logger_, requestDataPtr); - } - - if (response.HasError()) { - throw Exception(std::string("Error downloading execution providers: ") + response.error, *logger_); + if (response.HasError()) { + throw Exception(std::string("Error downloading execution providers: ") + response.error, *logger_); + } } EpDownloadResult result; diff --git a/sdk/cpp/src/model.cpp b/sdk/cpp/src/model.cpp index e09f55414..6b0c13069 100644 --- a/sdk/cpp/src/model.cpp +++ b/sdk/cpp/src/model.cpp @@ -8,6 +8,8 @@ #include #include #include +#include +#include #include @@ -67,38 +69,35 @@ namespace foundry_local { return false; } - void ModelVariant::Download(DownloadProgressCallback onProgress) { + void ModelVariant::Download(DownloadProgressCallback onProgress, CancellationCallback isCancellationRequested) { if (IsCached()) { logger_->Log(LogLevel::Information, "Model '" + info_.name + "' is already cached, skipping download."); return; } - if (onProgress) { - struct ProgressState { - DownloadProgressCallback* cb; - ILogger* logger; - } state{&onProgress, logger_}; - - auto nativeCallback = [](void* data, int32_t len, void* user) -> int { - if (!data || len <= 0) - return 0; - auto* st = static_cast(user); - std::string perc(static_cast(data), static_cast(len)); - try { - float value = std::stof(perc); - (*(st->cb))(value); + if (onProgress || isCancellationRequested) { + std::function onChunk = [&onProgress](const std::string& chunk) { + if (!onProgress) { + return true; } - catch (...) { - st->logger->Log(LogLevel::Warning, "Failed to parse download progress: " + perc); + + std::istringstream tokens(chunk); + std::string token; + while (tokens >> token) { + float value = 0.0f; + if (TryParseFloatToken(token, value)) { + if (!onProgress(value)) { + return false; + } + } } - return 0; + return true; }; - auto response = CallWithJsonAndCallback(core_, "download_model", MakeModelParams(info_.name), *logger_, - +nativeCallback, &state); - if (response.HasError()) { - throw Exception("Error downloading model [" + info_.name + "]: " + response.error, *logger_); - } + const std::string payload = MakeModelParams(info_.name).dump(); + CallWithStreamingCallback(core_, "download_model", payload, *logger_, onChunk, + "Error downloading model [" + info_.name + "]: ", + std::move(isCancellationRequested)); } else { auto response = CallWithJson(core_, "download_model", MakeModelParams(info_.name), *logger_); diff --git a/sdk/cpp/test/ep_test.cpp b/sdk/cpp/test/ep_test.cpp index 7649b1efd..78c9ecaf6 100644 --- a/sdk/cpp/test/ep_test.cpp +++ b/sdk/cpp/test/ep_test.cpp @@ -72,7 +72,7 @@ static EpDownloadResult TestDownloadAndRegisterEps( struct EpCallbackContext { EpProgressCallback* callback; }; - auto nativeCb = [](void* data, int32_t dataLength, void* userData) -> int { + auto nativeCb = [](const void* data, int32_t dataLength, void* userData) -> int32_t { auto* ctx = static_cast(userData); if (!ctx || !ctx->callback || !*ctx->callback) return 0; if (!data || dataLength <= 0) return 0; @@ -249,9 +249,9 @@ TEST_F(DownloadAndRegisterEpsTest, CallbackInvokedWithProgressData) { [](std::string_view, const std::string*, NativeCallbackFn callback, void* userData) -> std::string { if (callback) { std::string p1 = "WebGpuExecutionProvider|25.0"; - callback(const_cast(p1.data()), static_cast(p1.size()), userData); + callback(p1.data(), static_cast(p1.size()), userData); std::string p2 = "WebGpuExecutionProvider|100.0"; - callback(const_cast(p2.data()), static_cast(p2.size()), userData); + callback(p2.data(), static_cast(p2.size()), userData); } return R"({"Success": true, "Status": "OK", "RegisteredEps": ["WebGpuExecutionProvider"], "FailedEps": []})"; }); diff --git a/sdk/cpp/test/model_variant_test.cpp b/sdk/cpp/test/model_variant_test.cpp index c631f8ff3..a95ecdd12 100644 --- a/sdk/cpp/test/model_variant_test.cpp +++ b/sdk/cpp/test/model_variant_test.cpp @@ -9,6 +9,7 @@ #include "foundry_local_exception.h" #include +#include using namespace foundry_local; using namespace foundry_local::Testing; @@ -136,7 +137,7 @@ TEST_F(ModelVariantTest, Download_WithCallback_ReturnsZeroToContinue) { [](std::string_view, const std::string*, NativeCallbackFn callback, void* userData) -> std::string { if (callback && userData) { std::string progress = "50"; - int result = callback(progress.data(), static_cast(progress.size()), userData); + const int32_t result = callback(progress.data(), static_cast(progress.size()), userData); EXPECT_EQ(0, result) << "Callback should return 0 (continue), not " << result; } return ""; @@ -146,6 +147,85 @@ TEST_F(ModelVariantTest, Download_WithCallback_ReturnsZeroToContinue) { variant.Download([&](float) { return true; }); } +TEST_F(ModelVariantTest, Download_ParsesNumericProgressTokens) { + core_.OnCall("get_cached_models", R"([])"); + core_.OnCall("download_model", + [](std::string_view, const std::string*, NativeCallbackFn callback, void* userData) -> std::string { + if (callback && userData) { + std::string progress = "status 12.5\nbad 37"; + callback(progress.data(), static_cast(progress.size()), userData); + } + return ""; + }); + + auto variant = MakeVariant("test-model"); + std::vector progressValues; + variant.Download([&](float pct) { + progressValues.push_back(pct); + return true; + }); + + ASSERT_EQ(2u, progressValues.size()); + EXPECT_NEAR(12.5f, progressValues[0], 0.01f); + EXPECT_NEAR(37.0f, progressValues[1], 0.01f); +} + +TEST_F(ModelVariantTest, Download_WithCancellationRequestsNativeCancel) { + core_.OnCall("get_cached_models", R"([])"); + bool nativeCallbackCancelled = false; + core_.OnCall("download_model", + [&](std::string_view, const std::string*, NativeCallbackFn callback, void* userData) -> std::string { + if (callback && userData) { + std::string progress = "50"; + nativeCallbackCancelled = + callback(progress.data(), static_cast(progress.size()), userData) == 1; + } + return ""; + }); + + auto variant = MakeVariant("test-model"); + EXPECT_THROW(variant.Download(nullptr, [] { return true; }), Exception); + EXPECT_TRUE(nativeCallbackCancelled); +} + +TEST_F(ModelVariantTest, Download_ProgressCallbackFalseRequestsNativeCancel) { + core_.OnCall("get_cached_models", R"([])"); + bool nativeCallbackCancelled = false; + core_.OnCall("download_model", + [&](std::string_view, const std::string*, NativeCallbackFn callback, void* userData) -> std::string { + if (callback && userData) { + std::string progress = "50"; + nativeCallbackCancelled = + callback(progress.data(), static_cast(progress.size()), userData) == 1; + } + return ""; + }); + + auto variant = MakeVariant("test-model"); + EXPECT_THROW(variant.Download([](float) { return false; }), Exception); + EXPECT_TRUE(nativeCallbackCancelled); +} + +TEST_F(ModelVariantTest, Download_CancellationAfterFinalCallbackDoesNotCancelSuccessfulDownload) { + core_.OnCall("get_cached_models", R"([])"); + core_.OnCall("download_model", + [](std::string_view, const std::string*, NativeCallbackFn callback, void* userData) -> std::string { + if (callback && userData) { + std::string progress = "100"; + callback(progress.data(), static_cast(progress.size()), userData); + } + return ""; + }); + + auto variant = MakeVariant("test-model"); + bool cancel = false; + EXPECT_NO_THROW(variant.Download([&](float) { + cancel = true; + return true; + }, [&] { return cancel; })); + EXPECT_TRUE(cancel); +} + TEST_F(ModelVariantTest, RemoveFromCache_CallsCore) { core_.OnCall("remove_cached_model", ""); auto variant = MakeVariant("test-model"); diff --git a/sdk/cs/README.md b/sdk/cs/README.md index 276ffb716..9493eea0b 100644 --- a/sdk/cs/README.md +++ b/sdk/cs/README.md @@ -99,6 +99,18 @@ await mgr.DownloadAndRegisterEpsAsync((epName, percent) => Console.WriteLine(); ``` +#### Cancelling model and EP downloads + +Pass a `CancellationToken` to either download API. Cancellation is observed on the next progress update. + +```csharp +// mgr and model already initialized +using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(5)); + +await mgr.DownloadAndRegisterEpsAsync(ct: cts.Token); +await model.DownloadAsync(ct: cts.Token); +``` + Catalog access no longer blocks on EP downloads. Call `DownloadAndRegisterEpsAsync` explicitly when you need hardware-accelerated execution providers. ## Quick Start diff --git a/sdk/cs/src/Detail/CoreInterop.cs b/sdk/cs/src/Detail/CoreInterop.cs index 7239a48e4..138aa9411 100644 --- a/sdk/cs/src/Detail/CoreInterop.cs +++ b/sdk/cs/src/Detail/CoreInterop.cs @@ -177,6 +177,7 @@ public Response ExecuteCommandImpl(string commandName, string? commandInput, }; ResponseBuffer response = default; + Exception? callbackException = null; if (callback != null) { @@ -190,18 +191,19 @@ public Response ExecuteCommandImpl(string commandName, string? commandInput, var helperHandle = GCHandle.Alloc(helper); var helperPtr = GCHandle.ToIntPtr(helperHandle); - unsafe + try { - CoreExecuteCommandWithCallback(&request, &response, funcPtr, helperPtr); + unsafe + { + CoreExecuteCommandWithCallback(&request, &response, funcPtr, helperPtr); + } } - - helperHandle.Free(); - - if (helper.Exception != null) + finally { - throw new FoundryLocalException("Exception in callback handler. See InnerException for details", - helper.Exception); + helperHandle.Free(); } + + callbackException = helper.Exception; } else { @@ -239,6 +241,17 @@ public Response ExecuteCommandImpl(string commandName, string? commandInput, Marshal.FreeHGlobal(inputPtr!.Value); } + if (callbackException != null) + { + if (callbackException is OperationCanceledException canceledException) + { + throw canceledException; + } + + throw new FoundryLocalException("Exception in callback handler. See InnerException for details", + callbackException); + } + return result; } catch (Exception ex) when (ex is not OperationCanceledException) diff --git a/sdk/cs/src/Detail/ModelVariant.cs b/sdk/cs/src/Detail/ModelVariant.cs index 250c601a2..c13a596bd 100644 --- a/sdk/cs/src/Detail/ModelVariant.cs +++ b/sdk/cs/src/Detail/ModelVariant.cs @@ -6,6 +6,8 @@ namespace Microsoft.AI.Foundry.Local; +using System.Globalization; + using Microsoft.AI.Foundry.Local.Detail; using Microsoft.Extensions.Logging; @@ -144,24 +146,41 @@ private async Task DownloadImplAsync(Action? downloadProgress = null, }; ICoreInterop.Response? response; + var useCallbackPath = downloadProgress != null || (ct?.CanBeCanceled ?? false); - if (downloadProgress == null) - { - response = await _coreInterop.ExecuteCommandAsync("download_model", request, ct).ConfigureAwait(false); - } - else + if (useCallbackPath) { var callback = new ICoreInterop.CallbackFn(progressString => { - if (float.TryParse(progressString, out var progress)) + if (ct is CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + } + + if (downloadProgress == null) { - downloadProgress(progress); + return; + } + + foreach (var token in progressString.Split((char[]?)null, StringSplitOptions.RemoveEmptyEntries)) + { + if (float.TryParse(token, + NumberStyles.Float, + CultureInfo.InvariantCulture, + out var progress)) + { + downloadProgress(progress); + } } }); response = await _coreInterop.ExecuteCommandWithCallbackAsync("download_model", request, callback, ct).ConfigureAwait(false); } + else + { + response = await _coreInterop.ExecuteCommandAsync("download_model", request, ct).ConfigureAwait(false); + } if (response.Error != null) { diff --git a/sdk/cs/src/FoundryLocalManager.cs b/sdk/cs/src/FoundryLocalManager.cs index b014850f6..855aed4a2 100644 --- a/sdk/cs/src/FoundryLocalManager.cs +++ b/sdk/cs/src/FoundryLocalManager.cs @@ -6,6 +6,7 @@ namespace Microsoft.AI.Foundry.Local; using System; +using System.Globalization; using System.Text.Json; using System.Threading.Tasks; @@ -373,20 +374,27 @@ private async Task DownloadAndRegisterEpsImplAsync(IEnumerable ICoreInterop.Response result; - if (progressCallback != null) + var useCallbackPath = progressCallback != null || (ct?.CanBeCanceled ?? false); + + if (useCallbackPath) { var callback = new ICoreInterop.CallbackFn(progressString => { + if (ct is CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + } + var sepIndex = progressString.IndexOf('|'); if (sepIndex >= 0) { var name = progressString[..sepIndex]; if (double.TryParse(progressString[(sepIndex + 1)..], - System.Globalization.NumberStyles.Float, - System.Globalization.CultureInfo.InvariantCulture, + NumberStyles.Float, + CultureInfo.InvariantCulture, out var percent)) { - progressCallback(string.IsNullOrEmpty(name) ? "" : name, percent); + progressCallback?.Invoke(string.IsNullOrEmpty(name) ? "" : name, percent); } } }); diff --git a/sdk/cs/test/FoundryLocal.Tests/DownloadCancellationTests.cs b/sdk/cs/test/FoundryLocal.Tests/DownloadCancellationTests.cs new file mode 100644 index 000000000..e3d18fd36 --- /dev/null +++ b/sdk/cs/test/FoundryLocal.Tests/DownloadCancellationTests.cs @@ -0,0 +1,122 @@ +// -------------------------------------------------------------------------------------------------------------------- +// +// Copyright (c) Microsoft. All rights reserved. +// +// -------------------------------------------------------------------------------------------------------------------- + +namespace Microsoft.AI.Foundry.Local.Tests; + +using Microsoft.AI.Foundry.Local.Detail; + +using Microsoft.Extensions.Logging; + +using Moq; + +internal sealed class DownloadCancellationTests +{ + [Test] + public async Task ModelVariantDownload_WithCancellableToken_UsesCallbackPathAndPropagatesCancellation() + { + var modelInfo = new ModelInfo + { + Id = "test-model-cpu:1", + Name = "test-model-cpu", + Alias = "test-model", + Version = 1, + ProviderType = "AzureFoundry", + Uri = "azureml://registries/azureml/models/test-model-cpu/versions/1", + ModelType = "ONNX", + }; + + var modelLoadManager = new Mock(MockBehavior.Strict); + var coreInterop = new Mock(MockBehavior.Strict); + var logger = new Mock(); + using var cts = new CancellationTokenSource(); + + coreInterop.Setup(x => x.ExecuteCommandWithCallbackAsync( + It.Is(s => s == "download_model"), + It.Is(r => r != null && + r.Params != null && + r.Params.ContainsKey("Model") && + r.Params["Model"] == modelInfo.Id), + It.IsAny(), + It.IsAny())) + .Returns((string commandName, + CoreInteropRequest? request, + ICoreInterop.CallbackFn callback, + CancellationToken? cancellationToken) => + { + callback("10"); + cts.Cancel(); + callback("20"); + return Task.FromResult(new ICoreInterop.Response()); + }); + + var model = new ModelVariant(modelInfo, modelLoadManager.Object, coreInterop.Object, logger.Object); + + OperationCanceledException? caught = null; + try + { + await model.DownloadAsync(ct: cts.Token); + } + catch (OperationCanceledException ex) + { + caught = ex; + } + + await Assert.That(caught).IsNotNull(); + coreInterop.Verify(x => x.ExecuteCommandWithCallbackAsync( + "download_model", + It.IsAny(), + It.IsAny(), + It.IsAny()), + Times.Once); + coreInterop.Verify(x => x.ExecuteCommandAsync( + It.IsAny(), + It.IsAny(), + It.IsAny()), + Times.Never); + } + + [Test] + public async Task ModelVariantDownload_WithMixedProgressChunk_ParsesNumericTokens() + { + var modelInfo = new ModelInfo + { + Id = "test-model-cpu:1", + Name = "test-model-cpu", + Alias = "test-model", + Version = 1, + ProviderType = "AzureFoundry", + Uri = "azureml://registries/azureml/models/test-model-cpu/versions/1", + ModelType = "ONNX", + }; + + var modelLoadManager = new Mock(MockBehavior.Strict); + var coreInterop = new Mock(MockBehavior.Strict); + var logger = new Mock(); + + coreInterop.Setup(x => x.ExecuteCommandWithCallbackAsync( + It.Is(s => s == "download_model"), + It.IsAny(), + It.IsAny(), + It.IsAny())) + .Returns((string commandName, + CoreInteropRequest? request, + ICoreInterop.CallbackFn callback, + CancellationToken? cancellationToken) => + { + callback("status 12.5\nbad 37"); + return Task.FromResult(new ICoreInterop.Response()); + }); + + var model = new ModelVariant(modelInfo, modelLoadManager.Object, coreInterop.Object, logger.Object); + var progressValues = new List(); + + await model.DownloadAsync(progressValues.Add); + + await Assert.That(progressValues.Count).IsEqualTo(2); + await Assert.That(progressValues[0]).IsEqualTo(12.5f); + await Assert.That(progressValues[1]).IsEqualTo(37.0f); + } +} diff --git a/sdk/cs/test/FoundryLocal.Tests/Utils.cs b/sdk/cs/test/FoundryLocal.Tests/Utils.cs index f89698539..91cc6a81d 100644 --- a/sdk/cs/test/FoundryLocal.Tests/Utils.cs +++ b/sdk/cs/test/FoundryLocal.Tests/Utils.cs @@ -475,7 +475,8 @@ private static string GetRepoRoot() while (dir != null) { - if (Directory.Exists(Path.Combine(dir.FullName, ".git"))) + var gitPath = Path.Combine(dir.FullName, ".git"); + if (Directory.Exists(gitPath) || File.Exists(gitPath)) return dir.FullName; dir = dir.Parent; diff --git a/sdk/js/README.md b/sdk/js/README.md index 26471cc8c..905bc5cc9 100644 --- a/sdk/js/README.md +++ b/sdk/js/README.md @@ -77,6 +77,19 @@ await manager.downloadAndRegisterEps((epName, percent) => { process.stdout.write('\n'); ``` +#### Cancelling model and EP downloads + +Use an `AbortController` with either `downloadAndRegisterEps()` or `model.download()`. Aborting the signal rejects the in-progress download promise. + +```typescript +// manager and model already initialized +const controller = new AbortController(); +setTimeout(() => controller.abort(), 5000); + +await manager.downloadAndRegisterEps(controller.signal); +await model.download(undefined, controller.signal); +``` + Catalog access does not block on EP downloads. Call `downloadAndRegisterEps()` when you need hardware-accelerated execution providers. ## Quick Start @@ -336,4 +349,4 @@ See `test/README.md` for details on prerequisites and setup. npm run example ``` -This runs the chat completion example in `examples/chat-completion.ts`. \ No newline at end of file +This runs the chat completion example in `examples/chat-completion.ts`. diff --git a/sdk/js/src/detail/coreInterop.ts b/sdk/js/src/detail/coreInterop.ts index 72013815c..36098d4ab 100644 --- a/sdk/js/src/detail/coreInterop.ts +++ b/sdk/js/src/detail/coreInterop.ts @@ -136,9 +136,47 @@ export class CoreInterop { return this.addon.executeCommandWithBinary(command, dataStr, binBuf); } - public executeCommandStreaming(command: string, params: any, callback: (chunk: string) => void): Promise { + public async executeCommandStreaming( + command: string, + params: any, + callback: (chunk: string) => void, + signal?: AbortSignal + ): Promise { + const createAbortError = (): Error => { + const error = new Error('Operation cancelled'); + error.name = 'AbortError'; + return error; + }; + + if (signal?.aborted) { + throw createAbortError(); + } + const dataStr = params ? JSON.stringify(params) : ''; - return this.addon.executeCommandStreaming(command, dataStr, callback); + let cancelled = false; + const wrappedCallback = (chunk: string) => { + if (signal?.aborted) { + cancelled = true; + throw createAbortError(); + } + + callback(chunk); + }; + + try { + const result = await this.addon.executeCommandStreaming(command, dataStr, wrappedCallback); + if (cancelled) { + throw createAbortError(); + } + + return result; + } catch (error) { + if (cancelled) { + throw createAbortError(); + } + + throw error; + } } } diff --git a/sdk/js/src/detail/model.ts b/sdk/js/src/detail/model.ts index ffd962db5..c7aa551c3 100644 --- a/sdk/js/src/detail/model.ts +++ b/sdk/js/src/detail/model.ts @@ -126,9 +126,10 @@ export class Model implements IModel { /** * Downloads the currently selected variant. * @param progressCallback - Optional callback to report download progress. + * @param signal - Optional AbortSignal. When aborted, the download will be cancelled at the next progress update. */ - public download(progressCallback?: (progress: number) => void): Promise { - return this.selectedVariant.download(progressCallback); + public download(progressCallback?: (progress: number) => void, signal?: AbortSignal): Promise { + return this.selectedVariant.download(progressCallback, signal); } /** @@ -202,4 +203,4 @@ export class Model implements IModel { public createResponsesClient(baseUrl: string): ResponsesClient { return this.selectedVariant.createResponsesClient(baseUrl); } -} \ No newline at end of file +} diff --git a/sdk/js/src/detail/modelVariant.ts b/sdk/js/src/detail/modelVariant.ts index af150bb81..4b1f8d26f 100644 --- a/sdk/js/src/detail/modelVariant.ts +++ b/sdk/js/src/detail/modelVariant.ts @@ -108,18 +108,30 @@ export class ModelVariant implements IModel { /** * Downloads the model variant. * @param progressCallback - Optional callback to report download progress (0-100). + * @param signal - Optional AbortSignal. When aborted, the download will be + * cancelled at the next progress update and the returned promise will reject. */ - public async download(progressCallback?: (progress: number) => void): Promise { + public async download(progressCallback?: (progress: number) => void, signal?: AbortSignal): Promise { const request = { Params: { Model: this._modelInfo.id } }; - if (!progressCallback) { + if (!progressCallback && !signal) { await this.coreInterop.executeCommandAsync("download_model", request); } else { + // Use the streaming path when progress or cancellation is needed. + // Provide a no-op callback when only cancellation is requested so + // the native callback mechanism is engaged. + const cb = progressCallback ?? (() => {}); await this.coreInterop.executeCommandStreaming("download_model", request, (chunk: string) => { - const progress = parseFloat(chunk); - if (!isNaN(progress)) { - progressCallback(progress); + for (const token of chunk.split(/\s+/)) { + if (token.length === 0) { + continue; + } + + const progress = Number(token); + if (!Number.isNaN(progress)) { + cb(progress); + } } - }); + }, signal); } } diff --git a/sdk/js/src/foundryLocalManager.ts b/sdk/js/src/foundryLocalManager.ts index f3224e656..c9f15cb24 100644 --- a/sdk/js/src/foundryLocalManager.ts +++ b/sdk/js/src/foundryLocalManager.ts @@ -5,6 +5,13 @@ import { Catalog } from './catalog.js'; import { ResponsesClient } from './openai/responsesClient.js'; import { EpInfo, EpDownloadResult } from './types.js'; +function isAbortSignal(value: unknown): value is AbortSignal { + return typeof value === 'object' + && value !== null + && 'aborted' in value + && typeof (value as AbortSignal).aborted === 'boolean'; +} + /** * The main entry point for the Foundry Local SDK. * Manages the initialization of the core system and provides access to the Catalog and ModelLoadManager. @@ -178,18 +185,38 @@ export class FoundryLocalManager { * @returns A promise that resolves with an EpDownloadResult describing the outcome. */ public downloadAndRegisterEps(): Promise; + /** + * Downloads and registers execution providers. + * @param signal - Optional AbortSignal used to cancel an in-progress download. + * @returns A promise that resolves with an EpDownloadResult describing the outcome. + */ + public downloadAndRegisterEps(signal: AbortSignal): Promise; /** * Downloads and registers execution providers. * @param names - Array of EP names to download. * @returns A promise that resolves with an EpDownloadResult describing the outcome. */ public downloadAndRegisterEps(names: string[]): Promise; + /** + * Downloads and registers execution providers. + * @param names - Array of EP names to download. + * @param signal - Optional AbortSignal used to cancel an in-progress download. + * @returns A promise that resolves with an EpDownloadResult describing the outcome. + */ + public downloadAndRegisterEps(names: string[], signal: AbortSignal): Promise; /** * Downloads and registers execution providers, reporting progress. * @param progressCallback - Callback invoked with (epName, percent) as each EP downloads. Percent is 0-100. * @returns A promise that resolves with an EpDownloadResult describing the outcome. */ public downloadAndRegisterEps(progressCallback: (epName: string, percent: number) => void): Promise; + /** + * Downloads and registers execution providers, reporting progress. + * @param progressCallback - Callback invoked with (epName, percent) as each EP downloads. Percent is 0-100. + * @param signal - Optional AbortSignal used to cancel an in-progress download. + * @returns A promise that resolves with an EpDownloadResult describing the outcome. + */ + public downloadAndRegisterEps(progressCallback: (epName: string, percent: number) => void, signal: AbortSignal): Promise; /** * Downloads and registers execution providers, reporting progress. * @param names - Array of EP names to download. @@ -197,15 +224,62 @@ export class FoundryLocalManager { * @returns A promise that resolves with an EpDownloadResult describing the outcome. */ public downloadAndRegisterEps(names: string[], progressCallback: (epName: string, percent: number) => void): Promise; + /** + * Downloads and registers execution providers, reporting progress. + * @param names - Array of EP names to download. + * @param progressCallback - Callback invoked with (epName, percent) as each EP downloads. Percent is 0-100. + * @param signal - Optional AbortSignal used to cancel an in-progress download. + * @returns A promise that resolves with an EpDownloadResult describing the outcome. + */ + public downloadAndRegisterEps(names: string[], progressCallback: (epName: string, percent: number) => void, signal: AbortSignal): Promise; + /** + * Downloads and registers execution providers, preserving compatibility with callers that pass undefined for names. + * @param names - Undefined to download all EPs. + * @param signal - Optional AbortSignal used to cancel an in-progress download. + * @returns A promise that resolves with an EpDownloadResult describing the outcome. + */ + public downloadAndRegisterEps(names: undefined, signal: AbortSignal): Promise; + /** + * Downloads and registers execution providers, preserving compatibility with callers that pass undefined for names. + * @param names - Undefined to download all EPs. + * @param progressCallback - Callback invoked with (epName, percent) as each EP downloads. Percent is 0-100. + * @param signal - Optional AbortSignal used to cancel an in-progress download. + * @returns A promise that resolves with an EpDownloadResult describing the outcome. + */ + public downloadAndRegisterEps(names: undefined, progressCallback: (epName: string, percent: number) => void, signal?: AbortSignal): Promise; public async downloadAndRegisterEps( - namesOrCallback?: string[] | ((epName: string, percent: number) => void), - progressCallback?: (epName: string, percent: number) => void + namesOrCallbackOrSignal?: string[] | ((epName: string, percent: number) => void) | AbortSignal, + progressCallbackOrSignal?: ((epName: string, percent: number) => void) | AbortSignal, + maybeSignal?: AbortSignal ): Promise { + let progressCallback: ((epName: string, percent: number) => void) | undefined; let names: string[] | undefined; - if (typeof namesOrCallback === 'function') { - progressCallback = namesOrCallback; + let signal: AbortSignal | undefined; + + if (Array.isArray(namesOrCallbackOrSignal)) { + names = namesOrCallbackOrSignal; + if (typeof progressCallbackOrSignal === 'function') { + progressCallback = progressCallbackOrSignal; + signal = maybeSignal; + } else if (isAbortSignal(progressCallbackOrSignal)) { + signal = progressCallbackOrSignal; + } + } else if (typeof namesOrCallbackOrSignal === 'function') { + progressCallback = namesOrCallbackOrSignal; + if (isAbortSignal(progressCallbackOrSignal)) { + signal = progressCallbackOrSignal; + } + } else if (isAbortSignal(namesOrCallbackOrSignal)) { + signal = namesOrCallbackOrSignal; } else { - names = namesOrCallback; + if (typeof progressCallbackOrSignal === 'function') { + progressCallback = progressCallbackOrSignal; + signal = maybeSignal; + } else if (isAbortSignal(progressCallbackOrSignal)) { + signal = progressCallbackOrSignal; + } else { + signal = maybeSignal; + } } const params: { Params?: { Names: string } } = {}; @@ -235,13 +309,15 @@ export class FoundryLocalManager { progressCallback(epName || '', percent); } } - } + }, + signal ); } else { response = await this.coreInterop.executeCommandStreaming( "download_and_register_eps", Object.keys(params).length > 0 ? params : undefined, - () => {} // no-op callback + () => {}, // no-op callback + signal ); } diff --git a/sdk/js/src/imodel.ts b/sdk/js/src/imodel.ts index 7a8a79e35..a3257dec3 100644 --- a/sdk/js/src/imodel.ts +++ b/sdk/js/src/imodel.ts @@ -17,7 +17,13 @@ export interface IModel { get capabilities(): string | null; get supportsToolCalling(): boolean | null; - download(progressCallback?: (progress: number) => void): Promise; + /** + * Download the model to local cache if not already present. + * @param progressCallback - Optional callback for download progress (0-100). + * @param signal - Optional AbortSignal. When aborted, the download will be + * cancelled at the next progress update and the returned promise will reject. + */ + download(progressCallback?: (progress: number) => void, signal?: AbortSignal): Promise; get path(): string; load(): Promise; removeFromCache(): void; diff --git a/sdk/js/test/detail/coreInterop.test.ts b/sdk/js/test/detail/coreInterop.test.ts new file mode 100644 index 000000000..28b78cb66 --- /dev/null +++ b/sdk/js/test/detail/coreInterop.test.ts @@ -0,0 +1,26 @@ +import { describe, it } from 'mocha'; +import { expect } from 'chai'; +import { CoreInterop } from '../../src/detail/coreInterop.js'; + +describe('CoreInterop Tests', () => { + it('executeCommandStreaming should not reject when signal aborts after the final observed callback', async function() { + const controller = new AbortController(); + const interop = Object.create(CoreInterop.prototype) as any; + interop.addon = { + executeCommandStreaming: async (_command: string, _dataJson: string, callback: (chunk: string) => void) => { + callback('100'); + return 'ok'; + } + }; + + const result = await CoreInterop.prototype.executeCommandStreaming.call( + interop, + 'download_model', + undefined, + () => controller.abort(), + controller.signal + ); + + expect(result).to.equal('ok'); + }); +}); diff --git a/sdk/js/test/foundryLocalManager.test.ts b/sdk/js/test/foundryLocalManager.test.ts index 48adcff40..827f945b5 100644 --- a/sdk/js/test/foundryLocalManager.test.ts +++ b/sdk/js/test/foundryLocalManager.test.ts @@ -1,6 +1,7 @@ import { describe, it } from 'mocha'; import { expect } from 'chai'; import { getTestManager } from './testUtils.js'; +import { FoundryLocalManager } from '../src/foundryLocalManager.js'; describe('Foundry Local Manager Tests', () => { it('should initialize successfully', function() { @@ -18,64 +19,153 @@ describe('Foundry Local Manager Tests', () => { }); it('downloadAndRegisterEps should call command without params when names are omitted', async function() { - const manager = getTestManager() as any; const calls: unknown[][] = []; - const originalExecuteCommandStreaming = manager.coreInterop.executeCommandStreaming; - - manager.coreInterop.executeCommandStreaming = (...args: unknown[]) => { - calls.push(args); - return Promise.resolve(JSON.stringify({ - Success: true, - Status: 'All providers registered', - RegisteredEps: ['CUDAExecutionProvider'], - FailedEps: [] - })); + const manager = Object.create(FoundryLocalManager.prototype) as any; + manager.coreInterop = { + executeCommandStreaming: (...args: unknown[]) => { + calls.push(args); + return Promise.resolve(JSON.stringify({ + Success: true, + Status: 'All providers registered', + RegisteredEps: ['CUDAExecutionProvider'], + FailedEps: [] + })); + } + }; + manager._catalog = { + invalidateCache: () => {} }; - try { - const result = await manager.downloadAndRegisterEps(); - expect(calls.length).to.equal(1); - expect(calls[0][0]).to.equal('download_and_register_eps'); - expect(calls[0][1]).to.be.undefined; - expect(result).to.deep.equal({ - success: true, - status: 'All providers registered', - registeredEps: ['CUDAExecutionProvider'], - failedEps: [] - }); - } finally { - manager.coreInterop.executeCommandStreaming = originalExecuteCommandStreaming; - } + const result = await manager.downloadAndRegisterEps(); + expect(calls.length).to.equal(1); + expect(calls[0][0]).to.equal('download_and_register_eps'); + expect(calls[0][1]).to.be.undefined; + expect(result).to.deep.equal({ + success: true, + status: 'All providers registered', + registeredEps: ['CUDAExecutionProvider'], + failedEps: [] + }); }); it('downloadAndRegisterEps should send Names param when subset is provided', async function() { - const manager = getTestManager() as any; const calls: unknown[][] = []; - const originalExecuteCommandStreaming = manager.coreInterop.executeCommandStreaming; + const manager = Object.create(FoundryLocalManager.prototype) as any; + manager.coreInterop = { + executeCommandStreaming: (...args: unknown[]) => { + calls.push(args); + return Promise.resolve(JSON.stringify({ + Success: false, + Status: 'Some providers failed', + RegisteredEps: ['CUDAExecutionProvider'], + FailedEps: ['OpenVINOExecutionProvider'] + })); + } + }; + manager._catalog = { + invalidateCache: () => {} + }; + + const result = await manager.downloadAndRegisterEps(['CUDAExecutionProvider', 'OpenVINOExecutionProvider']); + expect(calls.length).to.equal(1); + expect(calls[0][0]).to.equal('download_and_register_eps'); + expect(calls[0][1]).to.deep.equal({ Params: { Names: 'CUDAExecutionProvider,OpenVINOExecutionProvider' } }); + expect(result).to.deep.equal({ + success: false, + status: 'Some providers failed', + registeredEps: ['CUDAExecutionProvider'], + failedEps: ['OpenVINOExecutionProvider'] + }); + }); + + it('downloadAndRegisterEps should pass AbortSignal through to streaming interop', async function() { + const calls: unknown[][] = []; + const controller = new AbortController(); + const manager = Object.create(FoundryLocalManager.prototype) as any; + manager.coreInterop = { + executeCommandStreaming: (...args: unknown[]) => { + calls.push(args); + return Promise.resolve(JSON.stringify({ + Success: true, + Status: 'All providers registered', + RegisteredEps: ['CUDAExecutionProvider'], + FailedEps: [] + })); + } + }; + manager._catalog = { + invalidateCache: () => {} + }; + + await FoundryLocalManager.prototype.downloadAndRegisterEps.call( + manager, + ['CUDAExecutionProvider'], + controller.signal + ); + expect(calls.length).to.equal(1); + expect(calls[0][0]).to.equal('download_and_register_eps'); + expect(calls[0][3]).to.equal(controller.signal); + }); + + it('downloadAndRegisterEps should honor progress callback when names are explicitly undefined', async function() { + const calls: unknown[][] = []; + const progress: Array<[string, number]> = []; + const manager = Object.create(FoundryLocalManager.prototype) as any; + manager.coreInterop = { + executeCommandStreaming: (...args: unknown[]) => { + calls.push(args); + const callback = args[2] as (chunk: string) => void; + callback('CUDAExecutionProvider|42.5'); + return Promise.resolve(JSON.stringify({ + Success: true, + Status: 'All providers registered', + RegisteredEps: ['CUDAExecutionProvider'], + FailedEps: [] + })); + } + }; + manager._catalog = { + invalidateCache: () => {} + }; + + await FoundryLocalManager.prototype.downloadAndRegisterEps.call( + manager, + undefined, + (epName: string, percent: number) => progress.push([epName, percent]) + ); - manager.coreInterop.executeCommandStreaming = (...args: unknown[]) => { - calls.push(args); - return Promise.resolve(JSON.stringify({ - Success: false, - Status: 'Some providers failed', - RegisteredEps: ['CUDAExecutionProvider'], - FailedEps: ['OpenVINOExecutionProvider'] - })); + expect(calls.length).to.equal(1); + expect(calls[0][1]).to.be.undefined; + expect(progress).to.deep.equal([['CUDAExecutionProvider', 42.5]]); + }); + + it('downloadAndRegisterEps should pass AbortSignal when names are explicitly undefined', async function() { + const calls: unknown[][] = []; + const controller = new AbortController(); + const manager = Object.create(FoundryLocalManager.prototype) as any; + manager.coreInterop = { + executeCommandStreaming: (...args: unknown[]) => { + calls.push(args); + return Promise.resolve(JSON.stringify({ + Success: true, + Status: 'All providers registered', + RegisteredEps: ['CUDAExecutionProvider'], + FailedEps: [] + })); + } }; + manager._catalog = { + invalidateCache: () => {} + }; + + await FoundryLocalManager.prototype.downloadAndRegisterEps.call( + manager, + undefined, + controller.signal + ); - try { - const result = await manager.downloadAndRegisterEps(['CUDAExecutionProvider', 'OpenVINOExecutionProvider']); - expect(calls.length).to.equal(1); - expect(calls[0][0]).to.equal('download_and_register_eps'); - expect(calls[0][1]).to.deep.equal({ Params: { Names: 'CUDAExecutionProvider,OpenVINOExecutionProvider' } }); - expect(result).to.deep.equal({ - success: false, - status: 'Some providers failed', - registeredEps: ['CUDAExecutionProvider'], - failedEps: ['OpenVINOExecutionProvider'] - }); - } finally { - manager.coreInterop.executeCommandStreaming = originalExecuteCommandStreaming; - } + expect(calls.length).to.equal(1); + expect(calls[0][1]).to.be.undefined; + expect(calls[0][3]).to.equal(controller.signal); }); }); diff --git a/sdk/js/test/model.test.ts b/sdk/js/test/model.test.ts index 4048d9a11..6e411a30c 100644 --- a/sdk/js/test/model.test.ts +++ b/sdk/js/test/model.test.ts @@ -1,6 +1,9 @@ import { describe, it } from 'mocha'; import { expect } from 'chai'; import { getTestManager, TEST_MODEL_ALIAS } from './testUtils.js'; +import { Model } from '../src/detail/model.js'; +import { ModelVariant } from '../src/detail/modelVariant.js'; +import { DeviceType, type ModelInfo } from '../src/types.js'; describe('Model Tests', () => { it('should verify cached models from test-data-shared', async function() { @@ -58,4 +61,78 @@ describe('Model Tests', () => { await model.unload(); expect(await model.isLoaded()).to.be.false; }); -}); \ No newline at end of file + + it('download should use streaming interop when only an AbortSignal is provided', async function() { + const calls: unknown[][] = []; + const controller = new AbortController(); + const fakeCoreInterop = { + executeCommand: () => { + throw new Error('download should not use executeCommand when a signal is provided'); + }, + executeCommandStreaming: (...args: unknown[]) => { + calls.push(args); + return Promise.resolve(''); + } + }; + const modelInfo: ModelInfo = { + id: 'test-model-cpu:1', + name: 'test-model-cpu', + version: 1, + alias: TEST_MODEL_ALIAS, + providerType: 'AzureFoundry', + uri: 'azureml://registries/azureml/models/test-model-cpu/versions/1', + modelType: 'ONNX', + cached: false, + createdAtUnix: 0, + runtime: { + deviceType: DeviceType.CPU, + executionProvider: 'CPUExecutionProvider' + } + }; + const variant = new ModelVariant(modelInfo, fakeCoreInterop as any, {} as any); + const model = new Model(variant); + + await model.download(undefined, controller.signal); + expect(calls.length).to.equal(1); + expect(calls[0][0]).to.equal('download_model'); + expect(calls[0][3]).to.equal(controller.signal); + }); + + it('download should parse numeric progress tokens and ignore status text', async function() { + const progress: number[] = []; + const fakeCoreInterop = { + executeCommand: () => { + throw new Error('download should use streaming interop when progress is provided'); + }, + executeCommandStreaming: async ( + _command: string, + _request: unknown, + callback: (chunk: string) => void + ) => { + callback('status 12.5\nbad 37'); + return ''; + } + }; + const modelInfo: ModelInfo = { + id: 'test-model-cpu:1', + name: 'test-model-cpu', + version: 1, + alias: TEST_MODEL_ALIAS, + providerType: 'AzureFoundry', + uri: 'azureml://registries/azureml/models/test-model-cpu/versions/1', + modelType: 'ONNX', + cached: false, + createdAtUnix: 0, + runtime: { + deviceType: DeviceType.CPU, + executionProvider: 'CPUExecutionProvider' + } + }; + const variant = new ModelVariant(modelInfo, fakeCoreInterop as any, {} as any); + const model = new Model(variant); + + await model.download(progress.push.bind(progress)); + + expect(progress).to.deep.equal([12.5, 37]); + }); +}); diff --git a/sdk/python/README.md b/sdk/python/README.md index 2a121411e..55a6f8d17 100644 --- a/sdk/python/README.md +++ b/sdk/python/README.md @@ -108,6 +108,21 @@ manager.download_and_register_eps(progress_callback=on_progress) print() ``` +### Cancelling model and EP downloads + +Pass a `threading.Event` as `cancel_event` to either download API. Set the event from another thread or handler to cancel the in-progress download. + +```python +import threading + +# manager and model already initialized +cancel_event = threading.Event() +threading.Timer(5.0, cancel_event.set).start() + +manager.download_and_register_eps(cancel_event=cancel_event) +model.download(cancel_event=cancel_event) +``` + Catalog access does not block on EP downloads. Call `download_and_register_eps()` when you need hardware-accelerated execution providers. ## Quick Start @@ -328,4 +343,4 @@ See [test/README.md](test/README.md) for detailed test setup and structure. ```bash python examples/chat_completion.py -``` \ No newline at end of file +``` diff --git a/sdk/python/requirements.txt b/sdk/python/requirements.txt index 0f8b80d1e..32b00f502 100644 --- a/sdk/python/requirements.txt +++ b/sdk/python/requirements.txt @@ -1,9 +1,8 @@ pydantic>=2.0.0 requests>=2.32.4 openai>=2.24.0 -# Standard native binary packages from the ORT-Nightly PyPI feed. foundry-local-core==1.1.0 -onnxruntime-core==1.25.1; sys_platform != "linux" -onnxruntime-gpu==1.25.1; sys_platform == "linux" -onnxruntime-genai-core==0.13.2; sys_platform != "linux" -onnxruntime-genai-cuda==0.13.2; sys_platform == "linux" +onnxruntime-gpu==1.25.1; platform_system == "Linux" +onnxruntime-core==1.25.1; platform_system != "Linux" +onnxruntime-genai-cuda==0.13.2; platform_system == "Linux" +onnxruntime-genai-core==0.13.2; platform_system != "Linux" diff --git a/sdk/python/src/detail/core_interop.py b/sdk/python/src/detail/core_interop.py index f93b79f03..a013f7ba7 100644 --- a/sdk/python/src/detail/core_interop.py +++ b/sdk/python/src/detail/core_interop.py @@ -10,6 +10,7 @@ import logging import os import sys +import threading from dataclasses import dataclass from pathlib import Path @@ -84,6 +85,10 @@ class Response: error: Optional[str] = None +class CancelledException(Exception): + """Raised internally when a download or streaming operation is cancelled.""" + + class CallbackHelper: """Internal helper class to convert the callback from ctypes to a str and call the python callback.""" @staticmethod @@ -92,18 +97,27 @@ def callback(data_ptr, length, self_ptr): try: self = ctypes.cast(self_ptr, ctypes.POINTER(ctypes.py_object)).contents.value + # Check for cancellation before processing the callback data. + if self._cancel_event is not None and self._cancel_event.is_set(): + raise CancelledException("Operation cancelled") + # convert to a string and pass to the python callback data_bytes = ctypes.string_at(data_ptr, length) data_str = data_bytes.decode('utf-8') self._py_callback(data_str) return 0 # continue + except CancelledException as e: + if self is not None and self.exception is None: + self.exception = e + return 1 # cancel except Exception as e: if self is not None and self.exception is None: self.exception = e # keep the first only as they are likely all the same return 1 # cancel on error - def __init__(self, py_callback: Callable[[str], None]): + def __init__(self, py_callback: Callable[[str], None], cancel_event: Optional['threading.Event'] = None): self._py_callback = py_callback + self._cancel_event = cancel_event self.exception = None @@ -252,37 +266,44 @@ def __init__(self, config: Configuration): logger.info("Foundry.Local.Core initialized successfully: %s", response.data) def _execute_command(self, command: str, interop_request: InteropRequest = None, - callback: CoreInterop.CALLBACK_TYPE = None): + callback: CoreInterop.CALLBACK_TYPE = None, + cancel_event: Optional[threading.Event] = None): cmd_ptr, cmd_len, cmd_buf = CoreInterop._to_c_buffer(command) data_ptr, data_len, data_buf = CoreInterop._to_c_buffer(interop_request.to_json() if interop_request else None) req = RequestBuffer(Command=cmd_ptr, CommandLength=cmd_len, Data=data_ptr, DataLength=data_len) resp = ResponseBuffer() lib = CoreInterop._flcore_library + callback_exception = None if (callback is not None): # If a callback is provided, use the execute_command_with_callback method # We need a helper to do the initial conversion from ctypes to Python and pass it through to the # provided callback function - callback_helper = CallbackHelper(callback) + callback_helper = CallbackHelper(callback, cancel_event) callback_py_obj = ctypes.py_object(callback_helper) callback_helper_ptr = ctypes.cast(ctypes.pointer(callback_py_obj), ctypes.c_void_p) callback_fn = CoreInterop.CALLBACK_TYPE(CallbackHelper.callback) lib.execute_command_with_callback(ctypes.byref(req), ctypes.byref(resp), callback_fn, callback_helper_ptr) - - if callback_helper.exception is not None: - raise callback_helper.exception + callback_exception = callback_helper.exception else: lib.execute_command(ctypes.byref(req), ctypes.byref(resp)) req = None # Free Python reference to request - response_str = ctypes.string_at(resp.Data, resp.DataLength).decode("utf-8") if resp.Data else None - error_str = ctypes.string_at(resp.Error, resp.ErrorLength).decode("utf-8") if resp.Error else None - - # C# owns the memory in the response so we need to free it explicitly - lib.free_response(resp) + try: + response_str = ctypes.string_at(resp.Data, resp.DataLength).decode("utf-8") if resp.Data else None + error_str = ctypes.string_at(resp.Error, resp.ErrorLength).decode("utf-8") if resp.Error else None + finally: + # C# owns the memory in the response so we need to free it explicitly. + # Do this before surfacing callback exceptions so cancellation does not leak native buffers. + lib.free_response(resp) + + if callback_exception is not None: + if isinstance(callback_exception, CancelledException): + raise FoundryLocalException("Operation cancelled") + raise callback_exception return Response(data=response_str, error=error_str) @@ -303,23 +324,33 @@ def execute_command(self, command_name: str, command_input: Optional[InteropRequ return response def execute_command_with_callback(self, command_name: str, command_input: Optional[InteropRequest], - callback: Callable[[str], None]) -> Response: + callback: Callable[[str], None], + cancel_event: Optional[threading.Event] = None) -> Response: """Execute a command with a streaming callback. The ``callback`` receives incremental string data from the native layer (e.g. streaming chat tokens or download progress). + If ``cancel_event`` is provided and is set, the native call will be + cancelled at the next callback invocation and a ``FoundryLocalException`` + with message ``"Operation cancelled"`` will be raised. + Args: command_name: The native command name. command_input: Optional request parameters. callback: Called with each incremental string response. + cancel_event: Optional ``threading.Event`` that signals cancellation + when set. Returns: A ``Response`` with ``data`` on success or ``error`` on failure. + + Raises: + FoundryLocalException: If the operation is cancelled or fails. """ logger.debug("Executing command with callback: %s Input: %s", command_name, command_input.params if command_input else None) - response = self._execute_command(command_name, command_input, callback) + response = self._execute_command(command_name, command_input, callback, cancel_event) return response def execute_command_with_binary(self, command_name: str, diff --git a/sdk/python/src/detail/model.py b/sdk/python/src/detail/model.py index 6d60b7a2f..a71b1dba5 100644 --- a/sdk/python/src/detail/model.py +++ b/sdk/python/src/detail/model.py @@ -5,6 +5,7 @@ from __future__ import annotations import logging +from threading import Event from typing import Callable, List, Optional from ..imodel import IModel @@ -115,9 +116,10 @@ def is_loaded(self) -> bool: """Is the currently selected variant loaded in memory?""" return self._selected_variant.is_loaded - def download(self, progress_callback: Optional[Callable[[float], None]] = None) -> None: + def download(self, progress_callback: Optional[Callable[[float], None]] = None, + cancel_event: Optional[Event] = None) -> None: """Download the currently selected variant.""" - self._selected_variant.download(progress_callback) + self._selected_variant.download(progress_callback, cancel_event) def get_path(self) -> str: """Get the path to the currently selected variant.""" diff --git a/sdk/python/src/detail/model_variant.py b/sdk/python/src/detail/model_variant.py index 76efb05cd..fa1fe9c44 100644 --- a/sdk/python/src/detail/model_variant.py +++ b/sdk/python/src/detail/model_variant.py @@ -5,6 +5,7 @@ from __future__ import annotations import logging +from threading import Event from typing import Callable, List, Optional from ..imodel import IModel @@ -112,20 +113,37 @@ def is_loaded(self) -> bool: loaded_model_ids = self._model_load_manager.list_loaded() return self.id in loaded_model_ids - def download(self, progress_callback: Callable[[float], None] = None): + def download(self, progress_callback: Callable[[float], None] = None, + cancel_event: Optional[Event] = None): """Download this variant to the local cache. Args: progress_callback: Optional callback receiving download progress as a percentage (0.0 to 100.0). + cancel_event: Optional ``threading.Event``. When set, the download will be + cancelled at the next progress update and ``FoundryLocalException`` is raised. """ request = InteropRequest(params={"Model": self.id}) - if progress_callback is None: + if progress_callback is None and cancel_event is None: response = self._core_interop.execute_command("download_model", request) else: + # Use the callback path when either progress or cancellation is needed. + # Ignore non-progress chunks so cancellation-only downloads still + # tolerate any status text emitted by the native layer. + def _on_chunk(chunk: str) -> None: + if progress_callback is None: + return + + for token in chunk.split(): + try: + progress_callback(float(token)) + except ValueError: + pass + response = self._core_interop.execute_command_with_callback( "download_model", request, - lambda pct_str: progress_callback(float(pct_str)) + _on_chunk, + cancel_event, ) logger.info("Download response: %s", response) diff --git a/sdk/python/src/foundry_local_manager.py b/sdk/python/src/foundry_local_manager.py index a649f8e56..f36782678 100644 --- a/sdk/python/src/foundry_local_manager.py +++ b/sdk/python/src/foundry_local_manager.py @@ -101,6 +101,7 @@ def download_and_register_eps( self, names: Optional[list[str]] = None, progress_callback: Optional[Callable[[str, float], None]] = None, + cancel_event: Optional[threading.Event] = None, ) -> EpDownloadResult: """Download and register execution providers. @@ -109,6 +110,8 @@ def download_and_register_eps( all discoverable EPs are downloaded. progress_callback: Optional callback ``(ep_name: str, percent: float) -> None`` invoked as each EP downloads. ``percent`` is 0-100. + cancel_event: Optional ``threading.Event`` that signals cancellation + when set. The download will be cancelled at the next progress update. Returns: ``EpDownloadResult`` describing operation status and per-EP outcomes. @@ -120,19 +123,22 @@ def download_and_register_eps( if names is not None and len(names) > 0: request = InteropRequest(params={"Names": ",".join(names)}) - if progress_callback is not None: + if progress_callback is not None or cancel_event is not None: + user_cb = progress_callback + def _on_chunk(chunk: str) -> None: - sep = chunk.find("|") - if sep >= 0: - ep_name = chunk[:sep] or "" - try: - percent = float(chunk[sep + 1:]) - progress_callback(ep_name, percent) - except ValueError: - pass + if user_cb is not None: + sep = chunk.find("|") + if sep >= 0: + ep_name = chunk[:sep] or "" + try: + percent = float(chunk[sep + 1:]) + user_cb(ep_name, percent) + except ValueError: + pass response = self._core_interop.execute_command_with_callback( - "download_and_register_eps", request, _on_chunk + "download_and_register_eps", request, _on_chunk, cancel_event ) else: response = self._core_interop.execute_command("download_and_register_eps", request) diff --git a/sdk/python/src/imodel.py b/sdk/python/src/imodel.py index f723e514a..fc63f3747 100644 --- a/sdk/python/src/imodel.py +++ b/sdk/python/src/imodel.py @@ -5,6 +5,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from threading import Event from typing import Callable, List, Optional from .openai.chat_client import ChatClient @@ -76,10 +77,13 @@ def supports_tool_calling(self) -> Optional[bool]: pass @abstractmethod - def download(self, progress_callback: Callable[[float], None] = None) -> None: + def download(self, progress_callback: Callable[[float], None] = None, + cancel_event: Optional[Event] = None) -> None: """ Download the model to local cache if not already present. :param progress_callback: Optional callback function for download progress as a percentage (0.0 to 100.0). + :param cancel_event: Optional ``threading.Event``. When set, the download will be + cancelled at the next progress update and ``FoundryLocalException`` is raised. """ pass diff --git a/sdk/python/test/test_foundry_local_manager.py b/sdk/python/test/test_foundry_local_manager.py index 315288912..3abb37f64 100644 --- a/sdk/python/test/test_foundry_local_manager.py +++ b/sdk/python/test/test_foundry_local_manager.py @@ -6,6 +6,10 @@ from __future__ import annotations +import threading + +from foundry_local_sdk.foundry_local_manager import FoundryLocalManager + class _Response: def __init__(self, data=None, error=None): @@ -22,6 +26,12 @@ def execute_command(self, command_name, command_input=None): self.calls.append((command_name, command_input)) return self._responses[command_name] + def execute_command_with_callback( + self, command_name, command_input=None, callback=None, cancel_event=None + ): + self.calls.append((command_name, command_input, callback, cancel_event)) + return self._responses[command_name] + class TestFoundryLocalManager: """Foundry Local Manager Tests.""" @@ -81,3 +91,36 @@ def test_download_and_register_eps_returns_result(self, manager): assert result.status == "ok" assert result.registered_eps == ["CUDAExecutionProvider"] assert result.failed_eps == [] + + def test_download_and_register_eps_uses_callback_path_when_cancel_event_is_provided(self): + fake_core = _FakeCoreInterop( + { + "download_and_register_eps": _Response( + data=( + '{"Success":true,"Status":"ok",' + '"RegisteredEps":["CUDAExecutionProvider"],"FailedEps":[]}' + ), + error=None, + ) + } + ) + manager = FoundryLocalManager.__new__(FoundryLocalManager) + manager._core_interop = fake_core + manager.catalog = type( + "_FakeCatalog", + (), + {"_invalidate_cache": staticmethod(lambda: None)}, + )() + cancel_event = threading.Event() + + result = manager.download_and_register_eps( + ["CUDAExecutionProvider"], cancel_event=cancel_event + ) + + assert result.success is True + assert len(fake_core.calls) == 1 + command_name, command_input, callback, seen_cancel_event = fake_core.calls[0] + assert command_name == "download_and_register_eps" + assert command_input.params == {"Names": "CUDAExecutionProvider"} + assert callable(callback) + assert seen_cancel_event is cancel_event diff --git a/sdk/python/test/test_model.py b/sdk/python/test/test_model.py index e2ea15090..a01d4f93c 100644 --- a/sdk/python/test/test_model.py +++ b/sdk/python/test/test_model.py @@ -6,6 +6,12 @@ from __future__ import annotations +import threading + +from types import SimpleNamespace + +from foundry_local_sdk.detail.model_variant import ModelVariant + from .conftest import TEST_MODEL_ALIAS, AUDIO_MODEL_ALIAS @@ -86,3 +92,75 @@ def test_should_expose_supports_tool_calling(self, catalog): assert model is not None stc = model.supports_tool_calling assert stc is None or isinstance(stc, bool) + + def test_download_should_use_callback_path_when_cancel_event_is_provided(self): + """Model download should route through callback interop when cancellation is enabled.""" + + class _Response: + def __init__(self, data=None, error=None): + self.data = data + self.error = error + + class _FakeCoreInterop: + def __init__(self): + self.calls = [] + + def execute_command(self, command_name, command_input=None): + raise AssertionError( + "download should not use execute_command when cancel_event is provided" + ) + + def execute_command_with_callback( + self, command_name, command_input=None, callback=None, cancel_event=None + ): + self.calls.append((command_name, command_input, callback, cancel_event)) + return _Response(data="", error=None) + + fake_core = _FakeCoreInterop() + cancel_event = threading.Event() + variant = ModelVariant.__new__(ModelVariant) + variant._model_info = SimpleNamespace(id="test-model-cpu:1", alias=TEST_MODEL_ALIAS) + variant._id = "test-model-cpu:1" + variant._alias = TEST_MODEL_ALIAS + variant._core_interop = fake_core + variant._model_load_manager = None + + variant.download(cancel_event=cancel_event) + + assert len(fake_core.calls) == 1 + command_name, command_input, callback, seen_cancel_event = fake_core.calls[0] + assert command_name == "download_model" + assert command_input.params == {"Model": "test-model-cpu:1"} + assert callable(callback) + assert seen_cancel_event is cancel_event + callback("status: starting") + + def test_download_should_parse_numeric_progress_tokens_and_ignore_status_text(self): + """Model download progress parsing should tolerate mixed native chunks.""" + + class _Response: + def __init__(self, data=None, error=None): + self.data = data + self.error = error + + class _FakeCoreInterop: + def execute_command(self, command_name, command_input=None): + raise AssertionError("download should use callback interop when progress is provided") + + def execute_command_with_callback( + self, command_name, command_input=None, callback=None, cancel_event=None + ): + callback("status 12.5\nbad 37") + return _Response(data="", error=None) + + progress = [] + variant = ModelVariant.__new__(ModelVariant) + variant._model_info = SimpleNamespace(id="test-model-cpu:1", alias=TEST_MODEL_ALIAS) + variant._id = "test-model-cpu:1" + variant._alias = TEST_MODEL_ALIAS + variant._core_interop = _FakeCoreInterop() + variant._model_load_manager = None + + variant.download(progress_callback=progress.append) + + assert progress == [12.5, 37.0] diff --git a/sdk/rust/README.md b/sdk/rust/README.md index ce97a7dd0..d017ce5e2 100644 --- a/sdk/rust/README.md +++ b/sdk/rust/README.md @@ -107,6 +107,28 @@ manager.download_and_register_eps_with_progress(None, move |ep_name: &str, perce println!(); ``` +#### Cancelling model and EP downloads + +Use a shared `Arc` with the cancellable download APIs. Set the flag from another task or signal handler to stop the in-progress download. + +```rust +use std::sync::{ + Arc, + atomic::AtomicBool, +}; + +// manager and model already initialized +let cancel_flag = Arc::new(AtomicBool::new(false)); +// call cancel_flag.store(true, ...) from another task or signal handler to cancel + +manager + .download_and_register_eps_cancellable(None, Arc::clone(&cancel_flag)) + .await?; +model + .download_cancellable(None::, Arc::clone(&cancel_flag)) + .await?; +``` + Catalog access does not block on EP downloads. Call `download_and_register_eps` when you need hardware-accelerated execution providers. ## Quick Start diff --git a/sdk/rust/src/detail/core_interop.rs b/sdk/rust/src/detail/core_interop.rs index 0d17fe62d..30a621e01 100644 --- a/sdk/rust/src/detail/core_interop.rs +++ b/sdk/rust/src/detail/core_interop.rs @@ -9,6 +9,7 @@ use std::ffi::CString; use std::path::{Path, PathBuf}; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use libloading::{Library, Symbol}; @@ -143,6 +144,8 @@ unsafe fn free_native_buffer(ptr: *mut u8) { struct StreamingCallbackState<'a> { callback: &'a mut dyn FnMut(&str), buf: Vec, + cancel_flag: Option>, + cancelled_observed: bool, } impl<'a> StreamingCallbackState<'a> { @@ -150,9 +153,37 @@ impl<'a> StreamingCallbackState<'a> { Self { callback, buf: Vec::new(), + cancel_flag: None, + cancelled_observed: false, } } + fn with_cancel(callback: &'a mut dyn FnMut(&str), cancel_flag: Arc) -> Self { + Self { + callback, + buf: Vec::new(), + cancel_flag: Some(cancel_flag), + cancelled_observed: false, + } + } + + /// Records and returns `true` only when this callback invocation observes a cancellation request. + fn mark_cancelled_if_requested(&mut self) -> bool { + let cancelled = self + .cancel_flag + .as_ref() + .is_some_and(|f| f.load(Ordering::Relaxed)); + if cancelled { + self.cancelled_observed = true; + } + + cancelled + } + + fn cancellation_observed(&self) -> bool { + self.cancelled_observed + } + /// Append raw bytes, decode as much valid UTF-8 as possible, and forward /// complete text to the callback. Any trailing incomplete multi-byte /// sequence is kept in the buffer for the next call. Invalid byte @@ -225,16 +256,19 @@ unsafe extern "C" fn streaming_trampoline( // by the caller of `execute_command_with_callback` for the duration of // the native call. let state = &mut *(user_data as *mut StreamingCallbackState<'_>); + + // Check for cancellation before processing the chunk. + if state.mark_cancelled_if_requested() { + return 1; // cancel + } + // SAFETY: `data` is valid for `length` bytes as guaranteed by the native // core's callback contract. let slice = std::slice::from_raw_parts(data, length as usize); state.push(slice); + 0 // continue })); - if result.is_err() { - 1 - } else { - 0 - } + result.unwrap_or(1) } // ── CoreInterop ────────────────────────────────────────────────────────────── @@ -452,6 +486,32 @@ impl CoreInterop { where F: FnMut(&str), { + self.execute_command_streaming_impl(command, params, &mut callback, None) + } + + /// Like [`Self::execute_command_streaming`], but accepts a cancellation + /// flag. When `cancel_flag` is set to `true`, the native call will be + /// cancelled at the next callback invocation and an error is returned. + pub fn execute_command_streaming_cancellable( + &self, + command: &str, + params: Option<&Value>, + mut callback: F, + cancel_flag: Arc, + ) -> Result + where + F: FnMut(&str), + { + self.execute_command_streaming_impl(command, params, &mut callback, Some(cancel_flag)) + } + + fn execute_command_streaming_impl( + &self, + command: &str, + params: Option<&Value>, + callback: &mut dyn FnMut(&str), + cancel_flag: Option>, + ) -> Result { let cmd = CString::new(command).map_err(|e| FoundryLocalError::CommandExecution { reason: format!("Invalid command string: {e}"), })?; @@ -476,8 +536,10 @@ impl CoreInterop { // Wrap the closure in a StreamingCallbackState that handles partial // UTF-8 sequences split across native callbacks. - let mut cb = |chunk: &str| callback(chunk); - let mut state = StreamingCallbackState::new(&mut cb); + let mut state = match cancel_flag { + Some(flag) => StreamingCallbackState::with_cancel(callback, flag), + None => StreamingCallbackState::new(callback), + }; let user_data = &mut state as *mut StreamingCallbackState<'_> as *mut std::ffi::c_void; // SAFETY: `request` fields point into `cmd` and `data_cstr` which are @@ -494,9 +556,19 @@ impl CoreInterop { ); } + let cancelled = state.cancellation_observed(); + // Flush any trailing partial UTF-8 bytes. state.flush(); + if cancelled { + // Free native response memory before returning the error. + Self::process_response(response).ok(); + return Err(FoundryLocalError::CommandExecution { + reason: "Operation cancelled".to_string(), + }); + } + Self::process_response(response) } @@ -540,6 +612,36 @@ impl CoreInterop { })? } + /// Async version of [`Self::execute_command_streaming_cancellable`]. + /// + /// Accepts a shared cancellation flag (`Arc`). When the flag + /// is set to `true`, the native call will be cancelled at the next + /// callback invocation and an error is returned. + pub async fn execute_command_streaming_cancellable_async( + self: &Arc, + command: String, + params: Option, + callback: F, + cancel_flag: Arc, + ) -> Result + where + F: FnMut(&str) + Send + 'static, + { + let this = Arc::clone(self); + tokio::task::spawn_blocking(move || { + this.execute_command_streaming_cancellable( + &command, + params.as_ref(), + callback, + cancel_flag, + ) + }) + .await + .map_err(|e| FoundryLocalError::CommandExecution { + reason: format!("task join error: {e}"), + })? + } + /// Async streaming variant that bridges the FFI callback into a /// [`tokio::sync::mpsc`] channel. /// @@ -702,3 +804,35 @@ impl CoreInterop { Ok(libs) } } + +#[cfg(test)] +mod tests { + use super::StreamingCallbackState; + use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }; + + #[test] + fn cancellation_request_after_callback_is_not_observed_until_next_callback() { + let cancel_flag = Arc::new(AtomicBool::new(false)); + let mut callback = |_chunk: &str| {}; + let mut state = + StreamingCallbackState::with_cancel(&mut callback, Arc::clone(&cancel_flag)); + + state.push(b"100"); + cancel_flag.store(true, Ordering::Relaxed); + + assert!(!state.cancellation_observed()); + } + + #[test] + fn cancellation_is_recorded_when_callback_observes_cancel_flag() { + let cancel_flag = Arc::new(AtomicBool::new(true)); + let mut callback = |_chunk: &str| {}; + let mut state = StreamingCallbackState::with_cancel(&mut callback, cancel_flag); + + assert!(state.mark_cancelled_if_requested()); + assert!(state.cancellation_observed()); + } +} diff --git a/sdk/rust/src/detail/model.rs b/sdk/rust/src/detail/model.rs index 08288aee8..5921fbcde 100644 --- a/sdk/rust/src/detail/model.rs +++ b/sdk/rust/src/detail/model.rs @@ -6,7 +6,7 @@ use std::fmt; use std::path::PathBuf; -use std::sync::atomic::{AtomicUsize, Ordering::Relaxed}; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering::Relaxed}; use std::sync::Arc; use super::core_interop::CoreInterop; @@ -213,6 +213,23 @@ impl Model { self.selected_variant().download(progress).await } + /// Like [`Self::download`], but accepts a shared cancellation flag + /// (`Arc`). When the flag is set to `true`, the download + /// will be cancelled at the next progress callback and an error is + /// returned. + pub async fn download_cancellable( + &self, + progress: Option, + cancel_flag: Arc, + ) -> Result<()> + where + F: FnMut(f64) + Send + 'static, + { + self.selected_variant() + .download_cancellable(progress, cancel_flag) + .await + } + /// Return the local file-system path of the (selected) variant. pub async fn path(&self) -> Result { self.selected_variant().path().await diff --git a/sdk/rust/src/detail/model_variant.rs b/sdk/rust/src/detail/model_variant.rs index 1f8ce7d5b..a49aae214 100644 --- a/sdk/rust/src/detail/model_variant.rs +++ b/sdk/rust/src/detail/model_variant.rs @@ -5,6 +5,7 @@ use std::fmt; use std::path::PathBuf; +use std::sync::atomic::AtomicBool; use std::sync::Arc; use serde_json::json; @@ -88,12 +89,54 @@ impl ModelVariant { } pub(crate) async fn download(&self, progress: Option) -> Result<()> + where + F: FnMut(f64) + Send + 'static, + { + self.download_impl(progress, None).await + } + + /// Like [`Self::download`], but accepts a shared cancellation flag. + /// When `cancel_flag` is set to `true`, the download will be cancelled at + /// the next progress callback. + pub(crate) async fn download_cancellable( + &self, + progress: Option, + cancel_flag: Arc, + ) -> Result<()> + where + F: FnMut(f64) + Send + 'static, + { + self.download_impl(progress, Some(cancel_flag)).await + } + + async fn download_impl( + &self, + progress: Option, + cancel_flag: Option>, + ) -> Result<()> where F: FnMut(f64) + Send + 'static, { let params = json!({ "Params": { "Model": self.info.id } }); - match progress { - Some(mut cb) => { + match (progress, cancel_flag) { + (Some(mut cb), Some(flag)) => { + let wrapper = move |chunk: &str| { + for token in chunk.split_whitespace() { + if let Ok(pct) = token.parse::() { + cb(pct); + } + } + }; + self.core + .execute_command_streaming_cancellable_async( + "download_model".into(), + Some(params), + wrapper, + flag, + ) + .await?; + } + (Some(mut cb), None) => { let wrapper = move |chunk: &str| { for token in chunk.split_whitespace() { if let Ok(pct) = token.parse::() { @@ -105,7 +148,19 @@ impl ModelVariant { .execute_command_streaming_async("download_model".into(), Some(params), wrapper) .await?; } - None => { + (None, Some(flag)) => { + // Use a no-op callback to engage the callback mechanism + // required for cancellation checks. + self.core + .execute_command_streaming_cancellable_async( + "download_model".into(), + Some(params), + |_: &str| {}, + flag, + ) + .await?; + } + (None, None) => { self.core .execute_command_async("download_model".into(), Some(params)) .await?; diff --git a/sdk/rust/src/foundry_local_manager.rs b/sdk/rust/src/foundry_local_manager.rs index 0c22ef154..a14b42b75 100644 --- a/sdk/rust/src/foundry_local_manager.rs +++ b/sdk/rust/src/foundry_local_manager.rs @@ -4,6 +4,7 @@ //! library, provides access to the model [`Catalog`], and can start / stop //! the local web service. +use std::sync::atomic::AtomicBool; use std::sync::{Arc, Mutex, OnceLock}; use serde_json::json; @@ -150,7 +151,19 @@ impl FoundryLocalManager { &self, names: Option<&[&str]>, ) -> Result { - self.download_and_register_eps_impl(names, None::) + self.download_and_register_eps_impl(names, None::, None) + .await + } + + /// Like [`Self::download_and_register_eps`], but accepts a shared + /// cancellation flag (`Arc`). When the flag is set to `true`, + /// the download will be cancelled at the next progress callback. + pub async fn download_and_register_eps_cancellable( + &self, + names: Option<&[&str]>, + cancel_flag: Arc, + ) -> Result { + self.download_and_register_eps_impl(names, None::, Some(cancel_flag)) .await } @@ -169,7 +182,23 @@ impl FoundryLocalManager { where F: FnMut(&str, f64) + Send + 'static, { - self.download_and_register_eps_impl(names, Some(progress_callback)) + self.download_and_register_eps_impl(names, Some(progress_callback), None) + .await + } + + /// Like [`Self::download_and_register_eps_with_progress`], but accepts a + /// shared cancellation flag (`Arc`). When the flag is set to + /// `true`, the download will be cancelled at the next progress callback. + pub async fn download_and_register_eps_with_progress_cancellable( + &self, + names: Option<&[&str]>, + progress_callback: F, + cancel_flag: Arc, + ) -> Result + where + F: FnMut(&str, f64) + Send + 'static, + { + self.download_and_register_eps_impl(names, Some(progress_callback), Some(cancel_flag)) .await } @@ -177,6 +206,7 @@ impl FoundryLocalManager { &self, names: Option<&[&str]>, progress_callback: Option, + cancel_flag: Option>, ) -> Result where F: FnMut(&str, f64) + Send + 'static, @@ -186,8 +216,28 @@ impl FoundryLocalManager { _ => None, }; - let raw = match progress_callback { - Some(cb) => { + let raw = match (progress_callback, cancel_flag) { + (Some(cb), Some(flag)) => { + let mut callback = cb; + let wrapper = move |chunk: &str| { + if let Some(sep) = chunk.find('|') { + let name = &chunk[..sep]; + if let Ok(percent) = chunk[sep + 1..].parse::() { + callback(if name.is_empty() { "" } else { name }, percent); + } + } + }; + + self.core + .execute_command_streaming_cancellable_async( + "download_and_register_eps".into(), + params, + wrapper, + flag, + ) + .await? + } + (Some(cb), None) => { let mut callback = cb; let wrapper = move |chunk: &str| { if let Some(sep) = chunk.find('|') { @@ -206,7 +256,17 @@ impl FoundryLocalManager { ) .await? } - None => { + (None, Some(flag)) => { + self.core + .execute_command_streaming_cancellable_async( + "download_and_register_eps".into(), + params, + |_chunk: &str| {}, + flag, + ) + .await? + } + (None, None) => { self.core .execute_command_async("download_and_register_eps".into(), params) .await?