Skip to content

Commit a19b5ce

Browse files
authored
llama : fix FA when KV cache is not used (i.e. embeddings) (#12825)
* ggml : FA supports F32 V * graph : cast KV to F16 when the KV cache is not used ggml-ci * server : add test that exercises embeddings with FA enabled ggml-ci
1 parent 78a1ba0 commit a19b5ce

File tree

6 files changed

+59
-6
lines changed

6 files changed

+59
-6
lines changed

examples/server/tests/unit/test_embedding.py

+20
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,26 @@ def test_embedding_multiple():
4949
assert len(d['embedding']) > 1
5050

5151

52+
def test_embedding_multiple_with_fa():
53+
server = ServerPreset.bert_bge_small_with_fa()
54+
server.pooling = 'last'
55+
server.start()
56+
# one of these should trigger the FA branch (i.e. context size % 256 == 0)
57+
res = server.make_request("POST", "/v1/embeddings", data={
58+
"input": [
59+
"a "*253,
60+
"b "*254,
61+
"c "*255,
62+
"d "*256,
63+
],
64+
})
65+
assert res.status_code == 200
66+
assert len(res.body['data']) == 4
67+
for d in res.body['data']:
68+
assert 'embedding' in d
69+
assert len(d['embedding']) > 1
70+
71+
5272
@pytest.mark.parametrize(
5373
"input,is_multi_prompt",
5474
[

examples/server/tests/utils.py

+15
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,21 @@ def bert_bge_small() -> ServerProcess:
323323
server.server_embeddings = True
324324
return server
325325

326+
@staticmethod
327+
def bert_bge_small_with_fa() -> ServerProcess:
328+
server = ServerProcess()
329+
server.model_hf_repo = "ggml-org/models"
330+
server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf"
331+
server.model_alias = "bert-bge-small"
332+
server.n_ctx = 1024
333+
server.n_batch = 300
334+
server.n_ubatch = 300
335+
server.n_slots = 2
336+
server.fa = True
337+
server.seed = 42
338+
server.server_embeddings = True
339+
return server
340+
326341
@staticmethod
327342
def tinyllama_infill() -> ServerProcess:
328343
server = ServerProcess()

examples/server_embd.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ async def main():
1515
model_url = "http://127.0.0.1:6900"
1616
responses: list[requests.Response] = await asyncio.gather(*[requests_post_async(
1717
url= f"{model_url}/embedding",
18-
json= {"content": str(0)*1024}
18+
json= {"content": "a "*1022}
1919
) for i in range(n)])
2020

2121
for response in responses:

ggml/src/ggml-cpu/ops.cpp

+9-5
Original file line numberDiff line numberDiff line change
@@ -6721,8 +6721,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
67216721
ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot;
67226722
ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float;
67236723

6724-
GGML_ASSERT(q_to_vec_dot && "fattn: unsupported K-type");
6725-
GGML_ASSERT(v_to_float && "fattn: unsupported V-type");
6724+
GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type");
6725+
GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type");
67266726

67276727
// loop over n_batch and n_head
67286728
for (int ir = ir0; ir < ir1; ++ir) {
@@ -6818,10 +6818,14 @@ static void ggml_compute_forward_flash_attn_ext_f16(
68186818
vs = expf(s - M);
68196819
}
68206820

6821-
v_to_float(v_data, V32, DV);
6822-
68236821
// V += v*expf(s - M)
6824-
ggml_vec_mad_f32(DV, VKQ32, V32, vs);
6822+
if (v_to_float) {
6823+
v_to_float(v_data, V32, DV);
6824+
ggml_vec_mad_f32(DV, VKQ32, V32, vs);
6825+
} else {
6826+
// V is F32
6827+
ggml_vec_mad_f32(DV, VKQ32, (const float *) v_data, vs);
6828+
}
68256829
}
68266830

68276831
S = S*ms + vs; // scale and increment sum with partial sum

ggml/src/ggml-metal/ggml-metal.m

+5
Original file line numberDiff line numberDiff line change
@@ -1345,6 +1345,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
13451345
case GGML_OP_ARANGE:
13461346
return true;
13471347
case GGML_OP_FLASH_ATTN_EXT:
1348+
if (op->src[0]->ne[0] == 32) {
1349+
// head size == 32 (e.g. bert-bge-small)
1350+
// TODO: not sure if it is worth adding kernels for this size
1351+
return false;
1352+
}
13481353
if (op->src[1]->type != op->src[2]->type) {
13491354
return false;
13501355
}

src/llama-graph.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -1215,6 +1215,15 @@ ggml_tensor * llm_graph_context::build_attn_mha(
12151215
v = ggml_transpose(ctx0, v);
12161216
}
12171217

1218+
// this can happen when KV cache is not used (e.g. an embedding model with non-causal attn)
1219+
if (k->type == GGML_TYPE_F32) {
1220+
k = ggml_cast(ctx0, k, GGML_TYPE_F16);
1221+
}
1222+
1223+
if (v->type == GGML_TYPE_F32) {
1224+
v = ggml_cast(ctx0, v, GGML_TYPE_F16);
1225+
}
1226+
12181227
cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
12191228
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
12201229

0 commit comments

Comments
 (0)