Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2223,7 +2223,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
if (llama_supports_rpc()) {
add_opt(common_arg(
{"--rpc"}, "SERVERS",
"comma separated list of RPC servers (host:port)",
"comma-separated list of RPC servers (host:port)",
[](common_params & params, const std::string & value) {
add_rpc_devices(value);
GGML_UNUSED(params);
Expand Down Expand Up @@ -3555,7 +3555,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_MODEL"));
add_opt(common_arg(
{"--spec-type"}, common_speculative_all_types_str(),
string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n",
string_format("comma-separated list of types of speculative decoding to use (default: %s)\n",
common_speculative_type_name_str(params.speculative.types).c_str()),
[](common_params & params, const std::string & value) {
const auto enabled_types = string_split<std::string>(value, ',');
Expand Down
7 changes: 4 additions & 3 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,9 @@ enum common_params_sampling_config : uint64_t {

enum common_speculative_type {
COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding
COMMON_SPECULATIVE_TYPE_DRAFT, // draft model
COMMON_SPECULATIVE_TYPE_EAGLE3, // eagle draft model
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding
COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE, // standalone draft model speculative decoding
COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, // Eagle3 speculative decoding
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding based on n-grams
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values
COMMON_SPECULATIVE_TYPE_NGRAM_MOD,
Expand Down Expand Up @@ -342,6 +342,7 @@ struct common_params_speculative_ngram_cache {
struct common_params_speculative {
std::vector<enum common_speculative_type> types = { COMMON_SPECULATIVE_TYPE_NONE };

// used by Simple, MTP, Eagle3, etc. - all methods that require some kind of draft model
common_params_speculative_draft draft;

common_params_speculative_ngram_mod ngram_mod;
Expand Down
78 changes: 39 additions & 39 deletions common/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

const std::map<std::string, common_speculative_type> common_speculative_type_from_name_map = {
{"none", COMMON_SPECULATIVE_TYPE_NONE},
{"draft", COMMON_SPECULATIVE_TYPE_DRAFT},
{"eagle3", COMMON_SPECULATIVE_TYPE_EAGLE3},
{"draft-simple", COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE},
{"draft-eagle3", COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3},
{"ngram-simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE},
{"ngram-map-k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K},
{"ngram-map-k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V},
Expand Down Expand Up @@ -145,15 +145,15 @@ struct common_speculative_impl {
virtual void accept(llama_seq_id seq_id, uint16_t n_accepted) = 0;
};

struct common_speculative_state_draft : public common_speculative_impl {
struct common_speculative_impl_draft_simple : public common_speculative_impl {
common_params_speculative_draft params;

llama_batch batch;

std::vector<common_sampler_ptr> smpls;

common_speculative_state_draft(const common_params_speculative & params, uint32_t n_seq)
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT, n_seq)
common_speculative_impl_draft_simple(const common_params_speculative & params, uint32_t n_seq)
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE, n_seq)
, params(params.draft)
{
auto * ctx_dft = this->params.ctx_dft;
Expand Down Expand Up @@ -206,7 +206,7 @@ struct common_speculative_state_draft : public common_speculative_impl {
}
}

~common_speculative_state_draft() override {
~common_speculative_impl_draft_simple() override {
llama_batch_free(batch);
}

Expand Down Expand Up @@ -340,11 +340,11 @@ struct common_speculative_state_draft : public common_speculative_impl {
}
};

struct common_speculative_state_eagle3 : public common_speculative_impl {
struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
//common_params_speculative_eagle3 params;

common_speculative_state_eagle3(const common_params_speculative & /*params*/, uint32_t n_seq)
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_EAGLE3, n_seq) {}
common_speculative_impl_draft_eagle3(const common_params_speculative & /*params*/, uint32_t n_seq)
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, n_seq) {}

void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override {
// noop
Expand All @@ -365,13 +365,13 @@ struct common_speculative_state_eagle3 : public common_speculative_impl {
};

// state of self-speculation (simple implementation, not ngram-map)
struct common_speculative_state_ngram_simple : public common_speculative_impl {
struct common_speculative_impl_ngram_simple : public common_speculative_impl {
common_params_speculative_ngram_map params;

// shared across all sequences
common_ngram_simple_config config;

common_speculative_state_ngram_simple(
common_speculative_impl_ngram_simple(
const common_params_speculative & params, uint32_t n_seq,
common_ngram_simple_config config)
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, n_seq)
Expand Down Expand Up @@ -405,13 +405,13 @@ struct common_speculative_state_ngram_simple : public common_speculative_impl {
}
};

struct common_speculative_state_ngram_map_k : public common_speculative_impl {
struct common_speculative_impl_ngram_map_k : public common_speculative_impl {
common_params_speculative_ngram_map params;

// n_seq configs
std::vector<common_ngram_map> config;

common_speculative_state_ngram_map_k(
common_speculative_impl_ngram_map_k(
const common_params_speculative & params,
const common_ngram_map & config,
uint32_t n_seq)
Expand Down Expand Up @@ -453,7 +453,7 @@ struct common_speculative_state_ngram_map_k : public common_speculative_impl {
}
};

struct common_speculative_state_ngram_mod : public common_speculative_impl {
struct common_speculative_impl_ngram_mod : public common_speculative_impl {
common_params_speculative_ngram_mod params;

// shared across all sequences
Expand All @@ -475,7 +475,7 @@ struct common_speculative_state_ngram_mod : public common_speculative_impl {

std::vector<seq_info> sinfos;

common_speculative_state_ngram_mod(
common_speculative_impl_ngram_mod(
const common_params_speculative & params,
uint32_t n_seq)
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_MOD, n_seq)
Expand Down Expand Up @@ -621,7 +621,7 @@ struct common_speculative_state_ngram_mod : public common_speculative_impl {
}
};

struct common_speculative_state_ngram_cache : public common_speculative_impl {
struct common_speculative_impl_ngram_cache : public common_speculative_impl {
common_params_speculative_ngram_cache params;

uint16_t n_draft;
Expand All @@ -639,7 +639,7 @@ struct common_speculative_state_ngram_cache : public common_speculative_impl {

std::vector<seq_info> sinfos;

common_speculative_state_ngram_cache(
common_speculative_impl_ngram_cache(
const common_params_speculative & params,
uint32_t n_seq,
uint16_t n_draft,
Expand Down Expand Up @@ -775,7 +775,7 @@ static common_ngram_map get_common_ngram_map(
return common_ngram_map(size_key, size_value, key_only, min_hits);
}

static common_speculative_state_ngram_cache create_state_ngram_cache(
static common_speculative_impl_ngram_cache create_state_ngram_cache(
const common_speculative_config & config,
uint32_t n_seq,
const std::string & path_static,
Expand All @@ -786,7 +786,7 @@ static common_speculative_state_ngram_cache create_state_ngram_cache(
bool save_static = false;
bool save_dynamic = false;

common_speculative_state_ngram_cache state(config.params, n_seq, n_draft, path_static, path_dynamic, save_static, save_dynamic);
common_speculative_impl_ngram_cache state(config.params, n_seq, n_draft, path_static, path_dynamic, save_static, save_dynamic);

return state;
}
Expand Down Expand Up @@ -818,8 +818,8 @@ const char * common_speculative_all_types_str() {
std::string common_speculative_type_to_str(common_speculative_type type) {
switch (type) {
case COMMON_SPECULATIVE_TYPE_NONE: return "none";
case COMMON_SPECULATIVE_TYPE_DRAFT: return "draft";
case COMMON_SPECULATIVE_TYPE_EAGLE3: return "eagle3";
case COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE: return "draft-simple";
case COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3: return "draft-eagle3";
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram-simple";
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram-map-k";
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram-map-k4v";
Expand Down Expand Up @@ -872,9 +872,9 @@ common_speculative * common_speculative_init(common_params_speculative & params,
{
uint32_t enabled_configs = common_get_enabled_speculative_configs(params.types);

bool has_draft = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT));
bool has_draft_model = !params.draft.mparams.path.empty();
bool has_draft_model_path = !params.draft.mparams.path.empty();

bool has_draft_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE));
// bool has_mtp = false; // TODO: add MTP here
bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3

Expand Down Expand Up @@ -906,22 +906,22 @@ common_speculative * common_speculative_init(common_params_speculative & params,
if (has_ngram_cache) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, params));
}
if (has_draft) {
if (!has_draft_model) {
if (has_draft_simple) {
if (!has_draft_model_path) {
LOG_WRN("%s: draft model is not specified - cannot use 'draft' type\n", __func__);
has_draft = false;
has_draft_simple = false;
}
} else if (has_draft_model) {
} else if (has_draft_model_path) {
LOG_WRN("%s: draft model is specified but 'draft' speculative type is not explicitly enabled - enabling it\n", __func__);
has_draft = true;
has_draft_simple = true;
}

if (has_draft) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT, params));
if (has_draft_simple) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE, params));
}
// TODO: add MTP here
if (has_draft_eagle3) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_EAGLE3, params));
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, params));
}
}

Expand All @@ -932,12 +932,12 @@ common_speculative * common_speculative_init(common_params_speculative & params,
switch (config.type) {
case COMMON_SPECULATIVE_TYPE_NONE:
break;
case COMMON_SPECULATIVE_TYPE_DRAFT: {
impls.push_back(std::make_unique<common_speculative_state_draft>(config.params, n_seq));
case COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE: {
impls.push_back(std::make_unique<common_speculative_impl_draft_simple>(config.params, n_seq));
break;
}
case COMMON_SPECULATIVE_TYPE_EAGLE3: {
impls.push_back(std::make_unique<common_speculative_state_eagle3>(config.params, n_seq));
case COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3: {
impls.push_back(std::make_unique<common_speculative_impl_draft_eagle3>(config.params, n_seq));
break;
}
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: {
Expand All @@ -950,7 +950,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
/* .size_ngram = */ ngram_size_key,
/* .size_mgram = */ mgram_size_value
};
auto state = std::make_unique<common_speculative_state_ngram_simple>(
auto state = std::make_unique<common_speculative_impl_ngram_simple>(
/* .params = */ config.params,
/* .n_seq = */ n_seq,
/* .state = */ config_simple
Expand All @@ -961,21 +961,21 @@ common_speculative * common_speculative_init(common_params_speculative & params,
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K:
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: {
impls.push_back(
std::make_unique<common_speculative_state_ngram_map_k>(
std::make_unique<common_speculative_impl_ngram_map_k>(
config.params, get_common_ngram_map(config.type, config.params.ngram_map_k), n_seq));
break;
}
case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: {
impls.push_back(
std::make_unique<common_speculative_state_ngram_mod>(config.params, n_seq));
std::make_unique<common_speculative_impl_ngram_mod>(config.params, n_seq));
break;
}
case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: {
auto state = create_state_ngram_cache(
config, n_seq,
params.ngram_cache.lookup_cache_static,
params.ngram_cache.lookup_cache_dynamic);
impls.push_back(std::make_unique<common_speculative_state_ngram_cache>(state));
impls.push_back(std::make_unique<common_speculative_impl_ngram_cache>(state));
break;
}
default:
Expand Down
11 changes: 6 additions & 5 deletions examples/llama-eval/llama-eval.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#!/usr/bin/env python3
# type: ignore

import argparse
import json
Expand Down Expand Up @@ -100,6 +99,8 @@ def wilson_interval(correct: int, total: int, z: float = 1.96) -> Tuple[float, f


class BaseDataset(ABC):
questions: List[Dict]

@abstractmethod
def get_question(self, index: int) -> Dict:
pass
Expand Down Expand Up @@ -573,7 +574,7 @@ def normalize_number(s: str) -> Optional[int]:
class AimeDataset(BaseDataset):
def __init__(self, split: str = "train"):
self.split = split
self.questions: List[Dict] = []
self.questions = []
self._load_dataset()

def _load_dataset(self):
Expand Down Expand Up @@ -618,7 +619,7 @@ def get_prompt(self, question: Dict) -> str:

class Aime2025Dataset(BaseDataset):
def __init__(self):
self.questions: List[Dict] = []
self.questions = []
self._load_dataset()

def _load_dataset(self):
Expand Down Expand Up @@ -681,7 +682,7 @@ def get_prompt(self, question: Dict) -> str:
class Gsm8kDataset(BaseDataset):
def __init__(self, split: str = "test"):
self.split = split
self.questions: List[Dict] = []
self.questions = []
self._load_dataset()

def _load_dataset(self):
Expand Down Expand Up @@ -742,7 +743,7 @@ class GpqaDataset(BaseDataset):
def __init__(self, variant: str = "diamond", seed: int = 1234):
self.variant = variant
self.seed = seed
self.questions: List[Dict] = []
self.questions = []
self._load_dataset()

def _load_dataset(self):
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-zendnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF")
ExternalProject_Add(
zendnn
GIT_REPOSITORY https://github.com/amd/ZenDNN.git
GIT_TAG f79f7321a1add65ced6397a6bfab7edba6e3e14e # ZenDNN-2026-WW13
GIT_TAG ac9e580d9434b7b98985f2627a7ebfb5eba4bb0d # ZenDNN-2026-WW17
PREFIX ${ZENDNN_PREFIX}
SOURCE_DIR ${ZENDNN_SOURCE_DIR}
BINARY_DIR ${ZENDNN_BUILD_DIR}
Expand Down
Loading
Loading