Skip to content

Support for DeepseekV32ForCausalLM with DeepSeek Sparse Attention (DSA)#21149

Draft
fairydreaming wants to merge 52 commits intoggml-org:masterfrom
fairydreaming:deepseek-dsa
Draft

Support for DeepseekV32ForCausalLM with DeepSeek Sparse Attention (DSA)#21149
fairydreaming wants to merge 52 commits intoggml-org:masterfrom
fairydreaming:deepseek-dsa

Conversation

@fairydreaming
Copy link
Copy Markdown
Collaborator

@fairydreaming fairydreaming commented Mar 29, 2026

Overview

This PR adds support for DeepseekV32ForCausalLM (DeepSeek V3.2 Exp, DeepSeek V3.2, DeepSeek V3.2 Speciale) models. It contains implementation of the lightning indexer and DeepSeek Sparse Attention (DSA) - both implemented in the simplest possible way as a proof of concept. So far only CPU and CUDA backends are supported.

Due to the way it's currently implemented it doesn't improve long context performance yet, more work is needed for this.

Some GGUFs for testing are available here (-light models), I uploaded Q8_0/Q4_K_M quants, so you need over 700GB/400GB of RAM/VRAM to run them.

I also created a 16GB baby DeepSeek V3.2 GGUF for VRAM-deprived people. It outputs incoherent gibberish, but should be useful for testing and optimizing this implementation even with limited resources.

I really could use some help with verifying the implementation correctness. If you have large GPU cluster and can run some benchmarks to compare results with official reported benchmark results for DeepSeek V3.2 models then go for it. More details in #21183.

Fixes #16331, #20363

Additional information

Decisions I made when implementing this:

  • new model arch DEEPSEEK32 was added (mostly a copy of existing GLM_DSA arch),
  • sparse attention was implemented by masking KQ mask entries corresponding to tokens that are not in the set of top-k tokens selected by the lightning indexer,
  • for this purpose I added new GGML op GGML_OP_SCATTER that works similar to torch scatter_ operation but is currently limited to setting tensor elements at specified indices to a given scalar value,
  • Hadamard transform was added as another new GGML op GGML_OP_HADAMARD with implementation borrowed from ik_llama.cpp (thx @ikawrakow), implementation from llama : rotate activations for better quantization #21038 was used in lightning indexer
  • KV cache was implemented as a new llama_kv_cache_dsa class which aggregates the usual llama_kv_cache that caches MLA latent representations (same as before for DeepSeek V3) and another new llama_ik_cache class (basically a copy of llama_kv_cache stripped of code related to V vector) that caches lightning indexer keys, two instances of llama_kv_cache - one for caching MLA latent representations, second for caching lightning indexer keys
  • since there are no official jinja templates for V3.2 and V3.2 Speciale, I simply decided to ignore this problem for now. You have to explicitly set chat template for these models (using jinja template from V3.2 Exp with these models will allow you to chat but tool calls won't work correctly). PR chat: dedicated DeepSeek v3.2 parser + "official" template #21785 added DeepSeek V3.2 chat template that you can use with --chat-template-file models/templates/deepseek-ai-DeepSeek-V3.2.jinja

Requirements

Due to limitations of the current CUDA ggml_top_k() implementation NVIDIA CUDA CCCL library (version >3.2) and enabling GGML_CUDA_USE_CUB during CUDA backend compilation is needed, otherwise the CUDA implementation will crash for context sizes larger than (I think) 1024 tokens. I use it with CUDA 13.2 and CCCL 13.2.27.
Bug in ggml_top_k() is now fixed, fix is merged, so it should work even on 2.[89] CUDA without CCCL.

Also if you want to convert the model by yourself, set add_bos_token to true in tokenizer_config.json before the model conversion - this is needed for DeepSeek V3.2 and DeepSeek V3.2 Speciale. The conversion script has assert that checks this.

Next Steps

  • I'd like to confirm my architectural choices regarding the implementation,
  • If they are accepted I will clean up the code if needed, merge with the current master and it will be ready for code review,
  • If not then So Long, and Thanks for All the Fish. Just joking, we can talk about this.

  • I have read and agree with the contributing guidelines
  • AI usage disclosure: YES, AI was used as an assistant helping me find bugs in CUDA kernel implementations.

sszymczy added 26 commits March 12, 2026 13:15
…e attention). Needs manual change of add_bos_token to true in tokenizer_config.json before conversion.
…indexer implementation since the former fails for large tensors even when using CCCL.
… of llama_kv_cache and new llama_ik_cache (lightning indexer key cache).

