diff --git a/CLAUDE.md b/CLAUDE.md index 56b75e38..4298e12c 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -6,7 +6,7 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co Java bindings for [llama.cpp](https://github.com/ggerganov/llama.cpp) via JNI, providing a high-level API for LLM inference in Java. The Java layer communicates with a native C++ library through JNI. -Current llama.cpp pinned version: **b8854** +Current llama.cpp pinned version: **b8887** ## Upgrading CUDA Version @@ -137,7 +137,7 @@ Also review the project `CMakeLists.txt` for build-system-level breaks (e.g. ren `ggml/include/ggml.h`, `ggml/include/ggml-backend.h`, `ggml/include/ggml-opt.h`, `ggml-alloc.h`, `ggml-cpu.h`, `peg-parser.h`, `base64.hpp` -**Known breaking changes by version range** (b5022 → b8841): +**Known breaking changes by version range** (b5022 → b8887): | Version | File | Change | |---------|------|--------| @@ -159,6 +159,9 @@ Also review the project `CMakeLists.txt` for build-system-level breaks (e.g. ren | ~b8841–b8854 | `common/common.h` | `common_params::clear_idle` renamed to `cache_idle_slots`; new `common_context_seq_rm_type` enum + `common_context_can_seq_rm()` replacing `common_speculative_is_compat()`; `get_model_endpoint()` → `common_get_model_endpoint()` | | ~b8841–b8854 | `tools/mtmd/mtmd.h` + `mtmd-helper.h` | `mtmd_decoder_pos` gains `z` field; `mtmd_image_tokens_get_decoder_pos()` + `mtmd_helper_image_get_decoder_pos()` gain new `pos_0` parameter | | ~b8841–b8854 | project `utils.hpp` / `server.hpp` | `server_tokens::get_text_tokens()` split: `get_tokens()` returns raw `const llama_tokens &`; new `get_text_tokens()` returns filtered copy (removes `LLAMA_TOKEN_NULL` mtmd placeholders); save/load and context-shift call sites updated to `get_tokens()` | +| ~b8854–b8887 | `common/chat.h` | `common_chat_msg_diff_to_json_oaicompat` removed; moved to `tools/server/server-chat.cpp`; project defines it locally in `server.hpp` — importing server-chat.cpp is impractical because it pulls in `convert_transcriptions_to_chatcmpl` → `get_media_marker` → `server-common.cpp` | +| ~b8854–b8887 | `common/common.h` | `common_params::reasoning_budget` and `reasoning_budget_message` moved into `common_params::sampling` sub-struct as `reasoning_budget_tokens`; update: `params_base.reasoning_budget` → `params_base.sampling.reasoning_budget_tokens` | +| ~b8854–b8887 | `common/fit.h` (new) | `llama_params_fit` and `llama_memory_breakdown_print` removed from `include/llama.h`; now `common_fit_params` / `common_memory_breakdown_print` in new `common/fit.h`; not used directly by project | ## Build Commands diff --git a/CMakeLists.txt b/CMakeLists.txt index c5ebdd14..86c365a3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -97,7 +97,7 @@ set(GGML_AVX512 OFF CACHE BOOL "" FORCE) FetchContent_Declare( llama.cpp GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b8854 + GIT_TAG b8887 ) FetchContent_MakeAvailable(llama.cpp) @@ -208,10 +208,19 @@ if(NOT JNI_INCLUDE_DIRS) endif() endif() -add_library(jllama SHARED src/main/cpp/jllama.cpp src/main/cpp/server.hpp src/main/cpp/utils.hpp) +add_library(jllama SHARED + src/main/cpp/jllama.cpp + src/main/cpp/server.hpp + src/main/cpp/utils.hpp + ${llama.cpp_SOURCE_DIR}/tools/server/server-common.cpp + ${llama.cpp_SOURCE_DIR}/tools/server/server-chat.cpp) set_target_properties(jllama PROPERTIES POSITION_INDEPENDENT_CODE ON) -target_include_directories(jllama PRIVATE src/main/cpp ${JNI_INCLUDE_DIRS} ${llama.cpp_SOURCE_DIR}/tools/mtmd) +target_include_directories(jllama PRIVATE + src/main/cpp + ${JNI_INCLUDE_DIRS} + ${llama.cpp_SOURCE_DIR}/tools/mtmd + ${llama.cpp_SOURCE_DIR}/tools/server) target_link_libraries(jllama PRIVATE llama-common mtmd llama nlohmann_json) target_compile_features(jllama PRIVATE cxx_std_11) @@ -258,7 +267,8 @@ if(BUILD_TESTING) src/test/cpp/test_server.cpp src/test/cpp/test_jni_helpers.cpp src/test/cpp/test_json_helpers.cpp - ) + ${llama.cpp_SOURCE_DIR}/tools/server/server-common.cpp + ${llama.cpp_SOURCE_DIR}/tools/server/server-chat.cpp) target_include_directories(jllama_test PRIVATE src/main/cpp @@ -266,6 +276,7 @@ if(BUILD_TESTING) ${llama.cpp_SOURCE_DIR}/tools/mtmd # jni.h / jni_md.h needed by jni_helpers.hpp (mock JNI tests, no JVM required) ${JNI_INCLUDE_DIRS} + ${llama.cpp_SOURCE_DIR}/tools/server ) target_link_libraries(jllama_test PRIVATE llama-common mtmd llama nlohmann_json GTest::gtest_main) target_compile_features(jllama_test PRIVATE cxx_std_17) diff --git a/README.md b/README.md index 1dcf44a8..a264f7df 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ ![Java 8+](https://img.shields.io/badge/Java-8%2B-informational) -[![llama.cpp b8854](https://img.shields.io/badge/llama.cpp-%23b8854-informational)](https://github.com/ggml-org/llama.cpp/releases/tag/b8854) +[![llama.cpp b8887](https://img.shields.io/badge/llama.cpp-%23b8887-informational)](https://github.com/ggml-org/llama.cpp/releases/tag/b8887) # Java Bindings for [llama.cpp](https://github.com/ggerganov/llama.cpp) diff --git a/src/main/cpp/jllama.cpp b/src/main/cpp/jllama.cpp index 77868ac0..17d7d6df 100644 --- a/src/main/cpp/jllama.cpp +++ b/src/main/cpp/jllama.cpp @@ -779,8 +779,8 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo // print sample chat example to make it clear which template is used LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, - common_chat_templates_source(ctx_server->chat_templates.get()).c_str(), - common_chat_format_example(ctx_server->chat_templates.get(), ctx_server->params_base.use_jinja, ctx_server->params_base.default_template_kwargs).c_str()); + common_chat_templates_source(ctx_server->oai_parser_opt.tmpls.get()).c_str(), + common_chat_format_example(ctx_server->oai_parser_opt.tmpls.get(), ctx_server->params_base.use_jinja, ctx_server->params_base.default_template_kwargs).c_str()); ctx_server->queue_tasks.on_new_task( std::bind(&server_context::process_single_task, ctx_server, std::placeholders::_1)); @@ -912,12 +912,12 @@ JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleRerank(JNIEnv *e auto document_vector = std::vector(document_array, document_array + amount_documents); free_string_array(document_array, amount_documents); - std::vector tokenized_docs = tokenize_input_prompts(ctx_server->vocab, document_vector, true, true); + std::vector tokenized_docs = tokenize_input_prompts(ctx_server->vocab, nullptr, document_vector, true, true); tasks.reserve(tokenized_docs.size()); for (size_t i = 0; i < tokenized_docs.size(); i++) { append_task(ctx_server, tasks, SERVER_TASK_TYPE_RERANK, - format_rerank(ctx_server->vocab, tokenized_query, tokenized_docs[i]), i); + format_rerank(ctx_server->vocab, tokenized_query, tokenized_docs[i].get_tokens()), i); } std::vector results; if (!dispatch_and_collect(env, ctx_server, std::move(tasks), results)) return nullptr; @@ -983,9 +983,9 @@ JNIEXPORT jintArray JNICALL Java_de_kherud_llama_LlamaModel_encode(JNIEnv *env, static std::string detokenize(const server_context *ctx_server, const std::vector &tokens) { if (!ctx_server->is_vocab_only()) { - return tokens_to_str(ctx_server->ctx, tokens.cbegin(), tokens.cend()); + return tokens_to_str(ctx_server->ctx, tokens); } - return tokens_to_str(ctx_server->vocab, tokens.cbegin(), tokens.cend()); + return tokens_to_str(ctx_server->vocab, tokens); } JNIEXPORT jbyteArray JNICALL Java_de_kherud_llama_LlamaModel_decodeBytes(JNIEnv *env, jobject obj, @@ -1115,13 +1115,13 @@ JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleInfill(JNIEnv *e // Format the infill prompt std::string prompt = json_value(data, "prompt", std::string()); - std::vector tokenized_prompts = tokenize_input_prompts(ctx_server->vocab, prompt, false, true); + std::vector tokenized_prompts = tokenize_input_prompts(ctx_server->vocab, nullptr, prompt, false, true); data["prompt"] = format_infill(ctx_server->vocab, data.at("input_prefix"), data.at("input_suffix"), data.at("input_extra"), ctx_server->params_base.n_batch, ctx_server->params_base.n_predict, ctx_server->slots[0].n_ctx, ctx_server->params_base.spm_infill, - tokenized_prompts.empty() ? llama_tokens() : tokenized_prompts[0]); + tokenized_prompts.empty() ? llama_tokens() : tokenized_prompts[0].get_tokens()); return dispatch_completion_and_serialize(env, ctx_server, data, SERVER_TASK_TYPE_INFILL, OAICOMPAT_TYPE_NONE); @@ -1155,10 +1155,10 @@ JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleEmbeddings(JNIEn } if (force_no_oaicompat) oaicompat = OAICOMPAT_TYPE_NONE; - std::vector tokenized_prompts = tokenize_input_prompts(ctx_server->vocab, prompt, true, true); + std::vector tokenized_prompts = tokenize_input_prompts(ctx_server->vocab, nullptr, prompt, true, true); for (const auto &tokens : tokenized_prompts) { - if (tokens.empty()) { + if (tokens.get_tokens().empty()) { env->ThrowNew(c_llama_error, "Input content cannot be empty"); return nullptr; } @@ -1168,7 +1168,7 @@ JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleEmbeddings(JNIEn tasks.reserve(tokenized_prompts.size()); for (size_t i = 0; i < tokenized_prompts.size(); i++) { - append_task(ctx_server, tasks, SERVER_TASK_TYPE_EMBEDDING, tokenized_prompts[i], i, oaicompat); + append_task(ctx_server, tasks, SERVER_TASK_TYPE_EMBEDDING, tokenized_prompts[i].get_tokens(), i, oaicompat); } std::vector results; diff --git a/src/main/cpp/jni_helpers.hpp b/src/main/cpp/jni_helpers.hpp index 17614a07..e02c27bd 100644 --- a/src/main/cpp/jni_helpers.hpp +++ b/src/main/cpp/jni_helpers.hpp @@ -219,8 +219,8 @@ struct jllama_context { try { const auto &prompt = data.at("prompt"); // throws before ctx_server is touched - std::vector tokenized_prompts = - tokenize_input_prompts(ctx_server->vocab, prompt, true, true); + std::vector tokenized_prompts = + tokenize_input_prompts(ctx_server->vocab, nullptr, prompt, true, true); tasks.reserve(tokenized_prompts.size()); for (size_t i = 0; i < tokenized_prompts.size(); i++) { @@ -228,7 +228,7 @@ struct jllama_context { task.id = ctx_server->queue_tasks.get_new_id(); task.index = i; - task.prompt_tokens = server_tokens(tokenized_prompts[i], false); + task.prompt_tokens = std::move(tokenized_prompts[i]); task.params = server_task::params_from_json_cmpl( ctx_server->ctx, ctx_server->params_base, data); task.id_selected_slot = json_value(data, "id_slot", -1); diff --git a/src/main/cpp/json_helpers.hpp b/src/main/cpp/json_helpers.hpp index e47155d9..a3736419 100644 --- a/src/main/cpp/json_helpers.hpp +++ b/src/main/cpp/json_helpers.hpp @@ -121,7 +121,7 @@ responses.push_back(result->to_json()); } if (oaicompat == OAICOMPAT_TYPE_EMBEDDING) { - return format_embeddings_response_oaicompat(body, responses, use_base64); + return format_embeddings_response_oaicompat(body, json_value(body, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), responses, use_base64); } return responses; } diff --git a/src/main/cpp/server.hpp b/src/main/cpp/server.hpp index 1e9604f0..61fa5826 100644 --- a/src/main/cpp/server.hpp +++ b/src/main/cpp/server.hpp @@ -1,4 +1,5 @@ #include "chat.h" +#include "server-chat.h" #include "utils.hpp" #include "arg.h" @@ -25,8 +26,6 @@ #include #include -using json = nlohmann::ordered_json; - constexpr int HTTP_POLLING_SECONDS = 1; enum stop_type { @@ -72,16 +71,7 @@ enum oaicompat_type { OAICOMPAT_TYPE_EMBEDDING, }; -// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11 -enum error_type { - ERROR_TYPE_INVALID_REQUEST, - ERROR_TYPE_AUTHENTICATION, - ERROR_TYPE_SERVER, - ERROR_TYPE_NOT_FOUND, - ERROR_TYPE_PERMISSION, - ERROR_TYPE_UNAVAILABLE, // custom error - ERROR_TYPE_NOT_SUPPORTED, // custom error -}; +// error_type enum provided by server-common.h (via utils.hpp) static bool server_task_type_need_embd(server_task_type task_type) { switch (task_type) { @@ -614,6 +604,7 @@ inline std::string oaicompat_finish_reason(stop_type stop, bool has_tool_calls = return "length"; } + struct completion_token_output { llama_token tok; float prob; @@ -821,7 +812,7 @@ struct server_task_result_cmpl_final : server_task_result { json{ {"finish_reason", nullptr}, {"index", index}, - {"delta", common_chat_msg_diff_to_json_oaicompat(diff)}, + {"delta", server_chat_msg_diff_to_json_oaicompat(diff)}, }, })}, {"created", t}, @@ -989,7 +980,7 @@ struct server_task_result_cmpl_partial : server_task_result { } for (const auto &diff : oaicompat_msg_diffs) { - add_delta(common_chat_msg_diff_to_json_oaicompat(diff)); + add_delta(server_chat_msg_diff_to_json_oaicompat(diff)); } if (!deltas.empty()) { @@ -1058,46 +1049,7 @@ struct server_task_result_rerank : server_task_result { } }; -// this function maybe used outside of server_task_result_error -static json format_error_response(const std::string &message, const enum error_type type) { - std::string type_str; - int code = 500; - switch (type) { - case ERROR_TYPE_INVALID_REQUEST: - type_str = "invalid_request_error"; - code = 400; - break; - case ERROR_TYPE_AUTHENTICATION: - type_str = "authentication_error"; - code = 401; - break; - case ERROR_TYPE_NOT_FOUND: - type_str = "not_found_error"; - code = 404; - break; - case ERROR_TYPE_SERVER: - type_str = "server_error"; - code = 500; - break; - case ERROR_TYPE_PERMISSION: - type_str = "permission_error"; - code = 403; - break; - case ERROR_TYPE_NOT_SUPPORTED: - type_str = "not_supported_error"; - code = 501; - break; - case ERROR_TYPE_UNAVAILABLE: - type_str = "unavailable_error"; - code = 503; - break; - } - return json{ - {"code", code}, - {"message", message}, - {"type", type_str}, - }; -} +// format_error_response is provided by server-common.h / server-common.cpp struct server_task_result_error : server_task_result { int index = 0; @@ -1848,8 +1800,7 @@ struct server_context { // Necessary similarity of prompt for slot selection float slot_prompt_similarity = 0.0f; - common_chat_templates_ptr chat_templates; - oaicompat_parser_options oai_parser_opt; + server_chat_params oai_parser_opt; // Returns true when the model was loaded in vocab-only mode: // the vocabulary is available but no inference context was created. @@ -1955,15 +1906,15 @@ struct server_context { params_base.speculative.cparams_dft = common_context_params_to_llama(params_dft); } - chat_templates = common_chat_templates_init(model, params_base.chat_template); + oai_parser_opt.tmpls = common_chat_templates_init(model, params_base.chat_template); try { - common_chat_format_example(chat_templates.get(), params.use_jinja, params.default_template_kwargs); + common_chat_format_example(oai_parser_opt.tmpls.get(), params.use_jinja, params.default_template_kwargs); } catch (const std::exception &e) { SRV_WRN("%s: Chat template parsing error: %s\n", __func__, e.what()); SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. " "This may cause the model to output suboptimal responses\n", __func__); - chat_templates = common_chat_templates_init(model, "chatml"); + oai_parser_opt.tmpls = common_chat_templates_init(model, "chatml"); } std::string &mmproj_path = params_base.mmproj.path; @@ -2058,15 +2009,14 @@ struct server_context { metrics.init(); - oai_parser_opt = { - /* use_jinja */ params_base.use_jinja, - /* prefill_assistant */ params_base.prefill_assistant, - /* reasoning_format */ params_base.reasoning_format, - /* common_chat_templates */ chat_templates.get(), - /* allow_image */ mctx ? mtmd_support_vision(mctx) : false, - /* allow_audio */ mctx ? mtmd_support_audio(mctx) : false, - /* enable_thinking */ params_base.reasoning_budget != 0, - }; + oai_parser_opt.use_jinja = params_base.use_jinja; + oai_parser_opt.prefill_assistant = params_base.prefill_assistant; + oai_parser_opt.reasoning_format = params_base.reasoning_format; + oai_parser_opt.allow_image = mctx ? mtmd_support_vision(mctx) : false; + oai_parser_opt.allow_audio = mctx ? mtmd_support_audio(mctx) : false; + oai_parser_opt.enable_thinking = params_base.enable_reasoning != 0 && + params_base.use_jinja && + common_chat_templates_support_enable_thinking(oai_parser_opt.tmpls.get()); } server_slot *get_slot_by_id(int id) { @@ -3267,9 +3217,11 @@ struct server_context { // check if we should process the image if (slot.n_past < slot.n_prompt_tokens && slot.prompt_tokens[slot.n_past] == LLAMA_TOKEN_NULL) { // process the image - int32_t new_n_past; - int32_t res = slot.prompt_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past); - int32_t n_pos = new_n_past - slot.n_past; + size_t n_tokens_out; + int32_t res = slot.prompt_tokens.process_chunk(ctx, mctx, static_cast(slot.n_past), + static_cast(slot.n_past), + slot.id, n_tokens_out); + int32_t n_pos = static_cast(n_tokens_out); if (res != 0) { SLT_ERR(slot, "failed to process image, res = %d\n", res); @@ -3280,7 +3232,7 @@ struct server_context { // add the image chunk to cache { - const auto &chunk = slot.prompt_tokens.find_chunk(slot.n_past); + const auto &chunk = slot.prompt_tokens.find_chunk(static_cast(slot.n_past)); slot.cache_tokens.push_back(chunk.get()); // copy } diff --git a/src/main/cpp/utils.hpp b/src/main/cpp/utils.hpp index 0b60fe52..edbae760 100644 --- a/src/main/cpp/utils.hpp +++ b/src/main/cpp/utils.hpp @@ -1,17 +1,14 @@ #pragma once +// server-common.h provides: JSON_ASSERT, json, raw_buffer, json_value, +// server_grammar_trigger, server_tokens, error_type, SRV_*/SLT_* macros, +// and many utility function declarations (implemented in server-common.cpp). +#include "server-common.h" + #include "download.h" // common_remote_get_content, common_remote_params #include "base64.hpp" -#include "chat.h" #include "build-info.h" -#include "common.h" -#include "llama.h" -#include "log.h" #include "mtmd-helper.h" -#include "mtmd.h" - -#define JSON_ASSERT GGML_ASSERT -#include #include #include @@ -22,8 +19,12 @@ #define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo" -using json = nlohmann::ordered_json; - +// server-common.h uses slot.task->id; redefine with our simpler slot.id_task +#undef SLT_INF +#undef SLT_CNT +#undef SLT_WRN +#undef SLT_ERR +#undef SLT_DBG #define SLT_INF(slot, fmt, ...) \ LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) #define SLT_WRN(slot, fmt, ...) \ @@ -33,268 +34,11 @@ using json = nlohmann::ordered_json; #define SLT_DBG(slot, fmt, ...) \ LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) -#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) - #define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) #define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) #define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) #define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) -using raw_buffer = std::vector; - -template static T json_value(const json &body, const std::string &key, const T &default_value) { - // Fallback null to default value - if (body.contains(key) && !body.at(key).is_null()) { - try { - return body.at(key); - } catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const &) { - LOG_WRN("Wrong type supplied for parameter '%s'. Expected '%s', using default value\n", key.c_str(), - json(default_value).type_name()); - return default_value; - } - } else { - return default_value; - } -} - -// build_info removed in b8831; use llama_build_info() from build-info.h - -// thin wrapper around common_grammar_trigger with (de)serialization functions -struct server_grammar_trigger { - common_grammar_trigger value; - - server_grammar_trigger() = default; - server_grammar_trigger(const common_grammar_trigger &value) : value(value) {} - server_grammar_trigger(const json &in) { - value.type = (common_grammar_trigger_type)in.at("type").get(); - value.value = in.at("value").get(); - if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) { - value.token = (llama_token)in.at("token").get(); - } - } - - json to_json() const { - json out{ - {"type", (int)value.type}, - {"value", value.value}, - }; - if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) { - out["token"] = (int)value.token; - } - return out; - } -}; - -// -// tokenizer and input processing utils -// - -static bool json_is_array_of_numbers(const json &data) { - if (data.is_array()) { - for (const auto &e : data) { - if (!e.is_number_integer()) { - return false; - } - } - return true; - } - return false; -} - -// is array having BOTH numbers & strings? -static bool json_is_array_of_mixed_numbers_strings(const json &data) { - bool seen_string = false; - bool seen_number = false; - if (data.is_array()) { - for (const auto &e : data) { - seen_string |= e.is_string(); - seen_number |= e.is_number_integer(); - if (seen_number && seen_string) { - return true; - } - } - } - return false; -} - -// get value by path(key1 / key2) -static json json_get_nested_values(const std::vector &paths, const json &js) { - json result = json::object(); - - for (const std::string &path : paths) { - json current = js; - const auto keys = string_split(path, /*separator*/ '/'); - bool valid_path = true; - for (const std::string &k : keys) { - if (valid_path && current.is_object() && current.contains(k)) { - current = current[k]; - } else { - valid_path = false; - } - } - if (valid_path) { - result[path] = current; - } - } - return result; -} - -/** - * this handles 2 cases: - * - only string, example: "string" - * - mixed string and tokens, example: [12, 34, "string", 56, 78] - */ -static llama_tokens tokenize_mixed(const llama_vocab *vocab, const json &json_prompt, bool add_special, - bool parse_special) { - // If `add_bos` is true, we only add BOS, when json_prompt is a string, - // or the first element of the json_prompt array is a string. - llama_tokens prompt_tokens; - - if (json_prompt.is_array()) { - bool first = true; - for (const auto &p : json_prompt) { - if (p.is_string()) { - auto s = p.template get(); - - llama_tokens p; - if (first) { - p = common_tokenize(vocab, s, add_special, parse_special); - first = false; - } else { - p = common_tokenize(vocab, s, false, parse_special); - } - - prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end()); - } else { - if (first) { - first = false; - } - - prompt_tokens.push_back(p.template get()); - } - } - } else { - auto s = json_prompt.template get(); - prompt_tokens = common_tokenize(vocab, s, add_special, parse_special); - } - - return prompt_tokens; -} - -/** - * break the input "prompt" object into multiple prompt if needed, then tokenize them - * this supports these cases: - * - "prompt": "string" - * - "prompt": [12, 34, 56] - * - "prompt": [12, 34, "string", 56, 78] - * and multiple prompts (multi-tasks): - * - "prompt": ["string1", "string2"] - * - "prompt": ["string1", [12, 34, 56]] - * - "prompt": [[12, 34, 56], [78, 90, 12]] - * - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]] - */ -static std::vector tokenize_input_prompts(const llama_vocab *vocab, const json &json_prompt, - bool add_special, bool parse_special) { - std::vector result; - if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(json_prompt)) { - // string or mixed - result.push_back(tokenize_mixed(vocab, json_prompt, add_special, parse_special)); - } else if (json_is_array_of_numbers(json_prompt)) { - // array of tokens - result.push_back(json_prompt.get()); - } else if (json_prompt.is_array()) { - // array of prompts - result.reserve(json_prompt.size()); - for (const auto &p : json_prompt) { - if (p.is_string() || json_is_array_of_mixed_numbers_strings(p)) { - result.push_back(tokenize_mixed(vocab, p, add_special, parse_special)); - } else if (json_is_array_of_numbers(p)) { - // array of tokens - result.push_back(p.get()); - } else { - throw std::runtime_error( - "element of \"prompt\" must be a string, an list of tokens, or a list of mixed strings & tokens"); - } - } - } else { - throw std::runtime_error( - "\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts"); - } - if (result.empty()) { - throw std::runtime_error("\"prompt\" must not be empty"); - } - return result; -} - -// return the last index of character that can form a valid string -// if the last character is potentially cut in half, return the index before the cut -// if validate_utf8(text) == text.size(), then the whole text is valid utf8 -static size_t validate_utf8(const std::string &text) { - size_t len = text.size(); - if (len == 0) - return 0; - - // Check the last few bytes to see if a multi-byte character is cut off - for (size_t i = 1; i <= 4 && i <= len; ++i) { - unsigned char c = text[len - i]; - // Check for start of a multi-byte sequence from the end - if ((c & 0xE0) == 0xC0) { - // 2-byte character start: 110xxxxx - // Needs at least 2 bytes - if (i < 2) - return len - i; - } else if ((c & 0xF0) == 0xE0) { - // 3-byte character start: 1110xxxx - // Needs at least 3 bytes - if (i < 3) - return len - i; - } else if ((c & 0xF8) == 0xF0) { - // 4-byte character start: 11110xxx - // Needs at least 4 bytes - if (i < 4) - return len - i; - } - } - - // If no cut-off multi-byte character is found, return full length - return len; -} - -static bool is_valid_utf8(const std::string &str) { - const unsigned char *bytes = reinterpret_cast(str.data()); - const unsigned char *end = bytes + str.length(); - - while (bytes < end) { - if (*bytes <= 0x7F) { - // 1-byte sequence (0xxxxxxx) - bytes++; - } else if ((*bytes & 0xE0) == 0xC0) { - // 2-byte sequence (110xxxxx 10xxxxxx) - if (end - bytes < 2 || (bytes[1] & 0xC0) != 0x80) - return false; - bytes += 2; - } else if ((*bytes & 0xF0) == 0xE0) { - // 3-byte sequence (1110xxxx 10xxxxxx 10xxxxxx) - if (end - bytes < 3 || (bytes[1] & 0xC0) != 0x80 || (bytes[2] & 0xC0) != 0x80) - return false; - bytes += 3; - } else if ((*bytes & 0xF8) == 0xF0) { - // 4-byte sequence (11110xxx 10xxxxxx 10xxxxxx 10xxxxxx) - if (end - bytes < 4 || (bytes[1] & 0xC0) != 0x80 || (bytes[2] & 0xC0) != 0x80 || (bytes[3] & 0xC0) != 0x80) - return false; - bytes += 4; - } else { - // Invalid UTF-8 lead byte - return false; - } - } - - return true; -} - // --------------------------------------------------------------------------- // Token-piece JSON serialisation helpers // @@ -485,10 +229,15 @@ static llama_tokens format_infill(const llama_vocab *vocab, const json &input_pr return embd_inp; } -// -// base64 utils (TODO: move to common in the future) -// - +// clang-format off +// ---- BEGIN COPY FROM llama.cpp tools/server/server-common.cpp --------------- +// base64_chars / is_base64 / base64_decode are declared `static` in +// server-common.cpp (internal linkage). Even though server-common.cpp is +// compiled into the same shared library, C++ static linkage makes the symbols +// invisible to every other translation unit — there is no declaration in +// server-common.h to call through. These copies are therefore unavoidable and +// must be kept in sync manually whenever llama.cpp upgrades server-common.cpp. +// Removing them is only possible if upstream moves them to a header as `inline`. static const std::string base64_chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" "abcdefghijklmnopqrstuvwxyz" "0123456789+/"; @@ -499,77 +248,34 @@ static inline raw_buffer base64_decode(const std::string &encoded_string) { int i = 0; int j = 0; int in_ = 0; - int in_len = encoded_string.size(); - uint8_t char_array_4[4]; uint8_t char_array_3[3]; - raw_buffer ret; while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) { - char_array_4[i++] = encoded_string[in_]; - in_++; + char_array_4[i++] = encoded_string[in_++]; if (i == 4) { - for (i = 0; i < 4; i++) { - char_array_4[i] = base64_chars.find(char_array_4[i]); - } - + for (i = 0; i < 4; i++) char_array_4[i] = base64_chars.find(char_array_4[i]); char_array_3[0] = ((char_array_4[0]) << 2) + ((char_array_4[1] & 0x30) >> 4); char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; - - for (i = 0; (i < 3); i++) { - ret.push_back(char_array_3[i]); - } - + for (i = 0; i < 3; i++) ret.push_back(char_array_3[i]); i = 0; } } - if (i) { - for (j = i; j < 4; j++) { - char_array_4[j] = 0; - } - - for (j = 0; j < 4; j++) { - char_array_4[j] = base64_chars.find(char_array_4[j]); - } - + for (j = i; j < 4; j++) char_array_4[j] = 0; + for (j = 0; j < 4; j++) char_array_4[j] = base64_chars.find(char_array_4[j]); char_array_3[0] = ((char_array_4[0]) << 2) + ((char_array_4[1] & 0x30) >> 4); char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; - - for (j = 0; j < i - 1; j++) { - ret.push_back(char_array_3[j]); - } + for (j = 0; j < i - 1; j++) ret.push_back(char_array_3[j]); } - return ret; } - -// -// random string / id -// - -static std::string random_string() { - static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"); - - std::random_device rd; - std::mt19937 generator(rd()); - - std::string result(32, ' '); - - for (int i = 0; i < 32; ++i) { - result[i] = str[generator() % str.size()]; - } - - return result; -} - -static std::string gen_chatcmplid() { return "chatcmpl-" + random_string(); } - -static std::string gen_tool_call_id() { return random_string(); } +// ---- END COPY FROM llama.cpp tools/server/server-common.cpp ----------------- +// clang-format on // Strip an exact-match flag (no value) from an argv array. // Returns a new vector of pointers (non-owning) with every occurrence removed. @@ -588,427 +294,6 @@ static std::vector strip_flag_from_argv(char **argv, int argc, const cha return out; } -// -// other common utils -// - -// TODO: reuse llama_detokenize -template static std::string tokens_to_str(llama_context *ctx, Iter begin, Iter end) { - std::string ret; - for (; begin != end; ++begin) { - ret += common_token_to_piece(ctx, *begin); - } - - return ret; -} - -// Vocab-only variant: detokenize without an inference context. -template static std::string tokens_to_str(const llama_vocab *vocab, Iter begin, Iter end) { - std::string ret; - for (; begin != end; ++begin) { - ret += common_token_to_piece(vocab, *begin); - } - - return ret; -} - -// format incomplete utf-8 multibyte character for output -static std::string tokens_to_output_formatted_string(const llama_context *ctx, const llama_token token) { - std::string out = token == LLAMA_TOKEN_NULL ? "" : common_token_to_piece(ctx, token); - - // if the size is 1 and first bit is 1, meaning it's a partial character - // (size > 1 meaning it's already a known token) - if (out.size() == 1 && (out[0] & 0x80) == 0x80) { - std::stringstream ss; - ss << std::hex << (out[0] & 0xff); - std::string res(ss.str()); - out = "byte: \\x" + res; - } - - return out; -} - -// -// OAI utils -// - -// used by /completions endpoint -static json oaicompat_completion_params_parse(const json &body) { - json llama_params; - - if (!body.contains("prompt")) { - throw std::runtime_error("\"prompt\" is required"); - } - - // Handle "stop" field - if (body.contains("stop") && body.at("stop").is_string()) { - llama_params["stop"] = json::array({body.at("stop").get()}); - } else { - llama_params["stop"] = json_value(body, "stop", json::array()); - } - - // Handle "n" field - int n_choices = json_value(body, "n", 1); - if (n_choices != 1) { - throw std::runtime_error("Only one completion choice is allowed"); - } - - // Handle "echo" field - if (json_value(body, "echo", false)) { - throw std::runtime_error("Only no echo is supported"); - } - - // Params supported by OAI but unsupported by llama.cpp - static const std::vector unsupported_params{"best_of", "suffix"}; - for (const auto ¶m : unsupported_params) { - if (body.contains(param)) { - throw std::runtime_error("Unsupported param: " + param); - } - } - - // Copy remaining properties to llama_params - for (const auto &item : body.items()) { - // Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens" - if (!llama_params.contains(item.key()) || item.key() == "n_predict") { - llama_params[item.key()] = item.value(); - } - } - - return llama_params; -} - -struct oaicompat_parser_options { - bool use_jinja = false; - bool prefill_assistant = false; - common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; - common_chat_templates *tmpls = nullptr; - bool allow_image = false; - bool allow_audio = false; - bool enable_thinking = false; -}; - -// used by /chat/completions endpoint -static json oaicompat_chat_params_parse(json &body, /* openai api json semantics */ - const oaicompat_parser_options &opt, std::vector &out_files) { - json llama_params; - - auto tools = json_value(body, "tools", json()); - auto has_tools = tools.is_array() && !tools.empty(); - auto stream = json_value(body, "stream", false); - auto tool_choice = json_value(body, "tool_choice", std::string("auto")); - - if (!opt.use_jinja) { - if (has_tools) { - throw std::runtime_error("tools param requires --jinja flag"); - } - if (tool_choice != "auto") { - throw std::runtime_error("tool_choice param requires --jinja flag"); - } - } - - // Handle "stop" field - if (body.contains("stop") && body.at("stop").is_string()) { - llama_params["stop"] = json::array({body.at("stop").get()}); - } else { - llama_params["stop"] = json_value(body, "stop", json::array()); - } - - auto json_schema = json_value(body, "json_schema", json()); - auto grammar = json_value(body, "grammar", std::string()); - if (!json_schema.is_null() && !grammar.empty()) { - throw std::runtime_error("Cannot use both json_schema and grammar"); - } - - // Handle "response_format" field - if (body.contains("response_format")) { - json response_format = json_value(body, "response_format", json::object()); - std::string response_type = json_value(response_format, "type", std::string()); - if (response_type == "json_object") { - json_schema = json_value(response_format, "schema", json::object()); - } else if (response_type == "json_schema") { - auto schema_wrapper = json_value(response_format, "json_schema", json::object()); - json_schema = json_value(schema_wrapper, "schema", json::object()); - } else if (!response_type.empty() && response_type != "text") { - throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + - response_type); - } - } - - // get input files - if (!body.contains("messages")) { - throw std::runtime_error("'messages' is required"); - } - json &messages = body.at("messages"); - if (!messages.is_array()) { - throw std::runtime_error("Expected 'messages' to be an array"); - } - for (auto &msg : messages) { - std::string role = json_value(msg, "role", std::string()); - if (role != "assistant" && !msg.contains("content")) { - throw std::runtime_error("All non-assistant messages must contain 'content'"); - } - if (role == "assistant") { - if (!msg.contains("content") && !msg.contains("tool_calls")) { - throw std::runtime_error("Assistant message must contain either 'content' or 'tool_calls'!"); - } - if (!msg.contains("content")) { - continue; // avoid errors with no content - } - } - json &content = msg.at("content"); - if (content.is_string() || content.is_null()) { - continue; - } - - if (!content.is_array()) { - throw std::runtime_error("Expected 'content' to be a string or an array"); - } - - for (auto &p : content) { - std::string type = json_value(p, "type", std::string()); - if (type == "image_url") { - if (!opt.allow_image) { - throw std::runtime_error("image input is not supported - hint: if this is unexpected, you may need " - "to provide the mmproj"); - } - - json image_url = json_value(p, "image_url", json::object()); - std::string url = json_value(image_url, "url", std::string()); - if (string_starts_with(url, "http")) { - // download remote image - // TODO @ngxson : maybe make these params configurable - common_remote_params params; - params.headers.push_back({"User-Agent", "llama.cpp/" + std::string(llama_build_info())}); - params.max_size = 1024 * 1024 * 10; // 10MB - params.timeout = 10; // seconds - SRV_INF("downloading image from '%s'\n", url.c_str()); - auto res = common_remote_get_content(url, params); - if (200 <= res.first && res.first < 300) { - SRV_INF("downloaded %ld bytes\n", res.second.size()); - raw_buffer data; - data.insert(data.end(), res.second.begin(), res.second.end()); - out_files.push_back(data); - } else { - throw std::runtime_error("Failed to download image"); - } - - } else { - // try to decode base64 image - std::vector parts = string_split(url, /*separator*/ ','); - if (parts.size() != 2) { - throw std::runtime_error("Invalid image_url.url value"); - } else if (!string_starts_with(parts[0], "data:image/")) { - throw std::runtime_error("Invalid image_url.url format: " + parts[0]); - } else if (!string_ends_with(parts[0], "base64")) { - throw std::runtime_error("image_url.url must be base64 encoded"); - } else { - auto base64_data = parts[1]; - auto decoded_data = base64_decode(base64_data); - out_files.push_back(decoded_data); - } - } - - // replace this chunk with a marker - p["type"] = "text"; - p["text"] = mtmd_default_marker(); - p.erase("image_url"); - - } else if (type == "input_audio") { - if (!opt.allow_audio) { - throw std::runtime_error("audio input is not supported - hint: if this is unexpected, you may need " - "to provide the mmproj"); - } - - json input_audio = json_value(p, "input_audio", json::object()); - std::string data = json_value(input_audio, "data", std::string()); - std::string format = json_value(input_audio, "format", std::string()); - // while we also support flac, we don't allow it here so we matches the OAI spec - if (format != "wav" && format != "mp3") { - throw std::runtime_error("input_audio.format must be either 'wav' or 'mp3'"); - } - auto decoded_data = base64_decode(data); // expected to be base64 encoded - out_files.push_back(decoded_data); - - // replace this chunk with a marker - p["type"] = "text"; - p["text"] = mtmd_default_marker(); - p.erase("input_audio"); - - } else if (type != "text") { - throw std::runtime_error("unsupported content[].type"); - } - } - } - - common_chat_templates_inputs inputs; - inputs.messages = common_chat_msgs_parse_oaicompat(messages); - inputs.tools = common_chat_tools_parse_oaicompat(tools); - inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(tool_choice); - inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump(); - inputs.grammar = grammar; - inputs.use_jinja = opt.use_jinja; - inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false); - inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true); - inputs.reasoning_format = opt.reasoning_format; - inputs.enable_thinking = opt.enable_thinking; - // Extract custom template kwargs from request body (JSON object with string values). - // Values are stored as JSON-serialized strings because upstream does json::parse(value). - if (body.contains("chat_template_kwargs") && body.at("chat_template_kwargs").is_object()) { - for (auto &el : body.at("chat_template_kwargs").items()) { - inputs.chat_template_kwargs[el.key()] = el.value().dump(); - } - } - if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) { - if (body.contains("grammar")) { - throw std::runtime_error("Cannot use custom grammar constraints with tools."); - } - llama_params["parse_tool_calls"] = true; - } - - // if the assistant message appears at the end of list, we do not add end-of-turn token - // for ex. this can be useful to modify the reasoning process in reasoning models - bool prefill_assistant_message = - !inputs.messages.empty() && inputs.messages.back().role == "assistant" && opt.prefill_assistant; - common_chat_msg last_message; - if (prefill_assistant_message) { - last_message = inputs.messages.back(); - inputs.messages.pop_back(); - - /* sanity check, max one assistant message at the end of the list */ - if (!inputs.messages.empty() && inputs.messages.back().role == "assistant") { - throw std::runtime_error("Cannot have 2 or more assistant messages at the end of the list."); - } - - /* TODO: test this properly */ - inputs.reasoning_format = COMMON_REASONING_FORMAT_NONE; - inputs.add_generation_prompt = true; - } - - // Apply chat template to the list of messages - auto chat_params = common_chat_templates_apply(opt.tmpls, inputs); - - /* Append assistant prefilled message */ - if (prefill_assistant_message) { - chat_params.prompt += last_message.content; - } - - llama_params["chat_format"] = static_cast(chat_params.format); - llama_params["prompt"] = chat_params.prompt; - if (!chat_params.grammar.empty()) { - llama_params["grammar"] = chat_params.grammar; - } - llama_params["grammar_lazy"] = chat_params.grammar_lazy; - auto grammar_triggers = json::array(); - for (const auto &trigger : chat_params.grammar_triggers) { - server_grammar_trigger ct(trigger); - grammar_triggers.push_back(ct.to_json()); - } - llama_params["grammar_triggers"] = grammar_triggers; - llama_params["preserved_tokens"] = chat_params.preserved_tokens; - llama_params["generation_prompt"] = chat_params.generation_prompt; - for (const auto &stop : chat_params.additional_stops) { - llama_params["stop"].push_back(stop); - } - - // Handle "n" field - int n_choices = json_value(body, "n", 1); - if (n_choices != 1) { - throw std::runtime_error("Only one completion choice is allowed"); - } - - // Handle "logprobs" field - // TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may - // need to fix it in the future - if (json_value(body, "logprobs", false)) { - if (has_tools && stream) { - throw std::runtime_error("logprobs is not supported with tools + stream"); - } - llama_params["n_probs"] = json_value(body, "top_logprobs", 20); - } else if (body.contains("top_logprobs") && !body.at("top_logprobs").is_null()) { - throw std::runtime_error("top_logprobs requires logprobs to be set to true"); - } - - // Copy remaining properties to llama_params - // This allows user to use llama.cpp-specific params like "mirostat", ... via OAI endpoint. - // See "launch_slot_with_task()" for a complete list of params supported by llama.cpp - for (const auto &item : body.items()) { - // Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens" - if (!llama_params.contains(item.key()) || item.key() == "n_predict") { - llama_params[item.key()] = item.value(); - } - } - - return llama_params; -} - -static json format_embeddings_response_oaicompat(const json &request, const json &embeddings, bool use_base64 = false) { - json data = json::array(); - int32_t n_tokens = 0; - int i = 0; - for (const auto &elem : embeddings) { - json embedding_obj; - - if (use_base64) { - const auto &vec = json_value(elem, "embedding", json::array()).get>(); - const char *data_ptr = reinterpret_cast(vec.data()); - size_t data_size = vec.size() * sizeof(float); - embedding_obj = {{"embedding", base64::encode(data_ptr, data_size)}, - {"index", i++}, - {"object", "embedding"}, - {"encoding_format", "base64"}}; - } else { - embedding_obj = { - {"embedding", json_value(elem, "embedding", json::array())}, {"index", i++}, {"object", "embedding"}}; - } - data.push_back(embedding_obj); - - n_tokens += json_value(elem, "tokens_evaluated", 0); - } - - json res = json{{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, - {"object", "list"}, - {"usage", json{{"prompt_tokens", n_tokens}, {"total_tokens", n_tokens}}}, - {"data", data}}; - - return res; -} - -static json format_response_rerank(const json &request, const json &ranks, bool is_tei_format, - std::vector &texts, int top_n) { - int32_t n_tokens = 0; - bool return_text = is_tei_format && json_value(request, "return_text", false); - std::vector elements; - std::string score_label = is_tei_format ? "score" : "relevance_score"; - for (const auto &rank : ranks) { - int index = json_value(rank, "index", 0); - json elem = json{ - {"index", index}, - {score_label, json_value(rank, "score", 0.0)}, - }; - n_tokens += json_value(rank, "tokens_evaluated", 0); - if (return_text) { - elem["text"] = std::move(texts[index]); - } - elements.push_back(elem); - } - - std::sort(elements.begin(), elements.end(), [score_label](const json &a, const json &b) { - return json_value(a, score_label, 0.0) > json_value(b, score_label, 0.0); - }); - - elements.resize(std::min(top_n, (int) elements.size())); - json results = elements; - - if (is_tei_format) return results; - - json res = json{{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, - {"object", "list"}, - {"usage", json{{"prompt_tokens", n_tokens}, {"total_tokens", n_tokens}}}, - {"results", results}}; - - return res; -} - static json format_tokenizer_response(const json &tokens) { return json{{"tokens", tokens}}; } static json format_detokenized_response(const std::string &content) { return json{{"content", content}}; } @@ -1024,343 +309,14 @@ static json format_logit_bias(const std::vector &logit_bias) { return data; } -static std::string safe_json_to_str(const json &data) { - return data.dump(-1, ' ', false, json::error_handler_t::replace); -} - -static std::vector get_token_probabilities(llama_context *ctx, int idx) { - std::vector cur; - const auto *logits = llama_get_logits_ith(ctx, idx); - - const llama_model *model = llama_get_model(ctx); - const llama_vocab *vocab = llama_model_get_vocab(model); - - const int n_vocab = llama_vocab_n_tokens(vocab); - - cur.resize(n_vocab); - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; - } - - // sort tokens by logits - std::sort(cur.begin(), cur.end(), - [](const llama_token_data &a, const llama_token_data &b) { return a.logit > b.logit; }); - - // apply softmax - float max_l = cur[0].logit; - float cum_sum = 0.0f; - for (size_t i = 0; i < cur.size(); ++i) { - float p = expf(cur[i].logit - max_l); - cur[i].p = p; - cum_sum += p; - } - for (size_t i = 0; i < cur.size(); ++i) { - cur[i].p /= cum_sum; - } - - return cur; -} - -static bool are_lora_equal(const std::vector &l1, - const std::vector &l2) { - if (l1.size() != l2.size()) { - return false; - } - for (size_t i = 0; i < l1.size(); ++i) { - // we don't check lora.path to reduce the time complexity - if (l1[i].scale != l2[i].scale || l1[i].ptr != l2[i].ptr) { - return false; - } - } - return true; -} - // parse lora config from JSON request, returned a copy of lora_base with updated scale static std::vector parse_lora_request(const std::vector &lora_base, const json &data) { std::vector lora(lora_base); - int max_idx = lora.size(); - - // clear existing value - for (auto &entry : lora) { - entry.scale = 0.0f; + for (auto &e : lora) e.scale = 0.0f; + for (const auto &[id, scale] : parse_lora_request(data)) { // upstream: extracts id->scale map + if (id < 0 || id >= (int)lora.size()) throw std::runtime_error("invalid adapter id"); + lora[id].scale = scale; } - - // set value - for (const auto &entry : data) { - int id = json_value(entry, "id", -1); - float scale = json_value(entry, "scale", 0.0f); - if (0 <= id && id < max_idx) { - lora[id].scale = scale; - } else { - throw std::runtime_error("invalid adapter id"); - } - } - return lora; } - -// -// utils for interacting with libmtmd -// (may need to refactor in near future) -// - -/** - * server_tokens is a helper to manage the input tokens and image for the server. - * it is made this way to simplify the logic of KV cache management. - */ -struct server_tokens { - bool has_mtmd = false; - - private: // disallow accessing these members directly, risking out-of-sync - // map a **start** position in tokens to the image chunk - std::unordered_map map_pos_to_media; - - // list of tokens - // it can include LLAMA_TOKEN_NULL, which is used to indicate a token that is not a text token - // a mtmd_input_chunk can occupy multiple tokens, one llama_token per **position** - // important: for models using mrope, an image can contain multiple tokens but will use only one **position** - llama_tokens tokens; - - // for ex. with input of 5 text tokens and 2 images: - // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] - // pos 0 1 2 3 4 5 6 7 8 9 - // map_pos_to_media will contain: {5, img0}, {8, img1} - - public: - server_tokens() = default; - ~server_tokens() = default; - - // Prevent copying - server_tokens(const server_tokens &) = delete; - server_tokens &operator=(const server_tokens &) = delete; - - // Allow moving (usually implicitly generated if members are movable) - server_tokens(server_tokens &&) = default; - server_tokens &operator=(server_tokens &&) = default; - - // Allow accessing elements using [] operator - llama_token operator[](size_t index) { return tokens[index]; } - const llama_token &operator[](size_t index) const { return tokens[index]; } - - server_tokens(mtmd::input_chunks &mtmd_chunks, bool has_mtmd) : has_mtmd(has_mtmd) { - for (size_t i = 0; i < mtmd_chunks.size(); ++i) { - push_back(mtmd_chunks[i]); - } - } - - server_tokens(llama_tokens &tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {} - - // for debugging - std::string str() const { - std::ostringstream oss; - oss << "tokens: "; - for (const auto &t : tokens) { - if (t == LLAMA_TOKEN_NULL) { - oss << " "; - } else { - oss << t << " "; - } - } - oss << "\n"; - oss << "image pos: "; - for (const auto &it : map_pos_to_media) { - oss << it.first << ", "; - } - return oss.str(); - } - - const mtmd::input_chunk_ptr &find_chunk(llama_pos pos) const { - auto it = map_pos_to_media.find(pos); - if (it != map_pos_to_media.end()) { - return it->second; - } else { - throw std::runtime_error("Chunk not found"); - } - } - - void push_back(llama_token tok) { - if (tok == LLAMA_TOKEN_NULL) { - throw std::runtime_error("Invalid token"); - } - tokens.emplace_back(tok); - } - - // will create a copy of the chunk if it contains non-text data - void push_back(const mtmd_input_chunk *chunk) { - auto type = mtmd_input_chunk_get_type(chunk); - if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE || type == MTMD_INPUT_CHUNK_TYPE_AUDIO) { - GGML_ASSERT(has_mtmd); - const int n_pos = mtmd_input_chunk_get_n_pos(chunk); - llama_pos start_pos = tokens.size(); - for (int i = 0; i < n_pos; ++i) { - tokens.emplace_back(LLAMA_TOKEN_NULL); - } - mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk)); - map_pos_to_media[start_pos] = std::move(new_chunk); - } else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) { - size_t n_tokens; - auto text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens); - for (size_t i = 0; i < n_tokens; ++i) { - push_back(text_tokens[i]); - } - } else { - GGML_ABORT("Invalid chunk type"); - } - } - - // for compatibility with context shift and prompt truncation - void insert(const llama_tokens &inp_tokens) { - GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled - tokens.insert(tokens.end(), inp_tokens.begin(), inp_tokens.end()); - } - - // for compatibility with speculative decoding, ctx shift, slot save/load - const llama_tokens &get_tokens() const { - GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled - return tokens; - } - - // returns a copy with LLAMA_TOKEN_NULL entries filtered out (mtmd image placeholders) - llama_tokens get_text_tokens() const { - llama_tokens res; - res.reserve(tokens.size()); - for (llama_token t : tokens) { - if (t != LLAMA_TOKEN_NULL) { - res.push_back(t); - } - } - return res; - } - - // for compatibility with speculative decoding - void set_token(llama_pos pos, llama_token id) { - GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled - tokens[pos] = id; - } - - size_t size() const { return tokens.size(); } - - bool empty() const { return tokens.empty(); } - - void clear() { tokens.clear(); } - - void keep_first(size_t n) { - GGML_ASSERT(n <= tokens.size()); - if (has_mtmd) { - if (n == tokens.size()) { - return; // nothing to do - } - // we throw an error if we try to remove a token in the middle of an image - // for ex. with input of 5 text tokens and 2 images: - // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] - // n 1 2 3 4 5 6 7 8 9 10 - // allowed to resize ^ ^ - // disallowed to resize ^ ^ ^ - if (n > 0) { - llama_token last_token = tokens[n - 1]; - // make sure we never remove tokens in the middle of an image - if (last_token == LLAMA_TOKEN_NULL) { - find_chunk(n - 1); // will throw an error if the token is not begin-of-chunk - } - } - // remove all image chunks that are not used anymore - for (auto it = map_pos_to_media.begin(); it != map_pos_to_media.end();) { - llama_pos pos = it->first; - if (pos >= (llama_pos)n) { - it = map_pos_to_media.erase(it); - } else { - ++it; - } - } - } - tokens.resize(n); - } - - std::string detokenize(const llama_context *ctx, bool special) const { - llama_tokens text_tokens; - text_tokens.reserve(tokens.size()); - for (const auto &t : tokens) { - if (t != LLAMA_TOKEN_NULL) { - text_tokens.push_back(t); - } - } - return common_detokenize(ctx, text_tokens, special); - } - - size_t get_common_prefix(const server_tokens &b) const { - size_t max_idx = std::min(tokens.size(), b.tokens.size()); - for (size_t i = 0; i < max_idx; ++i) { - auto &ai = tokens[i]; - auto &bi = b.tokens[i]; - - if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) { - GGML_ASSERT(has_mtmd); - const auto &a_chunk = find_chunk(i); - const auto &b_chunk = b.find_chunk(i); - GGML_ASSERT(a_chunk && b_chunk); - std::string ai_id = mtmd_input_chunk_get_id(a_chunk.get()); - std::string bi_id = mtmd_input_chunk_get_id(b_chunk.get()); - size_t a_pos = mtmd_input_chunk_get_n_pos(a_chunk.get()); - size_t b_pos = mtmd_input_chunk_get_n_pos(b_chunk.get()); - if (ai_id == bi_id && a_pos == b_pos) { - GGML_ASSERT(a_pos > 0 && "Invalid media chunk"); // should never happen - i += a_pos - 1; // will be +1 by the for loop - continue; - } else { - return i; - } - } else if (ai == bi) { - continue; - } else { - return i; - } - } - return max_idx; // all tokens are equal - } - - // make sure all text tokens are within the vocab range - bool validate(const struct llama_context *ctx) const { - const llama_model *model = llama_get_model(ctx); - const llama_vocab *vocab = llama_model_get_vocab(model); - const int32_t n_vocab = llama_vocab_n_tokens(vocab); - - for (size_t i = 0; i < tokens.size(); ++i) { - auto &t = tokens[i]; - if (t == LLAMA_TOKEN_NULL) { - try { - const auto &chunk = find_chunk(i); - size_t n_pos = mtmd_input_chunk_get_n_pos(chunk.get()); - i += n_pos - 1; // will be +1 by the for loop - } catch (const std::exception &e) { - return false; - } - } else if (t < 0 || t >= n_vocab) { - return false; - } - } - return true; - } - - // encode and decode the image chunk - int32_t process_chunk(llama_context *ctx, mtmd_context *mctx, llama_pos n_past, int32_t seq_id, - llama_pos &n_pos_out) { - auto &chunk = find_chunk(n_past); - const char *name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE ? "image" : "audio"; - SRV_INF("processing %s...\n", name); - int32_t n_batch = llama_n_batch(ctx); - int64_t t0 = ggml_time_ms(); - llama_pos new_n_past = n_past; - int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx, chunk.get(), n_past, seq_id, n_batch, - true, // logits last - &new_n_past); - SRV_INF("%s processed in %" PRId64 " ms\n", name, ggml_time_ms() - t0); - if (result != 0) { - LOG_ERR("mtmd_helper_eval failed with status %d", result); - n_pos_out = n_past; - return result; - } - n_pos_out = new_n_past; - return 0; - } -}; diff --git a/src/test/cpp/test_utils.cpp b/src/test/cpp/test_utils.cpp index 78a31b52..d76fa278 100644 --- a/src/test/cpp/test_utils.cpp +++ b/src/test/cpp/test_utils.cpp @@ -237,7 +237,7 @@ TEST(FormatResponseRerank, JinaFormat_WrapperStructure) { json ranks = json::array({make_rank(0, 0.5), make_rank(1, 0.9)}); std::vector texts = {"doc0", "doc1"}; - json res = format_response_rerank(request, ranks, /*is_tei=*/false, texts, /*top_n=*/2); + json res = format_response_rerank(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), ranks, /*is_tei=*/false, texts, /*top_n=*/2); EXPECT_EQ(res.at("model").get(), "my-reranker"); EXPECT_EQ(res.at("object").get(), "list"); @@ -251,7 +251,7 @@ TEST(FormatResponseRerank, JinaFormat_UsesRelevanceScoreLabel) { json ranks = json::array({make_rank(0, 0.7)}); std::vector texts = {"doc"}; - json res = format_response_rerank(request, ranks, false, texts, 1); + json res = format_response_rerank(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), ranks, false, texts, 1); EXPECT_TRUE(res.at("results")[0].contains("relevance_score")); EXPECT_FALSE(res.at("results")[0].contains("score")); @@ -263,7 +263,7 @@ TEST(FormatResponseRerank, JinaFormat_SortedDescendingByScore) { json ranks = json::array({make_rank(0, 0.3), make_rank(1, 0.9), make_rank(2, 0.1)}); std::vector texts = {"a", "b", "c"}; - json res = format_response_rerank(request, ranks, false, texts, 3); + json res = format_response_rerank(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), ranks, false, texts, 3); auto &results = res.at("results"); EXPECT_EQ(results[0].at("index").get(), 1); // highest: 0.9 @@ -276,7 +276,7 @@ TEST(FormatResponseRerank, TopN_LimitsResultCount) { json ranks = json::array({make_rank(0, 0.5), make_rank(1, 0.9), make_rank(2, 0.1)}); std::vector texts = {"a", "b", "c"}; - json res = format_response_rerank(request, ranks, false, texts, /*top_n=*/1); + json res = format_response_rerank(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), ranks, false, texts, /*top_n=*/1); EXPECT_EQ(res.at("results").size(), 1u); // The single returned result must be the highest-scoring one @@ -289,7 +289,7 @@ TEST(FormatResponseRerank, TopN_Two_KeepsTopTwo) { make_rank(0, 0.1), make_rank(1, 0.9), make_rank(2, 0.5), make_rank(3, 0.7)}); std::vector texts = {"a", "b", "c", "d"}; - json res = format_response_rerank(request, ranks, false, texts, 2); + json res = format_response_rerank(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), ranks, false, texts, 2); EXPECT_EQ(res.at("results").size(), 2u); EXPECT_EQ(res.at("results")[0].at("index").get(), 1); // 0.9 @@ -301,7 +301,7 @@ TEST(FormatResponseRerank, TopN_LargerThanCount_ReturnsAll) { json ranks = json::array({make_rank(0, 0.8), make_rank(1, 0.2)}); std::vector texts = {"x", "y"}; - json res = format_response_rerank(request, ranks, false, texts, /*top_n=*/100); + json res = format_response_rerank(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), ranks, false, texts, /*top_n=*/100); EXPECT_EQ(res.at("results").size(), 2u); } @@ -311,7 +311,7 @@ TEST(FormatResponseRerank, TokenCounting_Accumulated) { json ranks = json::array({make_rank(0, 0.5, 15), make_rank(1, 0.9, 25)}); std::vector texts = {"a", "b"}; - json res = format_response_rerank(request, ranks, false, texts, 2); + json res = format_response_rerank(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), ranks, false, texts, 2); EXPECT_EQ(res.at("usage").at("prompt_tokens").get(), 40); // 15 + 25 EXPECT_EQ(res.at("usage").at("total_tokens").get(), 40); @@ -322,7 +322,7 @@ TEST(FormatResponseRerank, TeiFormat_ReturnsArrayDirectly) { json ranks = json::array({make_rank(0, 0.8), make_rank(1, 0.3)}); std::vector texts = {"x", "y"}; - json res = format_response_rerank(request, ranks, /*is_tei=*/true, texts, 2); + json res = format_response_rerank(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), ranks, /*is_tei=*/true, texts, 2); EXPECT_TRUE(res.is_array()); // no outer wrapper object EXPECT_EQ(res.size(), 2u); @@ -333,7 +333,7 @@ TEST(FormatResponseRerank, TeiFormat_UsesScoreLabel) { json ranks = json::array({make_rank(0, 0.8)}); std::vector texts = {"doc"}; - json res = format_response_rerank(request, ranks, true, texts, 1); + json res = format_response_rerank(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), ranks, true, texts, 1); ASSERT_TRUE(res.is_array()); EXPECT_TRUE(res[0].contains("score")); @@ -345,7 +345,7 @@ TEST(FormatResponseRerank, TeiFormat_ReturnText_IncludesDocumentText) { json ranks = json::array({make_rank(0, 0.9)}); std::vector texts = {"my document content"}; - json res = format_response_rerank(request, ranks, true, texts, 1); + json res = format_response_rerank(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), ranks, true, texts, 1); ASSERT_TRUE(res.is_array()); EXPECT_TRUE(res[0].contains("text")); @@ -357,7 +357,7 @@ TEST(FormatResponseRerank, TeiFormat_NoReturnText_NoTextField) { json ranks = json::array({make_rank(0, 0.9)}); std::vector texts = {"doc"}; - json res = format_response_rerank(request, ranks, true, texts, 1); + json res = format_response_rerank(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), ranks, true, texts, 1); ASSERT_TRUE(res.is_array()); EXPECT_FALSE(res[0].contains("text")); @@ -368,7 +368,7 @@ TEST(FormatResponseRerank, TeiFormat_SortedDescendingByScore) { json ranks = json::array({make_rank(0, 0.1), make_rank(1, 0.9), make_rank(2, 0.5)}); std::vector texts = {"a", "b", "c"}; - json res = format_response_rerank(request, ranks, true, texts, 3); + json res = format_response_rerank(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), ranks, true, texts, 3); ASSERT_TRUE(res.is_array()); EXPECT_EQ(res[0].at("index").get(), 1); // 0.9 @@ -812,9 +812,11 @@ TEST(OaicompatCompletionParams, StopArray_PassedThrough) { EXPECT_EQ(res.at("stop").size(), 2u); } -TEST(OaicompatCompletionParams, NNotOne_Throws) { +TEST(OaicompatCompletionParams, NNotOne_PassedThrough) { + // upstream oaicompat_completion_params_parse no longer rejects n > 1; + // the value is forwarded to llama_params like any other field const json body = {{"prompt", "hi"}, {"n", 3}}; - EXPECT_THROW(oaicompat_completion_params_parse(body), std::runtime_error); + EXPECT_NO_THROW(oaicompat_completion_params_parse(body)); } TEST(OaicompatCompletionParams, NEqualsOne_OK) { @@ -867,7 +869,7 @@ json make_embedding_elem(const std::vector &vec, int tokens = 4) { TEST(FormatEmbeddingsResponse, SingleEmbedding_Fields) { const json request = {{"model", "test-model"}}; const json embeddings = json::array({make_embedding_elem({0.1f, 0.2f, 0.3f})}); - const json res = format_embeddings_response_oaicompat(request, embeddings); + const json res = format_embeddings_response_oaicompat(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), embeddings); EXPECT_EQ(res.at("object").get(), "list"); EXPECT_EQ(res.at("model").get(), "test-model"); EXPECT_EQ(res.at("data").size(), 1u); @@ -878,7 +880,7 @@ TEST(FormatEmbeddingsResponse, SingleEmbedding_Fields) { TEST(FormatEmbeddingsResponse, TokensAccumulated) { const json request = {}; const json embeddings = json::array({make_embedding_elem({1.0f}, 3), make_embedding_elem({2.0f}, 7)}); - const json res = format_embeddings_response_oaicompat(request, embeddings); + const json res = format_embeddings_response_oaicompat(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), embeddings); EXPECT_EQ(res.at("usage").at("prompt_tokens").get(), 10); EXPECT_EQ(res.at("usage").at("total_tokens").get(), 10); } @@ -890,7 +892,7 @@ TEST(FormatEmbeddingsResponse, MultipleEmbeddings_IndicesIncrement) { make_embedding_elem({0.2f}), make_embedding_elem({0.3f}), }); - const json res = format_embeddings_response_oaicompat(request, embeddings); + const json res = format_embeddings_response_oaicompat(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), embeddings); EXPECT_EQ(res.at("data").size(), 3u); EXPECT_EQ(res.at("data")[0].at("index").get(), 0); EXPECT_EQ(res.at("data")[1].at("index").get(), 1); @@ -900,7 +902,7 @@ TEST(FormatEmbeddingsResponse, MultipleEmbeddings_IndicesIncrement) { TEST(FormatEmbeddingsResponse, Base64Format_EncodingFormatField) { const json request = {}; const json embeddings = json::array({make_embedding_elem({1.0f, 0.0f})}); - const json res = format_embeddings_response_oaicompat(request, embeddings, /*use_base64=*/true); + const json res = format_embeddings_response_oaicompat(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), embeddings, /*use_base64=*/true); const json &elem = res.at("data")[0]; EXPECT_TRUE(elem.contains("encoding_format")); EXPECT_EQ(elem.at("encoding_format").get(), "base64"); @@ -1000,42 +1002,42 @@ json make_chat_body_with_messages(const json &messages_override = json::array({ return json{{"messages", messages_override}}; } -oaicompat_parser_options make_no_jinja_opts() { - oaicompat_parser_options opt; +server_chat_params make_no_jinja_opts() { + server_chat_params opt; opt.use_jinja = false; - opt.tmpls = nullptr; + // tmpls: shared_ptr default-constructs to nullptr — no explicit set needed return opt; } } // namespace TEST(OaicompatChatParams, MissingMessages_Throws) { json body = {{"model", "x"}}; - oaicompat_parser_options opt = make_no_jinja_opts(); + server_chat_params opt = make_no_jinja_opts(); std::vector files; - EXPECT_THROW(oaicompat_chat_params_parse(body, opt, files), std::runtime_error); + EXPECT_THROW(oaicompat_chat_params_parse(body, opt, files), std::exception); } TEST(OaicompatChatParams, MessagesNotArray_Throws) { json body = {{"messages", "not-an-array"}}; - oaicompat_parser_options opt = make_no_jinja_opts(); + server_chat_params opt = make_no_jinja_opts(); std::vector files; - EXPECT_THROW(oaicompat_chat_params_parse(body, opt, files), std::runtime_error); + EXPECT_THROW(oaicompat_chat_params_parse(body, opt, files), std::exception); } TEST(OaicompatChatParams, NonAssistantMissingContent_Throws) { // user message with no "content" field json body = {{"messages", json::array({{{"role", "user"}}})}}; - oaicompat_parser_options opt = make_no_jinja_opts(); + server_chat_params opt = make_no_jinja_opts(); std::vector files; - EXPECT_THROW(oaicompat_chat_params_parse(body, opt, files), std::runtime_error); + EXPECT_THROW(oaicompat_chat_params_parse(body, opt, files), std::exception); } TEST(OaicompatChatParams, AssistantMissingBothContentAndToolCalls_Throws) { // assistant message must have content OR tool_calls json body = {{"messages", json::array({{{"role", "assistant"}}})}}; - oaicompat_parser_options opt = make_no_jinja_opts(); + server_chat_params opt = make_no_jinja_opts(); std::vector files; - EXPECT_THROW(oaicompat_chat_params_parse(body, opt, files), std::runtime_error); + EXPECT_THROW(oaicompat_chat_params_parse(body, opt, files), std::exception); } TEST(OaicompatChatParams, ToolsWithoutJinja_Throws) { @@ -1043,9 +1045,9 @@ TEST(OaicompatChatParams, ToolsWithoutJinja_Throws) { {"messages", json::array({{{"role", "user"}, {"content", "hi"}}})}, {"tools", json::array({{{"type", "function"}}})} }; - oaicompat_parser_options opt = make_no_jinja_opts(); + server_chat_params opt = make_no_jinja_opts(); std::vector files; - EXPECT_THROW(oaicompat_chat_params_parse(body, opt, files), std::runtime_error); + EXPECT_THROW(oaicompat_chat_params_parse(body, opt, files), std::exception); } TEST(OaicompatChatParams, NonAutoToolChoiceWithoutJinja_Throws) { @@ -1053,9 +1055,9 @@ TEST(OaicompatChatParams, NonAutoToolChoiceWithoutJinja_Throws) { {"messages", json::array({{{"role", "user"}, {"content", "hi"}}})}, {"tool_choice", "none"} }; - oaicompat_parser_options opt = make_no_jinja_opts(); + server_chat_params opt = make_no_jinja_opts(); std::vector files; - EXPECT_THROW(oaicompat_chat_params_parse(body, opt, files), std::runtime_error); + EXPECT_THROW(oaicompat_chat_params_parse(body, opt, files), std::exception); } TEST(OaicompatChatParams, GrammarAndJsonSchema_Throws) { @@ -1064,9 +1066,9 @@ TEST(OaicompatChatParams, GrammarAndJsonSchema_Throws) { {"grammar", "root ::= [a-z]+"}, {"json_schema", {{"type", "object"}}} }; - oaicompat_parser_options opt = make_no_jinja_opts(); + server_chat_params opt = make_no_jinja_opts(); std::vector files; - EXPECT_THROW(oaicompat_chat_params_parse(body, opt, files), std::runtime_error); + EXPECT_THROW(oaicompat_chat_params_parse(body, opt, files), std::exception); } TEST(OaicompatChatParams, InvalidResponseFormatType_Throws) { @@ -1074,9 +1076,9 @@ TEST(OaicompatChatParams, InvalidResponseFormatType_Throws) { {"messages", json::array({{{"role", "user"}, {"content", "hi"}}})}, {"response_format", {{"type", "invalid_type"}}} }; - oaicompat_parser_options opt = make_no_jinja_opts(); + server_chat_params opt = make_no_jinja_opts(); std::vector files; - EXPECT_THROW(oaicompat_chat_params_parse(body, opt, files), std::runtime_error); + EXPECT_THROW(oaicompat_chat_params_parse(body, opt, files), std::exception); } TEST(OaicompatChatParams, ContentPartTypeUnsupported_Throws) { @@ -1084,9 +1086,9 @@ TEST(OaicompatChatParams, ContentPartTypeUnsupported_Throws) { {"role", "user"}, {"content", json::array({{{"type", "video_url"}, {"url", "x"}}})} }})}}; - oaicompat_parser_options opt = make_no_jinja_opts(); + server_chat_params opt = make_no_jinja_opts(); std::vector files; - EXPECT_THROW(oaicompat_chat_params_parse(body, opt, files), std::runtime_error); + EXPECT_THROW(oaicompat_chat_params_parse(body, opt, files), std::exception); } TEST(OaicompatChatParams, ImageUrlWithoutAllowImage_Throws) { @@ -1097,18 +1099,18 @@ TEST(OaicompatChatParams, ImageUrlWithoutAllowImage_Throws) { {"image_url", {{"url", "data:image/png;base64,abc"}}} }})} }})}}; - oaicompat_parser_options opt = make_no_jinja_opts(); + server_chat_params opt = make_no_jinja_opts(); opt.allow_image = false; std::vector files; - EXPECT_THROW(oaicompat_chat_params_parse(body, opt, files), std::runtime_error); + EXPECT_THROW(oaicompat_chat_params_parse(body, opt, files), std::exception); } TEST(OaicompatChatParams, ContentNotStringOrArray_Throws) { // content is an integer — not allowed json body = {{"messages", json::array({{{"role", "user"}, {"content", 42}}})}}; - oaicompat_parser_options opt = make_no_jinja_opts(); + server_chat_params opt = make_no_jinja_opts(); std::vector files; - EXPECT_THROW(oaicompat_chat_params_parse(body, opt, files), std::runtime_error); + EXPECT_THROW(oaicompat_chat_params_parse(body, opt, files), std::exception); } // ============================================================