diff --git a/common/sampling.cpp b/common/sampling.cpp index d4a2fdcdacc..5665d0a706c 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -547,6 +547,8 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co auto & chain = gsmpl->chain; auto & cur_p = gsmpl->cur_p; // initialized by set_logits + gsmpl->set_logits(ctx, idx); + // Check if a backend sampler has already sampled a token in which case we // return that token id directly. { @@ -558,17 +560,17 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co GGML_ASSERT(!gsmpl->grmr && "using grammar in combination with backend sampling is not supported"); GGML_ASSERT(!gsmpl->rbudget && "using reasoning budget in combination with backend sampling is not supported"); - // TODO: simplify - gsmpl->cur.resize(1); - gsmpl->cur[0] = { id, 0.0f, 1.0f }; - cur_p = { gsmpl->cur.data(), gsmpl->cur.size(), 0, true }; + for (size_t i = 0; i < cur_p.size; ++i) { + if (cur_p.data[i].id == id) { + cur_p.selected = i; + break; + } + } return id; } } - gsmpl->set_logits(ctx, idx); - // apply reasoning budget first llama_sampler_apply(rbudget, &cur_p); diff --git a/ggml/src/ggml-virtgpu/ggml-backend-device.cpp b/ggml/src/ggml-virtgpu/ggml-backend-device.cpp index ec8156bb868..a978812cd90 100644 --- a/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +++ b/ggml/src/ggml-virtgpu/ggml-backend-device.cpp @@ -1,5 +1,7 @@ #include "ggml-remoting.h" +#include + static const char * ggml_backend_remoting_device_get_name(ggml_backend_dev_t dev) { virtgpu * gpu = DEV_TO_GPU(dev); diff --git a/scripts/sync_vendor.py b/scripts/sync_vendor.py index 467a0660870..274b06440e7 100755 --- a/scripts/sync_vendor.py +++ b/scripts/sync_vendor.py @@ -5,7 +5,7 @@ import sys import subprocess -HTTPLIB_VERSION = "refs/tags/v0.43.3" +HTTPLIB_VERSION = "refs/tags/v0.43.4" vendor = { "https://github.com/nlohmann/json/releases/latest/download/json.hpp": "vendor/nlohmann/json.hpp", diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 3f20c94c550..637f8d21647 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -1317,7 +1317,7 @@ struct server_context_impl { return false; } - const bool need_logits = task.params.sampling.n_probs > 0; + const bool need_pre_sample_logits = task.params.sampling.n_probs > 0 && !task.params.post_sampling_probs; bool backend_sampling = true; @@ -1326,8 +1326,8 @@ struct server_context_impl { // TODO: speculative decoding requires multiple samples per batch - not supported yet backend_sampling &= !(slot.can_speculate() && common_speculative_n_max(slot.spec.get(), task.params.speculative) > 0); - // TODO: getting post/pre sampling logits is not yet supported with backend sampling - backend_sampling &= !need_logits; + // TODO: getting pre sampling logits is not yet supported with backend sampling + backend_sampling &= !need_pre_sample_logits; // TODO: tmp until backend sampling is fully implemented if (backend_sampling) { @@ -1504,6 +1504,12 @@ struct server_context_impl { // set probability for top n_probs tokens result.probs.reserve(n_probs); for (size_t i = 0; i < n_probs; i++) { + // Some samplers do return 0.0 probabilities, others don't. + // Filter 0.0 probailities, to ensure the behavior is consistent. + if (cur_p->data[i].p == 0.0) { + break; + } + result.probs.push_back({ cur_p->data[i].id, common_token_to_piece(ctx, cur_p->data[i].id, special), diff --git a/tools/server/server-queue.cpp b/tools/server/server-queue.cpp index a2a026a12ce..d5fceb1b131 100644 --- a/tools/server/server-queue.cpp +++ b/tools/server/server-queue.cpp @@ -381,7 +381,8 @@ server_task_result_ptr server_response_reader::next(const std::function if (result == nullptr) { // timeout, check stop condition if (should_stop()) { - SRV_DBG("%s", "stopping wait for next result due to should_stop condition\n"); + SRV_WRN("%s", "stopping wait for next result due to should_stop condition (adjust the --timeout argument if needed)\n"); + SRV_WRN("%s", "ref: https://github.com/ggml-org/llama.cpp/pull/22907\n"); return nullptr; } } else { diff --git a/tools/server/tests/unit/test_completion.py b/tools/server/tests/unit/test_completion.py index c1a19785434..1e0891987a9 100644 --- a/tools/server/tests/unit/test_completion.py +++ b/tools/server/tests/unit/test_completion.py @@ -491,29 +491,82 @@ def test_n_probs_post_sampling(): global server server.start() res = server.make_request("POST", "/completion", data={ - "prompt": "I believe the meaning of life is", + "prompt": "Today was the day. Today I would finally become a", "n_probs": 10, - "temperature": 0.0, + "temperature": 1.0, "n_predict": 5, "post_sampling_probs": True, }) assert res.status_code == 200 assert "completion_probabilities" in res.body assert len(res.body["completion_probabilities"]) == 5 - for tok in res.body["completion_probabilities"]: + for (i, tok) in enumerate(res.body["completion_probabilities"]): assert "id" in tok and tok["id"] > 0 assert "token" in tok and type(tok["token"]) == str assert "prob" in tok and 0.0 < tok["prob"] <= 1.0 assert "bytes" in tok and type(tok["bytes"]) == list - assert len(tok["top_probs"]) == 10 + assert "top_probs" in tok and type(tok["top_probs"]) == list + for prob in tok["top_probs"]: assert "id" in prob and prob["id"] > 0 assert "token" in prob and type(prob["token"]) == str - assert "prob" in prob and 0.0 <= prob["prob"] <= 1.0 + # 0.0 probability tokens should never be returned by the server + assert "prob" in prob and 0.0 < prob["prob"] <= 1.0 assert "bytes" in prob and type(prob["bytes"]) == list - # because the test model usually output token with either 100% or 0% probability, we need to check all the top_probs - assert any(prob["prob"] == 1.0 for prob in tok["top_probs"]) + if i == 0: + # The prompt is vague enough that we should get at least 10 possibilities + # for the first token. + assert len(tok["top_probs"]) == 10 + + if len(tok["top_probs"]) < 10: + # Getting less than the requested number of probabilities should only happen + # if the ones we did get already sum to 1.0. + assert sum(p["prob"] for p in tok["top_probs"]) == pytest.approx(1.0) + +def test_n_probs_post_backend_sampling(): + """Verify that the same probabilities are returned with and without backend sampling.""" + global server + server.backend_sampling = True + server.start() + + def make_request(backend_sampling): + n_predict = 20 + + res = server.make_request("POST", "/completion", data={ + "prompt": "The countries of Europe, in random order, are:", + "n_probs": 10, + "n_predict": n_predict, + "post_sampling_probs": True, + "seed": 4242, + "backend_sampling": backend_sampling, + }) + assert res.status_code == 200 + + total_probs = 0 + completions = res.body["completion_probabilities"] + assert len(completions) == n_predict + for tok in completions: + # Handling of 0.0 probabilities differs between samplers and backend sampling. Filter them to normalize the + # data. + tok["top_probs"] = [x for x in tok["top_probs"] if x["prob"] > 0.0] + total_probs += len(tok["top_probs"]) + # Verify that we got at least two top probs on average, to ensure the effectiveness of the test. + assert total_probs >= 2 * n_predict + return completions + + def verify_token(a, b): + assert a["id"] == b["id"] + assert a["token"] == b["token"] + assert a["bytes"] == b["bytes"] + assert a["prob"] == pytest.approx(b["prob"], abs=0.01) + + for (a, b) in zip(make_request(True), make_request(False)): + verify_token(a, b) + assert len(a["top_probs"]) == len(b["top_probs"]) + + for (aa, bb) in zip(a["top_probs"], b["top_probs"]): + verify_token(aa, bb) @pytest.mark.parametrize("tokenize,openai_style", [(False, False), (False, True), (True, False), (True, True)]) def test_logit_bias(tokenize, openai_style): diff --git a/tools/server/tests/utils.py b/tools/server/tests/utils.py index ce939038726..c5dba1c139f 100644 --- a/tools/server/tests/utils.py +++ b/tools/server/tests/utils.py @@ -108,6 +108,7 @@ class ServerProcess: no_cache_idle_slots: bool = False log_path: str | None = None webui_mcp_proxy: bool = False + backend_sampling: bool = False gcp_compat: bool = False # session variables @@ -252,6 +253,8 @@ def start(self, timeout_seconds: int = DEFAULT_HTTP_TIMEOUT) -> None: server_args.append("--no-cache-idle-slots") if self.webui_mcp_proxy: server_args.append("--webui-mcp-proxy") + if self.backend_sampling: + server_args.append("--backend_sampling") if self.gcp_compat: env["AIP_MODE"] = "PREDICTION" diff --git a/vendor/cpp-httplib/httplib.cpp b/vendor/cpp-httplib/httplib.cpp index cb8ea9742b6..024e9a3d581 100644 --- a/vendor/cpp-httplib/httplib.cpp +++ b/vendor/cpp-httplib/httplib.cpp @@ -8980,10 +8980,22 @@ ssize_t ChunkedDecoder::read_payload(char *buf, size_t len, stream_line_reader lr(strm, line_buf, sizeof(line_buf)); if (!lr.getline()) { return -1; } - char *endptr = nullptr; - unsigned long chunk_len = std::strtoul(lr.ptr(), &endptr, 16); - if (endptr == lr.ptr()) { return -1; } - if (chunk_len == ULONG_MAX) { return -1; } + // RFC 9112 ยง7.1: chunk-size = 1*HEXDIG + const char *p = lr.ptr(); + int v = 0; + if (!is_hex(*p, v)) { return -1; } + + size_t chunk_len = 0; + constexpr size_t chunk_len_max = (std::numeric_limits::max)(); + for (; is_hex(*p, v); ++p) { + if (chunk_len > (chunk_len_max >> 4)) { return -1; } + chunk_len = (chunk_len << 4) | static_cast(v); + } + + while (is_space_or_tab(*p)) { + ++p; + } + if (*p != '\0' && *p != ';' && *p != '\r' && *p != '\n') { return -1; } if (chunk_len == 0) { chunk_remaining = 0; @@ -8993,7 +9005,7 @@ ssize_t ChunkedDecoder::read_payload(char *buf, size_t len, return 0; } - chunk_remaining = static_cast(chunk_len); + chunk_remaining = chunk_len; last_chunk_total = chunk_remaining; last_chunk_offset = 0; } diff --git a/vendor/cpp-httplib/httplib.h b/vendor/cpp-httplib/httplib.h index 8d3c4c2c5b2..25dc7fee7a7 100644 --- a/vendor/cpp-httplib/httplib.h +++ b/vendor/cpp-httplib/httplib.h @@ -8,8 +8,8 @@ #ifndef CPPHTTPLIB_HTTPLIB_H #define CPPHTTPLIB_HTTPLIB_H -#define CPPHTTPLIB_VERSION "0.43.3" -#define CPPHTTPLIB_VERSION_NUM "0x002b03" +#define CPPHTTPLIB_VERSION "0.43.4" +#define CPPHTTPLIB_VERSION_NUM "0x002b04" #ifdef _WIN32 #if defined(_WIN32_WINNT) && _WIN32_WINNT < 0x0A00