Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions tests/neuron/2_core/test_multi_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# SPDX-License-Identifier: Apache-2.0

from huggingface_hub import snapshot_download

from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest


def test_llama_single_lora():
sql_lora_files = snapshot_download(
repo_id="yard1/llama-2-7b-sql-lora-test")
llm = LLM(model="meta-llama/Llama-2-7b-hf",
tensor_parallel_size=2,
max_num_seqs=4,
max_model_len=512,
use_v2_block_manager=True,
override_neuron_config={
"sequence_parallel_enabled": False,
"skip_warmup": True,
"lora_modules": [{
"name": "lora_id_1",
"path": sql_lora_files
}]
},
enable_lora=True,
max_loras=1,
max_lora_rank=256,
device="neuron")
"""For multi-lora requests using NxDI as the backend, only the lora_name
needs to be specified. The lora_id and lora_path are supplied at the LLM
class/server initialization, after which the paths are handled by NxDI"""
lora_req_1 = LoRARequest("lora_id_1", 0, " ")
prompts = [
"The president of the United States is",
"The capital of France is",
]
outputs = llm.generate(prompts,
SamplingParams(top_k=1),
lora_request=[lora_req_1, lora_req_1])

expected_outputs = [
" the head of state and head of government of the United States. "
"The president direct",
" a city of contrasts. The city is home to the Eiffel Tower"
]

for expected_output, output in zip(expected_outputs, outputs):
generated_text = output.outputs[0].text
assert (expected_output == generated_text)


def test_llama_multiple_lora():
sql_lora_files = snapshot_download(
repo_id="yard1/llama-2-7b-sql-lora-test")
llm = LLM(model="meta-llama/Llama-2-7b-hf",
tensor_parallel_size=2,
max_num_seqs=4,
max_model_len=512,
use_v2_block_manager=True,
override_neuron_config={
"sequence_parallel_enabled":
False,
"skip_warmup":
True,
"lora_modules": [{
"name": "lora_id_1",
"path": sql_lora_files
}, {
"name": "lora_id_2",
"path": sql_lora_files
}]
},
enable_lora=True,
max_loras=2,
max_lora_rank=256,
device="neuron")
"""For multi-lora requests using NxDI as the backend, only the lora_name
needs to be specified. The lora_id and lora_path are supplied at the LLM
class/server initialization, after which the paths are handled by NxDI"""
lora_req_1 = LoRARequest("lora_id_1", 0, " ")
lora_req_2 = LoRARequest("lora_id_2", 1, " ")
prompts = [
"The president of the United States is",
"The capital of France is",
]
outputs = llm.generate(prompts,
SamplingParams(top_k=1),
lora_request=[lora_req_1, lora_req_2])

expected_outputs = [
" the head of state and head of government of the United States. "
"The president direct",
" a city of contrasts. The city is home to the Eiffel Tower"
]

for expected_output, output in zip(expected_outputs, outputs):
generated_text = output.outputs[0].text
assert (expected_output == generated_text)
31 changes: 18 additions & 13 deletions vllm/model_executor/model_loader/neuronx_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
FusedSpecNeuronConfig, OnDeviceSamplingConfig)
from neuronx_distributed_inference.models.mllama.utils import (
create_vision_mask)
from neuronx_distributed_inference.modules.lora_serving import (
LoraServingConfig)
from neuronx_distributed_inference.utils.hf_adapter import (
load_pretrained_config)
from transformers import AutoModelForCausalLM, AutoTokenizer, PretrainedConfig
Expand Down Expand Up @@ -80,25 +82,26 @@ def __init__(
# Lazy initialized
self.model: nn.Module

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_block_ids: torch.Tensor,
sampling_params: torch.Tensor,
) -> torch.Tensor:
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_block_ids: torch.Tensor,
sampling_params: torch.Tensor,
prev_hidden: Optional[torch.Tensor] = None,
adapter_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
# sort block ids sequentially for perf/neuron support reasons
sorted_input_block_ids, sorted_indices = torch.sort(input_block_ids)
input_ids = torch.index_select(input_ids, 0, sorted_indices)
positions = torch.index_select(positions, 0, sorted_indices)
sampling_params = torch.index_select(sampling_params, 0,
sorted_indices)

