From 0fd2d50ab2ddf4bd733506a4e7e3d57a1855abe3 Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 6 Apr 2026 06:59:01 +0000 Subject: [PATCH 01/11] Extract extract_first_embedding_row from embed() into jni_server_helpers.hpp MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The pure computation inside Java_de_kherud_llama_LlamaModel_embed — parsing out_res["embedding"] as a 2D float array, validating it is non-empty, and returning the first row — was interleaved with JNI allocation calls, making it untestable without a live JVM. Extract it as extract_first_embedding_row(const json &) in jni_server_helpers.hpp. It throws std::runtime_error for empty arrays and nlohmann::json::exception for missing/wrong-type keys; the JNI caller catches std::exception and converts to ThrowNew(c_error_oom). Six unit tests added to test_jni_server_helpers.cpp covering: single row, first-of-many rows, missing key, empty outer array, empty inner array, and a 128-element large row. All run without a JVM or model. https://claude.ai/code/session_0197uhVYKafh2feJkLuq7KiT --- src/main/cpp/jllama.cpp | 29 +++++--------- src/main/cpp/jni_server_helpers.hpp | 48 +++++++++++++++++------ src/test/cpp/test_jni_server_helpers.cpp | 49 ++++++++++++++++++++++++ 3 files changed, 94 insertions(+), 32 deletions(-) diff --git a/src/main/cpp/jllama.cpp b/src/main/cpp/jllama.cpp index 68221dac..4c280077 100644 --- a/src/main/cpp/jllama.cpp +++ b/src/main/cpp/jllama.cpp @@ -849,36 +849,25 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env, const auto out_res = result->to_json(); - // Extract "embedding" as a vector of vectors (2D array) - std::vector> embedding = out_res["embedding"].get>>(); - - // Get total number of rows in the embedding - jsize embedding_rows = embedding.size(); - - // Get total number of columns in the first row (assuming all rows are of equal length) - jsize embedding_cols = embedding_rows > 0 ? embedding[0].size() : 0; - - SRV_INF("Embedding has %d rows and %d columns\n", embedding_rows, embedding_cols); - - // Ensure embedding is not empty - if (embedding.empty() || embedding[0].empty()) { - env->ThrowNew(c_error_oom, "embedding array is empty"); + std::vector first_row; + try { + first_row = extract_first_embedding_row(out_res); + } catch (const std::exception &e) { + env->ThrowNew(c_error_oom, e.what()); return nullptr; } - // Extract only the first row - const std::vector &first_row = embedding[0]; // Reference to avoid copying + const jsize embedding_cols = static_cast(first_row.size()); + SRV_INF("Embedding has %d columns\n", embedding_cols); - // Create a new float array in JNI jfloatArray j_embedding = env->NewFloatArray(embedding_cols); if (j_embedding == nullptr) { env->ThrowNew(c_error_oom, "could not allocate embedding"); return nullptr; } - // Copy the first row into the JNI float array - env->SetFloatArrayRegion(j_embedding, 0, embedding_cols, reinterpret_cast(first_row.data())); - + env->SetFloatArrayRegion(j_embedding, 0, embedding_cols, + reinterpret_cast(first_row.data())); return j_embedding; } diff --git a/src/main/cpp/jni_server_helpers.hpp b/src/main/cpp/jni_server_helpers.hpp index 654054d1..d69a8919 100644 --- a/src/main/cpp/jni_server_helpers.hpp +++ b/src/main/cpp/jni_server_helpers.hpp @@ -22,18 +22,19 @@ // redefinition errors in any TU that already includes server.hpp directly. // // Declaration order (each function must be defined before its first caller): -// 1. get_result_error_message — used by recv_slot_task_result_impl, -// collect_task_results_impl -// 2. json_to_jstring_impl — used by recv_slot_task_result_impl, -// results_to_jstring_impl -// 3. build_completion_tasks_impl — no dependencies on helpers above -// 4. recv_slot_task_result_impl — uses 1 + 2 -// 5. collect_task_results_impl — uses 1 -// 6. results_to_json_impl — no dependencies on helpers above -// 7. results_to_jstring_impl — uses 2 + 6 -// 8. check_infill_support_impl — no dependencies on helpers above -// 9. rerank_results_to_json — no dependencies on helpers above -// 10. append_task — no dependencies on helpers above +// 1. get_result_error_message — used by recv_slot_task_result_impl, +// collect_task_results_impl +// 2. json_to_jstring_impl — used by recv_slot_task_result_impl, +// results_to_jstring_impl +// 3. build_completion_tasks_impl — no dependencies on helpers above +// 4. recv_slot_task_result_impl — uses 1 + 2 +// 5. collect_task_results_impl — uses 1 +// 6. results_to_json_impl — no dependencies on helpers above +// 7. results_to_jstring_impl — uses 2 + 6 +// 8. check_infill_support_impl — no dependencies on helpers above +// 9. rerank_results_to_json — no dependencies on helpers above +// 10. append_task — no dependencies on helpers above +// 11. extract_first_embedding_row — no dependencies on helpers above #include "jni.h" @@ -315,3 +316,26 @@ inline void append_task(server_context *ctx_server, } return true; } + +// --------------------------------------------------------------------------- +// extract_first_embedding_row +// +// Parses out_res["embedding"] as a 2D float array and returns the first row. +// +// Throws std::runtime_error if the outer or inner array is empty. +// Throws nlohmann::json::exception if the "embedding" key is absent or the +// value cannot be coerced to vector>. +// +// Pure computation — no JNI calls, no llama context. +// Unit-testable with any JSON literal: +// {"embedding": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]} +// --------------------------------------------------------------------------- +[[nodiscard]] inline std::vector +extract_first_embedding_row(const json &out_res) { + // .at() throws json::out_of_range if "embedding" is absent. + const auto embedding = out_res.at("embedding").get>>(); + if (embedding.empty() || embedding[0].empty()) { + throw std::runtime_error("embedding array is empty"); + } + return embedding[0]; +} diff --git a/src/test/cpp/test_jni_server_helpers.cpp b/src/test/cpp/test_jni_server_helpers.cpp index 475b8642..1c60f932 100644 --- a/src/test/cpp/test_jni_server_helpers.cpp +++ b/src/test/cpp/test_jni_server_helpers.cpp @@ -498,3 +498,52 @@ TEST_F(CollectResultsFixture, JsonToJstring_ReturnsSentinelFromStub) { EXPECT_EQ(js, reinterpret_cast(0xBEEF)); } + +// ============================================================ +// Tests for extract_first_embedding_row +// +// Pure computation — no JNI or llama context needed. +// ============================================================ + +TEST(ExtractFirstEmbeddingRow, SingleRow_ReturnsRow) { + json j = {{"embedding", {{0.1f, 0.2f, 0.3f}}}}; + auto row = extract_first_embedding_row(j); + ASSERT_EQ(row.size(), 3u); + EXPECT_FLOAT_EQ(row[0], 0.1f); + EXPECT_FLOAT_EQ(row[1], 0.2f); + EXPECT_FLOAT_EQ(row[2], 0.3f); +} + +TEST(ExtractFirstEmbeddingRow, MultipleRows_ReturnsFirstRowOnly) { + json j = {{"embedding", {{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}}}; + auto row = extract_first_embedding_row(j); + ASSERT_EQ(row.size(), 2u); + EXPECT_FLOAT_EQ(row[0], 1.0f); + EXPECT_FLOAT_EQ(row[1], 2.0f); +} + +TEST(ExtractFirstEmbeddingRow, MissingEmbeddingKey_ThrowsJsonException) { + json j = {{"other_key", "value"}}; + EXPECT_THROW(extract_first_embedding_row(j), nlohmann::json::exception); +} + +TEST(ExtractFirstEmbeddingRow, EmptyOuterArray_ThrowsRuntimeError) { + json j = {{"embedding", json::array()}}; + EXPECT_THROW(extract_first_embedding_row(j), std::runtime_error); +} + +TEST(ExtractFirstEmbeddingRow, EmptyInnerArray_ThrowsRuntimeError) { + json j = {{"embedding", {json::array()}}}; + EXPECT_THROW(extract_first_embedding_row(j), std::runtime_error); +} + +TEST(ExtractFirstEmbeddingRow, LargeRow_AllValuesPreserved) { + std::vector vals(128); + for (int i = 0; i < 128; ++i) vals[i] = static_cast(i) * 0.01f; + json j = {{"embedding", {vals}}}; + auto row = extract_first_embedding_row(j); + ASSERT_EQ(row.size(), 128u); + for (int i = 0; i < 128; ++i) { + EXPECT_FLOAT_EQ(row[i], static_cast(i) * 0.01f); + } +} From 04ae5c017913d866a4f954d5c830935099009225 Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 6 Apr 2026 07:24:52 +0000 Subject: [PATCH 02/11] Extract parse_encoding_format_impl from handleEmbeddings into jni_server_helpers.hpp MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The 9-line encoding_format validation block inside Java_de_kherud_llama_LlamaModel_handleEmbeddings mixed a JNI ThrowNew call into pure string-comparison logic. Extract it as parse_encoding_format_impl(const json &) in jni_server_helpers.hpp: Returns false — field absent or "float" → float encoding Returns true — "base64" → base64 encoding Throws std::invalid_argument → unknown value The JNI caller is reduced to a single try/catch line. Six unit tests added to test_jni_server_helpers.cpp: absent field, explicit "float", "base64", unknown string, empty string, and a check that the error message names both valid options. All run without a JVM or model. https://claude.ai/code/session_0197uhVYKafh2feJkLuq7KiT --- src/main/cpp/jllama.cpp | 11 ++---- src/main/cpp/jni_server_helpers.hpp | 29 ++++++++++++++++ src/test/cpp/test_jni_server_helpers.cpp | 44 ++++++++++++++++++++++++ 3 files changed, 75 insertions(+), 9 deletions(-) diff --git a/src/main/cpp/jllama.cpp b/src/main/cpp/jllama.cpp index 4c280077..1c556bc5 100644 --- a/src/main/cpp/jllama.cpp +++ b/src/main/cpp/jllama.cpp @@ -1180,15 +1180,8 @@ JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleEmbeddings(JNIEn } bool use_base64 = false; - if (body.count("encoding_format") != 0) { - const std::string &format = body.at("encoding_format"); - if (format == "base64") { - use_base64 = true; - } else if (format != "float") { - env->ThrowNew(c_llama_error, "encoding_format must be \"float\" or \"base64\""); - return nullptr; - } - } + try { use_base64 = parse_encoding_format_impl(body); } + catch (const std::exception &e) { env->ThrowNew(c_llama_error, e.what()); return nullptr; } std::vector tokenized_prompts = tokenize_input_prompts(ctx_server->vocab, prompt, true, true); diff --git a/src/main/cpp/jni_server_helpers.hpp b/src/main/cpp/jni_server_helpers.hpp index d69a8919..2a907d0c 100644 --- a/src/main/cpp/jni_server_helpers.hpp +++ b/src/main/cpp/jni_server_helpers.hpp @@ -35,6 +35,7 @@ // 9. rerank_results_to_json — no dependencies on helpers above // 10. append_task — no dependencies on helpers above // 11. extract_first_embedding_row — no dependencies on helpers above +// 12. parse_encoding_format_impl — no dependencies on helpers above #include "jni.h" @@ -339,3 +340,31 @@ extract_first_embedding_row(const json &out_res) { } return embedding[0]; } + +// --------------------------------------------------------------------------- +// parse_encoding_format_impl +// +// Reads the optional "encoding_format" field from `body` and returns whether +// base64 encoding was requested. +// +// Returns false — field absent, or value is "float" → use float encoding. +// Returns true — value is "base64" → use base64 encoding. +// Throws std::invalid_argument — value is present but neither "float" nor +// "base64", with a message suitable for forwarding to JNI ThrowNew. +// +// Pure computation — no JNI calls, no llama context. +// Unit-testable with any JSON literal. +// --------------------------------------------------------------------------- +[[nodiscard]] inline bool parse_encoding_format_impl(const json &body) { + if (!body.contains("encoding_format")) { + return false; + } + const std::string format = body.at("encoding_format").get(); + if (format == "base64") { + return true; + } + if (format == "float") { + return false; + } + throw std::invalid_argument("encoding_format must be \"float\" or \"base64\""); +} diff --git a/src/test/cpp/test_jni_server_helpers.cpp b/src/test/cpp/test_jni_server_helpers.cpp index 1c60f932..e1bacde5 100644 --- a/src/test/cpp/test_jni_server_helpers.cpp +++ b/src/test/cpp/test_jni_server_helpers.cpp @@ -547,3 +547,47 @@ TEST(ExtractFirstEmbeddingRow, LargeRow_AllValuesPreserved) { EXPECT_FLOAT_EQ(row[i], static_cast(i) * 0.01f); } } + +// ============================================================ +// Tests for parse_encoding_format_impl +// +// Pure computation — no JNI or llama context needed. +// ============================================================ + +TEST(ParseEncodingFormat, FieldAbsent_ReturnsFalse) { + json body = {{"model", "text-embedding-ada-002"}}; + EXPECT_FALSE(parse_encoding_format_impl(body)); +} + +TEST(ParseEncodingFormat, ExplicitFloat_ReturnsFalse) { + json body = {{"encoding_format", "float"}}; + EXPECT_FALSE(parse_encoding_format_impl(body)); +} + +TEST(ParseEncodingFormat, Base64_ReturnsTrue) { + json body = {{"encoding_format", "base64"}}; + EXPECT_TRUE(parse_encoding_format_impl(body)); +} + +TEST(ParseEncodingFormat, UnknownFormat_ThrowsInvalidArgument) { + json body = {{"encoding_format", "binary"}}; + EXPECT_THROW(parse_encoding_format_impl(body), std::invalid_argument); +} + +TEST(ParseEncodingFormat, EmptyString_ThrowsInvalidArgument) { + json body = {{"encoding_format", ""}}; + EXPECT_THROW(parse_encoding_format_impl(body), std::invalid_argument); +} + +TEST(ParseEncodingFormat, UnknownFormat_MessageMentionsValidOptions) { + json body = {{"encoding_format", "hex"}}; + try { + parse_encoding_format_impl(body); + FAIL() << "Expected std::invalid_argument"; + } catch (const std::invalid_argument &e) { + EXPECT_NE(std::string(e.what()).find("float"), std::string::npos) + << "error message should mention \"float\""; + EXPECT_NE(std::string(e.what()).find("base64"), std::string::npos) + << "error message should mention \"base64\""; + } +} From 156747468a92d844f83f61b66a604ce0e66d8b5f Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 6 Apr 2026 07:28:16 +0000 Subject: [PATCH 03/11] Extract extract_embedding_prompt_impl from handleEmbeddings into jni_server_helpers.hpp MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The 10-line prompt-selection block inside handleEmbeddings mixed a JNI ThrowNew with pure key-precedence logic ("input" preferred over "content", error if neither present). The oaicompat downgrade side effect was also hidden inside the block. Extract it as extract_embedding_prompt_impl(const json &, bool &) in jni_server_helpers.hpp: Returns body["input"] — force_no_oaicompat = false Returns body["content"] — force_no_oaicompat = true Throws std::invalid_argument — neither key present The JNI caller is reduced to a try/catch + an explicit oaicompat assignment, making the downgrade visible at the call site. Six unit tests added to test_jni_server_helpers.cpp: "input" key, "content" key, "input" priority over "content", neither key, empty body, and array-valued "input" (batch embedding). All run without a JVM or model. https://claude.ai/code/session_0197uhVYKafh2feJkLuq7KiT --- src/main/cpp/jllama.cpp | 13 ++---- src/main/cpp/jni_server_helpers.hpp | 29 +++++++++++++ src/test/cpp/test_jni_server_helpers.cpp | 54 ++++++++++++++++++++++++ 3 files changed, 87 insertions(+), 9 deletions(-) diff --git a/src/main/cpp/jllama.cpp b/src/main/cpp/jllama.cpp index 1c556bc5..97ebfcdd 100644 --- a/src/main/cpp/jllama.cpp +++ b/src/main/cpp/jllama.cpp @@ -1168,16 +1168,11 @@ JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleEmbeddings(JNIEn json body = parse_json_params(env, jparams); + bool force_no_oaicompat = false; json prompt; - if (body.count("input") != 0) { - prompt = body.at("input"); - } else if (body.contains("content")) { - oaicompat = OAICOMPAT_TYPE_NONE; - prompt = body.at("content"); - } else { - env->ThrowNew(c_llama_error, "\"input\" or \"content\" must be provided"); - return nullptr; - } + try { prompt = extract_embedding_prompt_impl(body, force_no_oaicompat); } + catch (const std::exception &e) { env->ThrowNew(c_llama_error, e.what()); return nullptr; } + if (force_no_oaicompat) oaicompat = OAICOMPAT_TYPE_NONE; bool use_base64 = false; try { use_base64 = parse_encoding_format_impl(body); } diff --git a/src/main/cpp/jni_server_helpers.hpp b/src/main/cpp/jni_server_helpers.hpp index 2a907d0c..95fc454f 100644 --- a/src/main/cpp/jni_server_helpers.hpp +++ b/src/main/cpp/jni_server_helpers.hpp @@ -36,6 +36,7 @@ // 10. append_task — no dependencies on helpers above // 11. extract_first_embedding_row — no dependencies on helpers above // 12. parse_encoding_format_impl — no dependencies on helpers above +// 13. extract_embedding_prompt_impl — no dependencies on helpers above #include "jni.h" @@ -368,3 +369,31 @@ extract_first_embedding_row(const json &out_res) { } throw std::invalid_argument("encoding_format must be \"float\" or \"base64\""); } + +// --------------------------------------------------------------------------- +// extract_embedding_prompt_impl +// +// Selects the prompt value from an embedding request body using OpenAI-style +// key precedence: "input" is preferred (OAI-compatible); "content" is the +// fallback (legacy, non-OAI path). +// +// On success: returns the prompt JSON value and sets force_no_oaicompat=true +// when "content" was used (caller must downgrade oaicompat to NONE). +// Throws std::invalid_argument if neither "input" nor "content" is present, +// with a message suitable for forwarding to JNI ThrowNew. +// +// Pure computation — no JNI calls, no llama context. +// Unit-testable with any JSON literal. +// --------------------------------------------------------------------------- +[[nodiscard]] inline json extract_embedding_prompt_impl(const json &body, + bool &force_no_oaicompat) { + force_no_oaicompat = false; + if (body.count("input") != 0) { + return body.at("input"); + } + if (body.contains("content")) { + force_no_oaicompat = true; + return body.at("content"); + } + throw std::invalid_argument("\"input\" or \"content\" must be provided"); +} diff --git a/src/test/cpp/test_jni_server_helpers.cpp b/src/test/cpp/test_jni_server_helpers.cpp index e1bacde5..c3b06a7f 100644 --- a/src/test/cpp/test_jni_server_helpers.cpp +++ b/src/test/cpp/test_jni_server_helpers.cpp @@ -591,3 +591,57 @@ TEST(ParseEncodingFormat, UnknownFormat_MessageMentionsValidOptions) { << "error message should mention \"base64\""; } } + +// ============================================================ +// Tests for extract_embedding_prompt_impl +// +// Pure computation — no JNI or llama context needed. +// ============================================================ + +TEST(ExtractEmbeddingPrompt, InputKey_ReturnsValueAndDoesNotSetFlag) { + bool flag = true; // pre-set to verify it gets cleared + json body = {{"input", "hello world"}}; + json prompt = extract_embedding_prompt_impl(body, flag); + EXPECT_EQ(prompt, "hello world"); + EXPECT_FALSE(flag) << "force_no_oaicompat must be false when \"input\" is used"; +} + +TEST(ExtractEmbeddingPrompt, ContentKey_ReturnsValueAndSetsFlag) { + bool flag = false; + json body = {{"content", "some text"}}; + json prompt = extract_embedding_prompt_impl(body, flag); + EXPECT_EQ(prompt, "some text"); + EXPECT_TRUE(flag) << "force_no_oaicompat must be true when \"content\" is used"; +} + +TEST(ExtractEmbeddingPrompt, InputTakesPriorityOverContent) { + bool flag = false; + json body = {{"input", "from input"}, {"content", "from content"}}; + json prompt = extract_embedding_prompt_impl(body, flag); + EXPECT_EQ(prompt, "from input"); + EXPECT_FALSE(flag) << "\"input\" path must not set force_no_oaicompat"; +} + +TEST(ExtractEmbeddingPrompt, NeitherKey_ThrowsInvalidArgument) { + bool flag = false; + json body = {{"model", "text-embedding-ada-002"}}; + EXPECT_THROW(extract_embedding_prompt_impl(body, flag), std::invalid_argument); +} + +TEST(ExtractEmbeddingPrompt, EmptyBody_ThrowsInvalidArgument) { + bool flag = false; + EXPECT_THROW(extract_embedding_prompt_impl(json::object(), flag), std::invalid_argument); +} + +TEST(ExtractEmbeddingPrompt, ArrayPrompt_ReturnedAsIs) { + // "input" may be an array of strings (batch embedding); the function must + // return the JSON value unchanged without trying to coerce it. + bool flag = false; + json body = {{"input", {"sentence one", "sentence two"}}}; + json prompt = extract_embedding_prompt_impl(body, flag); + ASSERT_TRUE(prompt.is_array()); + ASSERT_EQ(prompt.size(), 2u); + EXPECT_EQ(prompt[0], "sentence one"); + EXPECT_EQ(prompt[1], "sentence two"); + EXPECT_FALSE(flag); +} From fdf26aa7b5bcb52e48f540e78e0d44c2ef64a950 Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 6 Apr 2026 07:31:39 +0000 Subject: [PATCH 04/11] Extract build_embeddings_response_json_impl from handleEmbeddings into jni_server_helpers.hpp MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The 7-line response-building block at the end of handleEmbeddings — collecting task results into a JSON array and then conditionally wrapping with format_embeddings_response_oaicompat — was the only pure computation fused with the final json_to_jstring JNI call. Extract it as build_embeddings_response_json_impl in jni_server_helpers.hpp (symmetric counterpart to rerank_results_to_json): OAICOMPAT_TYPE_EMBEDDING → format_embeddings_response_oaicompat(...) any other oaicompat → bare JSON array of result objects The JNI function is reduced to a single return line. Also adds fake_embedding_result / make_embedding helpers to the test file and five unit tests: non-OAI single result, non-OAI multiple results, OAI float encoding (structure check), OAI base64 encoding (embedding is string not array), and OAI usage token summation across results. All run without a JVM or model. https://claude.ai/code/session_0197uhVYKafh2feJkLuq7KiT --- src/main/cpp/jllama.cpp | 11 +-- src/main/cpp/jni_server_helpers.hpp | 33 ++++++++- src/test/cpp/test_jni_server_helpers.cpp | 94 ++++++++++++++++++++++++ 3 files changed, 126 insertions(+), 12 deletions(-) diff --git a/src/main/cpp/jllama.cpp b/src/main/cpp/jllama.cpp index 97ebfcdd..c1d1d239 100644 --- a/src/main/cpp/jllama.cpp +++ b/src/main/cpp/jllama.cpp @@ -1199,16 +1199,7 @@ JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleEmbeddings(JNIEn std::vector results; if (!collect_task_results(env, ctx_server, task_ids, results)) return nullptr; - json responses = json::array(); - for (const auto &result : results) { - responses.push_back(result->to_json()); - } - - json root = oaicompat == OAICOMPAT_TYPE_EMBEDDING - ? format_embeddings_response_oaicompat(body, responses, use_base64) - : json(responses); - - return json_to_jstring(env, root); + return json_to_jstring(env, build_embeddings_response_json_impl(results, body, oaicompat, use_base64)); } JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleTokenize(JNIEnv *env, jobject obj, jstring jcontent, diff --git a/src/main/cpp/jni_server_helpers.hpp b/src/main/cpp/jni_server_helpers.hpp index 95fc454f..25de347b 100644 --- a/src/main/cpp/jni_server_helpers.hpp +++ b/src/main/cpp/jni_server_helpers.hpp @@ -35,8 +35,9 @@ // 9. rerank_results_to_json — no dependencies on helpers above // 10. append_task — no dependencies on helpers above // 11. extract_first_embedding_row — no dependencies on helpers above -// 12. parse_encoding_format_impl — no dependencies on helpers above -// 13. extract_embedding_prompt_impl — no dependencies on helpers above +// 12. parse_encoding_format_impl — no dependencies on helpers above +// 13. extract_embedding_prompt_impl — no dependencies on helpers above +// 14. build_embeddings_response_json_impl — no dependencies on helpers above #include "jni.h" @@ -397,3 +398,31 @@ extract_first_embedding_row(const json &out_res) { } throw std::invalid_argument("\"input\" or \"content\" must be provided"); } + +// --------------------------------------------------------------------------- +// build_embeddings_response_json_impl +// +// Collects task results into a JSON array, then formats the final response: +// - OAICOMPAT_TYPE_EMBEDDING → wraps via format_embeddings_response_oaicompat +// (adds "object":"list", "usage", and per-embedding "object":"embedding") +// - any other oaicompat → returns the bare JSON array +// +// Symmetric counterpart to rerank_results_to_json. +// +// Pure computation — no JNI calls, no llama context. +// Unit-testable with fake_ok_result mocks and any JSON body literal. +// --------------------------------------------------------------------------- +[[nodiscard]] inline json build_embeddings_response_json_impl( + const std::vector &results, + const json &body, + oaicompat_type oaicompat, + bool use_base64) { + json responses = json::array(); + for (const auto &result : results) { + responses.push_back(result->to_json()); + } + if (oaicompat == OAICOMPAT_TYPE_EMBEDDING) { + return format_embeddings_response_oaicompat(body, responses, use_base64); + } + return responses; +} diff --git a/src/test/cpp/test_jni_server_helpers.cpp b/src/test/cpp/test_jni_server_helpers.cpp index c3b06a7f..482d445a 100644 --- a/src/test/cpp/test_jni_server_helpers.cpp +++ b/src/test/cpp/test_jni_server_helpers.cpp @@ -54,6 +54,22 @@ struct fake_ok_result : server_task_result { json to_json() override { return {{"content", msg}}; } }; +// An embedding result whose to_json() returns the shape expected by +// format_embeddings_response_oaicompat: {"embedding": [...], "tokens_evaluated": N}. +struct fake_embedding_result : server_task_result { + std::vector vec; + int tokens_evaluated; + explicit fake_embedding_result(int id_, std::vector v, int tok = 4) + : vec(std::move(v)), tokens_evaluated(tok) { id = id_; } + json to_json() override { + return {{"embedding", vec}, {"tokens_evaluated", tokens_evaluated}}; + } +}; + +static server_task_result_ptr make_embedding(int id_, std::vector v = {0.1f, 0.2f, 0.3f}) { + return std::make_unique(id_, std::move(v)); +} + // An error result — reuses the real server_task_result_error so that // to_json() → format_error_response() → {"message": err_msg, ...} matches // the exact JSON key that collect_task_results_impl reads. @@ -645,3 +661,81 @@ TEST(ExtractEmbeddingPrompt, ArrayPrompt_ReturnedAsIs) { EXPECT_EQ(prompt[1], "sentence two"); EXPECT_FALSE(flag); } + +// ============================================================ +// Tests for build_embeddings_response_json_impl +// +// Pure computation — no JNI or llama context needed. +// Uses fake_embedding_result (added above) for OAI-path tests that +// need "embedding" + "tokens_evaluated" in the result JSON. +// ============================================================ + +TEST(BuildEmbeddingsResponseJson, NonOai_SingleResult_ReturnsBareArray) { + std::vector results; + results.push_back(make_embedding(1, {0.1f, 0.2f})); + + json out = build_embeddings_response_json_impl(results, json::object(), + OAICOMPAT_TYPE_NONE, false); + + ASSERT_TRUE(out.is_array()); + ASSERT_EQ(out.size(), 1u); + EXPECT_TRUE(out[0].contains("embedding")); +} + +TEST(BuildEmbeddingsResponseJson, NonOai_MultipleResults_AllInArray) { + std::vector results; + results.push_back(make_embedding(1, {0.1f})); + results.push_back(make_embedding(2, {0.2f})); + results.push_back(make_embedding(3, {0.3f})); + + json out = build_embeddings_response_json_impl(results, json::object(), + OAICOMPAT_TYPE_NONE, false); + + ASSERT_TRUE(out.is_array()); + EXPECT_EQ(out.size(), 3u); +} + +TEST(BuildEmbeddingsResponseJson, OaiFloat_WrapsWithOaiStructure) { + std::vector results; + results.push_back(make_embedding(1, {0.5f, 0.6f, 0.7f})); + + json body = {{"model", "text-embedding-ada-002"}}; + json out = build_embeddings_response_json_impl(results, body, + OAICOMPAT_TYPE_EMBEDDING, false); + + EXPECT_TRUE(out.is_object()) << "OAI response must be an object"; + EXPECT_EQ(out.value("object", ""), "list"); + EXPECT_TRUE(out.contains("data")) << "OAI response must have \"data\""; + EXPECT_TRUE(out.contains("usage")) << "OAI response must have \"usage\""; + EXPECT_EQ(out.value("model", ""), "text-embedding-ada-002"); + + ASSERT_TRUE(out["data"].is_array()); + ASSERT_EQ(out["data"].size(), 1u); + EXPECT_EQ(out["data"][0].value("object", ""), "embedding"); +} + +TEST(BuildEmbeddingsResponseJson, OaiBase64_EmbeddingEncodedAsString) { + std::vector results; + results.push_back(make_embedding(1, {1.0f, 2.0f})); + + json out = build_embeddings_response_json_impl(results, json::object(), + OAICOMPAT_TYPE_EMBEDDING, /*use_base64=*/true); + + ASSERT_TRUE(out["data"].is_array()); + ASSERT_EQ(out["data"].size(), 1u); + // base64 path stores embedding as a string, not an array + EXPECT_TRUE(out["data"][0]["embedding"].is_string()) + << "base64 embedding must be serialised as a string"; +} + +TEST(BuildEmbeddingsResponseJson, OaiUsage_TokensSummedAcrossResults) { + std::vector results; + results.push_back(std::make_unique(1, std::vector{0.1f}, /*tok=*/3)); + results.push_back(std::make_unique(2, std::vector{0.2f}, /*tok=*/5)); + + json out = build_embeddings_response_json_impl(results, json::object(), + OAICOMPAT_TYPE_EMBEDDING, false); + + EXPECT_EQ(out["usage"].value("prompt_tokens", 0), 8) + << "usage.prompt_tokens must be sum of tokens_evaluated across all results"; +} From 69b7e709bbc8f69d3dbe2b3a8da555b026547603 Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 6 Apr 2026 07:59:07 +0000 Subject: [PATCH 05/11] Reorganise C++ helpers: json_helpers.hpp + unified jni_helpers.hpp MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Establish a clear semantic boundary between pure JSON transforms and JNI bridge code: json_helpers.hpp (new) Pure data-transformation functions — zero JNI, zero llama state. Moved from jni_server_helpers.hpp: get_result_error_message, results_to_json, rerank_results_to_json, build_embeddings_response_json, extract_first_embedding_row, parse_encoding_format, extract_embedding_prompt. Newly extracted from jllama.cpp: is_infill_request, parse_slot_prompt_similarity, parse_positive_int_config. All functions are independently unit-testable without a live JVM. jni_helpers.hpp (rewritten) Merges the former jni_helpers.hpp (Layer A: handle management) and jni_server_helpers.hpp (Layer B: server orchestration) into one file. Layer B includes json_helpers.hpp so bridge helpers can call transforms directly. results_to_jstring_impl delegates to results_to_json. jllama.cpp Drops include of jni_server_helpers.hpp (superseded by jni_helpers.hpp). Call sites updated: *_impl suffix dropped for functions moved to json_helpers.hpp; requestCompletion uses is_infill_request; configureParallelInference uses parse_slot_prompt_similarity / parse_positive_int_config with a single try/catch. Tests test_json_helpers.cpp (new): ~60 tests covering all json_helpers.hpp functions, using fake result types — no JVM, no model required. test_jni_helpers.cpp: merged with relevant tests from the deleted test_jni_server_helpers.cpp; three fixtures cover mock JNI, server response queue, and JNI array access. test_jni_server_helpers.cpp deleted (content fully migrated). CMakeLists.txt: replaces test_jni_server_helpers.cpp with test_json_helpers.cpp in the jllama_test executable. https://claude.ai/code/session_0197uhVYKafh2feJkLuq7KiT --- CMakeLists.txt | 2 +- src/main/cpp/jllama.cpp | 58 +- src/main/cpp/jni_helpers.hpp | 307 ++++++++-- src/main/cpp/jni_server_helpers.hpp | 428 ------------- src/main/cpp/json_helpers.hpp | 243 ++++++++ src/test/cpp/test_jni_helpers.cpp | 511 ++++++++++------ src/test/cpp/test_jni_server_helpers.cpp | 741 ----------------------- src/test/cpp/test_json_helpers.cpp | 469 ++++++++++++++ src/test/cpp/test_server.cpp | 2 +- 9 files changed, 1324 insertions(+), 1437 deletions(-) delete mode 100644 src/main/cpp/jni_server_helpers.hpp create mode 100644 src/main/cpp/json_helpers.hpp delete mode 100644 src/test/cpp/test_jni_server_helpers.cpp create mode 100644 src/test/cpp/test_json_helpers.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 19e45d52..2decc648 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -242,7 +242,7 @@ if(BUILD_TESTING) src/test/cpp/test_utils.cpp src/test/cpp/test_server.cpp src/test/cpp/test_jni_helpers.cpp - src/test/cpp/test_jni_server_helpers.cpp + src/test/cpp/test_json_helpers.cpp ) target_include_directories(jllama_test PRIVATE diff --git a/src/main/cpp/jllama.cpp b/src/main/cpp/jllama.cpp index c1d1d239..c6c8eaa5 100644 --- a/src/main/cpp/jllama.cpp +++ b/src/main/cpp/jllama.cpp @@ -2,12 +2,11 @@ #include "arg.h" #include "json-schema-to-grammar.h" -#include "jni_helpers.hpp" #include "llama.h" #include "log.h" #include "nlohmann/json.hpp" #include "server.hpp" -#include "jni_server_helpers.hpp" +#include "jni_helpers.hpp" #include #include @@ -147,7 +146,7 @@ static void throw_invalid_request(JNIEnv *env, const std::exception &e) { } /** - * Convenience wrapper around build_completion_tasks_impl (jni_server_helpers.hpp) + * Convenience wrapper around build_completion_tasks_impl (jni_helpers.hpp) * that supplies the module-level globals so call sites need no boilerplate. */ [[nodiscard]] static bool build_completion_tasks(JNIEnv *env, server_context *ctx_server, @@ -212,7 +211,7 @@ static int require_single_task_id(JNIEnv *env, } /** - * Convenience wrapper around recv_slot_task_result_impl (jni_server_helpers.hpp). + * Convenience wrapper around recv_slot_task_result_impl (jni_helpers.hpp). * Caller must have already registered task_id with add_waiting_task_id() and * posted the task; this wrapper covers recv → check → return. */ @@ -221,7 +220,7 @@ static int require_single_task_id(JNIEnv *env, } /** - * Convenience wrapper around collect_task_results_impl (jni_server_helpers.hpp) + * Convenience wrapper around collect_task_results_impl (jni_helpers.hpp) * that supplies the module-level globals so call sites need no boilerplate. */ [[nodiscard]] static bool collect_task_results(JNIEnv *env, @@ -232,7 +231,7 @@ static int require_single_task_id(JNIEnv *env, } /** - * Convenience wrapper around results_to_jstring_impl (jni_server_helpers.hpp). + * Convenience wrapper around results_to_jstring_impl (jni_helpers.hpp). * Serialises results to a jstring (single object or JSON array). */ [[nodiscard]] static jstring results_to_jstring( @@ -242,7 +241,7 @@ static int require_single_task_id(JNIEnv *env, } /** - * Convenience wrapper around json_to_jstring_impl (jni_server_helpers.hpp). + * Convenience wrapper around json_to_jstring_impl (jni_helpers.hpp). * Serialises any json value to a JNI string via dump() + NewStringUTF. */ [[nodiscard]] static jstring json_to_jstring(JNIEnv *env, const json &j) { @@ -770,11 +769,9 @@ JNIEXPORT jint JNICALL Java_de_kherud_llama_LlamaModel_requestCompletion(JNIEnv json data = parse_json_params(env, jparams); - server_task_type type = SERVER_TASK_TYPE_COMPLETION; - - if (data.contains("input_prefix") || data.contains("input_suffix")) { - type = SERVER_TASK_TYPE_INFILL; - } + const server_task_type type = is_infill_request(data) + ? SERVER_TASK_TYPE_INFILL + : SERVER_TASK_TYPE_COMPLETION; auto completion_id = gen_chatcmplid(); std::vector tasks; @@ -1170,12 +1167,12 @@ JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleEmbeddings(JNIEn bool force_no_oaicompat = false; json prompt; - try { prompt = extract_embedding_prompt_impl(body, force_no_oaicompat); } + try { prompt = extract_embedding_prompt(body, force_no_oaicompat); } catch (const std::exception &e) { env->ThrowNew(c_llama_error, e.what()); return nullptr; } if (force_no_oaicompat) oaicompat = OAICOMPAT_TYPE_NONE; bool use_base64 = false; - try { use_base64 = parse_encoding_format_impl(body); } + try { use_base64 = parse_encoding_format(body); } catch (const std::exception &e) { env->ThrowNew(c_llama_error, e.what()); return nullptr; } std::vector tokenized_prompts = tokenize_input_prompts(ctx_server->vocab, prompt, true, true); @@ -1199,7 +1196,7 @@ JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleEmbeddings(JNIEn std::vector results; if (!collect_task_results(env, ctx_server, task_ids, results)) return nullptr; - return json_to_jstring(env, build_embeddings_response_json_impl(results, body, oaicompat, use_base64)); + return json_to_jstring(env, build_embeddings_response_json(results, body, oaicompat, use_base64)); } JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleTokenize(JNIEnv *env, jobject obj, jstring jcontent, @@ -1276,27 +1273,20 @@ JNIEXPORT jboolean JNICALL Java_de_kherud_llama_LlamaModel_configureParallelInfe json config = parse_json_params(env, jconfig); - if (config.contains("slot_prompt_similarity")) { - float similarity = config["slot_prompt_similarity"].get(); - if (similarity < 0.0f || similarity > 1.0f) { - env->ThrowNew(c_llama_error, "slot_prompt_similarity must be between 0.0 and 1.0"); - return JNI_FALSE; + try { + if (auto v = parse_slot_prompt_similarity(config)) { + ctx_server->slot_prompt_similarity = *v; } - ctx_server->slot_prompt_similarity = similarity; - } - - auto apply_thread_count = [&](const char *key, int &target) -> bool { - if (!config.contains(key)) return true; - int v = config[key].get(); - if (v <= 0) { - env->ThrowNew(c_llama_error, (std::string(key) + " must be greater than 0").c_str()); - return false; + if (auto v = parse_positive_int_config(config, "n_threads")) { + ctx_server->params_base.cpuparams.n_threads = *v; } - target = v; - return true; - }; - if (!apply_thread_count("n_threads", ctx_server->params_base.cpuparams.n_threads)) return JNI_FALSE; - if (!apply_thread_count("n_threads_batch", ctx_server->params_base.cpuparams_batch.n_threads)) return JNI_FALSE; + if (auto v = parse_positive_int_config(config, "n_threads_batch")) { + ctx_server->params_base.cpuparams_batch.n_threads = *v; + } + } catch (const std::exception &e) { + env->ThrowNew(c_llama_error, e.what()); + return JNI_FALSE; + } return JNI_TRUE; } diff --git a/src/main/cpp/jni_helpers.hpp b/src/main/cpp/jni_helpers.hpp index 87c634dc..11247bec 100644 --- a/src/main/cpp/jni_helpers.hpp +++ b/src/main/cpp/jni_helpers.hpp @@ -1,12 +1,51 @@ #pragma once -// jni_helpers.hpp — JNI utility helpers for jllama.cpp +// jni_helpers.hpp — JNI bridge helpers for jllama.cpp. // -// Extracted from jllama.cpp so that the core logic can be tested without a -// running JVM. The single public entry point is get_server_context_impl(), -// which validates the Java-side model handle and returns the native -// server_context pointer. All module-level globals are passed explicitly so -// the function is self-contained and unit-testable with mock JNI environments. +// This file is the single project-side helper header for all JNI bridge code. +// It was formed by merging the former jni_helpers.hpp (handle management) and +// the former jni_server_helpers.hpp (server orchestration) into one coherent file. +// +// Two layers live here: +// +// Layer A — JNI handle management (no server.hpp required): +// jllama_context struct, get_server_context_impl, get_jllama_context_impl, +// require_single_task_id_impl, require_json_field_impl, +// jint_array_to_tokens_impl +// +// Layer B — JNI + server orchestration (server.hpp must precede this header): +// json_to_jstring_impl, results_to_jstring_impl, +// build_completion_tasks_impl, recv_slot_task_result_impl, +// collect_task_results_impl, check_infill_support_impl, append_task +// +// Pure JSON transforms (no JNI, no llama state) live in json_helpers.hpp, +// which is included at the bottom of this file so all bridge helpers can +// call them directly. +// +// IMPORTANT — include order for Layer B: +// server.hpp must be included by the including translation unit BEFORE this +// header. server.hpp has no include guard, so including it here would cause +// redefinition errors in any TU that already includes server.hpp directly. +// +// All parameters are passed explicitly (no module-level globals) so every +// function can be exercised in unit tests using a mock JNIEnv. +// +// Declaration order (each function must be defined before its first caller): +// Layer A: +// 1. jllama_context struct +// 2. get_server_context_impl +// 3. get_jllama_context_impl +// 4. require_single_task_id_impl +// 5. require_json_field_impl +// 6. jint_array_to_tokens_impl +// Layer B (needs server.hpp in TU): +// 7. json_to_jstring_impl +// 8. build_completion_tasks_impl +// 9. recv_slot_task_result_impl — uses get_result_error_message (json_helpers), json_to_jstring_impl +// 10. collect_task_results_impl — uses get_result_error_message (json_helpers) +// 11. results_to_jstring_impl — uses results_to_json (json_helpers), json_to_jstring_impl +// 12. check_infill_support_impl +// 13. append_task #include "jni.h" #include "nlohmann/json.hpp" @@ -15,11 +54,17 @@ #include #include #include +#include -// Forward declaration — callers that need the full definition must include -// server.hpp themselves. +// Forward declaration — Layer A helpers only hold/cast pointers to +// server_context; they never dereference it, so a full definition is not +// needed here. TUs that call Layer B functions must include server.hpp first. struct server_context; +// =========================================================================== +// Layer A — JNI handle management +// =========================================================================== + // --------------------------------------------------------------------------- // jllama_context // @@ -29,9 +74,9 @@ struct server_context; // between thread teardown and JVM shutdown. // --------------------------------------------------------------------------- struct jllama_context { - server_context *server = nullptr; + server_context *server = nullptr; std::thread worker; - bool vocab_only = false; + bool vocab_only = false; // Signals that the worker thread has entered start_loop() and is ready. // Without this, terminate() can race with start_loop() setting running=true. std::atomic worker_ready{false}; @@ -45,13 +90,11 @@ struct jllama_context { // // On success: returns a non-null server_context*. // On failure: throws "Model is not loaded" via JNI and returns nullptr. -// -// Parameters are passed explicitly (no module-level globals) so the function -// can be exercised from unit tests using a mock JNIEnv. // --------------------------------------------------------------------------- -[[nodiscard]] inline server_context *get_server_context_impl(JNIEnv *env, jobject obj, - jfieldID field_id, - jclass error_class) { +[[nodiscard]] inline server_context *get_server_context_impl(JNIEnv *env, + jobject obj, + jfieldID field_id, + jclass error_class) { const jlong handle = env->GetLongField(obj, field_id); if (handle == 0) { env->ThrowNew(error_class, "Model is not loaded"); @@ -63,24 +106,19 @@ struct jllama_context { // --------------------------------------------------------------------------- // get_jllama_context_impl // -// Like get_server_context_impl, but returns the jllama_context wrapper -// itself instead of its inner server_context. Used ONLY by the delete -// path, which must call `delete jctx` and therefore needs the outer struct, -// not just its .server member. +// Like get_server_context_impl but returns the jllama_context wrapper itself. +// Used ONLY by the delete path, which must call `delete jctx`. // // Intentionally does NOT throw on null: a zero handle means the model was // already deleted (or never fully initialised), which is a valid no-op for -// a destructor-style call. All other callers should use -// get_server_context_impl instead, which does throw. -// -// On success: returns a non-null jllama_context*. -// On null handle: returns nullptr silently (no JNI exception is thrown). +// a destructor-style call. // --------------------------------------------------------------------------- -[[nodiscard]] inline jllama_context *get_jllama_context_impl(JNIEnv *env, jobject obj, - jfieldID field_id) { +[[nodiscard]] inline jllama_context *get_jllama_context_impl(JNIEnv *env, + jobject obj, + jfieldID field_id) { const jlong handle = env->GetLongField(obj, field_id); if (handle == 0) { - return nullptr; // already deleted or never initialised — silent no-op + return nullptr; } return reinterpret_cast(handle); // NOLINT(*-no-int-to-ptr) } @@ -90,19 +128,11 @@ struct jllama_context { // // Validates that exactly one task was created after dispatch and returns its // ID. Returns 0 (with a JNI exception pending) when the count is not 1. -// -// Used by requestCompletion and requestChatCompletion, which hand the returned -// ID back to the Java caller for streaming consumption via -// receiveCompletionJson. Both functions are restricted to single-prompt, -// single-task invocations. -// -// On success: returns the single task id (> 0 in practice). -// On failure: throws via JNI, returns 0. // --------------------------------------------------------------------------- [[nodiscard]] inline int require_single_task_id_impl( - JNIEnv *env, + JNIEnv *env, const std::unordered_set &task_ids, - jclass error_class) { + jclass error_class) { if (task_ids.size() != 1) { env->ThrowNew(error_class, "multitasking currently not supported"); return 0; @@ -115,17 +145,11 @@ struct jllama_context { // // Checks that `data` contains the given key. Returns true if present. // On missing key: throws " is required" via JNI and returns false. -// -// Extracted from the repeated pattern in handleInfill: -// if (!data.contains("input_prefix")) { ThrowNew(...); return nullptr; } -// if (!data.contains("input_suffix")) { ThrowNew(...); return nullptr; } -// -// Parameters are explicit so the function can be unit-tested without a real JVM. // --------------------------------------------------------------------------- -[[nodiscard]] inline bool require_json_field_impl(JNIEnv *env, +[[nodiscard]] inline bool require_json_field_impl(JNIEnv *env, const nlohmann::json &data, - const char *field, - jclass error_class) { + const char *field, + jclass error_class) { if (data.contains(field)) { return true; } @@ -137,18 +161,191 @@ struct jllama_context { // --------------------------------------------------------------------------- // jint_array_to_tokens_impl // -// Reads a Java int array into a std::vector and releases the -// JNI array elements with JNI_ABORT (read-only — no writeback needed). -// -// Extracted from the identical 4-line pattern repeated in decodeBytes and -// handleDetokenize. Parameters are explicit so the function is unit-testable -// without a real JVM. +// Reads a Java int array into a std::vector and releases the JNI +// array elements with JNI_ABORT (read-only — no writeback needed). // --------------------------------------------------------------------------- [[nodiscard]] inline std::vector jint_array_to_tokens_impl( JNIEnv *env, jintArray array) { - const jsize length = env->GetArrayLength(array); - jint *elements = env->GetIntArrayElements(array, nullptr); + const jsize length = env->GetArrayLength(array); + jint *elements = env->GetIntArrayElements(array, nullptr); std::vector tokens(elements, elements + length); env->ReleaseIntArrayElements(array, elements, JNI_ABORT); return tokens; } + +// =========================================================================== +// Layer B — JNI + server orchestration +// (server.hpp must be included by the TU before this header) +// =========================================================================== + +// json_helpers.hpp provides get_result_error_message, results_to_json, and +// the other pure JSON transforms used by the functions below. +#include "json_helpers.hpp" + +// --------------------------------------------------------------------------- +// json_to_jstring_impl +// +// Serialises any json value to a JNI string via dump() + NewStringUTF. +// --------------------------------------------------------------------------- +[[nodiscard]] inline jstring json_to_jstring_impl(JNIEnv *env, const json &j) { + std::string s = j.dump(); + return env->NewStringUTF(s.c_str()); +} + +// --------------------------------------------------------------------------- +// build_completion_tasks_impl +// +// Reads data["prompt"], tokenises it, and appends one server_task per prompt +// token sequence to `tasks`. task_type and oaicompat are caller-specified. +// +// IMPORTANT: data["prompt"] is read before any ctx_server member is accessed, +// so passing ctx_server=nullptr is safe in tests that exercise the error path +// (missing "prompt" key). +// +// On success: `tasks` is populated, returns true. +// On error: throws via JNI using error_class, returns false. +// --------------------------------------------------------------------------- +[[nodiscard]] inline bool build_completion_tasks_impl( + JNIEnv *env, + server_context *ctx_server, + const json &data, + const std::string &completion_id, + server_task_type task_type, + oaicompat_type oaicompat, + std::vector &tasks, + jclass error_class) { + try { + const auto &prompt = data.at("prompt"); // throws before ctx_server is touched + + std::vector tokenized_prompts = + tokenize_input_prompts(ctx_server->vocab, prompt, true, true); + + tasks.reserve(tokenized_prompts.size()); + for (size_t i = 0; i < tokenized_prompts.size(); i++) { + server_task task = server_task(task_type); + task.id = ctx_server->queue_tasks.get_new_id(); + task.index = i; + + task.prompt_tokens = server_tokens(tokenized_prompts[i], false); + task.params = server_task::params_from_json_cmpl( + ctx_server->ctx, ctx_server->params_base, data); + task.id_selected_slot = json_value(data, "id_slot", -1); + + task.params.oaicompat = oaicompat; + task.params.oaicompat_cmpl_id = completion_id; + + tasks.push_back(std::move(task)); + } + } catch (const std::exception &e) { + const auto &err = format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST); + env->ThrowNew(error_class, err.dump().c_str()); + return false; + } + return true; +} + +// --------------------------------------------------------------------------- +// recv_slot_task_result_impl +// +// Receives a single slot-action result from the response queue, checks for +// an error, and returns the result JSON as a JNI string. +// +// On success: returns a new jstring containing result->to_json().dump(). +// On error: removes the waiting task id, throws via JNI, returns nullptr. +// --------------------------------------------------------------------------- +[[nodiscard]] inline jstring recv_slot_task_result_impl(JNIEnv *env, + server_response &queue, + int task_id, + jclass error_class) { + server_task_result_ptr result = queue.recv(task_id); + queue.remove_waiting_task_id(task_id); + if (result->is_error()) { + env->ThrowNew(error_class, get_result_error_message(result).c_str()); + return nullptr; + } + return json_to_jstring_impl(env, result->to_json()); +} + +// --------------------------------------------------------------------------- +// collect_task_results_impl +// +// Precondition: each ID in task_ids has already been registered with +// queue.add_waiting_task_id() (or add_waiting_tasks()). +// +// On success: appends all results to `out`, removes waiting ids, returns true. +// On error: removes waiting ids, throws via JNI, returns false. +// --------------------------------------------------------------------------- +[[nodiscard]] inline bool collect_task_results_impl( + JNIEnv *env, + server_response &queue, + const std::unordered_set &task_ids, + std::vector &out, + jclass error_class) { + out.reserve(task_ids.size()); + for (size_t i = 0; i < task_ids.size(); i++) { + server_task_result_ptr result = queue.recv(task_ids); + if (result->is_error()) { + queue.remove_waiting_task_ids(task_ids); + env->ThrowNew(error_class, get_result_error_message(result).c_str()); + return false; + } + out.push_back(std::move(result)); + } + queue.remove_waiting_task_ids(task_ids); + return true; +} + +// --------------------------------------------------------------------------- +// results_to_jstring_impl +// +// Serialises a vector of task results to a jstring by delegating JSON +// construction to results_to_json (json_helpers.hpp) and serialisation to +// json_to_jstring_impl. +// --------------------------------------------------------------------------- +[[nodiscard]] inline jstring results_to_jstring_impl( + JNIEnv *env, + const std::vector &results) { + return json_to_jstring_impl(env, results_to_json(results)); +} + +// --------------------------------------------------------------------------- +// check_infill_support_impl +// +// Checks that the model vocabulary has all three fill-in-the-middle (FIM) +// tokens (prefix, suffix, middle). Returns true if infill is supported. +// On failure: throws via JNI and returns false. +// --------------------------------------------------------------------------- +[[nodiscard]] inline bool check_infill_support_impl(JNIEnv *env, + const llama_vocab *vocab, + jclass error_class) { + std::string err; + if (llama_vocab_fim_pre(vocab) == LLAMA_TOKEN_NULL) { err += "prefix token is missing. "; } + if (llama_vocab_fim_suf(vocab) == LLAMA_TOKEN_NULL) { err += "suffix token is missing. "; } + if (llama_vocab_fim_mid(vocab) == LLAMA_TOKEN_NULL) { err += "middle token is missing. "; } + if (!err.empty()) { + env->ThrowNew(error_class, ("Infill is not supported by this model: " + err).c_str()); + return false; + } + return true; +} + +// --------------------------------------------------------------------------- +// append_task +// +// Constructs a server_task of the given type and appends it to `tasks`. +// The caller is responsible for pre-computing `prompt_tokens`. +// `oaicompat` defaults to NONE so rerank call sites need no explicit argument. +// --------------------------------------------------------------------------- +inline void append_task(server_context *ctx_server, + std::vector &tasks, + server_task_type type, + llama_tokens prompt_tokens, + size_t index, + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE) { + server_task task(type); + task.id = ctx_server->queue_tasks.get_new_id(); + task.index = index; + task.prompt_tokens = server_tokens(prompt_tokens, false); + task.params.oaicompat = oaicompat; + tasks.push_back(std::move(task)); +} diff --git a/src/main/cpp/jni_server_helpers.hpp b/src/main/cpp/jni_server_helpers.hpp deleted file mode 100644 index 25de347b..00000000 --- a/src/main/cpp/jni_server_helpers.hpp +++ /dev/null @@ -1,428 +0,0 @@ -#pragma once - -// jni_server_helpers.hpp — JNI helpers that need server.hpp types. -// -// Kept separate from jni_helpers.hpp intentionally: jni_helpers.hpp has a -// deliberately minimal include surface (only jni.h + stdlib) so it can be -// unit-tested without the full llama.cpp stack. Any helper that must reach -// into server.hpp types belongs here instead. -// -// Public entry points: -// build_completion_tasks_impl — tokenise and build a server_task vector -// collect_task_results_impl — drain results from the response queue -// recv_slot_task_result_impl — recv + check a single slot-action result -// append_task — construct and push a single server_task -// -// All parameters are explicit (no module-level globals) so each function can -// be exercised in unit tests using local server objects and a mock JNIEnv. -// -// IMPORTANT — include order: -// server.hpp must be included by the including translation unit BEFORE this -// header. server.hpp has no include guard, so including it here would cause -// redefinition errors in any TU that already includes server.hpp directly. -// -// Declaration order (each function must be defined before its first caller): -// 1. get_result_error_message — used by recv_slot_task_result_impl, -// collect_task_results_impl -// 2. json_to_jstring_impl — used by recv_slot_task_result_impl, -// results_to_jstring_impl -// 3. build_completion_tasks_impl — no dependencies on helpers above -// 4. recv_slot_task_result_impl — uses 1 + 2 -// 5. collect_task_results_impl — uses 1 -// 6. results_to_json_impl — no dependencies on helpers above -// 7. results_to_jstring_impl — uses 2 + 6 -// 8. check_infill_support_impl — no dependencies on helpers above -// 9. rerank_results_to_json — no dependencies on helpers above -// 10. append_task — no dependencies on helpers above -// 11. extract_first_embedding_row — no dependencies on helpers above -// 12. parse_encoding_format_impl — no dependencies on helpers above -// 13. extract_embedding_prompt_impl — no dependencies on helpers above -// 14. build_embeddings_response_json_impl — no dependencies on helpers above - -#include "jni.h" - -#include -#include - -// --------------------------------------------------------------------------- -// get_result_error_message -// -// Extracts the human-readable error string from a failed task result. -// Equivalent to result->to_json()["message"].get(), which -// appears verbatim in five places: -// -// receiveCompletionJson, embed, handleRerank (in jllama.cpp) -// collect_task_results_impl, recv_slot_task_result_impl (in this header) -// --------------------------------------------------------------------------- -[[nodiscard]] inline std::string get_result_error_message( - const server_task_result_ptr &result) { - return result->to_json()["message"].get(); -} - -// --------------------------------------------------------------------------- -// json_to_jstring_impl -// -// Serialises any json value to a JNI string via dump() + NewStringUTF. -// Extracted from the repeated two-line pattern: -// -// std::string response_str = some_json.dump(); -// return env->NewStringUTF(response_str.c_str()); -// -// Used by recv_slot_task_result_impl, results_to_jstring_impl, -// receiveCompletionJson, handleRerank, handleEmbeddings, -// handleTokenize, and handleDetokenize. -// --------------------------------------------------------------------------- -[[nodiscard]] inline jstring json_to_jstring_impl(JNIEnv *env, const json &j) { - std::string s = j.dump(); - return env->NewStringUTF(s.c_str()); -} - -// --------------------------------------------------------------------------- -// build_completion_tasks_impl -// -// Reads data["prompt"], tokenises it, and appends one server_task per prompt -// token sequence to `tasks`. task_type and oaicompat are caller-specified, -// covering all six JNI call sites: -// requestCompletion → COMPLETION or INFILL / NONE (type from caller) -// handleCompletions → COMPLETION / NONE -// handleCompletionsOai → COMPLETION / COMPLETION -// handleChatCompletions → COMPLETION / CHAT -// requestChatCompletion → COMPLETION / NONE (template already applied) -// handleInfill → INFILL / NONE -// -// IMPORTANT: data["prompt"] is read in its own statement before any -// ctx_server member is accessed, so passing ctx_server=nullptr is safe in -// tests that only exercise the error path (missing "prompt" key). -// -// On success: `tasks` is populated, returns true. -// On error: throws via JNI using error_class, returns false. -// --------------------------------------------------------------------------- -[[nodiscard]] inline bool build_completion_tasks_impl(JNIEnv *env, - server_context *ctx_server, - const json &data, - const std::string &completion_id, - server_task_type task_type, - oaicompat_type oaicompat, - std::vector &tasks, - jclass error_class) { - try { - const auto &prompt = data.at("prompt"); // throws before ctx_server is touched - - std::vector tokenized_prompts = - tokenize_input_prompts(ctx_server->vocab, prompt, true, true); - - tasks.reserve(tokenized_prompts.size()); - for (size_t i = 0; i < tokenized_prompts.size(); i++) { - server_task task = server_task(task_type); - task.id = ctx_server->queue_tasks.get_new_id(); - task.index = i; - - task.prompt_tokens = server_tokens(tokenized_prompts[i], false); - task.params = server_task::params_from_json_cmpl( - ctx_server->ctx, ctx_server->params_base, data); - task.id_selected_slot = json_value(data, "id_slot", -1); - - task.params.oaicompat = oaicompat; - task.params.oaicompat_cmpl_id = completion_id; - - tasks.push_back(std::move(task)); - } - } catch (const std::exception &e) { - const auto &err = format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST); - env->ThrowNew(error_class, err.dump().c_str()); - return false; - } - return true; -} - -// --------------------------------------------------------------------------- -// recv_slot_task_result_impl -// -// Receives a single slot-action result from the response queue, checks for -// an error, and returns the result JSON as a JNI string. -// -// Used by all four handleSlotAction switch cases (LIST / SAVE / RESTORE / -// ERASE). The caller is responsible for constructing the task, registering -// the task ID with queue.add_waiting_task_id(), and posting it to the task -// queue; this helper only covers the recv → check → return leg. -// -// On success: returns a new jstring containing result->to_json().dump(). -// On error: removes the waiting task id, throws via JNI, returns nullptr. -// --------------------------------------------------------------------------- -[[nodiscard]] inline jstring recv_slot_task_result_impl(JNIEnv *env, - server_response &queue, - int task_id, - jclass error_class) { - server_task_result_ptr result = queue.recv(task_id); - queue.remove_waiting_task_id(task_id); - if (result->is_error()) { - env->ThrowNew(error_class, get_result_error_message(result).c_str()); - return nullptr; - } - return json_to_jstring_impl(env, result->to_json()); -} - -// --------------------------------------------------------------------------- -// collect_task_results_impl -// -// Precondition: each ID in task_ids has already been registered with -// queue.add_waiting_task_id() (or add_waiting_tasks()) so that -// remove_waiting_task_ids() performs correct cleanup. -// -// On success: appends all results to `out`, removes waiting ids, returns true. -// On error: removes waiting ids, throws via JNI using error_class, -// returns false. The caller must return nullptr (or equivalent -// sentinel) immediately — the JNI exception is already pending. -// --------------------------------------------------------------------------- -[[nodiscard]] inline bool collect_task_results_impl(JNIEnv *env, - server_response &queue, - const std::unordered_set &task_ids, - std::vector &out, - jclass error_class) { - out.reserve(task_ids.size()); - for (size_t i = 0; i < task_ids.size(); i++) { - server_task_result_ptr result = queue.recv(task_ids); - if (result->is_error()) { - queue.remove_waiting_task_ids(task_ids); - env->ThrowNew(error_class, get_result_error_message(result).c_str()); - return false; - } - out.push_back(std::move(result)); - } - queue.remove_waiting_task_ids(task_ids); - return true; -} - -// --------------------------------------------------------------------------- -// results_to_json_impl -// -// Converts a vector of task results to a json value without touching JNI. -// -// When there is exactly one result, the top-level JSON is that result's object -// directly. When there are multiple results, they are wrapped in a JSON array. -// This mirrors the OpenAI API convention used by handleCompletions, -// handleCompletionsOai, handleChatCompletions, and handleInfill. -// -// Separated from results_to_jstring_impl so the construction logic is -// unit-testable without any JNI mock. The caller is responsible for checking -// that `results` is non-empty before calling (an empty vector produces an -// empty JSON array). -// --------------------------------------------------------------------------- -[[nodiscard]] inline json results_to_json_impl( - const std::vector &results) { - if (results.size() == 1) { - return results[0]->to_json(); - } - json arr = json::array(); - for (const auto &res : results) { - arr.push_back(res->to_json()); - } - return arr; -} - -// --------------------------------------------------------------------------- -// results_to_jstring_impl -// -// Serialises a vector of task results to a jstring by delegating JSON -// construction to results_to_json_impl and serialisation to -// json_to_jstring_impl. -// --------------------------------------------------------------------------- -[[nodiscard]] inline jstring results_to_jstring_impl( - JNIEnv *env, - const std::vector &results) { - return json_to_jstring_impl(env, results_to_json_impl(results)); -} - -// --------------------------------------------------------------------------- -// check_infill_support_impl -// -// Checks that the model vocabulary has all three fill-in-the-middle (FIM) -// tokens (prefix, suffix, middle). Returns true if infill is supported. -// On failure: populates a descriptive error message, throws via JNI using -// error_class, and returns false. -// -// Extracted from the 10-line compatibility block in handleInfill so it can -// be unit-tested independently of the JNI dispatch layer. -// --------------------------------------------------------------------------- -// --------------------------------------------------------------------------- -// rerank_results_to_json -// -// Converts a collected vector of rerank task results to a JSON array. -// Each element contains the original document text (looked up via the -// result's "index" field), the index, and the relevance score. -// -// Pure computation — no JNI calls, no llama context required. -// Unit-testable with any vector of server_task_result_ptr and strings. -// --------------------------------------------------------------------------- -[[nodiscard]] inline json rerank_results_to_json( - const std::vector &results, - const std::vector &documents) { - json arr = json::array(); - for (const auto &result : results) { - const auto out = result->to_json(); - int index = out["index"].get(); - float score = out["score"].get(); - arr.push_back({ - {"document", documents[index]}, - {"index", index}, - {"score", score} - }); - } - return arr; -} - -// --------------------------------------------------------------------------- -// append_task -// -// Constructs a server_task of the given type and appends it to `tasks`. -// Captures the repeated 5–6-line block that appears in embed (single task), -// handleEmbeddings (loop), and handleRerank (loop): -// -// server_task task(type); -// task.id = ctx_server->queue_tasks.get_new_id(); -// task.index = index; -// task.prompt_tokens = server_tokens(prompt_tokens, false); -// task.params.oaicompat = oaicompat; -// tasks.push_back(std::move(task)); -// -// The caller is responsible for pre-computing `prompt_tokens` (e.g. via -// format_rerank() for rerank tasks). Taken by value because server_tokens -// constructor requires a non-const lvalue reference. `oaicompat` defaults -// to NONE so the rerank call site needs no explicit argument. -// -// Unit-testable without JNI: takes only C++ objects, no JNIEnv calls. -// --------------------------------------------------------------------------- -inline void append_task(server_context *ctx_server, - std::vector &tasks, - server_task_type type, - llama_tokens prompt_tokens, - size_t index, - oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE) { - server_task task(type); - task.id = ctx_server->queue_tasks.get_new_id(); - task.index = index; - task.prompt_tokens = server_tokens(prompt_tokens, false); - task.params.oaicompat = oaicompat; - tasks.push_back(std::move(task)); -} - -[[nodiscard]] inline bool check_infill_support_impl(JNIEnv *env, - const llama_vocab *vocab, - jclass error_class) { - std::string err; - if (llama_vocab_fim_pre(vocab) == LLAMA_TOKEN_NULL) { err += "prefix token is missing. "; } - if (llama_vocab_fim_suf(vocab) == LLAMA_TOKEN_NULL) { err += "suffix token is missing. "; } - if (llama_vocab_fim_mid(vocab) == LLAMA_TOKEN_NULL) { err += "middle token is missing. "; } - if (!err.empty()) { - env->ThrowNew(error_class, ("Infill is not supported by this model: " + err).c_str()); - return false; - } - return true; -} - -// --------------------------------------------------------------------------- -// extract_first_embedding_row -// -// Parses out_res["embedding"] as a 2D float array and returns the first row. -// -// Throws std::runtime_error if the outer or inner array is empty. -// Throws nlohmann::json::exception if the "embedding" key is absent or the -// value cannot be coerced to vector>. -// -// Pure computation — no JNI calls, no llama context. -// Unit-testable with any JSON literal: -// {"embedding": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]} -// --------------------------------------------------------------------------- -[[nodiscard]] inline std::vector -extract_first_embedding_row(const json &out_res) { - // .at() throws json::out_of_range if "embedding" is absent. - const auto embedding = out_res.at("embedding").get>>(); - if (embedding.empty() || embedding[0].empty()) { - throw std::runtime_error("embedding array is empty"); - } - return embedding[0]; -} - -// --------------------------------------------------------------------------- -// parse_encoding_format_impl -// -// Reads the optional "encoding_format" field from `body` and returns whether -// base64 encoding was requested. -// -// Returns false — field absent, or value is "float" → use float encoding. -// Returns true — value is "base64" → use base64 encoding. -// Throws std::invalid_argument — value is present but neither "float" nor -// "base64", with a message suitable for forwarding to JNI ThrowNew. -// -// Pure computation — no JNI calls, no llama context. -// Unit-testable with any JSON literal. -// --------------------------------------------------------------------------- -[[nodiscard]] inline bool parse_encoding_format_impl(const json &body) { - if (!body.contains("encoding_format")) { - return false; - } - const std::string format = body.at("encoding_format").get(); - if (format == "base64") { - return true; - } - if (format == "float") { - return false; - } - throw std::invalid_argument("encoding_format must be \"float\" or \"base64\""); -} - -// --------------------------------------------------------------------------- -// extract_embedding_prompt_impl -// -// Selects the prompt value from an embedding request body using OpenAI-style -// key precedence: "input" is preferred (OAI-compatible); "content" is the -// fallback (legacy, non-OAI path). -// -// On success: returns the prompt JSON value and sets force_no_oaicompat=true -// when "content" was used (caller must downgrade oaicompat to NONE). -// Throws std::invalid_argument if neither "input" nor "content" is present, -// with a message suitable for forwarding to JNI ThrowNew. -// -// Pure computation — no JNI calls, no llama context. -// Unit-testable with any JSON literal. -// --------------------------------------------------------------------------- -[[nodiscard]] inline json extract_embedding_prompt_impl(const json &body, - bool &force_no_oaicompat) { - force_no_oaicompat = false; - if (body.count("input") != 0) { - return body.at("input"); - } - if (body.contains("content")) { - force_no_oaicompat = true; - return body.at("content"); - } - throw std::invalid_argument("\"input\" or \"content\" must be provided"); -} - -// --------------------------------------------------------------------------- -// build_embeddings_response_json_impl -// -// Collects task results into a JSON array, then formats the final response: -// - OAICOMPAT_TYPE_EMBEDDING → wraps via format_embeddings_response_oaicompat -// (adds "object":"list", "usage", and per-embedding "object":"embedding") -// - any other oaicompat → returns the bare JSON array -// -// Symmetric counterpart to rerank_results_to_json. -// -// Pure computation — no JNI calls, no llama context. -// Unit-testable with fake_ok_result mocks and any JSON body literal. -// --------------------------------------------------------------------------- -[[nodiscard]] inline json build_embeddings_response_json_impl( - const std::vector &results, - const json &body, - oaicompat_type oaicompat, - bool use_base64) { - json responses = json::array(); - for (const auto &result : results) { - responses.push_back(result->to_json()); - } - if (oaicompat == OAICOMPAT_TYPE_EMBEDDING) { - return format_embeddings_response_oaicompat(body, responses, use_base64); - } - return responses; -} diff --git a/src/main/cpp/json_helpers.hpp b/src/main/cpp/json_helpers.hpp new file mode 100644 index 00000000..e47155d9 --- /dev/null +++ b/src/main/cpp/json_helpers.hpp @@ -0,0 +1,243 @@ +#pragma once + +// json_helpers.hpp — Pure JSON transformation helpers. +// +// Every function in this file is pure data transformation: +// - input: nlohmann::json values, server_task_result_ptr, or plain C++ types +// - output: nlohmann::json, std::vector, std::optional, or plain C++ types +// - zero JNI calls (no JNIEnv*, jclass, jstring, …) +// - zero llama state (no llama_context*, llama_vocab*, server_context*) +// +// All functions are unit-testable with JSON literals and fake result objects; +// no JVM and no loaded model are required. +// +// IMPORTANT — include order: +// server.hpp (and transitively utils.hpp) must be included by the including +// translation unit BEFORE this header. That header defines: +// server_task_result_ptr, oaicompat_type, OAICOMPAT_TYPE_EMBEDDING, +// format_embeddings_response_oaicompat, and the `json` type alias. +// server.hpp has no include guard, so pulling it in here would cause +// redefinition errors in any TU that already includes it directly. +// +// Declaration order: +// 1. get_result_error_message — used by nothing above it +// 2. results_to_json — used by nothing above it +// 3. rerank_results_to_json — used by nothing above it +// 4. build_embeddings_response_json — used by nothing above it +// 5. extract_first_embedding_row — used by nothing above it +// 6. parse_encoding_format — used by nothing above it +// 7. extract_embedding_prompt — used by nothing above it +// 8. is_infill_request — used by nothing above it +// 9. parse_slot_prompt_similarity — used by nothing above it +// 10. parse_positive_int_config — used by nothing above it + +#include "nlohmann/json.hpp" + +#include +#include +#include +#include + +// --------------------------------------------------------------------------- +// get_result_error_message +// +// Extracts the human-readable error string from a failed task result. +// Equivalent to result->to_json()["message"].get(). +// +// Used by recv_slot_task_result_impl and collect_task_results_impl in +// jni_helpers.hpp, and directly in receiveCompletionJson, embed, and +// handleRerank in jllama.cpp. +// --------------------------------------------------------------------------- +[[nodiscard]] inline std::string get_result_error_message( + const server_task_result_ptr &result) { + return result->to_json()["message"].get(); +} + +// --------------------------------------------------------------------------- +// results_to_json +// +// Converts a vector of task results to a single json value. +// +// One result → the result's JSON object directly (no wrapping array). +// Many results → a JSON array of each result's JSON object. +// Empty vector → empty JSON array. +// +// This mirrors the OpenAI API convention used by handleCompletions, +// handleCompletionsOai, handleChatCompletions, and handleInfill. +// --------------------------------------------------------------------------- +[[nodiscard]] inline json results_to_json( + const std::vector &results) { + if (results.size() == 1) { + return results[0]->to_json(); + } + json arr = json::array(); + for (const auto &res : results) { + arr.push_back(res->to_json()); + } + return arr; +} + +// --------------------------------------------------------------------------- +// rerank_results_to_json +// +// Converts a collected vector of rerank task results to a JSON array. +// Each element contains the original document text (looked up via the +// result's "index" field), the index, and the relevance score. +// --------------------------------------------------------------------------- +[[nodiscard]] inline json rerank_results_to_json( + const std::vector &results, + const std::vector &documents) { + json arr = json::array(); + for (const auto &result : results) { + const auto out = result->to_json(); + int index = out["index"].get(); + float score = out["score"].get(); + arr.push_back({ + {"document", documents[index]}, + {"index", index}, + {"score", score} + }); + } + return arr; +} + +// --------------------------------------------------------------------------- +// build_embeddings_response_json +// +// Collects task results into a JSON array, then formats the final response: +// - OAICOMPAT_TYPE_EMBEDDING → wraps via format_embeddings_response_oaicompat +// (adds "object":"list", "usage", and per-embedding "object":"embedding") +// - any other oaicompat → returns the bare JSON array +// +// Symmetric counterpart to rerank_results_to_json. +// --------------------------------------------------------------------------- +[[nodiscard]] inline json build_embeddings_response_json( + const std::vector &results, + const json &body, + oaicompat_type oaicompat, + bool use_base64) { + json responses = json::array(); + for (const auto &result : results) { + responses.push_back(result->to_json()); + } + if (oaicompat == OAICOMPAT_TYPE_EMBEDDING) { + return format_embeddings_response_oaicompat(body, responses, use_base64); + } + return responses; +} + +// --------------------------------------------------------------------------- +// extract_first_embedding_row +// +// Parses out_res["embedding"] as a 2D float array and returns the first row. +// +// Throws std::runtime_error if the outer or inner array is empty. +// Throws nlohmann::json::exception if the "embedding" key is absent or the +// value cannot be coerced to vector>. +// --------------------------------------------------------------------------- +[[nodiscard]] inline std::vector +extract_first_embedding_row(const json &out_res) { + // .at() throws json::out_of_range if "embedding" is absent. + const auto embedding = out_res.at("embedding").get>>(); + if (embedding.empty() || embedding[0].empty()) { + throw std::runtime_error("embedding array is empty"); + } + return embedding[0]; +} + +// --------------------------------------------------------------------------- +// parse_encoding_format +// +// Reads the optional "encoding_format" field from `body`. +// +// Returns false — field absent, or value is "float" → use float encoding. +// Returns true — value is "base64" → use base64 encoding. +// Throws std::invalid_argument — value is present but neither "float" nor +// "base64", with a message suitable for forwarding to JNI ThrowNew. +// --------------------------------------------------------------------------- +[[nodiscard]] inline bool parse_encoding_format(const json &body) { + if (!body.contains("encoding_format")) { + return false; + } + const std::string format = body.at("encoding_format").get(); + if (format == "base64") { return true; } + if (format == "float") { return false; } + throw std::invalid_argument("encoding_format must be \"float\" or \"base64\""); +} + +// --------------------------------------------------------------------------- +// extract_embedding_prompt +// +// Selects the prompt value from an embedding request body using OAI-style +// key precedence: "input" is preferred (OAI path); "content" is the fallback +// (legacy non-OAI path). +// +// On success: returns the prompt JSON value. Sets force_no_oaicompat=true +// when "content" was used — the caller must downgrade oaicompat to NONE. +// Throws std::invalid_argument if neither "input" nor "content" is present. +// --------------------------------------------------------------------------- +[[nodiscard]] inline json extract_embedding_prompt(const json &body, + bool &force_no_oaicompat) { + force_no_oaicompat = false; + if (body.count("input") != 0) { + return body.at("input"); + } + if (body.contains("content")) { + force_no_oaicompat = true; + return body.at("content"); + } + throw std::invalid_argument("\"input\" or \"content\" must be provided"); +} + +// --------------------------------------------------------------------------- +// is_infill_request +// +// Returns true if the request data contains "input_prefix" or "input_suffix", +// indicating that the caller wants fill-in-the-middle (infill) inference +// rather than plain completion. +// --------------------------------------------------------------------------- +[[nodiscard]] inline bool is_infill_request(const json &data) { + return data.contains("input_prefix") || data.contains("input_suffix"); +} + +// --------------------------------------------------------------------------- +// parse_slot_prompt_similarity +// +// Reads the optional "slot_prompt_similarity" field from `config`. +// +// Returns empty optional — field absent, no change needed. +// Returns float — validated value in [0.0, 1.0]. +// Throws std::invalid_argument — present but outside [0.0, 1.0]. +// --------------------------------------------------------------------------- +[[nodiscard]] inline std::optional +parse_slot_prompt_similarity(const json &config) { + if (!config.contains("slot_prompt_similarity")) { + return std::nullopt; + } + const float v = config["slot_prompt_similarity"].get(); + if (v < 0.0f || v > 1.0f) { + throw std::invalid_argument("slot_prompt_similarity must be between 0.0 and 1.0"); + } + return v; +} + +// --------------------------------------------------------------------------- +// parse_positive_int_config +// +// Reads an optional integer field `key` from `config` and validates it is > 0. +// +// Returns empty optional — field absent, no change needed. +// Returns int — validated value > 0. +// Throws std::invalid_argument(" must be greater than 0") — present but ≤ 0. +// --------------------------------------------------------------------------- +[[nodiscard]] inline std::optional +parse_positive_int_config(const json &config, const char *key) { + if (!config.contains(key)) { + return std::nullopt; + } + const int v = config[key].get(); + if (v <= 0) { + throw std::invalid_argument(std::string(key) + " must be greater than 0"); + } + return v; +} diff --git a/src/test/cpp/test_jni_helpers.cpp b/src/test/cpp/test_jni_helpers.cpp index eed31666..4ff0a7c4 100644 --- a/src/test/cpp/test_jni_helpers.cpp +++ b/src/test/cpp/test_jni_helpers.cpp @@ -1,257 +1,261 @@ -// Tests for get_server_context_impl() in jni_helpers.hpp. +// Tests for jni_helpers.hpp. // -// The function relies on JNIEnv, which is normally only available when a JVM -// is running. To keep this test self-contained (no JVM), we exploit the fact -// that JNIEnv_ in C++ mode is a thin class whose every method dispatches -// through a JNINativeInterface_ function-pointer table. We zero-initialise -// that table, then patch the two slots we actually call (GetLongField and -// ThrowNew) with small lambda-backed stubs. +// This file covers all functions in jni_helpers.hpp — both Layer A (JNI handle +// management) and Layer B (JNI + server orchestration). // -// Covered scenarios: -// - handle == 0 → ThrowNew("Model is not loaded") called, nullptr returned -// - handle != 0 → no throw, correct server_context* returned -// - ThrowNew is NOT called when the handle is valid +// Pure JSON transform tests live in test_json_helpers.cpp. +// +// Layer A tests (no server.hpp needed for the functions under test, but +// server.hpp is included here for Layer B and to satisfy the TU convention): +// get_server_context_impl, get_jllama_context_impl, +// require_single_task_id_impl, require_json_field_impl, +// jint_array_to_tokens_impl +// +// Layer B tests (need server.hpp + mock JNIEnv + pre-seeded server_response): +// json_to_jstring_impl, results_to_jstring_impl, +// build_completion_tasks_impl, recv_slot_task_result_impl, +// collect_task_results_impl +// +// JNIEnv is mocked via a zero-filled JNINativeInterface_ table with only the +// slots exercised by each test patched. server_response is used directly: +// results are pre-seeded via send() before recv() is called, so the condvar +// is satisfied immediately without blocking. #include #include +#include +#include +#include #include -// jni_helpers.hpp is the unit under test; it includes jni.h which defines -// JNIEnv_ and JNINativeInterface_. +// server.hpp must precede jni_helpers.hpp (no include guard in server.hpp). +#include "server.hpp" #include "jni_helpers.hpp" // ============================================================ -// Minimal mock JNI environment +// Shared fake result types // ============================================================ namespace { -// Mutable globals written by the stub functions, read by the tests. -static jlong g_mock_handle = 0; -static bool g_throw_called = false; -static std::string g_throw_message; +struct fake_ok_result : server_task_result { + std::string msg; + explicit fake_ok_result(int id_, std::string m) : msg(std::move(m)) { id = id_; } + json to_json() override { return {{"content", msg}}; } +}; -// Stub that satisfies the JNINativeInterface_::GetLongField signature. -static jlong JNICALL stub_GetLongField(JNIEnv * /*env*/, jobject /*obj*/, - jfieldID /*id*/) { - return g_mock_handle; +static server_task_result_ptr make_error(int id_, const std::string &msg) { + auto r = std::make_unique(); + r->id = id_; + r->err_msg = msg; + r->err_type = ERROR_TYPE_SERVER; + return r; } -// Stub that satisfies the JNINativeInterface_::ThrowNew signature. -static jint JNICALL stub_ThrowNew(JNIEnv * /*env*/, jclass /*clazz*/, - const char *msg) { +static server_task_result_ptr make_ok(int id_, const std::string &msg = "ok") { + return std::make_unique(id_, msg); +} + +// ============================================================ +// Mock JNI environment helpers +// ============================================================ + +// State captured by stubs — reset in each fixture's SetUp(). +static bool g_throw_called = false; +static std::string g_throw_message; +static std::string g_new_string_utf_value; +static jlong g_mock_handle = 0; + +static jstring g_new_string_utf_sentinel = reinterpret_cast(0xBEEF); + +static jint JNICALL stub_ThrowNew(JNIEnv *, jclass, const char *msg) { g_throw_called = true; g_throw_message = msg ? msg : ""; return 0; } +static jlong JNICALL stub_GetLongField(JNIEnv *, jobject, jfieldID) { + return g_mock_handle; +} +static jstring JNICALL stub_NewStringUTF(JNIEnv *, const char *utf) { + g_new_string_utf_value = utf ? utf : ""; + return g_new_string_utf_sentinel; +} -// Build a JNIEnv that routes GetLongField and ThrowNew through our stubs. -// All other slots remain null; any unexpected call will crash, acting as an -// assertion that we only touch the two operations we intend to. +// Minimal env: ThrowNew + GetLongField + NewStringUTF. JNIEnv *make_mock_env(JNINativeInterface_ &table, JNIEnv_ &env_obj) { std::memset(&table, 0, sizeof(table)); - table.GetLongField = stub_GetLongField; table.ThrowNew = stub_ThrowNew; + table.GetLongField = stub_GetLongField; + table.NewStringUTF = stub_NewStringUTF; env_obj.functions = &table; return &env_obj; } -// Convenience: reset all mock state before each test. +// Base fixture: resets all mock state. struct MockJniFixture : ::testing::Test { JNINativeInterface_ table{}; JNIEnv_ env_obj{}; - JNIEnv *env = nullptr; - - // Dummy field/class handles — their values are never dereferenced by the - // stubs, so any non-null sentinel is fine. - jfieldID dummy_field = reinterpret_cast(0x1); - jclass dummy_class = reinterpret_cast(0x2); + JNIEnv *env = nullptr; + jfieldID dummy_field = reinterpret_cast(0x1); + jclass dummy_class = reinterpret_cast(0x2); void SetUp() override { - env = make_mock_env(table, env_obj); - g_mock_handle = 0; - g_throw_called = false; + env = make_mock_env(table, env_obj); + g_mock_handle = 0; + g_throw_called = false; g_throw_message.clear(); + g_new_string_utf_value.clear(); } }; +// Extends MockJniFixture with a fresh server_response queue. +struct ServerFixture : MockJniFixture { + server_response queue; +}; + } // namespace // ============================================================ -// Test: null handle → ThrowNew + nullptr +// get_server_context_impl // ============================================================ -TEST_F(MockJniFixture, NullHandle_ThrowsAndReturnsNullptr) { - g_mock_handle = 0; // model not loaded +TEST_F(MockJniFixture, GetServerContext_NullHandle_ThrowsAndReturnsNull) { + g_mock_handle = 0; server_context *result = - get_server_context_impl(env, /*obj=*/nullptr, dummy_field, dummy_class); + get_server_context_impl(env, nullptr, dummy_field, dummy_class); - EXPECT_EQ(result, nullptr) - << "Expected nullptr when the model handle is 0"; - EXPECT_TRUE(g_throw_called) - << "Expected ThrowNew to be called for a null handle"; + EXPECT_EQ(result, nullptr); + EXPECT_TRUE(g_throw_called); EXPECT_EQ(g_throw_message, "Model is not loaded"); } -// ============================================================ -// Test: valid handle → correct server_context* returned, no throw -// ============================================================ - -TEST_F(MockJniFixture, ValidHandle_ReturnsServerContextAndDoesNotThrow) { - // Build a minimal jllama_context on the stack. We only need server to be - // set; worker and worker_ready are never touched by the helper. - // server_context is forward-declared in jni_helpers.hpp, so we can legally - // hold a pointer to it without a full definition. +TEST_F(MockJniFixture, GetServerContext_ValidHandle_ReturnsServerContextNoThrow) { server_context *sentinel = reinterpret_cast(0xDEADBEEF); jllama_context fake_ctx; fake_ctx.server = sentinel; - - g_mock_handle = reinterpret_cast(&fake_ctx); + g_mock_handle = reinterpret_cast(&fake_ctx); server_context *result = - get_server_context_impl(env, /*obj=*/nullptr, dummy_field, dummy_class); + get_server_context_impl(env, nullptr, dummy_field, dummy_class); - EXPECT_EQ(result, sentinel) - << "Expected the server pointer embedded in jllama_context"; - EXPECT_FALSE(g_throw_called) - << "ThrowNew must not be called for a valid handle"; + EXPECT_EQ(result, sentinel); + EXPECT_FALSE(g_throw_called); } -// ============================================================ -// Test: ThrowNew message is exactly "Model is not loaded" -// (guards against future typo regressions) -// ============================================================ - -TEST_F(MockJniFixture, NullHandle_ErrorMessageIsExact) { +TEST_F(MockJniFixture, GetServerContext_ErrorMessageIsExact) { g_mock_handle = 0; - (void)get_server_context_impl(env, nullptr, dummy_field, dummy_class); - ASSERT_TRUE(g_throw_called); EXPECT_EQ(g_throw_message, "Model is not loaded"); } -// ============================================================ -// Test: ThrowNew is NOT called when handle is valid -// ============================================================ - -TEST_F(MockJniFixture, ValidHandle_NeverCallsThrowNew) { +TEST_F(MockJniFixture, GetServerContext_ValidHandle_NeverCallsThrowNew) { server_context *sentinel = reinterpret_cast(0xCAFEBABE); jllama_context fake_ctx; fake_ctx.server = sentinel; - g_mock_handle = reinterpret_cast(&fake_ctx); - + g_mock_handle = reinterpret_cast(&fake_ctx); (void)get_server_context_impl(env, nullptr, dummy_field, dummy_class); - EXPECT_FALSE(g_throw_called); } // ============================================================ -// Tests for get_jllama_context_impl() -// -// Key contract differences from get_server_context_impl: -// - Returns the jllama_context* wrapper itself, NOT its inner .server -// - Returns nullptr SILENTLY on null handle (no ThrowNew) -// - Used only by the delete path, where null == valid "already gone" no-op +// get_jllama_context_impl // ============================================================ -TEST_F(MockJniFixture, GetJllamaContext_NullHandle_ReturnsNullptrWithoutThrow) { +TEST_F(MockJniFixture, GetJllamaContext_NullHandle_ReturnsNullWithoutThrow) { g_mock_handle = 0; - jllama_context *result = - get_jllama_context_impl(env, /*obj=*/nullptr, dummy_field); + jllama_context *result = get_jllama_context_impl(env, nullptr, dummy_field); - EXPECT_EQ(result, nullptr) - << "Expected nullptr when the model handle is 0"; - EXPECT_FALSE(g_throw_called) - << "get_jllama_context_impl must NOT throw on null handle (delete is a no-op)"; + EXPECT_EQ(result, nullptr); + EXPECT_FALSE(g_throw_called); } -TEST_F(MockJniFixture, GetJllamaContext_ValidHandle_ReturnsWrapperAndDoesNotThrow) { +TEST_F(MockJniFixture, GetJllamaContext_ValidHandle_ReturnsWrapper) { jllama_context fake_ctx; - fake_ctx.server = nullptr; // .server content is irrelevant for this test - - g_mock_handle = reinterpret_cast(&fake_ctx); + fake_ctx.server = nullptr; + g_mock_handle = reinterpret_cast(&fake_ctx); - jllama_context *result = - get_jllama_context_impl(env, /*obj=*/nullptr, dummy_field); + jllama_context *result = get_jllama_context_impl(env, nullptr, dummy_field); - EXPECT_EQ(result, &fake_ctx) - << "Expected the jllama_context wrapper pointer itself, not the inner .server"; - EXPECT_FALSE(g_throw_called) - << "ThrowNew must not be called for a valid handle"; + EXPECT_EQ(result, &fake_ctx); + EXPECT_FALSE(g_throw_called); } TEST_F(MockJniFixture, GetJllamaContext_ReturnsWrapperNotInnerServer) { - // Verify that the returned pointer is the outer struct, not .server, - // which is what distinguishes this helper from get_server_context_impl. server_context *sentinel = reinterpret_cast(0xDEADBEEF); jllama_context fake_ctx; fake_ctx.server = sentinel; + g_mock_handle = reinterpret_cast(&fake_ctx); - g_mock_handle = reinterpret_cast(&fake_ctx); - - jllama_context *result = - get_jllama_context_impl(env, /*obj=*/nullptr, dummy_field); + jllama_context *result = get_jllama_context_impl(env, nullptr, dummy_field); - EXPECT_EQ(result, &fake_ctx) - << "Must return the outer jllama_context wrapper"; - EXPECT_NE(static_cast(result), static_cast(sentinel)) - << "Must NOT return the inner .server — use get_server_context_impl for that"; + EXPECT_EQ(result, &fake_ctx); + EXPECT_NE(static_cast(result), static_cast(sentinel)); } -TEST_F(MockJniFixture, GetJllamaContext_ContractComparison_GetServerContextThrowsWhereGetJllamaContextDoesNot) { - // Regression guard: get_server_context_impl throws on null, but - // get_jllama_context_impl must not. Both are tested with the same - // zero handle so any future merge of the two helpers breaks this test. +TEST_F(MockJniFixture, GetJllamaContext_NullHandle_WhileGetServerContextThrows) { g_mock_handle = 0; - server_context *sc = get_server_context_impl(env, nullptr, dummy_field, dummy_class); - EXPECT_TRUE(g_throw_called) << "get_server_context_impl should throw on null"; - EXPECT_EQ(sc, nullptr); + (void)get_server_context_impl(env, nullptr, dummy_field, dummy_class); + EXPECT_TRUE(g_throw_called); - // Reset and test the delete-path helper g_throw_called = false; - jllama_context *jc = get_jllama_context_impl(env, nullptr, dummy_field); - EXPECT_FALSE(g_throw_called) << "get_jllama_context_impl must NOT throw on null"; - EXPECT_EQ(jc, nullptr); + (void)get_jllama_context_impl(env, nullptr, dummy_field); + EXPECT_FALSE(g_throw_called); } // ============================================================ -// Tests for require_single_task_id_impl() +// require_single_task_id_impl // ============================================================ TEST_F(MockJniFixture, RequireSingleTaskId_ExactlyOne_ReturnsIdNoThrow) { std::unordered_set ids = {42}; - int result = require_single_task_id_impl(env, ids, dummy_class); - EXPECT_EQ(result, 42); + EXPECT_EQ(require_single_task_id_impl(env, ids, dummy_class), 42); EXPECT_FALSE(g_throw_called); } TEST_F(MockJniFixture, RequireSingleTaskId_Empty_ReturnsZeroAndThrows) { std::unordered_set ids; - int result = require_single_task_id_impl(env, ids, dummy_class); - EXPECT_EQ(result, 0); + EXPECT_EQ(require_single_task_id_impl(env, ids, dummy_class), 0); EXPECT_TRUE(g_throw_called); EXPECT_EQ(g_throw_message, "multitasking currently not supported"); } TEST_F(MockJniFixture, RequireSingleTaskId_Multiple_ReturnsZeroAndThrows) { std::unordered_set ids = {1, 2, 3}; - int result = require_single_task_id_impl(env, ids, dummy_class); - EXPECT_EQ(result, 0); + EXPECT_EQ(require_single_task_id_impl(env, ids, dummy_class), 0); EXPECT_TRUE(g_throw_called); - EXPECT_EQ(g_throw_message, "multitasking currently not supported"); } // ============================================================ -// Tests for jint_array_to_tokens_impl() -// -// Needs GetArrayLength, GetIntArrayElements, and -// ReleaseIntArrayElements stubs. GetIntArrayElements returns a -// pointer to a static buffer. ReleaseIntArrayElements is a no-op. +// require_json_field_impl +// ============================================================ + +TEST_F(MockJniFixture, RequireJsonField_PresentField_ReturnsTrueNoThrow) { + nlohmann::json data = {{"input_prefix", "hello"}}; + EXPECT_TRUE(require_json_field_impl(env, data, "input_prefix", dummy_class)); + EXPECT_FALSE(g_throw_called); +} + +TEST_F(MockJniFixture, RequireJsonField_MissingField_ReturnsFalseAndThrows) { + nlohmann::json data = {{"other", 1}}; + EXPECT_FALSE(require_json_field_impl(env, data, "input_prefix", dummy_class)); + EXPECT_TRUE(g_throw_called); + EXPECT_EQ(g_throw_message, "\"input_prefix\" is required"); +} + +TEST_F(MockJniFixture, RequireJsonField_EmptyJson_ReturnsFalseAndThrows) { + EXPECT_FALSE(require_json_field_impl( + env, nlohmann::json::object(), "input_suffix", dummy_class)); + EXPECT_TRUE(g_throw_called); + EXPECT_EQ(g_throw_message, "\"input_suffix\" is required"); +} + +// ============================================================ +// jint_array_to_tokens_impl // ============================================================ namespace { @@ -261,18 +265,11 @@ static jsize g_array_length = 0; static bool g_release_called = false; static jint g_release_mode = -1; -static jsize JNICALL stub_GetArrayLength(JNIEnv * /*env*/, jarray /*arr*/) { - return g_array_length; -} -static jint *JNICALL stub_GetIntArrayElements(JNIEnv * /*env*/, - jintArray /*arr*/, - jboolean * /*isCopy*/) { +static jsize JNICALL stub_GetArrayLength(JNIEnv *, jarray) { return g_array_length; } +static jint *JNICALL stub_GetIntArrayElements(JNIEnv *, jintArray, jboolean *) { return g_array_data; } -static void JNICALL stub_ReleaseIntArrayElements(JNIEnv * /*env*/, - jintArray /*arr*/, - jint * /*elems*/, - jint mode) { +static void JNICALL stub_ReleaseIntArrayElements(JNIEnv *, jintArray, jint *, jint mode) { g_release_called = true; g_release_mode = mode; } @@ -296,7 +293,7 @@ struct ArrayFixture : ::testing::Test { g_release_called = false; g_release_mode = -1; std::memset(g_array_data, 0, sizeof(g_array_data)); - g_array_length = 0; + g_array_length = 0; } }; @@ -304,22 +301,16 @@ struct ArrayFixture : ::testing::Test { TEST_F(ArrayFixture, JintArrayToTokens_EmptyArray_ReturnsEmptyVector) { g_array_length = 0; - auto tokens = jint_array_to_tokens_impl(env, nullptr); - EXPECT_TRUE(tokens.empty()); EXPECT_TRUE(g_release_called); EXPECT_EQ(g_release_mode, JNI_ABORT); } TEST_F(ArrayFixture, JintArrayToTokens_ThreeElements_CopiedCorrectly) { - g_array_data[0] = 10; - g_array_data[1] = 20; - g_array_data[2] = 30; + g_array_data[0] = 10; g_array_data[1] = 20; g_array_data[2] = 30; g_array_length = 3; - auto tokens = jint_array_to_tokens_impl(env, nullptr); - ASSERT_EQ(tokens.size(), 3u); EXPECT_EQ(tokens[0], 10); EXPECT_EQ(tokens[1], 20); @@ -327,49 +318,215 @@ TEST_F(ArrayFixture, JintArrayToTokens_ThreeElements_CopiedCorrectly) { } TEST_F(ArrayFixture, JintArrayToTokens_ReleasesWithAbortFlag) { - // JNI_ABORT means no writeback — required since we only read the array. - g_array_length = 1; - g_array_data[0] = 42; - + g_array_length = 1; g_array_data[0] = 42; (void)jint_array_to_tokens_impl(env, nullptr); - EXPECT_TRUE(g_release_called); - EXPECT_EQ(g_release_mode, JNI_ABORT) - << "must use JNI_ABORT (no writeback) for read-only array access"; + EXPECT_EQ(g_release_mode, JNI_ABORT); } // ============================================================ -// Tests for require_json_field_impl() -// -// Uses the ThrowNew stub from MockJniFixture to verify that the -// function throws (or does not throw) correctly. +// json_to_jstring_impl // ============================================================ -TEST_F(MockJniFixture, RequireJsonField_PresentField_ReturnsTrueNoThrow) { - nlohmann::json data = {{"input_prefix", "hello"}, {"other", 1}}; +TEST_F(MockJniFixture, JsonToJstring_Object_RoundTrips) { + json j = {{"key", "value"}, {"n", 42}}; + jstring js = json_to_jstring_impl(env, j); + EXPECT_NE(js, nullptr); + json parsed = json::parse(g_new_string_utf_value); + EXPECT_TRUE(parsed.is_object()); + EXPECT_EQ(parsed.value("key", ""), "value"); + EXPECT_EQ(parsed.value("n", 0), 42); +} + +TEST_F(MockJniFixture, JsonToJstring_Array_RoundTrips) { + json j = json::array({1, 2, 3}); + jstring js = json_to_jstring_impl(env, j); + EXPECT_NE(js, nullptr); + json parsed = json::parse(g_new_string_utf_value); + EXPECT_TRUE(parsed.is_array()); + ASSERT_EQ(parsed.size(), 3u); +} + +TEST_F(MockJniFixture, JsonToJstring_ReturnsSentinel) { + jstring js = json_to_jstring_impl(env, {{"ok", true}}); + EXPECT_EQ(js, reinterpret_cast(0xBEEF)); +} + +// ============================================================ +// results_to_jstring_impl +// ============================================================ + +TEST_F(MockJniFixture, ResultsToJstring_SingleResult_ReturnsBareObject) { + std::vector results; + results.push_back(make_ok(1, "hello")); + + jstring js = results_to_jstring_impl(env, results); + + EXPECT_NE(js, nullptr); + json parsed = json::parse(g_new_string_utf_value); + EXPECT_TRUE(parsed.is_object()); + EXPECT_EQ(parsed.value("content", ""), "hello"); +} + +TEST_F(MockJniFixture, ResultsToJstring_MultipleResults_ReturnsArray) { + std::vector results; + results.push_back(make_ok(2, "first")); + results.push_back(make_ok(3, "second")); - bool ok = require_json_field_impl(env, data, "input_prefix", dummy_class); + jstring js = results_to_jstring_impl(env, results); - EXPECT_TRUE(ok); + EXPECT_NE(js, nullptr); + json parsed = json::parse(g_new_string_utf_value); + EXPECT_TRUE(parsed.is_array()); + ASSERT_EQ(parsed.size(), 2u); + EXPECT_EQ(parsed[0].value("content", ""), "first"); + EXPECT_EQ(parsed[1].value("content", ""), "second"); +} + +TEST_F(MockJniFixture, ResultsToJstring_EmptyVector_ReturnsEmptyArray) { + std::vector results; + jstring js = results_to_jstring_impl(env, results); + EXPECT_NE(js, nullptr); + json parsed = json::parse(g_new_string_utf_value); + EXPECT_TRUE(parsed.is_array()); + EXPECT_TRUE(parsed.empty()); +} + +// ============================================================ +// collect_task_results_impl +// ============================================================ + +TEST_F(ServerFixture, CollectResults_SingleOk_ReturnsTrueAndFillsOut) { + queue.add_waiting_task_id(1); + queue.send(make_ok(1, "hello")); + + std::unordered_set ids = {1}; + std::vector out; + + EXPECT_TRUE(collect_task_results_impl(env, queue, ids, out, dummy_class)); + ASSERT_EQ(out.size(), 1u); + EXPECT_EQ(out[0]->to_json()["content"], "hello"); EXPECT_FALSE(g_throw_called); } -TEST_F(MockJniFixture, RequireJsonField_MissingField_ReturnsFalseAndThrows) { - nlohmann::json data = {{"other", 1}}; +TEST_F(ServerFixture, CollectResults_SingleError_ReturnsFalseAndThrows) { + queue.add_waiting_task_id(2); + queue.send(make_error(2, "something went wrong")); + + std::unordered_set ids = {2}; + std::vector out; - bool ok = require_json_field_impl(env, data, "input_prefix", dummy_class); + EXPECT_FALSE(collect_task_results_impl(env, queue, ids, out, dummy_class)); + EXPECT_TRUE(out.empty()); + EXPECT_TRUE(g_throw_called); + EXPECT_EQ(g_throw_message, "something went wrong"); +} + +TEST_F(ServerFixture, CollectResults_MultipleOk_AllCollected) { + for (int i = 10; i < 13; ++i) { queue.add_waiting_task_id(i); queue.send(make_ok(i)); } + + std::unordered_set ids = {10, 11, 12}; + std::vector out; + + EXPECT_TRUE(collect_task_results_impl(env, queue, ids, out, dummy_class)); + EXPECT_EQ(out.size(), 3u); + EXPECT_FALSE(g_throw_called); +} + +TEST_F(ServerFixture, CollectResults_SecondError_StopsAndThrows) { + queue.add_waiting_task_id(20); queue.send(make_ok(20)); + queue.add_waiting_task_id(21); queue.send(make_error(21, "task 21 failed")); + + std::unordered_set ids = {20, 21}; + std::vector out; + + EXPECT_FALSE(collect_task_results_impl(env, queue, ids, out, dummy_class)); + EXPECT_TRUE(g_throw_called); + EXPECT_EQ(g_throw_message, "task 21 failed"); +} + +TEST_F(ServerFixture, CollectResults_SuccessPath_WaitingIdsRemoved) { + queue.add_waiting_task_id(30); queue.send(make_ok(30)); + std::unordered_set ids = {30}; + std::vector out; + (void)collect_task_results_impl(env, queue, ids, out, dummy_class); + EXPECT_FALSE(queue.waiting_task_ids.count(30)); +} + +TEST_F(ServerFixture, CollectResults_ErrorPath_WaitingIdsRemoved) { + queue.add_waiting_task_id(40); queue.send(make_error(40, "err")); + std::unordered_set ids = {40}; + std::vector out; + (void)collect_task_results_impl(env, queue, ids, out, dummy_class); + EXPECT_FALSE(queue.waiting_task_ids.count(40)); +} + +// ============================================================ +// recv_slot_task_result_impl +// ============================================================ + +TEST_F(ServerFixture, RecvSlotResult_Success_ReturnsNonNullNoThrow) { + queue.add_waiting_task_id(50); queue.send(make_ok(50, "slot-ok")); + + jstring result = recv_slot_task_result_impl(env, queue, 50, dummy_class); + + EXPECT_NE(result, nullptr); + EXPECT_FALSE(g_throw_called); + EXPECT_NE(g_new_string_utf_value.find("slot-ok"), std::string::npos); +} + +TEST_F(ServerFixture, RecvSlotResult_Error_ReturnsNullAndThrows) { + queue.add_waiting_task_id(51); queue.send(make_error(51, "slot operation failed")); + + jstring result = recv_slot_task_result_impl(env, queue, 51, dummy_class); + + EXPECT_EQ(result, nullptr); + EXPECT_TRUE(g_throw_called); + EXPECT_EQ(g_throw_message, "slot operation failed"); +} + +TEST_F(ServerFixture, RecvSlotResult_Success_WaitingIdRemoved) { + queue.add_waiting_task_id(52); queue.send(make_ok(52)); + (void)recv_slot_task_result_impl(env, queue, 52, dummy_class); + EXPECT_FALSE(queue.waiting_task_ids.count(52)); +} + +TEST_F(ServerFixture, RecvSlotResult_Error_WaitingIdRemoved) { + queue.add_waiting_task_id(53); queue.send(make_error(53, "err")); + (void)recv_slot_task_result_impl(env, queue, 53, dummy_class); + EXPECT_FALSE(queue.waiting_task_ids.count(53)); +} + +// ============================================================ +// build_completion_tasks_impl — error path only +// (success path requires a live server_context with vocab/ctx) +// ============================================================ + +TEST_F(MockJniFixture, BuildTasks_MissingPrompt_ReturnsFalseAndThrows) { + json data = {{"n_predict", 1}}; + std::vector tasks; + + bool ok = build_completion_tasks_impl(env, /*ctx_server=*/nullptr, data, + "test-cmpl-id", + SERVER_TASK_TYPE_COMPLETION, + OAICOMPAT_TYPE_NONE, + tasks, dummy_class); EXPECT_FALSE(ok); EXPECT_TRUE(g_throw_called); - EXPECT_EQ(g_throw_message, "\"input_prefix\" is required"); + EXPECT_TRUE(tasks.empty()); } -TEST_F(MockJniFixture, RequireJsonField_EmptyJson_ReturnsFalseAndThrows) { - nlohmann::json data = nlohmann::json::object(); +TEST_F(MockJniFixture, BuildTasks_MissingPrompt_InfillTypeHasSameBehaviour) { + json data = {{"input_prefix", "def f():"}, {"input_suffix", "return 1"}}; + std::vector tasks; - bool ok = require_json_field_impl(env, data, "input_suffix", dummy_class); + bool ok = build_completion_tasks_impl(env, nullptr, data, "infill-id", + SERVER_TASK_TYPE_INFILL, + OAICOMPAT_TYPE_NONE, + tasks, dummy_class); EXPECT_FALSE(ok); EXPECT_TRUE(g_throw_called); - EXPECT_EQ(g_throw_message, "\"input_suffix\" is required"); + EXPECT_TRUE(tasks.empty()); } diff --git a/src/test/cpp/test_jni_server_helpers.cpp b/src/test/cpp/test_jni_server_helpers.cpp deleted file mode 100644 index 482d445a..00000000 --- a/src/test/cpp/test_jni_server_helpers.cpp +++ /dev/null @@ -1,741 +0,0 @@ -// Tests for jni_server_helpers.hpp: -// - build_completion_tasks_impl -// - collect_task_results_impl -// - recv_slot_task_result_impl (added by Finding 5) -// -// build_completion_tasks_impl needs: -// - JNIEnv — used only for ThrowNew on the error path -// - server_context* — NOT accessed when "prompt" is absent (exception thrown -// first), so nullptr is safe for error-path tests. -// - json data — provided inline in the test -// -// collect_task_results_impl and recv_slot_task_result_impl need: -// - server_response (from server.hpp) — provides recv() / remove_waiting_task_ids() -// - JNIEnv — used only for ThrowNew on the error path -// -// server_response is used directly: we pre-seed it with results via send() -// before calling collect_task_results_impl(). Because recv() checks the -// queue under a mutex+condvar, pre-seeding lets us call recv() from the same -// thread without blocking. -// -// JNIEnv is mocked with the same stub technique used in test_jni_helpers.cpp: -// a zero-filled JNINativeInterface_ table with only GetLongField (unused here) -// and ThrowNew patched so we can observe whether an exception was raised. -// -// Covered scenarios: -// - single success result → out filled, no throw, returns true -// - single error result → out empty, ThrowNew called with correct message, returns false -// - multiple success results → all collected in order, returns true -// - first result ok, second is error → cleanup, ThrowNew, returns false -// - waiting ids are removed on success (remove_waiting_task_ids called) -// - waiting ids are removed on error (remove_waiting_task_ids called) - -#include - -#include -#include -#include -#include - -// server.hpp must come before jni_server_helpers.hpp (no include guard in server.hpp). -#include "server.hpp" -#include "jni_server_helpers.hpp" - -// ============================================================ -// Minimal concrete server_task_result subtypes for testing -// ============================================================ - -namespace { - -// A success result whose to_json() returns {"content": ""}. -struct fake_ok_result : server_task_result { - std::string msg; - explicit fake_ok_result(int id_, std::string m) : msg(std::move(m)) { id = id_; } - json to_json() override { return {{"content", msg}}; } -}; - -// An embedding result whose to_json() returns the shape expected by -// format_embeddings_response_oaicompat: {"embedding": [...], "tokens_evaluated": N}. -struct fake_embedding_result : server_task_result { - std::vector vec; - int tokens_evaluated; - explicit fake_embedding_result(int id_, std::vector v, int tok = 4) - : vec(std::move(v)), tokens_evaluated(tok) { id = id_; } - json to_json() override { - return {{"embedding", vec}, {"tokens_evaluated", tokens_evaluated}}; - } -}; - -static server_task_result_ptr make_embedding(int id_, std::vector v = {0.1f, 0.2f, 0.3f}) { - return std::make_unique(id_, std::move(v)); -} - -// An error result — reuses the real server_task_result_error so that -// to_json() → format_error_response() → {"message": err_msg, ...} matches -// the exact JSON key that collect_task_results_impl reads. -static server_task_result_ptr make_error(int id_, const std::string &msg) { - auto r = std::make_unique(); - r->id = id_; - r->err_msg = msg; - r->err_type = ERROR_TYPE_SERVER; - return r; -} - -static server_task_result_ptr make_ok(int id_, const std::string &msg = "ok") { - return std::make_unique(id_, msg); -} - -// ============================================================ -// Mock JNI environment (same pattern as test_jni_helpers.cpp) -// ============================================================ - -static bool g_throw_called = false; -static std::string g_throw_message; - -// NewStringUTF stub: stores the string so tests can inspect it, returns a -// non-null sentinel so callers can distinguish success from nullptr (error). -static std::string g_new_string_utf_value; -static jstring g_new_string_utf_sentinel = reinterpret_cast(0xBEEF); - -static jint JNICALL stub_ThrowNew(JNIEnv *, jclass, const char *msg) { - g_throw_called = true; - g_throw_message = msg ? msg : ""; - return 0; -} - -static jlong JNICALL stub_GetLongField(JNIEnv *, jobject, jfieldID) { return 0; } - -static jstring JNICALL stub_NewStringUTF(JNIEnv *, const char *utf) { - g_new_string_utf_value = utf ? utf : ""; - return g_new_string_utf_sentinel; -} - -JNIEnv *make_mock_env(JNINativeInterface_ &table, JNIEnv_ &env_obj) { - std::memset(&table, 0, sizeof(table)); - table.ThrowNew = stub_ThrowNew; - table.GetLongField = stub_GetLongField; // unused but avoids a null slot crash - table.NewStringUTF = stub_NewStringUTF; - env_obj.functions = &table; - return &env_obj; -} - -// Test fixture: fresh mock env + fresh server_response per test. -struct CollectResultsFixture : ::testing::Test { - JNINativeInterface_ table{}; - JNIEnv_ env_obj{}; - JNIEnv *env = nullptr; - jclass dummy_eclass = reinterpret_cast(0x1); - - server_response queue; - - void SetUp() override { - env = make_mock_env(table, env_obj); - g_throw_called = false; - g_throw_message.clear(); - g_new_string_utf_value.clear(); - } -}; - -} // namespace - -// ============================================================ -// Tests for get_result_error_message() -// -// No JNI needed — the function only calls result->to_json() and -// performs a JSON key lookup, both of which are pure C++. -// ============================================================ - -TEST(GetResultErrorMessage, ErrorResult_ReturnsMessageString) { - server_task_result_ptr r = make_error(1, "something went wrong"); - EXPECT_EQ(get_result_error_message(r), "something went wrong"); -} - -TEST(GetResultErrorMessage, DifferentMessage_ReturnsCorrectString) { - server_task_result_ptr r = make_error(2, "out of memory"); - EXPECT_EQ(get_result_error_message(r), "out of memory"); -} - -// ============================================================ -// Single success result -// ============================================================ - -TEST_F(CollectResultsFixture, SingleOkResult_ReturnsTrueAndFillsOut) { - queue.add_waiting_task_id(1); - queue.send(make_ok(1, "hello")); - - std::unordered_set ids = {1}; - std::vector out; - - bool ok = collect_task_results_impl(env, queue, ids, out, dummy_eclass); - - EXPECT_TRUE(ok); - EXPECT_EQ(out.size(), 1u); - EXPECT_EQ(out[0]->to_json()["content"], "hello"); - EXPECT_FALSE(g_throw_called); -} - -// ============================================================ -// Single error result -// ============================================================ - -TEST_F(CollectResultsFixture, SingleErrorResult_ReturnsFalseAndThrows) { - queue.add_waiting_task_id(2); - queue.send(make_error(2, "something went wrong")); - - std::unordered_set ids = {2}; - std::vector out; - - bool ok = collect_task_results_impl(env, queue, ids, out, dummy_eclass); - - EXPECT_FALSE(ok); - EXPECT_TRUE(out.empty()) << "out must not be populated on error"; - EXPECT_TRUE(g_throw_called); - EXPECT_EQ(g_throw_message, "something went wrong"); -} - -// ============================================================ -// Multiple success results -// ============================================================ - -TEST_F(CollectResultsFixture, MultipleOkResults_AllCollected) { - for (int i = 10; i < 13; ++i) { - queue.add_waiting_task_id(i); - queue.send(make_ok(i, "msg" + std::to_string(i))); - } - - std::unordered_set ids = {10, 11, 12}; - std::vector out; - - bool ok = collect_task_results_impl(env, queue, ids, out, dummy_eclass); - - EXPECT_TRUE(ok); - EXPECT_EQ(out.size(), 3u); - EXPECT_FALSE(g_throw_called); -} - -// ============================================================ -// First ok, second error — error path cleans up remaining ids -// ============================================================ - -TEST_F(CollectResultsFixture, SecondResultIsError_StopsAndThrows) { - queue.add_waiting_task_id(20); - queue.add_waiting_task_id(21); - queue.send(make_ok(20)); - queue.send(make_error(21, "task 21 failed")); - - std::unordered_set ids = {20, 21}; - std::vector out; - - bool ok = collect_task_results_impl(env, queue, ids, out, dummy_eclass); - - EXPECT_FALSE(ok); - EXPECT_TRUE(g_throw_called); - EXPECT_EQ(g_throw_message, "task 21 failed"); -} - -// ============================================================ -// Waiting ids are removed from the queue on the success path -// ============================================================ - -TEST_F(CollectResultsFixture, SuccessPath_WaitingIdsRemovedAfterCollect) { - queue.add_waiting_task_id(30); - queue.send(make_ok(30)); - - std::unordered_set ids = {30}; - std::vector out; - (void)collect_task_results_impl(env, queue, ids, out, dummy_eclass); - - // After collect, the id must no longer be in the waiting set. - // We verify indirectly: sending a second result for id=30 should - // NOT be returned by a subsequent recv for a different id — the - // simplest check is that waiting_task_ids no longer contains 30. - EXPECT_FALSE(queue.waiting_task_ids.count(30)) - << "remove_waiting_task_ids must clear the id on success"; -} - -// ============================================================ -// Waiting ids are removed from the queue on the error path -// ============================================================ - -TEST_F(CollectResultsFixture, ErrorPath_WaitingIdsRemovedAfterError) { - queue.add_waiting_task_id(40); - queue.send(make_error(40, "err")); - - std::unordered_set ids = {40}; - std::vector out; - (void)collect_task_results_impl(env, queue, ids, out, dummy_eclass); - - EXPECT_FALSE(queue.waiting_task_ids.count(40)) - << "remove_waiting_task_ids must clear the id on error"; -} - -// ============================================================ -// Tests for build_completion_tasks_impl -// -// Only the error path is unit-testable here: the function reads -// data["prompt"] in its own statement BEFORE accessing ctx_server, so -// passing nullptr for ctx_server is safe when "prompt" is absent. -// -// The success path requires a live server_context (llama vocab + context -// pointers) and is covered by LlamaModelTest Java integration tests. -// ============================================================ - -TEST_F(CollectResultsFixture, BuildTasks_MissingPrompt_ReturnsFalseAndThrows) { - json data = {{"n_predict", 1}}; // deliberately no "prompt" key - std::string completion_id = "test-cmpl-id"; - std::vector tasks; - - // ctx_server is nullptr — safe because data.at("prompt") throws before - // any ctx_server member is accessed. - bool ok = build_completion_tasks_impl(env, - /*ctx_server=*/nullptr, - data, completion_id, - SERVER_TASK_TYPE_COMPLETION, - OAICOMPAT_TYPE_NONE, - tasks, - dummy_eclass); - - EXPECT_FALSE(ok) << "Missing 'prompt' must return false"; - EXPECT_TRUE(g_throw_called) << "ThrowNew must be called for missing 'prompt'"; - EXPECT_TRUE(tasks.empty()) << "tasks must remain empty on error"; -} - -TEST_F(CollectResultsFixture, BuildTasks_MissingPrompt_TaskTypeDoesNotAffectErrorBehaviour) { - // The error path is identical regardless of task_type / oaicompat. - // Verify INFILL behaves the same way. - json data = {{"input_prefix", "def f():"}, {"input_suffix", "return 1"}}; - std::string completion_id = "infill-cmpl-id"; - std::vector tasks; - - bool ok = build_completion_tasks_impl(env, - /*ctx_server=*/nullptr, - data, completion_id, - SERVER_TASK_TYPE_INFILL, - OAICOMPAT_TYPE_NONE, - tasks, - dummy_eclass); - - EXPECT_FALSE(ok); - EXPECT_TRUE(g_throw_called); - EXPECT_TRUE(tasks.empty()); -} - -// ============================================================ -// Tests for recv_slot_task_result_impl -// -// Pre-seed the server_response queue (same technique as collect tests): -// calling send() before recv() satisfies the condvar predicate immediately, -// so recv() returns without blocking even from the same thread. -// -// NewStringUTF is stubbed to return a sentinel non-null jstring and capture -// the serialised JSON so we can verify the success path. -// ============================================================ - -TEST_F(CollectResultsFixture, RecvSlotResult_SuccessResult_ReturnsNonNullAndDoesNotThrow) { - queue.add_waiting_task_id(50); - queue.send(make_ok(50, "slot-ok")); - - jstring result = recv_slot_task_result_impl(env, queue, 50, dummy_eclass); - - EXPECT_NE(result, nullptr) << "success result must return non-null jstring"; - EXPECT_FALSE(g_throw_called) << "ThrowNew must not be called on success"; - // The stub captures the serialised JSON passed to NewStringUTF. - EXPECT_FALSE(g_new_string_utf_value.empty()) - << "NewStringUTF must be called with the result JSON"; - EXPECT_NE(g_new_string_utf_value.find("slot-ok"), std::string::npos) - << "result JSON must contain the content from the fake result"; -} - -TEST_F(CollectResultsFixture, RecvSlotResult_ErrorResult_ReturnsNullAndThrows) { - queue.add_waiting_task_id(51); - queue.send(make_error(51, "slot operation failed")); - - jstring result = recv_slot_task_result_impl(env, queue, 51, dummy_eclass); - - EXPECT_EQ(result, nullptr) << "error result must return nullptr"; - EXPECT_TRUE(g_throw_called) << "ThrowNew must be called on error"; - EXPECT_EQ(g_throw_message, "slot operation failed"); -} - -TEST_F(CollectResultsFixture, RecvSlotResult_WaitingIdRemovedAfterSuccess) { - queue.add_waiting_task_id(52); - queue.send(make_ok(52)); - - (void)recv_slot_task_result_impl(env, queue, 52, dummy_eclass); - - EXPECT_FALSE(queue.waiting_task_ids.count(52)) - << "remove_waiting_task_id must clear the id on success"; -} - -TEST_F(CollectResultsFixture, RecvSlotResult_WaitingIdRemovedAfterError) { - queue.add_waiting_task_id(53); - queue.send(make_error(53, "err")); - - (void)recv_slot_task_result_impl(env, queue, 53, dummy_eclass); - - EXPECT_FALSE(queue.waiting_task_ids.count(53)) - << "remove_waiting_task_id must clear the id on error"; -} - -// ============================================================ -// Tests for results_to_jstring_impl -// -// Verifies that the serialisation helper produces the right shape: -// - single result → bare JSON object -// - multiple results → JSON array -// -// NewStringUTF is stubbed via the fixture's mock env; g_new_string_utf_value -// captures whatever string was passed to it. -// ============================================================ - -TEST_F(CollectResultsFixture, ResultsToJstring_SingleResult_ReturnsBareObject) { - std::vector results; - results.push_back(make_ok(1, "hello")); - - jstring js = results_to_jstring_impl(env, results); - - EXPECT_NE(js, nullptr); - EXPECT_FALSE(g_new_string_utf_value.empty()); - - // The top-level JSON must be an object (not an array). - json parsed = json::parse(g_new_string_utf_value); - EXPECT_TRUE(parsed.is_object()) - << "single result must serialise as a bare JSON object, got: " - << g_new_string_utf_value; - EXPECT_EQ(parsed.value("content", ""), "hello"); -} - -TEST_F(CollectResultsFixture, ResultsToJstring_MultipleResults_ReturnsArray) { - std::vector results; - results.push_back(make_ok(2, "first")); - results.push_back(make_ok(3, "second")); - - jstring js = results_to_jstring_impl(env, results); - - EXPECT_NE(js, nullptr); - json parsed = json::parse(g_new_string_utf_value); - EXPECT_TRUE(parsed.is_array()) - << "multiple results must serialise as a JSON array, got: " - << g_new_string_utf_value; - ASSERT_EQ(parsed.size(), 2u); - EXPECT_EQ(parsed[0].value("content", ""), "first"); - EXPECT_EQ(parsed[1].value("content", ""), "second"); -} - -TEST_F(CollectResultsFixture, ResultsToJstring_EmptyVector_ReturnsEmptyArray) { - std::vector results; - - jstring js = results_to_jstring_impl(env, results); - - EXPECT_NE(js, nullptr); - json parsed = json::parse(g_new_string_utf_value); - EXPECT_TRUE(parsed.is_array()); - EXPECT_TRUE(parsed.empty()); -} - -// ============================================================ -// Tests for results_to_json_impl -// -// Pure-logic counterpart to results_to_jstring_impl — no JNI needed. -// ============================================================ - -TEST(ResultsToJsonImpl, SingleResult_ReturnsObjectDirectly) { - std::vector results; - results.push_back(make_ok(1, "only")); - - json out = results_to_json_impl(results); - - EXPECT_TRUE(out.is_object()); - EXPECT_EQ(out.value("content", ""), "only"); -} - -TEST(ResultsToJsonImpl, MultipleResults_ReturnsArray) { - std::vector results; - results.push_back(make_ok(1, "a")); - results.push_back(make_ok(2, "b")); - - json out = results_to_json_impl(results); - - EXPECT_TRUE(out.is_array()); - ASSERT_EQ(out.size(), 2u); - EXPECT_EQ(out[0].value("content", ""), "a"); - EXPECT_EQ(out[1].value("content", ""), "b"); -} - -TEST(ResultsToJsonImpl, EmptyVector_ReturnsEmptyArray) { - std::vector results; - - json out = results_to_json_impl(results); - - EXPECT_TRUE(out.is_array()); - EXPECT_TRUE(out.empty()); -} - -// ============================================================ -// Tests for json_to_jstring_impl -// -// Verifies that any json value is serialised correctly via -// dump() + NewStringUTF. The stub captures the string so tests -// can round-trip parse it. -// ============================================================ - -TEST_F(CollectResultsFixture, JsonToJstring_Object_RoundTrips) { - json j = {{"key", "value"}, {"n", 42}}; - - jstring js = json_to_jstring_impl(env, j); - - EXPECT_NE(js, nullptr); - EXPECT_FALSE(g_new_string_utf_value.empty()); - json parsed = json::parse(g_new_string_utf_value); - EXPECT_TRUE(parsed.is_object()); - EXPECT_EQ(parsed.value("key", ""), "value"); - EXPECT_EQ(parsed.value("n", 0), 42); -} - -TEST_F(CollectResultsFixture, JsonToJstring_Array_RoundTrips) { - json j = json::array({1, 2, 3}); - - jstring js = json_to_jstring_impl(env, j); - - EXPECT_NE(js, nullptr); - json parsed = json::parse(g_new_string_utf_value); - EXPECT_TRUE(parsed.is_array()); - ASSERT_EQ(parsed.size(), 3u); - EXPECT_EQ(parsed[0], 1); - EXPECT_EQ(parsed[2], 3); -} - -TEST_F(CollectResultsFixture, JsonToJstring_ReturnsSentinelFromStub) { - // The mock NewStringUTF returns the 0xBEEF sentinel — verify the - // function propagates it unchanged so callers get a non-null jstring. - json j = {{"ok", true}}; - - jstring js = json_to_jstring_impl(env, j); - - EXPECT_EQ(js, reinterpret_cast(0xBEEF)); -} - -// ============================================================ -// Tests for extract_first_embedding_row -// -// Pure computation — no JNI or llama context needed. -// ============================================================ - -TEST(ExtractFirstEmbeddingRow, SingleRow_ReturnsRow) { - json j = {{"embedding", {{0.1f, 0.2f, 0.3f}}}}; - auto row = extract_first_embedding_row(j); - ASSERT_EQ(row.size(), 3u); - EXPECT_FLOAT_EQ(row[0], 0.1f); - EXPECT_FLOAT_EQ(row[1], 0.2f); - EXPECT_FLOAT_EQ(row[2], 0.3f); -} - -TEST(ExtractFirstEmbeddingRow, MultipleRows_ReturnsFirstRowOnly) { - json j = {{"embedding", {{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}}}; - auto row = extract_first_embedding_row(j); - ASSERT_EQ(row.size(), 2u); - EXPECT_FLOAT_EQ(row[0], 1.0f); - EXPECT_FLOAT_EQ(row[1], 2.0f); -} - -TEST(ExtractFirstEmbeddingRow, MissingEmbeddingKey_ThrowsJsonException) { - json j = {{"other_key", "value"}}; - EXPECT_THROW(extract_first_embedding_row(j), nlohmann::json::exception); -} - -TEST(ExtractFirstEmbeddingRow, EmptyOuterArray_ThrowsRuntimeError) { - json j = {{"embedding", json::array()}}; - EXPECT_THROW(extract_first_embedding_row(j), std::runtime_error); -} - -TEST(ExtractFirstEmbeddingRow, EmptyInnerArray_ThrowsRuntimeError) { - json j = {{"embedding", {json::array()}}}; - EXPECT_THROW(extract_first_embedding_row(j), std::runtime_error); -} - -TEST(ExtractFirstEmbeddingRow, LargeRow_AllValuesPreserved) { - std::vector vals(128); - for (int i = 0; i < 128; ++i) vals[i] = static_cast(i) * 0.01f; - json j = {{"embedding", {vals}}}; - auto row = extract_first_embedding_row(j); - ASSERT_EQ(row.size(), 128u); - for (int i = 0; i < 128; ++i) { - EXPECT_FLOAT_EQ(row[i], static_cast(i) * 0.01f); - } -} - -// ============================================================ -// Tests for parse_encoding_format_impl -// -// Pure computation — no JNI or llama context needed. -// ============================================================ - -TEST(ParseEncodingFormat, FieldAbsent_ReturnsFalse) { - json body = {{"model", "text-embedding-ada-002"}}; - EXPECT_FALSE(parse_encoding_format_impl(body)); -} - -TEST(ParseEncodingFormat, ExplicitFloat_ReturnsFalse) { - json body = {{"encoding_format", "float"}}; - EXPECT_FALSE(parse_encoding_format_impl(body)); -} - -TEST(ParseEncodingFormat, Base64_ReturnsTrue) { - json body = {{"encoding_format", "base64"}}; - EXPECT_TRUE(parse_encoding_format_impl(body)); -} - -TEST(ParseEncodingFormat, UnknownFormat_ThrowsInvalidArgument) { - json body = {{"encoding_format", "binary"}}; - EXPECT_THROW(parse_encoding_format_impl(body), std::invalid_argument); -} - -TEST(ParseEncodingFormat, EmptyString_ThrowsInvalidArgument) { - json body = {{"encoding_format", ""}}; - EXPECT_THROW(parse_encoding_format_impl(body), std::invalid_argument); -} - -TEST(ParseEncodingFormat, UnknownFormat_MessageMentionsValidOptions) { - json body = {{"encoding_format", "hex"}}; - try { - parse_encoding_format_impl(body); - FAIL() << "Expected std::invalid_argument"; - } catch (const std::invalid_argument &e) { - EXPECT_NE(std::string(e.what()).find("float"), std::string::npos) - << "error message should mention \"float\""; - EXPECT_NE(std::string(e.what()).find("base64"), std::string::npos) - << "error message should mention \"base64\""; - } -} - -// ============================================================ -// Tests for extract_embedding_prompt_impl -// -// Pure computation — no JNI or llama context needed. -// ============================================================ - -TEST(ExtractEmbeddingPrompt, InputKey_ReturnsValueAndDoesNotSetFlag) { - bool flag = true; // pre-set to verify it gets cleared - json body = {{"input", "hello world"}}; - json prompt = extract_embedding_prompt_impl(body, flag); - EXPECT_EQ(prompt, "hello world"); - EXPECT_FALSE(flag) << "force_no_oaicompat must be false when \"input\" is used"; -} - -TEST(ExtractEmbeddingPrompt, ContentKey_ReturnsValueAndSetsFlag) { - bool flag = false; - json body = {{"content", "some text"}}; - json prompt = extract_embedding_prompt_impl(body, flag); - EXPECT_EQ(prompt, "some text"); - EXPECT_TRUE(flag) << "force_no_oaicompat must be true when \"content\" is used"; -} - -TEST(ExtractEmbeddingPrompt, InputTakesPriorityOverContent) { - bool flag = false; - json body = {{"input", "from input"}, {"content", "from content"}}; - json prompt = extract_embedding_prompt_impl(body, flag); - EXPECT_EQ(prompt, "from input"); - EXPECT_FALSE(flag) << "\"input\" path must not set force_no_oaicompat"; -} - -TEST(ExtractEmbeddingPrompt, NeitherKey_ThrowsInvalidArgument) { - bool flag = false; - json body = {{"model", "text-embedding-ada-002"}}; - EXPECT_THROW(extract_embedding_prompt_impl(body, flag), std::invalid_argument); -} - -TEST(ExtractEmbeddingPrompt, EmptyBody_ThrowsInvalidArgument) { - bool flag = false; - EXPECT_THROW(extract_embedding_prompt_impl(json::object(), flag), std::invalid_argument); -} - -TEST(ExtractEmbeddingPrompt, ArrayPrompt_ReturnedAsIs) { - // "input" may be an array of strings (batch embedding); the function must - // return the JSON value unchanged without trying to coerce it. - bool flag = false; - json body = {{"input", {"sentence one", "sentence two"}}}; - json prompt = extract_embedding_prompt_impl(body, flag); - ASSERT_TRUE(prompt.is_array()); - ASSERT_EQ(prompt.size(), 2u); - EXPECT_EQ(prompt[0], "sentence one"); - EXPECT_EQ(prompt[1], "sentence two"); - EXPECT_FALSE(flag); -} - -// ============================================================ -// Tests for build_embeddings_response_json_impl -// -// Pure computation — no JNI or llama context needed. -// Uses fake_embedding_result (added above) for OAI-path tests that -// need "embedding" + "tokens_evaluated" in the result JSON. -// ============================================================ - -TEST(BuildEmbeddingsResponseJson, NonOai_SingleResult_ReturnsBareArray) { - std::vector results; - results.push_back(make_embedding(1, {0.1f, 0.2f})); - - json out = build_embeddings_response_json_impl(results, json::object(), - OAICOMPAT_TYPE_NONE, false); - - ASSERT_TRUE(out.is_array()); - ASSERT_EQ(out.size(), 1u); - EXPECT_TRUE(out[0].contains("embedding")); -} - -TEST(BuildEmbeddingsResponseJson, NonOai_MultipleResults_AllInArray) { - std::vector results; - results.push_back(make_embedding(1, {0.1f})); - results.push_back(make_embedding(2, {0.2f})); - results.push_back(make_embedding(3, {0.3f})); - - json out = build_embeddings_response_json_impl(results, json::object(), - OAICOMPAT_TYPE_NONE, false); - - ASSERT_TRUE(out.is_array()); - EXPECT_EQ(out.size(), 3u); -} - -TEST(BuildEmbeddingsResponseJson, OaiFloat_WrapsWithOaiStructure) { - std::vector results; - results.push_back(make_embedding(1, {0.5f, 0.6f, 0.7f})); - - json body = {{"model", "text-embedding-ada-002"}}; - json out = build_embeddings_response_json_impl(results, body, - OAICOMPAT_TYPE_EMBEDDING, false); - - EXPECT_TRUE(out.is_object()) << "OAI response must be an object"; - EXPECT_EQ(out.value("object", ""), "list"); - EXPECT_TRUE(out.contains("data")) << "OAI response must have \"data\""; - EXPECT_TRUE(out.contains("usage")) << "OAI response must have \"usage\""; - EXPECT_EQ(out.value("model", ""), "text-embedding-ada-002"); - - ASSERT_TRUE(out["data"].is_array()); - ASSERT_EQ(out["data"].size(), 1u); - EXPECT_EQ(out["data"][0].value("object", ""), "embedding"); -} - -TEST(BuildEmbeddingsResponseJson, OaiBase64_EmbeddingEncodedAsString) { - std::vector results; - results.push_back(make_embedding(1, {1.0f, 2.0f})); - - json out = build_embeddings_response_json_impl(results, json::object(), - OAICOMPAT_TYPE_EMBEDDING, /*use_base64=*/true); - - ASSERT_TRUE(out["data"].is_array()); - ASSERT_EQ(out["data"].size(), 1u); - // base64 path stores embedding as a string, not an array - EXPECT_TRUE(out["data"][0]["embedding"].is_string()) - << "base64 embedding must be serialised as a string"; -} - -TEST(BuildEmbeddingsResponseJson, OaiUsage_TokensSummedAcrossResults) { - std::vector results; - results.push_back(std::make_unique(1, std::vector{0.1f}, /*tok=*/3)); - results.push_back(std::make_unique(2, std::vector{0.2f}, /*tok=*/5)); - - json out = build_embeddings_response_json_impl(results, json::object(), - OAICOMPAT_TYPE_EMBEDDING, false); - - EXPECT_EQ(out["usage"].value("prompt_tokens", 0), 8) - << "usage.prompt_tokens must be sum of tokens_evaluated across all results"; -} diff --git a/src/test/cpp/test_json_helpers.cpp b/src/test/cpp/test_json_helpers.cpp new file mode 100644 index 00000000..4398ce4d --- /dev/null +++ b/src/test/cpp/test_json_helpers.cpp @@ -0,0 +1,469 @@ +// Tests for json_helpers.hpp. +// +// Every function in json_helpers.hpp is pure JSON transformation with no JNI +// and no llama state. Tests for functions that only take nlohmann::json +// arguments need zero setup. Tests for functions that take +// server_task_result_ptr use lightweight fake result objects defined below; +// they need server.hpp for the type definitions but never load a model. +// +// Covered functions: +// get_result_error_message +// results_to_json +// rerank_results_to_json +// build_embeddings_response_json +// extract_first_embedding_row +// parse_encoding_format +// extract_embedding_prompt +// is_infill_request +// parse_slot_prompt_similarity +// parse_positive_int_config + +#include + +#include +#include +#include + +// server.hpp must precede json_helpers.hpp (defines server_task_result_ptr, +// oaicompat_type, format_embeddings_response_oaicompat, and the json alias). +#include "server.hpp" +#include "json_helpers.hpp" + +// ============================================================ +// Minimal fake result types +// ============================================================ + +namespace { + +// Error result — reuses the real server_task_result_error so that +// to_json() → format_error_response() → {"message": msg, ...} matches the +// exact JSON key that get_result_error_message reads. +static server_task_result_ptr make_error(int id_, const std::string &msg) { + auto r = std::make_unique(); + r->id = id_; + r->err_msg = msg; + r->err_type = ERROR_TYPE_SERVER; + return r; +} + +// Generic success result: to_json() returns {"content": msg}. +struct fake_ok_result : server_task_result { + std::string msg; + explicit fake_ok_result(int id_, std::string m) : msg(std::move(m)) { id = id_; } + json to_json() override { return {{"content", msg}}; } +}; + +static server_task_result_ptr make_ok(int id_, const std::string &msg = "ok") { + return std::make_unique(id_, msg); +} + +// Embedding result: to_json() returns the shape expected by +// format_embeddings_response_oaicompat. +struct fake_embedding_result : server_task_result { + std::vector vec; + int tokens_evaluated; + explicit fake_embedding_result(int id_, std::vector v, int tok = 4) + : vec(std::move(v)), tokens_evaluated(tok) { id = id_; } + json to_json() override { + return {{"embedding", vec}, {"tokens_evaluated", tokens_evaluated}}; + } +}; + +static server_task_result_ptr make_embedding(int id_, + std::vector v = {0.1f, 0.2f, 0.3f}) { + return std::make_unique(id_, std::move(v)); +} + +} // namespace + +// ============================================================ +// get_result_error_message +// ============================================================ + +TEST(GetResultErrorMessage, ErrorResult_ReturnsMessageString) { + auto r = make_error(1, "something went wrong"); + EXPECT_EQ(get_result_error_message(r), "something went wrong"); +} + +TEST(GetResultErrorMessage, DifferentMessage_ReturnsCorrectString) { + auto r = make_error(2, "out of memory"); + EXPECT_EQ(get_result_error_message(r), "out of memory"); +} + +// ============================================================ +// results_to_json +// ============================================================ + +TEST(ResultsToJson, SingleResult_ReturnsObjectDirectly) { + std::vector results; + results.push_back(make_ok(1, "only")); + + json out = results_to_json(results); + + EXPECT_TRUE(out.is_object()); + EXPECT_EQ(out.value("content", ""), "only"); +} + +TEST(ResultsToJson, MultipleResults_ReturnsArray) { + std::vector results; + results.push_back(make_ok(1, "a")); + results.push_back(make_ok(2, "b")); + + json out = results_to_json(results); + + EXPECT_TRUE(out.is_array()); + ASSERT_EQ(out.size(), 2u); + EXPECT_EQ(out[0].value("content", ""), "a"); + EXPECT_EQ(out[1].value("content", ""), "b"); +} + +TEST(ResultsToJson, EmptyVector_ReturnsEmptyArray) { + std::vector results; + json out = results_to_json(results); + EXPECT_TRUE(out.is_array()); + EXPECT_TRUE(out.empty()); +} + +// ============================================================ +// rerank_results_to_json +// ============================================================ + +namespace { +struct fake_rerank_result : server_task_result { + int index; float score; + fake_rerank_result(int id_, int idx, float sc) : index(idx), score(sc) { id = id_; } + json to_json() override { return {{"index", index}, {"score", score}}; } +}; +static server_task_result_ptr make_rerank(int id_, int idx, float sc) { + return std::make_unique(id_, idx, sc); +} +} // namespace + +TEST(RerankResultsToJson, TwoResults_CorrectShape) { + std::vector results; + results.push_back(make_rerank(1, 0, 0.9f)); + results.push_back(make_rerank(2, 1, 0.4f)); + std::vector docs = {"doc A", "doc B"}; + + json out = rerank_results_to_json(results, docs); + + ASSERT_TRUE(out.is_array()); + ASSERT_EQ(out.size(), 2u); + EXPECT_EQ(out[0].value("document", ""), "doc A"); + EXPECT_EQ(out[0].value("index", -1), 0); + EXPECT_FLOAT_EQ(out[0].value("score", 0.0f), 0.9f); + EXPECT_EQ(out[1].value("document", ""), "doc B"); +} + +TEST(RerankResultsToJson, EmptyResults_ReturnsEmptyArray) { + std::vector results; + json out = rerank_results_to_json(results, {}); + EXPECT_TRUE(out.is_array()); + EXPECT_TRUE(out.empty()); +} + +// ============================================================ +// build_embeddings_response_json +// ============================================================ + +TEST(BuildEmbeddingsResponseJson, NonOai_SingleResult_ReturnsBareArray) { + std::vector results; + results.push_back(make_embedding(1, {0.1f, 0.2f})); + + json out = build_embeddings_response_json(results, json::object(), + OAICOMPAT_TYPE_NONE, false); + + ASSERT_TRUE(out.is_array()); + ASSERT_EQ(out.size(), 1u); + EXPECT_TRUE(out[0].contains("embedding")); +} + +TEST(BuildEmbeddingsResponseJson, NonOai_MultipleResults_AllInArray) { + std::vector results; + results.push_back(make_embedding(1, {0.1f})); + results.push_back(make_embedding(2, {0.2f})); + results.push_back(make_embedding(3, {0.3f})); + + json out = build_embeddings_response_json(results, json::object(), + OAICOMPAT_TYPE_NONE, false); + + ASSERT_TRUE(out.is_array()); + EXPECT_EQ(out.size(), 3u); +} + +TEST(BuildEmbeddingsResponseJson, OaiFloat_WrapsWithOaiStructure) { + std::vector results; + results.push_back(make_embedding(1, {0.5f, 0.6f, 0.7f})); + json body = {{"model", "text-embedding-ada-002"}}; + + json out = build_embeddings_response_json(results, body, + OAICOMPAT_TYPE_EMBEDDING, false); + + EXPECT_TRUE(out.is_object()); + EXPECT_EQ(out.value("object", ""), "list"); + EXPECT_TRUE(out.contains("data")); + EXPECT_TRUE(out.contains("usage")); + EXPECT_EQ(out.value("model", ""), "text-embedding-ada-002"); + ASSERT_TRUE(out["data"].is_array()); + ASSERT_EQ(out["data"].size(), 1u); + EXPECT_EQ(out["data"][0].value("object", ""), "embedding"); +} + +TEST(BuildEmbeddingsResponseJson, OaiBase64_EmbeddingEncodedAsString) { + std::vector results; + results.push_back(make_embedding(1, {1.0f, 2.0f})); + + json out = build_embeddings_response_json(results, json::object(), + OAICOMPAT_TYPE_EMBEDDING, /*use_base64=*/true); + + ASSERT_TRUE(out["data"].is_array()); + EXPECT_TRUE(out["data"][0]["embedding"].is_string()) + << "base64 embedding must be serialised as a string"; +} + +TEST(BuildEmbeddingsResponseJson, OaiUsage_TokensSummedAcrossResults) { + std::vector results; + results.push_back(std::make_unique(1, std::vector{0.1f}, 3)); + results.push_back(std::make_unique(2, std::vector{0.2f}, 5)); + + json out = build_embeddings_response_json(results, json::object(), + OAICOMPAT_TYPE_EMBEDDING, false); + + EXPECT_EQ(out["usage"].value("prompt_tokens", 0), 8) + << "usage.prompt_tokens must be sum of tokens_evaluated across all results"; +} + +// ============================================================ +// extract_first_embedding_row +// ============================================================ + +TEST(ExtractFirstEmbeddingRow, SingleRow_ReturnsRow) { + json j = {{"embedding", {{0.1f, 0.2f, 0.3f}}}}; + auto row = extract_first_embedding_row(j); + ASSERT_EQ(row.size(), 3u); + EXPECT_FLOAT_EQ(row[0], 0.1f); + EXPECT_FLOAT_EQ(row[1], 0.2f); + EXPECT_FLOAT_EQ(row[2], 0.3f); +} + +TEST(ExtractFirstEmbeddingRow, MultipleRows_ReturnsFirstRowOnly) { + json j = {{"embedding", {{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}}}; + auto row = extract_first_embedding_row(j); + ASSERT_EQ(row.size(), 2u); + EXPECT_FLOAT_EQ(row[0], 1.0f); + EXPECT_FLOAT_EQ(row[1], 2.0f); +} + +TEST(ExtractFirstEmbeddingRow, MissingEmbeddingKey_ThrowsJsonException) { + json j = {{"other_key", "value"}}; + EXPECT_THROW(extract_first_embedding_row(j), nlohmann::json::exception); +} + +TEST(ExtractFirstEmbeddingRow, EmptyOuterArray_ThrowsRuntimeError) { + json j = {{"embedding", json::array()}}; + EXPECT_THROW(extract_first_embedding_row(j), std::runtime_error); +} + +TEST(ExtractFirstEmbeddingRow, EmptyInnerArray_ThrowsRuntimeError) { + json j = {{"embedding", {json::array()}}}; + EXPECT_THROW(extract_first_embedding_row(j), std::runtime_error); +} + +TEST(ExtractFirstEmbeddingRow, LargeRow_AllValuesPreserved) { + std::vector vals(128); + for (int i = 0; i < 128; ++i) vals[i] = static_cast(i) * 0.01f; + json j = {{"embedding", {vals}}}; + auto row = extract_first_embedding_row(j); + ASSERT_EQ(row.size(), 128u); + for (int i = 0; i < 128; ++i) { + EXPECT_FLOAT_EQ(row[i], static_cast(i) * 0.01f); + } +} + +// ============================================================ +// parse_encoding_format +// ============================================================ + +TEST(ParseEncodingFormat, FieldAbsent_ReturnsFalse) { + EXPECT_FALSE(parse_encoding_format({{"model", "x"}})); +} + +TEST(ParseEncodingFormat, ExplicitFloat_ReturnsFalse) { + EXPECT_FALSE(parse_encoding_format({{"encoding_format", "float"}})); +} + +TEST(ParseEncodingFormat, Base64_ReturnsTrue) { + EXPECT_TRUE(parse_encoding_format({{"encoding_format", "base64"}})); +} + +TEST(ParseEncodingFormat, UnknownFormat_ThrowsInvalidArgument) { + EXPECT_THROW(parse_encoding_format({{"encoding_format", "binary"}}), + std::invalid_argument); +} + +TEST(ParseEncodingFormat, EmptyString_ThrowsInvalidArgument) { + EXPECT_THROW(parse_encoding_format({{"encoding_format", ""}}), + std::invalid_argument); +} + +TEST(ParseEncodingFormat, ErrorMessage_MentionsBothValidOptions) { + try { + parse_encoding_format({{"encoding_format", "hex"}}); + FAIL() << "Expected std::invalid_argument"; + } catch (const std::invalid_argument &e) { + const std::string msg(e.what()); + EXPECT_NE(msg.find("float"), std::string::npos); + EXPECT_NE(msg.find("base64"), std::string::npos); + } +} + +// ============================================================ +// extract_embedding_prompt +// ============================================================ + +TEST(ExtractEmbeddingPrompt, InputKey_ReturnsValueAndDoesNotSetFlag) { + bool flag = true; // pre-set to verify it gets cleared + json prompt = extract_embedding_prompt({{"input", "hello world"}}, flag); + EXPECT_EQ(prompt, "hello world"); + EXPECT_FALSE(flag); +} + +TEST(ExtractEmbeddingPrompt, ContentKey_ReturnsValueAndSetsFlag) { + bool flag = false; + json prompt = extract_embedding_prompt({{"content", "some text"}}, flag); + EXPECT_EQ(prompt, "some text"); + EXPECT_TRUE(flag); +} + +TEST(ExtractEmbeddingPrompt, InputTakesPriorityOverContent) { + bool flag = false; + json prompt = extract_embedding_prompt( + {{"input", "from input"}, {"content", "from content"}}, flag); + EXPECT_EQ(prompt, "from input"); + EXPECT_FALSE(flag); +} + +TEST(ExtractEmbeddingPrompt, NeitherKey_ThrowsInvalidArgument) { + bool flag = false; + EXPECT_THROW(extract_embedding_prompt({{"model", "x"}}, flag), + std::invalid_argument); +} + +TEST(ExtractEmbeddingPrompt, EmptyBody_ThrowsInvalidArgument) { + bool flag = false; + EXPECT_THROW(extract_embedding_prompt(json::object(), flag), + std::invalid_argument); +} + +TEST(ExtractEmbeddingPrompt, ArrayPrompt_ReturnedAsIs) { + bool flag = false; + json prompt = extract_embedding_prompt( + {{"input", {"sentence one", "sentence two"}}}, flag); + ASSERT_TRUE(prompt.is_array()); + ASSERT_EQ(prompt.size(), 2u); + EXPECT_EQ(prompt[0], "sentence one"); + EXPECT_EQ(prompt[1], "sentence two"); + EXPECT_FALSE(flag); +} + +// ============================================================ +// is_infill_request +// ============================================================ + +TEST(IsInfillRequest, HasInputPrefix_ReturnsTrue) { + EXPECT_TRUE(is_infill_request({{"input_prefix", "def f():"}})); +} + +TEST(IsInfillRequest, HasInputSuffix_ReturnsTrue) { + EXPECT_TRUE(is_infill_request({{"input_suffix", "return 1"}})); +} + +TEST(IsInfillRequest, HasBoth_ReturnsTrue) { + EXPECT_TRUE(is_infill_request( + {{"input_prefix", "def f():"}, {"input_suffix", "return 1"}})); +} + +TEST(IsInfillRequest, HasNeither_ReturnsFalse) { + EXPECT_FALSE(is_infill_request({{"prompt", "hello"}})); +} + +TEST(IsInfillRequest, EmptyBody_ReturnsFalse) { + EXPECT_FALSE(is_infill_request(json::object())); +} + +// ============================================================ +// parse_slot_prompt_similarity +// ============================================================ + +TEST(ParseSlotPromptSimilarity, FieldAbsent_ReturnsEmpty) { + EXPECT_FALSE(parse_slot_prompt_similarity({{"other", 1}}).has_value()); +} + +TEST(ParseSlotPromptSimilarity, Zero_ReturnsZero) { + auto v = parse_slot_prompt_similarity({{"slot_prompt_similarity", 0.0f}}); + ASSERT_TRUE(v.has_value()); + EXPECT_FLOAT_EQ(*v, 0.0f); +} + +TEST(ParseSlotPromptSimilarity, Half_ReturnsHalf) { + auto v = parse_slot_prompt_similarity({{"slot_prompt_similarity", 0.5f}}); + ASSERT_TRUE(v.has_value()); + EXPECT_FLOAT_EQ(*v, 0.5f); +} + +TEST(ParseSlotPromptSimilarity, One_ReturnsOne) { + auto v = parse_slot_prompt_similarity({{"slot_prompt_similarity", 1.0f}}); + ASSERT_TRUE(v.has_value()); + EXPECT_FLOAT_EQ(*v, 1.0f); +} + +TEST(ParseSlotPromptSimilarity, TooLow_ThrowsInvalidArgument) { + EXPECT_THROW( + parse_slot_prompt_similarity({{"slot_prompt_similarity", -0.1f}}), + std::invalid_argument); +} + +TEST(ParseSlotPromptSimilarity, TooHigh_ThrowsInvalidArgument) { + EXPECT_THROW( + parse_slot_prompt_similarity({{"slot_prompt_similarity", 1.1f}}), + std::invalid_argument); +} + +// ============================================================ +// parse_positive_int_config +// ============================================================ + +TEST(ParsePositiveIntConfig, FieldAbsent_ReturnsEmpty) { + EXPECT_FALSE(parse_positive_int_config({{"other", 1}}, "n_threads").has_value()); +} + +TEST(ParsePositiveIntConfig, ValidOne_ReturnsOne) { + auto v = parse_positive_int_config({{"n_threads", 1}}, "n_threads"); + ASSERT_TRUE(v.has_value()); + EXPECT_EQ(*v, 1); +} + +TEST(ParsePositiveIntConfig, ValidLarge_ReturnsValue) { + auto v = parse_positive_int_config({{"n_threads", 128}}, "n_threads"); + ASSERT_TRUE(v.has_value()); + EXPECT_EQ(*v, 128); +} + +TEST(ParsePositiveIntConfig, Zero_ThrowsInvalidArgument) { + EXPECT_THROW(parse_positive_int_config({{"n_threads", 0}}, "n_threads"), + std::invalid_argument); +} + +TEST(ParsePositiveIntConfig, Negative_ThrowsInvalidArgument) { + EXPECT_THROW(parse_positive_int_config({{"n_threads", -4}}, "n_threads"), + std::invalid_argument); +} + +TEST(ParsePositiveIntConfig, ErrorMessage_ContainsKeyName) { + try { + parse_positive_int_config({{"n_threads_batch", 0}}, "n_threads_batch"); + FAIL() << "Expected std::invalid_argument"; + } catch (const std::invalid_argument &e) { + EXPECT_NE(std::string(e.what()).find("n_threads_batch"), std::string::npos); + } +} diff --git a/src/test/cpp/test_server.cpp b/src/test/cpp/test_server.cpp index 193417b3..98d1b0f9 100644 --- a/src/test/cpp/test_server.cpp +++ b/src/test/cpp/test_server.cpp @@ -17,7 +17,7 @@ // - stop_type_to_str (enum → string mapping for all stop types) // - oaicompat_finish_reason (extracted helper: stop_type + tool_calls → OAI finish_reason) // -// collect_task_results_impl() is tested in test_jni_server_helpers.cpp. +// collect_task_results_impl() is tested in test_jni_helpers.cpp. #include From bbde4d693c7fdcc282aa4469ef750d0b143a0af6 Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 6 Apr 2026 08:01:44 +0000 Subject: [PATCH 06/11] Update CLAUDE.md: document json_helpers.hpp / jni_helpers.hpp architecture Adds a "Native Helper Architecture" subsection under Architecture that describes the semantic split between the two helper headers, the TU include-order rule, the full function inventory for each header, and guidance for adding new helpers. Updates the Testing section with a C++ unit test table and cmake invocation for running jllama_test without a live JVM or model file. https://claude.ai/code/session_0197uhVYKafh2feJkLuq7KiT --- CLAUDE.md | 79 ++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 78 insertions(+), 1 deletion(-) diff --git a/CLAUDE.md b/CLAUDE.md index a62ca232..89c29998 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -197,8 +197,67 @@ clang-format -i src/main/cpp/*.cpp src/main/cpp/*.hpp # Format C++ code - `jllama.cpp` — JNI implementation bridging Java calls to llama.cpp. - `server.hpp` — Inference server logic (adapted from llama.cpp's server). - `utils.hpp` — Helper utilities. +- `json_helpers.hpp` — Pure JSON transformation helpers (no JNI, no llama state). Independently unit-testable. +- `jni_helpers.hpp` — JNI bridge helpers (handle management + server orchestration). Includes `json_helpers.hpp`. - Uses `nlohmann/json` for JSON deserialization of parameters. +### Native Helper Architecture + +The project C++ helpers follow a strict semantic split: + +**`json_helpers.hpp`** — Pure data transforms. +- Input: `nlohmann::json`, `server_task_result_ptr`, plain C++ types. +- Output: `json`, `std::vector`, `std::optional`, plain C++ types. +- Zero JNI calls (`JNIEnv*` never appears). +- Zero llama state (`llama_context*`, `llama_vocab*`, `server_context*` never appear). +- Functions are named without `_impl` suffix — they are the canonical implementation. +- Testable with JSON literals and fake result objects; no JVM and no loaded model required. +- Requires `server.hpp` to be included by the translation unit first (TU convention — `server.hpp` has no include guard). + +Functions: `get_result_error_message`, `results_to_json`, `rerank_results_to_json`, +`build_embeddings_response_json`, `extract_first_embedding_row`, `parse_encoding_format`, +`extract_embedding_prompt`, `is_infill_request`, `parse_slot_prompt_similarity`, +`parse_positive_int_config`. + +**`jni_helpers.hpp`** — JNI bridge helpers, split into two layers: + +*Layer A* (no `server.hpp` required): handle management. +- `jllama_context` struct — owns `server_context*` and background worker thread. +- `get_server_context_impl` — reads Java `ctx` handle, throws on null. +- `get_jllama_context_impl` — like above but returns the wrapper (delete path only). +- `require_single_task_id_impl` — validates exactly one task ID was created. +- `require_json_field_impl` — throws `" is required"` if key is absent. +- `jint_array_to_tokens_impl` — reads a Java `int[]` into `std::vector`. + +*Layer B* (requires `server.hpp` in the TU before `jni_helpers.hpp`): server orchestration. +Includes `json_helpers.hpp` so all bridge helpers can call transforms directly. +- `json_to_jstring_impl` — serialises any `json` value to a JNI string. +- `build_completion_tasks_impl` — tokenises prompt and populates `server_task` vector. +- `recv_slot_task_result_impl` — receives one slot result, throws on error. +- `collect_task_results_impl` — receives all results for a task-id set, throws on error. +- `results_to_jstring_impl` — delegates to `results_to_json` then `json_to_jstring_impl`. +- `check_infill_support_impl` — validates FIM prefix/suffix/middle tokens present. +- `append_task` — constructs and appends a `server_task` of a given type. + +Functions with `_impl` suffix have a thin module-level wrapper in `jllama.cpp`; functions +without the suffix (in `json_helpers.hpp`) are called directly. + +**Include order rule:** +``` +// In jllama.cpp and any TU that uses Layer B helpers: +#include "server.hpp" // must come first — no include guard +#include "jni_helpers.hpp" // includes json_helpers.hpp internally +``` + +**Adding a new pure transform** (e.g. a new JSON field parser): +- Add it to `json_helpers.hpp`. No JNI, no llama types. +- Add tests to `src/test/cpp/test_json_helpers.cpp`. + +**Adding a new JNI bridge helper:** +- Add it to `jni_helpers.hpp` in the appropriate layer. +- If it needs `server.hpp` types, put it in Layer B (after the `json_helpers.hpp` include). +- Add tests to `src/test/cpp/test_jni_helpers.cpp`. + ### Parameter Flow Java parameters are serialized to JSON strings and passed to native code, which deserializes them using nlohmann/json. This avoids complex JNI field mapping for the many llama.cpp parameters. @@ -213,7 +272,8 @@ Docker-based cross-compilation scripts are in `.github/dockcross/` for ARM/Andro ## Testing -Tests require a model file. The CI downloads models from HuggingFace: +### Java tests +Require a model file. The CI downloads models from HuggingFace: - **LlamaModel tests**: CodeLlama-7B-GGUF (`codellama-7b.Q2_K.gguf`) - **RerankingModel tests**: Jina-Reranker model @@ -221,6 +281,23 @@ Set the model path via system property or environment variable (see test files f Test files are in `src/test/java/de/kherud/llama/` and `src/test/java/examples/`. +### C++ unit tests +No JVM or model file required. Built as `jllama_test` via CMake when `BUILD_TESTING=ON`. + +| File | What it tests | +|------|---------------| +| `test_json_helpers.cpp` | All functions in `json_helpers.hpp` — pure JSON transforms, using fake result objects | +| `test_jni_helpers.cpp` | All functions in `jni_helpers.hpp` — mock `JNIEnv`, pre-seeded `server_response` queue | +| `test_server.cpp` | Selected `server.hpp` internals (result types, error formatting, routing helpers) | +| `test_utils.cpp` | Utilities from `utils.hpp` | + +Run C++ tests: +```bash +cmake -B build -DBUILD_TESTING=ON +cmake --build build --config Release +ctest --test-dir build --output-on-failure +``` + ## Key Constraints - **Java 11+** required. From ab2fb44422853f35feb0227eada524973e22a769 Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 6 Apr 2026 08:47:21 +0000 Subject: [PATCH 07/11] Slim embed() and handleEmbeddings(); add embedding_to_jfloat_array_impl MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Items 2 and 3 of the separation-of-concerns backlog: Item 3 — embed() manual recv replaced by collect_task_results embed() was the only dispatch site that opened the result queue directly instead of going through the shared collect_task_results wrapper. It also called remove_waiting_task_id() on two separate code paths (error and stop), creating asymmetric cleanup risk. Now: dispatch_tasks → collect_task_results → results[0]->to_json(). Consistent with handleRerank/handleEmbeddings/handleCompletions. Item 2a — embedding_to_jfloat_array_impl (jni_helpers.hpp) The JNI float array allocation and SetFloatArrayRegion call in embed() are pure JNI mechanics, not domain logic. Extracted to embedding_to_jfloat_array_impl: takes a vector, allocates the jfloatArray, copies the data, throws OOM on allocation failure. embed() is now a thin caller. Five unit tests added to test_jni_helpers.cpp (FloatArrayFixture). Item 2b — handleEmbeddings() double try-catch merged extract_embedding_prompt and parse_encoding_format were wrapped in two separate try { ... } catch { ThrowNew; return nullptr; } blocks. Merged into one block; semantics are identical since the two calls are independent and both propagate as c_llama_error. Also corrected the error class used when extract_first_embedding_row throws in embed(): was c_error_oom (wrong — not an allocation failure), now c_llama_error. https://claude.ai/code/session_0197uhVYKafh2feJkLuq7KiT --- CLAUDE.md | 1 + src/main/cpp/jllama.cpp | 50 +++++++--------------- src/main/cpp/jni_helpers.hpp | 24 +++++++++++ src/test/cpp/test_jni_helpers.cpp | 70 ++++++++++++++++++++++++++++++- 4 files changed, 108 insertions(+), 37 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 89c29998..78f09ec7 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -238,6 +238,7 @@ Includes `json_helpers.hpp` so all bridge helpers can call transforms directly. - `results_to_jstring_impl` — delegates to `results_to_json` then `json_to_jstring_impl`. - `check_infill_support_impl` — validates FIM prefix/suffix/middle tokens present. - `append_task` — constructs and appends a `server_task` of a given type. +- `embedding_to_jfloat_array_impl` — converts `std::vector` to a Java `jfloatArray`; throws OOM on allocation failure. Functions with `_impl` suffix have a thin module-level wrapper in `jllama.cpp`; functions without the suffix (in `json_helpers.hpp`) are called directly. diff --git a/src/main/cpp/jllama.cpp b/src/main/cpp/jllama.cpp index c6c8eaa5..fd0d0317 100644 --- a/src/main/cpp/jllama.cpp +++ b/src/main/cpp/jllama.cpp @@ -826,46 +826,22 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env, auto tokens = tokenize_mixed(ctx_server->vocab, prompt, true, true); std::vector tasks; - append_task(ctx_server, tasks, SERVER_TASK_TYPE_EMBEDDING, tokens, 0); const auto task_ids = dispatch_tasks(ctx_server, tasks); - const auto id_task = *task_ids.begin(); - - server_task_result_ptr result = ctx_server->queue_results.recv(id_task); - - if (result->is_error()) { - ctx_server->queue_results.remove_waiting_task_id(id_task); - env->ThrowNew(c_llama_error, get_result_error_message(result).c_str()); - return nullptr; - } - - if (result->is_stop()) { - ctx_server->queue_results.remove_waiting_task_id(id_task); - } - - const auto out_res = result->to_json(); + std::vector results; + if (!collect_task_results(env, ctx_server, task_ids, results)) return nullptr; std::vector first_row; try { - first_row = extract_first_embedding_row(out_res); + first_row = extract_first_embedding_row(results[0]->to_json()); } catch (const std::exception &e) { - env->ThrowNew(c_error_oom, e.what()); - return nullptr; - } - - const jsize embedding_cols = static_cast(first_row.size()); - SRV_INF("Embedding has %d columns\n", embedding_cols); - - jfloatArray j_embedding = env->NewFloatArray(embedding_cols); - if (j_embedding == nullptr) { - env->ThrowNew(c_error_oom, "could not allocate embedding"); + env->ThrowNew(c_llama_error, e.what()); return nullptr; } - env->SetFloatArrayRegion(j_embedding, 0, embedding_cols, - reinterpret_cast(first_row.data())); - return j_embedding; + SRV_INF("Embedding has %d columns\n", static_cast(first_row.size())); + return embedding_to_jfloat_array_impl(env, first_row, c_error_oom); } JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleRerank(JNIEnv *env, jobject obj, jstring jprompt, @@ -1167,13 +1143,15 @@ JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleEmbeddings(JNIEn bool force_no_oaicompat = false; json prompt; - try { prompt = extract_embedding_prompt(body, force_no_oaicompat); } - catch (const std::exception &e) { env->ThrowNew(c_llama_error, e.what()); return nullptr; } - if (force_no_oaicompat) oaicompat = OAICOMPAT_TYPE_NONE; - bool use_base64 = false; - try { use_base64 = parse_encoding_format(body); } - catch (const std::exception &e) { env->ThrowNew(c_llama_error, e.what()); return nullptr; } + try { + prompt = extract_embedding_prompt(body, force_no_oaicompat); + use_base64 = parse_encoding_format(body); + } catch (const std::exception &e) { + env->ThrowNew(c_llama_error, e.what()); + return nullptr; + } + if (force_no_oaicompat) oaicompat = OAICOMPAT_TYPE_NONE; std::vector tokenized_prompts = tokenize_input_prompts(ctx_server->vocab, prompt, true, true); diff --git a/src/main/cpp/jni_helpers.hpp b/src/main/cpp/jni_helpers.hpp index 11247bec..f017869c 100644 --- a/src/main/cpp/jni_helpers.hpp +++ b/src/main/cpp/jni_helpers.hpp @@ -46,6 +46,7 @@ // 11. results_to_jstring_impl — uses results_to_json (json_helpers), json_to_jstring_impl // 12. check_infill_support_impl // 13. append_task +// 14. embedding_to_jfloat_array_impl #include "jni.h" #include "nlohmann/json.hpp" @@ -349,3 +350,26 @@ inline void append_task(server_context *ctx_server, task.params.oaicompat = oaicompat; tasks.push_back(std::move(task)); } + +// --------------------------------------------------------------------------- +// embedding_to_jfloat_array_impl +// +// Converts a float vector to a Java jfloatArray. +// +// On success: returns a new jfloatArray filled with the embedding values. +// On allocation failure: throws via JNI with oom_class and returns nullptr. +// --------------------------------------------------------------------------- +[[nodiscard]] inline jfloatArray embedding_to_jfloat_array_impl( + JNIEnv *env, + const std::vector &values, + jclass oom_class) { + const jsize len = static_cast(values.size()); + jfloatArray arr = env->NewFloatArray(len); + if (arr == nullptr) { + env->ThrowNew(oom_class, "could not allocate embedding"); + return nullptr; + } + env->SetFloatArrayRegion(arr, 0, len, + reinterpret_cast(values.data())); + return arr; +} diff --git a/src/test/cpp/test_jni_helpers.cpp b/src/test/cpp/test_jni_helpers.cpp index 4ff0a7c4..f6b972b3 100644 --- a/src/test/cpp/test_jni_helpers.cpp +++ b/src/test/cpp/test_jni_helpers.cpp @@ -14,7 +14,7 @@ // Layer B tests (need server.hpp + mock JNIEnv + pre-seeded server_response): // json_to_jstring_impl, results_to_jstring_impl, // build_completion_tasks_impl, recv_slot_task_result_impl, -// collect_task_results_impl +// collect_task_results_impl, embedding_to_jfloat_array_impl // // JNIEnv is mocked via a zero-filled JNINativeInterface_ table with only the // slots exercised by each test patched. server_response is used directly: @@ -33,6 +33,8 @@ #include "server.hpp" #include "jni_helpers.hpp" +// embedding_to_jfloat_array_impl is also tested in this file (see bottom). + // ============================================================ // Shared fake result types // ============================================================ @@ -530,3 +532,69 @@ TEST_F(MockJniFixture, BuildTasks_MissingPrompt_InfillTypeHasSameBehaviour) { EXPECT_TRUE(g_throw_called); EXPECT_TRUE(tasks.empty()); } + +// ============================================================ +// embedding_to_jfloat_array_impl +// ============================================================ + +namespace { + +static bool g_float_new_called = false; +static jsize g_float_alloc_size = -1; +static jsize g_float_copied_size = -1; + +static jfloatArray JNICALL stub_NewFloatArray(JNIEnv *, jsize n) { + g_float_new_called = true; + g_float_alloc_size = n; + return reinterpret_cast(0xF1); +} +static void JNICALL stub_SetFloatArrayRegion(JNIEnv *, jfloatArray, jsize, jsize n, const jfloat *) { + g_float_copied_size = n; +} + +struct FloatArrayFixture : MockJniFixture { + void SetUp() override { + MockJniFixture::SetUp(); + g_float_new_called = false; + g_float_alloc_size = -1; + g_float_copied_size = -1; + table.NewFloatArray = stub_NewFloatArray; + table.SetFloatArrayRegion = stub_SetFloatArrayRegion; + } +}; + +} // namespace + +TEST_F(FloatArrayFixture, EmbeddingToJfloatArray_ReturnsSentinel) { + std::vector v = {1.0f, 2.0f, 3.0f}; + auto *result = embedding_to_jfloat_array_impl(env, v, dummy_class); + EXPECT_EQ(result, reinterpret_cast(0xF1)); +} + +TEST_F(FloatArrayFixture, EmbeddingToJfloatArray_AllocatesCorrectSize) { + std::vector v = {0.1f, 0.2f}; + embedding_to_jfloat_array_impl(env, v, dummy_class); + EXPECT_EQ(g_float_alloc_size, 2); +} + +TEST_F(FloatArrayFixture, EmbeddingToJfloatArray_CopiesAllElements) { + std::vector v(5, 0.5f); + embedding_to_jfloat_array_impl(env, v, dummy_class); + EXPECT_EQ(g_float_copied_size, 5); +} + +TEST_F(FloatArrayFixture, EmbeddingToJfloatArray_EmptyVector_AllocatesZeroLen) { + std::vector v; + embedding_to_jfloat_array_impl(env, v, dummy_class); + EXPECT_EQ(g_float_alloc_size, 0); + EXPECT_FALSE(g_throw_called); +} + +TEST_F(FloatArrayFixture, EmbeddingToJfloatArray_AllocFails_ThrowsOomAndReturnsNull) { + table.NewFloatArray = [](JNIEnv *, jsize) -> jfloatArray { return nullptr; }; + std::vector v = {1.0f}; + auto *result = embedding_to_jfloat_array_impl(env, v, dummy_class); + EXPECT_EQ(result, nullptr); + EXPECT_TRUE(g_throw_called); + EXPECT_EQ(g_throw_message, "could not allocate embedding"); +} From 5a7b03ce7e7781873263a51e328c8bd6199ef131 Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 6 Apr 2026 09:00:14 +0000 Subject: [PATCH 08/11] Extract dispatch_completion_and_serialize, request_completion_task_id, tokens_to_jint_array_impl MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three deduplication items applied: Finding 1 — dispatch_completion_and_serialize (4 call sites → 1) handleCompletions, handleCompletionsOai, handleChatCompletions, and handleInfill each repeated the same 6-line pipeline: gen_chatcmplid → build_completion_tasks → dispatch_tasks → collect_and_serialize Extracted to dispatch_completion_and_serialize(env, ctx_server, data, task_type, oaicompat). Each caller reduces to a single return statement. Finding 2 — request_completion_task_id (2 call sites → 1) requestCompletion and requestChatCompletion repeated the same pipeline but returned the task ID for streaming instead of serialising results: gen_chatcmplid → build_completion_tasks → dispatch_tasks → require_single_task_id Extracted to request_completion_task_id(env, ctx_server, data, task_type, oaicompat). Finding 3 — tokens_to_jint_array_impl (jni_helpers.hpp) encode() had an inline 8-line NewIntArray + null-check + SetIntArrayRegion block, symmetric with embedding_to_jfloat_array_impl. Extracted to tokens_to_jint_array_impl. encode() reduces to two lines. Five unit tests added to test_jni_helpers.cpp (IntArrayFixture). The embedding dispatch sites (embed, handleRerank, handleEmbeddings) are intentionally not touched — each uses a different JSON builder so they cannot share a single wrapper. https://claude.ai/code/session_0197uhVYKafh2feJkLuq7KiT --- CLAUDE.md | 1 + src/main/cpp/jllama.cpp | 116 +++++++++++++++--------------- src/main/cpp/jni_helpers.hpp | 25 +++++++ src/test/cpp/test_jni_helpers.cpp | 69 +++++++++++++++++- 4 files changed, 150 insertions(+), 61 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 78f09ec7..2157abff 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -239,6 +239,7 @@ Includes `json_helpers.hpp` so all bridge helpers can call transforms directly. - `check_infill_support_impl` — validates FIM prefix/suffix/middle tokens present. - `append_task` — constructs and appends a `server_task` of a given type. - `embedding_to_jfloat_array_impl` — converts `std::vector` to a Java `jfloatArray`; throws OOM on allocation failure. +- `tokens_to_jint_array_impl` — converts `std::vector` to a Java `jintArray`; throws OOM on allocation failure. Functions with `_impl` suffix have a thin module-level wrapper in `jllama.cpp`; functions without the suffix (in `json_helpers.hpp`) are called directly. diff --git a/src/main/cpp/jllama.cpp b/src/main/cpp/jllama.cpp index fd0d0317..a4a653b4 100644 --- a/src/main/cpp/jllama.cpp +++ b/src/main/cpp/jllama.cpp @@ -272,6 +272,50 @@ static int require_single_task_id(JNIEnv *env, return results_to_jstring(env, results); } +/** + * Build completion tasks from `data`, dispatch them, collect all results, and + * serialise to a JNI string. Used by handleCompletions, handleCompletionsOai, + * handleChatCompletions, and handleInfill — all of which follow exactly this + * pipeline and differ only in task_type and oaicompat. + * + * On error (build or collect fails): a JNI exception is already pending; + * returns nullptr so the caller can propagate it. + */ +[[nodiscard]] static jstring dispatch_completion_and_serialize( + JNIEnv *env, + server_context *ctx_server, + const json &data, + server_task_type task_type, + oaicompat_type oaicompat) { + auto completion_id = gen_chatcmplid(); + std::vector tasks; + if (!build_completion_tasks(env, ctx_server, data, completion_id, + task_type, oaicompat, tasks)) return nullptr; + const auto task_ids = dispatch_tasks(ctx_server, tasks); + return collect_and_serialize(env, ctx_server, task_ids); +} + +/** + * Build completion tasks from `data`, dispatch them, and return the single + * task ID to the Java caller for streaming via receiveCompletionJson. + * Used by requestCompletion and requestChatCompletion. + * + * On error: a JNI exception is already pending; returns 0. + */ +[[nodiscard]] static int request_completion_task_id( + JNIEnv *env, + server_context *ctx_server, + const json &data, + server_task_type task_type, + oaicompat_type oaicompat) { + auto completion_id = gen_chatcmplid(); + std::vector tasks; + if (!build_completion_tasks(env, ctx_server, data, completion_id, + task_type, oaicompat, tasks)) return 0; + const auto task_ids = dispatch_tasks(ctx_server, tasks); + return require_single_task_id(env, task_ids); +} + /** * Convert a Java string to a std::string */ @@ -773,15 +817,7 @@ JNIEXPORT jint JNICALL Java_de_kherud_llama_LlamaModel_requestCompletion(JNIEnv ? SERVER_TASK_TYPE_INFILL : SERVER_TASK_TYPE_COMPLETION; - auto completion_id = gen_chatcmplid(); - std::vector tasks; - // oaicompat_model is already populated by params_from_json_cmpl inside the helper - if (!build_completion_tasks(env, ctx_server, data, completion_id, - type, OAICOMPAT_TYPE_NONE, tasks)) return 0; - - const auto task_ids = dispatch_tasks(ctx_server, tasks); - - return require_single_task_id(env, task_ids); + return request_completion_task_id(env, ctx_server, data, type, OAICOMPAT_TYPE_NONE); } JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_releaseTask(JNIEnv *env, jobject obj, jint id_task) { @@ -900,14 +936,8 @@ JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleChatCompletions( json data; if (!parse_oai_chat_params(env, ctx_server, body, data)) return nullptr; - auto completion_id = gen_chatcmplid(); - std::vector tasks; - if (!build_completion_tasks(env, ctx_server, data, completion_id, - SERVER_TASK_TYPE_COMPLETION, OAICOMPAT_TYPE_CHAT, tasks)) return nullptr; - - const auto task_ids = dispatch_tasks(ctx_server, tasks); - - return collect_and_serialize(env, ctx_server, task_ids); + return dispatch_completion_and_serialize(env, ctx_server, data, + SERVER_TASK_TYPE_COMPLETION, OAICOMPAT_TYPE_CHAT); } JNIEXPORT jint JNICALL Java_de_kherud_llama_LlamaModel_requestChatCompletion(JNIEnv *env, jobject obj, @@ -920,14 +950,8 @@ JNIEXPORT jint JNICALL Java_de_kherud_llama_LlamaModel_requestChatCompletion(JNI json data; if (!parse_oai_chat_params(env, ctx_server, body, data)) return 0; - auto completion_id = gen_chatcmplid(); - std::vector tasks; - if (!build_completion_tasks(env, ctx_server, data, completion_id, - SERVER_TASK_TYPE_COMPLETION, OAICOMPAT_TYPE_NONE, tasks)) return 0; - - const auto task_ids = dispatch_tasks(ctx_server, tasks); - - return require_single_task_id(env, task_ids); + return request_completion_task_id(env, ctx_server, data, + SERVER_TASK_TYPE_COMPLETION, OAICOMPAT_TYPE_NONE); } JNIEXPORT jintArray JNICALL Java_de_kherud_llama_LlamaModel_encode(JNIEnv *env, jobject obj, jstring jprompt) { @@ -936,17 +960,7 @@ JNIEXPORT jintArray JNICALL Java_de_kherud_llama_LlamaModel_encode(JNIEnv *env, const std::string c_prompt = parse_jstring(env, jprompt); llama_tokens tokens = tokenize_mixed(ctx_server->vocab, c_prompt, false, true); - jsize token_size = tokens.size(); // NOLINT(*-narrowing-conversions) - - jintArray java_tokens = env->NewIntArray(token_size); - if (java_tokens == nullptr) { - env->ThrowNew(c_error_oom, "could not allocate token memory"); - return nullptr; - } - - env->SetIntArrayRegion(java_tokens, 0, token_size, reinterpret_cast(tokens.data())); - - return java_tokens; + return tokens_to_jint_array_impl(env, tokens, c_error_oom); } /** @@ -1045,14 +1059,8 @@ JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleCompletions(JNIE json data = parse_json_params(env, jparams); - auto completion_id = gen_chatcmplid(); - std::vector tasks; - if (!build_completion_tasks(env, ctx_server, data, completion_id, - SERVER_TASK_TYPE_COMPLETION, OAICOMPAT_TYPE_NONE, tasks)) return nullptr; - - const auto task_ids = dispatch_tasks(ctx_server, tasks); - - return collect_and_serialize(env, ctx_server, task_ids); + return dispatch_completion_and_serialize(env, ctx_server, data, + SERVER_TASK_TYPE_COMPLETION, OAICOMPAT_TYPE_NONE); } JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleCompletionsOai(JNIEnv *env, jobject obj, @@ -1070,14 +1078,8 @@ JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleCompletionsOai(J return nullptr; } - auto completion_id = gen_chatcmplid(); - std::vector tasks; - if (!build_completion_tasks(env, ctx_server, data, completion_id, - SERVER_TASK_TYPE_COMPLETION, OAICOMPAT_TYPE_COMPLETION, tasks)) return nullptr; - - const auto task_ids = dispatch_tasks(ctx_server, tasks); - - return collect_and_serialize(env, ctx_server, task_ids); + return dispatch_completion_and_serialize(env, ctx_server, data, + SERVER_TASK_TYPE_COMPLETION, OAICOMPAT_TYPE_COMPLETION); } /** @@ -1111,14 +1113,8 @@ JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleInfill(JNIEnv *e ctx_server->params_base.spm_infill, tokenized_prompts.empty() ? llama_tokens() : tokenized_prompts[0]); - auto completion_id = gen_chatcmplid(); - std::vector tasks; - if (!build_completion_tasks(env, ctx_server, data, completion_id, - SERVER_TASK_TYPE_INFILL, OAICOMPAT_TYPE_NONE, tasks)) return nullptr; - - const auto task_ids = dispatch_tasks(ctx_server, tasks); - - return collect_and_serialize(env, ctx_server, task_ids); + return dispatch_completion_and_serialize(env, ctx_server, data, + SERVER_TASK_TYPE_INFILL, OAICOMPAT_TYPE_NONE); } JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleEmbeddings(JNIEnv *env, jobject obj, diff --git a/src/main/cpp/jni_helpers.hpp b/src/main/cpp/jni_helpers.hpp index f017869c..001d7d0e 100644 --- a/src/main/cpp/jni_helpers.hpp +++ b/src/main/cpp/jni_helpers.hpp @@ -47,6 +47,7 @@ // 12. check_infill_support_impl // 13. append_task // 14. embedding_to_jfloat_array_impl +// 15. tokens_to_jint_array_impl #include "jni.h" #include "nlohmann/json.hpp" @@ -373,3 +374,27 @@ inline void append_task(server_context *ctx_server, reinterpret_cast(values.data())); return arr; } + +// --------------------------------------------------------------------------- +// tokens_to_jint_array_impl +// +// Converts a token vector to a Java jintArray. Symmetric with +// embedding_to_jfloat_array_impl for the int (token) case. +// +// On success: returns a new jintArray filled with the token values. +// On allocation failure: throws via JNI with oom_class and returns nullptr. +// --------------------------------------------------------------------------- +[[nodiscard]] inline jintArray tokens_to_jint_array_impl( + JNIEnv *env, + const std::vector &tokens, + jclass oom_class) { + const jsize len = static_cast(tokens.size()); + jintArray arr = env->NewIntArray(len); + if (arr == nullptr) { + env->ThrowNew(oom_class, "could not allocate token memory"); + return nullptr; + } + env->SetIntArrayRegion(arr, 0, len, + reinterpret_cast(tokens.data())); + return arr; +} diff --git a/src/test/cpp/test_jni_helpers.cpp b/src/test/cpp/test_jni_helpers.cpp index f6b972b3..9cec259f 100644 --- a/src/test/cpp/test_jni_helpers.cpp +++ b/src/test/cpp/test_jni_helpers.cpp @@ -14,7 +14,8 @@ // Layer B tests (need server.hpp + mock JNIEnv + pre-seeded server_response): // json_to_jstring_impl, results_to_jstring_impl, // build_completion_tasks_impl, recv_slot_task_result_impl, -// collect_task_results_impl, embedding_to_jfloat_array_impl +// collect_task_results_impl, embedding_to_jfloat_array_impl, +// tokens_to_jint_array_impl // // JNIEnv is mocked via a zero-filled JNINativeInterface_ table with only the // slots exercised by each test patched. server_response is used directly: @@ -598,3 +599,69 @@ TEST_F(FloatArrayFixture, EmbeddingToJfloatArray_AllocFails_ThrowsOomAndReturnsN EXPECT_TRUE(g_throw_called); EXPECT_EQ(g_throw_message, "could not allocate embedding"); } + +// ============================================================ +// tokens_to_jint_array_impl +// ============================================================ + +namespace { + +static bool g_int_new_called = false; +static jsize g_int_alloc_size = -1; +static jsize g_int_copied_size = -1; + +static jintArray JNICALL stub_NewIntArray(JNIEnv *, jsize n) { + g_int_new_called = true; + g_int_alloc_size = n; + return reinterpret_cast(0xF2); +} +static void JNICALL stub_SetIntArrayRegion(JNIEnv *, jintArray, jsize, jsize n, const jint *) { + g_int_copied_size = n; +} + +struct IntArrayFixture : MockJniFixture { + void SetUp() override { + MockJniFixture::SetUp(); + g_int_new_called = false; + g_int_alloc_size = -1; + g_int_copied_size = -1; + table.NewIntArray = stub_NewIntArray; + table.SetIntArrayRegion = stub_SetIntArrayRegion; + } +}; + +} // namespace + +TEST_F(IntArrayFixture, TokensToJintArray_ReturnsSentinel) { + std::vector v = {1, 2, 3}; + auto *result = tokens_to_jint_array_impl(env, v, dummy_class); + EXPECT_EQ(result, reinterpret_cast(0xF2)); +} + +TEST_F(IntArrayFixture, TokensToJintArray_AllocatesCorrectSize) { + std::vector v = {10, 20}; + tokens_to_jint_array_impl(env, v, dummy_class); + EXPECT_EQ(g_int_alloc_size, 2); +} + +TEST_F(IntArrayFixture, TokensToJintArray_CopiesAllElements) { + std::vector v(7, 42); + tokens_to_jint_array_impl(env, v, dummy_class); + EXPECT_EQ(g_int_copied_size, 7); +} + +TEST_F(IntArrayFixture, TokensToJintArray_EmptyVector_AllocatesZeroLen) { + std::vector v; + tokens_to_jint_array_impl(env, v, dummy_class); + EXPECT_EQ(g_int_alloc_size, 0); + EXPECT_FALSE(g_throw_called); +} + +TEST_F(IntArrayFixture, TokensToJintArray_AllocFails_ThrowsOomAndReturnsNull) { + table.NewIntArray = [](JNIEnv *, jsize) -> jintArray { return nullptr; }; + std::vector v = {1}; + auto *result = tokens_to_jint_array_impl(env, v, dummy_class); + EXPECT_EQ(result, nullptr); + EXPECT_TRUE(g_throw_called); + EXPECT_EQ(g_throw_message, "could not allocate token memory"); +} From 1b87aa55db2ce4ccac2e95887967527121587470 Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 6 Apr 2026 09:14:20 +0000 Subject: [PATCH 09/11] Disable GGML_NATIVE to fix AVX-512 crash on AMD runners MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit GGML_NATIVE=ON (the default for non-cross builds) causes MSVC to run FindSIMD.cmake, which probes the build machine via CheckCSourceRuns and then does a plain set(GGML_AVX512 ON) — a local CMake variable that shadows the cache in that scope. This meant our existing set(GGML_AVX512 OFF CACHE BOOL "" FORCE) was silently overridden inside FindSIMD's scope, and /arch:AVX512 was still being added to the compiler flags on the Intel Xeon build machine. The resulting jllama.dll contained AVX-512 instructions (EVEX prefix 0x62) that crashed with EXCEPTION_ILLEGAL_INSTRUCTION (0xc000001d) on the AMD EPYC 7763 GitHub Actions test runner, which has no AVX-512. Fix: also set GGML_NATIVE=OFF FORCE so FindSIMD.cmake is never executed. GGML then reads the individual GGML_AVX* cache variables as set here. GGML_AVX2 is set ON explicitly (it is GGML's default, but explicit is clearer) and GGML_AVX512 stays OFF, producing an AVX2 binary that runs correctly on any x86-64 CPU built since 2013. https://claude.ai/code/session_0197uhVYKafh2feJkLuq7KiT --- CMakeLists.txt | 52 ++++++++++++++++++++++++-------------------------- 1 file changed, 25 insertions(+), 27 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 2decc648..c744446e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -53,40 +53,38 @@ if(WIN32) set(LLAMA_BUILD_BORINGSSL ON CACHE BOOL "" FORCE) endif() -# Disable AVX-512 for all builds. +# Instruction-set policy: target AVX2, disable AVX-512, disable native-CPU detection. # -# AVX-512 causes problems in several well-known scenarios, so we turn it off -# unconditionally to keep the distributed library safe and broadly compatible. +# GGML_NATIVE (default ON) compiles the library for the exact CPU that runs the +# build. On MSVC it runs FindSIMD.cmake, which probes the build machine with +# CheckCSourceRuns and then does a plain set(GGML_AVX512 ON) — a local variable +# that shadows the cache — before the /arch:AVX512 flag is added. This means +# a plain GGML_AVX512=OFF FORCE in the parent scope is not enough: FindSIMD +# silently re-enables AVX-512 inside its own scope. Turning GGML_NATIVE OFF +# prevents FindSIMD from running at all, so the individual GGML_AVX* cache +# variables are respected. # -# 1. Many CPUs simply do not support AVX-512 at all. Loading a binary that -# contains AVX-512 instructions on such a machine crashes the JVM immediately -# with SIGILL (illegal instruction). A real example from CI: the AMD EPYC -# 7763 used by GitHub Actions runners has no AVX-512, while the Intel Xeon -# Platinum 8370C used in the manylinux build container does. A library built -# on the Xeon would crash on the EPYC. +# AVX-512 is disabled for several reasons: +# +# 1. Many CPUs do not support AVX-512. Loading such a binary crashes the JVM +# immediately with SIGILL. Real CI example: AMD EPYC 7763 (test runner, no +# AVX-512) vs Intel Xeon Platinum 8370C (build machine, has AVX-512). # # 2. Intel disabled AVX-512 on 12th-gen (Alder Lake) and later desktop CPUs. -# Alder Lake mixes big P-cores (which have AVX-512 hardware) with small -# E-cores (which do not). Because the OS scheduler can move a thread between -# core types at any time, running AVX-512 instructions is unsafe. Intel -# first locked it out via BIOS/microcode updates, then began physically fusing -# it off in silicon on new chips. All current shipping Alder Lake, Raptor -# Lake, and newer Intel desktop CPUs have AVX-512 permanently disabled. +# The hybrid P+E-core design makes AVX-512 unsafe at the OS scheduler level; +# Intel has since fused it off in silicon on all current desktop chips. # -# 3. Even CPUs that do support AVX-512 often throttle their clock speed while -# executing it. Older Intel server chips (Skylake-X era) can drop hundreds -# of MHz across all cores — even cores not running AVX-512. AMD Ryzen 9000 -# (Zen 5) drops roughly 10 % in frequency when AVX-512 is active and takes -# more than 100 ms to ramp back up after the workload ends. +# 3. CPUs that do support AVX-512 often throttle clock speed while executing it, +# and take 100+ ms to ramp back up — a net loss for bursty inference workloads. # -# 4. AVX-512's high instantaneous power draw can destabilize systems that are -# running near their voltage limits, leading to random crashes. +# 4. High instantaneous power draw from AVX-512 can destabilise systems running +# near their voltage limits. # -# Keeping AVX-512 off produces a binary that runs correctly on any x86-64 -# machine with at least AVX2 support, which covers virtually all hardware -# built since 2013. FORCE prevents llama.cpp's own cmake defaults from -# re-enabling it regardless of the build flags passed on the command line. -set(GGML_AVX512 OFF CACHE BOOL "" FORCE) +# AVX2 (default ON in GGML) covers virtually all x86-64 hardware built since +# 2013 and is set explicitly here for clarity. +set(GGML_NATIVE OFF CACHE BOOL "" FORCE) +set(GGML_AVX2 ON CACHE BOOL "" FORCE) +set(GGML_AVX512 OFF CACHE BOOL "" FORCE) FetchContent_Declare( llama.cpp GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git From 8c5f81e165d5f259f4f7b84cdc06747030885313 Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 6 Apr 2026 09:48:28 +0000 Subject: [PATCH 10/11] Enable GGML_BMI2 to fix __pdep_u64 linker error on Windows MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit With GGML_NATIVE=OFF and only /arch:AVX2, MSVC no longer implicitly enables BMI2 intrinsics (previously /arch:AVX512 pulled them in for free). GGML's quants.c wraps _pdep_u64 in an inline __pdep_u64 helper gated on #if defined(__BMI2__) || defined(GGML_BMI2). Without that define the wrapper is absent and the linker cannot resolve the symbol referenced by _ggml_vec_dot_iq1_m_q8_K. GGML_BMI2=ON adds -D__BMI2__ -DGGML_BMI2 for MSVC (no extra /arch flag needed — MSVC exposes BMI2 intrinsics under /arch:AVX2 once the macros are defined). BMI2 is safe to require alongside AVX2: every CPU with AVX2 also has BMI2 (Intel Haswell 2013+, AMD Ryzen 2017+). https://claude.ai/code/session_0197uhVYKafh2feJkLuq7KiT --- CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index c744446e..0bb4f5e8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -84,6 +84,7 @@ endif() # 2013 and is set explicitly here for clarity. set(GGML_NATIVE OFF CACHE BOOL "" FORCE) set(GGML_AVX2 ON CACHE BOOL "" FORCE) +set(GGML_BMI2 ON CACHE BOOL "" FORCE) set(GGML_AVX512 OFF CACHE BOOL "" FORCE) FetchContent_Declare( llama.cpp From cd2c374f8fdc92cdaaa4a9c2373b56e25f9f0e3c Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 6 Apr 2026 10:10:45 +0000 Subject: [PATCH 11/11] Target full Haswell baseline: add SSE42, AVX, FMA, F16C flags The previous set (AVX2 + BMI2) was incomplete for GCC/Clang builds. GGML's ggml-cpu cmake uses independent `if` statements on non-MSVC, so each ISA feature needs its own flag: -msse4.2, -mavx, -mfma, -mf16c. Without them, GGML code paths gated on #if defined(GGML_FMA) or #if defined(GGML_F16C) were silently not compiled on Linux/macOS. On MSVC the elseif chain picks /arch:AVX2 and bundles GGML_FMA + GGML_F16C automatically, so SSE42/AVX/FMA/F16C flags are no-ops there but harmless. The full set now matches GGML's own "haswell" named variant in GGML_CPU_ALL_VARIANTS: SSE42 + AVX + AVX2 + BMI2 + FMA + F16C. This covers Intel Haswell (2013+) and AMD Ryzen/EPYC (2017+). Also rewrote the block comment to explain the Haswell target, the MSVC vs GCC/Clang difference, and why each flag is explicit. https://claude.ai/code/session_0197uhVYKafh2feJkLuq7KiT --- CMakeLists.txt | 48 +++++++++++++++++++++++------------------------- 1 file changed, 23 insertions(+), 25 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 0bb4f5e8..bf37f7e2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -53,38 +53,36 @@ if(WIN32) set(LLAMA_BUILD_BORINGSSL ON CACHE BOOL "" FORCE) endif() -# Instruction-set policy: target AVX2, disable AVX-512, disable native-CPU detection. +# Instruction-set policy: target the "Haswell" baseline (x86-64-v3). # -# GGML_NATIVE (default ON) compiles the library for the exact CPU that runs the -# build. On MSVC it runs FindSIMD.cmake, which probes the build machine with -# CheckCSourceRuns and then does a plain set(GGML_AVX512 ON) — a local variable -# that shadows the cache — before the /arch:AVX512 flag is added. This means -# a plain GGML_AVX512=OFF FORCE in the parent scope is not enough: FindSIMD -# silently re-enables AVX-512 inside its own scope. Turning GGML_NATIVE OFF -# prevents FindSIMD from running at all, so the individual GGML_AVX* cache -# variables are respected. +# This set of flags matches GGML's own "haswell" named variant in +# GGML_CPU_ALL_VARIANTS and covers every x86-64 CPU since: +# - Intel Haswell (2013) +# - AMD Ryzen / EPYC (2017) # -# AVX-512 is disabled for several reasons: +# GGML_NATIVE is OFF so the build never probes the build machine's CPU. +# Without this, MSVC runs FindSIMD.cmake which shadow-sets GGML_AVX512=ON +# via a local variable that bypasses our CACHE FORCE, and GCC/Clang uses +# -march=native which embeds whatever the build machine supports. # -# 1. Many CPUs do not support AVX-512. Loading such a binary crashes the JVM -# immediately with SIGILL. Real CI example: AMD EPYC 7763 (test runner, no -# AVX-512) vs Intel Xeon Platinum 8370C (build machine, has AVX-512). +# The individual flags are set explicitly because with GGML_NATIVE=OFF +# they all default to OFF. On MSVC the elseif chain in ggml-cpu cmake +# picks the highest level (/arch:AVX2) and bundles FMA + F16C defines +# automatically; SSE42, AVX, FMA, F16C have no additional effect there +# but are needed for GCC/Clang where each flag independently adds its +# -m flag and GGML_* preprocessor define. # -# 2. Intel disabled AVX-512 on 12th-gen (Alder Lake) and later desktop CPUs. -# The hybrid P+E-core design makes AVX-512 unsafe at the OS scheduler level; -# Intel has since fused it off in silicon on all current desktop chips. -# -# 3. CPUs that do support AVX-512 often throttle clock speed while executing it, -# and take 100+ ms to ramp back up — a net loss for bursty inference workloads. -# -# 4. High instantaneous power draw from AVX-512 can destabilise systems running -# near their voltage limits. -# -# AVX2 (default ON in GGML) covers virtually all x86-64 hardware built since -# 2013 and is set explicitly here for clarity. +# AVX-512 stays OFF: +# - Many CPUs lack it (AMD EPYC 7763, all Intel desktop since Alder Lake). +# - MSVC's /arch:AVX512 applies to the entire TU — no per-function gating. +# - Frequency throttling and power draw make it a net loss for bursty work. set(GGML_NATIVE OFF CACHE BOOL "" FORCE) +set(GGML_SSE42 ON CACHE BOOL "" FORCE) +set(GGML_AVX ON CACHE BOOL "" FORCE) set(GGML_AVX2 ON CACHE BOOL "" FORCE) set(GGML_BMI2 ON CACHE BOOL "" FORCE) +set(GGML_FMA ON CACHE BOOL "" FORCE) +set(GGML_F16C ON CACHE BOOL "" FORCE) set(GGML_AVX512 OFF CACHE BOOL "" FORCE) FetchContent_Declare( llama.cpp