Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions ggml/src/ggml-cpu/arch/x86/quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -1941,8 +1941,16 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));

const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums);
const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
const __m256i ones = _mm256_set1_epi16(1);
__m256i q8s_32 = _mm256_madd_epi16(q8sums, ones);
__m128i m128 = _mm256_extracti128_si256(mins_and_scales, 1);
__m256i m256 = _mm256_cvtepi16_epi32(m128);
__m256i prod_256 = _mm256_mullo_epi32(q8s_32, m256);
__m128i pL = _mm256_castsi256_si128(prod_256);
__m128i pH = _mm256_extracti128_si256(prod_256, 1);
__m128i even = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(pL), _mm_castsi128_ps(pH), _MM_SHUFFLE(2, 0, 2, 0)));
__m128i odd = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(pL), _mm_castsi128_ps(pH), _MM_SHUFFLE(3, 1, 3, 1)));
__m128i prod = _mm_add_epi32(even, odd);
acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m);

const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
Expand Down Expand Up @@ -2121,10 +2129,19 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));

const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums);
const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero);
summs += dmin * _mm_extract_epi32(hsum, 0);
const __m256i ones = _mm256_set1_epi16(1);
__m256i q8s_32 = _mm256_madd_epi16(q8sums, ones);
__m128i m128 = _mm256_extracti128_si256(mins_and_scales, 1);
__m256i m256 = _mm256_cvtepi16_epi32(m128);
__m256i prod_256 = _mm256_mullo_epi32(q8s_32, m256);
__m128i pL = _mm256_castsi256_si128(prod_256);
__m128i pH = _mm256_extracti128_si256(prod_256, 1);
__m128i sum128 = _mm_add_epi32(pL, pH);
__m128i upper = _mm_shuffle_epi32(sum128, _MM_SHUFFLE(1, 0, 3, 2));
__m128i sum64 = _mm_add_epi32(sum128, upper);
__m128i upper32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(1, 1, 1, 1));
__m128i final_sum = _mm_add_epi32(sum64, upper32);
summs += dmin * _mm_cvtsi128_si32(final_sum);

const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
const __m256i scales = MM256_SET_M128I(sc128, sc128);
Expand Down