Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 46 additions & 53 deletions ggml/src/ggml-sycl/vecdotq.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,32 @@ static __dpct_inline__ int get_int_from_uint8_aligned(
(const int*)(x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
}

static __dpct_inline__ int byte_sub_4(const int a, const int b) {
const uint32_t ua = static_cast<uint32_t>(a);
const uint32_t ub = static_cast<uint32_t>(b);
return static_cast<int>(((ua | 0x80808080u) - ub) ^ 0x80808080u);
}

static __dpct_inline__ float vec_dot_q6_K_q8_1_impl_mmvq_scalar(
const int vl, const int vh, const int u0, const int u1, const int8_t sc0,
const int8_t sc1, const float d, const float d80, const float d81) {
static_assert(QR6_K == 2, "q6_K MMVQ scalar fast path assumes QR6_K == 2");

const int vil0 = (vl >> 0) & 0x0F0F0F0F;
const int vih0 = ((vh >> 0) << 4) & 0x30303030;
const int vi0 = byte_sub_4(vil0 | vih0, 0x20202020);

const int vil1 = (vl >> 4) & 0x0F0F0F0F;
const int vih1 = ((vh >> 4) << 4) & 0x30303030;
const int vi1 = byte_sub_4(vil1 | vih1, 0x20202020);

const float sumf =
d80 * (dpct::dp4a(vi0, u0, 0) * sc0) +
d81 * (dpct::dp4a(vi1, u1, 0) * sc1);

return d * sumf;
}

static __dpct_inline__ void get_int_from_table_16(const uint32_t &q4,
const uint8_t *values,
int &val1, int &val2) {
Expand Down Expand Up @@ -279,24 +305,8 @@ vec_dot_q6_K_q8_1_impl_mmvq(const int &vl, const int &vh,
const int *__restrict__ u,
const int8_t *__restrict__ scales, const float &d,
const float *__restrict__ d8) {

float sumf = 0.0f;

#pragma unroll
for (int i = 0; i < QR6_K; ++i) {
const int sc = scales[4*i];

const int vil = (vl >> (4*i)) & 0x0F0F0F0F;

const int vih = ((vh >> (4*i)) << 4) & 0x30303030;

const int vi = dpct::vectorized_binary<sycl::char4>(
(vil | vih), 0x20202020, dpct::sub_sat()); // vi = (vil | vih) - 32

sumf += d8[i] * (dpct::dp4a(vi, u[i], 0) * sc); // SIMD dot product
}

return d*sumf;
return vec_dot_q6_K_q8_1_impl_mmvq_scalar(
vl, vh, u[0], u[1], scales[0], scales[4], d, d8[0], d8[1]);
}

// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
Expand Down Expand Up @@ -490,23 +500,8 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K> {
__dpct_inline__ float vec_dot_q6_K_q8_1_impl_mmvq(const int vl, const int vh, const int * __restrict__ u,
const int8_t * __restrict__ scales, const float d,
const float * __restrict__ d8) {
float sumf = 0.0f;

#pragma unroll
for (int i = 0; i < QR6_K; ++i) {
const int sc = scales[4 * i];

const int vil = (vl >> (4 * i)) & 0x0F0F0F0F;

const int vih = ((vh >> (4 * i)) << 4) & 0x30303030;

const int vi = dpct::vectorized_binary<sycl::char4>((vil | vih), 0x20202020,
dpct::sub_sat()); // vi = (vil | vih) - 32

sumf += d8[i] * (dpct::dp4a(vi, u[i], 0) * sc); // SIMD dot product
}

return d * sumf;
return vec_dot_q6_K_q8_1_impl_mmvq_scalar(
vl, vh, u[0], u[1], scales[0], scales[4], d, d8[0], d8[1]);
}

__dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair<int, int> ibx_offset,
Expand All @@ -527,16 +522,15 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K> {

const int8_t * scs = scales + scale_offset;

int u[QR6_K];
float d8[QR6_K];
const int u0 = get_int_from_int8_aligned(
q8_1_quant_ptr + bq8_offset * QK8_1, iqs % QI8_1);
const int u1 = get_int_from_int8_aligned(
q8_1_quant_ptr + (bq8_offset + 2) * QK8_1, iqs % QI8_1);
const float d80 = (*(q8_1_ds + bq8_offset + 0))[0];
const float d81 = (*(q8_1_ds + bq8_offset + 2))[0];

#pragma unroll
for (int i = 0; i < QR6_K; ++i) {
u[i] = get_int_from_int8_aligned(q8_1_quant_ptr + (bq8_offset + 2 * i) * QK8_1, iqs % QI8_1);
const sycl::half2 ds_values = *(q8_1_ds + bq8_offset + 2 * i);
d8[i] = ds_values[0];
}
return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scs, *d, d8);
return vec_dot_q6_K_q8_1_impl_mmvq_scalar(
vl, vh, u0, u1, scs[0], scs[4], *d, d80, d81);
}
};
#define VDR_Q4_0_Q8_1_MMVQ 2
Expand Down Expand Up @@ -1115,16 +1109,15 @@ vec_dot_q6_K_q8_1(const void *__restrict__ vbq,

const int8_t * scales = bq6_K->scales + scale_offset;

int u[QR6_K];
float d8[QR6_K];

#pragma unroll
for (int i = 0; i < QR6_K; ++i) {
u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1);
d8[i] = bq8_1[bq8_offset + 2 * i].ds[0];
}
const int u0 = get_int_from_int8_aligned(
bq8_1[bq8_offset + 0].qs, iqs % QI8_1);
const int u1 = get_int_from_int8_aligned(
bq8_1[bq8_offset + 2].qs, iqs % QI8_1);
const float d80 = bq8_1[bq8_offset + 0].ds[0];
const float d81 = bq8_1[bq8_offset + 2].ds[0];

return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8);
return vec_dot_q6_K_q8_1_impl_mmvq_scalar(
vl, vh, u0, u1, scales[0], scales[4], bq6_K->d, d80, d81);
}


Expand Down