diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 998c83014a..0b0353e607 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -87,6 +87,19 @@ class Frameworks(Enum): } +TRTLLMVersionMapper: Dict[str, Dict[str, str]] = { + "0.18.0": {"min_trt": "10.9", "max_trt": "10.9"}, + "0.17.0": {"min_trt": "10.8", "max_trt": "10.8"}, + "0.16.0": {"min_trt": "10.7", "max_trt": "10.7"}, + "0.15.0": {"min_trt": "10.6", "max_trt": "10.6"}, + "0.14.0": {"min_trt": "10.4", "max_trt": "10.5"}, + "0.13.0": {"min_trt": "10.4", "max_trt": "10.5"}, + "0.12.0": {"min_trt": "10.3", "max_trt": "10.3"}, + "0.11.0": {"min_trt": "10.1", "max_trt": "10.1"}, + "0.10.0": {"min_trt": "10.0", "max_trt": "10.0"}, +} + + def delete_module(module: torch.fx.GraphModule) -> None: """ This is a helper function to delete the instance of module. We first move it to CPU and then @@ -817,13 +830,27 @@ def is_tegra_platform() -> bool: return False +def get_compatible_trtllm_version(trt_version: str) -> str: + for trtllm_version, version_range in TRTLLMVersionMapper.items(): + if version_range["min_trt"] <= trt_version <= version_range["max_trt"]: + return trtllm_version + raise ValueError( + f"No compatible TRT-LLM version found for TRT version {trt_version}" + ) + + def download_plugin_lib_path(py_version: str, platform: str) -> str: plugin_lib_path = None # Downloading TRT-LLM lib - # TODO: check how to fix the 0.18.0 hardcode below base_url = "https://pypi.nvidia.com/tensorrt-llm/" - file_name = f"tensorrt_llm-0.18.0-{py_version}-{py_version}-{platform}.whl" + + tensorrt_version = version.parse(trt.__version__) + tensorrt_major_minor = f"{tensorrt_version.major}.{tensorrt_version.minor}" + tensorrt_llm_version = get_compatible_trtllm_version(tensorrt_major_minor) + file_name = ( + f"tensorrt_llm-{tensorrt_llm_version}-{py_version}-{py_version}-{platform}.whl" + ) download_url = base_url + file_name if not (os.path.exists(file_name)): try: