diff --git a/src/voxtral.cpp b/src/voxtral.cpp index 62eaccc..289a29a 100644 --- a/src/voxtral.cpp +++ b/src/voxtral.cpp @@ -1076,14 +1076,8 @@ static void clear_kv_cache(voxtral_context * ctx) { if (!ctx || !ctx->kv_self_k || !ctx->kv_self_v) { return; } - void * k_data = ggml_get_data(ctx->kv_self_k); - void * v_data = ggml_get_data(ctx->kv_self_v); - if (k_data) { - memset(k_data, 0, ggml_nbytes(ctx->kv_self_k)); - } - if (v_data) { - memset(v_data, 0, ggml_nbytes(ctx->kv_self_v)); - } + ggml_backend_tensor_memset(ctx->kv_self_k, 0, 0, ggml_nbytes(ctx->kv_self_k)); + ggml_backend_tensor_memset(ctx->kv_self_v, 0, 0, ggml_nbytes(ctx->kv_self_v)); ctx->kv_used = 0; } @@ -1097,24 +1091,23 @@ static void kv_cache_shift_left(voxtral_context * ctx, int32_t shift) { return; } - uint8_t * k_data = (uint8_t *) ggml_get_data(ctx->kv_self_k); - uint8_t * v_data = (uint8_t *) ggml_get_data(ctx->kv_self_v); - if (!k_data || !v_data) { - return; - } - - const size_t row_bytes = ctx->kv_self_k->nb[1]; + const size_t row_bytes = ctx->kv_self_k->nb[1]; const size_t layer_stride = ctx->kv_self_k->nb[2]; + const size_t keep_bytes = (size_t)(window - shift) * row_bytes; + const size_t shift_offset = (size_t)shift * row_bytes; + + std::vector tmp(keep_bytes); for (int32_t l = 0; l < VOXTRAL_DEC_LAYERS; ++l) { - uint8_t * k_base = k_data + (size_t) l * layer_stride; - uint8_t * v_base = v_data + (size_t) l * layer_stride; + const size_t base = (size_t)l * layer_stride; - memmove(k_base, k_base + (size_t) shift * row_bytes, (size_t) (window - shift) * row_bytes); - memmove(v_base, v_base + (size_t) shift * row_bytes, (size_t) (window - shift) * row_bytes); + ggml_backend_tensor_get(ctx->kv_self_k, tmp.data(), base + shift_offset, keep_bytes); + ggml_backend_tensor_set(ctx->kv_self_k, tmp.data(), base, keep_bytes); + ggml_backend_tensor_memset(ctx->kv_self_k, 0, base + keep_bytes, shift_offset); - memset(k_base + (size_t) (window - shift) * row_bytes, 0, (size_t) shift * row_bytes); - memset(v_base + (size_t) (window - shift) * row_bytes, 0, (size_t) shift * row_bytes); + ggml_backend_tensor_get(ctx->kv_self_v, tmp.data(), base + shift_offset, keep_bytes); + ggml_backend_tensor_set(ctx->kv_self_v, tmp.data(), base, keep_bytes); + ggml_backend_tensor_memset(ctx->kv_self_v, 0, base + keep_bytes, shift_offset); } } @@ -1592,26 +1585,21 @@ static ggml_tensor * build_decoder_layer( ctx->kv_self_v->nb[1], layer_idx * ctx->kv_self_v->nb[2]); // [kv_dim, n_kv] - // Flash attention with GQA - // Q: [n_heads*head_dim, n_tokens] -> [head_dim, n_heads, n_tokens] -> [head_dim, n_tokens, n_heads] + // Reshape for flash attention: [head_dim, n_tokens/n_kv, n_heads/n_kv_heads] ggml_tensor * q3 = ggml_reshape_3d(gctx, q, VOXTRAL_DEC_HEAD_DIM, VOXTRAL_DEC_HEADS, n_tokens); q3 = ggml_permute(gctx, q3, 0, 2, 1, 3); // [head_dim, n_tokens, n_heads] - // K: [kv_dim, n_kv] -> [head_dim, n_kv_heads, n_kv] -> [head_dim, n_kv, n_kv_heads] ggml_tensor * k3 = ggml_reshape_3d(gctx, k_full, VOXTRAL_DEC_HEAD_DIM, VOXTRAL_DEC_KV_HEADS, n_kv); k3 = ggml_permute(gctx, k3, 0, 2, 1, 3); // [head_dim, n_kv, n_kv_heads] - // V: [kv_dim, n_kv] -> [head_dim, n_kv_heads, n_kv] -> [head_dim, n_kv, n_kv_heads] ggml_tensor * v3 = ggml_reshape_3d(gctx, v_full, VOXTRAL_DEC_HEAD_DIM, VOXTRAL_DEC_KV_HEADS, n_kv); v3 = ggml_permute(gctx, v3, 0, 2, 1, 3); // [head_dim, n_kv, n_kv_heads] const float scale = 1.0f / sqrtf((float)VOXTRAL_DEC_HEAD_DIM); - // ggml_flash_attn_ext fuses Q@K^T, scale, mask, softmax, @V in one op - // GQA broadcast is built-in (n_heads % n_kv_heads == 0) - // Mask is cast to F16 inside the graph if provided - ggml_tensor * attn_mask_f16 = attn_mask ? ggml_cast(gctx, attn_mask, GGML_TYPE_F16) : nullptr; - ggml_tensor * attn_out = ggml_flash_attn_ext(gctx, q3, k3, v3, attn_mask_f16, scale, 0.0f, 0.0f); + ggml_tensor * mask_f16 = attn_mask ? ggml_cast(gctx, attn_mask, GGML_TYPE_F16) : nullptr; + + ggml_tensor * attn_out = ggml_flash_attn_ext(gctx, q3, k3, v3, mask_f16, scale, 0.0f, 0.0f); // Output: [head_dim, n_heads, n_tokens] (already permuted by flash_attn_ext) attn_out = ggml_cont(gctx, attn_out); attn_out = ggml_reshape_2d(gctx, attn_out, VOXTRAL_DEC_HEADS * VOXTRAL_DEC_HEAD_DIM, n_tokens);