-
Notifications
You must be signed in to change notification settings - Fork 239
Finish reason=tool_call support #3990
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
c13dd14
8d2af26
968fee5
b848781
e7d8f86
0b4f016
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -848,6 +848,30 @@ void updateUsage(CompletionUsageStatistics& usage, const std::vector<int64_t>& g | |
| usage.completionTokens -= usage.promptTokens; | ||
| } | ||
|
|
||
| static std::optional<std::string> mapFinishReason(ov::genai::GenerationFinishReason finishReason, bool hasToolCalls) { | ||
| // GenerationFinishReason::TOOL_CALLS is not available in GenAI yet. | ||
| // Use feature detection based on presence of tool calls as a workaround until GenAI exposes TOOL_CALLS. | ||
| if (hasToolCalls && finishReason == ov::genai::GenerationFinishReason::STOP) { | ||
| return "tool_calls"; | ||
| } | ||
|
Comment on lines
+854
to
+856
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could be prone to any behavior changes (if model generates tool calls + something more) in the future, but looks like it's fine for now. |
||
| switch (finishReason) { | ||
| case ov::genai::GenerationFinishReason::STOP: | ||
| return "stop"; | ||
| case ov::genai::GenerationFinishReason::LENGTH: | ||
| return "length"; | ||
| default: | ||
| return std::nullopt; | ||
| } | ||
| } | ||
|
|
||
| static bool hasToolCallsInStreamingDelta(const rapidjson::Document& delta) { | ||
| if (!delta.HasMember("delta") || !delta["delta"].IsObject()) { | ||
| return false; | ||
| } | ||
| const auto& deltaObj = delta["delta"]; | ||
| return deltaObj.HasMember("tool_calls") && deltaObj["tool_calls"].IsArray(); | ||
| } | ||
|
|
||
| ParsedOutput OpenAIChatCompletionsHandler::parseOutputIfNeeded(const std::vector<int64_t>& generatedIds) { | ||
| OVMS_PROFILE_FUNCTION(); | ||
| ParsedOutput parsedOutput; | ||
|
|
@@ -878,22 +902,13 @@ std::string OpenAIChatCompletionsHandler::serializeUnaryResponse(const std::vect | |
| // finish_reason: string; | ||
| // "stop" => natural stop point due to stopping criteria | ||
| // "length" => due to reaching max_tokens parameter | ||
| // "tool_calls" => generation stopped due to generated tool calls | ||
|
|
||
| std::string finishReason; | ||
| switch (generationOutput.finish_reason) { | ||
| case ov::genai::GenerationFinishReason::STOP: | ||
| finishReason = "stop"; | ||
| break; | ||
| case ov::genai::GenerationFinishReason::LENGTH: | ||
| finishReason = "length"; | ||
| break; | ||
| default: | ||
| finishReason = "unknown"; | ||
| std::optional<std::string> finishReason = mapFinishReason(generationOutput.finish_reason, !parsedOutput.toolCalls.empty()); | ||
| if (!finishReason.has_value()) { | ||
| SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Unknown finish reason: {}", static_cast<int>(generationOutput.finish_reason)); | ||
| break; | ||
| } | ||
| jsonResponse.FinishReason(finishReason); | ||
|
|
||
| jsonResponse.FinishReason(finishReason.value_or("unknown")); | ||
| // index: integer; Choice index, only n=1 supported anyway | ||
| jsonResponse.Index(index++); | ||
|
|
||
|
|
@@ -1005,8 +1020,9 @@ std::string OpenAIChatCompletionsHandler::serializeUnaryResponse(const ov::genai | |
| updateUsage(usage, tokens, request.echo); | ||
| ParsedOutput parsedOutput = parseOutputIfNeeded(tokens); | ||
| jsonResponse.StartObject(); | ||
| // finish_reason: string; always "stop" for this method | ||
| jsonResponse.FinishReason("stop"); | ||
| // finish_reason: "stop" in regular scenario, "tool_calls" if output contains tool calls | ||
| auto finishReason = mapFinishReason(ov::genai::GenerationFinishReason::STOP, !parsedOutput.toolCalls.empty()); | ||
| jsonResponse.FinishReason(finishReason.value_or("unknown")); | ||
| // index: integer; Choice index, only n=1 supported anyway | ||
| jsonResponse.Index(index++); | ||
|
|
||
|
|
@@ -1058,11 +1074,13 @@ std::string OpenAIChatCompletionsHandler::serializeUnaryResponse(const ov::genai | |
| jsonResponse.StartArray("choices"); | ||
| int index = 0; | ||
| usage.completionTokens = completionTokens; | ||
|
|
||
| for (int i = 0; i < results.texts.size(); i++) { | ||
| const std::string& text = results.texts[i]; | ||
| SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Generated text: {}", text); | ||
| jsonResponse.StartObject(); | ||
| // finish_reason: string; always "stop" for this method | ||
| // tool_calls from VLM legacy pipeline are unsupported due to lack of tokens in API, so finish reason cannot be tool_call | ||
| jsonResponse.FinishReason("stop"); | ||
| // index: integer; Choice index, only n=1 supported anyway | ||
| jsonResponse.Index(index++); | ||
|
|
@@ -1121,6 +1139,7 @@ std::string OpenAIChatCompletionsHandler::serializeStreamingChunk(const std::str | |
|
|
||
| Value choices(kArrayType); | ||
| Value choice(kObjectType); | ||
| bool hasToolCalls = false; | ||
|
|
||
| // choices: array of size N, where N is related to n request parameter | ||
| choices.SetArray(); | ||
|
|
@@ -1129,19 +1148,9 @@ std::string OpenAIChatCompletionsHandler::serializeStreamingChunk(const std::str | |
| // "stop" => natural stop point due to stopping criteria | ||
| // "length" => due to reaching max_tokens parameter | ||
| // "content_filter" => when produced restricted output (not supported) | ||
| // "tool_calls" => generation stopped and waiting for tool output (not supported) | ||
| // "tool_calls" => generation stopped and waiting for tool output | ||
| // "function_call" => deprecated | ||
| // null - natural scenario when the generation has not completed yet | ||
| switch (finishReason) { | ||
| case ov::genai::GenerationFinishReason::STOP: | ||
| choice.AddMember("finish_reason", "stop", allocator); | ||
| break; | ||
| case ov::genai::GenerationFinishReason::LENGTH: | ||
| choice.AddMember("finish_reason", "length", allocator); | ||
| break; | ||
| default: | ||
| choice.AddMember("finish_reason", Value(), allocator); | ||
| } | ||
| // index: integer; Choice index, only n=1 supported anyway | ||
| choice.AddMember("index", 0, allocator); | ||
| // logprobs: object/null; Log probability information for the choice. TODO | ||
|
|
@@ -1155,6 +1164,7 @@ std::string OpenAIChatCompletionsHandler::serializeStreamingChunk(const std::str | |
| if (delta->HasMember("delta")) { | ||
| // Deep copy the "delta" member value into the choice object | ||
| choice.AddMember("delta", Value((*delta)["delta"], allocator), allocator); | ||
| hasToolCalls = hasToolCallsInStreamingDelta(*delta); | ||
| } | ||
|
|
||
| } else { | ||
|
|
@@ -1167,6 +1177,13 @@ std::string OpenAIChatCompletionsHandler::serializeStreamingChunk(const std::str | |
| choice.AddMember("text", Value(chunkResponse.c_str(), allocator), allocator); | ||
| } | ||
|
|
||
| auto serializedFinishReason = mapFinishReason(finishReason, hasToolCalls); | ||
| if (serializedFinishReason.has_value()) { | ||
| choice.AddMember("finish_reason", Value(serializedFinishReason.value().c_str(), allocator), allocator); | ||
| } else { | ||
| choice.AddMember("finish_reason", Value(rapidjson::kNullType), allocator); | ||
| } | ||
|
|
||
| choices.PushBack(choice, allocator); | ||
| doc.AddMember("choices", choices, allocator); | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,6 +31,8 @@ | |
| #include "test_utils.hpp" | ||
| #include "platform_utils.hpp" | ||
|
|
||
| const std::string llama3TokenizerPathForHandlerTests = getGenericFullPathForSrcTest("/ovms/src/test/llm_testing/unsloth/Llama-3.1-8B-Instruct"); | ||
|
|
||
| class HttpOpenAIHandlerTest : public ::testing::Test { | ||
| protected: | ||
| ovms::Server& server = ovms::Server::instance(); | ||
|
|
@@ -402,6 +404,167 @@ class HttpOpenAIHandlerParsingTest : public ::testing::Test { | |
| } | ||
| }; | ||
|
|
||
| static std::vector<int64_t> createLlama3ToolCallTokens(ov::genai::Tokenizer& tokenizer) { | ||
| std::string toolCall = "<|python_tag|>" | ||
| R"({"name": "example_tool", "parameters": {"arg1": "value1", "arg2": 42}})"; | ||
| auto generatedTensor = tokenizer.encode(toolCall, ov::genai::add_special_tokens(true)).input_ids; | ||
| std::vector<int64_t> generatedTokens(generatedTensor.data<int64_t>(), generatedTensor.data<int64_t>() + generatedTensor.get_size()); | ||
| return generatedTokens; | ||
| } | ||
|
|
||
| TEST_F(HttpOpenAIHandlerParsingTest, serializeStreamingChunkReturnsIntermediateNullAndFinallyToolCallsFinishReason) { | ||
| std::shared_ptr<ov::genai::Tokenizer> llama3Tokenizer = std::make_shared<ov::genai::Tokenizer>(llama3TokenizerPathForHandlerTests); | ||
| std::string json = R"({ | ||
| "model": "llama", | ||
| "stream": true, | ||
| "messages": [{"role": "user", "content": "What is weather?"}], | ||
| "tools": [{ | ||
| "type": "function", | ||
| "function": { | ||
| "name": "get_humidity", | ||
| "parameters": { | ||
| "type": "object", | ||
| "properties": { | ||
| "location": {"type": "string"} | ||
| } | ||
| } | ||
| } | ||
| }] | ||
| })"; | ||
| doc.Parse(json.c_str()); | ||
| ASSERT_FALSE(doc.HasParseError()); | ||
|
|
||
| auto apiHandler = std::make_shared<ovms::OpenAIChatCompletionsHandler>(doc, ovms::Endpoint::CHAT_COMPLETIONS, std::chrono::system_clock::now(), *llama3Tokenizer, "llama3"); | ||
| uint32_t maxTokensLimit = 100; | ||
| uint32_t bestOfLimit = 0; | ||
| std::optional<uint32_t> maxModelLength; | ||
| ASSERT_EQ(apiHandler->parseRequest(maxTokensLimit, bestOfLimit, maxModelLength), absl::OkStatus()); | ||
|
|
||
| std::vector<std::pair<std::string, ov::genai::GenerationFinishReason>> stream = { | ||
| {"<|python_tag|>", ov::genai::GenerationFinishReason::NONE}, | ||
| {"{\"", ov::genai::GenerationFinishReason::NONE}, | ||
| {"name", ov::genai::GenerationFinishReason::NONE}, | ||
| {"\":", ov::genai::GenerationFinishReason::NONE}, | ||
| {" \"", ov::genai::GenerationFinishReason::NONE}, | ||
| {"get", ov::genai::GenerationFinishReason::NONE}, | ||
| {"_humidity", ov::genai::GenerationFinishReason::NONE}, | ||
| {"\",", ov::genai::GenerationFinishReason::NONE}, | ||
| {" \"", ov::genai::GenerationFinishReason::NONE}, | ||
| {"parameters", ov::genai::GenerationFinishReason::NONE}, | ||
| {"\":", ov::genai::GenerationFinishReason::NONE}, | ||
| {" {\"", ov::genai::GenerationFinishReason::NONE}, | ||
| {"location", ov::genai::GenerationFinishReason::NONE}, | ||
| {"\":", ov::genai::GenerationFinishReason::NONE}, | ||
| {" \"", ov::genai::GenerationFinishReason::NONE}, | ||
| {"Paris\"}}", ov::genai::GenerationFinishReason::STOP}, | ||
| }; | ||
|
|
||
| std::vector<std::string> serializedChunks; | ||
| for (const auto& [chunk, finishReason] : stream) { | ||
| std::string serialized = apiHandler->serializeStreamingChunk(chunk, finishReason); | ||
| if (!serialized.empty()) { | ||
| serializedChunks.push_back(serialized); | ||
| } | ||
| } | ||
| ASSERT_FALSE(serializedChunks.empty()); | ||
| const std::string& lastChunk = serializedChunks.back(); | ||
| ASSERT_NE(lastChunk.find("\"tool_calls\""), std::string::npos) << lastChunk; | ||
| ASSERT_NE(lastChunk.find("\"finish_reason\":\"tool_calls\""), std::string::npos) << lastChunk; | ||
| // Verify that intermediate chunks with NONE finish_reason are serialized correctly | ||
| ASSERT_GE(serializedChunks.size(), 2u); | ||
| for (size_t i = 0; i + 1 < serializedChunks.size(); ++i) { | ||
| const std::string& chunkStr = serializedChunks[i]; | ||
| ASSERT_NE(chunkStr.find("\"finish_reason\":null"), std::string::npos) << chunkStr; | ||
| } | ||
| } | ||
|
|
||
| TEST_F(HttpOpenAIHandlerParsingTest, serializeUnaryResponseGenerationOutputReturnsToolCallsFinishReason) { | ||
| std::shared_ptr<ov::genai::Tokenizer> llama3Tokenizer = std::make_shared<ov::genai::Tokenizer>(llama3TokenizerPathForHandlerTests); | ||
| std::string json = R"({ | ||
| "model": "llama", | ||
| "stream": false, | ||
| "messages": [{"role": "user", "content": "What is weather?"}], | ||
| "tools": [{ | ||
| "type": "function", | ||
| "function": { | ||
| "name": "example_tool", | ||
| "parameters": {"type": "object"} | ||
| } | ||
| }] | ||
| })"; | ||
| doc.Parse(json.c_str()); | ||
| ASSERT_FALSE(doc.HasParseError()); | ||
|
|
||
| auto apiHandler = std::make_shared<ovms::OpenAIChatCompletionsHandler>(doc, ovms::Endpoint::CHAT_COMPLETIONS, std::chrono::system_clock::now(), *llama3Tokenizer, "llama3"); | ||
| uint32_t maxTokensLimit = 100; | ||
| uint32_t bestOfLimit = 0; | ||
| std::optional<uint32_t> maxModelLength; | ||
| ASSERT_EQ(apiHandler->parseRequest(maxTokensLimit, bestOfLimit, maxModelLength), absl::OkStatus()); | ||
|
|
||
| ov::genai::GenerationOutput generationOutput; | ||
| generationOutput.generated_ids = createLlama3ToolCallTokens(*llama3Tokenizer); | ||
| generationOutput.finish_reason = ov::genai::GenerationFinishReason::STOP; // Change it once GenAI introduces tool_calls finish reason | ||
| std::string serialized = apiHandler->serializeUnaryResponse(std::vector<ov::genai::GenerationOutput>{generationOutput}); | ||
|
|
||
| ASSERT_NE(serialized.find("\"finish_reason\":\"tool_calls\""), std::string::npos) << serialized; | ||
| ASSERT_NE(serialized.find("\"tool_calls\":[{"), std::string::npos) << serialized; | ||
| } | ||
|
|
||
| TEST_F(HttpOpenAIHandlerParsingTest, serializeUnaryResponseEncodedResultsReturnsToolCallsFinishReason) { | ||
| std::shared_ptr<ov::genai::Tokenizer> llama3Tokenizer = std::make_shared<ov::genai::Tokenizer>(llama3TokenizerPathForHandlerTests); | ||
| std::string json = R"({ | ||
| "model": "llama", | ||
| "stream": false, | ||
| "messages": [{"role": "user", "content": "What is weather?"}], | ||
| "tools": [{ | ||
| "type": "function", | ||
| "function": { | ||
| "name": "example_tool", | ||
| "parameters": {"type": "object"} | ||
| } | ||
| }] | ||
| })"; | ||
| doc.Parse(json.c_str()); | ||
| ASSERT_FALSE(doc.HasParseError()); | ||
|
|
||
| auto apiHandler = std::make_shared<ovms::OpenAIChatCompletionsHandler>(doc, ovms::Endpoint::CHAT_COMPLETIONS, std::chrono::system_clock::now(), *llama3Tokenizer, "llama3"); | ||
| uint32_t maxTokensLimit = 100; | ||
| uint32_t bestOfLimit = 0; | ||
| std::optional<uint32_t> maxModelLength; | ||
| ASSERT_EQ(apiHandler->parseRequest(maxTokensLimit, bestOfLimit, maxModelLength), absl::OkStatus()); | ||
|
|
||
| ov::genai::EncodedResults results; | ||
| results.tokens = {createLlama3ToolCallTokens(*llama3Tokenizer)}; | ||
| std::string serialized = apiHandler->serializeUnaryResponse(results); | ||
|
|
||
| ASSERT_NE(serialized.find("\"finish_reason\":\"tool_calls\""), std::string::npos) << serialized; | ||
| ASSERT_NE(serialized.find("\"tool_calls\":[{"), std::string::npos) << serialized; | ||
| } | ||
|
|
||
| // This is unsupported, once we have tool calling for VLM legacy pipeline, change the test | ||
| TEST_F(HttpOpenAIHandlerParsingTest, serializeUnaryResponseVLMSupportsToolCallsFinishReason_Unsupported) { | ||
| std::string json = R"({ | ||
| "model": "llama", | ||
| "stream": false, | ||
| "messages": [{"role": "user", "content": "What is weather?"}] | ||
| })"; | ||
| doc.Parse(json.c_str()); | ||
| ASSERT_FALSE(doc.HasParseError()); | ||
|
|
||
| auto apiHandler = std::make_shared<ovms::OpenAIChatCompletionsHandler>(doc, ovms::Endpoint::CHAT_COMPLETIONS, std::chrono::system_clock::now(), *tokenizer); | ||
| uint32_t maxTokensLimit = 100; | ||
| uint32_t bestOfLimit = 0; | ||
| std::optional<uint32_t> maxModelLength; | ||
| ASSERT_EQ(apiHandler->parseRequest(maxTokensLimit, bestOfLimit, maxModelLength), absl::OkStatus()); | ||
|
|
||
| ov::genai::VLMDecodedResults results; | ||
| results.texts = {"dummy"}; | ||
| std::string serialized = apiHandler->serializeUnaryResponse(results, 1); | ||
|
|
||
| // ASSERT_NE(serialized.find("\"finish_reason\":\"tool_calls\""), std::string::npos) << serialized; | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove or add more context
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. my next PR is introducing tool calling for vlm legacy pipeline, I dont think we need any more context, it is above the test, check
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so, can it be removed?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test will be changed in next PR (qwen3-vl support) so please let me keep it so I dont forget to switch to commented one |
||
| ASSERT_NE(serialized.find("\"finish_reason\":\"stop\""), std::string::npos) << serialized; | ||
| } | ||
|
|
||
| TEST_F(HttpOpenAIHandlerParsingTest, ParsingMessagesSucceedsBase64) { | ||
| std::string json = R"({ | ||
| "model": "llama", | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know of any task for that in GenAI scope?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@apaniukov is there a task for that in GenAI?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CVS-181410