diff --git a/common/arg.cpp b/common/arg.cpp index 55ec9389b69..9fefe411ee2 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -622,10 +622,6 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context for (auto & seq_breaker : params.sampling.dry_sequence_breakers) { string_process_escapes(seq_breaker); } - for (auto & pair : params.speculative.draft.replacements) { - string_process_escapes(pair.first); - string_process_escapes(pair.second); - } } if (!params.kv_overrides.empty()) { @@ -3518,13 +3514,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.speculative.draft.p_min = std::stof(value); } ).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_P_MIN")); - add_opt(common_arg( - {"--spec-draft-ctx-size", "-cd", "--ctx-size-draft"}, "N", - string_format("size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.speculative.draft.n_ctx), - [](common_params & params, int value) { - params.speculative.draft.n_ctx = value; - } - ).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_CTX_SIZE")); add_opt(common_arg( {"--spec-draft-device", "-devd", "--device-draft"}, "", "comma-separated list of devices to use for offloading the draft model (none = don't offload)\n" @@ -3561,32 +3550,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } ).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_MODEL")); add_opt(common_arg( - {"--spec-draft-replace", "--spec-replace"}, "TARGET", "DRAFT", - "translate the string in TARGET into DRAFT if the draft model and main model are not compatible", - [](common_params & params, const std::string & tgt, const std::string & dft) { - params.speculative.draft.replacements.push_back({ tgt, dft }); - } - ).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI})); - add_opt(common_arg( - {"--spec-type"}, "[none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]", + {"--spec-type"}, common_speculative_all_types_str(), string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n", - common_speculative_type_to_str(params.speculative.type).c_str()), + common_speculative_type_name_str(params.speculative.types).c_str()), [](common_params & params, const std::string & value) { - if (value == "none") { - params.speculative.type = COMMON_SPECULATIVE_TYPE_NONE; - } else if (value == "ngram-cache") { - params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_CACHE; - } else if (value == "ngram-simple") { - params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE; - } else if (value == "ngram-map-k") { - params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K; - } else if (value == "ngram-map-k4v") { - params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V; - } else if (value == "ngram-mod") { - params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MOD; - } else { - throw std::invalid_argument("unknown speculative decoding type without draft model"); - } + const auto enabled_types = string_split(value, ','); + params.speculative.types = common_speculative_types_from_names(enabled_types); } ).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_TYPE")); add_opt(common_arg( @@ -4075,7 +4044,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--spec-default"}, string_format("enable default speculative decoding config"), [](common_params & params) { - params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MOD; + params.speculative.types = { COMMON_SPECULATIVE_TYPE_NGRAM_MOD }; params.speculative.ngram_mod.n_match = 24; params.speculative.ngram_mod.n_min = 48; params.speculative.ngram_mod.n_max = 64; diff --git a/common/common.cpp b/common/common.cpp index 793b8fee7b8..352af0b1787 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1422,7 +1422,7 @@ common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) { // try to remove the last tokens if (!llama_memory_seq_rm(mem, 0, 1, -1)) { - LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__); + LOG_WRN("%s: the context does not support partial sequence removal\n", __func__); res = COMMON_CONTEXT_SEQ_RM_TYPE_FULL; goto done; } @@ -1960,3 +1960,102 @@ bool common_prompt_batch_decode( return true; } + +size_t common_prompt_checkpoint::size() const { + return data_tgt.size() + data_dft.size(); +} + +bool common_prompt_checkpoint::empty() const { + return data_tgt.empty(); +} + +void common_prompt_checkpoint::clear() { + n_tokens = 0; + + pos_min = 0; + pos_max = 0; + + data_tgt.clear(); + data_dft.clear(); +} + +void common_prompt_checkpoint::update_pos( + int64_t n_tokens, + llama_pos pos_min, + llama_pos pos_max) { + this->n_tokens = n_tokens; + this->pos_min = pos_min; + this->pos_max = pos_max; +} + +void common_prompt_checkpoint::update_tgt( + llama_context * ctx, + llama_seq_id seq_id, + llama_state_seq_flags flags) { + if (ctx == nullptr) { + return; + } + + const size_t ckpt_size = llama_state_seq_get_size_ext(ctx, seq_id, flags); + + data_tgt.resize(ckpt_size); + + const size_t n = llama_state_seq_get_data_ext(ctx, data_tgt.data(), ckpt_size, seq_id, flags); + if (n != ckpt_size) { + GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size, n); + } +} + +void common_prompt_checkpoint::update_dft( + llama_context * ctx, + llama_seq_id seq_id, + llama_state_seq_flags flags) { + if (ctx == nullptr) { + return; + } + + const size_t ckpt_size = llama_state_seq_get_size_ext(ctx, seq_id, flags); + + data_dft.resize(ckpt_size); + + const size_t n = llama_state_seq_get_data_ext(ctx, data_dft.data(), ckpt_size, seq_id, flags); + if (n != ckpt_size) { + GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size, n); + } +} + +void common_prompt_checkpoint::load_tgt( + llama_context * ctx, + llama_seq_id seq_id, + llama_state_seq_flags flags) const { + if (ctx == nullptr) { + return; + } + + if (data_tgt.empty()) { + return; + } + + const size_t n = llama_state_seq_set_data_ext(ctx, data_tgt.data(), data_tgt.size(), seq_id, flags); + if (n != data_tgt.size()) { + GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", data_tgt.size(), n); + } +} + +void common_prompt_checkpoint::load_dft( + llama_context * ctx, + llama_seq_id seq_id, + llama_state_seq_flags flags) const { + if (ctx == nullptr) { + return; + } + + if (data_dft.empty()) { + return; + } + + const size_t n = llama_state_seq_set_data_ext(ctx, data_dft.data(), data_dft.size(), seq_id, flags); + if (n != data_dft.size()) { + GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", data_dft.size(), n); + } +} diff --git a/common/common.h b/common/common.h index a564b3b8c2b..aafc376f2e7 100644 --- a/common/common.h +++ b/common/common.h @@ -295,8 +295,6 @@ struct common_params_model { std::string name = ""; // in format /[:] (tag is optional) // NOLINT }; -struct common_ngram_mod; - // draft-model-based speculative decoding parameters struct common_params_speculative_draft { int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding @@ -307,11 +305,9 @@ struct common_params_speculative_draft { common_params_model mparams; - llama_model * model = nullptr; // a llama_model that can be shared by multiple speculative contexts - - llama_context_params cparams; // these are the parameters for the draft llama_context + llama_context * ctx_tgt = nullptr; + llama_context * ctx_dft = nullptr; - int32_t n_ctx = 0; // draft context size int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default) ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K @@ -322,7 +318,6 @@ struct common_params_speculative_draft { std::vector devices; // devices to use for offloading - std::vector> replacements; // main to speculative model replacements std::vector tensor_buft_overrides; }; @@ -331,9 +326,6 @@ struct common_params_speculative_ngram_mod { int32_t n_max = 64; int32_t n_min = 48; - - // shared instance of the ngram container for all speculative decoding contexts - std::shared_ptr obj; }; struct common_params_speculative_ngram_map { @@ -348,8 +340,7 @@ struct common_params_speculative_ngram_cache { }; struct common_params_speculative { - // TODO: become a vector in order to support "chains of speculators" - common_speculative_type type = COMMON_SPECULATIVE_TYPE_NONE; + std::vector types = { COMMON_SPECULATIVE_TYPE_NONE }; common_params_speculative_draft draft; @@ -1026,3 +1017,47 @@ ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std // "adamw" or "sgd" (case insensitive) enum ggml_opt_optimizer_type common_opt_get_optimizer(const char *); + +// +// prompt utils +// + +struct common_prompt_checkpoint { + int64_t n_tokens; + + llama_pos pos_min; + llama_pos pos_max; + + std::vector data_tgt; + std::vector data_dft; + + size_t size() const; + + bool empty() const; + void clear(); + + void update_pos( + int64_t n_tokens, + llama_pos pos_min, + llama_pos pos_max); + + void update_tgt( + llama_context * ctx, + llama_seq_id seq_id, + llama_state_seq_flags flags); + + void update_dft( + llama_context * ctx, + llama_seq_id seq_id, + llama_state_seq_flags flags); + + void load_tgt( + llama_context * ctx, + llama_seq_id seq_id, + llama_state_seq_flags flags) const; + + void load_dft( + llama_context * ctx, + llama_seq_id seq_id, + llama_state_seq_flags flags) const; +}; diff --git a/common/speculative.cpp b/common/speculative.cpp index e9fa751e20e..e487e003d39 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -10,6 +10,7 @@ #include "sampling.h" #include +#include #include #include #include @@ -18,26 +19,15 @@ #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128 #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 -const std::vector common_speculative_types = { - COMMON_SPECULATIVE_TYPE_NONE, - COMMON_SPECULATIVE_TYPE_DRAFT, - COMMON_SPECULATIVE_TYPE_EAGLE3, - COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, - COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, - COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, - COMMON_SPECULATIVE_TYPE_NGRAM_MOD, - COMMON_SPECULATIVE_TYPE_NGRAM_CACHE -}; - -const std::map common_speculative_type_from_name_map = { +const std::map common_speculative_type_from_name_map = { {"none", COMMON_SPECULATIVE_TYPE_NONE}, {"draft", COMMON_SPECULATIVE_TYPE_DRAFT}, {"eagle3", COMMON_SPECULATIVE_TYPE_EAGLE3}, - {"ngram_simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE}, - {"ngram_map_k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K}, - {"ngram_map_k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V}, - {"ngram_mod", COMMON_SPECULATIVE_TYPE_NGRAM_MOD}, - {"ngram_cache", COMMON_SPECULATIVE_TYPE_NGRAM_CACHE} + {"ngram-simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE}, + {"ngram-map-k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K}, + {"ngram-map-k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V}, + {"ngram-mod", COMMON_SPECULATIVE_TYPE_NGRAM_MOD}, + {"ngram-cache", COMMON_SPECULATIVE_TYPE_NGRAM_CACHE} }; struct common_speculative_config { @@ -115,12 +105,16 @@ static bool common_speculative_are_compatible( return true; } +using common_speculative_draft_params_vec = std::vector; + // state of an implementation of speculative decoding // // each implementation has a unique type and a state that is implementation-specific -// in a subclass of common_speculative_state -struct common_speculative_state { - const enum common_speculative_type type; +// in a subclass of common_speculative_impl +struct common_speculative_impl { + const common_speculative_type type; + + uint32_t n_seq; size_t n_call_begin = 0; // number of times this implementation was called for refresh. size_t n_call_draft = 0; // number of times this implementation was called for generation. @@ -138,65 +132,34 @@ struct common_speculative_state { int64_t t_draft_us = 0; // total time spent in generating drafts in this implementation in microseconds. int64_t t_accept_us = 0; // total time spent in accumulation of this implementation in microseconds. - common_speculative_state(enum common_speculative_type type) : type(type) {} - - virtual ~common_speculative_state() = default; - - virtual void begin(const llama_tokens & prompt) = 0; - - virtual void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) = 0; - - virtual void accept(uint16_t n_accepted) = 0; + common_speculative_impl(common_speculative_type type, uint32_t n_seq) : type(type), n_seq(n_seq) {} - virtual int32_t n_max(const common_params_speculative & params) const = 0; - virtual int32_t n_min(const common_params_speculative & params) const = 0; -}; + virtual ~common_speculative_impl() = default; -struct common_speculative_checkpoint { - llama_pos pos_min = 0; - llama_pos pos_max = 0; + virtual void begin(llama_seq_id seq_id, const llama_tokens & prompt) = 0; - int64_t n_tokens = 0; + virtual bool process(const llama_batch & batch) = 0; - std::vector data; + virtual void draft(common_speculative_draft_params_vec & dparams) = 0; - size_t size() const { - return data.size(); - } + virtual void accept(llama_seq_id seq_id, uint16_t n_accepted) = 0; }; -struct common_speculative_state_draft : public common_speculative_state { - llama_context * ctx_tgt; // only used for retokenizing from ctx_dft - llama_context * ctx_dft; - - bool use_ckpt = false; - common_speculative_checkpoint ckpt; +struct common_speculative_state_draft : public common_speculative_impl { + common_params_speculative_draft params; - common_sampler * smpl; + llama_batch batch; - llama_batch batch; - llama_tokens prompt_dft; + std::vector smpls; - bool vocab_cmpt = true; // whether retokenization is needed - std::unordered_map vocab_map; - - common_speculative_state_draft( - enum common_speculative_type type, - llama_context * ctx_tgt, - llama_context * ctx_dft, - const std::vector> & replacements, - bool use_ckpt) - : common_speculative_state(type) - , ctx_tgt(ctx_tgt) - , ctx_dft(ctx_dft) - , use_ckpt(use_ckpt) + common_speculative_state_draft(const common_params_speculative & params, uint32_t n_seq) + : common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT, n_seq) + , params(params.draft) { + auto * ctx_dft = this->params.ctx_dft; + auto * ctx_tgt = this->params.ctx_tgt; + batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1); - smpl = nullptr; // TODO: optimize or pass from outside? // { @@ -214,7 +177,9 @@ struct common_speculative_state_draft : public common_speculative_state { // // result->smpl = common_sampler_init(llama_get_model(ctx_dft), params); // } - { + + smpls.resize(n_seq); + for (auto & smpl : smpls) { common_params_sampling params; params.no_perf = false; params.top_k = 10; @@ -222,482 +187,321 @@ struct common_speculative_state_draft : public common_speculative_state { COMMON_SAMPLER_TYPE_TOP_K, }; - smpl = common_sampler_init(llama_get_model(ctx_dft), params); + smpl.reset(common_sampler_init(llama_get_model(ctx_dft), params)); } - vocab_cmpt = common_speculative_are_compatible(llama_get_model(ctx_tgt), llama_get_model(ctx_dft)); - LOG_DBG("vocab_cmpt = %d\n", vocab_cmpt); + const bool vocab_cmpt = common_speculative_are_compatible(llama_get_model(ctx_tgt), llama_get_model(ctx_dft)); + LOG_DBG("%s: vocab_cmpt = %d\n", __func__, vocab_cmpt); if (!vocab_cmpt) { - LOG_WRN("the target and draft vocabs are not compatible - tokens will be translated between the two\n"); + LOG_ERR("%s: the target and draft vocabs are not compatible\n", __func__); - for (const auto & pair : replacements) { - vocab_map[pair.first] = pair.second; - } + throw std::runtime_error("draft model vocab type must match target model to use speculation"); } - } - ~common_speculative_state_draft() override { - llama_perf_context_print(ctx_dft); - - llama_free(ctx_dft); + if (n_seq != llama_n_seq_max(ctx_dft)) { + LOG_ERR("%s: n_seq mismatch: %d != %d\n", __func__, n_seq, llama_n_seq_max(ctx_dft)); - common_sampler_free(smpl); + throw std::runtime_error("the draft model number of sequences is incompatible with the speculative n_seq"); + } + } + ~common_speculative_state_draft() override { llama_batch_free(batch); } - void begin(const llama_tokens & /*prompt*/) override { + void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override { + // noop } - size_t create_checkpoint(int n_tokens_prompt) { - int slot_id = 0; - const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx_dft, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); - - ckpt.pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_dft), slot_id); - ckpt.pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), slot_id); - ckpt.n_tokens = n_tokens_prompt; - ckpt.data.resize(checkpoint_size); + bool process(const llama_batch & batch) override { + auto * ctx_dft = params.ctx_dft; - const size_t n = llama_state_seq_get_data_ext(ctx_dft, ckpt.data.data(), checkpoint_size, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); - if (n != checkpoint_size) { - GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n); - } + const int ret = llama_decode(ctx_dft, batch); - LOG_DBG("%s: pos_min = %d, pos_max = %d, size = %.3f MiB\n", __func__, - ckpt.pos_min, ckpt.pos_max, (float) ckpt.data.size() / 1024 / 1024); - return n; - } + if (ret != 0) { + LOG_ERR("%s: failed to decode draft batch, ret = %d\n", __func__, ret); - size_t restore_checkpoint() { - int slot_id = 0; - LOG_DBG("%s: pos_min = %d, pos_max = %d\n", __func__, ckpt.pos_min, ckpt.pos_max); - const size_t n = llama_state_seq_set_data_ext(ctx_dft, ckpt.data.data(), ckpt.size(), slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); - if (n != ckpt.size()) { - GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu", - __func__, ckpt.pos_min, ckpt.pos_max, ckpt.size()); + return false; } - llama_memory_seq_rm(llama_get_memory(ctx_dft), slot_id, ckpt.pos_max + 1, -1); - return n; + return true; } - void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) override { - const auto & sparams = params.draft; - - auto * spec = this; - - auto & batch = spec->batch; - auto & ctx_tgt = spec->ctx_tgt; - auto & ctx_dft = spec->ctx_dft; - auto & smpl = spec->smpl; - auto & prompt_dft = spec->prompt_dft; - - auto * mem_dft = llama_get_memory(ctx_dft); - - int reuse_i = 0; // index of part to be reused in prompt_dft - int reuse_n = 0; // length of part to be reused in prompt_dft - - const int n_ctx = llama_n_ctx(ctx_dft) - sparams.n_max; - - llama_tokens prompt_cnv; - if (!spec->vocab_cmpt) { - std::string text; - - text = common_detokenize(ctx_tgt, prompt_tgt, true); - text = replace_to_dft(text); + void draft(common_speculative_draft_params_vec & dparams) override { + auto & ctx_dft = params.ctx_dft; - LOG_DBG("%s: main->draft detokenized string: '%s'\n", __func__, text.c_str()); + common_batch_clear(batch); - prompt_cnv = common_tokenize(ctx_dft, text, false, true); + // keep track of which sequences are still drafting + int n_drafting = 0; + std::vector drafting(n_seq); - // convert id_last to draft vocab. llama_detokenize is called directly to avoid an allocation - const auto * model_tgt = llama_get_model(ctx_tgt); - const auto * vocab_tgt = llama_model_get_vocab(model_tgt); + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + auto & dp = dparams[seq_id]; - int32_t n_chars = llama_detokenize(vocab_tgt, &id_last, 1, nullptr, 0, false, false); - GGML_ASSERT(n_chars < 0 && "failed to detokenize id_last"); + if (!dp.drafting) { + continue; + } - text.resize(-n_chars); - llama_detokenize(vocab_tgt, &id_last, 1, text.data(), text.size(), false, false); - text = replace_to_dft(text); + n_drafting++; + drafting[seq_id] = true; + common_sampler_reset(smpls[seq_id].get()); - LOG_DBG("main->draft detokenized id_last(%d): '%s'\n", id_last, text.c_str()); - id_last = common_tokenize(ctx_dft, text, false, true)[0]; + common_batch_add(batch, dp.id_last, dp.n_past, { seq_id }, true); } - const llama_tokens & prompt_cur = spec->vocab_cmpt ? prompt_tgt : prompt_cnv; - - const int i_start = std::max(0, (int) prompt_cur.size() - n_ctx); - - if (use_ckpt && i_start > 0) { - LOG_WRN("%s: context shift is not supported with checkpoint-based contexts - skipping\n", __func__); + int ret = llama_decode(ctx_dft, batch); + if (ret != 0) { + LOG_WRN("%s: llama_decode returned %d\n", __func__, ret); return; } - // reuse as much as possible from the old draft context - // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt - for (int i = 0; i < (int) prompt_dft.size(); ++i) { - int cur = 0; - while (i_start + cur < (int) prompt_cur.size() && - i + cur < (int) prompt_dft.size() && - prompt_cur[i_start + cur] == prompt_dft[i + cur]) { - cur++; - } - - if ((cur >= 256 || n_ctx >= (int) prompt_cur.size()) && cur > reuse_n) { - reuse_i = i; - reuse_n = cur; - } - - if (use_ckpt) { - break; - } - } - - LOG_DBG("%s: reuse_i = %d, reuse_n = %d, #prompt_dft = %zu, #prompt_cur = %zu\n", - __func__, reuse_i, reuse_n, prompt_dft.size(), prompt_cur.size()); - if (use_ckpt && ckpt.n_tokens > reuse_n) { - LOG_DBG("%s: checkpoint (n_tokens = %d) is outdated -> delete it\n", __func__, (int) ckpt.n_tokens); + int i = 0; - reuse_i = 0; - reuse_n = 0; + while (n_drafting > 0) { + int i_batch = 0; - ckpt = {}; - } - - result.clear(); - result.reserve(sparams.n_max); + common_batch_clear(batch); - if (reuse_n == 0 || (use_ckpt && reuse_i > 0)) { - llama_memory_clear(mem_dft, false); - prompt_dft.clear(); - } else { - // this happens when a previous draft has been discarded (for example, due to being too small), but the - // target model agreed with it. in this case, we simply pass back the previous results to save compute - if (reuse_i + reuse_n < (int64_t) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) { - for (int i = reuse_i + reuse_n + 1; i < (int) prompt_dft.size(); ++i) { - result.push_back(prompt_dft[i]); - - if (sparams.n_max <= (int) result.size()) { - break; - } + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + if (!drafting[seq_id]) { + continue; } - return; - } + auto * smpl = smpls[seq_id].get(); - if (reuse_i > 0) { - GGML_ASSERT(!use_ckpt); + common_sampler_sample(smpl, ctx_dft, i_batch, true); + ++i_batch; - bool is_removed = llama_memory_seq_rm (mem_dft, 0, 0, reuse_i); - if (!is_removed) { - LOG_ERR("%s: llama_memory_seq_rm failed, reuse_i=%d\n", __func__, reuse_i); - return; - } - llama_memory_seq_add(mem_dft, 0, reuse_i, -1, -reuse_i); - - prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i); - } + const auto * cur_p = common_sampler_get_candidates(smpl, true); - if (reuse_n < (int) prompt_dft.size()) { - if (use_ckpt) { - if (ckpt.n_tokens > 0) { - LOG_DBG("%s: restoring checkpoint, reuse_n=%d, prompt_dft.size=%zu\n", __func__, reuse_n, prompt_dft.size()); - restore_checkpoint(); - reuse_n = ckpt.n_tokens; - prompt_dft.resize(reuse_n); - } - } else { - const bool is_removed = llama_memory_seq_rm(mem_dft, 0, reuse_n, -1); - if (!is_removed) { - LOG_ERR("%s: llama_memory_seq_rm failed, reuse_n=%d, prompt_dft.size=%zu\n", __func__, reuse_n, prompt_dft.size()); - return; - } - prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end()); + for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { + LOG_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", + seq_id, k, i, cur_p->data[k].id, cur_p->data[k].p, + common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str()); } - } - } - // prepare a batch to evaluate any new tokens in the prompt - common_batch_clear(batch); - - for (size_t i = i_start + reuse_n; i < prompt_cur.size(); ++i) { - //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_cur[i]); - common_batch_add(batch, prompt_cur[i], i - i_start, { 0 }, false); - - prompt_dft.push_back(prompt_cur[i]); - } + // add drafted token for each sequence + const llama_token id = cur_p->data[0].id; - // we should rarely end-up here during normal decoding - if (batch.n_tokens > 0) { - //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); - LOG_DBG("%s: draft prompt batch: %d tokens\n", __func__, batch.n_tokens); + // only collect very high-confidence draft tokens + if (cur_p->data[0].p < params.p_min) { + drafting[seq_id] = false; + n_drafting--; - int ret = llama_decode(ctx_dft, batch); - if (ret != 0 && ret != 1) { - LOG_WRN("%s: llama_decode returned %d, prompt_cur.size=%zu\n", - __func__, ret, prompt_cur.size()); - } - - if (use_ckpt) { - create_checkpoint(prompt_dft.size()); - } - } - - const llama_pos n_past = prompt_dft.size(); - - LOG_DBG("%s: n_past = %d\n", __func__, n_past); - - common_batch_clear(batch); - common_batch_add (batch, id_last, n_past, { 0 }, true); - - prompt_dft.push_back(id_last); - - //LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str()); - - int ret = llama_decode(ctx_dft, batch); - if (ret != 0 && ret != 1) { - LOG_WRN("%s: llama_decode returned %d, prompt_cur.size=%zu, prompt_dft.size=%zu\n", - __func__, ret, prompt_cur.size(), prompt_dft.size()); - } - - common_sampler_reset(smpl); - - // sample n_draft tokens from the draft model - for (int i = 0; i < sparams.n_max; ++i) { - common_batch_clear(batch); + continue; + } - common_sampler_sample(smpl, ctx_dft, 0, true); + common_sampler_accept(smpl, id, true); - const auto * cur_p = common_sampler_get_candidates(smpl, true); + auto & dp = dparams.at(seq_id); + auto & result = *dp.result; - for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { - LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", - k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str()); - } + result.push_back(id); - // add drafted token for each sequence - const llama_token id = cur_p->data[0].id; - - common_sampler_accept(smpl, id, true); + if ((params.n_max <= (int) result.size()) || + (dp.n_max > 0 && dp.n_max <= (int) result.size())) { + drafting[seq_id] = false; + n_drafting--; + continue; + } - // only collect very high-confidence draft tokens - if (cur_p->data[0].p < sparams.p_min) { - break; + common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true); } - result.push_back(id); - - if (sparams.n_max <= (int) result.size()) { + if (batch.n_tokens == 0) { break; } - common_batch_add(batch, id, n_past + i + 1, { 0 }, true); - // evaluate the drafted tokens on the draft model ret = llama_decode(ctx_dft, batch); if (ret != 0) { - LOG_WRN("%s: llama_decode[%d] returned %d, prompt_cur.size=%zu, prompt_dft.size=%zu\n", - __func__, i, ret, prompt_cur.size(), prompt_dft.size()); + LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret); + break; } - prompt_dft.push_back(id); + ++i; } - if (!spec->vocab_cmpt) { - std::string detokenized = common_detokenize(ctx_dft, result, true); - detokenized = replace_to_tgt(detokenized); - LOG_DBG("draft->main detokenized string: '%s'\n", detokenized.c_str()); - result = common_tokenize(ctx_tgt, detokenized, false, true); - if (result.size() > (size_t) sparams.n_max) { - result.resize(sparams.n_max); + for (auto & dp : dparams) { + if (!dp.drafting) { + continue; } - } - if (result.size() < (size_t) sparams.n_min) { - result.clear(); + if (dp.result->size() < (size_t) params.n_min) { + dp.result->clear(); + } } } - void accept(uint16_t n_accepted) override { + void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { // noop - GGML_UNUSED(n_accepted); - } - - int32_t n_max(const common_params_speculative & params) const override { - return params.draft.n_max; - } - - int32_t n_min(const common_params_speculative & params) const override { - return params.draft.n_min; - } - - std::string replace_to_dft(const std::string & input) const { - std::string result = input; - - for (const auto & pair : this->vocab_map) { - size_t pos = result.find(pair.first); - while (pos != std::string::npos) { - result.replace(pos, pair.first.length(), pair.second); - pos = result.find(pair.first, pos + pair.second.length()); - } - } - - return result; } +}; - std::string replace_to_tgt(const std::string & input) const { - std::string result = input; +struct common_speculative_state_eagle3 : public common_speculative_impl { + //common_params_speculative_eagle3 params; - for (const auto & pair : this->vocab_map) { - size_t pos = result.find(pair.second); - while (pos != std::string::npos) { - result.replace(pos, pair.second.length(), pair.first); - pos = result.find(pair.second, pos + pair.first.length()); - } - } + common_speculative_state_eagle3(const common_params_speculative & /*params*/, uint32_t n_seq) + : common_speculative_impl(COMMON_SPECULATIVE_TYPE_EAGLE3, n_seq) {} - return result; + void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override { + // noop } -}; - -struct common_speculative_state_eagle3 : public common_speculative_state { - common_speculative_state_eagle3(enum common_speculative_type type) : common_speculative_state(type) {} - void begin(const llama_tokens & prompt) override { - GGML_UNUSED(prompt); + bool process(const llama_batch & /*batch*/) override { + // TODO: implement + return true; } - void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & draft_tokens) override { + void draft(common_speculative_draft_params_vec & /*dparams*/) override { // TODO: implement - GGML_UNUSED(params); - GGML_UNUSED(prompt_tgt); - GGML_UNUSED(id_last); - GGML_UNUSED(draft_tokens); } - void accept(uint16_t n_accepted) override { + void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { // noop - GGML_UNUSED(n_accepted); - } - - int32_t n_max(const common_params_speculative & params) const override { - return params.draft.n_max; - } - - int32_t n_min(const common_params_speculative & params) const override { - return params.draft.n_min; } }; // state of self-speculation (simple implementation, not ngram-map) -struct common_speculative_state_ngram_simple : public common_speculative_state { +struct common_speculative_state_ngram_simple : public common_speculative_impl { + common_params_speculative_ngram_map params; + + // shared across all sequences common_ngram_simple_config config; common_speculative_state_ngram_simple( - enum common_speculative_type type, + const common_params_speculative & params, uint32_t n_seq, common_ngram_simple_config config) - : common_speculative_state(type), config(config) {} + : common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, n_seq) + , params(params.ngram_simple) + , config(config) {} - void begin(const llama_tokens & prompt) override { - GGML_UNUSED(prompt); + void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override { + // noop } - void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) override { - - result = common_ngram_simple_draft(config, prompt_tgt, id_last); - GGML_UNUSED(params); + bool process(const llama_batch & /*batch*/) override { + // TODO: implement + return true; } - void accept(uint16_t n_accepted) override { - // noop - GGML_UNUSED(n_accepted); - } + void draft(common_speculative_draft_params_vec & dparams) override { + assert(dparams.size() == n_seq); - int32_t n_max(const common_params_speculative & /*params*/) const override { - return config.size_mgram; + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + auto & dp = dparams[seq_id]; + if (!dp.drafting) { + continue; + } + + *dp.result = common_ngram_simple_draft(config, *dp.prompt, dp.id_last); + } } - int32_t n_min(const common_params_speculative & /*params*/) const override { - return config.size_mgram; + void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { + // noop } }; -struct common_speculative_state_ngram_map_k : public common_speculative_state { - // draft ngram map for speculative decoding without draft model - common_ngram_map config; +struct common_speculative_state_ngram_map_k : public common_speculative_impl { + common_params_speculative_ngram_map params; - common_speculative_state_ngram_map_k( - enum common_speculative_type type, - common_ngram_map config) - : common_speculative_state(type), config(std::move(config)) {} + // n_seq configs + std::vector config; - void begin(const llama_tokens & prompt) override { - common_ngram_map_begin(config, prompt); - } - - void draft( + common_speculative_state_ngram_map_k( const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) override { - common_ngram_map_draft(config, prompt_tgt, id_last, result); - GGML_UNUSED(params); + const common_ngram_map & config, + uint32_t n_seq) + : common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, n_seq) + , params(params.ngram_map_k) { + for (uint32_t i = 0; i < n_seq; i++) { + this->config.push_back(config); + } } - void accept(uint16_t n_accepted) override { - common_ngram_map_accept(config, n_accepted); + void begin(llama_seq_id seq_id, const llama_tokens & prompt) override { + GGML_ASSERT(seq_id < (llama_seq_id) n_seq); + + common_ngram_map_begin(config[seq_id], prompt); } - int32_t n_max(const common_params_speculative & /*params*/) const override { - return config.size_value; + bool process(const llama_batch & /*batch*/) override { + // TODO: implement + return true; } - int32_t n_min(const common_params_speculative & /*params*/) const override { - return config.size_value; + void draft(common_speculative_draft_params_vec & dparams) override { + assert(dparams.size() == n_seq); + + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + auto & dp = dparams[seq_id]; + if (!dp.drafting) { + continue; + } + + common_ngram_map_draft(config[seq_id], *dp.prompt, dp.id_last, *dp.result); + } } -}; -struct common_speculative_state_ngram_mod : public common_speculative_state { - common_ngram_mod & mod; + void accept(llama_seq_id seq_id, uint16_t n_accepted) override { + GGML_ASSERT((seq_id < (llama_seq_id) config.size())); - // the last position in the prompt that was added to the ngram container - size_t i_last = 0; + common_ngram_map_accept(config[seq_id], n_accepted); + } +}; - // length of the last drafted n‑gram (number of tokens returned by draft) - size_t n_draft_last = 0; +struct common_speculative_state_ngram_mod : public common_speculative_impl { + common_params_speculative_ngram_mod params; - // consecutive accept rounds with low acceptance fraction (< 0.5) - int n_low = 0; + // shared across all sequences + common_ngram_mod mod; // enable trace logging if LLAMA_TRACE is set const bool verbose; - common_speculative_state_ngram_mod(enum common_speculative_type type, common_ngram_mod & mod) - : common_speculative_state(type), mod(mod), verbose(std::getenv("LLAMA_TRACE") != nullptr) { + struct seq_info { + // the last position in the prompt that was added to the ngram container + size_t i_last = 0; + + // length of the last drafted n‑gram (number of tokens returned by draft) + size_t n_draft_last = 0; + + // consecutive accept rounds with low acceptance fraction (< 0.5) + int n_low = 0; + }; + + std::vector sinfos; + + common_speculative_state_ngram_mod( + const common_params_speculative & params, + uint32_t n_seq) + : common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_MOD, n_seq) + , params(params.ngram_mod) + , mod(params.ngram_mod.n_match, 4*1024*1024) + , verbose(std::getenv("LLAMA_TRACE") != nullptr) { static_assert(sizeof(llama_token) == sizeof(common_ngram_mod::entry_t)); + + LOG_INF("%s: initialized ngram_mod with n_match=%d, size=%zu (%.3f MB)\n", __func__, + this->params.n_match, mod.size(), (float)(mod.size_bytes())/1024/1024); + + if (this->params.n_match < 16) { + LOG_WRN("%s: ngram_mod n_match=%d is too small - poor quality is possible, " + "see: https://github.com/ggml-org/llama.cpp/pull/19164\n", __func__, this->params.n_match); + } + + sinfos.resize(n_seq); } - void begin(const llama_tokens & prompt) override { - i_last = 0; + void begin(llama_seq_id seq_id, const llama_tokens & prompt) override { + auto & sinfo = sinfos[seq_id]; - n_draft_last = 0; + sinfo.i_last = 0; + sinfo.n_draft_last = 0; const size_t n = mod.get_n(); - if (prompt.size() < n) { return; } @@ -706,7 +510,7 @@ struct common_speculative_state_ngram_mod : public common_speculative_state { mod.add(prompt.data() + i); } - i_last = prompt.size() - n; + sinfo.i_last = prompt.size() - n; const double f = (double)mod.get_used() / (double)mod.size(); LOG_INF("%s: ngram_mod occupancy = %zu/%zu (%.2f)\n", __func__, mod.get_used(), mod.size(), f); @@ -719,16 +523,17 @@ struct common_speculative_state_ngram_mod : public common_speculative_state { } } - void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) override { - const auto & sparams = params.ngram_mod; + void draft_one( + llama_seq_id seq_id, + common_speculative_draft_params & dparams) { + auto & sinfo = sinfos[seq_id]; + auto & result = *dparams.result; + + const auto & prompt = *dparams.prompt; - n_draft_last = 0; + sinfo.n_draft_last = 0; - const size_t cur_len = prompt_tgt.size(); + const size_t cur_len = prompt.size(); if (cur_len < mod.get_n()) { return; } @@ -736,24 +541,24 @@ struct common_speculative_state_ngram_mod : public common_speculative_state { const size_t n = mod.get_n(); // add new ngrams in chunks - if (i_last + 32 < cur_len) { - for (size_t i = i_last; i < cur_len - n; ++i) { - mod.add(prompt_tgt.data() + i); + if (sinfo.i_last + 32 < cur_len) { + for (size_t i = sinfo.i_last; i < cur_len - n; ++i) { + mod.add(prompt.data() + i); } - i_last = cur_len - n; + sinfo.i_last = cur_len - n; } - result.resize(n + sparams.n_max); + result.resize(n + params.n_max); for (size_t i = 0; i < n - 1; ++i) { - result[i] = prompt_tgt[cur_len - n + 1 + i]; + result[i] = prompt.at(cur_len - n + 1 + i); } - result[n - 1] = id_last; + result[n - 1] = dparams.id_last; - for (int i = 0; i < sparams.n_max; ++i) { + for (int i = 0; i < params.n_max; ++i) { const llama_token token = mod.get(result.data() + i); if (token == common_ngram_mod::EMPTY) { - if (i < sparams.n_min) { + if (i < params.n_min) { result.clear(); return; } @@ -771,65 +576,92 @@ struct common_speculative_state_ngram_mod : public common_speculative_state { result.resize(result.size() - n); // store length of drafted n‑gram for later acceptance analysis - n_draft_last = result.size(); + sinfo.n_draft_last = result.size(); } - void accept(uint16_t n_accepted) override { + bool process(const llama_batch & /*batch*/) override { + // TODO: implement + return true; + } + + void draft(common_speculative_draft_params_vec & dparams) override { + assert(dparams.size() == n_seq); + + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + auto & dp = dparams[seq_id]; + if (!dp.drafting) { + continue; + } + + draft_one(seq_id, dp); + } + } + + void accept(llama_seq_id seq_id, uint16_t n_accepted) override { + auto & sinfo = sinfos[seq_id]; + // compute acceptance fraction if we have a recorded draft length - if (n_draft_last > 0) { - const double f_acc = (double)n_accepted / (double)n_draft_last; + if (sinfo.n_draft_last > 0) { + const double f_acc = (double)n_accepted / (double)sinfo.n_draft_last; if (f_acc < 0.5) { - n_low++; - if (n_low >= 3) { + sinfo.n_low++; + if (sinfo.n_low >= 3) { if (verbose) { - LOG_WRN("%s: low acceptance streak (%d) – resetting ngram_mod\n", __func__, n_low); + LOG_WRN("%s: low acceptance streak (%d) – resetting ngram_mod\n", __func__, sinfo.n_low); } mod.reset(); - n_low = 0; - i_last = 0; + sinfo.n_low = 0; + sinfo.i_last = 0; } } else { - n_low = 0; + sinfo.n_low = 0; } } } - - int32_t n_max(const common_params_speculative & params) const override { - return params.ngram_mod.n_max; - } - - int32_t n_min(const common_params_speculative & params) const override { - return params.ngram_mod.n_min; - } }; -struct common_speculative_state_ngram_cache : public common_speculative_state { +struct common_speculative_state_ngram_cache : public common_speculative_impl { + common_params_speculative_ngram_cache params; + uint16_t n_draft; + bool save_dynamic; bool save_static; - common_ngram_cache ngram_cache_context; - common_ngram_cache ngram_cache_dynamic; - common_ngram_cache ngram_cache_static; + struct seq_info { + size_t cache_size = 0; // number of tokens in n-gram cache + + common_ngram_cache ngram_cache_context; + common_ngram_cache ngram_cache_dynamic; + common_ngram_cache ngram_cache_static; + }; - size_t cache_size = 0; // number of tokens in n-gram cache + std::vector sinfos; common_speculative_state_ngram_cache( - const enum common_speculative_type type, + const common_params_speculative & params, + uint32_t n_seq, + uint16_t n_draft, const std::string & path_static, const std::string & path_dynamic, - uint16_t n_draft, - bool save_dynamic, - bool save_static) - : common_speculative_state(type) + bool save_dynamic, + bool save_static) + : common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, n_seq) + , params(params.ngram_cache) , n_draft(n_draft) , save_dynamic(save_dynamic) , save_static(save_static) { + sinfos.resize(n_seq); + if (!path_static.empty()) { try { - ngram_cache_static = common_ngram_cache_load(path_static); + auto ngram_cache_static = common_ngram_cache_load(path_static); + + for (auto & sinfo : sinfos) { + sinfo.ngram_cache_static = ngram_cache_static; + } } catch (...) { LOG_ERR("failed to open static lookup cache: %s", path_static.c_str()); GGML_ABORT("Couldn't read static lookup cache"); @@ -838,7 +670,11 @@ struct common_speculative_state_ngram_cache : public common_speculative_state { if (!path_dynamic.empty()) { try { - ngram_cache_dynamic = common_ngram_cache_load(path_dynamic); + auto ngram_cache_dynamic = common_ngram_cache_load(path_dynamic); + + for (auto & sinfo : sinfos) { + sinfo.ngram_cache_dynamic = ngram_cache_dynamic; + } } catch (...) { LOG_ERR("failed to open dynamic lookup cache: %s", path_dynamic.c_str()); GGML_ABORT("Couldn't read dynamic lookup cache"); @@ -846,44 +682,48 @@ struct common_speculative_state_ngram_cache : public common_speculative_state { } } - void begin(const llama_tokens & prompt) override { - GGML_UNUSED(prompt); + void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override { + // noop } - void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) override { - GGML_UNUSED(params); + void draft_one( + llama_seq_id seq_id, + common_speculative_draft_params & dparams) { + auto & sinfo = sinfos[seq_id]; + auto & result = *dparams.result; - if (cache_size < prompt_tgt.size() + 1) { + const auto & prompt = *dparams.prompt; + + if (sinfo.cache_size < prompt.size() + 1) { llama_tokens tokens_new; - tokens_new.reserve(prompt_tgt.size() + 1 - cache_size); - for (size_t j = cache_size; j < prompt_tgt.size(); ++j) { - tokens_new.push_back(prompt_tgt[j]); + tokens_new.reserve(prompt.size() + 1 - sinfo.cache_size); + for (size_t j = sinfo.cache_size; j < prompt.size(); ++j) { + tokens_new.push_back(prompt[j]); } - tokens_new.push_back(id_last); // add the last token + tokens_new.push_back(dparams.id_last); // add the last token - // Update context ngram cache with new prompt_tgt: - common_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, + // Update context ngram cache with new dparams.prompt: + common_ngram_cache_update( + sinfo.ngram_cache_context, + LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, tokens_new, tokens_new.size(), false); - cache_size = prompt_tgt.size() + 1; + sinfo.cache_size = prompt.size() + 1; } llama_tokens inp; - inp.reserve(prompt_tgt.size() + 1); - for (size_t j = 0; j < prompt_tgt.size(); ++j) { - inp.push_back(prompt_tgt[j]); + inp.reserve(prompt.size() + 1); + for (size_t j = 0; j < prompt.size(); ++j) { + inp.push_back(prompt[j]); } - inp.push_back(id_last); + inp.push_back(dparams.id_last); - result.push_back(id_last); + result.push_back(dparams.id_last); - common_ngram_cache_draft(inp, result, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, - ngram_cache_context, - ngram_cache_dynamic, - ngram_cache_static); + common_ngram_cache_draft( + inp, result, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, + sinfo.ngram_cache_context, + sinfo.ngram_cache_dynamic, + sinfo.ngram_cache_static); if (result.size() > 0) { // delete first token in result (which is the id_last token) @@ -891,24 +731,37 @@ struct common_speculative_state_ngram_cache : public common_speculative_state { } } - void accept(uint16_t n_accepted) override { - // TODO: noop - GGML_UNUSED(n_accepted); + bool process(const llama_batch & /*batch*/) override { + // TODO: implement + return true; } - int32_t n_max(const common_params_speculative & /*params*/) const override { - return n_draft; + void draft(common_speculative_draft_params_vec & dparams) override { + assert(dparams.size() == n_seq); + + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + auto & dp = dparams[seq_id]; + if (!dp.drafting) { + continue; + } + + draft_one(seq_id, dp); + } } - int32_t n_min(const common_params_speculative & /*params*/) const override { - return 0; + void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { + // noop } }; struct common_speculative { - std::vector> impls; // list of implementations to use and their states + common_speculative_draft_params_vec dparams; - common_speculative_state * curr_impl = nullptr; // current implementation in use (for stats) + // list of implementations to use and their states + std::vector> impls; + + // which implementaion was used for a given seq_id + std::vector impl_last; }; static common_ngram_map get_common_ngram_map( @@ -923,45 +776,79 @@ static common_ngram_map get_common_ngram_map( } static common_speculative_state_ngram_cache create_state_ngram_cache( - const std::string & path_static, const std::string & path_dynamic, - const common_speculative_config & config) { + const common_speculative_config & config, + uint32_t n_seq, + const std::string & path_static, + const std::string & path_dynamic) { uint16_t n_draft = 8; // TODO get from config? // TODO bool param in common/common.h to set save_static/save_dynamic? bool save_static = false; bool save_dynamic = false; - common_speculative_state_ngram_cache state(config.type, path_static, path_dynamic, n_draft, save_static, save_dynamic); + common_speculative_state_ngram_cache state(config.params, n_seq, n_draft, path_static, path_dynamic, save_static, save_dynamic); return state; } -std::string common_speculative_type_name_str() { +std::string common_speculative_type_name_str(const std::vector & types) { std::string result; - for (size_t i = 0; i < common_speculative_types.size(); i++) { + + for (size_t i = 0; i < types.size(); i++) { if (i > 0) { - result += ", "; + result += ","; } - result += common_speculative_type_to_str(common_speculative_types[i]); + result += common_speculative_type_to_str(types[i]); } return result; } -std::string common_speculative_type_to_str(enum common_speculative_type type) { +const char * common_speculative_all_types_str() { + static std::string all_types_str = []() { + std::vector types; + types.reserve(COMMON_SPECULATIVE_TYPE_COUNT); + for (int i = 0; i < COMMON_SPECULATIVE_TYPE_COUNT; i++) { + types.push_back((common_speculative_type) i); + } + return common_speculative_type_name_str(types); + }(); + return all_types_str.c_str(); +} + +std::string common_speculative_type_to_str(common_speculative_type type) { switch (type) { case COMMON_SPECULATIVE_TYPE_NONE: return "none"; case COMMON_SPECULATIVE_TYPE_DRAFT: return "draft"; case COMMON_SPECULATIVE_TYPE_EAGLE3: return "eagle3"; - case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram_simple"; - case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram_map_k"; - case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram_map_k4v"; - case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: return "ngram_mod"; - case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: return "ngram_cache"; + case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram-simple"; + case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram-map-k"; + case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram-map-k4v"; + case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: return "ngram-mod"; + case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: return "ngram-cache"; default: return "unknown"; } } -enum common_speculative_type common_speculative_type_from_name(const std::string & name) { +std::vector common_speculative_types_from_names(const std::vector & names) { + std::vector types; + types.reserve(names.size()); + + for (const auto & name : names) { + auto type = common_speculative_type_from_name_map.find(name); + if (type != common_speculative_type_from_name_map.end()) { + if (type->second == COMMON_SPECULATIVE_TYPE_NONE) { + return std::vector { COMMON_SPECULATIVE_TYPE_NONE }; + } + types.push_back(type->second); + continue; + } + throw std::invalid_argument("unknown speculative type: " + name); + } + + return types; +} + +common_speculative_type common_speculative_type_from_name(const std::string & name) { const auto it = common_speculative_type_from_name_map.find(name); if (it == common_speculative_type_from_name_map.end()) { return COMMON_SPECULATIVE_TYPE_COUNT; @@ -969,34 +856,39 @@ enum common_speculative_type common_speculative_type_from_name(const std::string return it->second; } -// initialization of the speculative decoding system -// -common_speculative * common_speculative_init( - common_params_speculative & params, - llama_context * ctx_tgt) { - llama_context * ctx_dft = nullptr; - if (params.draft.model) { - ctx_dft = llama_init_from_model(params.draft.model, params.draft.cparams); - if (ctx_dft == nullptr) { - LOG_ERR("%s", "failed to create draft context\n"); - return nullptr; - } +static uint32_t common_get_enabled_speculative_configs(const std::vector & configs) { + uint32_t result = 0; + for (size_t i = 0; i < configs.size(); i++) { + result |= (1u << configs[i]); } + return result; +} +// initialization of the speculative decoding system +// +common_speculative * common_speculative_init(common_params_speculative & params, uint32_t n_seq) { // Compute the implementations to use based on the config and their order of preference std::vector configs = {}; // list of speculative configs to try { - bool has_draft = !params.draft.mparams.path.empty(); + uint32_t enabled_configs = common_get_enabled_speculative_configs(params.types); + + bool has_draft = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT)); + bool has_draft_model = !params.draft.mparams.path.empty(); + + // bool has_mtp = false; // TODO: add MTP here bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3 - bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE); - bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE); - bool has_ngram_map_k = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K); - bool has_ngram_map_k4v = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V); - bool has_ngram_mod = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MOD); + bool has_ngram_cache = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_CACHE)); + bool has_ngram_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE)); + bool has_ngram_map_k = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K)); + bool has_ngram_map_k4v = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V)); + bool has_ngram_mod = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_MOD)); - // In a more complex implementation we could use the same implementation but with different parameters. - // This was initially used in PR-18471 but removed to simplify the code. + // when adding a new type - update here the logic above + static_assert(COMMON_SPECULATIVE_TYPE_COUNT == 8); + + // this list here defines the priority of the speculators + // the one with highest priority are listed first if (has_ngram_simple) { // This implementation can guess a lot of tokens without any draft model. configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, params)); @@ -1009,53 +901,43 @@ common_speculative * common_speculative_init( configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, params)); } if (has_ngram_mod) { - auto & sparams = params.ngram_mod; - - if (!sparams.obj) { - sparams.obj = std::make_shared(sparams.n_match, 4*1024*1024); - - LOG_INF("%s: initialized ngram_mod with n_match=%d, size=%zu (%.3f MB)\n", __func__, - sparams.n_match, sparams.obj->size(), (float)(sparams.obj->size_bytes())/1024/1024); - - if (sparams.n_match < 16) { - LOG_WRN("%s: ngram_mod n_match=%d is too small - poor quality is possible, " - "see: https://github.com/ggml-org/llama.cpp/pull/19164\n", __func__, sparams.n_match); - } - } - configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MOD, params)); } if (has_ngram_cache) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, params)); } + if (has_draft) { + if (!has_draft_model) { + LOG_WRN("%s: draft model is not specified - cannot use 'draft' type\n", __func__); + has_draft = false; + } + } else if (has_draft_model) { + LOG_WRN("%s: draft model is specified but 'draft' speculative type is not explicitly enabled - enabling it\n", __func__); + has_draft = true; + } + if (has_draft) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT, params)); } + // TODO: add MTP here if (has_draft_eagle3) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_EAGLE3, params)); } } - std::vector> impls = {}; + std::vector> impls = {}; for (const common_speculative_config & config : configs) { - LOG_DBG("%s: adding implementation %s\n", __func__, common_speculative_type_to_str(config.type).c_str()); + LOG_INF("%s: adding speculative implementation '%s'\n", __func__, common_speculative_type_to_str(config.type).c_str()); switch (config.type) { case COMMON_SPECULATIVE_TYPE_NONE: break; case COMMON_SPECULATIVE_TYPE_DRAFT: { - const bool use_ckpt = common_context_can_seq_rm(ctx_dft) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; - - impls.push_back(std::make_unique(config.type, - /* .ctx_tgt = */ ctx_tgt, - /* .ctx_dft = */ ctx_dft, - /* .replacements = */ params.draft.replacements, - /* .use_ckpt = */ use_ckpt - )); + impls.push_back(std::make_unique(config.params, n_seq)); break; } case COMMON_SPECULATIVE_TYPE_EAGLE3: { - impls.push_back(std::make_unique(config.type)); + impls.push_back(std::make_unique(config.params, n_seq)); break; } case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: { @@ -1069,27 +951,30 @@ common_speculative * common_speculative_init( /* .size_mgram = */ mgram_size_value }; auto state = std::make_unique( - /* .type = */ config.type, - /* .state = */ config_simple + /* .params = */ config.params, + /* .n_seq = */ n_seq, + /* .state = */ config_simple ); impls.push_back(std::move(state)); break; } case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: { - impls.push_back(std::make_unique( - (config.type), - get_common_ngram_map(config.type, config.params.ngram_map_k) - )); + impls.push_back( + std::make_unique( + config.params, get_common_ngram_map(config.type, config.params.ngram_map_k), n_seq)); break; } case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: { - GGML_ASSERT(config.params.ngram_mod.obj); - impls.push_back(std::make_unique(config.type, *config.params.ngram_mod.obj)); + impls.push_back( + std::make_unique(config.params, n_seq)); break; } case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: { - auto state = create_state_ngram_cache(params.ngram_cache.lookup_cache_static, params.ngram_cache.lookup_cache_dynamic, config); + auto state = create_state_ngram_cache( + config, n_seq, + params.ngram_cache.lookup_cache_static, + params.ngram_cache.lookup_cache_dynamic); impls.push_back(std::make_unique(state)); break; } @@ -1104,8 +989,9 @@ common_speculative * common_speculative_init( } auto * result = new common_speculative { + /* .dparams = */ common_speculative_draft_params_vec(n_seq), /* .impls = */ std::move(impls), - /* .curr_impl = */ nullptr, + /* .impl_last = */ std::vector(n_seq, nullptr) }; return result; @@ -1119,65 +1005,128 @@ void common_speculative_free(common_speculative * spec) { delete spec; } -void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt) { +common_speculative_draft_params & common_speculative_get_draft_params( + common_speculative * spec, + llama_seq_id seq_id) { + GGML_ASSERT(spec); + GGML_ASSERT(seq_id < (llama_seq_id) spec->dparams.size()); + + return spec->dparams[seq_id]; +} + +void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, const llama_tokens & prompt) { if (spec == nullptr) { return; } for (auto & impl : spec->impls) { common_time_meas tm(impl->t_begin_us, !impl->gen_perf); - impl->begin(prompt); + impl->begin(seq_id, prompt); impl->n_call_begin++; } } -llama_tokens common_speculative_draft( - common_speculative * spec, - const common_params_speculative & params, - const llama_tokens & prompt_tgt, // specified in target model vocab - llama_token id_last) { - llama_tokens result; +bool common_speculative_process(common_speculative * spec, const llama_batch & batch) { + bool result = true; + + if (spec == nullptr) { + return result; + } + + for (auto & impl : spec->impls) { + result = result && impl->process(batch); + } + + return result; +} + +void common_speculative_draft(common_speculative * spec) { + if (spec == nullptr) { + return; + } + + auto & dparams = spec->dparams; + + { + int n_drafting = 0; + + for (auto & dp : dparams) { + GGML_ASSERT(!dp.drafting || dp.result->empty()); + + if (dp.drafting) { + n_drafting++; + } + } - spec->curr_impl = nullptr; // reset current implementation + if (n_drafting == 0) { + return; + } + } for (auto & impl : spec->impls) { { common_time_meas tm(impl->t_draft_us, !impl->gen_perf); - impl->draft(params, prompt_tgt, id_last, result); + impl->draft(dparams); impl->n_call_draft++; } - { - const int n_min = impl->n_min(params); + int n_drafting = 0; + + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) dparams.size(); ++seq_id) { + auto & dp = dparams[seq_id]; + + auto & result = *dp.result; + + // a new draft has been sampled + if (dp.drafting && !result.empty()) { + dp.drafting = false; + + if (dp.n_max > 0) { + if (!result.empty() && (int) result.size() > dp.n_max) { + LOG_DBG("%s: truncating draft to %d tokens\n", __func__, dp.n_max); + result.resize(dp.n_max); + } + } + + if (!result.empty()) { + LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__, + common_speculative_type_to_str(impl.get()->type).c_str(), dp.prompt->size(), + impl.get()->n_call_draft, result.size()); + + // remember which implementation was used + spec->impl_last[seq_id] = impl.get(); - if (!result.empty() && (int) result.size() < n_min) { - LOG_DBG("%s: ignoring small draft: %d < %d\n", __func__, (int) result.size(), n_min); - result.clear(); + impl->n_gen_drafts++; + impl->n_gen_tokens += result.size(); + } + } + + if (dp.drafting) { + n_drafting++; } } - if (!result.empty()) { - LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__, - common_speculative_type_to_str(impl.get()->type).c_str(), prompt_tgt.size(), - impl.get()->n_call_draft, result.size()); + if (n_drafting == 0) { + break; + } + } - spec->curr_impl = impl.get(); // set current implementation for stats - impl->n_gen_drafts++; - impl->n_gen_tokens += result.size(); + // these sequences failed to generate a draft + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) dparams.size(); ++seq_id) { + auto & dp = dparams[seq_id]; - break; // we have a draft, so break out of the loop and return it. + if (dp.drafting) { + dp.drafting = false; } } - - return result; } -void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) { +void common_speculative_accept(common_speculative * spec, llama_seq_id seq_id, uint16_t n_accepted) { if (n_accepted == 0) { return; } - common_speculative_state * impl = spec->curr_impl; + common_speculative_impl * impl = spec->impl_last[seq_id]; GGML_ASSERT(impl); @@ -1188,37 +1137,11 @@ void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) { impl->n_acc_tokens += n_accepted; } - impl->accept(n_accepted); + impl->accept(seq_id, n_accepted); impl->n_call_accept++; } } -int32_t common_speculative_n_max(const common_speculative * spec, const common_params_speculative & params) { - if (spec == nullptr) { - return 0; - } - - int32_t n_max = 0; - for (const auto & impl : spec->impls) { - n_max = std::max(n_max, impl->n_max(params)); - } - - return n_max; -} - -int32_t common_speculative_n_min(const common_speculative * spec, const common_params_speculative & params) { - if (spec == nullptr) { - return 0; - } - - int32_t n_min = 0; - for (const auto & impl : spec->impls) { - n_min = std::max(n_min, impl->n_min(params)); - } - - return n_min; -} - void common_speculative_print_stats(const common_speculative * spec) { if (spec == nullptr) { return; diff --git a/common/speculative.h b/common/speculative.h index 14744763170..51f0b059fa4 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -5,8 +5,14 @@ struct common_speculative; +// comma separated list the provided types +std::string common_speculative_type_name_str(const std::vector & types); + // comma separated list of all types -std::string common_speculative_type_name_str(); +const char * common_speculative_all_types_str(); + +// parse user provided types +std::vector common_speculative_types_from_names(const std::vector & names); // convert string to type enum common_speculative_type common_speculative_type_from_name(const std::string & name); @@ -14,27 +20,44 @@ enum common_speculative_type common_speculative_type_from_name(const std::string // convert type to string std::string common_speculative_type_to_str(enum common_speculative_type type); -common_speculative * common_speculative_init( - common_params_speculative & params, - llama_context * ctx_tgt); +common_speculative * common_speculative_init(common_params_speculative & params, uint32_t n_seq); void common_speculative_free(common_speculative * spec); +struct common_speculative_draft_params { + // this flag is used to chain the drafts through all the available implementations + // after the first successful draft from an implementation, we set it + // to false to prevent further drafts for that sequence + // at the end of the draft() call, all drafting flags will be reset to false + bool drafting = false; + + // overrides individual configurations (-1 disabled) + // can be used to constraint the max draft based on the remaining context size + int32_t n_max = -1; + + llama_pos n_past; + llama_token id_last; + + // TODO: remove in the future by keeping track of the prompt from the _begin() call and the consecutive accept calls + const llama_tokens * prompt; + + // the generated draft from the last _draft() call + llama_tokens * result; +}; + +common_speculative_draft_params & common_speculative_get_draft_params(common_speculative * spec, llama_seq_id seq_id); + // optionally call once at the beginning of a new generation -void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt); +void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, const llama_tokens & prompt); -// sample up to n_draft tokens and add them to the batch using the draft model -llama_tokens common_speculative_draft( - common_speculative * spec, - const common_params_speculative & params, - const llama_tokens & prompt, - llama_token id_last); +// process the batch and update the internal state of the speculative context +bool common_speculative_process(common_speculative * spec, const llama_batch & batch); -// informs the speculative decoder that n_accepted tokens were accepted by the target model -void common_speculative_accept(common_speculative * spec, uint16_t n_accepted); +// generate drafts for the sequences specified with `common_speculative_get_draft_params` +void common_speculative_draft(common_speculative * spec); -int32_t common_speculative_n_max(const common_speculative * spec, const common_params_speculative & params); -int32_t common_speculative_n_min(const common_speculative * spec, const common_params_speculative & params); +// informs the speculative context that n_accepted tokens were accepted by the target model +void common_speculative_accept(common_speculative * spec, llama_seq_id, uint16_t n_accepted); // print statistics about the speculative decoding void common_speculative_print_stats(const common_speculative * spec); diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index 5b61b62a1bf..5325bcc9e3f 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -13,20 +13,6 @@ #include #include -struct spec_checkpoint { - int64_t n_tokens = 0; - - std::vector data; - - size_t size() const { - return data.size(); - } - - bool empty() const { - return data.empty(); - } -}; - int main(int argc, char ** argv) { std::setlocale(LC_NUMERIC, "C"); @@ -43,11 +29,6 @@ int main(int argc, char ** argv) { return 1; } - if (params.speculative.draft.mparams.path.empty()) { - LOG_ERR("%s: --model-draft is required\n", __func__); - return 1; - } - // init llama.cpp llama_backend_init(); llama_numa_init(params.numa); @@ -62,18 +43,11 @@ int main(int argc, char ** argv) { model_tgt = llama_init_tgt->model(); ctx_tgt = llama_init_tgt->context(); - // check if the context supports partial sequence removal - const auto ctx_seq_rm = common_context_can_seq_rm(ctx_tgt); - const bool use_ckpt = (ctx_seq_rm == COMMON_CONTEXT_SEQ_RM_TYPE_FULL); - - if (use_ckpt) { - LOG_INF("speculative decoding will use checkpoints (context does not support partial sequence removal)\n"); - } - const llama_vocab * vocab = llama_model_get_vocab(model_tgt); // load the draft model llama_model_ptr model_dft; + llama_context_ptr ctx_dft; // TODO: simplify this logic { @@ -81,9 +55,6 @@ int main(int argc, char ** argv) { auto params_dft = params; - params_dft.n_parallel = 1; - params_dft.n_ctx = params_spec.n_ctx; - params_dft.n_batch = llama_n_ctx_seq(ctx_tgt); params_dft.devices = params_spec.devices; params_dft.model = params_spec.mparams; params_dft.n_gpu_layers = params_spec.n_gpu_layers; @@ -103,8 +74,19 @@ int main(int argc, char ** argv) { return 1; } - params.speculative.draft.model = model_dft.get(); - params.speculative.draft.cparams = common_context_params_to_llama(params_dft); + auto cparams = common_context_params_to_llama(params_dft); + ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams)); + + params.speculative.draft.ctx_tgt = ctx_tgt; + params.speculative.draft.ctx_dft = ctx_dft.get(); + } + + // check if the context supports partial sequence removal + const bool use_ckpt_tgt = (common_context_can_seq_rm(ctx_tgt) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL); + const bool use_ckpt_dft = (common_context_can_seq_rm(ctx_dft.get()) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL); + + if (use_ckpt_tgt) { + LOG_INF("speculative decoding will use checkpoints (context does not support partial sequence removal)\n"); } // Tokenize the prompt @@ -136,6 +118,8 @@ int main(int argc, char ** argv) { // used to determine end of generation bool has_eos = false; + llama_seq_id seq_id = 0; + // ================================================ // everything until here is standard initialization // the relevant stuff for speculative decoding starts here @@ -146,7 +130,8 @@ int main(int argc, char ** argv) { common_sampler_ptr smpl(common_sampler_init(model_tgt, params.sampling)); // eval the prompt - llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1)); + llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1)); + llama_decode(ctx_dft.get(), llama_batch_get_one(inp.data(), inp.size() - 1)); // note: keep the last token separate! llama_token id_last = inp.back(); @@ -160,16 +145,16 @@ int main(int argc, char ** argv) { // init the speculator const auto & params_spec = params.speculative; - struct common_speculative * spec = common_speculative_init(params.speculative, ctx_tgt); + struct common_speculative * spec = common_speculative_init(params.speculative, 1); - common_speculative_begin(spec, prompt_tgt); + common_speculative_begin(spec, seq_id, prompt_tgt); llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1); size_t n_draft = 0; llama_tokens draft; - spec_checkpoint spec_ckpt; + common_prompt_checkpoint ckpt; const auto t_enc_end = ggml_time_us(); @@ -184,40 +169,57 @@ int main(int argc, char ** argv) { // from a cache or lookup tables. // if (draft.empty()) { + ckpt.update_pos( + prompt_tgt.size(), + llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), seq_id), + llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), seq_id)); + + if (use_ckpt_dft) { + ckpt.update_dft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + } + // generate a new draft - draft = common_speculative_draft(spec, params_spec, prompt_tgt, id_last); + common_speculative_get_draft_params(spec, seq_id) = { + /* .drafting = */ true, + /* .n_max = */ -1, + /* .n_past = */ n_past, + /* .id_last = */ id_last, + /* .prompt = */ &prompt_tgt, + /* .result = */ &draft, // output + }; + common_speculative_draft(spec); // save the original draft size n_draft = draft.size(); // save a checkpoint of the target context before evaluating the draft // this allows us to restore the state if partial draft acceptance occurs - if (!draft.empty() && use_ckpt) { - const size_t ckpt_size = llama_state_seq_get_size_ext(ctx_tgt, 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - spec_ckpt.data.resize(ckpt_size); + if (!draft.empty()) { + if (use_ckpt_tgt) { + ckpt.update_tgt(ctx_tgt, seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + } + } - const size_t n = llama_state_seq_get_data_ext(ctx_tgt, spec_ckpt.data.data(), ckpt_size, 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - GGML_ASSERT(n == ckpt_size); + { + ckpt.load_dft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); - spec_ckpt.n_tokens = (int64_t) prompt_tgt.size(); - LOG_DBG("created speculative checkpoint (n_tokens = %" PRId64 ", size = %.3f MiB)\n", - spec_ckpt.n_tokens, (float) spec_ckpt.data.size() / 1024 / 1024); + llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), seq_id, ckpt.pos_max + 1, -1); } } else { // we have a previous (partial) draft to reuse from checkpoint restoration - if (use_ckpt) { - GGML_ASSERT(!spec_ckpt.empty()); + if (use_ckpt_tgt) { + GGML_ASSERT(!ckpt.empty()); } } // always have a token to evaluate from before - id_last common_batch_clear(batch_tgt); - common_batch_add (batch_tgt, id_last, n_past++, { 0 }, true); + common_batch_add (batch_tgt, id_last, n_past++, { seq_id }, true); // evaluate the target model on [id_last, draft0, draft1, ..., draftN-1] { for (size_t i = 0; i < draft.size(); ++i) { - common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true); + common_batch_add(batch_tgt, draft[i], n_past + i, { seq_id }, true); } //LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str()); @@ -225,9 +227,15 @@ int main(int argc, char ** argv) { llama_decode(ctx_tgt, batch_tgt); } + // evaluate the same batch with the draft model + { + // TODO: extend to support MTP, Eagle, etc. See server code for reference + llama_decode(ctx_dft.get(), batch_tgt); + } + // only save the sampler sampler state if we use checkpoints common_sampler_ptr smpl_save; - if (use_ckpt) { + if (use_ckpt_tgt) { smpl_save.reset(common_sampler_clone(smpl.get())); } @@ -247,17 +255,24 @@ int main(int argc, char ** argv) { // check for partial draft acceptance: // if the context doesn't support partial sequence removal, restore the checkpoint // and make the accepted tokens the new partial draft for the next iteration - if (use_ckpt && ids.size() - 1 < draft.size()) { + if (use_ckpt_tgt && ids.size() - 1 < draft.size()) { LOG_DBG("partial acceptance: %zu < %zu, restoring checkpoint\n", ids.size() - 1, draft.size()); draft = std::move(ids); - const size_t n = llama_state_seq_set_data_ext(ctx_tgt, spec_ckpt.data.data(), spec_ckpt.size(), 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - GGML_ASSERT(n == spec_ckpt.size()); + { + ckpt.load_tgt(ctx_tgt, seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); - llama_memory_seq_rm(llama_get_memory(ctx_tgt), 0, spec_ckpt.n_tokens, -1); + llama_memory_seq_rm(llama_get_memory(ctx_tgt), seq_id, ckpt.pos_max + 1, -1); + } + + { + ckpt.load_dft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + + llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), seq_id, ckpt.pos_max + 1, -1); + } - prompt_tgt.resize(spec_ckpt.n_tokens); + prompt_tgt.resize(ckpt.n_tokens); smpl = std::move(smpl_save); n_past = (int) prompt_tgt.size(); @@ -265,7 +280,7 @@ int main(int argc, char ** argv) { continue; } - common_speculative_accept(spec, ids.size() - 1); + common_speculative_accept(spec, seq_id, ids.size() - 1); // full acceptance: consume the draft and commit accepted tokens n_past += ids.size() - 1; @@ -305,7 +320,8 @@ int main(int argc, char ** argv) { { LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past); - llama_memory_seq_rm(llama_get_memory(ctx_tgt), 0, n_past, -1); + llama_memory_seq_rm(llama_get_memory(ctx_tgt), seq_id, n_past, -1); + llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), seq_id, n_past, -1); } if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) { diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index b92a208705d..e25be3592fd 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3929,10 +3929,25 @@ static int ggml_cuda_try_fuse(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph // closure check: the trailing add must read the same x as the leading mul const ggml_tensor * x_in_add = (add->src[0] == mul1) ? add->src[1] : add->src[0]; - const bool type_ok = (x->type == GGML_TYPE_F32 || x->type == GGML_TYPE_F16 || x->type == GGML_TYPE_BF16); + // Kernel iterates over total = T * C, so x and add must be 2D and + // a / inv_b must collapse to [1, C, 1, 1]. Higher dims are not handled. + const bool dim_ok = (x->ne[2] == 1 && x->ne[3] == 1) && + (add->ne[2] == 1 && add->ne[3] == 1) && + (a->ne[2] == 1 && a->ne[3] == 1); const bool shape_ok = ggml_are_same_shape(a, inv_b) && a->ne[0] == 1 && a->ne[1] == x->ne[1]; - if (type_ok && shape_ok && x_in_add == x && add->type == x->type) { + // x must be in the supported whitelist and every operand / intermediate + // result must share x's type, since launch_snake casts a / inv_b as + // float and templates the kernel on a single T. Mixed precision chains + // fall back to the naive path. + const ggml_tensor * sin1 = cgraph->nodes[i + 1]; + const bool types_ok = (x->type == GGML_TYPE_F32 || x->type == GGML_TYPE_F16 || x->type == GGML_TYPE_BF16) && + (a->type == x->type) && (inv_b->type == x->type) && + (mul0->type == x->type) && (sin1->type == x->type) && + (sqr->type == x->type) && (mul1->type == x->type) && + (add->type == x->type); + + if (types_ok && shape_ok && dim_ok && x_in_add == x) { ggml_cuda_op_snake_fused(*cuda_ctx, x, a, inv_b, add); return 4; } @@ -5291,12 +5306,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_VIEW: case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: - case GGML_OP_ADD: case GGML_OP_ADD_ID: case GGML_OP_ADD1: - case GGML_OP_SUB: - case GGML_OP_MUL: - case GGML_OP_DIV: case GGML_OP_SCALE: case GGML_OP_SQR: case GGML_OP_SQRT: @@ -5305,6 +5316,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_CLAMP: case GGML_OP_LOG: return true; + case GGML_OP_ADD: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) && + (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16); case GGML_OP_SSM_SCAN: { if (op->src[3]->ne[0] == 1) { // Mamba2 diff --git a/ggml/src/ggml-cuda/im2col.cu b/ggml/src/ggml-cuda/im2col.cu index 56dc0545742..28c79ab462e 100644 --- a/ggml/src/ggml-cuda/im2col.cu +++ b/ggml/src/ggml-cuda/im2col.cu @@ -1,5 +1,6 @@ #include "im2col.cuh" +#define MAX_GRIDDIM_Y 65535 #define MAX_GRIDDIM_Z 65535 template @@ -18,22 +19,23 @@ static __global__ void im2col_kernel( const int64_t ikh = rem / KW; const int64_t ikw = rem - ikh * KW; - const int64_t iow = blockIdx.y; - for (int64_t iz = blockIdx.z; iz < N_OH; iz+=MAX_GRIDDIM_Z) { - const int64_t in = iz / OH; - const int64_t ioh = iz - in * OH; + for (int64_t iow = blockIdx.y; iow < OW; iow += MAX_GRIDDIM_Y) { + for (int64_t iz = blockIdx.z; iz < N_OH; iz += MAX_GRIDDIM_Z) { + const int64_t in = iz / OH; + const int64_t ioh = iz - in * OH; - const int64_t iiw = iow * s0 + ikw * d0 - p0; - const int64_t iih = ioh * s1 + ikh * d1 - p1; + const int64_t iiw = iow * s0 + ikw * d0 - p0; + const int64_t iih = ioh * s1 + ikh * d1 - p1; - const int64_t offset_dst = - ((in * OH + ioh) * OW + iow) * IC_KH_KW + iic * KH_KW + ikh * KW + ikw; + const int64_t offset_dst = + ((in * OH + ioh) * OW + iow) * IC_KH_KW + iic * KH_KW + ikh * KW + ikw; - if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { - dst[offset_dst] = 0.0f; - } else { - const int64_t offset_src = iic * IC_IH_IW + in * IH_IW; - dst[offset_dst] = x[offset_src + iih * IW + iiw]; + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst[offset_dst] = 0.0f; + } else { + const int64_t offset_src = iic * IC_IH_IW + in * IH_IW; + dst[offset_dst] = x[offset_src + iih * IW + iiw]; + } } } @@ -51,7 +53,7 @@ static void im2col_cuda(const float * x, T* dst, const int64_t num_blocks = (IC_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE; const int64_t N_OH = N * OH; const int64_t KH_KW = KW*KH; - dim3 block_nums(num_blocks, OW, MIN(N_OH, MAX_GRIDDIM_Z)); + dim3 block_nums(num_blocks, MIN(OW, MAX_GRIDDIM_Y), MIN(N_OH, MAX_GRIDDIM_Z)); im2col_kernel<<>>(x, dst, IC, IW, IH, OH, OW, KW, KH, IC_IH_IW, IH_IW, N_OH, KH_KW, IC_KH_KW, s0, s1, p0, p1, d0, d1); @@ -136,23 +138,24 @@ static __global__ void im2col_3d_kernel( const int64_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW; const int64_t ikw = i % KW; - const int64_t iow = blockIdx.y; - for (int64_t iz = blockIdx.z; iz < N_OD_OH; iz+=MAX_GRIDDIM_Z) { - const int64_t in = iz / OD_OH; - const int64_t iod = (iz - in*OD_OH) / OH; - const int64_t ioh = iz % OH; + for (int64_t iow = blockIdx.y; iow < OW; iow += MAX_GRIDDIM_Y) { + for (int64_t iz = blockIdx.z; iz < N_OD_OH; iz += MAX_GRIDDIM_Z) { + const int64_t in = iz / OD_OH; + const int64_t iod = (iz - in*OD_OH) / OH; + const int64_t ioh = iz % OH; - const int64_t iiw = iow * s0 + ikw * d0 - p0; - const int64_t iih = ioh * s1 + ikh * d1 - p1; - const int64_t iid = iod * s2 + ikd * d2 - p2; + const int64_t iiw = iow * s0 + ikw * d0 - p0; + const int64_t iih = ioh * s1 + ikh * d1 - p1; + const int64_t iid = iod * s2 + ikd * d2 - p2; - const int64_t offset_dst = in*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw; + const int64_t offset_dst = in*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw; - if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) { - dst[offset_dst] = 0.0f; - } else { - const int64_t offset_src = ((in * IC + iic) * stride_q) + (iid * stride_z) + (iih * stride_y) + (iiw * stride_x); - dst[offset_dst] = src[offset_src]; + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) { + dst[offset_dst] = 0.0f; + } else { + const int64_t offset_src = ((in * IC + iic) * stride_q) + (iid * stride_z) + (iih * stride_y) + (iiw * stride_x); + dst[offset_dst] = src[offset_src]; + } } } } @@ -178,7 +181,7 @@ static void im2col_3d_cuda(const float * src, T* dst, const int64_t OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW; const int64_t OW_IC_KD_KH_KW = OW*IC*KD*KH*KW; const int64_t num_blocks = (IC_KD_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE; - dim3 block_nums(num_blocks, OW, MIN(N_OD_OH, MAX_GRIDDIM_Z)); + dim3 block_nums(num_blocks, MIN(OW, MAX_GRIDDIM_Y), MIN(N_OD_OH, MAX_GRIDDIM_Z)); im2col_3d_kernel<<>>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, OH_OW, KD_KH_KW, ID_IH_IW, KH_KW, IH_IW, IC_ID_IH_IW, IC_KD_KH_KW, OW_KD_KH_KW, OD_OH_OW_IC_KD_KH_KW, diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index ffde6a4f063..7edb3eb4e9c 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -104,6 +104,8 @@ set(GGML_OPENCL_KERNELS mul_mv_id_mxfp4_f32_flat gemm_moe_q4_0_f32_ns gemv_moe_q4_0_f32_ns + gemm_moe_q4_1_f32_ns + gemv_moe_q4_1_f32_ns gemm_moe_mxfp4_f32 gemv_moe_mxfp4_f32 gemm_moe_mxfp4_f32_ns diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 4e6f6fb43d2..73a58f74a94 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -544,6 +544,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0; cl_kernel kernel_convert_block_q4_0_trans4_ns, kernel_restore_block_q4_0_trans4_ns; cl_kernel kernel_convert_block_q4_1, kernel_restore_block_q4_1; + cl_kernel kernel_convert_block_q4_1_trans4_ns, kernel_restore_block_q4_1_trans4_ns; cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans; cl_kernel kernel_convert_block_mxfp4_trans4_ns, kernel_restore_block_mxfp4_trans4_ns; cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0, kernel_restore_block_q8_0_trans; @@ -602,6 +603,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_ssm_conv_f32_f32, kernel_ssm_conv_f32_f32_4; cl_kernel kernel_timestep_embedding; cl_kernel kernel_gemv_moe_q4_0_f32_ns, kernel_gemm_moe_q4_0_f32_ns; + cl_kernel kernel_gemv_moe_q4_1_f32_ns, kernel_gemm_moe_q4_1_f32_ns; cl_kernel kernel_gemv_moe_mxfp4_f32, kernel_gemm_moe_mxfp4_f32; cl_kernel kernel_gemv_moe_mxfp4_f32_ns, kernel_gemm_moe_mxfp4_f32_ns; cl_kernel kernel_moe_reorder_b; @@ -958,6 +960,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_restore_block_q4_1_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1_noshuffle", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_1", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q4_1_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_1_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q4_1_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans4_ns", &err), err)); @@ -2856,6 +2860,38 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve " -cl-mad-enable " " -cl-fast-relaxed-math"; + // gemv_moe_q4_1_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_moe_q4_1_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_moe_q4_1_f32_ns.cl"); +#endif + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemv_moe_q4_1_f32_ns = clCreateKernel(prog, "kernel_gemv_moe_q4_1_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemm_moe_q4_1_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_moe_q4_1_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_moe_q4_1_f32_ns.cl"); +#endif + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemm_moe_q4_1_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_q4_1_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // gemv_moe_mxfp4_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -3749,11 +3785,14 @@ struct ggml_tensor_extra_cl_q4_1 { CL_CHECK(clReleaseMemObject(m)); m = nullptr; } + if (q_img != nullptr) { + CL_CHECK(clReleaseMemObject(q_img)); + q_img = nullptr; + } // Currently, q_img and d_img are only initialized when SMALL_ALLOC is // enabled. They point to the images in ggml_backend_opencl_buffer_context. // So, there is no need to release them here. // TODO: initialize them for non SMALL_PATH path, or remove them. - q_img = nullptr; d_img = nullptr; m_img = nullptr; size_q = 0; @@ -4189,6 +4228,35 @@ static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggm return GGML_STATUS_SUCCESS; } +// The optimized gemm and gemv kernels are used for large matrices without batch. +// tensor is the quantized weights matrix. +inline bool use_adreno_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { + int64_t threshold_ne0 = 512; + int64_t threshold_ne1 = 512; + if (!backend_ctx->adreno_cl_compiler_version.newer_than_or_same(E031, 38, 11, 0) && + backend_ctx->adreno_cl_compiler_version.type != DX) { + threshold_ne0 = 128; + threshold_ne1 = 128; + } + return tensor->ne[0] >= threshold_ne0 && tensor->ne[1] >= threshold_ne1 && + tensor->ne[2] == 1 && tensor->ne[3] == 1; +} + +inline bool use_adreno_moe_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { + GGML_UNUSED(backend_ctx); + int ne01 = tensor->ne[1]; + return (((strstr(tensor->name, "ffn") != NULL) && (strstr(tensor->name, "exps") != NULL)) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 64 == 0); +} + +inline bool enable_adreno_trans_weight(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { + + bool adreno_kernel = use_adreno_kernels(backend_ctx, tensor); + + size_t elem_num = tensor->ne[0] * tensor->ne[1] * tensor->ne[2] * tensor->ne[3]; + + return ((elem_num < 128 * 1024 * 1024) && adreno_kernel); // max element num: 2**27 +} + static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *)dev->context; ggml_backend_opencl_context * backend_ctx = dev_ctx->backend_ctx; @@ -4385,6 +4453,18 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); } } + // q4_0, q8_0 and mxfp4 have general MUL_MAT_ID support, + // the quantizations here currently do not - they are only supported by Adreno with certain shapes + if (op->src[0]->type == GGML_TYPE_Q4_1) { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (op->src[1]->type == GGML_TYPE_F32) { + return use_adreno_moe_kernels(backend_ctx, op->src[0]) + && ggml_is_contiguous(op->src[0]) + && ggml_is_contiguous(op->src[1]); + } +#endif + return false; + } return false; case GGML_OP_RESHAPE: case GGML_OP_VIEW: @@ -4555,6 +4635,12 @@ struct ggml_backend_opencl_buffer_context { for (ggml_tensor_extra_cl_q4_0 * e : temp_tensor_extras_q4_0_in_use) { delete e; } + for (ggml_tensor_extra_cl_q4_1 * e : temp_tensor_extras_q4_1) { + delete e; + } + for (ggml_tensor_extra_cl_q4_1 * e : temp_tensor_extras_q4_1_in_use) { + delete e; + } for (ggml_tensor_extra_cl_mxfp4 * e : temp_tensor_extras_mxfp4) { delete e; } @@ -4868,35 +4954,6 @@ static enum ggml_status ggml_backend_opencl_buffer_init_tensor(ggml_backend_buff return GGML_STATUS_SUCCESS; } -// The optimized gemm and gemv kernels are used for large matrices without batch. -// tensor is the quantized weights matrix. -inline bool use_adreno_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { - int64_t threshold_ne0 = 512; - int64_t threshold_ne1 = 512; - if (!backend_ctx->adreno_cl_compiler_version.newer_than_or_same(E031, 38, 11, 0) && - backend_ctx->adreno_cl_compiler_version.type != DX) { - threshold_ne0 = 128; - threshold_ne1 = 128; - } - return tensor->ne[0] >= threshold_ne0 && tensor->ne[1] >= threshold_ne1 && - tensor->ne[2] == 1 && tensor->ne[3] == 1; -} - -inline bool use_adreno_moe_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { - GGML_UNUSED(backend_ctx); - int ne01 = tensor->ne[1]; - return (((strstr(tensor->name, "ffn") != NULL) && (strstr(tensor->name, "exps") != NULL)) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 64 == 0); -} - -inline bool enable_adreno_trans_weight(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { - - bool adreno_kernel = use_adreno_kernels(backend_ctx, tensor); - - size_t elem_num = tensor->ne[0] * tensor->ne[1] * tensor->ne[2] * tensor->ne[3]; - - return ((elem_num < 128 * 1024 * 1024) && adreno_kernel); // max element num: 2**27 -} - static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer->buft->device); @@ -5097,15 +5154,54 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); CL_CHECK(err); - #ifdef GGML_OPENCL_USE_ADRENO_KERNELS +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // Adreno moe q4_1 kernel needs special transpose and unshuffling + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_kernel kernel = backend_ctx->kernel_convert_block_q4_1_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->m)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); + + size_t global_work_size[3] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 32), static_cast(ne02)}; + size_t local_work_size[3] = {64, 2, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + // Create image for Q + cl_image_format img_format_q = {CL_R, CL_UNSIGNED_INT32}; + cl_image_desc img_desc_q = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast(ggml_nelements(tensor) / 8), + 0, 0, 0, 0, 0, 0, 0, + { extra->q } + }; + extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err); + tensor->extra = extra; + + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + + // normal q4_1 repack +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS cl_kernel kernel = backend_ctx->kernel_convert_block_q4_1; if (use_adreno_kernels(backend_ctx, tensor)) { kernel = backend_ctx->kernel_convert_block_q4_1_noshuffle; } - #else +#else cl_kernel kernel = backend_ctx->kernel_convert_block_q4_1; - #endif // GGML_OPENCL_USE_ADRENO_KERNELS +#endif // GGML_OPENCL_USE_ADRENO_KERNELS CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); @@ -5862,6 +5958,36 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, ggml_tensor_extra_cl_q4_1 * extra = (ggml_tensor_extra_cl_q4_1 *)tensor->extra; #ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + cl_kernel kernel = backend_ctx->kernel_restore_block_q4_1_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->m)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &ne01)); + + size_t global_work_size[3] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 32), static_cast(ne02)}; + size_t local_work_size[3] = {64, 2, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } if (use_adreno_kernels(backend_ctx, tensor)) { static ggml_cl_buffer buf_trans_q; static ggml_cl_buffer buf_trans_m; @@ -12862,6 +12988,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, #ifdef GGML_OPENCL_SOA_Q ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra; + ggml_tensor_extra_cl_q4_1 * extra0_q4_1 = (ggml_tensor_extra_cl_q4_1 *)src0->extra; ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra; ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra; #endif @@ -13131,6 +13258,179 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, break; } + case GGML_TYPE_Q4_1: { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, src0)) { + cl_int status; + + size_t local_size[3] = {64, 2, 1}; + size_t global_size[3] = {64, 2, 1}; + + if (ne12 == 1) { // for gemv + kernel = backend_ctx->kernel_gemv_moe_q4_1_f32_ns; + + cl_mem src1_sub_buffer, buf_src1_image, buf_src2; + + // create a sub_buffer for src2 + cl_buffer_region region; + region.origin = offset2; + region.size = ne20 * ne21 * sizeof(int); + buf_src2 = clCreateSubBuffer(extra2->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // set thread grid + global_size[0] = static_cast(ne01); + global_size[1] = 4; + global_size[2] = static_cast(ne20); + local_size[1] = 4; + + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // create image for src1 + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}}; + buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_1->q)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_1->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_1->m)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11)); + + // launch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + // deallocate sub buffers and images + CL_CHECK(clReleaseMemObject(src1_sub_buffer)); + CL_CHECK(clReleaseMemObject(buf_src1_image)); + CL_CHECK(clReleaseMemObject(buf_src2)); + + } else { // for gemm + kernel = backend_ctx->kernel_gemm_moe_q4_1_f32_ns; + + if (strstr(src0->name, "as") != NULL) { + moe_router_reoerder(backend, src2, ne20); + } + + cl_mem sub_buf_src1_pre, buf_src1_reordered, image_src1_reordered, sub_buf_dst, buf_dst_image; + cl_mem buf_src2, buf_src2_emap; + + cl_buffer_region region; + region.origin = 0; + region.size = sizeof(int) * max_post_router_tile * n_tile_size; + buf_src2 = clCreateSubBuffer(backend_ctx->prealloc_post_router.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + region.origin = 0; + region.size = sizeof(short) * max_post_router_tile; + buf_src2_emap = clCreateSubBuffer(backend_ctx->prealloc_emap.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Reorder activations + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + sub_buf_src1_pre = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Create image for reordered src1 + // Use pre-allocated placeholder + region.origin = 0; + region.size = ne00 * max_post_router_tile * n_tile_size * sizeof(float); + backend_ctx->prealloc_act_trans.allocate(backend_ctx->context, region.size); + buf_src1_reordered = clCreateSubBuffer( + backend_ctx->prealloc_act_trans.buffer, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + cl_image_format image_format_buf_src1; + cl_image_desc image_desc_buf_src1; + image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}}; + image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + unsigned short map_ratio = ne20 / ne11; + GGML_ASSERT(((map_ratio == 1) || (map_ratio == ne20)) && "Map ratio not supported\n"); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 0, sizeof(cl_mem), &sub_buf_src1_pre)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 1, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 2, sizeof(cl_mem), &buf_src1_reordered)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 3, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 4, sizeof(unsigned int), &ne00)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 5, sizeof(unsigned short), &map_ratio)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 6, sizeof(unsigned int), &n_tile_size)); + + size_t reorder_b_local_size[3] = {256, 1, 1}; + size_t reorder_b_global_size[3] = {static_cast(((ne00 / 4) + 255) / 256 * 256), static_cast(max_post_router_tile * n_tile_size), 1}; + + // Dispatch reorder kernel + backend_ctx->enqueue_ndrange_kernel(backend_ctx->kernel_moe_reorder_b, 3, reorder_b_global_size, reorder_b_local_size, dst); + + // MoE kernel prepare + // Create sub buffer for dst + region.origin = offsetd; + region.size = ne0 * ne1 * ne2 * sizeof(float); + sub_buf_dst = clCreateSubBuffer( + extrad->data_device, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + // Create image for dst + cl_image_format image_format_buf_dst = {CL_R, CL_FLOAT}; + cl_image_desc image_desc_buf_dst = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne0 * ne1 * ne2), 0,0,0,0,0,0,0, {sub_buf_dst}}; + buf_dst_image = clCreateImage(backend_ctx->context, CL_MEM_WRITE_ONLY, &image_format_buf_dst, &image_desc_buf_dst, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_1->q_img)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_1->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_1->m)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &image_src1_reordered)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2_emap)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_dst_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + + // set thread grid + global_size[1] = static_cast((ne01 + 63) / 64); + global_size[2] = static_cast(max_post_router_tile); + local_size[1] = 1; + local_size[2] = 1; + + // Dispatch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + clReleaseMemObject(sub_buf_src1_pre); + clReleaseMemObject(buf_src1_reordered); + clReleaseMemObject(image_src1_reordered); + clReleaseMemObject(buf_src2); + clReleaseMemObject(buf_src2_emap); + clReleaseMemObject(sub_buf_dst); + clReleaseMemObject(buf_dst_image); + } + return; + } +#endif //GGML_OPENCL_USE_ADRENO_KERNELS + } case GGML_TYPE_Q8_0: { #ifdef GGML_OPENCL_SOA_Q kernel = backend_ctx->kernel_mul_mv_id_q8_0_f32_flat; diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index c87450dc49e..5bbf09710f9 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -370,6 +370,96 @@ kernel void kernel_restore_block_q4_1_noshuffle( } } +kernel void kernel_convert_block_q4_1_trans4_ns( + __global struct block_q4_1 * src0, + __global uint * dst_q, + __global half * dst_d, + __global half * dst_m, + uint ne00, + uint ne01 +) { + uint i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + uint ne00_blk = ne00 / QK4_1; + uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + global struct block_q4_1 * b = src0 + src_blk_offset; + dst_d[dst_blk_offset] = b->d; + dst_m[dst_blk_offset] = b->m; + + // extract quantization and unshuffle + ushort8 pre_block = ((global ushort8 *)(&(b->qs[0])))[0]; + + ushort8 post_block = (ushort8)(0); + + uchar * pre_block_ptr = (uchar *)(&pre_block); + uchar * post_block_ptr = (uchar *)(&post_block); + + for (int i = 0; i < QK4_1 / 4; ++i) { + uchar x0 = pre_block_ptr[2*i + 0]; + uchar x1 = pre_block_ptr[2*i + 1]; + + post_block_ptr[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + post_block_ptr[i + QK4_1 / 4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + } + + uint4 q_block = as_uint4(post_block); + + uint offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01; + dst_q[offset] = q_block.x; + dst_q[offset + ne01] = q_block.y; + dst_q[offset + ne01 * 2] = q_block.z; + dst_q[offset + ne01 * 3] = q_block.w; +} + +kernel void kernel_restore_block_q4_1_trans4_ns( + __global uint * src_q, + __global half * src_d, + __global half * src_m, + __global struct block_q4_1 * dst0, + uint ne00, + uint ne01 +) { + int i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + uint ne00_blk = ne00 / QK4_1; + uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint src_dm_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + __global struct block_q4_1 * b = dst0 + dst_blk_offset; + b->d = src_d[src_dm_offset]; + b->m = src_m[src_dm_offset]; + + // collect transposed quantization parts for a block + uint src_q_offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01; + uint4 q_block; + q_block.x = src_q[src_q_offset]; + q_block.y = src_q[src_q_offset + ne01]; + q_block.z = src_q[src_q_offset + ne01 * 2]; + q_block.w = src_q[src_q_offset + ne01 * 3]; + + ushort8 post_block = as_ushort8(q_block); + ushort8 pre_block = (ushort8)(0); + + uchar * pre_block_ptr = (uchar *)(&pre_block); + uchar * post_block_ptr = (uchar *)(&post_block); + + for (int i = 0; i < QK4_0 / 4; ++i) { + uchar x0 = post_block_ptr[i + 0]; + uchar x1 = post_block_ptr[i + QK4_0 / 4]; + + pre_block_ptr[2 * i + 0] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + pre_block_ptr[2 * i + 1] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + } + + ((__global ushort8 *)(&(b->qs[0])))[0] = pre_block; +} + //------------------------------------------------------------------------------ // block_mxfp4 //------------------------------------------------------------------------------ diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl new file mode 100644 index 00000000000..e2574ae0187 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl @@ -0,0 +1,254 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable +#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable + +#define TILESIZE_K 16 +#define TILESIZE_M 64 +#define TILESIZE_N 32 + + +#define dequantize_q4_1(q4, a_f16, scale, m) \ + a_f16.s0 = (half)(q4.s0 & 0x000F) * scale + m; \ + a_f16.s1 = (half)((q4.s0 & 0x00F0) >> 4) * scale + m; \ + a_f16.s2 = (half)((q4.s0 & 0x0F00) >> 8) * scale + m; \ + a_f16.s3 = (half)((q4.s0 & 0xF000) >> 12) * scale + m; \ + a_f16.s4 = (half)(q4.s1 & 0x000F) * scale + m; \ + a_f16.s5 = (half)((q4.s1 & 0x00F0) >> 4) * scale + m; \ + a_f16.s6 = (half)((q4.s1 & 0x0F00) >> 8) * scale + m; \ + a_f16.s7 = (half)((q4.s1 & 0xF000) >> 12) * scale + m; \ + a_f16.s8 = (half)(q4.s2 & 0x000F) * scale + m; \ + a_f16.s9 = (half)((q4.s2 & 0x00F0) >> 4) * scale + m; \ + a_f16.sa = (half)((q4.s2 & 0x0F00) >> 8) * scale + m; \ + a_f16.sb = (half)((q4.s2 & 0xF000) >> 12) * scale + m; \ + a_f16.sc = (half)(q4.s3 & 0x000F) * scale + m; \ + a_f16.sd = (half)((q4.s3 & 0x00F0) >> 4) * scale + m; \ + a_f16.se = (half)((q4.s3 & 0x0F00) >> 8) * scale + m; \ + a_f16.sf = (half)((q4.s3 & 0xF000) >> 12) * scale + m; \ + + +#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \ + acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \ + acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \ + acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \ + acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \ + acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \ + acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \ + acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \ + acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \ + acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \ + acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \ + acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \ + acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \ + acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \ + acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \ + acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \ + acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \ + acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \ + acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \ + acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \ + acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \ + acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \ + acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \ + acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \ + acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \ + acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \ + acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \ + acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \ + acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \ + acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \ + acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \ + acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \ + acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \ + acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \ + acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \ + acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \ + acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \ + acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \ + acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \ + acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \ + acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \ + acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \ + acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \ + acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \ + acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \ + acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \ + acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \ + acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \ + acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \ + acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \ + acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \ + acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \ + acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \ + acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \ + acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \ + acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \ + acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \ + acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \ + acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \ + acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \ + acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \ + acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \ + acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \ + acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + + +__attribute__((qcom_wave_pair_mode(1))) // 1=force single 2=force pair +kernel void kernel_gemm_moe_q4_1_f32_ns( + __read_only image1d_buffer_t src0_q, + __global half * src0_d, + __global half * src0_m, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global ushort * src2_emap, + __write_only image1d_buffer_t dst, + __global int * total_tiles, + uint ne00, + uint ne01 +) { + uint block_id_m = get_global_id(1); // m_tile + uint block_id_n = get_global_id(2); // n_tile + + // Boundary check + if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + return; + } + + __private half16 reg_a; + __private float32 reg_c = (float32)(0); + __local half4 shared_b[128]; + + const ushort expert_id = src2_emap[block_id_n]; + + const uint row = block_id_m * TILESIZE_M; + const uint col = block_id_n * TILESIZE_N; + + uint sub_block_id_m = get_local_id(0); + uint2 b_global_offset; + b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00; + b_global_offset.y = b_global_offset.x + (16 * ne00); + uint2 b_local_offset; + b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2); + b_local_offset.y = b_local_offset.x + 16; + + // Loop along K axis, 32 elements (one block) for each iteration, divided into 2 sub-blocks + for (uint step = 0; step < ne00; step += TILESIZE_K * 2) { + // First sub-block + uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + uint s_sub_offset = row + ((ne01 * step) >> 5) + ((expert_id * ne00 * ne01) >> 5); + uint b_sub_offset = col * ne00 + step; + + // Load scale and m for current Q4_1 block + uint sm_offset = s_sub_offset + get_global_id(0); + half s = src0_d[sm_offset]; + half m = src0_m[sm_offset]; + + // Load 16 q (64-bits) in transposed layout + uint2 q4x16; + q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements + float8 bx8_f32; + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + // Convert to half and store to LM to share within the subgroup + half8 bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + dequantize_q4_1(as_ushort4(q4x16), reg_a, s, m); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + // 32 16x16 fp16 dot product with 8 elements reduction for better precision + half16 acc; + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + + // Repeat for second sub-block + uint half_step = step + TILESIZE_K; + q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + b_sub_offset = col * ne00 + half_step; + + // Load next 16 q (64-bits) in transposed layout + q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + // Convert to half and store to LM to share within the subgroup + bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + dequantize_q4_1(as_ushort4(q4x16), reg_a, s, m); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + // 32 16x16 fp16 dot product with 3-levels reduction for better precision + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + } + + // Load poster router and share in LM + __local uint out_idx[TILESIZE_N]; + + if (get_local_id(0) < TILESIZE_N) { + uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)]; + if (idx == 0xFFFFFFFF) { + idx = src2[block_id_n * TILESIZE_N + 0]; + } + out_idx[get_local_id(0)] = idx * ne01; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + // Scatter results back to original position in output grid + uint m_offset = row + get_local_id(0); + + write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1)); + write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2)); + write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3)); + write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4)); + write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5)); + write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6)); + write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7)); + write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8)); + write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9)); + write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa)); + write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb)); + write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc)); + write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd)); + write_imagef(dst, out_idx[14] + m_offset, (reg_c.se)); + write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf)); + write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg)); + write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh)); + write_imagef(dst, out_idx[18] + m_offset, (reg_c.si)); + write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj)); + write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk)); + write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl)); + write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm)); + write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn)); + write_imagef(dst, out_idx[24] + m_offset, (reg_c.so)); + write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp)); + write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq)); + write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr)); + write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss)); + write_imagef(dst, out_idx[29] + m_offset, (reg_c.st)); + write_imagef(dst, out_idx[30] + m_offset, (reg_c.su)); + write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv)); + + // Store zero padding parts to the index of first output in tile, override correct result in the end + barrier(CLK_GLOBAL_MEM_FENCE); + write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0)); +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl new file mode 100644 index 00000000000..3739a215705 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl @@ -0,0 +1,119 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#define QK_Q4_1 32 +#define N_SIMDGROUP 4 +#define SIMDGROUP_WIDTH 64 + +static inline float8 q4_1_to_fp32_packed8(ushort2 q4x8, half s, half m) { + float8 fp32x8; + fp32x8.s0 = (float)((q4x8.s0 & 0x000F) * s + m); + fp32x8.s1 = (float)(((q4x8.s0 & 0x00F0) >> 4) * s + m); + fp32x8.s2 = (float)(((q4x8.s0 & 0x0F00) >> 8) * s + m); + fp32x8.s3 = (float)(((q4x8.s0 & 0xF000) >> 12) * s + m); + fp32x8.s4 = (float)((q4x8.s1 & 0x000F) * s + m); + fp32x8.s5 = (float)(((q4x8.s1 & 0x00F0) >> 4) * s + m); + fp32x8.s6 = (float)(((q4x8.s1 & 0x0F00) >> 8) * s + m); + fp32x8.s7 = (float)(((q4x8.s1 & 0xF000) >> 12) * s + m); + return fp32x8; +} + + +__attribute__((qcom_reqd_sub_group_size("half"))) +__kernel void kernel_gemv_moe_q4_1_f32_ns( + __global uint * src0_q, + __global half * src0_d, + __global half * src0_m, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne11 +) { + uint i01 = get_global_id(0); + uint i20 = get_global_id(2); + uint sgid = get_local_id(1); + uint slid = get_sub_group_local_id(); + + uint i11 = i20 % ne11; + + uint expert_id = src2[i20]; + uint expert_offset = expert_id * ne00 * ne01 / 32; + + __private float sum = 0.0f; // each thread calculate partial sum of one output + + // loop along ne00 in block granularity, skip 4 blocks every iter + for (uint ib00 = sgid; ib00 < (ne00 / QK_Q4_1); ib00 += N_SIMDGROUP) { + + // load one block of q + uint4 regQ; + uint block_offset = expert_offset * 4 + ib00 * ne01 * 4 + i01; + + regQ.s0 = src0_q[block_offset]; + regQ.s1 = src0_q[block_offset + ne01]; + regQ.s2 = src0_q[block_offset + ne01 * 2]; + regQ.s3 = src0_q[block_offset + ne01 * 3]; + + uint offset = i11 * ne00 / 4 + ib00 * 8; + + half regM = src0_m[ib00 * ne01 + i01 + expert_offset]; + half regS = src0_d[ib00 * ne01 + i01 + expert_offset]; + + float8 fp32x8 = q4_1_to_fp32_packed8(as_ushort2(regQ.s0), regS, regM); + + float4 shared_y4; + shared_y4 = read_imagef(src1, (offset + 0)); + float4 acc = shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 1)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q4_1_to_fp32_packed8(as_ushort2(regQ.s1), regS, regM); + + shared_y4 = read_imagef(src1, (offset + 2)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 3)); + acc += shared_y4 * fp32x8.hi; + + + fp32x8 = q4_1_to_fp32_packed8(as_ushort2(regQ.s2), regS, regM); + + shared_y4 = read_imagef(src1, (offset + 4)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 5)); + acc += shared_y4 * fp32x8.hi; + + + fp32x8 = q4_1_to_fp32_packed8(as_ushort2(regQ.s3), regS, regM); + + shared_y4 = read_imagef(src1, (offset + 6)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 7)); + acc += shared_y4 * fp32x8.hi; + + sum += ((acc.s0 + acc.s1) + (acc.s2 + acc.s3)); + } + + // reduction in local memory, assumes #subgroups=4 + __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)]; + if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum; + if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum; + if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum; + barrier(CLK_LOCAL_MEM_FENCE); + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 1 outputs per thread in subgroup 0 + if (sgid == 0) { + dst = dst + (offsetd >> 2); + dst[i01 + i20 * ne01] = sum; + } + +} diff --git a/include/llama.h b/include/llama.h index 2ea226726ad..308e8ba9dbd 100644 --- a/include/llama.h +++ b/include/llama.h @@ -858,6 +858,8 @@ extern "C" { size_t n_token_capacity, size_t * n_token_count_out); +#define LLAMA_STATE_SEQ_FLAGS_NONE 0 + // for backwards-compat #define LLAMA_STATE_SEQ_FLAGS_SWA_ONLY 1 diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 71a59395eb2..3d9714ab166 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -2475,11 +2475,29 @@ class llama_io_write_device : public llama_io_write_i { } if (need_alloc) { - mbuf_cur = std::move(mbuf); + if (!mbuf_cur.buf || mbuf_cur.total_size != mbuf.total_size) { + mbuf_cur = std::move(mbuf); - mbuf_cur.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(mbuf_cur.ctx.get(), buft)); + mbuf_cur.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(mbuf_cur.ctx.get(), buft)); - LLAMA_LOG_INFO("%s: allocated '%s' buffer %.3f MiB\n", __func__, ggml_backend_buft_name(buft), mbuf.total_size/1024.0/1024.0); + LLAMA_LOG_INFO("%s: allocated '%s' buffer %.3f MiB\n", __func__, ggml_backend_buft_name(buft), mbuf.total_size/1024.0/1024.0); + } else { + //LLAMA_LOG_INFO("%s: reallocating tensors in '%s' buffer %.3f MiB\n", __func__, ggml_backend_buft_name(buft), mbuf.total_size/1024.0/1024.0); + + // save the old buffer and allocate the new tensors in it + auto buf = std::move(mbuf_cur.buf); + + mbuf_cur = std::move(mbuf); + + ggml_tallocr talloc = ggml_tallocr_new(buf.get()); + + for (size_t i = 0; i < mbuf_cur.org.size(); ++i) { + ggml_backend_view_init(mbuf_cur.org[i]); + ggml_tallocr_alloc(&talloc, mbuf_cur.cpy[i]); + } + + mbuf_cur.buf = std::move(buf); + } } for (size_t i = 0; i < mbuf_cur.org.size(); ++i) { @@ -2559,8 +2577,7 @@ class llama_io_read_device : public llama_io_read_i { mbuf.org.push_back(ggml_view_1d(mbuf.ctx.get(), rinfo.tensor, n, rinfo.offset)); - auto & view = mbuf.org.back(); - view->buffer = rinfo.tensor->buffer; + ggml_backend_view_init(mbuf.org.back()); } for (auto & [buft, mbuf] : mbufs_new) { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 922ad493a34..33311948669 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3561,7 +3561,7 @@ struct test_relu_sqr : public test_case { // and dispatches a single fused kernel. struct test_snake_fuse : public test_case { const ggml_type type; - const std::array ne; // [T, C] + const std::array ne; // [T, C, D2, D3] std::string op_desc(ggml_tensor * t) override { GGML_UNUSED(t); @@ -3586,11 +3586,11 @@ struct test_snake_fuse : public test_case { } test_snake_fuse(ggml_type type = GGML_TYPE_F32, - std::array ne = {256, 192}) + std::array ne = {256, 192, 1, 1}) : type(type), ne(ne) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * x = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]); + ggml_tensor * x = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]); ggml_set_name(x, "x"); ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, ne[1]); @@ -7558,11 +7558,15 @@ static std::vector> make_test_cases_eval() { // SNAKE activation fusion: x + sin(a*x)^2 * inv_b for (ggml_type type : { GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16 }) { - test_cases.emplace_back(new test_snake_fuse(type, { 5, 7})); // primes sub-block - test_cases.emplace_back(new test_snake_fuse(type, { 33, 32})); // boundary - test_cases.emplace_back(new test_snake_fuse(type, {1025, 13})); // large prime, grid-stride - test_cases.emplace_back(new test_snake_fuse(type, { 128, 16})); // power-of-two - test_cases.emplace_back(new test_snake_fuse(type, { 256, 192})); // BigVGAN-ish + test_cases.emplace_back(new test_snake_fuse(type, { 5, 7, 1, 1})); // primes sub-block + test_cases.emplace_back(new test_snake_fuse(type, { 33, 32, 1, 1})); // boundary + test_cases.emplace_back(new test_snake_fuse(type, {1025, 13, 1, 1})); // large prime, grid-stride + test_cases.emplace_back(new test_snake_fuse(type, { 128, 16, 1, 1})); // power-of-two + test_cases.emplace_back(new test_snake_fuse(type, { 256, 192, 1, 1})); // BigVGAN-ish + // higher-rank shapes: matcher must reject fusion, fallback to naive chain + test_cases.emplace_back(new test_snake_fuse(type, { 64, 32, 2, 1})); // ne[2] > 1 + test_cases.emplace_back(new test_snake_fuse(type, { 64, 32, 1, 2})); // ne[3] > 1 + test_cases.emplace_back(new test_snake_fuse(type, { 64, 32, 2, 3})); // ne[2] > 1 and ne[3] > 1 } // glu ops @@ -9093,9 +9097,9 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1})); // SNAKE activation fusion at BigVGAN scale (T=7680 = 24 kHz x 320 ms, C=192) - test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_F32, {7680, 192})); - test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_F16, {7680, 192})); - test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_BF16, {7680, 192})); + test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_F32, {7680, 192, 1, 1})); + test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_F16, {7680, 192, 1, 1})); + test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_BF16, {7680, 192, 1, 1})); test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3})); test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, 2*16416)); diff --git a/tools/cli/README.md b/tools/cli/README.md index bca4da7efbb..02c564a2906 100644 --- a/tools/cli/README.md +++ b/tools/cli/README.md @@ -195,11 +195,9 @@ | `--spec-draft-n-min N` | minimum number of draft tokens to use for speculative decoding (default: 0)
(env: LLAMA_ARG_SPEC_DRAFT_N_MIN) | | `--spec-draft-p-split, --draft-p-split P` | speculative decoding split probability (default: 0.10)
(env: LLAMA_ARG_SPEC_DRAFT_P_SPLIT) | | `--spec-draft-p-min, --draft-p-min P` | minimum speculative decoding probability (greedy) (default: 0.75)
(env: LLAMA_ARG_SPEC_DRAFT_P_MIN) | -| `--spec-draft-ctx-size, -cd, --ctx-size-draft N` | size of the prompt context for the draft model (default: 0, 0 = loaded from model)
(env: LLAMA_ARG_SPEC_DRAFT_CTX_SIZE) | | `--spec-draft-device, -devd, --device-draft ` | comma-separated list of devices to use for offloading the draft model (none = don't offload)
use --list-devices to see a list of available devices | | `--spec-draft-ngl, -ngld, --gpu-layers-draft, --n-gpu-layers-draft N` | max. number of draft model layers to store in VRAM, either an exact number, 'auto', or 'all' (default: auto)
(env: LLAMA_ARG_N_GPU_LAYERS_DRAFT) | | `--spec-draft-model, -md, --model-draft FNAME` | draft model for speculative decoding (default: unused)
(env: LLAMA_ARG_SPEC_DRAFT_MODEL) | -| `--spec-draft-replace, --spec-replace TARGET DRAFT` | translate the string in TARGET into DRAFT if the draft model and main model are not compatible | | `--spec-type [none\|ngram-cache\|ngram-simple\|ngram-map-k\|ngram-map-k4v\|ngram-mod]` | type of speculative decoding to use when no draft model is provided (default: none)

(env: LLAMA_ARG_SPEC_TYPE) | | `--spec-ngram-mod-n-min N` | minimum number of ngram tokens to use for ngram-based speculative decoding (default: 48) | | `--spec-ngram-mod-n-max N` | maximum number of ngram tokens to use for ngram-based speculative decoding (default: 64) | diff --git a/tools/server/README.md b/tools/server/README.md index 024760da6b6..7f856faa813 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -244,11 +244,9 @@ For the full list of features, please refer to [server's changelog](https://gith | `--spec-draft-n-min N` | minimum number of draft tokens to use for speculative decoding (default: 0)
(env: LLAMA_ARG_SPEC_DRAFT_N_MIN) | | `--spec-draft-p-split, --draft-p-split P` | speculative decoding split probability (default: 0.10)
(env: LLAMA_ARG_SPEC_DRAFT_P_SPLIT) | | `--spec-draft-p-min, --draft-p-min P` | minimum speculative decoding probability (greedy) (default: 0.75)
(env: LLAMA_ARG_SPEC_DRAFT_P_MIN) | -| `--spec-draft-ctx-size, -cd, --ctx-size-draft N` | size of the prompt context for the draft model (default: 0, 0 = loaded from model)
(env: LLAMA_ARG_SPEC_DRAFT_CTX_SIZE) | | `--spec-draft-device, -devd, --device-draft ` | comma-separated list of devices to use for offloading the draft model (none = don't offload)
use --list-devices to see a list of available devices | | `--spec-draft-ngl, -ngld, --gpu-layers-draft, --n-gpu-layers-draft N` | max. number of draft model layers to store in VRAM, either an exact number, 'auto', or 'all' (default: auto)
(env: LLAMA_ARG_N_GPU_LAYERS_DRAFT) | | `--spec-draft-model, -md, --model-draft FNAME` | draft model for speculative decoding (default: unused)
(env: LLAMA_ARG_SPEC_DRAFT_MODEL) | -| `--spec-draft-replace, --spec-replace TARGET DRAFT` | translate the string in TARGET into DRAFT if the draft model and main model are not compatible | | `--spec-type [none\|ngram-cache\|ngram-simple\|ngram-map-k\|ngram-map-k4v\|ngram-mod]` | type of speculative decoding to use when no draft model is provided (default: none)

(env: LLAMA_ARG_SPEC_TYPE) | | `--spec-ngram-mod-n-min N` | minimum number of ngram tokens to use for ngram-based speculative decoding (default: 48) | | `--spec-ngram-mod-n-max N` | maximum number of ngram tokens to use for ngram-based speculative decoding (default: 64) | @@ -1045,16 +1043,23 @@ If query param `?fail_on_no_slot=1` is set, this endpoint will respond with stat This endpoint is only accessible if `--metrics` is set. -Available metrics: -- `llamacpp:prompt_tokens_total`: Number of prompt tokens processed. -- `llamacpp:tokens_predicted_total`: Number of generation tokens processed. -- `llamacpp:prompt_tokens_seconds`: Average prompt throughput in tokens/s. -- `llamacpp:predicted_tokens_seconds`: Average generation throughput in tokens/s. -- `llamacpp:kv_cache_usage_ratio`: KV-cache usage. `1` means 100 percent usage. -- `llamacpp:kv_cache_tokens`: KV-cache tokens. -- `llamacpp:requests_processing`: Number of requests processing. -- `llamacpp:requests_deferred`: Number of requests deferred. -- `llamacpp:n_tokens_max`: High watermark of the context size observed. +In *router mode* the query param `?model={model_id}` has to be set. This endpoint will respond with status code 400 `model name is missing from the request` if not set. + +#### Available metrics + +| Metric | Type | Description | +| ------ | ---------------------- | ----------- | +| `llamacpp:prompt_tokens_total` | Counter | Number of prompt tokens processed. | +| `llamacpp:prompt_seconds_total` | Counter | Prompt process time in seconds. | +| `llamacpp:prompt_tokens_seconds` | Gauge | Average prompt throughput in tokens/s. | +| `llamacpp:tokens_predicted_total` | Counter | Number of generation tokens processed. | +| `llamacpp:tokens_predicted_seconds_total` | Counter | Predict process time in seconds. | +| `llamacpp:predicted_tokens_seconds` | Gauge | Average generation throughput in tokens/s. | +| `llamacpp:requests_processing` | Gauge | Number of requests processing. | +| `llamacpp:requests_deferred` | Gauge | Number of requests deferred. | +| `llamacpp:n_tokens_max` | Counter | High watermark of the context size observed. | +| `llamacpp:n_decode_total` | Counter | Total Number of llama_decode() calls. | +| `llamacpp:n_busy_slots_per_decode` | Gauge | Average number of busy slots per llama_decode() call. | ### POST `/slots/{id_slot}?action=save`: Save the prompt cache of the specified slot to a file. diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 637f8d21647..ce743e6656d 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -36,32 +36,6 @@ using json = nlohmann::ordered_json; constexpr int HTTP_POLLING_SECONDS = 1; -static void server_prompt_checkpoint_update(server_prompt_checkpoint & ckpt, llama_context * ctx, int id, int64_t n_tokens, bool on_device, llama_pos pos_min = -1, llama_pos pos_max = -1) { - if (pos_min == -1) { - pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), id); - } - if (pos_max == -1) { - pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), id); - } - - auto flags = LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY; - if (on_device) { - flags |= LLAMA_STATE_SEQ_FLAGS_ON_DEVICE; - } - - const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, id, flags); - - ckpt.pos_min = pos_min; - ckpt.pos_max = pos_max; - ckpt.n_tokens = n_tokens; - ckpt.data.resize(checkpoint_size); - - const size_t n = llama_state_seq_get_data_ext(ctx, ckpt.data.data(), checkpoint_size, id, flags); - if (n != checkpoint_size) { - GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n); - } -} - // state diagram: https://github.com/ggml-org/llama.cpp/pull/9283 enum slot_state { SLOT_STATE_IDLE, @@ -80,18 +54,19 @@ enum server_state { struct server_slot { int id; - llama_context * ctx = nullptr; - - common_context_seq_rm_type ctx_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO; + llama_context * ctx_tgt = nullptr; + llama_context * ctx_dft = nullptr; // multimodal mtmd_context * mctx = nullptr; // speculative decoding + common_speculative * spec; + llama_tokens spec_draft; + llama_tokens spec_prompt; std::vector spec_i_batch; - server_prompt_checkpoint spec_ckpt; - common_speculative_ptr spec; + common_prompt_checkpoint spec_ckpt; // TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state // see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837 @@ -135,21 +110,27 @@ struct server_slot { void prompt_save(server_prompt_cache & prompt_cache) const { GGML_ASSERT(prompt.data.size() == 0); - const size_t cur_size = llama_state_seq_get_size_ext(ctx, id, 0); + const size_t cur_size_tgt = llama_state_seq_get_size_ext(ctx_tgt, id, LLAMA_STATE_SEQ_FLAGS_NONE); + const size_t cur_size_dft = ctx_dft ? llama_state_seq_get_size_ext(ctx_dft, id, LLAMA_STATE_SEQ_FLAGS_NONE) : 0; - SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB\n", - (int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0)); + const size_t cur_size = cur_size_tgt + cur_size_dft; - auto * cur = prompt_cache.alloc(prompt, cur_size); + SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB (draft: %.3f MiB)\n", + (int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0), cur_size_dft / (1024.0 * 1024.0)); + + auto * cur = prompt_cache.alloc(prompt, cur_size_tgt, cur_size_dft); if (cur == nullptr) { return; } - llama_state_seq_get_data_ext(ctx, cur->data.data(), cur_size, id, 0); + llama_state_seq_get_data_ext(ctx_tgt, cur->data.main.data(), cur_size_tgt, id, LLAMA_STATE_SEQ_FLAGS_NONE); + if (ctx_dft) { + llama_state_seq_get_data_ext(ctx_dft, cur->data.drft.data(), cur_size_dft, id, LLAMA_STATE_SEQ_FLAGS_NONE); + } } bool prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) { - bool res = prompt_cache.load(prompt, tokens, ctx, id); + bool res = prompt_cache.load(prompt, tokens, ctx_tgt, ctx_dft, id); if (!res) { SLT_WRN(*this, "%s", "failed to load prompt from cache\n"); } @@ -164,7 +145,11 @@ struct server_slot { SLT_INF(*this, "clearing prompt with %zu tokens\n", prompt.tokens.size()); - llama_memory_seq_rm(llama_get_memory(ctx), id, -1, -1); + llama_memory_seq_rm(llama_get_memory(ctx_tgt), id, -1, -1); + if (ctx_dft) { + llama_memory_seq_rm(llama_get_memory(ctx_dft), id, -1, -1); + } + prompt.tokens.clear(); } @@ -222,7 +207,7 @@ struct server_slot { task_prev = std::move(task); task.reset(); - llama_set_sampler(ctx, id, nullptr); + llama_set_sampler(ctx_tgt, id, nullptr); // clear alora start alora_invocation_start = -1; @@ -259,7 +244,7 @@ struct server_slot { return !task->need_embd() || - (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST); + (llama_get_memory(ctx_tgt) && llama_pooling_type(ctx_tgt) == LLAMA_POOLING_TYPE_LAST); } bool can_batch_with(server_slot & other_slot) const { @@ -310,14 +295,10 @@ struct server_slot { return 0; } - const int n_draft_min = common_speculative_n_min(spec.get(), task->params.speculative); - // determine the max draft that fits the current slot state - int n_draft_max = common_speculative_n_max(spec.get(), task->params.speculative); - // note: slot.prompt is not yet expanded with the `id` token sampled above // also, need to leave space for 1 extra token to allow context shifts - n_draft_max = std::min(n_draft_max, n_ctx - prompt.n_tokens() - 2); + int n_draft_max = n_ctx - prompt.n_tokens() - 2; if (n_remaining > 0) { n_draft_max = std::min(n_draft_max, n_remaining - 1); @@ -325,61 +306,10 @@ struct server_slot { SLT_DBG(*this, "max possible draft: %d\n", n_draft_max); - if (n_draft_max < n_draft_min) { - SLT_DBG(*this, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, n_draft_min); - n_draft_max = 0; - } - return n_draft_max; } void update_batch(llama_batch & batch) { - const int n_draft_max = get_n_draft_max(); - if (n_draft_max > 0) { - GGML_ASSERT(can_speculate()); - - // generate draft tokens in speculative decoding mode - // TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK] - // perform the speculative drafting for all sequences at the same time in a single batch - const llama_tokens & tokens = prompt.tokens.get_text_tokens(); - - const auto & params_spec = task->params.speculative; - - if (!spec_draft.empty()) { - // we have a previous (partial) draft to reuse - if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) { - GGML_ASSERT(!spec_ckpt.empty()); - } - } else { - GGML_ASSERT(spec_i_batch.empty()); - - // generate a new draft - spec_draft = common_speculative_draft(spec.get(), params_spec, tokens, sampled); - n_draft_total += spec_draft.size(); - - if (spec_draft.size() > (size_t) n_draft_max) { - SLT_WRN(*this, "draft size %d exceeds max %d, truncating\n", (int) spec_draft.size(), n_draft_max); - spec_draft.resize(n_draft_max); - } - - if (!spec_draft.empty() && ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) { - const auto n_tokens = prompt.tokens.size(); - - //const int64_t t_start = ggml_time_us(); - - server_prompt_checkpoint_update(spec_ckpt, ctx, this->id, n_tokens, true); - - //const int64_t t_total = ggml_time_us() - t_start; - //printf("checkpoint total: %f ms\n", t_total / 1000.0); - - SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %zu, size = %.3f MiB)\n", - spec_ckpt.pos_min, spec_ckpt.pos_max, n_tokens, (float) spec_ckpt.data.size() / 1024 / 1024); - } - } - - GGML_ASSERT(spec_draft.size() <= (size_t) n_draft_max); - } - if (spec_draft.empty()) { // no speculative decoding i_batch = batch.n_tokens; @@ -511,7 +441,7 @@ struct server_slot { ); } - common_speculative_print_stats(spec.get()); + common_speculative_print_stats(spec); } json to_json(bool only_metrics = false) const { @@ -539,7 +469,7 @@ struct server_slot { }; if (!only_metrics) { - res["prompt"] = ptask->tokens.detokenize(ctx, true); + res["prompt"] = ptask->tokens.detokenize(ctx_tgt, true); res["generated"] = generated_text.empty() ? debug_generated_text : generated_text; } } @@ -550,8 +480,13 @@ struct server_slot { void copy_state_to(server_slot & other) const { GGML_ASSERT(state == SLOT_STATE_DONE_PROMPT); - llama_memory_seq_rm(llama_get_memory(ctx), other.id, -1, -1); - llama_memory_seq_cp(llama_get_memory(ctx), id, other.id, -1, -1); + llama_memory_seq_rm(llama_get_memory(ctx_tgt), other.id, -1, -1); + llama_memory_seq_cp(llama_get_memory(ctx_tgt), id, other.id, -1, -1); + + if (ctx_dft) { + llama_memory_seq_rm(llama_get_memory(ctx_dft), other.id, -1, -1); + llama_memory_seq_cp(llama_get_memory(ctx_dft), id, other.id, -1, -1); + } other.n_decoded = n_decoded; other.n_remaining = n_remaining; @@ -642,7 +577,8 @@ struct server_context_impl { // only use these pointers outside of this class: // - when not in sleeping state // - and, with thread-safe APIs (e.g., tokenizer calls) - llama_model * model = nullptr; + llama_model * model_tgt = nullptr; + mtmd_context * mctx = nullptr; const llama_vocab * vocab = nullptr; @@ -669,11 +605,17 @@ struct server_context_impl { // note: keep these alive - they determine the lifetime of the model, context, etc. common_init_result_ptr llama_init; - llama_context * ctx = nullptr; + llama_context * ctx_tgt = nullptr; llama_batch batch {}; llama_model_ptr model_dft; + llama_context_ptr ctx_dft; + + common_context_seq_rm_type ctx_tgt_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO; + common_context_seq_rm_type ctx_dft_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO; + + common_speculative_ptr spec; bool add_bos_token = true; @@ -708,18 +650,12 @@ struct server_context_impl { void destroy() { llama_init.reset(); - ctx = nullptr; - model = nullptr; + ctx_tgt = nullptr; + model_tgt = nullptr; mtmd_free(mctx); mctx = nullptr; - for (server_slot & slot : slots) { - if (slot.can_speculate()) { - slot.spec.reset(); - } - } - llama_batch_free(batch); } @@ -759,17 +695,17 @@ struct server_context_impl { llama_init = common_init_from_params(params_base); - model = llama_init->model(); - ctx = llama_init->context(); + model_tgt = llama_init->model(); + ctx_tgt = llama_init->context(); - if (model == nullptr) { + if (model_tgt == nullptr) { SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str()); return false; } - vocab = llama_model_get_vocab(model); + vocab = llama_model_get_vocab(model_tgt); - n_ctx = llama_n_ctx(ctx); + n_ctx = llama_n_ctx(ctx_tgt); add_bos_token = llama_vocab_get_add_bos(vocab); @@ -781,9 +717,6 @@ struct server_context_impl { auto params_dft = params_base; - params_dft.n_parallel = 1; - params_dft.n_ctx = params_spec.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_spec.n_ctx; - params_dft.n_batch = llama_n_ctx_seq(ctx); params_dft.devices = params_spec.devices; params_dft.model = params_spec.mparams; params_dft.n_gpu_layers = params_spec.n_gpu_layers; @@ -805,8 +738,13 @@ struct server_context_impl { return false; } - params_base.speculative.draft.model = model_dft.get(); - params_base.speculative.draft.cparams = common_context_params_to_llama(params_dft); + auto cparams = common_context_params_to_llama(params_dft); + ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams)); + + ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get()); + + params_base.speculative.draft.ctx_tgt = ctx_tgt; + params_base.speculative.draft.ctx_dft = ctx_dft.get(); } std::string & mmproj_path = params_base.mmproj.path; @@ -826,7 +764,7 @@ struct server_context_impl { mparams.image_max_tokens = params_base.image_max_tokens; mparams.media_marker = get_media_marker(); - mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams); + mctx = mtmd_init_from_file(mmproj_path.c_str(), model_tgt, mparams); if (mctx == nullptr) { SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str()); return false; @@ -844,7 +782,7 @@ struct server_context_impl { } } - if (!llama_memory_can_shift(llama_get_memory(ctx))) { + if (!llama_memory_can_shift(llama_get_memory(ctx_tgt))) { if (params_base.ctx_shift) { params_base.ctx_shift = false; SRV_WRN("%s\n", "ctx_shift is not supported by this context, it will be disabled"); @@ -856,14 +794,14 @@ struct server_context_impl { } } - if (llama_model_n_swa(model) == 0) { + if (llama_model_n_swa(model_tgt) == 0) { if (params_base.swa_full) { params_base.swa_full = false; SRV_WRN("%s\n", "swa_full is not supported by this model, it will be disabled"); } } - n_swa = params_base.swa_full ? 0 : llama_model_n_swa(model); + n_swa = params_base.swa_full ? 0 : llama_model_n_swa(model_tgt); // Necessary similarity of prompt for slot selection slot_prompt_similarity = params_base.slot_prompt_similarity; @@ -871,9 +809,9 @@ struct server_context_impl { // setup slots SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel); - const int n_ctx_train = llama_model_n_ctx_train(model); + const int n_ctx_train = llama_model_n_ctx_train(model_tgt); - int n_ctx_slot = llama_n_ctx_seq(ctx); + int n_ctx_slot = llama_n_ctx_seq(ctx_tgt); if (n_ctx_slot > n_ctx_train) { SRV_WRN("the slot context (%d) exceeds the training context of the model (%d) - capping\n", n_ctx_slot, n_ctx_train); n_ctx_slot = n_ctx_train; @@ -881,12 +819,12 @@ struct server_context_impl { slots.clear(); - const auto ctx_seq_rm_type = common_context_can_seq_rm(ctx); - if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_NO) { + ctx_tgt_seq_rm_type = common_context_can_seq_rm(ctx_tgt); + if (ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_NO) { SRV_WRN("%s", "speculative decoding not supported by this context\n"); } - if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) { + if (ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) { SRV_WRN("%s", "speculative decoding will use checkpoints\n"); } @@ -895,27 +833,33 @@ struct server_context_impl { slots.emplace_back(); } + // try speculative decoding + if (ctx_tgt_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) { + try { + spec.reset(common_speculative_init(params_base.speculative, params_base.n_parallel)); + } catch (const std::exception & e) { + SRV_ERR("failed to initialize speculative decoding context: %s\n", e.what()); + } + } + + if (spec) { + SRV_INF("%s", "speculative decoding context initialized\n"); + } else { + ctx_dft.reset(); + } + for (int i = 0; i < params_base.n_parallel; i++) { server_slot & slot = slots[i]; - slot.id = i; - slot.ctx = ctx; - slot.n_ctx = n_ctx_slot; - - slot.ctx_seq_rm_type = ctx_seq_rm_type; + slot.id = i; + slot.ctx_tgt = ctx_tgt; + slot.ctx_dft = ctx_dft.get(); + slot.spec = spec.get(); + slot.n_ctx = n_ctx_slot; slot.mctx = mctx; slot.prompt.tokens.has_mtmd = mctx != nullptr; - // try speculative decoding - if (ctx_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) { - slot.spec.reset(common_speculative_init(params_base.speculative, slot.ctx)); - - if (slot.spec) { - SLT_INF(slot, "%s", "speculative decoding context initialized\n"); - } - } - SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx); slot.callback_on_release = [this](int id_slot) { @@ -946,7 +890,7 @@ struct server_context_impl { // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) { - const int32_t n_batch = llama_n_batch(ctx); + const int32_t n_batch = llama_n_batch(ctx_tgt); batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1); } @@ -990,8 +934,9 @@ struct server_context_impl { // unlike load_model(), this is only called once during initialization bool init() { - GGML_ASSERT(ctx != nullptr); - GGML_ASSERT(model != nullptr); + GGML_ASSERT(ctx_tgt != nullptr); + GGML_ASSERT(model_tgt != nullptr); + GGML_ASSERT(!sleeping); // wiring up server queues @@ -1037,7 +982,7 @@ struct server_context_impl { common_chat_templates_ptr chat_templates; try { - chat_templates = common_chat_templates_init(model, params_base.chat_template); + chat_templates = common_chat_templates_init(model_tgt, params_base.chat_template); LOG_INF("%s: chat template, example_format: '%s'\n", __func__, common_chat_format_example(chat_templates.get(), params_base.use_jinja, params_base.default_template_kwargs).c_str()); @@ -1300,7 +1245,7 @@ struct server_context_impl { } } - if (!task.tokens.validate(ctx)) { + if (!task.tokens.validate(ctx_tgt)) { send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST); return false; } @@ -1310,7 +1255,7 @@ struct server_context_impl { // initialize samplers if (task.need_sampling()) { try { - slot.smpl.reset(common_sampler_init(model, task.params.sampling)); + slot.smpl.reset(common_sampler_init(model_tgt, task.params.sampling)); } catch (std::exception & e) { std::string err_msg = std::string("Failed to initialize samplers: ") + e.what(); send_error(task, err_msg, ERROR_TYPE_INVALID_REQUEST); @@ -1324,16 +1269,16 @@ struct server_context_impl { backend_sampling &= task.params.sampling.backend_sampling; // TODO: speculative decoding requires multiple samples per batch - not supported yet - backend_sampling &= !(slot.can_speculate() && common_speculative_n_max(slot.spec.get(), task.params.speculative) > 0); + backend_sampling &= !(slot.can_speculate()); // TODO: getting pre sampling logits is not yet supported with backend sampling backend_sampling &= !need_pre_sample_logits; // TODO: tmp until backend sampling is fully implemented if (backend_sampling) { - llama_set_sampler(ctx, slot.id, common_sampler_get(slot.smpl.get())); + llama_set_sampler(ctx_tgt, slot.id, common_sampler_get(slot.smpl.get())); } else { - llama_set_sampler(ctx, slot.id, nullptr); + llama_set_sampler(ctx_tgt, slot.id, nullptr); } SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str()); @@ -1512,13 +1457,13 @@ struct server_context_impl { result.probs.push_back({ cur_p->data[i].id, - common_token_to_piece(ctx, cur_p->data[i].id, special), + common_token_to_piece(ctx_tgt, cur_p->data[i].id, special), cur_p->data[i].p }); } } else { // TODO: optimize this with min-p optimization - std::vector cur = get_token_probabilities(ctx, idx); + std::vector cur = get_token_probabilities(ctx_tgt, idx); const size_t max_probs = cur.size(); const size_t n_probs = std::min(max_probs, n_probs_request); @@ -1536,7 +1481,7 @@ struct server_context_impl { for (size_t i = 0; i < n_probs; i++) { result.probs.push_back({ cur[i].id, - common_token_to_piece(ctx, cur[i].id, special), + common_token_to_piece(ctx_tgt, cur[i].id, special), cur[i].p }); } @@ -1639,7 +1584,7 @@ struct server_context_impl { res->tokens = std::move(slot.generated_tokens); } res->timings = slot.get_timings(); - res->prompt = slot.task->tokens.detokenize(ctx, true); + res->prompt = slot.task->tokens.detokenize(ctx_tgt, true); res->response_fields = std::move(slot.task->params.response_fields); res->truncated = slot.truncated; @@ -1662,7 +1607,7 @@ struct server_context_impl { // populate res.probs_output if (slot.task->params.sampling.n_probs > 0) { if (!slot.task->params.stream && slot.stop == STOP_TYPE_WORD) { - const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false); + const llama_tokens stop_word_toks = common_tokenize(ctx_tgt, slot.stopping_word, false); size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size()); res->probs_output = std::vector( @@ -1687,7 +1632,7 @@ struct server_context_impl { res->n_tokens = slot.task->n_tokens(); res->res_type = slot.task->params.res_type; - const int n_embd_out = llama_model_n_embd_out(model); + const int n_embd_out = llama_model_n_embd_out(model_tgt); std::vector embd_res(n_embd_out, 0.0f); @@ -1697,10 +1642,10 @@ struct server_context_impl { } const float * embd = nullptr; - if (llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE) { - embd = llama_get_embeddings_ith(ctx, i); + if (llama_pooling_type(slot.ctx_tgt) == LLAMA_POOLING_TYPE_NONE) { + embd = llama_get_embeddings_ith(slot.ctx_tgt, i); } else { - embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + embd = llama_get_embeddings_seq(slot.ctx_tgt, batch.seq_id[i][0]); } if (embd == nullptr) { @@ -1711,7 +1656,7 @@ struct server_context_impl { } // normalize only when there is pooling - if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) { + if (llama_pooling_type(slot.ctx_tgt) != LLAMA_POOLING_TYPE_NONE) { common_embd_normalize(embd, embd_res.data(), n_embd_out, slot.task->params.embd_normalize); res->embedding.push_back(embd_res); break; @@ -1736,9 +1681,9 @@ struct server_context_impl { continue; } - const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + const float * embd = llama_get_embeddings_seq(ctx_tgt, batch.seq_id[i][0]); if (embd == NULL) { - embd = llama_get_embeddings_ith(ctx, i); + embd = llama_get_embeddings_ith(ctx_tgt, i); } if (embd == NULL) { @@ -1843,18 +1788,22 @@ struct server_context_impl { const auto & cur = slot.prompt.checkpoints.front(); SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", - cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024); + cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.size() / 1024 / 1024); slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin()); } auto & cur = slot.prompt.checkpoints.emplace_back(); - server_prompt_checkpoint_update(cur, ctx, slot.id, slot.prompt.n_tokens() - n_tokens_cur, false, pos_min, pos_max); + + cur.update_pos(slot.prompt.n_tokens() - n_tokens_cur, pos_min, pos_max); + + cur.update_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + cur.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, - cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024); + cur.pos_max, cur.n_tokens, (float) cur.size() / 1024 / 1024); } void process_single_task(server_task && task) { @@ -2009,7 +1958,7 @@ struct server_context_impl { std::string filepath = task.slot_action.filepath; const llama_tokens & tokens = slot->prompt.tokens.get_tokens(); - const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, tokens.data(), token_count); + const size_t nwrite = llama_state_seq_save_file(ctx_tgt, filepath.c_str(), slot->id, tokens.data(), token_count); const int64_t t_end = ggml_time_us(); const double t_save_ms = (t_end - t_start) / 1000.0; @@ -2048,7 +1997,7 @@ struct server_context_impl { llama_tokens tokens; tokens.resize(slot->n_ctx); size_t token_count = 0; - size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count); + size_t nread = llama_state_seq_load_file(ctx_tgt, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count); if (nread == 0) { slot->prompt.tokens.clear(); // KV may already been invalidated? send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); @@ -2207,8 +2156,13 @@ struct server_context_impl { SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard); - llama_memory_seq_rm (llama_get_memory(ctx), slot.id, n_keep , n_keep + n_discard); - llama_memory_seq_add(llama_get_memory(ctx), slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard); + llama_memory_seq_rm (llama_get_memory(ctx_tgt), slot.id, n_keep , n_keep + n_discard); + llama_memory_seq_add(llama_get_memory(ctx_tgt), slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard); + + if (ctx_dft) { + llama_memory_seq_rm (llama_get_memory(ctx_dft.get()), slot.id, n_keep , n_keep + n_discard); + llama_memory_seq_add(llama_get_memory(ctx_dft.get()), slot.id, n_keep + n_discard, slot.prompt.tokens.pos_next(), -n_discard); + } // add generated tokens to cache // ref: https://github.com/ggml-org/llama.cpp/pull/16818#discussion_r2473269481 @@ -2236,12 +2190,10 @@ struct server_context_impl { // track if given slot can be batched with slots already in the batch server_slot * slot_batched = nullptr; - auto accept_special_token = [&](server_slot & slot, llama_token token) { - return params_base.special || - slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end(); - }; + std::vector generating; + std::vector drafting; - // first, add sampled tokens from any ongoing sequences + // determine which slots are generating and drafting for (auto & slot : slots) { if (slot.state != SLOT_STATE_GENERATING) { continue; @@ -2254,12 +2206,103 @@ struct server_context_impl { continue; } + generating.push_back(&slot); + + if (spec) { + common_speculative_get_draft_params(spec.get(), slot.id).drafting = false; + + const bool use_ckpt_tgt = ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; + const bool use_ckpt_dft = ctx_dft_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; + + const int n_draft_max = slot.get_n_draft_max(); + + if (n_draft_max > 0) { + GGML_ASSERT(slot.can_speculate()); + + if (!slot.spec_draft.empty()) { + // we have a previous (partial) draft to reuse + if (use_ckpt_tgt) { + GGML_ASSERT(!slot.spec_ckpt.empty()); + } + } else { + GGML_ASSERT(slot.spec_i_batch.empty()); + + slot.spec_ckpt.update_pos( + slot.prompt.n_tokens(), + llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id), + llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id)); + + if (use_ckpt_dft) { + slot.spec_ckpt.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + } + + slot.spec_prompt = slot.prompt.tokens.get_text_tokens(); + + common_speculative_get_draft_params(spec.get(), slot.id) = { + /* .drafting = */ true, + /* .n_max = */ n_draft_max, + /* .n_past = */ slot.prompt.n_tokens(), + /* .id_last = */ slot.sampled, + /* .prompt = */ &slot.spec_prompt, + /* .result = */ &slot.spec_draft, + }; + + drafting.push_back(&slot); + } + } + } + } + + // generate the actual drafts (if any) + { + common_speculative_draft(spec.get()); + } + + // make checkpoints if needed + for (auto * slot_ptr : drafting) { + auto & slot = *slot_ptr; + + auto & draft = slot.spec_draft; + auto & ckpt = slot.spec_ckpt; + + slot.n_draft_total += draft.size(); + + // TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL] + if (ctx_dft) { + ckpt.load_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + + llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), slot.id, ckpt.pos_max + 1, -1); + } + + if (!draft.empty()) { + const bool use_ckpt_tgt = ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; + + if (use_ckpt_tgt) { + //const int64_t t_start = ggml_time_us(); + + ckpt.update_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + + //const int64_t t_total = ggml_time_us() - t_start; + //printf("checkpoint total: %f ms\n", t_total / 1000.0); + + SLT_DBG(slot, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %d, size = %.3f MiB, draft = %.3f MiB)\n", + ckpt.pos_min, ckpt.pos_max, slot.prompt.n_tokens(), + (float) ckpt.size() / 1024 / 1024, + (float) ckpt.data_dft.size() / 1024 / 1024); + } + } + } + + // update the batch with the sampled/drafted tokens + for (auto * slot_ptr : generating) { + auto & slot = *slot_ptr; + slot.update_batch(batch); } // process in chunks of params.n_batch - int32_t n_batch = llama_n_batch(ctx); - int32_t n_ubatch = llama_n_ubatch(ctx); + int32_t n_batch = llama_n_batch(ctx_tgt); + int32_t n_ubatch = llama_n_ubatch(ctx_tgt); float alora_scale = -1.0f; size_t alora_disabled_id = 0; @@ -2303,12 +2346,12 @@ struct server_context_impl { /*if (1) { // first 16 tokens (avoid flooding logs) for (int i = 0; i < std::min(16, input_tokens.size()); i++) { - SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str()); + SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx_tgt, input_tokens[i]).c_str()); } } else { // all for (int i = 0; i < (int) input_tokens.size(); i++) { - SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str()); + SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx_tgt, input_tokens[i]).c_str()); } }*/ @@ -2327,7 +2370,7 @@ struct server_context_impl { } // TODO: support memory-less logits computation - if (slot.task->need_logits() && !llama_get_memory(ctx)) { + if (slot.task->need_logits() && !llama_get_memory(ctx_tgt)) { send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER); slot.release(); continue; @@ -2379,7 +2422,7 @@ struct server_context_impl { const auto n_cache_reuse = slot.task->params.n_cache_reuse; const bool can_cache_reuse = - llama_memory_can_shift(llama_get_memory(ctx)) && + llama_memory_can_shift(llama_get_memory(ctx_tgt)) && !slot.prompt.tokens.has_mtmd; if (!can_cache_reuse && n_cache_reuse > 0) { @@ -2413,13 +2456,18 @@ struct server_context_impl { if (n_match >= (size_t) n_cache_reuse) { SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match); //for (size_t i = head_p; i < head_p + n_match; i++) { - // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); + // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx_tgt, prompt_tokens[i]).c_str()); //} const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c; - llama_memory_seq_rm (llama_get_memory(ctx), slot.id, head_p, head_c); - llama_memory_seq_add(llama_get_memory(ctx), slot.id, head_c, head_c + n_match, kv_shift); + llama_memory_seq_rm (llama_get_memory(ctx_tgt), slot.id, head_p, head_c); + llama_memory_seq_add(llama_get_memory(ctx_tgt), slot.id, head_c, head_c + n_match, kv_shift); + + if (ctx_dft) { + llama_memory_seq_rm (llama_get_memory(ctx_dft.get()), slot.id, head_p, head_c); + llama_memory_seq_add(llama_get_memory(ctx_dft.get()), slot.id, head_c, head_c + n_match, kv_shift); + } for (size_t i = 0; i < n_match; i++) { slot.prompt.tokens.set_token(head_p + i, slot.prompt.tokens[head_c + i]); @@ -2446,7 +2494,7 @@ struct server_context_impl { const auto pos_min_thold = std::max(0, pos_next - n_swa); if (n_past > 0 && n_past < slot.prompt.n_tokens()) { - const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); + const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id); if (pos_min == -1) { SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min); GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237"); @@ -2475,14 +2523,14 @@ struct server_context_impl { { const auto token = slot.prompt.tokens[i]; - const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]"; + const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx_tgt, token) : "[mtmd]"; ss0 << piece; st0 << std::setw(8) << token; } { const auto token = slot.task->tokens[i]; - const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]"; + const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx_tgt, token) : "[mtmd]"; ss1 << piece; st1 << std::setw(8) << token; } @@ -2514,18 +2562,13 @@ struct server_context_impl { if (!do_reset) { // restore the context checkpoint - const size_t checkpoint_size = it->data.size(); - const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - if (n != checkpoint_size) { - SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, (float) checkpoint_size / 1024 / 1024); - do_reset = true; - //printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint"); - } else { - pos_next = std::min(pos_next, std::max(it->pos_min + 1, it->pos_max)); - n_past = std::min(slot.prompt.tokens.size_up_to_pos(pos_next), (size_t) it->n_tokens); - SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_past = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, n_past, (float) checkpoint_size / 1024 / 1024); - } + it->load_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + it->load_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + + pos_next = std::min(pos_next, std::max(it->pos_min + 1, it->pos_max)); + n_past = std::min(slot.prompt.tokens.size_up_to_pos(pos_next), (size_t) it->n_tokens); + SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_past = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, n_past, (float) it->size() / 1024 / 1024); } if (do_reset) { @@ -2542,7 +2585,7 @@ struct server_context_impl { for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) { const auto & cur = *it; if (cur.pos_max > pos_next) { - SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_swa = %d, pos_next = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.n_tokens, n_swa, pos_next, (float) cur.data.size() / 1024 / 1024); + SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_swa = %d, pos_next = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.n_tokens, n_swa, pos_next, (float) cur.size() / 1024 / 1024); it = slot.prompt.checkpoints.erase(it); } else { ++it; @@ -2582,14 +2625,18 @@ struct server_context_impl { SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0); - if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) { + if (!llama_memory_seq_rm(llama_get_memory(ctx_tgt), slot.id, p0, -1)) { SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0); slot.prompt_clear(true); // there is no common part left slot.n_prompt_tokens_cache = 0; - } + } else { + if (ctx_dft && !llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), slot.id, p0, -1)) { + GGML_ABORT("failed to truncate draft context\n"); + } + } // If using an alora, there may be uncached tokens that come // before the invocation sequence. When this happens, the @@ -2615,7 +2662,7 @@ struct server_context_impl { // - the model does not support partial sequence removal // - the model uses SWA (and we are not using `swa_full`) do_checkpoint = do_checkpoint && ( - (slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) || + (ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) || (n_swa > 0)); bool has_mtmd = false; @@ -2624,7 +2671,7 @@ struct server_context_impl { while (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) { // process the image size_t n_tokens_out = 0; - int32_t res = input_tokens.process_chunk(ctx, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out); + int32_t res = input_tokens.process_chunk(ctx_tgt, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out); if (res != 0) { SLT_ERR(slot, "failed to process image, res = %d\n", res); send_error(slot, "failed to process image", ERROR_TYPE_SERVER); @@ -2632,6 +2679,16 @@ struct server_context_impl { continue; } + if (ctx_dft) { + // TODO: in the future, figure out how to infuse target embeddings to the images + // for now, we skip this for simplicity + // maybe we simply need to call `common_speculative_process()` on the mtmd batches in the `process_chunk` above? + res = input_tokens.process_chunk(ctx_dft.get(), mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out); + if (res != 0) { + GGML_ABORT("failed to process multi-modal data on draft context\n"); + } + } + slot.n_prompt_tokens_processed += n_tokens_out; // add the image chunk to cache @@ -2733,8 +2790,8 @@ struct server_context_impl { SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens()); } - const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); - const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id); + const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id); + const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id); // no need for empty or small checkpoints do_checkpoint = do_checkpoint && (pos_min >= 0 && slot.prompt.n_tokens() >= 64); @@ -2765,9 +2822,14 @@ struct server_context_impl { SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens); + auto accept_special_token = [&](server_slot & slot, llama_token token) { + return params_base.special || + slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end(); + }; + if (slot_batched) { // apply lora, only need to do it once per batch - common_set_adapter_lora(ctx, slot_batched->lora); + common_set_adapter_lora(ctx_tgt, slot_batched->lora); // if the lora is temporarily disabled for an alora, re-enable it // for next time @@ -2776,7 +2838,7 @@ struct server_context_impl { slot_batched->lora[alora_disabled_id].scale = alora_scale; } - llama_set_embeddings(ctx, slot_batched->task->need_embd()); + llama_set_embeddings(ctx_tgt, slot_batched->task->need_embd()); } if (batch.n_tokens == 0) { @@ -2805,7 +2867,7 @@ struct server_context_impl { batch.logits + i, }; - const int ret = llama_decode(ctx, batch_view); + const int ret = llama_decode(ctx_tgt, batch_view); metrics.on_decoded(slots); @@ -2858,11 +2920,63 @@ struct server_context_impl { continue; // continue loop of n_batch } + // TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL] + // for now, always re-evaluate for simplicity + // ref: https://github.com/ggml-org/llama.cpp/pull/22728#issuecomment-4400925384 + // + // | spec type | need re-eval | + // | --- | --- | + // | draft model | no | because the draft model does not use embeddings from the target + // | MTP (std) | yes | + // | MTP Gemma4 | no | because the KV cache is shared + // | Eagle3 | yes | + // | DFlash | yes | https://github.com/ggml-org/llama.cpp/pull/22728#issuecomment-4405406982 + // + // note: this logic is now moved in `common_speculative_process()` + // keeping the sketch here until for a bit, until the logic is finalized + // + //if (ctx_dft) { + // // TODO: update as needed for MTP, Eagle3, etc. + // const bool need_tgt_embd = false; + + // if (need_tgt_embd) { + // llama_synchronize(ctx_tgt); + // } + + // // the logic here varies depending on the speculative decoding method + // // - some draft contexts require embeddings from the target context, others don't + // // - some draft contexts involve an encoder step to transform the target embeddings to draft embeddings + // // TODO: extract this in a function ? + // { + // // TODO: hook the embeddings from the last target batch here + // if (llama_model_has_encoder(model_dft.get())) { + // //llama_encode(ctx_dft, ...); + + // GGML_ABORT("not implemented yet\n"); + // } + + // const int ret = llama_decode(ctx_dft.get(), batch_view); + + // if (ret != 0) { + // SRV_ERR("failed to decode draft batch, ret = %d\n", ret); + + // // TODO: handle error + // break; + // } + // } + //} + if (!common_speculative_process(spec.get(), batch_view)) { + SRV_ERR("%s", "failed to process speculative batch\n"); + + // TODO: handle error + break; + } + // move the head of the batch forward with the number of tokens we just processed i_next = i + n_tokens; // on successful decode, restore the original batch size - n_batch = llama_n_batch(ctx); + n_batch = llama_n_batch(ctx_tgt); // handle `n_cmpl > 1` tasks - when the main prompt is processed, activate all child tasks too for (auto & slot : slots) { @@ -2921,7 +3035,7 @@ struct server_context_impl { slot.state = SLOT_STATE_GENERATING; if (slot.can_speculate()) { - common_speculative_begin(slot.spec.get(), slot.prompt.tokens.get_text_tokens()); + common_speculative_begin(spec.get(), slot.id, slot.prompt.tokens.get_text_tokens()); } } else if (slot.state != SLOT_STATE_GENERATING) { continue; // continue loop of slots @@ -2933,7 +3047,7 @@ struct server_context_impl { const int tok_idx = slot.i_batch - i; - llama_token id = common_sampler_sample(slot.smpl.get(), slot.ctx, tok_idx); + llama_token id = common_sampler_sample(slot.smpl.get(), slot.ctx_tgt, tok_idx); slot.i_batch = -1; @@ -2954,7 +3068,7 @@ struct server_context_impl { completion_token_output result; result.tok = id; - result.text_to_send = common_token_to_piece(slot.ctx, result.tok, accept_special_token(slot, result.tok)); + result.text_to_send = common_token_to_piece(slot.ctx_tgt, result.tok, accept_special_token(slot, result.tok)); result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs if (slot.task->params.sampling.n_probs > 0) { @@ -2985,23 +3099,23 @@ struct server_context_impl { // verify and try to accept the draft { - const bool use_ckpt = slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; + const bool use_ckpt_tgt = ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; // only save the sampler sampler state if we use checkpoints common_sampler_ptr smpl_save; - if (use_ckpt) { + if (use_ckpt_tgt) { smpl_save.reset(common_sampler_clone(slot.smpl.get())); } GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1); - auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx, slot.spec_i_batch, slot.spec_draft); + auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx_tgt, slot.spec_i_batch, slot.spec_draft); slot.spec_i_batch.clear(); GGML_ASSERT(accepted.size() >= 1); // check for partial draft acceptance if (accepted.size() < slot.spec_draft.size() + 1) { - if (use_ckpt) { + if (use_ckpt_tgt) { if (trace > 0) { SLT_INF(slot, "accepted %2zu/%2zu draft tokens (restore checkpoint)\n", accepted.size() - 1, slot.spec_draft.size()); } @@ -3011,16 +3125,19 @@ struct server_context_impl { const auto & ckpt = slot.spec_ckpt; - SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n", - ckpt.pos_min, ckpt.pos_max, ckpt.size()); + SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n", ckpt.pos_min, ckpt.pos_max, ckpt.size()); - const size_t n = llama_state_seq_set_data_ext(slot.ctx, ckpt.data.data(), ckpt.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); - if (n != ckpt.size()) { - GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu", - __func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt.size(), n); + { + ckpt.load_tgt(slot.ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + + llama_memory_seq_rm(llama_get_memory(slot.ctx_tgt), slot.id, ckpt.pos_max + 1, -1); } - llama_memory_seq_rm(llama_get_memory(slot.ctx), slot.id, ckpt.pos_max + 1, -1); + if (slot.ctx_dft) { + ckpt.load_dft(slot.ctx_dft, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + + llama_memory_seq_rm(llama_get_memory(slot.ctx_dft), slot.id, ckpt.pos_max + 1, -1); + } slot.prompt.tokens.keep_first(ckpt.n_tokens); slot.smpl = std::move(smpl_save); @@ -3033,7 +3150,7 @@ struct server_context_impl { SLT_INF(slot, "accepted %2zu/%2zu draft tokens\n", accepted.size() - 1, n_draft); } - common_speculative_accept(slot.spec.get(), accepted.size() - 1); + common_speculative_accept(spec.get(), slot.id, accepted.size() - 1); slot.spec_draft = std::move(accepted); } @@ -3055,13 +3172,16 @@ struct server_context_impl { slot.sampled = ids.back(); // last accepted token SLT_DBG(slot, "add accepted tokens: sampled=%d, ids.size=%zu, n_draft=%zu\n", slot.sampled, ids.size(), n_draft); - llama_memory_seq_rm(llama_get_memory(slot.ctx), slot.id, slot.prompt.tokens.pos_next(), -1); + llama_memory_seq_rm(llama_get_memory(slot.ctx_tgt), slot.id, slot.prompt.tokens.pos_next(), -1); + if (slot.ctx_dft) { + llama_memory_seq_rm(llama_get_memory(slot.ctx_dft), slot.id, slot.prompt.tokens.pos_next(), -1); + } for (size_t i = 0; i < ids.size(); ++i) { completion_token_output result; result.tok = ids[i]; - result.text_to_send = common_token_to_piece(slot.ctx, result.tok, accept_special_token(slot, result.tok)); + result.text_to_send = common_token_to_piece(slot.ctx_tgt, result.tok, accept_special_token(slot, result.tok)); result.prob = 1.0f; // set later // TODO: set result.probs @@ -3113,7 +3233,7 @@ void server_context::terminate() { } llama_context * server_context::get_llama_context() const { - return impl->ctx; + return impl->ctx_tgt; } server_response_reader server_context::get_response_reader() { @@ -3123,8 +3243,8 @@ server_response_reader server_context::get_response_reader() { server_context_meta server_context::get_meta() const { auto bos_id = llama_vocab_bos(impl->vocab); auto eos_id = llama_vocab_eos(impl->vocab); - auto bos_token_str = bos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx, bos_id, true) : ""; - auto eos_token_str = eos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx, eos_id, true) : ""; + auto bos_token_str = bos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx_tgt, bos_id, true) : ""; + auto eos_token_str = eos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx_tgt, eos_id, true) : ""; return server_context_meta { /* build_info */ std::string(llama_build_info()), @@ -3137,7 +3257,7 @@ server_context_meta server_context::get_meta() const { /* has_inp_audio */ impl->chat_params.allow_audio, /* json_webui_settings */ impl->json_webui_settings, /* slot_n_ctx */ impl->get_slot_n_ctx(), - /* pooling_type */ llama_pooling_type(impl->ctx), + /* pooling_type */ llama_pooling_type(impl->ctx_tgt), /* chat_params */ impl->chat_params, /* chat_template_caps */ common_chat_templates_get_caps(impl->chat_params.tmpls.get()), @@ -3155,10 +3275,10 @@ server_context_meta server_context::get_meta() const { /* model_vocab_type */ llama_vocab_type(impl->vocab), /* model_vocab_n_tokens */ llama_vocab_n_tokens(impl->vocab), - /* model_n_ctx_train */ llama_model_n_ctx_train(impl->model), - /* model_n_embd_inp */ llama_model_n_embd(impl->model), - /* model_n_params */ llama_model_n_params(impl->model), - /* model_size */ llama_model_size(impl->model), + /* model_n_ctx_train */ llama_model_n_ctx_train(impl->model_tgt), + /* model_n_embd_inp */ llama_model_n_embd(impl->model_tgt), + /* model_n_params */ llama_model_n_params(impl->model_tgt), + /* model_size */ llama_model_size(impl->model_tgt), }; } @@ -3502,10 +3622,6 @@ void server_routes::init_routes() { {"name", "n_tokens_max"}, {"help", "Largest observed n_tokens."}, {"value", res_task->n_tokens_max} - }, { - {"name", "n_busy_slots_per_decode"}, - {"help", "Average number of busy slots per llama_decode() call"}, - {"value", (float) res_task->n_busy_slots_total / std::max((float) res_task->n_decode_total, 1.f)} }}}, {"gauge", {{ {"name", "prompt_tokens_seconds"}, @@ -3523,6 +3639,10 @@ void server_routes::init_routes() { {"name", "requests_deferred"}, {"help", "Number of requests deferred."}, {"value", (uint64_t) res_task->n_tasks_deferred} + },{ + {"name", "n_busy_slots_per_decode"}, + {"help", "Average number of busy slots per llama_decode() call"}, + {"value", (float) res_task->n_busy_slots_total / std::max((float) res_task->n_decode_total, 1.f)} }}} }; @@ -4045,7 +4165,7 @@ void server_routes::init_routes() { std::vector tasks; tasks.reserve(documents.size()); for (size_t i = 0; i < documents.size(); i++) { - auto tmp = format_prompt_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, query, documents[i]); + auto tmp = format_prompt_rerank(ctx_server.model_tgt, ctx_server.vocab, ctx_server.mctx, query, documents[i]); server_task task = server_task(SERVER_TASK_TYPE_RERANK); task.id = rd.get_new_id(); task.tokens = std::move(tmp); diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index 45e5168fabe..b9b4a704a14 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -76,7 +76,7 @@ json task_params::to_json(bool only_metrics) const { {"reasoning_in_content", chat_parser_params.reasoning_in_content}, {"generation_prompt", chat_parser_params.generation_prompt}, {"samplers", samplers}, - {"speculative.type", common_speculative_type_to_str(speculative.type)}, + {"speculative.types", common_speculative_type_name_str(speculative.types)}, {"timings_per_token", timings_per_token}, {"post_sampling_probs", post_sampling_probs}, {"backend_sampling", sampling.backend_sampling}, @@ -133,7 +133,7 @@ json task_params::to_json(bool only_metrics) const { {"reasoning_in_content", chat_parser_params.reasoning_in_content}, {"generation_prompt", chat_parser_params.generation_prompt}, {"samplers", samplers}, - {"speculative.type", common_speculative_type_to_str(speculative.type)}, + {"speculative.types", common_speculative_type_name_str(speculative.types)}, {"timings_per_token", timings_per_token}, {"post_sampling_probs", post_sampling_probs}, {"backend_sampling", sampling.backend_sampling}, @@ -296,6 +296,8 @@ task_params server_task::params_from_json_cmpl( params.speculative = defaults.speculative; + // TODO: to keep things simple, we disable speculative parameter adjustments for now +#if 0 // TODO: for now, be able to adjust only the draft-model based speculative parameters params.speculative.draft.n_min = json_value(data, "speculative.n_min", defaults.speculative.draft.n_min); params.speculative.draft.n_max = json_value(data, "speculative.n_max", defaults.speculative.draft.n_max); @@ -305,7 +307,6 @@ task_params server_task::params_from_json_cmpl( params.speculative.draft.n_min = std::max(params.speculative.draft.n_min, 0); params.speculative.draft.n_max = std::max(params.speculative.draft.n_max, 0); -#if 0 // for debugging and research purposes params.speculative.type = common_speculative_type_from_name(json_value(data, "speculative.type", common_speculative_type_to_str(defaults.speculative.type))); @@ -1981,7 +1982,7 @@ size_t server_prompt_cache::n_tokens() const { return res; } -server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size) { +server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size_tgt, size_t state_size_dft) { // first check if the current state is contained fully in the cache for (auto it = states.begin(); it != states.end(); ++it) { const int cur_lcp_len = it->tokens.get_common_prefix(prompt.tokens); @@ -2005,11 +2006,13 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t } } - std::vector state_data; + std::vector state_data_tgt; + std::vector state_data_dft; // check if we can allocate enough memory for the new state try { - state_data.resize(state_size); + state_data_tgt.resize(state_size_tgt); + state_data_dft.resize(state_size_dft); } catch (const std::bad_alloc & e) { SRV_ERR("failed to allocate memory for prompt cache state: %s\n", e.what()); @@ -2022,17 +2025,19 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t return nullptr; } - auto & cur = states.emplace_back(); - cur = { + states.push_back({ /*.tokens =*/ prompt.tokens.clone(), - /*.data =*/ std::move(state_data), + /*.data =*/ { + /*.main =*/ std::move(state_data_tgt), + /*.drft =*/ std::move(state_data_dft), + }, /*.checkpoints =*/ prompt.checkpoints, - }; + }); - return &cur; + return &states.back(); } -bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot) { +bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx_tgt, llama_context * ctx_dft, int32_t id_slot) { const int lcp_best = prompt.tokens.get_common_prefix(tokens_new); float f_keep_best = prompt.tokens.size() > 0 ? float(lcp_best) / prompt.tokens.size() : -1.0f; // empty slot: any cache entry wins @@ -2065,16 +2070,39 @@ bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tok if (it_best != states.end()) { SRV_WRN(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best); - const size_t size = it_best->data.size(); - const size_t n = llama_state_seq_set_data_ext(ctx, it_best->data.data(), size, id_slot, 0); - if (n != size) { - SRV_WRN("failed to restore state with size %zu\n", size); + { + auto & data = it_best->data.main; + + const size_t size = data.size(); + const size_t n = llama_state_seq_set_data_ext(ctx_tgt, data.data(), size, id_slot, 0); + if (n != size) { + SRV_WRN("failed to restore state with size %zu\n", size); + + return false; + } - return false; + data.clear(); + data.shrink_to_fit(); } - it_best->data.clear(); - it_best->data.shrink_to_fit(); + { + auto & data = it_best->data.drft; + + if (!data.empty()) { + GGML_ASSERT(ctx_dft); + + const size_t size = data.size(); + const size_t n = llama_state_seq_set_data_ext(ctx_dft, data.data(), size, id_slot, 0); + if (n != size) { + SRV_WRN("failed to restore state with size %zu\n", size); + + return false; + } + + data.clear(); + data.shrink_to_fit(); + } + } prompt = std::move(*it_best); diff --git a/tools/server/server-task.h b/tools/server/server-task.h index 289e1fb8d24..64bdecd794f 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -565,42 +565,29 @@ struct server_task_result_apply_lora : server_task_result { virtual json to_json() override; }; -struct server_prompt_checkpoint { - llama_pos pos_min; - llama_pos pos_max; - - int64_t n_tokens; - - std::vector data; +struct server_prompt_data { + std::vector main; + std::vector drft; size_t size() const { - return data.size(); - } - - bool empty() const { - return data.empty(); - } - - void clear() { - pos_min = 0; - pos_max = 0; - n_tokens = 0; - data.clear(); + return main.size() + drft.size(); } }; struct server_prompt { server_tokens tokens; - std::vector data; + server_prompt_data data; - std::list checkpoints; + std::list checkpoints; size_t size() const { - size_t res = data.size(); + size_t res = 0; + + res += data.size(); - for (const auto & checkpoint : checkpoints) { - res += checkpoint.size(); + for (const auto & ckpt : checkpoints) { + res += ckpt.size(); } return res; @@ -614,7 +601,7 @@ struct server_prompt { return server_prompt { tokens.clone(), data, - checkpoints + checkpoints, }; } }; @@ -637,9 +624,9 @@ struct server_prompt_cache { size_t n_tokens() const; - server_prompt * alloc(const server_prompt & prompt, size_t state_size); + server_prompt * alloc(const server_prompt & prompt, size_t state_size_main, size_t state_size_drft); - bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot); + bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx_main, llama_context * ctx_drft, int32_t id_slot); void update(); }; diff --git a/tools/server/tests/unit/test_speculative.py b/tools/server/tests/unit/test_speculative.py index eebd3cc8fa2..84cd77e6f2e 100644 --- a/tools/server/tests/unit/test_speculative.py +++ b/tools/server/tests/unit/test_speculative.py @@ -5,7 +5,7 @@ server = ServerPreset.stories15m_moe() -MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q4_0.gguf" +MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/tiny-llamas/resolve/main/stories15M-q4_0.gguf" def create_server(): global server