model : used new llama_kv_cache_dsa instead of modified llama_kv_cache with indexer keys in DeepseekV32ForCausalLM
model : removed non-MLA path in DeepseekV32ForCausalLM
…e can get rid of ggml_cast() calls in sparse attention implementation
@fairydreaming fairydreaming requested review from a team, CISC and ggerganov as code owners March 29, 2026 12:56
@fairydreaming fairydreaming marked this pull request as draft March 29, 2026 12:56
@fairydreaming
Copy link
Copy Markdown
Collaborator Author

I managed to get rid of GGML_OP_SCATTER by using ggml_set_rows() with 1-element rows, @jeffbolznv thanks for inspiration!

@fairydreaming
Copy link
Copy Markdown
Collaborator Author

I did some experiment trying to optimize long context inference on a small 4-layers DeepSeek V3.2 model:

No optimization

(base) phm@epyc:~/projects/llama.cpp-deepseek-dsa/build-cuda$ ./bin/llama-bench -m ../models/DeepSeek-V3.2-4Layers-Q8_0.gguf -p 0 -n 32 -ub 2048 -d 0,8192,16386,32768,65536,131072 -r 3 -fa 1
ggml_cuda_init: found 1 CUDA devices (Total VRAM: 97247 MiB):
  Device 0: NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition, compute capability 12.0, VMM: yes, VRAM: 97247 MiB
| model                          |       size |     params | backend    | ngl | n_ubatch | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | --------------: | -------------------: |
| deepseek32 ?B Q8_0             |  15.02 GiB |    15.17 B | CUDA       |  99 |     2048 |  1 |            tg32 |       334.62 ± 16.92 |
| deepseek32 ?B Q8_0             |  15.02 GiB |    15.17 B | CUDA       |  99 |     2048 |  1 |    tg32 @ d8192 |        295.75 ± 3.15 |
| deepseek32 ?B Q8_0             |  15.02 GiB |    15.17 B | CUDA       |  99 |     2048 |  1 |   tg32 @ d16386 |        279.14 ± 2.38 |
| deepseek32 ?B Q8_0             |  15.02 GiB |    15.17 B | CUDA       |  99 |     2048 |  1 |   tg32 @ d32768 |        258.88 ± 1.91 |
| deepseek32 ?B Q8_0             |  15.02 GiB |    15.17 B | CUDA       |  99 |     2048 |  1 |   tg32 @ d65536 |        208.80 ± 1.17 |
| deepseek32 ?B Q8_0             |  15.02 GiB |    15.17 B | CUDA       |  99 |     2048 |  1 |  tg32 @ d131072 |        156.44 ± 0.65 |

Using ggml_get_rows() on KQ mask and KV cache to convert sparse attention into a dense one for batch size 1

(base) phm@epyc:~/projects/llama.cpp-deepseek-dsa/build-cuda$ ./bin/llama-bench -m ../models/DeepSeek-V3.2-4Layers-Q8_0.gguf -p 0 -n 32 -ub 2048 -d 0,8192,16386,32768,65536,131072 -r 3 -fa 1
ggml_cuda_init: found 1 CUDA devices (Total VRAM: 97247 MiB):
  Device 0: NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition, compute capability 12.0, VMM: yes, VRAM: 97247 MiB
| model                          |       size |     params | backend    | ngl | n_ubatch | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | --------------: | -------------------: |
| deepseek32 ?B Q8_0             |  15.02 GiB |    15.17 B | CUDA       |  99 |     2048 |  1 |            tg32 |       311.49 ± 14.36 |
| deepseek32 ?B Q8_0             |  15.02 GiB |    15.17 B | CUDA       |  99 |     2048 |  1 |    tg32 @ d8192 |        283.81 ± 2.98 |
| deepseek32 ?B Q8_0             |  15.02 GiB |    15.17 B | CUDA       |  99 |     2048 |  1 |   tg32 @ d16386 |        277.83 ± 3.28 |
| deepseek32 ?B Q8_0             |  15.02 GiB |    15.17 B | CUDA       |  99 |     2048 |  1 |   tg32 @ d32768 |        241.28 ± 2.16 |
| deepseek32 ?B Q8_0             |  15.02 GiB |    15.17 B | CUDA       |  99 |     2048 |  1 |   tg32 @ d65536 |        228.12 ± 1.57 |
| deepseek32 ?B Q8_0             |  15.02 GiB |    15.17 B | CUDA       |  99 |     2048 |  1 |  tg32 @ d131072 |        198.19 ± 1.44 |

so it works but only for very long contexts. There is one problem though, ggml_get_rows() currently automatically dequantizes/converts the output to f32, so I had to add a boolean parameter to disable this behavior, which resulted in changes scattered over the whole codebase (all ggml_get_rows() occurences). This was messy, so I'm not going to commit it here, maybe later as a separate PR.

Using sparse variant of CUDA fattn-vec.cuh kernel

(base) phm@epyc:~/projects/llama.cpp-deepseek-dsa/build-cuda$ ./bin/llama-bench -m ../models/DeepSeek-V3.2-4Layers-Q8_0.gguf -p 0 -n 32 -ub 2048 -d 0,8192,16386,32768,65536,131072 -r 3 -fa 1
ggml_cuda_init: found 1 CUDA devices (Total VRAM: 97247 MiB):
  Device 0: NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition, compute capability 12.0, VMM: yes, VRAM: 97247 MiB
| model                          |       size |     params | backend    | ngl | n_ubatch | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | --------------: | -------------------: |
| deepseek32 ?B Q8_0             |  15.02 GiB |    15.17 B | CUDA       |  99 |     2048 |  1 |            tg32 |       334.76 ± 15.91 |
| deepseek32 ?B Q8_0             |  15.02 GiB |    15.17 B | CUDA       |  99 |     2048 |  1 |    tg32 @ d8192 |        278.86 ± 3.02 |
| deepseek32 ?B Q8_0             |  15.02 GiB |    15.17 B | CUDA       |  99 |     2048 |  1 |   tg32 @ d16386 |        273.21 ± 2.23 |
| deepseek32 ?B Q8_0             |  15.02 GiB |    15.17 B | CUDA       |  99 |     2048 |  1 |   tg32 @ d32768 |        269.34 ± 2.02 |
| deepseek32 ?B Q8_0             |  15.02 GiB |    15.17 B | CUDA       |  99 |     2048 |  1 |   tg32 @ d65536 |        246.40 ± 1.78 |
| deepseek32 ?B Q8_0             |  15.02 GiB |    15.17 B | CUDA       |  99 |     2048 |  1 |  tg32 @ d131072 |        212.32 ± 1.33 |

This worked better than ggml_get_rows(), but first I had to modify the kernel to support different QK and V dimensions, then implement the simplest possible sparse attention variant (iterate over top_k KV cache vectors instead of the whole KV cache) and modify GGML_OP_FLASH_ATTN_EXT by passing additional top_k argument. From what I read in #16817 this is not how the sparse attention support should work (it's clearly stated here that KQ mask should be used as a source of the sparsity information), so again I'm not going to commit it here, maybe later as a separate PR.

Also are there any Vulkanologists here that would know why the Vulkan CIs crash on test-llama-archs (likely related to added DeepSeek V3.2 arch)?

@pwilkin
Copy link
Copy Markdown
Member

pwilkin commented Apr 20, 2026

Vulkanologists needed! @jeffbolznv @0cc4m

@fairydreaming
Copy link
Copy Markdown
Collaborator Author

fairydreaming commented Apr 20, 2026

Vulkanologists needed! @jeffbolznv @0cc4m

Some more info - looks like a problem in SET_ROWS:

ggml_vk_build_graph(0x5555565c1d30, SET_ROWS)
ggml_vk_op_f32((0x5555565c1bc0, name=node_41, type=0, ne0=1, ne1=512, ne2=512, ne3=1, nb0=4, nb1=4, nb2=2048, nb3=1048576), (0x5555565c18e0, name=top_k-0 (view), type=26, ne0=512, ne1=512, ne2=1, ne3=1, nb0=4, nb1=2048, nb2=1048576, nb3=1048576), (0x5555565c1d30, name= (view) (view), type=1, ne0=1, ne1=512, ne2=512, ne3=1, nb0=2, nb1=2, nb2=1024, nb3=524288), SET_ROWS)
ggml_pipeline_request_descriptor_sets(set_rows_f16_i32, 1)
ggml_vk_load_shaders(Vulkan0)
[New Thread 0x7fff8afe56c0 (LWP 114349)]
ggml_vk_create_pipeline(Vulkan0, set_rows_f16_i32, main, 3, (1,1,1), specialization_constants, 1, 0, 0)
set_rows_f16_i32 Register Count: 16 registers
[Thread 0x7fff8afe56c0 (LWP 114349) exited]

Thread 1 "llama-bench" received signal SIGSEGV, Segmentation fault.
0x00007ffff7961678 in __gnu_cxx::__atomic_add (__val=1, __mem=0xd44a00011407) at /usr/include/c++/13/ext/atomicity.h:71
71	  { __atomic_fetch_add(__mem, __val, __ATOMIC_ACQ_REL); }
(gdb) bt
#0  0x00007ffff7961678 in __gnu_cxx::__atomic_add (__val=1, __mem=0xd44a00011407) at /usr/include/c++/13/ext/atomicity.h:71
#1  __gnu_cxx::__atomic_add_dispatch (__val=1, __mem=0xd44a00011407) at /usr/include/c++/13/ext/atomicity.h:111
#2  std::_Sp_counted_base<(__gnu_cxx::_Lock_policy)2>::_M_add_ref_copy (this=0xd44a000113ff)
    at /usr/include/c++/13/bits/shared_ptr_base.h:152
#3  0x00007ffff7955af3 in std::__shared_count<(__gnu_cxx::_Lock_policy)2>::operator= (this=0x7fffffff6828, __r=...)
    at /usr/include/c++/13/bits/shared_ptr_base.h:1088
#4  0x00007ffff2949d4d in std::__shared_ptr<vk_buffer_struct, (__gnu_cxx::_Lock_policy)2>::operator= (this=0x7fffffff6820)
    at /usr/include/c++/13/bits/shared_ptr_base.h:1523
#5  0x00007ffff2949d7b in std::shared_ptr<vk_buffer_struct>::operator= (this=0x7fffffff6820) at /usr/include/c++/13/bits/shared_ptr.h:414
#6  0x00007ffff2877369 in ggml_vk_tensor_subbuffer (ctx=0x555555e4ac40, tensor=0x5555565c1d30, allow_misalign=true)
    at /home/phm/projects/llama.cpp-deepseek-dsa/ggml/src/ggml-vulkan/ggml-vulkan.cpp:6536
#7  0x00007ffff28e8b99 in ggml_vk_op_f32<vk_op_binary_push_constants> (ctx=0x555555e4ac40, 
    subctx=std::shared_ptr<vk_context_struct> (use count 3, weak count 1) = {...}, src0=0x5555565c1bc0, src1=0x5555565c18e0, src2=0x0, 
    src3=0x0, dst=0x5555565c1d30, op=GGML_OP_SET_ROWS, pc=...)
    at /home/phm/projects/llama.cpp-deepseek-dsa/ggml/src/ggml-vulkan/ggml-vulkan.cpp:9974
#8  0x00007ffff28acf25 in ggml_vk_set_rows (ctx=0x555555e4ac40, 
    subctx=std::shared_ptr<vk_context_struct> (use count 3, weak count 1) = {...}, src0=0x5555565c1bc0, src1=0x5555565c18e0, 
    dst=0x5555565c1d30) at /home/phm/projects/llama.cpp-deepseek-dsa/ggml/src/ggml-vulkan/ggml-vulkan.cpp:10962
#9  0x00007ffff28b5f0d in ggml_vk_build_graph (ctx=0x555555e4ac40, cgraph=0x5555560c9110, node_idx=0, node_begin=0x5555565c1d30, 
    node_idx_begin=0, last_node=false, almost_ready=false, submit=false)
    at /home/phm/projects/llama.cpp-deepseek-dsa/ggml/src/ggml-vulkan/ggml-vulkan.cpp:13150
#10 0x00007ffff28c4e5e in ggml_backend_vk_graph_compute (backend=0x555557bb5f40, cgraph=0x5555560c9110)
    at /home/phm/projects/llama.cpp-deepseek-dsa/ggml/src/ggml-vulkan/ggml-vulkan.cpp:14722
#11 0x00007ffff252f262 in ggml_backend_graph_compute_async (backend=0x555557bb5f40, cgraph=0x5555560c9110)
    at /home/phm/projects/llama.cpp-deepseek-dsa/ggml/src/ggml-backend.cpp:452
#12 0x00007ffff25343d1 in ggml_backend_sched_compute_splits (sched=0x555557ff3540)
    at /home/phm/projects/llama.cpp-deepseek-dsa/ggml/src/ggml-backend.cpp:1678
#13 0x00007ffff253539d in ggml_backend_sched_graph_compute_async (sched=0x555557ff3540, graph=0x5555565b1df0)
--Type <RET> for more, q to quit, c to continue without paging--
    at /home/phm/projects/llama.cpp-deepseek-dsa/ggml/src/ggml-backend.cpp:1901
#14 0x00007ffff6cdce9a in llama_context::graph_compute (this=0x555555f2eef0, gf=0x5555565b1df0, batched=true)
    at /home/phm/projects/llama.cpp-deepseek-dsa/src/llama-context.cpp:2191
