From 97c30e5039df58295b5eba4a2978e42f62046fe7 Mon Sep 17 00:00:00 2001 From: chang-wenbin Date: Sun, 28 Jun 2026 23:31:03 +0800 Subject: [PATCH] update dsk-v3 --- fastdeploy/model_executor/layers/linear.py | 24 ++++++++++ .../model_executor/models/deepseek_v3.py | 44 ++++++------------- 2 files changed, 37 insertions(+), 31 deletions(-) diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index 90de0f7d033..193bd80cd90 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -1129,6 +1129,30 @@ def forward_k_b(self, x: paddle.Tensor) -> paddle.Tensor: out = paddle.bmm(x, self.k_b_proj_weight) return out + def forward_k_b_thd(self, x: paddle.Tensor) -> paddle.Tensor: + """ + Forward K_b projection for token-head-dim layout. + + Args: + x: Input tensor with shape [tokens, heads, qk_nope_head_dim] + + Returns: + K_b projection output with shape [tokens, heads, kv_lora_rank] + """ + return paddle.einsum("thd,hdr->thr", x, self.k_b_proj_weight) + + def forward_v_b_htr(self, x: paddle.Tensor) -> paddle.Tensor: + """ + Forward V_b projection for token-head-rank layout. + + Args: + x: Input tensor with shape [tokens, heads, kv_lora_rank] + + Returns: + V_b projection output with shape [tokens, heads, v_head_dim] + """ + return paddle.einsum("thr,hrv->thv", x, self.v_b_proj_weight) + def forward_v_b(self, x: paddle.Tensor) -> paddle.Tensor: """ Forward pass for V_b projection using bmm diff --git a/fastdeploy/model_executor/models/deepseek_v3.py b/fastdeploy/model_executor/models/deepseek_v3.py index 953a294914e..14a6d62d117 100644 --- a/fastdeploy/model_executor/models/deepseek_v3.py +++ b/fastdeploy/model_executor/models/deepseek_v3.py @@ -454,8 +454,7 @@ def forward_swa_static( key_pe: paddle.Tensor, ): """MLA static attention with sliding window indexer.""" - q_nope_out = self.kv_b_proj_bmm(query_nope.transpose([1, 0, 2]), proj_type="k").transpose([1, 0, 2]) - + q_nope_out = self.kv_b_proj_bmm.forward_k_b_thd(query_nope) q_input = paddle.concat([q_nope_out, query_pe], axis=-1) q_input.reshape_( [ @@ -488,13 +487,10 @@ def forward_swa_static( attn_softmax_scale=self.attn_softmax_scale, ) - fmqa_out = fmqa_out.reshape_([-1, self.num_attention_heads_tp, self.kv_lora_rank]).transpose([1, 0, 2]) + fmqa_out = fmqa_out.reshape_([-1, self.num_attention_heads_tp, self.kv_lora_rank]) - return ( - self.kv_b_proj_bmm(fmqa_out, proj_type="v") - .transpose([1, 0, 2]) - .reshape_([-1, self.num_attention_heads_tp * self.v_head_dim]) - ) + fmqa_out = self.kv_b_proj_bmm.forward_v_b_htr(fmqa_out) + return fmqa_out.reshape_([-1, self.num_attention_heads_tp * self.v_head_dim]) def forward( self, @@ -578,8 +574,6 @@ def forward( ) key[..., : self.qk_nope_head_dim] = key_nope key[..., self.qk_nope_head_dim :] = full_k_pe.unsqueeze(1) - if self.qk_head_dim - self.v_head_dim != 0: - value = paddle.nn.functional.pad(value, [0, self.qk_head_dim - self.v_head_dim], value=0) fmha_out = self.mla_attn( q=query, @@ -591,8 +585,6 @@ def forward( forward_meta=forward_meta, ) - fmha_out.reshape_([-1, self.num_attention_heads_tp, self.qk_head_dim]) - fmha_out = fmha_out[:, :, : self.v_head_dim] fmha_out.reshape_([-1, self.num_attention_heads_tp * self.v_head_dim]) attn_out = fmha_out @@ -627,7 +619,7 @@ def forward( query_nope = decoder_query_nope.reshape([0, -1, self.qk_nope_head_dim]) query_pe = decoder_query_pe.reshape([0, -1, self.qk_rope_head_dim]) - q_nope_out = self.kv_b_proj_bmm(query_nope.transpose([1, 0, 2]), proj_type="k").transpose([1, 0, 2]) + q_nope_out = self.kv_b_proj_bmm.forward_k_b_thd(query_nope) q_input = paddle.concat([q_nope_out, query_pe], axis=-1) q_input.reshape_( @@ -647,13 +639,9 @@ def forward( forward_meta=forward_meta, ) - fmqa_out = fmqa_out.reshape_([-1, self.num_attention_heads_tp, self.kv_lora_rank]).transpose([1, 0, 2]) - - fmqa_out = ( - self.kv_b_proj_bmm(fmqa_out, proj_type="v") - .transpose([1, 0, 2]) - .reshape_([-1, self.num_attention_heads_tp * self.v_head_dim]) - ) + fmqa_out = fmqa_out.reshape_([-1, self.num_attention_heads_tp, self.kv_lora_rank]) + fmqa_out = self.kv_b_proj_bmm.forward_v_b_htr(fmqa_out) + fmqa_out = fmqa_out.reshape_([-1, self.num_attention_heads_tp * self.v_head_dim]) if int(os.getenv("USE_FLASH_MLA", "0")) == 0 and self.prop.major == 9: pass @@ -1036,8 +1024,8 @@ def forward( query_nope, query_pe = query.split([self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1) query_pe, key_pe = self.rotary_emb(forward_meta.position_ids, query_pe, key_pe) - q_nope_out = self.kv_b_proj_bmm(query_nope.transpose([1, 0, 2]).contiguous(), proj_type="k") - q_input = paddle.concat([q_nope_out.transpose([1, 0, 2]).contiguous(), query_pe], axis=-1) + q_nope_out = self.kv_b_proj_bmm.forward_k_b_thd(query_nope) + q_input = paddle.concat([q_nope_out, query_pe], axis=-1) compressed_kv = self.kv_a_layernorm(compressed_kv)[0] # kv = paddle.concat([compressed_kv, key_pe.squeeze(1)], axis=-1) @@ -1053,15 +1041,9 @@ def forward( forward_meta=forward_meta, ) - fmha_out = fmha_out.reshape_([-1, self.num_attention_heads_tp, self.kv_lora_rank]).transpose([1, 0, 2]) - fmha_out = ( - self.kv_b_proj_bmm( - fmha_out, - proj_type="v", - ) - .transpose([1, 0, 2]) - .reshape_([-1, self.num_attention_heads_tp * self.v_head_dim]) - ) + fmha_out = fmha_out.reshape_([-1, self.num_attention_heads_tp, self.kv_lora_rank]) + fmha_out = self.kv_b_proj_bmm.forward_v_b_htr(fmha_out) + fmha_out = fmha_out.reshape_([-1, self.num_attention_heads_tp * self.v_head_dim]) output = self.o_proj(fmha_out)