Skip to content

Commit 29b65b0

Browse files
committed
feat: support lowering rmsnorm module to flashinfer.rmsnorm
This commit supports lowering rmsnorm module to flashinfer.rmsnorm. The example can be found in the PR which demonstrates how to lower the rmsnorm in pytorch to flashinfer.rmsnorm and run the lowered node using flashinfer library based on automatic plugin feature. This PR also resolves unique ID issue when creating constant layers.
1 parent b97237c commit 29b65b0

File tree

3 files changed

+262
-7
lines changed

3 files changed

+262
-7
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
from typing import Callable, Optional, Sequence, Union
2+
3+
import flashinfer
4+
import torch
5+
import torch_tensorrt
6+
from torch.fx.passes.shape_prop import TensorMetadata
7+
from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import (
8+
_aten_lowering_pass,
9+
)
10+
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
11+
clean_up_graph_after_modifications,
12+
)
13+
from transformers import LlamaConfig, LlamaForCausalLM
14+
15+
16+
@torch.library.custom_op("flashinfer::rmsnorm", mutates_args=()) # type: ignore[misc]
17+
def flashinfer_rmsnorm(
18+
input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
19+
) -> torch.Tensor:
20+
return flashinfer.norm.rmsnorm(input, weight)
21+
22+
23+
@torch.library.register_fake("flashinfer::rmsnorm")
24+
def _(input: torch.Tensor, weight: torch.Tensor, b: float = 1e-6) -> torch.Tensor:
25+
return input
26+
27+
28+
torch_tensorrt.dynamo.conversion.plugins.custom_op(
29+
"flashinfer::rmsnorm", supports_dynamic_shapes=True
30+
)
31+
32+
33+
@_aten_lowering_pass
34+
def replace_rmsnorm(
35+
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
36+
) -> torch.fx.GraphModule:
37+
for node in gm.graph.nodes:
38+
if (
39+
node.target == torch.ops.aten._to_copy.default
40+
and node.kwargs.get("dtype") is torch.float32
41+
and len(node.users) == 2
42+
):
43+
if (
44+
list(node.users)[0].target == torch.ops.aten.pow.Tensor_Scalar
45+
and list(node.users)[1].target == torch.ops.aten.mul.Tensor
46+
):
47+
pow_node = list(node.users)[0]
48+
if (
49+
len(pow_node.users) == 1
50+
and list(pow_node.users)[0].target == torch.ops.aten.mean.dim
51+
):
52+
mean_node = list(pow_node.users)[0]
53+
if (
54+
len(mean_node.users) == 1
55+
and list(mean_node.users)[0].target == torch.ops.aten.add.Tensor
56+
):
57+
add_node = list(mean_node.users)[0]
58+
if (
59+
len(add_node.users) == 1
60+
and list(add_node.users)[0].target
61+
== torch.ops.aten.sqrt.default
62+
):
63+
sqrt_node = list(add_node.users)[0]
64+
if (
65+
len(sqrt_node.users) == 1
66+
and list(sqrt_node.users)[0].target
67+
== torch.ops.aten.div.Tensor
68+
):
69+
div_node = list(sqrt_node.users)[0]
70+
if list(div_node.users)[0] == list(node.users)[1]:
71+
mul_node = list(div_node.users)[0]
72+
copy_node = list(mul_node.users)[0]
73+
weight_mul_node = list(copy_node.users)[0]
74+
75+
weight = weight_mul_node.args[0]
76+
77+
original_meta = weight_mul_node.meta.get(
78+
"tensor_meta", {}
79+
)
80+
memory_format = original_meta.memory_format
81+
82+
with gm.graph.inserting_after(weight_mul_node):
83+
b = gm.graph.create_node(
84+
op="call_function",
85+
target=torch.ops.aten.sym_size.int,
86+
args=(node.args[0], 0),
87+
)
88+
b.meta["tensor_meta"] = TensorMetadata(
89+
shape=torch.Size([]),
90+
dtype=torch.int64,
91+
requires_grad=False,
92+
stride=None,
93+
memory_format=memory_format,
94+
is_quantized=False,
95+
qparams={},
96+
)
97+
s = gm.graph.create_node(
98+
op="call_function",
99+
target=torch.ops.aten.sym_size.int,
100+
args=(node.args[0], 1),
101+
)
102+
s.meta.update(b.meta)
103+
104+
d = gm.graph.create_node(
105+
op="call_function",
106+
target=torch.ops.aten.sym_size.int,
107+
args=(node.args[0], 2),
108+
)
109+
d.meta.update(b.meta)
110+
111+
with gm.graph.inserting_after(b):
112+
new_first_dim = gm.graph.create_node(
113+
op="call_function",
114+
target=torch.ops.aten.mul.Scalar,
115+
args=(b, s),
116+
)
117+
new_first_dim.meta.update(b.meta)
118+
119+
with gm.graph.inserting_after(new_first_dim):
120+
# with gm.graph.inserting_after(weight_mul_node):
121+
reshape_node = gm.graph.create_node(
122+
op="call_function",
123+
target=torch.ops.aten.reshape.default,
124+
args=(node.args[0], [new_first_dim, d]),
125+
)
126+
b_val = original_meta.shape[0]
127+
s_val = original_meta.shape[1]
128+
d_val = original_meta.shape[2]
129+
130+
reshape_node.meta["tensor_meta"] = (
131+
TensorMetadata(
132+
shape=torch.Size(
133+
[b_val * s_val, d_val]
134+
),
135+
dtype=original_meta.dtype,
136+
requires_grad=True,
137+
stride=None,
138+
memory_format=memory_format,
139+
is_quantized=False,
140+
qparams={},
141+
)
142+
)
143+
144+
with gm.graph.inserting_after(reshape_node):
145+
flashinfer_rmsnorm_node = gm.graph.create_node(
146+
op="call_function",
147+
target=torch.ops.flashinfer.rmsnorm.default,
148+
args=(
149+
reshape_node,
150+
weight,
151+
add_node.args[1],
152+
),
153+
)
154+
flashinfer_rmsnorm_node.meta.update(
155+
reshape_node.meta
156+
)
157+
158+
with gm.graph.inserting_after(
159+
flashinfer_rmsnorm_node
160+
):
161+
reshapback_node = gm.graph.create_node(
162+
op="call_function",
163+
target=torch.ops.aten.reshape.default,
164+
args=(
165+
flashinfer_rmsnorm_node,
166+
[b, s, d],
167+
),
168+
)
169+
170+
weight_mul_node.replace_all_uses_with(
171+
reshapback_node
172+
)
173+
reshapback_node.meta.update(weight_mul_node.meta)
174+
175+
modified_graph = True
176+
177+
gm.graph.erase_node(weight_mul_node)
178+
gm.graph.erase_node(copy_node)
179+
gm.graph.erase_node(mul_node)
180+
gm.graph.erase_node(div_node)
181+
gm.graph.erase_node(sqrt_node)
182+
gm.graph.erase_node(add_node)
183+
gm.graph.erase_node(mean_node)
184+
gm.graph.erase_node(pow_node)
185+
gm.graph.erase_node(node)
186+
187+
if modified_graph:
188+
gm = clean_up_graph_after_modifications(gm)
189+
190+
return gm
191+
192+
193+
# 1. Create a custom config with 1 layer
194+
config = LlamaConfig(
195+
vocab_size=32000,
196+
hidden_size=4096, # LLaMA2-7B dimensions
197+
intermediate_size=11008, # FFN hidden_dim = 4 * 4096 * 0.7 (SwiGLU scaling)
198+
num_hidden_layers=1, # Only 1 decoder layer
199+
num_attention_heads=32,
200+
max_position_embeddings=4096,
201+
use_cache=False, # Disable KV caching for export
202+
)
203+
204+
# 2. Initialize model (random weights)
205+
with torch.no_grad():
206+
model = LlamaForCausalLM(config).eval().half()
207+
208+
# 3. Export with static shapes
209+
input_ids = torch.randint(0, 32000, (1, 64)) # Static [batch=1, seq=64]
210+
exported = torch.export.export(
211+
model,
212+
(input_ids,),
213+
dynamic_shapes=None, # Fully static
214+
)
215+
216+
# Test forward pass
217+
input_ids = torch.randint(0, 32000, (1, 64))
218+
output = model(input_ids)
219+
print(output)
220+
221+
# Export validation
222+
223+
DEVICE = torch.device("cuda:0")
224+
225+
with torch_tensorrt.logging.errors():
226+
trt_model = torch_tensorrt.dynamo.compile(
227+
exported,
228+
inputs=[input_ids],
229+
enabled_precisions={torch.float32, torch.float16},
230+
truncate_double=True,
231+
device=DEVICE,
232+
disable_tf32=True,
233+
use_explicit_typing=False,
234+
use_fp32_acc=True,
235+
# debug=True,
236+
)
237+
238+
input_ids = input_ids.to(DEVICE)
239+
240+
res = trt_model.forward(input_ids)
241+
print(res)

