Open
Description
Bug Description
FLUX fails to generate actual images. It has accuracy issue.
To Reproduce
import torch
import torch_tensorrt
from diffusers import FluxPipeline
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
DEVICE = "cuda:0"
enabled_precisions = {torch.float16}
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.float16,
).to(torch.float16)
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
pipe.transformer = FluxTransformer2DModel(
num_layers=1, num_single_layers=1, guidance_embeds=True
).to(torch.float16)
backbone = pipe.transformer
pipe.to(DEVICE)
batch_size = 1
settings = {
"strict": False,
"allow_complex_guards_as_runtime_asserts": True,
"enabled_precisions": enabled_precisions,
"truncate_double": True,
"min_block_size": 1,
"use_python_runtime": True,
"immutable_weights": False,
"offload_module_to_cpu": True,
}
trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings)
pipe.transformer = trt_gm
image = pipe(
"Test",
output_type="pil",
num_inference_steps=2,
num_images_per_prompt=batch_size,
).images
backbone.to(DEVICE)
inp, kwinp = trt_gm.arg_inputs, trt_gm.kwarg_inputs
trt_result = trt_gm(*inp, **kwinp)[0]
pytorch_result = backbone(*inp, **kwinp)[0]
assert torch.allclose(trt_result, pytorch_result)
Expected behavior
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0): main branch
- PyTorch Version (e.g. 1.0): nightly
- CPU Architecture:
- OS (e.g., Linux):
- How you installed PyTorch (
conda
,pip
,libtorch
, source): - Build command you used (if compiling from source):
- Are you using local sources or building from archives:
- Python version:
- CUDA version:
- GPU models and configuration:
- Any other relevant information: