Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions charts/model-engine/values_circleci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ config:
tgi_repository: "text-generation-inference"
vllm_repository: "vllm"
lightllm_repository: "lightllm"
tensorrt_llm_repository: "tensorrt-llm"
user_inference_base_repository: "launch/inference"
user_inference_pytorch_repository: "hosted-model-inference/async-pytorch"
user_inference_tensorflow_repository: "hosted-model-inference/async-tensorflow-cpu"
Expand Down
1 change: 1 addition & 0 deletions model-engine/model_engine_server/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class HostedModelInferenceServiceConfig:
tgi_repository: str
vllm_repository: str
lightllm_repository: str
tensorrt_llm_repository: str
user_inference_base_repository: str
user_inference_pytorch_repository: str
user_inference_tensorflow_repository: str
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class LLMInferenceFramework(str, Enum):
TEXT_GENERATION_INFERENCE = "text_generation_inference"
VLLM = "vllm"
LIGHTLLM = "lightllm"
TENSORRT_LLM = "tensorrt_llm"


class Quantization(str, Enum):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@
from model_engine_server.domain.services import LLMModelEndpointService, ModelEndpointService
from model_engine_server.infra.gateways.filesystem_gateway import FilesystemGateway

# Hack for TensorRT-LLM. Remove when it supports returning output tokens only
# See https://github.com/NVIDIA/TensorRT-LLM/issues/227
from transformers import AutoTokenizer

from ...common.datadog_utils import add_trace_request_id
from ..authorization.live_authorization_module import LiveAuthorizationModule
from .model_bundle_use_cases import CreateModelBundleV2UseCase
Expand Down Expand Up @@ -150,13 +154,17 @@
"llama-2-70b": "meta-llama/Llama-2-70b-hf",
"llama-2-70b-chat": "meta-llama/Llama-2-70b-chat-hf",
},
LLMInferenceFramework.TENSORRT_LLM: {
"llama-2-7b": "huggyllama/llama-7b", # Hack to get tokenizer for llama without sign in to huggingface
},
}

_SUPPORTED_QUANTIZATIONS: Dict[LLMInferenceFramework, List[Quantization]] = {
LLMInferenceFramework.DEEPSPEED: [],
LLMInferenceFramework.TEXT_GENERATION_INFERENCE: [Quantization.BITSANDBYTES],
LLMInferenceFramework.VLLM: [Quantization.AWQ],
LLMInferenceFramework.LIGHTLLM: [],
LLMInferenceFramework.TENSORRT_LLM: [],
}

