|
| 1 | +#include "ggml-blas.h" |
| 2 | +#include "ggml-backend-impl.h" |
| 3 | + |
| 4 | +#include <future> |
| 5 | +#include <vector> |
| 6 | + |
| 7 | +#if defined(GGML_USE_ACCELERATE) |
| 8 | +# include <Accelerate/Accelerate.h> |
| 9 | +#elif defined(GGML_BLAS_USE_MKL) |
| 10 | +# include <mkl.h> |
| 11 | +#else |
| 12 | +# include <cblas.h> |
| 13 | +# ifdef BLIS_ENABLE_CBLAS |
| 14 | +# include <blis.h> |
| 15 | +# endif |
| 16 | +#endif |
| 17 | + |
| 18 | +struct ggml_backend_blas_context { |
| 19 | + int n_threads = GGML_DEFAULT_N_THREADS; |
| 20 | + std::unique_ptr<char[]> work_data; |
| 21 | + size_t work_size = 0; |
| 22 | +#ifndef GGML_USE_OPENMP |
| 23 | + std::vector<std::future<void>> tasks; |
| 24 | +#endif |
| 25 | +}; |
| 26 | + |
| 27 | +// helper function to determine if it is better to use BLAS or not |
| 28 | +// for large matrices, BLAS is faster |
| 29 | +static bool ggml_backend_blas_use_blas(const struct ggml_tensor * dst) { |
| 30 | + const struct ggml_tensor * src0 = dst->src[0]; |
| 31 | + const struct ggml_tensor * src1 = dst->src[1]; |
| 32 | + |
| 33 | + const int64_t ne10 = src1->ne[0]; |
| 34 | + |
| 35 | + const int64_t ne0 = dst->ne[0]; |
| 36 | + const int64_t ne1 = dst->ne[1]; |
| 37 | + |
| 38 | + // TODO: find the optimal values for these |
| 39 | + if (ggml_is_contiguous(src0) && |
| 40 | + ggml_is_contiguous(src1) && |
| 41 | + src1->type == GGML_TYPE_F32 && |
| 42 | + (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) { |
| 43 | + |
| 44 | + /*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/ |
| 45 | + return true; |
| 46 | + } |
| 47 | + |
| 48 | + return false; |
| 49 | +} |
| 50 | + |
| 51 | +static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) { |
| 52 | + const struct ggml_tensor * src0 = dst->src[0]; |
| 53 | + const struct ggml_tensor * src1 = dst->src[1]; |
| 54 | + |
| 55 | + GGML_TENSOR_BINARY_OP_LOCALS |
| 56 | + |
| 57 | + const enum ggml_type type = src0->type; |
| 58 | + |
| 59 | + GGML_ASSERT(ne0 == ne01); |
| 60 | + GGML_ASSERT(ne1 == ne11); |
| 61 | + GGML_ASSERT(ne2 == ne12); |
| 62 | + GGML_ASSERT(ne3 == ne13); |
| 63 | + |
| 64 | + // we don't support permuted src0 or src1 |
| 65 | + GGML_ASSERT(nb00 == ggml_type_size(type)); |
| 66 | + GGML_ASSERT(nb10 == ggml_type_size(src1->type)); |
| 67 | + |
| 68 | + // dst cannot be transposed or permuted |
| 69 | + GGML_ASSERT(nb0 == sizeof(float)); |
| 70 | + GGML_ASSERT(nb0 <= nb1); |
| 71 | + GGML_ASSERT(nb1 <= nb2); |
| 72 | + GGML_ASSERT(nb2 <= nb3); |
| 73 | + |
| 74 | + // broadcast factors |
| 75 | + const int64_t r2 = ne12/ne02; |
| 76 | + const int64_t r3 = ne13/ne03; |
| 77 | + |
| 78 | + const int64_t ne_plane = ne01*ne00; |
| 79 | + const size_t desired_wsize = type == GGML_TYPE_F32 ? 0 : ne03*ne02*ne_plane*sizeof(float); |
| 80 | + |
| 81 | + if (ctx->work_size < desired_wsize) { |
| 82 | + ctx->work_data.reset(new char[desired_wsize]); |
| 83 | + ctx->work_size = desired_wsize; |
| 84 | + } |
| 85 | + void * wdata = ctx->work_data.get(); |
| 86 | + |
| 87 | + // convert src0 to float |
| 88 | + if (type != GGML_TYPE_F32) { |
| 89 | + ggml_type_traits_t type_traits = ggml_internal_get_type_traits(type); |
| 90 | + ggml_to_float_t const to_float = type_traits.to_float; |
| 91 | + |
| 92 | + for (int64_t i03 = 0; i03 < ne03; i03++) { |
| 93 | + for (int64_t i02 = 0; i02 < ne02; i02++) { |
| 94 | + const void * x = (char *) src0->data + i02*nb02 + i03*nb03; |
| 95 | + float * const wplane = (float *) wdata + i02*ne_plane + i03*ne02*ne_plane; |
| 96 | + |
| 97 | + const int min_cols_per_thread = 4096; |
| 98 | + const int min_rows_per_thread = std::max((int)(min_cols_per_thread/ne00), 1); |
| 99 | + const int n_threads = std::max(std::min(ctx->n_threads, (int)(ne01/min_rows_per_thread)), 1); |
| 100 | + |
| 101 | +#ifdef GGML_USE_OPENMP |
| 102 | + #pragma omp parallel for num_threads(n_threads) |
| 103 | + for (int64_t i01 = 0; i01 < ne01; i01++) { |
| 104 | + to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00); |
| 105 | + } |
| 106 | +#else |
| 107 | + for (int i = 1; i < n_threads; i++) { |
| 108 | + const int64_t start = i*ne01/n_threads; |
| 109 | + const int64_t end = (i + 1)*ne01/n_threads; |
| 110 | + if (start < end) { |
| 111 | + ctx->tasks.push_back(std::async(std::launch::async, [=]() { |
| 112 | + for (int64_t i01 = start; i01 < end; i01++) { |
| 113 | + to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00); |
| 114 | + } |
| 115 | + })); |
| 116 | + } |
| 117 | + } |
| 118 | + { |
| 119 | + // reuse the current thread for the first task |
| 120 | + const int64_t start = 0; |
| 121 | + const int64_t end = ne01/n_threads; |
| 122 | + for (int64_t i01 = start; i01 < end; i01++) { |
| 123 | + to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00); |
| 124 | + } |
| 125 | + } |
| 126 | +#endif |
| 127 | + } |
| 128 | + } |
| 129 | + |
| 130 | +#ifndef GGML_USE_OPENMP |
| 131 | + // wait for all tasks to finish |
| 132 | + for (auto & task : ctx->tasks) { |
| 133 | + task.get(); |
| 134 | + } |
| 135 | + ctx->tasks.clear(); |
| 136 | +#endif |
| 137 | + } |
| 138 | + |
| 139 | +#if defined(OPENBLAS_VERSION) |
| 140 | + openblas_set_num_threads(ctx->n_threads); |
| 141 | +#endif |
| 142 | + |
| 143 | +#if defined(BLIS_ENABLE_CBLAS) |
| 144 | + bli_thread_set_num_threads(ctx->n_threads); |
| 145 | +#endif |
| 146 | + |
| 147 | + for (int64_t i13 = 0; i13 < ne13; i13++) { |
| 148 | + for (int64_t i12 = 0; i12 < ne12; i12++) { |
| 149 | + const int64_t i03 = i13/r3; |
| 150 | + const int64_t i02 = i12/r2; |
| 151 | + |
| 152 | + const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03); |
| 153 | + const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13); |
| 154 | + float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3); |
| 155 | + |
| 156 | + if (type != GGML_TYPE_F32) { |
| 157 | + x = (float *) wdata + i02*ne_plane + i03*ne02*ne_plane; |
| 158 | + } |
| 159 | + |
| 160 | + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, |
| 161 | + ne1, ne01, ne10, |
| 162 | + 1.0f, y, ne10, |
| 163 | + x, ne00, |
| 164 | + 0.0f, d, ne01); |
| 165 | + } |
| 166 | + } |
| 167 | +} |
| 168 | + |
| 169 | +static void ggml_backend_blas_out_prod(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) { |
| 170 | + const struct ggml_tensor * src0 = dst->src[0]; |
| 171 | + const struct ggml_tensor * src1 = dst->src[1]; |
| 172 | + |
| 173 | + GGML_TENSOR_BINARY_OP_LOCALS |
| 174 | + |
| 175 | + GGML_ASSERT(ne0 == ne00); |
| 176 | + GGML_ASSERT(ne1 == ne10); |
| 177 | + GGML_ASSERT(ne2 == ne02); |
| 178 | + GGML_ASSERT(ne02 == ne12); |
| 179 | + GGML_ASSERT(ne3 == ne13); |
| 180 | + GGML_ASSERT(ne03 == ne13); |
| 181 | + |
| 182 | + // we don't support permuted src0 or src1 |
| 183 | + GGML_ASSERT(nb00 == sizeof(float)); |
| 184 | + |
| 185 | + // dst cannot be transposed or permuted |
| 186 | + GGML_ASSERT(nb0 == sizeof(float)); |
| 187 | + // GGML_ASSERT(nb0 <= nb1); |
| 188 | + // GGML_ASSERT(nb1 <= nb2); |
| 189 | + // GGML_ASSERT(nb2 <= nb3); |
| 190 | + |
| 191 | + // Arguments to ggml_compute_forward_out_prod (expressed as major,minor) |
| 192 | + // src0: (k,n) |
| 193 | + // src1: (k,m) |
| 194 | + // dst: (m,n) |
| 195 | + // |
| 196 | + // Arguments to sgemm (see https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/sgemm.f) |
| 197 | + // Also expressed as (major,minor) |
| 198 | + // a: (m,k): so src1 transposed |
| 199 | + // b: (k,n): so src0 |
| 200 | + // c: (m,n) |
| 201 | + // |
| 202 | + // However, if ggml_is_transposed(src1) is true, then |
| 203 | + // src1->data already contains a transposed version, so sgemm mustn't |
| 204 | + // transpose it further. |
| 205 | + |
| 206 | + int n = src0->ne[0]; |
| 207 | + int k = src0->ne[1]; |
| 208 | + int m = src1->ne[0]; |
| 209 | + |
| 210 | + CBLAS_TRANSPOSE transposeA; |
| 211 | + int lda; |
| 212 | + |
| 213 | + if (!ggml_is_transposed(src1)) { |
| 214 | + transposeA = CblasTrans; |
| 215 | + lda = m; |
| 216 | + } else { |
| 217 | + transposeA = CblasNoTrans; |
| 218 | + lda = k; |
| 219 | + } |
| 220 | + |
| 221 | + float * a = (float *) ((char *) src1->data); |
| 222 | + float * b = (float *) ((char *) src0->data); |
| 223 | + float * c = (float *) ((char *) dst->data); |
| 224 | + |
| 225 | + cblas_sgemm(CblasRowMajor, transposeA, CblasNoTrans, m, n, k, 1.0, a, lda, b, n, 0.0, c, n); |
| 226 | + |
| 227 | + GGML_UNUSED(ctx); |
| 228 | +} |
| 229 | + |
| 230 | +// backend interface |
| 231 | + |
| 232 | +GGML_CALL static const char * ggml_backend_blas_name(ggml_backend_t backend) { |
| 233 | + return "BLAS"; |
| 234 | + |
| 235 | + GGML_UNUSED(backend); |
| 236 | +} |
| 237 | + |
| 238 | +GGML_CALL static void ggml_backend_blas_free(ggml_backend_t backend) { |
| 239 | + ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context; |
| 240 | + delete ctx; |
| 241 | + delete backend; |
| 242 | +} |
| 243 | + |
| 244 | +GGML_CALL static ggml_backend_buffer_type_t ggml_backend_blas_get_default_buffer_type(ggml_backend_t backend) { |
| 245 | + return ggml_backend_cpu_buffer_type(); |
| 246 | + |
| 247 | + GGML_UNUSED(backend); |
| 248 | +} |
| 249 | + |
| 250 | +GGML_CALL static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { |
| 251 | + ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context; |
| 252 | + |
| 253 | + for (int i = 0; i < cgraph->n_nodes; i++) { |
| 254 | + struct ggml_tensor * node = cgraph->nodes[i]; |
| 255 | + |
| 256 | + switch (node->op) { |
| 257 | + case GGML_OP_MUL_MAT: |
| 258 | + ggml_backend_blas_mul_mat(ctx, node); |
| 259 | + break; |
| 260 | + |
| 261 | + case GGML_OP_OUT_PROD: |
| 262 | + ggml_backend_blas_out_prod(ctx, node); |
| 263 | + break; |
| 264 | + |
| 265 | + case GGML_OP_NONE: |
| 266 | + case GGML_OP_RESHAPE: |
| 267 | + case GGML_OP_VIEW: |
| 268 | + case GGML_OP_PERMUTE: |
| 269 | + case GGML_OP_TRANSPOSE: |
| 270 | + break; |
| 271 | + |
| 272 | + default: |
| 273 | + fprintf(stderr, "%s: unsupported op %s\n", __func__, ggml_op_desc(node)); |
| 274 | + GGML_ASSERT(false); |
| 275 | + } |
| 276 | + } |
| 277 | + |
| 278 | + return GGML_STATUS_SUCCESS; |
| 279 | + |
| 280 | + GGML_UNUSED(backend); |
| 281 | +} |
| 282 | + |
| 283 | +GGML_CALL static bool ggml_backend_blas_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) { |
| 284 | + const struct ggml_tensor * src0 = op->src[0]; |
| 285 | + const struct ggml_tensor * src1 = op->src[1]; |
| 286 | + |
| 287 | + return (op->op == GGML_OP_MUL_MAT && ggml_backend_blas_use_blas(op)) || |
| 288 | + (op->op == GGML_OP_OUT_PROD && op->src[0]->type == GGML_TYPE_F32 && |
| 289 | + op->src[1]->type == GGML_TYPE_F32 && |
| 290 | + ggml_is_matrix(src0) && |
| 291 | + ggml_is_matrix(src1) && |
| 292 | + ggml_is_contiguous(src0) && |
| 293 | + (ggml_is_contiguous(src1) || ggml_is_transposed(src1))); |
| 294 | + |
| 295 | + GGML_UNUSED(backend); |
| 296 | +} |
| 297 | + |
| 298 | +GGML_CALL static bool ggml_backend_blas_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) { |
| 299 | + return ggml_backend_buft_is_host(buft); |
| 300 | + |
| 301 | + GGML_UNUSED(backend); |
| 302 | +} |
| 303 | + |
| 304 | +static struct ggml_backend_i blas_backend_i = { |
| 305 | + /* .get_name = */ ggml_backend_blas_name, |
| 306 | + /* .free = */ ggml_backend_blas_free, |
| 307 | + /* .get_default_buffer_type = */ ggml_backend_blas_get_default_buffer_type, |
| 308 | + /* .set_tensor_async = */ NULL, |
| 309 | + /* .get_tensor_async = */ NULL, |
| 310 | + /* .cpy_tensor_async = */ NULL, |
| 311 | + /* .synchronize = */ NULL, |
| 312 | + /* .graph_plan_create = */ NULL, |
| 313 | + /* .graph_plan_free = */ NULL, |
| 314 | + /* .graph_plan_update = */ NULL, |
| 315 | + /* .graph_plan_compute = */ NULL, |
| 316 | + /* .graph_compute = */ ggml_backend_blas_graph_compute, |
| 317 | + /* .supports_op = */ ggml_backend_blas_supports_op, |
| 318 | + /* .supports_buft = */ ggml_backend_blas_supports_buft, |
| 319 | + /* .offload_op = */ NULL, |
| 320 | + /* .event_new = */ NULL, |
| 321 | + /* .event_free = */ NULL, |
| 322 | + /* .event_record = */ NULL, |
| 323 | + /* .event_wait = */ NULL, |
| 324 | + /* .event_synchronize = */ NULL, |
| 325 | +}; |
| 326 | + |
| 327 | +static ggml_guid_t ggml_backend_blas_guid(void) { |
| 328 | + static ggml_guid guid = { 0x12, 0xa8, 0xae, 0xf4, 0xc0, 0x1e, 0x61, 0x97, 0x8f, 0xeb, 0x33, 0x04, 0xa1, 0x33, 0x51, 0x2d }; |
| 329 | + return &guid; |
| 330 | +} |
| 331 | + |
| 332 | +ggml_backend_t ggml_backend_blas_init(void) { |
| 333 | + ggml_backend_blas_context * ctx = new ggml_backend_blas_context; |
| 334 | + |
| 335 | + ggml_backend_t backend = new ggml_backend { |
| 336 | + /* .guid = */ ggml_backend_blas_guid(), |
| 337 | + /* .interface = */ blas_backend_i, |
| 338 | + /* .context = */ ctx, |
| 339 | + }; |
| 340 | + |
| 341 | +#if !defined(NDEBUG) && defined(OPENBLAS_VERSION) && defined(GGML_USE_OPENMP) |
| 342 | + if (openblas_get_parallel() != OPENBLAS_OPENMP) { |
| 343 | + fprintf(stderr, "%s: warning: ggml is using OpenMP, but OpenBLAS was compiled without OpenMP support\n", __func__); |
| 344 | + } |
| 345 | +#endif |
| 346 | + |
| 347 | +#if !defined(NDEBUG) && defined(BLIS_ENABLE_CBLAS) && defined(GGML_USE_OPENMP) && !defined(BLIS_ENABLE_OPENMP) |
| 348 | + fprintf(stderr, "%s: warning: ggml is using OpenMP, but BLIS was compiled without OpenMP support\n", __func__); |
| 349 | +#endif |
| 350 | + |
| 351 | + return backend; |
| 352 | +} |
| 353 | + |
| 354 | +GGML_CALL bool ggml_backend_is_blas(ggml_backend_t backend) { |
| 355 | + return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_blas_guid()); |
| 356 | +} |
| 357 | + |
| 358 | +void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads) { |
| 359 | + GGML_ASSERT(ggml_backend_is_blas(backend_blas)); |
| 360 | + |
| 361 | + ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend_blas->context; |
| 362 | + ctx->n_threads = n_threads; |
| 363 | +} |
0 commit comments