diff --git a/charts/model-engine/values_circleci.yaml b/charts/model-engine/values_circleci.yaml index 1cc777e3..a5e29f87 100644 --- a/charts/model-engine/values_circleci.yaml +++ b/charts/model-engine/values_circleci.yaml @@ -146,6 +146,7 @@ config: tgi_repository: "text-generation-inference" vllm_repository: "vllm" lightllm_repository: "lightllm" + tensorrt_llm_repository: "tensorrt-llm" user_inference_base_repository: "launch/inference" user_inference_pytorch_repository: "hosted-model-inference/async-pytorch" user_inference_tensorflow_repository: "hosted-model-inference/async-tensorflow-cpu" diff --git a/model-engine/model_engine_server/common/config.py b/model-engine/model_engine_server/common/config.py index 958881e1..f2b33eea 100644 --- a/model-engine/model_engine_server/common/config.py +++ b/model-engine/model_engine_server/common/config.py @@ -57,6 +57,7 @@ class HostedModelInferenceServiceConfig: tgi_repository: str vllm_repository: str lightllm_repository: str + tensorrt_llm_repository: str user_inference_base_repository: str user_inference_pytorch_repository: str user_inference_tensorflow_repository: str diff --git a/model-engine/model_engine_server/domain/entities/llm_entity.py b/model-engine/model_engine_server/domain/entities/llm_entity.py index 0624857f..30ec8993 100644 --- a/model-engine/model_engine_server/domain/entities/llm_entity.py +++ b/model-engine/model_engine_server/domain/entities/llm_entity.py @@ -12,6 +12,7 @@ class LLMInferenceFramework(str, Enum): TEXT_GENERATION_INFERENCE = "text_generation_inference" VLLM = "vllm" LIGHTLLM = "lightllm" + TENSORRT_LLM = "tensorrt_llm" class Quantization(str, Enum): diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index ab138747..1130e6e2 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -62,6 +62,10 @@ from model_engine_server.domain.services import LLMModelEndpointService, ModelEndpointService from model_engine_server.infra.gateways.filesystem_gateway import FilesystemGateway +# Hack for TensorRT-LLM. Remove when it supports returning output tokens only +# See https://github.com/NVIDIA/TensorRT-LLM/issues/227 +from transformers import AutoTokenizer + from ...common.datadog_utils import add_trace_request_id from ..authorization.live_authorization_module import LiveAuthorizationModule from .model_bundle_use_cases import CreateModelBundleV2UseCase @@ -150,6 +154,9 @@ "llama-2-70b": "meta-llama/Llama-2-70b-hf", "llama-2-70b-chat": "meta-llama/Llama-2-70b-chat-hf", }, + LLMInferenceFramework.TENSORRT_LLM: { + "llama-2-7b": "huggyllama/llama-7b", # Hack to get tokenizer for llama without sign in to huggingface + }, } _SUPPORTED_QUANTIZATIONS: Dict[LLMInferenceFramework, List[Quantization]] = { @@ -157,6 +164,7 @@ LLMInferenceFramework.TEXT_GENERATION_INFERENCE: [Quantization.BITSANDBYTES], LLMInferenceFramework.VLLM: [Quantization.AWQ], LLMInferenceFramework.LIGHTLLM: [], + LLMInferenceFramework.TENSORRT_LLM: [], } # We need a dict where if we need to override we can @@ -340,6 +348,14 @@ async def create_model_bundle( num_shards, checkpoint_path, ) + elif framework == LLMInferenceFramework.TENSORRT_LLM: + bundle_id = await self.create_tensorrt_llm_bundle( + user, + framework_image_tag, + endpoint_name, + num_shards, + checkpoint_path, + ) else: raise ObjectHasInvalidValueException( f"Framework {framework} is not supported for source {source}." @@ -384,7 +400,7 @@ async def create_text_generation_inference_bundle( ) else: raise ObjectHasInvalidValueException( - f"Not able to load checkpoint path {checkpoint_path}." + f"Only S3 paths are supported. Given checkpoint path: {checkpoint_path}." ) else: final_weights_folder = _SUPPORTED_MODEL_NAMES[ @@ -471,6 +487,32 @@ def load_model_weights_sub_commands( return subcommands + def load_model_files_sub_commands_trt_llm( + self, + checkpoint_path, + ): + """ + This function generate subcommands to load model files for TensorRT-LLM. + Each model checkpoint is constituted of two folders: `model_weights` which stores the model engine files, + and `model_tokenizer` which stores the model tokenizer files. + See llm-engine/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/tensorrt_llm/config.pbtxt + and llm-engine/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/postprocessing/config.pbtxt + """ + subcommands = [] + + base_path = checkpoint_path.split("/")[-1] + + if base_path.endswith(".tar"): + raise ObjectHasInvalidValueException( + "Checkpoint for TensorRT-LLM models must be a folder, not a tar file." + ) + else: + subcommands.append( + f"./s5cmd --numworkers 512 cp --concurrency 50 {os.path.join(checkpoint_path, '*')} ./" + ) + + return subcommands + async def create_deepspeed_bundle( self, user: User, @@ -587,7 +629,7 @@ async def create_vllm_bundle( ) else: raise ObjectHasInvalidValueException( - f"Not able to load checkpoint path {checkpoint_path}." + f"Only S3 paths are supported. Given checkpoint path: {checkpoint_path}." ) else: final_weights_folder = _SUPPORTED_MODEL_NAMES[LLMInferenceFramework.VLLM][model_name] @@ -678,7 +720,7 @@ async def create_lightllm_bundle( ) else: raise ObjectHasInvalidValueException( - f"Not able to load checkpoint path {checkpoint_path}." + f"Only S3 paths are supported. Given checkpoint path: {checkpoint_path}." ) else: final_weights_folder = _SUPPORTED_MODEL_NAMES[LLMInferenceFramework.VLLM][model_name] @@ -721,6 +763,70 @@ async def create_lightllm_bundle( ) ).model_bundle_id + async def create_tensorrt_llm_bundle( + self, + user: User, + framework_image_tag: str, + endpoint_unique_name: str, + num_shards: int, + checkpoint_path: Optional[str], + ): + command = [] + + subcommands = [] + if checkpoint_path is not None: + if checkpoint_path.startswith("s3://"): + subcommands += self.load_model_files_sub_commands_trt_llm( + checkpoint_path, + ) + else: + raise ObjectHasInvalidValueException( + f"Only S3 paths are supported. Given checkpoint path: {checkpoint_path}." + ) + else: + raise ObjectHasInvalidValueException( + "Checkpoint must be provided for TensorRT-LLM models." + ) + + subcommands.append( + f"python3 launch_triton_server.py --world_size={num_shards} --model_repo=./model_repo/" + ) + + command = [ + "/bin/bash", + "-c", + ";".join(subcommands), + ] + + return ( + await self.create_model_bundle_use_case.execute( + user, + CreateModelBundleV2Request( + name=endpoint_unique_name, + schema_location="TBA", + flavor=StreamingEnhancedRunnableImageFlavor( + flavor=ModelBundleFlavorType.STREAMING_ENHANCED_RUNNABLE_IMAGE, + repository=hmi_config.tensorrt_llm_repository, + tag=framework_image_tag, + command=command, + streaming_command=command, + protocol="http", + readiness_initial_delay_seconds=10, + healthcheck_route="/v2/health/ready", + # See https://github.com/triton-inference-server/server/blob/main/docs/protocol/extension_generate.md + predict_route="/v2/models/ensemble/generate", + streaming_predict_route="/v2/models/ensemble/generate_stream", + env={}, + ), + metadata={}, + ), + do_auth_check=False, + # Skip auth check because llm create endpoint is called as the user itself, + # but the user isn't directly making the action. It should come from the fine tune + # job. + ) + ).model_bundle_id + async def execute( self, user: User, request: CreateLLMModelEndpointV1Request ) -> CreateLLMModelEndpointV1Response: @@ -741,6 +847,8 @@ async def execute( if request.inference_framework in [ LLMInferenceFramework.TEXT_GENERATION_INFERENCE, LLMInferenceFramework.VLLM, + LLMInferenceFramework.LIGHTLLM, + LLMInferenceFramework.TENSORRT_LLM, ]: if request.endpoint_type != ModelEndpointType.STREAMING: raise ObjectHasInvalidValueException( @@ -1002,9 +1110,26 @@ def validate_and_update_completion_params( "presence_penalty and frequency_penalty are only supported in vllm, lightllm." ) + # return_token_log_probs + if inference_framework in [ + LLMInferenceFramework.DEEPSPEED, + LLMInferenceFramework.TEXT_GENERATION_INFERENCE, + LLMInferenceFramework.VLLM, + LLMInferenceFramework.LIGHTLLM, + ]: + pass + else: + if request.return_token_log_probs: + raise ObjectHasInvalidValueException( + "return_token_log_probs is only supported in deepspeed, text-generation-inference, vllm, lightllm." + ) + return request +tokenizer_cache: Dict[str, AutoTokenizer] = {} + + class CompletionSyncV1UseCase: """ Use case for running a prompt completion on an LLM endpoint. @@ -1024,6 +1149,7 @@ def model_output_to_completion_output( model_output: Dict[str, Any], model_endpoint: ModelEndpoint, with_token_probs: Optional[bool], + prompt: Optional[str] = None, ) -> CompletionOutput: model_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) if model_content.inference_framework == LLMInferenceFramework.DEEPSPEED: @@ -1083,6 +1209,28 @@ def model_output_to_completion_output( num_completion_tokens=model_output["count_output_tokens"], tokens=tokens, ) + elif model_content.inference_framework == LLMInferenceFramework.TENSORRT_LLM: + if not model_content.model_name: + raise InvalidRequestException( + f"Invalid endpoint {model_content.name} has no base model" + ) + if not prompt: + raise InvalidRequestException("Prompt must be provided for TensorRT-LLM models.") + if model_content.model_name not in tokenizer_cache: + tokenizer_cache[model_content.model_name] = AutoTokenizer.from_pretrained( + _SUPPORTED_MODEL_NAMES[LLMInferenceFramework.TENSORRT_LLM][ + model_content.model_name + ] + ) + tokenizer = tokenizer_cache[model_content.model_name] + prompt_tokens = tokenizer.encode(prompt) + + return CompletionOutput( + text=model_output["text_output"][ + len(prompt) + 4 : + ], # Output is " prompt output" + num_completion_tokens=len(model_output["token_ids"]) - len(prompt_tokens), + ) else: raise EndpointUnsupportedInferenceTypeException( f"Unsupported inference framework {model_content.inference_framework}" @@ -1187,6 +1335,7 @@ async def execute( predict_result.result["result"][0], model_endpoint, request.return_token_log_probs, + request.prompt, ), ) else: @@ -1317,6 +1466,42 @@ async def execute( output, model_endpoint, request.return_token_log_probs ), ) + elif endpoint_content.inference_framework == LLMInferenceFramework.TENSORRT_LLM: + # TODO: Stop sequences is buggy and return token logprobs are not supported + # TODO: verify the implementation of presence_penalty and repetition_penalty + # and see if they fit our existing definition of presence_penalty and frequency_penalty + # Ref https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/sampling_penalty_kernels.cu + trt_llm_args: Any = { + "text_input": request.prompt, + "max_tokens": request.max_new_tokens, + "stop_words": request.stop_sequences if request.stop_sequences else "", + "bad_words": "", + "temperature": request.temperature, + } + + inference_request = SyncEndpointPredictV1Request( + args=trt_llm_args, + num_retries=NUM_DOWNSTREAM_REQUEST_RETRIES, + timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS, + ) + predict_result = await inference_gateway.predict( + topic=model_endpoint.record.destination, + predict_request=inference_request, + ) + + if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: + return CompletionSyncV1Response( + request_id=request_id, + output=None, + ) + + output = json.loads(predict_result.result["result"]) + return CompletionSyncV1Response( + request_id=request_id, + output=self.model_output_to_completion_output( + output, model_endpoint, request.return_token_log_probs, request.prompt + ), + ) else: raise EndpointUnsupportedInferenceTypeException( f"Unsupported inference framework {endpoint_content.inference_framework}" @@ -1471,6 +1656,19 @@ async def execute( args["parameters"]["do_sample"] = False if request.return_token_log_probs: args["parameters"]["return_details"] = True + elif model_content.inference_framework == LLMInferenceFramework.TENSORRT_LLM: + # TODO: Stop sequences is buggy and return token logprobs are not supported + # TODO: verify the implementation of presence_penalty and repetition_penalty + # and see if they fit our existing definition of presence_penalty and frequency_penalty + # Ref https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/sampling_penalty_kernels.cu + args = { + "text_input": request.prompt, + "max_tokens": request.max_new_tokens, + "stop_words": request.stop_sequences if request.stop_sequences else "", + "bad_words": "", + "temperature": request.temperature, + "stream": True, + } else: raise EndpointUnsupportedInferenceTypeException( f"Unsupported inference framework {model_content.inference_framework}" @@ -1606,6 +1804,22 @@ async def execute( request_id=request_id, output=None, ) + elif model_content.inference_framework == LLMInferenceFramework.TENSORRT_LLM: + if res.status == TaskStatus.SUCCESS and result is not None: + num_completion_tokens += 1 + yield CompletionStreamV1Response( + request_id=request_id, + output=CompletionStreamOutput( + text=result["result"]["text_output"], + finished=False, # Tracked by https://github.com/NVIDIA/TensorRT-LLM/issues/240 + num_completion_tokens=num_completion_tokens, + ), + ) + else: + yield CompletionStreamV1Response( + request_id=request_id, + output=None, + ) else: raise EndpointUnsupportedInferenceTypeException( f"Unsupported inference framework {model_content.inference_framework}" diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/Dockerfile b/model-engine/model_engine_server/inference/tensorrt-llm/Dockerfile new file mode 100644 index 00000000..7bae22fd --- /dev/null +++ b/model-engine/model_engine_server/inference/tensorrt-llm/Dockerfile @@ -0,0 +1,12 @@ +FROM nvcr.io/nvidia/tritonserver:23.10-trtllm-python-py3 + +COPY requirements.txt /workspace/requirements.txt +WORKDIR /workspace +RUN pip install -r requirements.txt + +# Install s5cmd +RUN wget https://github.com/peak/s5cmd/releases/download/v2.2.1/s5cmd_2.2.1_Linux-64bit.tar.gz +RUN tar -xvzf s5cmd_2.2.1_Linux-64bit.tar.gz + +COPY launch_triton_server.py /workspace/launch_triton_server.py +COPY triton_model_repo /workspace/model_repo \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/launch_triton_server.py b/model-engine/model_engine_server/inference/tensorrt-llm/launch_triton_server.py new file mode 100644 index 00000000..0ce46d2b --- /dev/null +++ b/model-engine/model_engine_server/inference/tensorrt-llm/launch_triton_server.py @@ -0,0 +1,33 @@ +import argparse +import subprocess +from pathlib import Path + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--world_size", type=int, default=1, help="world size, only support tensor parallelism now" + ) + parser.add_argument("--tritonserver", type=str, default="/opt/tritonserver/bin/tritonserver") + parser.add_argument( + "--http-port", + type=int, + default=5005, + help="Default HTTP port to 5005. See llm-engine/model-engine/model_engine_server/inference/configs/service--http_forwarder.yaml", + ) + path = str(Path(__file__).parent.absolute()) + "/../all_models/gpt" + parser.add_argument("--model_repo", type=str, default=path) + return parser.parse_args() + + +def get_cmd(world_size, tritonserver, model_repo, http_port): + cmd = "mpirun --allow-run-as-root " + for i in range(world_size): + cmd += f" -n 1 {tritonserver} --model-repository={model_repo} --http-address ipv6:[::1] --http-port {http_port} --disable-auto-complete-config --backend-config=python,shm-region-prefix-name=prefix{i}_ : " + return cmd + + +if __name__ == "__main__": + args = parse_arguments() + cmd = get_cmd(int(args.world_size), args.tritonserver, args.model_repo, args.http_port) + subprocess.call(cmd, shell=True) diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/requirements.txt b/model-engine/model_engine_server/inference/tensorrt-llm/requirements.txt new file mode 100644 index 00000000..e2e60684 --- /dev/null +++ b/model-engine/model_engine_server/inference/tensorrt-llm/requirements.txt @@ -0,0 +1,2 @@ +sentencepiece==0.1.99 +protobuf==4.24.4 \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/ensemble/1/.tmp b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/ensemble/1/.tmp new file mode 100644 index 00000000..e69de29b diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/ensemble/config.pbtxt b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/ensemble/config.pbtxt new file mode 100755 index 00000000..7a7662d3 --- /dev/null +++ b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/ensemble/config.pbtxt @@ -0,0 +1,255 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "ensemble" +platform: "ensemble" +max_batch_size: 128 +input [ + { + name: "text_input" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "max_tokens" + data_type: TYPE_UINT32 + dims: [ -1 ] + }, + { + name: "bad_words" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "stop_words" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "end_id" + data_type: TYPE_UINT32 + dims: [ 1 ] + optional: true + }, + { + name: "pad_id" + data_type: TYPE_UINT32 + dims: [ 1 ] + optional: true + }, + { + name: "top_k" + data_type: TYPE_UINT32 + dims: [ 1 ] + optional: true + }, + { + name: "top_p" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "temperature" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "length_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "repetition_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "min_length" + data_type: TYPE_UINT32 + dims: [ 1 ] + optional: true + }, + { + name: "presence_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "random_seed" + data_type: TYPE_UINT64 + dims: [ 1 ] + optional: true + }, + { + name: "beam_width" + data_type: TYPE_UINT32 + dims: [ 1 ] + optional: true + }, + { + name: "stream" + data_type: TYPE_BOOL + dims: [ 1 ] + optional: true + } +] +output [ + { + name: "text_output" + data_type: TYPE_STRING + dims: [ -1, -1 ] + }, + { + name: "token_ids" + data_type: TYPE_INT32 + dims: [ -1, -1 ] + } +] +ensemble_scheduling { + step [ + { + model_name: "preprocessing" + model_version: -1 + input_map { + key: "QUERY" + value: "text_input" + } + input_map { + key: "REQUEST_OUTPUT_LEN" + value: "max_tokens" + } + input_map { + key: "BAD_WORDS_DICT" + value: "bad_words" + } + input_map { + key: "STOP_WORDS_DICT" + value: "stop_words" + } + output_map { + key: "REQUEST_INPUT_LEN" + value: "_REQUEST_INPUT_LEN" + } + output_map { + key: "INPUT_ID" + value: "_INPUT_ID" + } + output_map { + key: "REQUEST_OUTPUT_LEN" + value: "_REQUEST_OUTPUT_LEN" + } + }, + { + model_name: "tensorrt_llm" + model_version: -1 + input_map { + key: "input_ids" + value: "_INPUT_ID" + } + input_map { + key: "input_lengths" + value: "_REQUEST_INPUT_LEN" + } + input_map { + key: "request_output_len" + value: "_REQUEST_OUTPUT_LEN" + } + input_map { + key: "end_id" + value: "end_id" + } + input_map { + key: "pad_id" + value: "pad_id" + } + input_map { + key: "runtime_top_k" + value: "top_k" + } + input_map { + key: "runtime_top_p" + value: "top_p" + } + input_map { + key: "temperature" + value: "temperature" + } + input_map { + key: "len_penalty" + value: "length_penalty" + } + input_map { + key: "repetition_penalty" + value: "repetition_penalty" + } + input_map { + key: "min_length" + value: "min_length" + } + input_map { + key: "presence_penalty" + value: "presence_penalty" + } + input_map { + key: "random_seed" + value: "random_seed" + } + input_map { + key: "beam_width" + value: "beam_width" + } + input_map { + key: "streaming" + value: "stream" + } + output_map { + key: "output_ids" + value: "_TOKENS_BATCH" + } + }, + { + model_name: "postprocessing" + model_version: -1 + input_map { + key: "TOKENS_BATCH" + value: "_TOKENS_BATCH" + } + output_map { + key: "OUTPUT" + value: "text_output" + } + output_map { + key: "OUTPUT_TOKEN_IDS" + value: "token_ids" + } + } + ] +} diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/postprocessing/1/model.py b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/postprocessing/1/model.py new file mode 100644 index 00000000..1cd809d9 --- /dev/null +++ b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/postprocessing/1/model.py @@ -0,0 +1,156 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json + +import numpy as np +import triton_python_backend_utils as pb_utils +from transformers import AutoTokenizer, LlamaTokenizer, T5Tokenizer + + +class TritonPythonModel: + """Your Python model must use the same class name. Every Python model + that is created must have "TritonPythonModel" as the class name. + """ + + def initialize(self, args): + """`initialize` is called only once when the model is being loaded. + Implementing `initialize` function is optional. This function allows + the model to initialize any state associated with this model. + Parameters + ---------- + args : dict + Both keys and values are strings. The dictionary keys and values are: + * model_config: A JSON string containing the model configuration + * model_instance_kind: A string containing model instance kind + * model_instance_device_id: A string containing model instance device ID + * model_repository: Model repository path + * model_version: Model version + * model_name: Model name + """ + # Parse model configs + model_config = json.loads(args["model_config"]) + tokenizer_dir = model_config["parameters"]["tokenizer_dir"]["string_value"] + tokenizer_type = model_config["parameters"]["tokenizer_type"]["string_value"] + + if tokenizer_type == "t5": + self.tokenizer = T5Tokenizer(vocab_file=tokenizer_dir, padding_side="left") + elif tokenizer_type == "auto": + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, padding_side="left") + elif tokenizer_type == "llama": + self.tokenizer = LlamaTokenizer.from_pretrained( + tokenizer_dir, legacy=False, padding_side="left" + ) + else: + raise AttributeError(f"Unexpected tokenizer type: {tokenizer_type}") + self.tokenizer.pad_token = self.tokenizer.eos_token + + # Parse model output configs + output_config = pb_utils.get_output_config_by_name(model_config, "OUTPUT") + output_token_ids_config = pb_utils.get_output_config_by_name( + model_config, "OUTPUT_TOKEN_IDS" + ) + + # Convert Triton types to numpy types + self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"]) + + self.output_token_ids_dtype = pb_utils.triton_string_to_numpy( + output_token_ids_config["data_type"] + ) + + def execute(self, requests): + """`execute` must be implemented in every Python model. `execute` + function receives a list of pb_utils.InferenceRequest as the only + argument. This function is called when an inference is requested + for this model. Depending on the batching configuration (e.g. Dynamic + Batching) used, `requests` may contain multiple requests. Every + Python model, must create one pb_utils.InferenceResponse for every + pb_utils.InferenceRequest in `requests`. If there is an error, you can + set the error argument when creating a pb_utils.InferenceResponse. + Parameters + ---------- + requests : list + A list of pb_utils.InferenceRequest + Returns + ------- + list + A list of pb_utils.InferenceResponse. The length of this list must + be the same as `requests` + """ + + responses = [] + + # Every Python backend must iterate over everyone of the requests + # and create a pb_utils.InferenceResponse for each of them. + for idx, request in enumerate(requests): + # Get input tensors + tokens_batch = pb_utils.get_input_tensor_by_name(request, "TOKENS_BATCH").as_numpy() + + # Reshape Input + # tokens_batch = tokens_batch.reshape([-1, tokens_batch.shape[0]]) + # tokens_batch = tokens_batch.T + + # Postprocessing output data. + outputs = self._postprocessing(tokens_batch) + + # Create output tensors. You need pb_utils.Tensor + # objects to create pb_utils.InferenceResponse. + output_tensor = pb_utils.Tensor("OUTPUT", np.array(outputs).astype(self.output_dtype)) + + output_token_ids = pb_utils.Tensor( + "OUTPUT_TOKEN_IDS", np.array(tokens_batch).astype(self.output_token_ids_dtype) + ) + + # Create InferenceResponse. You can set an error here in case + # there was a problem with handling this inference request. + # Below is an example of how you can set errors in inference + # response: + # + # pb_utils.InferenceResponse( + # output_tensors=..., TritonError("An error occurred")) + inference_response = pb_utils.InferenceResponse( + output_tensors=[output_tensor, output_token_ids] + ) + responses.append(inference_response) + + # You should return a list of pb_utils.InferenceResponse. Length + # of this list must match the length of `requests` list. + return responses + + def finalize(self): + """`finalize` is called only once when the model is being unloaded. + Implementing `finalize` function is optional. This function allows + the model to perform any necessary clean ups before exit. + """ + print("Cleaning up...") + + def _postprocessing(self, tokens_batch): + outputs = [] + for beam_tokens in tokens_batch: + for tokens in beam_tokens: + output = self.tokenizer.decode(tokens) + outputs.append(output.encode("utf8")) + return outputs diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/postprocessing/config.pbtxt b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/postprocessing/config.pbtxt new file mode 100755 index 00000000..cc61a24e --- /dev/null +++ b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/postprocessing/config.pbtxt @@ -0,0 +1,69 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "postprocessing" +backend: "python" +max_batch_size: 128 +input [ + { + name: "TOKENS_BATCH" + data_type: TYPE_INT32 + dims: [ -1, -1 ] + } +] +output [ + { + name: "OUTPUT" + data_type: TYPE_STRING + dims: [ -1, -1 ] + }, + { + name: "OUTPUT_TOKEN_IDS" + data_type: TYPE_INT32 + dims: [ -1, -1 ] + } +] + +parameters { + key: "tokenizer_dir" + value: { + string_value: "model_tokenizer" + } +} + +parameters { + key: "tokenizer_type" + value: { + string_value: "llama" + } +} + +instance_group [ + { + count: 1 + kind: KIND_CPU + } +] diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/preprocessing/1/model.py b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/preprocessing/1/model.py new file mode 100644 index 00000000..b5996f87 --- /dev/null +++ b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/preprocessing/1/model.py @@ -0,0 +1,224 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import csv +import json +from typing import List + +import numpy as np +import torch +import triton_python_backend_utils as pb_utils +from torch.nn.utils.rnn import pad_sequence +from transformers import AutoTokenizer, LlamaTokenizer, T5Tokenizer + + +class TritonPythonModel: + """Your Python model must use the same class name. Every Python model + that is created must have "TritonPythonModel" as the class name. + """ + + def initialize(self, args): + """`initialize` is called only once when the model is being loaded. + Implementing `initialize` function is optional. This function allows + the model to initialize any state associated with this model. + Parameters + ---------- + args : dict + Both keys and values are strings. The dictionary keys and values are: + * model_config: A JSON string containing the model configuration + * model_instance_kind: A string containing model instance kind + * model_instance_device_id: A string containing model instance device ID + * model_repository: Model repository path + * model_version: Model version + * model_name: Model name + """ + # Parse model configs + model_config = json.loads(args["model_config"]) + tokenizer_dir = model_config["parameters"]["tokenizer_dir"]["string_value"] + tokenizer_type = model_config["parameters"]["tokenizer_type"]["string_value"] + + if tokenizer_type == "t5": + self.tokenizer = T5Tokenizer(vocab_file=tokenizer_dir, padding_side="left") + elif tokenizer_type == "auto": + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, padding_side="left") + elif tokenizer_type == "llama": + self.tokenizer = LlamaTokenizer.from_pretrained( + tokenizer_dir, legacy=False, padding_side="left" + ) + else: + raise AttributeError(f"Unexpected tokenizer type: {tokenizer_type}") + self.tokenizer.pad_token = self.tokenizer.eos_token + + self.pad_id = self.tokenizer.encode(self.tokenizer.pad_token, add_special_tokens=False)[0] + + # Parse model output configs and convert Triton types to numpy types + input_names = ["INPUT_ID", "REQUEST_INPUT_LEN", "BAD_WORDS_IDS", "STOP_WORDS_IDS"] + for input_name in input_names: + setattr( + self, + input_name.lower() + "_dtype", + pb_utils.triton_string_to_numpy( + pb_utils.get_output_config_by_name(model_config, input_name)["data_type"] + ), + ) + + def execute(self, requests): + """`execute` must be implemented in every Python model. `execute` + function receives a list of pb_utils.InferenceRequest as the only + argument. This function is called when an inference is requested + for this model. Depending on the batching configuration (e.g. Dynamic + Batching) used, `requests` may contain multiple requests. Every + Python model, must create one pb_utils.InferenceResponse for every + pb_utils.InferenceRequest in `requests`. If there is an error, you can + set the error argument when creating a pb_utils.InferenceResponse. + Parameters + ---------- + requests : list + A list of pb_utils.InferenceRequest + Returns + ------- + list + A list of pb_utils.InferenceResponse. The length of this list must + be the same as `requests` + """ + + responses = [] + + # Every Python backend must iterate over everyone of the requests + # and create a pb_utils.InferenceResponse for each of them. + for idx, request in enumerate(requests): + # Get input tensors + query = pb_utils.get_input_tensor_by_name(request, "QUERY").as_numpy() + request_output_len = pb_utils.get_input_tensor_by_name( + request, "REQUEST_OUTPUT_LEN" + ).as_numpy() + + bad_words_dict = pb_utils.get_input_tensor_by_name(request, "BAD_WORDS_DICT").as_numpy() + stop_words_dict = pb_utils.get_input_tensor_by_name( + request, "STOP_WORDS_DICT" + ).as_numpy() + + # Preprocessing input data. + input_id, request_input_len = self._create_request(query) + bad_words = self._to_word_list_format(bad_words_dict) + stop_words = self._to_word_list_format(stop_words_dict) + + # Create output tensors. You need pb_utils.Tensor + # objects to create pb_utils.InferenceResponse. + input_id_tensor = pb_utils.Tensor( + "INPUT_ID", np.array(input_id).astype(self.input_id_dtype) + ) + request_input_len_tensor = pb_utils.Tensor( + "REQUEST_INPUT_LEN", + np.array(request_input_len).astype(self.request_input_len_dtype), + ) + request_output_len_tensor = pb_utils.Tensor("REQUEST_OUTPUT_LEN", request_output_len) + bad_words_ids_tensor = pb_utils.Tensor("BAD_WORDS_IDS", bad_words) + stop_words_ids_tensor = pb_utils.Tensor("STOP_WORDS_IDS", stop_words) + + # Create InferenceResponse. You can set an error here in case + # there was a problem with handling this inference request. + # Below is an example of how you can set errors in inference + # response: + # + # pb_utils.InferenceResponse( + # output_tensors=..., TritonError("An error occurred")) + inference_response = pb_utils.InferenceResponse( + output_tensors=[ + input_id_tensor, + bad_words_ids_tensor, + stop_words_ids_tensor, + request_input_len_tensor, + request_output_len_tensor, + ] + ) + responses.append(inference_response) + + # You should return a list of pb_utils.InferenceResponse. Length + # of this list must match the length of `requests` list. + return responses + + def finalize(self): + """`finalize` is called only once when the model is being unloaded. + Implementing `finalize` function is optional. This function allows + the model to perform any necessary clean ups before exit. + """ + print("Cleaning up...") + + def _create_request(self, query): + """ + query : batch string (2D numpy array) + """ + start_ids = [torch.IntTensor(self.tokenizer.encode(s[0].decode())) for s in query] + start_lengths = torch.IntTensor([[len(ids)] for ids in start_ids]) + + start_ids = pad_sequence(start_ids, batch_first=True, padding_value=self.pad_id) + # input_len = min(start_lengths) + # attn_mask = torch.ones((batch_size, input_len, input_len)).tril() + + return start_ids, start_lengths + + def _to_word_list_format(self, word_dict: List[List[str]]): + """ + format of word_dict + len(word_dict) should be same to batch_size + word_dict[i] means the words for batch i + len(word_dict[i]) must be 1, which means it only contains 1 string + This string can contains several sentences and split by ",". + For example, if word_dict[2] = " I am happy, I am sad", then this function will return + the ids for two short sentences " I am happy" and " I am sad". + """ + assert self.tokenizer is not None, "need to set tokenizer" + + flat_ids = [] + offsets = [] + for word_dict_item in word_dict: + item_flat_ids = [] + item_offsets = [] + + if isinstance(word_dict_item[0], bytes): + word_dict_item = [word_dict_item[0].decode()] + + words = list(csv.reader(word_dict_item))[0] + for word in words: + ids = self.tokenizer.encode(word) + + if len(ids) == 0: + continue + + item_flat_ids += ids + item_offsets.append(len(ids)) + + flat_ids.append(np.array(item_flat_ids)) + offsets.append(np.cumsum(np.array(item_offsets))) + + pad_to = max(1, max(len(ids) for ids in flat_ids)) + + for i, (ids, offs) in enumerate(zip(flat_ids, offsets)): + flat_ids[i] = np.pad(ids, (0, pad_to - len(ids)), constant_values=0) + offsets[i] = np.pad(offs, (0, pad_to - len(offs)), constant_values=-1) + + return np.array([flat_ids, offsets], dtype="int32").transpose((1, 0, 2)) diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/preprocessing/config.pbtxt b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/preprocessing/config.pbtxt new file mode 100644 index 00000000..89d9c91e --- /dev/null +++ b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/preprocessing/config.pbtxt @@ -0,0 +1,99 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "preprocessing" +backend: "python" +max_batch_size: 128 +input [ + { + name: "QUERY" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "BAD_WORDS_DICT" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "STOP_WORDS_DICT" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "REQUEST_OUTPUT_LEN" + data_type: TYPE_UINT32 + dims: [ -1 ] + } +] +output [ + { + name: "INPUT_ID" + data_type: TYPE_INT32 + dims: [ -1 ] + }, + { + name: "REQUEST_INPUT_LEN" + data_type: TYPE_INT32 + dims: [ 1 ] + }, + { + name: "BAD_WORDS_IDS" + data_type: TYPE_INT32 + dims: [ 2, -1 ] + }, + { + name: "STOP_WORDS_IDS" + data_type: TYPE_INT32 + dims: [ 2, -1 ] + }, + { + name: "REQUEST_OUTPUT_LEN" + data_type: TYPE_UINT32 + dims: [ -1 ] + } +] + +parameters { + key: "tokenizer_dir" + value: { + string_value: "model_tokenizer" + } +} + +parameters { + key: "tokenizer_type" + value: { + string_value: "llama" + } +} + +instance_group [ + { + count: 1 + kind: KIND_CPU + } +] diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/tensorrt_llm/1/.gitkeep b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/tensorrt_llm/1/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/tensorrt_llm/config.pbtxt b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/tensorrt_llm/config.pbtxt new file mode 100644 index 00000000..e24a95b4 --- /dev/null +++ b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/tensorrt_llm/config.pbtxt @@ -0,0 +1,208 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "tensorrt_llm" +backend: "tensorrtllm" +max_batch_size: 128 + +model_transaction_policy { + decoupled: true +} + +input [ + { + name: "input_ids" + data_type: TYPE_INT32 + dims: [ -1 ] + }, + { + name: "input_lengths" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + }, + { + name: "request_output_len" + data_type: TYPE_UINT32 + dims: [ 1 ] + }, + { + name: "end_id" + data_type: TYPE_UINT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "pad_id" + data_type: TYPE_UINT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "beam_width" + data_type: TYPE_UINT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "temperature" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "runtime_top_k" + data_type: TYPE_UINT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "runtime_top_p" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "len_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "repetition_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "min_length" + data_type: TYPE_UINT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "presence_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "random_seed" + data_type: TYPE_UINT64 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "stop" + data_type: TYPE_BOOL + dims: [ 1 ] + optional: true + }, + { + name: "streaming" + data_type: TYPE_BOOL + dims: [ 1 ] + optional: true + } +] +output [ + { + name: "output_ids" + data_type: TYPE_INT32 + dims: [ -1, -1 ] + } +] +instance_group [ + { + count: 1 + kind : KIND_CPU + } +] +parameters: { + key: "max_beam_width" + value: { + string_value: "1" + } +} +parameters: { + key: "FORCE_CPU_ONLY_INPUT_TENSORS" + value: { + string_value: "no" + } +} +parameters: { + key: "gpt_model_type" + value: { + string_value: "inflight_fused_batching" + } +} +parameters: { + key: "gpt_model_path" + value: { + string_value: "./model_weights" + } +} +parameters: { + key: "max_tokens_in_paged_kv_cache" + value: { + string_value: "${max_tokens_in_paged_kv_cache}" + } +} +parameters: { + key: "batch_scheduler_policy" + value: { + string_value: "${batch_scheduler_policy}" + } +} +parameters: { + key: "kv_cache_free_gpu_mem_fraction" + value: { + string_value: "0.9" + } +} +parameters: { + key: "max_num_sequences" + value: { + string_value: "${max_num_sequences}" + } +} +parameters: { + key: "enable_trt_overlap" + value: { + string_value: "${enable_trt_overlap}" + } +} diff --git a/model-engine/mypy.ini b/model-engine/mypy.ini index 9abfbeaa..82c6107a 100644 --- a/model-engine/mypy.ini +++ b/model-engine/mypy.ini @@ -6,7 +6,7 @@ namespace_packages = True explicit_package_bases = True strict_optional = True plugins = pydantic.mypy -exclude = clients +exclude = clients|.*/triton_model_repo/.* [mypy-model_engine_server.cli.*] ignore_errors = True diff --git a/model-engine/requirements.in b/model-engine/requirements.in index e173eeef..ecdf78a1 100644 --- a/model-engine/requirements.in +++ b/model-engine/requirements.in @@ -45,7 +45,8 @@ sseclient-py==1.7.2 tenacity>=6.0.0,<=6.2.0 testing-postgresql==1.3.0 tqdm~=4.64 +transformers==4.34.1 twine==3.7.1 uvicorn==0.17.6 uvloop==0.17.0 -yarl~=1.4 +yarl~=1.4 \ No newline at end of file diff --git a/model-engine/requirements.txt b/model-engine/requirements.txt index c367da1e..87adb372 100644 --- a/model-engine/requirements.txt +++ b/model-engine/requirements.txt @@ -133,10 +133,16 @@ exceptiongroup==1.1.3 # cattrs fastapi==0.78.0 # via -r model-engine/requirements.in +filelock==3.13.1 + # via + # huggingface-hub + # transformers frozenlist==1.3.3 # via # aiohttp # aiosignal +fsspec==2023.10.0 + # via huggingface-hub gitdb==4.0.10 # via gitpython gitdb2==2.0.6 @@ -160,6 +166,10 @@ hpack==4.0.0 # via h2 httptools==0.5.0 # via -r model-engine/requirements.in +huggingface-hub==0.17.3 + # via + # tokenizers + # transformers hypercorn==0.14.4 # via quart hyperframe==6.0.1 @@ -175,7 +185,7 @@ importlib-metadata==6.8.0 # keyring # quart # twine -importlib-resources==6.0.1 +importlib-resources==6.1.0 # via # alembic # jsonschema @@ -249,6 +259,8 @@ mypy-boto3-sqs==1.26.148 # via boto3-stubs mypy-extensions==1.0.0 # via typing-inspect +numpy==1.24.4 + # via transformers oauthlib==3.2.2 # via requests-oauthlib orjson==3.8.6 @@ -258,7 +270,9 @@ packaging==23.1 # build # ddtrace # deprecation + # huggingface-hub # marshmallow + # transformers pep517==0.13.0 # via build pg8000==1.29.8 @@ -314,9 +328,11 @@ python-multipart==0.0.6 # via -r model-engine/requirements.in pyyaml==6.0 # via + # huggingface-hub # kubeconfig # kubernetes # kubernetes-asyncio + # transformers quart==0.18.3 # via -r model-engine/requirements.in readme-renderer==40.0 @@ -327,15 +343,19 @@ referencing==0.30.2 # via # jsonschema # jsonschema-specifications +regex==2023.10.3 + # via transformers requests==2.31.0 # via # -r model-engine/requirements.in # datadog # docker + # huggingface-hub # kubernetes # requests-auth-aws-sigv4 # requests-oauthlib # requests-toolbelt + # transformers # twine requests-auth-aws-sigv4==0.7 # via -r model-engine/requirements.in @@ -355,6 +375,8 @@ rsa==4.9 # via google-auth s3transfer==0.6.1 # via boto3 +safetensors==0.4.0 + # via transformers scramp==1.4.4 # via pg8000 secretstorage==3.3.3 @@ -403,6 +425,8 @@ testing-common-database==2.0.3 # via testing-postgresql testing-postgresql==1.3.0 # via -r model-engine/requirements.in +tokenizers==0.14.1 + # via transformers tomli==2.0.1 # via # build @@ -411,7 +435,11 @@ tomli==2.0.1 tqdm==4.65.0 # via # -r model-engine/requirements.in + # huggingface-hub + # transformers # twine +transformers==4.34.1 + # via -r model-engine/requirements.in twine==3.7.1 # via -r model-engine/requirements.in types-awscrt==0.16.23 @@ -430,6 +458,7 @@ typing-extensions==4.7.1 # cattrs # datadog-api-client # ddtrace + # huggingface-hub # kombu # mypy-boto3-cloudformation # mypy-boto3-dynamodb diff --git a/model-engine/service_configs/service_config_circleci.yaml b/model-engine/service_configs/service_config_circleci.yaml index 3438f65d..68172acf 100644 --- a/model-engine/service_configs/service_config_circleci.yaml +++ b/model-engine/service_configs/service_config_circleci.yaml @@ -57,6 +57,7 @@ istio_enabled: true tgi_repository: "text-generation-inference" vllm_repository: "vllm" lightllm_repository: "lightllm" +tensorrt_llm_repository: "tensorrt-llm" user_inference_base_repository: "launch/inference" user_inference_pytorch_repository: "hosted-model-inference/async-pytorch" user_inference_tensorflow_repository: "hosted-model-inference/async-tensorflow-cpu" diff --git a/model-engine/setup.cfg b/model-engine/setup.cfg index c47c17ed..f40a2dd1 100644 --- a/model-engine/setup.cfg +++ b/model-engine/setup.cfg @@ -31,5 +31,7 @@ addopts = --mypy --mypy-ini-file=mypy.ini --ignore=clients +# Need to specify this since pytest override mypy.ini See https://github.com/realpython/pytest-mypy/issues/123 + --ignore-glob=*triton_model_repo* # --pylint # --pylint-rcfile=setup.cfg diff --git a/model-engine/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py index e9dd1e44..445ab83d 100644 --- a/model-engine/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -3942,3 +3942,76 @@ def llm_model_endpoint_text_generation_inference( image="test_image", ), ) + + +@pytest.fixture +def llm_model_endpoint_trt_llm( + test_api_key: str, model_bundle_1: ModelBundle +) -> Tuple[ModelEndpoint, Any]: + return ModelEndpoint( + record=ModelEndpointRecord( + id="test_llm_model_endpoint_id_3", + name="test_llm_model_endpoint_name_trt_llm", + created_by=test_api_key, + created_at=datetime(2022, 1, 3), + last_updated_at=datetime(2022, 1, 3), + metadata={ + "_llm": { + "model_name": "llama-2-7b", + "source": "hugging_face", + "inference_framework": "tensorrt_llm", + "inference_framework_image_tag": "23.10", + "num_shards": 4, + } + }, + creation_task_id="test_creation_task_id", + endpoint_type=ModelEndpointType.STREAMING, + destination="test_destination", + status=ModelEndpointStatus.READY, + current_model_bundle=model_bundle_1, + owner=test_api_key, + public_inference=True, + ), + infra_state=ModelEndpointInfraState( + deployment_name=f"{test_api_key}-test_llm_model_endpoint_name_trt_llm", + aws_role="test_aws_role", + results_s3_bucket="test_s3_bucket", + child_fn_info=None, + labels={}, + prewarm=True, + high_priority=False, + deployment_state=ModelEndpointDeploymentState( + min_workers=1, + max_workers=3, + per_worker=2, + available_workers=1, + unavailable_workers=1, + ), + resource_state=ModelEndpointResourceState( + cpus=1, + gpus=1, + memory="1G", + gpu_type=GpuType.NVIDIA_TESLA_T4, + storage="10G", + optimize_costs=True, + ), + user_config_state=ModelEndpointUserConfigState( + app_config=model_bundle_1.app_config, + endpoint_config=ModelEndpointConfig( + bundle_name=model_bundle_1.name, + endpoint_name="test_llm_model_endpoint_name_1", + post_inference_hooks=["callback"], + default_callback_url="http://www.example.com", + default_callback_auth=CallbackAuth( + __root__=CallbackBasicAuth( + kind="basic", + username="test_username", + password="test_password", + ), + ), + ), + ), + num_queued_items=1, + image="test_image", + ), + ) diff --git a/model-engine/tests/unit/domain/conftest.py b/model-engine/tests/unit/domain/conftest.py index c27aaa52..06310666 100644 --- a/model-engine/tests/unit/domain/conftest.py +++ b/model-engine/tests/unit/domain/conftest.py @@ -330,6 +330,61 @@ def create_llm_model_endpoint_text_generation_inference_request_async() -> ( ) +@pytest.fixture +def create_llm_model_endpoint_trt_llm_request_streaming() -> CreateLLMModelEndpointV1Request: + return CreateLLMModelEndpointV1Request( + name="test_llm_endpoint_name_trt_llm_streaming", + model_name="llama-2-7b", + source="hugging_face", + inference_framework="tensorrt_llm", + inference_framework_image_tag="23.10", + num_shards=2, + endpoint_type=ModelEndpointType.STREAMING, + metadata={}, + post_inference_hooks=["billing"], + cpus=1, + gpus=2, + memory="8G", + gpu_type=GpuType.NVIDIA_TESLA_T4, + storage=None, + min_workers=1, + max_workers=3, + per_worker=2, + labels={"team": "infra", "product": "my_product"}, + aws_role="test_aws_role", + results_s3_bucket="test_s3_bucket", + checkpoint_path="s3://test_checkpoint_path", + ) + + +@pytest.fixture +def create_llm_model_endpoint_trt_llm_request_async() -> CreateLLMModelEndpointV1Request: + return CreateLLMModelEndpointV1Request( + name="test_llm_endpoint_name_tgi_async", + model_name="llama-2-7b", + source="hugging_face", + inference_framework="tensorrt_llm", + inference_framework_image_tag="23.10", + num_shards=2, + quantize=Quantization.BITSANDBYTES, + endpoint_type=ModelEndpointType.ASYNC, + metadata={}, + post_inference_hooks=["billing"], + cpus=1, + gpus=2, + memory="8G", + gpu_type=GpuType.NVIDIA_TESLA_T4, + storage=None, + min_workers=1, + max_workers=3, + per_worker=2, + labels={"team": "infra", "product": "my_product"}, + aws_role="test_aws_role", + results_s3_bucket="test_s3_bucket", + checkpoint_path="s3://test_checkpoint_path", + ) + + @pytest.fixture def create_llm_model_endpoint_request_invalid_model_name() -> CreateLLMModelEndpointV1Request: return CreateLLMModelEndpointV1Request( @@ -386,7 +441,7 @@ def create_llm_model_endpoint_request_invalid_quantization() -> CreateLLMModelEn @pytest.fixture def completion_sync_request() -> CompletionSyncV1Request: return CompletionSyncV1Request( - prompt="test_prompt_1", + prompt="What is machine learning?", max_new_tokens=10, temperature=0.5, return_token_log_probs=True, @@ -396,7 +451,7 @@ def completion_sync_request() -> CompletionSyncV1Request: @pytest.fixture def completion_stream_request() -> CompletionStreamV1Request: return CompletionStreamV1Request( - prompt="test_prompt_1", + prompt="What is machine learning?", max_new_tokens=10, temperature=0.5, ) diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index c83e1049..d7ec41f0 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -260,6 +260,63 @@ async def test_create_model_endpoint_text_generation_inference_use_case_success( ) +@pytest.mark.asyncio +async def test_create_model_endpoint_trt_llm_use_case_success( + test_api_key: str, + fake_model_bundle_repository, + fake_model_endpoint_service, + fake_docker_repository_image_always_exists, + fake_model_primitive_gateway, + fake_llm_artifact_gateway, + create_llm_model_endpoint_trt_llm_request_async: CreateLLMModelEndpointV1Request, + create_llm_model_endpoint_trt_llm_request_streaming: CreateLLMModelEndpointV1Request, +): + fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository + bundle_use_case = CreateModelBundleV2UseCase( + model_bundle_repository=fake_model_bundle_repository, + docker_repository=fake_docker_repository_image_always_exists, + model_primitive_gateway=fake_model_primitive_gateway, + ) + use_case = CreateLLMModelEndpointV1UseCase( + create_model_bundle_use_case=bundle_use_case, + model_bundle_repository=fake_model_bundle_repository, + model_endpoint_service=fake_model_endpoint_service, + llm_artifact_gateway=fake_llm_artifact_gateway, + docker_repository=fake_docker_repository_image_always_exists, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + response_1 = await use_case.execute( + user=user, + request=create_llm_model_endpoint_trt_llm_request_streaming, + ) + assert response_1.endpoint_creation_task_id + assert isinstance(response_1, CreateLLMModelEndpointV1Response) + endpoint = ( + await fake_model_endpoint_service.list_model_endpoints( + owner=None, + name=create_llm_model_endpoint_trt_llm_request_streaming.name, + order_by=None, + ) + )[0] + assert endpoint.record.endpoint_type == ModelEndpointType.STREAMING + assert endpoint.record.metadata == { + "_llm": { + "model_name": create_llm_model_endpoint_trt_llm_request_streaming.model_name, + "source": create_llm_model_endpoint_trt_llm_request_streaming.source, + "inference_framework": create_llm_model_endpoint_trt_llm_request_streaming.inference_framework, + "inference_framework_image_tag": create_llm_model_endpoint_trt_llm_request_streaming.inference_framework_image_tag, + "num_shards": create_llm_model_endpoint_trt_llm_request_streaming.num_shards, + "quantize": create_llm_model_endpoint_trt_llm_request_streaming.quantize, + } + } + + with pytest.raises(ObjectHasInvalidValueException): + await use_case.execute( + user=user, + request=create_llm_model_endpoint_trt_llm_request_async, + ) + + @pytest.mark.asyncio async def test_create_llm_model_endpoint_use_case_raises_invalid_value_exception( test_api_key: str, @@ -545,6 +602,39 @@ async def test_completion_sync_text_generation_inference_use_case_success( ) +@pytest.mark.asyncio +async def test_completion_sync_trt_llm_use_case_success( + test_api_key: str, + fake_model_endpoint_service, + fake_llm_model_endpoint_service, + llm_model_endpoint_trt_llm: ModelEndpoint, + completion_sync_request: CompletionSyncV1Request, +): + completion_sync_request.return_token_log_probs = False # not yet supported + fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_trt_llm) + fake_model_endpoint_service.sync_model_endpoint_inference_gateway.response = SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={ + "result": '{"model_name": "ensemble", "model_version": "1", "sequence_end": false, "sequence_id": 0, "sequence_start": false, "text_output": " What is machine learning? Machine learning is a branch", "token_ids": [1, 1724, 338, 4933, 6509, 29973, 6189, 6509, 338, 263, 5443]}' + }, + traceback=None, + ) + use_case = CompletionSyncV1UseCase( + model_endpoint_service=fake_model_endpoint_service, + llm_model_endpoint_service=fake_llm_model_endpoint_service, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + response_1 = await use_case.execute( + user=user, + model_endpoint_name=llm_model_endpoint_trt_llm.record.name, + request=completion_sync_request, + ) + assert response_1.output == CompletionOutput( + text=" Machine learning is a branch", + num_completion_tokens=5, + ) + + @pytest.mark.asyncio async def test_completion_sync_use_case_predict_failed( test_api_key: str, @@ -777,6 +867,61 @@ async def test_completion_stream_text_generation_inference_use_case_success( i += 1 +@pytest.mark.asyncio +async def test_completion_stream_trt_llm_use_case_success( + test_api_key: str, + fake_model_endpoint_service, + fake_llm_model_endpoint_service, + llm_model_endpoint_trt_llm: ModelEndpoint, + completion_stream_request: CompletionStreamV1Request, +): + fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_trt_llm) + fake_model_endpoint_service.streaming_model_endpoint_inference_gateway.responses = [ + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={"result": {"text_output": "Machine", "token_ids": 6189}}, + traceback=None, + ), + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={"result": {"text_output": "learning", "token_ids": 6509}}, + traceback=None, + ), + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={"result": {"text_output": "is", "token_ids": 338}}, + traceback=None, + ), + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={"result": {"text_output": "a", "token_ids": 263}}, + traceback=None, + ), + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={"result": {"text_output": "branch", "token_ids": 5443}}, + traceback=None, + ), + ] + use_case = CompletionStreamV1UseCase( + model_endpoint_service=fake_model_endpoint_service, + llm_model_endpoint_service=fake_llm_model_endpoint_service, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + response_1 = use_case.execute( + user=user, + model_endpoint_name=llm_model_endpoint_trt_llm.record.name, + request=completion_stream_request, + ) + output_texts = ["Machine", "learning", "is", "a", "branch"] + i = 0 + async for message in response_1: + assert message.dict()["request_id"] + assert message.dict()["output"]["text"] == output_texts[i] + assert message.dict()["output"]["num_completion_tokens"] == i + 1 + i += 1 + + @pytest.mark.asyncio async def test_create_llm_fine_tune_model_name_valid(): assert is_model_name_suffix_valid("model-name")