Skip to content

Commit a71a407

Browse files
eddnjjnchaxu01
andauthored
ggml-cpu: Integrate fp32=bf16xbf16 SME KleidiAI kernel (ggml-org#13053)
* ggml-cpu: Integrate fp32=bf16xbf16 SME KleidiAI kernel Signed-off-by: Dan Johansson <[email protected]> * * code review fixes Signed-off-by: Dan Johansson <[email protected]> * * adds a comment that clarifies barrier usage Signed-off-by: Dan Johansson <[email protected]> --------- Signed-off-by: Dan Johansson <[email protected]> Co-authored-by: Charles Xu <[email protected]>
1 parent 95e1888 commit a71a407

File tree

4 files changed

+414
-97
lines changed

4 files changed

+414
-97
lines changed

ggml/src/ggml-cpu/CMakeLists.txt

+18-11
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
428428
${KLEIDIAI_SRC}/kai/ukernels/
429429
${KLEIDIAI_SRC}/kai/ukernels/matmul/
430430
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/
431+
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/
431432
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/)
432433

433434
set(ARCH_FLAGS_TEMP "${ARCH_FLAGS}")
@@ -438,27 +439,33 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
438439
string(FIND "${ARCH_FLAGS_TEMP}" "+i8mm" I8MM_ENABLED)
439440
string(FIND "${ARCH_FLAGS_TEMP}" "+sme" SME_ENABLED)
440441

441-
set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS})
442+
set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS_TEMP})
442443

443-
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c)
444-
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c)
445-
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c)
446-
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c)
444+
list(APPEND GGML_KLEIDIAI_SOURCES
445+
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c
446+
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c
447+
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c
448+
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c)
447449

448450
if (NOT DOTPROD_ENABLED MATCHES -1)
449-
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c)
450-
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c)
451-
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c)
451+
list(APPEND GGML_KLEIDIAI_SOURCES
452+
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c
453+
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c
454+
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c)
452455
endif()
453456

454457
if (NOT I8MM_ENABLED MATCHES -1)
455458
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.c)
456459
endif()
457460

458461
if (NOT SME_ENABLED MATCHES -1)
459-
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c)
460-
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c)
461-
set(PRIVATE_ARCH_FLAGS "${PRIVATE_ARCH_FLAGS}+sve+sve2")
462+
list(APPEND GGML_KLEIDIAI_SOURCES
463+
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c
464+
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c
465+
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c
466+
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c
467+
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c)
468+
set(PRIVATE_ARCH_FLAGS "-fno-tree-vectorize;${PRIVATE_ARCH_FLAGS}+sve+sve2")
462469
endif()
463470

464471
set_source_files_properties(${GGML_KLEIDIAI_SOURCES} PROPERTIES COMPILE_OPTIONS "${PRIVATE_ARCH_FLAGS}")

ggml/src/ggml-cpu/kleidiai/kernels.cpp

+88-5
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,22 @@
44

55
// KleidiAI micro-kernels
66
#include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h"
7-
#include "kai_lhs_quant_pack_qsi8d32p_f32.h"
8-
#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
9-
#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
10-
#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
117
#include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
128
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
139
#include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h"
1410
#include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h"
1511
#include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h"
1612
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h"
13+
#include "kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h"
14+
15+
#include "kai_lhs_pack_bf16p2vlx2_f32_sme.h"
16+
#include "kai_lhs_quant_pack_qsi8d32p_f32.h"
17+
#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
18+
19+
#include "kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h"
20+
#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
21+
#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
22+
1723
#include "kai_common.h"
1824

1925
#include "kernels.h"
@@ -61,6 +67,53 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
6167
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
6268
},
6369
/* .required_cpu = */ CPU_FEATURE_SME,
70+
/* .lhs_type = */ GGML_TYPE_F32,
71+
/* .rhs_type = */ GGML_TYPE_Q4_0,
72+
/* .op_type = */ GGML_TYPE_F32,
73+
},
74+
{
75+
/* SME GEMM */
76+
/* .kern_info = */ {
77+
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
78+
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
79+
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
80+
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
81+
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
82+
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
83+
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
84+
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
85+
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
86+
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
87+
/* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
88+
},
89+
/* SME GEMV */
90+
/* .kern_info = */ {
91+
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
92+
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
93+
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
94+
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
95+
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
96+
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
97+
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
98+
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
99+
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
100+
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
101+
/* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
102+
},
103+
/* .lhs_info = */ {
104+
/* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme,
105+
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme,
106+
/* .packed_size = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme,
107+
/* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme,
108+
},
109+
/* .rhs_info = */ {
110+
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
111+
/* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
112+
},
113+
/* .required_cpu = */ CPU_FEATURE_SME,
114+
/* .lhs_type = */ GGML_TYPE_F32,
115+
/* .rhs_type = */ GGML_TYPE_F16,
116+
/* .op_type = */ GGML_TYPE_F32,
64117
},
65118
#endif
66119
#if defined(__APPLE__)
@@ -105,6 +158,9 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
105158
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
106159
},
107160
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
161+
/* .lhs_type = */ GGML_TYPE_F32,
162+
/* .rhs_type = */ GGML_TYPE_Q4_0,
163+
/* .op_type = */ GGML_TYPE_F32,
108164
},
109165
#endif
110166
#if defined(__ARM_FEATURE_MATMUL_INT8)
@@ -148,6 +204,9 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
148204
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
149205
},
150206
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
207+
/* .lhs_type = */ GGML_TYPE_F32,
208+
/* .rhs_type = */ GGML_TYPE_Q4_0,
209+
/* .op_type = */ GGML_TYPE_F32,
151210
},
152211
#endif
153212
#else
@@ -192,6 +251,9 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
192251
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
193252
},
194253
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
254+
/* .lhs_type = */ GGML_TYPE_F32,
255+
/* .rhs_type = */ GGML_TYPE_Q4_0,
256+
/* .op_type = */ GGML_TYPE_F32,
195257
},
196258
#endif
197259
#if defined(__ARM_FEATURE_DOTPROD)
@@ -235,12 +297,33 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
235297
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
236298
},
237299
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
300+
/* .lhs_type = */ GGML_TYPE_F32,
301+
/* .rhs_type = */ GGML_TYPE_Q4_0,
302+
/* .op_type = */ GGML_TYPE_F32,
238303
},
239304
#endif
240305
#endif
241306
};
242307

