diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 576f9581bdaee..f4b3d9cf5929c 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -4358,7 +4358,7 @@ static bool ggml_metal_encode_node( // TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0) // for now avoiding mainly to keep the number of templates/kernels a bit lower // these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612 - if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 96 && ne00 != 192 && ne00 != 576)) { + if (ne01 >= 20 || (ne00%128 != 0 && ne00 != 96 && ne00 != 192 && ne00 != 576)) { switch (src1->type) { case GGML_TYPE_F16: { diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 9cfddf4503abe..122ae59737196 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -3887,6 +3887,11 @@ kernel void kernel_flash_attn_ext_vec( sm[tiisg] = pm[ic + tiisg]; } + // skip -INF blocks + if (simd_max(sm[tiisg]) == -INFINITY) { + continue; + } + // Q*K^T { // each simdgroup processes 1 query and NE (NW/NL) head elements diff --git a/tools/batched-bench/batched-bench.cpp b/tools/batched-bench/batched-bench.cpp index 0f4019293d581..119df471b25ee 100644 --- a/tools/batched-bench/batched-bench.cpp +++ b/tools/batched-bench/batched-bench.cpp @@ -123,8 +123,8 @@ int main(int argc, char ** argv) { common_batch_clear(batch); - for (int i = 0; i < pp; ++i) { - for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) { + for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) { + for (int i = 0; i < pp; ++i) { common_batch_add(batch, 0, i, { j }, false); } }