feat: x86 matmul and transpose support (SDPA llama runnable)#641
feat: x86 matmul and transpose support (SDPA llama runnable)#641oreomaker merged 3 commits intoUbiquitousLearning:mainfrom
Conversation
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
📝 WalkthroughWalkthroughThis PR introduces optimized x86 CPU kernels for BLAS matrix multiplication (GEMM/GEMV with AVX/AVX2 acceleration) and tensor transpose operations (with SIMD variants), integrating them into the MatMul and Transpose operators. It also adds support for GGUF Q4_0 quantized format in embedding operations. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
mllm/backends/cpu/ops/TransposeOp.cpp (1)
95-103:⚠️ Potential issue | 🔴 CriticalFix critical bug:
batchinitialization must be 1, not 0, to correctly compute batch dimensions.Line 96 initializes
batch = 0, then multiplies it in a loop — the result is always 0. The x86 dispatch on lines 102–103 and 121–122 (and ARM equivalents at 105–106 and 124–125) inherit this bug, causingtranspose_last_dims_*to process 0 batches and produce no output for any CASE 3 transpose operation.Fix
- int batch = 0; + int batch = 1; for (int i = 0; i < input_shape.size() - 2; i++) { batch *= input_shape[i]; }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@mllm/backends/cpu/ops/TransposeOp.cpp` around lines 95 - 103, The batch accumulator in TransposeOp.cpp is initialized to 0 which makes subsequent multiplications yield 0; change its initialization to 1 so the loop that multiplies over the leading dims correctly computes the batch size (in the branch that checks input_shape.size() != 2 && (dim0 + dim1 == 1)); update the variable declared as int batch = 1 and leave the existing loop and downstream calls to x86::transpose_last_dims_fp32 / transpose_last_dims_fp16 (and ARM equivalents) unchanged so they receive the correct batch value.
🧹 Nitpick comments (2)
mllm/core/aops/EmbeddingOp.cpp (1)
74-77: LGTM — also incidentally fixes a pre-existing Q4_K output-dtype bugBefore this change,
kGGUF_Q4_Kweight would fall through toout_dtype = weight_.dtype()(line 74), allocating the output tensor with dtypekGGUF_Q4_K.CPUEmbeddingOp::forwardthen wrotefloatvalues into it viacoffsettedPtr<float>, producing a type mismatch. This PR correctly maps both quantised formats tokFloat32.Optionally, the two sequential
ifstatements can be made anelse ifchain to make the mutual-exclusivity explicit and avoid evaluating the second condition when the first matches.♻️ Optional: use else-if chain
auto out_dtype = weight_.dtype(); - if (weight_.dtype() == kUInt16) { out_dtype = kUInt16PerTensorAsy; } - if (weight_.dtype() == kGGUF_Q4_0 || weight_.dtype() == kGGUF_Q4_K) { out_dtype = kFloat32; } + if (weight_.dtype() == kUInt16) { + out_dtype = kUInt16PerTensorAsy; + } else if (weight_.dtype() == kGGUF_Q4_0 || weight_.dtype() == kGGUF_Q4_K) { + out_dtype = kFloat32; + }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@mllm/core/aops/EmbeddingOp.cpp` around lines 74 - 77, The current logic sets out_dtype from weight_.dtype() then overwrites it for certain formats; change the two separate conditionals to an explicit mutually-exclusive chain so only one branch runs: check weight_.dtype() once, then use "if (weight_.dtype() == kUInt16) { out_dtype = kUInt16PerTensorAsy; } else if (weight_.dtype() == kGGUF_Q4_0 || weight_.dtype() == kGGUF_Q4_K) { out_dtype = kFloat32; }" to ensure kGGUF_Q4_K and kGGUF_Q4_0 map to kFloat32 before calling Tensor::empty(o_shape, out_dtype, i.device()), avoiding the CPUEmbeddingOp::forward write via koffsettedPtr<float> into a quantised-typed tensor.mllm/backends/cpu/kernels/x86/mllm_blas/mllm_blas_sgemm.hpp (1)
43-61: Guard against null-pointer arithmetic onC.When
Cisnullptr(common —MatMulOp.cpppassesnullptr),C + b * C_batch_strideis technically undefined behavior even when the stride is 0. Same issue in the__mllm_blas_batch_matmul_fp32_gemv_nt_nt_decode_small_d_wvtemplate (lines 90–108).🛡️ Suggested defensive fix (applies to both templates)
auto a_ptr = A + b * A_batch_stride; auto b_ptr = B + b * B_batch_stride; - auto c_ptr = C + b * C_batch_stride; + auto c_ptr = C ? C + b * C_batch_stride : nullptr; auto d_ptr = dst + b * Dst_batch_stride;🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@mllm/backends/cpu/kernels/x86/mllm_blas/mllm_blas_sgemm.hpp` around lines 43 - 61, The pointer arithmetic on C can UB when C == nullptr; change the per-batch pointer computation in both templates (the loop in the function containing __mllm_blas_matmul_fp32_gemv_nt_t_decode_small_d_qk and the other template __mllm_blas_batch_matmul_fp32_gemv_nt_nt_decode_small_d_wv) to compute c_ptr conditionally (e.g., set c_ptr to nullptr when C is nullptr, otherwise C + b * C_batch_stride) both inside the MLLM_AUTO_PARALLEL_FOR_BEGIN_NT block and the fallback for-loop so you never perform pointer arithmetic on a null base pointer.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@mllm/backends/cpu/kernels/x86/mllm_blas/mllm_blas_sgemm.cpp`:
- Around line 519-523: The batch GEMM loop passes the same C pointer for every
batch, causing incorrect bias usage when C is per-batch; inside the for loop
that calls __mllm_blas_sgemm_nt_t, compute a per-batch c_ptr = (C ? C + i *
C_batch_stride : nullptr) (matching the GEMV pattern) and pass c_ptr instead of
C; update the call site in the batch loop to use that c_ptr so each batch uses C
+ i * C_batch_stride when C is non-null.
- Around line 174-175: Remove the unused AVX variable a_vec declared with
_mm256_set_ps in this SGEMM implementation: the expression assigned to a_vec is
never used later (the code uses a0_vec, a1_vec, etc.), so delete the declaration
of a_vec to eliminate dead code and compiler warnings; verify no other code
depends on the a_vec symbol in the same function.
- Line 483: Fix the typo in the NYI message used in mllm_blas_matmul_fp32
(gemm/gemv) by removing the duplicated phrase "not supported" so the message
reads e.g. "transpose_a && transpose_b not supported in mllm_blas_matmul_fp32
gemm/gemv"; update both NYI(...) calls (the two occurrences within
mllm_blas_matmul_fp32) to use the corrected single "not supported" wording.
In `@mllm/backends/cpu/ops/EmbeddingOp.cpp`:
- Line 57: The NYI message in the default branch of the dtype switch in
EmbeddingOp.cpp is misleadingly specific to "arm"; update the error emitted by
NYI("Not supported weight dtype for arm llm embedding token op") to a generic
message (e.g., "Not supported weight dtype for llm embedding token op" or
include the actual backend name dynamically) so it no longer references
"arm"—locate the default case in the dtype switch inside the embedding operator
implementation and replace the string passed to NYI accordingly.
- Around line 49-56: In the kGGUF_Q4_0 branch of the switch, when token_idx < 0
the output row at ous.coffsettedPtr<float>({b, (int)s, 0}) is left
uninitialized; update the negative-token branch to explicitly zero-fill that
output row (length options_.hidden_size) instead of doing nothing. Use the same
target pointer used for dequantize_row_q4_0 (ous.coffsettedPtr<float>({b,
(int)s, 0})) and fill options_.hidden_size floats (e.g., memset to 0 or
std::fill_n) so downstream readers see zeros; keep the dequantize call for
token_idx >= 0 unchanged.
---
Outside diff comments:
In `@mllm/backends/cpu/ops/TransposeOp.cpp`:
- Around line 95-103: The batch accumulator in TransposeOp.cpp is initialized to
0 which makes subsequent multiplications yield 0; change its initialization to 1
so the loop that multiplies over the leading dims correctly computes the batch
size (in the branch that checks input_shape.size() != 2 && (dim0 + dim1 == 1));
update the variable declared as int batch = 1 and leave the existing loop and
downstream calls to x86::transpose_last_dims_fp32 / transpose_last_dims_fp16
(and ARM equivalents) unchanged so they receive the correct batch value.
---
Nitpick comments:
In `@mllm/backends/cpu/kernels/x86/mllm_blas/mllm_blas_sgemm.hpp`:
- Around line 43-61: The pointer arithmetic on C can UB when C == nullptr;
change the per-batch pointer computation in both templates (the loop in the
function containing __mllm_blas_matmul_fp32_gemv_nt_t_decode_small_d_qk and the
other template __mllm_blas_batch_matmul_fp32_gemv_nt_nt_decode_small_d_wv) to
compute c_ptr conditionally (e.g., set c_ptr to nullptr when C is nullptr,
otherwise C + b * C_batch_stride) both inside the
MLLM_AUTO_PARALLEL_FOR_BEGIN_NT block and the fallback for-loop so you never
perform pointer arithmetic on a null base pointer.
In `@mllm/core/aops/EmbeddingOp.cpp`:
- Around line 74-77: The current logic sets out_dtype from weight_.dtype() then
overwrites it for certain formats; change the two separate conditionals to an
explicit mutually-exclusive chain so only one branch runs: check weight_.dtype()
once, then use "if (weight_.dtype() == kUInt16) { out_dtype =
kUInt16PerTensorAsy; } else if (weight_.dtype() == kGGUF_Q4_0 || weight_.dtype()
== kGGUF_Q4_K) { out_dtype = kFloat32; }" to ensure kGGUF_Q4_K and kGGUF_Q4_0
map to kFloat32 before calling Tensor::empty(o_shape, out_dtype, i.device()),
avoiding the CPUEmbeddingOp::forward write via koffsettedPtr<float> into a
quantised-typed tensor.
| __m256 a_vec = | ||
| _mm256_set_ps(A[k + 3], A[k + 3], A[k + 2], A[k + 2], A[k + 1], A[k + 1], A[k], A[k]); // For broadcasting later |
There was a problem hiding this comment.
Dead code: a_vec is computed but never used.
The _mm256_set_ps(...) result on line 174 is assigned to a_vec but the subsequent code uses individual scalar broadcasts (a0_vec, a1_vec, etc.) instead. Remove the unused variable to avoid confusion and compiler warnings.
🧹 Proposed fix
for (; k <= K - 4; k += 4) {
- __m256 a_vec =
- _mm256_set_ps(A[k + 3], A[k + 3], A[k + 2], A[k + 2], A[k + 1], A[k + 1], A[k], A[k]); // For broadcasting later
float a0 = A[k + 0];
float a1 = A[k + 1];
float a2 = A[k + 2];
float a3 = A[k + 3];📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| __m256 a_vec = | |
| _mm256_set_ps(A[k + 3], A[k + 3], A[k + 2], A[k + 2], A[k + 1], A[k + 1], A[k], A[k]); // For broadcasting later | |
| float a0 = A[k + 0]; | |
| float a1 = A[k + 1]; | |
| float a2 = A[k + 2]; | |
| float a3 = A[k + 3]; |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@mllm/backends/cpu/kernels/x86/mllm_blas/mllm_blas_sgemm.cpp` around lines 174
- 175, Remove the unused AVX variable a_vec declared with _mm256_set_ps in this
SGEMM implementation: the expression assigned to a_vec is never used later (the
code uses a0_vec, a1_vec, etc.), so delete the declaration of a_vec to eliminate
dead code and compiler warnings; verify no other code depends on the a_vec
symbol in the same function.
| } | ||
| return; | ||
| } else { | ||
| NYI("transpose_a && transpose_b not supported not supported in mllm_blas_matmul_fp32 gemm/gemv"); |
There was a problem hiding this comment.
Typo: doubled "not supported" in NYI message.
Same typo on line 527.
✏️ Proposed fix
- NYI("transpose_a && transpose_b not supported not supported in mllm_blas_matmul_fp32 gemm/gemv");
+ NYI("transpose_a && transpose_b not supported in mllm_blas_matmul_fp32 gemm/gemv");And line 527:
- NYI("transpose_a && transpose_b not supported not supported in mllm_blas_matmul_fp32 gemm/gemv");
+ NYI("transpose_a && transpose_b not supported in mllm_blas_batch_matmul_fp32 gemm/gemv");📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| NYI("transpose_a && transpose_b not supported not supported in mllm_blas_matmul_fp32 gemm/gemv"); | |
| NYI("transpose_a && transpose_b not supported in mllm_blas_matmul_fp32 gemm/gemv"); |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@mllm/backends/cpu/kernels/x86/mllm_blas/mllm_blas_sgemm.cpp` at line 483, Fix
the typo in the NYI message used in mllm_blas_matmul_fp32 (gemm/gemv) by
removing the duplicated phrase "not supported" so the message reads e.g.
"transpose_a && transpose_b not supported in mllm_blas_matmul_fp32 gemm/gemv";
update both NYI(...) calls (the two occurrences within mllm_blas_matmul_fp32) to
use the corrected single "not supported" wording.
| // Parallel is in the inner loops, not here. | ||
| for (int i = 0; i < BATCH; ++i) { | ||
| __mllm_blas_sgemm_nt_t(M, N, K, A + i * A_batch_stride, K, B + i * B_batch_stride, K, dst + i * Dst_batch_stride, N, 0, | ||
| C, thread_count); | ||
| } |
There was a problem hiding this comment.
Batch NT_T GEMM path does not offset C (bias) by batch stride.
The GEMV batch templates (__mllm_blas_batch_matmul_fp32_gemv_nt_t_decode_small_d_qk) correctly compute c_ptr = C + b * C_batch_stride per batch, but the GEMM path here passes the same C pointer for every batch iteration. If C is ever non-null with a per-batch layout, this will produce incorrect results.
Currently C is always nullptr from MatMulOp.cpp, so no immediate breakage, but this is inconsistent with the GEMV path and a latent bug.
🐛 Suggested fix for consistency
for (int i = 0; i < BATCH; ++i) {
__mllm_blas_sgemm_nt_t(M, N, K, A + i * A_batch_stride, K, B + i * B_batch_stride, K, dst + i * Dst_batch_stride, N, 0,
- C, thread_count);
+ C ? C + i * C_batch_stride : nullptr, thread_count);
}🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@mllm/backends/cpu/kernels/x86/mllm_blas/mllm_blas_sgemm.cpp` around lines 519
- 523, The batch GEMM loop passes the same C pointer for every batch, causing
incorrect bias usage when C is per-batch; inside the for loop that calls
__mllm_blas_sgemm_nt_t, compute a per-batch c_ptr = (C ? C + i * C_batch_stride
: nullptr) (matching the GEMV pattern) and pass c_ptr instead of C; update the
call site in the batch loop to use that c_ptr so each batch uses C + i *
C_batch_stride when C is non-null.
| case kGGUF_Q4_0: { | ||
| auto token_idx = *ins.coffsettedPtr<mllm_int64_t>({b, (int)s}); | ||
| if (token_idx >= 0) { | ||
| dequantize_row_q4_0(weight_.ptr<block_q4_0>() + token_idx * options_.hidden_size / QK4_0, | ||
| ous.coffsettedPtr<float>({b, (int)s, 0}), options_.hidden_size); | ||
| } | ||
| break; | ||
| } |
There was a problem hiding this comment.
Uninitialized output row when token_idx < 0
When token_idx < 0, the output row at {b, s, 0} is silently left uninitialized. This is consistent with the kGGUF_Q4_K case, but both share the same gap: a downstream consumer reading that row will see garbage. The kFloat32/kFloat16 cases are even less defensive (no guard at all), but since Q4_0 adds the check, a zero-fill for the negative branch would make it fully safe.
🛡️ Optional: zero-fill the output row for negative token indices
case kGGUF_Q4_0: {
auto token_idx = *ins.coffsettedPtr<mllm_int64_t>({b, (int)s});
if (token_idx >= 0) {
dequantize_row_q4_0(weight_.ptr<block_q4_0>() + token_idx * options_.hidden_size / QK4_0,
ous.coffsettedPtr<float>({b, (int)s, 0}), options_.hidden_size);
+ } else {
+ std::memset(ous.coffsettedPtr<char>({b, (int)s, 0}), 0, options_.hidden_size * sizeof(float));
}
break;
}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| case kGGUF_Q4_0: { | |
| auto token_idx = *ins.coffsettedPtr<mllm_int64_t>({b, (int)s}); | |
| if (token_idx >= 0) { | |
| dequantize_row_q4_0(weight_.ptr<block_q4_0>() + token_idx * options_.hidden_size / QK4_0, | |
| ous.coffsettedPtr<float>({b, (int)s, 0}), options_.hidden_size); | |
| } | |
| break; | |
| } | |
| case kGGUF_Q4_0: { | |
| auto token_idx = *ins.coffsettedPtr<mllm_int64_t>({b, (int)s}); | |
| if (token_idx >= 0) { | |
| dequantize_row_q4_0(weight_.ptr<block_q4_0>() + token_idx * options_.hidden_size / QK4_0, | |
| ous.coffsettedPtr<float>({b, (int)s, 0}), options_.hidden_size); | |
| } else { | |
| std::memset(ous.coffsettedPtr<char>({b, (int)s, 0}), 0, options_.hidden_size * sizeof(float)); | |
| } | |
| break; | |
| } |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@mllm/backends/cpu/ops/EmbeddingOp.cpp` around lines 49 - 56, In the
kGGUF_Q4_0 branch of the switch, when token_idx < 0 the output row at
ous.coffsettedPtr<float>({b, (int)s, 0}) is left uninitialized; update the
negative-token branch to explicitly zero-fill that output row (length
options_.hidden_size) instead of doing nothing. Use the same target pointer used
for dequantize_row_q4_0 (ous.coffsettedPtr<float>({b, (int)s, 0})) and fill
options_.hidden_size floats (e.g., memset to 0 or std::fill_n) so downstream
readers see zeros; keep the dequantize call for token_idx >= 0 unchanged.
| } | ||
| break; | ||
| } | ||
| default: NYI("Not supported weight dtype for arm llm embedding token op"); |
There was a problem hiding this comment.
Misleading "arm" in NYI error message
The default-case error message "Not supported weight dtype for arm llm embedding token op" refers to "arm" in the generic CPU backend that now also serves x86 paths (the whole point of this PR). This message will appear to x86 users encountering an unsupported dtype.
✏️ Suggested fix
- default: NYI("Not supported weight dtype for arm llm embedding token op");
+ default: NYI("Not supported weight dtype for CPU embedding token op");📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| default: NYI("Not supported weight dtype for arm llm embedding token op"); | |
| default: NYI("Not supported weight dtype for CPU embedding token op"); |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@mllm/backends/cpu/ops/EmbeddingOp.cpp` at line 57, The NYI message in the
default branch of the dtype switch in EmbeddingOp.cpp is misleadingly specific
to "arm"; update the error emitted by NYI("Not supported weight dtype for arm
llm embedding token op") to a generic message (e.g., "Not supported weight dtype
for llm embedding token op" or include the actual backend name dynamically) so
it no longer references "arm"—locate the default case in the dtype switch inside
the embedding operator implementation and replace the string passed to NYI
accordingly.
Please check Guidelines for Contributing.
Summary by CodeRabbit
Release Notes
New Features
Improvements