diff --git a/py/requirements.txt b/py/requirements.txt index dbb342dae5..00cb832331 100644 --- a/py/requirements.txt +++ b/py/requirements.txt @@ -5,4 +5,5 @@ pybind11==2.6.2 torch>=2.8.0.dev,<2.9.0 torchvision>=0.22.0.dev,<0.23.0 --extra-index-url https://pypi.ngc.nvidia.com -pyyaml \ No newline at end of file +pyyaml +dllist \ No newline at end of file diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index 7ba704274e..6dbf52de86 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -9,7 +9,7 @@ import torch import torch.fx from torch_tensorrt._enums import dtype -from torch_tensorrt._features import ENABLED_FEATURES +from torch_tensorrt._features import ENABLED_FEATURES, needs_cross_compile from torch_tensorrt._Input import Input from torch_tensorrt.dynamo import _defaults from torch_tensorrt.dynamo.runtime._CudaGraphsTorchTensorRTModule import ( @@ -301,6 +301,7 @@ def compile( raise RuntimeError("Module is an unknown format or the ir requested is unknown") +@needs_cross_compile def cross_compile_for_windows( module: torch.nn.Module, file_path: str, diff --git a/py/torch_tensorrt/_features.py b/py/torch_tensorrt/_features.py index bee0c3dbf0..3beccec6af 100644 --- a/py/torch_tensorrt/_features.py +++ b/py/torch_tensorrt/_features.py @@ -3,7 +3,10 @@ from collections import namedtuple from typing import Any, Callable, Dict, List, Optional, Type, TypeVar -from torch_tensorrt._utils import sanitized_torch_version +from torch_tensorrt._utils import ( + check_cross_compile_trt_win_lib, + sanitized_torch_version, +) from packaging import version @@ -15,6 +18,7 @@ "dynamo_frontend", "fx_frontend", "refit", + "windows_cross_compile", ], ) @@ -38,9 +42,15 @@ _DYNAMO_FE_AVAIL = version.parse(sanitized_torch_version()) >= version.parse("2.1.dev") _FX_FE_AVAIL = True _REFIT_AVAIL = True +_WINDOWS_CROSS_COMPILE = check_cross_compile_trt_win_lib() ENABLED_FEATURES = FeatureSet( - _TS_FE_AVAIL, _TORCHTRT_RT_AVAIL, _DYNAMO_FE_AVAIL, _FX_FE_AVAIL, _REFIT_AVAIL + _TS_FE_AVAIL, + _TORCHTRT_RT_AVAIL, + _DYNAMO_FE_AVAIL, + _FX_FE_AVAIL, + _REFIT_AVAIL, + _WINDOWS_CROSS_COMPILE, ) @@ -80,6 +90,22 @@ def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: return wrapper +def needs_cross_compile(f: Callable[..., Any]) -> Callable[..., Any]: + def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: + if ENABLED_FEATURES.windows_cross_compile: + return f(*args, **kwargs) + else: + + def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: + raise NotImplementedError( + "Windows cross compilation feature is not available" + ) + + return not_implemented(*args, **kwargs) + + return wrapper + + T = TypeVar("T") diff --git a/py/torch_tensorrt/_utils.py b/py/torch_tensorrt/_utils.py index 3d5f98b5e5..135740a536 100644 --- a/py/torch_tensorrt/_utils.py +++ b/py/torch_tensorrt/_utils.py @@ -1,6 +1,7 @@ from typing import Any import torch +from torch_tensorrt._enums import Platform def sanitized_torch_version() -> Any: @@ -9,3 +10,18 @@ def sanitized_torch_version() -> Any: if ".nv" not in torch.__version__ else torch.__version__.split(".nv")[0] ) + + +def check_cross_compile_trt_win_lib() -> bool: + # cross compile feature is only available on linux + # build engine on linux and run on windows + import dllist + + platform = Platform.current_platform() + platform = str(platform).lower() + if platform.startswith("linux"): + loaded_libs = dllist.dllist() + target_lib = "libnvinfer_builder_resource_win.so.*" + if target_lib in loaded_libs: + return True + return False diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index acd16a32f0..be9e1fae05 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -11,6 +11,7 @@ from torch.fx.node import Target from torch_tensorrt._Device import Device from torch_tensorrt._enums import EngineCapability, dtype +from torch_tensorrt._features import needs_cross_compile from torch_tensorrt._Input import Input from torch_tensorrt.dynamo import _defaults, partitioning from torch_tensorrt.dynamo._DryRunTracker import ( @@ -49,6 +50,7 @@ logger = logging.getLogger(__name__) +@needs_cross_compile def cross_compile_for_windows( exported_program: ExportedProgram, inputs: Optional[Sequence[Sequence[Any]]] = None, @@ -1190,6 +1192,7 @@ def convert_exported_program_to_serialized_trt_engine( return serialized_engine +@needs_cross_compile def save_cross_compiled_exported_program( gm: torch.fx.GraphModule, file_path: str, diff --git a/pyproject.toml b/pyproject.toml index 3bb857e3e0..70e36a3175 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ requires = [ "pybind11==2.6.2", "numpy", "sympy", + "dllist", ] build-backend = "setuptools.build_meta" @@ -63,6 +64,7 @@ dependencies = [ "packaging>=23", "numpy", "typing-extensions>=4.7.0", + "dllist", ] dynamic = ["version"] diff --git a/tests/py/dynamo/runtime/test_003_cross_compile_for_windows.py b/tests/py/dynamo/runtime/test_003_cross_compile_for_windows.py index 44a14a74de..4a4a084dd8 100644 --- a/tests/py/dynamo/runtime/test_003_cross_compile_for_windows.py +++ b/tests/py/dynamo/runtime/test_003_cross_compile_for_windows.py @@ -7,6 +7,7 @@ import torch import torch_tensorrt from torch.testing._internal.common_utils import TestCase +from torch_tensorrt._utils import check_cross_compile_trt_win_lib from ..testing_utilities import DECIMALS_OF_AGREEMENT @@ -16,6 +17,10 @@ class TestCrossCompileSaveForWindows(TestCase): platform.system() != "Linux" or platform.architecture()[0] != "64bit", "Cross compile for windows can only be enabled on linux x86-64 platform", ) + @unittest.skipIf( + not (check_cross_compile_trt_win_lib()), + "TRT windows lib for cross compile not found", + ) @pytest.mark.unit def test_cross_compile_for_windows(self): class Add(torch.nn.Module): @@ -40,6 +45,10 @@ def forward(self, a, b): platform.system() != "Linux" or platform.architecture()[0] != "64bit", "Cross compile for windows can only be enabled on linux x86-64 platform", ) + @unittest.skipIf( + not (check_cross_compile_trt_win_lib()), + "TRT windows lib for cross compile not found", + ) @pytest.mark.unit def test_dynamo_cross_compile_for_windows(self): class Add(torch.nn.Module): @@ -68,6 +77,10 @@ def forward(self, a, b): platform.system() != "Linux" or platform.architecture()[0] != "64bit", "Cross compile for windows can only be enabled on linux x86-64 platform", ) + @unittest.skipIf( + not (check_cross_compile_trt_win_lib()), + "TRT windows lib for cross compile not found", + ) @pytest.mark.unit def test_dynamo_cross_compile_for_windows_multiple_output(self): class Add(torch.nn.Module):