From b0810188f48190616dbac242b43fd8e3d5960484 Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Thu, 8 May 2025 22:29:58 +0000 Subject: [PATCH] chore: update the docstring for llama2 rmsnorm automatic plugin example --- examples/dynamo/llama2_flashinfer_rmsnorm.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/examples/dynamo/llama2_flashinfer_rmsnorm.py b/examples/dynamo/llama2_flashinfer_rmsnorm.py index 847d80238b..9d57794c6a 100644 --- a/examples/dynamo/llama2_flashinfer_rmsnorm.py +++ b/examples/dynamo/llama2_flashinfer_rmsnorm.py @@ -1,3 +1,20 @@ +""" +.._llama2_flashinfer_rmsnorm: + +Automatically generate a TensorRT Plugin for RMSNorm module and apply it in Llama2 +=================================================================== + +This example showcases how to optimize inference for a LLaMA2 model by replacing its RMSNorm layers with FlashInfer's high-performance implementation. It demonstrates the use of Torch-TensorRT's automatic plugin feature, which dynamically generates and integrates custom TensorRT plugins during compilation. + +Key features: +- Leverages automatic plugin registration for FlashInfer RMSNorm ops. +- Applies a custom TorchDynamo lowering pass to replace standard RMSNorm ops. +- Compiles the modified model using Torch-TensorRT's Dynamo path. +- Benchmarks inference performance with and without FlashInfer RMSNorm. + +This example illustrates advanced extensibility in Torch-TensorRT through automatic plugin generation and operator lowering customization. +""" + from typing import Callable, Optional, Sequence, Union import flashinfer @@ -86,7 +103,7 @@ def replace_rmsnorm( args=(node.args[0], 0), ) b.meta["tensor_meta"] = TensorMetadata( - shape=torch.Size([]), + shape=torch.Size([1]), dtype=torch.int64, requires_grad=False, stride=None,