Skip to content

model : jina-embeddings-v3 support #13693

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2460,15 +2460,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"--lora"}, "FNAME",
"path to LoRA adapter (can be repeated to use multiple adapters)",
[](common_params & params, const std::string & value) {
params.lora_adapters.push_back({ std::string(value), 1.0, nullptr });
params.lora_adapters.push_back({ std::string(value), 1.0, "", "", nullptr });
}
// we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}));
add_opt(common_arg(
{"--lora-scaled"}, "FNAME", "SCALE",
"path to LoRA adapter with user defined scaling (can be repeated to use multiple adapters)",
[](common_params & params, const std::string & fname, const std::string & scale) {
params.lora_adapters.push_back({ fname, std::stof(scale), nullptr });
params.lora_adapters.push_back({ fname, std::stof(scale), "", "", nullptr });
}
// we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}));
Expand Down
2 changes: 2 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -993,6 +993,8 @@ struct common_init_result common_init_from_params(common_params & params) {
}

la.ptr = lora.get();
llama_adapter_meta_val_str(la.ptr, "adapter.lora.task_name", la.task_name, sizeof(la.task_name));
llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", la.prompt_prefix, sizeof(la.prompt_prefix));
iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters
}

Expand Down
3 changes: 3 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ struct common_adapter_lora_info {
std::string path;
float scale;

char task_name[64];
char prompt_prefix[256];

struct llama_adapter_lora * ptr;
};

Expand Down
77 changes: 75 additions & 2 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class ModelBase:
endianess: gguf.GGUFEndian
use_temp_file: bool
lazy: bool
dry_run: bool
part_names: list[str]
is_safetensors: bool
hparams: dict[str, Any]
Expand Down Expand Up @@ -98,6 +99,7 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
self.use_temp_file = use_temp_file
self.lazy = not eager or (remote_hf_model_id is not None)
self.dry_run = dry_run
self.remote_hf_model_id = remote_hf_model_id
if remote_hf_model_id is not None:
self.is_safetensors = True
Expand Down Expand Up @@ -4153,11 +4155,31 @@ def modify_tensors(self, data_torch, name, bid):
@ModelBase.register("XLMRobertaModel", "XLMRobertaForSequenceClassification")
class XLMRobertaModel(BertModel):
model_arch = gguf.MODEL_ARCH.BERT
_lora_files = {}

def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, **kwargs: Any):
hparams = kwargs.pop("hparams", None)
if hparams is None:
hparams = ModelBase.load_hparams(dir_model)

if lora_names := hparams.get("lora_adaptations"):
self.model_arch = gguf.MODEL_ARCH.JINA_BERT_V3

super().__init__(dir_model, ftype, fname_out, hparams=hparams, **kwargs)

if lora_names:
for name in lora_names:
fname = self.add_prefix_to_filename(self.fname_out, f"lora-{name}-")
self._lora_files[name] = gguf.GGUFWriter(fname, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file, dry_run=self.dry_run)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._xlmroberta_tokenizer_init()

def set_type(self):
for lora_writer in self._lora_files.values():
lora_writer.add_type(gguf.GGUFType.ADAPTER)
lora_writer.add_string(gguf.Keys.Adapter.TYPE, "lora")
super().set_type()

def set_vocab(self):
self._xlmroberta_set_vocab()

Expand All @@ -4167,13 +4189,64 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
if name.startswith("roberta."):
name = name[8:]

# jina-embeddings-v3
if ".parametrizations." in name:
name = name.replace(".parametrizations.", ".")
if name.endswith(".original"):
name = name[:-9]

# position embeddings start at pad_token_id + 1, so just chop down the weight tensor
if name == "embeddings.position_embeddings.weight":
if self._position_offset is not None:
data_torch = data_torch[self._position_offset:,:]

if name.endswith(".0.lora_A") or name.endswith(".0.lora_B"):
if name.startswith("pooler.dense"):
return []

num_loras = data_torch.size(0)
assert num_loras == len(self._lora_files)

# Split out each LoRA in their own GGUF
for i, lora_writer in enumerate(self._lora_files.values()):
new_name = self.map_tensor_name(name[:-9]) + name[-7:].lower()
data_qtype = gguf.GGMLQuantizationType.F32
data = data_torch[i, :, :]
# Transpose/flip token_embd/types into correct shape
if new_name == "token_embd.weight.lora_b":
data = data.T
elif new_name.startswith("token_types.weight."):
new_name = new_name[:-1] + ("a" if new_name[-1:] == "b" else "b")
data = gguf.quants.quantize(data.numpy(), data_qtype)
lora_writer.add_tensor(new_name, data, raw_dtype=data_qtype)

return []

return super().modify_tensors(data_torch, name, bid)

def set_gguf_parameters(self):
super().set_gguf_parameters()

# jina-embeddings-v3
if rotary_emb_base := self.hparams.get("rotary_emb_base"):
self.gguf_writer.add_rope_freq_base(rotary_emb_base)
lora_alpha = self.hparams.get("lora_alpha")
if lora_prompt_prefixes := self.hparams.get("task_instructions"):
assert self._lora_files and all(lora_name in lora_prompt_prefixes for lora_name in self._lora_files.keys())
for lora_name, lora_writer in self._lora_files.items():
lora_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, lora_alpha if lora_alpha is not None else 1.0)
lora_writer.add_string(gguf.Keys.Adapter.LORA_TASK_NAME, lora_name)
if lora_prompt_prefixes:
lora_writer.add_string(gguf.Keys.Adapter.LORA_PROMPT_PREFIX, lora_prompt_prefixes[lora_name])

def write(self):
super().write()
for lora_writer in self._lora_files.values():
lora_writer.write_header_to_file()
lora_writer.write_kv_data_to_file()
lora_writer.write_tensors_to_file(progress=True)
lora_writer.close()


@ModelBase.register("GemmaForCausalLM")
class GemmaModel(TextModel):
Expand Down
20 changes: 18 additions & 2 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,10 @@ class Tokenizer:
MIDDLE_ID = "tokenizer.ggml.middle_token_id"

class Adapter:
TYPE = "adapter.type"
LORA_ALPHA = "adapter.lora.alpha"
TYPE = "adapter.type"
LORA_ALPHA = "adapter.lora.alpha"
LORA_TASK_NAME = "adapter.lora.task_name"
LORA_PROMPT_PREFIX = "adapter.lora.prompt_prefix"

class Clip:
PROJECTOR_TYPE = "clip.projector_type"
Expand Down Expand Up @@ -301,6 +303,7 @@ class MODEL_ARCH(IntEnum):
NOMIC_BERT_MOE = auto()
NEO_BERT = auto()
JINA_BERT_V2 = auto()
JINA_BERT_V3 = auto()
BLOOM = auto()
STABLELM = auto()
QWEN = auto()
Expand Down Expand Up @@ -604,6 +607,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.NOMIC_BERT_MOE: "nomic-bert-moe",
MODEL_ARCH.NEO_BERT: "neo-bert",
MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2",
MODEL_ARCH.JINA_BERT_V3: "jina-bert-v3",
MODEL_ARCH.BLOOM: "bloom",
MODEL_ARCH.STABLELM: "stablelm",
MODEL_ARCH.QWEN: "qwen",
Expand Down Expand Up @@ -1161,6 +1165,18 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.LAYER_OUT_NORM,
MODEL_TENSOR.CLS,
],
MODEL_ARCH.JINA_BERT_V3: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.TOKEN_EMBD_NORM,
MODEL_TENSOR.TOKEN_TYPES,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ATTN_OUT_NORM,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.LAYER_OUT_NORM,
],
MODEL_ARCH.MPT: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
Expand Down
18 changes: 18 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,24 @@ extern "C" {
struct llama_model * model,
const char * path_lora);

// Functions to access the adapter's GGUF metadata scalar values
// - The functions return the length of the string on success, or -1 on failure
// - The output string is always null-terminated and cleared on failure
// - When retrieving a string, an extra byte must be allocated to account for the null terminator
// - GGUF array values are not supported by these functions

// Get metadata value as a string by key name
LLAMA_API int32_t llama_adapter_meta_val_str(const struct llama_adapter_lora * adapter, const char * key, char * buf, size_t buf_size);

// Get the number of metadata key/value pairs
LLAMA_API int32_t llama_adapter_meta_count(const struct llama_adapter_lora * adapter);

// Get metadata key name by index
LLAMA_API int32_t llama_adapter_meta_key_by_index(const struct llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size);

// Get metadata value as a string by index
LLAMA_API int32_t llama_adapter_meta_val_str_by_index(const struct llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size);

// Manually free a LoRA adapter
// Note: loaded adapters will be free when the associated model is deleted
LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter);
Expand Down
72 changes: 68 additions & 4 deletions src/llama-adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,38 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_

