|
4 | 4 |
|
5 | 5 | // KleidiAI micro-kernels
|
6 | 6 | #include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h"
|
7 |
| -#include "kai_lhs_quant_pack_qsi8d32p_f32.h" |
8 |
| -#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h" |
9 |
| -#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h" |
10 |
| -#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h" |
11 | 7 | #include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
|
12 | 8 | #include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
|
13 | 9 | #include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h"
|
14 | 10 | #include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h"
|
15 | 11 | #include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h"
|
16 | 12 | #include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h"
|
| 13 | +#include "kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h" |
| 14 | + |
| 15 | +#include "kai_lhs_pack_bf16p2vlx2_f32_sme.h" |
| 16 | +#include "kai_lhs_quant_pack_qsi8d32p_f32.h" |
| 17 | +#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h" |
| 18 | + |
| 19 | +#include "kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h" |
| 20 | +#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h" |
| 21 | +#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h" |
| 22 | + |
17 | 23 | #include "kai_common.h"
|
18 | 24 |
|
19 | 25 | #include "kernels.h"
|
@@ -61,6 +67,53 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
61 | 67 | /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
|
62 | 68 | },
|
63 | 69 | /* .required_cpu = */ CPU_FEATURE_SME,
|
| 70 | + /* .lhs_type = */ GGML_TYPE_F32, |
| 71 | + /* .rhs_type = */ GGML_TYPE_Q4_0, |
| 72 | + /* .op_type = */ GGML_TYPE_F32, |
| 73 | + }, |
| 74 | + { |
| 75 | + /* SME GEMM */ |
| 76 | + /* .kern_info = */ { |
| 77 | + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, |
| 78 | + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, |
| 79 | + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, |
| 80 | + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, |
| 81 | + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, |
| 82 | + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, |
| 83 | + /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, |
| 84 | + /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, |
| 85 | + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, |
| 86 | + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, |
| 87 | + /* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, |
| 88 | + }, |
| 89 | + /* SME GEMV */ |
| 90 | + /* .kern_info = */ { |
| 91 | + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, |
| 92 | + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, |
| 93 | + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, |
| 94 | + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, |
| 95 | + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, |
| 96 | + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, |
| 97 | + /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, |
| 98 | + /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, |
| 99 | + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, |
| 100 | + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, |
| 101 | + /* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, |
| 102 | + }, |
| 103 | + /* .lhs_info = */ { |
| 104 | + /* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme, |
| 105 | + /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme, |
| 106 | + /* .packed_size = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme, |
| 107 | + /* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme, |
| 108 | + }, |
| 109 | + /* .rhs_info = */ { |
| 110 | + /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme, |
| 111 | + /* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme, |
| 112 | + }, |
| 113 | + /* .required_cpu = */ CPU_FEATURE_SME, |
| 114 | + /* .lhs_type = */ GGML_TYPE_F32, |
| 115 | + /* .rhs_type = */ GGML_TYPE_F16, |
| 116 | + /* .op_type = */ GGML_TYPE_F32, |
64 | 117 | },
|
65 | 118 | #endif
|
66 | 119 | #if defined(__APPLE__)
|
@@ -105,6 +158,9 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
105 | 158 | /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
106 | 159 | },
|
107 | 160 | /* .required_cpu = */ CPU_FEATURE_DOTPROD,
|
| 161 | + /* .lhs_type = */ GGML_TYPE_F32, |
| 162 | + /* .rhs_type = */ GGML_TYPE_Q4_0, |
| 163 | + /* .op_type = */ GGML_TYPE_F32, |
108 | 164 | },
|
109 | 165 | #endif
|
110 | 166 | #if defined(__ARM_FEATURE_MATMUL_INT8)
|
@@ -148,6 +204,9 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
148 | 204 | /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
149 | 205 | },
|
150 | 206 | /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
|
| 207 | + /* .lhs_type = */ GGML_TYPE_F32, |
| 208 | + /* .rhs_type = */ GGML_TYPE_Q4_0, |
| 209 | + /* .op_type = */ GGML_TYPE_F32, |
151 | 210 | },
|
152 | 211 | #endif
|
153 | 212 | #else
|
@@ -192,6 +251,9 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
192 | 251 | /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
193 | 252 | },
|
194 | 253 | /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
|
| 254 | + /* .lhs_type = */ GGML_TYPE_F32, |
| 255 | + /* .rhs_type = */ GGML_TYPE_Q4_0, |
| 256 | + /* .op_type = */ GGML_TYPE_F32, |
195 | 257 | },
|
196 | 258 | #endif
|
197 | 259 | #if defined(__ARM_FEATURE_DOTPROD)
|
@@ -235,12 +297,33 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
235 | 297 | /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
236 | 298 | },
|
237 | 299 | /* .required_cpu = */ CPU_FEATURE_DOTPROD,
|
| 300 | + /* .lhs_type = */ GGML_TYPE_F32, |
| 301 | + /* .rhs_type = */ GGML_TYPE_Q4_0, |
| 302 | + /* .op_type = */ GGML_TYPE_F32, |
238 | 303 | },
|
239 | 304 | #endif
|
240 | 305 | #endif
|
241 | 306 | };
|
242 | 307 |
|
243 |
| -ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature features) { |
| 308 | +ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor) { |
| 309 | + ggml_kleidiai_kernels * kernel = nullptr; |
| 310 | + |
| 311 | + if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) { |
| 312 | + for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) { |
| 313 | + if ((cpu_features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu && |
| 314 | + gemm_gemv_kernels[i].lhs_type == tensor->src[1]->type && |
| 315 | + gemm_gemv_kernels[i].rhs_type == tensor->src[0]->type && |
| 316 | + gemm_gemv_kernels[i].op_type == tensor->type) { |
| 317 | + kernel = &gemm_gemv_kernels[i]; |
| 318 | + break; |
| 319 | + } |
| 320 | + } |
| 321 | + } |
| 322 | + |
| 323 | + return kernel; |
| 324 | +} |
| 325 | + |
| 326 | +ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features) { |
244 | 327 | ggml_kleidiai_kernels * kernels = nullptr;
|
245 | 328 |
|
246 | 329 | for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
|
|
0 commit comments