|
| 1 | +from typing import Callable, Optional, Sequence, Union |
| 2 | + |
| 3 | +import flashinfer |
| 4 | +import torch |
| 5 | +import torch_tensorrt |
| 6 | +from torch.fx.passes.shape_prop import TensorMetadata |
| 7 | +from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( |
| 8 | + _aten_lowering_pass, |
| 9 | +) |
| 10 | +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( |
| 11 | + clean_up_graph_after_modifications, |
| 12 | +) |
| 13 | +from transformers import LlamaConfig, LlamaForCausalLM |
| 14 | + |
| 15 | + |
| 16 | +@torch.library.custom_op("flashinfer::rmsnorm", mutates_args=()) # type: ignore[misc] |
| 17 | +def flashinfer_rmsnorm( |
| 18 | + input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 |
| 19 | +) -> torch.Tensor: |
| 20 | + return flashinfer.norm.rmsnorm(input, weight) |
| 21 | + |
| 22 | + |
| 23 | +@torch.library.register_fake("flashinfer::rmsnorm") |
| 24 | +def _(input: torch.Tensor, weight: torch.Tensor, b: float = 1e-6) -> torch.Tensor: |
| 25 | + return input |
| 26 | + |
| 27 | + |
| 28 | +torch_tensorrt.dynamo.conversion.plugins.custom_op( |
| 29 | + "flashinfer::rmsnorm", supports_dynamic_shapes=True |
| 30 | +) |
| 31 | + |
| 32 | + |
| 33 | +@_aten_lowering_pass |
| 34 | +def replace_rmsnorm( |
| 35 | + gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor] |
| 36 | +) -> torch.fx.GraphModule: |
| 37 | + for node in gm.graph.nodes: |
| 38 | + if ( |
| 39 | + node.target == torch.ops.aten._to_copy.default |
| 40 | + and node.kwargs.get("dtype") is torch.float32 |
| 41 | + and len(node.users) == 2 |
| 42 | + ): |
| 43 | + if ( |
| 44 | + list(node.users)[0].target == torch.ops.aten.pow.Tensor_Scalar |
| 45 | + and list(node.users)[1].target == torch.ops.aten.mul.Tensor |
| 46 | + ): |
| 47 | + pow_node = list(node.users)[0] |
| 48 | + if ( |
| 49 | + len(pow_node.users) == 1 |
| 50 | + and list(pow_node.users)[0].target == torch.ops.aten.mean.dim |
| 51 | + ): |
| 52 | + mean_node = list(pow_node.users)[0] |
| 53 | + if ( |
| 54 | + len(mean_node.users) == 1 |
| 55 | + and list(mean_node.users)[0].target == torch.ops.aten.add.Tensor |
| 56 | + ): |
| 57 | + add_node = list(mean_node.users)[0] |
| 58 | + if ( |
| 59 | + len(add_node.users) == 1 |
| 60 | + and list(add_node.users)[0].target |
| 61 | + == torch.ops.aten.sqrt.default |
| 62 | + ): |
| 63 | + sqrt_node = list(add_node.users)[0] |
| 64 | + if ( |
| 65 | + len(sqrt_node.users) == 1 |
| 66 | + and list(sqrt_node.users)[0].target |
| 67 | + == torch.ops.aten.div.Tensor |
| 68 | + ): |
| 69 | + div_node = list(sqrt_node.users)[0] |
| 70 | + if list(div_node.users)[0] == list(node.users)[1]: |
| 71 | + mul_node = list(div_node.users)[0] |
| 72 | + copy_node = list(mul_node.users)[0] |
| 73 | + weight_mul_node = list(copy_node.users)[0] |
| 74 | + |
| 75 | + weight = weight_mul_node.args[0] |
| 76 | + |
| 77 | + original_meta = weight_mul_node.meta.get( |
| 78 | + "tensor_meta", {} |
| 79 | + ) |
| 80 | + memory_format = original_meta.memory_format |
| 81 | + |
| 82 | + with gm.graph.inserting_after(weight_mul_node): |
| 83 | + b = gm.graph.create_node( |
| 84 | + op="call_function", |
| 85 | + target=torch.ops.aten.sym_size.int, |
| 86 | + args=(node.args[0], 0), |
| 87 | + ) |
| 88 | + b.meta["tensor_meta"] = TensorMetadata( |
| 89 | + shape=torch.Size([]), |
| 90 | + dtype=torch.int64, |
| 91 | + requires_grad=False, |
| 92 | + stride=None, |
| 93 | + memory_format=memory_format, |
| 94 | + is_quantized=False, |
| 95 | + qparams={}, |
| 96 | + ) |
| 97 | + s = gm.graph.create_node( |
| 98 | + op="call_function", |
| 99 | + target=torch.ops.aten.sym_size.int, |
| 100 | + args=(node.args[0], 1), |
| 101 | + ) |
| 102 | + s.meta.update(b.meta) |
| 103 | + |
| 104 | + d = gm.graph.create_node( |
| 105 | + op="call_function", |
| 106 | + target=torch.ops.aten.sym_size.int, |
| 107 | + args=(node.args[0], 2), |
| 108 | + ) |
| 109 | + d.meta.update(b.meta) |
| 110 | + |
| 111 | + with gm.graph.inserting_after(b): |
| 112 | + new_first_dim = gm.graph.create_node( |
| 113 | + op="call_function", |
| 114 | + target=torch.ops.aten.mul.Scalar, |
| 115 | + args=(b, s), |
| 116 | + ) |
| 117 | + new_first_dim.meta.update(b.meta) |
| 118 | + |
| 119 | + with gm.graph.inserting_after(new_first_dim): |
| 120 | + # with gm.graph.inserting_after(weight_mul_node): |
| 121 | + reshape_node = gm.graph.create_node( |
| 122 | + op="call_function", |
| 123 | + target=torch.ops.aten.reshape.default, |
| 124 | + args=(node.args[0], [new_first_dim, d]), |
| 125 | + ) |
| 126 | + b_val = original_meta.shape[0] |
| 127 | + s_val = original_meta.shape[1] |
| 128 | + d_val = original_meta.shape[2] |
| 129 | + |
| 130 | + reshape_node.meta["tensor_meta"] = ( |
| 131 | + TensorMetadata( |
| 132 | + shape=torch.Size( |
| 133 | + [b_val * s_val, d_val] |
| 134 | + ), |
| 135 | + dtype=original_meta.dtype, |
| 136 | + requires_grad=True, |
| 137 | + stride=None, |
| 138 | + memory_format=memory_format, |
| 139 | + is_quantized=False, |
| 140 | + qparams={}, |
| 141 | + ) |
| 142 | + ) |
| 143 | + |
| 144 | + with gm.graph.inserting_after(reshape_node): |
| 145 | + flashinfer_rmsnorm_node = gm.graph.create_node( |
| 146 | + op="call_function", |
| 147 | + target=torch.ops.flashinfer.rmsnorm.default, |
| 148 | + args=( |
| 149 | + reshape_node, |
| 150 | + weight, |
| 151 | + add_node.args[1], |
| 152 | + ), |
| 153 | + ) |
| 154 | + flashinfer_rmsnorm_node.meta.update( |
| 155 | + reshape_node.meta |
| 156 | + ) |
| 157 | + |
| 158 | + with gm.graph.inserting_after( |
| 159 | + flashinfer_rmsnorm_node |
| 160 | + ): |
| 161 | + reshapback_node = gm.graph.create_node( |
| 162 | + op="call_function", |
| 163 | + target=torch.ops.aten.reshape.default, |
| 164 | + args=( |
| 165 | + flashinfer_rmsnorm_node, |
| 166 | + [b, s, d], |
| 167 | + ), |
| 168 | + ) |
| 169 | + |
| 170 | + weight_mul_node.replace_all_uses_with( |
| 171 | + reshapback_node |
| 172 | + ) |
| 173 | + reshapback_node.meta.update(weight_mul_node.meta) |
| 174 | + |
| 175 | + modified_graph = True |
| 176 | + |
| 177 | + gm.graph.erase_node(weight_mul_node) |
| 178 | + gm.graph.erase_node(copy_node) |
| 179 | + gm.graph.erase_node(mul_node) |
| 180 | + gm.graph.erase_node(div_node) |
| 181 | + gm.graph.erase_node(sqrt_node) |
| 182 | + gm.graph.erase_node(add_node) |
| 183 | + gm.graph.erase_node(mean_node) |
| 184 | + gm.graph.erase_node(pow_node) |
| 185 | + gm.graph.erase_node(node) |
| 186 | + |
| 187 | + if modified_graph: |
| 188 | + gm = clean_up_graph_after_modifications(gm) |
| 189 | + |
| 190 | + return gm |
| 191 | + |
| 192 | + |
| 193 | +# 1. Create a custom config with 1 layer |
| 194 | +config = LlamaConfig( |
| 195 | + vocab_size=32000, |
| 196 | + hidden_size=4096, # LLaMA2-7B dimensions |
| 197 | + intermediate_size=11008, # FFN hidden_dim = 4 * 4096 * 0.7 (SwiGLU scaling) |
| 198 | + num_hidden_layers=1, # Only 1 decoder layer |
| 199 | + num_attention_heads=32, |
| 200 | + max_position_embeddings=4096, |
| 201 | + use_cache=False, # Disable KV caching for export |
| 202 | +) |
| 203 | + |
| 204 | +# 2. Initialize model (random weights) |
| 205 | +with torch.no_grad(): |
| 206 | + model = LlamaForCausalLM(config).eval().half() |
| 207 | + |
| 208 | +# 3. Export with static shapes |
| 209 | +input_ids = torch.randint(0, 32000, (1, 64)) # Static [batch=1, seq=64] |
| 210 | +exported = torch.export.export( |
| 211 | + model, |
| 212 | + (input_ids,), |
| 213 | + dynamic_shapes=None, # Fully static |
| 214 | +) |
| 215 | + |
| 216 | +# Test forward pass |
| 217 | +input_ids = torch.randint(0, 32000, (1, 64)) |
| 218 | +output = model(input_ids) |
| 219 | +print(output) |
| 220 | + |
| 221 | +# Export validation |
| 222 | + |
| 223 | +DEVICE = torch.device("cuda:0") |
| 224 | + |
| 225 | +with torch_tensorrt.logging.errors(): |
| 226 | + trt_model = torch_tensorrt.dynamo.compile( |
| 227 | + exported, |
| 228 | + inputs=[input_ids], |
| 229 | + enabled_precisions={torch.float32, torch.float16}, |
| 230 | + truncate_double=True, |
| 231 | + device=DEVICE, |
| 232 | + disable_tf32=True, |
| 233 | + use_explicit_typing=False, |
| 234 | + use_fp32_acc=True, |
| 235 | + # debug=True, |
| 236 | + ) |
| 237 | + |
| 238 | +input_ids = input_ids.to(DEVICE) |
| 239 | + |
| 240 | +res = trt_model.forward(input_ids) |
| 241 | +print(res) |
0 commit comments