diff --git a/examples/dynamo/llama2_flashinfer_rmsnorm.py b/examples/dynamo/llama2_flashinfer_rmsnorm.py index 847d80238b..9d57794c6a 100644 --- a/examples/dynamo/llama2_flashinfer_rmsnorm.py +++ b/examples/dynamo/llama2_flashinfer_rmsnorm.py @@ -1,3 +1,20 @@ +""" +.._llama2_flashinfer_rmsnorm: + +Automatically generate a TensorRT Plugin for RMSNorm module and apply it in Llama2 +=================================================================== + +This example showcases how to optimize inference for a LLaMA2 model by replacing its RMSNorm layers with FlashInfer's high-performance implementation. It demonstrates the use of Torch-TensorRT's automatic plugin feature, which dynamically generates and integrates custom TensorRT plugins during compilation. + +Key features: +- Leverages automatic plugin registration for FlashInfer RMSNorm ops. +- Applies a custom TorchDynamo lowering pass to replace standard RMSNorm ops. +- Compiles the modified model using Torch-TensorRT's Dynamo path. +- Benchmarks inference performance with and without FlashInfer RMSNorm. + +This example illustrates advanced extensibility in Torch-TensorRT through automatic plugin generation and operator lowering customization. +""" + from typing import Callable, Optional, Sequence, Union import flashinfer @@ -86,7 +103,7 @@ def replace_rmsnorm( args=(node.args[0], 0), ) b.meta["tensor_meta"] = TensorMetadata( - shape=torch.Size([]), + shape=torch.Size([1]), dtype=torch.int64, requires_grad=False, stride=None,