Description
Bug Description
I tested the 2.9.0 nightly to use the debugger to plot the tensorrt graph produced by torch_tensorrt.dynamo.compile. However running compile on a model quantized by nvidia-modelopt mqt.quantize fails with
10:48:56 - ERROR - [layersImpl.cpp::validateQuantizationHelper::362] Error Code 2: API Usage Error ((Unnamed Layer* 6) [Quantize]: ScaleMode is illegal.)
Running the same exact code but with the latest release causes no issues. Since neither torch, tensorrt or torch_tensorrt changed in major version i assume this isnt intended?
Besides that are there other options to understand what the tensorrt compiles to? So i can verify it actually calls int8 quantized kernels?
To Reproduce
Steps to reproduce the behavior:
- install 2.9.0 nightlys etc.
pip install --pre 'torch>=2.9.0.dev' torchvision --index-url https://download.pytorch.org/whl/nightly/cu128
pip install --pre nvidia-modelopt[torch,onnx] --extra-index-url https://download.pytorch.org/whl/nightly/cu128
pip install 'tensorrt>=10.11.0,<10.12.0'
pip install --pre 'torch-tensorrt==2.9.0.dev20250706+cu128' --extra-index-url https://download.pytorch.org/whl/nightly/cu128
- run provided code
import argparse
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.models as models
from tqdm import tqdm
import torch_tensorrt as torchtrt
import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.utils import export_torch_mode
def get_dataloaders(batch_size):
transform_test = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010)),
])
testset = datasets.CIFAR10(root="./data",
train=False,
download=True,
transform=transform_test)
return torch.utils.data.DataLoader(testset,
batch_size=batch_size,
shuffle=False,
num_workers=8,
drop_last=True)
def main():
parser = argparse.ArgumentParser(
description="quantized ResNet18")
parser.add_argument("--batch-size",
default=32,
type=int,
help="Batch size for calibration and inference")
parser.add_argument("--quantize-type",
default="int8",
choices=["int8", "fp8"],
help="Quantization type (currently supports int8)")
args = parser.parse_args()
device = torch.device("cuda:0"
if torch.cuda.is_available() else "cpu")
testloader_calib = get_dataloaders(args.batch_size)
testloader_dummy = get_dataloaders(args.batch_size)
model = models.resnet18(pretrained=True)
# Replace final layer for CIFAR-10:
model.fc = nn.Linear(model.fc.in_features, 10)
model = model.to(device)
model.eval()
quant_cfg = mtq.INT8_DEFAULT_CFG
def forward_loop(m):
for images, _ in tqdm(testloader_calib):
m(images.to(device))
mtq.quantize(model, quant_cfg, forward_loop=forward_loop)
images, _ = next(iter(testloader_dummy))
with torch.no_grad(), export_torch_mode():
input_tensor = images.to(device)
model.eval()
exp_program = torch.export.export(model, (input_tensor,),
strict=False)
with torchtrt.logging.graphs(), torchtrt.logging.debug():
enabled_precisions = {torch.int8}
trt_model = torchtrt.dynamo.compile(
exp_program,
inputs=[input_tensor],
enabled_precisions=enabled_precisions,
min_block_size=1,
require_full_compilation=True,
disable_tf32=True,
use_python_runtime=False,
)
out = trt_model(input_tensor)
if __name__ == "__main__":
main()
Errorlog:
10:48:56 - INFO - Not found cached TRT engines. Start building engine.
10:48:56 - ERROR - [layersImpl.cpp::validateQuantizationHelper::362] Error Code 2: API Usage Error ((Unnamed Layer* 6) [Quantize]: ScaleMode is illegal.)
10:48:56 - ERROR - [layersImpl.cpp::validateQuantizationHelper::362] Error Code 2: API Usage Error ((Unnamed Layer* 6) [Quantize]: ScaleMode is illegal.)
Traceback (most recent call last):
File "/workspace/t.py", line 190, in <module>
main()
File "/workspace/t.py", line 160, in main
trt_model = torchtrt.dynamo.compile(
^^^^^^^^^^^^^^^^^^^^^^^^
File "/venv/main/lib/python3.12/site-packages/torch_tensorrt/dynamo/_compiler.py", line 720, in compile
trt_gm = compile_module(
^^^^^^^^^^^^^^^
File "/venv/main/lib/python3.12/site-packages/torch_tensorrt/dynamo/_compiler.py", line 950, in compile_module
trt_module = convert_module(
^^^^^^^^^^^^^^^
File "/venv/main/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 88, in convert_module
interpreter_result = interpret_module_to_result(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/venv/main/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 67, in interpret_module_to_result
interpreter_result = interpreter.run()
^^^^^^^^^^^^^^^^^
File "/venv/main/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 742, in run
assert serialized_engine
^^^^^^^^^^^^^^^^^
AssertionError
Expected behavior
Running the same code just with 2.7.0 release works. Tho I am not sure if the model is compiled "correctly" as I expect it. (This is why i wanted to get the debugger running in the first place, to see if i actually get quantized kernels)
dependencies:
tensorrt 10.9.0.34
tensorrt_cu12 10.9.0.34
tensorrt_cu12_bindings 10.9.0.34
tensorrt_cu12_libs 10.9.0.34
torch 2.7.1
torch_tensorrt 2.7.0
torchprofile 0.0.4
torchvision 0.22.1
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
OS: Linux x86_64 Ubuntu 24.04
CUDA version
(main) [email protected]:/workspace$ nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2025 NVIDIA Corporation
Built on Fri_Feb_21_20:23:50_PST_2025
Cuda compilation tools, release 12.8, V12.8.93
Build cuda_12.8.r12.8/compiler.35583870_0
Other
I just cant get it to work. If you can give me another option to visualize or see what tensorrt actually calls for kernels (so i can verify it actually runs int8 quantized kernels) i would be grateful :)