output = self.model(input_ids,
attention_mask=None,
position_ids=positions,
seq_ids=sorted_input_block_ids,
sampling_params=sampling_params)
sampling_params=sampling_params,
prev_hidden=prev_hidden,
adapter_ids=adapter_ids)
# on-device sampling
if self.config.neuron_config.on_device_sampling_config:
output = output.hidden_states
Expand Down Expand Up @@ -522,7 +525,8 @@ def _get_model_architecture(config: PretrainedConfig) -> str:

def _get_default_neuron_config(model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig):
scheduler_config: SchedulerConfig,
lora_serving_config: LoraServingConfig):
"""Generate a neuron config based on vllm config args."""
on_device_sampling_config = OnDeviceSamplingConfig(dynamic=True,
deterministic=False)
Expand All @@ -541,7 +545,7 @@ def _get_default_neuron_config(model_config: ModelConfig,
padding_side="right",
on_device_sampling_config=on_device_sampling_config,
sequence_parallel_enabled=True,
)
lora_serving_config=lora_serving_config)
return neuron_config


Expand Down Expand Up @@ -581,15 +585,16 @@ def _get_neuron_config_after_override(default_neuron_config,

def get_neuron_model(model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig) -> nn.Module:
scheduler_config: SchedulerConfig,
lora_serving_config: LoraServingConfig) -> nn.Module:
"""Initializes a neuron-optimized model for inference."""
model_arch = _get_model_architecture(model_config.hf_config)
if model_arch == "MllamaForConditionalGeneration":
model = NeuronMllamaForCausalLM(model_config.hf_config)
else:
model = NeuronCausalLM(model_config.hf_config)
default_neuron_config_args = _get_default_neuron_config(
model_config, parallel_config, scheduler_config)
model_config, parallel_config, scheduler_config, lora_serving_config)
neuron_config = _get_neuron_config_after_override(
default_neuron_config_args, model_config.override_neuron_config)

Expand Down
3 changes: 0 additions & 3 deletions vllm/platforms/neuron.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
if parallel_config.world_size > 1:
parallel_config.distributed_executor_backend = "uni"

assert (vllm_config.lora_config
is None), "LoRA is not supported for Neuron backend."

if vllm_config.cache_config and vllm_config.model_config:
# neuron needs block_size = max_model_len
vllm_config.cache_config.block_size = \
Expand Down
32 changes: 31 additions & 1 deletion vllm/worker/neuron_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

import os
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union

import torch
from torch import nn

from vllm.config import DeviceConfig, VllmConfig
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.neuron import get_neuron_model
Expand Down Expand Up @@ -36,6 +38,7 @@ class ModelInputForNeuron(ModelRunnerInputBase):
input_block_ids: Optional[torch.Tensor] = None
sampling_metadata: SamplingMetadata = None
multi_modal_kwargs: BatchedTensorInputs = None
adapter_ids: Optional[str] = None

def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
Expand Down Expand Up @@ -80,6 +83,7 @@ def __init__(
"The model will run without sliding window.")
self.device_config = (self.device_config if self.device_config
is not None else DeviceConfig())
self.lora_config = vllm_config.lora_config
self.device = self.device_config.device
self.pin_memory = is_pin_memory_available()

