Skip to content

sycl: addressing non-contiguous src1 mul_mats (nc and batched) #13343

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 6 additions & 11 deletions ggml/src/ggml-sycl/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,17 +118,12 @@ static void crash() {
GGML_ABORT("SYCL error");
}

#define SYCL_CHECK(err) \
do { \
auto err_ = (err); \
if (err_ != 0) \
ggml_sycl_error( \
#err, \
__func__, \
__FILE__, \
__LINE__, \
"Meet error in this line code!"); \
} while (0)
#define SYCL_CHECK(err) \
do { \
auto err_ = (err); \
if (err_ != 0) \
ggml_sycl_error(#err, __func__, __FILE__, __LINE__, "Exception caught in this line of code."); \
} while (0)

#if DPCT_COMPAT_RT_VERSION >= 11100
#define GGML_SYCL_ASSUME(x) __builtin_assume(x)
Expand Down
66 changes: 43 additions & 23 deletions ggml/src/ggml-sycl/convert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,41 +437,52 @@ static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int64_t k
}

template <typename src_t, typename dst_t>
static void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k,
const sycl::nd_item<3> &item_ct1) {
static void convert_unary_nc(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01,
const int64_t ne02, const int64_t s01, const int64_t s02, const int64_t s03,
const sycl::nd_item<3> & item_ct1) {

const int64_t work_group_size = item_ct1.get_local_range(2);
const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2);
const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2);

const int64_t i01 = item_ct1.get_group(1);
const int64_t i02 = item_ct1.get_group(0) % ne02;
const int64_t i03 = item_ct1.get_group(0) / ne02;

// make each work-item deal with more elements since sycl global range can not exceed max int
const src_t * x = (const src_t *) vx;
for (int64_t i = global_id; i < k; i += work_group_size * item_ct1.get_group_range(2)) {
y[i] = x[i];
const src_t * x = static_cast<const src_t *>(vx);
const int64_t ix = i03 * s03 + i02 * s02 + i01 * s01;
const int64_t iy = ((i03 * ne02 + i02) * ne01 + i01) * ne00;

#pragma unroll
for (int64_t i00 = global_id; i00 < ne00; i00 += work_group_size * item_ct1.get_group_range(2)) {
y[iy + i00] = static_cast<dst_t>(x[ix + i00]);
}
}

template <typename src_t, typename dst_t>
static void convert_unary_sycl(const void *__restrict__ vx,
dst_t *__restrict__ y, const int64_t k,
dpct::queue_ptr stream) {
const int64_t num_blocks = (k + SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / SYCL_DEQUANTIZE_BLOCK_SIZE;
static void convert_unary_nc_sycl(const void * __restrict__ vx, dst_t * __restrict__ y,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t s01, const int64_t s02, const int64_t s03, dpct::queue_ptr queue) {
dpct::has_capability_or_fail(queue->get_device(), { sycl::aspect::fp16 });

sycl::range<3> global_size(ne02 * ne03, ne01, ceil_div(ne00, SYCL_DEQUANTIZE_BLOCK_SIZE));

// decrease global range when it exceeds the max int
int64_t local_size = downsample_sycl_global_range(num_blocks, SYCL_DEQUANTIZE_BLOCK_SIZE);
sycl::range<3> block_nums(1, 1, num_blocks);
sycl::range<3> local_range(1, 1, local_size);
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
// TODO: Downsample logic is separated from the kernel, a rewrite is desirable
int64_t downsized_workgroup = downsample_sycl_global_range(global_size[0], SYCL_DEQUANTIZE_BLOCK_SIZE);
sycl::range<3> workgroup_size(1, 1, downsized_workgroup);

stream->parallel_for(
sycl::nd_range<3>(block_nums * local_range, local_range),
[=](sycl::nd_item<3> item_ct1) {
convert_unary<src_t>(vx, y, k, item_ct1);
});
}
queue->parallel_for(sycl::nd_range<3>(global_size * workgroup_size, workgroup_size), [=](sycl::nd_item<3> item_ct1) {
convert_unary_nc<src_t>(vx, y, ne00, ne01, ne02, s01, s02, s03, item_ct1);
});
}

to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor *dst) {
template <typename src_t, typename dst_t>
static void convert_unary_sycl(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr queue) {
convert_unary_nc_sycl<src_t>(vx, y, k, 1, 1, 1, k, k, k, queue);
}

to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
switch (type) {
case GGML_TYPE_Q4_0:
if (dst->src[0]->extra &&
Expand Down Expand Up @@ -574,3 +585,12 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
return nullptr;
}
}

to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type) {
switch (type) {
case GGML_TYPE_F32:
return convert_unary_nc_sycl<float>;
default:
return nullptr;
}
}
21 changes: 14 additions & 7 deletions ggml/src/ggml-sycl/convert.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//
// MIT license
// Copyright (C) 2024 Intel Corporation
// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: MIT
//

