Skip to content

🐛 [Bug] dynamo.compile fails in 2.9.0 nightly when mqt quantization is used #3659

Open
@robertos99

Description

@robertos99

Bug Description

I tested the 2.9.0 nightly to use the debugger to plot the tensorrt graph produced by torch_tensorrt.dynamo.compile. However running compile on a model quantized by nvidia-modelopt mqt.quantize fails with
10:48:56 - ERROR - [layersImpl.cpp::validateQuantizationHelper::362] Error Code 2: API Usage Error ((Unnamed Layer* 6) [Quantize]: ScaleMode is illegal.)

Running the same exact code but with the latest release causes no issues. Since neither torch, tensorrt or torch_tensorrt changed in major version i assume this isnt intended?

Besides that are there other options to understand what the tensorrt compiles to? So i can verify it actually calls int8 quantized kernels?

To Reproduce

Steps to reproduce the behavior:

  1. install 2.9.0 nightlys etc.
pip install --pre 'torch>=2.9.0.dev' torchvision --index-url https://download.pytorch.org/whl/nightly/cu128

pip install --pre nvidia-modelopt[torch,onnx] --extra-index-url https://download.pytorch.org/whl/nightly/cu128

pip install 'tensorrt>=10.11.0,<10.12.0'

pip install --pre 'torch-tensorrt==2.9.0.dev20250706+cu128' --extra-index-url https://download.pytorch.org/whl/nightly/cu128
  1. run provided code
import argparse


import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.models as models
from tqdm import tqdm
import torch_tensorrt as torchtrt

import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.utils import export_torch_mode



def get_dataloaders(batch_size):
    transform_test = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    testset = datasets.CIFAR10(root="./data",
                               train=False,
                               download=True,
                               transform=transform_test)
    return torch.utils.data.DataLoader(testset,
                                       batch_size=batch_size,
                                       shuffle=False,
                                       num_workers=8,
                                       drop_last=True)


def main():
    parser = argparse.ArgumentParser(
        description="quantized ResNet18")
    parser.add_argument("--batch-size",
                        default=32,
                        type=int,
                        help="Batch size for calibration and inference")
    parser.add_argument("--quantize-type",
                        default="int8",
                        choices=["int8", "fp8"],
                        help="Quantization type (currently supports int8)")
    args = parser.parse_args()

    device = torch.device("cuda:0"
                          if torch.cuda.is_available() else "cpu")

    testloader_calib = get_dataloaders(args.batch_size)
    testloader_dummy = get_dataloaders(args.batch_size)



    model = models.resnet18(pretrained=True)
    # Replace final layer for CIFAR-10:
    model.fc = nn.Linear(model.fc.in_features, 10)
    model = model.to(device)
    model.eval()

    quant_cfg = mtq.INT8_DEFAULT_CFG
    def forward_loop(m):
        for images, _ in tqdm(testloader_calib):
            m(images.to(device))

    mtq.quantize(model, quant_cfg, forward_loop=forward_loop)


    images, _ = next(iter(testloader_dummy))
    with torch.no_grad(), export_torch_mode():
        input_tensor = images.to(device)
        model.eval()
        exp_program = torch.export.export(model, (input_tensor,),
                                          strict=False)

        with torchtrt.logging.graphs(), torchtrt.logging.debug():
            enabled_precisions = {torch.int8}
            trt_model = torchtrt.dynamo.compile(
                exp_program,
                inputs=[input_tensor],
                enabled_precisions=enabled_precisions,
                min_block_size=1,
                require_full_compilation=True,
                disable_tf32=True,
                use_python_runtime=False,
            )

            out = trt_model(input_tensor)


if __name__ == "__main__":
    main()

Errorlog:

10:48:56 - INFO - Not found cached TRT engines. Start building engine.
10:48:56 - ERROR - [layersImpl.cpp::validateQuantizationHelper::362] Error Code 2: API Usage Error ((Unnamed Layer* 6) [Quantize]: ScaleMode is illegal.)
10:48:56 - ERROR - [layersImpl.cpp::validateQuantizationHelper::362] Error Code 2: API Usage Error ((Unnamed Layer* 6) [Quantize]: ScaleMode is illegal.)
Traceback (most recent call last):
  File "/workspace/t.py", line 190, in <module>
    main()
  File "/workspace/t.py", line 160, in main
    trt_model = torchtrt.dynamo.compile(
                ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/venv/main/lib/python3.12/site-packages/torch_tensorrt/dynamo/_compiler.py", line 720, in compile
    trt_gm = compile_module(
             ^^^^^^^^^^^^^^^
  File "/venv/main/lib/python3.12/site-packages/torch_tensorrt/dynamo/_compiler.py", line 950, in compile_module
    trt_module = convert_module(
                 ^^^^^^^^^^^^^^^
  File "/venv/main/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 88, in convert_module
    interpreter_result = interpret_module_to_result(
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/venv/main/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 67, in interpret_module_to_result
    interpreter_result = interpreter.run()
                         ^^^^^^^^^^^^^^^^^
  File "/venv/main/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 742, in run
    assert serialized_engine
           ^^^^^^^^^^^^^^^^^
AssertionError

Expected behavior

Running the same code just with 2.7.0 release works. Tho I am not sure if the model is compiled "correctly" as I expect it. (This is why i wanted to get the debugger running in the first place, to see if i actually get quantized kernels)

dependencies:

tensorrt                 10.9.0.34
tensorrt_cu12            10.9.0.34
tensorrt_cu12_bindings   10.9.0.34
tensorrt_cu12_libs       10.9.0.34
torch                    2.7.1
torch_tensorrt           2.7.0
torchprofile             0.0.4
torchvision              0.22.1

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

OS: Linux x86_64 Ubuntu 24.04

CUDA version

(main) [email protected]:/workspace$ nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2025 NVIDIA Corporation
Built on Fri_Feb_21_20:23:50_PST_2025
Cuda compilation tools, release 12.8, V12.8.93
Build cuda_12.8.r12.8/compiler.35583870_0

Other

I just cant get it to work. If you can give me another option to visualize or see what tensorrt actually calls for kernels (so i can verify it actually runs int8 quantized kernels) i would be grateful :)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions