Skip to content

Add GGML fused quantized dense and gather ops#364

Open
ajroetker wants to merge 17 commits intogomlx:mainfrom
ajroetker:gguf-fused-quantized-dense
Open

Add GGML fused quantized dense and gather ops#364
ajroetker wants to merge 17 commits intogomlx:mainfrom
ajroetker:gguf-fused-quantized-dense

Conversation

@ajroetker
Copy link
Copy Markdown
Contributor

@ajroetker ajroetker commented Mar 11, 2026

Summary

  • Add FusedQuantizedGather for quantized embedding lookups (analogous to ggml_get_rows)
  • Add FusedQuantizedDense for fused dequant + matmul + bias + activation on GGML-quantized weights
  • Add GGML dequantization support for Q4_0, Q8_0, IQ4_NL, Q4_K, Q6_K types
  • Add shift operations (LogicalShiftLeft/Right) needed for sub-byte unpacking
  • Add QuantizedGather high-level layer in pkg/ml/nn

Review feedback addressed

  • Reverted buffer pool bucketing changes to a future PR
  • Reverted execBitcast fix to separate PR: Fix execBitcast buffer reuse for cross-bit-width types #374
  • Split quantized executors into exec_fused_quantized.go and exec_fused_quantized_ggml.go
  • Added GGML block format documentation with references to ggml/llama.cpp
  • Added detailed deriveGGMLK docs explaining N/K dimensions
  • Renamed table/tableQuantizationdata/dataQuantization in FusedQuantizedGather
  • Used errors.Wrapf(backends.ErrNotImplemented, ...) with feature-request guidance

Test plan

  • go build ./... passes
  • Existing tests pass
  • go test ./backends/simplego/... — includes new benchmark for fused ops
  • go test ./pkg/ml/nn/... — includes new quantized dense test

- Add QuantGGML quantization scheme with GGMLQuantType enum supporting
  Q4_0, Q8_0, IQ4_NL, Q2_K, Q3_K, Q4_K, Q5_K, Q6_K block formats
- Add FusedQuantizedGather op for quantized embedding lookups with
  graph builder, shape inference, and simplego executor
- Implement GGML dequantization executors (Q4_0, Q8_0, Q4_K, Q6_K)
  for both FusedQuantizedDense and FusedQuantizedGather
- Add QuantizedGather nn layer for quantized embedding tables
- Fix packedLen to only pack sub-byte integer types (Int4/Uint4),
  not Bool
- Add benchmarks for GGML dequantization and quantized dense tests
…y bench dep

- exec_shift_ops.go: use unexported binaryOperandsAndOutput/newBroadcastIterator
- Remove BenchmarkQuantizedDenseGGML that depends on go-highway/gguf
Resolve conflict in buffers.go: keep bucketSize/subSliceFlat/fullCapFlat
from gguf branch, drop isPackedSubByteDType/packedLen (replaced by
dtype.IsPacked() and dtype.SizeForDimensions()). Fix subSliceFlat call
in getBuffer to use element count for non-packed types and byte count
for packed types.
The parameter describes the quantization of the table (embedding matrix),
not generic weights. This makes the API more precise and consistent with
the table/indices naming used elsewhere in the function signature.
- Fix ggmlFp16LE subnormal path: use int32 for exponent to avoid
  fragile uint32 underflow/cast chain
- Add allow-list check to FusedQuantizedGather so unsupported GGML
  types (e.g. IQ4_NL) return ErrNotImplemented at build time,
  enabling transparent fallback via InternalFusedOpCaller
- Extract deriveGGMLK helper to deduplicate K derivation+validation
  in fusedQuantizedDenseGGML and FusedQuantizedGather
- Extract processColumn closure in quantizedDenseGGML to deduplicate
  serial/parallel branches
- Extract flatToIntSlice helper to deduplicate 3-way type switch in
  execFusedQuantizedGather
- Extract extractNibbleBlock helper to deduplicate shared Q4_0/IQ4_NL
  nibble unpacking in dequant.go
- Remove dead _ = vpb assignment in Dequant
- Remove duplicated comment in local.go
- Merge identical BytesPerBlock cases (GGMLQ4_0, GGMLIQ4NL)
…xecutor, harden gather

- Hand-write BackendFusedQuantizedGather in fused_ops.go using toBackend(),
  matching the BackendFusedQuantizedDense pattern. Remove auto-generated version.
- Add dequantIQ4NLRow fused executor and register IQ4_NL in builder switches
  for both Dense and Gather, so IQ4_NL no longer always falls back to decomposed.
- Add bounds check on rowIdx in execFusedQuantizedGather to produce a meaningful
  error instead of an opaque Go panic on out-of-range indices.
- Refactor quantizedDenseGGML to use quantizedDenseParallel for consistent
  parallelism with NF4/Linear paths, including M=1 column tiling.
- Add QuantizedGather tests for Q8_0, Q4_0, and IQ4_NL covering both fused
  (simplego) and decomposed (xla:cpu) paths.
- Clarify IQ4NLLookupTable comment and K-quant backend requirements in docs.
Copy link
Copy Markdown
Contributor

@janpfeifer janpfeifer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not fully reviewed, but some important change requests (sry), so let's start with this.

- Merge upstream/main (DTypeMap error returns, pad support, etc.)
- Revert buffer pool bucketing changes to separate PR per review
- Revert exec_bitcast changes to separate PR per review
- Split quantized code into exec_fused_quantized.go and
  exec_fused_quantized_ggml.go per review
- Add GGML block format documentation with references to ggml/llama.cpp
- Add detailed deriveGGMLK documentation explaining N/K dimensions
- Rename table/tableQuantization -> data/dataQuantization in
  FusedQuantizedGather (parameter names, not just for embeddings)
- Use errors.Wrapf(backends.ErrNotImplemented, ...) for unsupported
  quantization schemes in FusedQuantizedGather
- Add feature-request guidance in error messages
Copy link
Copy Markdown
Contributor

@janpfeifer janpfeifer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies for the delay. This took some time to review.

…acked sub-byte types

- Rename FusedQuantizedGather → QuantizedEmbeddingLookup across all
  backends, graph layer, and ggml package (it doesn't fuse operations)
- Rename Dequant → Dequantize, GatherDecomposed → EmbeddingLookupDecomposed
- Rename flatToIntSlice → quantGatherIntSliceOfFlat (too generic for specialized use)
- Extract validateGGMLTypeSupported helper to deduplicate GGML type validation
- Fix ValueSafe for packed sub-byte tensors (Int4, Uint4, Int2, Uint2):
  unpack before scalar check and multi-dimensional slice building
- Fix Summary to show actual dtype name for packed types instead of Go storage type
- Add comprehensive tests for packed sub-byte ValueSafe and Summary
- Improve Dequantize doc comments explaining N, K, and why N is explicit
- Sort shift ops alphabetically in capabilities.go
- Regenerate enumerators and backend ops wrappers
…s, fix dead branch

- Replace hand-rolled ggmlFp16LE with float16.Frombits from existing
  x448/float16 dependency (already used elsewhere in simplego)
- Pre-allocate per-worker scratch buffers in quantizedDenseGGML to avoid
  heap allocation per tile invocation on the inference hot path
- Eliminate intermediate dequantRow + copy in execQuantizedEmbeddingLookup
  by dequantizing directly into the output slice
- Remove dead branch in execBitcast canReuse (targetDType == dtypes.Uint8
  is unreachable when !sameBitWidth and srcIsUint8)
@ajroetker ajroetker requested a review from janpfeifer March 23, 2026 00:25
@ajroetker
Copy link
Copy Markdown
Contributor Author

Tried to address everything, found a couple of bugs while refactoring after the fixups too!

Extract parallelTileCount helper and add workerIdx callback parameter
to quantizedDenseParallel, eliminating duplicated tileSize/numWorkers
computation in quantizedDenseGGML.
Copy link
Copy Markdown
Contributor

@janpfeifer janpfeifer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still haven't reviewed everything ... Question: in many cases it's not making use of the infra that is already there, I assume most of the code was written by the AI, correct ?

If yes, I wonder if there is a context prompt we can add to the AGENTS.md to try to make it more attentive to following the patterns in the package already.

func quantizedDenseParallel(backend *Backend, M, K, N int, rowFn func(m, nStart, nEnd int)) {
// parallelTileCount returns the number of parallel work units that
// quantizedDenseParallel will dispatch for the given dimensions.
func parallelTileCount(backend *Backend, M, K, N int) int {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: parallelTileCount -> quantizedDenseParallelTileCount ?

(Since simplego is this giant package ...)

_, srcIsUint8 := src.flat.([]uint8)
dstIsUint8Storage := targetDType == dtypes.Uint8 || targetDType.Bits() < 8
canReuse = srcIsUint8 && dstIsUint8Storage
canReuse = srcIsUint8 && targetDType.Bits() < 8
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe:

tgtIsUint8 := targetDType.GoType().Kind() == reflect.Uint8 
canReuse = srcIsUint8 && tgtIsUint8

Just in case there eventually is a packed data type that doesn't use uint8 as the storage dtype.

// unpackWeightsToInt8 unpacks sub-byte weight data (Int4, Uint4) from packed
// []byte storage into []int8 (one value per element) for the matmul kernel.
// For non-sub-byte types, returns the flat data as-is.
func unpackWeightsToInt8(wBuf *Buffer) any {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably should return a buffer here: meaning the buffer pool was created exactly for these quick temporary allocations that will be the same size at every execution of the graph -- so that we expect to happen often.

And, if that makes sense, then we may as well use the already existing execConvertDType from Uint4/Int4 to Int8.

return output, nil
}

// quantGatherIntSliceOfFlat converts a flat index slice ([]int32, []int64, or []int) to []int.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, we likely want to use the buffer pool to allocate the converted ints. In which case, you may just use the corresponding execConvertDType from whatever is the indices to int64 (better use int64 than int in this case, with explicit number of bits -- for all our platforms int=int64 anyway).

return nil, err
}

numIndices := indicesBuf.shape.Size() / indicesBuf.shape.Dimensions[indicesBuf.shape.Rank()-1]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic here is very odd.

If the last dimension of indices != 1, you would be simply truncating the indices and throwing the rest away !?

But we know the last dimension is 1 (it's pre-checked) and so numIndices := indicesBuf.shape.Size().


numIndices := indicesBuf.shape.Size() / indicesBuf.shape.Dimensions[indicesBuf.shape.Rank()-1]

indices, err := quantGatherIntSliceOfFlat(indicesBuf.flat, numIndices)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional: instead of converting to int, what about having a generic function for all types of ints and use the usual registration system to map the index dtype to the corresponding generic instantiation ? You save one temporary allocation, and it will probably be a tiny bit faster for smaller integers (less memory to scan).

Very optional, since likely this won't make much of a difference, I can't imagine the embedding lookup being the bottleneck :)

}
}

// convertToIntSlice converts the first n elements of an integer slice to []int.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed, if you use the execConvertDType... per comment above.

- Use buffer pool + ConvertDType for weight unpacking and index
  conversion instead of ad-hoc allocations (unpackWeightsToBuffer,
  convertIndicesToInt64)
- Inline shift operations following binary ops pattern to eliminate
  per-element closure overhead (shiftLeftOp, shiftRightArithmeticOp,
  shiftRightLogicalUnsignedOp, shiftRightLogicalSignedOp)
- Rename parallelTileCount → quantizedDenseParallelTileCount
- Simplify numIndices calculation (last dim pre-validated as 1)
- Use tgtIsUint8 GoType check in exec_bitcast instead of Bits() < 8
- Add GGML format references and doc links to fused_ops.go
- Add "Follow Existing Patterns" guidance to AGENTS.md
@ajroetker ajroetker force-pushed the gguf-fused-quantized-dense branch from 5250386 to 18d38b4 Compare March 25, 2026 19:40
Copy link
Copy Markdown
Contributor

