diff --git a/src/llm/apis/openai_completions.cpp b/src/llm/apis/openai_completions.cpp index ae2492abd7..8398011fbe 100644 --- a/src/llm/apis/openai_completions.cpp +++ b/src/llm/apis/openai_completions.cpp @@ -17,7 +17,9 @@ #include "openai_completions.hpp" #include +#include #include +#include #include "src/port/rapidjson_stringbuffer.hpp" #include "src/port/rapidjson_writer.hpp" #include @@ -44,6 +46,57 @@ namespace ovms { constexpr size_t DEFAULT_MAX_STOP_WORDS = 16; // same as deep-seek +namespace { + +ov::genai::JsonContainer rapidJsonValueToJsonContainer(const rapidjson::Value& value) { + if (value.IsNull()) { + return ov::genai::JsonContainer(nullptr); + } + if (value.IsBool()) { + return ov::genai::JsonContainer(value.GetBool()); + } + if (value.IsInt()) { + return ov::genai::JsonContainer(value.GetInt()); + } + if (value.IsUint()) { + return ov::genai::JsonContainer(static_cast(value.GetUint())); + } + if (value.IsInt64()) { + return ov::genai::JsonContainer(value.GetInt64()); + } + if (value.IsUint64()) { + auto uintValue = value.GetUint64(); + if (uintValue <= static_cast(std::numeric_limits::max())) { + return ov::genai::JsonContainer(static_cast(uintValue)); + } + return ov::genai::JsonContainer(static_cast(uintValue)); + } + if (value.IsDouble()) { + return ov::genai::JsonContainer(value.GetDouble()); + } + if (value.IsString()) { + return ov::genai::JsonContainer(std::string(value.GetString(), value.GetStringLength())); + } + if (value.IsArray()) { + ov::genai::JsonContainer arrayContainer = ov::genai::JsonContainer::array(); + for (const auto& item : value.GetArray()) { + arrayContainer.push_back(rapidJsonValueToJsonContainer(item)); + } + return arrayContainer; + } + if (value.IsObject()) { + ov::genai::JsonContainer objectContainer = ov::genai::JsonContainer::object(); + for (auto member = value.MemberBegin(); member != value.MemberEnd(); ++member) { + const std::string key(member->name.GetString(), member->name.GetStringLength()); + objectContainer[key] = rapidJsonValueToJsonContainer(member->value); + } + return objectContainer; + } + throw std::invalid_argument("Unsupported JSON value type"); +} + +} // namespace + absl::Status OpenAIChatCompletionsHandler::parseCompletionsPart() { // prompt: string auto it = doc.FindMember("prompt"); @@ -439,6 +492,27 @@ absl::Status OpenAIChatCompletionsHandler::parseTools() { return absl::OkStatus(); } +absl::StatusOr> OpenAIChatCompletionsHandler::parseToolsToJsonContainer() { + auto it = doc.FindMember("tools"); + if (it == doc.MemberEnd() || it->value.IsNull()) { + return std::nullopt; + } + try { + return rapidJsonValueToJsonContainer(it->value); + } catch (const std::exception& e) { + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Direct tools conversion to JsonContainer failed: {}. Falling back to JSON string conversion.", e.what()); + try { + rapidjson::StringBuffer toolsBuffer; + rapidjson::Writer toolsWriter(toolsBuffer); + it->value.Accept(toolsWriter); + return ov::genai::JsonContainer::from_json_string(toolsBuffer.GetString()); + } catch (const std::exception& fallbackEx) { + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Fallback tools conversion failed: {}", fallbackEx.what()); + return absl::InvalidArgumentError(absl::StrCat("Invalid tools payload: ", fallbackEx.what())); + } + } +} + const bool OpenAIChatCompletionsHandler::areToolsAvailable() const { return !request.toolNameSchemaMap.empty(); } @@ -843,6 +917,7 @@ absl::Status OpenAIChatCompletionsHandler::parseRequest(std::optional void updateUsage(CompletionUsageStatistics& usage, const std::vector& generatedIds, bool echoPrompt) { OVMS_PROFILE_FUNCTION(); + SPDLOG_INFO("Echo prompt: {}", echoPrompt); usage.completionTokens += generatedIds.size(); if (echoPrompt) usage.completionTokens -= usage.promptTokens; @@ -1049,7 +1124,7 @@ std::string OpenAIChatCompletionsHandler::serializeUnaryResponse(const ov::genai return jsonResponse.ToString(); } -std::string OpenAIChatCompletionsHandler::serializeUnaryResponse(const ov::genai::VLMDecodedResults& results, size_t completionTokens) { +std::string OpenAIChatCompletionsHandler::serializeUnaryResponse(const ov::genai::VLMDecodedResults& results) { OVMS_PROFILE_FUNCTION(); OpenAiJsonResponse jsonResponse; jsonResponse.StartObject(); @@ -1057,10 +1132,28 @@ std::string OpenAIChatCompletionsHandler::serializeUnaryResponse(const ov::genai // choices: array of size N, where N is related to n request parameter jsonResponse.StartArray("choices"); int index = 0; - usage.completionTokens = completionTokens; for (int i = 0; i < results.texts.size(); i++) { const std::string& text = results.texts[i]; SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Generated text: {}", text); + + // Workaround to use OVMS unary parsers: get tokens from string + // This way we have detokenized text from GenAI and calculate tokens, to further convert back to text again, in parseOutputIfNeeded... + auto result = tokenizer.encode(text); + auto& input_ids = result.input_ids; + if (input_ids.get_shape().size() != 2) + throw std::runtime_error("input_ids should have 2 dimensions"); + if (input_ids.get_shape()[0] != 1) + throw std::runtime_error("input_ids should have 1 batch size"); + if (input_ids.get_element_type() != ov::element::i64) + throw std::runtime_error("input_ids should have i64 element type"); + + int64_t* input_ids_data = reinterpret_cast(input_ids.data()); + std::vector tokens(input_ids_data, input_ids_data + input_ids.get_shape()[1]); + + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Generated tokens: {}", tokens); + updateUsage(usage, tokens, request.echo); + ParsedOutput parsedOutput = parseOutputIfNeeded(tokens); + jsonResponse.StartObject(); // finish_reason: string; always "stop" for this method jsonResponse.FinishReason("stop"); @@ -1068,16 +1161,10 @@ std::string OpenAIChatCompletionsHandler::serializeUnaryResponse(const ov::genai jsonResponse.Index(index++); // logprobs: object/null; Log probability information for the choice. TODO - // message: object if (endpoint == Endpoint::CHAT_COMPLETIONS) { - jsonResponse.StartObject("message"); - jsonResponse.String("content", text); - jsonResponse.String("role", "assistant"); // TODO - hardcoded - // TODO: tools_call - // TODO: function_call (deprecated) - jsonResponse.EndObject(); + jsonResponse.MessageObject(parsedOutput); } else if (endpoint == Endpoint::COMPLETIONS) { - jsonResponse.String("text", text); + jsonResponse.Text(parsedOutput); } // finish message object diff --git a/src/llm/apis/openai_completions.hpp b/src/llm/apis/openai_completions.hpp index 7292d99c01..9e48701d7f 100644 --- a/src/llm/apis/openai_completions.hpp +++ b/src/llm/apis/openai_completions.hpp @@ -26,6 +26,7 @@ #include #include +#include #include #include #include @@ -34,6 +35,7 @@ #pragma warning(push) #pragma warning(disable : 6001 4324 6385 6386) #include "absl/status/status.h" +#include "absl/status/statusor.h" #pragma warning(pop) #include "../io_processing/output_parser.hpp" #include "openai_request.hpp" @@ -115,12 +117,13 @@ class OpenAIChatCompletionsHandler { absl::Status parseRequest(std::optional maxTokensLimit, uint32_t bestOfLimit, std::optional maxModelLength, std::optional allowedLocalMediaPath = std::nullopt, std::optional> allowedMediaDomains = std::nullopt); absl::Status parseMessages(std::optional allowedLocalMediaPath = std::nullopt, std::optional> allowedMediaDomains = std::nullopt); absl::Status parseTools(); + absl::StatusOr> parseToolsToJsonContainer(); const bool areToolsAvailable() const; std::string serializeUnaryResponse(const std::vector& generationOutputs); std::string serializeUnaryResponse(const ov::genai::EncodedResults& results); // VLMDecodedResults does not contain tokens that we can count, so we need to pass completionTokens in order to provide correct usage statistics - std::string serializeUnaryResponse(const ov::genai::VLMDecodedResults& results, size_t completionTokens); + std::string serializeUnaryResponse(const ov::genai::VLMDecodedResults& results); std::string serializeStreamingChunk(const std::string& chunkResponse, ov::genai::GenerationFinishReason finishReason); std::string serializeStreamingUsageChunk(); }; diff --git a/src/llm/servable.cpp b/src/llm/servable.cpp index 3c463a5aa9..02d1fd0fd8 100644 --- a/src/llm/servable.cpp +++ b/src/llm/servable.cpp @@ -182,8 +182,17 @@ absl::Status GenAiServable::prepareInputs(std::shared_ptrapiHandler->getChatHistory(); constexpr bool add_generation_prompt = true; // confirm it should be hardcoded + auto toolsStatus = executionContext->apiHandler->parseToolsToJsonContainer(); + if (!toolsStatus.ok()) { + return toolsStatus.status(); + } + const auto& tools = toolsStatus.value(); try { - inputText = getProperties()->tokenizer.apply_chat_template(chatHistory, add_generation_prompt); + if (tools.has_value()) { + inputText = getProperties()->tokenizer.apply_chat_template(chatHistory, add_generation_prompt, {}, tools); + } else { + inputText = getProperties()->tokenizer.apply_chat_template(chatHistory, add_generation_prompt); + } } catch (const std::exception& e) { SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Failed to apply chat template: {}", e.what()); return absl::Status(absl::StatusCode::kInvalidArgument, "Failed to apply chat template. The model either does not have chat template or has an invalid one."); diff --git a/src/llm/visual_language_model/continuous_batching/servable.cpp b/src/llm/visual_language_model/continuous_batching/servable.cpp index 6a38d3e4bb..bd292ada00 100644 --- a/src/llm/visual_language_model/continuous_batching/servable.cpp +++ b/src/llm/visual_language_model/continuous_batching/servable.cpp @@ -93,7 +93,16 @@ absl::Status VisualLanguageModelServable::prepareInputs(std::shared_ptrinputText = properties->tokenizer.apply_chat_template(chatHistory, add_generation_prompt); + auto toolsStatus = vlmExecutionContext->apiHandler->parseToolsToJsonContainer(); + if (!toolsStatus.ok()) { + return toolsStatus.status(); + } + const auto& tools = toolsStatus.value(); + if (tools.has_value()) { + vlmExecutionContext->inputText = properties->tokenizer.apply_chat_template(chatHistory, add_generation_prompt, {}, tools); + } else { + vlmExecutionContext->inputText = properties->tokenizer.apply_chat_template(chatHistory, add_generation_prompt); + } } else { return absl::InvalidArgumentError("Unsupported endpoint"); } diff --git a/src/llm/visual_language_model/legacy/legacy_executor.cpp b/src/llm/visual_language_model/legacy/legacy_executor.cpp index 1e1ce255cb..a21c799cec 100644 --- a/src/llm/visual_language_model/legacy/legacy_executor.cpp +++ b/src/llm/visual_language_model/legacy/legacy_executor.cpp @@ -16,7 +16,7 @@ #include "legacy_executor.hpp" #include "servable.hpp" -#include "vector" +#include namespace ovms { VisualLanguageModelLegacyExecutor::VisualLanguageModelLegacyExecutor(std::shared_ptr pipe) { diff --git a/src/llm/visual_language_model/legacy/servable.cpp b/src/llm/visual_language_model/legacy/servable.cpp index 5a09a24390..fed304bf80 100644 --- a/src/llm/visual_language_model/legacy/servable.cpp +++ b/src/llm/visual_language_model/legacy/servable.cpp @@ -81,7 +81,9 @@ absl::Status VisualLanguageModelLegacyServable::parseRequest(std::shared_ptrapiHandler = std::make_shared(*legacyExecutionContext->payload.parsedJson, legacyExecutionContext->endpoint, std::chrono::system_clock::now(), - getProperties()->tokenizer); + getProperties()->tokenizer, + getProperties()->toolParserName, + getProperties()->reasoningParserName); auto& config = ovms::Config::instance(); auto status = executionContext->apiHandler->parseRequest(getProperties()->maxTokensLimit, getProperties()->bestOfLimit, getProperties()->maxModelLength, config.getServerSettings().allowedLocalMediaPath, config.getServerSettings().allowedMediaDomains); @@ -101,7 +103,12 @@ absl::Status VisualLanguageModelLegacyServable::parseRequest(std::shared_ptrtextStreamer = std::make_shared(getProperties()->tokenizer, callback); + ov::AnyMap streamerConfig; + if (legacyExecutionContext->apiHandler->getOutputParser() != nullptr && + (legacyExecutionContext->apiHandler->getOutputParser()->requiresStreamingWithSpecialTokens())) { + streamerConfig.insert(ov::genai::skip_special_tokens(false)); + } + legacyExecutionContext->textStreamer = std::make_shared(getProperties()->tokenizer, callback, streamerConfig); } legacyExecutionContext->generationConfigBuilder = std::make_shared(getProperties()->baseGenerationConfig, getProperties()->toolParserName, @@ -150,13 +157,7 @@ absl::Status VisualLanguageModelLegacyServable::prepareCompleteResponse(std::sha if (legacyExecutionContext->payload.client->isDisconnected()) { return absl::CancelledError(); } - size_t completionTokens = 0; - for (std::string text : legacyExecutionContext->results.texts) { - auto tokensTensor = properties->tokenizer.encode(text, ov::genai::add_special_tokens(false)).input_ids; - completionTokens += tokensTensor.get_size(); - } - SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Generated tokens number: {}", completionTokens); - executionContext->response = executionContext->apiHandler->serializeUnaryResponse(legacyExecutionContext->results, completionTokens); + executionContext->response = executionContext->apiHandler->serializeUnaryResponse(legacyExecutionContext->results); SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Complete unary response: {}", executionContext->response); return absl::OkStatus(); } @@ -252,7 +253,16 @@ absl::Status VisualLanguageModelLegacyServable::prepareInputs(std::shared_ptrinputText = properties->tokenizer.apply_chat_template(chatHistory, add_generation_prompt); + auto toolsStatus = vlmExecutionContext->apiHandler->parseToolsToJsonContainer(); + if (!toolsStatus.ok()) { + return toolsStatus.status(); + } + const auto& tools = toolsStatus.value(); + if (tools.has_value()) { + vlmExecutionContext->inputText = properties->tokenizer.apply_chat_template(chatHistory, add_generation_prompt, {}, tools); + } else { + vlmExecutionContext->inputText = properties->tokenizer.apply_chat_template(chatHistory, add_generation_prompt); + } } else { return absl::InvalidArgumentError("Unsupported endpoint"); } diff --git a/src/llm/visual_language_model/legacy/servable_initializer.cpp b/src/llm/visual_language_model/legacy/servable_initializer.cpp index 5fddcbc98e..ec8bfd327a 100644 --- a/src/llm/visual_language_model/legacy/servable_initializer.cpp +++ b/src/llm/visual_language_model/legacy/servable_initializer.cpp @@ -53,6 +53,14 @@ Status VisualLanguageModelLegacyServableInitializer::initialize(std::shared_ptr< if (std::filesystem::exists(modelGenerationConfigPath)) { properties->baseGenerationConfig = ov::genai::GenerationConfig(modelGenerationConfigPath.string()); } + + if (nodeOptions.has_tool_parser()) { + properties->toolParserName = nodeOptions.tool_parser(); + } + + if (nodeOptions.has_reasoning_parser()) { + properties->reasoningParserName = nodeOptions.reasoning_parser(); + } properties->schedulerConfig.max_num_batched_tokens = nodeOptions.max_num_batched_tokens(); properties->schedulerConfig.cache_size = nodeOptions.cache_size(); properties->schedulerConfig.dynamic_split_fuse = nodeOptions.dynamic_split_fuse(); @@ -90,6 +98,7 @@ Status VisualLanguageModelLegacyServableInitializer::initialize(std::shared_ptr< } properties->bestOfLimit = nodeOptions.best_of_limit(); properties->maxModelLength = parseMaxModelLength(parsedModelsPath); + properties->enableToolGuidedGeneration = nodeOptions.enable_tool_guided_generation(); return StatusCode::OK; } diff --git a/src/test/http_openai_handler_test.cpp b/src/test/http_openai_handler_test.cpp index fec2009867..0ddb241fa3 100644 --- a/src/test/http_openai_handler_test.cpp +++ b/src/test/http_openai_handler_test.cpp @@ -1175,6 +1175,170 @@ TEST_F(HttpOpenAIHandlerParsingTest, ParseRequestWithTools_Provided1_ChoiceNone) assertRequestWithTools(providedTools, toolsChoice, expectedJson); } +TEST_F(HttpOpenAIHandlerParsingTest, ParseRequestWithTools_ParsesToolsJsonContainerOnDemand) { + std::string json = R"({ + "model": "llama", + "messages": [ + { + "role": "user", + "content": "What is the weather?" + } + ], + "tools": [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string" + } + }, + "required": ["location"] + } + } + } + ] + })"; + doc.Parse(json.c_str()); + ASSERT_FALSE(doc.HasParseError()); + std::optional maxTokensLimit; + uint32_t bestOfLimit = 0; + std::optional maxModelLength; + std::shared_ptr apiHandler = + std::make_shared(doc, ovms::Endpoint::CHAT_COMPLETIONS, std::chrono::system_clock::now(), *tokenizer); + + ASSERT_EQ(apiHandler->parseRequest(maxTokensLimit, bestOfLimit, maxModelLength), absl::OkStatus()); + auto toolsStatus = apiHandler->parseToolsToJsonContainer(); + ASSERT_TRUE(toolsStatus.ok()); + const auto& tools = toolsStatus.value(); + ASSERT_TRUE(tools.has_value()); + EXPECT_TRUE(tools->is_array()); + ASSERT_EQ(tools->size(), 1); + ASSERT_TRUE((*tools)[0]["function"]["name"].as_string().has_value()); + EXPECT_EQ((*tools)[0]["function"]["name"].as_string().value(), "get_weather"); +} + +TEST_F(HttpOpenAIHandlerParsingTest, OutputParserInitializationDependsOnParserNames) { + std::string json = R"({ + "model": "llama", + "messages": [ + { + "role": "user", + "content": "hello" + } + ] + })"; + doc.Parse(json.c_str()); + ASSERT_FALSE(doc.HasParseError()); + + auto withoutParserNames = std::make_shared( + doc, ovms::Endpoint::CHAT_COMPLETIONS, std::chrono::system_clock::now(), *tokenizer); + EXPECT_EQ(withoutParserNames->getOutputParser(), nullptr); + + auto withParserNames = std::make_shared( + doc, ovms::Endpoint::CHAT_COMPLETIONS, std::chrono::system_clock::now(), *tokenizer, "llama3", ""); + EXPECT_NE(withParserNames->getOutputParser(), nullptr); +} + +TEST_F(HttpOpenAIHandlerParsingTest, SerializeUnaryResponseVLMDecodedResultsWithToolParser) { + std::string json = R"({ + "model": "llama", + "messages": [ + { + "role": "user", + "content": "hello" + } + ], + "tools": [ + { + "type": "function", + "function": { + "name": "get_weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string" + } + }, + "required": ["location"] + } + } + } + ] + })"; + doc.Parse(json.c_str()); + ASSERT_FALSE(doc.HasParseError()); + + uint32_t maxTokensLimit = 64; + uint32_t bestOfLimit = 0; + std::optional maxModelLength; + + auto apiHandler = std::make_shared( + doc, ovms::Endpoint::CHAT_COMPLETIONS, std::chrono::system_clock::now(), *tokenizer, "hermes3", ""); + + ASSERT_EQ(apiHandler->parseRequest(maxTokensLimit, bestOfLimit, maxModelLength), absl::OkStatus()); + + ov::genai::VLMDecodedResults results; + results.texts.push_back( + "I will call a tool.{\"name\":\"get_weather\",\"arguments\":{\"location\":\"Paris\"}}"); + + std::string serialized = apiHandler->serializeUnaryResponse(results); + + rapidjson::Document responseDoc; + responseDoc.Parse(serialized.c_str()); + ASSERT_FALSE(responseDoc.HasParseError()); + ASSERT_TRUE(responseDoc.IsObject()); + + ASSERT_TRUE(responseDoc.HasMember("choices")); + ASSERT_TRUE(responseDoc["choices"].IsArray()); + ASSERT_EQ(responseDoc["choices"].Size(), 1); + + const auto& choice = responseDoc["choices"][0]; + ASSERT_TRUE(choice.IsObject()); + ASSERT_TRUE(choice.HasMember("finish_reason")); + ASSERT_TRUE(choice["finish_reason"].IsString()); + EXPECT_STREQ(choice["finish_reason"].GetString(), "stop"); + + ASSERT_TRUE(choice.HasMember("message")); + ASSERT_TRUE(choice["message"].IsObject()); + const auto& message = choice["message"]; + + ASSERT_TRUE(message.HasMember("content")); + ASSERT_TRUE(message["content"].IsString()); + EXPECT_STREQ(message["content"].GetString(), "I will call a tool."); + + ASSERT_TRUE(message.HasMember("tool_calls")); + ASSERT_TRUE(message["tool_calls"].IsArray()); + ASSERT_EQ(message["tool_calls"].Size(), 1); + + const auto& toolCall = message["tool_calls"][0]; + ASSERT_TRUE(toolCall.IsObject()); + ASSERT_TRUE(toolCall.HasMember("id")); + ASSERT_TRUE(toolCall["id"].IsString()); + EXPECT_GT(std::string(toolCall["id"].GetString()).size(), 0); + ASSERT_TRUE(toolCall.HasMember("function")); + ASSERT_TRUE(toolCall["function"].IsObject()); + ASSERT_TRUE(toolCall["function"].HasMember("name")); + EXPECT_STREQ(toolCall["function"]["name"].GetString(), "get_weather"); + ASSERT_TRUE(toolCall["function"].HasMember("arguments")); + EXPECT_STREQ(toolCall["function"]["arguments"].GetString(), "{\"location\":\"Paris\"}"); + + ASSERT_TRUE(responseDoc.HasMember("object")); + EXPECT_STREQ(responseDoc["object"].GetString(), "chat.completion"); + ASSERT_TRUE(responseDoc.HasMember("model")); + EXPECT_STREQ(responseDoc["model"].GetString(), "llama"); + + ASSERT_TRUE(responseDoc.HasMember("usage")); + ASSERT_TRUE(responseDoc["usage"].IsObject()); + ASSERT_TRUE(responseDoc["usage"].HasMember("completion_tokens")); + EXPECT_GT(responseDoc["usage"]["completion_tokens"].GetInt(), 0); +} + // Provide get_weather1, get_weather2, get_weather3 but take only first one - get_weather1 TEST_F(HttpOpenAIHandlerParsingTest, ParseRequestWithTools_Provided3_ChoiceFirst) { std::string providedTools = R"( diff --git a/src/test/llm/visual_language_model/complete_flow_test.cpp b/src/test/llm/visual_language_model/complete_flow_test.cpp index 4afbd99f31..4dc22d6fa3 100644 --- a/src/test/llm/visual_language_model/complete_flow_test.cpp +++ b/src/test/llm/visual_language_model/complete_flow_test.cpp @@ -260,6 +260,50 @@ TEST_P(VLMServableExecutionTestParameterized, UnaryRestrictedTagUsed) { ovms::StatusCode::MEDIAPIPE_EXECUTION_ERROR); } +TEST_P(VLMServableExecutionTestParameterized, unaryBasicWithTools) { + auto modelName = GetParam(); + std::vector> fields = { + {"temperature", "0.0"}, + {"stream", "false"}, + {"max_tokens", "5"}, + {"ignore_eos", "true"}, + {"tool_choice", R"("auto")"}, + {"tools", R"([ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather by city", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string" + } + }, + "required": ["city"] + } + } + } + ])"}}; + std::string requestBody = createRequestBody(modelName, fields); + + ASSERT_EQ( + handler->dispatchToProcessor(endpointChatCompletions, requestBody, &response, comp, responseComponents, writer, multiPartParser), + ovms::StatusCode::OK); + + parsedResponse.Parse(response.c_str()); + ASSERT_TRUE(parsedResponse.IsObject()); + ASSERT_TRUE(parsedResponse.HasMember("choices")); + ASSERT_TRUE(parsedResponse["choices"].IsArray()); + ASSERT_EQ(parsedResponse["choices"].Capacity(), 1); + ASSERT_TRUE(parsedResponse["choices"][0].HasMember("message")); + ASSERT_TRUE(parsedResponse["choices"][0]["message"].IsObject()); + ASSERT_TRUE(parsedResponse["choices"][0]["message"]["content"].IsString()); + EXPECT_STREQ(parsedResponse["object"].GetString(), "chat.completion"); + EXPECT_STREQ(parsedResponse["model"].GetString(), modelName.c_str()); +} + // Stream flow TEST_P(VLMServableExecutionTestParameterized, streamBasic) { @@ -361,6 +405,47 @@ TEST_P(VLMServableExecutionTestParameterized, streamRestrictedTagUsed) { ovms::StatusCode::PARTIAL_END); } +TEST_P(VLMServableExecutionTestParameterized, streamBasicWithTools) { + auto modelName = GetParam(); + std::vector> fields = { + {"temperature", "0.0"}, + {"stream", "true"}, + {"max_tokens", "5"}, + {"ignore_eos", "true"}, + {"tool_choice", R"("auto")"}, + {"tools", R"([ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather by city", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string" + } + }, + "required": ["city"] + } + } + } + ])"}}; + std::string requestBody = createRequestBody(modelName, fields); + + std::vector responses; + EXPECT_CALL(*writer, PartialReply(::testing::_)) + .WillRepeatedly([&responses](std::string response) { + responses.push_back(response); + }); + EXPECT_CALL(*writer, PartialReplyEnd()).Times(1); + + ASSERT_EQ( + handler->dispatchToProcessor(endpointChatCompletions, requestBody, &response, comp, responseComponents, writer, multiPartParser), + ovms::StatusCode::PARTIAL_END); + ASSERT_FALSE(responses.empty()); +} + INSTANTIATE_TEST_SUITE_P( VLMServableExecutionTests, VLMServableExecutionTestParameterized,