Expand Down Expand Up @@ -378,6 +382,7 @@ def execute_model(
positions=model_input.input_positions,
input_block_ids=model_input.input_block_ids,
sampling_params=sampling_params,
adapter_ids=model_input.adapter_ids,
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
dtype=self.model_config.dtype,
Expand Down Expand Up @@ -416,3 +421,28 @@ def execute_model(
@property
def vocab_size(self) -> int:
return self.model_config.get_vocab_size()

def remove_all_loras(self):
raise NotImplementedError(
"LoRAs are not supported for Transformers NeuronX framework")

def set_active_loras(self, lora_requests: Set[LoRARequest],
lora_mapping: LoRAMapping) -> None:
raise NotImplementedError(
"LoRAs are not supported for Transformers NeuronX framework")

def add_lora(self, lora_request: LoRARequest):
raise NotImplementedError(
"LoRAs are not supported for Transformers NeuronX framework")

def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError(
"LoRAs are not supported for Transformers NeuronX framework")

def pin_lora(self, lora_id: int) -> bool:
raise NotImplementedError(
"LoRAs are not supported for Transformers NeuronX framework")

def list_loras(self) -> Set[int]:
raise NotImplementedError(
"LoRAs are not supported for Transformers NeuronX framework")
42 changes: 38 additions & 4 deletions vllm/worker/neuron_worker.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,27 @@
# SPDX-License-Identifier: Apache-2.0
"""A Neuron worker class."""
import os
from typing import List, Optional, Tuple
from typing import List, Optional, Set, Tuple

import torch.distributed

from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.platforms.neuron import NeuronFramework
from vllm.sequence import ExecuteModelRequest
from vllm.worker.neuron_model_runner import NeuronModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
LoRANotSupportedWorkerBase, WorkerBase,
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
WorkerInput)

logger = init_logger(__name__)


class NeuronWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase):
class NeuronWorker(LocalOrDistributedWorkerBase):
"""A worker class that executes the model on a group of neuron cores.
"""

Expand All @@ -38,6 +38,7 @@ def __init__(self,
self.rank = rank
self.distributed_init_method = distributed_init_method
self.is_driver_worker = is_driver_worker
self.lora_config = vllm_config.lora_config

if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
Expand All @@ -59,6 +60,9 @@ def __init__(self,
"[transformers-neuronx, neuronx-distributed-inference]")

def get_tnx_model_runner(self, vllm_config):
assert (self.lora_config
is None), ("LoRA is not supported for TransformersNeuronX "
"framework.")
from vllm.worker.multi_step_neuron_model_runner import (
MultiStepNeuronModelRunner)
if self.speculative_config is not None:
Expand All @@ -72,6 +76,8 @@ def get_neuronx_distributed_model_runner(self, vllm_config):
from vllm.worker.neuronx_distributed_model_runner import (
NeuronxDistributedModelRunner)
if self.speculative_config is not None:
assert (self.lora_config
is None), "LoRA is not supported for Speculative Decoding"
return MultiStepNeuronxDistributedModelRunner(
vllm_config=vllm_config)
else:
Expand Down Expand Up @@ -156,3 +162,31 @@ def init_distributed_environment(self):
1,
1,
)

def add_lora(self, lora_request: LoRARequest) -> bool:
if current_platform.use_transformers_neuronx():
raise NotImplementedError(
f"{type(self)} does not support LoRA with Neuron Framework "
f"Transformers NeuronX")
return self.model_runner.add_lora(lora_request)

def remove_lora(self, lora_id: int) -> bool:
if current_platform.use_transformers_neuronx():
raise NotImplementedError(
f"{type(self)} does not support LoRA with Neuron Framework "
f"Transformers NeuronX")
return self.model_runner.remove_lora(lora_id)

def pin_lora(self, lora_id: int) -> bool:
if current_platform.use_transformers_neuronx():
raise NotImplementedError(
f"{type(self)} does not support LoRA with Neuron Framework "
f"Transformers NeuronX")
return self.model_runner.pin_lora(lora_id)

def list_loras(self) -> Set[int]:
if current_platform.use_transformers_neuronx():
raise NotImplementedError(
f"{type(self)} does not support LoRA with Neuron Framework "
f"Transformers NeuronX")
return self.model_runner.list_loras()
Loading