diff --git a/src/llm/language_model/legacy/servable.cpp b/src/llm/language_model/legacy/servable.cpp index aca6b0de22..98ac661ab9 100644 --- a/src/llm/language_model/legacy/servable.cpp +++ b/src/llm/language_model/legacy/servable.cpp @@ -83,22 +83,31 @@ absl::Status LegacyServable::parseRequest(std::shared_ptrapiHandler->isStream()) { legacyExecutionContext->lastStreamerCallbackOutput = ""; // initialize with empty string - auto callback = [& executionInProgress = legacyExecutionContext->executionInProgress, &mutex = legacyExecutionContext->mutex, &lastStreamerCallbackOutput = legacyExecutionContext->lastStreamerCallbackOutput](std::string text) { - SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Streamer callback executed with text: [{}]", text); - { - std::lock_guard lock(mutex); - lastStreamerCallbackOutput += text; - executionInProgress.notify_one(); - } - return ov::genai::StreamingStatus::RUNNING; - }; - ov::AnyMap streamerConfig; - if (legacyExecutionContext->apiHandler->getOutputParser() != nullptr && - (legacyExecutionContext->apiHandler->getOutputParser()->requiresStreamingWithSpecialTokens())) { - streamerConfig.insert(ov::genai::skip_special_tokens(false)); + } + auto callback = [& executionInProgress = legacyExecutionContext->executionInProgress, + &mutex = legacyExecutionContext->mutex, + &lastStreamerCallbackOutput = legacyExecutionContext->lastStreamerCallbackOutput, + &clientDisconnected = legacyExecutionContext->clientDisconnected, + streamMode = legacyExecutionContext->apiHandler->isStream()](std::string text) { + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Streamer callback executed with text: [{}]", text); + if (clientDisconnected.load()) { + executionInProgress.notify_one(); + return ov::genai::StreamingStatus::CANCEL; + } + if (streamMode) { + std::lock_guard lock(mutex); + lastStreamerCallbackOutput += text; + executionInProgress.notify_one(); } - legacyExecutionContext->textStreamer = std::make_shared(getProperties()->tokenizer, callback, streamerConfig); + return ov::genai::StreamingStatus::RUNNING; + }; + ov::AnyMap streamerConfig; + if (legacyExecutionContext->apiHandler->isStream() && + legacyExecutionContext->apiHandler->getOutputParser() != nullptr && + (legacyExecutionContext->apiHandler->getOutputParser()->requiresStreamingWithSpecialTokens())) { + streamerConfig.insert(ov::genai::skip_special_tokens(false)); } + legacyExecutionContext->textStreamer = std::make_shared(getProperties()->tokenizer, callback, streamerConfig); legacyExecutionContext->generationConfigBuilder = std::make_shared(getProperties()->baseGenerationConfig, getProperties()->toolParserName, getProperties()->enableToolGuidedGeneration, @@ -130,10 +139,11 @@ absl::Status LegacyServable::scheduleExecution(std::shared_ptr weakContext = legacyExecutionContext; legacyExecutionContext->payload.client->registerDisconnectionCallback([weakContext]() { if (auto context = weakContext.lock()) { - context->clientDisconnected = true; + context->signalDisconnection(); } }); if (legacyExecutionContext->payload.client->isDisconnected()) { + legacyExecutionContext->signalDisconnection(); return absl::CancelledError(); } properties->legacyExecutor->addRequest(legacyExecutionContext); diff --git a/src/llm/language_model/legacy/servable.hpp b/src/llm/language_model/legacy/servable.hpp index 04638b9094..19af42df85 100644 --- a/src/llm/language_model/legacy/servable.hpp +++ b/src/llm/language_model/legacy/servable.hpp @@ -37,6 +37,11 @@ struct LegacyServableExecutionContext : public GenAiServableExecutionContext { // Disconnection handling std::atomic clientDisconnected{false}; + + void signalDisconnection() { + clientDisconnected = true; + executionInProgress.notify_all(); + } }; struct LegacyServableProperties : public GenAiServableProperties { diff --git a/src/llm/visual_language_model/legacy/servable.cpp b/src/llm/visual_language_model/legacy/servable.cpp index 5a09a24390..df6b2fc885 100644 --- a/src/llm/visual_language_model/legacy/servable.cpp +++ b/src/llm/visual_language_model/legacy/servable.cpp @@ -92,17 +92,25 @@ absl::Status VisualLanguageModelLegacyServable::parseRequest(std::shared_ptrapiHandler->isStream()) { legacyExecutionContext->lastStreamerCallbackOutput = ""; // initialize with empty string - auto callback = [& executionInProgress = legacyExecutionContext->executionInProgress, &mutex = legacyExecutionContext->mutex, &lastStreamerCallbackOutput = legacyExecutionContext->lastStreamerCallbackOutput](std::string text) { - SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Streamer callback executed with text: [{}]", text); - { - std::lock_guard lock(mutex); - lastStreamerCallbackOutput += text; - executionInProgress.notify_one(); - } - return ov::genai::StreamingStatus::RUNNING; - }; - legacyExecutionContext->textStreamer = std::make_shared(getProperties()->tokenizer, callback); } + auto callback = [& executionInProgress = legacyExecutionContext->executionInProgress, + &mutex = legacyExecutionContext->mutex, + &lastStreamerCallbackOutput = legacyExecutionContext->lastStreamerCallbackOutput, + &clientDisconnected = legacyExecutionContext->clientDisconnected, + streamMode = legacyExecutionContext->apiHandler->isStream()](std::string text) { + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Streamer callback executed with text: [{}]", text); + if (clientDisconnected.load()) { + executionInProgress.notify_one(); + return ov::genai::StreamingStatus::CANCEL; + } + if (streamMode) { + std::lock_guard lock(mutex); + lastStreamerCallbackOutput += text; + executionInProgress.notify_one(); + } + return ov::genai::StreamingStatus::RUNNING; + }; + legacyExecutionContext->textStreamer = std::make_shared(getProperties()->tokenizer, callback); legacyExecutionContext->generationConfigBuilder = std::make_shared(getProperties()->baseGenerationConfig, getProperties()->toolParserName, getProperties()->enableToolGuidedGeneration, @@ -123,10 +131,11 @@ absl::Status VisualLanguageModelLegacyServable::scheduleExecution(std::shared_pt std::weak_ptr weakContext = legacyExecutionContext; legacyExecutionContext->payload.client->registerDisconnectionCallback([weakContext]() { if (auto context = weakContext.lock()) { - context->clientDisconnected = true; + context->signalDisconnection(); } }); if (legacyExecutionContext->payload.client->isDisconnected()) { + legacyExecutionContext->signalDisconnection(); return absl::CancelledError(); } properties->legacyExecutor->addRequest(legacyExecutionContext); diff --git a/src/llm/visual_language_model/legacy/servable.hpp b/src/llm/visual_language_model/legacy/servable.hpp index 12459b67be..8828153e7a 100644 --- a/src/llm/visual_language_model/legacy/servable.hpp +++ b/src/llm/visual_language_model/legacy/servable.hpp @@ -40,6 +40,11 @@ struct VisualLanguageModelLegacyServableExecutionContext : public GenAiServableE // Disconnection handling std::atomic clientDisconnected{false}; + + void signalDisconnection() { + clientDisconnected = true; + executionInProgress.notify_all(); + } }; struct VisualLanguageModelLegacyServableProperties : public GenAiServableProperties {