Expand All @@ -16,12 +16,19 @@
#include "common.hpp"

template <typename T>
using to_t_sycl_t = void (*)(const void *__restrict__ x, T *__restrict__ y,
int64_t k, dpct::queue_ptr stream);
typedef to_t_sycl_t<float> to_fp32_sycl_t;
using to_t_sycl_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int64_t k, dpct::queue_ptr stream);
typedef to_t_sycl_t<float> to_fp32_sycl_t;
typedef to_t_sycl_t<sycl::half> to_fp16_sycl_t;

to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor *dst);
to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst);
to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst);
to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor * dst);

#endif // GGML_SYCL_CONVERT_HPP
// Nc = Non-contiguous
template <typename T>
using to_t_nc_sycl_t = void (*)(const void * x, T * y, int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03,
int64_t s01, int64_t s02, int64_t s03, dpct::queue_ptr queue);

typedef to_t_nc_sycl_t<sycl::half> to_fp16_nc_sycl_t;
to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type);

#endif // GGML_SYCL_CONVERT_HPP
159 changes: 75 additions & 84 deletions ggml/src/ggml-sycl/ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2694,138 +2694,132 @@ catch (sycl::exception const &exc) {
std::exit(1);
}

static void k_compute_batched_ptrs(const sycl::half *src0_as_f16,
const sycl::half *src1_as_f16, char *dst,
const void **ptrs_src, void **ptrs_dst,
int64_t ne12, int64_t ne13, int64_t ne23,
size_t nb02, size_t nb03, size_t nb12,
size_t nb13, size_t nbd2, size_t nbd3,
int64_t r2, int64_t r3,
const sycl::nd_item<3> &item_ct1) {
int64_t i13 = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
item_ct1.get_local_id(2);
int64_t i12 = item_ct1.get_group(1) * item_ct1.get_local_range(1) +
item_ct1.get_local_id(1);
static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, char * dst,
const void ** ptrs_src, void ** ptrs_dst, int64_t ne12, int64_t ne13, int64_t ne23,
size_t nb02, size_t nb03, size_t nb12, size_t nb13, size_t nbd2, size_t nbd3,
int64_t r2, int64_t r3, const sycl::nd_item<3> & item_ct1) {
const int64_t i13 = item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2);
const int64_t i12 = item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1);

if (i13 >= ne13 || i12 >= ne12) {
return;
}

int64_t i03 = i13 / r3;
int64_t i02 = i12 / r2;
const int64_t i03 = i13 / r3;
const int64_t i02 = i12 / r2;

const uint8_t * src0_bytes = reinterpret_cast<const uint8_t *>(src0_as_f16);
const uint8_t * src1_bytes = reinterpret_cast<const uint8_t *>(src1_as_f16);
uint8_t * dst_bytes = reinterpret_cast<uint8_t *>(dst);

ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12 + i13*nb13;
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
ptrs_src[0 * ne23 + i12 + i13 * ne12] = src0_bytes + i02 * nb02 + i03 * nb03;
ptrs_src[1 * ne23 + i12 + i13 * ne12] = src1_bytes + i12 * nb12 + i13 * nb13;
ptrs_dst[0 * ne23 + i12 + i13 * ne12] = dst_bytes + i12 * nbd2 + i13 * nbd3;
}

static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
const ggml_tensor *src0,
const ggml_tensor *src1,
ggml_tensor *dst) try {
static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor * src0,
const ggml_tensor * src1, ggml_tensor * dst) try {
GGML_ASSERT(!ggml_is_transposed(src0));
GGML_ASSERT(!ggml_is_transposed(src1));
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
GGML_ASSERT(src0->type == GGML_TYPE_F16);

GGML_TENSOR_BINARY_OP_LOCALS

// TODO: see https://github.com/ggml-org/llama.cpp/pull/13155
// Batched mul_mat requires a rewrite to support both oneDNN and non-contiguous dst
GGML_ASSERT(ggml_is_contiguous(dst));

SYCL_CHECK(ggml_sycl_set_device(ctx.device));
queue_ptr main_stream = ctx.stream();;
queue_ptr queue = ctx.stream();

void * src0_ddq = src0->data;
sycl::half *src0_as_f16 = (sycl::half *)src0_ddq;
float * src1_ddf = (float *) src1->data;
float * dst_ddf = (float *) dst->data;
dpct::has_capability_or_fail(queue->get_device(), { sycl::aspect::fp16 });

// convert src1 to fp16
const sycl::half * src0_f16 = static_cast<const sycl::half *>(src0->data);
float * dst_ddf = static_cast<float *>(dst->data);

const sycl::half * src1_f16 = static_cast<const sycl::half *>(src1->data);
const size_t type_size_src1 = ggml_type_size(src1->type);
GGML_ASSERT(nb10 == type_size_src1);

// SRC1 strides
int64_t s11 = nb11 / type_size_src1;
int64_t s12 = nb12 / type_size_src1;
int64_t s13 = nb13 / type_size_src1;
ggml_sycl_pool_alloc<sycl::half> src1_f16_alloc(ctx.pool());

// convert src1 to fp16
if (src1->type != GGML_TYPE_F16) {
const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type);
GGML_ASSERT(to_fp16_nc_sycl != nullptr);
const int64_t ne_src1 = ggml_nelements(src1);
src1_f16_alloc.alloc(ne_src1);
GGML_ASSERT(to_fp16_sycl != nullptr);
to_fp16_sycl(src1_ddf, src1_f16_alloc.get(), ne_src1, main_stream);
to_fp16_nc_sycl(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, queue);

src1_f16 = src1_f16_alloc.get();
s11 = ne10;
s12 = ne11 * s11;
s13 = ne12 * s12;
}
sycl::half *src1_f16 = src1->type == GGML_TYPE_F16 ? (sycl::half *)src1_ddf
: src1_f16_alloc.get();