py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import itertools
12
import logging
23
from types import FunctionType
34
from typing import Any, Callable, Tuple
@@ -108,7 +109,6 @@ def generate_signature(
108109

109110
def _generic_plugin_desc(*args: Any, **kwargs: Any) -> Tuple[trtp.TensorDesc]:
110111
shape_env = ShapeEnv()
111-
fake_mode = FakeTensorMode(shape_env=shape_env)
112112
syms_args = []
113113
tensor_args = [elem for elem in args if isinstance(elem, trtp.TensorDesc)]
114114

@@ -121,7 +121,7 @@ def _generic_plugin_desc(*args: Any, **kwargs: Any) -> Tuple[trtp.TensorDesc]:
121121
]
122122
syms_args.append(syms_arg)
123123

124-
with FakeTensorMode() as fake_mode:
124+
with FakeTensorMode(shape_env=shape_env) as fake_mode:
125125
fake_args = []
126126
for syms_arg in syms_args:
127127
fake_arg = torch.randn(syms_arg)
@@ -130,16 +130,25 @@ def _generic_plugin_desc(*args: Any, **kwargs: Any) -> Tuple[trtp.TensorDesc]:
130130
output = torch_op(*fake_args, **kwargs)
131131

132132
# We assume that number of dimensions are the same in torch op
133-
shape_calc_fns = [None] * args[0].ndim
134-
for i in range(args[0].ndim):
135-
input_node_expr = [syms_arg[i].node.expr for syms_arg in syms_args]
133+
shape_calc_fns = [None] * output.ndim
134+
135+
for i in range(output.ndim):
136+
input_node_expr = list(
137+
itertools.chain.from_iterable(
138+
[sym.node.expr for sym in syms_arg] for syms_arg in syms_args
139+
)
140+
)
141+
136142
shape_calc_fns[i] = lambdify(
137143
tuple(input_node_expr), output.shape[i].node.expr, "math"
138144
)
139145

