Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -701,12 +701,22 @@ Specified using `--task embed`.
* ✅︎
* ✅︎
- * `GteModel`
* GteModel
* Arctic-Embed-2.0-M
* `Snowflake/snowflake-arctic-embed-m-v2.0`.
*
* ︎
- * `GteNewModel`
* mGTE-TRM (see note)
* `Alibaba-NLP/gte-multilingual-base`, etc.
* ︎
* ︎
- * `ModernBertModel`
* ModernBERT-based
* `Alibaba-NLP/gte-modernbert-base`, etc.
* ︎
* ︎
- * `NomicBertModel`
* NomicBertModel
* Nomic BERT
* `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc.
* ︎
* ︎
Expand Down Expand Up @@ -749,6 +759,10 @@ See [relevant issue on HF Transformers](https://github.com/huggingface/transform
`jinaai/jina-embeddings-v3` supports multiple tasks through lora, while vllm temporarily only supports text-matching tasks by merging lora weights.
:::

:::{note}
The second-generation GTE model (mGTE-TRM) is named `NewModel`. The name `NewModel` is too generic, you should set `--hf-overrides '{"architectures": ["GteNewModel"]}'` to specify the use of the `GteNewModel` architecture.
:::

If your model is not in the above list, we will try to automatically convert the model using
{func}`~vllm.model_executor.models.adapters.as_embedding_model`. By default, the embeddings
of the whole prompt are extracted from the normalized hidden state corresponding to the last token.
Expand Down
17 changes: 12 additions & 5 deletions tests/models/language/pooling/mteb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest

from tests.models.utils import EmbedModelInfo
from vllm.model_executor.model_loader.utils import set_default_torch_dtype

# Most models on the STS12 task (See #17175):
# - Model implementation and minor changes in tensor dtype
Expand Down Expand Up @@ -77,16 +78,22 @@ def run_mteb_embed_task_st(model_name, tasks):
return run_mteb_embed_task(model, tasks)


def mteb_test_embed_models(hf_runner, vllm_runner, model_info: EmbedModelInfo):
def mteb_test_embed_models(hf_runner,
vllm_runner,
model_info: EmbedModelInfo,
vllm_extra_kwargs=None):
if not model_info.enable_test:
# A model family has many models with the same architecture,
# and we don't need to test each one.
pytest.skip("Skipping test.")

vllm_extra_kwargs = vllm_extra_kwargs or {}

with vllm_runner(model_info.name,
task="embed",
max_model_len=None,
dtype=model_info.dtype) as vllm_model:
dtype=model_info.dtype,
**vllm_extra_kwargs) as vllm_model:

if model_info.architecture:
assert (model_info.architecture
Expand All @@ -99,9 +106,9 @@ def mteb_test_embed_models(hf_runner, vllm_runner, model_info: EmbedModelInfo):
vllm_model.model.llm_engine.model_config.hf_config, "torch_dtype",
vllm_dtype)

with hf_runner(model_info.name,
is_sentence_transformer=True,
dtype=model_dtype) as hf_model:
with set_default_torch_dtype(model_dtype) and hf_runner(
model_info.name, is_sentence_transformer=True,
dtype=model_dtype) as hf_model:
st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS)

print("VLLM:", vllm_dtype, vllm_main_score)
Expand Down
104 changes: 104 additions & 0 deletions tests/models/language/pooling/test_gte.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any

import pytest

from ...utils import EmbedModelInfo, run_embedding_correctness_test

MODELS = [
########## BertModel
EmbedModelInfo("thenlper/gte-large",
architecture="BertModel",
dtype="float32",
enable_test=True),
EmbedModelInfo("thenlper/gte-base",
architecture="BertModel",
dtype="float32",
enable_test=False),
EmbedModelInfo("thenlper/gte-small",
architecture="BertModel",
dtype="float32",
enable_test=False),
EmbedModelInfo("thenlper/gte-large-zh",
architecture="BertModel",
dtype="float32",
enable_test=False),
EmbedModelInfo("thenlper/gte-base-zh",
architecture="BertModel",
dtype="float32",
enable_test=False),
EmbedModelInfo("thenlper/gte-small-zh",
architecture="BertModel",
dtype="float32",
enable_test=False),
########### NewModel
EmbedModelInfo("Alibaba-NLP/gte-multilingual-base",
architecture="GteNewModel",
enable_test=True),
EmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5",
architecture="GteNewModel",
enable_test=True),
EmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5",
architecture="GteNewModel",
enable_test=True),
########### Qwen2ForCausalLM
EmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct",
architecture="Qwen2ForCausalLM",
enable_test=True),
EmbedModelInfo("Alibaba-NLP/gte-Qwen2-7B-instruct",
architecture="Qwen2ForCausalLM",
enable_test=False),
########## ModernBertModel
EmbedModelInfo("Alibaba-NLP/gte-modernbert-base",
architecture="ModernBertModel",
enable_test=True),
]


@pytest.mark.parametrize("model_info", MODELS)
def test_models_mteb(hf_runner, vllm_runner,
model_info: EmbedModelInfo) -> None:
pytest.skip("Skipping mteb test.")

from .mteb_utils import mteb_test_embed_models

vllm_extra_kwargs: dict[str, Any] = {}
if model_info.name == "Alibaba-NLP/gte-Qwen2-1.5B-instruct":
vllm_extra_kwargs["hf_overrides"] = {"is_causal": True}

if model_info.architecture == "GteNewModel":
vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]}

mteb_test_embed_models(hf_runner, vllm_runner, model_info,
vllm_extra_kwargs)


@pytest.mark.parametrize("model_info", MODELS)
def test_models_correctness(hf_runner, vllm_runner, model_info: EmbedModelInfo,
example_prompts) -> None:
if not model_info.enable_test:
pytest.skip("Skipping test.")

# ST will strip the input texts, see test_embedding.py
example_prompts = [str(s).strip() for s in example_prompts]

vllm_extra_kwargs: dict[str, Any] = {}
if model_info.name == "Alibaba-NLP/gte-Qwen2-1.5B-instruct":
vllm_extra_kwargs["hf_overrides"] = {"is_causal": True}

if model_info.architecture == "GteNewModel":
vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]}

with vllm_runner(model_info.name,
task="embed",
dtype=model_info.dtype,
max_model_len=None,
**vllm_extra_kwargs) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts)

with hf_runner(
model_info.name,
dtype=model_info.dtype,
is_sentence_transformer=True,
) as hf_model:
run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs)
4 changes: 4 additions & 0 deletions tests/models/language/pooling/test_nomic.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
@pytest.mark.parametrize("model_info", MODELS)
def test_models_mteb(hf_runner, vllm_runner,
model_info: EmbedModelInfo) -> None:
pytest.skip("Skipping mteb test.")
from .mteb_utils import mteb_test_embed_models
mteb_test_embed_models(hf_runner, vllm_runner, model_info)

Expand All @@ -33,6 +34,9 @@ def test_models_correctness(hf_runner, vllm_runner, model_info: EmbedModelInfo,
if not model_info.enable_test:
pytest.skip("Skipping test.")

# ST will strip the input texts, see test_embedding.py
example_prompts = [str(s).strip() for s in example_prompts]

with vllm_runner(model_info.name,
task="embed",
dtype=model_info.dtype,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def test_models_mteb(
vllm_runner,
model_info: EmbedModelInfo,
) -> None:
pytest.skip("Skipping mteb test.")
from .mteb_utils import mteb_test_embed_models
mteb_test_embed_models(hf_runner, vllm_runner, model_info)

Expand All @@ -60,6 +61,9 @@ def test_models_correctness(
if not model_info.enable_test:
pytest.skip("Skipping test.")

# ST will strip the input texts, see test_embedding.py
example_prompts = [str(s).strip() for s in example_prompts]

with vllm_runner(model_info.name,
task="embed",
dtype=model_info.dtype,
Expand Down
6 changes: 6 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,11 +256,17 @@ def check_available_online(
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
"GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
trust_remote_code=True),
"GteNewModel": _HfExamplesInfo("Alibaba-NLP/gte-base-en-v1.5",
trust_remote_code=True,
hf_overrides={"architectures":
["GteNewModel"]}),
"InternLM2ForRewardModel": _HfExamplesInfo("internlm/internlm2-1_8b-reward",
trust_remote_code=True),
"JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"), # noqa: E501
"LlamaModel": _HfExamplesInfo("llama", is_available_online=False),
"MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"),
"ModernBertModel": _HfExamplesInfo("Alibaba-NLP/gte-modernbert-base",
trust_remote_code=True),
"NomicBertModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-long", # noqa: E501
trust_remote_code=True),
"Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"),
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def get_act_fn(act_fn_name: str) -> nn.Module:
_ACTIVATION_AND_MUL_REGISTRY = LazyDict({
"gelu": lambda: GeluAndMul(),
"silu": lambda: SiluAndMul(),
"gelu_and_mul": lambda: GeluAndMul(),
"geglu": lambda: GeluAndMul(),
})


Expand Down
42 changes: 42 additions & 0 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,40 @@ def scaling_factor_to_offset(self) -> dict[float, int]:
return self._scaling_factor_to_offset


class NTKScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with fixed and mixed NTK scaling.
https://kexue.fm/archives/9706 """

def __init__(self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
mixed_b: Optional[float] = None) -> None:
self.scaling_factor = scaling_factor
self.mixed_b = mixed_b
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)

def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
base = self.base * (self.scaling_factor if self.mixed_b is None else 1)
inv_freq = super()._compute_inv_freq(base)

if self.mixed_b is None:
inv_freq = inv_freq / self.scaling_factor**(2 / self.rotary_dim)
else:
a = torch.tensor(self.scaling_factor).log() / (self.rotary_dim /
2)**self.mixed_b
lambda_1_m = (a * torch.arange(
1, self.rotary_dim // 2 + 1).float()**self.mixed_b).exp()
inv_freq = inv_freq / lambda_1_m

return inv_freq


class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with Dynamic NTK scaling.

Expand Down Expand Up @@ -1765,6 +1799,14 @@ def get_rope(
max_position, base,
is_neox_style,
scaling_factor, dtype)
elif scaling_type == "ntk":
scaling_factor = rope_scaling["factor"]
mixed_b = rope_scaling.get('mixed_b', None)
rotary_emb = NTKScalingRotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style,
scaling_factor, dtype,
mixed_b)
elif scaling_type == "dynamic":
scaling_factor = rope_scaling["factor"]
rotary_emb = DynamicNTKScalingRotaryEmbedding(
Expand Down
Loading