diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 5b4fb79fc1b..8179bd1c1f9 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -831,6 +831,8 @@ def prepare_tensors(self): gguf.MODEL_TENSOR.SSM_CONV1D_Q, gguf.MODEL_TENSOR.SSM_CONV1D_K, gguf.MODEL_TENSOR.SSM_CONV1D_V, + # DSA indexer weights should be F32 + gguf.MODEL_TENSOR.INDEXER_PROJ, ) ) or new_name[-7:] not in (".weight", ".lora_a", ".lora_b") @@ -9186,6 +9188,147 @@ def prepare_tensors(self): raise ValueError(f"Unprocessed experts: {experts}") +@ModelBase.register( + "DeepseekV32ForCausalLM", +) +class DeepseekV32Model(TextModel): + model_arch = gguf.MODEL_ARCH.DEEPSEEK32 + + # TODO @ngxson : remove this when we support MTP for deepseek models + skip_mtp = True + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.block_count = self.hparams["num_hidden_layers"] + self.hparams.get("num_nextn_predict_layers", 0) + self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + + def set_vocab(self): + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(self.dir_model) + assert getattr(tokenizer, "add_bos_token", False), "Change value of add_bos_token to true in tokenizer_config.json file." + self._set_vocab_gpt2() + + def set_gguf_parameters(self): + + # note: deepseek32 using MLA converts into MQA (ie: GQA with 1 group) + self.hparams["num_key_value_heads"] = 1 + + super().set_gguf_parameters() + hparams = self.hparams + + # first_k_dense_replace: number of leading layers using dense FFN instead of MoE + self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"]) + self.gguf_writer.add_vocab_size(hparams["vocab_size"]) + self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"]) + self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"]) + + # note: deepseek32 using MLA converts into MQA with larger heads, then decompresses to MHA + self.gguf_writer.add_key_length(hparams["kv_lora_rank"] + hparams["qk_rope_head_dim"]) + self.gguf_writer.add_value_length(hparams["kv_lora_rank"]) + self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) + self.gguf_writer.add_value_length_mla(hparams["v_head_dim"]) + + # MoE parameters (required by C++ code for DEEPSEEK32 arch) + self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) + self.gguf_writer.add_expert_count(hparams["n_routed_experts"]) + self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"]) + self.gguf_writer.add_expert_weights_scale(self.hparams["routed_scaling_factor"]) + self.gguf_writer.add_expert_weights_norm(self.hparams["norm_topk_prob"]) + + self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) + + if (rope_mscale_all := self.rope_parameters.get("mscale_all_dim")) is not None: + # [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX] + # note: for legacy reasons, this is not consistent with the other usages of self.gguf_writer.add_rope_scaling_yarn_log_mul + # ref https://github.com/ggml-org/llama.cpp/pull/17945 + self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1 * rope_mscale_all) + + # NextN/MTP prediction layers + if (num_nextn_predict_layers := self.hparams.get("num_nextn_predict_layers")) is not None: + self.gguf_writer.add_nextn_predict_layers(num_nextn_predict_layers) + + # DSA indexer parameters + self.gguf_writer.add_indexer_head_count(self.hparams["index_n_heads"]) + self.gguf_writer.add_indexer_key_length(self.hparams["index_head_dim"]) + self.gguf_writer.add_indexer_top_k(self.hparams["index_topk"]) + + _experts: list[dict[str, Tensor]] | None = None + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if name.startswith("language_model."): + name = name.replace("language_model.", "") + + # rename e_score_correction_bias tensors + if name.endswith("e_score_correction_bias"): + name = name.replace("e_score_correction_bias", "e_score_correction.bias") + + # skip Multi-Token Prediction (MTP) layers + if self.skip_mtp: + block_count = self.hparams["num_hidden_layers"] + match = re.match(r"model.layers.(\d+)", name) + if match and int(match.group(1)) >= block_count: + return + + # process the experts separately + if name.find("mlp.experts") != -1: + n_experts = self.hparams["n_routed_experts"] + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + # merge the experts into a single 3d tensor + for w_name in ["down_proj", "gate_proj", "up_proj"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + + merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" + + yield from super().modify_tensors(data_torch, merged_name, bid) + return + else: + return + + # note: MLA with the absorption optimization, needs these two split and k_b_proj transposed + if name.endswith("kv_b_proj.weight"): + name_kb = name.replace("kv_b_proj", "k_b_proj") + name_vb = name.replace("kv_b_proj", "v_b_proj") + + n_head_kv = self.hparams["num_key_value_heads"] + v_head_dim = self.hparams["v_head_dim"] + qk_nope_head_dim = self.hparams["qk_nope_head_dim"] + + assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim) + + kv_b = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1]) + k_b, v_b = torch.split(kv_b, [qk_nope_head_dim, v_head_dim], dim=1) + k_b = k_b.transpose(1, 2) + + yield from super().modify_tensors(k_b, name_kb, bid) + yield from super().modify_tensors(v_b, name_vb, bid) + return + + yield from super().modify_tensors(data_torch, name, bid) + + def prepare_tensors(self): + super().prepare_tensors() + + if self._experts is not None: + # flatten `list[dict[str, Tensor]]` into `list[str]` + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + + @ModelBase.register( "Mistral3ForConditionalGeneration", "Ministral3ForCausalLM", diff --git a/ggml/include/ggml-rpc.h b/ggml/include/ggml-rpc.h index 6fcf5a43393..5ad121ae57f 100644 --- a/ggml/include/ggml-rpc.h +++ b/ggml/include/ggml-rpc.h @@ -8,10 +8,10 @@ extern "C" { #define RPC_PROTO_MAJOR_VERSION 4 #define RPC_PROTO_MINOR_VERSION 0 -#define RPC_PROTO_PATCH_VERSION 0 +#define RPC_PROTO_PATCH_VERSION 1 #ifdef __cplusplus -static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT has changed - update RPC_PROTO_PATCH_VERSION"); +static_assert(GGML_OP_COUNT == 97, "GGML_OP_COUNT has changed - update RPC_PROTO_PATCH_VERSION"); #endif #define GGML_RPC_MAX_SERVERS 16 diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 703e3783136..5b6d24f6355 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -561,6 +561,7 @@ extern "C" { GGML_OP_RWKV_WKV7, GGML_OP_SOLVE_TRI, GGML_OP_GATED_DELTA_NET, + GGML_OP_LIGHTNING_INDEXER, GGML_OP_UNARY, @@ -2539,6 +2540,14 @@ extern "C" { struct ggml_tensor * beta, struct ggml_tensor * state); + GGML_API struct ggml_tensor * ggml_lightning_indexer( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * weights, + float scale_embd, + float scale_heads); + // custom operators typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata); diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 2b3eb5b5ce6..9d146f1051b 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2037,6 +2037,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_gated_delta_net(params, tensor); } break; + case GGML_OP_LIGHTNING_INDEXER: + { + ggml_compute_forward_lightning_indexer(params, tensor); + } break; case GGML_OP_MAP_CUSTOM1: { ggml_compute_forward_map_custom1(params, tensor); @@ -2356,6 +2360,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_FLASH_ATTN_BACK: case GGML_OP_SSM_CONV: case GGML_OP_SSM_SCAN: + case GGML_OP_LIGHTNING_INDEXER: { n_tasks = n_threads; } break; @@ -2939,6 +2944,12 @@ struct ggml_cplan ggml_graph_plan( { GGML_ABORT("fatal error"); } + case GGML_OP_LIGHTNING_INDEXER: + { + // temp buffer for dequantizing lightning indexer keys + const int64_t ne10 = node->src[1]->ne[0]; + cur += sizeof(float)*ne10*n_tasks; + } break; default: break; } diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index a9bc21da6f0..efd960823a3 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -2235,8 +2235,42 @@ static void ggml_compute_forward_fill_f32(const ggml_compute_params * params, gg } } +static void ggml_compute_forward_fill_f16(const ggml_compute_params * params, ggml_tensor * dst) { + const ggml_fp16_t c = GGML_CPU_FP32_TO_FP16(ggml_get_op_params_f32(dst, 0)); + + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); + GGML_TENSOR_LOCALS(size_t, nb, dst, nb); + + const auto [ir0, ir1] = get_thread_range(params, dst); + + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir/(ne2*ne1); + const int64_t i02 = (ir - i03*ne2*ne1)/ne1; + const int64_t i01 = (ir - i03*ne2*ne1 - i02*ne1); + + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1); + + ggml_vec_set_f16(ne0, dst_ptr, c); + } +} + void ggml_compute_forward_fill(const ggml_compute_params * params, ggml_tensor * dst) { - ggml_compute_forward_fill_f32(params, dst); + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_fill_f32(params, dst); + } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_fill_f16(params, dst); + } break; + default: + { + GGML_ABORT("unsupported type for ggml_compute_forward_fill: %s", ggml_type_name(src0->type)); + } + } } // ggml_compute_tri @@ -11212,3 +11246,76 @@ void ggml_compute_forward_opt_step_sgd(const ggml_compute_params * params, ggml_ } } } + +// ggml_compute_forward_lightning_indexer + +void ggml_compute_forward_lightning_indexer( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; // q + const ggml_tensor * src1 = dst->src[1]; // k + const ggml_tensor * src2 = dst->src[2]; // weights + + const float scale_embd = ggml_get_op_params_f32(dst, 0); + const float scale_heads = ggml_get_op_params_f32(dst, 1); + + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src2->type == GGML_TYPE_F32); + + GGML_TENSOR_TERNARY_OP_LOCALS + + GGML_ASSERT( nb0 == sizeof(float)); + GGML_ASSERT(nb00 == sizeof(float)); + + int n_embd = src0->ne[0]; + int n_head = src0->ne[1]; + int n_batch = src0->ne[2]; + int n_stream = src0->ne[3]; + int n_kv = src1->ne[2]; + + ggml_to_float_t const k_to_float = ggml_get_type_traits(src1->type)->to_float; + GGML_ASSERT((src1->type == GGML_TYPE_F32 || k_to_float) && "lightning indexer: unsupported K-type"); + + const int nr = n_kv; + const int ith = params->ith; + const int nth = params->nth; + + // (temporary) buffer for K converted to float + float * src1_row_f32 = (float *) params->wdata + ith*(1*n_embd + CACHE_LINE_SIZE_F32); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i_stream = 0; i_stream < n_stream; ++i_stream) { + for (int i_batch = 0; i_batch < n_batch; ++i_batch) { + for (int i_kv = ir0; i_kv < ir1; ++i_kv) { + char * src1_row = (char *) src1->data + i_kv*nb12 + i_stream*nb13; + if (k_to_float) { + k_to_float(src1_row, src1_row_f32, n_embd); + } else { + src1_row_f32 = (float *) src1_row; + } + float * src2_row = (float *) ((char *) src2->data + i_batch*nb21 + i_stream*nb23); + float * dst_row = (float *) ((char *) dst->data + i_batch*nb1 + i_stream*nb3); + float score = 0.0f; + for (int i_head = 0; i_head < n_head; ++i_head) { + // dot product of q and k for head i_head + float qk = 0.0f; + float * src0_row = (float *) ((char *) src0->data + i_head*nb01 + i_batch*nb02 + i_stream*nb03); + ggml_vec_dot_f32(n_embd, &qk, 0, src0_row, 0, src1_row_f32, 0, 1); + qk *= scale_embd; + // ReLU and weights + score += MAX(qk, 0.0f) * src2_row[i_head]; + } + score *= scale_heads; + dst_row[i_kv] = score; + } + } + } +} diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index 3fa1443abc4..c3f4a0a6c07 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -103,6 +103,7 @@ void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, s void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_gated_delta_net(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_lightning_indexer(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom3(const struct ggml_compute_params * params, struct ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 185956317e0..2ec25830c0b 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -61,6 +61,7 @@ #include "ggml-cuda/tri.cuh" #include "ggml-cuda/cumsum.cuh" #include "ggml-cuda/fill.cuh" +#include "ggml-cuda/lightning_indexer.cuh" #include "ggml.h" #include @@ -2952,6 +2953,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_FILL: ggml_cuda_op_fill(ctx, dst); break; + case GGML_OP_LIGHTNING_INDEXER: + ggml_cuda_op_lightning_indexer(ctx, dst); + break; default: return false; } @@ -5142,6 +5146,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_TRI: case GGML_OP_DIAG: case GGML_OP_SOLVE_TRI: + case GGML_OP_LIGHTNING_INDEXER: return true; default: diff --git a/ggml/src/ggml-cuda/lightning_indexer.cu b/ggml/src/ggml-cuda/lightning_indexer.cu new file mode 100644 index 00000000000..c8a2d829da9 --- /dev/null +++ b/ggml/src/ggml-cuda/lightning_indexer.cu @@ -0,0 +1,172 @@ +#include "lightning_indexer.cuh" +#include "fattn-common.cuh" +#include "convert.cuh" + +constexpr int KVS_PER_WARP = 8; +constexpr int WARPS_PER_BLOCK = 8; + +template +static __global__ void lightning_indexer_kernel( + const float * src0, const char * src1, const float * src2, float * dst, + const float scale_embd, const float scale_heads, + int64_t n_stream, int64_t n_batch, int64_t n_kv, + size_t nb1, size_t nb2, size_t nb3, + size_t nb01, size_t nb02, size_t nb03, + size_t nb11, size_t nb12, size_t nb13, + size_t nb21, size_t nb22, size_t nb23 + ) { + + int i_batch = blockIdx.y; + int i_stream = blockIdx.z; + int i_warp = threadIdx.y; + int i_lane = threadIdx.x; + + // each warp processes KVS_PER_WARP KV elements + // each block processes WARPS_PER_BLOCK * KVS_PER_WARP KV elements + int start_kv_block = blockIdx.x * (WARPS_PER_BLOCK * KVS_PER_WARP); + int start_kv = start_kv_block + i_warp * KVS_PER_WARP; + + const char * q_base = (const char *) src0 + i_batch*nb02 + i_stream*nb03; + const float * w_base = (const float *) ((const char *) src2 + i_batch*nb21 + i_stream*nb23); + + float4 k_vecs[KVS_PER_WARP]; + float score_k[KVS_PER_WARP] = {0.0f}; + + constexpr dequantize_V_t dequantize_k = get_dequantize_V(); + + // preload k values (they are reused in a loop below) + #pragma unroll + for (int k = 0; k < KVS_PER_WARP; ++k) { + int i_kv = start_kv + k; + if (i_kv < n_kv) { + const void * k_base = (const void *) ((const char *) src1 + i_kv*nb12 + i_stream*nb13); + dequantize_k(k_base, &k_vecs[k], i_lane * 4); + } else { + k_vecs[k] = make_float4(0, 0, 0, 0); + } + } + + for (int h = 0; h < n_head; ++h) { + const float4 q_vec = *(const float4 *) (q_base + h*nb01 + i_lane*4*sizeof(float)); + const float w_val = w_base[h]; + + float qk[KVS_PER_WARP] = {0.0f}; + + #pragma unroll + for (int k = 0; k < KVS_PER_WARP; ++k) { + const float4 k_vec = k_vecs[k]; + qk[k] += q_vec.x * k_vec.x; + qk[k] += q_vec.y * k_vec.y; + qk[k] += q_vec.z * k_vec.z; + qk[k] += q_vec.w * k_vec.w; + } + + #pragma unroll + for (int k = 0; k < KVS_PER_WARP; ++k) { + float sum = warp_reduce_sum(qk[k]); + + // scale_embd, ReLU, weight + if (i_lane == 0) { + sum *= scale_embd; + sum = (sum > 0.0f) ? sum : 0.0f; + score_k[k] += sum * w_val; + } + } + } + + // scale_heads, store output + if (i_lane == 0) { + float * dst_base = (float *) ((char *) dst + i_batch*nb1 + i_stream*nb3); + #pragma unroll + for (int k = 0; k < KVS_PER_WARP; ++k) { + int i_kv = start_kv + k; + if (i_kv < n_kv) { + dst_base[i_kv] = score_k[k] * scale_heads; + } + } + } +} + +#define DECL_LIGHTNING_INDEXER_CASE(n_embd, n_head, type_K) \ + template __global__ void lightning_indexer_kernel ( \ + const float * src0, const char * src1, const float * src2, float * dst, \ + const float scale_embd, const float scale_heads, \ + int64_t n_stream, int64_t n_batch, int64_t n_kv, \ + size_t nb1, size_t nb2, size_t nb3, \ + size_t nb01, size_t nb02, size_t nb03, \ + size_t nb11, size_t nb12, size_t nb13, \ + size_t nb21, size_t nb22, size_t nb23); + +DECL_LIGHTNING_INDEXER_CASE(128, 64, GGML_TYPE_F16) +DECL_LIGHTNING_INDEXER_CASE(128, 64, GGML_TYPE_Q4_0) +DECL_LIGHTNING_INDEXER_CASE(128, 64, GGML_TYPE_Q4_1) +DECL_LIGHTNING_INDEXER_CASE(128, 64, GGML_TYPE_Q5_0) +DECL_LIGHTNING_INDEXER_CASE(128, 64, GGML_TYPE_Q5_1) +DECL_LIGHTNING_INDEXER_CASE(128, 64, GGML_TYPE_Q8_0) +DECL_LIGHTNING_INDEXER_CASE(128, 64, GGML_TYPE_BF16) + +#define LIGHTNING_INDEXER_CASE(n_embd, n_head, K, type_K) \ + if (K->type == (type_K)) { \ + lightning_indexer_kernel<<>>( \ + src0_d, src1_d, src2_d, dst_d, scale_embd, scale_heads, \ + n_stream, n_batch, n_kv, \ + nb1, nb2, nb3, \ + nb01, nb02, nb03, \ + nb11, nb12, nb13, \ + nb21, nb22, nb23 \ + ); \ + } else + +void ggml_cuda_op_lightning_indexer(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * src2 = dst->src[2]; + + const float scale_embd = ggml_get_op_params_f32(dst, 0); + const float scale_heads = ggml_get_op_params_f32(dst, 1); + + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src2->type == GGML_TYPE_F32); + + GGML_TENSOR_TERNARY_OP_LOCALS + + // input tensor rows must be contiguous + GGML_ASSERT(nb00 == ggml_type_size(src0->type)); + GGML_ASSERT(nb10 == ggml_type_size(src1->type)); + GGML_ASSERT(nb20 == ggml_type_size(src2->type)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + const int n_embd = src0->ne[0]; + const int n_head = src0->ne[1]; + const int n_batch = src0->ne[2]; + const int n_stream = src0->ne[3]; + const int n_kv = src1->ne[2]; + + const float * src0_d = (const float *) src0->data; + const char * src1_d = (const char *) src1->data; + const float * src2_d = (const float *) src2->data; + float * dst_d = (float *) dst->data; + + dim3 block(32, WARPS_PER_BLOCK); + int num_kv_blocks = (n_kv + (KVS_PER_WARP * WARPS_PER_BLOCK) - 1) / (KVS_PER_WARP * WARPS_PER_BLOCK); + dim3 grid(num_kv_blocks, n_batch, n_stream); + + if (n_embd == 128 && n_head == 64) { + LIGHTNING_INDEXER_CASE(128, 64, src1, GGML_TYPE_F16) + LIGHTNING_INDEXER_CASE(128, 64, src1, GGML_TYPE_Q4_0) + LIGHTNING_INDEXER_CASE(128, 64, src1, GGML_TYPE_Q4_1) + LIGHTNING_INDEXER_CASE(128, 64, src1, GGML_TYPE_Q5_0) + LIGHTNING_INDEXER_CASE(128, 64, src1, GGML_TYPE_Q5_1) + LIGHTNING_INDEXER_CASE(128, 64, src1, GGML_TYPE_Q8_0) + LIGHTNING_INDEXER_CASE(128, 64, src1, GGML_TYPE_BF16) + GGML_ABORT("fatal error"); + } else { + GGML_ABORT("fatal error"); + } +} diff --git a/ggml/src/ggml-cuda/lightning_indexer.cuh b/ggml/src/ggml-cuda/lightning_indexer.cuh new file mode 100644 index 00000000000..31fcc7d5ae0 --- /dev/null +++ b/ggml/src/ggml-cuda/lightning_indexer.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_lightning_indexer(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index eda041f4518..ad43e7949bf 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1058,6 +1058,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "RWKV_WKV7", "SOLVE_TRI", "GATED_DELTA_NET", + "LIGHTNING_INDEXER", "UNARY", @@ -1075,7 +1076,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96"); +static_assert(GGML_OP_COUNT == 97, "GGML_OP_COUNT != 97"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1168,6 +1169,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "rwkv_wkv7(r, w, k, v, a, b, s)", "A X = B, A triangular, solve X", "gated_delta_net(q, k, v, g, beta, s)", + "lightning_indexer(q, k, weights, scale_embd, scale_heads)", "unary(x)", @@ -1185,7 +1187,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)", }; -static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96"); +static_assert(GGML_OP_COUNT == 97, "GGML_OP_COUNT != 97"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -5208,7 +5210,7 @@ static struct ggml_tensor * ggml_fill_impl( struct ggml_tensor * a, float c, bool inplace) { - GGML_ASSERT(a->type == GGML_TYPE_F32); + GGML_ASSERT(a->type == GGML_TYPE_F32 || a->type == GGML_TYPE_F16); GGML_ASSERT(ggml_is_contiguous(a)); struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); @@ -6213,6 +6215,40 @@ struct ggml_tensor * ggml_gated_delta_net( return result; } +// ggml_lightning_indexer + +struct ggml_tensor * ggml_lightning_indexer( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * weights, + float scale_embd, + float scale_heads) { + + GGML_ASSERT(q->type == GGML_TYPE_F32); + GGML_ASSERT(weights->type == GGML_TYPE_F32); + GGML_ASSERT(q->ne[0] == k->ne[0]); + GGML_ASSERT(q->ne[1] == weights->ne[0]); + GGML_ASSERT(k->ne[1] == 1); + GGML_ASSERT(q->ne[2] == weights->ne[1]); + GGML_ASSERT(weights->ne[2] == 1); + GGML_ASSERT(q->ne[3] == k->ne[3]); + GGML_ASSERT(k->ne[3] == weights->ne[3]); + + int64_t ne[4] = { k->ne[2], q->ne[2], 1, q->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + ggml_set_op_params_f32(result, 0, scale_embd); + ggml_set_op_params_f32(result, 1, scale_heads); + + result->op = GGML_OP_LIGHTNING_INDEXER; + result->src[0] = q; + result->src[1] = k; + result->src[2] = weights; + + return result; +} + //////////////////////////////////////////////////////////////////////////////// struct ggml_hash_set ggml_hash_set_new(size_t size) { diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index c5297a2f440..a636de384f7 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -441,6 +441,7 @@ class MODEL_ARCH(IntEnum): DEEPSEEK = auto() DEEPSEEK2 = auto() DEEPSEEK2OCR = auto() + DEEPSEEK32 = auto() CHATGLM = auto() GLM4 = auto() GLM4_MOE = auto() @@ -926,6 +927,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.DEEPSEEK: "deepseek", MODEL_ARCH.DEEPSEEK2: "deepseek2", MODEL_ARCH.DEEPSEEK2OCR: "deepseek2-ocr", + MODEL_ARCH.DEEPSEEK32: "deepseek32", MODEL_ARCH.CHATGLM: "chatglm", MODEL_ARCH.GLM4: "glm4", MODEL_ARCH.GLM4_MOE: "glm4moe", @@ -2813,6 +2815,46 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_UP_SHEXP, MODEL_TENSOR.FFN_EXP_PROBS_B, ], + MODEL_ARCH.DEEPSEEK32: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_A, + MODEL_TENSOR.ATTN_Q_B, + MODEL_TENSOR.ATTN_KV_A_MQA, + MODEL_TENSOR.ATTN_K_B, + MODEL_TENSOR.ATTN_V_B, + MODEL_TENSOR.ATTN_Q_A_NORM, + MODEL_TENSOR.ATTN_KV_A_NORM, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.FFN_EXP_PROBS_B, + MODEL_TENSOR.INDEXER_K_NORM, + MODEL_TENSOR.INDEXER_PROJ, + MODEL_TENSOR.INDEXER_ATTN_K, + MODEL_TENSOR.INDEXER_ATTN_Q_B, + # NextN/MTP tensors - preserved but unused + MODEL_TENSOR.NEXTN_EH_PROJ, + MODEL_TENSOR.NEXTN_EMBED_TOKENS, + MODEL_TENSOR.NEXTN_ENORM, + MODEL_TENSOR.NEXTN_HNORM, + MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD, + MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM, + ], MODEL_ARCH.ERNIE4_5_MOE: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -3926,6 +3968,10 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_ROT_EMBD, ], + MODEL_ARCH.DEEPSEEK32: [ + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_ROT_EMBD, + ], MODEL_ARCH.CHATGLM: [ MODEL_TENSOR.ROPE_FREQS, ], diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 7b1fcfca0ad..d15ccfd99f1 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -24,6 +24,7 @@ add_library(llama llama-io.cpp llama-kv-cache.cpp llama-kv-cache-iswa.cpp + llama-kv-cache-dsa.cpp llama-memory.cpp llama-memory-hybrid.cpp llama-memory-hybrid-iswa.cpp diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 6904b9c1a64..b95a095bbcc 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -75,6 +75,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_DEEPSEEK, "deepseek" }, { LLM_ARCH_DEEPSEEK2, "deepseek2" }, { LLM_ARCH_DEEPSEEK2OCR, "deepseek2-ocr" }, + { LLM_ARCH_DEEPSEEK32, "deepseek32" }, { LLM_ARCH_CHATGLM, "chatglm" }, { LLM_ARCH_GLM4, "glm4" }, { LLM_ARCH_GLM4_MOE, "glm4moe" }, @@ -888,6 +889,7 @@ bool llm_arch_supports_sm_tensor(const llm_arch & arch) { case LLM_ARCH_OLMO2: case LLM_ARCH_OLMOE: case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_DEEPSEEK32: case LLM_ARCH_GLM_DSA: case LLM_ARCH_BITNET: case LLM_ARCH_T5: diff --git a/src/llama-arch.h b/src/llama-arch.h index c4aabab7e0c..fc298daba2b 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -79,6 +79,7 @@ enum llm_arch { LLM_ARCH_DEEPSEEK, LLM_ARCH_DEEPSEEK2, LLM_ARCH_DEEPSEEK2OCR, + LLM_ARCH_DEEPSEEK32, LLM_ARCH_CHATGLM, LLM_ARCH_GLM4, LLM_ARCH_GLM4_MOE, diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 2ff23f87cf4..a50715bed34 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -7,6 +7,7 @@ #include "llama-kv-cache.h" #include "llama-kv-cache-iswa.h" +#include "llama-kv-cache-dsa.h" #include "llama-memory-hybrid.h" #include "llama-memory-hybrid-iswa.h" #include "llama-memory-recurrent.h" @@ -494,6 +495,34 @@ bool llm_graph_input_attn_k::can_reuse(const llm_graph_params & params) { return res; } +void llm_graph_input_attn_k_dsa::set_input(const llama_ubatch * ubatch) { + mctx->get_mla()->set_input_k_idxs(self_k_idxs_mla, ubatch); + + mctx->get_mla()->set_input_kq_mask(self_kq_mask_mla, ubatch, cparams.causal_attn); + + mctx->get_lid()->set_input_k_idxs(self_k_idxs_lid, ubatch); + + mctx->get_lid()->set_input_kq_mask(self_kq_mask_lid, ubatch, cparams.causal_attn); + + mctx->get_lid()->set_input_k_rot(self_k_rot_lid); +} + +bool llm_graph_input_attn_k_dsa::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); + + this->mctx = mctx; + + bool res = true; + + res &= self_k_idxs_mla->ne[0] == params.ubatch.n_tokens; + res &= self_k_idxs_lid->ne[0] == params.ubatch.n_tokens; + + res &= can_reuse_kq_mask(self_kq_mask_mla, mctx->get_mla(), params.ubatch, params.cparams); + res &= can_reuse_kq_mask(self_kq_mask_lid, mctx->get_lid(), params.ubatch, params.cparams); + + return res; +} + void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) { mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch); mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch); @@ -2333,6 +2362,85 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } +ggml_tensor * llm_graph_context::build_attn( + llm_graph_input_attn_k_dsa * inp, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + ggml_tensor * kq_b, + ggml_tensor * sinks, + ggml_tensor * v_mla, + ggml_tensor * top_k, + float kq_scale, + int il) const { + // these nodes are added to the graph together so that they are not reordered + // by doing so, the number of splits in the graph is reduced + // expand k later to enable rope fusion which directly writes into k-v cache + ggml_build_forward_expand(gf, q_cur); + ggml_build_forward_expand(gf, v_cur); + ggml_build_forward_expand(gf, k_cur); + + const auto * mctx_cur = inp->mctx->get_mla(); + + // store to KV cache + { + const auto & k_idxs = inp->get_k_idxs_mla(); + + ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il)); + } + + const auto & kq_mask = inp->get_kq_mask_mla(); + + // prepare new kq mask - starts filled with -INFINITY + ggml_tensor * kq_mask_all = ggml_fill(ctx0, kq_mask, -INFINITY); + + // reshape KQ mask into tensor with rows of size 1: + // [n_kv, n_batch, 1, n_stream] -> [1, n_kv, n_batch, n_stream] + kq_mask_all = ggml_view_4d(ctx0, kq_mask_all, 1, kq_mask_all->ne[0], kq_mask_all->ne[1], kq_mask_all->ne[3], kq_mask_all->nb[0], kq_mask_all->nb[1], kq_mask_all->nb[2], 0); + + // reshape top_k indices: [n_top_k, n_batch, 1, n_stream] -> [n_top_k, n_batch, n_stream, 1] + top_k = ggml_view_4d(ctx0, top_k, top_k->ne[0], top_k->ne[1], top_k->ne[3], 1, top_k->nb[1], top_k->nb[2], top_k->ne[3]*top_k->nb[3], 0); + + // prepare zero-filled tensor with rows of size 1: [1, n_top_k, n_batch, n_stream] + // this will be our source of zero values for unmasking top k mask elements + ggml_tensor * zeros = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, 1, top_k->ne[0], top_k->ne[1], top_k->ne[2]); + zeros = ggml_fill(ctx0, zeros, 0.0f); + + // modify KQ mask by unmasking elements that are in top_k indices + // ggml_set_rows([1, n_kv, n_batch, n_stream], [1, n_top_k, n_batch, n_stream], [n_top_k, n_batch, n_stream, 1]) + ggml_tensor * kq_mask_top_k = ggml_set_rows(ctx0, kq_mask_all, zeros, top_k); + + // reshape to restore the original shape of KQ mask: + // [1, n_kv, n_batch, n_stream] -> [n_kv, n_batch, 1, n_stream] + kq_mask_top_k = ggml_view_4d(ctx0, kq_mask_top_k, kq_mask_top_k->ne[1], kq_mask_top_k->ne[2], 1, kq_mask_top_k->ne[3], kq_mask_top_k->nb[2], kq_mask_top_k->nb[3], kq_mask_top_k->nb[3], 0); + + // combine with the original kq mask + kq_mask_top_k = ggml_add(ctx0, kq_mask_top_k, kq_mask); + + ggml_tensor * q = q_cur; + ggml_tensor * k = mctx_cur->get_k(ctx0, il); + ggml_tensor * v = ggml_view_4d(ctx0, k, v_cur->ne[0], k->ne[1], k->ne[2], k->ne[3], k->nb[1], k->nb[2], k->nb[3], 0); + + ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask_top_k, sinks, v_mla, kq_scale, il); + cb(cur, "kqv_out", il); + + if (wo) { + cur = build_lora_mm(wo, cur); + if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) { + // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators + ggml_mul_mat_set_prec(cur, GGML_PREC_F32); + } + } + + if (wo_b) { + cur = ggml_add(ctx0, cur, wo_b); + } + + return cur; +} + ggml_tensor * llm_graph_context::build_attn( llm_graph_input_attn_kv_iswa * inp, ggml_tensor * wo, @@ -2476,6 +2584,30 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } +llm_graph_input_attn_k_dsa * llm_graph_context::build_attn_inp_k_dsa() const { + const auto * mctx_cur = static_cast(mctx); + + auto inp = std::make_unique(hparams, cparams, mctx_cur); + + { + inp->self_k_idxs_mla = mctx_cur->get_mla()->build_input_k_idxs(ctx0, ubatch); + + inp->self_kq_mask_mla = build_attn_inp_kq_mask(ctx0, mctx_cur->get_mla(), ubatch, cparams); + inp->self_kq_mask_mla_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_mla, GGML_TYPE_F16) : inp->self_kq_mask_mla; + } + + { + inp->self_k_idxs_lid = mctx_cur->get_lid()->build_input_k_idxs(ctx0, ubatch); + + inp->self_kq_mask_lid = build_attn_inp_kq_mask(ctx0, mctx_cur->get_lid(), ubatch, cparams); + inp->self_kq_mask_lid_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_lid, GGML_TYPE_F16) : inp->self_kq_mask_lid; + + inp->self_k_rot_lid = mctx_cur->get_lid()->build_input_k_rot(ctx0); + } + + return (llm_graph_input_attn_k_dsa *) res->add_input(std::move(inp)); +} + // TODO: maybe separate the inner implementation into a separate function // like with the non-sliding window equivalent // once sliding-window hybrid caches are a thing. diff --git a/src/llama-graph.h b/src/llama-graph.h index 5cb1756c6a9..016683af0cd 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -22,6 +22,7 @@ struct llama_layer; struct llama_memory_context_i; class llama_kv_cache_context; +class llama_kv_cache_dsa_context; class llama_kv_cache_iswa_context; class llama_memory_recurrent_context; class llama_memory_hybrid_context; @@ -355,6 +356,44 @@ class llm_graph_input_attn_k : public llm_graph_input_i { const llama_kv_cache_context * mctx; }; +class llm_graph_input_attn_k_dsa : public llm_graph_input_i { +public: + llm_graph_input_attn_k_dsa( + const llama_hparams & hparams, + const llama_cparams & cparams, + const llama_kv_cache_dsa_context * mctx) : + hparams(hparams), + cparams(cparams), + mctx(mctx) { + } + ~llm_graph_input_attn_k_dsa() = default; + + void set_input(const llama_ubatch * ubatch) override; + + bool can_reuse(const llm_graph_params & params) override; + + ggml_tensor * get_k_idxs_mla() const { return self_k_idxs_mla; } + ggml_tensor * get_k_idxs_lid() const { return self_k_idxs_lid; } + + ggml_tensor * get_kq_mask_mla() const { return self_kq_mask_mla_cnv; } + ggml_tensor * get_kq_mask_lid() const { return self_kq_mask_lid; } + + ggml_tensor * self_k_idxs_mla = nullptr; // I64 [n_batch] + ggml_tensor * self_k_idxs_lid = nullptr; // I64 [n_batch] + + ggml_tensor * self_kq_mask_mla = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_mla_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_lid = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_lid_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + + ggml_tensor * self_k_rot_lid = nullptr; + + const llama_hparams hparams; + const llama_cparams cparams; + + const llama_kv_cache_dsa_context * mctx; +}; + class llm_graph_input_attn_kv_iswa : public llm_graph_input_i { public: llm_graph_input_attn_kv_iswa( @@ -952,6 +991,22 @@ struct llm_graph_context { float kq_scale, int il) const; + llm_graph_input_attn_k_dsa * build_attn_inp_k_dsa() const; + + ggml_tensor * build_attn( + llm_graph_input_attn_k_dsa * inp, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] + ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] + ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] + ggml_tensor * kq_b, + ggml_tensor * sinks, // [n_head_q] + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] + ggml_tensor * top_k, // [n_indexer_top_k, n_tokens] + float kq_scale, + int il) const; + llm_graph_input_attn_kv_iswa * build_attn_inp_kv_iswa() const; // note: if k_cur or v_cur are not provided, they will not be stored in the memory diff --git a/src/llama-kv-cache-dsa.cpp b/src/llama-kv-cache-dsa.cpp new file mode 100644 index 00000000000..a7d9513917d --- /dev/null +++ b/src/llama-kv-cache-dsa.cpp @@ -0,0 +1,260 @@ +#include "llama-kv-cache-dsa.h" + +#include "llama-impl.h" +#include "llama-batch.h" +#include "llama-model.h" + +#include +#include + +// +// llama_kv_cache_dsa +// + +llama_kv_cache_dsa::llama_kv_cache_dsa( + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool offload, + bool unified, + uint32_t kv_size, + uint32_t n_seq_max, + uint32_t n_pad, + uint32_t n_swa, + llama_swa_type swa_type, + const layer_filter_cb & filter, + const layer_reuse_cb & reuse) : + hparams_lid(model.hparams), n_stream(unified ? 1 : n_seq_max) { + + LLAMA_LOG_INFO("%s: creating main KV cache, size = %u cells\n", __func__, kv_size); + + kv_mla = std::make_unique( + model, model.hparams, type_k, type_v, + v_trans, offload, unified, kv_size, n_seq_max, n_pad, + n_swa, swa_type, filter, reuse); + + // we use llama_kv_cache for caching indexer keys + // by hand-tweaking some hparams we fool it to create + // indexer key cache tensors with correct dimensions + // https://github.com/ggml-org/llama.cpp/pull/21149#discussion_r3015940823 + + // DSA lightning indexer uses MQA with single key head + std::fill(hparams_lid.n_head_kv_arr.begin(), hparams_lid.n_head_kv_arr.end(), 1); + hparams_lid.n_embd_head_k_full = model.hparams.indexer_head_size; + + LLAMA_LOG_INFO("%s: creating indexer KV cache, size = %u cells\n", __func__, kv_size); + + kv_lid = std::make_unique( + model, hparams_lid, type_k, type_v, + v_trans, offload, unified, kv_size, n_seq_max, n_pad, + n_swa, swa_type, filter, reuse); +} + +void llama_kv_cache_dsa::clear(bool data) { + kv_mla->clear(data); + kv_lid->clear(data); +} + +bool llama_kv_cache_dsa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + bool res = true; + + res = res & kv_mla->seq_rm(seq_id, p0, p1); + res = res & kv_lid->seq_rm(seq_id, p0, p1); + + return res; +} + +void llama_kv_cache_dsa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + kv_mla->seq_cp(seq_id_src, seq_id_dst, p0, p1); + kv_lid->seq_cp(seq_id_src, seq_id_dst, p0, p1); +} + +void llama_kv_cache_dsa::seq_keep(llama_seq_id seq_id) { + kv_mla->seq_keep(seq_id); + kv_lid->seq_keep(seq_id); +} + +void llama_kv_cache_dsa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + kv_mla->seq_add(seq_id, p0, p1, shift); + kv_lid->seq_add(seq_id, p0, p1, shift); +} + +void llama_kv_cache_dsa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + kv_mla->seq_div(seq_id, p0, p1, d); + kv_lid->seq_div(seq_id, p0, p1, d); +} + +llama_pos llama_kv_cache_dsa::seq_pos_min(llama_seq_id seq_id) const { + return kv_mla->seq_pos_min(seq_id); +} + +llama_pos llama_kv_cache_dsa::seq_pos_max(llama_seq_id seq_id) const { + return kv_mla->seq_pos_max(seq_id); +} + +std::map llama_kv_cache_dsa::memory_breakdown() const { + std::map mb = kv_mla->memory_breakdown(); + for (const auto & buft_size : kv_lid->memory_breakdown()) { + mb[buft_size.first] += buft_size.second; + } + return mb; +} + +llama_memory_context_ptr llama_kv_cache_dsa::init_batch( + llama_batch_allocr & balloc, + uint32_t n_ubatch, + bool embd_all) { + GGML_UNUSED(embd_all); + + do { + balloc.split_reset(); + + std::vector ubatches; + while (true) { + auto ubatch = n_stream == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, true); + + if (ubatch.n_tokens == 0) { + break; + } + + ubatches.push_back(std::move(ubatch)); // NOLINT + } + + if (balloc.get_n_used() < balloc.get_n_tokens()) { + // failed to find a suitable split + break; + } + + auto sinfos_mla = kv_mla->prepare(ubatches); + if (sinfos_mla.empty()) { + break; + } + + auto sinfos_lid = kv_lid->prepare(ubatches); + if (sinfos_lid.empty()) { + break; + } + + assert(sinfos_mla.size() == sinfos_lid.size()); + + return std::make_unique( + this, std::move(sinfos_mla), std::move(sinfos_lid), std::move(ubatches)); + } while (false); + + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); +} + +llama_memory_context_ptr llama_kv_cache_dsa::init_full() { + return std::make_unique(this); +} + +llama_memory_context_ptr llama_kv_cache_dsa::init_update(llama_context * lctx, bool optimize) { + return std::make_unique(this, lctx, optimize); +} + +bool llama_kv_cache_dsa::get_can_shift() const { + return kv_mla->get_can_shift() && + kv_lid->get_can_shift() && + kv_mla->get_size() == kv_lid->get_size(); +} + +void llama_kv_cache_dsa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { + kv_mla->state_write(io, seq_id, flags); + kv_lid->state_write(io, seq_id, flags); +} + +void llama_kv_cache_dsa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { + kv_mla->state_read(io, seq_id, flags); + kv_lid->state_read(io, seq_id, flags); +} + +llama_kv_cache * llama_kv_cache_dsa::get_mla() const { + return kv_mla.get(); +} + +llama_kv_cache * llama_kv_cache_dsa::get_lid() const { + return kv_lid.get(); +} + +// +// llama_kv_cache_dsa_context +// + +llama_kv_cache_dsa_context::llama_kv_cache_dsa_context(llama_memory_status status) : status(status) {} + +llama_kv_cache_dsa_context::llama_kv_cache_dsa_context( + llama_kv_cache_dsa * kv) : + ctx_mla(kv->get_mla()->init_full()), + ctx_lid(kv->get_lid()->init_full()), + status(llama_memory_status_combine(ctx_mla->get_status(), ctx_lid->get_status())) { +} + +llama_kv_cache_dsa_context::llama_kv_cache_dsa_context( + llama_kv_cache_dsa * kv, + llama_context * lctx, + bool optimize) : + ctx_mla(kv->get_mla()->init_update(lctx, optimize)), + ctx_lid(kv->get_lid()->init_update(lctx, optimize)), + status(llama_memory_status_combine(ctx_mla->get_status(), ctx_lid->get_status())) { +} + +llama_kv_cache_dsa_context::llama_kv_cache_dsa_context( + llama_kv_cache_dsa * kv, + slot_info_vec_t sinfos_mla, + slot_info_vec_t sinfos_lid, + std::vector ubatches) : + ubatches(std::move(ubatches)), + // note: here we copy the ubatches. not sure if this is ideal + ctx_mla(new llama_kv_cache_context(kv->get_mla(), std::move(sinfos_mla), this->ubatches)), + ctx_lid(new llama_kv_cache_context(kv->get_lid(), std::move(sinfos_lid), this->ubatches)), + status(llama_memory_status_combine(ctx_mla->get_status(), ctx_lid->get_status())) { +} + +llama_kv_cache_dsa_context:: ~llama_kv_cache_dsa_context() = default; + +bool llama_kv_cache_dsa_context::next() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + ctx_mla->next(); + ctx_lid->next(); + + if (++i_next >= ubatches.size()) { + return false; + } + + return true; +} + +bool llama_kv_cache_dsa_context::apply() { + assert(!llama_memory_status_is_fail(status)); + + bool res = true; + + res = res & ctx_mla->apply(); + res = res & ctx_lid->apply(); + + return res; +} + +llama_memory_status llama_kv_cache_dsa_context::get_status() const { + return status; +} + +const llama_ubatch & llama_kv_cache_dsa_context::get_ubatch() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return ubatches[i_next]; +} + +const llama_kv_cache_context * llama_kv_cache_dsa_context::get_mla() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return static_cast(ctx_mla.get()); +} + +const llama_kv_cache_context * llama_kv_cache_dsa_context::get_lid() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return static_cast(ctx_lid.get()); +} diff --git a/src/llama-kv-cache-dsa.h b/src/llama-kv-cache-dsa.h new file mode 100644 index 00000000000..e2b330993b8 --- /dev/null +++ b/src/llama-kv-cache-dsa.h @@ -0,0 +1,138 @@ +#pragma once + +#include "llama-kv-cache.h" + +#include + +// +// llama_kv_cache_dsa +// + +// utilizes two instances of llama_kv_cache: +// - the first instance is for caching key tensors of the model, +// - the second instance is for caching lightning indexer key tensors + +class llama_kv_cache_dsa : public llama_memory_i { +public: + llama_kv_cache_dsa( + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool offload, + bool unified, + uint32_t kv_size, + uint32_t n_seq_max, + uint32_t n_pad, + uint32_t n_swa, + llama_swa_type swa_type, + const layer_filter_cb & filter, + const layer_reuse_cb & reuse); + + ~llama_kv_cache_dsa() = default; + + // + // llama_memory_i + // + + llama_memory_context_ptr init_batch( + llama_batch_allocr & balloc, + uint32_t n_ubatch, + bool embd_all) override; + + llama_memory_context_ptr init_full() override; + + llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override; + + bool get_can_shift() const override; + + void clear(bool data) override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_min(llama_seq_id seq_id) const override; + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + + std::map memory_breakdown() const override; + + // state write/load + + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override; + + // + // llama_kv_cache_dsa specific API + // + + llama_kv_cache * get_mla() const; + llama_kv_cache * get_lid() const; + +private: + // we keep indexer KV cache hparams instance here as llama_kv_cache stores only reference to it + llama_hparams hparams_lid; + const uint32_t n_stream = 1; + + std::unique_ptr kv_mla; + std::unique_ptr kv_lid; +}; + +class llama_kv_cache_dsa_context : public llama_memory_context_i { +public: + using slot_info_vec_t = llama_kv_cache::slot_info_vec_t; + + // used for errors + llama_kv_cache_dsa_context(llama_memory_status status); + + // used to create a full-cache context + llama_kv_cache_dsa_context( + llama_kv_cache_dsa * kv); + + // used to create an update context + llama_kv_cache_dsa_context( + llama_kv_cache_dsa * kv, + llama_context * lctx, + bool optimize); + + // used to create a batch processing context from a batch + llama_kv_cache_dsa_context( + llama_kv_cache_dsa * kv, + slot_info_vec_t sinfos_base, + slot_info_vec_t sinfos_ik, + std::vector ubatches); + + virtual ~llama_kv_cache_dsa_context(); + + // + // llama_memory_context_i + // + + bool next() override; + bool apply() override; + + llama_memory_status get_status() const override; + const llama_ubatch & get_ubatch() const override; + + // + // llama_kv_cache_dsa_context specific API + // + + const llama_kv_cache_context * get_mla() const; + const llama_kv_cache_context * get_lid() const; + +private: + //llama_kv_cache_dsa * kv; + + // the index of the next ubatch to process + size_t i_next = 0; + + std::vector ubatches; + + const llama_memory_context_ptr ctx_mla; + const llama_memory_context_ptr ctx_lid; + + const llama_memory_status status; +}; diff --git a/src/llama-kv-cache-iswa.cpp b/src/llama-kv-cache-iswa.cpp index 26e2cb4270b..9b9f1790363 100644 --- a/src/llama-kv-cache-iswa.cpp +++ b/src/llama-kv-cache-iswa.cpp @@ -60,14 +60,14 @@ llama_kv_cache_iswa::llama_kv_cache_iswa( LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base); kv_base = std::make_unique( - model, type_k, type_v, + model, hparams, type_k, type_v, v_trans, offload, unified, size_base, n_seq_max, n_pad, 0, LLAMA_SWA_TYPE_NONE, filter_base, reuse); LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa); kv_swa = std::make_unique( - model, type_k, type_v, + model, hparams, type_k, type_v, v_trans, offload, unified, size_swa, n_seq_max, n_pad, hparams.n_swa, hparams.swa_type, filter_swa, reuse); } diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 09102f549c8..914528a3ca1 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -78,6 +78,7 @@ static ggml_tensor * ggml_mul_mat_aux( llama_kv_cache::llama_kv_cache( const llama_model & model, + const llama_hparams & hparams, ggml_type type_k, ggml_type type_v, bool v_trans, @@ -90,7 +91,7 @@ llama_kv_cache::llama_kv_cache( llama_swa_type swa_type, const layer_filter_cb & filter, const layer_reuse_cb & reuse) : - model(model), hparams(model.hparams), v_trans(v_trans), + model(model), hparams(hparams), v_trans(v_trans), n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) { GGML_ASSERT(kv_size % n_pad == 0); @@ -252,7 +253,7 @@ llama_kv_cache::llama_kv_cache( // allocate tensors and initialize the buffers to avoid NaNs in the padding for (auto & [buft, ctx] : ctx_map) { ggml_backend_buffer_t buf; - if (model.hparams.no_alloc) { + if (hparams.no_alloc) { buf = ggml_backend_buft_alloc_buffer(buft, /*size =*/ 0); // dummy buffer for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != nullptr; t = ggml_get_next_tensor(ctx.get(), t)) { t->buffer = buf; // set dummy buffer for KV cache so that the backend scheduler won't try to allocate it @@ -292,6 +293,11 @@ llama_kv_cache::llama_kv_cache( ggml_is_quantized(type_k) && hparams.n_embd_head_k() % 64 == 0; + // always create Hadamard rotation tensors for DeepSeek V3.2 DSA lightning indexer + if (model.arch == LLM_ARCH_DEEPSEEK32 && hparams.n_embd_head_k_full == hparams.indexer_head_size) { + attn_rot_k = true; + } + attn_rot_v = !attn_rot_disable && n_embd_head_v_all > 0 && diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 0b62dc7b232..0b0a56ce92f 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -95,6 +95,7 @@ class llama_kv_cache : public llama_memory_i { llama_kv_cache( const llama_model & model, + const llama_hparams & hparams, ggml_type type_k, ggml_type type_v, bool v_trans, diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index 4ce1af592c1..74cb550195c 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -32,6 +32,7 @@ llama_memory_hybrid::llama_memory_hybrid( hparams(model.hparams), mem_attn(new llama_kv_cache( model, + model.hparams, type_k, type_v, v_trans, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index f77b2e9217f..297b70638fc 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -10,6 +10,7 @@ #include "llama-kv-cache.h" #include "llama-kv-cache-iswa.h" +#include "llama-kv-cache-dsa.h" #include "llama-memory-hybrid.h" #include "llama-memory-hybrid-iswa.h" #include "llama-memory-recurrent.h" @@ -498,6 +499,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_310B_A15B: return "310B.A15B"; case LLM_TYPE_355B_A32B: return "355B.A32B"; case LLM_TYPE_397B_A17B: return "397B.A17B"; + case LLM_TYPE_685B_A37B: return "685B.A37B"; case LLM_TYPE_744B_A40B: return "744B.A40B"; case LLM_TYPE_E2B: return "E2B"; case LLM_TYPE_E4B: return "E4B"; @@ -2026,6 +2028,56 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_DEEPSEEK32: + { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + hparams.f_norm_eps = 1e-6; // eps for layer norm + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); + + // MoE parameters + ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert); + ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + + // deepseek MLA parameters + ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + + // DSA parameters + ml.get_key(LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, hparams.indexer_n_head); + ml.get_key(LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, hparams.indexer_head_size); + ml.get_key(LLM_KV_ATTENTION_INDEXER_TOP_K, hparams.indexer_top_k); + + // Expert gating function + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); + + if (ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f)) { + // [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX] + // cancel the factor from the convert script + hparams.rope_yarn_log_mul /= 0.1f; + } + + // NextN/MTP parameters + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + + // TODO: when MTP is implemented, this should probably be updated if needed + hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + + switch (hparams.n_layer) { + case 62: type = LLM_TYPE_685B_A37B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_PLM: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -5397,6 +5449,108 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } } break; + case LLM_ARCH_DEEPSEEK32: + { + const bool is_mla = hparams.is_mla(); + if (!is_mla) { + throw std::runtime_error("DEEPSEEK32 architecture requires MLA"); + } + + // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA + const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla(); + const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla(); + + const int64_t n_embd_head_qk_rope = hparams.n_rot(); + const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; + + const int64_t q_lora_rank = hparams.n_lora_q; + const int64_t kv_lora_rank = hparams.n_lora_kv; + + const int64_t n_ff_exp = hparams.n_ff_exp; + const int64_t n_expert_shared = hparams.n_expert_shared; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // try to load output.weight, if not found, use token_embd (tied embeddings) + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + int flags = 0; + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + // skip all tensors in the NextN layers + // TODO @ngxson : TENSOR_NOT_REQUIRED was a hack, need to remove it later + flags |= TENSOR_SKIP | TENSOR_NOT_REQUIRED; + } + + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); + layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, flags); + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, flags); + + layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, flags); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, flags); + + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope}, flags); + + // note: only old legacy GGUF files will have the unsplit wkv_b tensor in + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, kv_lora_rank, n_head}, flags); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, flags); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, flags); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); + + // DSA indexer + layer.indexer_k_norm = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "weight", i), {hparams.indexer_head_size}, flags); + layer.indexer_k_norm_b = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "bias", i), {hparams.indexer_head_size}, flags); + layer.indexer_proj = create_tensor(tn(LLM_TENSOR_INDEXER_PROJ, "weight", i), {n_embd, hparams.indexer_n_head}, flags); + layer.indexer_attn_k = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_K, "weight", i), {n_embd, hparams.indexer_head_size}, flags); + layer.indexer_attn_q_b = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_Q_B, "weight", i), {q_lora_rank, hparams.indexer_n_head * hparams.indexer_head_size}, flags); + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, flags); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + // MoE branch + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, flags); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); + + // Shared expert branch + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, flags); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, flags); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, flags); + } + + // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); + + // Optional tensors + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED); + } + } + } break; case LLM_ARCH_PLM: { const int64_t n_embd_head_qk_rope = hparams.n_rot(); @@ -8238,7 +8392,7 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); } - if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_DEEPSEEK2OCR || arch == LLM_ARCH_GLM_DSA || arch == LLM_ARCH_MISTRAL4) { + if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_DEEPSEEK2OCR || arch == LLM_ARCH_DEEPSEEK32 || arch == LLM_ARCH_GLM_DSA || arch == LLM_ARCH_MISTRAL4) { LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); @@ -8426,6 +8580,23 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, { res = nullptr; } break; + case LLM_ARCH_DEEPSEEK32: + { + res = new llama_kv_cache_dsa( + *this, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + cparams.kv_unified, + cparams.n_ctx_seq, + cparams.n_seq_max, + 1, + hparams.n_swa, + hparams.swa_type, + nullptr, + nullptr); + } break; // Models that need standard caching should rely on recurrent/hybrid // checks default: @@ -8529,6 +8700,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, res = new llama_kv_cache( *this, + hparams, params.type_k, params.type_v, !cparams.flash_attn, @@ -8822,6 +8994,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_DEEPSEEK32: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_CHATGLM: { llm = std::make_unique(*this, params); @@ -9217,6 +9393,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_DEEPSEEK: case LLM_ARCH_DEEPSEEK2: case LLM_ARCH_DEEPSEEK2OCR: + case LLM_ARCH_DEEPSEEK32: case LLM_ARCH_PLM: case LLM_ARCH_CHATGLM: case LLM_ARCH_GRANITE: diff --git a/src/llama-model.h b/src/llama-model.h index 5f101bd6374..2e32361ec8d 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -137,6 +137,7 @@ enum llm_type { LLM_TYPE_310B_A15B, // /MiMo-V2-Flash LLM_TYPE_355B_A32B, // GLM-4.5 LLM_TYPE_397B_A17B, // Qwen3.5 + LLM_TYPE_685B_A37B, // DeepSeek V3.2 LLM_TYPE_744B_A40B, // GLM-5 LLM_TYPE_E2B, LLM_TYPE_E4B, diff --git a/src/models/deepseek32.cpp b/src/models/deepseek32.cpp new file mode 100644 index 00000000000..cb4171ca940 --- /dev/null +++ b/src/models/deepseek32.cpp @@ -0,0 +1,354 @@ +#include "models.h" + +#include "llama-kv-cache.h" +#include "llama-kv-cache-dsa.h" + +llm_build_deepseek32::llm_build_deepseek32(const llama_model & model, const llm_graph_params & params) : + llm_graph_context(params) { + const bool is_mla = hparams.is_mla(); + GGML_ASSERT(is_mla); + + // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA + const int64_t n_embd_head_k = hparams.n_embd_head_k_mla(); + const int64_t n_embd_head_v = hparams.n_embd_head_v_mla(); + GGML_UNUSED(n_embd_head_v); + + const int64_t n_embd_head_qk_rope = hparams.n_rot(); + const int64_t n_embd_head_qk_nope = n_embd_head_k - n_embd_head_qk_rope; + + const int64_t n_indexer_head = hparams.indexer_n_head; + const int64_t n_embd_indexer_head = hparams.indexer_head_size; + const int64_t n_embd_indexer_head_rope = hparams.n_rot(); + const int64_t n_embd_indexer_head_nope = n_embd_indexer_head - n_embd_indexer_head_rope; + const uint32_t n_indexer_top_k = hparams.indexer_top_k; + + const uint32_t kv_lora_rank = hparams.n_lora_kv; + + // We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly. + // See https://github.com/ggml-org/llama.cpp/discussions/7416 for detailed explanation. + // And also: https://github.com/ggml-org/llama.cpp/pull/17945 [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX] + + // first cancel the adjustment from llama_hparams::yarn_attn_factor_adjust to get the original attn_factor + GGML_ASSERT(ext_factor >= 0.0f); + const float attn_factor_org = attn_factor * (1.0f + 0.1f * logf(1.0f / freq_scale)); + + // use the original attn_factor to pre-scale the kq_scale + const float mscale = attn_factor_org * (1.0f + 0.1f * hparams.rope_yarn_log_mul * logf(1.0f / freq_scale)); + const float kq_scale = 1.0f * mscale * mscale / sqrtf(float(n_embd_head_k)); + + ggml_tensor * cur; + ggml_tensor * inpL; + + // {n_embd, n_tokens} + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + llm_graph_input_attn_k_dsa * inp_attn_dsa = build_attn_inp_k_dsa(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + int effective_n_layers = hparams.n_layer - hparams.nextn_predict_layers; + for (int il = 0; il < effective_n_layers; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self_attention + { + ggml_tensor * qr = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur); + cb(qr, "qr", il); + + qr = build_norm(qr, model.layers[il].attn_q_a_norm, nullptr, LLM_NORM_RMS, il); + cb(qr, "qr", il); + + ggml_tensor * top_k = nullptr; + + // lightning indexer + { + ggml_tensor * indexer_q = ggml_mul_mat(ctx0, model.layers[il].indexer_attn_q_b, qr); + cb(indexer_q, "indexer_q", il); + + // split into {n_embd_indexer_head_rope, n_indexer_head, n_tokens} + ggml_tensor * indexer_q_pe = + ggml_view_3d(ctx0, indexer_q, n_embd_indexer_head_rope, n_indexer_head, n_tokens, + ggml_row_size(indexer_q->type, n_embd_indexer_head), + ggml_row_size(indexer_q->type, n_embd_indexer_head) * n_indexer_head, 0); + cb(indexer_q_pe, "indexer_q_pe", il); + + // and {n_embd_indexer_head_nope, n_indexer_head, n_tokens} + ggml_tensor * indexer_q_nope = + ggml_view_3d(ctx0, indexer_q, n_embd_indexer_head_nope, n_indexer_head, n_tokens, + ggml_row_size(indexer_q->type, n_embd_indexer_head), + ggml_row_size(indexer_q->type, n_embd_indexer_head) * n_indexer_head, + ggml_row_size(indexer_q->type, n_embd_indexer_head_nope)); + cb(indexer_q_nope, "indexer_q_nope", il); + + indexer_q_pe = ggml_rope_ext(ctx0, indexer_q_pe, inp_pos, nullptr, n_rot, + LLAMA_ROPE_TYPE_NEOX, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(indexer_q_pe, "indexer_q_pe", il); + + // {n_embd_indexer_head_rope + n_embd_indexer_head_nope, n_head, n_tokens} + indexer_q = ggml_concat(ctx0, indexer_q_pe, indexer_q_nope, 0); + cb(indexer_q, "indexer_q", il); + + ggml_tensor * indexer_k = ggml_mul_mat(ctx0, model.layers[il].indexer_attn_k, cur); + cb(indexer_k, "indexer_k", il); + + indexer_k = build_norm(indexer_k, model.layers[il].indexer_k_norm, model.layers[il].indexer_k_norm_b, LLM_NORM, il); + cb(indexer_k, "indexer_k", il); + + // split into {n_embd_indexer_head_rope, 1, n_tokens} + ggml_tensor * indexer_k_pe = + ggml_view_3d(ctx0, indexer_k, n_embd_indexer_head_rope, 1, n_tokens, + ggml_row_size(indexer_k->type, n_embd_indexer_head), + ggml_row_size(indexer_k->type, n_embd_indexer_head) * 1, 0); + cb(indexer_k_pe, "indexer_k_pe", il); + + // and {n_embd_indexer_head_nope, 1, n_tokens} + ggml_tensor * indexer_k_nope = + ggml_view_3d(ctx0, indexer_k, n_embd_indexer_head_nope, 1, n_tokens, + ggml_row_size(indexer_k->type, n_embd_indexer_head), + ggml_row_size(indexer_k->type, n_embd_indexer_head) * 1, + ggml_row_size(indexer_k->type, n_embd_indexer_head_nope)); + cb(indexer_k_nope, "indexer_k_nope", il); + + indexer_k_pe = ggml_rope_ext(ctx0, indexer_k_pe, inp_pos, nullptr, n_rot, + LLAMA_ROPE_TYPE_NEOX, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(indexer_k_pe, "indexer_k_pe", il); + + // {n_embd_indexer_head_rope + n_embd_indexer_head_nope, 1, n_tokens} + indexer_k = ggml_concat(ctx0, indexer_k_pe, indexer_k_nope, 0); + cb(indexer_k, "indexer_k", il); + + // perform Hadamard transform on indexer q and k + indexer_q = ggml_mul_mat(ctx0, inp_attn_dsa->self_k_rot_lid, indexer_q); + cb(indexer_q, "indexer_q", il); + indexer_k = ggml_mul_mat(ctx0, inp_attn_dsa->self_k_rot_lid, indexer_k); + cb(indexer_k, "indexer_k", il); + + // store indexer keys to KV cache + const auto * mctx_lid = inp_attn_dsa->mctx->get_lid(); + const auto & k_idxs_lid = inp_attn_dsa->get_k_idxs_lid(); + ggml_build_forward_expand(gf, mctx_lid->cpy_k(ctx0, indexer_k, k_idxs_lid, il)); + + // prepare indexer weights + ggml_tensor * indexer_weights = ggml_mul_mat(ctx0, model.layers[il].indexer_proj, cur); + cb(indexer_weights, "indexer_weights", il); + + // get cached indexer keys + indexer_k = mctx_lid->get_k(ctx0, il); + + // split the batch into streams if needed + const auto n_stream = indexer_k->ne[3]; + indexer_q = ggml_view_4d(ctx0, indexer_q, indexer_q->ne[0], indexer_q->ne[1], indexer_q->ne[2]/n_stream, n_stream, indexer_q->nb[1], indexer_q->nb[2], indexer_q->nb[3]/n_stream, 0); + indexer_weights = ggml_view_4d(ctx0, indexer_weights, indexer_weights->ne[0], indexer_weights->ne[1]/n_stream, indexer_weights->ne[2], n_stream, indexer_weights->nb[1], indexer_weights->nb[2]/n_stream, indexer_weights->nb[3]/n_stream, 0); + +#if 1 + ggml_tensor * indexer_score = ggml_lightning_indexer(ctx0, indexer_q, indexer_k, indexer_weights, 1.0f / sqrtf(float(n_embd_indexer_head)), 1.0f / sqrtf(float(n_indexer_head))); + cb(indexer_score, "indexer_score", il); +#else + // calculate indexer kq + indexer_q = ggml_permute(ctx0, indexer_q, 0, 2, 1, 3); + cb(indexer_q, "indexer_q", il); + indexer_k = ggml_permute(ctx0, indexer_k, 0, 2, 1, 3); + cb(indexer_k, "indexer_k", il); + + ggml_tensor * indexer_kq = ggml_mul_mat(ctx0, indexer_k, indexer_q); + cb(indexer_kq, "indexer_kq", il); + + // ReLU requires contiguous tensors + indexer_kq = ggml_cont(ctx0, ggml_permute(ctx0, indexer_kq, 2, 1, 0, 3)); + cb(indexer_kq, "indexer_kq", il); + + // apply ReLU + ggml_tensor * indexer_score = ggml_relu(ctx0, indexer_kq); + cb(indexer_score, "indexer_score", il); + + // scale weights + indexer_weights = ggml_scale(ctx0, indexer_weights, 1.0f / sqrtf(float(n_indexer_head))); + cb(indexer_weights, "indexer_weights", il); + + // multiply scores by indexer weights + indexer_score = ggml_mul(ctx0, indexer_score, indexer_weights); + cb(indexer_score, "indexer_score", il); + + // sum by q n_indexer_head dimension + indexer_score = ggml_sum_rows(ctx0, indexer_score); + cb(indexer_score, "indexer_score", il); + + indexer_score = ggml_permute(ctx0, indexer_score, 2, 1, 0, 3); + cb(indexer_score, "indexer_score", il); + + indexer_score = ggml_cont(ctx0, indexer_score); + cb(indexer_score, "indexer_score", il); + + // TODO maybe pre-scale indexer weights, so we won't have to do it here + indexer_score = ggml_scale(ctx0, indexer_score, 1.0f / sqrtf(float(n_embd_indexer_head))); + cb(indexer_score, "indexer_score", il); +#endif + // mask indexer scores + ggml_tensor * indexer_kq_mask = inp_attn_dsa->get_kq_mask_lid(); + indexer_score = ggml_add(ctx0, indexer_score, indexer_kq_mask); + cb(indexer_score, "indexer_score", il); + + // get indices of top k indexer scores + uint32_t n_top_k = indexer_score->ne[0] < n_indexer_top_k ? indexer_score->ne[0] : n_indexer_top_k; + top_k = ggml_cont(ctx0, ggml_top_k(ctx0, indexer_score, n_top_k)); + cb(top_k, "top_k", il); + } + + ggml_tensor * q = ggml_mul_mat(ctx0, model.layers[il].wq_b, qr); + cb(q, "q", il); + + // split into {n_embd_head_qk_nope, n_head, n_tokens} + ggml_tensor * q_nope = + ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens, ggml_row_size(q->type, n_embd_head_k), + ggml_row_size(q->type, n_embd_head_k) * n_head, 0); + cb(q_nope, "q_nope", il); + + // and {n_embd_head_qk_rope, n_head, n_tokens} + ggml_tensor * q_pe = ggml_view_3d( + ctx0, q, n_embd_head_qk_rope, n_head, n_tokens, ggml_row_size(q->type, n_embd_head_k), + ggml_row_size(q->type, n_embd_head_k) * n_head, ggml_row_size(q->type, n_embd_head_qk_nope)); + cb(q_pe, "q_pe", il); + + ggml_tensor * kv_cmpr_pe = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur); + cb(kv_cmpr_pe, "kv_cmpr_pe", il); + + // split into {kv_lora_rank, n_tokens} + ggml_tensor * kv_cmpr = + ggml_view_2d(ctx0, kv_cmpr_pe, kv_lora_rank, n_tokens, + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), 0); + cb(kv_cmpr, "kv_cmpr", il); + + // and {n_embd_head_qk_rope, 1, n_tokens} + ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_cmpr_pe, n_embd_head_qk_rope, 1, n_tokens, + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank)); + cb(k_pe, "k_pe", il); + + q_pe = ggml_rope_ext(ctx0, q_pe, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(q_pe, "q_pe", il); + + k_pe = ggml_rope_ext(ctx0, k_pe, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(k_pe, "k_pe", il); + + kv_cmpr = build_norm(kv_cmpr, model.layers[il].attn_kv_a_norm, nullptr, LLM_NORM_RMS, il); + cb(kv_cmpr, "kv_cmpr", il); + + // MLA attention + { + // {n_embd_head_qk_nope, n_tokens, n_head} + q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3); + cb(q_nope, "q_nope_perm", il); + + // {n_embd_head_qk_nope, kv_lora_rank, n_head} x {n_embd_head_qk_nope, n_tokens, n_head} + ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, model.layers[il].wk_b, q_nope); + cb(q_nope_absorbed, "q_nope_absorbed", il); + + // {kv_lora_rank, n_head, n_tokens} + q_nope_absorbed = ggml_permute(ctx0, q_nope_absorbed, 0, 2, 1, 3); + cb(q_nope_absorbed, "q_nope_absorbed_perm", il); + + // {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens} + // note: rope must go first for in-place context shifting in build_rope_shift() + ggml_tensor * Qcur = ggml_concat(ctx0, q_nope_absorbed, q_pe, 0); + cb(Qcur, "Qcur", il); + + kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens); + cb(kv_cmpr, "kv_cmpr_reshape", il); + + // {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens} + ggml_tensor * Kcur = ggml_concat(ctx0, kv_cmpr, k_pe, 0); + cb(Kcur, "Kcur", il); + + // {kv_lora_rank, 1, n_tokens} + ggml_tensor * Vcur = kv_cmpr; + cb(Vcur, "Vcur", il); + + // note: MLA with the absorption optimization converts into MQA (ie: GQA with 1 group) + cur = build_attn(inp_attn_dsa, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, model.layers[il].wv_b, top_k, kq_scale, il); + } + } + if (il == effective_n_layers - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + if ((uint32_t) il < hparams.n_layer_dense_lead) { + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // MoE branch + ggml_tensor * moe_out = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il, + nullptr, + model.layers[il].ffn_gate_up_exps); + cb(moe_out, "ffn_moe_out", il); + + // FFN shared expert + { + ggml_tensor * ffn_shexp = + build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } + } + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/src/models/models.h b/src/models/models.h index 94991c55fe8..de8daad172a 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -186,12 +186,16 @@ struct llm_build_deci : public llm_graph_context { llm_build_deci(const llama_model & model, const llm_graph_params & params); }; +struct llm_build_deepseek : public llm_graph_context { + llm_build_deepseek(const llama_model & model, const llm_graph_params & params); +}; + struct llm_build_deepseek2 : public llm_graph_context { llm_build_deepseek2(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_deepseek : public llm_graph_context { - llm_build_deepseek(const llama_model & model, const llm_graph_params & params); +struct llm_build_deepseek32 : public llm_graph_context { + llm_build_deepseek32(const llama_model & model, const llm_graph_params & params); }; struct llm_build_dots1 : public llm_graph_context { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 828a9c14a45..de94dd1125b 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6622,6 +6622,57 @@ struct test_diag : public test_case { } }; +// GGML_OP_LIGHTNING_INDEXER +struct test_lightning_indexer : public test_case { + const ggml_type type_a; + const ggml_type type_b; + const ggml_type type_c; + const std::array ne_a; + const std::array ne_b; + const std::array ne_c; + float scale_embd; + float scale_heads; + + std::string vars() override { + return VARS_TO_STR8(type_a, type_b, type_c, ne_a, ne_b, ne_c, scale_embd, scale_heads); + } + + test_lightning_indexer(ggml_type type_a = GGML_TYPE_F32, + ggml_type type_b = GGML_TYPE_F16, + ggml_type type_c = GGML_TYPE_F32, + std::array ne_a = {128, 64, 128, 1}, + std::array ne_b = {128, 1, 256, 1}, + std::array ne_c = {64, 128, 1, 1}, + float scale_embd = 1.0f / sqrtf(float(128)), + float scale_heads = 1.0f / sqrtf(float(64))) + : type_a(type_a), type_b(type_b), type_c(type_c), ne_a(ne_a), ne_b(ne_b), ne_c(ne_c), scale_embd(scale_embd), scale_heads(scale_heads) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type_a, 4, ne_a.data()); + ggml_set_param(a); + ggml_set_name(a, "a"); + + ggml_tensor * b = ggml_new_tensor(ctx, type_b, 4, ne_b.data()); + ggml_set_param(b); + ggml_set_name(b, "b"); + + ggml_tensor * c = ggml_new_tensor(ctx, type_c, 4, ne_c.data()); + ggml_set_param(c); + ggml_set_name(c, "c"); + + ggml_tensor * out = ggml_lightning_indexer(ctx, a, b, c, scale_embd, scale_heads); + ggml_set_name(out, "out"); + + return out; + } + + void initialize_tensors(ggml_context * ctx) override { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + init_tensor_uniform(t); + } + } +}; + // Deserializable generic test case struct input_tensor { ggml_type type; @@ -8519,6 +8570,10 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_fill(2.0f, GGML_TYPE_F32, { 303, 207, 11, 3 })); test_cases.emplace_back(new test_fill(-152.0f, GGML_TYPE_F32, { 800, 600, 4, 4 })); test_cases.emplace_back(new test_fill(3.5f, GGML_TYPE_F32, { 2048, 512, 2, 2 })); + test_cases.emplace_back(new test_fill(0.0f, GGML_TYPE_F16)); + test_cases.emplace_back(new test_fill(2.0f, GGML_TYPE_F16, { 303, 207, 11, 3 })); + test_cases.emplace_back(new test_fill(-152.0f, GGML_TYPE_F16, { 800, 600, 4, 4 })); + test_cases.emplace_back(new test_fill(3.5f, GGML_TYPE_F16, { 2048, 512, 2, 2 })); test_cases.emplace_back(new test_diag()); test_cases.emplace_back(new test_diag(GGML_TYPE_F32, { 79, 1, 19, 13 })); @@ -8689,6 +8744,15 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_falcon(2)); #endif + // lightning_indexer + test_cases.emplace_back(new test_lightning_indexer(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {128, 64, 128, 1}, {128, 1, 256, 1}, {64, 128, 1, 1}, 1.0f / sqrtf(float(128)), 1.0f / sqrtf(float(64)))); + test_cases.emplace_back(new test_lightning_indexer(GGML_TYPE_F32, GGML_TYPE_Q4_0, GGML_TYPE_F32, {128, 64, 128, 1}, {128, 1, 256, 1}, {64, 128, 1, 1}, 1.0f / sqrtf(float(128)), 1.0f / sqrtf(float(64)))); + test_cases.emplace_back(new test_lightning_indexer(GGML_TYPE_F32, GGML_TYPE_Q4_1, GGML_TYPE_F32, {128, 64, 128, 1}, {128, 1, 256, 1}, {64, 128, 1, 1}, 1.0f / sqrtf(float(128)), 1.0f / sqrtf(float(64)))); + test_cases.emplace_back(new test_lightning_indexer(GGML_TYPE_F32, GGML_TYPE_Q5_0, GGML_TYPE_F32, {128, 64, 128, 1}, {128, 1, 256, 1}, {64, 128, 1, 1}, 1.0f / sqrtf(float(128)), 1.0f / sqrtf(float(64)))); + test_cases.emplace_back(new test_lightning_indexer(GGML_TYPE_F32, GGML_TYPE_Q5_1, GGML_TYPE_F32, {128, 64, 128, 1}, {128, 1, 256, 1}, {64, 128, 1, 1}, 1.0f / sqrtf(float(128)), 1.0f / sqrtf(float(64)))); + test_cases.emplace_back(new test_lightning_indexer(GGML_TYPE_F32, GGML_TYPE_Q8_0, GGML_TYPE_F32, {128, 64, 128, 1}, {128, 1, 256, 1}, {64, 128, 1, 1}, 1.0f / sqrtf(float(128)), 1.0f / sqrtf(float(64)))); + test_cases.emplace_back(new test_lightning_indexer(GGML_TYPE_F32, GGML_TYPE_BF16, GGML_TYPE_F32, {128, 64, 128, 1}, {128, 1, 256, 1}, {64, 128, 1, 1}, 1.0f / sqrtf(float(128)), 1.0f / sqrtf(float(64)))); + return test_cases; } #ifdef _MSC_VER @@ -8964,6 +9028,15 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 128, 1024, 1)); // 4h PP-1024 test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 64, 1, 1, false, true)); // KDA PP-64 + // lightning_indexer + test_cases.emplace_back(new test_lightning_indexer(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {128, 64, 128, 1}, {128, 1, 256, 1}, {64, 128, 1, 1}, 1.0f / sqrtf(float(128)), 1.0f / sqrtf(float(64)))); + test_cases.emplace_back(new test_lightning_indexer(GGML_TYPE_F32, GGML_TYPE_Q4_0, GGML_TYPE_F32, {128, 64, 128, 1}, {128, 1, 256, 1}, {64, 128, 1, 1}, 1.0f / sqrtf(float(128)), 1.0f / sqrtf(float(64)))); + test_cases.emplace_back(new test_lightning_indexer(GGML_TYPE_F32, GGML_TYPE_Q4_1, GGML_TYPE_F32, {128, 64, 128, 1}, {128, 1, 256, 1}, {64, 128, 1, 1}, 1.0f / sqrtf(float(128)), 1.0f / sqrtf(float(64)))); + test_cases.emplace_back(new test_lightning_indexer(GGML_TYPE_F32, GGML_TYPE_Q5_0, GGML_TYPE_F32, {128, 64, 128, 1}, {128, 1, 256, 1}, {64, 128, 1, 1}, 1.0f / sqrtf(float(128)), 1.0f / sqrtf(float(64)))); + test_cases.emplace_back(new test_lightning_indexer(GGML_TYPE_F32, GGML_TYPE_Q5_1, GGML_TYPE_F32, {128, 64, 128, 1}, {128, 1, 256, 1}, {64, 128, 1, 1}, 1.0f / sqrtf(float(128)), 1.0f / sqrtf(float(64)))); + test_cases.emplace_back(new test_lightning_indexer(GGML_TYPE_F32, GGML_TYPE_Q8_0, GGML_TYPE_F32, {128, 64, 128, 1}, {128, 1, 256, 1}, {64, 128, 1, 1}, 1.0f / sqrtf(float(128)), 1.0f / sqrtf(float(64)))); + test_cases.emplace_back(new test_lightning_indexer(GGML_TYPE_F32, GGML_TYPE_BF16, GGML_TYPE_F32, {128, 64, 128, 1}, {128, 1, 256, 1}, {64, 128, 1, 1}, 1.0f / sqrtf(float(128)), 1.0f / sqrtf(float(64)))); + return test_cases; } diff --git a/tests/test-llama-archs.cpp b/tests/test-llama-archs.cpp index 16af11a2862..81d4b8c9926 100644 --- a/tests/test-llama-archs.cpp +++ b/tests/test-llama-archs.cpp @@ -99,6 +99,7 @@ static gguf_context_ptr get_gguf_ctx(const llm_arch arch, const bool moe) { n_ff = 96; n_layer = 22; // hparams.n_layer_kv_from_start = 20 is hardcoded } else if (arch == LLM_ARCH_DEEPSEEK2 + || arch == LLM_ARCH_DEEPSEEK32 || arch == LLM_ARCH_GLM_DSA || arch == LLM_ARCH_KIMI_LINEAR || arch == LLM_ARCH_MISTRAL4) { @@ -155,6 +156,7 @@ static gguf_context_ptr get_gguf_ctx(const llm_arch arch, const bool moe) { ms.add_kv(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, 8.0f); if (arch == LLM_ARCH_DEEPSEEK2 + || arch == LLM_ARCH_DEEPSEEK32 || arch == LLM_ARCH_GLM_DSA || arch == LLM_ARCH_KIMI_LINEAR || arch == LLM_ARCH_MISTRAL4) { @@ -193,8 +195,8 @@ static gguf_context_ptr get_gguf_ctx(const llm_arch arch, const bool moe) { ms.add_kv(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, uint32_t(2)); } - ms.add_kv(LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, uint32_t(1)); - ms.add_kv(LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, uint32_t(64)); + ms.add_kv(LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, uint32_t(64)); + ms.add_kv(LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, uint32_t(128)); ms.add_kv(LLM_KV_ATTENTION_INDEXER_TOP_K, uint32_t(8)); ms.add_kv(LLM_KV_ROPE_DIMENSION_SECTIONS, std::vector({n_embd_head/4, n_embd_head/4, n_embd_head/4, n_embd_head/4})); ms.add_kv(LLM_KV_TOKENIZER_MODEL, "no_vocab"); @@ -331,6 +333,7 @@ static bool moe_mandatory(const llm_arch arch) { case LLM_ARCH_ARCTIC: case LLM_ARCH_DEEPSEEK: case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_DEEPSEEK32: case LLM_ARCH_GLM4_MOE: case LLM_ARCH_GLM_DSA: case LLM_ARCH_EXAONE_MOE: