Add GGML fused quantized dense and gather ops#364
Add GGML fused quantized dense and gather ops#364ajroetker wants to merge 17 commits intogomlx:mainfrom
Conversation
- Add QuantGGML quantization scheme with GGMLQuantType enum supporting Q4_0, Q8_0, IQ4_NL, Q2_K, Q3_K, Q4_K, Q5_K, Q6_K block formats - Add FusedQuantizedGather op for quantized embedding lookups with graph builder, shape inference, and simplego executor - Implement GGML dequantization executors (Q4_0, Q8_0, Q4_K, Q6_K) for both FusedQuantizedDense and FusedQuantizedGather - Add QuantizedGather nn layer for quantized embedding tables - Fix packedLen to only pack sub-byte integer types (Int4/Uint4), not Bool - Add benchmarks for GGML dequantization and quantized dense tests
…y bench dep - exec_shift_ops.go: use unexported binaryOperandsAndOutput/newBroadcastIterator - Remove BenchmarkQuantizedDenseGGML that depends on go-highway/gguf
Resolve conflict in buffers.go: keep bucketSize/subSliceFlat/fullCapFlat from gguf branch, drop isPackedSubByteDType/packedLen (replaced by dtype.IsPacked() and dtype.SizeForDimensions()). Fix subSliceFlat call in getBuffer to use element count for non-packed types and byte count for packed types.
The parameter describes the quantization of the table (embedding matrix), not generic weights. This makes the API more precise and consistent with the table/indices naming used elsewhere in the function signature.
- Fix ggmlFp16LE subnormal path: use int32 for exponent to avoid fragile uint32 underflow/cast chain - Add allow-list check to FusedQuantizedGather so unsupported GGML types (e.g. IQ4_NL) return ErrNotImplemented at build time, enabling transparent fallback via InternalFusedOpCaller - Extract deriveGGMLK helper to deduplicate K derivation+validation in fusedQuantizedDenseGGML and FusedQuantizedGather - Extract processColumn closure in quantizedDenseGGML to deduplicate serial/parallel branches - Extract flatToIntSlice helper to deduplicate 3-way type switch in execFusedQuantizedGather - Extract extractNibbleBlock helper to deduplicate shared Q4_0/IQ4_NL nibble unpacking in dequant.go - Remove dead _ = vpb assignment in Dequant - Remove duplicated comment in local.go - Merge identical BytesPerBlock cases (GGMLQ4_0, GGMLIQ4NL)
…xecutor, harden gather - Hand-write BackendFusedQuantizedGather in fused_ops.go using toBackend(), matching the BackendFusedQuantizedDense pattern. Remove auto-generated version. - Add dequantIQ4NLRow fused executor and register IQ4_NL in builder switches for both Dense and Gather, so IQ4_NL no longer always falls back to decomposed. - Add bounds check on rowIdx in execFusedQuantizedGather to produce a meaningful error instead of an opaque Go panic on out-of-range indices. - Refactor quantizedDenseGGML to use quantizedDenseParallel for consistent parallelism with NF4/Linear paths, including M=1 column tiling. - Add QuantizedGather tests for Q8_0, Q4_0, and IQ4_NL covering both fused (simplego) and decomposed (xla:cpu) paths. - Clarify IQ4NLLookupTable comment and K-quant backend requirements in docs.
janpfeifer
left a comment
There was a problem hiding this comment.
Not fully reviewed, but some important change requests (sry), so let's start with this.
- Merge upstream/main (DTypeMap error returns, pad support, etc.) - Revert buffer pool bucketing changes to separate PR per review - Revert exec_bitcast changes to separate PR per review - Split quantized code into exec_fused_quantized.go and exec_fused_quantized_ggml.go per review - Add GGML block format documentation with references to ggml/llama.cpp - Add detailed deriveGGMLK documentation explaining N/K dimensions - Rename table/tableQuantization -> data/dataQuantization in FusedQuantizedGather (parameter names, not just for embeddings) - Use errors.Wrapf(backends.ErrNotImplemented, ...) for unsupported quantization schemes in FusedQuantizedGather - Add feature-request guidance in error messages
janpfeifer
left a comment
There was a problem hiding this comment.
Apologies for the delay. This took some time to review.
…acked sub-byte types - Rename FusedQuantizedGather → QuantizedEmbeddingLookup across all backends, graph layer, and ggml package (it doesn't fuse operations) - Rename Dequant → Dequantize, GatherDecomposed → EmbeddingLookupDecomposed - Rename flatToIntSlice → quantGatherIntSliceOfFlat (too generic for specialized use) - Extract validateGGMLTypeSupported helper to deduplicate GGML type validation - Fix ValueSafe for packed sub-byte tensors (Int4, Uint4, Int2, Uint2): unpack before scalar check and multi-dimensional slice building - Fix Summary to show actual dtype name for packed types instead of Go storage type - Add comprehensive tests for packed sub-byte ValueSafe and Summary - Improve Dequantize doc comments explaining N, K, and why N is explicit - Sort shift ops alphabetically in capabilities.go - Regenerate enumerators and backend ops wrappers
…s, fix dead branch - Replace hand-rolled ggmlFp16LE with float16.Frombits from existing x448/float16 dependency (already used elsewhere in simplego) - Pre-allocate per-worker scratch buffers in quantizedDenseGGML to avoid heap allocation per tile invocation on the inference hot path - Eliminate intermediate dequantRow + copy in execQuantizedEmbeddingLookup by dequantizing directly into the output slice - Remove dead branch in execBitcast canReuse (targetDType == dtypes.Uint8 is unreachable when !sameBitWidth and srcIsUint8)
|
Tried to address everything, found a couple of bugs while refactoring after the fixups too! |
Extract parallelTileCount helper and add workerIdx callback parameter to quantizedDenseParallel, eliminating duplicated tileSize/numWorkers computation in quantizedDenseGGML.
janpfeifer
left a comment
There was a problem hiding this comment.
I still haven't reviewed everything ... Question: in many cases it's not making use of the infra that is already there, I assume most of the code was written by the AI, correct ?
If yes, I wonder if there is a context prompt we can add to the AGENTS.md to try to make it more attentive to following the patterns in the package already.
| func quantizedDenseParallel(backend *Backend, M, K, N int, rowFn func(m, nStart, nEnd int)) { | ||
| // parallelTileCount returns the number of parallel work units that | ||
| // quantizedDenseParallel will dispatch for the given dimensions. | ||
| func parallelTileCount(backend *Backend, M, K, N int) int { |
There was a problem hiding this comment.
nit: parallelTileCount -> quantizedDenseParallelTileCount ?
(Since simplego is this giant package ...)
backends/simplego/exec_bitcast.go
Outdated
| _, srcIsUint8 := src.flat.([]uint8) | ||
| dstIsUint8Storage := targetDType == dtypes.Uint8 || targetDType.Bits() < 8 | ||
| canReuse = srcIsUint8 && dstIsUint8Storage | ||
| canReuse = srcIsUint8 && targetDType.Bits() < 8 |
There was a problem hiding this comment.
Maybe:
tgtIsUint8 := targetDType.GoType().Kind() == reflect.Uint8
canReuse = srcIsUint8 && tgtIsUint8
Just in case there eventually is a packed data type that doesn't use uint8 as the storage dtype.
| // unpackWeightsToInt8 unpacks sub-byte weight data (Int4, Uint4) from packed | ||
| // []byte storage into []int8 (one value per element) for the matmul kernel. | ||
| // For non-sub-byte types, returns the flat data as-is. | ||
| func unpackWeightsToInt8(wBuf *Buffer) any { |
There was a problem hiding this comment.
We probably should return a buffer here: meaning the buffer pool was created exactly for these quick temporary allocations that will be the same size at every execution of the graph -- so that we expect to happen often.
And, if that makes sense, then we may as well use the already existing execConvertDType from Uint4/Int4 to Int8.
| return output, nil | ||
| } | ||
|
|
||
| // quantGatherIntSliceOfFlat converts a flat index slice ([]int32, []int64, or []int) to []int. |
There was a problem hiding this comment.
Same as above, we likely want to use the buffer pool to allocate the converted ints. In which case, you may just use the corresponding execConvertDType from whatever is the indices to int64 (better use int64 than int in this case, with explicit number of bits -- for all our platforms int=int64 anyway).
| return nil, err | ||
| } | ||
|
|
||
| numIndices := indicesBuf.shape.Size() / indicesBuf.shape.Dimensions[indicesBuf.shape.Rank()-1] |
There was a problem hiding this comment.
The logic here is very odd.
If the last dimension of indices != 1, you would be simply truncating the indices and throwing the rest away !?
But we know the last dimension is 1 (it's pre-checked) and so numIndices := indicesBuf.shape.Size().
|
|
||
| numIndices := indicesBuf.shape.Size() / indicesBuf.shape.Dimensions[indicesBuf.shape.Rank()-1] | ||
|
|
||
| indices, err := quantGatherIntSliceOfFlat(indicesBuf.flat, numIndices) |
There was a problem hiding this comment.
Optional: instead of converting to int, what about having a generic function for all types of ints and use the usual registration system to map the index dtype to the corresponding generic instantiation ? You save one temporary allocation, and it will probably be a tiny bit faster for smaller integers (less memory to scan).
Very optional, since likely this won't make much of a difference, I can't imagine the embedding lookup being the bottleneck :)
| } | ||
| } | ||
|
|
||
| // convertToIntSlice converts the first n elements of an integer slice to []int. |
There was a problem hiding this comment.
Not needed, if you use the execConvertDType... per comment above.
- Use buffer pool + ConvertDType for weight unpacking and index conversion instead of ad-hoc allocations (unpackWeightsToBuffer, convertIndicesToInt64) - Inline shift operations following binary ops pattern to eliminate per-element closure overhead (shiftLeftOp, shiftRightArithmeticOp, shiftRightLogicalUnsignedOp, shiftRightLogicalSignedOp) - Rename parallelTileCount → quantizedDenseParallelTileCount - Simplify numIndices calculation (last dim pre-validated as 1) - Use tgtIsUint8 GoType check in exec_bitcast instead of Bits() < 8 - Add GGML format references and doc links to fused_ops.go - Add "Follow Existing Patterns" guidance to AGENTS.md
5250386 to
18d38b4
Compare
janpfeifer
left a comment
There was a problem hiding this comment.
Very neat, thanks @ajroetker !!
A few more minor comments, all reviewed now!
| func unpackWeightsToBuffer(backend *Backend, wBuf *Buffer) (*Buffer, bool, error) { | ||
| var targetDType dtypes.DType | ||
| switch wBuf.shape.DType { | ||
| case dtypes.Int4: |
There was a problem hiding this comment.
Question: what about Int2 and Uint2 ?
|
|
||
| // For packed sub-byte weights (from Bitcast), unpack nibbles via the buffer pool | ||
| // and ConvertDType infrastructure. Non-sub-byte types pass through unchanged. | ||
| unpackedBuf, unpackedPooled, err := unpackWeightsToBuffer(backend, wBuf) |
There was a problem hiding this comment.
nit: maybe unpackedPooled -> isUnpackedPooled or isUnpackedOwned since it comes from the pool anyway (wBuf is pooled), the question is whether we need to release it.
| numIndices := indicesBuf.shape.Size() | ||
|
|
||
| // Convert indices to int64 via the buffer pool and ConvertDType infrastructure. | ||
| idxBuf, idxPooled, err := convertIndicesToInt64(backend, indicesBuf) |
There was a problem hiding this comment.
nit: when reading I keep thinking idxPooled is another buffer. I was going to suggest as above, prefix is "is" or "are" or "has" for booleans, and maybe call it "owned" since all buffers are pooled. So idxPooled -> isIdxOwned.
Which reminds me that I should have named inputsOwned to areInputsOwned 😄
| } | ||
|
|
||
| // execShiftLeft executes lhs << rhs for integer types. | ||
| func execShiftLeft(backend *Backend, node *Node, inputs []*Buffer, inputsOwned []bool) (*Buffer, error) { |
There was a problem hiding this comment.
Very optional: you could use the DTypeMap to register functions per dtype (and get rid of the various laundry list switches). Well, the ShfitLogicalRight would have to be manually registered because of taking two type parameters.
But ... when we have SIMD, it will be easy to simply register the supported SIMD version detected in runtime.
But it can also wait ...
| w("[%d]", dim) | ||
| } | ||
| w("%s", values.Type().Elem()) | ||
| if t.shape.DType.IsPacked() { |
There was a problem hiding this comment.
Hmm ... this is delicate: so far I used Summary() to print a sample of even very large tensors, because it doesn't create a copy of it.
If we use unpackFlatValues() here this becomes a very costly operation, potentially requiring a giant allocation.
Can I suggest (I'm hoping the AI can easily handle this):
- Refactor
wValue()above to take instead of the value itself, just an index of of the value in the flat vector. - Write separately a small
extractPackedElement(flatPacked []uint8, packedDType dtypes.DType, index int) int, that takes the packed bytes (inflatPacked), the packed dtype (Int2, Int4, Uint2, or Uint8) and the index, and returns the corresponding value unpacked to anint.
Then in the wValue(index ii) you can check if dtype.IsPacked() and if true, call this extractPackedElement() instead.
Wdyt ?
| dequantW := Dequantize(weights, ggmlType, N) | ||
|
|
||
| // Transpose to [K, N] for matmul. | ||
| dequantW = Transpose(dequantW, 0, 1) // [K, N] |
There was a problem hiding this comment.
Instead of transpose here, just change the below Dot.Product() to the corresponding Dot.General() with the appropriate contraction axes. Let the dotgeneral algorithm decide what it wants to do (if it wants to transpose or not).
- Register Int2/Uint2 → {Int8,Uint8,Int32,Int64,Float32,Float64}
converters via execConvertPackedSubByte with valuesPerByte=4
- Add unpackInt2Bits and unpackUint2Bits for 2-bit packed data
- Handle Int2/Uint2 in unpackWeightsToBuffer alongside Int4/Uint4
- Register mutableBytes and fillBuffer for Int2/Uint2
- Rename unpackedPooled → isUnpackedOwned, idxPooled → isIdxOwned
for clarity (all buffers are pooled; the bool tracks ownership)
c93f94a to
a1494db
Compare
Summary
FusedQuantizedGatherfor quantized embedding lookups (analogous toggml_get_rows)FusedQuantizedDensefor fused dequant + matmul + bias + activation on GGML-quantized weightsQuantizedGatherhigh-level layer inpkg/ml/nnReview feedback addressed
execBitcastfix to separate PR: Fix execBitcast buffer reuse for cross-bit-width types #374exec_fused_quantized.goandexec_fused_quantized_ggml.goderiveGGMLKdocs explaining N/K dimensionstable/tableQuantization→data/dataQuantizationinFusedQuantizedGathererrors.Wrapf(backends.ErrNotImplemented, ...)with feature-request guidanceTest plan
go build ./...passesgo test ./backends/simplego/...— includes new benchmark for fused opsgo test ./pkg/ml/nn/...— includes new quantized dense test