char * dst_t;
ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
char * dst_t = reinterpret_cast<char *>(dst_ddf);

dpct::library_data_t cu_compute_type = dpct::library_data_t::real_float;
dpct::library_data_t cu_data_type = dpct::library_data_t::real_float;
dpct::library_data_t mkl_compute_type = dpct::library_data_t::real_float;
dpct::library_data_t mkl_data_type = dpct::library_data_t::real_float;

// dst strides
size_t nbd2 = dst->nb[2];
size_t nbd3 = dst->nb[3];

const float alpha_f32 = 1.0f;
const float beta_f32 = 0.0f;
const float beta_f32 = 0.0f;

const void * alpha = &alpha_f32;
const void * beta = &beta_f32;

dst_t = (char *) dst_ddf;

GGML_ASSERT(ne12 % ne02 == 0);
GGML_ASSERT(ne13 % ne03 == 0);

// broadcast factors
const int64_t r2 = ne12/ne02;
const int64_t r3 = ne13/ne03;
const int64_t r2 = ne12 / ne02;
const int64_t r3 = ne13 / ne03;

if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
*main_stream, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
(const char *) src0_as_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
(const char *) src1_f16, dpct::library_data_t::real_half, nb11 / nb10, nb12 / nb10, beta, (char *) dst_t,
cu_data_type, ne01, nb2 / nb0, ne12 * ne13, cu_compute_type)));
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_t,
mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
} else {
const int ne23 = ne12*ne13;
const int ne23 = ne12 * ne13;

ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
ggml_sycl_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2 * ne23);
ggml_sycl_pool_alloc<void *> ptrs_dst(ctx.pool(), 1 * ne23);
ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);

sycl::range<3> block_dims(1, ne12, ne13);
/*
DPCT1049:47: The work-group size passed to the SYCL kernel may exceed
the limit. To get the device limit, query
info::device::max_work_group_size. Adjust the work-group size if needed.
*/
{
dpct::has_capability_or_fail(main_stream->get_device(),
{sycl::aspect::fp16});

main_stream->submit([&](sycl::handler &cgh) {
const void **ptrs_src_get = ptrs_src.get();
void **ptrs_dst_get = ptrs_dst.get();
size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : nb12 / 2;
size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : nb13 / 2;
cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
k_compute_batched_ptrs(
src0_as_f16, src1_f16,
dst_t, ptrs_src_get,
ptrs_dst_get, ne12, ne13, ne23,
nb02, nb03, nb12_scaled, nb13_scaled,
nbd2, nbd3, r2, r3, item_ct1);
});
queue->submit([&](sycl::handler & cgh) {
const void ** ptrs_src_get = ptrs_src.get();
void ** ptrs_dst_get = ptrs_dst.get();
size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
k_compute_batched_ptrs(src0_f16, src1_f16, dst_t, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
});
}
});

SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
*main_stream, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
*queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
(const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
(const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, nb11 / nb10, beta,
(void **) (ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, cu_compute_type, matrix_info.get())));
(const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta,
(void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get())));
}
}
catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
<< ", line:" << __LINE__ << std::endl;
std::exit(1);
} catch (const sycl::exception & exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
std::exit(1);
}

inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
Expand Down Expand Up @@ -2966,7 +2960,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
// The kernel from the if path is faster for that specific case, but does not support all mul mats.
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
}
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
// KQV single-batch
ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
Expand Down Expand Up @@ -3873,9 +3867,6 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
if (a->ne[3] != b->ne[3]) {
return false;
}
if (!ggml_is_contiguous(b)) {
return false;
}
ggml_type a_type = a->type;
if (a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ4_XS ||
a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ3_S ||
Expand Down
Loading