Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 143 additions & 0 deletions ggml/src/ggml-sycl/cumsum.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
#include "cumsum.hpp"
#include "common.hpp"

#include <algorithm>

#define SYCL_CUMSUM_BLOCK_SIZE 256

static __dpct_inline__ float warp_prefix_inclusive_sum_f32(float x, const sycl::nd_item<3> & item) {
return sycl::inclusive_scan_over_group(item.get_sub_group(), x, sycl::plus<float>());
}

static void cumsum_f32_kernel(
const float * __restrict__ src, float * __restrict__ dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t s01, const int64_t s02, const int64_t s03,
const int64_t d1, const int64_t d2, const int64_t d3,
const sycl::nd_item<3> & item, float * smem) {

const int tid = item.get_local_id(2);
const int block_size = item.get_local_range(2);
const int lane = tid % WARP_SIZE;
const int warp = tid / WARP_SIZE;
const int warps_per_block = block_size / WARP_SIZE;

float * s_vals = smem;
float * s_warp_sums = smem + block_size;
float * s_carry = smem + block_size + warps_per_block;

if (tid == 0) {
s_carry[0] = 0.0f;
}
item.barrier(sycl::access::fence_space::local_space);

const int64_t i3 = item.get_group(0);
const int64_t i2 = item.get_group(1);
const int64_t i1 = item.get_group(2);
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
return;
}

const float * src_row = src + i1 * s01 + i2 * s02 + i3 * s03;
float * dst_row = dst + i1 * d1 + i2 * d2 + i3 * d3;

constexpr int num_unroll = 4;
float temp[num_unroll];

for (int64_t i = 0; i < ne00; i += num_unroll * block_size) {
int64_t idx = i + tid * num_unroll;

temp[0] = (idx < ne00 ? src_row[idx] : 0.0f);
#pragma unroll
for (int j = 1; j < num_unroll; j++) {
temp[j] = temp[j - 1];
if (idx + j < ne00) {
temp[j] += src_row[idx + j];
}
}

float val = (idx < ne00) ? temp[num_unroll - 1] : 0.0f;

val = warp_prefix_inclusive_sum_f32(val, item);
s_vals[tid] = val;

if (lane == WARP_SIZE - 1) {
s_warp_sums[warp] = val;
}
item.barrier(sycl::access::fence_space::local_space);

if (warp == 0) {
float w = (tid < warps_per_block) ? s_warp_sums[tid] : 0.0f;
float inc = warp_prefix_inclusive_sum_f32(w, item);
if (tid < warps_per_block) {
s_warp_sums[tid] = inc - w;
}
if (tid == warps_per_block - 1) {
s_carry[1] = inc;
}
}
item.barrier(sycl::access::fence_space::local_space);

float carry = s_carry[0];
float final_offset = s_vals[tid] + s_warp_sums[warp] + carry - temp[num_unroll - 1];

#pragma unroll
for (int j = 0; j < num_unroll; j++) {
if (idx + j < ne00) {
dst_row[idx + j] = temp[j] + final_offset;
}
}

item.barrier(sycl::access::fence_space::local_space);

if (tid == 0) {
s_carry[0] += s_carry[1];
}
}
}

void ggml_sycl_op_cumsum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);

dpct::queue_ptr stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));

const float * src_d = static_cast<const float *>(src0->data);
float * dst_d = static_cast<float *>(dst->data);

const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3];

const size_t ts = sizeof(float);
const int64_t s01 = src0->nb[1] / ts;
const int64_t s02 = src0->nb[2] / ts;
const int64_t s03 = src0->nb[3] / ts;
const int64_t d1 = dst->nb[1] / ts;
const int64_t d2 = dst->nb[2] / ts;
const int64_t d3 = dst->nb[3] / ts;

const int num_warps = (ne00 + WARP_SIZE - 1) / WARP_SIZE;
int block_size = num_warps * WARP_SIZE;
block_size = std::min(block_size, SYCL_CUMSUM_BLOCK_SIZE);
const int warps_per_block = block_size / WARP_SIZE;
const int smem_size = block_size + warps_per_block + 2;