243-
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature features) {
308+
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor) {
309+
ggml_kleidiai_kernels * kernel = nullptr;
310+
311+
if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) {
312+
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
313+
if ((cpu_features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu &&
314+
gemm_gemv_kernels[i].lhs_type == tensor->src[1]->type &&
315+
gemm_gemv_kernels[i].rhs_type == tensor->src[0]->type &&
316+
gemm_gemv_kernels[i].op_type == tensor->type) {
317+
kernel = &gemm_gemv_kernels[i];
318+
break;
319+
}
320+
}
321+
}
322+
323+
return kernel;
324+
}
325+
326+
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features) {
244327
ggml_kleidiai_kernels * kernels = nullptr;
245328

246329
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {

ggml/src/ggml-cpu/kleidiai/kernels.h

+46-12
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44

55
#pragma once
66

7+
#include <functional>
8+
#include "ggml.h"
9+
710
enum cpu_feature {
811
CPU_FEATURE_NONE = 0,
912
CPU_FEATURE_DOTPROD = 1,
@@ -26,26 +29,53 @@ struct kernel_info {
2629
size_t (*get_nr)(void);
2730
size_t (*get_kr)(void);
2831
size_t (*get_sr)(void);
29-
size_t (*get_lhs_offset)(size_t m_idx, size_t k, size_t bl);
30-
size_t (*get_rhs_packed_offset)(size_t n_idx, size_t k, size_t bl);
32+
std::variant<
33+
std::function<size_t(size_t n_idx, size_t k, size_t bl)>,
34+
std::function<size_t(size_t m_idx, size_t k)>
35+
> get_lhs_offset;
36+
std::variant<
37+
std::function<size_t(size_t n_idx, size_t k, size_t bl)>,
38+
std::function<size_t(size_t n_idx, size_t k)>
39+
> get_rhs_packed_offset;
3140
size_t (*get_dst_offset)(size_t m_idx, size_t n_idx, size_t stride);
3241
size_t (*get_dst_size)(size_t m, size_t n);
33-
void (*run_kernel)(size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed,
34-
float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max);
42+
std::variant<
43+
std::function<void(size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed,
44+
float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max)>,
45+
std::function<void(size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row,
46+
size_t dst_stride_col, float clamp_min, float clamp_max)>
47+
> run_kernel;
3548
};
3649

3750
struct lhs_packing_info {
3851
size_t (*get_offset)(size_t m_idx, size_t lhs_stride);
39-
size_t (*get_packed_offset)(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
40-
size_t (*packed_size)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
41-
void (*pack_func)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs,
42-
size_t lhs_stride, void* lhs_packed);
52+
std::variant<
53+
std::function<size_t(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr)>,
54+
std::function<size_t(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr)>
55+
> get_packed_offset;
56+
std::variant<
57+
std::function<size_t(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr)>,
58+
std::function<size_t(size_t m, size_t k, size_t mr, size_t kr, size_t sr)>
59+
> packed_size;
60+
std::variant<
61+
std::function<void(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs,
62+
size_t lhs_stride, void* lhs_packed)>,
63+
std::function<void(size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride,
64+
void* lhs_packed)>
65+
> pack_func;
4366
};
4467

4568
struct rhs_packing_info {
46-
size_t (*packed_size)(size_t n, size_t k, size_t nr, size_t kr, size_t bl);
47-
void (*pack_func)(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
48-
const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params);
69+
std::variant<
70+
std::function<size_t(size_t n, size_t k, size_t nr, size_t kr, size_t bl)>,
71+
std::function<size_t(size_t n, size_t k)>
72+
> packed_size;
73+
std::variant<
74+
std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
75+
const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params)>,
76+
std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs,
77+
const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params)>
78+
> pack_func;
4979
};
5080

5181
struct ggml_kleidiai_kernels {
@@ -55,6 +85,10 @@ struct ggml_kleidiai_kernels {
5585
rhs_packing_info rhs_info;
5686

5787
cpu_feature required_cpu;
88+
ggml_type lhs_type;
89+
ggml_type rhs_type;
90+
ggml_type op_type;
5891
};
5992

60-
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features);
93+
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor);
94+
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features);

0 commit comments

Comments
 (0)