Skip to content

Commit cffd9c2

Browse files
jerryzh168Wei Wei
authored and
Wei Wei
committed
[fx2trt] Fix dummy weight initialization in conv1d converter (#78402)
Summary: X-link: pytorch/pytorch#78402 Pull Request resolved: pytorch/fx2trt#84 att, currently it errors out with the following error: ``` ---> 72 dummy_weight = trt.Weights(weight_shape) 73 layer = network.add_convolution_nd( 74 input=input_val, TypeError: __init__(): incompatible constructor arguments. The following argument types are supported: 1. tensorrt.tensorrt.Weights(type: tensorrt.tensorrt.DataType = <DataType.FLOAT: 0>) 2. tensorrt.tensorrt.Weights(a: numpy.ndarray) ``` full error: https://www.internalfb.com/phabricator/paste/view/P503598381 we need to pass arond a numpy ndarray instead of a shape here. Reviewed By: wushirong Differential Revision: D36721313 fbshipit-source-id: c13583c3accb37f781429b5905cfda8e856cebbc
1 parent f64f4c1 commit cffd9c2

File tree

2 files changed

+3
-5
lines changed

2 files changed

+3
-5
lines changed

fx/converters/acc_ops_converters.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,11 @@ def acc_ops_conv1d(
6565
unsqueeze_weight_layer = network.add_shuffle(input=weight)
6666
unsqueeze_weight_layer.reshape_dims = tuple([*weight.shape, 1])
6767
set_layer_name(unsqueeze_layer, target, name + "_unsqueeze_weight")
68-
weiht = unsqueeze_weight_layer.get_output(0)
68+
weight = unsqueeze_weight_layer.get_output(0)
6969
weight_shape = tuple(kwargs["weight"].shape) # type: ignore[union-attr]
7070
# will need to use uninitialized weight and set it later to support
7171
# ITensor weights
72-
dummy_weight = trt.Weights(weight_shape)
72+
dummy_weight = trt.Weights()
7373
layer = network.add_convolution_nd(
7474
input=input_val,
7575
num_output_maps=weight.shape[0],

test/quant/test_quant_trt.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@ def forward(self, x):
473473

474474
# just testing conv2d since conv1d and conv3d are not supported in fx2trt
475475
for dim, has_relu, f_relu, is_qat in itertools.product(
476-
[2], [True, False], [True, False], [True, False]
476+
[1, 2], [True, False], [True, False], [True, False]
477477
):
478478
# when has_relu=False, we have torch.nn.Identity, which would introduce
479479
# extra quant-dequat pair
@@ -564,7 +564,6 @@ def forward(self, x):
564564
ns.call_module(torch.nn.quantized._reference.Linear): 1,
565565
ns.call_module(torch.nn.quantized._reference.Conv2d): 1,
566566
}
567-
print(m)
568567
self.checkGraphModuleNodes(m, expected_node_occurrence=expected_occurrence)
569568

570569
def test_unsupported_qconfig(self):
@@ -725,7 +724,6 @@ def conv_add_extra_inputs_getter(pattern):
725724
example_inputs,
726725
backend_config_dict=modified_backend_config_dict,
727726
)
728-
print(m)
729727
node_occurrence = {
730728
ns.call_module(torch.ao.quantization.HistogramObserver): 3,
731729
}

0 commit comments

Comments
 (0)