From e4e2ed92d0a97a86718a7481b71679f603556a93 Mon Sep 17 00:00:00 2001 From: "alexey.sokolov" Date: Mon, 27 Jan 2025 20:06:59 +0700 Subject: [PATCH 1/9] begin upgrading to llama.cpp b4513 --- CMakeLists.txt | 2 +- src/main/cpp/jllama.cpp | 27 +- src/main/cpp/server.hpp | 4114 ++++++++++++++++++++++----------------- src/main/cpp/utils.hpp | 1082 +++++----- 4 files changed, 2988 insertions(+), 2237 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ec2b84a6..a083ea10 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -23,7 +23,7 @@ FetchContent_MakeAvailable(json) FetchContent_Declare( llama.cpp GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b3751 + GIT_TAG b4513 ) FetchContent_MakeAvailable(llama.cpp) diff --git a/src/main/cpp/jllama.cpp b/src/main/cpp/jllama.cpp index 07eef014..8ae54811 100644 --- a/src/main/cpp/jllama.cpp +++ b/src/main/cpp/jllama.cpp @@ -1,5 +1,6 @@ #include "jllama.h" +#include "log.h" #include "llama.h" #include "nlohmann/json.hpp" #include "server.hpp" @@ -354,7 +355,8 @@ JNIEXPORT void JNICALL JNI_OnUnload(JavaVM *vm, void *reserved) JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jobject obj, jstring jparams) { - gpt_params params; + common_params params; + common_init(); auto *ctx_server = new server_context(); @@ -364,16 +366,16 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo if (json_value(json_params, "disable_log", false)) { - log_disable(); + common_log_disable(); } else { log_enable(); } - if (!params.system_prompt.empty()) + if (!params.prompt.empty()) { - ctx_server->system_prompt_set(params.system_prompt); + ctx_server->system_prompt_set(params.prompt); } if (params.model_alias == "unknown") @@ -383,14 +385,10 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo llama_numa_init(params.numa); - LOG_INFO("build info", {{"build", LLAMA_BUILD_NUMBER}, {"commit", LLAMA_COMMIT}}); - - LOG_INFO("system info", { - {"n_threads", params.cpuparams.n_threads}, - {"n_threads_batch", params.cpuparams_batch.n_threads}, - {"total_threads", std::thread::hardware_concurrency()}, - {"system_info", llama_print_system_info()}, - }); + LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params.cpuparams.n_threads, params.cpuparams_batch.n_threads, std::thread::hardware_concurrency()); + LOG_INF("\n"); + LOG_INF("%s\n", common_params_get_system_info(params).c_str()); + LOG_INF("\n"); std::atomic state{SERVER_STATE_LOADING_MODEL}; @@ -417,9 +415,8 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo { if (!ctx_server->validate_model_chat_template()) { - LOG_ERROR("The chat template that comes with this model is not yet supported, falling back to chatml. This " - "may cause the model to output suboptimal responses", - {}); + LOG_ERR("%s", "The chat template that comes with this model is not yet supported, falling back to chatml. This " + "may cause the model to output suboptimal responses"); params.chat_template = "chatml"; } } diff --git a/src/main/cpp/server.hpp b/src/main/cpp/server.hpp index 029721c1..4ea0a18e 100644 --- a/src/main/cpp/server.hpp +++ b/src/main/cpp/server.hpp @@ -1,9 +1,11 @@ #include "utils.hpp" #include "common.h" -#include "sampling.h" -#include "grammar-parser.h" +#include "json-schema-to-grammar.h" #include "llama.h" +#include "log.h" +#include "sampling.h" +#include "speculative.h" #include "nlohmann/json.hpp" @@ -11,39 +13,45 @@ #include #include #include +#include +#include +#include #include -#include #include -#include -#include +#include #include -#include +#include using json = nlohmann::ordered_json; -enum stop_type -{ - STOP_TYPE_FULL, - STOP_TYPE_PARTIAL, +constexpr int HTTP_POLLING_SECONDS = 1; + +enum stop_type { + STOP_TYPE_NONE, + STOP_TYPE_EOS, + STOP_TYPE_WORD, + STOP_TYPE_LIMIT, }; // state diagram: https://github.com/ggerganov/llama.cpp/pull/9283 enum slot_state { SLOT_STATE_IDLE, + SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future SLOT_STATE_PROCESSING_PROMPT, SLOT_STATE_DONE_PROMPT, SLOT_STATE_GENERATING, }; -enum server_state -{ - SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet - SERVER_STATE_READY, // Server is ready and model is loaded +enum server_state { + SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet + SERVER_STATE_READY, // Server is ready and model is loaded }; -enum server_task_type -{ +enum server_task_type { SERVER_TASK_TYPE_COMPLETION, + SERVER_TASK_TYPE_EMBEDDING, + SERVER_TASK_TYPE_RERANK, + SERVER_TASK_TYPE_INFILL, SERVER_TASK_TYPE_CANCEL, SERVER_TASK_TYPE_NEXT_RESPONSE, SERVER_TASK_TYPE_METRICS, @@ -53,66 +61,1074 @@ enum server_task_type SERVER_TASK_TYPE_SET_LORA, }; -enum server_task_cmpl_type { - SERVER_TASK_CMPL_TYPE_NORMAL, - SERVER_TASK_CMPL_TYPE_EMBEDDING, - SERVER_TASK_CMPL_TYPE_INFILL, -}; +enum oaicompat_type { + OAICOMPAT_TYPE_NONE, + OAICOMPAT_TYPE_CHAT, + OAICOMPAT_TYPE_COMPLETION, + OAICOMPAT_TYPE_EMBEDDING, +}; + +// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11 +enum error_type { + ERROR_TYPE_INVALID_REQUEST, + ERROR_TYPE_AUTHENTICATION, + ERROR_TYPE_SERVER, + ERROR_TYPE_NOT_FOUND, + ERROR_TYPE_PERMISSION, + ERROR_TYPE_UNAVAILABLE, // custom error + ERROR_TYPE_NOT_SUPPORTED, // custom error +}; + +struct slot_params { + bool stream = true; + bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt + bool return_tokens = false; + + int32_t n_keep = 0; // number of tokens to keep from initial prompt + int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half + int32_t n_predict = -1; // new tokens to predict + int32_t n_indent = 0; // mininum line indentation for the generated text in number of whitespace characters + + int64_t t_max_prompt_ms = -1; // TODO: implement + int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit + + std::vector lora; + + std::vector antiprompt; + std::vector response_fields; + bool timings_per_token = false; + bool post_sampling_probs = false; + bool ignore_eos = false; + + struct common_params_sampling sampling; + struct common_params_speculative speculative; + + // OAI-compat fields + bool verbose = false; + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + + json to_json() const { + std::vector samplers; + samplers.reserve(sampling.samplers.size()); + for (const auto & sampler : sampling.samplers) { + samplers.emplace_back(common_sampler_type_to_str(sampler)); + } + + json lora = json::array(); + for (size_t i = 0; i < this->lora.size(); ++i) { + lora.push_back({{"id", i}, {"scale", this->lora[i].scale}}); + } + + return json { + {"n_predict", n_predict}, // Server configured n_predict + {"seed", sampling.seed}, + {"temperature", sampling.temp}, + {"dynatemp_range", sampling.dynatemp_range}, + {"dynatemp_exponent", sampling.dynatemp_exponent}, + {"top_k", sampling.top_k}, + {"top_p", sampling.top_p}, + {"min_p", sampling.min_p}, + {"xtc_probability", sampling.xtc_probability}, + {"xtc_threshold", sampling.xtc_threshold}, + {"typical_p", sampling.typ_p}, + {"repeat_last_n", sampling.penalty_last_n}, + {"repeat_penalty", sampling.penalty_repeat}, + {"presence_penalty", sampling.penalty_present}, + {"frequency_penalty", sampling.penalty_freq}, + {"dry_multiplier", sampling.dry_multiplier}, + {"dry_base", sampling.dry_base}, + {"dry_allowed_length", sampling.dry_allowed_length}, + {"dry_penalty_last_n", sampling.dry_penalty_last_n}, + {"dry_sequence_breakers", sampling.dry_sequence_breakers}, + {"mirostat", sampling.mirostat}, + {"mirostat_tau", sampling.mirostat_tau}, + {"mirostat_eta", sampling.mirostat_eta}, + {"stop", antiprompt}, + {"max_tokens", n_predict}, // User configured n_predict + {"n_keep", n_keep}, + {"n_discard", n_discard}, + {"ignore_eos", sampling.ignore_eos}, + {"stream", stream}, + {"logit_bias", format_logit_bias(sampling.logit_bias)}, + {"n_probs", sampling.n_probs}, + {"min_keep", sampling.min_keep}, + {"grammar", sampling.grammar}, + {"samplers", samplers}, + {"speculative.n_max", speculative.n_max}, + {"speculative.n_min", speculative.n_min}, + {"speculative.p_min", speculative.p_min}, + {"timings_per_token", timings_per_token}, + {"post_sampling_probs", post_sampling_probs}, + {"lora", lora}, + }; + } +}; + +struct server_task { + int id = -1; // to be filled by server_queue + int index = -1; // used when there are multiple prompts (batch request) + + server_task_type type; + + // used by SERVER_TASK_TYPE_CANCEL + int id_target = -1; + + // used by SERVER_TASK_TYPE_INFERENCE + slot_params params; + llama_tokens prompt_tokens; + int id_selected_slot = -1; + + // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE + struct slot_action { + int slot_id; + std::string filename; + std::string filepath; + }; + slot_action slot_action; + + // used by SERVER_TASK_TYPE_METRICS + bool metrics_reset_bucket = false; + + // used by SERVER_TASK_TYPE_SET_LORA + std::vector set_lora; + + server_task(server_task_type type) : type(type) {} + + static slot_params params_from_json_cmpl( + const llama_context * ctx, + const common_params & params_base, + const json & data) { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + slot_params params; + + // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them) + slot_params defaults; + defaults.sampling = params_base.sampling; + defaults.speculative = params_base.speculative; + + // enabling this will output extra debug information in the HTTP responses from the server + params.verbose = params_base.verbosity > 9; + params.timings_per_token = json_value(data, "timings_per_token", false); + + params.stream = json_value(data, "stream", false); + params.cache_prompt = json_value(data, "cache_prompt", true); + params.return_tokens = json_value(data, "return_tokens", false); + params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict)); + params.n_indent = json_value(data, "n_indent", defaults.n_indent); + params.n_keep = json_value(data, "n_keep", defaults.n_keep); + params.n_discard = json_value(data, "n_discard", defaults.n_discard); + //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement + params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms); + params.response_fields = json_value(data, "response_fields", std::vector()); + + params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k); + params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p); + params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p); + params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability); + params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold); + params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p); + params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp); + params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range); + params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent); + params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n); + params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat); + params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq); + params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present); + params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier); + params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base); + params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length); + params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n); + params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat); + params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau); + params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta); + params.sampling.seed = json_value(data, "seed", defaults.sampling.seed); + params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs); + params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep); + params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs); + + params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min); + params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max); + params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min); + + params.speculative.n_min = std::min(params.speculative.n_max, params.speculative.n_min); + params.speculative.n_min = std::max(params.speculative.n_min, 2); + params.speculative.n_max = std::max(params.speculative.n_max, 0); + + if (data.contains("lora")) { + if (data.at("lora").is_array()) { + params.lora = parse_lora_request(params_base.lora_adapters, data.at("lora")); + } else { + throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields"); + } + } else { + params.lora = params_base.lora_adapters; + } + + // TODO: add more sanity checks for the input parameters + + if (params.sampling.penalty_last_n < -1) { + throw std::runtime_error("Error: repeat_last_n must be >= -1"); + } + + if (params.sampling.dry_penalty_last_n < -1) { + throw std::runtime_error("Error: dry_penalty_last_n must be >= -1"); + } + + if (params.sampling.penalty_last_n == -1) { + // note: should be the slot's context and not the full context, but it's ok + params.sampling.penalty_last_n = llama_n_ctx(ctx); + } + + if (params.sampling.dry_penalty_last_n == -1) { + params.sampling.dry_penalty_last_n = llama_n_ctx(ctx); + } + + if (params.sampling.dry_base < 1.0f) { + params.sampling.dry_base = defaults.sampling.dry_base; + } + + // sequence breakers for DRY + { + // Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format + // Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39 + + if (data.contains("dry_sequence_breakers")) { + params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector()); + if (params.sampling.dry_sequence_breakers.empty()) { + throw std::runtime_error("Error: dry_sequence_breakers must be a non-empty array of strings"); + } + } + } + + // process "json_schema" and "grammar" + if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) { + throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both"); + } + if (data.contains("json_schema") && !data.contains("grammar")) { + try { + auto schema = json_value(data, "json_schema", json::object()); + params.sampling.grammar = json_schema_to_grammar(schema); + } catch (const std::exception & e) { + throw std::runtime_error(std::string("\"json_schema\": ") + e.what()); + } + } else { + params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar); + } + + { + params.sampling.logit_bias.clear(); + params.ignore_eos = json_value(data, "ignore_eos", false); + + const auto & logit_bias = data.find("logit_bias"); + if (logit_bias != data.end() && logit_bias->is_array()) { + const int n_vocab = llama_vocab_n_tokens(vocab); + for (const auto & el : *logit_bias) { + // TODO: we may want to throw errors here, in case "el" is incorrect + if (el.is_array() && el.size() == 2) { + float bias; + if (el[1].is_number()) { + bias = el[1].get(); + } else if (el[1].is_boolean() && !el[1].get()) { + bias = -INFINITY; + } else { + continue; + } + + if (el[0].is_number_integer()) { + llama_token tok = el[0].get(); + if (tok >= 0 && tok < n_vocab) { + params.sampling.logit_bias.push_back({tok, bias}); + } + } else if (el[0].is_string()) { + auto toks = common_tokenize(vocab, el[0].get(), false); + for (auto tok : toks) { + params.sampling.logit_bias.push_back({tok, bias}); + } + } + } + } + } + } + + { + params.antiprompt.clear(); + + const auto & stop = data.find("stop"); + if (stop != data.end() && stop->is_array()) { + for (const auto & word : *stop) { + if (!word.empty()) { + params.antiprompt.push_back(word); + } + } + } + } + + { + const auto & samplers = data.find("samplers"); + if (samplers != data.end()) { + if (samplers->is_array()) { + std::vector sampler_names; + for (const auto & name : *samplers) { + if (name.is_string()) { + sampler_names.emplace_back(name); + } + } + params.sampling.samplers = common_sampler_types_from_names(sampler_names, false); + } else if (samplers->is_string()){ + std::string sampler_string; + for (const auto & name : *samplers) { + sampler_string += name; + } + params.sampling.samplers = common_sampler_types_from_chars(sampler_string); + } + } else { + params.sampling.samplers = defaults.sampling.samplers; + } + } + + std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias; + params.oaicompat_model = json_value(data, "model", model_name); + + return params; + } + + // utility function + static std::unordered_set get_list_id(const std::vector & tasks) { + std::unordered_set ids(tasks.size()); + for (size_t i = 0; i < tasks.size(); i++) { + ids.insert(tasks[i].id); + } + return ids; + } +}; + +struct result_timings { + int32_t prompt_n = -1; + double prompt_ms; + double prompt_per_token_ms; + double prompt_per_second; + + int32_t predicted_n = -1; + double predicted_ms; + double predicted_per_token_ms; + double predicted_per_second; + + json to_json() const { + return { + {"prompt_n", prompt_n}, + {"prompt_ms", prompt_ms}, + {"prompt_per_token_ms", prompt_per_token_ms}, + {"prompt_per_second", prompt_per_second}, + + {"predicted_n", predicted_n}, + {"predicted_ms", predicted_ms}, + {"predicted_per_token_ms", predicted_per_token_ms}, + {"predicted_per_second", predicted_per_second}, + }; + } +}; + +struct server_task_result { + int id = -1; + int id_slot = -1; + virtual bool is_error() { + // only used by server_task_result_error + return false; + } + virtual bool is_stop() { + // only used by server_task_result_cmpl_* + return false; + } + virtual int get_index() { + return -1; + } + virtual json to_json() = 0; + virtual ~server_task_result() = default; +}; + +// using shared_ptr for polymorphism of server_task_result +using server_task_result_ptr = std::unique_ptr; + +inline std::string stop_type_to_str(stop_type type) { + switch (type) { + case STOP_TYPE_EOS: return "eos"; + case STOP_TYPE_WORD: return "word"; + case STOP_TYPE_LIMIT: return "limit"; + default: return "none"; + } +} + +struct completion_token_output { + llama_token tok; + float prob; + std::string text_to_send; + struct prob_info { + llama_token tok; + std::string txt; + float prob; + }; + std::vector probs; + + json to_json(bool post_sampling_probs) const { + json probs_for_token = json::array(); + for (const auto & p : probs) { + std::string txt(p.txt); + txt.resize(validate_utf8(txt)); + probs_for_token.push_back(json { + {"id", p.tok}, + {"token", txt}, + {"bytes", str_to_bytes(p.txt)}, + { + post_sampling_probs ? "prob" : "logprob", + post_sampling_probs ? p.prob : logarithm(p.prob) + }, + }); + } + return probs_for_token; + } + + static json probs_vector_to_json(const std::vector & probs, bool post_sampling_probs) { + json out = json::array(); + for (const auto & p : probs) { + std::string txt(p.text_to_send); + txt.resize(validate_utf8(txt)); + out.push_back(json { + {"id", p.tok}, + {"token", txt}, + {"bytes", str_to_bytes(p.text_to_send)}, + { + post_sampling_probs ? "prob" : "logprob", + post_sampling_probs ? p.prob : logarithm(p.prob) + }, + { + post_sampling_probs ? "top_probs" : "top_logprobs", + p.to_json(post_sampling_probs) + }, + }); + } + return out; + } + + static float logarithm(float x) { + // nlohmann::json converts -inf to null, so we need to prevent that + return x == 0.0f ? std::numeric_limits::lowest() : std::log(x); + } + + static std::vector str_to_bytes(const std::string & str) { + std::vector bytes; + for (unsigned char c : str) { + bytes.push_back(c); + } + return bytes; + } +}; + +struct server_task_result_cmpl_final : server_task_result { + int index = 0; + + std::string content; + llama_tokens tokens; + + bool stream; + result_timings timings; + std::string prompt; + + bool truncated; + int32_t n_decoded; + int32_t n_prompt_tokens; + int32_t n_tokens_cached; + bool has_new_line; + std::string stopping_word; + stop_type stop = STOP_TYPE_NONE; + + bool post_sampling_probs; + std::vector probs_output; + std::vector response_fields; + + slot_params generation_params; + + // OAI-compat fields + bool verbose = false; + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + + virtual int get_index() override { + return index; + } + + virtual bool is_stop() override { + return true; // in stream mode, final responses are considered stop + } + + virtual json to_json() override { + switch (oaicompat) { + case OAICOMPAT_TYPE_NONE: + return to_json_non_oaicompat(); + case OAICOMPAT_TYPE_COMPLETION: + return to_json_oaicompat(); + case OAICOMPAT_TYPE_CHAT: + return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat(); + default: + GGML_ASSERT(false && "Invalid oaicompat_type"); + } + } + + json to_json_non_oaicompat() { + json res = json { + {"index", index}, + {"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk + {"tokens", stream ? llama_tokens {} : tokens}, + {"id_slot", id_slot}, + {"stop", true}, + {"model", oaicompat_model}, + {"tokens_predicted", n_decoded}, + {"tokens_evaluated", n_prompt_tokens}, + {"generation_settings", generation_params.to_json()}, + {"prompt", prompt}, + {"has_new_line", has_new_line}, + {"truncated", truncated}, + {"stop_type", stop_type_to_str(stop)}, + {"stopping_word", stopping_word}, + {"tokens_cached", n_tokens_cached}, + {"timings", timings.to_json()}, + }; + if (!stream && !probs_output.empty()) { + res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs); + } + return response_fields.empty() ? res : json_get_nested_values(response_fields, res); + } + + json to_json_oaicompat() { + std::time_t t = std::time(0); + json logprobs = json(nullptr); // OAI default to null + if (!stream && probs_output.size() > 0) { + logprobs = json{ + {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)}, + }; + } + json finish_reason = "length"; + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + finish_reason = "stop"; + } + json res = json { + {"choices", json::array({ + json{ + {"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk + {"index", index}, + {"logprobs", logprobs}, + {"finish_reason", finish_reason}, + } + })}, + {"created", t}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "text_completion"}, + {"usage", json { + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens} + }}, + {"id", oaicompat_cmpl_id} + }; + + // extra fields for debugging purposes + if (verbose) { + res["__verbose"] = to_json_non_oaicompat(); + } + if (timings.prompt_n >= 0) { + res.push_back({"timings", timings.to_json()}); + } + + return res; + } + + json to_json_oaicompat_chat() { + std::string finish_reason = "length"; + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + finish_reason = "stop"; + } + + json choice = json{ + {"finish_reason", finish_reason}, + {"index", 0}, + {"message", json { + {"content", content}, + {"role", "assistant"} + } + }}; + + if (!stream && probs_output.size() > 0) { + choice["logprobs"] = json{ + {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)}, + }; + } + + std::time_t t = std::time(0); + + json res = json { + {"choices", json::array({choice})}, + {"created", t}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "chat.completion"}, + {"usage", json { + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens} + }}, + {"id", oaicompat_cmpl_id} + }; + + // extra fields for debugging purposes + if (verbose) { + res["__verbose"] = to_json_non_oaicompat(); + } + if (timings.prompt_n >= 0) { + res.push_back({"timings", timings.to_json()}); + } + + return res; + } + + json to_json_oaicompat_chat_stream() { + std::time_t t = std::time(0); + std::string finish_reason = "length"; + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + finish_reason = "stop"; + } + + json choice = json{ + {"finish_reason", finish_reason}, + {"index", 0}, + {"delta", json::object()} + }; + + json ret = json { + {"choices", json::array({choice})}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "chat.completion.chunk"}, + {"usage", json { + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens}, + }}, + }; + + if (timings.prompt_n >= 0) { + ret.push_back({"timings", timings.to_json()}); + } + + return ret; + } +}; + +struct server_task_result_cmpl_partial : server_task_result { + int index = 0; + + std::string content; + llama_tokens tokens; + + int32_t n_decoded; + int32_t n_prompt_tokens; + + bool post_sampling_probs; + completion_token_output prob_output; + result_timings timings; + + // OAI-compat fields + bool verbose = false; + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + + virtual int get_index() override { + return index; + } + + virtual bool is_stop() override { + return false; // in stream mode, partial responses are not considered stop + } + + virtual json to_json() override { + switch (oaicompat) { + case OAICOMPAT_TYPE_NONE: + return to_json_non_oaicompat(); + case OAICOMPAT_TYPE_COMPLETION: + return to_json_oaicompat(); + case OAICOMPAT_TYPE_CHAT: + return to_json_oaicompat_chat(); + default: + GGML_ASSERT(false && "Invalid oaicompat_type"); + } + } + + json to_json_non_oaicompat() { + // non-OAI-compat JSON + json res = json { + {"index", index}, + {"content", content}, + {"tokens", tokens}, + {"stop", false}, + {"id_slot", id_slot}, + {"tokens_predicted", n_decoded}, + {"tokens_evaluated", n_prompt_tokens}, + }; + // populate the timings object when needed (usually for the last response or with timings_per_token enabled) + if (timings.prompt_n > 0) { + res.push_back({"timings", timings.to_json()}); + } + if (!prob_output.probs.empty()) { + res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs); + } + return res; + } + + json to_json_oaicompat() { + std::time_t t = std::time(0); + json logprobs = json(nullptr); // OAI default to null + if (prob_output.probs.size() > 0) { + logprobs = json{ + {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)}, + }; + } + json res = json { + {"choices", json::array({ + json{ + {"text", content}, + {"index", index}, + {"logprobs", logprobs}, + {"finish_reason", nullptr}, + } + })}, + {"created", t}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "text_completion"}, + {"id", oaicompat_cmpl_id} + }; + + // extra fields for debugging purposes + if (verbose) { + res["__verbose"] = to_json_non_oaicompat(); + } + if (timings.prompt_n >= 0) { + res.push_back({"timings", timings.to_json()}); + } + + return res; + } + + json to_json_oaicompat_chat() { + bool first = n_decoded == 0; + std::time_t t = std::time(0); + json choices; + + if (first) { + if (content.empty()) { + choices = json::array({json{{"finish_reason", nullptr}, + {"index", 0}, + {"delta", json{{"role", "assistant"}}}}}); + } else { + // We have to send this as two updates to conform to openai behavior + json initial_ret = json{{"choices", json::array({json{ + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", json{ + {"role", "assistant"} + }}}})}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"object", "chat.completion.chunk"}}; + + json second_ret = json{ + {"choices", json::array({json{{"finish_reason", nullptr}, + {"index", 0}, + {"delta", json { + {"content", content}}} + }})}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"object", "chat.completion.chunk"}}; + + return std::vector({initial_ret, second_ret}); + } + } else { + choices = json::array({json{ + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", + json { + {"content", content}, + }}, + }}); + } + + GGML_ASSERT(choices.size() >= 1); + + if (prob_output.probs.size() > 0) { + choices[0]["logprobs"] = json{ + {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)}, + }; + } + + json ret = json { + {"choices", choices}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "chat.completion.chunk"} + }; + + if (timings.prompt_n >= 0) { + ret.push_back({"timings", timings.to_json()}); + } + + return std::vector({ret}); + } +}; + +struct server_task_result_embd : server_task_result { + int index = 0; + std::vector> embedding; + + int32_t n_tokens; + + // OAI-compat fields + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + + virtual int get_index() override { + return index; + } + + virtual json to_json() override { + return oaicompat == OAICOMPAT_TYPE_EMBEDDING + ? to_json_oaicompat() + : to_json_non_oaicompat(); + } + + json to_json_non_oaicompat() { + return json { + {"index", index}, + {"embedding", embedding}, + }; + } + + json to_json_oaicompat() { + return json { + {"index", index}, + {"embedding", embedding[0]}, + {"tokens_evaluated", n_tokens}, + }; + } +}; + +struct server_task_result_rerank : server_task_result { + int index = 0; + float score = -1e6; + + int32_t n_tokens; + + virtual int get_index() override { + return index; + } + + virtual json to_json() override { + return json { + {"index", index}, + {"score", score}, + {"tokens_evaluated", n_tokens}, + }; + } +}; + +// this function maybe used outside of server_task_result_error +static json format_error_response(const std::string & message, const enum error_type type) { + std::string type_str; + int code = 500; + switch (type) { + case ERROR_TYPE_INVALID_REQUEST: + type_str = "invalid_request_error"; + code = 400; + break; + case ERROR_TYPE_AUTHENTICATION: + type_str = "authentication_error"; + code = 401; + break; + case ERROR_TYPE_NOT_FOUND: + type_str = "not_found_error"; + code = 404; + break; + case ERROR_TYPE_SERVER: + type_str = "server_error"; + code = 500; + break; + case ERROR_TYPE_PERMISSION: + type_str = "permission_error"; + code = 403; + break; + case ERROR_TYPE_NOT_SUPPORTED: + type_str = "not_supported_error"; + code = 501; + break; + case ERROR_TYPE_UNAVAILABLE: + type_str = "unavailable_error"; + code = 503; + break; + } + return json { + {"code", code}, + {"message", message}, + {"type", type_str}, + }; +} + +struct server_task_result_error : server_task_result { + int index = 0; + error_type err_type = ERROR_TYPE_SERVER; + std::string err_msg; + + virtual bool is_error() override { + return true; + } + + virtual json to_json() override { + return format_error_response(err_msg, err_type); + } +}; + +struct server_task_result_metrics : server_task_result { + int n_idle_slots; + int n_processing_slots; + int n_tasks_deferred; + int64_t t_start; + + int32_t kv_cache_tokens_count; + int32_t kv_cache_used_cells; + + // TODO: somehow reuse server_metrics in the future, instead of duplicating the fields + uint64_t n_prompt_tokens_processed_total = 0; + uint64_t t_prompt_processing_total = 0; + uint64_t n_tokens_predicted_total = 0; + uint64_t t_tokens_generation_total = 0; + + uint64_t n_prompt_tokens_processed = 0; + uint64_t t_prompt_processing = 0; + + uint64_t n_tokens_predicted = 0; + uint64_t t_tokens_generation = 0; -struct server_task -{ - int id = -1; // to be filled by server_queue - int id_target = -1; // used by SERVER_TASK_TYPE_CANCEL + uint64_t n_decode_total = 0; + uint64_t n_busy_slots_total = 0; - server_task_type type; - json data; + // while we can also use std::vector this requires copying the slot object which can be quite messy + // therefore, we use json to temporarily store the slot.to_json() result + json slots_data = json::array(); - bool infill = false; - bool embedding = false; + virtual json to_json() override { + return json { + { "idle", n_idle_slots }, + { "processing", n_processing_slots }, + { "deferred", n_tasks_deferred }, + { "t_start", t_start }, - server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL; + { "n_prompt_tokens_processed_total", n_prompt_tokens_processed_total }, + { "t_tokens_generation_total", t_tokens_generation_total }, + { "n_tokens_predicted_total", n_tokens_predicted_total }, + { "t_prompt_processing_total", t_prompt_processing_total }, - // utility function - static std::unordered_set get_list_id(const std::vector & tasks) { - std::unordered_set ids(tasks.size()); - for (size_t i = 0; i < tasks.size(); i++) { - ids.insert(tasks[i].id); - } - return ids; - } -}; + { "n_prompt_tokens_processed", n_prompt_tokens_processed }, + { "t_prompt_processing", t_prompt_processing }, + { "n_tokens_predicted", n_tokens_predicted }, + { "t_tokens_generation", t_tokens_generation }, -struct server_task_result -{ - int id = -1; + { "n_decode_total", n_decode_total }, + { "n_busy_slots_total", n_busy_slots_total }, - json data; + { "kv_cache_tokens_count", kv_cache_tokens_count }, + { "kv_cache_used_cells", kv_cache_used_cells }, - bool stop; - bool error; + { "slots", slots_data }, + }; + } }; -struct slot_params -{ - bool stream = true; - bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt +struct server_task_result_slot_save_load : server_task_result { + std::string filename; + bool is_save; // true = save, false = load + + size_t n_tokens; + size_t n_bytes; + double t_ms; + + virtual json to_json() override { + if (is_save) { + return json { + { "id_slot", id_slot }, + { "filename", filename }, + { "n_saved", n_tokens }, + { "n_written", n_bytes }, + { "timings", { + { "save_ms", t_ms } + }}, + }; + } else { + return json { + { "id_slot", id_slot }, + { "filename", filename }, + { "n_restored", n_tokens }, + { "n_read", n_bytes }, + { "timings", { + { "restore_ms", t_ms } + }}, + }; + } + } +}; - int32_t n_keep = 0; // number of tokens to keep from initial prompt - int32_t n_discard = - 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half - int32_t n_predict = -1; // new tokens to predict +struct server_task_result_slot_erase : server_task_result { + size_t n_erased; - std::vector antiprompt; + virtual json to_json() override { + return json { + { "id_slot", id_slot }, + { "n_erased", n_erased }, + }; + } +}; - json input_prefix; - json input_suffix; +struct server_task_result_apply_lora : server_task_result { + virtual json to_json() override { + return json {{ "success", true }}; + } }; -struct server_slot -{ +struct server_slot { int id; int id_task = -1; + // only used for completion/embedding/infill/rerank + server_task_type task_type = SERVER_TASK_TYPE_COMPLETION; + + llama_batch batch_spec = {}; + + llama_context * ctx = nullptr; + llama_context * ctx_dft = nullptr; + + common_speculative * spec = nullptr; + + std::vector lora; + // the index relative to completion multi-task request size_t index = 0; @@ -124,56 +1140,45 @@ struct server_slot int64_t t_last_used = -1; // generation props - int32_t n_ctx = 0; // context size per slot - int32_t n_past = 0; - int32_t n_decoded = 0; + int32_t n_ctx = 0; // context size per slot + int32_t n_past = 0; + int32_t n_decoded = 0; int32_t n_remaining = -1; - int32_t i_batch = -1; - int32_t n_predict = -1; // TODO: disambiguate from params.n_predict + int32_t i_batch = -1; + int32_t n_predict = -1; // TODO: disambiguate from params.n_predict - int32_t n_prompt_tokens = 0; + // n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated + int32_t n_prompt_tokens = 0; int32_t n_prompt_tokens_processed = 0; - json prompt; + // input prompt tokens + llama_tokens prompt_tokens; + + size_t last_nl_pos = 0; - // when a task is submitted, we first tokenize the prompt and store it here - std::vector prompt_tokens; + std::string generated_text; + llama_tokens generated_tokens; + + llama_tokens cache_tokens; - std::string generated_text; - std::vector cache_tokens; std::vector generated_token_probs; - server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL; - bool infill = false; - bool embedding = false; bool has_next_token = true; - bool truncated = false; - bool stopped_eos = false; - bool stopped_word = false; - bool stopped_limit = false; - - bool oaicompat = false; + bool has_new_line = false; + bool truncated = false; + stop_type stop; - std::string oaicompat_model; std::string stopping_word; // sampling json json_schema; - struct gpt_sampler_params sparams; - struct gpt_sampler * smpl = nullptr; + struct common_sampler * smpl = nullptr; llama_token sampled; - int32_t ga_i = 0; // group-attention state - int32_t ga_n = 1; // group-attention factor - int32_t ga_w = 512; // group-attention width - - int32_t n_past_se = 0; // self-extend - // stats - size_t n_sent_text = 0; // number of sent text character - size_t n_sent_token_probs = 0; + size_t n_sent_text = 0; // number of sent text character int64_t t_start_process_prompt; int64_t t_start_generation; @@ -183,40 +1188,43 @@ struct server_slot std::function callback_on_release; - void reset() - { - n_prompt_tokens = 0; - generated_text = ""; - truncated = false; - stopped_eos = false; - stopped_word = false; - stopped_limit = false; - stopping_word = ""; - n_past = 0; - n_sent_text = 0; - n_sent_token_probs = 0; - cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL; - ga_i = 0; - n_past_se = 0; - + void reset() { + SLT_DBG(*this, "%s", "\n"); + + n_prompt_tokens = 0; + last_nl_pos = 0; + generated_text = ""; + has_new_line = false; + truncated = false; + stop = STOP_TYPE_NONE; + stopping_word = ""; + n_past = 0; + n_sent_text = 0; + task_type = SERVER_TASK_TYPE_COMPLETION; + + generated_tokens.clear(); generated_token_probs.clear(); } - bool has_budget(gpt_params &global_params) - { - if (params.n_predict == -1 && global_params.n_predict == -1) - { + bool is_non_causal() const { + return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK; + } + + bool can_batch_with(server_slot & other_slot) { + return is_non_causal() == other_slot.is_non_causal() + && are_lora_equal(lora, other_slot.lora); + } + + bool has_budget(const common_params & global_params) { + if (params.n_predict == -1 && global_params.n_predict == -1) { return true; // limitless } n_remaining = -1; - if (params.n_predict != -1) - { + if (params.n_predict != -1) { n_remaining = params.n_predict - n_decoded; - } - else if (global_params.n_predict != -1) - { + } else if (global_params.n_predict != -1) { n_remaining = global_params.n_predict - n_decoded; } @@ -227,70 +1235,64 @@ struct server_slot return state != SLOT_STATE_IDLE; } - void add_token_string(const completion_token_output &token) - { + bool can_speculate() const { + return ctx_dft && params.speculative.n_max > 0 && params.cache_prompt; + } + + void add_token(const completion_token_output & token) { if (!is_processing()) { + SLT_WRN(*this, "%s", "slot is not processing\n"); return; } generated_token_probs.push_back(token); } - void release() - { + void release() { if (is_processing()) { + SLT_INF(*this, "stop processing: n_past = %d, truncated = %d\n", n_past, truncated); + + t_last_used = ggml_time_us(); t_token_generation = (ggml_time_us() - t_start_generation) / 1e3; state = SLOT_STATE_IDLE; - LOG_INFO("slot released", { - {"id_slot", id}, - {"id_task", id_task}, - {"n_past", n_past}, - {"truncated", truncated}, - }); callback_on_release(id); } } - json get_formated_timings() const - { - return json{ - {"prompt_n", n_prompt_tokens_processed}, - {"prompt_ms", t_prompt_processing}, - {"prompt_per_token_ms", t_prompt_processing / n_prompt_tokens_processed}, - {"prompt_per_second", 1e3 / t_prompt_processing * n_prompt_tokens_processed}, - - {"predicted_n", n_decoded}, - {"predicted_ms", t_token_generation}, - {"predicted_per_token_ms", t_token_generation / n_decoded}, - {"predicted_per_second", 1e3 / t_token_generation * n_decoded}, - }; + result_timings get_timings() const { + result_timings timings; + timings.prompt_n = n_prompt_tokens_processed; + timings.prompt_ms = t_prompt_processing; + timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed; + timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; + + timings.predicted_n = n_decoded; + timings.predicted_ms = t_token_generation; + timings.predicted_per_token_ms = t_token_generation / n_decoded; + timings.predicted_per_second = 1e3 / t_token_generation * n_decoded; + + return timings; } - size_t find_stopping_strings(const std::string &text, const size_t last_token_size, const stop_type type) - { + size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) { size_t stop_pos = std::string::npos; - for (const std::string &word : params.antiprompt) - { + for (const std::string & word : params.antiprompt) { size_t pos; - if (type == STOP_TYPE_FULL) - { - const size_t tmp = word.size() + last_token_size; + if (is_full_stop) { + const size_t tmp = word.size() + last_token_size; const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0; pos = text.find(word, from_pos); - } - else - { + } else { + // otherwise, partial stop pos = find_partial_stop_string(word, text); } - if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) - { - if (type == STOP_TYPE_FULL) - { - stopped_word = true; - stopping_word = word; + if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) { + if (is_full_stop) { + stop = STOP_TYPE_WORD; + stopping_word = word; has_next_token = false; } stop_pos = pos; @@ -300,91 +1302,79 @@ struct server_slot return stop_pos; } - void print_timings() const - { - char buffer[512]; - - double t_token = t_prompt_processing / n_prompt_tokens_processed; - double n_tokens_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; - - snprintf(buffer, 512, - "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)", - t_prompt_processing, n_prompt_tokens_processed, t_token, n_tokens_second); - - LOG_INFO(buffer, { - {"id_slot", id}, - {"id_task", id_task}, - {"t_prompt_processing", t_prompt_processing}, - {"n_prompt_tokens_processed", n_prompt_tokens_processed}, - {"t_token", t_token}, - {"n_tokens_second", n_tokens_second}, - }); - - t_token = t_token_generation / n_decoded; - n_tokens_second = 1e3 / t_token_generation * n_decoded; - - snprintf(buffer, 512, - "generation eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)", - t_token_generation, n_decoded, t_token, n_tokens_second); - - LOG_INFO(buffer, { - {"id_slot", id}, - {"id_task", id_task}, - {"t_token_generation", t_token_generation}, - {"n_decoded", n_decoded}, - {"t_token", t_token}, - {"n_tokens_second", n_tokens_second}, - }); - - snprintf(buffer, 512, " total time = %10.2f ms", t_prompt_processing + t_token_generation); - - LOG_INFO(buffer, { - {"id_slot", id}, - {"id_task", id_task}, - {"t_prompt_processing", t_prompt_processing}, - {"t_token_generation", t_token_generation}, - {"t_total", t_prompt_processing + t_token_generation}, - }); + void print_timings() const { + const double t_prompt = t_prompt_processing / n_prompt_tokens_processed; + const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; + + const double t_gen = t_token_generation / n_decoded; + const double n_gen_second = 1e3 / t_token_generation * n_decoded; + + SLT_INF(*this, + "\n" + "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n" + " eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n" + " total time = %10.2f ms / %5d tokens\n", + t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second, + t_token_generation, n_decoded, t_gen, n_gen_second, + t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded); + } + + json to_json() const { + return json { + {"id", id}, + {"id_task", id_task}, + {"n_ctx", n_ctx}, + {"speculative", can_speculate()}, + {"is_processing", is_processing()}, + {"non_causal", is_non_causal()}, + {"params", params.to_json()}, + {"prompt", common_detokenize(ctx, prompt_tokens)}, + {"next_token", + { + {"has_next_token", has_next_token}, + {"has_new_line", has_new_line}, + {"n_remain", n_remaining}, + {"n_decoded", n_decoded}, + {"stopping_word", stopping_word}, + } + }, + }; } }; -struct server_metrics -{ +struct server_metrics { int64_t t_start = 0; uint64_t n_prompt_tokens_processed_total = 0; - uint64_t t_prompt_processing_total = 0; - uint64_t n_tokens_predicted_total = 0; - uint64_t t_tokens_generation_total = 0; + uint64_t t_prompt_processing_total = 0; + uint64_t n_tokens_predicted_total = 0; + uint64_t t_tokens_generation_total = 0; uint64_t n_prompt_tokens_processed = 0; - uint64_t t_prompt_processing = 0; + uint64_t t_prompt_processing = 0; - uint64_t n_tokens_predicted = 0; + uint64_t n_tokens_predicted = 0; uint64_t t_tokens_generation = 0; uint64_t n_decode_total = 0; uint64_t n_busy_slots_total = 0; - void init() - { + void init() { t_start = ggml_time_us(); } - void on_prompt_eval(const server_slot &slot) - { + void on_prompt_eval(const server_slot & slot) { n_prompt_tokens_processed_total += slot.n_prompt_tokens_processed; - n_prompt_tokens_processed += slot.n_prompt_tokens_processed; - t_prompt_processing += slot.t_prompt_processing; - t_prompt_processing_total += slot.t_prompt_processing; + n_prompt_tokens_processed += slot.n_prompt_tokens_processed; + t_prompt_processing += slot.t_prompt_processing; + t_prompt_processing_total += slot.t_prompt_processing; } - void on_prediction(const server_slot &slot) - { - n_tokens_predicted_total += slot.n_decoded; - n_tokens_predicted += slot.n_decoded; - t_tokens_generation += slot.t_token_generation; - t_tokens_generation_total += slot.t_token_generation; + void on_prediction(const server_slot & slot) { + n_tokens_predicted_total += slot.n_decoded; + n_tokens_predicted += slot.n_decoded; + t_tokens_generation += slot.t_token_generation; + t_tokens_generation_total += slot.t_token_generation; } void on_decoded(const std::vector & slots) { @@ -398,14 +1388,13 @@ struct server_metrics void reset_bucket() { n_prompt_tokens_processed = 0; - t_prompt_processing = 0; - n_tokens_predicted = 0; - t_tokens_generation = 0; + t_prompt_processing = 0; + n_tokens_predicted = 0; + t_tokens_generation = 0; } }; -struct server_queue -{ +struct server_queue { int id = 0; bool running; @@ -417,21 +1406,18 @@ struct server_queue std::condition_variable condition_tasks; // callback functions - std::function callback_new_task; - std::function callback_update_slots; + std::function callback_new_task; + std::function callback_update_slots; // Add a new task to the end of the queue int post(server_task task, bool front = false) { std::unique_lock lock(mutex_tasks); - if (task.id == -1) - { - task.id = id++; - LOG_VERBOSE("new task id", {{"new_id", task.id}}); - } + GGML_ASSERT(task.id != -1); + QUE_DBG("new task, id = %d, front = %d\n", task.id, front); if (front) { queue_tasks.push_front(std::move(task)); } else { - queue_tasks.push_back(std::move(task)); + queue_tasks.push_back(std::move(task)); } condition_tasks.notify_one(); return task.id; @@ -443,8 +1429,8 @@ struct server_queue for (auto & task : tasks) { if (task.id == -1) { task.id = id++; - LOG_VERBOSE("new task id", {{"new_id", task.id}}); } + QUE_DBG("new task, id = %d/%d, front = %d\n", task.id, (int) tasks.size(), front); if (front) { queue_tasks.push_front(std::move(task)); } else { @@ -456,31 +1442,27 @@ struct server_queue } // Add a new task, but defer until one slot is available - void defer(server_task task) - { + void defer(server_task task) { std::unique_lock lock(mutex_tasks); + QUE_DBG("defer task, id = %d\n", task.id); queue_tasks_deferred.push_back(std::move(task)); condition_tasks.notify_one(); } - // Get the next id for creating anew task - int get_new_id() - { + // Get the next id for creating a new task + int get_new_id() { std::unique_lock lock(mutex_tasks); int new_id = id++; - LOG_VERBOSE("new task id", {{"new_id", new_id}}); return new_id; } // Register function to process a new task - void on_new_task(std::function callback) - { + void on_new_task(std::function callback) { callback_new_task = std::move(callback); } // Register the function to be called when all slots data is ready to be processed - void on_update_slots(std::function callback) - { + void on_update_slots(std::function callback) { callback_update_slots = std::move(callback); } @@ -495,8 +1477,7 @@ struct server_queue } // end the start_loop routine - void terminate() - { + void terminate() { std::unique_lock lock(mutex_tasks); running = false; condition_tasks.notify_all(); @@ -509,48 +1490,43 @@ struct server_queue * - Check if multitask is finished * - Update all slots */ - void start_loop() - { + void start_loop() { running = true; - while (true) - { - LOG_VERBOSE("new task may arrive", {}); + while (true) { + QUE_DBG("%s", "processing new tasks\n"); - while (true) - { + while (true) { std::unique_lock lock(mutex_tasks); - if (queue_tasks.empty()) - { + if (queue_tasks.empty()) { lock.unlock(); break; } server_task task = queue_tasks.front(); queue_tasks.pop_front(); lock.unlock(); - LOG_VERBOSE("callback_new_task", {{"id_task", task.id}}); - callback_new_task(task); + + QUE_DBG("processing task, id = %d\n", task.id); + callback_new_task(std::move(task)); } // all tasks in the current loop is processed, slots data is now ready - LOG_VERBOSE("callback_update_slots", {}); + QUE_DBG("%s", "update slots\n"); callback_update_slots(); - LOG_VERBOSE("wait for new task", {}); + QUE_DBG("%s", "waiting for new tasks\n"); { std::unique_lock lock(mutex_tasks); - if (queue_tasks.empty()) - { - if (!running) - { - LOG_VERBOSE("ending start_loop", {}); + if (queue_tasks.empty()) { + if (!running) { + QUE_DBG("%s", "terminate\n"); return; } condition_tasks.wait(lock, [&]{ return (!queue_tasks.empty() || !running); }); - } + } } } } @@ -560,47 +1536,81 @@ struct server_response { // for keeping track of all tasks waiting for the result std::unordered_set waiting_task_ids; - // the main result queue - std::vector queue_results; + // the main result queue (using ptr for polymorphism) + std::vector queue_results; std::mutex mutex_results; std::condition_variable condition_results; // add the id_task to the list of tasks waiting for response - void add_waiting_task_id(int id_task) - { - LOG_VERBOSE("waiting for task id", {{"id_task", id_task}}); + void add_waiting_task_id(int id_task) { + SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size()); std::unique_lock lock(mutex_results); waiting_task_ids.insert(id_task); } void add_waiting_tasks(const std::vector & tasks) { - for (const auto & t : tasks) { - add_waiting_task_id(t.id); + std::unique_lock lock(mutex_results); + + for (const auto & task : tasks) { + SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", task.id, (int) waiting_task_ids.size()); + waiting_task_ids.insert(task.id); } } // when the request is finished, we can remove task associated with it - void remove_waiting_task_id(int id_task) - { - LOG_VERBOSE("remove waiting for task id", {{"id_task", id_task}}); + void remove_waiting_task_id(int id_task) { + SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size()); std::unique_lock lock(mutex_results); waiting_task_ids.erase(id_task); } + void remove_waiting_task_ids(const std::unordered_set & id_tasks) { + std::unique_lock lock(mutex_results); + + for (const auto & id_task : id_tasks) { + SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size()); + waiting_task_ids.erase(id_task); + } + } + // This function blocks the thread until there is a response for one of the id_tasks - server_task_result recv(const std::unordered_set & id_tasks) { - while (true) - { + server_task_result_ptr recv(const std::unordered_set & id_tasks) { + while (true) { std::unique_lock lock(mutex_results); - condition_results.wait(lock, [&] { return !queue_results.empty(); }); + condition_results.wait(lock, [&]{ + return !queue_results.empty(); + }); - for (int i = 0; i < (int)queue_results.size(); i++) - { - if (id_tasks.find(queue_results[i].id) != id_tasks.end()) { - server_task_result res = queue_results[i]; + for (int i = 0; i < (int) queue_results.size(); i++) { + if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) { + server_task_result_ptr res = std::move(queue_results[i]); + queue_results.erase(queue_results.begin() + i); + return res; + } + } + } + + // should never reach here + } + + // same as recv(), but have timeout in seconds + // if timeout is reached, nullptr is returned + server_task_result_ptr recv_with_timeout(const std::unordered_set & id_tasks, int timeout) { + while (true) { + std::unique_lock lock(mutex_results); + bool cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout), [&]{ + return !queue_results.empty(); + }); + if (!cr_res) { + return nullptr; + } + + for (int i = 0; i < (int) queue_results.size(); i++) { + if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) { + server_task_result_ptr res = std::move(queue_results[i]); queue_results.erase(queue_results.begin() + i); return res; } @@ -611,21 +1621,21 @@ struct server_response { } // single-task version of recv() - server_task_result recv(int id_task) { + server_task_result_ptr recv(int id_task) { std::unordered_set id_tasks = {id_task}; return recv(id_tasks); } // Send a new result to a waiting id_task - void send(server_task_result & result) { - LOG_VERBOSE("send new result", {{"id_task", result.id}}); + void send(server_task_result_ptr && result) { + SRV_DBG("sending result for task id = %d\n", result->id); std::unique_lock lock(mutex_results); for (const auto & id_task : waiting_task_ids) { - if (result.id == id_task) - { - LOG_VERBOSE("queue_results.push_back", {{"id_task", id_task}}); - queue_results.push_back(std::move(result)); + if (result->id == id_task) { + SRV_DBG("task id = %d pushed to result queue\n", result->id); + + queue_results.emplace_back(std::move(result)); condition_results.notify_all(); return; } @@ -633,33 +1643,35 @@ struct server_response { } }; -struct server_context -{ - llama_model *model = nullptr; - llama_context *ctx = nullptr; - std::vector lora_adapters; +struct server_context { + common_params params_base; + + // note: keep these alive - they determine the lifetime of the model, context, etc. + common_init_result llama_init; + common_init_result llama_init_dft; + + llama_model * model = nullptr; + llama_context * ctx = nullptr; + + const llama_vocab * vocab = nullptr; + + llama_model * model_dft = nullptr; - gpt_params params; + llama_context_params cparams_dft; llama_batch batch = {}; bool clean_kv_cache = true; - bool add_bos_token = true; + bool add_bos_token = true; bool has_eos_token = false; int32_t n_ctx; // total context for all clients / slots - // system prompt - bool system_need_update = false; - - std::string system_prompt; - std::vector system_tokens; - // slots / clients std::vector slots; json default_generation_settings_for_props; - server_queue queue_tasks; + server_queue queue_tasks; server_response queue_results; server_metrics metrics; @@ -667,101 +1679,124 @@ struct server_context // Necessary similarity of prompt for slot selection float slot_prompt_similarity = 0.0f; - ~server_context() - { - if (ctx) - { - llama_free(ctx); - ctx = nullptr; - } + ~server_context() { + // Clear any sampling context + for (server_slot & slot : slots) { + common_sampler_free(slot.smpl); + slot.smpl = nullptr; - if (model) - { - llama_free_model(model); - model = nullptr; - } + llama_free(slot.ctx_dft); + slot.ctx_dft = nullptr; - // Clear any sampling context - for (server_slot &slot : slots) - { - if (slot.smpl != nullptr) { - gpt_sampler_free(slot.smpl); - } + common_speculative_free(slot.spec); + slot.spec = nullptr; + + llama_batch_free(slot.batch_spec); } llama_batch_free(batch); } - bool load_model(const gpt_params ¶ms_) - { - params = params_; + bool load_model(const common_params & params) { + SRV_INF("loading model '%s'\n", params.model.c_str()); - // dedicate one sequence to the system prompt - params.n_parallel += 1; + params_base = params; - llama_init_result llama_init = llama_init_from_gpt_params(params); + llama_init = common_init_from_params(params_base); - model = llama_init.model; - ctx = llama_init.context; - lora_adapters = llama_init.lora_adapters; - params.n_parallel -= 1; // but be sneaky about it - if (model == nullptr) - { - LOG_ERROR("unable to load model", {{"model", params.model}}); + model = llama_init.model.get(); + ctx = llama_init.context.get(); + + if (model == nullptr) { + SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str()); return false; } + vocab = llama_model_get_vocab(model); + n_ctx = llama_n_ctx(ctx); - add_bos_token = llama_add_bos_token(model); - has_eos_token = !llama_add_eos_token(model); + add_bos_token = llama_vocab_get_add_bos(vocab); + has_eos_token = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL; + + if (!params_base.speculative.model.empty()) { + SRV_INF("loading draft model '%s'\n", params_base.speculative.model.c_str()); + + auto params_dft = params_base; + + params_dft.devices = params_base.speculative.devices; + params_dft.model = params_base.speculative.model; + params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx; + params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers; + params_dft.n_parallel = 1; + + llama_init_dft = common_init_from_params(params_dft); + + model_dft = llama_init_dft.model.get(); + + if (model_dft == nullptr) { + SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.c_str()); + return false; + } + + if (!common_speculative_are_compatible(ctx, llama_init_dft.context.get())) { + SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n", params_base.speculative.model.c_str(), params_base.model.c_str()); + + return false; + } + + const int n_ctx_dft = llama_n_ctx(llama_init_dft.context.get()); + + cparams_dft = common_context_params_to_llama(params_dft); + cparams_dft.n_batch = n_ctx_dft; + + // force F16 KV cache for the draft model for extra performance + cparams_dft.type_k = GGML_TYPE_F16; + cparams_dft.type_v = GGML_TYPE_F16; + } return true; } - bool validate_model_chat_template() const - { + bool validate_builtin_chat_template() const { llama_chat_message chat[] = {{"user", "test"}}; - - const int res = llama_chat_apply_template(model, nullptr, chat, 1, true, nullptr, 0); - - return res > 0; + const char * tmpl = llama_model_chat_template(model); + const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0); + return chat_res > 0; } - void init() - { - const int32_t n_ctx_slot = n_ctx / params.n_parallel; + void init() { + const int32_t n_ctx_slot = n_ctx / params_base.n_parallel; - LOG_INFO("initializing slots", {{"n_slots", params.n_parallel}}); + SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel); - for (int i = 0; i < params.n_parallel; i++) - { + for (int i = 0; i < params_base.n_parallel; i++) { server_slot slot; slot.id = i; + slot.ctx = ctx; slot.n_ctx = n_ctx_slot; - slot.n_predict = params.n_predict; + slot.n_predict = params_base.n_predict; - LOG_INFO("new slot", {{"id_slot", slot.id}, {"n_ctx_slot", slot.n_ctx}}); + if (model_dft) { + slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); - const int ga_n = params.grp_attn_n; - const int ga_w = params.grp_attn_w; - - if (ga_n != 1) - { - GGML_ASSERT(ga_n > 0 && "ga_n must be positive"); // NOLINT - GGML_ASSERT(ga_w % ga_n == 0 && "ga_w must be a multiple of ga_n"); // NOLINT - // GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT - // GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT + slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft); + if (slot.ctx_dft == nullptr) { + SRV_ERR("%s", "failed to create draft context\n"); + return; + } - LOG_INFO("slot self-extend", {{"id_slot", slot.id}, {"ga_n", ga_n}, {"ga_w", ga_w}}); + slot.spec = common_speculative_init(slot.ctx_dft); + if (slot.spec == nullptr) { + SRV_ERR("%s", "failed to create speculator\n"); + return; + } } - slot.ga_i = 0; - slot.ga_n = ga_n; - slot.ga_w = ga_w; + SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx); - slot.sparams = params.sparams; + slot.params.sampling = params_base.sampling; slot.callback_on_release = [this](int) { queue_tasks.pop_deferred_task(); @@ -772,8 +1807,7 @@ struct server_context slots.push_back(slot); } - default_generation_settings_for_props = get_formated_generation(slots.front()); - default_generation_settings_for_props["seed"] = -1; + default_generation_settings_for_props = slots[0].to_json(); // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) @@ -781,71 +1815,15 @@ struct server_context const int32_t n_batch = llama_n_batch(ctx); // only a single seq_id per token is needed - batch = llama_batch_init(std::max(n_batch, params.n_parallel), 0, 1); + batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1); } metrics.init(); } - std::vector tokenize(const json &json_prompt, bool add_special) const - { - // TODO: currently, we tokenize using special tokens by default - // this is not always correct (see https://github.com/ggerganov/llama.cpp/pull/4160#issuecomment-1824826216) - // but it's better compared to completely ignoring ChatML and other chat templates - const bool TMP_FORCE_SPECIAL = true; - - // If `add_bos` is true, we only add BOS, when json_prompt is a string, - // or the first element of the json_prompt array is a string. - std::vector prompt_tokens; - - if (json_prompt.is_array()) - { - bool first = true; - for (const auto &p : json_prompt) - { - if (p.is_string()) - { - auto s = p.template get(); - - std::vector p; - if (first) - { - p = ::llama_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL); - first = false; - } - else - { - p = ::llama_tokenize(ctx, s, false, TMP_FORCE_SPECIAL); - } - - prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end()); - } - else - { - if (first) - { - first = false; - } - - prompt_tokens.push_back(p.template get()); - } - } - } - else - { - auto s = json_prompt.template get(); - prompt_tokens = ::llama_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL); - } - - return prompt_tokens; - } - - server_slot *get_slot_by_id(int id) - { - for (server_slot &slot : slots) - { - if (slot.id == id) - { + server_slot * get_slot_by_id(int id) { + for (server_slot & slot : slots) { + if (slot.id == id) { return &slot; } } @@ -853,309 +1831,100 @@ struct server_context return nullptr; } - server_slot *get_available_slot(const std::string &prompt) - { - server_slot *ret = nullptr; + server_slot * get_available_slot(const server_task & task) { + server_slot * ret = nullptr; // find the slot that has at least n% prompt similarity - if (ret == nullptr && slot_prompt_similarity != 0.0f && !prompt.empty()) - { - int max_lcp_len = 0; + if (ret == nullptr && slot_prompt_similarity != 0.0f) { + int lcs_len = 0; float similarity = 0; - for (server_slot &slot : slots) - { + for (server_slot & slot : slots) { // skip the slot if it is not available if (slot.is_processing()) { continue; } - // skip the slot if it does not contains prompt - if (!slot.prompt.is_string()) - { + // skip the slot if it does not contains cached tokens + if (slot.cache_tokens.empty()) { continue; } - // current slot's prompt - std::string slot_prompt = slot.prompt.get(); - - // length of the current slot's prompt - int slot_prompt_len = slot_prompt.size(); - - // length of the Longest Common Prefix between the current slot's prompt and the input prompt - int lcp_len = common_part(slot_prompt, prompt); + // length of the Longest Common Subsequence between the current slot's prompt and the input prompt + int cur_lcs_len = common_lcs(slot.cache_tokens, task.prompt_tokens); - // fraction of the common substring length compared to the current slot's prompt length - similarity = static_cast(lcp_len) / slot_prompt_len; + // fraction of the common subsequence length compared to the current slot's prompt length + float cur_similarity = static_cast(cur_lcs_len) / static_cast(slot.cache_tokens.size()); // select the current slot if the criteria match - if (lcp_len > max_lcp_len && similarity > slot_prompt_similarity) - { - max_lcp_len = lcp_len; + if (cur_lcs_len > lcs_len && cur_similarity > slot_prompt_similarity) { + lcs_len = cur_lcs_len; + similarity = cur_similarity; ret = &slot; } } - if (ret != nullptr) - { - LOG_VERBOSE("selected slot by lcp similarity", { - {"id_slot", ret->id}, - {"max_lcp_len", max_lcp_len}, - {"similarity", similarity}, - }); + if (ret != nullptr) { + SLT_DBG(*ret, "selected slot by lcs similarity, lcs_len = %d, similarity = %f\n", lcs_len, similarity); } } // find the slot that has been least recently used - if (ret == nullptr) - { + if (ret == nullptr) { int64_t t_last = ggml_time_us(); - for (server_slot &slot : slots) - { - // skip the slot if it is not available - if (slot.is_processing()) { - continue; - } - - // select the current slot if the criteria match - if (slot.t_last_used < t_last) - { - t_last = slot.t_last_used; - ret = &slot; - } - } - - if (ret != nullptr) - { - LOG_VERBOSE("selected slot by lru", { - {"id_slot", ret->id}, - {"t_last", t_last}, - }); - } - } - - return ret; - } - - bool launch_slot_with_task(server_slot &slot, const server_task &task) - { - slot_params default_params; - // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them) - auto default_sparams = params.sparams; - const auto & data = task.data; - - slot.oaicompat = false; - slot.oaicompat_model = ""; - - slot.params.stream = json_value(data, "stream", false); - slot.params.cache_prompt = json_value(data, "cache_prompt", false); - slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict)); - slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k); - slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p); - slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p); - slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z); - slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p); - slot.sparams.temp = json_value(data, "temperature", default_sparams.temp); - slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range); - slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent); - slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n); - slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat); - slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq); - slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present); - slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat); - slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau); - slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta); - slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl); - slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep); - slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard); - slot.sparams.seed = json_value(data, "seed", default_sparams.seed); - slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); - slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep); - slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar); - - if (slot.params.cache_prompt && slot.ga_n != 1) - { - LOG_WARNING("cache_prompt is not supported with group-attention", {}); - slot.params.cache_prompt = false; - } - - if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) - { - // Might be better to reject the request with a 400 ? - LOG_WARNING("Max tokens to predict exceeds server configuration", - { - {"params.n_predict", slot.params.n_predict}, - {"slot.n_predict", slot.n_predict}, - }); - slot.params.n_predict = slot.n_predict; - } - - // infill - slot.params.input_prefix = json_value(data, "input_prefix", default_params.input_prefix); - slot.params.input_suffix = json_value(data, "input_suffix", default_params.input_suffix); - - // get prompt - if (task.cmpl_type != SERVER_TASK_CMPL_TYPE_INFILL) { - const auto &prompt = data.find("prompt"); - if (prompt == data.end()) - { - send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST); - return false; - } - - if ((prompt->is_string()) || (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) || - (prompt->is_array() && !prompt->empty() && prompt->at(0).is_number_integer())) - { - slot.prompt = *prompt; - } else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) { - slot.prompt = prompt->at(0); - } else { - send_error(task, "\"prompt\" must be a string or an array of integers", ERROR_TYPE_INVALID_REQUEST); - return false; - } - } - - // penalize user-provided tokens - { - slot.sparams.penalty_prompt_tokens.clear(); - slot.sparams.use_penalty_prompt_tokens = false; - - const auto &penalty_prompt = data.find("penalty_prompt"); - - if (penalty_prompt != data.end()) - { - if (penalty_prompt->is_string()) - { - const auto penalty_prompt_string = penalty_prompt->get(); - slot.sparams.penalty_prompt_tokens = llama_tokenize(model, penalty_prompt_string, false); - - if (slot.params.n_predict > 0) - { - slot.sparams.penalty_prompt_tokens.reserve(slot.sparams.penalty_prompt_tokens.size() + - slot.params.n_predict); - } - slot.sparams.use_penalty_prompt_tokens = true; - - LOG_VERBOSE("penalty_prompt_tokens", { - {"id_slot", slot.id}, - {"tokens", slot.sparams.penalty_prompt_tokens}, - }); + for (server_slot & slot : slots) { + // skip the slot if it is not available + if (slot.is_processing()) { + continue; } - else if (penalty_prompt->is_array()) - { - const auto n_tokens = penalty_prompt->size(); - slot.sparams.penalty_prompt_tokens.reserve(n_tokens + std::max(0, slot.params.n_predict)); - - const int n_vocab = llama_n_vocab(model); - for (const auto &penalty_token : *penalty_prompt) - { - if (penalty_token.is_number_integer()) - { - const auto tok = penalty_token.get(); - if (tok >= 0 && tok < n_vocab) - { - slot.sparams.penalty_prompt_tokens.push_back(tok); - } - } - } - slot.sparams.use_penalty_prompt_tokens = true; - LOG_VERBOSE("penalty_prompt_tokens", { - {"id_slot", slot.id}, - {"tokens", slot.sparams.penalty_prompt_tokens}, - }); + // select the current slot if the criteria match + if (slot.t_last_used < t_last) { + t_last = slot.t_last_used; + ret = &slot; } } - } - - { - slot.sparams.logit_bias.clear(); - if (json_value(data, "ignore_eos", false) && has_eos_token) { - slot.sparams.logit_bias.push_back({llama_token_eos(model), -INFINITY}); + if (ret != nullptr) { + SLT_DBG(*ret, "selected slot by lru, t_last = %" PRId64 "\n", t_last); } + } - const auto &logit_bias = data.find("logit_bias"); - if (logit_bias != data.end() && logit_bias->is_array()) - { - const int n_vocab = llama_n_vocab(model); - for (const auto &el : *logit_bias) - { - // TODO: we may want to throw errors here, in case "el" is incorrect - if (el.is_array() && el.size() == 2) - { - float bias; - if (el[1].is_number()) - { - bias = el[1].get(); - } - else if (el[1].is_boolean() && !el[1].get()) - { - bias = -INFINITY; - } - else - { - continue; - } + return ret; + } - if (el[0].is_number_integer()) - { - llama_token tok = el[0].get(); - if (tok >= 0 && tok < n_vocab) - { - slot.sparams.logit_bias.push_back({tok, bias}); - } - } - else if (el[0].is_string()) - { - auto toks = llama_tokenize(model, el[0].get(), false); - for (auto tok : toks) - { - slot.sparams.logit_bias.push_back({tok, bias}); - } - } - } - } - } + bool launch_slot_with_task(server_slot & slot, const server_task & task) { + slot.reset(); + slot.id_task = task.id; + slot.index = task.index; + slot.task_type = task.type; + slot.params = std::move(task.params); + slot.prompt_tokens = std::move(task.prompt_tokens); + + if (!are_lora_equal(task.params.lora, slot.lora)) { + // if lora is changed, we cannot reuse cached tokens + slot.cache_tokens.clear(); + slot.lora = task.params.lora; } - { - slot.params.antiprompt.clear(); + SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str()); - const auto &stop = data.find("stop"); - if (stop != data.end() && stop->is_array()) - { - for (const auto &word : *stop) - { - if (!word.empty()) - { - slot.params.antiprompt.push_back(word); - } - } - } + if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) { + // Might be better to reject the request with a 400 ? + slot.params.n_predict = slot.n_predict; + SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.n_predict, slot.n_predict); } - { - const auto & samplers = data.find("samplers"); - if (samplers != data.end() && samplers->is_array()) { - std::vector sampler_names; - for (const auto & name : *samplers) { - if (name.is_string()) { - sampler_names.emplace_back(name); - } - } - slot.sparams.samplers = gpt_sampler_types_from_names(sampler_names, false); - } else { - slot.sparams.samplers = default_sparams.samplers; - } + if (slot.params.ignore_eos && has_eos_token) { + slot.params.sampling.logit_bias.push_back({llama_vocab_eos(vocab), -INFINITY}); } { if (slot.smpl != nullptr) { - gpt_sampler_free(slot.smpl); + common_sampler_free(slot.smpl); } - slot.smpl = gpt_sampler_init(model, slot.sparams); + slot.smpl = common_sampler_init(model, slot.params.sampling); if (slot.smpl == nullptr) { // for now, the only error that may happen here is invalid grammar send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); @@ -1163,495 +1932,403 @@ struct server_context } } - slot.state = SLOT_STATE_PROCESSING_PROMPT; - slot.prompt_tokens.clear(); + if (slot.ctx_dft) { + llama_batch_free(slot.batch_spec); + + slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1); + } + + slot.state = SLOT_STATE_STARTED; - LOG_INFO("slot is processing task", { - {"id_slot", slot.id}, - {"id_task", slot.id_task}, - }); + SLT_INF(slot, "%s", "processing task\n"); return true; } - void kv_cache_clear() - { - LOG_VERBOSE("clearing KV cache", {}); + void kv_cache_clear() { + SRV_DBG("%s", "clearing KV cache\n"); // clear the entire KV cache llama_kv_cache_clear(ctx); clean_kv_cache = false; } - void system_prompt_update() - { - LOG_VERBOSE("system prompt update", { - {"system_prompt", system_prompt}, - }); - - kv_cache_clear(); - system_tokens.clear(); - - if (!system_prompt.empty()) - { - system_tokens = ::llama_tokenize(ctx, system_prompt, true); - - const int32_t n_batch = llama_n_batch(ctx); - const int32_t n_tokens_prompt = system_tokens.size(); - - for (int32_t i = 0; i < n_tokens_prompt; i += n_batch) { - const int32_t n_tokens = std::min(n_batch, n_tokens_prompt - i); - - llama_batch_clear(batch); - - for (int32_t j = 0; j < n_tokens; ++j) { - llama_batch_add(batch, system_tokens[i + j], i + j, { 0 }, false); - } - - if (llama_decode(ctx, batch) != 0) { - LOG_ERROR("llama_decode() failed", {}); - return; - } - } - - // assign the system KV cache to all parallel sequences - for (int32_t i = 1; i <= params.n_parallel; ++i) - { - llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); - } - } - - system_need_update = false; - } - - bool system_prompt_set(const std::string &sys_prompt) - { - system_prompt = sys_prompt; - - LOG_VERBOSE("system prompt process", { - {"system_prompt", system_prompt}, - }); - - // release all slots - for (server_slot &slot : slots) - { - slot.release(); - } - - system_need_update = true; - return true; - } - - bool process_token(completion_token_output &result, server_slot &slot) - { + bool process_token(completion_token_output & result, server_slot & slot) { // remember which tokens were sampled - used for repetition penalties during sampling - const std::string token_str = llama_token_to_piece(ctx, result.tok, params.special); + const std::string token_str = result.text_to_send; slot.sampled = result.tok; - // search stop word and delete it slot.generated_text += token_str; + if (slot.params.return_tokens) { + slot.generated_tokens.push_back(result.tok); + } slot.has_next_token = true; // check if there is incomplete UTF-8 character at the end - bool incomplete = false; - for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) - { - unsigned char c = slot.generated_text[slot.generated_text.size() - i]; - if ((c & 0xC0) == 0x80) - { - // continuation byte: 10xxxxxx - continue; - } - if ((c & 0xE0) == 0xC0) - { - // 2-byte character: 110xxxxx ... - incomplete = i < 2; - } - else if ((c & 0xF0) == 0xE0) - { - // 3-byte character: 1110xxxx ... - incomplete = i < 3; - } - else if ((c & 0xF8) == 0xF0) - { - // 4-byte character: 11110xxx ... - incomplete = i < 4; - } - // else 1-byte character or invalid byte - break; - } + bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size(); - if (!incomplete) - { + // search stop word and delete it + if (!incomplete) { size_t pos = std::min(slot.n_sent_text, slot.generated_text.size()); const std::string str_test = slot.generated_text.substr(pos); - bool is_stop_full = false; + bool send_text = true; - size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), STOP_TYPE_FULL); - if (stop_pos != std::string::npos) - { - is_stop_full = true; - slot.generated_text.erase(slot.generated_text.begin() + pos + stop_pos, slot.generated_text.end()); + size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), true); + if (stop_pos != std::string::npos) { + slot.generated_text.erase( + slot.generated_text.begin() + pos + stop_pos, + slot.generated_text.end()); pos = std::min(slot.n_sent_text, slot.generated_text.size()); - } - else - { - is_stop_full = false; - stop_pos = slot.find_stopping_strings(str_test, token_str.size(), STOP_TYPE_PARTIAL); + } else if (slot.has_next_token) { + stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false); + send_text = stop_pos == std::string::npos; } // check if there is any token to predict - if (stop_pos == std::string::npos || (!slot.has_next_token && !is_stop_full && stop_pos > 0)) - { + if (send_text) { // no send the stop word in the response result.text_to_send = slot.generated_text.substr(pos, std::string::npos); slot.n_sent_text += result.text_to_send.size(); // add the token to slot queue and cache + } else { + result.text_to_send = ""; } - slot.add_token_string(result); - if (slot.params.stream) - { + slot.add_token(result); + if (slot.params.stream) { send_partial_response(slot, result); } } - if (incomplete) - { + if (incomplete) { slot.has_next_token = true; } // check the limits - if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params)) - { - slot.stopped_limit = true; + if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) { + slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; - LOG_VERBOSE("stopped by limit", { - {"id_slot", slot.id}, - {"id_task", slot.id_task}, - {"n_decoded", slot.n_decoded}, - {"n_predict", slot.params.n_predict}, - }); + SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict); } - if (llama_token_is_eog(model, result.tok)) - { - slot.stopped_eos = true; + if (slot.has_new_line) { + // if we have already seen a new line, we stop after a certain time limit + if (slot.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) { + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; + + SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms); + } + + // require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent + if (slot.params.n_indent > 0) { + // check the current indentation + // TODO: improve by not doing it more than once for each new line + if (slot.last_nl_pos > 0) { + size_t pos = slot.last_nl_pos; + + int n_indent = 0; + while (pos < slot.generated_text.size() && (slot.generated_text[pos] == ' ' || slot.generated_text[pos] == '\t')) { + n_indent++; + pos++; + } + + if (pos < slot.generated_text.size() && n_indent < slot.params.n_indent) { + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; + + // cut the last line + slot.generated_text.erase(pos, std::string::npos); + + SLT_DBG(slot, "stopped by indentation limit, n_decoded = %d, n_indent = %d\n", slot.n_decoded, n_indent); + } + } + + // find the next new line + { + const size_t pos = slot.generated_text.find('\n', slot.last_nl_pos); + + if (pos != std::string::npos) { + slot.last_nl_pos = pos + 1; + } + } + } + } + + // check if there is a new line in the generated text + if (result.text_to_send.find('\n') != std::string::npos) { + slot.has_new_line = true; + } + + // if context shift is disabled, we stop when it reaches the context limit + if (slot.n_past >= slot.n_ctx) { + slot.truncated = true; + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; + + SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n", + slot.n_decoded, slot.n_prompt_tokens, slot.n_past, slot.n_ctx); + } + + if (llama_vocab_is_eog(vocab, result.tok)) { + slot.stop = STOP_TYPE_EOS; slot.has_next_token = false; - LOG_VERBOSE("eos token found", {}); - } - - auto n_ctx_train = llama_n_ctx_train(model); - if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.ga_n == 1 - && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) { - LOG_WARNING("n_predict is not set and self-context extend is disabled." - " Limiting generated tokens to n_ctx_train to avoid EOS-less generation infinite loop", - { - {"id_slot", slot.id}, - {"params.n_predict", slot.params.n_predict}, - {"slot.n_prompt_tokens", slot.n_prompt_tokens}, - {"slot.n_decoded", slot.n_decoded}, - {"slot.n_predict", slot.n_predict}, - {"n_slots", params.n_parallel}, - {"slot.n_ctx", slot.n_ctx}, - {"n_ctx", n_ctx}, - {"n_ctx_train", n_ctx_train}, - {"ga_n", slot.ga_n}, - }); - slot.truncated = true; - slot.stopped_limit = true; + SLT_DBG(slot, "%s", "stopped by EOS\n"); + } + + const auto n_ctx_train = llama_model_n_ctx_train(model); + + if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) { + slot.truncated = true; + slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; // stop prediction + + SLT_WRN(slot, + "n_predict (%d) is set for infinite generation. " + "Limiting generated tokens to n_ctx_train (%d) to avoid EOS-less generation infinite loop\n", + slot.params.n_predict, n_ctx_train); } - LOG_VERBOSE("next token", { - {"id_slot", slot.id}, - {"id_task", slot.id_task}, - {"token", result.tok}, - {"token_text", tokens_to_output_formatted_string(ctx, result.tok)}, - {"has_next_token", slot.has_next_token}, - {"n_remain", slot.n_remaining}, - {"n_decoded", slot.n_decoded}, - {"stopped_eos", slot.stopped_eos}, - {"stopped_word", slot.stopped_word}, - {"stopped_limit", slot.stopped_limit}, - {"stopping_word", slot.stopping_word}, - }); + SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str()); return slot.has_next_token; // continue } - json get_formated_generation(const server_slot & slot) const { - std::vector samplers; - samplers.reserve(slot.sparams.samplers.size()); - for (const auto & sampler : slot.sparams.samplers) { - samplers.emplace_back(gpt_sampler_type_to_str(sampler)); - } - - return json{{"n_ctx", slot.n_ctx}, - {"n_predict", slot.n_predict}, // Server configured n_predict - {"model", params.model_alias}, - {"seed", slot.sparams.seed}, - {"seed_cur", slot.smpl ? gpt_sampler_get_seed(slot.smpl) : 0}, - {"temperature", slot.sparams.temp}, - {"dynatemp_range", slot.sparams.dynatemp_range}, - {"dynatemp_exponent", slot.sparams.dynatemp_exponent}, - {"top_k", slot.sparams.top_k}, - {"top_p", slot.sparams.top_p}, - {"min_p", slot.sparams.min_p}, - {"tfs_z", slot.sparams.tfs_z}, - {"typical_p", slot.sparams.typ_p}, - {"repeat_last_n", slot.sparams.penalty_last_n}, - {"repeat_penalty", slot.sparams.penalty_repeat}, - {"presence_penalty", slot.sparams.penalty_present}, - {"frequency_penalty", slot.sparams.penalty_freq}, - {"mirostat", slot.sparams.mirostat}, - {"mirostat_tau", slot.sparams.mirostat_tau}, - {"mirostat_eta", slot.sparams.mirostat_eta}, - {"penalize_nl", slot.sparams.penalize_nl}, - {"stop", slot.params.antiprompt}, - {"max_tokens", slot.params.n_predict}, // User configured n_predict - {"n_keep", slot.params.n_keep}, - {"n_discard", slot.params.n_discard}, - {"ignore_eos", slot.sparams.ignore_eos}, - {"stream", slot.params.stream}, - //{"logit_bias", slot.sparams.logit_bias}, - {"n_probs", slot.sparams.n_probs}, - {"min_keep", slot.sparams.min_keep}, - {"grammar", slot.sparams.grammar}, - {"samplers", samplers}, - }; + void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) { + size_t n_probs = slot.params.sampling.n_probs; + size_t n_vocab = llama_vocab_n_tokens(vocab); + if (post_sampling) { + const auto * cur_p = common_sampler_get_candidates(slot.smpl); + const size_t max_probs = cur_p->size; + + // set probability for sampled token + for (size_t i = 0; i < max_probs; i++) { + if (cur_p->data[i].id == result.tok) { + result.prob = cur_p->data[i].p; + break; + } + } + + // set probability for top n_probs tokens + result.probs.reserve(max_probs); + for (size_t i = 0; i < std::min(max_probs, n_probs); i++) { + result.probs.push_back({ + cur_p->data[i].id, + common_detokenize(ctx, {cur_p->data[i].id}, special), + cur_p->data[i].p + }); + } + } else { + // TODO: optimize this with min-p optimization + std::vector cur = get_token_probabilities(ctx, idx); + + // set probability for sampled token + for (size_t i = 0; i < n_vocab; i++) { + // set probability for sampled token + if (cur[i].id == result.tok) { + result.prob = cur[i].p; + break; + } + } + + // set probability for top n_probs tokens + result.probs.reserve(n_probs); + for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) { + result.probs.push_back({ + cur[i].id, + common_detokenize(ctx, {cur[i].id}, special), + cur[i].p + }); + } + } } - void send_error(const server_task &task, const std::string &error, const enum error_type type = ERROR_TYPE_SERVER) - { + void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { send_error(task.id, error, type); } - void send_error(const server_slot &slot, const std::string &error, const enum error_type type = ERROR_TYPE_SERVER) - { + void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { send_error(slot.id_task, error, type); } void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { - LOG_ERROR("task error", { - {"id_task", id_task}, - {"error", error}, - }); + SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str()); - server_task_result res; - res.id = id_task; - res.stop = false; - res.error = true; - res.data = format_error_response(error, type); + auto res = std::make_unique(); + res->id = id_task; + res->err_type = type; + res->err_msg = error; - queue_results.send(res); + queue_results.send(std::move(res)); } - void send_partial_response(server_slot &slot, completion_token_output tkn) - { - server_task_result res; - res.id = slot.id_task; - res.error = false; - res.stop = false; - res.data = json { - {"content", tkn.text_to_send}, - {"stop", false}, - {"id_slot", slot.id}, - {"multimodal", false}, - {"index", slot.index}, - }; + void send_partial_response(server_slot & slot, const completion_token_output & tkn) { + auto res = std::make_unique(); - if (slot.sparams.n_probs > 0) - { - const std::vector to_send_toks = llama_tokenize(ctx, tkn.text_to_send, false); - const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size()); - const size_t probs_stop_pos = - std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size()); + res->id = slot.id_task; + res->index = slot.index; + res->content = tkn.text_to_send; + res->tokens = { tkn.tok }; - std::vector probs_output; - if (probs_pos < probs_stop_pos) - { - probs_output = - std::vector(slot.generated_token_probs.begin() + probs_pos, - slot.generated_token_probs.begin() + probs_stop_pos); - } - slot.n_sent_token_probs = probs_stop_pos; + res->n_decoded = slot.n_decoded; + res->n_prompt_tokens = slot.n_prompt_tokens; + res->post_sampling_probs = slot.params.post_sampling_probs; - res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs_output); + res->verbose = slot.params.verbose; + res->oaicompat = slot.params.oaicompat; + res->oaicompat_model = slot.params.oaicompat_model; + res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; + + // populate res.probs_output + if (slot.params.sampling.n_probs > 0) { + res->prob_output = tkn; // copy the token probs } - if (slot.oaicompat) - { - res.data["oaicompat_token_ctr"] = slot.n_decoded; - res.data["model"] = slot.oaicompat_model; + // populate timings if this is final response or timings_per_token is enabled + if (slot.stop != STOP_TYPE_NONE || slot.params.timings_per_token) { + res->timings = slot.get_timings(); } - queue_results.send(res); + queue_results.send(std::move(res)); } - void send_final_response(const server_slot &slot) - { - server_task_result res; - res.id = slot.id_task; - res.error = false; - res.stop = true; - res.data = json{{"content", !slot.params.stream ? slot.generated_text : ""}, - {"id_slot", slot.id}, - {"stop", true}, - {"model", params.model_alias}, - {"tokens_predicted", slot.n_decoded}, - {"tokens_evaluated", slot.n_prompt_tokens}, - {"generation_settings", get_formated_generation(slot)}, - {"prompt", slot.prompt}, - {"truncated", slot.truncated}, - {"stopped_eos", slot.stopped_eos}, - {"stopped_word", slot.stopped_word}, - {"stopped_limit", slot.stopped_limit}, - {"stopping_word", slot.stopping_word}, - {"tokens_cached", slot.n_past}, - {"timings", slot.get_formated_timings()}, - {"index", slot.index}, - }; - - if (slot.sparams.n_probs > 0) - { - std::vector probs; - if (!slot.params.stream && slot.stopped_word) - { - const std::vector stop_word_toks = llama_tokenize(ctx, slot.stopping_word, false); + void send_final_response(server_slot & slot) { + auto res = std::make_unique(); + res->id = slot.id_task; + res->id_slot = slot.id; + + res->index = slot.index; + res->content = slot.generated_text; + res->tokens = slot.generated_tokens; + res->timings = slot.get_timings(); + res->prompt = common_detokenize(ctx, slot.prompt_tokens, true); + res->response_fields = slot.params.response_fields; + + res->truncated = slot.truncated; + res->n_decoded = slot.n_decoded; + res->n_prompt_tokens = slot.n_prompt_tokens; + res->n_tokens_cached = slot.n_past; + res->has_new_line = slot.has_new_line; + res->stopping_word = slot.stopping_word; + res->stop = slot.stop; + res->post_sampling_probs = slot.params.post_sampling_probs; + + res->verbose = slot.params.verbose; + res->stream = slot.params.stream; + res->oaicompat = slot.params.oaicompat; + res->oaicompat_model = slot.params.oaicompat_model; + res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; + + // populate res.probs_output + if (slot.params.sampling.n_probs > 0) { + if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) { + const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false); size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size()); - probs = std::vector(slot.generated_token_probs.begin(), - slot.generated_token_probs.end() - safe_offset); - } - else - { - probs = std::vector(slot.generated_token_probs.begin(), - slot.generated_token_probs.end()); + res->probs_output = std::vector( + slot.generated_token_probs.begin(), + slot.generated_token_probs.end() - safe_offset); + } else { + res->probs_output = std::vector( + slot.generated_token_probs.begin(), + slot.generated_token_probs.end()); } - - res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs); } - if (slot.oaicompat) - { - res.data["oaicompat_token_ctr"] = slot.n_decoded; - res.data["model"] = slot.oaicompat_model; - } + res->generation_params = slot.params; // copy the parameters - queue_results.send(res); + queue_results.send(std::move(res)); } - void send_embedding(const server_slot &slot, const llama_batch &batch) - { - server_task_result res; - res.id = slot.id_task; - res.error = false; - res.stop = true; + void send_embedding(const server_slot & slot, const llama_batch & batch) { + auto res = std::make_unique(); + res->id = slot.id_task; + res->index = slot.index; + res->n_tokens = slot.n_prompt_tokens; + res->oaicompat = slot.params.oaicompat; - const int n_embd = llama_n_embd(model); + const int n_embd = llama_model_n_embd(model); std::vector embd_res(n_embd, 0.0f); - for (int i = 0; i < batch.n_tokens; ++i) - { - if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) - { + for (int i = 0; i < batch.n_tokens; ++i) { + if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { continue; } - const float *embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); - if (embd == NULL) - { + const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + if (embd == NULL) { embd = llama_get_embeddings_ith(ctx, i); } - if (embd == NULL) - { - LOG_ERROR("failed to get embeddings", {{"token", batch.token[i]}, {"seq_id", batch.seq_id[i][0]}}); - - res.data = json{ - {"embedding", std::vector(n_embd, 0.0f)}, - }; + if (embd == NULL) { + SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); + res->embedding.push_back(std::vector(n_embd, 0.0f)); continue; } - llama_embd_normalize(embd, embd_res.data(), n_embd); - - res.data = json{ - {"embedding", embd_res}, - {"index", slot.index}, - }; + // normalize only when there is pooling + // TODO: configurable + if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) { + common_embd_normalize(embd, embd_res.data(), n_embd, 2); + res->embedding.push_back(embd_res); + } else { + res->embedding.push_back({ embd, embd + n_embd }); + } } - queue_results.send(res); + SLT_DBG(slot, "%s", "sending embeddings\n"); + + queue_results.send(std::move(res)); } - // - // Functions to create new task(s) and receive result(s) - // + void send_rerank(const server_slot & slot, const llama_batch & batch) { + auto res = std::make_unique(); + res->id = slot.id_task; + res->index = slot.index; + res->n_tokens = slot.n_prompt_tokens; - std::vector create_tasks_cmpl(json data, server_task_cmpl_type cmpl_type) { - std::vector tasks; - auto create_task = [&](json & task_data, bool replace_prompt, json prompt) { - server_task task; - task.id = queue_tasks.get_new_id(); - task.cmpl_type = cmpl_type; - task.type = SERVER_TASK_TYPE_COMPLETION; - if (replace_prompt) { - task.data = task_data; - task.data["prompt"] = prompt; - } else { - task.data = std::move(task_data); + for (int i = 0; i < batch.n_tokens; ++i) { + if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { + continue; } - tasks.push_back(std::move(task)); - }; - static constexpr const char * error_msg = "\"prompt\" must be a string, an array of token ids or an array of prompts"; - if (!data.contains("prompt")) { - throw std::runtime_error(error_msg); - } + const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + if (embd == NULL) { + embd = llama_get_embeddings_ith(ctx, i); + } - json prompt = data.at("prompt"); - - // if the prompt is a singleton (i.e. a string or a list of tokens), we only need to create single task - if (prompt.is_string() || json_is_array_of_numbers(prompt)) { - data["index"] = 0; - create_task(data, false, nullptr); - } - // otherwise, it's a multiple-prompt task, we break it into smaller tasks - else if (prompt.is_array()) { - std::vector prompts = prompt; - for (size_t i = 0; i < prompts.size(); i++) { - const auto & e = prompts[i]; - if (e.is_string() || json_is_array_of_numbers(e)) { - data["index"] = i; - create_task(data, true, e); - } else { - throw std::runtime_error(error_msg); - } + if (embd == NULL) { + SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); + + res->score = -1e6; + continue; } + + res->score = embd[0]; } - // invalid case - else { - throw std::runtime_error(error_msg); - } - return tasks; + SLT_DBG(slot, "sending rerank result, res.score = %f\n", res->score); + + queue_results.send(std::move(res)); } + // + // Functions to create new task(s) and receive result(s) + // + void cancel_tasks(const std::unordered_set & id_tasks) { std::vector cancel_tasks; cancel_tasks.reserve(id_tasks.size()); for (const auto & id_task : id_tasks) { - LOG_VERBOSE("cancel task", {{"id_task", id_task}}); - server_task task; - task.type = SERVER_TASK_TYPE_CANCEL; - task.id_target = id_task; + SRV_WRN("cancel task, id_task = %d\n", id_task); + + server_task task(SERVER_TASK_TYPE_CANCEL); + task.id_target = id_task; cancel_tasks.push_back(task); queue_results.remove_waiting_task_id(id_task); } @@ -1659,361 +2336,312 @@ struct server_context queue_tasks.post(cancel_tasks, true); } - // receive the results from task(s) created by create_tasks_cmpl - void receive_cmpl_results(const std::unordered_set & id_tasks, std::function&)> result_handler, std::function error_handler) { - // TODO: currently, there is no way to detect the client has cancelled the request - std::vector results(id_tasks.size()); - for (size_t i = 0; i < id_tasks.size(); i++) { - server_task_result result = queue_results.recv(id_tasks); + // receive the results from task(s) + void receive_multi_results( + const std::unordered_set & id_tasks, + const std::function&)> & result_handler, + const std::function & error_handler, + const std::function & is_connection_closed) { + std::vector results(id_tasks.size()); + for (int i = 0; i < (int)id_tasks.size(); i++) { + server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS); + + if (is_connection_closed()) { + cancel_tasks(id_tasks); + return; + } + + if (result == nullptr) { + i--; // retry + continue; + } - if (result.error) { - error_handler(result.data); + if (result->is_error()) { + error_handler(result->to_json()); cancel_tasks(id_tasks); - break; + return; } - size_t idx = result.data["index"]; - results[idx] = result; + GGML_ASSERT( + dynamic_cast(result.get()) != nullptr + || dynamic_cast(result.get()) != nullptr + || dynamic_cast(result.get()) != nullptr + ); + const size_t idx = result->get_index(); + GGML_ASSERT(idx < results.size() && "index out of range"); + results[idx] = std::move(result); } result_handler(results); } - // receive the results from task(s) created by create_tasks_cmpl, in stream mode - void receive_cmpl_results_stream(const std::unordered_set & id_tasks, std::function result_handler, std::function error_handler) { + // receive the results from task(s), in stream mode + void receive_cmpl_results_stream( + const std::unordered_set & id_tasks, + const std::function & result_handler, + const std::function & error_handler, + const std::function & is_connection_closed) { size_t n_finished = 0; while (true) { - server_task_result result = queue_results.recv(id_tasks); - if (!result_handler(result)) { - cancel_tasks(id_tasks); - break; - } + server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS); - if (result.error) { - error_handler(result.data); + if (is_connection_closed()) { cancel_tasks(id_tasks); - break; - } - - if (result.stop) { - if (++n_finished == id_tasks.size()) { - break; - } - } - } - } - - // - // Functions to process the task - // - - void process_single_task(const server_task & task) { - switch (task.type) { - case SERVER_TASK_TYPE_COMPLETION: - { - const int id_slot = json_value(task.data, "id_slot", -1); - - server_slot *slot; - - if (id_slot != -1) - { - slot = get_slot_by_id(id_slot); - } - else - { - std::string prompt; - if (task.data.contains("prompt") && task.data.at("prompt").is_string()) - { - prompt = json_value(task.data, "prompt", std::string()); - } - - slot = get_available_slot(prompt); - } - - if (slot == nullptr) - { - // if no slot is available, we defer this task for processing later - LOG_VERBOSE("no slot is available", {{"id_task", task.id}}); - queue_tasks.defer(task); - break; - } - if (slot->is_processing()) { - // if requested slot is unavailable, we defer this task for processing later - LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}}); - queue_tasks.defer(task); - break; - } - - if (task.data.contains("system_prompt")) - { - std::string sys_prompt = json_value(task.data, "system_prompt", std::string()); - system_prompt_set(sys_prompt); - - for (server_slot &slot : slots) - { - slot.n_past = 0; - slot.n_past_se = 0; - } + return; } - slot->reset(); + if (result == nullptr) { + continue; // retry + } - slot->id_task = task.id; - slot->cmpl_type = task.cmpl_type; - slot->index = json_value(task.data, "index", 0); + if (result->is_error()) { + error_handler(result->to_json()); + cancel_tasks(id_tasks); + return; + } - if (!launch_slot_with_task(*slot, task)) - { - LOG_ERROR("error while launching slot", task.data); + GGML_ASSERT( + dynamic_cast(result.get()) != nullptr + || dynamic_cast(result.get()) != nullptr + ); + if (!result_handler(result)) { + cancel_tasks(id_tasks); break; } - } - break; - case SERVER_TASK_TYPE_CANCEL: { - // release slot linked with the task id - for (auto &slot : slots) - { - if (slot.id_task == task.id_target) - { - slot.release(); + + if (result->is_stop()) { + if (++n_finished == id_tasks.size()) { break; } } } - break; - case SERVER_TASK_TYPE_NEXT_RESPONSE: { - // do nothing - } - break; - case SERVER_TASK_TYPE_METRICS: { - json slots_data = json::array(); + } - int n_idle_slots = 0; - int n_processing_slots = 0; + // + // Functions to process the task + // - for (server_slot &slot : slots) - { - json slot_data = get_formated_generation(slot); - slot_data["id"] = slot.id; - slot_data["id_task"] = slot.id_task; - slot_data["state"] = slot.state; - slot_data["prompt"] = slot.prompt; - slot_data["next_token"] = { - {"has_next_token", slot.has_next_token}, {"n_remain", slot.n_remaining}, - {"n_decoded", slot.n_decoded}, {"stopped_eos", slot.stopped_eos}, - {"stopped_word", slot.stopped_word}, {"stopped_limit", slot.stopped_limit}, - {"stopping_word", slot.stopping_word}, - }; - - if (slot_data["state"] == SLOT_STATE_IDLE) + void process_single_task(server_task task) { + switch (task.type) { + case SERVER_TASK_TYPE_COMPLETION: + case SERVER_TASK_TYPE_INFILL: + case SERVER_TASK_TYPE_EMBEDDING: + case SERVER_TASK_TYPE_RERANK: { - n_idle_slots++; - } - else + const int id_slot = task.id_selected_slot; + + server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task); + + if (slot == nullptr) { + // if no slot is available, we defer this task for processing later + SRV_DBG("no slot is available, defer task, id_task = %d\n", task.id); + queue_tasks.defer(task); + break; + } + if (slot->is_processing()) { + // if requested slot is unavailable, we defer this task for processing later + SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); + queue_tasks.defer(task); + break; + } + + if (!launch_slot_with_task(*slot, task)) { + SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id); + break; + } + } break; + case SERVER_TASK_TYPE_CANCEL: { - n_processing_slots++; - } + // release slot linked with the task id + for (auto & slot : slots) { + if (slot.id_task == task.id_target) { + slot.release(); + break; + } + } + } break; + case SERVER_TASK_TYPE_NEXT_RESPONSE: + { + // do nothing + } break; + case SERVER_TASK_TYPE_METRICS: + { + json slots_data = json::array(); - slots_data.push_back(slot_data); - } - LOG_INFO("slot data", { - {"id_task", task.id}, - {"n_idle_slots", n_idle_slots}, - {"n_processing_slots", n_processing_slots} - }); + int n_idle_slots = 0; + int n_processing_slots = 0; - LOG_VERBOSE("slot data", {{"id_task", task.id}, - {"n_idle_slots", n_idle_slots}, - {"n_processing_slots", n_processing_slots}, - {"slots", slots_data}}); - - server_task_result res; - res.id = task.id; - res.stop = true; - res.error = false; - res.data = { - {"idle", n_idle_slots}, - {"processing", n_processing_slots}, - {"deferred", queue_tasks.queue_tasks_deferred.size()}, - {"t_start", metrics.t_start}, - - {"n_prompt_tokens_processed_total", metrics.n_prompt_tokens_processed_total}, - {"t_tokens_generation_total", metrics.t_tokens_generation_total}, - {"n_tokens_predicted_total", metrics.n_tokens_predicted_total}, - {"t_prompt_processing_total", metrics.t_prompt_processing_total}, - - {"n_prompt_tokens_processed", metrics.n_prompt_tokens_processed}, - {"t_prompt_processing", metrics.t_prompt_processing}, - {"n_tokens_predicted", metrics.n_tokens_predicted}, - {"t_tokens_generation", metrics.t_tokens_generation}, - - { "n_decode_total", metrics.n_decode_total}, - { "n_busy_slots_total", metrics.n_busy_slots_total}, - - {"kv_cache_tokens_count", llama_get_kv_cache_token_count(ctx)}, - {"kv_cache_used_cells", llama_get_kv_cache_used_cells(ctx)}, - - {"slots", slots_data}, - }; + for (server_slot & slot : slots) { + json slot_data = slot.to_json(); - if (json_value(task.data, "reset_bucket", false)) - { - metrics.reset_bucket(); - } - queue_results.send(res); - } - break; - case SERVER_TASK_TYPE_SLOT_SAVE: { - int id_slot = task.data.at("id_slot"); - server_slot *slot = get_slot_by_id(id_slot); - if (slot == nullptr) - { - send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); - break; - } + if (slot.is_processing()) { + n_processing_slots++; + } else { + n_idle_slots++; + } + + slots_data.push_back(slot_data); + } + SRV_DBG("n_idle_slots = %d, n_processing_slots = %d\n", n_idle_slots, n_processing_slots); + + auto res = std::make_unique(); + res->id = task.id; + res->slots_data = std::move(slots_data); + res->n_idle_slots = n_idle_slots; + res->n_processing_slots = n_processing_slots; + res->n_tasks_deferred = queue_tasks.queue_tasks_deferred.size(); + res->t_start = metrics.t_start; + + res->kv_cache_tokens_count = llama_get_kv_cache_token_count(ctx); + res->kv_cache_used_cells = llama_get_kv_cache_used_cells(ctx); + + res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total; + res->t_prompt_processing_total = metrics.t_prompt_processing_total; + res->n_tokens_predicted_total = metrics.n_tokens_predicted_total; + res->t_tokens_generation_total = metrics.t_tokens_generation_total; + + res->n_prompt_tokens_processed = metrics.n_prompt_tokens_processed; + res->t_prompt_processing = metrics.t_prompt_processing; + res->n_tokens_predicted = metrics.n_tokens_predicted; + res->t_tokens_generation = metrics.t_tokens_generation; + + res->n_decode_total = metrics.n_decode_total; + res->n_busy_slots_total = metrics.n_busy_slots_total; + + if (task.metrics_reset_bucket) { + metrics.reset_bucket(); + } + queue_results.send(std::move(res)); + } break; + case SERVER_TASK_TYPE_SLOT_SAVE: + { + int id_slot = task.slot_action.slot_id; + server_slot * slot = get_slot_by_id(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } if (slot->is_processing()) { - // if requested slot is unavailable, we defer this task for processing later - LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}}); - queue_tasks.defer(task); - break; - } + // if requested slot is unavailable, we defer this task for processing later + SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); + queue_tasks.defer(task); + break; + } - const size_t token_count = slot->cache_tokens.size(); - const int64_t t_start = ggml_time_us(); + const size_t token_count = slot->cache_tokens.size(); + const int64_t t_start = ggml_time_us(); - std::string filename = task.data.at("filename"); - std::string filepath = task.data.at("filepath"); + std::string filename = task.slot_action.filename; + std::string filepath = task.slot_action.filepath; - const size_t nwrite = - llama_state_seq_save_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), token_count); + const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count); - const int64_t t_end = ggml_time_us(); - const double t_save_ms = (t_end - t_start) / 1000.0; + const int64_t t_end = ggml_time_us(); + const double t_save_ms = (t_end - t_start) / 1000.0; - server_task_result result; - result.id = task.id; - result.stop = true; - result.error = false; - result.data = json{{"id_slot", id_slot}, - {"filename", filename}, - {"n_saved", token_count}, // tokens saved - {"n_written", nwrite}, // bytes written - {"timings", {{"save_ms", t_save_ms}}}}; - queue_results.send(result); - } - break; - case SERVER_TASK_TYPE_SLOT_RESTORE: { - int id_slot = task.data.at("id_slot"); - server_slot *slot = get_slot_by_id(id_slot); - if (slot == nullptr) - { - send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); - break; - } + auto res = std::make_unique(); + res->id = task.id; + res->id_slot = id_slot; + res->filename = filename; + res->is_save = true; + res->n_tokens = token_count; + res->n_bytes = nwrite; + res->t_ms = t_save_ms; + queue_results.send(std::move(res)); + } break; + case SERVER_TASK_TYPE_SLOT_RESTORE: + { + int id_slot = task.slot_action.slot_id; + server_slot * slot = get_slot_by_id(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } if (slot->is_processing()) { - // if requested slot is unavailable, we defer this task for processing later - LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}}); - queue_tasks.defer(task); - break; - } + // if requested slot is unavailable, we defer this task for processing later + SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); + queue_tasks.defer(task); + break; + } - const int64_t t_start = ggml_time_us(); + const int64_t t_start = ggml_time_us(); - std::string filename = task.data.at("filename"); - std::string filepath = task.data.at("filepath"); + std::string filename = task.slot_action.filename; + std::string filepath = task.slot_action.filepath; - slot->cache_tokens.resize(slot->n_ctx); - size_t token_count = 0; - size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), - slot->cache_tokens.size(), &token_count); - if (nread == 0) - { - slot->cache_tokens.resize(0); - send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", - ERROR_TYPE_INVALID_REQUEST); - break; - } - slot->cache_tokens.resize(token_count); - - const int64_t t_end = ggml_time_us(); - const double t_restore_ms = (t_end - t_start) / 1000.0; - - server_task_result result; - result.id = task.id; - result.stop = true; - result.error = false; - result.data = json{{"id_slot", id_slot}, - {"filename", filename}, - {"n_restored", token_count}, // tokens restored - {"n_read", nread}, // bytes read - {"timings", {{"restore_ms", t_restore_ms}}}}; - queue_results.send(result); - } - break; - case SERVER_TASK_TYPE_SLOT_ERASE: { - int id_slot = task.data.at("id_slot"); - server_slot *slot = get_slot_by_id(id_slot); - if (slot == nullptr) - { - send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); - break; - } + slot->cache_tokens.resize(slot->n_ctx); + size_t token_count = 0; + size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count); + if (nread == 0) { + slot->cache_tokens.resize(0); + send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); + break; + } + slot->cache_tokens.resize(token_count); + + const int64_t t_end = ggml_time_us(); + const double t_restore_ms = (t_end - t_start) / 1000.0; + + auto res = std::make_unique(); + res->id = task.id; + res->id_slot = id_slot; + res->filename = filename; + res->is_save = false; + res->n_tokens = token_count; + res->n_bytes = nread; + res->t_ms = t_restore_ms; + queue_results.send(std::move(res)); + } break; + case SERVER_TASK_TYPE_SLOT_ERASE: + { + int id_slot = task.slot_action.slot_id; + server_slot * slot = get_slot_by_id(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } if (slot->is_processing()) { - // if requested slot is unavailable, we defer this task for processing later - LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}}); - queue_tasks.defer(task); - break; - } + // if requested slot is unavailable, we defer this task for processing later + SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); + queue_tasks.defer(task); + break; + } - // Erase token cache - const size_t n_erased = slot->cache_tokens.size(); - llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1); - slot->cache_tokens.clear(); + // Erase token cache + const size_t n_erased = slot->cache_tokens.size(); + llama_kv_cache_seq_rm(ctx, slot->id, -1, -1); + slot->cache_tokens.clear(); - server_task_result result; - result.id = task.id; - result.stop = true; - result.error = false; - result.data = json{{"id_slot", id_slot}, {"n_erased", n_erased}}; - queue_results.send(result); - } - break; + auto res = std::make_unique(); + res->id = task.id; + res->id_slot = id_slot; + res->n_erased = n_erased; + queue_results.send(std::move(res)); + } break; case SERVER_TASK_TYPE_SET_LORA: - { - llama_lora_adapters_apply(ctx, lora_adapters); - server_task_result result; - result.id = task.id; - result.stop = true; - result.error = false; - result.data = json{{ "success", true }}; - queue_results.send(result); + { + params_base.lora_adapters = std::move(task.set_lora); + auto res = std::make_unique(); + res->id = task.id; + queue_results.send(std::move(res)); } break; } } - void update_slots() - { - if (system_need_update) - { - system_prompt_update(); - } - + void update_slots() { // check if all slots are idle { bool all_idle = true; - for (auto &slot : slots) - { + for (auto & slot : slots) { if (slot.is_processing()) { all_idle = false; break; } } - if (all_idle) - { - LOG_INFO("all slots are idle", {}); - if (system_prompt.empty() && clean_kv_cache) - { + if (all_idle) { + SRV_INF("%s", "all slots are idle\n"); + if (clean_kv_cache) { kv_cache_clear(); } @@ -2022,182 +2650,129 @@ struct server_context } { - LOG_VERBOSE("posting NEXT_RESPONSE", {}); - - server_task task; - task.type = SERVER_TASK_TYPE_NEXT_RESPONSE; - task.id_target = -1; + SRV_DBG("%s", "posting NEXT_RESPONSE\n"); + server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE); + task.id = queue_tasks.get_new_id(); queue_tasks.post(task); } // apply context-shift if needed // TODO: simplify and improve - for (server_slot &slot : slots) - { - if (slot.ga_n == 1) - { - if (slot.is_processing() && (int)system_tokens.size() + slot.n_past >= slot.n_ctx - 1) - { - // Shift context - const int n_keep = slot.params.n_keep + add_bos_token; - const int n_left = (int)system_tokens.size() + slot.n_past - n_keep; - const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2); - - LOG_INFO("slot context shift", {{"id_slot", slot.id}, - {"id_task", slot.id_task}, - {"n_keep", n_keep}, - {"n_left", n_left}, - {"n_discard", n_discard}, - {"n_ctx", n_ctx}, - {"n_past", slot.n_past}, - {"n_system_tokens", system_tokens.size()}, - {"n_cache_tokens", slot.cache_tokens.size()}}); - - llama_kv_cache_seq_rm(ctx, slot.id + 1, n_keep, n_keep + n_discard); - llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, - -n_discard); - - if (slot.params.cache_prompt) - { - for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) - { - slot.cache_tokens[i - n_discard] = slot.cache_tokens[i]; - } + for (server_slot & slot : slots) { + if (slot.is_processing() && slot.n_past + 1 >= slot.n_ctx) { + if (!params_base.ctx_shift) { + // this check is redundant (for good) + // we should never get here, because generation should already stopped in process_token() + slot.release(); + send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER); + continue; + } - slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard); - } + // Shift context + const int n_keep = slot.params.n_keep + add_bos_token; + const int n_left = slot.n_past - n_keep; + const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2); + + SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard); + + llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard); + llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard); - slot.n_past -= n_discard; + if (slot.params.cache_prompt) { + for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) { + slot.cache_tokens[i - n_discard] = slot.cache_tokens[i]; + } - slot.truncated = true; + slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard); } + + slot.n_past -= n_discard; + + slot.truncated = true; } } // start populating the batch for this iteration - llama_batch_clear(batch); + common_batch_clear(batch); + + // track if given slot can be batched with slots already in the batch + server_slot * slot_batched = nullptr; // frist, add sampled tokens from any ongoing sequences - for (auto &slot : slots) - { + for (auto & slot : slots) { if (slot.state != SLOT_STATE_GENERATING) { continue; } - slot.i_batch = batch.n_tokens; + // check if we can batch this slot with the previous one + if (!slot_batched) { + slot_batched = &slot; + } else if (!slot_batched->can_batch_with(slot)) { + continue; + } - const int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past; + slot.i_batch = batch.n_tokens; - // TODO: we always have to take into account the "system_tokens" - // this is not great and needs to be improved somehow - llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, {slot.id + 1}, true); + common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true); slot.n_past += 1; - if (slot.params.cache_prompt) - { + if (slot.params.cache_prompt) { slot.cache_tokens.push_back(slot.sampled); } - LOG_VERBOSE("slot decode token", {{"id_slot", slot.id}, - {"id_task", slot.id_task}, - {"n_ctx", n_ctx}, - {"n_past", slot.n_past}, - {"n_system_tokens", system_tokens.size()}, - {"n_cache_tokens", slot.cache_tokens.size()}, - {"truncated", slot.truncated}}); + SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n", + slot.n_ctx, slot.n_past, (int) slot.cache_tokens.size(), slot.truncated); } // process in chunks of params.n_batch - int32_t n_batch = llama_n_batch(ctx); + int32_t n_batch = llama_n_batch(ctx); int32_t n_ubatch = llama_n_ubatch(ctx); - // track if this is an embedding or non-embedding batch - // if we've added sampled tokens above, we are in non-embedding mode - // -1: none, 0: non-embedding, 1: embedding - int32_t batch_type = batch.n_tokens > 0 ? 0 : -1; - // next, batch any pending prompts without exceeding n_batch - if (params.cont_batching || batch.n_tokens == 0) - { - for (auto &slot : slots) - { - // this slot still has a prompt to be processed - if (slot.state == SLOT_STATE_PROCESSING_PROMPT) { - auto &prompt_tokens = slot.prompt_tokens; + if (params_base.cont_batching || batch.n_tokens == 0) { + for (auto & slot : slots) { + // check if we can batch this slot with the previous one + if (slot.is_processing()) { + if (!slot_batched) { + slot_batched = &slot; + } else if (!slot_batched->can_batch_with(slot)) { + continue; + } + } - // we haven't tokenized the prompt yet - do it now: - if (prompt_tokens.empty()) - { - LOG_VERBOSE("tokenizing prompt", {{"id_slot", slot.id}, {"id_task", slot.id_task}}); + // this slot still has a prompt to be processed + if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) { + auto & prompt_tokens = slot.prompt_tokens; + // TODO: maybe move branch to outside of this loop in the future + if (slot.state == SLOT_STATE_STARTED) { slot.t_start_process_prompt = ggml_time_us(); slot.t_start_generation = 0; - if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_INFILL) { - const bool add_bos = llama_add_bos_token(model); - bool suff_rm_leading_spc = true; - if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) - { - params.input_suffix.erase(0, 1); - suff_rm_leading_spc = false; - } - - auto prefix_tokens = tokenize(slot.params.input_prefix, false); - auto suffix_tokens = tokenize(slot.params.input_suffix, false); - - const int space_token = 29871; // TODO: this should not be hardcoded - if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0] == space_token) - { - suffix_tokens.erase(suffix_tokens.begin()); - } + slot.n_past = 0; + slot.n_prompt_tokens = prompt_tokens.size(); + slot.state = SLOT_STATE_PROCESSING_PROMPT; - prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model)); - suffix_tokens.insert(suffix_tokens.begin(), llama_token_suffix(model)); + SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens); - auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens; - auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens; - if (add_bos) - { - embd_inp.insert(embd_inp.begin(), llama_token_bos(model)); + // print prompt tokens (for debugging) + if (1) { + // first 16 tokens (avoid flooding logs) + for (int i = 0; i < std::min(16, prompt_tokens.size()); i++) { + SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); } - embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end()); - - const llama_token middle_token = llama_token_middle(model); - if (middle_token >= 0) - { - embd_inp.push_back(middle_token); + } else { + // all + for (int i = 0; i < (int) prompt_tokens.size(); i++) { + SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); } - - prompt_tokens = embd_inp; - } - else - { - prompt_tokens = - tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt } - slot.n_past = 0; - slot.n_prompt_tokens = prompt_tokens.size(); - - LOG_VERBOSE("prompt tokenized", { - {"id_slot", slot.id}, - {"id_task", slot.id_task}, - {"n_ctx", slot.n_ctx}, - {"n_keep", slot.params.n_keep}, - {"n_prompt_tokens", slot.n_prompt_tokens}, - {"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), - prompt_tokens.cend())}, - }); - // empty prompt passed -> release the slot and send empty response - if (prompt_tokens.empty()) - { - LOG_INFO("empty prompt - releasing slot", { - {"id_slot", slot.id}, - {"id_task", slot.id_task} - }); + if (prompt_tokens.empty()) { + SLT_WRN(slot, "%s", "empty prompt - releasing slot\n"); slot.release(); slot.print_timings(); @@ -2205,185 +2780,158 @@ struct server_context continue; } - if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) { - // this prompt is too large to process - discard it + if (slot.is_non_causal()) { if (slot.n_prompt_tokens > n_ubatch) { slot.release(); - send_error(slot, "input is too large to process. increase the physical batch size", - ERROR_TYPE_SERVER); + send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER); continue; } - } - else - { - if (slot.params.n_keep < 0) - { + + if (slot.n_prompt_tokens > slot.n_ctx) { + slot.release(); + send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_SERVER); + continue; + } + } else { + if (!params_base.ctx_shift) { + // if context shift is disabled, we make sure prompt size is smaller than KV size + // TODO: there should be a separate parameter that control prompt truncation + // context shift should be applied only during the generation phase + if (slot.n_prompt_tokens >= slot.n_ctx) { + slot.release(); + send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST); + continue; + } + } + if (slot.params.n_keep < 0) { slot.params.n_keep = slot.n_prompt_tokens; } slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep); - // if input prompt is too big, truncate it (if group attention self-extend is disabled) - if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) - { + // if input prompt is too big, truncate it + if (slot.n_prompt_tokens >= slot.n_ctx) { const int n_left = slot.n_ctx - slot.params.n_keep; const int n_block_size = n_left / 2; - const int erased_blocks = - (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; + const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; - std::vector new_tokens(prompt_tokens.begin(), - prompt_tokens.begin() + slot.params.n_keep); + llama_tokens new_tokens( + prompt_tokens.begin(), + prompt_tokens.begin() + slot.params.n_keep); - new_tokens.insert(new_tokens.end(), - prompt_tokens.begin() + slot.params.n_keep + - erased_blocks * n_block_size, - prompt_tokens.end()); + new_tokens.insert( + new_tokens.end(), + prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, + prompt_tokens.end()); prompt_tokens = std::move(new_tokens); slot.truncated = true; slot.n_prompt_tokens = prompt_tokens.size(); - LOG_VERBOSE("input truncated", - { - {"id_slot", slot.id}, - {"id_task", slot.id_task}, - {"n_ctx", slot.n_ctx}, - {"n_keep", slot.params.n_keep}, - {"n_left", n_left}, - {"n_prompt_tokens", slot.n_prompt_tokens}, - {"prompt_tokens", - tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())}, - }); + SLT_WRN(slot, "input truncated, n_ctx = %d, n_keep = %d, n_left = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, n_left, slot.n_prompt_tokens); GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); } - gpt_sampler_reset(slot.smpl); + if (slot.params.cache_prompt) { + // reuse any previously computed tokens that are common with the new prompt + slot.n_past = common_lcp(slot.cache_tokens, prompt_tokens); + + // reuse chunks from the cached prompt by shifting their KV cache in the new position + if (params_base.n_cache_reuse > 0) { + size_t head_c = slot.n_past; // cache + size_t head_p = slot.n_past; // current prompt - if (!slot.params.cache_prompt) - { - slot.n_past_se = 0; - slot.ga_i = 0; - } - else - { - GGML_ASSERT(slot.ga_n == 1); + SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past); - // reuse any previously computed tokens that are common with the new prompt - slot.n_past = common_part(slot.cache_tokens, prompt_tokens); + while (head_c < slot.cache_tokens.size() && + head_p < prompt_tokens.size()) { + + size_t n_match = 0; + while (head_c + n_match < slot.cache_tokens.size() && + head_p + n_match < prompt_tokens.size() && + slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) { + + n_match++; + } + + if (n_match >= (size_t) params_base.n_cache_reuse) { + SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match); + //for (size_t i = head_p; i < head_p + n_match; i++) { + // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); + //} + + const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c; + + llama_kv_cache_seq_rm (ctx, slot.id, head_p, head_c); + llama_kv_cache_seq_add(ctx, slot.id, head_c, -1, kv_shift); - // push the prompt into the sampling context (do not apply grammar) - for (int i = 0; i < slot.n_past; ++i) { - gpt_sampler_accept(slot.smpl, slot.cache_tokens[i], false); + for (size_t i = 0; i < n_match; i++) { + slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i]; + slot.n_past++; + } + + head_c += n_match; + head_p += n_match; + } else { + head_c += 1; + } + } + + SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past); } } } - if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) - { + if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) { // we have to evaluate at least 1 token to generate logits. - LOG_INFO("we have to evaluate at least 1 token to generate logits", { - { "id_slot", slot.id }, - { "id_task", slot.id_task } - }); + SLT_WRN(slot, "need to evaluate at least 1 token to generate logits, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens); slot.n_past--; - if (slot.ga_i > 0) - { - slot.n_past_se--; - } } slot.n_prompt_tokens_processed = 0; } - if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) { + // non-causal tasks require to fit the entire prompt in the physical batch + if (slot.is_non_causal()) { // cannot fit the prompt in the current batch - will try next iter - if (batch.n_tokens + slot.n_prompt_tokens > n_batch) - { + if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { continue; } } - // check that we are in the right batch_type, if not defer the slot - bool slot_type = slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ? 1 : 0; - if (batch_type == -1) - { - batch_type = slot_type; - } - else if (batch_type != slot_type) - { - continue; - } - // keep only the common part - int p0 = (int)system_tokens.size() + slot.n_past; - if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) - { + if (!llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1)) { // could not partially delete (likely using a non-Transformer model) - llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1); - - p0 = (int)system_tokens.size(); - if (p0 != 0) - { - // copy over the system prompt when there is one - llama_kv_cache_seq_cp(ctx, 0, slot.id + 1, -1, -1); - } + llama_kv_cache_seq_rm(ctx, slot.id, -1, -1); - // there is no common part left (except for the system prompt) + // there is no common part left slot.n_past = 0; - slot.n_past_se = 0; - slot.ga_i = 0; - // TODO: is the system prompt ever in the sampling context? - gpt_sampler_reset(slot.smpl); } + SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past); + // remove the non-common part from the cache slot.cache_tokens.resize(slot.n_past); - LOG_INFO("kv cache rm [p0, end)", {{"id_slot", slot.id}, {"id_task", slot.id_task}, {"p0", p0}}); - - int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past; - - int32_t ga_i = slot.ga_i; - int32_t ga_n = slot.ga_n; - int32_t ga_w = slot.ga_w; - // add prompt tokens for processing in the current batch - // TODO: the self-extend stuff here is a mess - simplify and/or abstract it somehow - for (; slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch; ++slot.n_past) - { - if (slot.ga_n != 1) - { - while (slot_npast >= ga_i + ga_w) - { - const int bd = (ga_w / ga_n) * (ga_n - 1); - slot_npast -= bd; - ga_i += ga_w / ga_n; - } - } + while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { + // without pooling, we want to output the embeddings for all the tokens in the batch + const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE; - llama_batch_add(batch, prompt_tokens[slot.n_past], system_tokens.size() + slot_npast, - {slot.id + 1}, false); + common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd); - if (slot.params.cache_prompt) - { + if (slot.params.cache_prompt) { slot.cache_tokens.push_back(prompt_tokens[slot.n_past]); } slot.n_prompt_tokens_processed++; - slot_npast++; + slot.n_past++; } - LOG_VERBOSE("prompt processing progress", - { - {"id_slot", slot.id}, - {"n_past", slot.n_past}, - {"n_ctx", n_ctx}, - {"n_tokens", batch.n_tokens}, - {"progress", (float)slot.n_prompt_tokens_processed / slot.n_prompt_tokens}, - }); + SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); // entire prompt has been processed if (slot.n_past == slot.n_prompt_tokens) { @@ -2391,113 +2939,64 @@ struct server_context GGML_ASSERT(batch.n_tokens > 0); + common_sampler_reset(slot.smpl); + + // Process all prompt tokens through sampler system + for (int i = 0; i < slot.n_prompt_tokens; ++i) { + common_sampler_accept(slot.smpl, prompt_tokens[i], false); + } + // extract the logits only for the last token batch.logits[batch.n_tokens - 1] = true; slot.n_decoded = 0; - slot.i_batch = batch.n_tokens - 1; - - LOG_VERBOSE("prompt done", { - {"id_slot", slot.id}, - {"n_past", slot.n_past}, - {"n_ctx", n_ctx}, - {"n_tokens", batch.n_tokens}, - }); + slot.i_batch = batch.n_tokens - 1; + + SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens); } } - if (batch.n_tokens >= n_batch) - { + if (batch.n_tokens >= n_batch) { break; } } } - if (batch.n_tokens == 0) - { - LOG_VERBOSE("no tokens to decode", {}); + if (batch.n_tokens == 0) { + SRV_WRN("%s", "no tokens to decode\n"); return; } - LOG_VERBOSE("decoding batch", { - {"n_tokens", batch.n_tokens}, - }); + SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens); - // make sure we're in the right embedding mode - llama_set_embeddings(ctx, batch_type == 1); + if (slot_batched) { + // make sure we're in the right embedding mode + llama_set_embeddings(ctx, slot_batched->is_non_causal()); + // apply lora, only need to do it once per batch + common_set_adapter_lora(ctx, slot_batched->lora); + } // process the created batch of tokens - for (int32_t i = 0; i < batch.n_tokens; i += n_batch) - { + for (int32_t i = 0; i < batch.n_tokens; i += n_batch) { const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); - for (auto &slot : slots) - { - if (slot.ga_n != 1) - { - // context extension via Self-Extend - // TODO: simplify and/or abstract this - while (slot.n_past_se >= slot.ga_i + slot.ga_w) - { - const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w; - const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1); - const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w; - - LOG_TEE("\n"); - LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, - slot.ga_i + ib * bd, slot.n_past_se + ib * bd); - LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, - slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, - (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n); - LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, - slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, - slot.n_past_se + ib * bd + dd); - - llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd); - llama_kv_cache_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, - slot.ga_n); - llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w, - slot.n_past_se + ib * bd, dd); - - slot.n_past_se -= bd; - - slot.ga_i += slot.ga_w / slot.ga_n; - - LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, - slot.ga_i); - } - - slot.n_past_se += n_tokens; - } - } - llama_batch batch_view = { n_tokens, - batch.token + i, + batch.token + i, nullptr, - batch.pos + i, + batch.pos + i, batch.n_seq_id + i, - batch.seq_id + i, - batch.logits + i, - 0, - 0, - 0, // unused + batch.seq_id + i, + batch.logits + i, }; const int ret = llama_decode(ctx, batch_view); metrics.on_decoded(slots); - if (ret != 0) - { - if (n_batch == 1 || ret < 0) - { + if (ret != 0) { + if (n_batch == 1 || ret < 0) { // if you get here, it means the KV cache is full - try increasing it via the context size - LOG_ERROR("failed to decode the batch: KV cache is full - try increasing it via the context size", - { - {"i", i}, - {"n_batch", n_batch}, - {"ret", ret}, - }); + SRV_ERR("failed to decode the batch: KV cache is full - try increasing it via the context size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret); for (auto & slot : slots) { slot.release(); send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size."); @@ -2509,29 +3008,31 @@ struct server_context n_batch /= 2; i -= n_batch; - LOG_WARNING("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation", { - {"i", i}, - {"n_batch", n_batch}, - {"ret", ret}, - }); + SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret); continue; // continue loop of n_batch } - for (auto &slot : slots) - { + for (auto & slot : slots) { if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) { continue; // continue loop of slots } if (slot.state == SLOT_STATE_DONE_PROMPT) { - if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) { - // prompt evaluated for embedding - send_embedding(slot, batch_view); - slot.release(); - slot.i_batch = -1; - continue; // continue loop of slots - } + if (slot.task_type == SERVER_TASK_TYPE_EMBEDDING) { + // prompt evaluated for embedding + send_embedding(slot, batch_view); + slot.release(); + slot.i_batch = -1; + continue; // continue loop of slots + } + + if (slot.task_type == SERVER_TASK_TYPE_RERANK) { + send_rerank(slot, batch_view); + slot.release(); + slot.i_batch = -1; + continue; // continue loop of slots + } // prompt evaluated for next-token prediction slot.state = SLOT_STATE_GENERATING; @@ -2539,29 +3040,34 @@ struct server_context continue; // continue loop of slots } - completion_token_output result; - const llama_token id = gpt_sampler_sample(slot.smpl, ctx, slot.i_batch - i); + const int tok_idx = slot.i_batch - i; + + llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx); + + slot.i_batch = -1; - gpt_sampler_accept(slot.smpl, id, true); + common_sampler_accept(slot.smpl, id, true); slot.n_decoded += 1; - if (slot.n_decoded == 1) - { - slot.t_start_generation = ggml_time_us(); + + const int64_t t_current = ggml_time_us(); + + if (slot.n_decoded == 1) { + slot.t_start_generation = t_current; slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; metrics.on_prompt_eval(slot); } - result.tok = id; + slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3; - const auto * cur_p = gpt_sampler_get_candidates(slot.smpl); + completion_token_output result; + result.tok = id; + result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special); + result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs - for (size_t i = 0; i < (size_t) slot.sparams.n_probs; ++i) { - result.probs.push_back({ - cur_p->data[i].id, - i >= cur_p->size ? 0.0f : cur_p->data[i].p, - }); - } + if (slot.params.sampling.n_probs > 0) { + populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, tok_idx); + } if (!process_token(result, slot)) { // release slot because of stop condition @@ -2569,45 +3075,178 @@ struct server_context slot.print_timings(); send_final_response(slot); metrics.on_prediction(slot); + continue; } + } - slot.i_batch = -1; + // do speculative decoding + for (auto & slot : slots) { + if (!slot.is_processing() || !slot.can_speculate()) { + continue; + } + + if (slot.state != SLOT_STATE_GENERATING) { + continue; + } + + // determine the max draft that fits the current slot state + int n_draft_max = slot.params.speculative.n_max; + + // note: n_past is not yet increased for the `id` token sampled above + // also, need to leave space for 1 extra token to allow context shifts + n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2); + + if (slot.n_remaining > 0) { + n_draft_max = std::min(n_draft_max, slot.n_remaining - 1); + } + + SLT_DBG(slot, "max possible draft: %d\n", n_draft_max); + + if (n_draft_max < slot.params.speculative.n_min) { + SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.params.speculative.n_min); + + continue; + } + + llama_token id = slot.sampled; + + struct common_speculative_params params_spec; + params_spec.n_draft = n_draft_max; + params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max; + params_spec.p_min = slot.params.speculative.p_min; + + llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id); + + // ignore small drafts + if (slot.params.speculative.n_min > (int) draft.size()) { + SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min); + + continue; + } + + // construct the speculation batch + common_batch_clear(slot.batch_spec); + common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true); + + for (size_t i = 0; i < draft.size(); ++i) { + common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true); + } + + SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens); + + llama_decode(ctx, slot.batch_spec); + + // the accepted tokens from the speculation + const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); + + slot.n_past += ids.size(); + slot.n_decoded += ids.size(); + + slot.cache_tokens.push_back(id); + slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1); + + llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1); + + for (size_t i = 0; i < ids.size(); ++i) { + completion_token_output result; + + result.tok = ids[i]; + result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special); + result.prob = 1.0f; // set later + + // TODO: set result.probs + + if (!process_token(result, slot)) { + // release slot because of stop condition + slot.release(); + slot.print_timings(); + send_final_response(slot); + metrics.on_prediction(slot); + break; + } + } + + SLT_DBG(slot, "accepted %d/%d draft tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past); } } - LOG_VERBOSE("run slots completed", {}); + SRV_DBG("%s", "run slots completed\n"); } - json model_meta() const - { - return json{ - {"vocab_type", llama_vocab_type(model)}, {"n_vocab", llama_n_vocab(model)}, - {"n_ctx_train", llama_n_ctx_train(model)}, {"n_embd", llama_n_embd(model)}, - {"n_params", llama_model_n_params(model)}, {"size", llama_model_size(model)}, + json model_meta() const { + return json { + {"vocab_type", llama_vocab_type (vocab)}, + {"n_vocab", llama_vocab_n_tokens (vocab)}, + {"n_ctx_train", llama_model_n_ctx_train(model)}, + {"n_embd", llama_model_n_embd (model)}, + {"n_params", llama_model_n_params (model)}, + {"size", llama_model_size (model)}, }; } }; +static void common_params_handle_model_default( + std::string & model, + const std::string & model_url, + std::string & hf_repo, + std::string & hf_file, + const std::string & hf_token) { + if (!hf_repo.empty()) { + // short-hand to avoid specifying --hf-file -> default it to --model + if (hf_file.empty()) { + if (model.empty()) { + auto auto_detected = common_get_hf_file(hf_repo, hf_token); + if (auto_detected.first.empty() || auto_detected.second.empty()) { + exit(1); // built without CURL, error message already printed + } + hf_repo = auto_detected.first; + hf_file = auto_detected.second; + } else { + hf_file = model; + } + } + // make sure model path is present (for caching purposes) + if (model.empty()) { + // this is to avoid different repo having same file name, or same file name in different subdirs + std::string filename = hf_repo + "_" + hf_file; + // to make sure we don't have any slashes in the filename + string_replace_all(filename, "/", "_"); + model = fs_get_cache_file(filename); + } + } else if (!model_url.empty()) { + if (model.empty()) { + auto f = string_split(model_url, '#').front(); + f = string_split(f, '?').front(); + model = fs_get_cache_file(string_split(f, '/').back()); + } + } else if (model.empty()) { + model = DEFAULT_MODEL_PATH; + } +} + // parse the given jparams (see de.kherud.llama.args.ModelParameters#toString()) from JSON to the required C++ struct. -static void server_params_parse(json jparams, gpt_params ¶ms) +static void server_params_parse(json jparams, common_params ¶ms) { - gpt_params default_params; + common_params default_params; - params.seed = json_value(jparams, "seed", default_params.seed); - params.n_threads = json_value(jparams, "n_threads", default_params.n_threads); - params.n_threads_draft = json_value(jparams, "n_threads_draft", default_params.n_threads_draft); - params.n_threads_batch = json_value(jparams, "n_threads_batch", default_params.n_threads_batch); - params.n_threads_batch_draft = json_value(jparams, "n_threads_batch_draft", default_params.n_threads_batch_draft); + params.sampling.seed = json_value(jparams, "seed", default_params.sampling.seed); + params.cpuparams.n_threads = json_value(jparams, "n_threads", default_params.cpuparams.n_threads); + params.speculative.cpuparams.n_threads = json_value(jparams, "n_threads_draft", default_params.speculative.cpuparams.n_threads); + params.cpuparams_batch.n_threads = json_value(jparams, "n_threads_batch", default_params.cpuparams_batch.n_threads); + params.speculative.cpuparams_batch.n_threads = json_value(jparams, "n_threads_batch_draft", default_params.speculative.cpuparams_batch.n_threads ); params.n_predict = json_value(jparams, "n_predict", default_params.n_predict); params.n_ctx = json_value(jparams, "n_ctx", default_params.n_ctx); params.n_batch = json_value(jparams, "n_batch", default_params.n_batch); params.n_ubatch = json_value(jparams, "n_ubatch", default_params.n_ubatch); params.n_keep = json_value(jparams, "n_keep", default_params.n_keep); - params.n_draft = json_value(jparams, "n_draft", default_params.n_draft); + + params.speculative.n_max = json_value(jparams, "n_draft", default_params.speculative.n_max); + params.speculative.n_min = json_value(jparams, "n_draft_min", default_params.speculative.n_min); + params.n_chunks = json_value(jparams, "n_chunks", default_params.n_chunks); params.n_parallel = json_value(jparams, "n_parallel", default_params.n_parallel); params.n_sequences = json_value(jparams, "n_sequences", default_params.n_sequences); - params.p_split = json_value(jparams, "p_split", default_params.p_split); + params.speculative.p_split = json_value(jparams, "p_split", default_params.speculative.p_split); params.grp_attn_n = json_value(jparams, "grp_attn_n", default_params.grp_attn_n); params.grp_attn_w = json_value(jparams, "grp_attn_w", default_params.grp_attn_w); params.n_print = json_value(jparams, "n_print", default_params.n_print); @@ -2623,7 +3262,7 @@ static void server_params_parse(json jparams, gpt_params ¶ms) params.rope_scaling_type = json_value(jparams, "rope_scaling_type", default_params.rope_scaling_type); params.pooling_type = json_value(jparams, "pooling_type", default_params.pooling_type); params.model = json_value(jparams, "model", default_params.model); - params.model_draft = json_value(jparams, "model_draft", default_params.model_draft); + params.speculative.model = json_value(jparams, "model_draft", default_params.speculative.model); params.model_alias = json_value(jparams, "model_alias", default_params.model_alias); params.model_url = json_value(jparams, "model_url", default_params.model_url); params.hf_repo = json_value(jparams, "hf_repo", default_params.hf_repo); @@ -2637,17 +3276,16 @@ static void server_params_parse(json jparams, gpt_params ¶ms) params.lookup_cache_static = json_value(jparams, "lookup_cache_static", default_params.lookup_cache_static); params.lookup_cache_dynamic = json_value(jparams, "lookup_cache_dynamic", default_params.lookup_cache_dynamic); params.logits_file = json_value(jparams, "logits_file", default_params.logits_file); - params.lora_adapter = json_value(jparams, "lora_adapter", default_params.lora_adapter); + // params.lora_adapters = json_value(jparams, "lora_adapter", default_params.lora_adapters); params.embedding = json_value(jparams, "embedding", default_params.embedding); params.escape = json_value(jparams, "escape", default_params.escape); params.cont_batching = json_value(jparams, "cont_batching", default_params.cont_batching); params.flash_attn = json_value(jparams, "flash_attn", default_params.flash_attn); params.input_prefix_bos = json_value(jparams, "input_prefix_bos", default_params.input_prefix_bos); - params.ignore_eos = json_value(jparams, "ignore_eos", default_params.ignore_eos); + params.sampling.ignore_eos = json_value(jparams, "ignore_eos", default_params.sampling.ignore_eos); params.use_mmap = json_value(jparams, "use_mmap", default_params.use_mmap); params.use_mlock = json_value(jparams, "use_mlock", default_params.use_mlock); params.no_kv_offload = json_value(jparams, "no_kv_offload", default_params.no_kv_offload); - params.system_prompt = json_value(jparams, "system_prompt", default_params.system_prompt); params.chat_template = json_value(jparams, "chat_template", default_params.chat_template); if (jparams.contains("n_gpu_layers")) @@ -2655,13 +3293,13 @@ static void server_params_parse(json jparams, gpt_params ¶ms) if (llama_supports_gpu_offload()) { params.n_gpu_layers = json_value(jparams, "n_gpu_layers", default_params.n_gpu_layers); - params.n_gpu_layers_draft = json_value(jparams, "n_gpu_layers_draft", default_params.n_gpu_layers_draft); + params.speculative.n_gpu_layers = json_value(jparams, "n_gpu_layers_draft", default_params.speculative.n_gpu_layers); } else { - LOG_WARNING("Not compiled with GPU offload support, --n-gpu-layers option will be ignored. " - "See main README.md for information on enabling GPU BLAS support", - {{"n_gpu_layers", params.n_gpu_layers}}); + SRV_WRN("Not compiled with GPU offload support, --n-gpu-layers option will be ignored. " + "See main README.md for information on enabling GPU BLAS support: %s = %d", + "n_gpu_layers", params.n_gpu_layers); } } @@ -2692,7 +3330,7 @@ static void server_params_parse(json jparams, gpt_params ¶ms) } } #else - LOG_WARNING("llama.cpp was compiled without CUDA. It is not possible to set a tensor split.\n", {}); + SRV_WRN("%s","llama.cpp was compiled without CUDA. It is not possible to set a tensor split.\n"); #endif // GGML_USE_CUDA } @@ -2701,9 +3339,9 @@ static void server_params_parse(json jparams, gpt_params ¶ms) #if defined(GGML_USE_CUDA) || defined(GGML_USE_SYCL) params.main_gpu = json_value(jparams, "main_gpu", default_params.main_gpu); #else - LOG_WARNING("llama.cpp was compiled without CUDA. It is not possible to set a main GPU.", {}); + SRV_WRN("%s","llama.cpp was compiled without CUDA. It is not possible to set a main GPU."); #endif } - gpt_params_handle_model_default(params); + common_params_handle_model_default(params.model, params.model_url, params.hf_repo, params.hf_file, params.hf_token); } diff --git a/src/main/cpp/utils.hpp b/src/main/cpp/utils.hpp index 7de7eac4..42b871fe 100644 --- a/src/main/cpp/utils.hpp +++ b/src/main/cpp/utils.hpp @@ -1,202 +1,385 @@ #pragma once #include "common.h" +#include "log.h" #include "llama.h" +#include "base64.hpp" +#ifndef NDEBUG +// crash the server in debug mode, otherwise send an http 500 error +#define CPPHTTPLIB_NO_EXCEPTIONS 1 +#endif +// increase max payload length to allow use of larger context size +#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576 +//#include "httplib.h" + +// Change JSON_ASSERT from assert() to GGML_ASSERT: +#define JSON_ASSERT GGML_ASSERT #include "json.hpp" #include #include #include #include +#include -#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613" +#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo" using json = nlohmann::ordered_json; -// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11 -enum error_type -{ - ERROR_TYPE_INVALID_REQUEST, - ERROR_TYPE_AUTHENTICATION, - ERROR_TYPE_SERVER, - ERROR_TYPE_NOT_FOUND, - ERROR_TYPE_PERMISSION, - ERROR_TYPE_UNAVAILABLE, // custom error - ERROR_TYPE_NOT_SUPPORTED, // custom error -}; - -extern bool log_json; -extern std::function log_callback; - -#if SERVER_VERBOSE -#define LOG_VERBOSE(MSG, ...) \ - do \ - { \ - server_log(GGML_LOG_LEVEL_DEBUG, __func__, __LINE__, MSG, __VA_ARGS__); \ - } while (0) -#else -#define LOG_VERBOSE(MSG, ...) -#endif +#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) +#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) +#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) +#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) -#define LOG_ERROR(MSG, ...) server_log(GGML_LOG_LEVEL_ERROR, __func__, __LINE__, MSG, __VA_ARGS__) -#define LOG_WARNING(MSG, ...) server_log(GGML_LOG_LEVEL_WARN, __func__, __LINE__, MSG, __VA_ARGS__) -#define LOG_INFO(MSG, ...) server_log(GGML_LOG_LEVEL_INFO, __func__, __LINE__, MSG, __VA_ARGS__) +#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) -static inline void server_log(ggml_log_level level, const char *function, int line, const char *message, - const json &extra); +#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) -template static T json_value(const json &body, const std::string &key, const T &default_value) -{ +template +static T json_value(const json & body, const std::string & key, const T & default_value) { // Fallback null to default value - if (body.contains(key) && !body.at(key).is_null()) - { - try - { + if (body.contains(key) && !body.at(key).is_null()) { + try { return body.at(key); - } - catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const &) - { - std::stringstream ss; - ss << "Wrong type supplied for parameter '" << key << "'. Expected '" << json(default_value).type_name() - << "', using default value."; - LOG_WARNING(ss.str().c_str(), body); + } catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const &) { + LOG_WRN("Wrong type supplied for parameter '%s'. Expected '%s', using default value\n", key.c_str(), json(default_value).type_name()); return default_value; } - } - else - { + } else { return default_value; } } -static const char *log_level_to_string(ggml_log_level level) -{ - switch (level) - { - case GGML_LOG_LEVEL_ERROR: - return "ERROR"; - case GGML_LOG_LEVEL_WARN: - return "WARN"; - default: - case GGML_LOG_LEVEL_INFO: - return "INFO"; - case GGML_LOG_LEVEL_DEBUG: - return "DEBUG"; +const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT); + +// +// tokenizer and input processing utils +// + +static bool json_is_array_of_numbers(const json & data) { + if (data.is_array()) { + for (const auto & e : data) { + if (!e.is_number_integer()) { + return false; + } + } + return true; } + return false; } -static inline void server_log(ggml_log_level level, const char *function, int line, const char *message, - const json &extra) -{ - std::stringstream ss_tid; - ss_tid << std::this_thread::get_id(); - - if (log_json) - { - json log = json{ - {"msg", message}, -#if SERVER_VERBOSE - {"ts", time(nullptr)}, {"level", log_level_to_string(level)}, {"tid", ss_tid.str()}, {"function", function}, - {"line", line}, -#endif - }; - - if (!extra.empty()) - { - log.merge_patch(extra); +// is array having BOTH numbers & strings? +static bool json_is_array_of_mixed_numbers_strings(const json & data) { + bool seen_string = false; + bool seen_number = false; + if (data.is_array()) { + for (const auto & e : data) { + seen_string |= e.is_string(); + seen_number |= e.is_number_integer(); + if (seen_number && seen_string) { + return true; + } } + } + return false; +} - auto dump = log.dump(-1, ' ', false, json::error_handler_t::replace); - if (log_callback == nullptr) - { - printf("%s\n", dump.c_str()); +// get value by path(key1 / key2) +static json json_get_nested_values(const std::vector & paths, const json & js) { + json result = json::object(); + + for (const std::string & path : paths) { + json current = js; + const auto keys = string_split(path, /*separator*/ '/'); + bool valid_path = true; + for (const std::string & k : keys) { + if (valid_path && current.is_object() && current.contains(k)) { + current = current[k]; + } else { + valid_path = false; + } } - else - { - log_callback(level, dump.c_str(), nullptr); + if (valid_path) { + result[path] = current; } } - else - { - std::stringstream ss; - ss << message; - - if (!extra.empty()) - { - for (const auto &el : extra.items()) - { - const std::string value = el.value().dump(-1, ' ', false, json::error_handler_t::replace); - ss << " " << el.key() << "=" << value; + return result; +} + +/** + * this handles 2 cases: + * - only string, example: "string" + * - mixed string and tokens, example: [12, 34, "string", 56, 78] + */ +static llama_tokens tokenize_mixed(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special) { + // If `add_bos` is true, we only add BOS, when json_prompt is a string, + // or the first element of the json_prompt array is a string. + llama_tokens prompt_tokens; + + if (json_prompt.is_array()) { + bool first = true; + for (const auto & p : json_prompt) { + if (p.is_string()) { + auto s = p.template get(); + + llama_tokens p; + if (first) { + p = common_tokenize(vocab, s, add_special, parse_special); + first = false; + } else { + p = common_tokenize(vocab, s, false, parse_special); + } + + prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end()); + } else { + if (first) { + first = false; + } + + prompt_tokens.push_back(p.template get()); } } + } else { + auto s = json_prompt.template get(); + prompt_tokens = common_tokenize(vocab, s, add_special, parse_special); + } -#if SERVER_VERBOSE - ss << " | ts " << time(nullptr) << " | tid " << ss_tid.str() << " | " << function << " line " << line; -#endif + return prompt_tokens; +} - const std::string str = ss.str(); - if (log_callback == nullptr) - { - printf("[%4s] %.*s\n", log_level_to_string(level), (int)str.size(), str.data()); +/** + * break the input "prompt" object into multiple prompt if needed, then tokenize them + * this supports these cases: + * - "prompt": "string" + * - "prompt": [12, 34, 56] + * - "prompt": [12, 34, "string", 56, 78] + * and multiple prompts (multi-tasks): + * - "prompt": ["string1", "string2"] + * - "prompt": ["string1", [12, 34, 56]] + * - "prompt": [[12, 34, 56], [78, 90, 12]] + * - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]] + */ +static std::vector tokenize_input_prompts(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special) { + std::vector result; + if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(json_prompt)) { + // string or mixed + result.push_back(tokenize_mixed(vocab, json_prompt, add_special, parse_special)); + } else if (json_is_array_of_numbers(json_prompt)) { + // array of tokens + result.push_back(json_prompt.get()); + } else if (json_prompt.is_array()) { + // array of prompts + result.reserve(json_prompt.size()); + for (const auto & p : json_prompt) { + if (p.is_string() || json_is_array_of_mixed_numbers_strings(p)) { + result.push_back(tokenize_mixed(vocab, p, add_special, parse_special)); + } else if (json_is_array_of_numbers(p)) { + // array of tokens + result.push_back(p.get()); + } else { + throw std::runtime_error("element of \"prompt\" must be a string, an list of tokens, or a list of mixed strings & tokens"); + } } - else - { - log_callback(level, str.c_str(), nullptr); + } else { + throw std::runtime_error("\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts"); + } + if (result.empty()) { + throw std::runtime_error("\"prompt\" must not be empty"); + } + return result; +} + +// return the last index of character that can form a valid string +// if the last character is potentially cut in half, return the index before the cut +// if validate_utf8(text) == text.size(), then the whole text is valid utf8 +static size_t validate_utf8(const std::string& text) { + size_t len = text.size(); + if (len == 0) return 0; + + // Check the last few bytes to see if a multi-byte character is cut off + for (size_t i = 1; i <= 4 && i <= len; ++i) { + unsigned char c = text[len - i]; + // Check for start of a multi-byte sequence from the end + if ((c & 0xE0) == 0xC0) { + // 2-byte character start: 110xxxxx + // Needs at least 2 bytes + if (i < 2) return len - i; + } else if ((c & 0xF0) == 0xE0) { + // 3-byte character start: 1110xxxx + // Needs at least 3 bytes + if (i < 3) return len - i; + } else if ((c & 0xF8) == 0xF0) { + // 4-byte character start: 11110xxx + // Needs at least 4 bytes + if (i < 4) return len - i; } } - fflush(stdout); + + // If no cut-off multi-byte character is found, return full length + return len; } // -// chat template utils +// template utils // +// format rerank task: [BOS]query[EOS][SEP]doc[EOS] +static llama_tokens format_rerank(const struct llama_vocab * vocab, const llama_tokens & query, const llama_tokens & doc) { + llama_tokens result; + + result.reserve(doc.size() + query.size() + 4); + result.push_back(llama_vocab_bos(vocab)); + result.insert(result.end(), query.begin(), query.end()); + result.push_back(llama_vocab_eos(vocab)); + result.push_back(llama_vocab_sep(vocab)); + result.insert(result.end(), doc.begin(), doc.end()); + result.push_back(llama_vocab_eos(vocab)); + + return result; +} + +// format infill task +static llama_tokens format_infill( + const llama_vocab * vocab, + const json & input_prefix, + const json & input_suffix, + const json & input_extra, + const int n_batch, + const int n_predict, + const int n_ctx, + const bool spm_infill, + const llama_tokens & tokens_prompt + ) { + // TODO: optimize this block by reducing memory allocations and movement + + // use FIM repo-level pattern: + // ref: https://arxiv.org/pdf/2409.12186 + // + // [FIM_REP]myproject + // [FIM_SEP]filename0 + // extra chunk 0 + // [FIM_SEP]filename1 + // extra chunk 1 + // ... + // [FIM_SEP]filename + // [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt + // + llama_tokens extra_tokens; + extra_tokens.reserve(n_ctx); + + auto tokens_prefix = tokenize_mixed(vocab, input_prefix, false, false); + auto tokens_suffix = tokenize_mixed(vocab, input_suffix, false, false); + + if (llama_vocab_fim_rep(vocab) != LLAMA_TOKEN_NULL) { + // TODO: make project name an input + static const auto k_fim_repo = common_tokenize(vocab, "myproject\n", false, false); + + extra_tokens.push_back(llama_vocab_fim_rep(vocab)); + extra_tokens.insert(extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end()); + } + for (const auto & chunk : input_extra) { + // { "text": string, "filename": string } + const std::string text = json_value(chunk, "text", std::string()); + const std::string filename = json_value(chunk, "filename", std::string("tmp")); + + if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) { + const auto k_fim_file = common_tokenize(vocab, filename + "\n", false, false); + + extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab)); + extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end()); + } else { + // chunk separator in binary form to avoid confusing the AI + static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00}; + static const auto k_chunk_prefix_tokens = common_tokenize(vocab, k_chunk_prefix_str, false, false); + + extra_tokens.insert(extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end()); + } + + const auto chunk_tokens = common_tokenize(vocab, text, false, false); + extra_tokens.insert(extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end()); + } + + if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) { + // TODO: current filename + static const auto k_fim_file = common_tokenize(vocab, "filename\n", false, false); + + extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab)); + extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end()); + } + + // for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?) + const int n_prefix_take = std::min(tokens_prefix.size(), 3*(n_batch/4)); + const int n_suffix_take = std::min(tokens_suffix.size(), std::max(0, (n_batch/4) - (2 + tokens_prompt.size()))); + + SRV_DBG("n_prefix_take = %d, n_suffix_take = %d, total = %d\n", n_prefix_take, n_suffix_take, (n_prefix_take + n_suffix_take)); + + // fill the rest of the context with extra chunks + const int n_extra_take = std::min(std::max(0, n_ctx - (n_batch) - 2*n_predict), extra_tokens.size()); + + tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take); + tokens_suffix.resize(n_suffix_take); + + tokens_prefix.insert(tokens_prefix.begin(), llama_vocab_fim_pre(vocab)); + tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end()); + tokens_suffix.insert(tokens_suffix.begin(), llama_vocab_fim_suf(vocab)); + + auto embd_inp = spm_infill ? tokens_suffix : tokens_prefix; + auto embd_end = spm_infill ? tokens_prefix : tokens_suffix; + + if (llama_vocab_get_add_bos(vocab)) { + embd_inp.insert(embd_inp.begin(), llama_vocab_bos(vocab)); + } + + SRV_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int) extra_tokens.size()); + + // put the extra context before the FIM prefix + embd_inp.insert(embd_inp.begin(), extra_tokens.end() - n_extra_take, extra_tokens.end()); + + embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end()); + embd_inp.push_back(llama_vocab_fim_mid(vocab)); + + return embd_inp; +} + // Format given chat. If tmpl is empty, we take the template from model metadata -inline std::string format_chat(const struct llama_model *model, const std::string &tmpl, - const std::vector &messages) -{ - std::vector chat; +inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector & messages) { + std::vector chat; - for (size_t i = 0; i < messages.size(); ++i) - { - const auto &curr_msg = messages[i]; + for (size_t i = 0; i < messages.size(); ++i) { + const auto & curr_msg = messages[i]; std::string role = json_value(curr_msg, "role", std::string("")); std::string content; - if (curr_msg.contains("content")) - { - if (curr_msg["content"].is_string()) - { + if (curr_msg.contains("content")) { + if (curr_msg["content"].is_string()) { content = curr_msg["content"].get(); - } - else if (curr_msg["content"].is_array()) - { - for (const auto &part : curr_msg["content"]) - { - if (part.contains("text")) - { + } else if (curr_msg["content"].is_array()) { + for (const auto & part : curr_msg["content"]) { + if (part.contains("text")) { content += "\n" + part["text"].get(); } } + } else { + throw std::runtime_error("Invalid 'content' type (ref: https://github.com/ggerganov/llama.cpp/issues/8367)"); } - else - { - throw std::runtime_error( - "Invalid 'content' type (ref: https://github.com/ggerganov/llama.cpp/issues/8367)"); - } - } - else - { + } else { throw std::runtime_error("Missing 'content' (ref: https://github.com/ggerganov/llama.cpp/issues/8367)"); } chat.push_back({role, content}); } - auto formatted_chat = llama_chat_apply_template(model, tmpl, chat, true); - LOG_VERBOSE("formatted_chat", {{"text", formatted_chat.c_str()}}); + const auto formatted_chat = common_chat_apply_template(model, tmpl, chat, true); + LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str()); + return formatted_chat; } @@ -204,17 +387,16 @@ inline std::string format_chat(const struct llama_model *model, const std::strin // base64 utils (TODO: move to common in the future) // -static const std::string base64_chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" - "abcdefghijklmnopqrstuvwxyz" - "0123456789+/"; +static const std::string base64_chars = + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz" + "0123456789+/"; -static inline bool is_base64(uint8_t c) -{ +static inline bool is_base64(uint8_t c) { return (isalnum(c) || (c == '+') || (c == '/')); } -static inline std::vector base64_decode(const std::string &encoded_string) -{ +static inline std::vector base64_decode(const std::string & encoded_string) { int i = 0; int j = 0; int in_ = 0; @@ -226,23 +408,18 @@ static inline std::vector base64_decode(const std::string &encoded_stri std::vector ret; - while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) - { - char_array_4[i++] = encoded_string[in_]; - in_++; - if (i == 4) - { - for (i = 0; i < 4; i++) - { + while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) { + char_array_4[i++] = encoded_string[in_]; in_++; + if (i == 4) { + for (i = 0; i < 4; i++) { char_array_4[i] = base64_chars.find(char_array_4[i]); } - char_array_3[0] = ((char_array_4[0]) << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4); char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); - char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; - for (i = 0; (i < 3); i++) - { + for (i = 0; (i < 3); i++) { ret.push_back(char_array_3[i]); } @@ -250,24 +427,20 @@ static inline std::vector base64_decode(const std::string &encoded_stri } } - if (i) - { - for (j = i; j < 4; j++) - { + if (i) { + for (j = i; j < 4; j++) { char_array_4[j] = 0; } - for (j = 0; j < 4; j++) - { + for (j = 0; j < 4; j++) { char_array_4[j] = base64_chars.find(char_array_4[j]); } - char_array_3[0] = ((char_array_4[0]) << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4); char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); - char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; - for (j = 0; j < i - 1; j++) - { + for (j = 0; j < i - 1; j++) { ret.push_back(char_array_3[j]); } } @@ -279,8 +452,7 @@ static inline std::vector base64_decode(const std::string &encoded_stri // random string / id // -static std::string random_string() -{ +static std::string random_string() { static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"); std::random_device rd; @@ -288,63 +460,32 @@ static std::string random_string() std::string result(32, ' '); - for (int i = 0; i < 32; ++i) - { + for (int i = 0; i < 32; ++i) { result[i] = str[generator() % str.size()]; } return result; } -static std::string gen_chatcmplid() -{ - std::stringstream chatcmplid; - chatcmplid << "chatcmpl-" << random_string(); - - return chatcmplid.str(); +static std::string gen_chatcmplid() { + return "chatcmpl-" + random_string(); } // // other common utils // -static size_t common_part(const std::vector &a, const std::vector &b) -{ - size_t i; - for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) - { - } - - return i; -} - -static size_t common_part(const std::string &a, const std::string &b) -{ - size_t i; - for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) - { - } - - return i; -} - -static bool ends_with(const std::string &str, const std::string &suffix) -{ +static bool ends_with(const std::string & str, const std::string & suffix) { return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); } -static size_t find_partial_stop_string(const std::string &stop, const std::string &text) -{ - if (!text.empty() && !stop.empty()) - { +static size_t find_partial_stop_string(const std::string &stop, const std::string &text) { + if (!text.empty() && !stop.empty()) { const char text_last_char = text.back(); - for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) - { - if (stop[char_index] == text_last_char) - { + for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) { + if (stop[char_index] == text_last_char) { const std::string current_partial = stop.substr(0, char_index + 1); - if (ends_with(text, current_partial)) - { + if (ends_with(text, current_partial)) { return text.size() - char_index - 1; } } @@ -355,26 +496,23 @@ static size_t find_partial_stop_string(const std::string &stop, const std::strin } // TODO: reuse llama_detokenize -template static std::string tokens_to_str(llama_context *ctx, Iter begin, Iter end) -{ +template +static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) { std::string ret; - for (; begin != end; ++begin) - { - ret += llama_token_to_piece(ctx, *begin); + for (; begin != end; ++begin) { + ret += common_token_to_piece(ctx, *begin); } return ret; } // format incomplete utf-8 multibyte character for output -static std::string tokens_to_output_formatted_string(const llama_context *ctx, const llama_token token) -{ - std::string out = token == -1 ? "" : llama_token_to_piece(ctx, token); +static std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token) { + std::string out = token == LLAMA_TOKEN_NULL ? "" : common_token_to_piece(ctx, token); // if the size is 1 and first bit is 1, meaning it's a partial character // (size > 1 meaning it's already a known token) - if (out.size() == 1 && (out[0] & 0x80) == 0x80) - { + if (out.size() == 1 && (out[0] & 0x80) == 0x80) { std::stringstream ss; ss << std::hex << (out[0] & 0xff); std::string res(ss.str()); @@ -383,127 +521,119 @@ static std::string tokens_to_output_formatted_string(const llama_context *ctx, c return out; } +/* +static bool server_sent_event(httplib::DataSink & sink, const char * event, const json & data) { + const std::string str = + std::string(event) + ": " + + data.dump(-1, ' ', false, json::error_handler_t::replace) + + "\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row). -struct completion_token_output -{ - llama_token tok; - std::string text_to_send; + LOG_DBG("data stream, to_send: %s", str.c_str()); - struct token_prob - { - llama_token tok; - float prob; - }; + return sink.write(str.c_str(), str.size()); +} +*/ +// +// OAI utils +// - std::vector probs; -}; - -// convert a vector of completion_token_output to json -static json probs_vector_to_json(const llama_context *ctx, const std::vector &probs) -{ - json out = json::array(); - - for (const auto &prob : probs) - { - json probs_for_token = json::array(); - - for (const auto &p : prob.probs) - { - const std::string tok_str = tokens_to_output_formatted_string(ctx, p.tok); - probs_for_token.push_back(json{ - {"tok_str", tok_str}, - {"prob", p.prob}, - }); +static json oaicompat_completion_params_parse(const json & body) { + json llama_params; + + if (!body.contains("prompt")) { + throw std::runtime_error("\"prompt\" is required"); + } + + // Handle "stop" field + if (body.contains("stop") && body.at("stop").is_string()) { + llama_params["stop"] = json::array({body.at("stop").get()}); + } else { + llama_params["stop"] = json_value(body, "stop", json::array()); + } + + // Handle "n" field + int n_choices = json_value(body, "n", 1); + if (n_choices != 1) { + throw std::runtime_error("Only one completion choice is allowed"); + } + + // Params supported by OAI but unsupported by llama.cpp + static const std::vector unsupported_params { "best_of", "echo", "suffix" }; + for (const auto & param : unsupported_params) { + if (body.contains(param)) { + throw std::runtime_error("Unsupported param: " + param); } + } - const std::string tok_str = tokens_to_output_formatted_string(ctx, prob.tok); - out.push_back(json{ - {"content", tok_str}, - {"probs", probs_for_token}, - }); + // Copy remaining properties to llama_params + for (const auto & item : body.items()) { + // Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens" + if (!llama_params.contains(item.key()) || item.key() == "n_predict") { + llama_params[item.key()] = item.value(); + } } - return out; + return llama_params; } -// -// OAI utils -// - -static json oaicompat_completion_params_parse(const struct llama_model *model, - const json &body, /* openai api json semantics */ - const std::string &chat_template) -{ +static json oaicompat_chat_completion_params_parse( + const struct llama_model * model, + const json & body, /* openai api json semantics */ + const std::string & chat_template) { json llama_params; - llama_params["__oaicompat"] = true; - // Apply chat template to the list of messages llama_params["prompt"] = format_chat(model, chat_template, body.at("messages")); // Handle "stop" field - if (body.contains("stop") && body.at("stop").is_string()) - { + if (body.contains("stop") && body.at("stop").is_string()) { llama_params["stop"] = json::array({body.at("stop").get()}); - } - else - { + } else { llama_params["stop"] = json_value(body, "stop", json::array()); } // Handle "response_format" field - if (body.contains("response_format")) - { - json response_format = json_value(body, "response_format", json::object()); + if (body.contains("response_format")) { + json response_format = json_value(body, "response_format", json::object()); std::string response_type = json_value(response_format, "type", std::string()); - if (response_type == "json_object") - { + if (response_type == "json_object") { llama_params["json_schema"] = json_value(response_format, "schema", json::object()); - } - else if (!response_type.empty() && response_type != "text") - { - throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + - response_type); + } else if (response_type == "json_schema") { + json json_schema = json_value(response_format, "json_schema", json::object()); + llama_params["json_schema"] = json_value(json_schema, "schema", json::object()); + } else if (!response_type.empty() && response_type != "text") { + throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type); } } // Handle "n" field int n_choices = json_value(body, "n", 1); - if (n_choices != 1) - { + if (n_choices != 1) { throw std::runtime_error("Only one completion choice is allowed"); } // Handle "logprobs" field - // TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may - // need to fix it in the future - if (body.contains("logprobs")) - { + // TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future + if (json_value(body, "logprobs", false)) { llama_params["n_probs"] = json_value(body, "top_logprobs", 20); - } - else if (body.contains("top_logprobs")) - { + } else if (body.contains("top_logprobs") && !body.at("top_logprobs").is_null()) { throw std::runtime_error("top_logprobs requires logprobs to be set to true"); } // Params supported by OAI but unsupported by llama.cpp - static const std::vector unsupported_params{"tools", "tool_choice"}; - for (auto ¶m : unsupported_params) - { - if (body.contains(param)) - { + static const std::vector unsupported_params { "tools", "tool_choice" }; + for (const auto & param : unsupported_params) { + if (body.contains(param)) { throw std::runtime_error("Unsupported param: " + param); } } // Copy remaining properties to llama_params - // This allows user to use llama.cpp-specific params like "mirostat", "tfs_z",... via OAI endpoint. + // This allows user to use llama.cpp-specific params like "mirostat", ... via OAI endpoint. // See "launch_slot_with_task()" for a complete list of params supported by llama.cpp - for (const auto &item : body.items()) - { + for (const auto & item : body.items()) { // Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens" - if (!llama_params.contains(item.key()) || item.key() == "n_predict") - { + if (!llama_params.contains(item.key()) || item.key() == "n_predict") { llama_params[item.key()] = item.value(); } } @@ -511,219 +641,205 @@ static json oaicompat_completion_params_parse(const struct llama_model *model, return llama_params; } -static json format_final_response_oaicompat(const json &request, json result, const std::string &completion_id, - bool streaming = false) -{ - bool stopped_word = result.count("stopped_word") != 0; - bool stopped_eos = json_value(result, "stopped_eos", false); - int num_tokens_predicted = json_value(result, "tokens_predicted", 0); - int num_prompt_tokens = json_value(result, "tokens_evaluated", 0); - std::string content = json_value(result, "content", std::string("")); - - std::string finish_reason = "length"; - if (stopped_word || stopped_eos) - { - finish_reason = "stop"; - } - - json choices = streaming - ? json::array({json{{"finish_reason", finish_reason}, {"index", 0}, {"delta", json::object()}}}) - : json::array({json{{"finish_reason", finish_reason}, - {"index", 0}, - {"message", json{{"content", content}, {"role", "assistant"}}}}}); - - std::time_t t = std::time(0); - - json res = json{{"choices", choices}, - {"created", t}, - {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, - {"object", streaming ? "chat.completion.chunk" : "chat.completion"}, - {"usage", json{{"completion_tokens", num_tokens_predicted}, - {"prompt_tokens", num_prompt_tokens}, - {"total_tokens", num_tokens_predicted + num_prompt_tokens}}}, - {"id", completion_id}}; - -#if SERVER_VERBOSE - res["__verbose"] = result; -#endif +static json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false) { + json data = json::array(); + int32_t n_tokens = 0; + int i = 0; + for (const auto & elem : embeddings) { + json embedding_obj; + + if (use_base64) { + const auto& vec = json_value(elem, "embedding", json::array()).get>(); + const char* data_ptr = reinterpret_cast(vec.data()); + size_t data_size = vec.size() * sizeof(float); + embedding_obj = { + {"embedding", base64::encode(data_ptr, data_size)}, + {"index", i++}, + {"object", "embedding"}, + {"encoding_format", "base64"} + }; + } else { + embedding_obj = { + {"embedding", json_value(elem, "embedding", json::array())}, + {"index", i++}, + {"object", "embedding"} + }; + } + data.push_back(embedding_obj); - if (result.contains("completion_probabilities")) - { - res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array()); + n_tokens += json_value(elem, "tokens_evaluated", 0); } + json res = json { + {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, + {"object", "list"}, + {"usage", json { + {"prompt_tokens", n_tokens}, + {"total_tokens", n_tokens} + }}, + {"data", data} + }; + return res; } -// return value is vector as there is one case where we might need to generate two responses -static std::vector format_partial_response_oaicompat(json result, const std::string &completion_id) -{ - if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) - { - return std::vector({result}); +static json format_response_rerank(const json & request, const json & ranks) { + json data = json::array(); + int32_t n_tokens = 0; + int i = 0; + for (const auto & rank : ranks) { + data.push_back(json{ + {"index", i++}, + {"relevance_score", json_value(rank, "score", 0.0)}, + }); + + n_tokens += json_value(rank, "tokens_evaluated", 0); } - bool first = json_value(result, "oaicompat_token_ctr", 0) == 0; - std::string modelname = json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL)); + json res = json { + {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, + {"object", "list"}, + {"usage", json { + {"prompt_tokens", n_tokens}, + {"total_tokens", n_tokens} + }}, + {"results", data} + }; - bool stopped_word = json_value(result, "stopped_word", false); - bool stopped_eos = json_value(result, "stopped_eos", false); - bool stopped_limit = json_value(result, "stopped_limit", false); - std::string content = json_value(result, "content", std::string("")); + return res; +} - std::string finish_reason; - if (stopped_word || stopped_eos) - { - finish_reason = "stop"; - } - if (stopped_limit) - { - finish_reason = "length"; +static bool is_valid_utf8(const std::string & str) { + const unsigned char* bytes = reinterpret_cast(str.data()); + const unsigned char* end = bytes + str.length(); + + while (bytes < end) { + if (*bytes <= 0x7F) { + // 1-byte sequence (0xxxxxxx) + bytes++; + } else if ((*bytes & 0xE0) == 0xC0) { + // 2-byte sequence (110xxxxx 10xxxxxx) + if (end - bytes < 2 || (bytes[1] & 0xC0) != 0x80) + return false; + bytes += 2; + } else if ((*bytes & 0xF0) == 0xE0) { + // 3-byte sequence (1110xxxx 10xxxxxx 10xxxxxx) + if (end - bytes < 3 || (bytes[1] & 0xC0) != 0x80 || (bytes[2] & 0xC0) != 0x80) + return false; + bytes += 3; + } else if ((*bytes & 0xF8) == 0xF0) { + // 4-byte sequence (11110xxx 10xxxxxx 10xxxxxx 10xxxxxx) + if (end - bytes < 4 || (bytes[1] & 0xC0) != 0x80 || + (bytes[2] & 0xC0) != 0x80 || (bytes[3] & 0xC0) != 0x80) + return false; + bytes += 4; + } else { + // Invalid UTF-8 lead byte + return false; + } } - std::time_t t = std::time(0); - - json choices; + return true; +} - if (!finish_reason.empty()) - { - choices = json::array({json{{"finish_reason", finish_reason}, {"index", 0}, {"delta", json::object()}}}); - } - else - { - if (first) - { - if (content.empty()) - { - choices = json::array( - {json{{"finish_reason", nullptr}, {"index", 0}, {"delta", json{{"role", "assistant"}}}}}); - } - else - { - // We have to send this as two updates to conform to openai behavior - json initial_ret = json{{"choices", json::array({json{{"finish_reason", nullptr}, - {"index", 0}, - {"delta", json{{"role", "assistant"}}}}})}, - {"created", t}, - {"id", completion_id}, - {"model", modelname}, - {"object", "chat.completion.chunk"}}; - - json second_ret = - json{{"choices", - json::array( - {json{{"finish_reason", nullptr}, {"index", 0}, {"delta", json{{"content", content}}}}})}, - {"created", t}, - {"id", completion_id}, - {"model", modelname}, - {"object", "chat.completion.chunk"}}; - - return std::vector({initial_ret, second_ret}); - } - } - else - { - // Some idiosyncrasy in task processing logic makes several trailing calls - // with empty content, we ignore these at the calee site. - if (content.empty()) - { - return std::vector({json::object()}); - } +static json format_tokenizer_response(const json & tokens) { + return json { + {"tokens", tokens} + }; +} - choices = json::array({json{ - {"finish_reason", nullptr}, - {"index", 0}, - {"delta", - json{ - {"content", content}, - }}, - }}); - } - } +static json format_detokenized_response(const std::string & content) { + return json { + {"content", content} + }; +} - json ret = json{{"choices", choices}, - {"created", t}, - {"id", completion_id}, - {"model", modelname}, - {"object", "chat.completion.chunk"}}; - if (!finish_reason.empty()) - { - int num_tokens_predicted = json_value(result, "tokens_predicted", 0); - int num_prompt_tokens = json_value(result, "tokens_evaluated", 0); - ret.push_back({"usage", json{{"completion_tokens", num_tokens_predicted}, - {"prompt_tokens", num_prompt_tokens}, - {"total_tokens", num_tokens_predicted + num_prompt_tokens}}}); +static json format_logit_bias(const std::vector & logit_bias) { + json data = json::array(); + for (const auto & lb : logit_bias) { + data.push_back(json{ + {"bias", lb.bias}, + {"token", lb.token}, + }); } + return data; +} - return std::vector({ret}); +static std::string safe_json_to_str(const json & data) { + return data.dump(-1, ' ', false, json::error_handler_t::replace); } -static json format_embeddings_response_oaicompat(const json &request, const json &embeddings) -{ - json data = json::array(); - int i = 0; - for (auto &elem : embeddings) - { - data.push_back( - json{{"embedding", json_value(elem, "embedding", json::array())}, {"index", i++}, {"object", "embedding"}}); +static std::vector get_token_probabilities(llama_context * ctx, int idx) { + std::vector cur; + const auto * logits = llama_get_logits_ith(ctx, idx); + + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + const int n_vocab = llama_vocab_n_tokens(vocab); + + cur.resize(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; } - json res = json{{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, - {"object", "list"}, - {"usage", json{{"prompt_tokens", 0}, {"total_tokens", 0}}}, - {"data", data}}; + // sort tokens by logits + std::sort(cur.begin(), cur.end(), [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }); - return res; -} + // apply softmax + float max_l = cur[0].logit; + float cum_sum = 0.0f; + for (size_t i = 0; i < cur.size(); ++i) { + float p = expf(cur[i].logit - max_l); + cur[i].p = p; + cum_sum += p; + } + for (size_t i = 0; i < cur.size(); ++i) { + cur[i].p /= cum_sum; + } -static json format_tokenizer_response(const std::vector &tokens) -{ - return json{{"tokens", tokens}}; + return cur; } -static json format_detokenized_response(const std::string &content) -{ - return json{{"content", content}}; +static bool are_lora_equal( + const std::vector & l1, + const std::vector & l2) { + if (l1.size() != l2.size()) { + return false; + } + for (size_t i = 0; i < l1.size(); ++i) { + // we don't check lora.path to reduce the time complexity + if (l1[i].scale != l2[i].scale || l1[i].ptr != l2[i].ptr) { + return false; + } + } + return true; } -static json format_error_response(const std::string &message, const enum error_type type) -{ - std::string type_str; - int code = 500; - switch (type) - { - case ERROR_TYPE_INVALID_REQUEST: - type_str = "invalid_request_error"; - code = 400; - break; - case ERROR_TYPE_AUTHENTICATION: - type_str = "authentication_error"; - code = 401; - break; - case ERROR_TYPE_NOT_FOUND: - type_str = "not_found_error"; - code = 404; - break; - case ERROR_TYPE_SERVER: - type_str = "server_error"; - code = 500; - break; - case ERROR_TYPE_PERMISSION: - type_str = "permission_error"; - code = 403; - break; - case ERROR_TYPE_NOT_SUPPORTED: - type_str = "not_supported_error"; - code = 501; - break; - case ERROR_TYPE_UNAVAILABLE: - type_str = "unavailable_error"; - code = 503; - break; - } - return json{ - {"code", code}, - {"message", message}, - {"type", type_str}, - }; +// parse lora config from JSON request, returned a copy of lora_base with updated scale +static std::vector parse_lora_request( + const std::vector & lora_base, + const json & data) { + std::vector lora(lora_base); + int max_idx = lora.size(); + + // clear existing value + for (auto & entry : lora) { + entry.scale = 0.0f; + } + + // set value + for (const auto & entry : data) { + int id = json_value(entry, "id", -1); + float scale = json_value(entry, "scale", 0.0f); + if (0 <= id && id < max_idx) { + lora[id].scale = scale; + } else { + throw std::runtime_error("invalid adapter id"); + } + } + + return lora; } From 4dc9dd1c5a8ed909f622ad8652885ef192c62c7d Mon Sep 17 00:00:00 2001 From: "alexey.sokolov" Date: Tue, 28 Jan 2025 20:30:19 +0700 Subject: [PATCH 2/9] compiled version --- src/main/cpp/jllama.cpp | 174 +++++++++++++++++++++------------------- 1 file changed, 93 insertions(+), 81 deletions(-) diff --git a/src/main/cpp/jllama.cpp b/src/main/cpp/jllama.cpp index 8ae54811..b5bb2937 100644 --- a/src/main/cpp/jllama.cpp +++ b/src/main/cpp/jllama.cpp @@ -139,6 +139,9 @@ JNIEnv *get_jni_env() return env; } +bool log_json; +std::function log_callback; + /** * Invoke the log callback if there is any. */ @@ -151,8 +154,7 @@ void log_callback_trampoline(ggml_log_level level, const char *text, void *user_ } } // namespace -bool log_json; -std::function log_callback; + /** * The VM calls JNI_OnLoad when the native library is loaded (for example, through `System.loadLibrary`). @@ -356,9 +358,6 @@ JNIEXPORT void JNICALL JNI_OnUnload(JavaVM *vm, void *reserved) JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jobject obj, jstring jparams) { common_params params; - common_init(); - - auto *ctx_server = new server_context(); std::string c_params = parse_jstring(env, jparams); json json_params = json::parse(c_params); @@ -366,23 +365,19 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo if (json_value(json_params, "disable_log", false)) { - common_log_disable(); + common_log_pause(common_log_main()); } else { - log_enable(); + common_log_resume(common_log_main()); } - if (!params.prompt.empty()) - { - ctx_server->system_prompt_set(params.prompt); - } + common_init(); - if (params.model_alias == "unknown") - { - params.model_alias = params.model; - } + // struct that contains llama context and inference + auto *ctx_server = new server_context(); + llama_backend_init(); llama_numa_init(params.numa); LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params.cpuparams.n_threads, params.cpuparams_batch.n_threads, std::thread::hardware_concurrency()); @@ -392,13 +387,16 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo std::atomic state{SERVER_STATE_LOADING_MODEL}; + // Necessary similarity of prompt for slot selection ctx_server->slot_prompt_similarity = params.slot_prompt_similarity; + LOG_INF("%s: loading model\n", __func__); + // load the model if (!ctx_server->load_model(params)) { - state.store(SERVER_STATE_ERROR); + llama_backend_free();; env->ThrowNew(c_llama_error, "could not load model from given file path"); return; } @@ -406,41 +404,23 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo ctx_server->init(); state.store(SERVER_STATE_READY); - LOG_INFO("model loaded", {}); + LOG_INF("%s: model loaded\n", __func__); const auto model_meta = ctx_server->model_meta(); // if a custom chat template is not supplied, we will use the one that comes with the model (if any) - if (params.chat_template.empty()) - { - if (!ctx_server->validate_model_chat_template()) - { - LOG_ERR("%s", "The chat template that comes with this model is not yet supported, falling back to chatml. This " - "may cause the model to output suboptimal responses"); - params.chat_template = "chatml"; - } - } - - // if a custom chat template is not supplied, we will use the one that comes with the model (if any) - if (params.chat_template.empty()) - { - if (!ctx_server->validate_model_chat_template()) - { - LOG_ERROR("The chat template that comes with this model is not yet supported, falling back to chatml. This " - "may cause the model to output suboptimal responses", - {}); + if (params.chat_template.empty()) { + if (!ctx_server->validate_builtin_chat_template()) { + LOG_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__); params.chat_template = "chatml"; } } // print sample chat example to make it clear which template is used - { - LOG_INFO("chat template", - { - {"chat_example", llama_chat_format_example(ctx_server->model, params.chat_template)}, - {"built_in", params.chat_template.empty()}, - }); - } + LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, + params.chat_template.empty() ? "(built-in)" : params.chat_template.c_str(), + common_chat_format_example(ctx_server->model, params.chat_template).c_str()); + ctx_server->queue_tasks.on_new_task(std::bind( &server_context::process_single_task, ctx_server, std::placeholders::_1)); @@ -471,22 +451,46 @@ JNIEXPORT jint JNICALL Java_de_kherud_llama_LlamaModel_requestCompletion(JNIEnv auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr) std::string c_params = parse_jstring(env, jparams); - json json_params = json::parse(c_params); + json data = json::parse(c_params); - server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL; - if (json_params.contains("input_prefix") || json_params.contains("input_suffix")) { - cmpl_type = SERVER_TASK_CMPL_TYPE_INFILL; + server_task_type type = SERVER_TASK_TYPE_COMPLETION; + + if (data.contains("input_prefix") || data.contains("input_suffix")) { + type = SERVER_TASK_TYPE_INFILL; } - if (json_params.value("use_chat_template", false)) - { - json chat; - chat.push_back({{"role", "system"}, {"content", ctx_server->system_prompt}}); - chat.push_back({{"role", "user"}, {"content", json_params["prompt"]}}); - json_params["prompt"] = format_chat(ctx_server->model, ctx_server->params.chat_template, chat); + auto completion_id = gen_chatcmplid(); + std::vector tasks; + + try { + std::vector tokenized_prompts = tokenize_input_prompts(ctx_server->vocab,data.at("prompt"), true, true); + tasks.reserve(tokenized_prompts.size()); + for (size_t i = 0; i < tokenized_prompts.size(); i++) { + server_task task = server_task(type); + + task.id = ctx_server->queue_tasks.get_new_id(); + task.index = i; + + task.prompt_tokens = std::move(tokenized_prompts[i]); + task.params = server_task::params_from_json_cmpl( + ctx_server->ctx, + ctx_server->params_base, + data); + task.id_selected_slot = json_value(data, "id_slot", -1); + + // OAI-compat + task.params.oaicompat = OAICOMPAT_TYPE_NONE; + task.params.oaicompat_cmpl_id = completion_id; + // oaicompat_model is already populated by params_from_json_cmpl + + tasks.push_back(task); + } + } catch (const std::exception & e) { + const auto &err = format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST); + env->ThrowNew(c_llama_error, err.dump().c_str()); + return 0; } - std::vector tasks = ctx_server->create_tasks_cmpl(json_params, cmpl_type); ctx_server->queue_results.add_waiting_tasks(tasks); ctx_server->queue_tasks.post(tasks); @@ -505,26 +509,26 @@ JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_receiveCompletion(JNIE jlong server_handle = env->GetLongField(obj, f_model_pointer); auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr) - server_task_result result = ctx_server->queue_results.recv(id_task); + server_task_result_ptr result = ctx_server->queue_results.recv(id_task); - if (result.error) + if (result->is_error()) { - std::string response = result.data["message"].get(); + std::string response = result->to_json()["message"].get(); ctx_server->queue_results.remove_waiting_task_id(id_task); env->ThrowNew(c_llama_error, response.c_str()); return nullptr; } - - std::string response = result.data["content"].get(); - if (result.stop) + const auto out_res = result->to_json(); + std::string response = out_res["content"].get(); + if (result->is_stop()) { ctx_server->queue_results.remove_waiting_task_id(id_task); } jobject o_probabilities = env->NewObject(c_hash_map, cc_hash_map); - if (result.data.contains("completion_probabilities")) + if (out_res.contains("completion_probabilities")) { - auto completion_probabilities = result.data["completion_probabilities"]; + auto completion_probabilities = out_res["completion_probabilities"]; for (const auto &entry : completion_probabilities) { auto probs = entry["probs"]; @@ -542,7 +546,7 @@ JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_receiveCompletion(JNIE } jbyteArray jbytes = parse_jbytes(env, response); - return env->NewObject(c_output, cc_output, jbytes, o_probabilities, result.stop); + return env->NewObject(c_output, cc_output, jbytes, o_probabilities, result->is_stop()); } JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env, jobject obj, jstring jprompt) @@ -550,7 +554,7 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env, jlong server_handle = env->GetLongField(obj, f_model_pointer); auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr) - if (!ctx_server->params.embedding) + if (!ctx_server->params_base.embedding) { env->ThrowNew(c_llama_error, "model was not loaded with embedding support (see ModelParameters#setEmbedding(boolean))"); @@ -559,36 +563,42 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env, const std::string prompt = parse_jstring(env, jprompt); - std::vector tasks = ctx_server->create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING); + const auto tokens = tokenize_mixed(ctx_server->vocab, prompt, true, true); + std::vector tasks; + + server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); + + task.id = ctx_server->queue_tasks.get_new_id(); + task.index = 0; + task.prompt_tokens = std::move(tokens); + + // OAI-compat + task.params.oaicompat = OAICOMPAT_TYPE_NONE; + + tasks.push_back(task); + ctx_server->queue_results.add_waiting_tasks(tasks); ctx_server->queue_tasks.post(tasks); std::unordered_set task_ids = server_task::get_list_id(tasks); - + const auto id_task = *task_ids.begin(); json responses = json::array(); json error = nullptr; - ctx_server->receive_cmpl_results(task_ids, [&](std::vector & results) { - for (const auto & res : results) { - responses.push_back(res.data); - } - }, [&](const json& error_data) { - error = error_data; - }); - if (error != nullptr) + server_task_result_ptr result = ctx_server->queue_results.recv(id_task); + + if (result->is_error()) { - std::string response = error["message"].get(); + std::string response = result->to_json()["message"].get(); + ctx_server->queue_results.remove_waiting_task_id(id_task); env->ThrowNew(c_llama_error, response.c_str()); return nullptr; } - if (responses.size() != 1) { - env->ThrowNew(c_llama_error, "could not compute embedding"); - return nullptr; - } + const auto out_res = result->to_json(); - std::vector embedding = responses[0]["embedding"].get>(); + std::vector embedding = out_res["embedding"].get>(); jsize embedding_size = embedding.size(); // NOLINT(*-narrowing-conversions) jfloatArray j_embedding = env->NewFloatArray(embedding_size); @@ -609,7 +619,8 @@ JNIEXPORT jintArray JNICALL Java_de_kherud_llama_LlamaModel_encode(JNIEnv *env, auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr) const std::string c_prompt = parse_jstring(env, jprompt); - std::vector tokens = ctx_server->tokenize(c_prompt, false); + + llama_tokens tokens = tokenize_mixed(ctx_server->vocab, c_prompt, false, true); jsize token_size = tokens.size(); // NOLINT(*-narrowing-conversions) jintArray java_tokens = env->NewIntArray(token_size); @@ -652,7 +663,8 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_cancelCompletion(JNIEnv * { jlong server_handle = env->GetLongField(obj, f_model_pointer); auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr) - ctx_server->request_cancel(id_task); + std::unordered_set id_tasks = {id_task}; + ctx_server->cancel_tasks(id_tasks); ctx_server->queue_results.remove_waiting_task_id(id_task); } From bc9c85acb5e91e80134c02dba9e5b38c21d9cb40 Mon Sep 17 00:00:00 2001 From: Alexey Sokolov Date: Thu, 30 Jan 2025 16:01:49 +0700 Subject: [PATCH 3/9] fix: the external cmake option "LLAMA_BUILD_COMMON=ON" moved inside CMakeList.txt --- CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index a083ea10..48681b9b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -20,6 +20,7 @@ FetchContent_MakeAvailable(json) #################### llama.cpp #################### +set(LLAMA_BUILD_COMMON ON) FetchContent_Declare( llama.cpp GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git From f446fb58908ebde653e3650427d33ac4fee57f36 Mon Sep 17 00:00:00 2001 From: Konstantin Herud Date: Mon, 3 Feb 2025 22:57:09 +0100 Subject: [PATCH 4/9] add cli params abstract class --- .../java/de/kherud/llama/CliParameters.java | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 src/main/java/de/kherud/llama/CliParameters.java diff --git a/src/main/java/de/kherud/llama/CliParameters.java b/src/main/java/de/kherud/llama/CliParameters.java new file mode 100644 index 00000000..4142628e --- /dev/null +++ b/src/main/java/de/kherud/llama/CliParameters.java @@ -0,0 +1,40 @@ +package de.kherud.llama; + +import org.jetbrains.annotations.Nullable; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +abstract class CliParameters { + + final Map parameters = new HashMap<>(); + + @Override + public String toString() { + StringBuilder builder = new StringBuilder(); + for (String key : parameters.keySet()) { + String value = parameters.get(key); + builder.append(key).append(" "); + if (value != null) { + builder.append(value).append(" "); + } + } + return builder.toString(); + } + + public String[] toArray() { + List result = new ArrayList<>(); + result.add(""); // c args contain the program name as the first argument, so we add an empty entry + for (String key : parameters.keySet()) { + result.add(key); + String value = parameters.get(key); + if (value != null) { + result.add(value); + } + } + return result.toArray(new String[0]); + } + +} From 027f2b66c6162adee49254f569202db07173702a Mon Sep 17 00:00:00 2001 From: Konstantin Herud Date: Mon, 3 Feb 2025 22:57:25 +0100 Subject: [PATCH 5/9] update config param enums --- .../java/de/kherud/llama/args/CacheType.java | 15 +++++++++++++++ .../de/kherud/llama/args/NumaStrategy.java | 4 +--- .../de/kherud/llama/args/PoolingType.java | 19 ++++++++++++++++--- .../de/kherud/llama/args/RopeScalingType.java | 19 ++++++++++++++++--- .../java/de/kherud/llama/args/Sampler.java | 16 ++++++++++------ 5 files changed, 58 insertions(+), 15 deletions(-) create mode 100644 src/main/java/de/kherud/llama/args/CacheType.java diff --git a/src/main/java/de/kherud/llama/args/CacheType.java b/src/main/java/de/kherud/llama/args/CacheType.java new file mode 100644 index 00000000..8404ed75 --- /dev/null +++ b/src/main/java/de/kherud/llama/args/CacheType.java @@ -0,0 +1,15 @@ +package de.kherud.llama.args; + +public enum CacheType { + + F32, + F16, + BF16, + Q8_0, + Q4_0, + Q4_1, + IQ4_NL, + Q5_0, + Q5_1 + +} diff --git a/src/main/java/de/kherud/llama/args/NumaStrategy.java b/src/main/java/de/kherud/llama/args/NumaStrategy.java index 35b24e19..fa7a61b0 100644 --- a/src/main/java/de/kherud/llama/args/NumaStrategy.java +++ b/src/main/java/de/kherud/llama/args/NumaStrategy.java @@ -2,9 +2,7 @@ public enum NumaStrategy { - DISABLED, DISTRIBUTE, ISOLATE, - NUMA_CTL, - MIRROR + NUMACTL } diff --git a/src/main/java/de/kherud/llama/args/PoolingType.java b/src/main/java/de/kherud/llama/args/PoolingType.java index e9b441d4..a9c9dbae 100644 --- a/src/main/java/de/kherud/llama/args/PoolingType.java +++ b/src/main/java/de/kherud/llama/args/PoolingType.java @@ -2,7 +2,20 @@ public enum PoolingType { - UNSPECIFIED, - MEAN, - CLS + UNSPECIFIED(-1), + NONE(0), + MEAN(1), + CLS(2), + LAST(3), + RANK(4); + + private final int id; + + PoolingType(int value) { + this.id = value; + } + + public int getId() { + return id; + } } diff --git a/src/main/java/de/kherud/llama/args/RopeScalingType.java b/src/main/java/de/kherud/llama/args/RopeScalingType.java index a69596f5..eed939a1 100644 --- a/src/main/java/de/kherud/llama/args/RopeScalingType.java +++ b/src/main/java/de/kherud/llama/args/RopeScalingType.java @@ -2,7 +2,20 @@ public enum RopeScalingType { - UNSPECIFIED, - LINEAR, - YARN + UNSPECIFIED(-1), + NONE(0), + LINEAR(1), + YARN2(2), + LONGROPE(3), + MAX_VALUE(3); + + private final int id; + + RopeScalingType(int value) { + this.id = value; + } + + public int getId() { + return id; + } } diff --git a/src/main/java/de/kherud/llama/args/Sampler.java b/src/main/java/de/kherud/llama/args/Sampler.java index 0864e91b..564a2e6f 100644 --- a/src/main/java/de/kherud/llama/args/Sampler.java +++ b/src/main/java/de/kherud/llama/args/Sampler.java @@ -2,10 +2,14 @@ public enum Sampler { - TOP_K, - TFS_Z, - TYPICAL_P, - TOP_P, - MIN_P, - TEMPERATURE + DRY, + TOP_K, + TOP_P, + TYP_P, + MIN_P, + TEMPERATURE, + XTC, + INFILL, + PENALTIES + } From 658d9b50196f67f781f6bfe0e8483af5bb9dbf8f Mon Sep 17 00:00:00 2001 From: Konstantin Herud Date: Mon, 3 Feb 2025 22:57:45 +0100 Subject: [PATCH 6/9] refactor model parameters to build CLI string instead of JSON --- .../java/de/kherud/llama/ModelParameters.java | 1495 +++++++++++------ 1 file changed, 946 insertions(+), 549 deletions(-) diff --git a/src/main/java/de/kherud/llama/ModelParameters.java b/src/main/java/de/kherud/llama/ModelParameters.java index 3b34d3f3..91587001 100644 --- a/src/main/java/de/kherud/llama/ModelParameters.java +++ b/src/main/java/de/kherud/llama/ModelParameters.java @@ -1,557 +1,954 @@ package de.kherud.llama; -import java.util.Map; - -import de.kherud.llama.args.GpuSplitMode; -import de.kherud.llama.args.NumaStrategy; -import de.kherud.llama.args.PoolingType; -import de.kherud.llama.args.RopeScalingType; +import de.kherud.llama.args.*; /*** * Parameters used for initializing a {@link LlamaModel}. */ -public final class ModelParameters extends JsonParameters { - - private static final String PARAM_SEED = "seed"; - private static final String PARAM_N_THREADS = "n_threads"; - private static final String PARAM_N_THREADS_DRAFT = "n_threads_draft"; - private static final String PARAM_N_THREADS_BATCH = "n_threads_batch"; - private static final String PARAM_N_THREADS_BATCH_DRAFT = "n_threads_batch_draft"; - private static final String PARAM_N_PREDICT = "n_predict"; - private static final String PARAM_N_CTX = "n_ctx"; - private static final String PARAM_N_BATCH = "n_batch"; - private static final String PARAM_N_UBATCH = "n_ubatch"; - private static final String PARAM_N_KEEP = "n_keep"; - private static final String PARAM_N_DRAFT = "n_draft"; - private static final String PARAM_N_CHUNKS = "n_chunks"; - private static final String PARAM_N_PARALLEL = "n_parallel"; - private static final String PARAM_N_SEQUENCES = "n_sequences"; - private static final String PARAM_P_SPLIT = "p_split"; - private static final String PARAM_N_GPU_LAYERS = "n_gpu_layers"; - private static final String PARAM_N_GPU_LAYERS_DRAFT = "n_gpu_layers_draft"; - private static final String PARAM_SPLIT_MODE = "split_mode"; - private static final String PARAM_MAIN_GPU = "main_gpu"; - private static final String PARAM_TENSOR_SPLIT = "tensor_split"; - private static final String PARAM_GRP_ATTN_N = "grp_attn_n"; - private static final String PARAM_GRP_ATTN_W = "grp_attn_w"; - private static final String PARAM_ROPE_FREQ_BASE = "rope_freq_base"; - private static final String PARAM_ROPE_FREQ_SCALE = "rope_freq_scale"; - private static final String PARAM_YARN_EXT_FACTOR = "yarn_ext_factor"; - private static final String PARAM_YARN_ATTN_FACTOR = "yarn_attn_factor"; - private static final String PARAM_YARN_BETA_FAST = "yarn_beta_fast"; - private static final String PARAM_YARN_BETA_SLOW = "yarn_beta_slow"; - private static final String PARAM_YARN_ORIG_CTX = "yarn_orig_ctx"; - private static final String PARAM_DEFRAG_THOLD = "defrag_thold"; - private static final String PARAM_NUMA = "numa"; - private static final String PARAM_ROPE_SCALING_TYPE = "rope_scaling_type"; - private static final String PARAM_POOLING_TYPE = "pooling_type"; - private static final String PARAM_MODEL = "model"; - private static final String PARAM_MODEL_DRAFT = "model_draft"; - private static final String PARAM_MODEL_ALIAS = "model_alias"; - private static final String PARAM_MODEL_URL = "model_url"; - private static final String PARAM_HF_REPO = "hf_repo"; - private static final String PARAM_HF_FILE = "hf_file"; - private static final String PARAM_LOOKUP_CACHE_STATIC = "lookup_cache_static"; - private static final String PARAM_LOOKUP_CACHE_DYNAMIC = "lookup_cache_dynamic"; - private static final String PARAM_LORA_ADAPTER = "lora_adapter"; - private static final String PARAM_EMBEDDING = "embedding"; - private static final String PARAM_CONT_BATCHING = "cont_batching"; - private static final String PARAM_FLASH_ATTENTION = "flash_attn"; - private static final String PARAM_INPUT_PREFIX_BOS = "input_prefix_bos"; - private static final String PARAM_IGNORE_EOS = "ignore_eos"; - private static final String PARAM_USE_MMAP = "use_mmap"; - private static final String PARAM_USE_MLOCK = "use_mlock"; - private static final String PARAM_NO_KV_OFFLOAD = "no_kv_offload"; - private static final String PARAM_SYSTEM_PROMPT = "system_prompt"; - private static final String PARAM_CHAT_TEMPLATE = "chat_template"; - - /** - * Set the RNG seed - */ - public ModelParameters setSeed(int seed) { - parameters.put(PARAM_SEED, String.valueOf(seed)); - return this; - } - - /** - * Set the number of threads to use during generation (default: 8) - */ - public ModelParameters setNThreads(int nThreads) { - parameters.put(PARAM_N_THREADS, String.valueOf(nThreads)); - return this; - } - - /** - * Set the number of threads to use during draft generation (default: same as {@link #setNThreads(int)}) - */ - public ModelParameters setNThreadsDraft(int nThreadsDraft) { - parameters.put(PARAM_N_THREADS_DRAFT, String.valueOf(nThreadsDraft)); - return this; - } - - /** - * Set the number of threads to use during batch and prompt processing (default: same as {@link #setNThreads(int)}) - */ - public ModelParameters setNThreadsBatch(int nThreadsBatch) { - parameters.put(PARAM_N_THREADS_BATCH, String.valueOf(nThreadsBatch)); - return this; - } - - /** - * Set the number of threads to use during batch and prompt processing (default: same as - * {@link #setNThreadsDraft(int)}) - */ - public ModelParameters setNThreadsBatchDraft(int nThreadsBatchDraft) { - parameters.put(PARAM_N_THREADS_BATCH_DRAFT, String.valueOf(nThreadsBatchDraft)); - return this; - } - - /** - * Set the number of tokens to predict (default: -1, -1 = infinity, -2 = until context filled) - */ - public ModelParameters setNPredict(int nPredict) { - parameters.put(PARAM_N_PREDICT, String.valueOf(nPredict)); - return this; - } - - /** - * Set the size of the prompt context (default: 512, 0 = loaded from model) - */ - public ModelParameters setNCtx(int nCtx) { - parameters.put(PARAM_N_CTX, String.valueOf(nCtx)); - return this; - } - - /** - * Set the logical batch size for prompt processing (must be >=32 to use BLAS) - */ - public ModelParameters setNBatch(int nBatch) { - parameters.put(PARAM_N_BATCH, String.valueOf(nBatch)); - return this; - } - - /** - * Set the physical batch size for prompt processing (must be >=32 to use BLAS) - */ - public ModelParameters setNUbatch(int nUbatch) { - parameters.put(PARAM_N_UBATCH, String.valueOf(nUbatch)); - return this; - } - - /** - * Set the number of tokens to keep from the initial prompt (default: 0, -1 = all) - */ - public ModelParameters setNKeep(int nKeep) { - parameters.put(PARAM_N_KEEP, String.valueOf(nKeep)); - return this; - } - - /** - * Set the number of tokens to draft for speculative decoding (default: 5) - */ - public ModelParameters setNDraft(int nDraft) { - parameters.put(PARAM_N_DRAFT, String.valueOf(nDraft)); - return this; - } - - /** - * Set the maximal number of chunks to process (default: -1, -1 = all) - */ - public ModelParameters setNChunks(int nChunks) { - parameters.put(PARAM_N_CHUNKS, String.valueOf(nChunks)); - return this; - } - - /** - * Set the number of parallel sequences to decode (default: 1) - */ - public ModelParameters setNParallel(int nParallel) { - parameters.put(PARAM_N_PARALLEL, String.valueOf(nParallel)); - return this; - } - - /** - * Set the number of sequences to decode (default: 1) - */ - public ModelParameters setNSequences(int nSequences) { - parameters.put(PARAM_N_SEQUENCES, String.valueOf(nSequences)); - return this; - } - - /** - * Set the speculative decoding split probability (default: 0.1) - */ - public ModelParameters setPSplit(float pSplit) { - parameters.put(PARAM_P_SPLIT, String.valueOf(pSplit)); - return this; - } - - /** - * Set the number of layers to store in VRAM (-1 - use default) - */ - public ModelParameters setNGpuLayers(int nGpuLayers) { - parameters.put(PARAM_N_GPU_LAYERS, String.valueOf(nGpuLayers)); - return this; - } - - /** - * Set the number of layers to store in VRAM for the draft model (-1 - use default) - */ - public ModelParameters setNGpuLayersDraft(int nGpuLayersDraft) { - parameters.put(PARAM_N_GPU_LAYERS_DRAFT, String.valueOf(nGpuLayersDraft)); - return this; - } - - /** - * Set how to split the model across GPUs - */ - public ModelParameters setSplitMode(GpuSplitMode splitMode) { -// switch (splitMode) { -// case NONE: parameters.put(PARAM_SPLIT_MODE, "\"none\""); break; -// case ROW: parameters.put(PARAM_SPLIT_MODE, "\"row\""); break; -// case LAYER: parameters.put(PARAM_SPLIT_MODE, "\"layer\""); break; -// } - parameters.put(PARAM_SPLIT_MODE, String.valueOf(splitMode.ordinal())); - return this; - } - - /** - * Set the GPU that is used for scratch and small tensors - */ - public ModelParameters setMainGpu(int mainGpu) { - parameters.put(PARAM_MAIN_GPU, String.valueOf(mainGpu)); - return this; - } - - /** - * Set how split tensors should be distributed across GPUs - */ - public ModelParameters setTensorSplit(float[] tensorSplit) { - if (tensorSplit.length > 0) { - StringBuilder builder = new StringBuilder(); - builder.append("["); - for (int i = 0; i < tensorSplit.length; i++) { - builder.append(tensorSplit[i]); - if (i < tensorSplit.length - 1) { - builder.append(", "); - } - } - builder.append("]"); - parameters.put(PARAM_TENSOR_SPLIT, builder.toString()); - } - return this; - } - - /** - * Set the group-attention factor (default: 1) - */ - public ModelParameters setGrpAttnN(int grpAttnN) { - parameters.put(PARAM_GRP_ATTN_N, String.valueOf(grpAttnN)); - return this; - } - - /** - * Set the group-attention width (default: 512.0) - */ - public ModelParameters setGrpAttnW(int grpAttnW) { - parameters.put(PARAM_GRP_ATTN_W, String.valueOf(grpAttnW)); - return this; - } - - /** - * Set the RoPE base frequency, used by NTK-aware scaling (default: loaded from model) - */ - public ModelParameters setRopeFreqBase(float ropeFreqBase) { - parameters.put(PARAM_ROPE_FREQ_BASE, String.valueOf(ropeFreqBase)); - return this; - } - - /** - * Set the RoPE frequency scaling factor, expands context by a factor of 1/N - */ - public ModelParameters setRopeFreqScale(float ropeFreqScale) { - parameters.put(PARAM_ROPE_FREQ_SCALE, String.valueOf(ropeFreqScale)); - return this; - } - - /** - * Set the YaRN extrapolation mix factor (default: 1.0, 0.0 = full interpolation) - */ - public ModelParameters setYarnExtFactor(float yarnExtFactor) { - parameters.put(PARAM_YARN_EXT_FACTOR, String.valueOf(yarnExtFactor)); - return this; - } - - /** - * Set the YaRN scale sqrt(t) or attention magnitude (default: 1.0) - */ - public ModelParameters setYarnAttnFactor(float yarnAttnFactor) { - parameters.put(PARAM_YARN_ATTN_FACTOR, String.valueOf(yarnAttnFactor)); - return this; - } - - /** - * Set the YaRN low correction dim or beta (default: 32.0) - */ - public ModelParameters setYarnBetaFast(float yarnBetaFast) { - parameters.put(PARAM_YARN_BETA_FAST, String.valueOf(yarnBetaFast)); - return this; - } - - /** - * Set the YaRN high correction dim or alpha (default: 1.0) - */ - public ModelParameters setYarnBetaSlow(float yarnBetaSlow) { - parameters.put(PARAM_YARN_BETA_SLOW, String.valueOf(yarnBetaSlow)); - return this; - } - - /** - * Set the YaRN original context size of model (default: 0 = model training context size) - */ - public ModelParameters setYarnOrigCtx(int yarnOrigCtx) { - parameters.put(PARAM_YARN_ORIG_CTX, String.valueOf(yarnOrigCtx)); - return this; - } - - /** - * Set the KV cache defragmentation threshold (default: -1.0, < 0 - disabled) - */ - public ModelParameters setDefragmentationThreshold(float defragThold) { - parameters.put(PARAM_DEFRAG_THOLD, String.valueOf(defragThold)); - return this; - } - - /** - * Set optimization strategies that help on some NUMA systems (if available) - *
    - *
  • distribute: spread execution evenly over all nodes
  • - *
  • isolate: only spawn threads on CPUs on the node that execution started on
  • - *
  • numactl: use the CPU map provided by numactl
  • - *
- * If run without this previously, it is recommended to drop the system page cache before using this - * (see #1437). - */ - public ModelParameters setNuma(NumaStrategy numa) { -// switch (numa) { -// case DISTRIBUTE: -// parameters.put(PARAM_NUMA, "\"distribute\""); -// break; -// case ISOLATE: -// parameters.put(PARAM_NUMA, "\"isolate\""); -// break; -// case NUMA_CTL: -// parameters.put(PARAM_NUMA, "\"numactl\""); -// break; -// case MIRROR: -// parameters.put(PARAM_NUMA, "\"mirror\""); -// break; -// } - parameters.put(PARAM_NUMA, String.valueOf(numa.ordinal())); - return this; - } - - /** - * Set the RoPE frequency scaling method, defaults to linear unless specified by the model - */ - public ModelParameters setRopeScalingType(RopeScalingType ropeScalingType) { -// switch (ropeScalingType) { -// case LINEAR: -// parameters.put(PARAM_ROPE_SCALING_TYPE, "\"linear\""); -// break; -// case YARN: -// parameters.put(PARAM_ROPE_SCALING_TYPE, "\"yarn\""); -// break; -// } - parameters.put(PARAM_ROPE_SCALING_TYPE, String.valueOf(ropeScalingType.ordinal())); - return this; - } - - /** - * Set the pooling type for embeddings, use model default if unspecified - */ - public ModelParameters setPoolingType(PoolingType poolingType) { -// switch (poolingType) { -// case MEAN: -// parameters.put(PARAM_POOLING_TYPE, "\"mean\""); -// break; -// case CLS: -// parameters.put(PARAM_POOLING_TYPE, "\"cls\""); -// break; -// } - parameters.put(PARAM_POOLING_TYPE, String.valueOf(poolingType.ordinal())); - return this; - } - - /** - * Set the model file path to load (default: models/7B/ggml-model-f16.gguf) - */ - public ModelParameters setModelFilePath(String model) { - parameters.put(PARAM_MODEL, toJsonString(model)); - return this; - } - - /** - * Set the draft model for speculative decoding (default: unused) - */ - public ModelParameters setModelDraft(String modelDraft) { - parameters.put(PARAM_MODEL_DRAFT, toJsonString(modelDraft)); - return this; - } - - /** - * Set a model alias - */ - public ModelParameters setModelAlias(String modelAlias) { - parameters.put(PARAM_MODEL_ALIAS, toJsonString(modelAlias)); - return this; - } - - /** - * Set a URL to download a model from (default: unused). - * Note, that this requires the library to be built with CURL (-DLLAMA_CURL=ON). - */ - public ModelParameters setModelUrl(String modelUrl) { - parameters.put(PARAM_MODEL_URL, toJsonString(modelUrl)); - return this; - } - - /** - * Set a Hugging Face model repository to use a model from (default: unused, see - * {@link #setHuggingFaceFile(String)}) - */ - public ModelParameters setHuggingFaceRepository(String hfRepo) { - parameters.put(PARAM_HF_REPO, toJsonString(hfRepo)); - return this; - } - - /** - * Set a Hugging Face model file to use (default: unused, see {@link #setHuggingFaceRepository(String)}) - */ - public ModelParameters setHuggingFaceFile(String hfFile) { - parameters.put(PARAM_HF_FILE, toJsonString(hfFile)); - return this; - } - - /** - * Set path to static lookup cache to use for lookup decoding (not updated by generation) - */ - public ModelParameters setLookupCacheStaticFilePath(String lookupCacheStatic) { - parameters.put(PARAM_LOOKUP_CACHE_STATIC, toJsonString(lookupCacheStatic)); - return this; - } - - /** - * Set path to dynamic lookup cache to use for lookup decoding (updated by generation) - */ - public ModelParameters setLookupCacheDynamicFilePath(String lookupCacheDynamic) { - parameters.put(PARAM_LOOKUP_CACHE_DYNAMIC, toJsonString(lookupCacheDynamic)); - return this; - } - - /** - * Set LoRA adapters to use (implies --no-mmap). - * The key is expected to be a file path, the values are expected to be scales. - */ - public ModelParameters setLoraAdapters(Map loraAdapters) { - if (!loraAdapters.isEmpty()) { - StringBuilder builder = new StringBuilder(); - builder.append("{"); - int i = 0; - for (Map.Entry entry : loraAdapters.entrySet()) { - String key = entry.getKey(); - Float value = entry.getValue(); - builder.append(toJsonString(key)) - .append(": ") - .append(value); - if (i++ < loraAdapters.size() - 1) { - builder.append(", "); - } - } - builder.append("}"); - parameters.put(PARAM_LORA_ADAPTER, builder.toString()); - } - return this; - } - - /** - * Whether to load model with embedding support - */ - public ModelParameters setEmbedding(boolean embedding) { - parameters.put(PARAM_EMBEDDING, String.valueOf(embedding)); - return this; - } - - /** - * Whether to enable continuous batching (also called "dynamic batching") (default: disabled) - */ - public ModelParameters setContinuousBatching(boolean contBatching) { - parameters.put(PARAM_CONT_BATCHING, String.valueOf(contBatching)); - return this; - } - - /** - * Whether to enable Flash Attention (default: disabled) - */ - public ModelParameters setFlashAttention(boolean flashAttention) { - parameters.put(PARAM_FLASH_ATTENTION, String.valueOf(flashAttention)); - return this; - } - - /** - * Whether to add prefix BOS to user inputs, preceding the `--in-prefix` string - */ - public ModelParameters setInputPrefixBos(boolean inputPrefixBos) { - parameters.put(PARAM_INPUT_PREFIX_BOS, String.valueOf(inputPrefixBos)); - return this; - } - - /** - * Whether to ignore end of stream token and continue generating (implies --logit-bias 2-inf) - */ - public ModelParameters setIgnoreEos(boolean ignoreEos) { - parameters.put(PARAM_IGNORE_EOS, String.valueOf(ignoreEos)); - return this; - } - - /** - * Whether to use memory-map model (faster load but may increase pageouts if not using mlock) - */ - public ModelParameters setUseMmap(boolean useMmap) { - parameters.put(PARAM_USE_MMAP, String.valueOf(useMmap)); - return this; - } - - /** - * Whether to force the system to keep model in RAM rather than swapping or compressing - */ - public ModelParameters setUseMlock(boolean useMlock) { - parameters.put(PARAM_USE_MLOCK, String.valueOf(useMlock)); - return this; - } - - /** - * Whether to disable KV offload - */ - public ModelParameters setNoKvOffload(boolean noKvOffload) { - parameters.put(PARAM_NO_KV_OFFLOAD, String.valueOf(noKvOffload)); - return this; - } - - /** - * Set a system prompt to use - */ - public ModelParameters setSystemPrompt(String systemPrompt) { - parameters.put(PARAM_SYSTEM_PROMPT, toJsonString(systemPrompt)); - return this; - } - - /** - * The chat template to use (default: empty) - */ - public ModelParameters setChatTemplate(String chatTemplate) { - parameters.put(PARAM_CHAT_TEMPLATE, toJsonString(chatTemplate)); - return this; - } +@SuppressWarnings("unused") +public final class ModelParameters extends CliParameters { + + /** + * Set the number of threads to use during generation (default: -1). + */ + public ModelParameters setThreads(int nThreads) { + parameters.put("--threads", String.valueOf(nThreads)); + return this; + } + + /** + * Set the number of threads to use during batch and prompt processing (default: same as --threads). + */ + public ModelParameters setThreadsBatch(int nThreads) { + parameters.put("--threads-batch", String.valueOf(nThreads)); + return this; + } + + /** + * Set the CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: ""). + */ + public ModelParameters setCpuMask(String mask) { + parameters.put("--cpu-mask", mask); + return this; + } + + /** + * Set the range of CPUs for affinity. Complements --cpu-mask. + */ + public ModelParameters setCpuRange(String range) { + parameters.put("--cpu-range", range); + return this; + } + + /** + * Use strict CPU placement (default: 0). + */ + public ModelParameters setCpuStrict(int strictCpu) { + parameters.put("--cpu-strict", String.valueOf(strictCpu)); + return this; + } + + /** + * Set process/thread priority: 0-normal, 1-medium, 2-high, 3-realtime (default: 0). + */ + public ModelParameters setPriority(int priority) { + if (priority < 0 || priority > 3) { + throw new IllegalArgumentException("Invalid value for priority"); + } + parameters.put("--prio", String.valueOf(priority)); + return this; + } + + /** + * Set the polling level to wait for work (0 - no polling, default: 0). + */ + public ModelParameters setPoll(int poll) { + parameters.put("--poll", String.valueOf(poll)); + return this; + } + + /** + * Set the CPU affinity mask for batch processing: arbitrarily long hex. Complements cpu-range-batch (default: same as --cpu-mask). + */ + public ModelParameters setCpuMaskBatch(String mask) { + parameters.put("--cpu-mask-batch", mask); + return this; + } + + /** + * Set the ranges of CPUs for batch affinity. Complements --cpu-mask-batch. + */ + public ModelParameters setCpuRangeBatch(String range) { + parameters.put("--cpu-range-batch", range); + return this; + } + + /** + * Use strict CPU placement for batch processing (default: same as --cpu-strict). + */ + public ModelParameters setCpuStrictBatch(int strictCpuBatch) { + parameters.put("--cpu-strict-batch", String.valueOf(strictCpuBatch)); + return this; + } + + /** + * Set process/thread priority for batch processing: 0-normal, 1-medium, 2-high, 3-realtime (default: 0). + */ + public ModelParameters setPriorityBatch(int priorityBatch) { + if (priorityBatch < 0 || priorityBatch > 3) { + throw new IllegalArgumentException("Invalid value for priority batch"); + } + parameters.put("--prio-batch", String.valueOf(priorityBatch)); + return this; + } + + /** + * Set the polling level for batch processing (default: same as --poll). + */ + public ModelParameters setPollBatch(int pollBatch) { + parameters.put("--poll-batch", String.valueOf(pollBatch)); + return this; + } + + /** + * Set the size of the prompt context (default: 0, 0 = loaded from model). + */ + public ModelParameters setCtxSize(int ctxSize) { + parameters.put("--ctx-size", String.valueOf(ctxSize)); + return this; + } + + /** + * Set the number of tokens to predict (default: -1 = infinity, -2 = until context filled). + */ + public ModelParameters setPredict(int nPredict) { + parameters.put("--predict", String.valueOf(nPredict)); + return this; + } + + /** + * Set the logical maximum batch size (default: 0). + */ + public ModelParameters setBatchSize(int batchSize) { + parameters.put("--batch-size", String.valueOf(batchSize)); + return this; + } + + /** + * Set the physical maximum batch size (default: 0). + */ + public ModelParameters setUbatchSize(int ubatchSize) { + parameters.put("--ubatch-size", String.valueOf(ubatchSize)); + return this; + } + + /** + * Set the number of tokens to keep from the initial prompt (default: -1 = all). + */ + public ModelParameters setKeep(int keep) { + parameters.put("--keep", String.valueOf(keep)); + return this; + } + + /** + * Disable context shift on infinite text generation (default: enabled). + */ + public ModelParameters disableContextShift() { + parameters.put("--no-context-shift", null); + return this; + } + + /** + * Enable Flash Attention (default: disabled). + */ + public ModelParameters enableFlashAttn() { + parameters.put("--flash-attn", null); + return this; + } + + /** + * Disable internal libllama performance timings (default: false). + */ + public ModelParameters disablePerf() { + parameters.put("--no-perf", null); + return this; + } + + /** + * Process escape sequences (default: true). + */ + public ModelParameters enableEscape() { + parameters.put("--escape", null); + return this; + } + + /** + * Do not process escape sequences (default: false). + */ + public ModelParameters disableEscape() { + parameters.put("--no-escape", null); + return this; + } + + /** + * Enable special tokens output (default: true). + */ + public ModelParameters enableSpecial() { + parameters.put("--special", null); + return this; + } + + /** + * Skip warming up the model with an empty run (default: false). + */ + public ModelParameters skipWarmup() { + parameters.put("--no-warmup", null); + return this; + } + + /** + * Use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. + * (default: disabled) + */ + public ModelParameters setSpmInfill() { + parameters.put("--spm-infill", null); + return this; + } + + /** + * Set samplers that will be used for generation in the order, separated by ';' (default: all). + */ + public ModelParameters setSamplers(Sampler... samplers) { + if (samplers.length > 0) { + StringBuilder builder = new StringBuilder(); + for (int i = 0; i < samplers.length; i++) { + Sampler sampler = samplers[i]; + builder.append(sampler.name().toLowerCase()); + if (i < samplers.length - 1) { + builder.append(";"); + } + } + parameters.put("--samplers", builder.toString()); + } + return this; + } + + /** + * Set RNG seed (default: -1, use random seed). + */ + public ModelParameters setSeed(long seed) { + parameters.put("--seed", String.valueOf(seed)); + return this; + } + + /** + * Ignore end of stream token and continue generating (implies --logit-bias EOS-inf). + */ + public ModelParameters ignoreEos() { + parameters.put("--ignore-eos", null); + return this; + } + + /** + * Set temperature for sampling (default: 0.8). + */ + public ModelParameters setTemp(float temp) { + parameters.put("--temp", String.valueOf(temp)); + return this; + } + + /** + * Set top-k sampling (default: 40, 0 = disabled). + */ + public ModelParameters setTopK(int topK) { + parameters.put("--top-k", String.valueOf(topK)); + return this; + } + + /** + * Set top-p sampling (default: 0.95, 1.0 = disabled). + */ + public ModelParameters setTopP(float topP) { + parameters.put("--top-p", String.valueOf(topP)); + return this; + } + + /** + * Set min-p sampling (default: 0.05, 0.0 = disabled). + */ + public ModelParameters setMinP(float minP) { + parameters.put("--min-p", String.valueOf(minP)); + return this; + } + + /** + * Set xtc probability (default: 0.0, 0.0 = disabled). + */ + public ModelParameters setXtcProbability(float xtcProbability) { + parameters.put("--xtc-probability", String.valueOf(xtcProbability)); + return this; + } + + /** + * Set xtc threshold (default: 0.1, 1.0 = disabled). + */ + public ModelParameters setXtcThreshold(float xtcThreshold) { + parameters.put("--xtc-threshold", String.valueOf(xtcThreshold)); + return this; + } + + /** + * Set locally typical sampling parameter p (default: 1.0, 1.0 = disabled). + */ + public ModelParameters setTypical(float typP) { + parameters.put("--typical", String.valueOf(typP)); + return this; + } + + /** + * Set last n tokens to consider for penalize (default: 64, 0 = disabled, -1 = ctx_size). + */ + public ModelParameters setRepeatLastN(int repeatLastN) { + if (repeatLastN < -1) { + throw new RuntimeException("Invalid repeat-last-n value"); + } + parameters.put("--repeat-last-n", String.valueOf(repeatLastN)); + return this; + } + + /** + * Set penalize repeat sequence of tokens (default: 1.0, 1.0 = disabled). + */ + public ModelParameters setRepeatPenalty(float repeatPenalty) { + parameters.put("--repeat-penalty", String.valueOf(repeatPenalty)); + return this; + } + + /** + * Set repeat alpha presence penalty (default: 0.0, 0.0 = disabled). + */ + public ModelParameters setPresencePenalty(float presencePenalty) { + parameters.put("--presence-penalty", String.valueOf(presencePenalty)); + return this; + } + + /** + * Set repeat alpha frequency penalty (default: 0.0, 0.0 = disabled). + */ + public ModelParameters setFrequencyPenalty(float frequencyPenalty) { + parameters.put("--frequency-penalty", String.valueOf(frequencyPenalty)); + return this; + } + + /** + * Set DRY sampling multiplier (default: 0.0, 0.0 = disabled). + */ + public ModelParameters setDryMultiplier(float dryMultiplier) { + parameters.put("--dry-multiplier", String.valueOf(dryMultiplier)); + return this; + } + + /** + * Set DRY sampling base value (default: 1.75). + */ + public ModelParameters setDryBase(float dryBase) { + parameters.put("--dry-base", String.valueOf(dryBase)); + return this; + } + + /** + * Set allowed length for DRY sampling (default: 2). + */ + public ModelParameters setDryAllowedLength(int dryAllowedLength) { + parameters.put("--dry-allowed-length", String.valueOf(dryAllowedLength)); + return this; + } + + /** + * Set DRY penalty for the last n tokens (default: -1, 0 = disable, -1 = context size). + */ + public ModelParameters setDryPenaltyLastN(int dryPenaltyLastN) { + if (dryPenaltyLastN < -1) { + throw new RuntimeException("Invalid dry-penalty-last-n value"); + } + parameters.put("--dry-penalty-last-n", String.valueOf(dryPenaltyLastN)); + return this; + } + + /** + * Add sequence breaker for DRY sampling, clearing out default breakers (default: none). + */ + public ModelParameters setDrySequenceBreaker(String drySequenceBreaker) { + parameters.put("--dry-sequence-breaker", drySequenceBreaker); + return this; + } + + /** + * Set dynamic temperature range (default: 0.0, 0.0 = disabled). + */ + public ModelParameters setDynatempRange(float dynatempRange) { + parameters.put("--dynatemp-range", String.valueOf(dynatempRange)); + return this; + } + + /** + * Set dynamic temperature exponent (default: 1.0). + */ + public ModelParameters setDynatempExponent(float dynatempExponent) { + parameters.put("--dynatemp-exp", String.valueOf(dynatempExponent)); + return this; + } + + /** + * Use Mirostat sampling (default: PLACEHOLDER, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0). + */ + public ModelParameters setMirostat(MiroStat mirostat) { + parameters.put("--mirostat", String.valueOf(mirostat.ordinal())); + return this; + } + + /** + * Set Mirostat learning rate, parameter eta (default: 0.1). + */ + public ModelParameters setMirostatLR(float mirostatLR) { + parameters.put("--mirostat-lr", String.valueOf(mirostatLR)); + return this; + } + + /** + * Set Mirostat target entropy, parameter tau (default: 5.0). + */ + public ModelParameters setMirostatEnt(float mirostatEnt) { + parameters.put("--mirostat-ent", String.valueOf(mirostatEnt)); + return this; + } + + /** + * Modify the likelihood of token appearing in the completion. + */ + public ModelParameters setLogitBias(String tokenIdAndBias) { + parameters.put("--logit-bias", tokenIdAndBias); + return this; + } + + /** + * Set BNF-like grammar to constrain generations (default: empty). + */ + public ModelParameters setGrammar(String grammar) { + parameters.put("--grammar", grammar); + return this; + } + + /** + * Specify the file to read grammar from. + */ + public ModelParameters setGrammarFile(String fileName) { + parameters.put("--grammar-file", fileName); + return this; + } + + /** + * Specify the JSON schema to constrain generations (default: empty). + */ + public ModelParameters setJsonSchema(String schema) { + parameters.put("--json-schema", schema); + return this; + } + + /** + * Set pooling type for embeddings (default: model default if unspecified). + */ + public ModelParameters setPoolingType(PoolingType type) { + parameters.put("--pooling", String.valueOf(type.getId())); + return this; + } + + /** + * Set RoPE frequency scaling method (default: linear unless specified by the model). + */ + public ModelParameters setRopeScaling(RopeScalingType type) { + parameters.put("--rope-scaling", String.valueOf(type.getId())); + return this; + } + + /** + * Set RoPE context scaling factor, expands context by a factor of N. + */ + public ModelParameters setRopeScale(float ropeScale) { + parameters.put("--rope-scale", String.valueOf(ropeScale)); + return this; + } + + /** + * Set RoPE base frequency, used by NTK-aware scaling (default: loaded from model). + */ + public ModelParameters setRopeFreqBase(float ropeFreqBase) { + parameters.put("--rope-freq-base", String.valueOf(ropeFreqBase)); + return this; + } + + /** + * Set RoPE frequency scaling factor, expands context by a factor of 1/N. + */ + public ModelParameters setRopeFreqScale(float ropeFreqScale) { + parameters.put("--rope-freq-scale", String.valueOf(ropeFreqScale)); + return this; + } + + /** + * Set YaRN: original context size of model (default: model training context size). + */ + public ModelParameters setYarnOrigCtx(int yarnOrigCtx) { + parameters.put("--yarn-orig-ctx", String.valueOf(yarnOrigCtx)); + return this; + } + + /** + * Set YaRN: extrapolation mix factor (default: 0.0 = full interpolation). + */ + public ModelParameters setYarnExtFactor(float yarnExtFactor) { + parameters.put("--yarn-ext-factor", String.valueOf(yarnExtFactor)); + return this; + } + + /** + * Set YaRN: scale sqrt(t) or attention magnitude (default: 1.0). + */ + public ModelParameters setYarnAttnFactor(float yarnAttnFactor) { + parameters.put("--yarn-attn-factor", String.valueOf(yarnAttnFactor)); + return this; + } + + /** + * Set YaRN: high correction dim or alpha (default: 1.0). + */ + public ModelParameters setYarnBetaSlow(float yarnBetaSlow) { + parameters.put("--yarn-beta-slow", String.valueOf(yarnBetaSlow)); + return this; + } + + /** + * Set YaRN: low correction dim or beta (default: 32.0). + */ + public ModelParameters setYarnBetaFast(float yarnBetaFast) { + parameters.put("--yarn-beta-fast", String.valueOf(yarnBetaFast)); + return this; + } + + /** + * Set group-attention factor (default: 1). + */ + public ModelParameters setGrpAttnN(int grpAttnN) { + parameters.put("--grp-attn-n", String.valueOf(grpAttnN)); + return this; + } + + /** + * Set group-attention width (default: 512). + */ + public ModelParameters setGrpAttnW(int grpAttnW) { + parameters.put("--grp-attn-w", String.valueOf(grpAttnW)); + return this; + } + + /** + * Enable verbose printing of the KV cache. + */ + public ModelParameters enableDumpKvCache() { + parameters.put("--dump-kv-cache", null); + return this; + } + + /** + * Disable KV offload. + */ + public ModelParameters disableKvOffload() { + parameters.put("--no-kv-offload", null); + return this; + } + + /** + * Set KV cache data type for K (allowed values: F16). + */ + public ModelParameters setCacheTypeK(CacheType type) { + parameters.put("--cache-type-k", type.name().toLowerCase()); + return this; + } + + /** + * Set KV cache data type for V (allowed values: F16). + */ + public ModelParameters setCacheTypeV(CacheType type) { + parameters.put("--cache-type-v", type.name().toLowerCase()); + return this; + } + + /** + * Set KV cache defragmentation threshold (default: 0.1, < 0 - disabled). + */ + public ModelParameters setDefragThold(float defragThold) { + parameters.put("--defrag-thold", String.valueOf(defragThold)); + return this; + } + + /** + * Set the number of parallel sequences to decode (default: 1). + */ + public ModelParameters setParallel(int nParallel) { + parameters.put("--parallel", String.valueOf(nParallel)); + return this; + } + + /** + * Enable continuous batching (a.k.a dynamic batching) (default: disabled). + */ + public ModelParameters enableContBatching() { + parameters.put("--cont-batching", null); + return this; + } + + /** + * Disable continuous batching. + */ + public ModelParameters disableContBatching() { + parameters.put("--no-cont-batching", null); + return this; + } + + /** + * Force system to keep model in RAM rather than swapping or compressing. + */ + public ModelParameters enableMlock() { + parameters.put("--mlock", null); + return this; + } + + /** + * Do not memory-map model (slower load but may reduce pageouts if not using mlock). + */ + public ModelParameters disableMmap() { + parameters.put("--no-mmap", null); + return this; + } + + /** + * Set NUMA optimization type for system. + */ + public ModelParameters setNuma(NumaStrategy numaStrategy) { + parameters.put("--numa", numaStrategy.name().toLowerCase()); + return this; + } + + /** + * Set comma-separated list of devices to use for offloading (none = don't offload). + */ + public ModelParameters setDevices(String devices) { + parameters.put("--device", devices); + return this; + } + + /** + * Set the number of layers to store in VRAM. + */ + public ModelParameters setGpuLayers(int gpuLayers) { + parameters.put("--gpu-layers", String.valueOf(gpuLayers)); + return this; + } + + /** + * Set how to split the model across multiple GPUs (none, layer, row). + */ + public ModelParameters setSplitMode(GpuSplitMode splitMode) { + parameters.put("--split-mode", splitMode.name().toLowerCase()); + return this; + } + + /** + * Set fraction of the model to offload to each GPU, comma-separated list of proportions N0,N1,N2,.... + */ + public ModelParameters setTensorSplit(String tensorSplit) { + parameters.put("--tensor-split", tensorSplit); + return this; + } + + /** + * Set the GPU to use for the model (with split-mode = none), or for intermediate results and KV (with split-mode = row). + */ + public ModelParameters setMainGpu(int mainGpu) { + parameters.put("--main-gpu", String.valueOf(mainGpu)); + return this; + } + + /** + * Enable checking model tensor data for invalid values. + */ + public ModelParameters enableCheckTensors() { + parameters.put("--check-tensors", null); + return this; + } + + /** + * Override model metadata by key. This option can be specified multiple times. + */ + public ModelParameters setOverrideKv(String keyValue) { + parameters.put("--override-kv", keyValue); + return this; + } + + /** + * Add a LoRA adapter (can be repeated to use multiple adapters). + */ + public ModelParameters addLoraAdapter(String fname) { + parameters.put("--lora", fname); + return this; + } + + /** + * Add a LoRA adapter with user-defined scaling (can be repeated to use multiple adapters). + */ + public ModelParameters addLoraScaledAdapter(String fname, float scale) { + parameters.put("--lora-scaled", fname + "," + scale); + return this; + } + + /** + * Add a control vector (this argument can be repeated to add multiple control vectors). + */ + public ModelParameters addControlVector(String fname) { + parameters.put("--control-vector", fname); + return this; + } + + /** + * Add a control vector with user-defined scaling (can be repeated to add multiple scaled control vectors). + */ + public ModelParameters addControlVectorScaled(String fname, float scale) { + parameters.put("--control-vector-scaled", fname + "," + scale); + return this; + } + + /** + * Set the layer range to apply the control vector(s) to (start and end inclusive). + */ + public ModelParameters setControlVectorLayerRange(int start, int end) { + parameters.put("--control-vector-layer-range", start + "," + end); + return this; + } + + /** + * Set the model path from which to load the base model. + */ + public ModelParameters setModel(String model) { + parameters.put("--model", model); + return this; + } + + /** + * Set the model download URL (default: unused). + */ + public ModelParameters setModelUrl(String modelUrl) { + parameters.put("--model-url", modelUrl); + return this; + } + + /** + * Set the Hugging Face model repository (default: unused). + */ + public ModelParameters setHfRepo(String hfRepo) { + parameters.put("--hf-repo", hfRepo); + return this; + } + + /** + * Set the Hugging Face model file (default: unused). + */ + public ModelParameters setHfFile(String hfFile) { + parameters.put("--hf-file", hfFile); + return this; + } + + /** + * Set the Hugging Face model repository for the vocoder model (default: unused). + */ + public ModelParameters setHfRepoV(String hfRepoV) { + parameters.put("--hf-repo-v", hfRepoV); + return this; + } + + /** + * Set the Hugging Face model file for the vocoder model (default: unused). + */ + public ModelParameters setHfFileV(String hfFileV) { + parameters.put("--hf-file-v", hfFileV); + return this; + } + + /** + * Set the Hugging Face access token (default: value from HF_TOKEN environment variable). + */ + public ModelParameters setHfToken(String hfToken) { + parameters.put("--hf-token", hfToken); + return this; + } + + /** + * Enable embedding use case; use only with dedicated embedding models. + */ + public ModelParameters enableEmbedding() { + parameters.put("--embedding", null); + return this; + } + + /** + * Enable reranking endpoint on server. + */ + public ModelParameters enableReranking() { + parameters.put("--reranking", null); + return this; + } + + /** + * Set minimum chunk size to attempt reusing from the cache via KV shifting. + */ + public ModelParameters setCacheReuse(int cacheReuse) { + parameters.put("--cache-reuse", String.valueOf(cacheReuse)); + return this; + } + + /** + * Set the path to save the slot kv cache. + */ + public ModelParameters setSlotSavePath(String slotSavePath) { + parameters.put("--slot-save-path", slotSavePath); + return this; + } + + /** + * Set custom jinja chat template. + */ + public ModelParameters setChatTemplate(String chatTemplate) { + parameters.put("--chat-template", chatTemplate); + return this; + } + + /** + * Set how much the prompt of a request must match the prompt of a slot in order to use that slot. + */ + public ModelParameters setSlotPromptSimilarity(float similarity) { + parameters.put("--slot-prompt-similarity", String.valueOf(similarity)); + return this; + } + + /** + * Load LoRA adapters without applying them (apply later via POST /lora-adapters). + */ + public ModelParameters setLoraInitWithoutApply() { + parameters.put("--lora-init-without-apply", null); + return this; + } + + /** + * Disable logging. + */ + public ModelParameters disableLog() { + parameters.put("--log-disable", null); + return this; + } + + /** + * Set the log file path. + */ + public ModelParameters setLogFile(String logFile) { + parameters.put("--log-file", logFile); + return this; + } + + /** + * Set verbosity level to infinity (log all messages, useful for debugging). + */ + public ModelParameters setVerbose() { + parameters.put("--verbose", null); + return this; + } + + /** + * Set the verbosity threshold (messages with a higher verbosity will be ignored). + */ + public ModelParameters setLogVerbosity(int verbosity) { + parameters.put("--log-verbosity", String.valueOf(verbosity)); + return this; + } + + /** + * Enable prefix in log messages. + */ + public ModelParameters enableLogPrefix() { + parameters.put("--log-prefix", null); + return this; + } + + /** + * Enable timestamps in log messages. + */ + public ModelParameters enableLogTimestamps() { + parameters.put("--log-timestamps", null); + return this; + } + + /** + * Set the number of tokens to draft for speculative decoding. + */ + public ModelParameters setDraftMax(int draftMax) { + parameters.put("--draft-max", String.valueOf(draftMax)); + return this; + } + + /** + * Set the minimum number of draft tokens to use for speculative decoding. + */ + public ModelParameters setDraftMin(int draftMin) { + parameters.put("--draft-min", String.valueOf(draftMin)); + return this; + } + + /** + * Set the minimum speculative decoding probability for greedy decoding. + */ + public ModelParameters setDraftPMin(float draftPMin) { + parameters.put("--draft-p-min", String.valueOf(draftPMin)); + return this; + } + + /** + * Set the size of the prompt context for the draft model. + */ + public ModelParameters setCtxSizeDraft(int ctxSizeDraft) { + parameters.put("--ctx-size-draft", String.valueOf(ctxSizeDraft)); + return this; + } + + /** + * Set the comma-separated list of devices to use for offloading the draft model. + */ + public ModelParameters setDeviceDraft(String deviceDraft) { + parameters.put("--device-draft", deviceDraft); + return this; + } + + /** + * Set the number of layers to store in VRAM for the draft model. + */ + public ModelParameters setGpuLayersDraft(int gpuLayersDraft) { + parameters.put("--gpu-layers-draft", String.valueOf(gpuLayersDraft)); + return this; + } + + /** + * Set the draft model for speculative decoding. + */ + public ModelParameters setModelDraft(String modelDraft) { + parameters.put("--model-draft", modelDraft); + return this; + } } From ec592d51c203e8aee5af69ab3d0d4b95c8fd7165 Mon Sep 17 00:00:00 2001 From: Konstantin Herud Date: Mon, 3 Feb 2025 22:58:10 +0100 Subject: [PATCH 7/9] update examples to use new model parameters --- .../java/de/kherud/llama/InferenceParameters.java | 6 ------ src/main/java/de/kherud/llama/LlamaModel.java | 13 ++++++------- src/test/java/de/kherud/llama/LlamaModelTest.java | 8 ++++---- src/test/java/examples/GrammarExample.java | 2 +- src/test/java/examples/InfillExample.java | 4 ++-- src/test/java/examples/MainExample.java | 4 ++-- 6 files changed, 15 insertions(+), 22 deletions(-) diff --git a/src/main/java/de/kherud/llama/InferenceParameters.java b/src/main/java/de/kherud/llama/InferenceParameters.java index d2698753..2c494c8c 100644 --- a/src/main/java/de/kherud/llama/InferenceParameters.java +++ b/src/main/java/de/kherud/llama/InferenceParameters.java @@ -459,12 +459,6 @@ public InferenceParameters setSamplers(Sampler... samplers) { case TOP_K: builder.append("\"top_k\""); break; - case TFS_Z: - builder.append("\"tfs_z\""); - break; - case TYPICAL_P: - builder.append("\"typical_p\""); - break; case TOP_P: builder.append("\"top_p\""); break; diff --git a/src/main/java/de/kherud/llama/LlamaModel.java b/src/main/java/de/kherud/llama/LlamaModel.java index b78e056e..2fe36ec3 100644 --- a/src/main/java/de/kherud/llama/LlamaModel.java +++ b/src/main/java/de/kherud/llama/LlamaModel.java @@ -16,7 +16,7 @@ *
    *
  • Streaming answers (and probabilities) via {@link #generate(InferenceParameters)}
  • *
  • Creating whole responses to prompts via {@link #complete(InferenceParameters)}
  • - *
  • Creating embeddings via {@link #embed(String)} (make sure to configure {@link ModelParameters#setEmbedding(boolean)}
  • + *
  • Creating embeddings via {@link #embed(String)} (make sure to configure {@link ModelParameters#enableEmbedding()}
  • *
  • Accessing the tokenizer via {@link #encode(String)} and {@link #decode(int[])}
  • *
*/ @@ -32,16 +32,16 @@ public class LlamaModel implements AutoCloseable { /** * Load with the given {@link ModelParameters}. Make sure to either set *
    - *
  • {@link ModelParameters#setModelFilePath(String)}
  • + *
  • {@link ModelParameters#setModel(String)}
  • *
  • {@link ModelParameters#setModelUrl(String)}
  • - *
  • {@link ModelParameters#setHuggingFaceRepository(String)}}, {@link ModelParameters#setHuggingFaceFile(String)}
  • + *
  • {@link ModelParameters#setHfRepo(String)}, {@link ModelParameters#setHfFile(String)}
  • *
* * @param parameters the set of options * @throws LlamaException if no model could be loaded from the given file path */ public LlamaModel(ModelParameters parameters) { - loadModel(parameters.toString()); + loadModel(parameters.toArray()); } /** @@ -73,8 +73,7 @@ public LlamaIterable generate(InferenceParameters parameters) { * * @param prompt the string to embed * @return an embedding float array - * @throws IllegalStateException if embedding mode was not activated (see - * {@link ModelParameters#setEmbedding(boolean)}) + * @throws IllegalStateException if embedding mode was not activated (see {@link ModelParameters#enableEmbedding()}) */ public native float[] embed(String prompt); @@ -124,7 +123,7 @@ public void close() { native byte[] decodeBytes(int[] tokens); - private native void loadModel(String parameters) throws LlamaException; + private native void loadModel(String... parameters) throws LlamaException; private native void delete(); diff --git a/src/test/java/de/kherud/llama/LlamaModelTest.java b/src/test/java/de/kherud/llama/LlamaModelTest.java index b5481cef..2a93e93e 100644 --- a/src/test/java/de/kherud/llama/LlamaModelTest.java +++ b/src/test/java/de/kherud/llama/LlamaModelTest.java @@ -24,11 +24,11 @@ public static void setup() { // LlamaModel.setLogger(LogFormat.TEXT, (level, msg) -> System.out.println(level + ": " + msg)); model = new LlamaModel( new ModelParameters() - .setNCtx(128) - .setModelFilePath("models/codellama-7b.Q2_K.gguf") + .setCtxSize(128) + .setModel("models/codellama-7b.Q2_K.gguf") // .setModelUrl("https://huggingface.co/TheBloke/CodeLlama-7B-GGUF/resolve/main/codellama-7b.Q2_K.gguf") - .setNGpuLayers(43) - .setEmbedding(true) + .setGpuLayers(43) + .enableEmbedding() ); } diff --git a/src/test/java/examples/GrammarExample.java b/src/test/java/examples/GrammarExample.java index a2fec2fb..d90de206 100644 --- a/src/test/java/examples/GrammarExample.java +++ b/src/test/java/examples/GrammarExample.java @@ -13,7 +13,7 @@ public static void main(String... args) { "expr ::= term ([-+*/] term)*\n" + "term ::= [0-9]"; ModelParameters modelParams = new ModelParameters() - .setModelFilePath("models/mistral-7b-instruct-v0.2.Q2_K.gguf"); + .setModel("models/mistral-7b-instruct-v0.2.Q2_K.gguf"); InferenceParameters inferParams = new InferenceParameters("") .setGrammar(grammar); try (LlamaModel model = new LlamaModel(modelParams)) { diff --git a/src/test/java/examples/InfillExample.java b/src/test/java/examples/InfillExample.java index b73eeb0f..e13ecb7c 100644 --- a/src/test/java/examples/InfillExample.java +++ b/src/test/java/examples/InfillExample.java @@ -9,8 +9,8 @@ public class InfillExample { public static void main(String... args) { ModelParameters modelParams = new ModelParameters() - .setModelFilePath("models/codellama-7b.Q2_K.gguf") - .setNGpuLayers(43); + .setModel("models/codellama-7b.Q2_K.gguf") + .setGpuLayers(43); String prefix = "def remove_non_ascii(s: str) -> str:\n \"\"\" "; String suffix = "\n return result\n"; diff --git a/src/test/java/examples/MainExample.java b/src/test/java/examples/MainExample.java index 92581144..2b5150a5 100644 --- a/src/test/java/examples/MainExample.java +++ b/src/test/java/examples/MainExample.java @@ -16,8 +16,8 @@ public class MainExample { public static void main(String... args) throws IOException { ModelParameters modelParams = new ModelParameters() - .setModelFilePath("models/mistral-7b-instruct-v0.2.Q2_K.gguf") - .setNGpuLayers(43); + .setModel("models/mistral-7b-instruct-v0.2.Q2_K.gguf") + .setGpuLayers(43); String system = "This is a conversation between User and Llama, a friendly chatbot.\n" + "Llama is helpful, kind, honest, good at writing, and never fails to answer any " + "requests immediately and with precision.\n\n" + From ffdbf4ef31384b5eed7a3ea31b1df9ce13e907aa Mon Sep 17 00:00:00 2001 From: Konstantin Herud Date: Mon, 3 Feb 2025 22:58:22 +0100 Subject: [PATCH 8/9] gitignore cmake build files --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 8857fd04..74bff7f2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ .idea target build +cmake-build-* .DS_Store .directory .vscode @@ -41,4 +42,4 @@ src/test/resources/**/*.gbnf **/*.etag **/*.lastModified -src/main/cpp/llama.cpp/ \ No newline at end of file +src/main/cpp/llama.cpp/ From 83cd1cff162b1209b832daba7f3b5c6eb622c3f2 Mon Sep 17 00:00:00 2001 From: Konstantin Herud Date: Mon, 3 Feb 2025 22:58:50 +0100 Subject: [PATCH 9/9] implement jni cli param parsing --- src/main/cpp/jllama.cpp | 126 ++++++++++++++++++++++++++-------------- src/main/cpp/jllama.h | 4 +- 2 files changed, 85 insertions(+), 45 deletions(-) diff --git a/src/main/cpp/jllama.cpp b/src/main/cpp/jllama.cpp index b5bb2937..b971210f 100644 --- a/src/main/cpp/jllama.cpp +++ b/src/main/cpp/jllama.cpp @@ -1,11 +1,13 @@ #include "jllama.h" -#include "log.h" +#include "arg.h" #include "llama.h" +#include "log.h" #include "nlohmann/json.hpp" #include "server.hpp" #include +#include #include // We store some references to Java classes and their fields/methods here to speed up things for later and to fail @@ -94,6 +96,38 @@ std::string parse_jstring(JNIEnv *env, jstring java_string) return string; } +char **parse_string_array(JNIEnv *env, const jobjectArray string_array, const jsize length) +{ + auto *const result = static_cast(malloc(length * sizeof(char *))); + + if (result == nullptr) + { + return nullptr; + } + + for (jsize i = 0; i < length; i++) + { + auto *const javaString = static_cast(env->GetObjectArrayElement(string_array, i)); + const char *cString = env->GetStringUTFChars(javaString, nullptr); + result[i] = strdup(cString); + env->ReleaseStringUTFChars(javaString, cString); + } + + return result; +} + +void free_string_array(char **array, jsize length) +{ + if (array != nullptr) + { + for (jsize i = 0; i < length; i++) + { + free(array[i]); + } + free(array); + } +} + /** * Since Java expects utf16 but std::strings are utf8, we can't directly use `env->NewString` or `env-NewString`, * but we directly send the bytes and do the conversion in Java. Unfortunately, there isn't a nice/standardized way to @@ -154,8 +188,6 @@ void log_callback_trampoline(ggml_log_level level, const char *text, void *user_ } } // namespace - - /** * The VM calls JNI_OnLoad when the native library is loaded (for example, through `System.loadLibrary`). * `JNI_OnLoad` must return the JNI version needed by the native library. @@ -355,21 +387,22 @@ JNIEXPORT void JNICALL JNI_OnUnload(JavaVM *vm, void *reserved) llama_backend_free(); } -JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jobject obj, jstring jparams) +JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jobject obj, jobjectArray jparams) { common_params params; - std::string c_params = parse_jstring(env, jparams); - json json_params = json::parse(c_params); - server_params_parse(json_params, params); - - if (json_value(json_params, "disable_log", false)) + const jsize argc = env->GetArrayLength(jparams); + char **argv = parse_string_array(env, jparams, argc); + if (argv == nullptr) { - common_log_pause(common_log_main()); + return; } - else + + const auto parsed_params = common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER); + free_string_array(argv, argc); + if (!parsed_params) { - common_log_resume(common_log_main()); + return; } common_init(); @@ -380,14 +413,14 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo llama_backend_init(); llama_numa_init(params.numa); - LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params.cpuparams.n_threads, params.cpuparams_batch.n_threads, std::thread::hardware_concurrency()); + LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params.cpuparams.n_threads, + params.cpuparams_batch.n_threads, std::thread::hardware_concurrency()); LOG_INF("\n"); LOG_INF("%s\n", common_params_get_system_info(params).c_str()); LOG_INF("\n"); std::atomic state{SERVER_STATE_LOADING_MODEL}; - // Necessary similarity of prompt for slot selection ctx_server->slot_prompt_similarity = params.slot_prompt_similarity; @@ -396,7 +429,8 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo // load the model if (!ctx_server->load_model(params)) { - llama_backend_free();; + llama_backend_free(); + ; env->ThrowNew(c_llama_error, "could not load model from given file path"); return; } @@ -409,9 +443,13 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo const auto model_meta = ctx_server->model_meta(); // if a custom chat template is not supplied, we will use the one that comes with the model (if any) - if (params.chat_template.empty()) { - if (!ctx_server->validate_builtin_chat_template()) { - LOG_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__); + if (params.chat_template.empty()) + { + if (!ctx_server->validate_builtin_chat_template()) + { + LOG_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. " + "This may cause the model to output suboptimal responses\n", + __func__); params.chat_template = "chatml"; } } @@ -421,11 +459,9 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo params.chat_template.empty() ? "(built-in)" : params.chat_template.c_str(), common_chat_format_example(ctx_server->model, params.chat_template).c_str()); - - ctx_server->queue_tasks.on_new_task(std::bind( - &server_context::process_single_task, ctx_server, std::placeholders::_1)); - ctx_server->queue_tasks.on_update_slots(std::bind( - &server_context::update_slots, ctx_server)); + ctx_server->queue_tasks.on_new_task( + std::bind(&server_context::process_single_task, ctx_server, std::placeholders::_1)); + ctx_server->queue_tasks.on_update_slots(std::bind(&server_context::update_slots, ctx_server)); std::thread t([ctx_server]() { JNIEnv *env; @@ -455,37 +491,40 @@ JNIEXPORT jint JNICALL Java_de_kherud_llama_LlamaModel_requestCompletion(JNIEnv server_task_type type = SERVER_TASK_TYPE_COMPLETION; - if (data.contains("input_prefix") || data.contains("input_suffix")) { + if (data.contains("input_prefix") || data.contains("input_suffix")) + { type = SERVER_TASK_TYPE_INFILL; } auto completion_id = gen_chatcmplid(); std::vector tasks; - try { - std::vector tokenized_prompts = tokenize_input_prompts(ctx_server->vocab,data.at("prompt"), true, true); + try + { + std::vector tokenized_prompts = + tokenize_input_prompts(ctx_server->vocab, data.at("prompt"), true, true); tasks.reserve(tokenized_prompts.size()); - for (size_t i = 0; i < tokenized_prompts.size(); i++) { + for (size_t i = 0; i < tokenized_prompts.size(); i++) + { server_task task = server_task(type); - task.id = ctx_server->queue_tasks.get_new_id(); + task.id = ctx_server->queue_tasks.get_new_id(); task.index = i; - task.prompt_tokens = std::move(tokenized_prompts[i]); - task.params = server_task::params_from_json_cmpl( - ctx_server->ctx, - ctx_server->params_base, - data); + task.prompt_tokens = std::move(tokenized_prompts[i]); + task.params = server_task::params_from_json_cmpl(ctx_server->ctx, ctx_server->params_base, data); task.id_selected_slot = json_value(data, "id_slot", -1); // OAI-compat - task.params.oaicompat = OAICOMPAT_TYPE_NONE; + task.params.oaicompat = OAICOMPAT_TYPE_NONE; task.params.oaicompat_cmpl_id = completion_id; // oaicompat_model is already populated by params_from_json_cmpl tasks.push_back(task); } - } catch (const std::exception & e) { + } + catch (const std::exception &e) + { const auto &err = format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST); env->ThrowNew(c_llama_error, err.dump().c_str()); return 0; @@ -496,7 +535,8 @@ JNIEXPORT jint JNICALL Java_de_kherud_llama_LlamaModel_requestCompletion(JNIEnv const auto task_ids = server_task::get_list_id(tasks); - if (task_ids.size() != 1) { + if (task_ids.size() != 1) + { env->ThrowNew(c_llama_error, "multitasking currently not supported"); return 0; } @@ -566,16 +606,16 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env, const auto tokens = tokenize_mixed(ctx_server->vocab, prompt, true, true); std::vector tasks; - server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); + server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); - task.id = ctx_server->queue_tasks.get_new_id(); - task.index = 0; - task.prompt_tokens = std::move(tokens); + task.id = ctx_server->queue_tasks.get_new_id(); + task.index = 0; + task.prompt_tokens = std::move(tokens); - // OAI-compat - task.params.oaicompat = OAICOMPAT_TYPE_NONE; + // OAI-compat + task.params.oaicompat = OAICOMPAT_TYPE_NONE; - tasks.push_back(task); + tasks.push_back(task); ctx_server->queue_results.add_waiting_tasks(tasks); ctx_server->queue_tasks.post(tasks); diff --git a/src/main/cpp/jllama.h b/src/main/cpp/jllama.h index 2fd0529e..4008f030 100644 --- a/src/main/cpp/jllama.h +++ b/src/main/cpp/jllama.h @@ -66,10 +66,10 @@ JNIEXPORT jbyteArray JNICALL Java_de_kherud_llama_LlamaModel_decodeBytes /* * Class: de_kherud_llama_LlamaModel * Method: loadModel - * Signature: (Ljava/lang/String;)V + * Signature: ([Ljava/lang/String;)V */ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel - (JNIEnv *, jobject, jstring); + (JNIEnv *, jobject, jobjectArray); /* * Class: de_kherud_llama_LlamaModel