// check metadata
{
const gguf_context * gguf_ctx = ctx_gguf.get();

LLAMA_LOG_INFO("%s: Dumping metadata keys/values.\n", __func__);

// get metadata as string
for (int i = 0; i < gguf_get_n_kv(gguf_ctx); i++) {
gguf_type type = gguf_get_kv_type(gguf_ctx, i);
const std::string type_name =
type == GGUF_TYPE_ARRAY
? format("%s[%s,%zu]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(gguf_ctx, i)), gguf_get_arr_n(gguf_ctx, i))
: gguf_type_name(type);
const char * name = gguf_get_key(gguf_ctx, i);
const std::string value = gguf_kv_to_str(gguf_ctx, i);

if (type != GGUF_TYPE_ARRAY) {
adapter.gguf_kv.emplace(name, value);
}

const size_t MAX_VALUE_LEN = 40;
std::string print_value = value.size() > MAX_VALUE_LEN ? format("%s...", value.substr(0, MAX_VALUE_LEN - 3).c_str()) : value;
replace_all(print_value, "\n", "\\n");

LLAMA_LOG_INFO("%s: - kv %3d: %42s %-16s = %s\n", __func__, i, name, type_name.c_str(), print_value.c_str());
}

auto get_kv_str = [&](const std::string & key) -> std::string {
int id = gguf_find_key(ctx_gguf.get(), key.c_str());
return id < 0 ? "" : std::string(gguf_get_val_str(ctx_gguf.get(), id));
int id = gguf_find_key(gguf_ctx, key.c_str());
return id < 0 ? "" : std::string(gguf_get_val_str(gguf_ctx, id));
};
auto get_kv_f32 = [&](const std::string & key) -> float {
int id = gguf_find_key(ctx_gguf.get(), key.c_str());
return id < 0 ? 0.0f : gguf_get_val_f32(ctx_gguf.get(), id);
int id = gguf_find_key(gguf_ctx, key.c_str());
return id < 0 ? 0.0f : gguf_get_val_f32(gguf_ctx, id);
};
LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN);

Expand Down Expand Up @@ -383,6 +408,45 @@ llama_adapter_lora * llama_adapter_lora_init(llama_model * model, const char * p
return nullptr;
}

int32_t llama_adapter_meta_val_str(const llama_adapter_lora * adapter, const char * key, char * buf, size_t buf_size) {
const auto & it = adapter->gguf_kv.find(key);
if (it == adapter->gguf_kv.end()) {
if (buf_size > 0) {
buf[0] = '\0';
}
return -1;
}
return snprintf(buf, buf_size, "%s", it->second.c_str());
}

int32_t llama_adapter_meta_count(const llama_adapter_lora * adapter) {
return (int)adapter->gguf_kv.size();
}

int32_t llama_adapter_meta_key_by_index(const llama_adapter_lora * adapter, int i, char * buf, size_t buf_size) {
if (i < 0 || i >= (int)adapter->gguf_kv.size()) {
if (buf_size > 0) {
buf[0] = '\0';
}
return -1;
}
auto it = adapter->gguf_kv.begin();
std::advance(it, i);
return snprintf(buf, buf_size, "%s", it->first.c_str());
}

int32_t llama_adapter_meta_val_str_by_index(const llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size) {
if (i < 0 || i >= (int)adapter->gguf_kv.size()) {
if (buf_size > 0) {
buf[0] = '\0';
}
return -1;
}
auto it = adapter->gguf_kv.begin();
std::advance(it, i);
return snprintf(buf, buf_size, "%s", it->second.c_str());
}

void llama_adapter_lora_free(llama_adapter_lora * adapter) {
delete adapter;
}
3 changes: 3 additions & 0 deletions src/llama-adapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ struct llama_adapter_lora {

float alpha;

// gguf metadata
std::unordered_map<std::string, std::string> gguf_kv;

llama_adapter_lora() = default;
~llama_adapter_lora() = default;

Expand Down
21 changes: 19 additions & 2 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" },
{ LLM_ARCH_NEO_BERT, "neo-bert" },
{ LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" },
{ LLM_ARCH_JINA_BERT_V3, "jina-bert-v3" },
{ LLM_ARCH_BLOOM, "bloom" },
{ LLM_ARCH_STABLELM, "stablelm" },
{ LLM_ARCH_QWEN, "qwen" },
Expand Down Expand Up @@ -216,8 +217,10 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_TOKENIZER_FIM_REP_ID, "tokenizer.ggml.fim_rep_token_id" },
{ LLM_KV_TOKENIZER_FIM_SEP_ID, "tokenizer.ggml.fim_sep_token_id" },

{ LLM_KV_ADAPTER_TYPE, "adapter.type" },
{ LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" },
{ LLM_KV_ADAPTER_TYPE, "adapter.type" },
{ LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" },
{ LLM_KV_ADAPTER_LORA_TASK_NAME, "adapter.lora.task_name" },
{ LLM_KV_ADAPTER_LORA_PROMPT_PREFIX, "adapter.lora.prompt_prefix" },

// deprecated
{ LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" },
Expand Down Expand Up @@ -557,6 +560,20 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_CLS, "cls" },
},
},
{
LLM_ARCH_JINA_BERT_V3,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
{ LLM_TENSOR_TOKEN_TYPES, "token_types" },
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
},
},
{
LLM_ARCH_BLOOM,
{
Expand Down
3 changes: 3 additions & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ enum llm_arch {
LLM_ARCH_NOMIC_BERT_MOE,
LLM_ARCH_NEO_BERT,
LLM_ARCH_JINA_BERT_V2,
LLM_ARCH_JINA_BERT_V3,
LLM_ARCH_BLOOM,
LLM_ARCH_STABLELM,
LLM_ARCH_QWEN,
Expand Down Expand Up @@ -214,6 +215,8 @@ enum llm_kv {

LLM_KV_ADAPTER_TYPE,
LLM_KV_ADAPTER_LORA_ALPHA,
LLM_KV_ADAPTER_LORA_TASK_NAME,
LLM_KV_ADAPTER_LORA_PROMPT_PREFIX,

LLM_KV_POSNET_EMBEDDING_LENGTH,
LLM_KV_POSNET_BLOCK_COUNT,
Expand Down
Loading
Loading