|
| 1 | +from torch_tensorrt.dynamo.backend.lowering import partition |
| 2 | +from torch.testing._internal.common_utils import run_tests, TestCase |
| 3 | +import torch |
| 4 | +from copy import deepcopy |
| 5 | +from torch_tensorrt.dynamo import compile |
| 6 | +from utils import lower_graph_testing |
| 7 | +from torch_tensorrt.dynamo.common_utils.test_utils import DECIMALS_OF_AGREEMENT |
| 8 | + |
| 9 | + |
| 10 | +class TestTRTModuleNextCompilation(TestCase): |
| 11 | + def test_trt_module_next_full_support(self): |
| 12 | + class FullySupportedMultiOp(torch.nn.Module): |
| 13 | + def forward(self, x, y): |
| 14 | + out = x - y |
| 15 | + out = out + x |
| 16 | + out = 2 * out |
| 17 | + out = out + y |
| 18 | + return torch.mean(out, dim=1) |
| 19 | + |
| 20 | + fx_graph = torch.fx.symbolic_trace(FullySupportedMultiOp()) |
| 21 | + partitioned_graph = partition(deepcopy(fx_graph), min_block_size=3) |
| 22 | + |
| 23 | + self.assertEquals( |
| 24 | + len(list(partitioned_graph.named_children())), |
| 25 | + 1, |
| 26 | + "All operators are supported, there should be one segment", |
| 27 | + ) |
| 28 | + |
| 29 | + inputs = [ |
| 30 | + torch.randint(-5, 5, (16, 7), dtype=torch.float).cuda(), |
| 31 | + torch.randint(-5, 5, (16, 7), dtype=torch.float).cuda(), |
| 32 | + ] |
| 33 | + |
| 34 | + torch._dynamo.reset() |
| 35 | + |
| 36 | + # Validate that the results between Torch and Torch-TRT are similar |
| 37 | + optimized_model = compile( |
| 38 | + fx_graph, |
| 39 | + inputs, |
| 40 | + min_block_size=1, |
| 41 | + pass_through_build_failures=True, |
| 42 | + torch_executed_ops={"torch.ops.aten.add.Tensor"}, |
| 43 | + use_experimental_rt=True, |
| 44 | + debug=True, |
| 45 | + ) |
| 46 | + optimized_model_results = optimized_model(*inputs).detach().cpu() |
| 47 | + torch_model_results = fx_graph(*inputs).detach().cpu() |
| 48 | + |
| 49 | + max_diff = float( |
| 50 | + torch.max(torch.abs(optimized_model_results - torch_model_results)) |
| 51 | + ) |
| 52 | + self.assertAlmostEqual( |
| 53 | + max_diff, |
| 54 | + 0, |
| 55 | + DECIMALS_OF_AGREEMENT, |
| 56 | + f"TRT outputs don't match with the original model.", |
| 57 | + ) |
| 58 | + |
| 59 | + def test_trt_module_next_partial_support(self): |
| 60 | + class PartiallySupportedMultiOp(torch.nn.Module): |
| 61 | + def forward(self, x, y): |
| 62 | + out = x - y |
| 63 | + out = out - 3 * x |
| 64 | + out = out + y |
| 65 | + out = out.to(torch.float) |
| 66 | + out = 2 * out |
| 67 | + return torch.mean(out, dim=-1) |
| 68 | + |
| 69 | + fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp()) |
| 70 | + unexpected_ops = {torch.ops.aten.add.Tensor} |
| 71 | + |
| 72 | + inputs = [ |
| 73 | + torch.randint(-40, 40, (16, 7, 5), dtype=torch.int).cuda(), |
| 74 | + torch.randint(1, 40, (16, 7, 5), dtype=torch.int).cuda(), |
| 75 | + ] |
| 76 | + |
| 77 | + (unexpected_ops_seen, _, partitioned_graphs,) = lower_graph_testing( |
| 78 | + fx_graph, |
| 79 | + inputs, |
| 80 | + unexpected_ops=unexpected_ops, |
| 81 | + min_block_size=1, |
| 82 | + torch_executed_ops={"torch.ops.aten.add.Tensor"}, |
| 83 | + testing_partitioning=True, |
| 84 | + ) |
| 85 | + |
| 86 | + self.assertEquals( |
| 87 | + len(unexpected_ops_seen), |
| 88 | + 0, |
| 89 | + f"The following unexpected ops were encountered: {unexpected_ops_seen}", |
| 90 | + ) |
| 91 | + self.assertEquals( |
| 92 | + len(partitioned_graphs), |
| 93 | + 1, |
| 94 | + "Without control flow breaks, there should only be a single graph", |
| 95 | + ) |
| 96 | + self.assertEquals( |
| 97 | + len(list(partitioned_graphs[0].named_children())), |
| 98 | + 2, |
| 99 | + "Certain operators are set to run in Torch, expected 2 segments", |
| 100 | + ) |
| 101 | + |
| 102 | + torch._dynamo.reset() |
| 103 | + |
| 104 | + # Validate that the results between Torch and Torch-TRT are similar |
| 105 | + optimized_model = compile( |
| 106 | + fx_graph, |
| 107 | + inputs, |
| 108 | + min_block_size=1, |
| 109 | + pass_through_build_failures=True, |
| 110 | + torch_executed_ops={"torch.ops.aten.add.Tensor"}, |
| 111 | + use_experimental_rt=True, |
| 112 | + debug=True, |
| 113 | + ) |
| 114 | + optimized_model_results = optimized_model(*inputs).detach().cpu() |
| 115 | + torch_model_results = fx_graph(*inputs).detach().cpu() |
| 116 | + |
| 117 | + max_diff = float( |
| 118 | + torch.max(torch.abs(optimized_model_results - torch_model_results)) |
| 119 | + ) |
| 120 | + self.assertAlmostEqual( |
| 121 | + max_diff, |
| 122 | + 0, |
| 123 | + DECIMALS_OF_AGREEMENT, |
| 124 | + f"TRT outputs don't match with the original model.", |
| 125 | + ) |
| 126 | + |
| 127 | + |
| 128 | +class TestCompilationOptions(TestCase): |
| 129 | + def test_trt_specific_options(self): |
| 130 | + class SupportedMultiOp(torch.nn.Module): |
| 131 | + def forward(self, x, y): |
| 132 | + out = x - y |
| 133 | + out = out - 3 * x |
| 134 | + out = out + y |
| 135 | + out = out - y / 5 |
| 136 | + out = 2 * out |
| 137 | + return torch.mean(out, dim=-1) |
| 138 | + |
| 139 | + fx_graph = torch.fx.symbolic_trace(SupportedMultiOp()) |
| 140 | + |
| 141 | + inputs = [ |
| 142 | + torch.randint(-40, 40, (16, 7, 5), dtype=torch.float).cuda(), |
| 143 | + torch.randint(1, 40, (16, 7, 5), dtype=torch.float).cuda(), |
| 144 | + ] |
| 145 | + |
| 146 | + # Validate that the results between Torch and Torch-TRT are similar |
| 147 | + optimized_model = compile( |
| 148 | + fx_graph, |
| 149 | + inputs, |
| 150 | + min_block_size=1, |
| 151 | + pass_through_build_failures=True, |
| 152 | + use_experimental_rt=True, |
| 153 | + optimization_level=4, |
| 154 | + version_compatible=True, |
| 155 | + max_aux_streams=5, |
| 156 | + debug=True, |
| 157 | + ) |
| 158 | + optimized_model_results = optimized_model(*inputs).detach().cpu() |
| 159 | + torch_model_results = fx_graph(*inputs).detach().cpu() |
| 160 | + |
| 161 | + max_diff = float( |
| 162 | + torch.max(torch.abs(optimized_model_results - torch_model_results)) |
| 163 | + ) |
| 164 | + self.assertAlmostEqual( |
| 165 | + max_diff, |
| 166 | + 0, |
| 167 | + DECIMALS_OF_AGREEMENT, |
| 168 | + f"TRT outputs don't match with the original model.", |
| 169 | + ) |
| 170 | + |
| 171 | + |
| 172 | +if __name__ == "__main__": |
| 173 | + run_tests() |
0 commit comments