const sycl::range<3> grid(ne03, ne02, ne01);
const sycl::range<3> block(1, 1, block_size);

stream->submit([&](sycl::handler & cgh) {
sycl::local_accessor<float, 1> smem_acc(sycl::range<1>(smem_size), cgh);
cgh.parallel_for(
sycl::nd_range<3>(grid * block, block),
[=](sycl::nd_item<3> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
cumsum_f32_kernel(src_d, dst_d, ne00, ne01, ne02, ne03,
s01, s02, s03, d1, d2, d3,
item, get_pointer(smem_acc));
});
});
}
5 changes: 5 additions & 0 deletions ggml/src/ggml-sycl/cumsum.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#pragma once

#include "common.hpp"

void ggml_sycl_op_cumsum(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
62 changes: 62 additions & 0 deletions ggml/src/ggml-sycl/diag.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#include "diag.hpp"
#include "common.hpp"

#define SYCL_DIAG_BLOCK_SIZE 256

template <typename T>
static void diag_kernel(T * __restrict__ dst, const T * __restrict__ src,
const int64_t ne0, const int64_t ne1,
const int64_t ne2, const int64_t ne3,
const int64_t total_elements,
const sycl::nd_item<1> & item) {
const int64_t i = item.get_global_id(0);
if (i >= total_elements) {
return;
}

const int64_t i0 = i % ne0;
const int64_t i1 = (i / ne0) % ne1;
const int64_t i2 = (i / (ne0 * ne1)) % ne2;
const int64_t i3 = i / (ne0 * ne1 * ne2);

const int64_t dst_idx = ((i3 * ne2 + i2) * ne1 + i1) * ne0 + i0;

if (i0 == i1) {
const int64_t batch_idx = i3 * ne2 + i2;
dst[dst_idx] = src[batch_idx * ne0 + i0];
} else {
dst[dst_idx] = T(0);
}

(void)ne3;
}

void ggml_sycl_op_diag(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];

GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->ne[1] == 1);

dpct::queue_ptr stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));

const void * src0_d = src0->data;
void * dst_d = dst->data;

const int64_t ne0 = dst->ne[0];
const int64_t ne1 = dst->ne[1];
const int64_t ne2 = dst->ne[2];
const int64_t ne3 = dst->ne[3];
const int64_t n_elems = ggml_nelements(dst);
const int64_t num_blocks = (n_elems + SYCL_DIAG_BLOCK_SIZE - 1) / SYCL_DIAG_BLOCK_SIZE;

GGML_ASSERT(dst->type == GGML_TYPE_F32);
stream->parallel_for(
sycl::nd_range<1>(num_blocks * SYCL_DIAG_BLOCK_SIZE, SYCL_DIAG_BLOCK_SIZE),
[=](sycl::nd_item<1> item) {
diag_kernel(static_cast<float *>(dst_d),
static_cast<const float *>(src0_d),
ne0, ne1, ne2, ne3, n_elems, item);
});
}
5 changes: 5 additions & 0 deletions ggml/src/ggml-sycl/diag.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#pragma once

#include "common.hpp"

void ggml_sycl_op_diag(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
50 changes: 50 additions & 0 deletions ggml/src/ggml-sycl/fill.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#include "fill.hpp"
#include "common.hpp"

#define SYCL_FILL_BLOCK_SIZE 256

template <typename T>
static void fill_kernel(T * dst, const int64_t k, const T value,
const sycl::nd_item<1> & item) {
const int64_t i = (int64_t)item.get_global_id(0);
if (i >= k) {
return;
}
dst[i] = value;
}

void ggml_sycl_op_fill(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(ggml_is_contiguous(dst));

dpct::queue_ptr stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));

float value;
memcpy(&value, dst->op_params, sizeof(float));

const int64_t k = ggml_nelements(dst);
const int64_t num_blocks = (k + SYCL_FILL_BLOCK_SIZE - 1) / SYCL_FILL_BLOCK_SIZE;
void * dst_d = dst->data;

switch (dst->type) {
case GGML_TYPE_F32:
stream->parallel_for(
sycl::nd_range<1>(num_blocks * SYCL_FILL_BLOCK_SIZE, SYCL_FILL_BLOCK_SIZE),
[=](sycl::nd_item<1> item) {
fill_kernel(static_cast<float *>(dst_d), k, value, item);
});
break;
case GGML_TYPE_F16:
{
sycl::half h_value = sycl::half(value);
stream->parallel_for(
sycl::nd_range<1>(num_blocks * SYCL_FILL_BLOCK_SIZE, SYCL_FILL_BLOCK_SIZE),
[=](sycl::nd_item<1> item) {
fill_kernel(static_cast<sycl::half *>(dst_d), k, h_value, item);
});
}
break;
default:
GGML_ABORT("unsupported type");
}
}
5 changes: 5 additions & 0 deletions ggml/src/ggml-sycl/fill.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#pragma once

#include "common.hpp"

void ggml_sycl_op_fill(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
1 change: 1 addition & 0 deletions ggml/src/ggml-sycl/gated_delta_net.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
#include "common.hpp"
#include "ggml.h"

void ggml_sycl_op_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
35 changes: 34 additions & 1 deletion ggml/src/ggml-sycl/ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,12 @@
#include "ggml-sycl/set.hpp"
#include "ggml-sycl/ssm_conv.hpp"
#include "ggml-sycl/sycl_hw.hpp"

#include "ggml-sycl/ssm_scan.hpp"
#include "ggml-sycl/fill.hpp"
#include "ggml-sycl/cumsum.hpp"
#include "ggml-sycl/diag.hpp"
#include "ggml-sycl/solve_tri.hpp"
#include "ggml-sycl/gated_delta_net.hpp"

static bool g_sycl_loaded = false;
int g_ggml_sycl_debug = 0;
Expand Down Expand Up @@ -4309,6 +4314,21 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
case GGML_OP_SSM_CONV:
ggml_sycl_ssm_conv(ctx, dst);
break;
case GGML_OP_SSM_SCAN:
ggml_sycl_ssm_scan(ctx, dst);
break;
case GGML_OP_FILL:
ggml_sycl_op_fill(ctx, dst);
break;
case GGML_OP_CUMSUM:
ggml_sycl_op_cumsum(ctx, dst);
break;
case GGML_OP_DIAG:
ggml_sycl_op_diag(ctx, dst);
break;
case GGML_OP_SOLVE_TRI:
ggml_sycl_op_solve_tri(ctx, dst);
break;
case GGML_OP_ROLL:
ggml_sycl_roll(ctx, dst);
break;
Expand Down Expand Up @@ -5019,6 +5039,19 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
return op->type == GGML_TYPE_F32;
case GGML_OP_ARANGE:
return op->type == GGML_TYPE_F32;
case GGML_OP_SSM_SCAN: {
if (op->src[3]->ne[0] == 1) {
return (op->src[0]->ne[0] == 128 || op->src[0]->ne[0] == 256) && op->src[0]->ne[1] % WARP_SIZE == 0;
} else {
return op->src[0]->ne[0] == 16 && op->src[0]->ne[1] == 1 && op->src[0]->ne[2] % 128 == 0 && op->src[4]->ne[1] == 1;
}
}
case GGML_OP_FILL:
case GGML_OP_CUMSUM:
case GGML_OP_DIAG:
return true;
case GGML_OP_SOLVE_TRI:
return op->src[0]->ne[0] <= SYCL_SOLVE_TRI_MAX_N && op->src[1]->ne[0] <= SYCL_SOLVE_TRI_MAX_K;
case GGML_OP_FLASH_ATTN_EXT:
return ggml_sycl_flash_attn_ext_supported(device, op);
default:
Expand Down
Loading