Skip to content

Commit ac04335

Browse files
committed
sycl: use DNN for matrices multiplication
1 parent 1f73301 commit ac04335

File tree

6 files changed

+221
-124
lines changed

6 files changed

+221
-124
lines changed

docs/backend/SYCL.md

+2
Original file line numberDiff line numberDiff line change
@@ -731,6 +731,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512
731731
| GGML_SYCL_DEVICE_ARCH | Optional (except for AMD) | Set the SYCL device architecture, optional except for AMD. Setting the device architecture can improve the performance. See the table [--offload-arch](https://github.com/intel/llvm/blob/sycl/sycl/doc/design/OffloadDesign.md#--offload-arch) for a list of valid architectures. |
732732
| GGML_SYCL_F16 | OFF *(default)* \|ON *(optional)* | Enable FP16 build with SYCL code path. |
733733
| GGML_SYCL_GRAPH | ON *(default)* \|OFF *(Optional)* | Enable build with [SYCL Graph extension](https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/experimental/sycl_ext_oneapi_graph.asciidoc). |
734+
| GGML_SYCL_DNN | ON *(default)* \|OFF *(Optional)* | Enable build with oneDNN. |
734735
| CMAKE_C_COMPILER | `icx` *(Linux)*, `icx/cl` *(Windows)* | Set `icx` compiler for SYCL code path. |
735736
| CMAKE_CXX_COMPILER | `icpx` *(Linux)*, `icx` *(Windows)* | Set `icpx/icx` compiler for SYCL code path. |
736737

@@ -741,6 +742,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512
741742
| GGML_SYCL_DEBUG | 0 (default) or 1 | Enable log function by macro: GGML_SYCL_DEBUG |
742743
| GGML_SYCL_DISABLE_OPT | 0 (default) or 1 | Disable optimize features based on Intel GPU type, to compare the performance increase |
743744
| GGML_SYCL_DISABLE_GRAPH | 0 or 1 (default) | Disable running computations through SYCL Graphs feature. Disabled by default because graph performance isn't yet better than non-graph performance. |
745+
| GGML_SYCL_DISABLE_DNN | 0 (default) or 1 | Disable running computations through oneDNN and always use oneMKL. |
744746
| ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.<br>Recommended to use when --split-mode = layer |
745747

746748

ggml/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ option(GGML_RPC "ggml: use RPC"
193193
option(GGML_SYCL "ggml: use SYCL" OFF)
194194
option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF)
195195
option(GGML_SYCL_GRAPH "ggml: enable graphs in the SYCL backend" ON)
196+
option(GGML_SYCL_DNN "ggml: enable oneDNN in the SYCL backend" ON)
196197
set (GGML_SYCL_TARGET "INTEL" CACHE STRING
197198
"ggml: sycl target device")
198199
set (GGML_SYCL_DEVICE_ARCH "" CACHE STRING

ggml/src/ggml-sycl/CMakeLists.txt

+27-23
Original file line numberDiff line numberDiff line change
@@ -49,35 +49,39 @@ endif()
4949
target_compile_options(ggml-sycl PRIVATE "-Wno-narrowing")
5050

5151
# Link against oneDNN
52-
find_package(DNNL)
5352
set(GGML_SYCL_DNNL 0)
54-
if(DNNL_FOUND)
55-
if (DEFINED ENV{ONEAPI_ROOT} AND NOT DEFINED DNNL_GPU_VENDOR)
56-
# Assuming oneDNN packaged with oneapi release is used which
57-
# supports only intel target
58-
set(DNNL_GPU_VENDOR "INTEL")
59-
if(NOT "${GGML_SYCL_TARGET}" STREQUAL "INTEL")
60-
message(WARNING "oneDNN builds bundled with oneapi release only support INTEL target")
53+
if(GGML_SYCL_DNN)
54+
find_package(DNNL)
55+
if(DNNL_FOUND)
56+
if (DEFINED ENV{ONEAPI_ROOT} AND NOT DEFINED DNNL_GPU_VENDOR)
57+
# Assuming oneDNN packaged with oneapi release is used which
58+
# supports only intel target
59+
set(DNNL_GPU_VENDOR "INTEL")
60+
if(NOT "${GGML_SYCL_TARGET}" STREQUAL "INTEL")
61+
message(WARNING "oneDNN builds bundled with oneapi release only support INTEL target")
62+
endif()
6163
endif()
62-
endif()
6364

64-
# Verify oneDNN was compiled for the same target as llama
65-
if("${GGML_SYCL_TARGET}" STREQUAL "${DNNL_GPU_VENDOR}")
66-
target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl)
67-
set(GGML_SYCL_DNNL 1)
68-
get_target_property(CONFIGS DNNL::dnnl IMPORTED_CONFIGURATIONS)
69-
foreach(CONFIG ${CONFIGS})
70-
get_target_property(DNNL_LIB DNNL::dnnl IMPORTED_LOCATION_${CONFIG})
71-
message(STATUS "Found oneDNN: ${DNNL_LIB}")
72-
endforeach()
65+
# Verify oneDNN was compiled for the same target as llama
66+
if("${GGML_SYCL_TARGET}" STREQUAL "${DNNL_GPU_VENDOR}")
67+
target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl)
68+
set(GGML_SYCL_DNNL 1)
69+
get_target_property(CONFIGS DNNL::dnnl IMPORTED_CONFIGURATIONS)
70+
foreach(CONFIG ${CONFIGS})
71+
get_target_property(DNNL_LIB DNNL::dnnl IMPORTED_LOCATION_${CONFIG})
72+
message(STATUS "Found oneDNN: ${DNNL_LIB}")
73+
endforeach()
74+
else()
75+
message(WARNING
76+
"oneDNN must be compiled for the same target as llama.cpp.
77+
llama.cpp: ${GGML_SYCL_TARGET}, oneDNN: ${DNNL_GPU_VENDOR}.
78+
Disabling oneDNN support.")
79+
endif()
7380
else()
74-
message(WARNING
75-
"oneDNN must be compiled for the same target as llama.cpp.
76-
llama.cpp: ${GGML_SYCL_TARGET}, oneDNN: ${DNNL_GPU_VENDOR}.
77-
Disabling oneDNN support.")
81+
message(STATUS "oneDNN not found, disabling oneDNN support")
7882
endif()
7983
else()
80-
message(STATUS "oneDNN not found, disabling oneDNN support")
84+
message(STATUS "oneDNN support disabled by the user")
8185
endif()
8286
target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_DNNL=${GGML_SYCL_DNNL})
8387

ggml/src/ggml-sycl/gemm.hpp

+37-8
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,36 @@ class DnnlGemmWrapper {
3232
else static_assert(0);
3333
}
3434

35-
static inline void row_gemm(ggml_backend_sycl_context & ctx, bool a_trans, bool b_trans, int m, int n, int k,
36-
const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
35+
// matrix A has m rows, k columns
36+
// matrix B has k rows, n columns
37+
// nra - number of elements to skip when moving into next row in A
38+
// nrb - number of elements to skip when moving into next row in B
39+
// nca - number of elements to skip when moving into next column in A
40+
// ncb - number of elements to skip when moving into next column in B
41+
// stride_a - number of elements to skip when moving to next A matrix
42+
// stride_b - number of elements to skip when moving to next B matrix
43+
// batches_a - number of A matrices
44+
// batches_b - number of B matrices
45+
static void gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
46+
const void * a, dt at, dnnl_dim_t nra, dnnl_dim_t nca, dnnl_dim_t stride_a,
47+
const void * b, dt bt, dnnl_dim_t nrb, dnnl_dim_t ncb, dnnl_dim_t stride_b,
48+
void * c, dt ct, const queue_ptr & q, dnnl_dim_t batches_a, dnnl_dim_t batches_b) {
49+
3750
auto stream = ctx.stream_dnnl(q);
3851
auto eng = ctx.engine_dnnl(q);
39-
dnnl::memory::dims a_dims = { m, k };
40-
dnnl::memory::dims b_dims = { k, n };
41-
dnnl::memory::dims c_dims = { m, n };
42-
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
43-
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
44-
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
52+
53+
// { # strides, # rows, # columns }
54+
dnnl::memory::dims a_dims = { batches_a, m, k };
55+
dnnl::memory::dims b_dims = { batches_b, k, n };
56+
dnnl::memory::dims c_dims = { std::max(batches_a, batches_b), m, n };
57+
58+
// { # elements to skip to next stride, # elements to skip to next row, # elements to skip to next column }
59+
dnnl::memory::dims a_strides = { stride_a, nra, nca };
60+
dnnl::memory::dims b_strides = { stride_b, nrb, ncb };
61+
62+
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_strides);
63+
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_strides);
64+
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::abc);
4565

4666
dnnl::primitive_attr primitive_attr;
4767
primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
@@ -63,6 +83,15 @@ class DnnlGemmWrapper {
6383

6484
matmul_prim.execute(stream, matmul_args);
6585
}
86+
87+
// matrices A and B are column major, both having k rows
88+
// matrix A has m column, matrix B has n columns
89+
// output: column major matrix C = A transposed * B
90+
static void row_gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
91+
const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
92+
93+
gemm(ctx, m, n, k, a, at, k, 1, k * m, b, bt, 1, k, n * k, c, ct, q, 1, 1);
94+
}
6695
};
6796

6897
#endif

0 commit comments

Comments
 (0)