Skip to content

llama : fix embeddings #5796

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 4, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
llama : add pooling switch
  • Loading branch information
ggerganov committed Mar 4, 2024
commit e66da356a41530137161d20feb224c76f5bc13ec
43 changes: 25 additions & 18 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8113,7 +8113,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {

for (int i = 0; i < n_tokens; ++i) {
const llama_seq_id seq_id = batch.seq_id[i][0];
const llama_pos pos = batch.pos[i];
const llama_pos pos = batch.pos[i];
if (pos == 0) {
data[seq_id] = i;
}
Expand Down Expand Up @@ -8379,10 +8379,17 @@ static int llama_decode_internal(
if (batch.logits[i] == 0) {
continue;
}
if (hparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
ggml_backend_tensor_get_async(backend_embd, embd, embeddings_out.data() + (n_embd*i), (n_embd*batch.seq_id[i][0])*sizeof(float), n_embd*sizeof(float));
} else {
ggml_backend_tensor_get_async(backend_embd, embd, embeddings_out.data() + (n_embd*i), (n_embd*i)*sizeof(float), n_embd*sizeof(float));
switch (hparams.pooling_type) {
case LLAMA_POOLING_TYPE_CLS:
ggml_backend_tensor_get_async(backend_embd, embd, embeddings_out.data() + (n_embd*i), (n_embd*batch.seq_id[i][0])*sizeof(float), n_embd*sizeof(float));
break;
case LLAMA_POOLING_TYPE_MEAN:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should have the LLAMA_POOLING_TYPE_MEAN case join the LLAMA_POOLING_TYPE_CLS case due to the output order of the averaging matrix.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I keep getting confused with the sequence-based instead of token-based embedding extraction.

Will try to modify the API to make things more clear

Copy link
Collaborator

@cebtenzzre cebtenzzre Mar 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, this change fixes the NaNs from embedding.cpp, and results in the MSE of nomic-embed-text-v1.f16.gguf actually being lower than before (4.71e-10 vs 5.62e-10). Also fp32 is down from 9.34e-11 to 1.18e-14.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also fp32 is down from 9.34e-11 to 1.18e-14.

This is likely due to no longer going through the KV cache.

Do you have any performance benchmarks to see if this change improved the speed?

case LLAMA_POOLING_TYPE_NONE:
ggml_backend_tensor_get_async(backend_embd, embd, embeddings_out.data() + (n_embd*i), (n_embd*i)*sizeof(float), n_embd*sizeof(float));
break;
default:
GGML_ASSERT(false && "unknown pooling type");
break;
}
}
}
Expand Down Expand Up @@ -8680,19 +8687,19 @@ static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) {
GGML_ASSERT(llama_is_byte_token(vocab, id));
const auto& token_data = vocab.id_to_token.at(id);
switch (llama_vocab_get_type(vocab)) {
case LLAMA_VOCAB_TYPE_SPM: {
auto buf = token_data.text.substr(3, 2);
return strtol(buf.c_str(), NULL, 16);
}
case LLAMA_VOCAB_TYPE_BPE: {
GGML_ASSERT(false);
return unicode_to_bytes_bpe(token_data.text);
}
case LLAMA_VOCAB_TYPE_WPM: {
GGML_ASSERT(false);
}
default:
GGML_ASSERT(false);
case LLAMA_VOCAB_TYPE_SPM: {
auto buf = token_data.text.substr(3, 2);
return strtol(buf.c_str(), NULL, 16);
}
case LLAMA_VOCAB_TYPE_BPE: {
GGML_ASSERT(false);
return unicode_to_bytes_bpe(token_data.text);
}
case LLAMA_VOCAB_TYPE_WPM: {
GGML_ASSERT(false);
}
default:
GGML_ASSERT(false);
}
}

Expand Down