diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 9cfddf4503abe..122ae59737196 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -3887,6 +3887,11 @@ kernel void kernel_flash_attn_ext_vec( sm[tiisg] = pm[ic + tiisg]; } + // skip -INF blocks + if (simd_max(sm[tiisg]) == -INFINITY) { + continue; + } + // Q*K^T { // each simdgroup processes 1 query and NE (NW/NL) head elements diff --git a/tools/batched-bench/batched-bench.cpp b/tools/batched-bench/batched-bench.cpp index 0f4019293d581..119df471b25ee 100644 --- a/tools/batched-bench/batched-bench.cpp +++ b/tools/batched-bench/batched-bench.cpp @@ -123,8 +123,8 @@ int main(int argc, char ** argv) { common_batch_clear(batch); - for (int i = 0; i < pp; ++i) { - for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) { + for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) { + for (int i = 0; i < pp; ++i) { common_batch_add(batch, 0, i, { j }, false); } }