#15 0x00007ffff6cd8019 in llama_context::process_ubatch (this=0x555555f2eef0, ubatch=..., gtype=LLM_GRAPH_TYPE_DECODER, 
    mctx=0x555555f20df0, ret=@0x7fffffffce30: GGML_STATUS_SUCCESS) at /home/phm/projects/llama.cpp-deepseek-dsa/src/llama-context.cpp:1231
#16 0x00007ffff6cda1d3 in llama_context::decode (this=0x555555f2eef0, batch_inp=...)
    at /home/phm/projects/llama.cpp-deepseek-dsa/src/llama-context.cpp:1692
#17 0x00007ffff6ce1baf in llama_decode (ctx=0x555555f2eef0, batch=...)
    at /home/phm/projects/llama.cpp-deepseek-dsa/src/llama-context.cpp:3454
#18 0x00005555555d4510 in test_prompt (ctx=0x555555f2eef0, n_prompt=512, n_batch=2048, n_threads=32)
    at /home/phm/projects/llama.cpp-deepseek-dsa/tools/llama-bench/llama-bench.cpp:2081
#19 0x00005555555d54dd in main (argc=15, argv=0x7fffffffe018)
    at /home/phm/projects/llama.cpp-deepseek-dsa/tools/llama-bench/llama-bench.cpp:2305

Edit: I added this case (same type and dimensions) to test-backend-ops and it works, so it may be a different problem.

@0cc4m
Copy link
Copy Markdown
Contributor

0cc4m commented Apr 20, 2026

I'll take a look.

@0cc4m
Copy link
Copy Markdown
Contributor

0cc4m commented Apr 20, 2026

For me it crashes inside of overlaps_unsynced, can you take a look @jeffbolznv ? You know more about how that works.

@jeffbolznv
Copy link
Copy Markdown
Contributor

Sure, I'll look.

@jeffbolznv
Copy link
Copy Markdown
Contributor

I got the same crash as in #21149 (comment). Looks like the destination tensor is in a host buffer, which I think is unexpected. But I'm not sure why that's happening.

@jeffbolznv
Copy link
Copy Markdown
Contributor

I guess what's happening is the model building code is taking the kq_mask input and trying to set_rows in a view of it, but we shouldn't be writing to input tensors. This diff works around it, but I don't know the model building code well enough to say what the right fix really is:

diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp
index a50715bed..459dbfcb8 100644
--- a/src/llama-graph.cpp
+++ b/src/llama-graph.cpp
@@ -2400,6 +2400,8 @@ ggml_tensor * llm_graph_context::build_attn(
     // [n_kv, n_batch, 1, n_stream] -> [1, n_kv, n_batch, n_stream]
     kq_mask_all = ggml_view_4d(ctx0, kq_mask_all, 1, kq_mask_all->ne[0], kq_mask_all->ne[1], kq_mask_all->ne[3], kq_mask_all->nb[0], kq_mask_all->nb[1], kq_mask_all->nb[2], 0);

+    kq_mask_all = ggml_cont(ctx0, kq_mask_all);
+
     // reshape top_k indices: [n_top_k, n_batch, 1, n_stream] -> [n_top_k, n_batch, n_stream, 1]
     top_k = ggml_view_4d(ctx0, top_k, top_k->ne[0], top_k->ne[1], top_k->ne[3], 1, top_k->nb[1], top_k->nb[2], top_k->ne[3]*top_k->nb[3], 0);

I don't think the vulkan backend is doing anything wrong, though maybe we should sanity check that we're not trying to write to host buffers?

@fairydreaming
Copy link
Copy Markdown
Collaborator Author

@jeffbolznv Hmm, but in the DSA attention implementation the first argument to ggml_set_rows() is kq_mask_all, that is a result of ggml_fill(kq_mask). Sure kq_mask is an input tensor, but kq_mask_all shouldn't be one? This ggml_fill() is not inplace, so it should do ggml_dup_tensor() first that would allocate Vulkan buffer. Or perhaps I'm misunderstanding something there?

@jeffbolznv
Copy link
Copy Markdown
Contributor

Hmm, I see this is an f16 fill, which the vulkan backend doesn't currently support, so I guess that makes it still be an input tensor after the graph is split. I don't know how this is supposed to be handled - a tensor ends up being an input to the graph split and we're supposed to write to it?

But ignoring this more general issue, I'll add f16 support for fill and see if it helps.

BTW, I had asked Claude 4.7 about this and it was grinding away when I realized what was happening, but it came to the same conclusion about missing f16 fill support. Nice!

@jeffbolznv
Copy link
Copy Markdown
Contributor

#22177 should fix this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning model Model specific Nvidia GPU Issues specific to Nvidia GPUs python python script changes testing Everything test related

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Feature Request: DeepSeek V3.2-Exp support