@janpfeifer janpfeifer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very neat, thanks @ajroetker !!

A few more minor comments, all reviewed now!

func unpackWeightsToBuffer(backend *Backend, wBuf *Buffer) (*Buffer, bool, error) {
var targetDType dtypes.DType
switch wBuf.shape.DType {
case dtypes.Int4:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: what about Int2 and Uint2 ?


// For packed sub-byte weights (from Bitcast), unpack nibbles via the buffer pool
// and ConvertDType infrastructure. Non-sub-byte types pass through unchanged.
unpackedBuf, unpackedPooled, err := unpackWeightsToBuffer(backend, wBuf)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe unpackedPooled -> isUnpackedPooled or isUnpackedOwned since it comes from the pool anyway (wBuf is pooled), the question is whether we need to release it.

numIndices := indicesBuf.shape.Size()

// Convert indices to int64 via the buffer pool and ConvertDType infrastructure.
idxBuf, idxPooled, err := convertIndicesToInt64(backend, indicesBuf)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: when reading I keep thinking idxPooled is another buffer. I was going to suggest as above, prefix is "is" or "are" or "has" for booleans, and maybe call it "owned" since all buffers are pooled. So idxPooled -> isIdxOwned.

Which reminds me that I should have named inputsOwned to areInputsOwned 😄

}

// execShiftLeft executes lhs << rhs for integer types.
func execShiftLeft(backend *Backend, node *Node, inputs []*Buffer, inputsOwned []bool) (*Buffer, error) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very optional: you could use the DTypeMap to register functions per dtype (and get rid of the various laundry list switches). Well, the ShfitLogicalRight would have to be manually registered because of taking two type parameters.

But ... when we have SIMD, it will be easy to simply register the supported SIMD version detected in runtime.

But it can also wait ...

w("[%d]", dim)
}
w("%s", values.Type().Elem())
if t.shape.DType.IsPacked() {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm ... this is delicate: so far I used Summary() to print a sample of even very large tensors, because it doesn't create a copy of it.

If we use unpackFlatValues() here this becomes a very costly operation, potentially requiring a giant allocation.

Can I suggest (I'm hoping the AI can easily handle this):

  • Refactor wValue() above to take instead of the value itself, just an index of of the value in the flat vector.
  • Write separately a small extractPackedElement(flatPacked []uint8, packedDType dtypes.DType, index int) int, that takes the packed bytes (in flatPacked), the packed dtype (Int2, Int4, Uint2, or Uint8) and the index, and returns the corresponding value unpacked to an int.

Then in the wValue(index ii) you can check if dtype.IsPacked() and if true, call this extractPackedElement() instead.

Wdyt ?

dequantW := Dequantize(weights, ggmlType, N)

// Transpose to [K, N] for matmul.
dequantW = Transpose(dequantW, 0, 1) // [K, N]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of transpose here, just change the below Dot.Product() to the corresponding Dot.General() with the appropriate contraction axes. Let the dotgeneral algorithm decide what it wants to do (if it wants to transpose or not).

- Register Int2/Uint2 → {Int8,Uint8,Int32,Int64,Float32,Float64}
  converters via execConvertPackedSubByte with valuesPerByte=4
- Add unpackInt2Bits and unpackUint2Bits for 2-bit packed data
- Handle Int2/Uint2 in unpackWeightsToBuffer alongside Int4/Uint4
- Register mutableBytes and fillBuffer for Int2/Uint2
- Rename unpackedPooled → isUnpackedOwned, idxPooled → isIdxOwned
  for clarity (all buffers are pooled; the bool tracks ownership)
@ajroetker ajroetker force-pushed the gguf-fused-quantized-dense branch from c93f94a to a1494db Compare March 26, 2026 14:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants