-
Notifications
You must be signed in to change notification settings - Fork 35
Any idea on how to support ConvTranspose2d? #2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Hi |
It seems not working well... import torch.nn.functional as F
from torch import nn
import tensorrt as trt
import torch
from torch2trt import torch2trt
class TestModel(torch.nn.Module):
def __init__(self, out_dims):
super(TestModel, self).__init__()
self.layer = nn.ConvTranspose2d(3, out_dims, 2, 2)
def forward(self, x):
return self.layer(x)
test_model = TestModel(256).cuda()
input_shape = (1, 3, 300, 400)
dummy_tensor = torch.randn(input_shape, dtype=torch.float32).cuda()
# output is (1, 256, 600, 800)
print(test_model(dummy_tensor).shape)
# convert test model to trt
import tensorrt as trt
opt_shape_param = [
[
[1, 3, 160, 240], # min
[1, 3, 800, 1200], # opt
[1, 3, 1600, 2400] # max
]
]
with torch.no_grad():
trt_model = torch2trt(
test_model,
[dummy_tensor],
fp16_mode=False,
opt_shape_param=opt_shape_param,
)
# test trt model
dummy_tensor = torch.randn((1, 3, 300, 400), dtype=torch.float32).cuda()
# except output is (1, 256, 800, 800)
print(trt_model(dummy_tensor).shape)
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-6-571ab0488d9e> in <module>
41 dummy_tensor = torch.randn((1, 3, 400, 400), dtype=torch.float32).cuda()
42 # except output is (1, 3, 800, 800), but actually the output shape is still (1, 3, 600, 800)
---> 43 print(trt_model(dummy_tensor).shape)
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
530 result = self._slow_forward(*input, **kwargs)
531 else:
--> 532 result = self.forward(*input, **kwargs)
533 for hook in self._forward_hooks.values():
534 hook_result = hook(self, input, result)
/opt/venv/ocr-detection-detectron2/lib/python3.6/site-packages/torch2trt/torch2trt.py in forward(self, *inputs)
421
422 for i, input_name in enumerate(self.input_names):
--> 423 idx = self.engine.get_binding_index(input_name)
424
425 self.context.set_binding_shape(idx, tuple(inputs[i].shape))
AttributeError: 'NoneType' object has no attribute 'get_binding_index'
|
Ok. I see. |
Hi I have reduce the
It works. CUDNN need a large workspace to do the deconv. Guess (1600, 2400) reach the memory limit. |
WTF, thanks! |
besides, I found It's not working with
interpolate
, Test code as follow:layer.scales = tuple([float(output_shape / input_shape) for input_shape, output_shape in zip(input.shape, output.shape)])
, the output shape is expected.The text was updated successfully, but these errors were encountered: