Skip to content

Commit 3aa76ab

Browse files
authored
chore: update the docstring for llama2 rmsnorm automatic plugin example (#3512)
1 parent c3ad86c commit 3aa76ab

File tree

1 file changed

+18
-1
lines changed

1 file changed

+18
-1
lines changed

examples/dynamo/llama2_flashinfer_rmsnorm.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,20 @@
1+
"""
2+
.._llama2_flashinfer_rmsnorm:
3+
4+
Automatically generate a TensorRT Plugin for RMSNorm module and apply it in Llama2
5+
===================================================================
6+
7+
This example showcases how to optimize inference for a LLaMA2 model by replacing its RMSNorm layers with FlashInfer's high-performance implementation. It demonstrates the use of Torch-TensorRT's automatic plugin feature, which dynamically generates and integrates custom TensorRT plugins during compilation.
8+
9+
Key features:
10+
- Leverages automatic plugin registration for FlashInfer RMSNorm ops.
11+
- Applies a custom TorchDynamo lowering pass to replace standard RMSNorm ops.
12+
- Compiles the modified model using Torch-TensorRT's Dynamo path.
13+
- Benchmarks inference performance with and without FlashInfer RMSNorm.
14+
15+
This example illustrates advanced extensibility in Torch-TensorRT through automatic plugin generation and operator lowering customization.
16+
"""
17+
118
from typing import Callable, Optional, Sequence, Union
219

320
import flashinfer
@@ -86,7 +103,7 @@ def replace_rmsnorm(
86103
args=(node.args[0], 0),
87104
)
88105
b.meta["tensor_meta"] = TensorMetadata(
89-
shape=torch.Size([]),
106+
shape=torch.Size([1]),
90107
dtype=torch.int64,
91108
requires_grad=False,
92109
stride=None,

0 commit comments

Comments
 (0)