# We need a dict where if we need to override we can
Expand Down Expand Up @@ -340,6 +348,14 @@ async def create_model_bundle(
num_shards,
checkpoint_path,
)
elif framework == LLMInferenceFramework.TENSORRT_LLM:
bundle_id = await self.create_tensorrt_llm_bundle(
user,
framework_image_tag,
endpoint_name,
num_shards,
checkpoint_path,
)
else:
raise ObjectHasInvalidValueException(
f"Framework {framework} is not supported for source {source}."
Expand Down Expand Up @@ -384,7 +400,7 @@ async def create_text_generation_inference_bundle(
)
else:
raise ObjectHasInvalidValueException(
f"Not able to load checkpoint path {checkpoint_path}."
f"Only S3 paths are supported. Given checkpoint path: {checkpoint_path}."
)
else:
final_weights_folder = _SUPPORTED_MODEL_NAMES[
Expand Down Expand Up @@ -471,6 +487,32 @@ def load_model_weights_sub_commands(

return subcommands

def load_model_files_sub_commands_trt_llm(
self,
checkpoint_path,
):
"""
This function generate subcommands to load model files for TensorRT-LLM.
Each model checkpoint is constituted of two folders: `model_weights` which stores the model engine files,
and `model_tokenizer` which stores the model tokenizer files.
See llm-engine/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/tensorrt_llm/config.pbtxt
and llm-engine/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/postprocessing/config.pbtxt
"""
subcommands = []

base_path = checkpoint_path.split("/")[-1]

if base_path.endswith(".tar"):
raise ObjectHasInvalidValueException(
"Checkpoint for TensorRT-LLM models must be a folder, not a tar file."
)
else:
subcommands.append(
f"./s5cmd --numworkers 512 cp --concurrency 50 {os.path.join(checkpoint_path, '*')} ./"
)

return subcommands

async def create_deepspeed_bundle(
self,
user: User,
Expand Down Expand Up @@ -587,7 +629,7 @@ async def create_vllm_bundle(
)
else:
raise ObjectHasInvalidValueException(
f"Not able to load checkpoint path {checkpoint_path}."
f"Only S3 paths are supported. Given checkpoint path: {checkpoint_path}."
)
else:
final_weights_folder = _SUPPORTED_MODEL_NAMES[LLMInferenceFramework.VLLM][model_name]
Expand Down Expand Up @@ -678,7 +720,7 @@ async def create_lightllm_bundle(
)
else:
raise ObjectHasInvalidValueException(
f"Not able to load checkpoint path {checkpoint_path}."
f"Only S3 paths are supported. Given checkpoint path: {checkpoint_path}."
)
else:
final_weights_folder = _SUPPORTED_MODEL_NAMES[LLMInferenceFramework.VLLM][model_name]
Expand Down Expand Up @@ -721,6 +763,70 @@ async def create_lightllm_bundle(
)
).model_bundle_id

async def create_tensorrt_llm_bundle(
self,
user: User,
framework_image_tag: str,
endpoint_unique_name: str,
num_shards: int,
checkpoint_path: Optional[str],
):
command = []

subcommands = []
if checkpoint_path is not None:
if checkpoint_path.startswith("s3://"):
subcommands += self.load_model_files_sub_commands_trt_llm(
checkpoint_path,
)
else:
raise ObjectHasInvalidValueException(
f"Only S3 paths are supported. Given checkpoint path: {checkpoint_path}."
)
else:
raise ObjectHasInvalidValueException(
"Checkpoint must be provided for TensorRT-LLM models."
)

subcommands.append(
f"python3 launch_triton_server.py --world_size={num_shards} --model_repo=./model_repo/"
)

command = [
"/bin/bash",
"-c",
";".join(subcommands),
]

return (
await self.create_model_bundle_use_case.execute(
user,
CreateModelBundleV2Request(
name=endpoint_unique_name,
schema_location="TBA",
flavor=StreamingEnhancedRunnableImageFlavor(
flavor=ModelBundleFlavorType.STREAMING_ENHANCED_RUNNABLE_IMAGE,
repository=hmi_config.tensorrt_llm_repository,
tag=framework_image_tag,
command=command,
streaming_command=command,
protocol="http",
readiness_initial_delay_seconds=10,
healthcheck_route="/v2/health/ready",
# See https://github.com/triton-inference-server/server/blob/main/docs/protocol/extension_generate.md
predict_route="/v2/models/ensemble/generate",
streaming_predict_route="/v2/models/ensemble/generate_stream",
env={},
),
metadata={},
),
do_auth_check=False,
# Skip auth check because llm create endpoint is called as the user itself,
# but the user isn't directly making the action. It should come from the fine tune
# job.
)
).model_bundle_id

async def execute(
self, user: User, request: CreateLLMModelEndpointV1Request
) -> CreateLLMModelEndpointV1Response:
Expand All @@ -741,6 +847,8 @@ async def execute(
if request.inference_framework in [
LLMInferenceFramework.TEXT_GENERATION_INFERENCE,
LLMInferenceFramework.VLLM,
LLMInferenceFramework.LIGHTLLM,
LLMInferenceFramework.TENSORRT_LLM,
]:
if request.endpoint_type != ModelEndpointType.STREAMING:
raise ObjectHasInvalidValueException(
Expand Down Expand Up @@ -1002,9 +1110,26 @@ def validate_and_update_completion_params(
"presence_penalty and frequency_penalty are only supported in vllm, lightllm."
)

# return_token_log_probs
if inference_framework in [
LLMInferenceFramework.DEEPSPEED,
LLMInferenceFramework.TEXT_GENERATION_INFERENCE,
LLMInferenceFramework.VLLM,
LLMInferenceFramework.LIGHTLLM,
]:
pass
else:
if request.return_token_log_probs:
raise ObjectHasInvalidValueException(
"return_token_log_probs is only supported in deepspeed, text-generation-inference, vllm, lightllm."
)

return request


tokenizer_cache: Dict[str, AutoTokenizer] = {}


class CompletionSyncV1UseCase:
"""
Use case for running a prompt completion on an LLM endpoint.
Expand All @@ -1024,6 +1149,7 @@ def model_output_to_completion_output(
model_output: Dict[str, Any],
model_endpoint: ModelEndpoint,
with_token_probs: Optional[bool],
prompt: Optional[str] = None,
) -> CompletionOutput:
model_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint)
if model_content.inference_framework == LLMInferenceFramework.DEEPSPEED:
Expand Down Expand Up @@ -1083,6 +1209,28 @@ def model_output_to_completion_output(
num_completion_tokens=model_output["count_output_tokens"],
tokens=tokens,
)
elif model_content.inference_framework == LLMInferenceFramework.TENSORRT_LLM:
if not model_content.model_name:
raise InvalidRequestException(
f"Invalid endpoint {model_content.name} has no base model"
)
if not prompt:
raise InvalidRequestException("Prompt must be provided for TensorRT-LLM models.")
if model_content.model_name not in tokenizer_cache:
tokenizer_cache[model_content.model_name] = AutoTokenizer.from_pretrained(
_SUPPORTED_MODEL_NAMES[LLMInferenceFramework.TENSORRT_LLM][
model_content.model_name
]
)
tokenizer = tokenizer_cache[model_content.model_name]
prompt_tokens = tokenizer.encode(prompt)

return CompletionOutput(
text=model_output["text_output"][
len(prompt) + 4 :
], # Output is "<s> prompt output"
num_completion_tokens=len(model_output["token_ids"]) - len(prompt_tokens),
)
else:
raise EndpointUnsupportedInferenceTypeException(
f"Unsupported inference framework {model_content.inference_framework}"
Expand Down Expand Up @@ -1187,6 +1335,7 @@ async def execute(
predict_result.result["result"][0],
model_endpoint,
request.return_token_log_probs,
request.prompt,
),
)
else:
Expand Down Expand Up @@ -1317,6 +1466,42 @@ async def execute(
output, model_endpoint, request.return_token_log_probs
),
)
elif endpoint_content.inference_framework == LLMInferenceFramework.TENSORRT_LLM:
# TODO: Stop sequences is buggy and return token logprobs are not supported
# TODO: verify the implementation of presence_penalty and repetition_penalty
# and see if they fit our existing definition of presence_penalty and frequency_penalty
# Ref https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/sampling_penalty_kernels.cu
trt_llm_args: Any = {
"text_input": request.prompt,
"max_tokens": request.max_new_tokens,
"stop_words": request.stop_sequences if request.stop_sequences else "",
"bad_words": "",
"temperature": request.temperature,
}

inference_request = SyncEndpointPredictV1Request(
args=trt_llm_args,
num_retries=NUM_DOWNSTREAM_REQUEST_RETRIES,
timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS,
)
predict_result = await inference_gateway.predict(
topic=model_endpoint.record.destination,
predict_request=inference_request,
)

if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None:
return CompletionSyncV1Response(
request_id=request_id,
output=None,
)

output = json.loads(predict_result.result["result"])
return CompletionSyncV1Response(
request_id=request_id,
output=self.model_output_to_completion_output(
output, model_endpoint, request.return_token_log_probs, request.prompt
),
)
else:
raise EndpointUnsupportedInferenceTypeException(
f"Unsupported inference framework {endpoint_content.inference_framework}"
Expand Down Expand Up @@ -1471,6 +1656,19 @@ async def execute(
args["parameters"]["do_sample"] = False
if request.return_token_log_probs:
args["parameters"]["return_details"] = True
elif model_content.inference_framework == LLMInferenceFramework.TENSORRT_LLM:
# TODO: Stop sequences is buggy and return token logprobs are not supported
# TODO: verify the implementation of presence_penalty and repetition_penalty
# and see if they fit our existing definition of presence_penalty and frequency_penalty
# Ref https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/sampling_penalty_kernels.cu
args = {
"text_input": request.prompt,
"max_tokens": request.max_new_tokens,
"stop_words": request.stop_sequences if request.stop_sequences else "",
"bad_words": "",
"temperature": request.temperature,
"stream": True,
}
else:
raise EndpointUnsupportedInferenceTypeException(
f"Unsupported inference framework {model_content.inference_framework}"
Expand Down Expand Up @@ -1606,6 +1804,22 @@ async def execute(
request_id=request_id,
output=None,
)
elif model_content.inference_framework == LLMInferenceFramework.TENSORRT_LLM:
if res.status == TaskStatus.SUCCESS and result is not None:
num_completion_tokens += 1
yield CompletionStreamV1Response(
request_id=request_id,
output=CompletionStreamOutput(
text=result["result"]["text_output"],
finished=False, # Tracked by https://github.com/NVIDIA/TensorRT-LLM/issues/240
num_completion_tokens=num_completion_tokens,
),
)
else:
yield CompletionStreamV1Response(
request_id=request_id,
output=None,
)
else:
raise EndpointUnsupportedInferenceTypeException(
f"Unsupported inference framework {model_content.inference_framework}"
Expand Down
12 changes: 12 additions & 0 deletions model-engine/model_engine_server/inference/tensorrt-llm/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
FROM nvcr.io/nvidia/tritonserver:23.10-trtllm-python-py3

COPY requirements.txt /workspace/requirements.txt
WORKDIR /workspace
RUN pip install -r requirements.txt

# Install s5cmd
RUN wget https://github.com/peak/s5cmd/releases/download/v2.2.1/s5cmd_2.2.1_Linux-64bit.tar.gz
RUN tar -xvzf s5cmd_2.2.1_Linux-64bit.tar.gz

COPY launch_triton_server.py /workspace/launch_triton_server.py
COPY triton_model_repo /workspace/model_repo
Loading