Skip to content

Commit f91fc25

Browse files
author
Wei Wei
committed
[fx2trt] setitem pass improvement, disable isinf and embedding (#86)
Summary: Pull Request resolved: https://github.com/pytorch/fx2trt/pull/86 1. delete dead code should be used cautiously. We need to keep the other inplace op for future optimization pass. 2. add support for type 3. disable isinf converter since the implementation will generate inaccurate results in hf_T5 model. The reason is not very clear 4. disable embedding converter for a few reasons 1) embedding in hf models are all int64 indices which will skip this op anyway 2)if the indices are constant node, it should throw error and this whole subgraph will fall back to non-TRT. So it is better uncommented and leave it to future. hf models in torchbench are summarized below. 3 models do not work well for TRT with speedup close 1x since major subgraph fall back to non-TRT. Reviewed By: 842974287 Differential Revision: D36731186 fbshipit-source-id: 8f56fb875419b9a3f03432f0ef5c00fdfc6b6741
1 parent cffd9c2 commit f91fc25

File tree

7 files changed

+171
-100
lines changed

7 files changed

+171
-100
lines changed

fx/converters/acc_ops_converters.py

+101-97
Original file line numberDiff line numberDiff line change
@@ -1677,35 +1677,41 @@ def acc_ops_logical_xor(
16771677
)
16781678

16791679

1680-
@tensorrt_converter(acc_ops.isinf)
1681-
def acc_ops_isinf(
1682-
network: TRTNetwork,
1683-
target: Target,
1684-
args: Tuple[Argument, ...],
1685-
kwargs: Dict[str, Argument],
1686-
name: str,
1687-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1688-
input_t = kwargs["input"]
1689-
if not isinstance(input_t, TRTTensor):
1690-
raise RuntimeError(
1691-
f"isinf received input {input_t} that is not part "
1692-
"of the TensorRT region!"
1693-
)
1694-
inf_t = torch.ones(tuple(input_t.shape))
1695-
inf_t = inf_t * float("inf")
1696-
inf_t = get_trt_tensor(network, inf_t, f"{name}_inf_t")
1697-
1698-
ninf_t = torch.ones(tuple(input_t.shape))
1699-
ninf_t = ninf_t * float("-inf")
1700-
ninf_t = get_trt_tensor(network, ninf_t, f"{name}_ninf_t")
1701-
1702-
kwargs_new = {"input": input_t, "other": inf_t}
1703-
inf_output = acc_ops_eq(network, target, None, kwargs_new, name + "_compare_inf")
1704-
kwargs_new = {"input": input_t, "other": ninf_t}
1705-
ninf_output = acc_ops_eq(network, target, None, kwargs_new, name + "_compare_ninf")
1706-
kwargs_new = {"input": inf_output, "other": ninf_output}
1707-
output = acc_ops_logical_or(network, target, None, kwargs_new, name + "_compare")
1708-
return output
1680+
# T113156424 Have some accuracy problems in hf_T5.
1681+
# [TRT] [W] Weights [name=isinf_1_inf_t]: Converted FP32 value in weights (either FP32 infinity or FP32 value outside FP16 range) to corresponding FP16 infinity. If this is not the desired behavior, please modify the weights or retrain with regularization to reduce the magnitude of the weights.
1682+
# @tensorrt_converter(acc_ops.isinf)
1683+
# def acc_ops_isinf(
1684+
# network: TRTNetwork,
1685+
# target: Target,
1686+
# args: Tuple[Argument, ...],
1687+
# kwargs: Dict[str, Argument],
1688+
# name: str,
1689+
# ) -> Union[TRTTensor, Sequence[TRTTensor]]:
1690+
# input_t = kwargs["input"]
1691+
# if not isinstance(input_t, TRTTensor):
1692+
# raise RuntimeError(
1693+
# f"isinf received input {input_t} that is not part "
1694+
# "of the TensorRT region!"
1695+
# )
1696+
# tdtype = torch_dtype_from_trt(input_t.dtype)
1697+
1698+
# inf_t = torch.ones(tuple(input_t.shape))
1699+
# inf_t = inf_t * float("inf")
1700+
# inf_t = inf_t.to(tdtype)
1701+
# inf_t = get_trt_tensor(network, inf_t, f"{name}_inf_t")
1702+
1703+
# ninf_t = torch.ones(tuple(input_t.shape))
1704+
# ninf_t = ninf_t * float("-inf")
1705+
# ninf_t = ninf_t.to(tdtype)
1706+
# ninf_t = get_trt_tensor(network, ninf_t, f"{name}_ninf_t")
1707+
1708+
# kwargs_new = {"input": input_t, "other": inf_t}
1709+
# inf_output = acc_ops_eq(network, target, None, kwargs_new, name + "_compare_inf")
1710+
# kwargs_new = {"input": input_t, "other": ninf_t}
1711+
# ninf_output = acc_ops_eq(network, target, None, kwargs_new, name + "_compare_ninf")
1712+
# kwargs_new = {"input": inf_output, "other": ninf_output}
1713+
# output = acc_ops_logical_or(network, target, None, kwargs_new, name + "_compare")
1714+
# return output
17091715

17101716

17111717
@tensorrt_converter(acc_ops.any)
@@ -1785,68 +1791,70 @@ def acc_ops_fmod(
17851791
return sub_value
17861792

17871793

1788-
@tensorrt_converter(acc_ops.embedding, no_implicit_batch_dim=True)
1789-
def acc_ops_embedding(
1790-
network: TRTNetwork,
1791-
target: Target,
1792-
args: Tuple[Argument, ...],
1793-
kwargs: Dict[str, Argument],
1794-
name: str,
1795-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1796-
if network.has_implicit_batch_dimension:
1797-
raise RuntimeError(
1798-
"The `embedding` function should be called with explicit batch dimension."
1799-
)
1800-
1801-
indices_tensor = kwargs["input"]
1802-
embedding_tensor = kwargs["weight"]
1803-
if isinstance(indices_tensor, torch.Tensor) and indices_tensor.dtype == torch.int64:
1804-
indices_tensor = indices_tensor.to(torch.int32)
1805-
warnings.warn(
1806-
"Embedding op has indices_tensor dtype=int64. Reduce it to int32 to run on TRT. Accuracy may not be correct!"
1807-
)
1808-
if (
1809-
isinstance(embedding_tensor, torch.Tensor)
1810-
and embedding_tensor.dtype == torch.int64
1811-
):
1812-
embedding_tensor = embedding_tensor.to(torch.int32)
1813-
warnings.warn(
1814-
"Embedding op has embedding_tensor dtype=int64. Reduce it to int32 to run on TRT. Accuracy may not be correct!"
1815-
)
1816-
indices_tensor = get_trt_tensor(network, indices_tensor, f"{name}_indices_tensor")
1817-
embedding_tensor = get_trt_tensor(
1818-
network, embedding_tensor, f"{name}_embedding_tensor"
1819-
)
1820-
1821-
# unsupported parameters
1822-
# ignore padding_idx since it is meaningful for training only
1823-
max_norm = kwargs["max_norm"]
1824-
norm_type = kwargs["norm_type"]
1825-
scale_grad_by_freq = kwargs["scale_grad_by_freq"]
1826-
sparse = kwargs["sparse"]
1827-
1828-
if max_norm is not None:
1829-
raise RuntimeError(
1830-
f"Currently we don't support specifying max_norm, got {max_norm}."
1831-
)
1832-
1833-
if norm_type != 2.0:
1834-
raise RuntimeError(
1835-
f"Currently we don't support specifying max_norm, got {norm_type} for norm_type."
1836-
)
1837-
1838-
if scale_grad_by_freq:
1839-
raise RuntimeError(
1840-
"Currently we don't support scale gradient by word frequency."
1841-
)
1842-
1843-
if sparse:
1844-
raise RuntimeError("Currently we don't support sparse gradient.")
1845-
1846-
# Implement embedding lookup with gather layer
1847-
gather_layer = network.add_gather(embedding_tensor, indices_tensor, axis=0)
1848-
set_layer_name(gather_layer, target, name + "_gather")
1849-
return gather_layer.get_output(0)
1794+
# T113156424 embedding implemenatation is very limited and shows no usage in hf models due to the indices are int64.
1795+
# if we cast to int32, it will create accuracy issues. We'd better leave it to future implementation.
1796+
# @tensorrt_converter(acc_ops.embedding, no_implicit_batch_dim=True)
1797+
# def acc_ops_embedding(
1798+
# network: TRTNetwork,
1799+
# target: Target,
1800+
# args: Tuple[Argument, ...],
1801+
# kwargs: Dict[str, Argument],
1802+
# name: str,
1803+
# ) -> Union[TRTTensor, Sequence[TRTTensor]]:
1804+
# if network.has_implicit_batch_dimension:
1805+
# raise RuntimeError(
1806+
# "The `embedding` function should be called with explicit batch dimension."
1807+
# )
1808+
1809+
# indices_tensor = kwargs["input"]
1810+
# embedding_tensor = kwargs["weight"]
1811+
# if isinstance(indices_tensor, torch.Tensor) and indices_tensor.dtype == torch.int64:
1812+
# indices_tensor = indices_tensor.to(torch.int32)
1813+
# warnings.warn(
1814+
# "Embedding op has indices_tensor dtype=int64. Reduce it to int32 to run on TRT. Accuracy may not be correct!"
1815+
# )
1816+
# if (
1817+
# isinstance(embedding_tensor, torch.Tensor)
1818+
# and embedding_tensor.dtype == torch.int64
1819+
# ):
1820+
# embedding_tensor = embedding_tensor.to(torch.int32)
1821+
# warnings.warn(
1822+
# "Embedding op has embedding_tensor dtype=int64. Reduce it to int32 to run on TRT. Accuracy may not be correct!"
1823+
# )
1824+
# indices_tensor = get_trt_tensor(network, indices_tensor, f"{name}_indices_tensor")
1825+
# embedding_tensor = get_trt_tensor(
1826+
# network, embedding_tensor, f"{name}_embedding_tensor"
1827+
# )
1828+
1829+
# # unsupported parameters
1830+
# # ignore padding_idx since it is meaningful for training only
1831+
# max_norm = kwargs["max_norm"]
1832+
# norm_type = kwargs["norm_type"]
1833+
# scale_grad_by_freq = kwargs["scale_grad_by_freq"]
1834+
# sparse = kwargs["sparse"]
1835+
1836+
# if max_norm is not None:
1837+
# raise RuntimeError(
1838+
# f"Currently we don't support specifying max_norm, got {max_norm}."
1839+
# )
1840+
1841+
# if norm_type != 2.0:
1842+
# raise RuntimeError(
1843+
# f"Currently we don't support specifying max_norm, got {norm_type} for norm_type."
1844+
# )
1845+
1846+
# if scale_grad_by_freq:
1847+
# raise RuntimeError(
1848+
# "Currently we don't support scale gradient by word frequency."
1849+
# )
1850+
1851+
# if sparse:
1852+
# raise RuntimeError("Currently we don't support sparse gradient.")
1853+
1854+
# # Implement embedding lookup with gather layer
1855+
# gather_layer = network.add_gather(embedding_tensor, indices_tensor, axis=0)
1856+
# set_layer_name(gather_layer, target, name + "_gather")
1857+
# return gather_layer.get_output(0)
18501858

18511859

18521860
@tensorrt_converter(acc_ops.max_pool1d)
@@ -2342,12 +2350,8 @@ def acc_ops_reshape(
23422350
name: str,
23432351
) -> Union[TRTTensor, Sequence[TRTTensor]]:
23442352
input_val = kwargs["input"]
2345-
2346-
if not isinstance(input_val, TRTTensor):
2347-
raise RuntimeError(
2348-
f"Reshape received input {input_val} that is not part "
2349-
"of the TensorRT region!"
2350-
)
2353+
# for case where input_val is TRTensor
2354+
input_val = get_trt_tensor(network, input_val, f"{name}_input_val")
23512355

23522356
shape = kwargs["acc_out_ty"].shape # type: ignore[misc]
23532357
if network.has_implicit_batch_dimension:

fx/passes/lower_basic_pass.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -451,8 +451,8 @@ def transform_setitem(gm: torch.fx.GraphModule, input: Input):
451451
continue
452452
node.replace_all_uses_with(concat_node_0)
453453
map_replace[input_node] = concat_node_0
454+
gm.graph.erase_node(node)
454455

455-
gm.graph.eliminate_dead_code()
456456
gm.graph.lint()
457457
gm.recompile()
458458
return gm

test/converters/acc_op/test_embedding.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
import unittest
2+
13
import fx2trt_oss.tracer.acc_tracer.acc_ops as acc_ops
24
import torch
35
from parameterized import param, parameterized
4-
from torch.testing._internal.common_fx2trt import AccTestCase, InputTensorSpec
6+
from torch.testing._internal.common_fx2trt import AccTestCase
57
from torch.testing._internal.common_utils import run_tests
68

7-
9+
@unittest.skip("Current implementation is limited. All implementations in hf use int64. T113156424")
810
class TestEmbeddingConverter(AccTestCase):
911
@parameterized.expand(
1012
[

test/converters/acc_op/test_isinf.py

+3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
import unittest
2+
13
import fx2trt_oss.tracer.acc_tracer.acc_ops as acc_ops
24
import torch
35
from torch.testing._internal.common_fx2trt import AccTestCase
46
from torch.testing._internal.common_utils import run_tests
57

68

9+
@unittest.skip("Implementation is commented out due to accuracy issue T113156424")
710
class TestInfConverter(AccTestCase):
811
def test_isinf(self):
912
class Test(torch.nn.Module):

test/converters/acc_op/test_type_as.py

+37
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,43 @@ def forward(self, input, other):
8686
precision=LowerPrecision.FP16,
8787
)
8888

89+
def test_type_tensor(self):
90+
class Type_as(torch.nn.Module):
91+
def forward(self, input):
92+
return input.type(dtype=torch.float16)
93+
94+
input = torch.randn(2, 2)
95+
96+
inputs = [
97+
input,
98+
]
99+
self.run_test(
100+
Type_as(),
101+
inputs,
102+
expected_ops={acc_ops.to_dtype},
103+
precision=LowerPrecision.FP16,
104+
)
105+
106+
def test_type_tensor_ext(self):
107+
class Type_as(torch.nn.Module):
108+
def forward(self, input, other):
109+
t = input.type()
110+
return other.type(t)
111+
112+
input = torch.randn(2, 2).to(dtype=torch.float16)
113+
other = torch.randn(2, 2)
114+
115+
inputs = [
116+
input,
117+
other,
118+
]
119+
self.run_test(
120+
Type_as(),
121+
inputs,
122+
expected_ops={acc_ops.to_dtype, acc_ops.dtype},
123+
precision=LowerPrecision.FP16,
124+
)
125+
89126

90127
if __name__ == "__main__":
91128
run_tests()

tracer/acc_tracer/acc_ops.py

+24
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,30 @@ def avg_pool2d(
223223
def sign(*, input):
224224
return torch.sign(input)
225225

226+
@register_custom_acc_mapper_fn(
227+
op_and_target=("call_method", "type"),
228+
arg_replacement_tuples=[
229+
("input", "input"),
230+
("dtype", "dtype", this_arg_is_optional),
231+
],
232+
)
233+
def custom_type_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node:
234+
input_obj = node.kwargs["input"]
235+
dtype_obj = node.kwargs.get("dtype")
236+
with node.graph.inserting_before(node):
237+
if dtype_obj == None:
238+
dtype_node = node.graph.call_function(dtype, kwargs={"input": input_obj})
239+
dtype_node.meta["type"] = torch.dtype
240+
return dtype_node
241+
else:
242+
new_kwargs = {
243+
"input": input_obj,
244+
"acc_out_ty": acc_utils.build_raw_tensor_meta(dtype=dtype_obj),
245+
}
246+
new_node = node.graph.call_function(to_dtype, kwargs=new_kwargs)
247+
new_node.meta = node.meta
248+
return new_node
249+
226250

227251
@register_custom_acc_mapper_fn(
228252
op_and_target=("call_method", "type_as"),

tracer/acc_tracer/acc_tracer.py

+1
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ def create_node(
272272
name_target[-1] == "_"
273273
and name_target[0] != "_"
274274
and not (name_target in allow_list)
275+
and kind != "placeholder"
275276
):
276277
raise RuntimeError(
277278
f"Tried to trace mutable operation {name_target}. FX only supports functional code"

0 commit comments

Comments
 (0)