Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,8 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
auto & chain = gsmpl->chain;
auto & cur_p = gsmpl->cur_p; // initialized by set_logits

gsmpl->set_logits(ctx, idx);

// Check if a backend sampler has already sampled a token in which case we
// return that token id directly.
{
Expand All @@ -558,17 +560,17 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
GGML_ASSERT(!gsmpl->grmr && "using grammar in combination with backend sampling is not supported");
GGML_ASSERT(!gsmpl->rbudget && "using reasoning budget in combination with backend sampling is not supported");

// TODO: simplify
gsmpl->cur.resize(1);
gsmpl->cur[0] = { id, 0.0f, 1.0f };
cur_p = { gsmpl->cur.data(), gsmpl->cur.size(), 0, true };
for (size_t i = 0; i < cur_p.size; ++i) {
if (cur_p.data[i].id == id) {
cur_p.selected = i;
break;
}
}

return id;
}
}

gsmpl->set_logits(ctx, idx);

// apply reasoning budget first
llama_sampler_apply(rbudget, &cur_p);

Expand Down
2 changes: 2 additions & 0 deletions ggml/src/ggml-virtgpu/ggml-backend-device.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "ggml-remoting.h"

#include <mutex>

static const char * ggml_backend_remoting_device_get_name(ggml_backend_dev_t dev) {
virtgpu * gpu = DEV_TO_GPU(dev);

Expand Down
2 changes: 1 addition & 1 deletion scripts/sync_vendor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import sys
import subprocess

HTTPLIB_VERSION = "refs/tags/v0.43.3"
HTTPLIB_VERSION = "refs/tags/v0.43.4"

vendor = {
"https://github.com/nlohmann/json/releases/latest/download/json.hpp": "vendor/nlohmann/json.hpp",
Expand Down
12 changes: 9 additions & 3 deletions tools/server/server-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1317,7 +1317,7 @@ struct server_context_impl {
return false;
}

const bool need_logits = task.params.sampling.n_probs > 0;
const bool need_pre_sample_logits = task.params.sampling.n_probs > 0 && !task.params.post_sampling_probs;

bool backend_sampling = true;

Expand All @@ -1326,8 +1326,8 @@ struct server_context_impl {
// TODO: speculative decoding requires multiple samples per batch - not supported yet
backend_sampling &= !(slot.can_speculate() && common_speculative_n_max(slot.spec.get(), task.params.speculative) > 0);

// TODO: getting post/pre sampling logits is not yet supported with backend sampling
backend_sampling &= !need_logits;
// TODO: getting pre sampling logits is not yet supported with backend sampling
backend_sampling &= !need_pre_sample_logits;

// TODO: tmp until backend sampling is fully implemented
if (backend_sampling) {
Expand Down Expand Up @@ -1504,6 +1504,12 @@ struct server_context_impl {
// set probability for top n_probs tokens
result.probs.reserve(n_probs);
for (size_t i = 0; i < n_probs; i++) {
// Some samplers do return 0.0 probabilities, others don't.
// Filter 0.0 probailities, to ensure the behavior is consistent.
if (cur_p->data[i].p == 0.0) {
break;
}

result.probs.push_back({
cur_p->data[i].id,
common_token_to_piece(ctx, cur_p->data[i].id, special),
Expand Down
3 changes: 2 additions & 1 deletion tools/server/server-queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,8 @@ server_task_result_ptr server_response_reader::next(const std::function<bool()>
if (result == nullptr) {
// timeout, check stop condition
if (should_stop()) {
SRV_DBG("%s", "stopping wait for next result due to should_stop condition\n");
SRV_WRN("%s", "stopping wait for next result due to should_stop condition (adjust the --timeout argument if needed)\n");
SRV_WRN("%s", "ref: https://github.com/ggml-org/llama.cpp/pull/22907\n");
return nullptr;
}
} else {
Expand Down
67 changes: 60 additions & 7 deletions tools/server/tests/unit/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,29 +491,82 @@ def test_n_probs_post_sampling():
global server
server.start()
res = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"prompt": "Today was the day. Today I would finally become a",
"n_probs": 10,
"temperature": 0.0,
"temperature": 1.0,
"n_predict": 5,
"post_sampling_probs": True,
})
assert res.status_code == 200
assert "completion_probabilities" in res.body
assert len(res.body["completion_probabilities"]) == 5
for tok in res.body["completion_probabilities"]:
for (i, tok) in enumerate(res.body["completion_probabilities"]):
assert "id" in tok and tok["id"] > 0
assert "token" in tok and type(tok["token"]) == str
assert "prob" in tok and 0.0 < tok["prob"] <= 1.0
assert "bytes" in tok and type(tok["bytes"]) == list
assert len(tok["top_probs"]) == 10
assert "top_probs" in tok and type(tok["top_probs"]) == list

for prob in tok["top_probs"]:
assert "id" in prob and prob["id"] > 0
assert "token" in prob and type(prob["token"]) == str
assert "prob" in prob and 0.0 <= prob["prob"] <= 1.0
# 0.0 probability tokens should never be returned by the server
assert "prob" in prob and 0.0 < prob["prob"] <= 1.0
assert "bytes" in prob and type(prob["bytes"]) == list
# because the test model usually output token with either 100% or 0% probability, we need to check all the top_probs
assert any(prob["prob"] == 1.0 for prob in tok["top_probs"])

if i == 0:
# The prompt is vague enough that we should get at least 10 possibilities
# for the first token.
assert len(tok["top_probs"]) == 10

if len(tok["top_probs"]) < 10:
# Getting less than the requested number of probabilities should only happen
# if the ones we did get already sum to 1.0.
assert sum(p["prob"] for p in tok["top_probs"]) == pytest.approx(1.0)

def test_n_probs_post_backend_sampling():
"""Verify that the same probabilities are returned with and without backend sampling."""
global server
server.backend_sampling = True
server.start()

def make_request(backend_sampling):
n_predict = 20

res = server.make_request("POST", "/completion", data={
"prompt": "The countries of Europe, in random order, are:",
"n_probs": 10,
"n_predict": n_predict,
"post_sampling_probs": True,
"seed": 4242,
"backend_sampling": backend_sampling,
})
assert res.status_code == 200

total_probs = 0
completions = res.body["completion_probabilities"]
assert len(completions) == n_predict
for tok in completions:
# Handling of 0.0 probabilities differs between samplers and backend sampling. Filter them to normalize the
# data.
tok["top_probs"] = [x for x in tok["top_probs"] if x["prob"] > 0.0]
total_probs += len(tok["top_probs"])
# Verify that we got at least two top probs on average, to ensure the effectiveness of the test.
assert total_probs >= 2 * n_predict
return completions

def verify_token(a, b):
assert a["id"] == b["id"]
assert a["token"] == b["token"]
assert a["bytes"] == b["bytes"]
assert a["prob"] == pytest.approx(b["prob"], abs=0.01)

for (a, b) in zip(make_request(True), make_request(False)):
verify_token(a, b)
assert len(a["top_probs"]) == len(b["top_probs"])

for (aa, bb) in zip(a["top_probs"], b["top_probs"]):
verify_token(aa, bb)

@pytest.mark.parametrize("tokenize,openai_style", [(False, False), (False, True), (True, False), (True, True)])
def test_logit_bias(tokenize, openai_style):
Expand Down
3 changes: 3 additions & 0 deletions tools/server/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ class ServerProcess:
no_cache_idle_slots: bool = False
log_path: str | None = None
webui_mcp_proxy: bool = False
backend_sampling: bool = False
gcp_compat: bool = False

# session variables
Expand Down Expand Up @@ -252,6 +253,8 @@ def start(self, timeout_seconds: int = DEFAULT_HTTP_TIMEOUT) -> None:
server_args.append("--no-cache-idle-slots")
if self.webui_mcp_proxy:
server_args.append("--webui-mcp-proxy")
if self.backend_sampling:
server_args.append("--backend_sampling")
if self.gcp_compat:
env["AIP_MODE"] = "PREDICTION"

Expand Down
22 changes: 17 additions & 5 deletions vendor/cpp-httplib/httplib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8980,10 +8980,22 @@ ssize_t ChunkedDecoder::read_payload(char *buf, size_t len,
stream_line_reader lr(strm, line_buf, sizeof(line_buf));
if (!lr.getline()) { return -1; }

char *endptr = nullptr;
unsigned long chunk_len = std::strtoul(lr.ptr(), &endptr, 16);
if (endptr == lr.ptr()) { return -1; }
if (chunk_len == ULONG_MAX) { return -1; }
// RFC 9112 §7.1: chunk-size = 1*HEXDIG
const char *p = lr.ptr();
int v = 0;
if (!is_hex(*p, v)) { return -1; }

size_t chunk_len = 0;
constexpr size_t chunk_len_max = (std::numeric_limits<size_t>::max)();
for (; is_hex(*p, v); ++p) {
if (chunk_len > (chunk_len_max >> 4)) { return -1; }
chunk_len = (chunk_len << 4) | static_cast<size_t>(v);
}

while (is_space_or_tab(*p)) {
++p;
}
if (*p != '\0' && *p != ';' && *p != '\r' && *p != '\n') { return -1; }

if (chunk_len == 0) {
chunk_remaining = 0;
Expand All @@ -8993,7 +9005,7 @@ ssize_t ChunkedDecoder::read_payload(char *buf, size_t len,
return 0;
}

chunk_remaining = static_cast<size_t>(chunk_len);
chunk_remaining = chunk_len;
last_chunk_total = chunk_remaining;
last_chunk_offset = 0;
}
Expand Down
4 changes: 2 additions & 2 deletions vendor/cpp-httplib/httplib.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
#ifndef CPPHTTPLIB_HTTPLIB_H
#define CPPHTTPLIB_HTTPLIB_H

#define CPPHTTPLIB_VERSION "0.43.3"
#define CPPHTTPLIB_VERSION_NUM "0x002b03"
#define CPPHTTPLIB_VERSION "0.43.4"
#define CPPHTTPLIB_VERSION_NUM "0x002b04"

#ifdef _WIN32
#if defined(_WIN32_WINNT) && _WIN32_WINNT < 0x0A00
Expand Down
Loading