diff --git a/ggml/src/ggml-cpu/arch/x86/quants.c b/ggml/src/ggml-cpu/arch/x86/quants.c index 0a3e071e57c..a0be7b26654 100644 --- a/ggml/src/ggml-cpu/arch/x86/quants.c +++ b/ggml/src/ggml-cpu/arch/x86/quants.c @@ -1941,8 +1941,16 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0])); const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums); - const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); - const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s); + const __m256i ones = _mm256_set1_epi16(1); + __m256i q8s_32 = _mm256_madd_epi16(q8sums, ones); + __m128i m128 = _mm256_extracti128_si256(mins_and_scales, 1); + __m256i m256 = _mm256_cvtepi16_epi32(m128); + __m256i prod_256 = _mm256_mullo_epi32(q8s_32, m256); + __m128i pL = _mm256_castsi256_si128(prod_256); + __m128i pH = _mm256_extracti128_si256(prod_256, 1); + __m128i even = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(pL), _mm_castsi128_ps(pH), _MM_SHUFFLE(2, 0, 2, 0))); + __m128i odd = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(pL), _mm_castsi128_ps(pH), _MM_SHUFFLE(3, 1, 3, 1))); + __m128i prod = _mm_add_epi32(even, odd); acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m); const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0); @@ -2121,10 +2129,19 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0])); const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums); - const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); - const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s); - const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero); - summs += dmin * _mm_extract_epi32(hsum, 0); + const __m256i ones = _mm256_set1_epi16(1); + __m256i q8s_32 = _mm256_madd_epi16(q8sums, ones); + __m128i m128 = _mm256_extracti128_si256(mins_and_scales, 1); + __m256i m256 = _mm256_cvtepi16_epi32(m128); + __m256i prod_256 = _mm256_mullo_epi32(q8s_32, m256); + __m128i pL = _mm256_castsi256_si128(prod_256); + __m128i pH = _mm256_extracti128_si256(prod_256, 1); + __m128i sum128 = _mm_add_epi32(pL, pH); + __m128i upper = _mm_shuffle_epi32(sum128, _MM_SHUFFLE(1, 0, 3, 2)); + __m128i sum64 = _mm_add_epi32(sum128, upper); + __m128i upper32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(1, 1, 1, 1)); + __m128i final_sum = _mm_add_epi32(sum64, upper32); + summs += dmin * _mm_cvtsi128_si32(final_sum); const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0); const __m256i scales = MM256_SET_M128I(sc128, sc128);