140146
out_desc = tensor_args[0].like()
141147
for i in range(out_desc.ndim):
142-
input_shape_expr = [tensor_arg.shape_expr[i] for tensor_arg in tensor_args]
148+
input_shape_expr = list(
149+
itertools.chain.from_iterable(arg.shape_expr for arg in tensor_args)
150+
)
151+
143152
if output.shape[i].node.expr is None:
144153
raise ValueError(f"output.shape[{i}].node.expr cannot be None")
145154
out_desc.shape_expr[i] = shape_calc_fns[i](*input_shape_expr) # type: ignore[misc]

py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import uuid
23
from typing import Callable, Dict, Optional, Sequence, Tuple, Union
34

45
import numpy as np
@@ -47,11 +48,15 @@ def custom_kernel_converter(
4748
kwargs: Dict[str, Argument],
4849
name: str,
4950
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
51+
5052
plugin = getattr(getattr(trtp.op, namespace), op_name)
53+
5154
tensor_inputs = plugin.input_tensor_names
5255
tensor_args = args[0 : len(tensor_inputs)]
56+
57+
unique_id = uuid.uuid4()
5358
itensor_args = [
54-
get_trt_tensor(ctx, t, f"{t_name}")
59+
get_trt_tensor(ctx, t, f"{t_name}_{unique_id}")
5560
for (t, t_name) in zip(tensor_args, tensor_inputs)
5661
]
5762

0 commit comments

Comments
 (0)